SJ_Koding

Pytorch, 이미지 분류 코드 자세히 이해하기 (4편) - ResNet9 모델 본문

PyTorch Code/Pytorch

Pytorch, 이미지 분류 코드 자세히 이해하기 (4편) - ResNet9 모델

성지코딩 2023. 11. 8. 17:02

2023.11.08 - [Deep Learning/Pytorch] - Pytorch, 이미지 분류 코드 자세히 이해하기 (3편) - AutoAugment

 

Pytorch, 이미지 분류 코드 자세히 이해하기 (3편) - AutoAugment

2023.11.07 - [Deep Learning/Pytorch] - Pytorch, 이미지 분류 코드 자세히 이해하기 (2편) - Dataset Pytorch, 이미지 분류 코드 자세히 이해하기 (2편) - Dataset 2023.11.07 - [Deep Learning/Pytorch] - Pytorch, 이미지 분류 코드

sjkoding.tistory.com

*이전 글들과 이어지는 내용입니다.

 

이번 대회에서는 ResNet 9 모델을 사용했습니다. 기존에는 18, 50, 152가 익숙한 숫자이신 분들도 많을텐데 이미지 크기가 32 by 32이므로 기존 ResNet9 계열에서 레이어를 더 쌓은 ResNet11 모델을 사용했습니다.

ResNet9 출처: https://www.researchgate.net/figure/ResNet-9-architecture-A-convolutional-neural-net-with-9-layers-and-skip-connections_fig1_363585139

 

def conv_block(in_channels, out_channels, pool=False):
    layers = [nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
              nn.BatchNorm2d(out_channels),
              nn.ReLU(inplace=True)]
    if pool: layers.append(nn.MaxPool2d(2))
    return nn.Sequential(*layers)

class ResNet11(nn.Module):
    def __init__(self, in_channels, num_classes):
        super().__init__()

        self.conv1 = conv_block(in_channels, 64)
        self.conv2 = conv_block(64, 128, pool=True)
        self.res1 = nn.Sequential(conv_block(128, 128), conv_block(128, 128))

        self.conv3 = conv_block(128, 256, pool=True)
        self.conv4 = conv_block(256, 512, pool=True)
        self.res2 = nn.Sequential(conv_block(512, 512), conv_block(512, 512))
        self.conv5 = conv_block(512, 1024, pool=True)
        self.res3 = nn.Sequential(conv_block(1024, 1024), conv_block(1024, 1024))

        self.classifier = nn.Sequential(nn.MaxPool2d(2), # 1024 x 1 x 1
                                        nn.Flatten(), # 1024
                                        nn.Linear(1024, num_classes)) # 1024 -> 30(클래스가 30개이므로)

    def forward(self, xb):
        out = self.conv1(xb)
        out = self.conv2(out)
        out = self.res1(out) + out # skip connection!
        out = self.conv3(out)
        out = self.conv4(out)
        out = self.res2(out) + out # skip connection!
        out = self.conv5(out)
        out = self.res3(out) + out # skip connection!
        out = self.classifier(out)
        return out

주석으로 skip connection을 표시해놨다. 이는 입력과 출력의 잔차(residual)를 학습할 수 있도록  하는데, 이는 신경망의 기울기 소실(Vanishing gradient) 문제를 완화시키는 기능을 합니다. (마치 LSTM의 정보 전달 매커니즘과 비슷합니다.)

 

그리고 모델을 불러와 GPU 장치에 담아줍니다.

model = ResNet11(3, 30) # 컬러 이미지이므로 3채널, 클래스 30개
model = model.to('cuda')

 

 

ResNet에 대한 자세한 이론적 내용은 아래 블로그를 참고해주세요.

https://velog.io/@lighthouse97/ResNet%EC%9D%98-%EC%9D%B4%ED%95%B4#54-bottleneck

 

ResNet의 이해

참고1 참고2 참고3 1. 동기(Motivation) 2015년 ILSVRC(ImageNet Large Scale Visual Recognition Challenge)에서 우승을 차지한 ResNet에 대해서 소개하려고 한다. ResNet은 마이크로소프트에서 개발한 알

velog.io