[skorch] VAE 적용 구현해보기

2020. 9. 22. 00:33분석 Python/Pytorch

728x90

base.py ( vae_models 폴더 안에)

from .types_ import *
from torch import nn
from abc import abstractmethod

class BaseVAE(nn.Module):
    
    def __init__(self) -> None:
        super(BaseVAE, self).__init__()

    def encode(self, input: Tensor) -> List[Tensor]:
        raise NotImplementedError

    def decode(self, input: Tensor) -> Any:
        raise NotImplementedError

    def sample(self, batch_size:int, current_device: int, **kwargs) -> Tensor:
        raise RuntimeWarning()

    def generate(self, x: Tensor, **kwargs) -> Tensor:
        raise NotImplementedError

    @abstractmethod
    def forward(self, *inputs: Tensor) -> Tensor:
        pass

    @abstractmethod
    def loss_function(self, *inputs: Any, **kwargs) -> Tensor:
        pass

types_.py ( vae_models 폴더 안에)

 

typing을 이렇게 하는 것도 배울 것 중에 하나다.

from typing import List, Callable, Union, Any, TypeVar, Tuple
# from torch import tensor as Tensor

Tensor = TypeVar('torch.tensor')

Vanila VAE

from torch import nn
from torch.autograd import Variable
import torch
import torch, torch.nn as nn, torch.nn.functional as F

def make_hidden_set(hiddens) :
    hidden_set = []
    for i in range(len(hiddens)-1) :
        hidden_set.append([hiddens[i],hiddens[i+1]])
    return hidden_set
 
class VanillaVAE(BaseVAE):
    def __init__(self,
                 inp : int,
                 latent_dim: int,
                 hidden_dims: List = None, 
                 **kwargs) -> None:
        super().__init__()

        self.latent_dim = latent_dim
        
        hiddens = [inp] + hidden_dims
        hidden_set = make_hidden_set(hiddens)
        modules = []
        for idx, (in_features , out_features) in enumerate(hidden_set) :
            fclayer = nn.Linear(in_features, out_features)
            modules.append(fclayer)
            modules.append(nn.SELU())
        self.encoder = nn.Sequential(*modules)
        self.fc_mu = nn.Linear(hidden_dims[-1], latent_dim)
        self.fc_var = nn.Linear(hidden_dims[-1], latent_dim)
        
        modules = []
        hidden_dims = hidden_dims[::-1]
        self.decoder_input = nn.Linear(latent_dim, hidden_dims[0] )
        hidden_set = make_hidden_set(hidden_dims)
        modules.append(self.decoder_input)
        for idx, (in_features , out_features) in enumerate(hidden_set) :
            fclayer = nn.Linear(in_features, out_features)
            modules.append(fclayer)
            modules.append(nn.SELU())
        self.decoder_output = nn.Linear(hidden_dims[-1], inp)
        modules.append(self.decoder_output)
        self.decoder = nn.Sequential(*modules)
        
    def encode(self, input: Tensor) -> List[Tensor]:
        """
        Encodes the input by passing through the encoder network
        and returns the latent codes.
        :param input: (Tensor) Input tensor to encoder [N x C x H x W]
        :return: (Tensor) List of latent codes
        """
        result = self.encoder(input)
#         result = torch.flatten(result, start_dim=1)
        # Split the result into mu and var components
        # of the latent Gaussian distribution
        mu = self.fc_mu(result)
        log_var = self.fc_var(result)

        return [mu, log_var]
    
    def decode(self, z: Tensor) -> Tensor:
        """
        Maps the given latent codes
        onto the image space.
        """
        result = self.decoder(z)
        return result
    
    def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
        """
        Reparameterization trick to sample from N(mu, var) from
        N(0,1).
        :param mu: (Tensor) Mean of the latent Gaussian [B x D]
        :param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D]
        :return: (Tensor) [B x D]
        """
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return eps * std + mu
    
    def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
        mu, log_var = self.encode(input)
        z = self.reparameterize(mu, log_var)
        return  [self.decode(z), input, mu, log_var]
    
    def loss_function(self,*args, **kwargs) -> dict:
        """
        Computes the VAE loss function.
        KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2}
        :param args:
        :param kwargs:
        :return:
        """
        recons = args[0]
        input = args[1]
        mu = args[2]
        log_var = args[3]

        kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset
        recons_loss =F.mse_loss(recons, input)
        kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)
        loss = recons_loss + kld_weight * kld_loss
        return {'loss': loss, 'Reconstruction_Loss':recons_loss, 'KLD':-kld_loss}
    
    def sample(self,
               num_samples:int,
               current_device: int, **kwargs) -> Tensor:
        """
        Samples from the latent space and return the corresponding
        image space map.
        :param num_samples: (Int) Number of samples
        :param current_device: (Int) Device to run the model
        :return: (Tensor)
        """
        z = torch.randn(num_samples, self.latent_dim)
#         z = z.to(current_device)
        samples = self.decode(z)
        return samples

    def generate(self, x: Tensor, **kwargs) -> Tensor:
        """
        Given an input image x, returns the reconstructed image
        :param x: (Tensor) [B x C x H x W]
        :return: (Tensor) [B x C x H x W]
        """
        return self.forward(x)[0]

skorch 적용

from skorch import NeuralNetRegressor
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from skorch.callbacks import Checkpoint, TrainEndCheckpoint , LoadInitState
from skorch.callbacks import LRScheduler , EarlyStopping , EpochScoring

partial 적용!

from functools import partial
inp_size = train_one_hot.shape[1]
hidden_sizes = [120,60,30]
VanillaVAE = partial(VanillaVAE, 
                          inp = inp_size,
                          latent_dim = 10, 
                          hidden_dims =  hidden_sizes)
VanillaVAE.__name__ = "VanillaVAE"

skorch Regressor를 사용해서 custom loss function 설계하기!

class VAENet(NeuralNetRegressor):
    def get_loss(self, y_pred, y_true, *args, **kwargs):
        recons, input , mu , log_var = y_pred  # <- unpack the tuple that was returned by `forward`
        recons_loss = super().get_loss(recons, input, *args, **kwargs)
        kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)
        kld_weight = 50
        loss = recons_loss + kld_weight * kld_loss
        return loss # loss_reconstruction # + loss_l1
    
cp = Checkpoint(dirname='VAE_skorch')
train_end_cp = TrainEndCheckpoint(dirname='VAE_skorch')
load_state = LoadInitState(cp)
net = VAENet(
    VanillaVAE,
    criterion=torch.nn.MSELoss,
    max_epochs=10000,
    optimizer=torch.optim.AdamW,
    optimizer__weight_decay=1e-5,
    lr=0.0001,
    batch_size= 50,
    callbacks=[cp , load_state ,
               ('earlystopping', 
                EarlyStopping(monitor="valid_loss", patience=200,
                              lower_is_better=True)),
        
    ]
)
net.fit(train_one_hot.values.astype(np.float32),
        train_one_hot.values.astype(np.float32))

굳 매우 편함.

언젠간 pytorch lighting도 살펴봐야겠다.

 

아주 잘 정리해준 고마운 분

github 저자는 pytorch ligthing으로 구현한 듯함.

보고 배울게 많아 보임

 

 

VAE Models

 

AntixK/PyTorch-VAE

A Collection of Variational Autoencoders (VAE) in PyTorch. - AntixK/PyTorch-VAE

github.com

GMVAE

 

jariasf/GMVAE

Implementation of Gaussian Mixture Variational Autoencoder (GMVAE) for Unsupervised Clustering - jariasf/GMVAE

github.com

 

728x90