PyTorch Lighting + Ray tune

2020. 11. 7. 13:55분석 Python/Pytorch

728x90

ray tune에 올라온 pyotch lighting으로 구현된 코드를 돌려봤는데, 문제가 생겨서

임시방편으로 막아놓은 코드를 공유한다.

일단 2가지가 안 되는 것을 확인했다.

 

tune.with_parameters
TuneReportCallback

위의 부분을 제거한 코드를 공유하겠다.

TunerReportCallback 도 문제이고, 

tune.with_parameters는 뭔가 좋은 기능을 쓰지 못할 수도 있을 것 같다는 생각이 들지만 일단 공유

버그는 머 나중에 다 잡힐 것이라고 믿기 때문에 일단은 알아두기만 해야겠다!


패키지 설치

pip install "ray[tune]"    # 1.0.0
pip install "pytorch-lightning>=1.0" 
pip install "pytorch-lightning-bolts>=0.2.5"

코드

라이브러리 로드

import torch
from torch.nn import functional as F
import pytorch_lightning as pl
from pl_bolts.datamodules import MNISTDataModule
import os
from ray.tune.integration.pytorch_lightning import TuneReportCallback
import tempfile
from ray import tune

아키텍처 구성

class LightningMNISTClassifier(pl.LightningModule):
    def __init__(self, config, data_dir=None):
        super(LightningMNISTClassifier, self).__init__()

        self.data_dir = data_dir or os.getcwd()
        self.lr = config["lr"]
        layer_1, layer_2 = config["layer_1"], config["layer_2"]

        # mnist images are (1, 28, 28) (channels, width, height)
        self.layer_1 = torch.nn.Linear(28 * 28, layer_1)
        self.layer_2 = torch.nn.Linear(layer_1, layer_2)
        self.layer_3 = torch.nn.Linear(layer_2, 10)
        self.accuracy = pl.metrics.Accuracy()

    def forward(self, x):
        batch_size, channels, width, height = x.size()
        x = x.view(batch_size, -1)
        x = self.layer_1(x)
        x = torch.relu(x)
        x = self.layer_2(x)
        x = torch.relu(x)
        x = self.layer_3(x)
        x = torch.log_softmax(x, dim=1)
        return x

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.lr)

    def training_step(self, train_batch, batch_idx):
        x, y = train_batch
        logits = self.forward(x)
        loss = F.nll_loss(logits, y)
        acc = self.accuracy(logits, y)
        self.log("ptl/train_loss", loss)
        self.log("ptl/train_accuracy", acc)
        return loss

    def validation_step(self, val_batch, batch_idx):
        x, y = val_batch
        logits = self.forward(x)
        loss = F.nll_loss(logits, y)
        acc = self.accuracy(logits, y)
        return {"val_loss": loss, "val_accuracy": acc}

    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean()
        avg_acc = torch.stack([x["val_accuracy"] for x in outputs]).mean()
        self.log("ptl/val_loss", avg_loss)
        self.log("ptl/val_accuracy", avg_acc)


def train_mnist_tune(config,
                     num_epochs=10, num_gpus=0):
    data_dir=os.path.join(tempfile.gettempdir(), "mnist_data_")
    model = LightningMNISTClassifier(config, data_dir)
    dm = MNISTDataModule(
        data_dir=data_dir, num_workers=1, batch_size=config["batch_size"])
    metrics = {"loss": "ptl/val_loss", "acc": "ptl/val_accuracy"}
    trainer = pl.Trainer(
        max_epochs=num_epochs,
        gpus=num_gpus,
        progress_bar_refresh_rate=0,
#         callbacks=[TuneReportCallback(metrics, on="validation_end")]
    )
    trainer.fit(model, dm)
    
def tune_mnist(num_samples=10, num_epochs=10, gpus_per_trial=0):
    data_dir = os.path.join(tempfile.gettempdir(), "mnist_data_")
    # Download data
    MNISTDataModule(data_dir=data_dir).prepare_data()

    config = {
        "layer_1": tune.choice([32, 64, 128]),
        "layer_2": tune.choice([64, 128, 256]),
        "lr": tune.loguniform(1e-4, 1e-1),
        "batch_size": tune.choice([32, 64, 128]),
    }

#     trainable = tune.with_parameters(
#         train_mnist_tune,
#         data_dir=data_dir,
#         num_epochs=num_epochs,
#         num_gpus=gpus_per_trial)
    tune.run(
        train_mnist_tune,
        resources_per_trial={
            "cpu": 1,
            "gpu": gpus_per_trial
        },
        metric="loss",
        mode="min",
        config=config,
        num_samples=num_samples,
        name="tune_mnist")

실행코드

import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
    "--smoke-test", action="store_true", help="Finish quickly for testing")
args, _ = parser.parse_known_args()

tune_mnist(num_samples=1, num_epochs=1, gpus_per_trial=0)

 

점점 간단해지는 것 같아서 편하지만, 한편으로는 씁쓸...


 

github.com/ray-project/ray/blob/releases/1.0.1/python/ray/tune/examples/mnist_ptl_mini.py

 

ray-project/ray

An open source framework that provides a simple, universal API for building distributed applications. Ray is packaged with RLlib, a scalable reinforcement learning library, and Tune, a scalable hyp...

github.com

 

towardsdatascience.com/how-to-tune-pytorch-lightning-hyperparameters-80089a281646

 

How to tune Pytorch Lightning hyperparameters

Use Ray Tune to optimize Pytorch Lightning hyperparameters in 30 lines of code!

towardsdatascience.com

 

728x90