본문 바로가기
취미 낙지/Python

[딥러닝 / Pytorch / 이미지 분석] Classification 예제 정리3 - ResNet

by 대머리낙지 2023. 6. 26.
반응형

파이토치-로고

안녕하세요 대머리 낙지입니다.

 

VGG에 이어서 ResNet을 Pytorch로 직접 구현해 보는 시간을 갖겠습니다.

ResNet이란 Residual Block의 개념이 처음 도입된 network로써 등장 당시에 아주 큰 성능 향상과 함께 큰 파문을 일으켰습니다.

현재 사용되는 대부분의 Network들도 Residual Block이라는 개념이 들어가 있을 정도로 말입니다.

 

Residual Block이란?

아래 그림과 같이 단순하게 Block의 input이 conv layer를 지난 결과와 합쳐지는 것입니다.

이 과정을 통해 input의 conv layer를 반복적으로 거치며 처음의 특징이 사라지는 현상을 막을 수 있게 되었습니다.

즉, Network를 더 깊게 구성할 수 있게 되었습니다.

ResNet-구조
출처: ResNet paper

 

 

pytorch로 Residual Block을 구현하고, 그것으로 ResNet34를 구성해 보겠습니다.

ResNet34의 구조는 아래와 같습니다.

ResNet-구조
출처: ResNet paper

 

먼저 Residual Block을 구성합니다.

class ResBlock(nn.Module):
    def __init__(self, ch_in, ch_out, down_sample=False):
        super(ResBlock, self).__init__()

        self.ch_in = ch_in
        self.ch_out = ch_out
        self.down_sample = down_sample

        self.conv1 = nn.Conv2d(self.ch_in, self.ch_out, kernel_size=3, stride=1, padding=1, bias=False)
        if down_sample:
            self.conv1 = nn.Conv2d(self.ch_in, self.ch_out, kernel_size=3, stride=2, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(self.ch_out)
        self.relu1 = nn.ReLU()

        self.conv2 = nn.Conv2d(self.ch_out, self.ch_out, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(self.ch_out)
        self.relu2 = nn.ReLU()

        # 1x1 conv for spatial size
        self.conv1x1 = nn.Conv2d(self.ch_in, self.ch_out, kernel_size=1, stride=1, bias=False)
        if down_sample:
            self.conv1x1 = nn.Conv2d(self.ch_in, self.ch_out, kernel_size=1, stride=2, bias=False)

        self.relu3 = nn.ReLU()

    def forward(self, x):
        short_cut = x
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu1(x)

        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu2(x)

        x = x + self.conv1x1(short_cut)
        
        return self.relu3(x)


model = ResBlock(8, 16, down_sample=True)
x = torch.randn([16, 8, 256, 256])
out = model(x)

print('residual block test:')
print(x.shape)
print(out.shape)

여기서 특징은 stride=2로 channel down sampling이 될 때 1x1 conv.로 다시 input의 channel과 동일하게 맞춰주는 것입니다.

 

반응형

 

다음으로는 ResNet 34를 구현해 보겠습니다.

이해가 쉽도록 위 ResNet34 Structure 이미지에서 각 블록의 색상에 변수 이름을 매칭시켰습니다.

class ResNet(nn.Module):
    def __init__(self, n_class=1000):
        super(ResNet, self).__init__()
        self.n_class = n_class

        self.conv0 = nn.Conv2d(3, 64, kernel_size=7, stride=2, bias=False)

        self.resblock1 = ResBlock(64, 64, down_sample=False)
        self.resblock2 = ResBlock(64, 64, down_sample=False)
        self.resblock3 = ResBlock(64, 64, down_sample=False)

        self.purple = nn.Sequential(
            ResBlock(64, 64, down_sample=False),
            ResBlock(64, 64, down_sample=False),
            ResBlock(64, 64, down_sample=False),
        )

        self.green = nn.Sequential(
            ResBlock(64, 128, down_sample=True),
            ResBlock(128, 128, down_sample=False),
            ResBlock(128, 128, down_sample=False),
            ResBlock(128, 128, down_sample=False),
        )

        self.red = nn.Sequential(
            ResBlock(128, 256, down_sample=True),
            ResBlock(256, 256, down_sample=False),
            ResBlock(256, 256, down_sample=False),
            ResBlock(256, 256, down_sample=False),
            ResBlock(256, 256, down_sample=False),
        )


        self.blue = nn.Sequential(
            ResBlock(256, 512, down_sample=True),
            ResBlock(512, 512, down_sample=False),
            ResBlock(512, 512, down_sample=False),
        )

        self.fc = nn.Sequential(
            nn.Linear(512, 2048),
            nn.BatchNorm1d(2048),
            nn.ReLU(),
            nn.Linear(2048, self.n_class),
        )
    


    def forward(self, x):
        x = self.conv0(x)

        x = self.purple(x)
        x = self.green(x)
        x = self.red(x)
        x = self.blue(x)
        x = F.adaptive_avg_pool2d(x, 1).squeeze()

        x = self.fc(x)


        out = x
        return out
    

model = ResNet(n_class=1000)
x = torch.randn([16, 3, 128, 128])
out = model(x)

print('residual block test:')
print(x.shape)
print(out.shape)

위 코드로 batch size 16의 128x128 크기 color 이미지에 대해 확인해 볼 수 있었습니다.

 

감사합니다.

 

학습 코드는 이전 글에서 모델과 input 수정 부탁드립니다.

[취미생활/Python] - [딥러닝 / Pytorch / 이미지 분석] Classification 예제 정리

 

[딥러닝 / Pytorch / 이미지 분석] Classification 예제 정리

안녕하세요 대머리 낙지입니다. Pytorch를 활용한 이미지 분류(Classification) 문제를 다루는 전반적인 flow를 기록해보려 합니다. Google colab환경에서 CIFAR10 이미지 데이터를 활용해 간단한 CNN을 돌려

nakzi-lab.tistory.com

 

반응형