RL) Double DQN 알아보기

2021. 5. 9. 11:13관심있는 주제/RL

728x90

일단 double dqn을 보기에 앞서 기존의 q-learning과 deep q-learning을 간략하게 그림으로 보여주고 시작하겠다.

간략히 나온 배경에 대해 말하자면, 기존의 있던 deep q learning 같은 경우 action value에 대한 overestimate가 문제였다. 그래서 이 double dqn은 이러한 문제를 해결하기 위해 나온 것이다. 

 

간략하게 핵심만 살펴보자.

 

  • Double Q-Learning이 무엇인지?
  • Double Q-Learning 알고리즘
  • Double Deep Q Network(Double DQN)
  • 구현

Double Q-Learning이 무엇인지?

 

double q learning은 [1] H. van Hasselt 2010 이 기존 q-learning에서 action value의 가한 추정 문제를 해결하기 위해서 제안했다.

 

간단히 말해서, 과대평가의 문제는 에이전트가 최대 Q- 값을 가지고 있기 때문에 주어진 상태에서 항상 최적이 아닌 동작을 선택한다는 것입니다. 

 

기본 q-learning에서는 에이전트의 최적 정책이 주어진 상태 내에서 최고의 액션을 선택하는 것이다.

이 아이디어의 가정은 최대 기대 혹은 추정된 q-value가 최고의 액션이라는 것이 있기 때문이다.

그러나 에이전트는 초기에 환경에 대해서 아무것도 모르기 때문에, 각 iteration에서 그들을 업데이트하거나 처음 q(s, a)를 추정할 필요가 있다.

이러한 q-value는 잡음을 가지게 되고, 개발자는 최대 기대 / 예상 Q- 값을 가진 행동이 정말 최고의 행동인지 확신할 수 없게 된다.

 

불행하게도, 최선의 행동은 대부분의 경우 최적이 아닌 행동에 비해 Q- 값이 더 작습니다.

기존 q-learning에서 최적 정책에 따라서, 에이전트는 오로지 최대 q-value에 의해 주어진 상태에서 최적이 아닌 행동을 하는 경향이 있게 됩니다. 

이러한 문제를 action value(q value)의 과추정이라고 합니다.

 

이러한 문제가 발생할 때, 추정된 q-value로부터 잡음은 업데이터 과정에서 큰 편향을 유발하게 되고, 결과적으로 학습과정은 복잡해지고 지저분해질 것입니다.

 

 

아래를 보게 되면 $q_{best}$ 같은 경우 noisy가 있는 현재의 $Q(s, a)$를 사용하여 추정되게 된다.

다른 잡음으로 인해서 best와 current의 $Q(s, a)$간의 차이는 지저분해지고, 편향이 생기게 된다.

 

 

