Added improvements to the REINFORCE algorithm
This commit is contained in:
parent
a59f84b446
commit
11d99df977
3 changed files with 302 additions and 14 deletions
|
@ -13,26 +13,33 @@ class REINFORCEAgent:
|
|||
self.target_net = target_net
|
||||
self.logger = logger
|
||||
|
||||
def _discount_rewards(self, rewards):
|
||||
discounted_rewards = torch.zeros_like(rewards)
|
||||
running_add = 0
|
||||
for t in reversed(range(len(rewards))):
|
||||
running_add = running_add * self.config['discount_rate'] + rewards[t]
|
||||
discounted_rewards[t] = running_add
|
||||
|
||||
# Normalize rewards
|
||||
discounted_rewards = (discounted_rewards - discounted_rewards.mean()) / (discounted_rewards.std() + np.finfo('float').eps)
|
||||
return discounted_rewards
|
||||
# Shaped rewards implements three improvements to REINFORCE
|
||||
# 1) Discounted rewards, future rewards matter less than current
|
||||
# 2) Baselines: We use the mean reward to see if the current reward is advantageous or not
|
||||
# 3) Causality: Your current actions do not affect your past. Only the present and future.
|
||||
def _shape_rewards(self, rewards):
|
||||
shaped_rewards = torch.zeros_like(rewards)
|
||||
baseline = rewards.mean()
|
||||
for i in range(len(rewards)):
|
||||
gammas = torch.cumprod(torch.tensor(self.config['discount_rate']).repeat(len(rewards) - i), dim = 0)
|
||||
advantages = rewards[i:] - baseline
|
||||
shaped_rewards[i] = (gammas * advantages).sum()
|
||||
return shaped_rewards
|
||||
|
||||
def learn(self):
|
||||
episode_batch = self.memory.recall()
|
||||
state_batch, action_batch, reward_batch, next_state_batch, done_batch, log_prob_batch = zip(*episode_batch)
|
||||
|
||||
discount_reward_batch = self._discount_rewards(torch.tensor(reward_batch))
|
||||
# Caluclate discounted rewards to place more importance to recent rewards
|
||||
shaped_reward_batch = self._shape_rewards(torch.tensor(reward_batch))
|
||||
|
||||
# Scale discounted rewards to have variance 1 (stabalizes training)
|
||||
shaped_reward_batch = shaped_reward_batch / (shaped_reward_batch.std() + np.finfo('float').eps)
|
||||
|
||||
log_prob_batch = torch.cat(log_prob_batch)
|
||||
|
||||
policy_loss = (-log_prob_batch * discount_reward_batch).sum()
|
||||
|
||||
policy_loss = (-log_prob_batch * shaped_reward_batch).sum()
|
||||
|
||||
if self.logger is not None:
|
||||
self.logger.append("Loss", policy_loss.item())
|
||||
|
||||
|
@ -47,5 +54,5 @@ class REINFORCEAgent:
|
|||
else:
|
||||
self.target_net.sync()
|
||||
|
||||
# Memory is irrelevant for future training
|
||||
# Memory under the old policy is not needed for future training
|
||||
self.memory.clear()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue