[RL] PPO 학습 중에 nan 나오는 특이한 경우
2022. 5. 12. 23:50ㆍ관심있는 주제/RL
강화 학습 학습 도중에 에러가 나는 경우를 공유한다.
바로 액션 공간이 좀 큰 상황에서 특정 값이 너무 작게 나오는 데 선택하는 경우이다.
아래처럼 예시를 만들면 다음과 같다.
특정 logit들은 엄청 크게 나오지만, 한 개의 logit은 엄청 작게 나오는 데 그것을 선택하는 경우 문제가 발생한다.
바로 아래와 같은 경우이다.
logit = torch.randint(low=1000000,high=2000000,size=(1,128*128*2))
logit[0,0] = 1e-4567
dist = Categorical(logits=logit)
log_prob = dist.log_prob(torch.tensor([0]))
## tensor([-1999934.])
PPO에서는 아시다시피 old log prob와 current log prob을 빼주고 exponential을 하는 부분이 있다.
바로 여기서 exp에서 문제가 발생한다.
바로 torch.exp 에서는 최대 88까지만 가능하기 때문이다 (현재 기준)
>>> torch.exp(torch.tensor(88))
tensor(1.6516e+38)
>>> torch.exp(torch.tensor(89))
tensor(inf)
그러므로 ratio를 구할 때 2개를 빼고 exp를 구하는데, 이때 log prob 자체 값이 저런 식으로 튀게 큰 값이 나오게 나오면
infinity로 계산하게 되서 학습이 안되게 된다.
(current log prob - old log prob).sum(1).exp()
액션을 저렇게 선택했다는 것부터 이상하긴 하지만, 큰 공간에서는 그럴 수 있다고 생각한다.
그러므로 나는 다음과 같은 트릭을 사용했다.
Solution
ratio = ( current_log_probs - old_log_probs).sum(1, keepdim=True)
ratio = ratio.clamp_(max=88).exp()
다른 분들은 삽질을 덜하시고 찾아내시길 ㅠ
728x90
'관심있는 주제 > RL' 카테고리의 다른 글
Paper) Heuristic Algorithm-based Action Masking Reinforcement Learning (HAAM-RL) with Ensemble Inference Method 읽어보기 (0) | 2024.06.22 |
---|---|
논문 리뷰) [TODO] Online Decision Transformer (0) | 2022.05.25 |
진행중) Reverb: a framework for experience replay 알아보기 (0) | 2021.10.07 |
RL) MARL 자료 모음 (2) | 2021.09.25 |
Paper) Neural Combinatorial Optimization with Reinforcement Learning - Not Finished... (0) | 2021.09.14 |