
[Pytorch] gather 함수 설명 (특정 인덱스만 추출하기)
·
분석 Python/Pytorch
우리가 자주 쓰는 방식은 위와 같은 방식이지만, 실제로 우리가 각각에 대해서 특정 인덱스를 뽑고자 하는 경우가 있다. 최근에 동료 연구원이 이것에 대해서 질문을 하였을 때, 즉각적으로 생각이 안 나서, 시간을 소비하였고, 유용하면서도 헷갈리는 것 같아 정리해보려고 한다. torch.gather(input, dim, index, out=None, sparse_grad=False) → TensorGathers values along an axis specified by dim. 위와 같이 특정 인덱스를 뽑으려고 하면 처음 접근 방식은 loop를 생각하지만, torch에서는 gather 함수를 제공하여 쉽게 indexing을 할 수 있다. 그리고 loop 방식은 차원이 커질 수록 일반화된 방식으로 처리하기가 ..