Skip to content
This repository has been archived by the owner on Oct 19, 2023. It is now read-only.

Commit

Permalink
update dqn example
Browse files Browse the repository at this point in the history
  • Loading branch information
keon committed Sep 10, 2019
1 parent baec02a commit c17d678
Showing 1 changed file with 91 additions and 100 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,40 +14,13 @@
"outputs": [],
"source": [
"import gym\n",
"from gym import wrappers\n",
"import random\n",
"import math\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.optim as optim\n",
"from torch.autograd import Variable\n",
"import torch.nn.functional as F\n",
"import matplotlib.pyplot as plt\n",
"from collections import deque\n",
"import numpy as np"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## OpenAI Gym을 이용하여 게임환경 구축하기\n",
"\n",
"\n",
"강화학습 예제들을 보면 항상 게임과 연관되어 있습니다. 원래 우리가 궁극적으로 원하는 목표는 어디서든 적응할 수 있는 인공지능이지만, 너무 복잡한 문제이기도 하고 가상 환경을 설계하기도 어렵기 때문에 일단 게임이라는 환경을 사용해 하는 것입니다.\n",
"\n",
"대부분의 게임은 점수 혹은 목표가 있습니다. 점수가 오르거나 목표에 도달하면 일종의 리워드를 받고 원치 않은 행동을 할때는 마이너스 리워드를 주는 경우도 있습니다. 아까 비유를 들었던 달리기를 배울때의 경우를 예로 들면 총 나아간 길이 혹은 목표 도착지 도착 여부로 리워드를 주고 넘어질때 패널티를 줄 수 있을 것입니다. \n",
"\n",
"게임중에서도 가장 간단한 카트폴이라는 환경을 구축하여 강화학습을 배울 토대를 마련해보겠습니다."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"env = gym.make('CartPole-v1')"
"from collections import deque"
]
},
{
Expand All @@ -60,18 +33,18 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"# 하이퍼파라미터\n",
"EPISODES = 50 # 에피소드 반복 횟수\n",
"EPS_START = 0.9 # e-greedy threshold 시작 값\n",
"EPS_END = 0.05 # e-greedy threshold 최종 값\n",
"EPS_DECAY = 200 # e-greedy threshold decay\n",
"GAMMA = 0.8 # \n",
"LR = 0.001 # NN optimizer learning rate\n",
"BATCH_SIZE = 64 # Q-learning batch size"
"EPISODES = 50 # 애피소드 반복횟수\n",
"EPS_START = 0.9 # 학습 시작시 에이전트가 무작위로 행동할 확률\n",
"EPS_END = 0.05 # 학습 막바지에 에이전트가 무작위로 행동할 확률\n",
"EPS_DECAY = 200 # 학습 진행시 에이전트가 무작위로 행동할 확률을 감소시키는 값\n",
"GAMMA = 0.8 # 할인계수\n",
"LR = 0.001 # 학습률\n",
"BATCH_SIZE = 64 # 배치 크기"
]
},
{
Expand All @@ -94,9 +67,15 @@
" nn.ReLU(),\n",
" nn.Linear(256, 2)\n",
" )\n",
" self.memory = deque(maxlen=10000)\n",
" self.optimizer = optim.Adam(self.model.parameters(), LR)\n",
" self.steps_done = 0\n",
" self.memory = deque(maxlen=10000)\n",
"\n",
" def memorize(self, state, action, reward, next_state):\n",
" self.memory.append((state,\n",
" action,\n",
" torch.FloatTensor([reward]),\n",
" torch.FloatTensor([next_state])))\n",
" \n",
" def act(self, state):\n",
" eps_threshold = EPS_END + (EPS_START - EPS_END) * math.exp(-1. * self.steps_done / EPS_DECAY)\n",
Expand All @@ -105,12 +84,6 @@
" return self.model(state).data.max(1)[1].view(1, 1)\n",
" else:\n",
" return torch.LongTensor([[random.randrange(2)]])\n",
"\n",
" def memorize(self, state, action, reward, next_state):\n",
" self.memory.append((state,\n",
" action,\n",
" torch.FloatTensor([reward]),\n",
" torch.FloatTensor([next_state])))\n",
" \n",
" def learn(self):\n",
" \"\"\"Experience Replay\"\"\"\n",
Expand Down Expand Up @@ -140,10 +113,9 @@
"source": [
"## 학습 준비하기\n",
"\n",
"드디어 만들어둔 DQNAgent를 인스턴스화 합니다.\n",
"그리고 `gym`을 이용하여 `CartPole-v0`환경도 준비합니다.\n",
"자, 이제 `agent` 객체를 이용하여 `CartPole-v0` 환경과 상호작용을 통해 게임을 배우도록 하겠습니다.\n",
"학습 진행을 기록하기 위해 `score_history` 리스트를 이용하여 점수를 저장하겠습니다."
"`gym`을 이용하여 `CartPole-v0`환경을 준비하고 앞서 만들어둔 DQNAgent를 agent로 인스턴스화 합니다.\n",
"\n",
"자, 이제 `agent` 객체를 이용하여 `CartPole-v0` 환경과 상호작용을 통해 게임을 배우도록 하겠습니다."
]
},
{
Expand All @@ -152,8 +124,8 @@
"metadata": {},
"outputs": [],
"source": [
"agent = DQNAgent()\n",
"env = gym.make('CartPole-v0')\n",
"agent = DQNAgent()\n",
"score_history = []"
]
},
Expand All @@ -173,56 +145,56 @@
"name": "stdout",
"output_type": "stream",
"text": [
"에피소드:1 점수: 11\n",
"에피소드:2 점수: 32\n",
"에피소드:3 점수: 10\n",
"에피소드:4 점수: 36\n",
"에피소드:5 점수: 13\n",
"에피소드:6 점수: 17\n",
"에피소드:7 점수: 9\n",
"에피소드:8 점수: 13\n",
"에피소드:9 점수: 11\n",
"에피소드:10 점수: 28\n",
"에피소드:11 점수: 11\n",
"에피소드:12 점수: 12\n",
"에피소드:1 점수: 21\n",
"에피소드:2 점수: 28\n",
"에피소드:3 점수: 21\n",
"에피소드:4 점수: 51\n",
"에피소드:5 점수: 12\n",
"에피소드:6 점수: 20\n",
"에피소드:7 점수: 8\n",
"에피소드:8 점수: 9\n",
"에피소드:9 점수: 10\n",
"에피소드:10 점수: 12\n",
"에피소드:11 점수: 14\n",
"에피소드:12 점수: 11\n",
"에피소드:13 점수: 9\n",
"에피소드:14 점수: 20\n",
"에피소드:15 점수: 11\n",
"에피소드:16 점수: 11\n",
"에피소드:17 점수: 10\n",
"에피소드:18 점수: 15\n",
"에피소드:19 점수: 13\n",
"에피소드:20 점수: 11\n",
"에피소드:21 점수: 13\n",
"에피소드:22 점수: 22\n",
"에피소드:23 점수: 26\n",
"에피소드:24 점수: 59\n",
"에피소드:14 점수: 9\n",
"에피소드:15 점수: 10\n",
"에피소드:16 점수: 26\n",
"에피소드:17 점수: 11\n",
"에피소드:18 점수: 9\n",
"에피소드:19 점수: 11\n",
"에피소드:20 점수: 25\n",
"에피소드:21 점수: 12\n",
"에피소드:22 점수: 19\n",
"에피소드:23 점수: 12\n",
"에피소드:24 점수: 27\n",
"에피소드:25 점수: 30\n",
"에피소드:26 점수: 22\n",
"에피소드:27 점수: 26\n",
"에피소드:28 점수: 25\n",
"에피소드:29 점수: 57\n",
"에피소드:30 점수: 83\n",
"에피소드:31 점수: 62\n",
"에피소드:32 점수: 45\n",
"에피소드:33 점수: 62\n",
"에피소드:34 점수: 80\n",
"에피소드:35 점수: 88\n",
"에피소드:36 점수: 57\n",
"에피소드:37 점수: 52\n",
"에피소드:38 점수: 45\n",
"에피소드:39 점수: 49\n",
"에피소드:40 점수: 63\n",
"에피소드:41 점수: 61\n",
"에피소드:42 점수: 75\n",
"에피소드:43 점수: 52\n",
"에피소드:44 점수: 81\n",
"에피소드:45 점수: 98\n",
"에피소드:46 점수: 129\n",
"에피소드:47 점수: 153\n",
"에피소드:48 점수: 169\n",
"에피소드:49 점수: 120\n",
"에피소드:50 점수: 144\n"
"에피소드:26 점수: 66\n",
"에피소드:27 점수: 24\n",
"에피소드:28 점수: 45\n",
"에피소드:29 점수: 47\n",
"에피소드:30 점수: 35\n",
"에피소드:31 점수: 35\n",
"에피소드:32 점수: 40\n",
"에피소드:33 점수: 44\n",
"에피소드:34 점수: 34\n",
"에피소드:35 점수: 57\n",
"에피소드:36 점수: 52\n",
"에피소드:37 점수: 70\n",
"에피소드:38 점수: 124\n",
"에피소드:39 점수: 118\n",
"에피소드:40 점수: 33\n",
"에피소드:41 점수: 128\n",
"에피소드:42 점수: 55\n",
"에피소드:43 점수: 178\n",
"에피소드:44 점수: 88\n",
"에피소드:45 점수: 103\n",
"에피소드:46 점수: 101\n",
"에피소드:47 점수: 120\n",
"에피소드:48 점수: 140\n",
"에피소드:49 점수: 113\n",
"에피소드:50 점수: 85\n"
]
}
],
Expand Down Expand Up @@ -252,19 +224,38 @@
" break"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[21, 28, 21, 51, 12, 20, 8, 9, 10, 12, 14, 11, 9, 9, 10, 26, 11, 9, 11, 25, 12, 19, 12, 27, 30, 66, 24, 45, 47, 35, 35, 40, 44, 34, 57, 52, 70, 124, 118, 33, 128, 55, 178, 88, 103, 101, 120, 140, 113, 85]\n"
]
}
],
"source": [
"print(score_history)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
"source": [
"import matplotlib"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"display_name": "pytorch",
"language": "python",
"name": "python3"
"name": "pytorch"
},
"language_info": {
"codemirror_mode": {
Expand All @@ -276,7 +267,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.0"
"version": "3.7.3"
}
},
"nbformat": 4,
Expand Down

0 comments on commit c17d678

Please sign in to comment.