This repository has been archived by the owner on Oct 19, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 244
/
Copy pathconditional_gan.py
139 lines (104 loc) · 3.59 KB
/
conditional_gan.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
#!/usr/bin/env python
# coding: utf-8
# # Conditional GAN으로 생성 컨트롤하기
import os
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets
from torchvision.utils import save_image
import matplotlib.pyplot as plt
import numpy as np
# 하이퍼파라미터
EPOCHS = 300
BATCH_SIZE = 100
USE_CUDA = torch.cuda.is_available()
DEVICE = torch.device("cuda" if USE_CUDA else "cpu")
print("Using Device:", DEVICE)
# Fashion MNIST 데이터셋
trainset = datasets.FashionMNIST('./.data',
train=True,
download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
]))
train_loader = torch.utils.data.DataLoader(
dataset = trainset,
batch_size = BATCH_SIZE,
shuffle = True)
def one_hot_embedding(labels, num_classes):
y = torch.eye(num_classes)
return y[labels]
# 생성자 (Generator)
D = nn.Sequential(
nn.Linear(784, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, 1),
nn.Sigmoid())
# 판별자 (Discriminator)
G = nn.Sequential(
nn.Linear(64 + 10, 256),
nn.ReLU(),
nn.Linear(256, 256),
nn.ReLU(),
nn.Linear(256, 784),
nn.Tanh())
# 모델의 가중치를 지정한 장치로 보내기
D = D.to(DEVICE)
G = G.to(DEVICE)
# 이진 크로스 엔트로피 (Binary cross entropy) 오차 함수와
# 생성자와 판별자를 최적화할 Adam 모듈
criterion = nn.BCELoss()
d_optimizer = optim.Adam(D.parameters(), lr=0.0002)
g_optimizer = optim.Adam(G.parameters(), lr=0.0002)
total_step = len(train_loader)
for epoch in range(EPOCHS):
for i, (images, label) in enumerate(train_loader):
images = images.reshape(BATCH_SIZE, -1).to(DEVICE)
real_labels = torch.ones(BATCH_SIZE, 1).to(DEVICE)
fake_labels = torch.zeros(BATCH_SIZE, 1).to(DEVICE)
outputs = D(images)
d_loss_real = criterion(outputs, real_labels)
real_score = outputs
class_label = one_hot_embedding(label, 10).to(DEVICE)
z = torch.randn(BATCH_SIZE, 64).to(DEVICE)
generator_input = torch.cat([z, class_label], 1)
fake_images= G(generator_input)
outputs = D(fake_images)
d_loss_fake = criterion(outputs, fake_labels)
fake_score = outputs
# Backprop and optimize
d_loss = d_loss_real + d_loss_fake
d_optimizer.zero_grad()
d_loss.backward()
d_optimizer.step()
# Train Generator
# Compute loss with fake images
fake_images = G(generator_input)
outputs = D(fake_images)
g_loss = criterion(outputs, real_labels)
# Backprop and optimize
d_optimizer.zero_grad()
g_optimizer.zero_grad()
g_loss.backward()
g_optimizer.step()
print('Epoch [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}, D(x): {:.2f}, D(G(z)): {:.2f}'
.format(epoch,
EPOCHS,
d_loss.item(),
g_loss.item(),
real_score.mean().item(),
fake_score.mean().item()))
for i in range(100):
label = torch.tensor([4])
class_label = one_hot_embedding(label, 10).to(DEVICE)
z = torch.randn(1, 64).to(DEVICE)
generator_input = torch.cat([z, class_label], 1)
fake_images= G(generator_input)
fake_images = np.reshape(fake_images.cpu().data.numpy()[0],(28, 28))
plt.imshow(fake_images, cmap = 'gray')
plt.show()