From 8dd9ca617efa4e2e1f992edafc78cebb96bf6af6 Mon Sep 17 00:00:00 2001 From: Brandon Rozek Date: Sun, 17 Nov 2019 18:36:35 -0500 Subject: [PATCH] Incorporated concepts from the paper "Deep Q-Learning From Demonstrations" --- play.py | 7 +++++-- play_env.py | 4 ++-- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/play.py b/play.py index c530d9e..61a102f 100644 --- a/play.py +++ b/play.py @@ -151,7 +151,7 @@ class Play: for i in range(self.num_sneaky_episodes): print("Episode: %d / %d, Reward: " % ((self.num_sneaky_episodes * self.sneaky_iteration) + i + 1, (self.sneaky_iteration + 1) * self.num_sneaky_episodes), end = "") - # Reset all episode releated variables + # Reset all episode related variables prev_obs = self.sneaky_env.reset() done = False step = 0 @@ -280,7 +280,10 @@ class Play: # Increment the timer if it's the human or shown computer's turn if self.state is COMPUTER_PLAY or self.state is HUMAN_PLAY: - self.agent.memory.append(prev_obs, action, reward, obs, env_done) + if self.state == HUMAN_PLAY and isinstance(self.agent.memory, 'DQfDMemory'): + self.agent.memory.append_demonstration(prev_obs, action, reward, obs, env_done) + else: + self.agent.memory.append(prev_obs, action, reward, obs, env_done) i += 1 # Perform a quick learning process and increment the state after a certain time period has passed if i % (self.fps * self.seconds_play_per_state) == 0: diff --git a/play_env.py b/play_env.py index fe1a3ce..e47041a 100644 --- a/play_env.py +++ b/play_env.py @@ -152,9 +152,9 @@ net = rn.Network(Value(state_size, action_size), target_net = rn.TargetNetwork(net, device = device) # Relevant components from RLTorch -memory = PrioritizedReplayMemory(capacity = config['memory_size'], alpha = config['prioritized_replay_sampling_priority']) +memory = M.DQfDMemory(capacity= config['memory_size'], alpha = config['prioritized_replay_sampling_priority'], max_demo = config['memory_size'] // 2) actor = ArgMaxSelector(net, action_size, device = device) -agent = rltorch.agents.DQNAgent(net, memory, config, target_net = target_net) +agent = rltorch.agents.DQfDAgent(net, memory, config, target_net = target_net) # Use a different environment for when the computer trains on the side so that the current game state isn't manipuated # Also use MaxEnvSkip to speed up processing