Incorporated concepts from the paper "Deep Q-Learning From Demonstrations"

This commit is contained in:
Brandon Rozek 2019-11-17 18:36:35 -05:00
parent 744656aaa9
commit 8dd9ca617e
2 changed files with 7 additions and 4 deletions

View file

@ -151,7 +151,7 @@ class Play:
for i in range(self.num_sneaky_episodes): for i in range(self.num_sneaky_episodes):
print("Episode: %d / %d, Reward: " % ((self.num_sneaky_episodes * self.sneaky_iteration) + i + 1, (self.sneaky_iteration + 1) * self.num_sneaky_episodes), end = "") print("Episode: %d / %d, Reward: " % ((self.num_sneaky_episodes * self.sneaky_iteration) + i + 1, (self.sneaky_iteration + 1) * self.num_sneaky_episodes), end = "")
# Reset all episode releated variables # Reset all episode related variables
prev_obs = self.sneaky_env.reset() prev_obs = self.sneaky_env.reset()
done = False done = False
step = 0 step = 0
@ -280,6 +280,9 @@ class Play:
# Increment the timer if it's the human or shown computer's turn # Increment the timer if it's the human or shown computer's turn
if self.state is COMPUTER_PLAY or self.state is HUMAN_PLAY: if self.state is COMPUTER_PLAY or self.state is HUMAN_PLAY:
if self.state == HUMAN_PLAY and isinstance(self.agent.memory, 'DQfDMemory'):
self.agent.memory.append_demonstration(prev_obs, action, reward, obs, env_done)
else:
self.agent.memory.append(prev_obs, action, reward, obs, env_done) self.agent.memory.append(prev_obs, action, reward, obs, env_done)
i += 1 i += 1
# Perform a quick learning process and increment the state after a certain time period has passed # Perform a quick learning process and increment the state after a certain time period has passed

View file

@ -152,9 +152,9 @@ net = rn.Network(Value(state_size, action_size),
target_net = rn.TargetNetwork(net, device = device) target_net = rn.TargetNetwork(net, device = device)
# Relevant components from RLTorch # Relevant components from RLTorch
memory = PrioritizedReplayMemory(capacity = config['memory_size'], alpha = config['prioritized_replay_sampling_priority']) memory = M.DQfDMemory(capacity= config['memory_size'], alpha = config['prioritized_replay_sampling_priority'], max_demo = config['memory_size'] // 2)
actor = ArgMaxSelector(net, action_size, device = device) actor = ArgMaxSelector(net, action_size, device = device)
agent = rltorch.agents.DQNAgent(net, memory, config, target_net = target_net) agent = rltorch.agents.DQfDAgent(net, memory, config, target_net = target_net)
# Use a different environment for when the computer trains on the side so that the current game state isn't manipuated # Use a different environment for when the computer trains on the side so that the current game state isn't manipuated
# Also use MaxEnvSkip to speed up processing # Also use MaxEnvSkip to speed up processing