Pytorch) 모델 가중치 로드 시 테스트 (전체 모델에서 서브 모델 가중치만 가져오기)

2023. 9. 15. 23:28분석 Python/Pytorch

상황

조금 더 일반화된 학습을 하기 위해 멀티 타겟에 대한 일반화된 모델을 만들고, 그 모델에서 부분적인 타겟에 대하서 추출할 때 가중치를 잘 가져오는 지에 대한 테스트를 수행해봄.

기대 효과

공유하는 네트워크(Shared Network)가 일반화되게 잘 학습이 되고, 부분적으로 학습시킬 때 좋은 인풋으로써의 기능을 할 수 있지 않을까 함.

방법

각 타겟에 대해서 Dict 으로 키로 관리하고, 나중에 load 시 strict=False를 하면, 알아서 파라미터가 매칭이 안되도 들어갈 것이다라는 생각으로 시작

엄밀하게 제거하는 작업도 있지만, 제거하지 않아도 자동으로 맵핑되는 지 보고 싶었음

코드

중간에 가중치를 임의로 지정하여 테스트

굳이 forward까지 구현하지 않아도 되므로 패스하고 진행한다. 

import torch 
from torch import nn 

# Define model
class Net(nn.Module):
    def __init__(self, input_size, hidden_size, output_names):
        super().__init__()
        self.fc1 = nn.Linear(input_size, hidden_size) 
        # self.fc2 = nn.Linear(hidden_size, output_size)
        self.output_dict = nn.ModuleDict()
        self.output_fc =nn.ModuleDict()
        for idx, output_col in enumerate(output_names):
            w1 = nn.Linear(hidden_size, hidden_size)
            w1.weight.data.fill_(0.01+idx)
            w1.bias.data.fill_(0.01+idx)
            self.output_dict[output_col] = w1
            w2 = nn.Linear(hidden_size, 1)
            w2.weight.data.fill_(0.01+idx)
            w2.bias.data.fill_(0.01+idx)
            self.output_fc[output_col] = w2
            
        self.relu = nn.ReLU()
    
    def forward(self,x) :
        pass

모델 정의

