[Pytorch] gather 함수 설명 (특정 인덱스만 추출하기)
우리가 자주 쓰는 방식은 위와 같은 방식이지만, 실제로 우리가 각각에 대해서 특정 인덱스를 뽑고자 하는 경우가 있다.
최근에 동료 연구원이 이것에 대해서 질문을 하였을 때, 즉각적으로 생각이 안 나서, 시간을 소비하였고, 유용하면서도 헷갈리는 것 같아 정리해보려고 한다.
torch.gather(input, dim, index, out=None, sparse_grad=False)
→ TensorGathers values along an axis specified by dim.

위와 같이 특정 인덱스를 뽑으려고 하면 처음 접근 방식은 loop를 생각하지만, torch에서는 gather 함수를 제공하여 쉽게 indexing을 할 수 있다.
그리고 loop 방식은 차원이 커질 수록 일반화된 방식으로 처리하기가 난해하지만, gather 함수를 쓰면 2차원이든, 3차원이든지 간에 일관된 방법으로 사용 가능하다.
예제 코드를 쓰면 다음과 같다.
어떤 2차원 matrix가 있고 우리가 각 행에서 특정 인덱스만 뽑아내고 싶은 경우의 케이스를 만들어봤다.
뒤에서 말하겠지만, indices의 차원의 수와 matrix의 차원의 수를 맞춰줘야 하는 것에 유념해야 한다.
import torch
matrix = torch.range(0,99).reshape(10,10)
indices = [0,1,2,3,4,5,6,7,8,9]
indices = torch.tensor(indices).unsqueeze(axis=-1)
print(matrix)
print(indices)
print(torch.gather(matrix, 1, indices ))



만약 왼쪽 그림처럼 특정 빨간색 부분만 뽑고 싶다면?



gather 함수를 사용하기 위해서는 크게 3가지 파라미터를 정의해야 한다.
- input — input tensor
- dim — dimension along to collect values
- index — tensor with indices of values to collect
여기서 중요한 것은 input과 index의 dimension이 동일해야 한다.
만약 인풋이 4x10x15이고 , dim =0 이면 index는 Nx10x15
즉 dim인 부분을 제외하고 나머지 차원은 동일해야 한다.
이제 3D에서 해보자.
만약 우리가 RNN을 한다고 했을 때, 우리가 하고 싶은 것은 각 seq에서 마지막 요소만 모은다고 하자. (hidden state의 모든 feature에서)
인풋 데이터는 다음의 tensor를 따른다고 하자. (batchsize x max seq len x hdden_state) ( 8 x 9 x 6 )
batch_size = 8
max_seq_len = 9
hidden_size = 6
x = torch.empty(batch_size, max_seq_len, hidden_size)
for i in range(batch_size):
for j in range(max_seq_len):
for k in range(hidden_size):
x[i,j,k] = i + j*10 + k*100
숫자의 의미는 다음과 같다 123이면 1 번째 배치 2번째 seq 3번째 hidden state라는 의미다.
x[:,4,:]
>tensor([[ 40., 140., 240., 340., 440., 540.],
[ 41., 141., 241., 341., 441., 541.],
[ 42., 142., 242., 342., 442., 542.],
[ 43., 143., 243., 343., 443., 543.],
[ 44., 144., 244., 344., 444., 544.],
[ 45., 145., 245., 345., 445., 545.],
[ 46., 146., 246., 346., 446., 546.],
[ 47., 147., 247., 347., 447., 547.]])
우리는 여기서 각 seq에 마지막 element를 알고 있다. (패딩을 제외함을 의미함)
lens = torch.LongTensor([5,6,1,8,3,7,3,4])
인풋의 차원은 8x9x6(batch size x max seq len x hidden state)
여기서 우리가 원하는 seq의 마지막 차원을 뽑고 싶으니까 index의 shape는 (8 x 1 x 6)이 되는 것이다.
그러면 총 우리는 일단 index를 하기 위해서 8 x 1 x 6 = 48개를 채워야 한다.
핵심은 우리가 6 개의 은닉 상태를 가지고 있고 그것들을 모두 수집하고 싶다는 것을 이해하는 것입니다. – 우리가 42 개의 가치를 기대한다는 것은 분명하다! (8 개의 예에서 6 개의 숨겨진 상태). 해결책은 매우 간단하다.
len를 6번 반복하면 됩니다.
lens = torch.LongTensor([5,6,1,8,3,7,3,4])
# add one trailing dimension
lens = lens.unsqueeze(-1)
print(lens.shape)
아래 결과를 보면 각 seq len가 5 ,6 ,1, 8 ,3 ,7 , 3, 4를 뽑은 것을 알 수 있다.



만약 우리가 3차원에 그림에서 적용된 상태를 보면 다음과 같다.
만약 seq len을 2,2,2,4,4,4,6,7로 하면 아래와 같은 형태로 뽑히는 것이다. 이
형태를 만들어주기 위해서 반복되는 수가 나오는 것이다.
[0,2,0] , [0,2,1]... [0,2,6]...

이러한 포맷은 4차원이든 3차원이든 2차원이든 다 적용될 것이다!
참고)
medium.com/analytics-vidhya/understanding-indexing-with-pytorch-gather-33717a84ebc4