Using Deep Q-Learning in the Classification of an Imbalanced Dataset - 리뷰

2020. 1. 7. 21:17관심있는 주제/RL

728x90

불균형 문제는 머신러닝을 사용할 때 직면하는 흔한 문제이다.
이 문제를 해결하기 위해 알고리즘의 수준 또는 데이터 수준에 관계없이 여러 가지 접근법이 사용되어 왔다.
알고리즘 수준에서는 class weight를 조정해 얼마 없는 class에 대해서 더 많은 가중치를 주는 cost function을 바꾸는 방법이 있다.
반면에 데이터 수준에서는 재 표본 기법이 있다. 얼마 없는 class에 대해서 upsampling을 하거나 많이 있는 class에 대해서 downsampling을 한다. 
해당 글에서는 딥 Q-러닝 뒤에 있는 개념이 어떻게 불균형한 데이터 세트의 문제를 해결하기 위해 활용될 수 있는지 볼 것이다.


Dataset:

DCIS(도관암)이라는 질병을 의학적 스캔으로 검출하는 데이터셋을 사용했다

세포나 계층의 출현 사례는 건강한 조직에 속하는 다른 계층의 출현보다 현저히 낮다고 한다.
이러한 점은 머신 러닝을 학습할 때 종종 겪는 문제이다. 
그리고 얼마 없는 class를 찾는 것은 어려워지고, 상당한 cost가 발생한다.


Formalize the Process:

Deep Q Network 이론 아래 URL 참고 

https://medium.com/@jonathan_hui/rl-dqn-deep-q-network-e207751f7ae4

 

만약 DQN을 분류 문제로 고려한다면, 추측 하는 게임으로 볼 수 있다.
양의 Reward는 만약 추축이 맞을 때 받게 돼, 틀리면, 음의 Reward를 받는 형태로 볼 수 있다.

이런 학습 프로세스를 걸쳐서, agent는 누적 보상을 최대화하여 전체 정확하게 분류된 샘플을 최대화하기 위한 최적의 정책을 배울 것이다.
DQN는 Markov Decision Process를 따르기 때문에, agent와 환경 사이의 상호작용 과정이 순차적인 proces로 변형되어야 한다.
그러나 만약 서로가 관련 없는 이미지를 분류하는 문제에 쓴다고하면, 이건 흔한 접근은 아니다.

제 분균형 이미지 분류를 위해 deep Q Learning의 적용이 얼마나 실제로 잘 작동하는지 알아보자.

  1. State 
    • 환경 안에 있는 state s 는 우리의 경우 이미지 샘플이 training sample이다.
  2. Action
    • agent의 행동 a 는 training sample의 label이다. 간단히 말해서, binary problem으로 다룰 수 있다. 
    • agent는 A = {0,1}에서 한 가지를 선택한다 (0 : majority class  , 1: minority class)
  3. Reward
    • reward r 은 state s를 올바르게 분류하는 성공을 측정하기 위해 환경이 agent에게 되돌려 주는 피드백이다.
    • agent가 불균형 데이터 셋일 때 최적의 분류 정책을 학습하기 위해서 agent에게 input state가 minority class에 속하면 더 높은 reward를 주고 majority class에 속할 때는 적은 reward를 준다.
  4. Discount factor
    • gamma은 [0,1] 이것은 미래 보상에 대한 중요성을 나타낸다. 
    • 그때 이미지 분류에 적용하기 위해서 연속적인 샘플들은 상관성이 없고 각 이미지는 분류가 올바르게 됐는지 확인할 필요가 있다. 그래서 gamma는 낮은 값은 주는 게 나은 선택이다.
  5. Exploration rate 
    • epsilon [0,1] 만약 1로 설정하면, 취하는 행동은 완전히 exploration에 기초하여 한다는 뜻이 된다.
    • 반면에 만약 0으로 하면 agent의 knowledge의 exploitation을 통해 행동한다.
  6. Transition Probability
    • ICMDP에서 p(st+1|st, at) 는 deterministic 하다. 
    • agent는 training data set의 샘플 순서에 따라서 현재 state에서 다음 state로 움직인다. 
  7.  Policy
    • 여기서 policy는 분류기 판단하는 확률을 의미한다.

 


