[Review] CURL: Contrastive Unsupervised Representations for Reinforcement Learning

2021. 2. 13. 23:20관심있는 주제/RL

728x90

논문 리뷰)

CURL : Contrastive Unsupervised Representations for Reinforcement Learning.

 

Abstract

CURL이라는 것은 강화학습에서 Contrastive를 사용한 비지도 방법을 의미한다.

CURL은 constrastive learning을 사용하여 원래의 픽셀로부터 고차원의 피처를 뽑아내고, 추출된 피처로 off-policy control를 수행한다.

CURL 은 복잡한 테스크(DeepMind Control Suite and Atari Games)에서 기존의 pixel based를 사용한 방법론(model-based , model-free)을 뛰어넘은 성능을 보였다고 한다. (1.9, 1.2배)


Introduction

  • 강화학습은 고차원의 데이터를 사용하는데, 이 부분에서 샘플 비효율성이 발생함. 
  • 랜더링 하거나 데이터 모으는데 많은 시간이 걸리므로 샘플 효율성을 향상하는 것은 중요함.
  • 그래서 연구가 2개의 흐름으로 있다고 함.
    • (i) Auxiliary tasks on the agent’s sensory observations;
    • (ii) World models that predict the future. While the former class of methods use auxiliary self-supervision tasks to accelerate the learning progress of model-free RL method
  • CURL은 이 중에서 1번째에 속함. 
    • 가정
If an agent learns a useful semantic representation from high dimensional observations, control algorithms built on top of those representations should be significantly more data-efficient.
(에이전트가 고차원 데이터로부터 유의미한 표현을 학습하면, 그 기반으로 알고리즘이 데이터 효율성을 가질 것이다)

최근에 masked language modeling이나 contrastive learning과 같이 self-supervised representation learning이 발전하고 있다.

하지만 저자가 강화학습에서 이러한 방법을 접근하려면 2가지 차이가 있다고 한다.

  • 에이전트가 interaction으로부터 데이터를 모으기 전에는 unlabeled dataset이 없다 
  • 에이전트는 하위 테스크를 위해서 학습된 모델을 사용하는 파인 튜닝과는 반대로 비지도와 강화학습을 동시에 해야 한다. 

그래서 저자는 이 차이들로 인해, 어떻게 contrastive learning을 사용할 때, online interaction에서 control 효율성과 데이터 효율성을 학습시킬 수 있을까를 고민했다고 한다.

 

이러한 고민을 통해 낸 논문이 CURL이다.

CURL은 contrastive learning의 한 형태를 사용합니다.
각 관측치는 시간적으로 순차적인 프레임의 스택인 동일한 관측치의 증강(augmented)된 버전 간의 일치를 최대화합니다.

이럼으로써, CURL이 기존 pixel based 방법보다 샘플 효율성을 향상했다고 한다. 

 

key contributions

  • simple framework that integrates contrastive learning with model-free RL with minimal changes to the architecture and training pipeline
  • a contrastive objective is the preferred self-supervised auxiliary task for achieving sample-efficiency compared to reconstruction based methods, and enables model-free methods to outperform state-of-the-art model-based methods in terms of data-efficiency.

 

 


CURL

그림1. Contrastive Unsupervised Representations for Reinforcement Learning (CURL) combines instance contrastive learning and reinforcement learning

CURL은 visual representation encoder를 훈련시킨다. 

data-augmentated versions($o$의 $o_q$ 와 $o_k$ )의 임베딩을 얻기 위해서 contrastive loss를 사용한다.

$o_k$가 positive와 negative를 포함할 동안, $o_q$는 anchor로써 다루어진다. 

여기서 positive와 negative 그리고 anchor라는 말이 나오는 데 이 부분은 contrastive learning이다.
positive은 anchor의 과거 정보/ 나머지(이미지)는 negative를 의미함. 
그래서 anchor는 positive와는 가깝게 하면서, negative와는 멀어지게 하는 consine 유사도를 사용한다고 함.

key들은 query encode의 모멘텀 평균 버전과 함께 인코딩 되고, 강화 학습의 policy와 value function은 query encoder로 구성된다. 

그래서 강화학습의 목적함수와 contrstive를 같이 훈련시킨다. 

CURL은 고차원 이미지로부터 representation learnin에 아무런 강화학습 알고리즘을 합칠 수 있는 일반화된 프레임워크라고 한다. 

 

 

  • query observation : $o_q$
  • key observation : $o_k$

