[skorch] VAE 적용 구현해보기
2020. 9. 22. 00:33ㆍ분석 Python/Pytorch
base.py ( vae_models 폴더 안에)
from .types_ import *
from torch import nn
from abc import abstractmethod
class BaseVAE(nn.Module):
def __init__(self) -> None:
super(BaseVAE, self).__init__()
def encode(self, input: Tensor) -> List[Tensor]:
raise NotImplementedError
def decode(self, input: Tensor) -> Any:
raise NotImplementedError
def sample(self, batch_size:int, current_device: int, **kwargs) -> Tensor:
raise RuntimeWarning()
def generate(self, x: Tensor, **kwargs) -> Tensor:
raise NotImplementedError
@abstractmethod
def forward(self, *inputs: Tensor) -> Tensor:
pass
@abstractmethod
def loss_function(self, *inputs: Any, **kwargs) -> Tensor:
pass
types_.py ( vae_models 폴더 안에)
typing을 이렇게 하는 것도 배울 것 중에 하나다.
from typing import List, Callable, Union, Any, TypeVar, Tuple
# from torch import tensor as Tensor
Tensor = TypeVar('torch.tensor')
Vanila VAE
from torch import nn
from torch.autograd import Variable
import torch
import torch, torch.nn as nn, torch.nn.functional as F
def make_hidden_set(hiddens) :
hidden_set = []
for i in range(len(hiddens)-1) :
hidden_set.append([hiddens[i],hiddens[i+1]])
return hidden_set
class VanillaVAE(BaseVAE):
def __init__(self,
inp : int,
latent_dim: int,
hidden_dims: List = None,
**kwargs) -> None:
super().__init__()
self.latent_dim = latent_dim
hiddens = [inp] + hidden_dims
hidden_set = make_hidden_set(hiddens)
modules = []
for idx, (in_features , out_features) in enumerate(hidden_set) :
fclayer = nn.Linear(in_features, out_features)
modules.append(fclayer)
modules.append(nn.SELU())
self.encoder = nn.Sequential(*modules)
self.fc_mu = nn.Linear(hidden_dims[-1], latent_dim)
self.fc_var = nn.Linear(hidden_dims[-1], latent_dim)
modules = []
hidden_dims = hidden_dims[::-1]
self.decoder_input = nn.Linear(latent_dim, hidden_dims[0] )
hidden_set = make_hidden_set(hidden_dims)
modules.append(self.decoder_input)
for idx, (in_features , out_features) in enumerate(hidden_set) :
fclayer = nn.Linear(in_features, out_features)
modules.append(fclayer)
modules.append(nn.SELU())
self.decoder_output = nn.Linear(hidden_dims[-1], inp)
modules.append(self.decoder_output)
self.decoder = nn.Sequential(*modules)
def encode(self, input: Tensor) -> List[Tensor]:
"""
Encodes the input by passing through the encoder network
and returns the latent codes.
:param input: (Tensor) Input tensor to encoder [N x C x H x W]
:return: (Tensor) List of latent codes
"""
result = self.encoder(input)
# result = torch.flatten(result, start_dim=1)
# Split the result into mu and var components
# of the latent Gaussian distribution
mu = self.fc_mu(result)
log_var = self.fc_var(result)
return [mu, log_var]
def decode(self, z: Tensor) -> Tensor:
"""
Maps the given latent codes
onto the image space.
"""
result = self.decoder(z)
return result
def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
"""
Reparameterization trick to sample from N(mu, var) from
N(0,1).
:param mu: (Tensor) Mean of the latent Gaussian [B x D]
:param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D]
:return: (Tensor) [B x D]
"""
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return eps * std + mu
def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
mu, log_var = self.encode(input)
z = self.reparameterize(mu, log_var)
return [self.decode(z), input, mu, log_var]
def loss_function(self,*args, **kwargs) -> dict:
"""
Computes the VAE loss function.
KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2}
:param args:
:param kwargs:
:return:
"""
recons = args[0]
input = args[1]
mu = args[2]
log_var = args[3]
kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset
recons_loss =F.mse_loss(recons, input)
kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)
loss = recons_loss + kld_weight * kld_loss
return {'loss': loss, 'Reconstruction_Loss':recons_loss, 'KLD':-kld_loss}
def sample(self,
num_samples:int,
current_device: int, **kwargs) -> Tensor:
"""
Samples from the latent space and return the corresponding
image space map.
:param num_samples: (Int) Number of samples
:param current_device: (Int) Device to run the model
:return: (Tensor)
"""
z = torch.randn(num_samples, self.latent_dim)
# z = z.to(current_device)
samples = self.decode(z)
return samples
def generate(self, x: Tensor, **kwargs) -> Tensor:
"""
Given an input image x, returns the reconstructed image
:param x: (Tensor) [B x C x H x W]
:return: (Tensor) [B x C x H x W]
"""
return self.forward(x)[0]
skorch 적용
from skorch import NeuralNetRegressor
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from skorch.callbacks import Checkpoint, TrainEndCheckpoint , LoadInitState
from skorch.callbacks import LRScheduler , EarlyStopping , EpochScoring
partial 적용!
from functools import partial
inp_size = train_one_hot.shape[1]
hidden_sizes = [120,60,30]
VanillaVAE = partial(VanillaVAE,
inp = inp_size,
latent_dim = 10,
hidden_dims = hidden_sizes)
VanillaVAE.__name__ = "VanillaVAE"
skorch Regressor를 사용해서 custom loss function 설계하기!
class VAENet(NeuralNetRegressor):
def get_loss(self, y_pred, y_true, *args, **kwargs):
recons, input , mu , log_var = y_pred # <- unpack the tuple that was returned by `forward`
recons_loss = super().get_loss(recons, input, *args, **kwargs)
kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)
kld_weight = 50
loss = recons_loss + kld_weight * kld_loss
return loss # loss_reconstruction # + loss_l1
cp = Checkpoint(dirname='VAE_skorch')
train_end_cp = TrainEndCheckpoint(dirname='VAE_skorch')
load_state = LoadInitState(cp)
net = VAENet(
VanillaVAE,
criterion=torch.nn.MSELoss,
max_epochs=10000,
optimizer=torch.optim.AdamW,
optimizer__weight_decay=1e-5,
lr=0.0001,
batch_size= 50,
callbacks=[cp , load_state ,
('earlystopping',
EarlyStopping(monitor="valid_loss", patience=200,
lower_is_better=True)),
]
)
net.fit(train_one_hot.values.astype(np.float32),
train_one_hot.values.astype(np.float32))
굳 매우 편함.
언젠간 pytorch lighting도 살펴봐야겠다.
아주 잘 정리해준 고마운 분
github 저자는 pytorch ligthing으로 구현한 듯함.
보고 배울게 많아 보임
728x90
'분석 Python > Pytorch' 카테고리의 다른 글
[TIP / Pytorch] Linear NN Model base Architecture (0) | 2020.10.23 |
---|---|
[Pytorch] Regression 관련 자료 (0) | 2020.09.29 |
[Pytorch] torch에서 모델 summary 확인하는 방법 (0) | 2020.08.25 |
[Pytorch] Pytorch를 Keras처럼 API 호출 하는 방식으로 사용하는 방법 (0) | 2020.08.25 |
[Pytorch] LSTM AutoEncoder for Anomaly Detection (3) | 2020.08.23 |