From 8855eda8dd63bbecd39c6367057ba1767f95ec89 Mon Sep 17 00:00:00 2001 From: keon Date: Sun, 14 Jul 2019 00:34:08 -0700 Subject: [PATCH] recompile python --- .../02-sequence-to-sequence-lstm.ipynb" | 254 ------------------ .../02-sequence-to-sequence-lstm.py" | 105 -------- .../02-sequence-to-sequence.py" | 16 +- .../03-seq2seq_gru.py" | 112 -------- .../assets/encoder_decoder.png" | Bin 7483 -> 0 bytes .../assets/pics" | 1 - .../01-fgsm-attack.py" | 14 +- .../02-conditional-gan.py" | 109 +++----- .../01-cartpole-dqn.py" | 148 +--------- README.md | 1 - 10 files changed, 50 insertions(+), 710 deletions(-) delete mode 100644 "07-\354\210\234\354\260\250\354\240\201\354\235\270_\353\215\260\354\235\264\355\204\260\353\245\274_\354\262\230\353\246\254\355\225\230\353\212\224_RNN/02-sequence-to-sequence-lstm.ipynb" delete mode 100644 "07-\354\210\234\354\260\250\354\240\201\354\235\270_\353\215\260\354\235\264\355\204\260\353\245\274_\354\262\230\353\246\254\355\225\230\353\212\224_RNN/02-sequence-to-sequence-lstm.py" delete mode 100644 "07-\354\210\234\354\260\250\354\240\201\354\235\270_\353\215\260\354\235\264\355\204\260\353\245\274_\354\262\230\353\246\254\355\225\230\353\212\224_RNN/03-seq2seq_gru.py" delete mode 100644 "07-\354\210\234\354\260\250\354\240\201\354\235\270_\353\215\260\354\235\264\355\204\260\353\245\274_\354\262\230\353\246\254\355\225\230\353\212\224_RNN/assets/encoder_decoder.png" delete mode 100644 "07-\354\210\234\354\260\250\354\240\201\354\235\270_\353\215\260\354\235\264\355\204\260\353\245\274_\354\262\230\353\246\254\355\225\230\353\212\224_RNN/assets/pics" diff --git "a/07-\354\210\234\354\260\250\354\240\201\354\235\270_\353\215\260\354\235\264\355\204\260\353\245\274_\354\262\230\353\246\254\355\225\230\353\212\224_RNN/02-sequence-to-sequence-lstm.ipynb" "b/07-\354\210\234\354\260\250\354\240\201\354\235\270_\353\215\260\354\235\264\355\204\260\353\245\274_\354\262\230\353\246\254\355\225\230\353\212\224_RNN/02-sequence-to-sequence-lstm.ipynb" deleted file mode 100644 index bed9c07..0000000 --- "a/07-\354\210\234\354\260\250\354\240\201\354\235\270_\353\215\260\354\235\264\355\204\260\353\245\274_\354\262\230\353\246\254\355\225\230\353\212\224_RNN/02-sequence-to-sequence-lstm.ipynb" +++ /dev/null @@ -1,254 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Seq2Seq 기계 번역 (LSTM)\n", - "\n", - "이번 프로젝트에선 임의로 Seq2Seq 모델을 아주 간단화 시켰습니다.\n", - "한 언어로 된 문장을 다른 언어로 된 문장으로 번역하는 덩치가 큰 모델이 아닌\n", - "영어 알파벳 문자열(\"hello\")을 스페인어 알파벳 문자열(\"hola\")로 번역하는 Mini Seq2Seq 모델을 같이 구현해 보겠습니다." - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "import torch\n", - "import torch.nn as nn\n", - "import torch.nn.functional as F\n", - "\n", - "%matplotlib inline" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "hello -> [104, 101, 108, 108, 111]\n", - "hola -> [104, 111, 108, 97]\n" - ] - } - ], - "source": [ - "vocab_size = 256 # 총 아스키 코드 개수\n", - "x_ = list(map(ord, \"hello\")) # 아스키 코드 리스트로 변환\n", - "y_ = list(map(ord, \"hola\")) # 아스키 코드 리스트로 변환\n", - "print(\"hello -> \", x_)\n", - "print(\"hola -> \", y_)" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [], - "source": [ - "x = torch.LongTensor(x_)\n", - "y = torch.LongTensor(y_)" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [], - "source": [ - "class Seq2Seq(nn.Module):\n", - " def __init__(self, vocab_size, hidden_size):\n", - " super(Seq2Seq, self).__init__()\n", - " self.n_layers = 1\n", - " self.hidden_size = hidden_size\n", - " self.embedding = nn.Embedding(vocab_size, hidden_size)\n", - " self.encoder = nn.LSTM(hidden_size, hidden_size)\n", - " self.decoder = nn.LSTM(hidden_size, hidden_size)\n", - " self.project = nn.Linear(hidden_size, vocab_size)\n", - "\n", - " def forward(self, inputs, targets):\n", - " # 인코더에 들어갈 입력\n", - " initial_state = self._init_state()\n", - " embedding = self.embedding(inputs).unsqueeze(1)\n", - " # embedding = [seq_len, batch_size, embedding_size]\n", - " \n", - " # 인코더 (Encoder)\n", - " encoder_output, encoder_state = self.encoder(embedding, initial_state)\n", - " # encoder_output = [seq_len, batch_size, hidden_size]\n", - " # encoder_state = [n_layers, seq_len, hidden_size]\n", - "\n", - " # 디코더에 들어갈 입력\n", - " decoder_state = encoder_state\n", - " decoder_input = torch.LongTensor([[0]])\n", - " \n", - " # 디코더 (Decoder)\n", - " outputs = []\n", - " for i in range(targets.size()[0]): \n", - " decoder_input = self.embedding(decoder_input)\n", - " decoder_output, decoder_state = self.decoder(decoder_input, decoder_state)\n", - " \n", - " # 디코더의 출력값으로 다음 글자 예측하기\n", - " projection = self.project(decoder_output.view(1, -1)) # batch x vocab_size\n", - " prediction = F.softmax(projection, dim=1) # batch x vocab_size\n", - " outputs.append(prediction)\n", - " \n", - " # 디코더 입력 갱신\n", - " _, top_i = prediction.data.topk(1) # 1 x 1\n", - " decoder_input = top_i\n", - "\n", - " outputs = torch.stack(outputs).squeeze()\n", - " return outputs\n", - " \n", - " def _init_state(self, batch_size=1):\n", - " weight = next(self.parameters()).data\n", - " return (\n", - " weight.new(self.n_layers, batch_size, self.hidden_size).zero_(),\n", - " weight.new(self.n_layers, batch_size, self.hidden_size).zero_()\n", - " )" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Seq2Seq(\n", - " (embedding): Embedding(256, 16)\n", - " (encoder): LSTM(16, 16)\n", - " (decoder): LSTM(16, 16)\n", - " (project): Linear(in_features=16, out_features=256, bias=True)\n", - ")\n" - ] - } - ], - "source": [ - "seq2seq = Seq2Seq(vocab_size, 16)\n", - "print(seq2seq)" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [], - "source": [ - "criterion = nn.CrossEntropyLoss()\n", - "optimizer = torch.optim.Adam(seq2seq.parameters(), lr=1e-3)" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - " 반복:0 오차: 5.5456109046936035\n", - "['B', 'J', 'È', 'ä']\n", - "\n", - " 반복:100 오차: 5.445922374725342\n", - "['h', 'o', 'l', 'l']\n", - "\n", - " 반복:200 오차: 4.915400505065918\n", - "['h', 'o', 'l', 'l']\n", - "\n", - " 반복:300 오차: 4.660857677459717\n", - "['h', 'o', 'l', 'a']\n", - "\n", - " 반복:400 오차: 4.599147319793701\n", - "['h', 'o', 'l', 'a']\n", - "\n", - " 반복:500 오차: 4.580879211425781\n", - "['h', 'o', 'l', 'a']\n", - "\n", - " 반복:600 오차: 4.571601867675781\n", - "['h', 'o', 'l', 'a']\n", - "\n", - " 반복:700 오차: 4.566298484802246\n", - "['h', 'o', 'l', 'a']\n", - "\n", - " 반복:800 오차: 4.563146114349365\n", - "['h', 'o', 'l', 'a']\n", - "\n", - " 반복:900 오차: 4.5610032081604\n", - "['h', 'o', 'l', 'a']\n" - ] - } - ], - "source": [ - "log = []\n", - "for i in range(1000):\n", - " prediction = seq2seq(x, y)\n", - " loss = criterion(prediction, y)\n", - " optimizer.zero_grad()\n", - " loss.backward()\n", - " optimizer.step()\n", - " loss_val = loss.data\n", - " log.append(loss_val)\n", - " if i % 100 == 0:\n", - " print(\"\\n 반복:%d 오차: %s\" % (i, loss_val.item()))\n", - " _, top1 = prediction.data.topk(1, 1)\n", - " print([chr(c) for c in top1.squeeze().numpy().tolist()])" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [ - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYUAAAD8CAYAAACYebj1AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAIABJREFUeJzt3XmcHHWd//HXp3vOzJlkJpNjMrkTIIEAGc4EuRQBWdhFUEEQAc2PXRd0dQ95gKs/3XVXWZDlJ6AILJ64CoLIDQIe3AkhkPsm9wmTTDLJzPT05/dH14yTuyaZmuqZfj8fj3p0V9W3qz+VgrzzrW9Vtbk7IiIiAIm4CxARkeyhUBARkQ4KBRER6aBQEBGRDgoFERHpoFAQEZEOCgUREemgUBARkQ4KBRER6ZAXdwFdVVVV5SNHjoy7DBGRXmXmzJmb3b36YO16XSiMHDmSGTNmxF2GiEivYmbvhWmn00ciItJBoSAiIh0UCiIi0kGhICIiHRQKIiLSQaEgIiIdFAoiItKh192ncKgWbWjk8dlrSSYSJBOQTCTISxiJhJGXMJIJoyCZoF9hkqrSQiqK8zumfgVJzCzuXRARiVxOhcIdLyw5pM/mJYwBJQWMqS5lXE0pU0b0Z9rYKgaWFnZzlSIi8TJ3j7uGLqmvr/fDuaM5nXZSaact7bS509aWeU21pWlpS7O9OcWW7S1s29nK1k7TpsZmlmzazuIN29nenMIMTh9fzTVTR3HauCr1JEQkq5nZTHevP1i7nOkptEskjILEof8F3pZ25qzZyvPzN/C/b67iM/e/wenjq7nlkmMYVF7UjZWKiPS8nOspdKeWVJqfvLqCW59dRElhHg9cfQKThlXEXZaIyF7C9hR09dFhKMhL8LnTRvPbv59KYV6CK+57nYXrG+MuS0TkkCkUusH4mjIe/PzJFCQTXPezmexoTsVdkojIIVEodJO6gf2447LjWLFlB996fF7c5YiIHBKFQjc6efRAPn/aaH755ipmr2qIuxwRkS5TKHSz688aS1VpAd95ekHcpYiIdJlCoZuVFeUz/UOjeWXpFvUWRKTXUShE4PKTRlBWlMd9f14edykiIl2iUIhAaWEef3PcMJ6eu56tTa1xlyMiEppCISKfqB9OSyrNb2evibsUEZHQFAoRmTSsgrGDSnnq3fVxlyIiEppCIUIfnVjDGyve54MdLXGXIiISikIhQh+dOJi2tPPCgo1xlyIiEopCIUJHD6tgSEURz87TKSQR6R0UChEyMz40rppXl26hLd27nkYrIrlJoRCxU8cOZNuuFPPXbYu7FBGRg1IoROyU0QMBeGXp5pgrERE5OIVCxAaVFzG6uoRXl26JuxQRkYNSKPSAU0YP5M0VH2hcQUSynkKhBxxf15/tzSmWbNwedykiIgekUOgBx9ZVAvD2qg9irkRE5MAUCj1g1MASyovyeFuP0haRLKdQ6AGJhDF5eCWzVioURCS7RRoKZrbCzN41s7fNbMYB2p1gZikzuyTKeuJ03PBKFm1oZEdzKu5SRET2qyd6Cme6+7HuXr+vlWaWBL4DPNsDtcTmmNpK0g7zdBObiGSxbDh9dD3wMNCnnxo3cVg5gO5sFpGsFnUoOPCsmc00s+l7rjSzYcDfAHcfaCNmNt3MZpjZjE2bNkVUarQGlxdR2S+feWsVCiKSvaIOhWnufjxwHvAFM/vQHutvB/7F3dMH2oi73+Pu9e5eX11dHVWtkTIzjhpSrtNHIpLVIg0Fd18TvG4EHgFO3KNJPfBLM1sBXALcZWZ/HWVNcTpqSDkL1zeSajtgBoqIxCayUDCzEjMra38PnAPM6dzG3Ue5+0h3Hwk8BPyduz8aVU1xO3JIOc2pNMs374i7FBGRfYqyp1AD/NnMZgNvAE+4+9Nmdp2ZXRfh92ato4ZmBpt1CklEslVeVBt292XA5H0s/8F+2n82qlqyxZjqUgqSCeat3cZFxw6LuxwRkb1kwyWpOaMgL8HYQaXMX98YdykiIvukUOhh42tKWbxBoSAi2Umh0MPG1ZSxbusutu1qjbsUEZG9KBR62PiaMgAWb9BvK4hI9lEo9LDxNaUAOoUkIllJodDDhvfvR1F+gkXqKYhIFlIo9LBEwhg7qJTFG9VTEJHso1CIwfhBZSzS6SMRyUIKhRiMqyljw7Zmtu7UFUgikl0UCjGYMFiDzSKSnRQKMRg3KHNZqgabRSTbKBRiMKyymH4FSY0riEjWUSjEIJEwxukKJBHJQgqFmIyrKdPpIxHJOgqFmIyvKWVTYzMNTS1xlyIi0kGhEJNxNRpsFpHso1CISfuD8RZqsFlEsohCISZDK4ooLczTvQoiklUUCjExM8bXlLJAv8ImIllEoRCjCYPLWbShEXePuxQREUChEKsJNaU0NLWysbE57lJERIAQoWBmXzSzcsu4z8zeMrNzeqK4vm7C4HIAFuoUkohkiTA9hWvcfRtwDtAfuBL4z0iryhETBrdflqpQEJHsECYULHg9H/ipu8/ttEwOw4CSAqrLCjXYLCJZI0wozDSzZ8mEwjNmVgakoy0rd0yo0Q/uiEj2CBMK1wJfBU5w9yYgH7g60qpyyPggFNJpXYEkIvELEwqnAAvdvcHMrgBuBrZGW1buOGJwGbta06x8vynuUkREQoXC3UCTmU0GvgIsBX4SaVU5ZPxgPe5CRLJHmFBIeebuqouA77v7nUBZtGXljvE1mZ/mXKTBZhHJAnkh2jSa2Y1kLkU9zcwSZMYVpBv0K8ijbkA/FqinICJZIExP4ZNAM5n7FdYDtcAtkVaVY8bXlKmnICJZ4aChEATBz4EKM7sA2OXuocYUzGyFmb1rZm+b2Yx9rP+0mb0TtHklGLfIOUcMLmP55h00p9riLkVEclyYx1x8AngDuBT4BPC6mV3She84092Pdff6faxbDpzu7kcD3wLu6cJ2+4wJg8tIpZ0lG/WDOyISrzBjCjeRuUdhI4CZVQPPAw8d7pe7+yudZl8jc2oq50wcmnkG0ty125g4tCLmakQkl4UZU0i0B0JgS8jPATjwrJnNNLPpB2l7LfBUyO32KSMHllBamMecNbr9Q0TiFaan8LSZPQM8GMx/Engy5PanufsaMxsEPGdmC9z9j3s2MrMzyYTCtH1tJAiU6QB1dXUhv7r3SCSMo4aWKxREJHZhBpr/icy5/mOC6R53/5cwG3f3NcHrRuAR4MQ925jZMcC9wEXuvmU/27nH3evdvb66ujrMV/c6k4ZWMG/dNtr0uAsRiVGYngLu/jDwcFc2bGYlZE49NQbvzwG+uUebOuA3wJXuvqgr2+9rjq4tZ9fLaZZu2s74Gt0bKCLx2G8omFkjmTGBvVYB7u7lB9l2DfCImbV/zy/c/Wkzu47MBn4A/CswELgraJfaz1VKfd6kYIB5zpqtCgURic1+Q8HdD+tvJndfBux130EQBu3vPwd87nC+p68YXV1KcX6SOWu2cfHxcVcjIrlKv9GcJZIabBaRLKBQyCKThpYzd+1WDTaLSGwUClnk2LpKdrS0sXijnoMkIvEI85iL682sf08Uk+um1A0AYOZ7H8RciYjkqjA9hRrgTTP7lZmda8FlQtL9hg8opqq0UKEgIrEJc/PazcA44D7gs8BiM/u2mY2JuLacY2ZMGVHJWwoFEYlJqDGF4JfX1gdTCugPPGRm342wtpw0ZUR/VmxpYvP25rhLEZEcFGZM4YtmNhP4LvAycLS7/y0wBfh4xPXlnCkjMsM36i2ISBzC9BQGABe7+0fd/dfu3grg7mnggkiry0ETh1ZQkEwwc6VCQUR63kGffeTuXzez483sIjKPvXjZ3d8K1s2PusBcU5SfZNKwcmauUCiISM8Lc/roa8CPyTyjqAr4HzO7OerCctkJowYwe3UDTS2puEsRkRwT5vTRFWR+ee3r7v514GTgymjLym1Tx1TR2ua8sfz9uEsRkRwTJhTWAkWd5guBNdGUIwAnjBxAQTLBK0v3+fMSIiKRCfN7CluBuWb2HJkxhY8Ab5jZHQDufkOE9eWk4oIkx9VV8vKSzXGXIiI5JkwoPBJM7V6KphTpbOrYKr73/CI+2NFC/5KCuMsRkRwR5uqjH5tZATA+WLSw/bJUic7UsQO57Tl4eelmLjhmaNzliEiOCHP10RnAYuBO4C5gkZl9KOK6ct7k2koqivN5YcHGuEsRkRwS5vTRrcA57r4QwMzGAw+SuaNZIpKXTHDmhGpeWriJtrSTTOg5hCISvTBXH+W3BwKAuy8C8qMrSdqdfWQN7+9oYZbubhaRHhImFGaY2b1mdkYw/QiYEXVhAqdPqCYvYTw/X6eQRKRnhAmFvwXmATcE07xgmUSsvCifk0YP4Pn5G+IuRURyxAFDwcySwP3ufpu7XxxM33N3Pde5h5w7aQhLNm5n4Xr9RKeIRO+AoeDubcCI4JJUicF5kwaTTBiPzdZN5CISvTBXHy0DXjazx4Ad7Qvd/bbIqpIOVaWFnDpmIL+bvY5/PGcC+jVUEYlSmDGFpcDjQduyYCqNsijZ3V9NHsrK95uYvXpr3KWISB8Xpqcwz91/3XmBmV0aUT2yDx+dOJivPTqHh2eu5tjhlXGXIyJ9WJiewo0hl0lEKorz+djRQ3h01hr9xoKIRGq/PQUzOw84HxjW/kTUQDmgv5l62GUn1fGbWWt4/J11fKJ+eNzliEgfdaCewloyN6ntAmZ2mh4DPhp9adJZ/Yj+jB1UyoNvrIy7FBHpw/bbU3D32cBsM/uFnooaPzPjUycM59+emM/ctVuZOLQi7pJEpA8KM6Zwopk9Z2aLzGyZmS03s2WRVyZ7uXTKcPoVJLn3T8vjLkVE+qgwoXAfcBswDTgBqA9eD8rMVpjZu2b2tpnt9bwky7jDzJaY2TtmdnxXis81Ff3y+dQJdfxu9lrWNuyMuxwR6YPChMJWd3/K3Te6+5b2qQvfcaa7H+vu9ftYdx4wLpimA3d3Ybs56ZppI3Hg/j+rtyAi3S9MKLxoZreY2Slmdnz71E3ffxHwE894Dag0syHdtO0+qbZ/Py44ZggPvrGSD3a0xF2OiPQxYULhJDKnjL5N5gd3bgX+K+T2HXjWzGaa2fR9rB8GrOo0vzpYJgfwhTPH0tTaxj1/0tCOiHSvML/RfOZhbH+au68xs0HAc2a2wN3/2NWNBIEyHaCuru4wyukbxteUceHkoTzw8gqunjqSQWVFcZckIn1EmN9orjGz+8zsqWD+KDO7NszG3X1N8LoReAQ4cY8ma4DOd2LVBsv23M497l7v7vXV1dVhvrrP+9KHx9PSlubul5bGXYqI9CFhTh89ADwDDA3mFwFfOtiHzKzEzMra3wPnAHP2aPYY8JngKqSTyQxqrwtZe04bVVXCJcfX8vPXVupKJBHpNmFCocrdfwWkAdw9BbSF+FwN8Gczmw28ATzh7k+b2XVmdl3Q5kkyj+ZeAvwI+Luu7kAuu/7ssQB877lFMVciIn1FmKek7jCzgWQGjWn/F/3BPuTuy4DJ+1j+g07vHfhC6GplN7X9+3HVqSO498/L+ezUkbrLWUQOW5iewpfJnOYZY2YvAz8Bro+0Kgnt788aR2VxPv/+xHwyGSsicugOGgru/hZwOnAq8H+Aie7+TtSFSTgVxfl86cPjeWXpFn4/f2Pc5YhILxemp4C7p9x9rrvP0cPxss/lJ9UxurqEbz85n9a2dNzliEgvFioUJLvlJxPcdP6RLNu8g5+/9l7c5YhIL6ZQ6CPOOmIQ08ZWcfvvF9PQpMdfiMihCXPz2tTgPgPM7Aozu83MRkRfmnSFmXHzBUfSuCvFfz27MO5yRKSXCtNTuBtoMrPJwFeApWSuQJIsc8Tgcq48eQQ/f30lc9Yc9KphEZG9hAmFVHA/wUXA9939TqAs2rLkUP3DR8YzsKSAr/12Dum0LlEVka4JEwqNZnYjcAXwhJklgPxoy5JDVVGcz1fPO5JZKxt4aObquMsRkV4mTCh8EmgGrnX39WQeWndLpFXJYbn4uGFMGdGf/3x6AVubdAWxiIQXqqcA/Le7/8nMxgPHAg9GW5YcjkTC+OZFE2loatGgs4h0SZhQ+CNQaGbDgGeBK8k8OVWy2MShFcGg83sadBaR0MKEgrl7E3AxcJe7XwpMirYs6Q5fPmcC/fsV8K8adBaRkEKFgpmdAnwaeKILn5OYZQadj+CtlQ089JYGnUXk4ML85f4l4EbgEXefa2ajgRejLUu6y8ePr2XKiP585ykNOovIwYV5Suof3P1C4E4zK3X3Ze5+Qw/UJt2gfdD5g6YWbn1Og84icmBhHnNxtJnNAuYC88xspplNjL406S7tg84/e02DziJyYGFOH/0Q+LK7j3D3OjKPuvhRtGVJd9Ogs4iEESYUSty9YwzB3V8CSiKrSCKhQWcRCSNMKCwzs6+Z2chguhlYFnVh0v0+fnwtx9dVcsszC9nRnIq7HBHJQmFC4RqgGvgN8DBQFSyTXiaRMG762JFsamzm3j8tj7scEclCeQdaaWZJ4CZdbdR3TBkxgHMnDuaHf1zK5SfVUV1WGHdJIpJFDthTcPc2YFoP1SI95F/OO4KWVJrbn18UdykikmXCnD6aZWaPmdmVZnZx+xR5ZRKZUVUlfPqkOn755iqWbNwedzkikkXChEIRsAU4C/irYLogyqIkejecPY7i/CS3PLMg7lJEJIsccEwBwN2v7olCpGcNLC3kc6eN4vbnFzNv7TaOGloed0kikgXC3NH8YzOr7DTf38zuj7Ys6QlXTx1FWVEed/x+cdyliEiWCHP66Bh3b2ifcfcPgOOiK0l6SkVxPtdMHcXTc9czb+22uMsRkSwQJhQSZta/fcbMBhDitJP0DtdMHUVZoXoLIpIRJhRuBV41s2+Z2beAV4DvRluW9JSKfvlcPS3TW5i/Tr0FkVwX5tHZPyHzq2sbgulid/9p1IVJz7k26C38vxfUWxDJdaF+Qc3d57n794NpXle+wMySZjbLzB7fx7o6M3sxWP+OmZ3flW1L96jol89np47kyXfXs3B9Y9zliEiMeuJnNb8IzN/PupuBX7n7ccCngLt6oB7Zh2umjqKkIKnegkiOizQUzKwW+Bhw736aONB+gXwFsDbKemT/+pcUcNWpI3ni3XUs3qDegkiuirqncDvwz0B6P+u/AVxhZquBJ4HrI65HDuBzp42mOD/J919cEncpIhKTyELBzC4ANrr7zAM0uwx4wN1rgfOBn5rZXjWZ2XQzm2FmMzZt2hRRxTKgpIArTx7B72avZekmPRNJJBdF2VOYClxoZiuAXwJnmdnP9mhzLfArAHd/lcxzlqr23JC73+Pu9e5eX11dHWHJ8vkPjaYgL8GdL6i3IJKLIgsFd7/R3WvdfSSZQeQX3P2KPZqtBM4GMLMjyYSCugIxqiot5IqTRvDo22tYsXlH3OWISA/riauPdmNm3zSzC4PZrwCfN7PZwIPAZ91dvyofs+mnjyY/mdDYgkgO6pHHVbj7S8BLwft/7bR8HpnTTJJFBpUVcflJdfzk1fe44axx1A3sF3dJItJDerynIL3DdaePIZkw7npJvQWRXKJQkH2qKS/ishOG89DM1ax6vynuckSkhygUZL+uO2MMCTPu/sPSuEsRkR6iUJD9GlJRzKX1tfx6xirWNOyMuxwR6QEKBTmgvztzLAA/eEm9BZFcoFCQAxpWWcwlU2r53zdXaWxBJAcoFOSgbjh7HMmE8e0n9/ewWxHpKxQKclBDKor52zPG8NSc9byydHPc5YhIhBQKEsr0D42mtn8x3/zdPFJt+3vorYj0dgoFCaUoP8lN5x/JgvWN/M/LK+IuR0QiolCQ0M6dNJgPH1nDrc8t1MPyRPoohYKEZmb8219PIj+R4Ku/eYd0Ws8uFOlrFArSJYMrirjpY0fy2rL3+fGrK+IuR0S6mUJBuuyTJwzn7CMG8R9PLmDu2q1xlyMi3UihIF1mZtxy6WT6l+Rz/S9msaM5FXdJItJNFApySAaUFHD7J49j+ZYd3PzoHPTbSCJ9g0JBDtkpYwbypbPH88isNdz7p+VxlyMi3UChIIfl+rPGct6kwfzHU/N5ceHGuMsRkcOkUJDDkkgYt35iMkcMLueGX8xiycbGuEsSkcOgUJDD1q8gjx9dVU9hfoKr7n+T9Vt3xV2SiBwihYJ0i2GVxTxw9Yls3dnKlfe9TkNTS9wlicghUChIt5k0rIJ7PjOF97Y0cfUDb+pSVZFeSKEg3erUMVXccdmxzF7VwGfuf4Ntu1rjLklEukChIN3u3ElDuPPy43lndQOX/+g1tmxvjrskEQlJoSCROO/oIdxzZT2LN2znwu+/zJw1ehyGSG+gUJDInHnEIB667lTcnY/f/Qo/e+09PVlVJMspFCRSR9dW8Nj10zhh5ABufnQOl9/7Gsv1WwwiWUuhIJGrKi3kp9eeyH9cfDRz12zjI7f9ga89OoeNjbqfQSTb5MVdgOQGM+OyE+s4+4hB3PHCYh58YyW/nrmKS6bUcs3UUYyuLo27RBEBrLc93bK+vt5nzJgRdxlymFZs3sFdLy3h0VlraWlLc8aEai6ZUsuHj6yhKD8Zd3kifY6ZzXT3+oO2UyhInDY1NvPT197jV2+uYv22XZQX5XH+0UP4yFE1TB1bpYAQ6SZZEwpmlgRmAGvc/YJ9rP8E8A3AgdnufvmBtqdQ6Jva0s6rS7fw8FureW7eBrY3pyjOT3LauCrOOmIQp4wZSN2AfphZ3KWK9EphQ6EnxhS+CMwHyvdcYWbjgBuBqe7+gZkN6oF6JAslE8a0cVVMG1dFc6qN15e9z/PzN/D8vA08O28DAEMqijh59EBOGT2Q4+oqGVNdSiKhkBDpTpH2FMysFvgx8O/Al/fsKZjZd4FF7n5v2G2qp5Bb3J2lm7bz6rL3eW3pFl5btoUtOzIP2ystzGPSsHIm11ZyTG0lRw+roLZ/sYJCZB+ypadwO/DPQNl+1o8HMLOXgSTwDXd/OuKapBcxM8YOKmPsoDKuPHkE7s6Sjdt5e1UD76zeyjurG7j/5eW0tmX+cVOcn2R8TSnjasqYUFPG+MFljK8pZXB5kU49iYQQWSiY2QXARnefaWZnHOD7xwFnALXAH83saHdv2GNb04HpAHV1dVGVLL2AmTGupoxxNWVcWj8cgOZUGwvWNTJ37TYWbWhk8cZGXlq4iYdmru74XFF+ghEDShgxsF8wlTByYGZ+SEUReUndsiMC0fYUpgIXmtn5QBFQbmY/c/crOrVZDbzu7q3AcjNbRCYk3uy8IXe/B7gHMqePIqxZeqHCvCSTh1cyeXjlbsvf39GSCYkNjazY0sR7W5pYvnkHf1i0ieZUuqNdMmHUlBUytLKYIZXFDK0oyrzv9DqgpEA9DckJPXJJatBT+Md9jCmcC1zm7leZWRUwCzjW3bfsb1saU5DDlU47Gxp3sWJzEyvf38HK95tY17CLtVt3sm7rLtY17KKlLb3bZwrzElSXFVJVWkhVaUHwGrzvWF7IwJICyovzSWpcQ7JMtowp7MXMvgnMcPfHgGeAc8xsHtAG/NOBAkGkOyQSxpCKYoZUFHPKmIF7rU+nnS07Wli3dSdrG3axtmEn67buZPP2FjZvb2b1BzuZvXorW7Y3s7/n+5UX5VHZr4DKfvlUFOdn3hfnd8y3L6sozqesKI/SwryOV53Kkjjp5jWRQ5ROOx80tXSExebtzby/o4WGpla27sxMDU0tNOxsZWtTKw3B/MEeFFuUn6CsKJ+ywjxKdwuMvwRIaVFmWUlBHsUFSfoVJCnOTwbv8+hXkKQoP7M8XyEjZHFPQaSvSCSMgaWFDCwtZMJ+L7DbXTrtbG9JZUKiqZWGnS3saE7RuCszbW/OTB3vd7XSuCvFe1uaOpY17mo9aLB0lp80ivMzYVEchEe/gkyAFOYlKcxLUJiXoCB4LcxPUpDcfVlBe7v8RGZde5tgvig/QV4iQX5egvyEkZdMkJc0CpIJ8hJGMmEak+klFAoiPSiRMMqL8ikvymf4gEPbhruzs7WN7UFINLW0sau1jaaWzLSzNbNsZzA1tQavLSl2tqbZ2ZJZ37grxeZUCy2pNppTaVpS6U6vbV0KnjAKgqDISxgFeZkQyUsa+ckE+UnLhEow/5flmVDZfZmRMAvCJkEywW6v7SGUTBhJy7zmdfpMImG7tclLBOuSwfas0+f3aLvnZ9rnE2YkjI7wa/9uC5a1r+8NwahQEOllzCw4RZRHVI8AcHdSad8rKNrnm4P5zmGSakuTanNa2oL36fb3TqotTUvwmko7rW1pWoN1rWmnNZUmlU7T2uaZ15SzPZXKrG9vG7RrTTvpdKa+9tc2d9rSmSmbJYxMQCSCEOl43x4udHpvJBLBfNDuUycM53OnjY60RoWCiOzFzDr+1V5SGHc14bk7aYdUOk06nXltD4u2IDxSbX9533ldKr37fGZZmnTwmbTv3qY9lNqC70ynM23agte0Z57p5e60paHN298H7Tu3DZa1dbzPfMb9L3W6Z36bJGoKBRHpM8yMpEEy0f50XT1lt6t0WYKIiHRQKIiISAeFgoiIdFAoiIhIB4WCiIh0UCiIiEgHhYKIiHRQKIiISIde95RUM9sEvHeIH68CNndjOb2B9jk3aJ9zw+Hs8wh3rz5Yo14XCofDzGaEeXRsX6J9zg3a59zQE/us00ciItJBoSAiIh1yLRTuibuAGGifc4P2OTdEvs85NaYgIiIHlms9BREROYCcCQUzO9fMFprZEjP7atz1dBczG25mL5rZPDOba2ZfDJYPMLPnzGxx8No/WG5mdkfw5/COmR0f7x4cGjNLmtksM3s8mB9lZq8H+/W/ZlYQLC8M5pcE60fGWffhMLNKM3vIzBaY2XwzO6UvH2cz+4fgv+k5ZvagmRX1xeNsZveb2UYzm9NpWZePq5ldFbRfbGZXHWo9OREKZpYE7gTOA44CLjOzo+KtqtukgK+4+1HAycAXgn37KvB7dx8H/D6Yh8yfwbhgmg7c3fMld4svAvM7zX8H+J67jwU+AK4Nll8LfBAs/17Qrrf6b+Bpdz8CmExm//vkcTazYcANQL27TyLzazmfom8e5weAc/dY1qXjamYDgK8DJwEnAl9vD5Iu8+An4vryBJwYobd5AAACy0lEQVQCPNNp/kbgxrjrimhffwt8BFgIDAmWDQEWBu9/CFzWqX1Hu94yAbXB/yhnAY8DRuaGnrw9jzfwDHBK8D4vaGdx78Mh7HMFsHzP2vvqcQaGAauAAcFxexz4aF89zsBIYM6hHlfgMuCHnZbv1q4rU070FPjLf2DtVgfL+pSgy3wc8DpQ4+7rglXrgZrgfV/4s7gd+GcgHcwPBBrcPRXMd96njv0N1m8N2vc2o4BNwP8Ep83uNbMS+uhxdvc1wH8BK4F1ZI7bTPr+cW7X1ePabcc7V0KhzzOzUuBh4Evuvq3zOs/806FPXGZmZhcAG919Zty19LA84Hjgbnc/DtjBX04pAH3uOPcHLiIThkOBEvY+xZITevq45koorAGGd5qvDZb1CWaWTyYQfu7uvwkWbzCzIcH6IcDGYHlv/7OYClxoZiuAX5I5hfTfQKWZ5QVtOu9Tx/4G6yuALT1ZcDdZDax299eD+YfIhERfPc4fBpa7+yZ3bwV+Q+bY9/Xj3K6rx7XbjneuhMKbwLjgyoUCMgNWj8VcU7cwMwPuA+a7+22dVj0GtF+BcBWZsYb25Z8JrmI4GdjaqZua9dz9RnevdfeRZI7jC+7+aeBF4JKg2Z772/7ncEnQvtf9a9rd1wOrzGxCsOhsYB599DiTOW10spn1C/4bb9/fPn2cO+nqcX0GOMfM+ge9rHOCZV0X9wBLDw7knA8sApYCN8VdTzfu1zQyXct3gLeD6Xwy51N/DywGngcGBO2NzJVYS4F3yVzdEft+HOK+nwE8HrwfDbwBLAF+DRQGy4uC+SXB+tFx130Y+3ssMCM41o8C/fvycQb+L7AAmAP8FCjsi8cZeJDMuEkrmR7htYdyXIFrgv1fAlx9qPXojmYREemQK6ePREQkBIWCiIh0UCiIiEgHhYKIiHRQKIiISAeFgoiIdFAoiIhIB4WCiIh0+P84aahfbmi3hQAAAABJRU5ErkJggg==\n", - "text/plain": [ - "
" - ] - }, - "metadata": { - "needs_background": "light" - }, - "output_type": "display_data" - } - ], - "source": [ - "import matplotlib.pyplot as plt\n", - "plt.plot(log)\n", - "plt.ylabel('cross entropy loss')\n", - "plt.show()" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.0" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git "a/07-\354\210\234\354\260\250\354\240\201\354\235\270_\353\215\260\354\235\264\355\204\260\353\245\274_\354\262\230\353\246\254\355\225\230\353\212\224_RNN/02-sequence-to-sequence-lstm.py" "b/07-\354\210\234\354\260\250\354\240\201\354\235\270_\353\215\260\354\235\264\355\204\260\353\245\274_\354\262\230\353\246\254\355\225\230\353\212\224_RNN/02-sequence-to-sequence-lstm.py" deleted file mode 100644 index 311f13e..0000000 --- "a/07-\354\210\234\354\260\250\354\240\201\354\235\270_\353\215\260\354\235\264\355\204\260\353\245\274_\354\262\230\353\246\254\355\225\230\353\212\224_RNN/02-sequence-to-sequence-lstm.py" +++ /dev/null @@ -1,105 +0,0 @@ -#!/usr/bin/env python -# coding: utf-8 - -# # Seq2Seq 기계 번역 (LSTM) -# 이번 프로젝트에선 임의로 Seq2Seq 모델을 아주 간단화 시켰습니다. -# 한 언어로 된 문장을 다른 언어로 된 문장으로 번역하는 덩치가 큰 모델이 아닌 -# 영어 알파벳 문자열("hello")을 스페인어 알파벳 문자열("hola")로 번역하는 Mini Seq2Seq 모델을 같이 구현해 보겠습니다. - -import torch -import torch.nn as nn -import torch.nn.functional as F - -get_ipython().run_line_magic('matplotlib', 'inline') - - -vocab_size = 256 # 총 아스키 코드 개수 -x_ = list(map(ord, "hello")) # 아스키 코드 리스트로 변환 -y_ = list(map(ord, "hola")) # 아스키 코드 리스트로 변환 -print("hello -> ", x_) -print("hola -> ", y_) - - -x = torch.LongTensor(x_) -y = torch.LongTensor(y_) - - -class Seq2Seq(nn.Module): - def __init__(self, vocab_size, hidden_size): - super(Seq2Seq, self).__init__() - self.n_layers = 1 - self.hidden_size = hidden_size - self.embedding = nn.Embedding(vocab_size, hidden_size) - self.encoder = nn.LSTM(hidden_size, hidden_size) - self.decoder = nn.LSTM(hidden_size, hidden_size) - self.project = nn.Linear(hidden_size, vocab_size) - - def forward(self, inputs, targets): - # 인코더에 들어갈 입력 - initial_state = self._init_state() - embedding = self.embedding(inputs).unsqueeze(1) - # embedding = [seq_len, batch_size, embedding_size] - - # 인코더 (Encoder) - encoder_output, encoder_state = self.encoder(embedding, initial_state) - # encoder_output = [seq_len, batch_size, hidden_size] - # encoder_state = [n_layers, seq_len, hidden_size] - - # 디코더에 들어갈 입력 - decoder_state = encoder_state - decoder_input = torch.LongTensor([[0]]) - - # 디코더 (Decoder) - outputs = [] - for i in range(targets.size()[0]): - decoder_input = self.embedding(decoder_input) - decoder_output, decoder_state = self.decoder(decoder_input, decoder_state) - - # 디코더의 출력값으로 다음 글자 예측하기 - projection = self.project(decoder_output.view(1, -1)) # batch x vocab_size - prediction = F.softmax(projection, dim=1) # batch x vocab_size - outputs.append(prediction) - - # 디코더 입력 갱신 - _, top_i = prediction.data.topk(1) # 1 x 1 - decoder_input = top_i - - outputs = torch.stack(outputs).squeeze() - return outputs - - def _init_state(self, batch_size=1): - weight = next(self.parameters()).data - return ( - weight.new(self.n_layers, batch_size, self.hidden_size).zero_(), - weight.new(self.n_layers, batch_size, self.hidden_size).zero_() - ) - - -seq2seq = Seq2Seq(vocab_size, 16) -print(seq2seq) - - -criterion = nn.CrossEntropyLoss() -optimizer = torch.optim.Adam(seq2seq.parameters(), lr=1e-3) - - -log = [] -for i in range(1000): - prediction = seq2seq(x, y) - loss = criterion(prediction, y) - optimizer.zero_grad() - loss.backward() - optimizer.step() - loss_val = loss.data - log.append(loss_val) - if i % 100 == 0: - print("\n 반복:%d 오차: %s" % (i, loss_val.item())) - _, top1 = prediction.data.topk(1, 1) - print([chr(c) for c in top1.squeeze().numpy().tolist()]) - - -import matplotlib.pyplot as plt -plt.plot(log) -plt.ylabel('cross entropy loss') -plt.show() - diff --git "a/07-\354\210\234\354\260\250\354\240\201\354\235\270_\353\215\260\354\235\264\355\204\260\353\245\274_\354\262\230\353\246\254\355\225\230\353\212\224_RNN/02-sequence-to-sequence.py" "b/07-\354\210\234\354\260\250\354\240\201\354\235\270_\353\215\260\354\235\264\355\204\260\353\245\274_\354\262\230\353\246\254\355\225\230\353\212\224_RNN/02-sequence-to-sequence.py" index 370254c..60c90eb 100644 --- "a/07-\354\210\234\354\260\250\354\240\201\354\235\270_\353\215\260\354\235\264\355\204\260\353\245\274_\354\262\230\353\246\254\355\225\230\353\212\224_RNN/02-sequence-to-sequence.py" +++ "b/07-\354\210\234\354\260\250\354\240\201\354\235\270_\353\215\260\354\235\264\355\204\260\353\245\274_\354\262\230\353\246\254\355\225\230\353\212\224_RNN/02-sequence-to-sequence.py" @@ -9,6 +9,7 @@ import torch import torch.nn as nn import torch.nn.functional as F +import random get_ipython().run_line_magic('matplotlib', 'inline') @@ -51,18 +52,15 @@ def forward(self, inputs, targets): # 디코더 (Decoder) outputs = [] - for i in range(targets.size()[0]): + + for i in range(targets.size()[0]): decoder_input = self.embedding(decoder_input) decoder_output, decoder_state = self.decoder(decoder_input, decoder_state) + projection = self.project(decoder_output) + outputs.append(projection) - # 디코더의 출력값으로 다음 글자 예측하기 - projection = self.project(decoder_output.view(1, -1)) # batch x vocab_size - prediction = F.softmax(projection, dim=1) # batch x vocab_size - outputs.append(prediction) - - # 디코더 입력 갱신 - _, top_i = prediction.data.topk(1) # 1 x 1 - decoder_input = top_i + #티처 포싱(Teacher Forcing) 사용 + decoder_input = torch.LongTensor([[targets[i]]]) outputs = torch.stack(outputs).squeeze() return outputs diff --git "a/07-\354\210\234\354\260\250\354\240\201\354\235\270_\353\215\260\354\235\264\355\204\260\353\245\274_\354\262\230\353\246\254\355\225\230\353\212\224_RNN/03-seq2seq_gru.py" "b/07-\354\210\234\354\260\250\354\240\201\354\235\270_\353\215\260\354\235\264\355\204\260\353\245\274_\354\262\230\353\246\254\355\225\230\353\212\224_RNN/03-seq2seq_gru.py" deleted file mode 100644 index 4ed663d..0000000 --- "a/07-\354\210\234\354\260\250\354\240\201\354\235\270_\353\215\260\354\235\264\355\204\260\353\245\274_\354\262\230\353\246\254\355\225\230\353\212\224_RNN/03-seq2seq_gru.py" +++ /dev/null @@ -1,112 +0,0 @@ -#!/usr/bin/env python -# coding: utf-8 - -# # Seq2Seq 기계 번역 -# 이번 프로젝트에선 임의로 Seq2Seq 모델을 아주 간단화 시켰습니다. -# 한 언어로 된 문장을 다른 언어로 된 문장으로 번역하는 덩치가 큰 모델이 아닌 -# 영어 알파벳 문자열("hello")을 스페인어 알파벳 문자열("hola")로 번역하는 Mini Seq2Seq 모델을 같이 구현해 보겠습니다. - -import numpy as np -import torch as th -import torch -import torch.nn as nn -import torch.nn.functional as F -from torch.autograd import Variable -from torch import optim - - -vocab_size = 256 # 총 아스키 코드 개수 -x_ = list(map(ord, "hello")) # 아스키 코드 리스트로 변환 -y_ = list(map(ord, "hola")) # 아스키 코드 리스트로 변환 -print("hello -> ", x_) -print("hola -> ", y_) - - -x = torch.LongTensor(x_) -y = torch.LongTensor(y_) - - -''' -미니 GRU 모델. -''' -class Seq2Seq_GRU(nn.Module): - def __init__(self, vocab_size, hidden_size): - super(Seq2Seq_GRU, self).__init__() - - self.n_layers = 1 - self.hidden_size = hidden_size - self.embedding = nn.Embedding(vocab_size, hidden_size) - self.encoder = nn.GRU(hidden_size, hidden_size) - self.decoder = nn.GRU(hidden_size * 2, hidden_size) - self.project = nn.Linear(hidden_size, vocab_size) - - def forward(self, inputs, targets): - # Encoder inputs and states - initial_state = self._init_state() - embedding = self.embedding(inputs).unsqueeze(1) - encoder_output, encoder_state = self.encoder(embedding, initial_state) - outputs = [] - - decoder_state = encoder_state - for i in range(targets.size()[0]): - decoder_input = self.embedding(targets)[i].view(1,-1, self.hidden_size) - decoder_input = th.cat((decoder_input, encoder_state), 2) - decoder_output, decoder_state = self.decoder(decoder_input, decoder_state) - projection = self.project(decoder_output)#.unsqueeze(0)) - outputs.append(projection) - - #_, top_i = prediction.data.topk(1) - - outputs = th.stack(outputs, 1).squeeze() - - return outputs - - def _init_state(self, batch_size=1): - weight = next(self.parameters()).data - return Variable(weight.new(self.n_layers, batch_size, self.hidden_size).zero_()) - - -model = Seq2Seq_GRU(vocab_size, 16) -pred = model(x, y) - - -criterion = nn.CrossEntropyLoss() -optimizer = th.optim.Adam(model.parameters(), lr=1e-3) - - -y_.append(3) -y_label = Variable(th.LongTensor(y_[1:])) - - -print(y_label.shape) -print(y_label) - - -log = [] -for i in range(10000): - prediction = model(x, y) - loss = criterion(prediction, y) - optimizer.zero_grad() - loss.backward() - optimizer.step() - loss_val = loss.data - log.append(loss_val) - if i % 100 == 0: - print("%d loss: %s" % (i, loss_val.item())) - _, top1 = prediction.data.topk(1, 1) - for c in top1.squeeze().numpy().tolist(): - print(chr(c), end=" ") - print() - - -import matplotlib.pyplot as plt -get_ipython().run_line_magic('matplotlib', 'inline') -plt.plot(log) -plt.xlim(0,150) -plt.ylim(0,15) -plt.ylabel('cross entropy loss') -plt.show() - - - - diff --git "a/07-\354\210\234\354\260\250\354\240\201\354\235\270_\353\215\260\354\235\264\355\204\260\353\245\274_\354\262\230\353\246\254\355\225\230\353\212\224_RNN/assets/encoder_decoder.png" "b/07-\354\210\234\354\260\250\354\240\201\354\235\270_\353\215\260\354\235\264\355\204\260\353\245\274_\354\262\230\353\246\254\355\225\230\353\212\224_RNN/assets/encoder_decoder.png" deleted file mode 100644 index 1cd7ce9e714e65a47429333ba7d42c891bee6d94..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 7483 zcmV-B9mL{^P)^y}iBg@bH?N znv08z>+9>@-rn8a-CSX7xwy8h#O=(?%+k`*adC05u&~G8|6N^OdwYA5l9H#Vr&C^X zgM))ZLqmR(wPRyrTUu8~M@NU8z+rBKR8&-de}7F)P4M{tNljHVGczV8CSY-nUvP~~ zS81ib=bN$CJ3~<~FE1Y-9}y7|=JNmW`Tuo=p)E5#dxDPJ>Hpc_?3Ag@lbos}C^2n) zmlzx(KS);+6&c9o|IOL%OHo_8&F-tc+?1rifRCtZaePZ!bI{h|thL6dtFpYo%D>6n zg^inxoVF(|H)UyXQe12@JWf4AODs1?#+ugmpz3c9FzyG&*B)~>;P+RZy>26)mubB$|21X-k zG#Uv3z<~~QpaUJ~KnFU|f&K#Im7C47N1U)t&`gsZ8S}>H7i!<>#%na{HgTff1Zf#a zbIi5yPmpTQ3C7&=*0md0q2qyix$FXcblQ~M^R65ptx5MgO|{wp?xrxRW#0x$v#SEV zTo%z`(*}WBwmxSzz9?!yzu7bm=!xqyLow9E$Pgo74voM9&~_uki8fc*qINArE(o;r zL=}}v3yfW3ES*cbueJoW3!rB7-0g=QC?U^x6-7i!Tay$CEu}N+SP+aH;FgVR;1AJ} zy!nN4ZsYdRUF$UJc1}?VjjJ*!j~e*U@446LybTVNOJyB%!fuNrFb6~3c2%rd)2yF@ zk&2HtKR9;)s`%2ik)F7*?S|qQ;!tyi#<>FVq22+~__Jo@OClKRdR0VqdQG}-k3n$i z3_THg^>fMeWvk=#s%u=DTdJx?V!5|IGYs8pS8E|tR5p#F+R~J<+wlOhF&W9G3=Wi% z5Nf7rDum0ZDk7@nISMOS6`oz0^h}UN&~2bszA1U1$YK% zD8zFN4IJn|2RhJ!4s@UceIk+!7Q!6eh@TGlN~zIg0QQIh7(E6c4Fv`=VwUXJN&dN% z515dlDGxH*KpisTzz(eD179gg5UTme9D4b{2%7mo6Z-jp1zY*RS4vd~^X*F=!hHMU zKse6@3h3duDiT?FJtXD!HR93cX78;6d7;E#*AcO;+s!;X@NSXR`@ z)vbwuToWC3#+}@GCa)7&hv*AQ26H~{lm+-KfaM6WZ}oMMYMOP-8f%8?pl+HjW)7&D zrU<;0br1^-5Spd|G|W&+l0#D~1Jl$bSu-_RLI99VGnAzO0X9kj12dHJfVgHIbf^P& z!xYn0L1V+}sXR);5_$e?O?^VDi=qT1Q4~^U*F&6=6J{jLj#H0-O z40Q{lEtcgtl^m{_jGpLkQVA1iM#!@fg*8g}5~;a8ei510)@R;>ab761<|>%CV$*Q< zOa8Ld+;-|dq0*&KsdUC5;fthjTQ7fArSCO~Oz*{K0iVUZFviJWBsD0#IHFj$Y}{{c z5dy@njXSx=fZY}2UXmts&K){AfSulI$0Zqi)sd2FY-ySnhUhQa8l^V_jG%v`t$|a= zN`dZ{OJ{acz~5>)UNejVt+U&juYFazy3vn8TMT!jAE8SDQaQU33n)|mk(>doH9&|U zvx323pW>1pVM#h!i>IT2S{90V%0i}~5TbATevGXuHQmHyh9Jxus%o~z|V{G{0nhA!ujHI;s6CJ`nxr}5- z>SyVI3Ml$nqYvvO(9Et-DXkF239#P@U>9{j83PREJ3#o1}kmBRFgT(*o{THZ{;v z4Cnx;7-K0=fajyW`UbbBJ?%(kQt6Z-_Oql%`O8GNN-Mp?QaaH+Y|2v5p^M$3o~?zT zacQ{t zJ%o-|)`2`Wu+kl4EW}C|ZgikGf!>-kW^MO^npySqseBSeKm$a*rHe_=ihyn#O6Jz2 zy+C=Sb;go%UiE!|%uCsKu4AQ-rJkloT^mXsHZA+wkrH~NEhwS4hypUt9+at}r;<&_ zSyXA*{wBFDv3giV*E*q!hq1uc4PcjugZz-tskd&eGHnGFB!wAC zS7rg$?oHsJbeeotkHubs%QQ$VrC0ypm|j#Ad&7)`IV6)3=8{ZKn4&0xj8ig70mNLK za)E@Ikft1h86*=EW|Pd$l_zE&c%{ln7i?}=IbJL(tk;fin9&8v0#bX68K?=6Yez*K zEVS_gQp-(K3EM0nwKa9UfYiV%Mvcv~(L5$3^5BkCyS$toEiVx3r)+w3oUn@;PSq-y z9W5>{zKvONadF9UD8(IJT&y^5m9UE=?&4y>ahDg1;#I5oYz z=qGHJvSq@SmzOcylIOcvsLGS~D_tK9zCZeYF!*lG?!G(vFHU9b&0z2?V({A}wZLqZ zvI=7N-yRJHD}1;*`WBr#s-VLobnf=MZ*dB-1!iv&_T6Cc@I6ZX7pF4zeZsyc*>@TH z<`*QaS*!2`p5roBNEnwgK8F?F&v-h;lPFIxei!!o{Yp1WMd}iMQXw+KXVM>CRq_Qt zn$#G;V8-WUkmo9kCAVXa3dJ)dtTjcy{}Si}HQg_5{T?Yjmnv5*V6uISl!^m}j=05g zXXO>&#Y*4xvOkAP&vsLKiK+DaPp95+8&hX9&Nuimfp7ay&3 zc1>r9UL4f^Or>+Bl(sd)6&L`&21+adUAqMoa)Gn=EBz^v53X3O6edT-VsTr3bn56K z52zqu!dD&v_|(y$hWNJp7+B>1AD|*Olfe$=rMNhIlF)=Md#=9*9_W_HFW;rWYufzr z(+$w4l*`v1QsBvg(!d8g(jZehznc8|>#uk1Uw<9Y3gM0s(7s1No6Z>5e?aH%&{gdk z4Jz&Mb+&3Yc&MhYWlGu;}9!eW~0Kjr^#1P0dtFN zoE$DU0B*L~I9Xqt5&RH-nZ@Un1qn>E#cRf8aZ6Z|ObB7NTsiKWczlNfH^{_{D%q`3 zX=J6;Es?+~i-C30j8(9$=8m&amzHKMHsbs;tAEuvW0kA>J9vw2Q2evF+9!3Nn_Y41 z6In;fEfzy6I~fx^--A~XF@xJSTn`d?J+7`qR|M11HvaAsn%QK;sSe%?!cD_L)0uKd42h&glZ4bEp!qSUAbS>P+j{m1JZU^d01iKGAoTlx5$+l zOgSoJjkwZOZf8eFiz}Wh&5(k*>69xVhD>*1hLq0LW?19X%unZzuaJ3J@^CS|n=XH= zE!zFt>G9ps)$#QHcy?D?{>I(5Q8`y$XE7wNUyJDq!n;LnkjbVl5?})1{i?Rgw4z)j zzyQMeV7ego{%j|}3RT^~jA-e?Lye#_$uJ~Ty~m-A#gNX&nuFtovV=*tYK77?zC-Po zOob?S*bk~9puHp=AZ?a-NOK4$%fdqjRQs{>R6rSf=og5L)sGXP4dK;sX_c+Rb61Jd zaZ9+b&E|u5Y zJA$>BmCEAH>S48lwcA+{_W_yGtG~5lx<}c0IcR6?w*Gqp+#z5-AZ%c2Nh2_Ka7O8G|RXLjU2mPbv zRS`~0gVpl(h}*b~*t}rx1#+pHqlHND->YB~q`Z;*!D#WWvCSbm9J##Ld(bH`IOfEeIWMBI_K z-z+PrPn}j4wOEj>c>Q!t6?e)U;Y)7Znz#oey2XuK69a^klJ-ytz{u3!01e8S-X0ub z?}b>)|2OQ8aIrLl%qHBXU2y`evbxbKkpPjg(^P&AOG7%z=8Vs-;E4vARI|!kct5|Y zaWmL{xvUIsuiC}I)x!f2S;3d$1J2vhI5m;xT%HmYjO<1 z6PmiF#t@}!J=2i#b1UaW`H|jj@nO?(1k@n*(TG%b$a%L;=BKYVwb1yH6qCzbCaF#N z>C4R$8Tkn*B|pu^86y%v!@&)3LX3s6w}J*f1+G@(5{8eFmsgnbNu}C3v|D_rrqrQ_ zQd~PkCw24}&_g|LWu}0+qJc$_0%4m5`knPd9;m_LM~_GgwV{%bTizO?M~xKd(6NWg zZ3-ybmi!Yv1*(^nG~B>j18e#DbV#}#q)AS%&&D;~+$pr%@Ar={5nFBthS0hGk2wIp z2KrmF6qo2+8v|Q+;xKS|j2P84J`~zL&3?_zjuEZU-;y3>^HU65Vqi;}+8WpvvxI>Q z%(m<12G}i38+W8nce;k>i?gdD_c+Xe#sjsNTt2`<7Vxr`4{%qr;ucW)`Ta^KL%Vov z-CtE=6G`71*~-d12f#+Oa$8^klWCEA6ePjGU|D$#e1Kko$8qQnQE>qC4HCD2`SCWc z5Bs^2@k_YK2c~ea-IR6-?P4(>?(*eF4`oO*Wk`dDR*$8LsRCI+dp5=9WtY${rqVrJ z<;qsPgM(ZVDi81>A85n7T!34_`$u)*14_5(3y{m6Kb2M)!2DT^;sW01>K0X+m$f{A zpFu_id=LGRIZ|$k#{GF$U|PUIL~_sUwaAj8wZbugWKk~@b~vCoelz`W(V83 zy2ba-dnnIeJ$h(*FEnh;3&G+>4_CWZ^>jpvy4OPaX180Uc28)H_;)ligcxs&tUp(uaoV@s(f6{&N)JLgj zh&_8&#)sXLa-4elEKZ%Aw7TBON%QH8?uq}m7bhp3XD@~)*3+L(PW)kbauPiKX?S8j z{ddexPX2{iaIU4Pe?9%z$;rPrOdED^$~+nVgxJYXPXomMW}Tdb-R?%HCYGXka`J)2 zM)I@2pPl`@X7Bvw&Goxi&we?3ov`P>oSi*?{^ab*KbrIF2&Y~@d-eM4$@AxDXFoiD zh1fq^2zZ84e}DGs9b!M9CG0;x59jBt=RdzY`^U3?o)PSyuX;B(%@_Ybsb?Ac4;J7* zXYW!L%kx9R-o1K`QZL@0k9;zid&)4^tDUoeG7DfXx&3HrFc;rjti{UA<-Aqavbx6S zdR0zjDP3Tykgar9{;mIcvy?hh9|7y$>~*57do@1SQt2;V=^E?fb5&O3sI%UPb69Wi zeQIJWeRuQ1yyAa1FA7uXAxr7Zwf`Bb-2Q`(Fo9fVP&;#56hK@=&V6GbjK?L`h| zB4}dP?Tu{ld=yD?x7U-vs5dm^VQ(lKtzHjEaMTr~s8X@l8(tsXCVZyv9@~|jKp7~e@+D0fOPvfuortu$qfk@f)x=(6eSP#!x zpTAGASvt`h|Ee`$aZO>GB^?It$3L=^y11>&P<^ zJWyBge}aprkYtZ4*}A4_Ch4Zg9{%EGOkP|8KSOypcB$z?RaL#Nir;9Z3BLlXk5v7! z*QKL#kOlP6RgYdF#y{?Q>h^1{A7#^{2q|q0dU$i8YY^|$>YP*e=>=X$je_hkFjF1G zds;L;!zUMZYRDtJRV}9N9ci1L6liW$$iPp5(=C@a?52RV<*e(i6fo(n3qD5ab`SZw zp%3s5WqKEk@kF#rH7DUFc}TDcTG_-j)t)|P zBLk&dZvkHw@di(wP@3ag(aLb=s5m_^Zf5V3stL|7*p%RqASeId4i8cnnQ9~&D z9?;i?Cyw3v7z|Yf0x0T{c;nV1Flsr&kMHVnN2(56tzlLZ>CHK%tR@oM>UAiiOT5HN zZ!bpd=|;bdNp#2b%b0eyD0KJ8MrZ+O^c)q7JaUz|#D~;#^Z8bm@ks#eEYHl&*d7#kp1=g}{FRR~U7t^tN+QAUT}wR3HN7UNrMELl$-*TwZs zh{sr+R10yHlZC7sB;GX+sH_m0rg!!XiN!{g3Uy+!5!Evj+rvqeVfdI5=&B$L#vlmr zm%0THmtXdI^qBnkxphA~a|2qUAE3SpL>+64qfJL3pX&f`YZIFQqQAtd1&H!bZ`r z%M<$@D%9r^h~N_KgA&!7_CY_wrpsg0{@@6pHZOJqP!1f|5zk%MCDu;ssD2q%R^-SO zMN;Ss+NnJEjHsb@fz}bstiAP4MbV86Z==&@C^f5^54=vt1cB*?lJ8Ydfq3&_0oR1} z<=u=%Q5DQRY3|840_xhlvJ55AQnU^Tv!>W37LZaV_LNvaN(F~ZVgV^-TZugm2!n#J z&kBv+mS)ti)+I z*yPkLpEfPoD(jo7VfbAQNLJG~B){2_Bz)*KcG?jGD?POEGGqW{WZ*z#TPm<46%L1U zt)&1t>+)JzO{X`8AUb<`Q)XhTTZL9SQPXLq6HBEjrSH>@Kx$@QPy-sesnSPv?D0xn z;ormt1$XlPwdX)9t=UjwbJc{Xq4wWdXhJQ7BOo^bhc{Zy4TlN{c+@yp z%dfsfQUx!t3=m?b1wqUt86C>N3`_|Gfhh=P5C91s0@<3wIqU`+Fapf1z)P4IoFia} zQ!MrTBtl6W zq4sOL)NBVpB~1y4HR2rr6Sm#~5HnXjHPX%hlw>c!bn`z~X&E`&0NRRcN4C4(29VqV z(6w8xOYKOw-h$BfRd+2)Ce&9<%!z4%87g6Iv`>aDY3l6wk1_J@F4tU-c*P&!2hYc@e<&W1BTc%rjx`p!eP10Co<2RhJ! zz68QEqlW$H)FYA)-p3~+1%H98NyVb<+(2Q~-z3Eku7cI}X}$laCl3#dA_qA7X1=1{ z1cF+RtF|Np5d*R+ZWjIlIcUnj)gRAqH1HZ}Q5S(R-lv(kPfyLjb?fJHI1Hn;{m?Wp zHhmwj-)L67p*I%gp;gX;y!PA|f>wCFZEr>vUPv|yqq;X|%yAqjrt@aA*%}(qu}{~_ zP=>CBl4Idqc1FgHVq0El0fVqPJqfS!3$_#+c5P>DE3SCziF>`K=zpeImtLpdv#ib9 z@oR8y18FRjm7%MQ11oUL3UIOy?1??#nE+T@YlVis6_|GG)I4v5%*_T4gs6@ee3$il z3&`uJ2~zL11FBtAb(^5k!ojU>19Y-kGj+lfLTUIZYbY zC9>wNXub}>s6KLOT0IG<5EVl?f-0!bhvUp2sY4x%TIbsiiU>!MWL$T [3, 244, 244] -original_img_view = original_img_view.transpose(0,2).transpose(0,1).detach().numpy() +original_img_view = img_tensor.squeeze(0).detach() # [1, 3, 244, 244] -> [3, 244, 244] +original_img_view = original_img_view.transpose(0,2).transpose(0,1).numpy() # 텐서 시각화 plt.imshow(original_img_view) @@ -86,7 +85,7 @@ def fgsm_attack(image, epsilon, gradient): # ## 적대적 예제 생성 # 이미지의 기울기값을 구하도록 설정 -img_tensor.requires_grad = True +img_tensor.requires_grad_(True) # 이미지를 모델에 통과시킴 output = model(img_tensor) @@ -121,7 +120,8 @@ def fgsm_attack(image, epsilon, gradient): # 시각화를 위해 넘파이 행렬 변환 -perturbed_data_view = perturbed_data.squeeze(0).transpose(0,2).transpose(0,1).detach().numpy() +perturbed_data_view = perturbed_data.squeeze(0).detach() +perturbed_data_view = perturbed_data_view.transpose(0,2).transpose(0,1).numpy() plt.imshow(perturbed_data_view) @@ -130,11 +130,11 @@ def fgsm_attack(image, epsilon, gradient): f, a = plt.subplots(1, 2, figsize=(10, 10)) -# 원본 사진 +# 원본 a[0].set_title(prediction_name) a[0].imshow(original_img_view) -# 복원된 사진 +# 적대적 예제 a[1].set_title(perturbed_prediction_name) a[1].imshow(perturbed_data_view) diff --git "a/09-\352\262\275\354\237\201\354\235\204_\355\206\265\355\225\264_\355\225\231\354\212\265\355\225\230\353\212\224_GAN/02-conditional-gan.py" "b/09-\352\262\275\354\237\201\354\235\204_\355\206\265\355\225\264_\355\225\231\354\212\265\355\225\230\353\212\224_GAN/02-conditional-gan.py" index c05f55f..f803c01 100644 --- "a/09-\352\262\275\354\237\201\354\235\204_\355\206\265\355\225\264_\355\225\231\354\212\265\355\225\230\353\212\224_GAN/02-conditional-gan.py" +++ "b/09-\352\262\275\354\237\201\354\235\204_\355\206\265\355\225\264_\355\225\231\354\212\265\355\225\230\353\212\224_GAN/02-conditional-gan.py" @@ -11,16 +11,17 @@ from torchvision import transforms, datasets from torchvision.utils import save_image import matplotlib.pyplot as plt +import numpy as np torch.manual_seed(1) # reproducible # Hyper Parameters -EPOCHS = 100 +EPOCHS = 300 BATCH_SIZE = 100 -USE_CDA = torch.cuda.is_available() -DEVICE = -1#torch.device("cuda" if USE_CUDA else "cpu") +USE_CUDA = torch.cuda.is_available() +DEVICE = torch.device("cuda" if USE_CUDA else "cpu") print("Using Device:", DEVICE) @@ -38,9 +39,6 @@ shuffle = True) - - - def one_hot_embedding(labels, num_classes): y = torch.eye(num_classes) return y[labels] @@ -68,8 +66,8 @@ def one_hot_embedding(labels, num_classes): # Device setting -# D = D.to(DEVICE) -# G = G.to(DEVICE) +D = D.to(DEVICE) +G = G.to(DEVICE) # Binary cross entropy loss and optimizer criterion = nn.BCELoss() @@ -77,23 +75,20 @@ def one_hot_embedding(labels, num_classes): g_optimizer = optim.Adam(G.parameters(), lr=0.0002) - - - total_step = len(train_loader) for epoch in range(EPOCHS): for i, (images, label) in enumerate(train_loader): - images = images.reshape(BATCH_SIZE, -1)#.to(DEVICE) + images = images.reshape(BATCH_SIZE, -1).to(DEVICE) - real_labels = torch.ones(BATCH_SIZE, 1)#.to(DEVICE) - fake_labels = torch.zeros(BATCH_SIZE, 1)#.to(DEVICE) + real_labels = torch.ones(BATCH_SIZE, 1).to(DEVICE) + fake_labels = torch.zeros(BATCH_SIZE, 1).to(DEVICE) outputs = D(images) d_loss_real = criterion(outputs, real_labels) real_score = outputs - class_label = one_hot_embedding(label, 10) - z = torch.randn(BATCH_SIZE, 64)#.to(DEVICE) + class_label = one_hot_embedding(label, 10).to(DEVICE) + z = torch.randn(BATCH_SIZE, 64).to(DEVICE) generator_input = torch.cat([z, class_label], 1) @@ -125,69 +120,25 @@ def one_hot_embedding(labels, num_classes): if (i+1) % 200 == 0: print('Epoch [{}/{}], Step [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}, D(x): {:.2f}, D(G(z)): {:.2f}' - .format(epoch, EPOCHS, i+1, total_step, d_loss.item(), g_loss.item(), - real_score.mean().item(), fake_score.mean().item())) - if (epoch+1) % 10 == 0 and (i+1) % 100 == 0 : - fake_images = np.reshape(fake_images.data.numpy()[0],(28, 28)) - plt.imshow(fake_images, cmap = 'gray') - plt.show() - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + .format(epoch, + EPOCHS, + i+1, + total_step, + d_loss.item(), + g_loss.item(), + real_score.mean().item(), + fake_score.mean().item())) + + +for i in range(100): + label = torch.tensor([4]) + class_label = one_hot_embedding(label, 10).to(DEVICE) + z = torch.randn(1, 64).to(DEVICE) + generator_input = torch.cat([z, class_label], 1) + fake_images= G(generator_input) + fake_images = np.reshape(fake_images.cpu().data.numpy()[0],(28, 28)) + plt.imshow(fake_images, cmap = 'gray') + plt.show() diff --git "a/10-\354\243\274\354\226\264\354\247\204_\355\231\230\352\262\275\352\263\274_\354\203\201\355\230\270\354\236\221\354\232\251\354\235\204_\355\206\265\355\225\264_\355\225\231\354\212\265\355\225\230\353\212\224_DQN/01-cartpole-dqn.py" "b/10-\354\243\274\354\226\264\354\247\204_\355\231\230\352\262\275\352\263\274_\354\203\201\355\230\270\354\236\221\354\232\251\354\235\204_\355\206\265\355\225\264_\355\225\231\354\212\265\355\225\230\353\212\224_DQN/01-cartpole-dqn.py" index b49cbb2..2221b92 100644 --- "a/10-\354\243\274\354\226\264\354\247\204_\355\231\230\352\262\275\352\263\274_\354\203\201\355\230\270\354\236\221\354\232\251\354\235\204_\355\206\265\355\225\264_\355\225\231\354\212\265\355\225\230\353\212\224_DQN/01-cartpole-dqn.py" +++ "b/10-\354\243\274\354\226\264\354\247\204_\355\231\230\352\262\275\352\263\274_\354\203\201\355\230\270\354\236\221\354\232\251\354\235\204_\355\206\265\355\225\264_\355\225\231\354\212\265\355\225\230\353\212\224_DQN/01-cartpole-dqn.py" @@ -2,23 +2,6 @@ # coding: utf-8 # # 카트폴 게임 마스터하기 -# 어떤 게임을 마스터한다는 뜻은 최고의 점수를 받는다는 뜻이기도 합니다. -# 그러므로 게임의 점수를 리워드로 취급하면 될 것 같습니다. -# 우리가 만들 에이전트는 리워드를 예측하고, -# 리워드를 최대로 만드는 쪽으로 학습하게 할 것입니다. -# 예를들어 카트폴 게임에서는 막대기를 세우고 오래 버틸수록 점수가 증가합니다. -# 카트폴 게임에서 막대가 오른쪽으로 기울었을때, -# 어느 동작이 가장 큰 리워드를 준다고 예측할 수 있을까요? -# 오른쪽으로 가서 중심을 다시 맞춰야 하니 -# 오른쪽 버튼을 누르는 쪽이 왼쪽 버튼보다 리워드가 클 것이라고 예측 할 수 있습니다. -# 이것을 한줄로 요약하자면 아래 한줄의 코드가 됩니다. -# ``` -# target = reward + gamma * np.amax(model.predict(next_state)) -# ``` -# DQN은 가장 중요한 특징 2가지로 요약될 수 있습니다. -# 바로 기억하기(Remember)와 다시 보기(Replay)입니다. -# 둘다 간단한 아이디어이지만 신경망이 강화학습에 이용될 수 있게 만든 혁명적인 방법들입니다. -# 순서대로 개념과 구현법을 알아보도록 하겠습니다. import gym from gym import wrappers @@ -43,99 +26,16 @@ # ### 하이퍼파라미터 -# hyper parameters -EPISODES = 50 # number of episodes -EPS_START = 0.9 # e-greedy threshold start value -EPS_END = 0.05 # e-greedy threshold end value +# 하이퍼파라미터 +EPISODES = 50 # 에피소드 반복 횟수 +EPS_START = 0.9 # e-greedy threshold 시작 값 +EPS_END = 0.05 # e-greedy threshold 최종 값 EPS_DECAY = 200 # e-greedy threshold decay -GAMMA = 0.8 # Q-learning discount factor -LR = 0.001 # NN optimizer learning rate +GAMMA = 0.8 LR = 0.001 # NN optimizer learning rate BATCH_SIZE = 64 # Q-learning batch size # ## DQN 에이전트 -# DQNAgent라는 클래스를 만들어 -# ```python -# class DQNAgent -# ``` -# ### DQN 에이전트의 뇌, 뉴럴넷 -# ![dqn_net](./assets/dqn_net.png) -# ```python -# self.model = nn.Sequential( -# nn.Linear(4, 256), -# nn.ReLU(), -# nn.Linear(256, 2) -# ) -# ``` -# ### 행동하기 (Act) -# ### 전 경험 기억하기 (Remember) -# 신경망을 Q-learning학습에 처음 적용하면서 맞닥뜨린 문제는 -# 바로 신경망이 새로운 경험을 전 경험에 겹쳐쓰며 쉽게 잊어버린다는 것이었습니다. -# 그래서 나온 해결책이 바로 기억하기(Remember)라는 기능인데요, -# 바로 이전 경험들을 배열에 담아 계속 재학습 시키며 신경망이 까먹지 않게 하는 아이디어 입니다. -# 각 경험은 상태, 행동, 보상등을 담아야 합니다. -# 이전 경험들을 담을 배열을 `memory`라고 부르고 아래와 같이 만들어봅시다. -# ```python -# self.memory = [(상태, 행동, 보상, 다음 상태)...] -# ``` -# 이를 구현하기 위해 복잡한 모델을 만들때는 Memory클래스를 구현하기도 하지만, -# 이번 예제에서는 사용하기 가장 간단한 deque (double ended queue), -# 즉 큐(queue) 자료구조를 이용할 것입니다. -# 파이썬에서 `deque`의 `maxlen`을 지정해주었을때 큐가 가득 찼을 경우 -# 제일 오래된 요소부터 없어지므로 -# 자연스레 오래된 기억을 까먹게 해주는 역할을 할 수 있습니다. -# ```python -# self.memory = deque(maxlen=10000) -# ``` -# 그리고 memory 배열에 새로운 경험을 덧붙일 remember() 함수를 만들어보겠습니다. -# ```python -# def memorize(self, state, action, reward, next_state): -# self.memory.append((state, -# action, -# torch.FloatTensor([reward]), -# torch.FloatTensor([next_state]))) -# ``` -# ### 경험으로부터 배우기 (Experience Replay) -# 이전 경험들을 모아놨으면 반복적으로 학습해야합니다. -# 사람도 수면중일때 자동차 운전, 농구 슈팅, -# 등 운동과 관련된 정보를 정리하며, -# 단기 기억을 운동피질에서 측두엽으로 전달하여 장기 기억으로 변환시킨다고 합니다. -# 우연하게도 DQN에이전트가 기억하고 다시 상기하는 과정도 비슷한 개념입니다. -# `learn`함수는 바로 이런 개념으로 방금 만들어둔 뉴럴넷인 `model`을 -# `memory`에 쌓인 경험을 토대로 학습시키는 역할을 합니다. -# ```python -# def learn(self): -# """Experience Replay""" -# if len(self.memory) < BATCH_SIZE: -# return -# batch = random.sample(self.memory, BATCH_SIZE) -# states, actions, rewards, next_states = zip(*batch) -# ``` -# `self.memory`에서 무작위로 배치 크기만큼의 "경험"들을 가져옵니다. -# 이 예제에선 배치사이즈를 64개로 정했습니다. -# ```python -# states = torch.cat(states) -# actions = torch.cat(actions) -# rewards = torch.cat(rewards) -# next_states = torch.cat(next_states) -# ``` -# 각각의 경험들은 상태(`states`), 행동(`actions`), 행동에 따른 보상(`rewards`), -# 그리고 다음 상태(`next_states`)를 담고있습니다. -# 모두 리스트의 리스트 형태이므로 `torch.cat()`을 이용하여 하나의 리스트로 만듭니다. -# `cat`은 concatenate의 준말로 결합하다, 혹은 연결하다라는 뜻입니다. -# ```python -# current_q = self.model(states).gather(1, actions) -# max_next_q = self.model(next_states).detach().max(1)[0] -# expected_q = rewards + (GAMMA * max_next_q) -# ``` -# Q값을 구합니다. -# ```python -# loss = F.mse_loss(current_q.squeeze(), expected_q) -# self.optimizer.zero_grad() -# loss.backward() -# self.optimizer.step() -# ``` -# 학습시킵니다. class DQNAgent: def __init__(self): @@ -195,43 +95,7 @@ def learn(self): score_history = [] -# ## 학습 시작하기 -# EPISODES는 얼마나 많은 게임을 진행하느냐를 나타내는 하이퍼파라미터입니다. -# ``` -# for e in range(1, EPISODES+1): -# state = env.reset() -# steps = 0 -# ``` -# `done`변수에는 게임이 끝났는지의 여부가 참(True), 거짓(False)로 표현됩니다. -# ``` -# while True: -# env.render() -# state = torch.FloatTensor([state]) -# action = agent.act(state) -# next_state, reward, done, _ = env.step(action.item()) -# ``` -# 우리의 에이전트가 한 행동의 결과가 나왔습니다! -# 이 경험을 기억(memorize)하고 배우도록합니다. -# ``` -# # negative reward when attempt ends -# if done: -# reward = -1 -# agent.memorize(state, action, reward, next_state) -# agent.learn() -# state = next_state -# steps += 1 -# ``` -# 게임이 끝났을 경우 `done`이 `True`가 되며 아래 코드가 실행되게 됩니다. -# 보통 게임 분석을 위해 복잡한 도구와 코드가 사용되는 경우가 많으나 -# 여기서는 간단하게 에피소드 숫자와 점수만 표기하도록 하겠습니다. -# 또 앞서 만들어둔 `score_history` 리스트에 점수를 담도록 합니다. -# 마지막으로 게임이 더 이상 진행되지 않으므로 `break` 문으로 무한루프를 나옵니다. -# ``` -# if done: -# print("에피소드:{0} 점수: {1}".format(e, steps)) -# score_history.append(steps) -# break -# ``` +# ## 학습 시작 for e in range(1, EPISODES+1): state = env.reset() diff --git a/README.md b/README.md index 70bf549..a9d7e6d 100644 --- a/README.md +++ b/README.md @@ -47,7 +47,6 @@ 8. [딥러닝 해킹하기](08-딥러닝_해킹하기) - Adversarial Attack * [개념] Adversarial Attack 이란? * [프로젝트 1] [FGSM 공격](08-딥러닝_해킹하기/01-fgsm-attack.ipynb) - * [프로젝트 2] [타겟을 정해 공격하기](08-딥러닝_해킹하기/02-iterative-target-attack.ipynb) * 더 보기 9. [경쟁을 통해 성장하는 GAN](09-경쟁을_통해_학습하는_GAN) - GAN을 이용하여 새로운 패션 아이템을 만들어봅니다. * [개념] GAN 기초