torchfunc) titanic data에 model parallel training 해보기

2022. 3. 26. 17:03분석 Python/Pytorch

jax 예시로 나온 것처럼 소규모 네트워크를 빠르게 학습시키는 방법에 대해 공유합니다.

 

소규모 네트워크를 훈련하는 경우 병렬화의 근본적인 한계에 부딪힙니다. 확실히 2계층 MLP는 ResNet-50보다 훨씬 빠르게 실행됩니다. 그러나 ResNet에는 약 4B의 곱셈 누산 연산이 있는 반면 MLP에는 100K만 있습니다.1 우리가 원하는 대로 MLP는 ResNet보다 40,000배 더 빠르게 훈련하지 않으며 GPU 사용률을 검사하면 그 이유를 알 수 있습니다. . GPU의 ~100%를 사용하는 ResNet과 달리 MLP는 2-3%만 사용할 수 있습니다.

 

더 많은 컴퓨팅을 병렬로 사용하는 한 가지 방법은 배치 크기를 늘리는 것입니다. 예를 들어 128개 요소의 배치를 사용하는 대신 GPU를 채울 때까지 이를 올릴 수 있습니다. 사실, 전체 데이터 세트를 하나의 배치로 사용하고 모든 요소에 대해 병렬화하는 것이 어떻습니까!

 

그리고 GPU 사용량을 위해 큰 배치를 해도 실제로 효율성이 떨어지기 때문에 그렇기 좋지 않은 방법이라고 합니다.

 

최소한 exotic한 트릭 없이는 현재 하드웨어에서 소규모 네트워크를 훨씬 빠르게 훈련할 수 없습니다. 데이터 로딩이 빠르도록 주의를 기울이면 2계층 MLP를 ResNet-50보다 400배 더 빠르게 훈련할 수 있습니다. 그리고 그것은 꽤 빠릅니다! 그러나 최소한 ResNet이 40,000배 더 많은 컴퓨팅을 사용한다는 우리 야구장의 추정에 따르면 여전히 100배의 추가 개선이 남아 있습니다.

 

-> 아무튼 작은 네트워크에서도 GPU를 잘 쓰는 방법에 대해서 소개하는 것 같다.

 

이번에는 타이타닉 데이터를 전처리부터 torchfunc을 이용하여 모델을 여러개 만들고 한번에 훈련시키는 것을 해보려고 한다. 

아직은 초반이다보니, 아쉬운 부분이 많지만, 충분히 매력적인 것 같아서 주기적으로 라이브러리 업데이트 상태를 확인하면 좋을 것 같다

 

Implementation

Library Load

from sklearn.preprocessing import OneHotEncoder , StandardScaler
from sklearn.impute import SimpleImputer
from sklearn.pipeline import FeatureUnion, Pipeline
from sklearn.compose import ColumnTransformer
import pandas as pd

from sklearn.utils._estimator_html_repr import *
from IPython.core.display import display, HTML
import numpy as np
def onehot_numpy(arr_1d) :
    num = np.unique(arr_1d, axis=0)
    num = num.shape[0]
    encoding = np.eye(num)[arr_1d]
    return encoding

 

Data Load

Read Data

titanic = pd.read_csv("https://raw.githubusercontent.com/datasciencedojo/datasets/master/titanic.csv")

Remove Columns

titanic = titanic.drop(['PassengerId','Name','Ticket'],axis=1)

Split X , Y

y = titanic.pop("Survived")
x = titanic

Preprocessing

 

categori_cols = ['Pclass','Sex','SibSp','Parch','Cabin','Embarked']
numeric_cols = ["Age"]
numeric_features = numeric_cols 
numeric_transformer = Pipeline(steps=[
    ('imputer', SimpleImputer(strategy='median')),
    ('scaler', StandardScaler())])
categorical_features = categori_cols
categorical_transformer = Pipeline(steps=[
    ('imputer', SimpleImputer(strategy='constant', fill_value='missing')),
    ('onehot', OneHotEncoder(handle_unknown='ignore'))])
preprocessor = ColumnTransformer(
    transformers=[
        ('num', numeric_transformer, numeric_features),
        ('cat', categorical_transformer, categorical_features)])
clf = Pipeline(steps=[('preprocessor', preprocessor),])
clf.fit(x)
html = estimator_html_repr(clf)
display(HTML(html))

 

 

result = clf.transform(x)
x_arr = result.toarray()
y_arr = y.values
y_arr = onehot_numpy(y_arr)
x_arr.shape , y_arr.shape
# ((891, 172), (891, 2))

Modeling

 

