분석 Python/Pytorch
[TIP / Pytorch] Linear NN Model base Architecture
데이터분석뉴비
2020. 10. 23. 20:38
728x90
pytorch가 arhcitecture가 저장이 안 되니, 본 판때기를 잘 만들어서 한 구조에서 여러개의 파리미터를 넣을 수 있도록 해야 한다.
여기서는 본 판때기에 대한 base를 만들어봄.
from torch import nn
class Net(nn.Module) :
def __init__(self, layers , activation, bn, dropout) :
super(Net,self).__init__()
self.model = self.make_model(layers , activation, bn, dropout)
def forward(self, x) :
self.model(x)
def make_model(self, layers , activation, bn, dropout) :
model = []
layers.append(1)
if activation.lower() == "selu" :
selu = nn.SELU
for idx , layer in enumerate(layers[1:]) :
mod = nn.Linear(layers[idx], layer)
model.append(mod)
if (idx+1) == len(layers[1:]) :
pass
else :
if dropout == True :
model.append(nn.AlphaDropout(0.8))
if bn == True :
model.append(nn.BatchNorm1d(layer))
model.append(selu())
return nn.Sequential(*model)
input_size = 150
layers = [input_size, 10,5,10]
layers = [input_size, 100,100,100,100,5,10]
activation = "selu"
dropout =True
bn = True
Net(layers,activation,bn,dropout)
Net(
(model): Sequential(
(0): Linear(in_features=150, out_features=10, bias=True)
(1): AlphaDropout(p=0.8, inplace=False)
(2): BatchNorm1d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(3): SELU()
(4): Linear(in_features=10, out_features=5, bias=True)
(5): AlphaDropout(p=0.8, inplace=False)
(6): BatchNorm1d(5, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(7): SELU()
(8): Linear(in_features=5, out_features=10, bias=True)
(9): AlphaDropout(p=0.8, inplace=False)
(10): BatchNorm1d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(11): SELU()
(12): Linear(in_features=10, out_features=1, bias=True)
)
)