RL) DuelingDQN 알아보기

2021. 6. 3. 20:39관심있는 주제/RL

비교 장표

기존의 Double DQN이라는 것이 있었는데, 이 논문에선 advantage function을 도입하여 성능을 더 향상한 논문이라 할 수 있습니다.

가장 큰 특징

Dueling DQN 알고리즘은 Q-VALUE을 값 함수 V (s)와 이점 함수 A (s, a)의 두 부분으로 나눕니다.

여기서는 advatange function과 value function이 동시에 존재한다해서 dueling이라고 한다고 합니다.

 

직관적으로, dueling 아키텍처는 각 state에 대한 각 작업의 효과를 학습할 필요 없이 어떤 상태가 가치 있는지(또는 가치가 없는지를) 학습할 수 있습니다.

이는 특히 해당 action이 environment에 관련되는 어떤 방식으로도 영향을 미치지 않는 상태에서 유용합니다.

배경

논문의 예시는 다음과 같습니다.

크게 위와 아래 그림을 나눠보면 위는 장애물이 앞에 없는 상황 아래는 장애물이 있는 상황입니다.

해당 게임의 목적은 목적지를 향해가는 것이고, 그때 value는 결국 목적지에 다가가는 것이 중요하기 때문에 지평선 근처를 집중하게 됩니다.  

advantage에서는 장애물이 없는 상황에서는 전방에 대해서 크게 신경을 쓰지 않습니다. 즉 value라는 것만 보게 되고, 어떤 행동을 할지에 따라 영향이 안 받기 때문입니다. 

그러나 아래 그림을 보게 되면, Advantage 관점에서 앞에 장애물을 신경쓰게 되는 것을 알 수 있습니다. 

즉 현재 베이스라인과 비교하였을 때 해당 장애물이 영향을 받는 다라는 것을 학습을 한 것이죠.

그래서 두 번째 단계(가장 오른쪽 이미지 쌍)에서 장점 스트림은 바로 앞에 차가 있기 때문에 주의를 기울이며, 그 동작 선택은 매우 관련성이 있게 됩니다.

 

실험에서, 우리는 무관한 행동이나 유사한 행동이 학습 문제에 추가됨에 따라 dueling 아키텍처가 정책 평가 중에 올바른 동작을 더 빨리 식별할 수 있음을 보여줍니다. 그래서 이러한 부분을 신경 쓰기 위해서 advantage와 value function을 나눠서 estimator 하고자 하는 것 같습니다. 

 

시행 방법

 

VALUE FUNCTION V(s)는 state s 로부터 모아진 얼마나 많은 보상을 받을지 알려줍니다.

Advantage Function A(s,a)는 다른 액션에 비해 해당 액션이 얼마나 더 나은지를 말해줍니다.

 

Value와 Advantage를 결합하는 것이 본 논문에서 제안하는 Q 함수가 됩니다. 

 

Dueling DQN에서 제안하고자 하는 것은 같은 네트워크를 통과해서 마지막 단에서 stat value function과 advantage function을 estimate 하는 것입니다. 

그리고 이것을 결합하여 하나의 q-value를 만듭니다. 

 

이렇게 할 경우 때떄로 행동에 대해서 정확한 값을 알 필요가 없이, state value function을 배우는 것으로도 충분할 수 있습니다. 

 

이때 이슈가 생기게 된다고 합니다.

 

단순히 V와 A를 더했을 때 Q가 얼마나 어떤 V와 A에게 영향을 받았는지 알 수 없습니다.(unidentifiable)

 

unidentifiable 해결 방법은 다음과 같습니다.

 

일단 단순히 value function과 advantage function을 합쳐서 훈련시키는 것은 불가능합니다.

그리고 Q=V+A에서는 함수 Q가 주어지면 식별할 수 없는 V 및 A의 값을 확인할 수 없습니다

만약 Q가 20이라 한다면 V+A가 20이라는 것인데, 그럴 때 V와 A의 경우의 수는 거의 무한에 가깝게 나올 수 있는 것이죠

이 문제를 해결하기 위해서, 논문에서는 다음과 같은 트릭을 제안합니다.

