[RL] PPO 학습 중에 nan 나오는 특이한 경우

2022. 5. 12. 23:50관심있는 주제/RL

728x90

강화 학습 학습 도중에 에러가 나는 경우를 공유한다.

바로 액션 공간이 좀 큰 상황에서 특정 값이 너무 작게 나오는 데 선택하는 경우이다.

 

아래처럼 예시를 만들면 다음과 같다.

특정 logit들은 엄청 크게 나오지만, 한 개의 logit은 엄청 작게 나오는 데 그것을 선택하는 경우 문제가 발생한다.

바로 아래와 같은 경우이다.

logit = torch.randint(low=1000000,high=2000000,size=(1,128*128*2))
logit[0,0] = 1e-4567
dist = Categorical(logits=logit)
log_prob = dist.log_prob(torch.tensor([0]))

## tensor([-1999934.])

PPO에서는 아시다시피 old log prob와 current log prob을 빼주고 exponential을 하는 부분이 있다.

바로 여기서 exp에서 문제가 발생한다.

바로 torch.exp 에서는 최대 88까지만 가능하기 때문이다 (현재 기준)

>>> torch.exp(torch.tensor(88))
tensor(1.6516e+38)
>>> torch.exp(torch.tensor(89))
tensor(inf)

그러므로 ratio를 구할 때 2개를 빼고 exp를 구하는데, 이때 log prob 자체 값이 저런 식으로 튀게 큰 값이 나오게 나오면

infinity로 계산하게 되서 학습이 안되게 된다.

(current log prob - old log prob).sum(1).exp()

액션을 저렇게 선택했다는 것부터 이상하긴 하지만, 큰 공간에서는 그럴 수 있다고 생각한다.

그러므로 나는 다음과 같은 트릭을 사용했다.

 

Solution

ratio = ( current_log_probs - old_log_probs).sum(1, keepdim=True)
ratio = ratio.clamp_(max=88).exp()

 

다른 분들은 삽질을 덜하시고 찾아내시길 ㅠ

728x90