Pytorch) 모델 가중치 로드 시 테스트 (전체 모델에서 서브 모델 가중치만 가져오기)
2023. 9. 15. 23:28ㆍ분석 Python/Pytorch
상황
조금 더 일반화된 학습을 하기 위해 멀티 타겟에 대한 일반화된 모델을 만들고, 그 모델에서 부분적인 타겟에 대하서 추출할 때 가중치를 잘 가져오는 지에 대한 테스트를 수행해봄.
기대 효과
공유하는 네트워크(Shared Network)가 일반화되게 잘 학습이 되고, 부분적으로 학습시킬 때 좋은 인풋으로써의 기능을 할 수 있지 않을까 함.
방법
각 타겟에 대해서 Dict 으로 키로 관리하고, 나중에 load 시 strict=False를 하면, 알아서 파라미터가 매칭이 안되도 들어갈 것이다라는 생각으로 시작
엄밀하게 제거하는 작업도 있지만, 제거하지 않아도 자동으로 맵핑되는 지 보고 싶었음
코드
중간에 가중치를 임의로 지정하여 테스트
굳이 forward까지 구현하지 않아도 되므로 패스하고 진행한다.
import torch
from torch import nn
# Define model
class Net(nn.Module):
def __init__(self, input_size, hidden_size, output_names):
super().__init__()
self.fc1 = nn.Linear(input_size, hidden_size)
# self.fc2 = nn.Linear(hidden_size, output_size)
self.output_dict = nn.ModuleDict()
self.output_fc =nn.ModuleDict()
for idx, output_col in enumerate(output_names):
w1 = nn.Linear(hidden_size, hidden_size)
w1.weight.data.fill_(0.01+idx)
w1.bias.data.fill_(0.01+idx)
self.output_dict[output_col] = w1
w2 = nn.Linear(hidden_size, 1)
w2.weight.data.fill_(0.01+idx)
w2.bias.data.fill_(0.01+idx)
self.output_fc[output_col] = w2
self.relu = nn.ReLU()
def forward(self,x) :
pass
모델 정의
model = Net(10,10, ['a','b','c'])
Net(
(fc1): Linear(in_features=10, out_features=10, bias=True)
(output_dict): ModuleDict(
(a): Linear(in_features=10, out_features=10, bias=True)
(b): Linear(in_features=10, out_features=10, bias=True)
(c): Linear(in_features=10, out_features=10, bias=True)
)
(output_fc): ModuleDict(
(a): Linear(in_features=10, out_features=1, bias=True)
(b): Linear(in_features=10, out_features=1, bias=True)
(c): Linear(in_features=10, out_features=1, bias=True)
)
(relu): ReLU()
)
model.state_dict()
OrderedDict([('fc1.weight',
tensor([[ 0.0017, -0.1278, -0.0266, -0.0183, -0.0126, 0.2083, 0.1939, 0.1057,
-0.1191, 0.2231],
[ 0.0284, -0.2354, 0.0346, 0.0863, 0.0670, 0.2026, 0.2626, -0.2846,
0.1729, -0.1541],
[-0.2734, -0.1182, -0.0646, -0.1695, -0.0969, 0.1581, 0.2259, 0.2693,
0.1707, 0.2501],
[ 0.0553, 0.1460, -0.0346, -0.3149, -0.1398, -0.1654, -0.2747, -0.1863,
-0.2407, -0.1953],
[-0.2127, -0.0775, -0.2025, -0.2823, 0.1191, 0.1548, -0.0751, -0.1676,
0.1808, -0.2031],
[-0.0907, -0.2605, 0.0703, -0.1875, 0.0004, -0.1114, 0.1550, 0.2570,
-0.2923, 0.1357],
[-0.0060, 0.1363, 0.1011, 0.0492, 0.2389, -0.2457, 0.2297, 0.1049,
-0.0022, 0.0798],
[-0.2751, 0.0437, 0.0444, -0.1880, 0.2614, -0.0529, 0.1194, 0.1280,
-0.0151, -0.2442],
[-0.2918, -0.1730, -0.1653, 0.0335, 0.1010, -0.0237, 0.2398, 0.0183,
-0.2276, -0.3015],
[ 0.1253, -0.1207, 0.2565, 0.2333, -0.1534, -0.0165, 0.1279, 0.1267,
-0.2052, -0.1208]])),
('fc1.bias',
tensor([ 0.2301, -0.2335, -0.3057, 0.1267, -0.1215, -0.0412, -0.0494, 0.2126,
-0.1698, 0.1222])),
('output_dict.a.weight',
tensor([[0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100,
0.0100],
[0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100,
0.0100],
[0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100,
0.0100],
[0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100,
0.0100],
[0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100,
0.0100],
[0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100,
0.0100],
[0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100,
0.0100],
[0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100,
0.0100],
[0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100,
0.0100],
[0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100,
0.0100]])),
('output_dict.a.bias',
tensor([0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100,
0.0100])),
('output_dict.b.weight',
tensor([[1.0100, 1.0100, 1.0100, 1.0100, 1.0100, 1.0100, 1.0100, 1.0100, 1.0100,
1.0100],
[1.0100, 1.0100, 1.0100, 1.0100, 1.0100, 1.0100, 1.0100, 1.0100, 1.0100,
1.0100],
[1.0100, 1.0100, 1.0100, 1.0100, 1.0100, 1.0100, 1.0100, 1.0100, 1.0100,
1.0100],
[1.0100, 1.0100, 1.0100, 1.0100, 1.0100, 1.0100, 1.0100, 1.0100, 1.0100,
1.0100],
[1.0100, 1.0100, 1.0100, 1.0100, 1.0100, 1.0100, 1.0100, 1.0100, 1.0100,
1.0100],
[1.0100, 1.0100, 1.0100, 1.0100, 1.0100, 1.0100, 1.0100, 1.0100, 1.0100,
1.0100],
[1.0100, 1.0100, 1.0100, 1.0100, 1.0100, 1.0100, 1.0100, 1.0100, 1.0100,
1.0100],
[1.0100, 1.0100, 1.0100, 1.0100, 1.0100, 1.0100, 1.0100, 1.0100, 1.0100,
1.0100],
[1.0100, 1.0100, 1.0100, 1.0100, 1.0100, 1.0100, 1.0100, 1.0100, 1.0100,
1.0100],
[1.0100, 1.0100, 1.0100, 1.0100, 1.0100, 1.0100, 1.0100, 1.0100, 1.0100,
1.0100]])),
('output_dict.b.bias',
tensor([1.0100, 1.0100, 1.0100, 1.0100, 1.0100, 1.0100, 1.0100, 1.0100, 1.0100,
1.0100])),
('output_dict.c.weight',
tensor([[2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100,
2.0100],
[2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100,
2.0100],
[2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100,
2.0100],
[2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100,
2.0100],
[2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100,
2.0100],
[2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100,
2.0100],
[2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100,
2.0100],
[2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100,
2.0100],
[2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100,
2.0100],
[2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100,
2.0100]])),
('output_dict.c.bias',
tensor([2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100,
2.0100])),
('output_fc.a.weight',
tensor([[0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100,
0.0100]])),
('output_fc.a.bias', tensor([0.0100])),
('output_fc.b.weight',
tensor([[1.0100, 1.0100, 1.0100, 1.0100, 1.0100, 1.0100, 1.0100, 1.0100, 1.0100,
1.0100]])),
('output_fc.b.bias', tensor([1.0100])),
('output_fc.c.weight',
tensor([[2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100,
2.0100]])),
('output_fc.c.bias', tensor([2.0100]))])
부분적인 모델 개발 후 모델 로드 하기
model2 = Net(10,10, ['c'])
Net(
(fc1): Linear(in_features=10, out_features=10, bias=True)
(output_dict): ModuleDict(
(c): Linear(in_features=10, out_features=10, bias=True)
)
(output_fc): ModuleDict(
(c): Linear(in_features=10, out_features=1, bias=True)
)
(relu): ReLU()
)
OrderedDict([('fc1.weight',
tensor([[ 0.1946, 0.2048, 0.1365, -0.1957, 0.1567, 0.0072, -0.0784, 0.1737,
-0.2101, 0.0298],
[-0.2536, 0.3019, -0.0925, 0.0198, -0.1691, -0.3024, -0.2791, 0.1541,
-0.0515, 0.2403],
[-0.2729, -0.0287, -0.1100, -0.2106, -0.0440, -0.2403, -0.2801, -0.2588,
-0.2976, 0.0337],
[-0.1317, -0.2659, -0.1828, -0.0062, 0.1688, -0.0673, -0.3125, -0.1765,
-0.1147, -0.3059],
[-0.0957, 0.1289, -0.3119, 0.2531, -0.3111, -0.1319, 0.0591, -0.2888,
0.0585, 0.0587],
[ 0.1130, -0.0909, 0.2883, 0.0791, 0.1900, 0.0951, 0.2858, -0.0830,
0.2342, 0.0112],
[-0.0021, 0.0022, 0.1160, 0.2418, -0.1260, 0.2176, 0.0009, 0.0213,
-0.2691, 0.2210],
[ 0.0776, -0.0344, -0.1799, 0.0740, -0.1348, -0.2958, 0.0050, -0.0732,
0.0202, 0.1260],
[-0.0207, 0.2465, -0.0502, 0.1314, -0.1049, -0.2610, -0.2597, 0.0308,
-0.2391, 0.2963],
[ 0.0236, 0.1121, 0.1393, 0.2172, -0.0822, 0.0677, -0.2918, 0.1014,
0.2830, 0.2161]])),
('fc1.bias',
tensor([-0.1148, 0.0394, 0.1593, -0.1153, 0.2548, -0.2897, -0.0852, 0.0883,
0.0784, -0.1738])),
('output_dict.c.weight',
tensor([[0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100,
0.0100],
[0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100,
0.0100],
[0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100,
0.0100],
[0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100,
0.0100],
[0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100,
0.0100],
[0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100,
0.0100],
[0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100,
0.0100],
[0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100,
0.0100],
[0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100,
0.0100],
[0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100,
0.0100]])),
('output_dict.c.bias',
tensor([0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100,
0.0100])),
('output_fc.c.weight',
tensor([[0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100,
0.0100]])),
('output_fc.c.bias', tensor([0.0100]))])
model2.load_state_dict(model.state_dict(), strict=False)
_IncompatibleKeys(missing_keys=[], unexpected_keys=['output_dict.a.weight', 'output_dict.a.bias', 'output_dict.b.weight', 'output_dict.b.bias', 'output_fc.a.weight', 'output_fc.a.bias', 'output_fc.b.weight', 'output_fc.b.bias'])
그러면 아래와 같이 미스 매치된 것들에 대한 것들이 나옴.
OrderedDict([('fc1.weight',
tensor([[ 0.0017, -0.1278, -0.0266, -0.0183, -0.0126, 0.2083, 0.1939, 0.1057,
-0.1191, 0.2231],
[ 0.0284, -0.2354, 0.0346, 0.0863, 0.0670, 0.2026, 0.2626, -0.2846,
0.1729, -0.1541],
[-0.2734, -0.1182, -0.0646, -0.1695, -0.0969, 0.1581, 0.2259, 0.2693,
0.1707, 0.2501],
[ 0.0553, 0.1460, -0.0346, -0.3149, -0.1398, -0.1654, -0.2747, -0.1863,
-0.2407, -0.1953],
[-0.2127, -0.0775, -0.2025, -0.2823, 0.1191, 0.1548, -0.0751, -0.1676,
0.1808, -0.2031],
[-0.0907, -0.2605, 0.0703, -0.1875, 0.0004, -0.1114, 0.1550, 0.2570,
-0.2923, 0.1357],
[-0.0060, 0.1363, 0.1011, 0.0492, 0.2389, -0.2457, 0.2297, 0.1049,
-0.0022, 0.0798],
[-0.2751, 0.0437, 0.0444, -0.1880, 0.2614, -0.0529, 0.1194, 0.1280,
-0.0151, -0.2442],
[-0.2918, -0.1730, -0.1653, 0.0335, 0.1010, -0.0237, 0.2398, 0.0183,
-0.2276, -0.3015],
[ 0.1253, -0.1207, 0.2565, 0.2333, -0.1534, -0.0165, 0.1279, 0.1267,
-0.2052, -0.1208]])),
('fc1.bias',
tensor([ 0.2301, -0.2335, -0.3057, 0.1267, -0.1215, -0.0412, -0.0494, 0.2126,
-0.1698, 0.1222])),
('output_dict.c.weight',
tensor([[2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100,
2.0100],
[2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100,
2.0100],
[2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100,
2.0100],
[2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100,
2.0100],
[2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100,
2.0100],
[2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100,
2.0100],
[2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100,
2.0100],
[2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100,
2.0100],
[2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100,
2.0100],
[2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100,
2.0100]])),
('output_dict.c.bias',
tensor([2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100,
2.0100])),
('output_fc.c.weight',
tensor([[2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100,
2.0100]])),
('output_fc.c.bias', tensor([2.0100]))])
model.output_dict['c'].weight.data.numpy()
array([[2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01],
[2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01],
[2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01],
[2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01],
[2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01],
[2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01],
[2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01],
[2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01],
[2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01],
[2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01]],
dtype=float32)
model2.output_dict['c'].weight.data.numpy()
array([[2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01],
[2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01],
[2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01],
[2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01],
[2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01],
[2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01],
[2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01],
[2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01],
[2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01],
[2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01]],
dtype=float32)
결론
결론적으로 ModuleDict을 사용하면 잘 맵핑되는 것을 확인할 수 있다.
사실 이런 모델 구조를 만들 일이 없을 수도 있지만, 이런 식으로 전체에 대한 부분을 가져올 수 있다는 것이고, 중간에 학습이 잘 안되는 모델에 대해서 집중적으로도 학습 시키고 또 바꿔주고.. 머 이런 생각이 들었다...
728x90
'분석 Python > Pytorch' 카테고리의 다른 글
TimeSeries) [MultiHead Self Attention] multi target 예측 (0) | 2023.09.23 |
---|---|
Pytorch) multioutput Regression 구현해보기 (4) | 2022.03.26 |
torchfunc) titanic data에 model parallel training 해보기 (0) | 2022.03.26 |
Pytorch 1.11 이후) functorch 알아보기 (0) | 2022.03.14 |
Pytorch 1.11 이후) torchdata 알아보기 (0) | 2022.03.14 |