Merge branch 'master' of https://github.com/Brandon-Rozek/GymHTTP
# Conflicts: # examples/example_dqn.py
This commit is contained in:
commit
5923c9893c
1 changed files with 39 additions and 41 deletions
|
@ -1,4 +1,4 @@
|
|||
from gymclient import env
|
||||
from gymclient import Environment
|
||||
import numpy as np
|
||||
import random
|
||||
from collections import deque
|
||||
|
@ -51,6 +51,7 @@ class Value(nn.Module):
|
|||
self.advantage_fc = nn.Linear(384, 384)
|
||||
self.advantage = nn.Linear(384, action_size)
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
x = x.float() / 255
|
||||
# Size changes from (batch_size, 4, 80, 70) to ()
|
||||
|
@ -116,6 +117,7 @@ class DQNAgent:
|
|||
action = self.act_random() if (action_values[0] == action_values).all() else action_values.argmax().item()
|
||||
return action
|
||||
|
||||
|
||||
def replay(self, batch_size):
|
||||
minibatch = self.memory.sample(batch_size)
|
||||
state_batch, action_batch, reward_batch, next_state_batch, done_batch = zip(*minibatch)
|
||||
|
@ -230,51 +232,47 @@ class FrameStack(gym.Wrapper):
|
|||
assert len(self.frames) == self.k
|
||||
return LazyFrames(list(self.frames))
|
||||
|
||||
env = Environment("127.0.0.1", 5000)
|
||||
# env = FrameStack(ProcessFrame(FireResetEnv(env)), 4)
|
||||
env = FrameStack(FireResetEnv(env), 4)
|
||||
# env.seed(SEED)
|
||||
state_size = [1, 4, 80, 70]
|
||||
action_size = env.action_space.n
|
||||
|
||||
agent = DQNAgent(state_size, action_size)
|
||||
done = False
|
||||
batch_size = 32
|
||||
EPISODES = 100
|
||||
epsilon = 0.999
|
||||
|
||||
def train():
|
||||
global env
|
||||
# env = FrameStack(ProcessFrame(FireResetEnv(env)), 4)
|
||||
env = FrameStack(FireResetEnv(env), 4)
|
||||
# env.seed(SEED)
|
||||
state_size = [1, 4, 80, 70]
|
||||
action_size = env.action_space.n
|
||||
|
||||
agent = DQNAgent(state_size, action_size)
|
||||
replaySkip = 4
|
||||
batch_size = batch_size * replaySkip
|
||||
# Now that we have some experiences in our buffer, start training
|
||||
for episode_num in range(EPISODES):
|
||||
state = env.reset(preprocess = True)
|
||||
total_reward = 0
|
||||
done = False
|
||||
batch_size = 32
|
||||
EPISODES = 100
|
||||
epsilon = 0.999
|
||||
|
||||
replaySkip = 4
|
||||
batch_size = batch_size * replaySkip
|
||||
for episode_num in range(EPISODES):
|
||||
state = env.reset(preprocess = True)
|
||||
total_reward = 0
|
||||
done = False
|
||||
replaySkip = 4
|
||||
while not done:
|
||||
replaySkip = replaySkip - 1
|
||||
if np.random.rand() > epsilon:
|
||||
action = agent.act(state)
|
||||
else:
|
||||
action = agent.act_random()
|
||||
epsilon = epsilon * 0.99997
|
||||
next_state, reward, done, _ = env.step(action, preprocess = True)
|
||||
while not done:
|
||||
replaySkip = replaySkip - 1
|
||||
if np.random.rand() > epsilon:
|
||||
action = agent.act(state)
|
||||
else:
|
||||
action = agent.act_random()
|
||||
epsilon = epsilon * 0.99997
|
||||
next_state, reward, done, _ = env.step(action, preprocess = True)
|
||||
|
||||
agent.remember(state, action, reward, next_state, done)
|
||||
total_reward = total_reward + reward
|
||||
state = next_state
|
||||
agent.remember(state, action, reward, next_state, done)
|
||||
total_reward = total_reward + reward
|
||||
state = next_state
|
||||
|
||||
if done:
|
||||
print("episode: {}/{}, score: {}, epsilon: {}"
|
||||
.format(episode_num, EPISODES, total_reward, epsilon))
|
||||
break # We finished this episode
|
||||
if done:
|
||||
print("episode: {}/{}, score: {}, epsilon: {}"
|
||||
.format(episode_num, EPISODES, total_reward, epsilon))
|
||||
break # We finished this episode
|
||||
|
||||
if len(agent.memory) > batch_size and replaySkip <= 0:
|
||||
replaySkip = 4
|
||||
agent.replay(batch_size)
|
||||
|
||||
if len(agent.memory) > batch_size and replaySkip <= 0:
|
||||
replaySkip = 4
|
||||
agent.replay(batch_size)
|
||||
|
||||
train()
|
||||
|
||||
|
||||
|
|
Loading…
Reference in a new issue