Created
June 10, 2022 19:42
-
-
Save shaabhishek/005f5a5c8ca7363331953e6f073f4df6 to your computer and use it in GitHub Desktop.
Revisions
-
shaabhishek created this gist
Jun 10, 2022 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,141 @@ class InverseGamma(TransformedDistribution): """ https://en.wikipedia.org/wiki/Inverse-gamma_distribution Creates an inverse-gamma distribution parameterized by `concentration` and `rate`. X ~ Gamma(concentration, rate) Y = 1/X ~ InverseGamma(concentration, rate) :param torch.Tensor concentration: the concentration parameter (i.e. alpha on wikipedia). :param torch.Tensor rate: the rate parameter (i.e. beta on wikipedia). """ arg_constraints = {"concentration": constraints.positive, "rate": constraints.positive} support = constraints.positive has_rsample = True def __init__(self, concentration, rate, validate_args=None): base_distribution = Gamma(concentration, rate) super().__init__(base_distribution, PowerTransform(-torch.ones_like(concentration)), validate_args=validate_args) def expand(self, batch_shape, _instance=None): new = self._get_checked_instance(InverseGamma, _instance) return super().expand(batch_shape, _instance=new) def entropy(self): """ https://en.wikipedia.org/wiki/Inverse-gamma_distribution alpha + log(beta * Gamma(alpha)) - (alpha + 1) * Digamma(alpha) """ return (self.concentration + torch.log(self.rate) + torch.lgamma(self.concentration) - (1.0 + self.concentration) * torch.digamma(self.concentration)) @property def concentration(self): return self.base_dist.concentration @property def rate(self): return self.base_dist.rate class NormalGamma(Distribution): """ https://en.wikipedia.org/wiki/Normal-gamma_distribution Creates an normal-gamma distribution parameterized by `mu`, `lambd`, `concentration` and `rate`. P ~ Gamma(concentration, rate) X ~ Normal(mu, variance=(lambd*P)^-1) => (X,P) ~ NormalGamma(mu, lambd, concentration, rate) :param torch.Tensor mu: the mean parameter for the normal distribution (i.e. mu on wikipedia). :param torch.Tensor lambd: the scaling of precision for the normal distribution (i.e. lambda on wikipedia). :param torch.Tensor concentration: the concentration parameter for the gamma distribution (i.e. alpha on wikipedia). :param torch.Tensor rate: the rate parameter for the gamma distribution (i.e. beta on wikipedia). """ arg_constraints = {"mu": constraints.real, "lambd": constraints.positive, "concentration": constraints.positive, "rate": constraints.positive} support = constraints.positive #TODO has_rsample = True def __init__(self, mu, lambd, concentration, rate, validate_args=None): self._gamma = Gamma(concentration, rate, validate_args=validate_args) self._mu = mu self._lambd = lambd batch_shape = self.mu.size() event_shape = torch.Size([2]) super().__init__(batch_shape, event_shape=event_shape, validate_args=validate_args) def expand(self, batch_shape, _instance=None): # new = self._get_checked_instance(NormalGamma, _instance) # return super().expand(batch_shape, _instance=new) new = self._get_checked_instance(NormalGamma, _instance) batch_shape = torch.Size(batch_shape) new._gamma.concentration = self.concentration.expand(batch_shape) new._gamma.rate = self.rate.expand(batch_shape) new.lambd = self.lambd.expand(batch_shape) new._mu = self.mu.expand(batch_shape) super(NormalGamma, new).__init__(batch_shape, event_shape=self._event_shape, validate_args=False) new._validate_args = self._validate_args return new def sample(self, sample_shape=()): precision = self._gamma.sample(sample_shape) mu = Normal(self.mu, (self.lambd*precision).sqrt().reciprocal()).sample() return torch.stack([mu, precision], dim=-1) def rsample(self, sample_shape=()): precision = self._gamma.rsample(sample_shape) mu = Normal(self.mu, (self.lambd*precision).sqrt().reciprocal()).rsample() return torch.stack([mu, precision], dim=-1) def entropy(self): """ https://en.wikipedia.org/wiki/Normal-gamma_distribution alpha + log(beta * Gamma(alpha)) - (alpha + 1) * Digamma(alpha) """ return (self.concentration + torch.log(self.rate) + torch.lgamma(self.concentration) - (1.0 + self.concentration) * torch.digamma(self.concentration)) @property def concentration(self): return self._gamma.concentration @property def rate(self): return self._gamma.rate @property def lambd(self): return self._lambd @property def mu(self): return self._mu @property def mean(self): return torch.stack([self.mu, self._gamma.mean], dim=-1) @property def variance(self): variance_mean = self.rate / (self.lambd * (self.concentration - 1)) variance_precision = self._gamma.variance return torch.stack([variance_mean, variance_precision], dim=-1) def log_prob(self, value): value = torch.as_tensor(value, dtype=self.mu.dtype, device=self.mu.device) mean = value[..., 0] precision = value[..., 1] sq_dist = (mean - self.mu) ** 2 return (self.concentration * torch.log(self.rate) + 0.5 * torch.log(self.lambd) + (self.concentration - 0.5) * torch.log(precision) - self.rate * precision - 0.5 * self.lambd * precision * sq_dist - torch.lgamma(self.concentration) - math.log(math.sqrt(2 * math.pi)))