Set the Reward Function

이전에 언급하였듯이, reward system은 불균형 데이터셋에서 minority class에 속하는 값을 올바르게 분류하는 것이 어렵다는 아이디어에서 설정해야 한다. 
그러므로 minority class 분류에서 높은 reward 값이나 오분류시 큰 페널티를 주는 식으로 해야 한다.

보상의 가치를 선택할 때의 주먹구구식 규칙은 minority class 요소에 대한 majority class 요소의 비율을 활용하는 것으로, 이를 ρ라고 부르기로 한다.


Set the Memory:

불균형 데이터 세트의 문제를 해결하기 위한 열쇠는 메모리 크기를 서로 다른 클래스 간에 sub-memory로 동일하게 나누는 것이다.(?)  
이 단계가 수행되면, 각각의 sub-memory는 거의 나타나지 않기 때문에 다수층에 속하는 샘플로 모든 소수 계급 샘플을 덮어쓰는 대신 해당 클래스에 의해 추가될 것이다. (파파고 번역)
각 memroy마다 균등하게 class가 유지되어 있다를 말하는 듯

this step is done, each sub-memory would be appended by its corresponding class instead of overwriting all the minority class samples by the ones that belong to the majority class since they rarely appear.

이 방법은 agent를 훈련하기 위해서 메모리로부터 랜덤 샘플들을 뽑을 때 샘플들이 training set의 다른 클래스들 사이에서 균형 있게 뽑히는 것을 보장한다고 한다.
(어떻게 하위 메모리를 만드는 것인지? oversample을 하라는 것인지?)
(논문에서는 sub memory에 대한 언급이 없다)


Episode Termination:

에피소드 e는 agent가 minority class로부터의 샘플들의 특정 숫자 x를 넘을 시 오분류될 시 끝난다.
또는 training set의 모든 샘플을 훈련하게 되면 끝난다.


Set the DQN:

Value Function은 환경에서 특정 state로부터 누적된 기대 보상이다. 
그리고 그것은 agent가 수행하기 위한 행동을 선택하는 것에 달려있다.

Q-Value Function은 일반 value function 보다 더 복잡한 단계이다. 
왜냐하면 Q-Value Function은 state와 action 둘 다 고려하기 때문이다.
이것은 agent가 수행하기 위한 액션을 선택에 달려있는 policy에 의존한다.

최적의 Q-Value Function은 Q∗(s, a) = Q(s, a, θ)로 명명되어 있다.
Q∗에 근사하기 위해서 deep Q-Learning 방법이 도입되었다. 그리고 새로운 텀 θ이 또 고려된다.

 Q∗에 근사하는 데 사용되는 깊은 신경망의 모든 가중치를 나타낸다.

각 time step에서 agent는 DQN을 훈련하기 위해서 사용된 이벤트들의 몇 개를 기억한다(replay memory? 말하는 듯)
이 이벤트들은 주어진 time step과 reward에서 state와 action의 combination이다.

DQN 아키텍처는 데이터의 복잡성에 의존한다. 
선형 활성화 기능은 DQN이 주어진 상태와 액션에서 Q 함수 추정기의 역할을 수행한 후 현재 상태로부터 시작되는 미래 예상 누적 보상을 출력하기 때문에 선택된다.
이것은 Q function은 실제 값이고 그리고 그 값은 softmax 값과 같은 활성홤수 처럼 값이 제한되어있지 않다는 것을 의미한다. loss function은 mse 가 되고 adam을 사용했다.

 


Conclusion

cost function의 weight를 더한다던지 데이터를 oversampling 하는 전통적인 방법 대신에 불균형 이미지 데이터셋의 문제를 극복하는데 활용할 수 있다.
매우 불균형 데이터셋의 경우에서 아주 좋은 결과를 달성했고, 미래에는 다중 분류에서도 강화 학습으로 적용해보려고 한다.

 

728x90