model = Net(10,10, ['a','b','c'])
Net(
  (fc1): Linear(in_features=10, out_features=10, bias=True)
  (output_dict): ModuleDict(
    (a): Linear(in_features=10, out_features=10, bias=True)
    (b): Linear(in_features=10, out_features=10, bias=True)
    (c): Linear(in_features=10, out_features=10, bias=True)
  )
  (output_fc): ModuleDict(
    (a): Linear(in_features=10, out_features=1, bias=True)
    (b): Linear(in_features=10, out_features=1, bias=True)
    (c): Linear(in_features=10, out_features=1, bias=True)
  )
  (relu): ReLU()
)
model.state_dict()
OrderedDict([('fc1.weight',
              tensor([[ 0.0017, -0.1278, -0.0266, -0.0183, -0.0126,  0.2083,  0.1939,  0.1057,
                       -0.1191,  0.2231],
                      [ 0.0284, -0.2354,  0.0346,  0.0863,  0.0670,  0.2026,  0.2626, -0.2846,
                        0.1729, -0.1541],
                      [-0.2734, -0.1182, -0.0646, -0.1695, -0.0969,  0.1581,  0.2259,  0.2693,
                        0.1707,  0.2501],
                      [ 0.0553,  0.1460, -0.0346, -0.3149, -0.1398, -0.1654, -0.2747, -0.1863,
                       -0.2407, -0.1953],
                      [-0.2127, -0.0775, -0.2025, -0.2823,  0.1191,  0.1548, -0.0751, -0.1676,
                        0.1808, -0.2031],
                      [-0.0907, -0.2605,  0.0703, -0.1875,  0.0004, -0.1114,  0.1550,  0.2570,
                       -0.2923,  0.1357],
                      [-0.0060,  0.1363,  0.1011,  0.0492,  0.2389, -0.2457,  0.2297,  0.1049,
                       -0.0022,  0.0798],
                      [-0.2751,  0.0437,  0.0444, -0.1880,  0.2614, -0.0529,  0.1194,  0.1280,
                       -0.0151, -0.2442],
                      [-0.2918, -0.1730, -0.1653,  0.0335,  0.1010, -0.0237,  0.2398,  0.0183,
                       -0.2276, -0.3015],
                      [ 0.1253, -0.1207,  0.2565,  0.2333, -0.1534, -0.0165,  0.1279,  0.1267,
                       -0.2052, -0.1208]])),
             ('fc1.bias',
              tensor([ 0.2301, -0.2335, -0.3057,  0.1267, -0.1215, -0.0412, -0.0494,  0.2126,
                      -0.1698,  0.1222])),
             ('output_dict.a.weight',
              tensor([[0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100,
                       0.0100],
                      [0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100,
                       0.0100],
                      [0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100,
                       0.0100],
                      [0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100,
                       0.0100],
                      [0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100,
                       0.0100],
                      [0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100,
                       0.0100],
                      [0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100,
                       0.0100],
                      [0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100,
                       0.0100],
                      [0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100,
                       0.0100],
                      [0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100,
                       0.0100]])),
             ('output_dict.a.bias',
              tensor([0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100,
                      0.0100])),
             ('output_dict.b.weight',
              tensor([[1.0100, 1.0100, 1.0100, 1.0100, 1.0100, 1.0100, 1.0100, 1.0100, 1.0100,
                       1.0100],
                      [1.0100, 1.0100, 1.0100, 1.0100, 1.0100, 1.0100, 1.0100, 1.0100, 1.0100,
                       1.0100],
                      [1.0100, 1.0100, 1.0100, 1.0100, 1.0100, 1.0100, 1.0100, 1.0100, 1.0100,
                       1.0100],
                      [1.0100, 1.0100, 1.0100, 1.0100, 1.0100, 1.0100, 1.0100, 1.0100, 1.0100,
                       1.0100],
                      [1.0100, 1.0100, 1.0100, 1.0100, 1.0100, 1.0100, 1.0100, 1.0100, 1.0100,
                       1.0100],
                      [1.0100, 1.0100, 1.0100, 1.0100, 1.0100, 1.0100, 1.0100, 1.0100, 1.0100,
                       1.0100],
                      [1.0100, 1.0100, 1.0100, 1.0100, 1.0100, 1.0100, 1.0100, 1.0100, 1.0100,
                       1.0100],
                      [1.0100, 1.0100, 1.0100, 1.0100, 1.0100, 1.0100, 1.0100, 1.0100, 1.0100,
                       1.0100],
                      [1.0100, 1.0100, 1.0100, 1.0100, 1.0100, 1.0100, 1.0100, 1.0100, 1.0100,
                       1.0100],
                      [1.0100, 1.0100, 1.0100, 1.0100, 1.0100, 1.0100, 1.0100, 1.0100, 1.0100,
                       1.0100]])),
             ('output_dict.b.bias',
              tensor([1.0100, 1.0100, 1.0100, 1.0100, 1.0100, 1.0100, 1.0100, 1.0100, 1.0100,
                      1.0100])),
             ('output_dict.c.weight',
              tensor([[2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100,
                       2.0100],
                      [2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100,
                       2.0100],
                      [2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100,
                       2.0100],
                      [2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100,
                       2.0100],
                      [2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100,
                       2.0100],
                      [2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100,
                       2.0100],
                      [2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100,
                       2.0100],
                      [2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100,
                       2.0100],
                      [2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100,
                       2.0100],
                      [2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100,
                       2.0100]])),
             ('output_dict.c.bias',
              tensor([2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100,
                      2.0100])),
             ('output_fc.a.weight',
              tensor([[0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100,
                       0.0100]])),
             ('output_fc.a.bias', tensor([0.0100])),
             ('output_fc.b.weight',
              tensor([[1.0100, 1.0100, 1.0100, 1.0100, 1.0100, 1.0100, 1.0100, 1.0100, 1.0100,
                       1.0100]])),
             ('output_fc.b.bias', tensor([1.0100])),
             ('output_fc.c.weight',
              tensor([[2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100,
                       2.0100]])),
             ('output_fc.c.bias', tensor([2.0100]))])

부분적인 모델 개발 후 모델 로드 하기

model2 = Net(10,10, ['c'])
Net(
  (fc1): Linear(in_features=10, out_features=10, bias=True)
  (output_dict): ModuleDict(
    (c): Linear(in_features=10, out_features=10, bias=True)
  )
  (output_fc): ModuleDict(
    (c): Linear(in_features=10, out_features=1, bias=True)
  )
  (relu): ReLU()
)
OrderedDict([('fc1.weight',
              tensor([[ 0.1946,  0.2048,  0.1365, -0.1957,  0.1567,  0.0072, -0.0784,  0.1737,
                       -0.2101,  0.0298],
                      [-0.2536,  0.3019, -0.0925,  0.0198, -0.1691, -0.3024, -0.2791,  0.1541,
                       -0.0515,  0.2403],
                      [-0.2729, -0.0287, -0.1100, -0.2106, -0.0440, -0.2403, -0.2801, -0.2588,
                       -0.2976,  0.0337],
                      [-0.1317, -0.2659, -0.1828, -0.0062,  0.1688, -0.0673, -0.3125, -0.1765,
                       -0.1147, -0.3059],
                      [-0.0957,  0.1289, -0.3119,  0.2531, -0.3111, -0.1319,  0.0591, -0.2888,
                        0.0585,  0.0587],
                      [ 0.1130, -0.0909,  0.2883,  0.0791,  0.1900,  0.0951,  0.2858, -0.0830,
                        0.2342,  0.0112],
                      [-0.0021,  0.0022,  0.1160,  0.2418, -0.1260,  0.2176,  0.0009,  0.0213,
                       -0.2691,  0.2210],
                      [ 0.0776, -0.0344, -0.1799,  0.0740, -0.1348, -0.2958,  0.0050, -0.0732,
                        0.0202,  0.1260],
                      [-0.0207,  0.2465, -0.0502,  0.1314, -0.1049, -0.2610, -0.2597,  0.0308,
                       -0.2391,  0.2963],
                      [ 0.0236,  0.1121,  0.1393,  0.2172, -0.0822,  0.0677, -0.2918,  0.1014,
                        0.2830,  0.2161]])),
             ('fc1.bias',
              tensor([-0.1148,  0.0394,  0.1593, -0.1153,  0.2548, -0.2897, -0.0852,  0.0883,
                       0.0784, -0.1738])),
             ('output_dict.c.weight',
              tensor([[0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100,
                       0.0100],
                      [0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100,
                       0.0100],
                      [0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100,
                       0.0100],
                      [0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100,
                       0.0100],
                      [0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100,
                       0.0100],
                      [0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100,
                       0.0100],
                      [0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100,
                       0.0100],
                      [0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100,
                       0.0100],
                      [0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100,
                       0.0100],
                      [0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100,
                       0.0100]])),
             ('output_dict.c.bias',
              tensor([0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100,
                      0.0100])),
             ('output_fc.c.weight',
              tensor([[0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100,
                       0.0100]])),
             ('output_fc.c.bias', tensor([0.0100]))])

 

