[Pytorch] gather 함수 설명 (특정 인덱스만 추출하기)

2021. 3. 17. 20:19분석 Python/Pytorch

728x90

우리가 자주 쓰는 방식은 위와 같은 방식이지만, 실제로 우리가 각각에 대해서 특정 인덱스를 뽑고자 하는 경우가 있다.

최근에 동료 연구원이 이것에 대해서 질문을 하였을 때, 즉각적으로 생각이 안 나서, 시간을 소비하였고, 유용하면서도 헷갈리는 것 같아 정리해보려고 한다.

torch.gather(input, dim, index, out=None, sparse_grad=False)
→ TensorGathers values along an axis specified by dim.

 

 

위와 같이 특정 인덱스를 뽑으려고 하면 처음 접근 방식은 loop를 생각하지만, torch에서는 gather 함수를 제공하여 쉽게 indexing을 할 수 있다. 

그리고 loop 방식은 차원이 커질 수록 일반화된 방식으로 처리하기가 난해하지만, gather 함수를 쓰면 2차원이든, 3차원이든지 간에 일관된 방법으로 사용 가능하다.


예제 코드를 쓰면 다음과 같다.

어떤 2차원 matrix가 있고 우리가 각 행에서 특정 인덱스만 뽑아내고 싶은 경우의 케이스를 만들어봤다. 

뒤에서 말하겠지만, indices의 차원의 수와 matrix의 차원의 수를 맞춰줘야 하는 것에 유념해야 한다. 

import torch

matrix = torch.range(0,99).reshape(10,10)
indices = [0,1,2,3,4,5,6,7,8,9]
indices = torch.tensor(indices).unsqueeze(axis=-1)
print(matrix)
print(indices)
print(torch.gather(matrix, 1, indices ))

 

matrix / indices / result

 

만약 왼쪽 그림처럼 특정 빨간색 부분만 뽑고 싶다면?

matrix / indices / result

gather 함수를 사용하기 위해서는 크게 3가지 파라미터를 정의해야 한다.

  • input — input tensor
  • dim — dimension along to collect values
  • index — tensor with indices of values to collect

여기서 중요한 것은 input과 index의 dimension이 동일해야 한다.

만약 인풋이 4x10x15이고  , dim =0 이면 index는 Nx10x15 

즉 dim인 부분을 제외하고 나머지 차원은 동일해야 한다.

 


이제 3D에서 해보자.

만약 우리가 RNN을 한다고 했을 때, 우리가 하고 싶은 것은 각 seq에서 마지막 요소만 모은다고 하자. (hidden state의 모든 feature에서)

 

인풋 데이터는 다음의 tensor를 따른다고 하자. (batchsize x max seq len x hdden_state)  ( 8 x 9 x 6 )

 

batch_size = 8
max_seq_len = 9
hidden_size = 6
x = torch.empty(batch_size, max_seq_len, hidden_size)
for i in range(batch_size):
  for j in range(max_seq_len):
    for k in range(hidden_size):
      x[i,j,k] = i + j*10 + k*100

숫자의 의미는 다음과 같다 123이면 1 번째 배치 2번째 seq 3번째 hidden state라는 의미다.

 

x[:,4,:]
>tensor([[ 40., 140., 240., 340., 440., 540.],
        [ 41., 141., 241., 341., 441., 541.],
        [ 42., 142., 242., 342., 442., 542.],
        [ 43., 143., 243., 343., 443., 543.],
        [ 44., 144., 244., 344., 444., 544.],
        [ 45., 145., 245., 345., 445., 545.],
        [ 46., 146., 246., 346., 446., 546.],
        [ 47., 147., 247., 347., 447., 547.]])

 

우리는 여기서 각 seq에 마지막 element를 알고 있다. (패딩을 제외함을 의미함)

lens = torch.LongTensor([5,6,1,8,3,7,3,4])

인풋의 차원은 8x9x6(batch size x max seq len x hidden state) 

여기서 우리가 원하는 seq의 마지막 차원을 뽑고 싶으니까 index의 shape는 (8 x 1 x 6)이 되는 것이다. 

그러면 총 우리는 일단 index를 하기 위해서 8 x 1 x 6 = 48개를 채워야 한다. 

핵심은 우리가 6 개의 은닉 상태를 가지고 있고 그것들을 모두 수집하고 싶다는 것을 이해하는 것입니다. – 우리가 42 개의 가치를 기대한다는 것은 분명하다! (8 개의 예에서 6 개의 숨겨진 상태). 해결책은 매우 간단하다.

len를 6번 반복하면 됩니다.

lens = torch.LongTensor([5,6,1,8,3,7,3,4])
# add one trailing dimension
lens = lens.unsqueeze(-1)
print(lens.shape)

아래 결과를 보면 각 seq len가 5 ,6 ,1, 8 ,3 ,7 , 3, 4를 뽑은 것을 알 수 있다.

 

만약 우리가 3차원에 그림에서 적용된 상태를 보면 다음과 같다.

만약 seq len을 2,2,2,4,4,4,6,7로 하면 아래와 같은 형태로 뽑히는 것이다. 이

형태를 만들어주기 위해서 반복되는 수가 나오는 것이다. 

[0,2,0] , [0,2,1]... [0,2,6]... 

 

index=2,2,2,4,4,4,6,7 / batch_size = 8. / hidden_size=6

 

이러한 포맷은 4차원이든 3차원이든 2차원이든 다 적용될 것이다!

 

 

참고)

medium.com/analytics-vidhya/understanding-indexing-with-pytorch-gather-33717a84ebc4

 

728x90