그런데 모든 Q- value의 noisy가 uniform 분포를 가지고 있다면, 보다 구체적으로 Q- 값이 똑같이 과대평가되면 이러한 잡음이 Q (s ', a) 간의 차이에 영향을 미치지 않기 때문에 과대평가는 문제가 되지 않습니다. 및 Q (s, a). 자세한 내용은 [1] H. van Hasselt 2010, 섹션 2에 있습니다.

 

샘플에 대해 최대 값을 취하여 여러 랜덤 변수의 기대치 최대 값을 추정하려고 할 때마다 이러한 현상이 발생하게 된다.

예를 들어, 그림에서 세 가지 선택 중 가장 좋은 값의 추정치가 모두 동일한 제로 평균 분포에서 샘플링되었지만 양의 편향이 있음을 보여준다. 직관적으로 3개 대신에 무한히 많은 샘플링을 한다면, 분포의 오른쪽 꼬리로 수렴할 수 있다.  

 

마찬가지로 강화 학습에서 action-value에 대한 적절한 함수 정의(Q)는 maximum expected total reward이다.

일반 q-learning은 전체 보상의 실제 분포가 아닌 maximum total reward의 분포를 사용하여 추정하게 되는데, 이때 잘못된 과추정을 하게 되는 것이라는 설명도 있다.

 

즉 간단히 정리하면 원인으로는 다음과 같다.

  • Insufficiently flexible function approximation
  • Noise or Stochasticity (in rewards and/or environmen

Double Q-Learning 알고리즘

 

그래서 기존의 과추정되는 문제를 풀기 위해서 해당 논문에서는 $Q$와 $Q'$라는 2개의 다른 action value function을 사용한다.

Q와 Q' 이 둘 다 잡음이 있다 할지라도, 그들의 잡음은 uniform distribution으로 봐서 action value의 과추정 문제를 해결해줄 수 있다고 한다. proof is in [1] H. van Hasselt 2010, Section 3.

 

 

Q 함수는 다음 상태의 최대 Q 값을 가진 최상의 동작을 선택하기 위한 것입니다.

 

Q' 함수는 위에서 선택된 행동 a를 사용하여 기댓값 Q-value를 계산하기 위한 것입니다.

 

 

Q' 함수의 기댓값 Q-value를 사용하여 Q 함수를 업데이트합니다.

 

Completed Pseudocode

 

[1]  H. van Hasselt. Double Q-learning. NIPS, 2010 .

그래서 2개의 estimator를 이용해서 랜덤 하게 각자를 업데이트하는 방식으로 진행할 때 2개의 차이를 이용해서 최대한 잡음을 빼버리고 업데이트하는 전략을 취하는 것 같습니다.

 

그러나 위에서 제안된 방식은 표 형식 및 행렬 방식이므로 큰 스테이트를 처리할 수 없게 되고, 이제 그래서 Deep Neural Network가 필요해집니다.

 

 


Double Deep Q Network (Double DQN)

Double Q-Learning implementation with Deep Neural Network is called Double Deep Q Network (Double DQN).

 

Double DQN은 [2] H. van Hasselt, 2016에 의해서 제안되었습니다. 

Dobule Q-Learning에서 영감을 받아서 Double DQN은 2개의 다른 네트워크(DQN과 Target Network)를 사용합니다. 

 

 

Deep Q Network의 매개 변수를 업데이트하는 최적화 단계에서 사용되므로 Q 값을 업데이트할 때 학습률 α가 없습니다.

 

 

  • Deep Q Network
    • 다음 상태의 최대 Q- 값으로 최상의 행동의 선택합니다.

 

 

  • Target Network
    • 선택된 행동 한에서 추정된 q-value를 계산합니다.

  • Target로부터 추정된 q-value를 기반으로 DQN을 업데이트합니다.

  • 규칙적으로 DQN의 파라미터를 기반으로 Target Network 파라미터를 업데이트합니다.

 

 

여기까지 봤을 때 사실 기존의 DQN과는 큰 차이는 없다.

기존의 DQN에서 replay memory에다가 데이터를 쌓아두고, 시작하는 것은 동일하다.

여기서는 Qnet를 통해서 일단 메모리 버퍼에 쌓아 놓는다.

그래서 아래에서 dqn이냐 아니냐에 따라서, q net의 최댓값을 쓸 것인지, 아니면 target net의 최댓값을 쓰는지 나뉘게 된다.

이러한 점을 빼고는 나머지 코드와 거의 유사하다는 것을 알 수 있다. 

    if self.args['double_dqn']:
        _, next_state_actions = self.Qnet(non_final_next_states).max(1, keepdim=True)
        V_next_state[non_final_mask] = self.target_Qnet(non_final_next_states).gather(1, next_state_actions)
    else:
        V_next_state[non_final_mask] = self.target_Qnet(non_final_next_states).max(1)[0]

 

def replay(self, batch_size):
    if len(self.memory) < batch_size:
        return
    batch = self.memory.sample(batch_size)
    state_batch = Variable(batch.state / 255)
    action_batch = Variable(batch.action)
    reward_batch = Variable(batch.reward)
    non_final = LongTensor([i for i, done in enumerate(batch.done) if not done])
    non_final_mask = 1 - batch.done
    # To prevent backprop through the target action values, set volatile=False (also sets requires_grad=False)
    non_final_next_states = Variable(batch.next_state.index_select(0, non_final) / 255, volatile=True)

    # Compute Q(s_t, a)
    Q_state_action = self.Qnet(state_batch).gather(1, action_batch)

    # Double DQN - Compute V(s_{t+1}) for all next states.
    V_next_state = Variable(torch.zeros(batch_size).type(Tensor))
    if self.args['double_dqn']:
        _, next_state_actions = self.Qnet(non_final_next_states).max(1, keepdim=True)
        V_next_state[non_final_mask] = self.target_Qnet(non_final_next_states).gather(1, next_state_actions)
    else:
        V_next_state[non_final_mask] = self.target_Qnet(non_final_next_states).max(1)[0]

    # Remove Volatile as it sets all variables computed from them volatile.
    # The Variable will just have requires_grad=False.
    V_next_state.volatile = False

    # Compute the target Q values
    target_Q_state_action = reward_batch + (self.gamma * V_next_state)

    # Compute loss
    loss = F.smooth_l1_loss(Q_state_action, target_Q_state_action)
    # td_error = target_Q_state_action - Q_state_action
    # clipped_error = td_error.clamp(-1, 1)

    # Optimize the model
    self.optimizer.zero_grad()
    loss.backward()
    for param in self.Qnet.parameters():
        param.grad.data.clamp_(-1, 1)
    self.optimizer.step()

깔끔하게 핵심 부분만 정리한 코드도 있어서 공유한다.

import torch
from torch import nn


def select_greedy_actions(states: torch.Tensor, q_network: nn.Module) -> torch.Tensor:
    """Select the greedy action for the current state given some Q-network."""
    _, actions = q_network(states).max(dim=1, keepdim=True)
    return actions


def evaluate_selected_actions(states: torch.Tensor,
                              actions: torch.Tensor,
                              rewards: torch.Tensor,
                              dones: torch.Tensor,
                              gamma: float,
                              q_network: nn.Module) -> torch.Tensor:
    """Compute the Q-values by evaluating the actions given the current states and Q-network."""
    next_q_values = q_network(states).gather(dim=1, index=actions)        
    q_values = rewards + (gamma * next_q_values * (1 - dones))
    return q_values


def q_learning_update(states: torch.Tensor,
                      rewards: torch.Tensor,
                      dones: torch.Tensor,
                      gamma: float,
                      q_network: nn.Module) -> torch.Tensor:
    """Q-Learning update with explicitly decoupled action selection and evaluation steps."""
    actions = select_greedy_actions(states, q_network)
    q_values = evaluate_selected_actions(states, actions, rewards, dones, gamma, q_network)
    return q_values

def double_q_learning_update(states: torch.Tensor,
                             rewards: torch.Tensor,
                             dones: torch.Tensor,
                             gamma: float,
                             q_network_1: nn.Module,
                             q_network_2: nn.Module) -> torch.Tensor:
    """Double Q-Learning uses Q-network 1 to select actions and Q-network 2 to evaluate the selected actions."""
    actions = select_greedy_actions(states, q_network_1)
    q_values = evaluate_selected_actions(states, actions, rewards, dones, gamma, q_network_2)
    return q_values

 

후기 

 

이러한 작은 차이로도 큰 성능개선을 할 수 있다는 것이 핵심을 잘 파악했다는 생각이 들었던 것 같고, q-value에 대해서 더 잘 알게 된 것 같다.

 

 

 

arxiv.org/pdf/1509.06461.pdf

medium.com/@qempsil0914/deep-q-learning-part2-double-deep-q-network-double-dqn-b8fc9212bbb2

 

Deep Q-Learning, Part2: Double Deep Q Network, (Double DQN)

An introduction and implementation tutorial with python3 and Tensorflow

medium.com

medium.com/@ameetsd97/deep-double-q-learning-why-you-should-use-it-bedf660d5295

 

Deep Double Q-Learning — Why you should use it

It is an exception rather than the norm to not use Neural Networks (ANNs) for training Reinforcement Learning (RL)Agents because of the…

medium.com

www.quora.com/Why-does-standard-Q-learning-tend-to-overestimate-q-values

 

Why does standard Q-learning tend to overestimate q-values?

Answer (1 of 2): Q(s, a) = r + gamma * maxQ(s', a') over all actions Since Q values are very noisy, when you take the max over all actions, you're probably getting an overestimated value. Think like this, the expected value of a dice roll is 3.5, but if yo

www.quora.com

medium.com/@qempsil0914/deep-q-learning-part2-double-deep-q-network-double-dqn-b8fc9212bbb2

 

Deep Q-Learning, Part2: Double Deep Q Network, (Double DQN)

An introduction and implementation tutorial with python3 and Tensorflow

medium.com

github.com/Shivanshu-Gupta/Pytorch-Double-DQN/blob/master/agent.py

 

Shivanshu-Gupta/Pytorch-Double-DQN

Pytorch Implementation of Double DQN Algorithm. Contribute to Shivanshu-Gupta/Pytorch-Double-DQN development by creating an account on GitHub.

github.com

davidrpugh.github.io/stochastic-expatriate-descent/pytorch/deep-reinforcement-learning/deep-q-networks/2020/04/11/double-dqn.html

 

Improving the DQN algorithm using Double Q-Learning

Notes on improving the DQN algorithm using Double Q-learning.

davidrpugh.github.io

 

728x90