[Pytorch] torch에서 모델 summary 확인하는 방법

2020. 8. 25. 23:30분석 Python/Pytorch

728x90

pytorch에서  keras처럼 summary를 정리해주는 함수가 있어서 공유한다.

찾다 보면 좋은 툴이 많은 것 같다(굳굳)

Keras처럼 파라미터 개수랑 용량을 제공해준다!

 

import torch
from torch import nn
from torchsummary import summary as summary_
from torch.nn import functional as F


class MnistModel(nn.Module):
    def __init__(self):
        super(MnistModel, self).__init__()
        # input is 28x28
        # padding=2 for same padding
        self.conv1 = nn.Conv2d(1, 32, 5, padding=2)
        # feature map size is 14*14 by pooling
        # padding=2 for same padding
        self.conv2 = nn.Conv2d(32, 64, 5, padding=2)
        # feature map size is 7*7 by pooling
        self.fc1 = nn.Linear(64*7*7, 1024)
        self.fc2 = nn.Linear(1024, 10)
        
    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), 2)
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, 64*7*7)   # reshape Variable
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x)
    
model = MnistModel()

summary_(model,(1,28,28),batch_size=10)

 

 

https://tensor flow.blog/2017/01/26/pytorch-mnist-example/

 

PyTorch MNIST Example

파이토치(PyTorch)로 텐서플로우 튜토리얼에 있는 MNIST 예제를 재현해 보았습니다. 이 코드는 파이토치의 MNIST 예제를 참고했으며 주피터 노트북으로 작성되어 깃허브에 올려져 있습니다. 당연하��

tensorflow.blog

 

728x90