Related Work

  • Self-Supervised Learning
  • Contrastive Learning
  • Self-Supervised Learning for RL
  • World Models for sample-efficiency
  • Sample-efficient RL for image-based control

Background

CURL is a general framework for combining contrastive learning with RL. In principle, one could use any RL algorithm in the CURL pipeline, be it on-policy or off-policy.

그림 2. CURL Architecture

replay buffer로부터 batch transition을 샘플링하고, observations을 query와 , key로부터 2배만큼 data augmentation을 한다. 이때 encoder는 개별적으로 query encoder와  key encoder를 사용한다. 

queries는 RL 알고리즘에 사용되고, query-key pairs들은 모두 contrastive learning 에 사용한다.

gradient update 동안에, query만 update가 됩니다. key encoder의 weight는 MoCo와 유사하게 query weights의 moving average를 이용한다.(EMA)

그림2. CURL Architecture


Contrastive Learning

이 논문의 핵심 요소중에 하나는 고차원 데이터를 contrastive unsupervised learning을 사용하여, 풍부한 representation을 학습할 수 있는 능력에 있다.

필자도 이번에 보는 거라 잘 알지는 못하지만, 직관적으로 이해한 것은 embedding 값을 얻고 싶은데, 얻는 방법을 비지도 방법으로 하고 싶은 것이고, 비지도 방법 중에서 사용한 방법이 contrastive learning 방법이고, 이 방법은 유사한 것은 유사하게 하고, 먼 것은 더 멀게 하는 마치 word2 vec과 같은 방법을 의미한다.

그래서 실제로 학습시에 레이블링이 없이 이미지 데이터를 바탕으로 임베딩 값을 잘 학습하게 하는 방법론이다.

http://dmqm.korea.ac.kr/activity/seminar/308
MoCo

 


Query-Key Pair Generation

 

저자는 observation의 시간적 구조에 대한 정보를 유지하기 위해 배치 전체에 걸쳐 무작위 증가를 적용하지만 각 프레임 스택에 일관되게 적용합니다.

 

 

Experiments

Evaluation

  • sample-efficiency
  • performance (100k 500k)
    • ratio = $\frac{CURL의 episode retrruns}{T steps의 basline에서 최고일 때 episode return}$

Sample Efficiency

 

performance 

Conclusion

CURL은 지금까지 학습 세계 모델 및 (또는) 디코더 기반 목표에 의존하는 접근 방식이 지배하는 복잡한 작업에 대한 최첨단 성능을 보여주기 위해 최소한의 아키텍처 변경으로 대조 학습으로 가속화된 최초의 model-free RL 파이프이라고 한다.


실제로 이러한 학습 방법은 simulator와 real의 간격을 줄여줄 수 있다고도 생각이 든다.

왜냐하면 실제 pixel을 그대로 활용하다보면, 둘 사이에는 반드시 간격이 생길 거라 생각하고, 이 사이를 줄여줄 수 있는 어떤 것이 필요한 것 같은데, 위의 방식처럼 잘 인코딩을 할 수 있는 방법과 결합한다면, 빠르게 더 좋은 성능을 낼 수 있을까도 기대를 한다.

물론 이 부분에서 궁금한 것은 data augmentation을 할 때 어떻게 state 정보를 포함해서 도메인에서 유의미한 feature를 생성할 수 있는지가 궁금하다. 결국 임의로 뭔가 정보를 담고 있는 데이터를 생성시키는 것 같은데, 도메인마다 다를 수 있다고 생각이 든다(domain by domain)


Reference

github.com/MishaLaskin/curl

 

MishaLaskin/curl

CURL: Contrastive Unsupervised Representation Learning for Sample-Efficient Reinforcement Learning - MishaLaskin/curl

github.com

 

contrastive learning 참고

dmqm.korea.ac.kr/activity/seminar/308

 

고려대학교 DMQA 연구실

고려대학교 산업경영공학부 데이터마이닝 및 품질애널리틱스 연구실

dmqa.korea.ac.kr

lilianweng.github.io/lil-log/2019/11/10/self-supervised-learning.html

 

Self-Supervised Representation Learning

Self-supervised learning opens up a huge opportunity for better utilizing unlabelled data, while learning in a supervised learning manner. This post covers many interesting ideas of self-supervised learning tasks on images, videos, and control problems.

lilianweng.github.io

 

 
728x90