import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from functorch import make_functional, grad_and_value, vmap, combine_state_for_ensemble

class MLPClassifier(nn.Module):
    def __init__(self, hidden_dim=32, n_classes=2,seed=1234):
        super().__init__()
        torch.manual_seed(seed)
        self.hidden_dim = hidden_dim
        self.n_classes = n_classes

        self.fc1 = nn.Linear(172, self.hidden_dim)
        self.fc2 = nn.Linear(self.hidden_dim, self.n_classes)

    def forward(self, x):
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        return x

DEVICE = "cuda:1"
func_model, weights = make_functional(MLPClassifier().to(DEVICE))
loss_fn = nn.BCEWithLogitsLoss(reduction="mean")
def train_step_fn(weights, batch, targets, lr=0.2):
    def compute_loss(weights, batch, targets):
        output = func_model(weights, batch)
        loss = loss_fn(output, targets)
        return loss
    grad_weights, loss = grad_and_value(compute_loss)(weights, batch, targets)

    new_weights = []
    with torch.no_grad():
        for grad_weight, weight in zip(grad_weights, weights):
            new_weights.append(weight - grad_weight * lr)

    return loss, new_weights

def init_fn(num_models ):
    models = [MLPClassifier(hidden_dim=30 , seed=np.random.randint(1,1000)).to(DEVICE) for _ in range(num_models)]
    _, params, _ = combine_state_for_ensemble(models)
    return params
    
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import numpy as np

class TabularDataSet(Dataset) :
    def __init__(self, X , Y) :
        self._X = np.float32(X)
        self._Y = Y

    def __len__(self,) :
        return len(self._Y)

    def __getitem__(self,idx) :
        return self._X[idx], self._Y[idx]
dataset = TabularDataSet(x_arr ,y_arr)
batch_size = 32
train_dl = DataLoader(dataset, batch_size=batch_size,shuffle=True)

Train


parallel_train_step_fn = vmap(train_step_fn, in_dims=(0, None, None))
batched_weights = init_fn(num_models=5)
loss_result = []
import matplotlib.pyplot as plt
from IPython import display 
for i in range(2000):
    loss_collection =[]
    for x , y in train_dl :
        x = x.to(DEVICE)
        y = y.to(DEVICE)
        loss, batched_weights = parallel_train_step_fn(batched_weights, x, y)
        loss_collection.append(loss)
        print(",".join([f"{i:.3f}" for i in loss.detach().cpu().numpy()]),end="\r")
    else :
        losses = torch.stack(loss_collection).mean(axis=0).detach().cpu().numpy()
        loss_result.append(losses)
    if i % 10 == 0 :
        display.clear_output(wait=True)
        plt.plot(loss_result)
        plt.show()

다만 아쉬운 점은 여기서 각각의 모델에 대한 베스트 모델이 다를 것 같은데, 아직은 naive한 버전밖에 없는 상황이다.

그래서 각 모델을 뽑기 위한 유틸 함수가 추가로 필요해보인다.

 

for i in batched_weights :
    print(i.size())
torch.Size([5, 30, 172])
torch.Size([5, 30])
torch.Size([5, 2, 30])
torch.Size([5, 2])

Prediction

def prediction_step_fn(weights, batch ):
    output = func_model(weights, batch)
    return output
parallel_prediction_step_fn = vmap(prediction_step_fn, in_dims=(0, None))

x_tensor = torch.FloatTensor(x_arr).to(DEVICE)
print(x_tensor.shape)
# torch.Size([891, 172])
result = parallel_prediction_step_fn(batched_weights , x_tensor )
print(result.shape)
#torch.Size([5, 891, 2])

idx = 0
one_row_prediction = result[:,idx,:]

one_result = one_row_prediction.detach().cpu().numpy()
num_models = one_result.shape[0]
pd.DataFrame(one_result , index = [f"model_{i}" for i in range(num_models)])

 

 

 

 

https://github.com/pytorch/functorch/blob/main/examples/ensembling/parallel_train.py

 

GitHub - pytorch/functorch: functorch is JAX-like composable function transforms for PyTorch.

functorch is JAX-like composable function transforms for PyTorch. - GitHub - pytorch/functorch: functorch is JAX-like composable function transforms for PyTorch.

github.com

http://willwhitney.com/parallel-training-jax.html

 

Parallelizing neural networks on one GPU with JAX | Will Whitney

Parallelizing neural networks on one GPU with JAX How you can get a 100x speedup for training small neural networks by making the most of your accelerator. Most neural network libraries these days give amazing computational performance for training large n

willwhitney.com

 

728x90