Added improvements to the REINFORCE algorithm

This commit is contained in:
Brandon Rozek 2019-03-04 17:10:24 -05:00
parent a59f84b446
commit 11d99df977
3 changed files with 302 additions and 14 deletions

View file

@ -13,26 +13,33 @@ class REINFORCEAgent:
self.target_net = target_net
self.logger = logger
def _discount_rewards(self, rewards):
discounted_rewards = torch.zeros_like(rewards)
running_add = 0
for t in reversed(range(len(rewards))):
running_add = running_add * self.config['discount_rate'] + rewards[t]
discounted_rewards[t] = running_add
# Normalize rewards
discounted_rewards = (discounted_rewards - discounted_rewards.mean()) / (discounted_rewards.std() + np.finfo('float').eps)
return discounted_rewards
# Shaped rewards implements three improvements to REINFORCE
# 1) Discounted rewards, future rewards matter less than current
# 2) Baselines: We use the mean reward to see if the current reward is advantageous or not
# 3) Causality: Your current actions do not affect your past. Only the present and future.
def _shape_rewards(self, rewards):
shaped_rewards = torch.zeros_like(rewards)
baseline = rewards.mean()
for i in range(len(rewards)):
gammas = torch.cumprod(torch.tensor(self.config['discount_rate']).repeat(len(rewards) - i), dim = 0)
advantages = rewards[i:] - baseline
shaped_rewards[i] = (gammas * advantages).sum()
return shaped_rewards
def learn(self):
episode_batch = self.memory.recall()
state_batch, action_batch, reward_batch, next_state_batch, done_batch, log_prob_batch = zip(*episode_batch)
discount_reward_batch = self._discount_rewards(torch.tensor(reward_batch))
# Caluclate discounted rewards to place more importance to recent rewards
shaped_reward_batch = self._shape_rewards(torch.tensor(reward_batch))
# Scale discounted rewards to have variance 1 (stabalizes training)
shaped_reward_batch = shaped_reward_batch / (shaped_reward_batch.std() + np.finfo('float').eps)
log_prob_batch = torch.cat(log_prob_batch)
policy_loss = (-log_prob_batch * discount_reward_batch).sum()
policy_loss = (-log_prob_batch * shaped_reward_batch).sum()
if self.logger is not None:
self.logger.append("Loss", policy_loss.item())
@ -47,5 +54,5 @@ class REINFORCEAgent:
else:
self.target_net.sync()
# Memory is irrelevant for future training
# Memory under the old policy is not needed for future training
self.memory.clear()