[Paper][RL] [ToDo]Mutual Information State Intrinsic Control 리뷰

2022. 5. 19. 21:20관심있는 주제/Paper

https://arxiv.org/abs/2103.08107#:~:text=Reinforcement%20learning%20has%20been%20shown,defining%20an%20intrinsic%20reward%20function.

 

Mutual Information State Intrinsic Control

Reinforcement learning has been shown to be highly successful at many challenging tasks. However, success heavily relies on well-shaped rewards. Intrinsically motivated RL attempts to remove this constraint by defining an intrinsic reward function. Motivat

arxiv.org

 

 

본 논문에서는 리워드에 의존하는 문제를 해결하기 위한 방법으로 intrinsic reward를 정의하려고 함.

심리학에서의 자의식 개념에 의해 동기 부여되어, 우리는 에이전트가 자신을 구성하는 것이 무엇인지 알고 있다는 자연스러운 가정을 하고, 에이전트가 그것에 대해 최대한의 통제력을 갖도록 장려하는 새로운 intrinsic object를 제안한다고 함.

저자는 보상을 현재 에이전트 정책에 따라 에이전트 상태(agenet state)와 주변 상태(surrounding state) 사이의 상호 정보로 수학적으로 공식화했다고 함.

 

https://www.youtube.com/watch?v=AUCwc9RThpk&feature=youtu.be&ab_channel=RuiZhao 

Code

잘 와닿지가 않아서 나중에 다시 보기로....

먼가 썩 와닿지는 않음...

# generate roll outs
o_new = np.empty((self.rollout_batch_size, self.dims['o']))
ag_new = np.empty((self.rollout_batch_size, self.dims['g']))
success = np.zeros(self.rollout_batch_size)
# compute new states and observations
for i in range(self.rollout_batch_size):
    try:
        # We fully ignore the reward here because it will have to be re-computed
        curr_o_new, _, _, info = self.envs[i].step(u[i])
        if 'is_success' in info:
            success[i] = info['is_success']
        o_new[i] = curr_o_new['observation']
        ag_new[i] = curr_o_new['achieved_goal']
        for idx, key in enumerate(self.info_keys):
            info_values[idx][t, i] = info[key]
episode = dict(o=obs,
               z=zs,
               u=acts,
               g=goals,
               ag=achieved_goals,)
for key, value in zip(self.info_keys, info_values):
    episode['info_{}'.format(key)] = value

 


def _grads(self):
    critic_loss, actor_loss, Q_grad, pi_grad, neg_logp_pi, e_w = self.sess.run([
        self.Q_loss_tf,
        self.main.Q_pi_tf,
        self.Q_grad_tf,
        self.pi_grad_tf,
        self.main.neg_logp_pi_tf,
        self.e_w_tf,
    ])
    return critic_loss, actor_loss, Q_grad, pi_grad, neg_logp_pi, e_w

def train(self, t, stage=True):
    if not self.buffer.current_size==0:
        if stage:
            self.stage_batch(ir=True, t=t)
        critic_loss, actor_loss, Q_grad, pi_grad, neg_logp_pi, e_w = self._grads()
        self._update(Q_grad, pi_grad)
        self.et_r_history.extend((( np.clip((self.et_r_scale * neg_logp_pi), *(-1, 0))) * e_w ).tolist())
        return critic_loss, actor_loss

 

 
# intrinsic reward (ir) network for mutual information
with tf.variable_scope('ir') as vs:
    if reuse:
        vs.reuse_variables()
    self.main_ir = self.create_discriminator(batch_tf, net_type='ir', **self.__dict__)
    vs.reuse_variables()

# loss functions

mi_grads_tf = tf.gradients(tf.reduce_mean(self.main_ir.mi_tf), self._vars('ir/state_mi'))
assert len(self._vars('ir/state_mi')) == len(mi_grads_tf)
self.mi_grads_vars_tf = zip(mi_grads_tf, self._vars('ir/state_mi'))
self.mi_grad_tf = flatten_grads(grads=mi_grads_tf, var_list=self._vars('ir/state_mi'))
self.mi_adam = MpiAdam(self._vars('ir/state_mi'), scale_grad_by_procs=False)

sk_grads_tf = tf.gradients(tf.reduce_mean(self.main_ir.sk_tf), self._vars('ir/skill_ds'))
assert len(self._vars('ir/skill_ds')) == len(sk_grads_tf)
self.sk_grads_vars_tf = zip(sk_grads_tf, self._vars('ir/skill_ds'))
self.sk_grad_tf = flatten_grads(grads=sk_grads_tf, var_list=self._vars('ir/skill_ds'))
self.sk_adam = MpiAdam(self._vars('ir/skill_ds'), scale_grad_by_procs=False)

target_Q_pi_tf = self.target.Q_pi_tf
clip_range = (-self.clip_return, self.clip_return if self.clip_pos_returns else np.inf)

self.e_w_tf = batch_tf['e_w']

if not self.sac:
    self.main.neg_logp_pi_tf = tf.zeros(1)

