[ Python ] gumbel softmax 알아보기

2019. 9. 14. 22:50분석 Python/Tensorflow

728x90

도움이 되셨다면, 광고 한 번만 눌러주세요.  블로그 관리에 큰 힘이 됩니다 :)

 

예전에 gumbel softmax 관련 영상을 보고 관련된 자료도 찾아봤자만, 이해가 안 됐고 당시에 코드도 Tensorflow로 많이 없어서 포기했다가, 최근에 다시 봐야 할 것 같아서 여러 코드 구현물을 찾고 내용을 다시 이해해보려고 한다.

일반적으로 DNN을 학습시키는 것은 모델을 구성하고 Loss를 정의한 다음에 gradient에 따라 점진적인 학습을 한다.

그러나 때 때로 이것은 랜덤 구성요소를 규합시키는 아키텍처에서는 쉽지 않다. forwad pass는 더 이상 인풋과 가중치들의 deterministic function이 아니다. 랜덤 구성요소는 샘플링하는 수단으로 stochasticity를 도입한다.

샘플링한 것을 Backpropagation 할 경우에는 일반적인 방법으로는 미분 가능하지 않게 된다.  그리고 적분도 할 수 없는 꼴이 된다. 그래서 보통 이런 문제를 해결하기 위해 Reparmeterization trick을 쓰긴 한다. 샘플링된 랜덤 변수를 파라미터가 없는 랜덤 변수의 deterministic  매개변수로 대체한다. 간략하게 말한다면, stochastic 부분과 deterministic부분을 분해시켜서 deterministic 한 부분으로 backpropagation을 흐르게 한다 

 

여기서 deterministic 한 부분은 vae에서는 평균과 분산 stochastic 한 부분은 epsilon이라고 생각한다.

여러 유형의 연속 분포에 대해 reparameterization trick을 수행할 수 있다.

 

그러나 만약에 값들의 discrete set에 대한 분포가 필요하다면 어떻게 해야 할까?


In the following sections you'll learn:

  • what the Gumbel distribution is
  • how it is used for sampling from a discrete distribution
  • how the weights that affect the distribution's parameters can be trained
  • how to use all of that in a toy example (with code)

The Gumbel distribution

$\mu, \beta$의 두 파라미터를 가진다. 

표준 Gumbel distribution은 $\mu = 0 , \beta = 1$을 가진다. 

만약 $logit (\alpha_i = i ~ k )$ discrete random variable이 있다고 고려할 때, logits은 학습에 필요한 인풋과 가중치의 함수이다. 

 

Gumbel softmax trick을 사용하여 discrete distribution을 샘플링한다.

이런 식으로 하게 되면 분포를 파라미터가 없는 분포의 deterministic 변형으로 대체할 수 있다.

그래서 모델에 연결하면 gradient는 logits의 가중치에 잘 전파될 수 있다.

hard one hot vector를 사용하는 대신에 softmax를 사용하여 approximate 한다고 한다.

이 과정은 위에 있는 argmax 대신에 softmax 대체해서 사용한다는 것이다. 

τ라는 것으로 나눔으로써, 근사치가 argmax에 얼마나 가깝게 할지 제어할 수 있다. 

τ 가 0에 가까우면 1에 가까워진다 argmax처럼 된다.

τ가 무한에 가까우 uniform처럼 된다. 

 τ를 작게 할수록, 더 좋은 근사치를 얻을 수 있다. 

τ를 작은 값으로 설정하는 것의 문제점은 Gradient 분산이 너무 높다는 것이다. 즉 학습시키기가 어렵다. 

그래서 일단 크게 한 다음 작게 해봐야 한다.

 

Tensorflow에서 gumbel softmax를 쉽게 구현한 것이 있어서 공유한다.

def sample_gumbel(shape, eps=1e-20):
    U = tf.random_uniform(shape, minval=0, maxval=1)
    return -tf.log(-tf.log(U + eps) + eps)

def gumbel_softmax(logits, temperature, hard=False):
    gumbel_softmax_sample = logits + sample_gumbel(tf.shape(logits))
    y = tf.nn.softmax(gumbel_softmax_sample / temperature)

    if hard:
        k = tf.shape(logits)[-1]
        y_hard = tf.cast(tf.equal(y, tf.reduce_max(y, 1, keep_dims=True)),
                         y.dtype)
        y = tf.stop_gradient(y_hard - y) + y

    return y

 

아래에는 toy example abcde라는 랜덤 하게 있는 한 discrete 분포를 학습해보는 것이다.

일반적인 방법으로는 빈도를 개선할 수 있지만 여기서는 GAN을 사용해서 해보는 것을 하려고 한다.  

그렇게 했을 때 단숨 빈도를 통해서 확률적으로 샘플링하는 것이 아닌 GAN을 통해 생성을 할 시, 빈도가 작은 부분에 대해서 생성을 많이 하게 된다면 Loss를 많이 줌으로써 그쪽 부분에 대한 생성을 줄이게 할 수 있다. 

 

BATCHS_IN_EPOCH = 100
BATCH_SIZE = 10
EPOCHS = 200  # the stream is infinite so one epoch will be defined as BATCHS_IN_EPOCH * BATCH_SIZE
GENERATOR_TRAINING_FACTOR = 10  # for every training of the disctiminator we'll train the generator 10 times
LEARNING_RATE = 0.0007
TEMPERATURE = 0.001 

