Skip to content

Instantly share code, notes, and snippets.

@shaabhishek
Created June 10, 2022 19:42
Show Gist options
  • Select an option

  • Save shaabhishek/005f5a5c8ca7363331953e6f073f4df6 to your computer and use it in GitHub Desktop.

Select an option

Save shaabhishek/005f5a5c8ca7363331953e6f073f4df6 to your computer and use it in GitHub Desktop.

Revisions

  1. shaabhishek created this gist Jun 10, 2022.
    141 changes: 141 additions & 0 deletions pytorch_custom_distributions.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,141 @@
    class InverseGamma(TransformedDistribution):
    """
    https://en.wikipedia.org/wiki/Inverse-gamma_distribution
    Creates an inverse-gamma distribution parameterized by
    `concentration` and `rate`.
    X ~ Gamma(concentration, rate)
    Y = 1/X ~ InverseGamma(concentration, rate)
    :param torch.Tensor concentration: the concentration parameter (i.e. alpha on wikipedia).
    :param torch.Tensor rate: the rate parameter (i.e. beta on wikipedia).
    """
    arg_constraints = {"concentration": constraints.positive, "rate": constraints.positive}
    support = constraints.positive
    has_rsample = True

    def __init__(self, concentration, rate, validate_args=None):
    base_distribution = Gamma(concentration, rate)
    super().__init__(base_distribution, PowerTransform(-torch.ones_like(concentration)), validate_args=validate_args)

    def expand(self, batch_shape, _instance=None):
    new = self._get_checked_instance(InverseGamma, _instance)
    return super().expand(batch_shape, _instance=new)

    def entropy(self):
    """
    https://en.wikipedia.org/wiki/Inverse-gamma_distribution
    alpha + log(beta * Gamma(alpha)) - (alpha + 1) * Digamma(alpha)
    """
    return (self.concentration + torch.log(self.rate) + torch.lgamma(self.concentration) - (1.0 + self.concentration) * torch.digamma(self.concentration))

    @property
    def concentration(self):
    return self.base_dist.concentration

    @property
    def rate(self):
    return self.base_dist.rate


    class NormalGamma(Distribution):
    """
    https://en.wikipedia.org/wiki/Normal-gamma_distribution
    Creates an normal-gamma distribution parameterized by
    `mu`, `lambd`, `concentration` and `rate`.
    P ~ Gamma(concentration, rate)
    X ~ Normal(mu, variance=(lambd*P)^-1)
    => (X,P) ~ NormalGamma(mu, lambd, concentration, rate)
    :param torch.Tensor mu: the mean parameter for the normal distribution (i.e. mu on wikipedia).
    :param torch.Tensor lambd: the scaling of precision for the normal distribution (i.e. lambda on wikipedia).
    :param torch.Tensor concentration: the concentration parameter for the gamma distribution (i.e. alpha on wikipedia).
    :param torch.Tensor rate: the rate parameter for the gamma distribution (i.e. beta on wikipedia).
    """
    arg_constraints = {"mu": constraints.real, "lambd": constraints.positive, "concentration": constraints.positive, "rate": constraints.positive}
    support = constraints.positive #TODO
    has_rsample = True

    def __init__(self, mu, lambd, concentration, rate, validate_args=None):
    self._gamma = Gamma(concentration, rate, validate_args=validate_args)
    self._mu = mu
    self._lambd = lambd
    batch_shape = self.mu.size()
    event_shape = torch.Size([2])
    super().__init__(batch_shape, event_shape=event_shape, validate_args=validate_args)

    def expand(self, batch_shape, _instance=None):
    # new = self._get_checked_instance(NormalGamma, _instance)
    # return super().expand(batch_shape, _instance=new)
    new = self._get_checked_instance(NormalGamma, _instance)
    batch_shape = torch.Size(batch_shape)
    new._gamma.concentration = self.concentration.expand(batch_shape)
    new._gamma.rate = self.rate.expand(batch_shape)
    new.lambd = self.lambd.expand(batch_shape)
    new._mu = self.mu.expand(batch_shape)
    super(NormalGamma, new).__init__(batch_shape, event_shape=self._event_shape, validate_args=False)
    new._validate_args = self._validate_args
    return new

    def sample(self, sample_shape=()):
    precision = self._gamma.sample(sample_shape)
    mu = Normal(self.mu, (self.lambd*precision).sqrt().reciprocal()).sample()
    return torch.stack([mu, precision], dim=-1)

    def rsample(self, sample_shape=()):
    precision = self._gamma.rsample(sample_shape)
    mu = Normal(self.mu, (self.lambd*precision).sqrt().reciprocal()).rsample()
    return torch.stack([mu, precision], dim=-1)

    def entropy(self):
    """
    https://en.wikipedia.org/wiki/Normal-gamma_distribution
    alpha + log(beta * Gamma(alpha)) - (alpha + 1) * Digamma(alpha)
    """
    return (self.concentration + torch.log(self.rate) + torch.lgamma(self.concentration) - (1.0 + self.concentration) * torch.digamma(self.concentration))

    @property
    def concentration(self):
    return self._gamma.concentration

    @property
    def rate(self):
    return self._gamma.rate

    @property
    def lambd(self):
    return self._lambd

    @property
    def mu(self):
    return self._mu

    @property
    def mean(self):
    return torch.stack([self.mu, self._gamma.mean], dim=-1)

    @property
    def variance(self):
    variance_mean = self.rate / (self.lambd * (self.concentration - 1))
    variance_precision = self._gamma.variance
    return torch.stack([variance_mean, variance_precision], dim=-1)

    def log_prob(self, value):
    value = torch.as_tensor(value, dtype=self.mu.dtype, device=self.mu.device)
    mean = value[..., 0]
    precision = value[..., 1]
    sq_dist = (mean - self.mu) ** 2

    return (self.concentration * torch.log(self.rate) +
    0.5 * torch.log(self.lambd) +
    (self.concentration - 0.5) * torch.log(precision) -
    self.rate * precision -
    0.5 * self.lambd * precision * sq_dist -
    torch.lgamma(self.concentration) -
    math.log(math.sqrt(2 * math.pi)))