안녕하세요 대머리 낙지입니다.
오늘은 GAN에 대해서 말씀드리겠습니다.
GAN이란 방법은 이미지 생성 or 새로운 특정한 형태의 데이터의 생성에 쓰입니다. GAN 안에는 2가지 모델이 동시에 존재합니다.
1. Discriminator
- Discriminator는 Generator에서 생성된 결과물을 실제 데이터와 비교해서 진짜/가짜를 판별하는데, Generator에서 생성된 결과물을 가짜로 판별할 수 있도록 학습됩니다.
2. Generator
- Generator는 말 그대로 새로운 데이터를 생성하는데, Discriminator를 속일 수 있도록 학습됩니다.
직접 Generator와 Discriminator를 모델링하고 MNIST(손글씨) 데이터셋을 활용해서 GAN을 학습시켜 보겠습니다.
사용할 모듈을 불러옵니다.
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import torchvision
import torchvision.transforms as T
import numpy as np
from torchvision import datasets
from torch.utils.data import DataLoader
import os
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MNIST 데이터를 받고 확인합니다.
# load datasets
batch_size = 32
transform = T.Compose([
T.ToTensor(),
])
trainset = datasets.MNIST(root='../data/', train=True, transform=transform, download=True)
trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, drop_last=True)
img = trainset[0][0].numpy()
img_plot = img.transpose(1,2,0)
plt.imshow(img_plot, 'gray')
plt.show()
Discriminator를 생성합니다. 간단한 CNN 모델을 만들었습니다.
# Descriminator
class Discriminator(nn.Module):
def __init__(self,):
super(Discriminator, self).__init__()
self.conv1 = nn.Conv2d(1, 8, 3, 1, 1)
self.bn1 = nn.BatchNorm2d(8)
self.conv2 = nn.Conv2d(8, 16, 3, 1, 1)
self.bn2 = nn.BatchNorm2d(16)
self.avgpool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Linear(16, 1)
def forward(self, x):
x = F.leaky_relu(self.bn1(self.conv1(x)))
x = F.leaky_relu(self.bn2(self.conv2(x)))
x = self.avgpool(x)
x = x.squeeze()
x = F.sigmoid(self.fc(x))
return x
model = Discriminator()
x = torch.randn([4, 1, 28, 28])
model(x).shape
Generator를 생성합니다. 100개의 random noise로부터 이미지 형태를 만드는 구조입니다.
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.gen = nn.Sequential(
nn.Linear(100, 128),
nn.BatchNorm1d(128),
nn.LeakyReLU(),
nn.Linear(128, 256),
nn.BatchNorm1d(256),
nn.LeakyReLU(),
nn.Linear(256, 512),
nn.BatchNorm1d(512),
nn.LeakyReLU(),
nn.Linear(512, 1024),
nn.BatchNorm1d(1024),
nn.LeakyReLU(),
nn.Linear(1024, 28*28),
)
def forward(self, x):
x = self.gen(x)
x = x.reshape(x.shape[0], 1, 28, 28)
x = F.sigmoid(x)
return x
model = Generator()
x = torch.randn([4, 100])
model(x).shape
학습을 진행합니다. Discriminator의 경우 실제 이미지가 들어간 경우와, 생성된 이미지가 들어간 경우 2가지에 대해 loss를 계산하고 합칩니다. Generator의 경우 생성된 이미지가 Discriminator에 들어가서 실제 이미지로 판별되도록 loss를 계산합니다.
epoch = 50
D = Discriminator().to(DEVICE)
G = Generator().to(DEVICE)
criterion = nn.BCELoss()
opt_D = torch.optim.Adam(D.parameters(), lr=0.001)
opt_G = torch.optim.Adam(G.parameters(), lr=0.001)
real = torch.ones((batch_size, 1)).to(DEVICE) # 정답지
fake = torch.zeros((batch_size, 1)).to(DEVICE) # 오답지
for ep in range(epoch):
g_loss = 0
d_loss = 0
for img, _ in trainloader:
img = img.to(DEVICE)
#=== Discriminator ===
# real part
D_x_out = D(img)
D_x_loss = criterion(D_x_out, real)
# fake part
z = torch.randn((batch_size, 100)).to(DEVICE)
D_z_out = D(G(z))
D_z_loss = criterion(D_z_out, fake)
D_loss = D_x_loss + D_z_loss
d_loss += D_loss.item() # for monitoring
opt_D.zero_grad()
D_loss.backward()
opt_D.step()
#=== Generator ===
z_g = torch.randn((batch_size, 100)).to(DEVICE)
D_g_out = D(G(z_g))
G_loss = criterion(D_g_out, real)
g_loss += G_loss.item() # for monitoring
opt_G.zero_grad()
G_loss.backward()
opt_G.step()
print('{}: D_loss:{:.4f}, G_loss:{:.4f}'.format(ep, d_loss, g_loss))
D_loss와 G_loss 모두 동시에 줄어드는 것을 볼 수 있습니다.
이번에는 Generator를 이용해 손글씨 이미지를 생성해 보겠습니다.
import matplotlib.pyplot as plt
z_g = torch.randn((4, 100)).to(DEVICE)
G.eval()
out = G(z_g)
out = out.to('cpu')
out = out.detach().numpy()
out = out.transpose(0,2,3,1)
plt.figure()
plt.subplot(121)
plt.imshow(out[0], 'gray')
plt.subplot(122)
plt.imshow(out[1], 'gray')
plt.show()
7도 아니고 6도 아니고 8도 아닌 요상한 글씨가 나왔습니다. 이런 이유는 모델의 튜닝과 학습이 적절히 진행되지 않았기 때문입니다. 모델을 교체하시거나 epoch을 더 늘려보시는 방법 등으로 학습을 최적화하시면 좀 더 좋은 결과물을 얻을 수 있습니다.
감사합니다.
구독과 좋아요는 큰 힘이됩니다 : )
'취미 낙지 > Python' 카테고리의 다른 글
[딥러닝 / Pytorch / 이미지 분석] Classification 예제 정리3 - ResNet (0) | 2023.06.26 |
---|---|
[딥러닝 / Pytorch / 이미지 분석] Classification 예제 정리2 - VGGNet (0) | 2023.06.26 |
[딥러닝 / Pytorch / 이미지 분석] Classification 예제 정리1 - 기본 사용법 (0) | 2023.06.25 |