2021. 3. 17. 20:19ㆍ분석 Python/Pytorch
우리가 자주 쓰는 방식은 위와 같은 방식이지만, 실제로 우리가 각각에 대해서 특정 인덱스를 뽑고자 하는 경우가 있다.
최근에 동료 연구원이 이것에 대해서 질문을 하였을 때, 즉각적으로 생각이 안 나서, 시간을 소비하였고, 유용하면서도 헷갈리는 것 같아 정리해보려고 한다.
torch.gather(input, dim, index, out=None, sparse_grad=False)
→ TensorGathers values along an axis specified by dim.
위와 같이 특정 인덱스를 뽑으려고 하면 처음 접근 방식은 loop를 생각하지만, torch에서는 gather 함수를 제공하여 쉽게 indexing을 할 수 있다.
그리고 loop 방식은 차원이 커질 수록 일반화된 방식으로 처리하기가 난해하지만, gather 함수를 쓰면 2차원이든, 3차원이든지 간에 일관된 방법으로 사용 가능하다.
예제 코드를 쓰면 다음과 같다.
어떤 2차원 matrix가 있고 우리가 각 행에서 특정 인덱스만 뽑아내고 싶은 경우의 케이스를 만들어봤다.
뒤에서 말하겠지만, indices의 차원의 수와 matrix의 차원의 수를 맞춰줘야 하는 것에 유념해야 한다.
import torch
matrix = torch.range(0,99).reshape(10,10)
indices = [0,1,2,3,4,5,6,7,8,9]
indices = torch.tensor(indices).unsqueeze(axis=-1)
print(matrix)
print(indices)
print(torch.gather(matrix, 1, indices ))
만약 왼쪽 그림처럼 특정 빨간색 부분만 뽑고 싶다면?
gather 함수를 사용하기 위해서는 크게 3가지 파라미터를 정의해야 한다.
- input — input tensor
- dim — dimension along to collect values
- index — tensor with indices of values to collect
여기서 중요한 것은 input과 index의 dimension이 동일해야 한다.
만약 인풋이 4x10x15이고 , dim =0 이면 index는 Nx10x15
즉 dim인 부분을 제외하고 나머지 차원은 동일해야 한다.
이제 3D에서 해보자.
만약 우리가 RNN을 한다고 했을 때, 우리가 하고 싶은 것은 각 seq에서 마지막 요소만 모은다고 하자. (hidden state의 모든 feature에서)
인풋 데이터는 다음의 tensor를 따른다고 하자. (batchsize x max seq len x hdden_state) ( 8 x 9 x 6 )
batch_size = 8
max_seq_len = 9
hidden_size = 6
x = torch.empty(batch_size, max_seq_len, hidden_size)
for i in range(batch_size):
for j in range(max_seq_len):
for k in range(hidden_size):
x[i,j,k] = i + j*10 + k*100
숫자의 의미는 다음과 같다 123이면 1 번째 배치 2번째 seq 3번째 hidden state라는 의미다.
x[:,4,:]
>tensor([[ 40., 140., 240., 340., 440., 540.],
[ 41., 141., 241., 341., 441., 541.],
[ 42., 142., 242., 342., 442., 542.],
[ 43., 143., 243., 343., 443., 543.],
[ 44., 144., 244., 344., 444., 544.],
[ 45., 145., 245., 345., 445., 545.],
[ 46., 146., 246., 346., 446., 546.],
[ 47., 147., 247., 347., 447., 547.]])
우리는 여기서 각 seq에 마지막 element를 알고 있다. (패딩을 제외함을 의미함)
lens = torch.LongTensor([5,6,1,8,3,7,3,4])
인풋의 차원은 8x9x6(batch size x max seq len x hidden state)
여기서 우리가 원하는 seq의 마지막 차원을 뽑고 싶으니까 index의 shape는 (8 x 1 x 6)이 되는 것이다.
그러면 총 우리는 일단 index를 하기 위해서 8 x 1 x 6 = 48개를 채워야 한다.
핵심은 우리가 6 개의 은닉 상태를 가지고 있고 그것들을 모두 수집하고 싶다는 것을 이해하는 것입니다. – 우리가 42 개의 가치를 기대한다는 것은 분명하다! (8 개의 예에서 6 개의 숨겨진 상태). 해결책은 매우 간단하다.
len를 6번 반복하면 됩니다.
lens = torch.LongTensor([5,6,1,8,3,7,3,4])
# add one trailing dimension
lens = lens.unsqueeze(-1)
print(lens.shape)
아래 결과를 보면 각 seq len가 5 ,6 ,1, 8 ,3 ,7 , 3, 4를 뽑은 것을 알 수 있다.
만약 우리가 3차원에 그림에서 적용된 상태를 보면 다음과 같다.
만약 seq len을 2,2,2,4,4,4,6,7로 하면 아래와 같은 형태로 뽑히는 것이다. 이
형태를 만들어주기 위해서 반복되는 수가 나오는 것이다.
[0,2,0] , [0,2,1]... [0,2,6]...
이러한 포맷은 4차원이든 3차원이든 2차원이든 다 적용될 것이다!
참고)
medium.com/analytics-vidhya/understanding-indexing-with-pytorch-gather-33717a84ebc4
'분석 Python > Pytorch' 카테고리의 다른 글
pytorch 1.8.0 램 메모리 누수 현상 발견 및 해결 (0) | 2021.12.18 |
---|---|
pytorch) dataloader sampler (4) | 2021.04.19 |
PyGAD + Pytorch + Skorch+ torch jit (0) | 2021.01.30 |
[Pytorch] Error : Leaf variable has been moved into the graph interior 해결 방법 공유 (0) | 2021.01.16 |
[Pytorch] How to Apply the Weight Initialization (Code) (0) | 2020.12.17 |