Incorporated concepts from the paper "Deep Q-Learning From Demonstrations"
This commit is contained in:
parent
744656aaa9
commit
8dd9ca617e
2 changed files with 7 additions and 4 deletions
7
play.py
7
play.py
|
@ -151,7 +151,7 @@ class Play:
|
|||
for i in range(self.num_sneaky_episodes):
|
||||
print("Episode: %d / %d, Reward: " % ((self.num_sneaky_episodes * self.sneaky_iteration) + i + 1, (self.sneaky_iteration + 1) * self.num_sneaky_episodes), end = "")
|
||||
|
||||
# Reset all episode releated variables
|
||||
# Reset all episode related variables
|
||||
prev_obs = self.sneaky_env.reset()
|
||||
done = False
|
||||
step = 0
|
||||
|
@ -280,7 +280,10 @@ class Play:
|
|||
|
||||
# Increment the timer if it's the human or shown computer's turn
|
||||
if self.state is COMPUTER_PLAY or self.state is HUMAN_PLAY:
|
||||
self.agent.memory.append(prev_obs, action, reward, obs, env_done)
|
||||
if self.state == HUMAN_PLAY and isinstance(self.agent.memory, 'DQfDMemory'):
|
||||
self.agent.memory.append_demonstration(prev_obs, action, reward, obs, env_done)
|
||||
else:
|
||||
self.agent.memory.append(prev_obs, action, reward, obs, env_done)
|
||||
i += 1
|
||||
# Perform a quick learning process and increment the state after a certain time period has passed
|
||||
if i % (self.fps * self.seconds_play_per_state) == 0:
|
||||
|
|
|
@ -152,9 +152,9 @@ net = rn.Network(Value(state_size, action_size),
|
|||
target_net = rn.TargetNetwork(net, device = device)
|
||||
|
||||
# Relevant components from RLTorch
|
||||
memory = PrioritizedReplayMemory(capacity = config['memory_size'], alpha = config['prioritized_replay_sampling_priority'])
|
||||
memory = M.DQfDMemory(capacity= config['memory_size'], alpha = config['prioritized_replay_sampling_priority'], max_demo = config['memory_size'] // 2)
|
||||
actor = ArgMaxSelector(net, action_size, device = device)
|
||||
agent = rltorch.agents.DQNAgent(net, memory, config, target_net = target_net)
|
||||
agent = rltorch.agents.DQfDAgent(net, memory, config, target_net = target_net)
|
||||
|
||||
# Use a different environment for when the computer trains on the side so that the current game state isn't manipuated
|
||||
# Also use MaxEnvSkip to speed up processing
|
||||
|
|
Loading…
Reference in a new issue