This repository has been archived by the owner on Oct 19, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 244
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
376 additions
and
280 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -106,3 +106,4 @@ data/ | |
.data/ | ||
|
||
tmp/ | ||
snapshot/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,375 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import os\n", | ||
"import sys\n", | ||
"import argparse\n", | ||
"import torch\n", | ||
"import torch.nn as nn\n", | ||
"import torch.nn.functional as F\n", | ||
"from torchtext import data, datasets\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# get hyper parameters\n", | ||
"BATCH_SIZE = 64\n", | ||
"lr = 0.001\n", | ||
"EPOCHS = 40\n", | ||
"torch.manual_seed(42)\n", | ||
"USE_CUDA = torch.cuda.is_available()\n", | ||
"DEVICE = torch.device(\"cuda\" if USE_CUDA else \"cpu\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# class BasicRNN(nn.Module):\n", | ||
"# \"\"\"\n", | ||
"# Basic RNN\n", | ||
"# \"\"\"\n", | ||
"# def __init__(self, n_layers, hidden_dim, n_vocab, embed_dim, n_classes, dropout_p=0.2):\n", | ||
"# super(BasicRNN, self).__init__()\n", | ||
"# print(\"Building Basic RNN model...\")\n", | ||
"# self.n_layers = n_layers\n", | ||
"# self.hidden_dim = hidden_dim\n", | ||
"\n", | ||
"# self.embed = nn.Embedding(n_vocab, embed_dim)\n", | ||
"# self.dropout = nn.Dropout(dropout_p)\n", | ||
"# self.rnn = nn.RNN(embed_dim, hidden_dim, n_layers,\n", | ||
"# dropout=dropout_p, batch_first=True)\n", | ||
"# self.out = nn.Linear(self.hidden_dim, n_classes)\n", | ||
"\n", | ||
"# def forward(self, x):\n", | ||
"# embedded = self.embed(x) # [b, i] -> [b, i, e]\n", | ||
"# _, hidden = self.rnn(embedded)\n", | ||
"# self.dropout(hidden)\n", | ||
"# hidden = hidden.squeeze()\n", | ||
"# logit = self.out(hidden) # [b, h] -> [b, o]\n", | ||
"# return logit\n", | ||
"\n", | ||
"class BasicLSTM(nn.Module):\n", | ||
" def __init__(self, n_layers, hidden_dim, n_vocab, embed_dim, n_classes, dropout_p=0.2):\n", | ||
" super(BasicLSTM, self).__init__()\n", | ||
" print(\"Building Basic LSTM model...\")\n", | ||
" self.n_layers = n_layers\n", | ||
" self.hidden_dim = hidden_dim\n", | ||
"\n", | ||
" self.embed = nn.Embedding(n_vocab, embed_dim)\n", | ||
" self.dropout = nn.Dropout(dropout_p)\n", | ||
" self.lstm = nn.LSTM(embed_dim, self.hidden_dim,\n", | ||
" num_layers=self.n_layers,\n", | ||
" dropout=dropout_p,\n", | ||
" batch_first=True)\n", | ||
" self.out = nn.Linear(self.hidden_dim, n_classes)\n", | ||
"\n", | ||
" def forward(self, x):\n", | ||
" x = self.embed(x) # [b, i] -> [b, i, e]\n", | ||
" h_0 = self._init_state(batch_size=x.size(0))\n", | ||
" x, _ = self.lstm(x, h_0) # [i, b, h]\n", | ||
" h_t = x[:,-1,:]\n", | ||
" self.dropout(h_t)\n", | ||
" logit = self.out(h_t) # [b, h] -> [b, o]\n", | ||
" return logit\n", | ||
" \n", | ||
" def _init_state(self, batch_size=1):\n", | ||
" weight = next(self.parameters()).data\n", | ||
" return (\n", | ||
" weight.new(self.n_layers, batch_size, self.hidden_dim).zero_(),\n", | ||
" weight.new(self.n_layers, batch_size, self.hidden_dim).zero_()\n", | ||
" )" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"def train(model, optimizer, train_iter):\n", | ||
" model.train()\n", | ||
" for b, batch in enumerate(train_iter):\n", | ||
" x, y = batch.text.to(DEVICE), batch.label.to(DEVICE)\n", | ||
" y.data.sub_(1) # index align\n", | ||
" optimizer.zero_grad()\n", | ||
" logit = model(x)\n", | ||
" loss = F.cross_entropy(logit, y)\n", | ||
" loss.backward()\n", | ||
" optimizer.step()\n", | ||
"# if b % 100 == 0:\n", | ||
"# corrects = (logit.max(1)[1].view(y.size()).data == y.data).sum()\n", | ||
"# accuracy = 100.0 * corrects / batch.batch_size\n", | ||
"# sys.stdout.write(\n", | ||
"# '\\rBatch[%d] - loss: %.6f acc: %.2f' %\n", | ||
"# (b, loss.item(), accuracy))" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 5, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"def evaluate(model, val_iter):\n", | ||
" \"\"\"evaluate model\"\"\"\n", | ||
" model.eval()\n", | ||
" corrects, avg_loss = 0, 0\n", | ||
" for batch in val_iter:\n", | ||
" x, y = batch.text.to(DEVICE), batch.label.to(DEVICE)\n", | ||
" y.data.sub_(1) # index align\n", | ||
" logit = model(x)\n", | ||
" loss = F.cross_entropy(logit, y, size_average=False)\n", | ||
" avg_loss += loss.item()\n", | ||
" corrects += (logit.max(1)[1].view(y.size()).data == y.data).sum()\n", | ||
" size = len(val_iter.dataset)\n", | ||
" avg_loss = avg_loss / size\n", | ||
" accuracy = 100.0 * corrects / size\n", | ||
" return avg_loss, accuracy" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# IMDB 데이터셋 가져오기" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 6, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"\n", | ||
"Loading data...\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"# load data\n", | ||
"print(\"\\nLoading data...\")\n", | ||
"TEXT = data.Field(sequential=True, batch_first=True, lower=True)\n", | ||
"LABEL = data.Field(sequential=False, batch_first=True)\n", | ||
"train_data, test_data = datasets.IMDB.splits(TEXT, LABEL)\n", | ||
"TEXT.build_vocab(train_data, min_freq=5)\n", | ||
"LABEL.build_vocab(train_data)\n", | ||
"\n", | ||
"train_iter, test_iter = data.BucketIterator.splits(\n", | ||
" (train_data, test_data), batch_size=BATCH_SIZE,\n", | ||
" shuffle=True, repeat=False)\n", | ||
"\n", | ||
"vocab_size = len(TEXT.vocab)\n", | ||
"n_classes = len(LABEL.vocab) - 1" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 7, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"[TRAIN]: 391 \t [TEST]: 391 \t [VOCAB] 46159 \t [CLASSES] 2\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"print(\"[TRAIN]: %d \\t [TEST]: %d \\t [VOCAB] %d \\t [CLASSES] %d\"\n", | ||
" % (len(train_iter),len(test_iter), vocab_size, n_classes))" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 8, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"Building Basic LSTM model...\n" | ||
] | ||
}, | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"/usr/local/lib/python3.5/dist-packages/torch/nn/modules/rnn.py:38: UserWarning: dropout option adds dropout after all but last recurrent layer, so non-zero dropout expects num_layers greater than 1, but got dropout=0.5 and num_layers=1\n", | ||
" \"num_layers={}\".format(dropout, num_layers))\n" | ||
] | ||
}, | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"BasicLSTM(\n", | ||
" (embed): Embedding(46159, 128)\n", | ||
" (dropout): Dropout(p=0.5)\n", | ||
" (lstm): LSTM(128, 256, batch_first=True, dropout=0.5)\n", | ||
" (out): Linear(in_features=256, out_features=2, bias=True)\n", | ||
")\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"model = BasicLSTM(1, 256, vocab_size, 128, n_classes, 0.5).to(DEVICE)\n", | ||
"optimizer = torch.optim.Adam(model.parameters(), lr=lr)\n", | ||
"\n", | ||
"print(model)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 9, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"/usr/local/lib/python3.5/dist-packages/torchtext-0.2.0-py3.5.egg/torchtext/data/field.py:320: UserWarning: volatile was removed and now has no effect. Use `with torch.no_grad():` instead.\n" | ||
] | ||
}, | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"\n", | ||
"[Epoch: 1] val_loss: 0.63 | acc:65.00\n", | ||
"\n", | ||
"[Epoch: 2] val_loss: 0.49 | acc:77.00\n", | ||
"\n", | ||
"[Epoch: 3] val_loss: 0.34 | acc:85.00\n", | ||
"\n", | ||
"[Epoch: 4] val_loss: 0.32 | acc:87.00\n", | ||
"\n", | ||
"[Epoch: 5] val_loss: 0.39 | acc:87.00\n", | ||
"\n", | ||
"[Epoch: 6] val_loss: 0.42 | acc:86.00\n", | ||
"\n", | ||
"[Epoch: 7] val_loss: 0.51 | acc:86.00\n", | ||
"\n", | ||
"[Epoch: 8] val_loss: 0.56 | acc:86.00\n", | ||
"\n", | ||
"[Epoch: 9] val_loss: 0.69 | acc:86.00\n", | ||
"\n", | ||
"[Epoch: 10] val_loss: 0.75 | acc:85.00\n", | ||
"\n", | ||
"[Epoch: 11] val_loss: 0.68 | acc:86.00\n", | ||
"\n", | ||
"[Epoch: 12] val_loss: 0.69 | acc:86.00\n", | ||
"\n", | ||
"[Epoch: 13] val_loss: 0.73 | acc:86.00\n", | ||
"\n", | ||
"[Epoch: 14] val_loss: 0.80 | acc:86.00\n", | ||
"\n", | ||
"[Epoch: 15] val_loss: 0.81 | acc:86.00\n", | ||
"\n", | ||
"[Epoch: 16] val_loss: 0.89 | acc:86.00\n", | ||
"\n", | ||
"[Epoch: 17] val_loss: 0.94 | acc:86.00\n", | ||
"\n", | ||
"[Epoch: 18] val_loss: 0.95 | acc:86.00\n", | ||
"\n", | ||
"[Epoch: 19] val_loss: 0.70 | acc:86.00\n", | ||
"\n", | ||
"[Epoch: 20] val_loss: 0.74 | acc:86.00\n", | ||
"\n", | ||
"[Epoch: 21] val_loss: 0.90 | acc:85.00\n", | ||
"\n", | ||
"[Epoch: 22] val_loss: 0.78 | acc:86.00\n", | ||
"\n", | ||
"[Epoch: 23] val_loss: 0.87 | acc:86.00\n", | ||
"\n", | ||
"[Epoch: 24] val_loss: 0.89 | acc:86.00\n", | ||
"\n", | ||
"[Epoch: 25] val_loss: 0.93 | acc:86.00\n", | ||
"\n", | ||
"[Epoch: 26] val_loss: 0.98 | acc:87.00\n", | ||
"\n", | ||
"[Epoch: 27] val_loss: 1.01 | acc:86.00\n", | ||
"\n", | ||
"[Epoch: 28] val_loss: 1.05 | acc:87.00\n", | ||
"\n", | ||
"[Epoch: 29] val_loss: 1.08 | acc:86.00\n", | ||
"\n", | ||
"[Epoch: 30] val_loss: 1.10 | acc:86.00\n", | ||
"\n", | ||
"[Epoch: 31] val_loss: 1.13 | acc:86.00\n", | ||
"\n", | ||
"[Epoch: 32] val_loss: 1.16 | acc:86.00\n", | ||
"\n", | ||
"[Epoch: 33] val_loss: 1.18 | acc:87.00\n", | ||
"\n", | ||
"[Epoch: 34] val_loss: 1.20 | acc:87.00\n", | ||
"\n", | ||
"[Epoch: 35] val_loss: 1.23 | acc:87.00\n", | ||
"\n", | ||
"[Epoch: 36] val_loss: 1.25 | acc:87.00\n", | ||
"\n", | ||
"[Epoch: 37] val_loss: 1.27 | acc:86.00\n", | ||
"\n", | ||
"[Epoch: 38] val_loss: 1.29 | acc:86.00\n", | ||
"\n", | ||
"[Epoch: 39] val_loss: 1.31 | acc:86.00\n", | ||
"\n", | ||
"[Epoch: 40] val_loss: 1.34 | acc:86.00\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"best_val_loss = None\n", | ||
"for e in range(1, EPOCHS+1):\n", | ||
" train(model, optimizer, train_iter)\n", | ||
" val_loss, val_accuracy = evaluate(model, test_iter)\n", | ||
"\n", | ||
" print(\"\\n[Epoch: %d] val_loss:%5.2f | acc:%5.2f\" % (e, val_loss, val_accuracy))\n", | ||
" \n", | ||
" # Save the model if the validation loss is the best we've seen so far.\n", | ||
"# if not best_val_loss or val_loss < best_val_loss:\n", | ||
"# if not os.path.isdir(\"snapshot\"):\n", | ||
"# os.makedirs(\"snapshot\")\n", | ||
"# torch.save(model.state_dict(), './snapshot/convcnn.pt')\n", | ||
"# best_val_loss = val_loss" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.5.2" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
Oops, something went wrong.