Created
          January 12, 2018 12:25 
        
      - 
            
      
        
      
    Star
      
          
          (116)
      
  
You must be signed in to star a gist 
- 
              
      
        
      
    Fork
      
          
          (14)
      
  
You must be signed in to fork a gist 
- 
      
- 
        Save yzh119/fd2146d2aeb329d067568a493b20172f to your computer and use it in GitHub Desktop. 
    ST-Gumbel-Softmax-Pytorch
  
        
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
  | from __future__ import print_function | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch.autograd import Variable | |
| def sample_gumbel(shape, eps=1e-20): | |
| U = torch.rand(shape).cuda() | |
| return -Variable(torch.log(-torch.log(U + eps) + eps)) | |
| def gumbel_softmax_sample(logits, temperature): | |
| y = logits + sample_gumbel(logits.size()) | |
| return F.softmax(y / temperature, dim=-1) | |
| def gumbel_softmax(logits, temperature): | |
| """ | |
| input: [*, n_class] | |
| return: [*, n_class] an one-hot vector | |
| """ | |
| y = gumbel_softmax_sample(logits, temperature) | |
| shape = y.size() | |
| _, ind = y.max(dim=-1) | |
| y_hard = torch.zeros_like(y).view(-1, shape[-1]) | |
| y_hard.scatter_(1, ind.view(-1, 1), 1) | |
| y_hard = y_hard.view(*shape) | |
| return (y_hard - y).detach() + y | |
| if __name__ == '__main__': | |
| import math | |
| print(gumbel_softmax(Variable(torch.cuda.FloatTensor([[math.log(0.1), math.log(0.4), math.log(0.3), math.log(0.2)]] * 20000)), 0.8).sum(dim=0)) | 
Hi, I am trying to implement this gumbel-softmax trick to a vae autoencoder for data synthesization. Here is the implementation. Am i doing something wrong ? thank you
import logging
import pandas as pd
import tensorflow as tf
import tensorflow_probability as tfp
from tensorflow.contrib.distributions import (Bernoulli, OneHotCategorical,
                                              RelaxedOneHotCategorical,
                                              kl_divergence)
from tensorflow.keras import backend as K
from tensorflow.keras import layers
from tensorflow.keras.models import Model
logging.getLogger('tensorflow').disabled = True
class DiscreteVAE:
    def encoder(self, latent_dim, input_dim):
        encoder_input = layers.Input(shape=(input_dim, ), name='encoder_input')
        x = encoder_input
        x = layers.Dense(256,
                         activation='relu',
                         kernel_initializer='random_uniform',
                         name='Dense_1')(x)
        x = layers.Dropout(0.3)(x)
        x = layers.Dense(256,
                         activation='relu',
                         kernel_initializer='random_uniform',
                         name='Dense_2')(x)
        x = tf.keras.layers.Dense(latent_dim)(x)
        encoder_model = Model(inputs=encoder_input, outputs=x)
        encoder_model.summary()
        return encoder_model
    def decoder(self, latent_dim, input_dim):
        decoder_input = layers.Input(latent_dim, name='decoder_input')
        x = decoder_input
        x = layers.Dense(256,
                         activation='relu',
                         kernel_initializer='random_uniform',
                         name='Dense_1')(x)
        x = layers.Dense(256,
                         activation='relu',
                         kernel_initializer='random_uniform',
                         name='Dense_2')(x)
        decoded_input = layers.Dense(input_dim, name='decoded_input')(x)
        decoder_model = Model(decoder_input, decoded_input)
        decoder_model.summary()
        return decoder_model
    def sample_gumbel(self, shape, eps=1e-20):
        """Sample from Gumbel(0, 1)"""
        U = tf.random_uniform(shape, minval=0, maxval=1, dtype=tf.float32)
        return -tf.log(-tf.log(U + eps) + eps)
    def gumbel_softmax_sample(self, logits, temperature):
        """ Draw a sample from the Gumbel-Softmax distribution"""
        y = logits + self.sample_gumbel(tf.shape(logits))
        return tf.nn.softmax(y / temperature)
    def gumbel_softmax(self, args):
        """Sample from the Gumbel-Softmax distribution and optionally discretize.
        Args:
            logits: [batch_size, n_class] unnormalized log-probs
            temperature: non-negative scalar
            hard: if True, take argmax, but differentiate w.r.t. soft sample y
        Returns:
            [batch_size, n_class] sample from the Gumbel-Softmax distribution.
            If hard=True, then the returned sample will be one-hot, otherwise it will
            be a probability distribution that sums to 1 across classes
        """
        logits, temperature = args
        y = self.gumbel_softmax_sample(logits, temperature)
        # k = tf.shape(logits)[-1]
        # y_hard = tf.cast(tf.one_hot(tf.argmax(y, 1), k), y.dtype)
        y_hard = tf.cast(tf.equal(y, tf.reduce_max(y, 1, keep_dims=True)),
                         y.dtype)
        y = tf.stop_gradient(y_hard - y) + y
        return y
    def CatVAE_loss(self, encoded_input, decoded_input, z, x, tau, latent_dim):
        reconstruction_error = tf.reduce_sum(
            Bernoulli(logits=decoded_input).log_prob(x), 1)
        logits_pz = tf.ones_like(decoded_input) * (1. / latent_dim)
        q_cat_z = OneHotCategorical(logits=encoded_input)
        p_cat_z = OneHotCategorical(logits=logits_pz)
        KL_qp = kl_divergence(q_cat_z, p_cat_z)
        ELBO = tf.reduce_mean(reconstruction_error - KL_qp)
        loss = -ELBO
        return loss
    def build_vae(self, latent_dim, input_dim, opt, data):
        tau = 0.5
        input_x = layers.Input(shape=input_dim, name='vae_input')
        encoder_m = self.encoder(latent_dim, input_dim)
        logits_y = encoder_m(input_x)
        z = layers.Lambda(self.gumbel_softmax)([logits_y, tau])
        decoder_m = self.decoder(latent_dim, input_dim)
        decoded_input = decoder_m(z)
        # loss = self.vae_loss(input_x, input_dim, decoded_input, data)
        loss = self.CatVAE_loss(logits_y, decoded_input, z, input_x, tau,
                                latent_dim)
        vae = Model(input_x, decoded_input)
        vae.add_loss(loss)
        vae.compile(optimizer=opt)
        return vae, decoder_m, encoder_m
  
    Sign up for free
    to join this conversation on GitHub.
    Already have an account?
    Sign in to comment
  
            
@ibrahim10h Right it's okay in this case because it's sending actual log of normalized probabilities. But in general neural network, we refer the output of network as
logitswhich could be the log of normalized probabilities with arbitrary offset. This stand alone example is correct, but it could induce potential error for people carelessly just copy paste.