class InverseGamma(TransformedDistribution): """ https://en.wikipedia.org/wiki/Inverse-gamma_distribution Creates an inverse-gamma distribution parameterized by `concentration` and `rate`. X ~ Gamma(concentration, rate) Y = 1/X ~ InverseGamma(concentration, rate) :param torch.Tensor concentration: the concentration parameter (i.e. alpha on wikipedia). :param torch.Tensor rate: the rate parameter (i.e. beta on wikipedia). """ arg_constraints = {"concentration": constraints.positive, "rate": constraints.positive} support = constraints.positive has_rsample = True def __init__(self, concentration, rate, validate_args=None): base_distribution = Gamma(concentration, rate) super().__init__(base_distribution, PowerTransform(-torch.ones_like(concentration)), validate_args=validate_args) def expand(self, batch_shape, _instance=None): new = self._get_checked_instance(InverseGamma, _instance) return super().expand(batch_shape, _instance=new) def entropy(self): """ https://en.wikipedia.org/wiki/Inverse-gamma_distribution alpha + log(beta * Gamma(alpha)) - (alpha + 1) * Digamma(alpha) """ return (self.concentration + torch.log(self.rate) + torch.lgamma(self.concentration) - (1.0 + self.concentration) * torch.digamma(self.concentration)) @property def concentration(self): return self.base_dist.concentration @property def rate(self): return self.base_dist.rate class NormalGamma(Distribution): """ https://en.wikipedia.org/wiki/Normal-gamma_distribution Creates an normal-gamma distribution parameterized by `mu`, `lambd`, `concentration` and `rate`. P ~ Gamma(concentration, rate) X ~ Normal(mu, variance=(lambd*P)^-1) => (X,P) ~ NormalGamma(mu, lambd, concentration, rate) :param torch.Tensor mu: the mean parameter for the normal distribution (i.e. mu on wikipedia). :param torch.Tensor lambd: the scaling of precision for the normal distribution (i.e. lambda on wikipedia). :param torch.Tensor concentration: the concentration parameter for the gamma distribution (i.e. alpha on wikipedia). :param torch.Tensor rate: the rate parameter for the gamma distribution (i.e. beta on wikipedia). """ arg_constraints = {"mu": constraints.real, "lambd": constraints.positive, "concentration": constraints.positive, "rate": constraints.positive} support = constraints.positive #TODO has_rsample = True def __init__(self, mu, lambd, concentration, rate, validate_args=None): self._gamma = Gamma(concentration, rate, validate_args=validate_args) self._mu = mu self._lambd = lambd batch_shape = self.mu.size() event_shape = torch.Size([2]) super().__init__(batch_shape, event_shape=event_shape, validate_args=validate_args) def expand(self, batch_shape, _instance=None): # new = self._get_checked_instance(NormalGamma, _instance) # return super().expand(batch_shape, _instance=new) new = self._get_checked_instance(NormalGamma, _instance) batch_shape = torch.Size(batch_shape) new._gamma.concentration = self.concentration.expand(batch_shape) new._gamma.rate = self.rate.expand(batch_shape) new.lambd = self.lambd.expand(batch_shape) new._mu = self.mu.expand(batch_shape) super(NormalGamma, new).__init__(batch_shape, event_shape=self._event_shape, validate_args=False) new._validate_args = self._validate_args return new def sample(self, sample_shape=()): precision = self._gamma.sample(sample_shape) mu = Normal(self.mu, (self.lambd*precision).sqrt().reciprocal()).sample() return torch.stack([mu, precision], dim=-1) def rsample(self, sample_shape=()): precision = self._gamma.rsample(sample_shape) mu = Normal(self.mu, (self.lambd*precision).sqrt().reciprocal()).rsample() return torch.stack([mu, precision], dim=-1) def entropy(self): """ https://en.wikipedia.org/wiki/Normal-gamma_distribution alpha + log(beta * Gamma(alpha)) - (alpha + 1) * Digamma(alpha) """ return (self.concentration + torch.log(self.rate) + torch.lgamma(self.concentration) - (1.0 + self.concentration) * torch.digamma(self.concentration)) @property def concentration(self): return self._gamma.concentration @property def rate(self): return self._gamma.rate @property def lambd(self): return self._lambd @property def mu(self): return self._mu @property def mean(self): return torch.stack([self.mu, self._gamma.mean], dim=-1) @property def variance(self): variance_mean = self.rate / (self.lambd * (self.concentration - 1)) variance_precision = self._gamma.variance return torch.stack([variance_mean, variance_precision], dim=-1) def log_prob(self, value): value = torch.as_tensor(value, dtype=self.mu.dtype, device=self.mu.device) mean = value[..., 0] precision = value[..., 1] sq_dist = (mean - self.mu) ** 2 return (self.concentration * torch.log(self.rate) + 0.5 * torch.log(self.lambd) + (self.concentration - 0.5) * torch.log(precision) - self.rate * precision - 0.5 * self.lambd * precision * sq_dist - torch.lgamma(self.concentration) - math.log(math.sqrt(2 * math.pi)))