Incorporated concepts from the paper "Deep Q-Learning From Demonstrations"
This commit is contained in:
parent
744656aaa9
commit
8dd9ca617e
2 changed files with 7 additions and 4 deletions
5
play.py
5
play.py
|
@ -151,7 +151,7 @@ class Play:
|
||||||
for i in range(self.num_sneaky_episodes):
|
for i in range(self.num_sneaky_episodes):
|
||||||
print("Episode: %d / %d, Reward: " % ((self.num_sneaky_episodes * self.sneaky_iteration) + i + 1, (self.sneaky_iteration + 1) * self.num_sneaky_episodes), end = "")
|
print("Episode: %d / %d, Reward: " % ((self.num_sneaky_episodes * self.sneaky_iteration) + i + 1, (self.sneaky_iteration + 1) * self.num_sneaky_episodes), end = "")
|
||||||
|
|
||||||
# Reset all episode releated variables
|
# Reset all episode related variables
|
||||||
prev_obs = self.sneaky_env.reset()
|
prev_obs = self.sneaky_env.reset()
|
||||||
done = False
|
done = False
|
||||||
step = 0
|
step = 0
|
||||||
|
@ -280,6 +280,9 @@ class Play:
|
||||||
|
|
||||||
# Increment the timer if it's the human or shown computer's turn
|
# Increment the timer if it's the human or shown computer's turn
|
||||||
if self.state is COMPUTER_PLAY or self.state is HUMAN_PLAY:
|
if self.state is COMPUTER_PLAY or self.state is HUMAN_PLAY:
|
||||||
|
if self.state == HUMAN_PLAY and isinstance(self.agent.memory, 'DQfDMemory'):
|
||||||
|
self.agent.memory.append_demonstration(prev_obs, action, reward, obs, env_done)
|
||||||
|
else:
|
||||||
self.agent.memory.append(prev_obs, action, reward, obs, env_done)
|
self.agent.memory.append(prev_obs, action, reward, obs, env_done)
|
||||||
i += 1
|
i += 1
|
||||||
# Perform a quick learning process and increment the state after a certain time period has passed
|
# Perform a quick learning process and increment the state after a certain time period has passed
|
||||||
|
|
|
@ -152,9 +152,9 @@ net = rn.Network(Value(state_size, action_size),
|
||||||
target_net = rn.TargetNetwork(net, device = device)
|
target_net = rn.TargetNetwork(net, device = device)
|
||||||
|
|
||||||
# Relevant components from RLTorch
|
# Relevant components from RLTorch
|
||||||
memory = PrioritizedReplayMemory(capacity = config['memory_size'], alpha = config['prioritized_replay_sampling_priority'])
|
memory = M.DQfDMemory(capacity= config['memory_size'], alpha = config['prioritized_replay_sampling_priority'], max_demo = config['memory_size'] // 2)
|
||||||
actor = ArgMaxSelector(net, action_size, device = device)
|
actor = ArgMaxSelector(net, action_size, device = device)
|
||||||
agent = rltorch.agents.DQNAgent(net, memory, config, target_net = target_net)
|
agent = rltorch.agents.DQfDAgent(net, memory, config, target_net = target_net)
|
||||||
|
|
||||||
# Use a different environment for when the computer trains on the side so that the current game state isn't manipuated
|
# Use a different environment for when the computer trains on the side so that the current game state isn't manipuated
|
||||||
# Also use MaxEnvSkip to speed up processing
|
# Also use MaxEnvSkip to speed up processing
|
||||||
|
|
Loading…
Reference in a new issue