optimal action $a^*$를 택했을 때, 가장 높은 Q-VALUE(Q(s,$a^*$)을 V(s)와 같게 합니다.

그러므로 advantage function의 가장 높은 값을 0으로 만들고 다른 모든 값을 음으로 만듭니다.

이렇게 하면 V의 값을 정확하게 알 수 있으며, V의 장점을 모두 계산하여 문제를 해결할 수 있습니다. 

 

아래와 같이 구성하여 훈련을 할 수 있습니다.

식별할 수 없는 문제를 풀기 위한 트릭 적용 수식

 

그러나 해당 논문에서는 이 과정에서 조금 변화를 줬다고 합니다. 

max를 계산하는 대신에, 평균으로 대체하였습니다. 

식별할 수 없는 문제를 풀기 위한 트릭 적용 수식을 max에서 mean으로 변경한 수식

 

위의 주장한 것처럼 하게 되면 V와 A의 의미를 가지게 되지만, 평균으로 대체하는 경우에는 의미를 잃어버리게 됩니다.

그렇지만 이러한 경우 상수로 고정해 놓은 목표가 아니기 때문에, 최적화의 안정성이 증가하게 되죠. 

하지만 실제로 2개를 같이 실험해보면, max로 하는 방식도 값이 계속 변경되기 때문에 잘 학습이 된다고 하고, 2개는 유사한 결과가 나온다고 합니다.

 

 

네트워크는 이런 느낌으로 구성됩니다. 

네트워크에서 다른 것들 큰 차이를 보면 mean operator와 max operator를 해주는 것이 큰 차이라고 할 수 있다.

class QNetwork(nn.Module):
    def __init__(self):
        super(QNetwork, self).__init__()

        self.fc1 = nn.Linear(4, 64)
        self.relu = nn.ReLU()
        self.fc_value = nn.Linear(64, 256)
        self.fc_adv = nn.Linear(64, 256)

        self.value = nn.Linear(256, 1)
        self.adv = nn.Linear(256, 2)

    def forward(self, state):
        y = self.relu(self.fc1(state))
        value = self.relu(self.fc_value(y))
        adv = self.relu(self.fc_adv(y))

        value = self.value(value)
        adv = self.adv(adv)
		# dueling dqn #
        ## mean operator
        advAverage = torch.mean(adv, dim=1, keepdim=True)
        Q = value + adv - advAverage
        ## max operator
        #advMax = torch.max(adv, dim=1, keepdim=True)
        # Q = value + adv - advMax
        return Q
        
    def select_action(self, state):
        with torch.no_grad():
            Q = self.forward(state)
            action_index = torch.argmax(Q, dim=1)
        return action_index.item()

학습 코드 (일부)

loss를 계산하는 것에서는 dqn을 하든 dueling dqn을 하든 동일한 형태를 나타낸다. 

 

for epoch in count():

    state = env.reset()
    episode_reward = 0
    for time_steps in range(200):
        p = random.random()
        if p < epsilon:
            action = random.randint(0, 1)
        else:
            tensor_state = torch.FloatTensor(state).unsqueeze(0).to(device)
            action = onlineQNetwork.select_action(tensor_state)
        next_state, reward, done, _ = env.step(action)
        episode_reward += reward
        ## append trajectory
        memory_replay.add((state, next_state, action, reward, done))
        if memory_replay.size() > 128:
          batch = memory_replay.sample(BATCH, False)
          batch_state, batch_next_state, batch_action, batch_reward, batch_done = zip(*batch)
          ## get batch 
          batch_state = torch.FloatTensor(batch_state).to(device)
          batch_next_state = torch.FloatTensor(batch_next_state).to(device)
          batch_action = torch.FloatTensor(batch_action).unsqueeze(1).to(device)
          batch_reward = torch.FloatTensor(batch_reward).unsqueeze(1).to(device)
          batch_done = torch.FloatTensor(batch_done).unsqueeze(1).to(device)
	      ## update 
          with torch.no_grad():
              onlineQ_next = onlineQNetwork(batch_next_state)
              targetQ_next = targetQNetwork(batch_next_state)
              online_max_action = torch.argmax(onlineQ_next, dim=1, keepdim=True)
              y = batch_reward + (1 - batch_done) * GAMMA * targetQ_next.gather(1, online_max_action.long())
          loss = F.mse_loss(onlineQNetwork(batch_state).gather(1, batch_action.long()), y)
          optimizer.zero_grad()
          loss.backward()
          optimizer.step()

 

 

 

Reference

paper : https://arxiv.org/abs/1511.06581 

https://www.programmersought.com/article/66274662959/

https://markelsanz14.medium.com/introduction-to-reinforcement-learning-part-4-double-dqn-and-dueling-dqn-b349c9a61ea1 

https://taek-l.tistory.com/37 

https://github.com/qfettes/DeepRL-Tutorials (알고리즘별로 되어 있음(참고하기 좋아 보임))

728x90