본문 바로가기
취미 낙지/Python

[딥러닝 / Pytorch / 이미지 분석] GAN 기초

by 대머리낙지 2023. 6. 27.
반응형

파이토치-로고

안녕하세요 대머리 낙지입니다.

 

오늘은 GAN에 대해서 말씀드리겠습니다.

 

GAN이란 방법은 이미지 생성 or 새로운 특정한 형태의 데이터의 생성에 쓰입니다. GAN 안에는 2가지 모델이 동시에 존재합니다.

 

1. Discriminator

 - Discriminator는 Generator에서 생성된 결과물을 실제 데이터와 비교해서 진짜/가짜를 판별하는데, Generator에서 생성된 결과물을 가짜로 판별할 수 있도록 학습됩니다.

 

2. Generator

 - Generator는 말 그대로 새로운 데이터를 생성하는데, Discriminator를 속일 수 있도록 학습됩니다.

 

MNIST-손글씨
MNIST 데이터셋

직접 Generator와 Discriminator를 모델링하고 MNIST(손글씨) 데이터셋을 활용해서 GAN을 학습시켜 보겠습니다.

사용할 모듈을 불러옵니다.

import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import torchvision
import torchvision.transforms as T
import numpy as np
from torchvision import datasets
from torch.utils.data import DataLoader
import os
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

 

MNIST 데이터를 받고 확인합니다.

# load datasets
batch_size = 32
transform = T.Compose([
    T.ToTensor(),
])
trainset = datasets.MNIST(root='../data/', train=True, transform=transform, download=True)
trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, drop_last=True)

img = trainset[0][0].numpy()
img_plot = img.transpose(1,2,0)
plt.imshow(img_plot, 'gray')
plt.show()

 

MNIST-5

Discriminator를 생성합니다. 간단한 CNN 모델을 만들었습니다.

# Descriminator
class Discriminator(nn.Module):
    def __init__(self,):
        super(Discriminator, self).__init__()
        self.conv1 = nn.Conv2d(1, 8, 3, 1, 1)
        self.bn1 = nn.BatchNorm2d(8)
        self.conv2 = nn.Conv2d(8, 16, 3, 1, 1)
        self.bn2 = nn.BatchNorm2d(16)
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(16, 1)

    def forward(self, x):
        x = F.leaky_relu(self.bn1(self.conv1(x)))
        x = F.leaky_relu(self.bn2(self.conv2(x)))
        x = self.avgpool(x)
        x = x.squeeze()
        x = F.sigmoid(self.fc(x))
        return x
    
model = Discriminator()
x = torch.randn([4, 1, 28, 28])
model(x).shape

 

Generator를 생성합니다. 100개의 random noise로부터 이미지 형태를 만드는 구조입니다.

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        
        self.gen = nn.Sequential(
        nn.Linear(100, 128),
        nn.BatchNorm1d(128),
        nn.LeakyReLU(),
        nn.Linear(128, 256),
        nn.BatchNorm1d(256),
        nn.LeakyReLU(),
        nn.Linear(256, 512),
        nn.BatchNorm1d(512),
        nn.LeakyReLU(),
        nn.Linear(512, 1024),
        nn.BatchNorm1d(1024),
        nn.LeakyReLU(),
        nn.Linear(1024, 28*28),
        )

    def forward(self, x):
        x = self.gen(x)
        x = x.reshape(x.shape[0], 1, 28, 28)
        x = F.sigmoid(x)
        return x

model = Generator()
x = torch.randn([4, 100])
model(x).shape

 

반응형

 

학습을 진행합니다. Discriminator의 경우 실제 이미지가 들어간 경우와, 생성된 이미지가 들어간 경우 2가지에 대해 loss를 계산하고 합칩니다. Generator의 경우 생성된 이미지가 Discriminator에 들어가서 실제 이미지로 판별되도록 loss를 계산합니다.

epoch = 50

D = Discriminator().to(DEVICE)
G = Generator().to(DEVICE)
criterion = nn.BCELoss()

opt_D = torch.optim.Adam(D.parameters(), lr=0.001)
opt_G = torch.optim.Adam(G.parameters(), lr=0.001)

real = torch.ones((batch_size, 1)).to(DEVICE) # 정답지
fake = torch.zeros((batch_size, 1)).to(DEVICE) # 오답지

for ep in range(epoch):

    g_loss = 0
    d_loss = 0

    for img, _ in trainloader:
        img = img.to(DEVICE)
        
        #=== Discriminator ===
        # real part
        D_x_out = D(img)
        D_x_loss = criterion(D_x_out, real)

        # fake part
        z = torch.randn((batch_size, 100)).to(DEVICE)
        D_z_out = D(G(z))
        D_z_loss = criterion(D_z_out, fake)
        D_loss = D_x_loss + D_z_loss

        d_loss += D_loss.item() # for monitoring

        opt_D.zero_grad()
        D_loss.backward()
        opt_D.step()

        #=== Generator ===
        z_g = torch.randn((batch_size, 100)).to(DEVICE)
        D_g_out = D(G(z_g))
        G_loss = criterion(D_g_out, real)

        g_loss += G_loss.item() # for monitoring

        opt_G.zero_grad()
        G_loss.backward()
        opt_G.step()
    print('{}: D_loss:{:.4f}, G_loss:{:.4f}'.format(ep, d_loss, g_loss))

D_loss와 G_loss 모두 동시에 줄어드는 것을 볼 수 있습니다.

 

이번에는 Generator를 이용해 손글씨 이미지를 생성해 보겠습니다.

import matplotlib.pyplot as plt
z_g = torch.randn((4, 100)).to(DEVICE)
G.eval()
out = G(z_g)
out = out.to('cpu')
out = out.detach().numpy()
out = out.transpose(0,2,3,1)

plt.figure()
plt.subplot(121)
plt.imshow(out[0], 'gray')
plt.subplot(122)
plt.imshow(out[1], 'gray')
plt.show()

실험결과

 

7도 아니고 6도 아니고 8도 아닌 요상한 글씨가 나왔습니다. 이런 이유는 모델의 튜닝과 학습이 적절히 진행되지 않았기 때문입니다. 모델을 교체하시거나 epoch을 더 늘려보시는 방법 등으로 학습을 최적화하시면 좀 더 좋은 결과물을 얻을 수 있습니다.

감사합니다.

 

구독과 좋아요는 큰 힘이됩니다 : )

반응형