2022. 3. 26. 17:03ㆍ분석 Python/Pytorch
jax 예시로 나온 것처럼 소규모 네트워크를 빠르게 학습시키는 방법에 대해 공유합니다.
소규모 네트워크를 훈련하는 경우 병렬화의 근본적인 한계에 부딪힙니다. 확실히 2계층 MLP는 ResNet-50보다 훨씬 빠르게 실행됩니다. 그러나 ResNet에는 약 4B의 곱셈 누산 연산이 있는 반면 MLP에는 100K만 있습니다.1 우리가 원하는 대로 MLP는 ResNet보다 40,000배 더 빠르게 훈련하지 않으며 GPU 사용률을 검사하면 그 이유를 알 수 있습니다. . GPU의 ~100%를 사용하는 ResNet과 달리 MLP는 2-3%만 사용할 수 있습니다.
더 많은 컴퓨팅을 병렬로 사용하는 한 가지 방법은 배치 크기를 늘리는 것입니다. 예를 들어 128개 요소의 배치를 사용하는 대신 GPU를 채울 때까지 이를 올릴 수 있습니다. 사실, 전체 데이터 세트를 하나의 배치로 사용하고 모든 요소에 대해 병렬화하는 것이 어떻습니까!
그리고 GPU 사용량을 위해 큰 배치를 해도 실제로 효율성이 떨어지기 때문에 그렇기 좋지 않은 방법이라고 합니다.
최소한 exotic한 트릭 없이는 현재 하드웨어에서 소규모 네트워크를 훨씬 빠르게 훈련할 수 없습니다. 데이터 로딩이 빠르도록 주의를 기울이면 2계층 MLP를 ResNet-50보다 400배 더 빠르게 훈련할 수 있습니다. 그리고 그것은 꽤 빠릅니다! 그러나 최소한 ResNet이 40,000배 더 많은 컴퓨팅을 사용한다는 우리 야구장의 추정에 따르면 여전히 100배의 추가 개선이 남아 있습니다.
-> 아무튼 작은 네트워크에서도 GPU를 잘 쓰는 방법에 대해서 소개하는 것 같다.
이번에는 타이타닉 데이터를 전처리부터 torchfunc을 이용하여 모델을 여러개 만들고 한번에 훈련시키는 것을 해보려고 한다.
아직은 초반이다보니, 아쉬운 부분이 많지만, 충분히 매력적인 것 같아서 주기적으로 라이브러리 업데이트 상태를 확인하면 좋을 것 같다
Implementation
Library Load
from sklearn.preprocessing import OneHotEncoder , StandardScaler
from sklearn.impute import SimpleImputer
from sklearn.pipeline import FeatureUnion, Pipeline
from sklearn.compose import ColumnTransformer
import pandas as pd
from sklearn.utils._estimator_html_repr import *
from IPython.core.display import display, HTML
import numpy as np
def onehot_numpy(arr_1d) :
num = np.unique(arr_1d, axis=0)
num = num.shape[0]
encoding = np.eye(num)[arr_1d]
return encoding
Data Load
Read Data
titanic = pd.read_csv("https://raw.githubusercontent.com/datasciencedojo/datasets/master/titanic.csv")
Remove Columns
titanic = titanic.drop(['PassengerId','Name','Ticket'],axis=1)
Split X , Y
y = titanic.pop("Survived")
x = titanic
Preprocessing
categori_cols = ['Pclass','Sex','SibSp','Parch','Cabin','Embarked']
numeric_cols = ["Age"]
numeric_features = numeric_cols
numeric_transformer = Pipeline(steps=[
('imputer', SimpleImputer(strategy='median')),
('scaler', StandardScaler())])
categorical_features = categori_cols
categorical_transformer = Pipeline(steps=[
('imputer', SimpleImputer(strategy='constant', fill_value='missing')),
('onehot', OneHotEncoder(handle_unknown='ignore'))])
preprocessor = ColumnTransformer(
transformers=[
('num', numeric_transformer, numeric_features),
('cat', categorical_transformer, categorical_features)])
clf = Pipeline(steps=[('preprocessor', preprocessor),])
clf.fit(x)
html = estimator_html_repr(clf)
display(HTML(html))
result = clf.transform(x)
x_arr = result.toarray()
y_arr = y.values
y_arr = onehot_numpy(y_arr)
x_arr.shape , y_arr.shape
# ((891, 172), (891, 2))
Modeling
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from functorch import make_functional, grad_and_value, vmap, combine_state_for_ensemble
class MLPClassifier(nn.Module):
def __init__(self, hidden_dim=32, n_classes=2,seed=1234):
super().__init__()
torch.manual_seed(seed)
self.hidden_dim = hidden_dim
self.n_classes = n_classes
self.fc1 = nn.Linear(172, self.hidden_dim)
self.fc2 = nn.Linear(self.hidden_dim, self.n_classes)
def forward(self, x):
x = self.fc1(x)
x = F.relu(x)
x = self.fc2(x)
return x
DEVICE = "cuda:1"
func_model, weights = make_functional(MLPClassifier().to(DEVICE))
loss_fn = nn.BCEWithLogitsLoss(reduction="mean")
def train_step_fn(weights, batch, targets, lr=0.2):
def compute_loss(weights, batch, targets):
output = func_model(weights, batch)
loss = loss_fn(output, targets)
return loss
grad_weights, loss = grad_and_value(compute_loss)(weights, batch, targets)
new_weights = []
with torch.no_grad():
for grad_weight, weight in zip(grad_weights, weights):
new_weights.append(weight - grad_weight * lr)
return loss, new_weights
def init_fn(num_models ):
models = [MLPClassifier(hidden_dim=30 , seed=np.random.randint(1,1000)).to(DEVICE) for _ in range(num_models)]
_, params, _ = combine_state_for_ensemble(models)
return params
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import numpy as np
class TabularDataSet(Dataset) :
def __init__(self, X , Y) :
self._X = np.float32(X)
self._Y = Y
def __len__(self,) :
return len(self._Y)
def __getitem__(self,idx) :
return self._X[idx], self._Y[idx]
dataset = TabularDataSet(x_arr ,y_arr)
batch_size = 32
train_dl = DataLoader(dataset, batch_size=batch_size,shuffle=True)
Train
parallel_train_step_fn = vmap(train_step_fn, in_dims=(0, None, None))
batched_weights = init_fn(num_models=5)
loss_result = []
import matplotlib.pyplot as plt
from IPython import display
for i in range(2000):
loss_collection =[]
for x , y in train_dl :
x = x.to(DEVICE)
y = y.to(DEVICE)
loss, batched_weights = parallel_train_step_fn(batched_weights, x, y)
loss_collection.append(loss)
print(",".join([f"{i:.3f}" for i in loss.detach().cpu().numpy()]),end="\r")
else :
losses = torch.stack(loss_collection).mean(axis=0).detach().cpu().numpy()
loss_result.append(losses)
if i % 10 == 0 :
display.clear_output(wait=True)
plt.plot(loss_result)
plt.show()
다만 아쉬운 점은 여기서 각각의 모델에 대한 베스트 모델이 다를 것 같은데, 아직은 naive한 버전밖에 없는 상황이다.
그래서 각 모델을 뽑기 위한 유틸 함수가 추가로 필요해보인다.
for i in batched_weights :
print(i.size())
torch.Size([5, 30, 172])
torch.Size([5, 30])
torch.Size([5, 2, 30])
torch.Size([5, 2])
Prediction
def prediction_step_fn(weights, batch ):
output = func_model(weights, batch)
return output
parallel_prediction_step_fn = vmap(prediction_step_fn, in_dims=(0, None))
x_tensor = torch.FloatTensor(x_arr).to(DEVICE)
print(x_tensor.shape)
# torch.Size([891, 172])
result = parallel_prediction_step_fn(batched_weights , x_tensor )
print(result.shape)
#torch.Size([5, 891, 2])
idx = 0
one_row_prediction = result[:,idx,:]
one_result = one_row_prediction.detach().cpu().numpy()
num_models = one_result.shape[0]
pd.DataFrame(one_result , index = [f"model_{i}" for i in range(num_models)])
https://github.com/pytorch/functorch/blob/main/examples/ensembling/parallel_train.py
http://willwhitney.com/parallel-training-jax.html
'분석 Python > Pytorch' 카테고리의 다른 글
Pytorch) 모델 가중치 로드 시 테스트 (전체 모델에서 서브 모델 가중치만 가져오기) (0) | 2023.09.15 |
---|---|
Pytorch) multioutput Regression 구현해보기 (4) | 2022.03.26 |
Pytorch 1.11 이후) functorch 알아보기 (0) | 2022.03.14 |
Pytorch 1.11 이후) torchdata 알아보기 (0) | 2022.03.14 |
pytorch 1.8.0 램 메모리 누수 현상 발견 및 해결 (0) | 2021.12.18 |