This repository has been archived by the owner on Oct 19, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 244
/
Copy pathconditional_gan.py
168 lines (134 loc) · 5.04 KB
/
conditional_gan.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
#!/usr/bin/env python
# coding: utf-8
# # cGAN으로 생성 제어하기
import os
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets
from torchvision.utils import save_image
import matplotlib.pyplot as plt
import numpy as np
# 하이퍼파라미터
EPOCHS = 300
BATCH_SIZE = 100
USE_CUDA = torch.cuda.is_available()
DEVICE = torch.device("cuda" if USE_CUDA else "cpu")
print("Using Device:", DEVICE)
# Fashion MNIST 데이터셋
trainset = datasets.FashionMNIST(
'./.data',
train=True,
download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
)
train_loader = torch.utils.data.DataLoader(
dataset = trainset,
batch_size = BATCH_SIZE,
shuffle = True
)
# 생성자 (Generator)
class Generator(nn.Module):
def __init__(self):
super().__init__()
self.embed = nn.Embedding(10, 10)
self.model = nn.Sequential(
nn.Linear(110, 256),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(256, 512),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(512, 1024),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(1024, 784),
nn.Tanh()
)
def forward(self, z, labels):
c = self.embed(labels)
x = torch.cat([z, c], 1)
return self.model(x)
# 판별자 (Discriminator)
class Discriminator(nn.Module):
def __init__(self):
super().__init__()
self.embed = nn.Embedding(10, 10)
self.model = nn.Sequential(
nn.Linear(794, 1024),
nn.LeakyReLU(0.2, inplace=True),
nn.Dropout(0.3),
nn.Linear(1024, 512),
nn.LeakyReLU(0.2, inplace=True),
nn.Dropout(0.3),
nn.Linear(512, 256),
nn.LeakyReLU(0.2, inplace=True),
nn.Dropout(0.3),
nn.Linear(256, 1),
nn.Sigmoid()
)
def forward(self, x, labels):
c = self.embed(labels)
x = torch.cat([x, c], 1)
return self.model(x)
# 모델 인스턴스를 만들고 모델의 가중치를 지정한 장치로 보내기
D = Discriminator().to(DEVICE)
G = Generator().to(DEVICE)
# 이진 교차 엔트로피 함수와
# 생성자와 판별자를 최적화할 Adam 모듈
criterion = nn.BCELoss()
d_optimizer = optim.Adam(D.parameters(), lr =0.0002)
g_optimizer = optim.Adam(G.parameters(), lr =0.0002)
total_step = len(train_loader)
for epoch in range(EPOCHS):
for i, (images, labels) in enumerate(train_loader):
images = images.reshape(BATCH_SIZE, -1).to(DEVICE)
# '진짜'와 '가짜' 레이블 생성
real_labels = torch.ones(BATCH_SIZE, 1).to(DEVICE)
fake_labels = torch.zeros(BATCH_SIZE, 1).to(DEVICE)
# 판별자가 진짜 이미지를 진짜로 인식하는 오차 계산 (데이터셋 레이블 입력)
labels = labels.to(DEVICE)
outputs = D(images, labels)
d_loss_real = criterion(outputs, real_labels)
real_score = outputs
# 무작위 텐서와 무작위 레이블을 생성자에 입력해 가짜 이미지 생성
z = torch.randn(BATCH_SIZE, 100).to(DEVICE)
g_label = torch.randint(0, 10, (BATCH_SIZE,)).to(DEVICE)
fake_images = G(z, g_label)
# 판별자가 가짜 이미지를 가짜로 인식하는 오차 계산
outputs = D(fake_images, g_label)
d_loss_fake = criterion(outputs, fake_labels)
fake_score = outputs
# 진짜와 가짜 이미지를 갖고 낸 오차를 더해서 판별자의 오차 계산
d_loss = d_loss_real + d_loss_fake
# 역전파 알고리즘으로 판별자 모델의 학습을 진행
d_optimizer.zero_grad()
g_optimizer.zero_grad()
d_loss.backward()
d_optimizer.step()
# 생성자가 판별자를 속였는지에 대한 오차 계산(무작위 레이블 입력)
fake_images = G(z, g_label)
outputs = D(fake_images, g_label)
g_loss = criterion(outputs, real_labels)
# 역전파 알고리즘으로 생성자 모델의 학습을 진행
d_optimizer.zero_grad()
g_optimizer.zero_grad()
g_loss.backward()
g_optimizer.step()
print('이폭 [{}/{}] d_loss:{:.4f} g_loss: {:.4f} D(x):{:.2f} D(G(z)):{:.2f}'
.format(epoch,
EPOCHS,
d_loss.item(),
g_loss.item(),
real_score.mean().item(),
fake_score.mean().item()))
# 만들고 싶은 아이템 생성하고 시각화하기
item_number = 9 # 아이템 번호
z = torch.randn(1, 100).to(DEVICE) # 배치 크기 1
g_label = torch.full((1,), item_number, dtype=torch.long).to(DEVICE)
sample_images = G(z, g_label)
sample_images_img = np.reshape(sample_images.data.cpu().numpy()
[0],(28, 28))
plt.imshow(sample_images_img, cmap = 'gray')
plt.show()