Merge branch 'master' of https://github.com/Brandon-Rozek/GymHTTP
# Conflicts: # examples/example_dqn.py
This commit is contained in:
commit
5923c9893c
1 changed files with 39 additions and 41 deletions
|
@ -1,4 +1,4 @@
|
||||||
from gymclient import env
|
from gymclient import Environment
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import random
|
import random
|
||||||
from collections import deque
|
from collections import deque
|
||||||
|
@ -51,6 +51,7 @@ class Value(nn.Module):
|
||||||
self.advantage_fc = nn.Linear(384, 384)
|
self.advantage_fc = nn.Linear(384, 384)
|
||||||
self.advantage = nn.Linear(384, action_size)
|
self.advantage = nn.Linear(384, action_size)
|
||||||
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = x.float() / 255
|
x = x.float() / 255
|
||||||
# Size changes from (batch_size, 4, 80, 70) to ()
|
# Size changes from (batch_size, 4, 80, 70) to ()
|
||||||
|
@ -116,6 +117,7 @@ class DQNAgent:
|
||||||
action = self.act_random() if (action_values[0] == action_values).all() else action_values.argmax().item()
|
action = self.act_random() if (action_values[0] == action_values).all() else action_values.argmax().item()
|
||||||
return action
|
return action
|
||||||
|
|
||||||
|
|
||||||
def replay(self, batch_size):
|
def replay(self, batch_size):
|
||||||
minibatch = self.memory.sample(batch_size)
|
minibatch = self.memory.sample(batch_size)
|
||||||
state_batch, action_batch, reward_batch, next_state_batch, done_batch = zip(*minibatch)
|
state_batch, action_batch, reward_batch, next_state_batch, done_batch = zip(*minibatch)
|
||||||
|
@ -230,25 +232,23 @@ class FrameStack(gym.Wrapper):
|
||||||
assert len(self.frames) == self.k
|
assert len(self.frames) == self.k
|
||||||
return LazyFrames(list(self.frames))
|
return LazyFrames(list(self.frames))
|
||||||
|
|
||||||
|
env = Environment("127.0.0.1", 5000)
|
||||||
|
# env = FrameStack(ProcessFrame(FireResetEnv(env)), 4)
|
||||||
|
env = FrameStack(FireResetEnv(env), 4)
|
||||||
|
# env.seed(SEED)
|
||||||
|
state_size = [1, 4, 80, 70]
|
||||||
|
action_size = env.action_space.n
|
||||||
|
|
||||||
|
agent = DQNAgent(state_size, action_size)
|
||||||
|
done = False
|
||||||
|
batch_size = 32
|
||||||
|
EPISODES = 100
|
||||||
|
epsilon = 0.999
|
||||||
|
|
||||||
def train():
|
replaySkip = 4
|
||||||
global env
|
batch_size = batch_size * replaySkip
|
||||||
# env = FrameStack(ProcessFrame(FireResetEnv(env)), 4)
|
# Now that we have some experiences in our buffer, start training
|
||||||
env = FrameStack(FireResetEnv(env), 4)
|
for episode_num in range(EPISODES):
|
||||||
# env.seed(SEED)
|
|
||||||
state_size = [1, 4, 80, 70]
|
|
||||||
action_size = env.action_space.n
|
|
||||||
|
|
||||||
agent = DQNAgent(state_size, action_size)
|
|
||||||
done = False
|
|
||||||
batch_size = 32
|
|
||||||
EPISODES = 100
|
|
||||||
epsilon = 0.999
|
|
||||||
|
|
||||||
replaySkip = 4
|
|
||||||
batch_size = batch_size * replaySkip
|
|
||||||
for episode_num in range(EPISODES):
|
|
||||||
state = env.reset(preprocess = True)
|
state = env.reset(preprocess = True)
|
||||||
total_reward = 0
|
total_reward = 0
|
||||||
done = False
|
done = False
|
||||||
|
@ -275,6 +275,4 @@ def train():
|
||||||
replaySkip = 4
|
replaySkip = 4
|
||||||
agent.replay(batch_size)
|
agent.replay(batch_size)
|
||||||
|
|
||||||
train()
|
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue