Skip to content

Commit

Permalink
Fix deep_q_network last 4 images bug
Browse files Browse the repository at this point in the history
  • Loading branch information
yenchenlin committed Mar 26, 2016
1 parent 944bb47 commit fac4ba7
Show file tree
Hide file tree
Showing 13 changed files with 16 additions and 12 deletions.
16 changes: 9 additions & 7 deletions deep_q_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@
ACTIONS = 2 # number of valid actions
GAMMA = 0.99 # decay rate of past observations
OBSERVE = 100000. # timesteps to observe before training
EXPLORE = 150000. # frames over which to anneal epsilon
FINAL_EPSILON = 0.0 # final value of epsilon
INITIAL_EPSILON = 0.0 # starting value of epsilon
EXPLORE = 2000000. # frames over which to anneal epsilon
FINAL_EPSILON = 0.0001 # final value of epsilon
INITIAL_EPSILON = 0.0001 # starting value of epsilon
REPLAY_MEMORY = 50000 # number of previous transitions to remember
BATCH = 32 # size of minibatch
FRAME_PER_ACTION = 1
Expand Down Expand Up @@ -79,7 +79,7 @@ def trainNetwork(s, readout, h_fc1, sess):
# define the cost function
a = tf.placeholder("float", [None, ACTIONS])
y = tf.placeholder("float", [None])
readout_action = tf.reduce_sum(tf.mul(readout, a), reduction_indices = 1)
readout_action = tf.reduce_sum(tf.mul(readout, a), reduction_indices=1)
cost = tf.reduce_mean(tf.square(y - readout_action))
train_step = tf.train.AdamOptimizer(1e-6).minimize(cost)

Expand All @@ -99,7 +99,7 @@ def trainNetwork(s, readout, h_fc1, sess):
x_t, r_0, terminal = game_state.frame_step(do_nothing)
x_t = cv2.cvtColor(cv2.resize(x_t, (80, 80)), cv2.COLOR_BGR2GRAY)
ret, x_t = cv2.threshold(x_t,1,255,cv2.THRESH_BINARY)
s_t = np.stack((x_t, x_t, x_t, x_t), axis = 2)
s_t = np.stack((x_t, x_t, x_t, x_t), axis=2)

# saving and loading networks
saver = tf.train.Saver()
Expand All @@ -111,11 +111,12 @@ def trainNetwork(s, readout, h_fc1, sess):
else:
print("Could not find old network weights")

# start training
epsilon = INITIAL_EPSILON
t = 0
while "flappy bird" != "angry bird":
# choose an action epsilon greedily
readout_t = readout.eval(feed_dict = {s : [s_t]})[0]
readout_t = readout.eval(feed_dict={s : [s_t]})[0]
a_t = np.zeros([ACTIONS])
action_index = 0
if t % FRAME_PER_ACTION == 0:
Expand All @@ -138,7 +139,8 @@ def trainNetwork(s, readout, h_fc1, sess):
x_t1 = cv2.cvtColor(cv2.resize(x_t1_colored, (80, 80)), cv2.COLOR_BGR2GRAY)
ret, x_t1 = cv2.threshold(x_t1, 1, 255, cv2.THRESH_BINARY)
x_t1 = np.reshape(x_t1, (80, 80, 1))
s_t1 = np.append(x_t1, s_t[:,:,1:], axis = 2)
#s_t1 = np.append(x_t1, s_t[:,:,1:], axis = 2)
s_t1 = np.append(x_t1, s_t[:, :, :3], axis=2)

# store the transition in D
D.append((s_t, a_t, r_t, s_t1, terminal))
Expand Down
2 changes: 1 addition & 1 deletion game/wrapped_flappy_bird.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def __init__(self):
self.playerMaxVelY = 10 # max vel along Y, max descend speed
self.playerMinVelY = -8 # min vel along Y, max ascend speed
self.playerAccY = 1 # players downward accleration
self.playerFlapAcc = -7 # players speed on flapping
self.playerFlapAcc = -9 # players speed on flapping
self.playerFlapped = False # True when player flaps

def frame_step(self, input_actions):
Expand Down
Binary file added saved_networks/bird-dqn-2880000
Binary file not shown.
Binary file added saved_networks/bird-dqn-2880000.meta
Binary file not shown.
Binary file added saved_networks/bird-dqn-2890000
Binary file not shown.
Binary file added saved_networks/bird-dqn-2890000.meta
Binary file not shown.
Binary file added saved_networks/bird-dqn-2900000
Binary file not shown.
Binary file added saved_networks/bird-dqn-2900000.meta
Binary file not shown.
Binary file added saved_networks/bird-dqn-2910000
Binary file not shown.
Binary file added saved_networks/bird-dqn-2910000.meta
Binary file not shown.
Binary file added saved_networks/bird-dqn-2920000
Binary file not shown.
Binary file added saved_networks/bird-dqn-2920000.meta
Binary file not shown.
10 changes: 6 additions & 4 deletions saved_networks/checkpoint
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
model_checkpoint_path: "bird-dqn-30000"
all_model_checkpoint_paths: "bird-dqn-10000"
all_model_checkpoint_paths: "bird-dqn-20000"
all_model_checkpoint_paths: "bird-dqn-30000"
model_checkpoint_path: "bird-dqn-2920000"
all_model_checkpoint_paths: "bird-dqn-2880000"
all_model_checkpoint_paths: "bird-dqn-2890000"
all_model_checkpoint_paths: "bird-dqn-2900000"
all_model_checkpoint_paths: "bird-dqn-2910000"
all_model_checkpoint_paths: "bird-dqn-2920000"

0 comments on commit fac4ba7

Please sign in to comment.