Pytorch 1.11 이후) functorch 알아보기
2022. 3. 14. 20:21ㆍ분석 Python/Pytorch
22년 3월 14일 기준 현재까지는 Beta 버전입니다.
Google JAX에서 크게 영감을 받은 functorch는 구성 가능한 함수 변환을 PyTorch에 추가하는 라이브러리입니다.
PyTorch 모듈 및 PyTorch autograd와 함께 작동하는 구성 가능한 vmap(벡터화) 및 autodiff 변환을 우수한 eager-mode 성능으로 제공하는 것을 목표로 합니다.
구성 가능한 함수 변환은 오늘날 PyTorch에서 수행하기 어려운 여러 사용 사례에 도움이 될 수 있습니다.
- computing per-sample-gradients (or other per-sample quantities)
- running ensembles of models on a single machine
- efficiently batching together tasks in the inner-loop of MAML
- efficiently computing Jacobians and Hessians as well as batched ones
여기서 가장 실용적이라 생각한 것은 ensemble 할 때 기존의 loop로 했던 방식을 개선할 수 있을 것 같다는 생각이 들었습니다.
Implementation
라이브러리 로드
import argparse
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from functorch import make_functional, grad_and_value, vmap, combine_state_for_ensemble
데이터 생성기
def make_spirals(n_samples, noise_std=0., rotations=1.):
ts = torch.linspace(0, 1, n_samples, device=DEVICE)
rs = ts ** 0.5
thetas = rs * rotations * 2 * math.pi
signs = torch.randint(0, 2, (n_samples,), device=DEVICE) * 2 - 1
labels = (signs > 0).to(torch.long).to(DEVICE)
xs = rs * signs * torch.cos(thetas) + torch.randn(n_samples, device=DEVICE) * noise_std
ys = rs * signs * torch.sin(thetas) + torch.randn(n_samples, device=DEVICE) * noise_std
points = torch.stack([xs, ys], dim=1)
return points, labels
DEVICE = "cpu"
points, labels = make_spirals(100, noise_std=0.05)
네트워크 정의
class MLPClassifier(nn.Module):
def __init__(self, hidden_dim=32, n_classes=2):
super().__init__()
self.hidden_dim = hidden_dim
self.n_classes = n_classes
self.fc1 = nn.Linear(2, self.hidden_dim)
self.fc2 = nn.Linear(self.hidden_dim, self.n_classes)
def forward(self, x):
x = self.fc1(x)
x = F.relu(x)
x = self.fc2(x)
x = F.log_softmax(x, -1)
return x
loss_fn = nn.NLLLoss()
모델을 functional로 만들고, 훈련 함수를 정의
func_model, weights = make_functional(MLPClassifier().to(DEVICE))
def train_step_fn(weights, batch, targets, lr=0.2):
def compute_loss(weights, batch, targets):
output = func_model(weights, batch)
loss = loss_fn(output, targets)
return loss
grad_weights, loss = grad_and_value(compute_loss)(weights, batch, targets)
# NB: PyTorch is missing a "functional optimizer API" (possibly coming soon)
# so we are going to re-implement SGD here.
new_weights = []
with torch.no_grad():
for grad_weight, weight in zip(grad_weights, weights):
new_weights.append(weight - grad_weight * lr)
return loss, new_weights
def step4():
global weights
for i in range(2000):
loss, weights = train_step_fn(weights, points, labels)
if i % 100 == 0:
print(loss)
step4()
여러 개의 모델을 준비하면, 파라미터로 모든 가중치를 반환
def init_fn(num_models):
models = [MLPClassifier().to(DEVICE) for _ in range(num_models)]
_, params, _ = combine_state_for_ensemble(models)
return
모델 병렬 학습
아래 같은 코드를 사용하면 병렬 학습이 가능하다고 합니다.
def step6():
parallel_train_step_fn = vmap(train_step_fn, in_dims=(0, None, None))
batched_weights = init_fn(num_models=2)
for i in range(2000):
loss, batched_weights = parallel_train_step_fn(batched_weights, points, labels)
if i % 200 == 0:
print(loss)
step6()
개인적으로 jax를 공부할까 말까 고민을 많이 했는데, pytorch에서도 functorch라는 것을 이번에 제시해서 좋았고,
기존에 loop로 하는 것을 좀 더 빠르게 하는 기능들이 나와서 속도 개선에 도움이 많이 될 것 같다는 생각이 들었습니다.
Reference
https://github.com/pytorch/functorch/blob/main/examples/ensembling/parallel_train.py http://willwhitney.com/parallel-training-jax.html
728x90
'분석 Python > Pytorch' 카테고리의 다른 글
Pytorch) multioutput Regression 구현해보기 (4) | 2022.03.26 |
---|---|
torchfunc) titanic data에 model parallel training 해보기 (0) | 2022.03.26 |
Pytorch 1.11 이후) torchdata 알아보기 (0) | 2022.03.14 |
pytorch 1.8.0 램 메모리 누수 현상 발견 및 해결 (0) | 2021.12.18 |
pytorch) dataloader sampler (4) | 2021.04.19 |