import random , pandas as pd
tf.reset_default_graph()
data = pd.get_dummies(pd.DataFrame(np.random.choice(list("abcde") , 50, p = [0.1,0.2,0.2,0.2,0.3] , replace = True)))
#data = np.argmax(data.values , axis =1)
x = tf.placeholder(tf.float32 , shape = [None ,5])
z = tf.placeholder(tf.float32 , shape = [None , 5])

sample_z = np.random.normal(size = data.shape)

def generator(z):
    with tf.variable_scope('generator'):
        h = tf.layers.dense(z , 10 )
        logits = tf.layers.dense(h , len(list("abcde") ) )
        gumbel_dist = tf.contrib.distributions.RelaxedOneHotCategorical(TEMPERATURE, 
                                                                        logits=logits)
        probs = tf.nn.softmax(logits)
        generated = gumbel_dist.sample(BATCH_SIZE)
        return generated, probs

def discriminator(x):
    with tf.variable_scope('discriminator', reuse=tf.AUTO_REUSE):
        return tf.contrib.layers.fully_connected(x,
                                                 num_outputs=1,
                                                 activation_fn=None)

generated_outputs, generated_probs = generator(z)
discriminated_real = discriminator(x)
discriminated_generated = discriminator(generated_outputs)

d_loss_real = tf.reduce_mean(
    tf.nn.sigmoid_cross_entropy_with_logits(logits=discriminated_real,
                                            labels=tf.ones_like(discriminated_real)))
d_loss_fake = tf.reduce_mean(
    tf.nn.sigmoid_cross_entropy_with_logits(logits=discriminated_generated,
                                            labels=tf.zeros_like(discriminated_generated)))
d_loss = d_loss_real + d_loss_fake

g_loss = tf.reduce_mean(
    tf.nn.sigmoid_cross_entropy_with_logits(logits=discriminated_generated,
                                            labels=tf.ones_like(discriminated_generated)))

all_vars = tf.trainable_variables()
g_vars = [var for var in all_vars if var.name.startswith('generator')]
d_vars = [var for var in all_vars if var.name.startswith('discriminator')]

d_train_opt = tf.train.AdamOptimizer(LEARNING_RATE).minimize(d_loss, var_list=d_vars)
g_train_opt = tf.train.AdamOptimizer(LEARNING_RATE).minimize(g_loss, var_list=g_vars)


with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    learned_probs = []
    for i in range(EPOCHS):
        print(i)
        for _ in range(BATCHS_IN_EPOCH):
            sess.run(d_train_opt , feed_dict = {x : data.values , z : sample_z  })
            for _ in range(GENERATOR_TRAINING_FACTOR):
                sess.run(g_train_opt , feed_dict = {x : data.values , z : sample_z  })
        learned_probs.append(sess.run(generated_probs , feed_dict = {z : np.random.normal(size = (1,data.shape[1])) }))

plt.figure(figsize=(10, 2))
prob_errors = [np.array(learned_prob[0]) - np.array([0.1,0.2,0.2,0.2,0.3])
               for learned_prob in learned_probs]

plt.imshow(np.transpose(prob_errors) ,
           cmap='bwr',
           aspect='auto',
           vmin=-2,
           vmax=2)
plt.xlabel('epoch')
plt.ylabel('number')
plt.colorbar(aspect=10, ticks=[-2, 0, 2]);

 

Counter({2: 9, 4: 11, 3: 11, 1: 13, 0: 6})

다음에는 좀 더 다양한 변수와 함께 있을 때 어떻게 처리하는지를 알아봐야겠다.

 

 

https://hulk89.github.io/machine%20 learning/2017/11/20/reparametrization-trick/

 

Reparametrization Trick 정리 · Hulk의 개인 공부용 블로그

Reparametrization Trick 정리 20 Nov 2017 | ml reparametrization sampling 참고 0: 요걸 기본으로 해서 이 포스트를 작성! 참고 1: Eric jang의 블로그 Motivation deep learning에서, 보통 우리는 $x \sim p_\theta(x)$에서 draw한 sample들을 통해 gradient를 backpropagation을 하게 된다. 물론 여기서 $p_\theta(x)$는 lear

hulk89.github.io

https://www.facebook.com/groups/TensorFlowKR/permalink/612839115723817/

 

보안 확인 필요

메뉴를 열려면 alt + / 키 조합을 누르세요

www.facebook.com

https://www.youtube.com/watch?v=ty3SciyoIyk

 

http://anotherdatum.com/gumbel-gan.html

 

Neural Networks gone wild! They can sample from discrete distributions now!

Learn how to use Gumbel distribution to form a NN containing a discrete random component.

anotherdatum.com

https://blog.evjang.com/2016/11/tutorial-categorical-variational.html

 

Tutorial: Categorical Variational Autoencoders using Gumbel-Softmax

In this post, I discuss our recent paper, Categorical Reparameterization with Gumbel-Softmax , which introduces a simple technique for t...

blog.evjang.com

https://github.com/vithursant/VAE-Gumbel-Softmax/blob/master/vae_gumbel_softmax.py

 

vithursant/VAE-Gumbel-Softmax

An implementation of a Variational-Autoencoder using the Gumbel-Softmax reparametrization trick in TensorFlow (tested on r1.5 CPU and GPU) in ICLR 2017. - vithursant/VAE-Gumbel-Softmax

github.com

728x90