PyTorch Lighting + Ray tune
2020. 11. 7. 13:55ㆍ분석 Python/Pytorch
ray tune에 올라온 pyotch lighting으로 구현된 코드를 돌려봤는데, 문제가 생겨서
임시방편으로 막아놓은 코드를 공유한다.
일단 2가지가 안 되는 것을 확인했다.
tune.with_parameters
TuneReportCallback
위의 부분을 제거한 코드를 공유하겠다.
TunerReportCallback 도 문제이고,
tune.with_parameters는 뭔가 좋은 기능을 쓰지 못할 수도 있을 것 같다는 생각이 들지만 일단 공유
버그는 머 나중에 다 잡힐 것이라고 믿기 때문에 일단은 알아두기만 해야겠다!
패키지 설치
pip install "ray[tune]" # 1.0.0
pip install "pytorch-lightning>=1.0"
pip install "pytorch-lightning-bolts>=0.2.5"
코드
라이브러리 로드
import torch
from torch.nn import functional as F
import pytorch_lightning as pl
from pl_bolts.datamodules import MNISTDataModule
import os
from ray.tune.integration.pytorch_lightning import TuneReportCallback
import tempfile
from ray import tune
아키텍처 구성
class LightningMNISTClassifier(pl.LightningModule):
def __init__(self, config, data_dir=None):
super(LightningMNISTClassifier, self).__init__()
self.data_dir = data_dir or os.getcwd()
self.lr = config["lr"]
layer_1, layer_2 = config["layer_1"], config["layer_2"]
# mnist images are (1, 28, 28) (channels, width, height)
self.layer_1 = torch.nn.Linear(28 * 28, layer_1)
self.layer_2 = torch.nn.Linear(layer_1, layer_2)
self.layer_3 = torch.nn.Linear(layer_2, 10)
self.accuracy = pl.metrics.Accuracy()
def forward(self, x):
batch_size, channels, width, height = x.size()
x = x.view(batch_size, -1)
x = self.layer_1(x)
x = torch.relu(x)
x = self.layer_2(x)
x = torch.relu(x)
x = self.layer_3(x)
x = torch.log_softmax(x, dim=1)
return x
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=self.lr)
def training_step(self, train_batch, batch_idx):
x, y = train_batch
logits = self.forward(x)
loss = F.nll_loss(logits, y)
acc = self.accuracy(logits, y)
self.log("ptl/train_loss", loss)
self.log("ptl/train_accuracy", acc)
return loss
def validation_step(self, val_batch, batch_idx):
x, y = val_batch
logits = self.forward(x)
loss = F.nll_loss(logits, y)
acc = self.accuracy(logits, y)
return {"val_loss": loss, "val_accuracy": acc}
def validation_epoch_end(self, outputs):
avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean()
avg_acc = torch.stack([x["val_accuracy"] for x in outputs]).mean()
self.log("ptl/val_loss", avg_loss)
self.log("ptl/val_accuracy", avg_acc)
def train_mnist_tune(config,
num_epochs=10, num_gpus=0):
data_dir=os.path.join(tempfile.gettempdir(), "mnist_data_")
model = LightningMNISTClassifier(config, data_dir)
dm = MNISTDataModule(
data_dir=data_dir, num_workers=1, batch_size=config["batch_size"])
metrics = {"loss": "ptl/val_loss", "acc": "ptl/val_accuracy"}
trainer = pl.Trainer(
max_epochs=num_epochs,
gpus=num_gpus,
progress_bar_refresh_rate=0,
# callbacks=[TuneReportCallback(metrics, on="validation_end")]
)
trainer.fit(model, dm)
def tune_mnist(num_samples=10, num_epochs=10, gpus_per_trial=0):
data_dir = os.path.join(tempfile.gettempdir(), "mnist_data_")
# Download data
MNISTDataModule(data_dir=data_dir).prepare_data()
config = {
"layer_1": tune.choice([32, 64, 128]),
"layer_2": tune.choice([64, 128, 256]),
"lr": tune.loguniform(1e-4, 1e-1),
"batch_size": tune.choice([32, 64, 128]),
}
# trainable = tune.with_parameters(
# train_mnist_tune,
# data_dir=data_dir,
# num_epochs=num_epochs,
# num_gpus=gpus_per_trial)
tune.run(
train_mnist_tune,
resources_per_trial={
"cpu": 1,
"gpu": gpus_per_trial
},
metric="loss",
mode="min",
config=config,
num_samples=num_samples,
name="tune_mnist")
실행코드
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"--smoke-test", action="store_true", help="Finish quickly for testing")
args, _ = parser.parse_known_args()
tune_mnist(num_samples=1, num_epochs=1, gpus_per_trial=0)
점점 간단해지는 것 같아서 편하지만, 한편으로는 씁쓸...
github.com/ray-project/ray/blob/releases/1.0.1/python/ray/tune/examples/mnist_ptl_mini.py
towardsdatascience.com/how-to-tune-pytorch-lightning-hyperparameters-80089a281646
728x90
'분석 Python > Pytorch' 카테고리의 다른 글
[Pytorch] How to Apply the Weight Initialization (Code) (0) | 2020.12.17 |
---|---|
[Pytorch] torch 유용한 함수 정리하기 (0) | 2020.11.30 |
[TIP / Pytorch] torch class name 얻는 방법 (0) | 2020.10.31 |
[TIP / Pytorch] calculate convolution output shae (conv2d , pooling) (Conv 아웃풋 값 (0) | 2020.10.31 |
[TIP / Pytorch 1.5~] jit script, save, load (0) | 2020.10.29 |