model2.load_state_dict(model.state_dict(), strict=False)

 

_IncompatibleKeys(missing_keys=[], unexpected_keys=['output_dict.a.weight', 'output_dict.a.bias', 'output_dict.b.weight', 'output_dict.b.bias', 'output_fc.a.weight', 'output_fc.a.bias', 'output_fc.b.weight', 'output_fc.b.bias'])

 

그러면 아래와 같이 미스 매치된 것들에 대한 것들이 나옴.

 

OrderedDict([('fc1.weight',
              tensor([[ 0.0017, -0.1278, -0.0266, -0.0183, -0.0126,  0.2083,  0.1939,  0.1057,
                       -0.1191,  0.2231],
                      [ 0.0284, -0.2354,  0.0346,  0.0863,  0.0670,  0.2026,  0.2626, -0.2846,
                        0.1729, -0.1541],
                      [-0.2734, -0.1182, -0.0646, -0.1695, -0.0969,  0.1581,  0.2259,  0.2693,
                        0.1707,  0.2501],
                      [ 0.0553,  0.1460, -0.0346, -0.3149, -0.1398, -0.1654, -0.2747, -0.1863,
                       -0.2407, -0.1953],
                      [-0.2127, -0.0775, -0.2025, -0.2823,  0.1191,  0.1548, -0.0751, -0.1676,
                        0.1808, -0.2031],
                      [-0.0907, -0.2605,  0.0703, -0.1875,  0.0004, -0.1114,  0.1550,  0.2570,
                       -0.2923,  0.1357],
                      [-0.0060,  0.1363,  0.1011,  0.0492,  0.2389, -0.2457,  0.2297,  0.1049,
                       -0.0022,  0.0798],
                      [-0.2751,  0.0437,  0.0444, -0.1880,  0.2614, -0.0529,  0.1194,  0.1280,
                       -0.0151, -0.2442],
                      [-0.2918, -0.1730, -0.1653,  0.0335,  0.1010, -0.0237,  0.2398,  0.0183,
                       -0.2276, -0.3015],
                      [ 0.1253, -0.1207,  0.2565,  0.2333, -0.1534, -0.0165,  0.1279,  0.1267,
                       -0.2052, -0.1208]])),
             ('fc1.bias',
              tensor([ 0.2301, -0.2335, -0.3057,  0.1267, -0.1215, -0.0412, -0.0494,  0.2126,
                      -0.1698,  0.1222])),
             ('output_dict.c.weight',
              tensor([[2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100,
                       2.0100],
                      [2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100,
                       2.0100],
                      [2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100,
                       2.0100],
                      [2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100,
                       2.0100],
                      [2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100,
                       2.0100],
                      [2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100,
                       2.0100],
                      [2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100,
                       2.0100],
                      [2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100,
                       2.0100],
                      [2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100,
                       2.0100],
                      [2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100,
                       2.0100]])),
             ('output_dict.c.bias',
              tensor([2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100,
                      2.0100])),
             ('output_fc.c.weight',
              tensor([[2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100, 2.0100,
                       2.0100]])),
             ('output_fc.c.bias', tensor([2.0100]))])

 

model.output_dict['c'].weight.data.numpy()
array([[2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01],
       [2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01],
       [2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01],
       [2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01],
       [2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01],
       [2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01],
       [2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01],
       [2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01],
       [2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01],
       [2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01]],
      dtype=float32)
model2.output_dict['c'].weight.data.numpy()
array([[2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01],
       [2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01],
       [2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01],
       [2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01],
       [2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01],
       [2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01],
       [2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01],
       [2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01],
       [2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01],
       [2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01, 2.01]],
      dtype=float32)

 

결론

결론적으로 ModuleDict을 사용하면 잘 맵핑되는 것을 확인할 수 있다. 

사실 이런 모델 구조를 만들 일이 없을 수도 있지만, 이런 식으로 전체에 대한 부분을 가져올 수 있다는 것이고, 중간에 학습이 잘 안되는 모델에 대해서 집중적으로도 학습 시키고 또 바꿔주고.. 머 이런 생각이 들었다...

 

 

 

728x90