target_tf = tf.clip_by_value(self.r_scale * batch_tf['r'] * batch_tf['r_w'] + (tf.clip_by_value( self.mi_r_scale * batch_tf['m'], *(0, 1) ) - (1 if not self.mi_r_scale == 0 else 0)) * batch_tf['m_w'] + (tf.clip_by_value( self.sk_r_scale * batch_tf['s'], *(-1, 0))) * batch_tf['s_w'] + (tf.clip_by_value( self.et_r_scale * self.main.neg_logp_pi_tf, *(-1, 0))) * self.e_w_tf + self.gamma * target_Q_pi_tf, *clip_range)

self.td_error_tf = tf.stop_gradient(target_tf) - self.main.Q_tf
self.errors_tf = tf.square(self.td_error_tf)
self.errors_tf = tf.reduce_mean(batch_tf['w'] * self.errors_tf)
self.Q_loss_tf = tf.reduce_mean(self.errors_tf)

self.pi_loss_tf = -tf.reduce_mean(self.main.Q_pi_tf)
self.pi_loss_tf += self.action_l2 * tf.reduce_mean(tf.square(self.main.pi_tf / self.max_u))
Q_grads_tf = tf.gradients(self.Q_loss_tf, self._vars('main/Q'))
pi_grads_tf = tf.gradients(self.pi_loss_tf, self._vars('main/pi'))
assert len(self._vars('main/Q')) == len(Q_grads_tf)
assert len(self._vars('main/pi')) == len(pi_grads_tf)
self.Q_grads_vars_tf = zip(Q_grads_tf, self._vars('main/Q'))
self.pi_grads_vars_tf = zip(pi_grads_tf, self._vars('main/pi'))
self.Q_grad_tf = flatten_grads(grads=Q_grads_tf, var_list=self._vars('main/Q'))
self.pi_grad_tf = flatten_grads(grads=pi_grads_tf, var_list=self._vars('main/pi'))

# compute the negative loss (maximise loss == minimise -loss)

class Discriminator:
    @store_args
    def __init__(self, inputs_tf, dimo, dimz, dimg, dimu, max_u, o_stats, g_stats, hidden, layers, env_name, **kwargs):
        """The discriminator network and related training code.
        Args:
            inputs_tf (dict of tensors): all necessary inputs for the network: the
                observation (o), the goal (g), and the action (u)
            dimo (int): the dimension of the observations
            dimg (int): the dimension of the goals
            dimu (int): the dimension of the actions
            max_u (float): the maximum magnitude of actions; action outputs will be scaled
                accordingly
            o_stats (baselines.her.Normalizer): normalizer for observations
            g_stats (baselines.her.Normalizer): normalizer for goals
            hidden (int): number of hidden units that should be used in hidden layers
            layers (int): number of hidden layers
        """

        self.o_tf = tf.placeholder(tf.float32, shape=(None, self.dimo))
        self.z_tf = tf.placeholder(tf.float32, shape=(None, self.dimz))
        self.g_tf = tf.placeholder(tf.float32, shape=(None, self.dimg))

        obs_tau_excludes_goal, obs_tau_achieved_goal = split_observation_tf(self.env_name, self.o_tau_tf)

        obs_excludes_goal, obs_achieved_goal = split_observation_tf(self.env_name, self.o_tf)

        # Discriminator networks

        with tf.variable_scope('state_mi'):
            # Mutual Information Neural Estimation
            # shuffle and concatenate
            x_in = obs_tau_excludes_goal
            y_in = obs_tau_achieved_goal
            y_in_tran = tf.transpose(y_in, perm=[1, 0, 2])
            y_shuffle_tran = tf.random_shuffle(y_in_tran)
            y_shuffle = tf.transpose(y_shuffle_tran, perm=[1, 0, 2])
            x_conc = tf.concat([x_in, x_in], axis=-2)
            y_conc = tf.concat([y_in, y_shuffle], axis=-2)

            # propagate the forward pass
            layerx = tf_layers.linear(x_conc, int(self.hidden/2))
            layery = tf_layers.linear(y_conc, int(self.hidden/2))
            layer2 = tf.nn.relu(layerx + layery)
            output = tf_layers.linear(layer2, 1)
            output = tf.nn.tanh(output)
            
            # split in T_xy and T_x_y predictions
            N_samples = tf.shape(x_in)[-2]
            T_xy = output[:,:N_samples,:]
            T_x_y = output[:,N_samples:,:]
            
            # compute the negative loss (maximise loss == minimise -loss)
            mean_exp_T_x_y = tf.reduce_mean(tf.math.exp(T_x_y), axis=-2)
            neg_loss = -(tf.reduce_mean(T_xy, axis=-2) - tf.math.log(mean_exp_T_x_y))
            neg_loss = tf.check_numerics(neg_loss, 'check_numerics caught bad neg_loss')
            self.mi_tf = neg_loss

        with tf.variable_scope('skill_ds'):
            self.logits_tf = nn(obs_achieved_goal, [int(self.hidden/2)] * self.layers + [self.dimz])
            self.sk_tf = tf.nn.softmax_cross_entropy_with_logits(labels=self.z_tf, logits=self.logits_tf)
            self.sk_r_tf = -1 * self.sk_tf

 

 

 

 

https://github.com/ruizhaogit/music

 

GitHub - ruizhaogit/music: Mutual Information State Intrinsic Control (ICLR 2021 Spotlight)

Mutual Information State Intrinsic Control (ICLR 2021 Spotlight) - GitHub - ruizhaogit/music: Mutual Information State Intrinsic Control (ICLR 2021 Spotlight)

github.com

 

728x90