Correct discount_rewards function to only multiply with gamma throughout

This commit is contained in:
Brandon Rozek 2019-03-04 21:59:02 -05:00
parent 11d99df977
commit 190eb1f0c4
2 changed files with 30 additions and 14 deletions

View file

@ -1,4 +1,5 @@
from copy import deepcopy from copy import deepcopy
import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import rltorch import rltorch
@ -13,13 +14,27 @@ class A2CSingleAgent:
self.logger = logger self.logger = logger
def _discount_rewards(self, rewards): def _discount_rewards(self, rewards):
discounted_rewards = torch.zeros_like(rewards) gammas = torch.cumprod(torch.tensor(self.config['discount_rate']).repeat(len(rewards)), dim = 0)
running_add = 0 return gammas * rewards
for t in reversed(range(len(rewards))):
running_add = running_add * self.config['discount_rate'] + rewards[t] # This function is currently not used since the performance gains hasn't been shown
discounted_rewards[t] = running_add # May be due to a faulty implementation, need to investigate more..
def _generalized_advantage_estimation(self, states, rewards, next_states, not_done):
tradeoff = 0.5
with torch.no_grad():
next_values = torch.zeros_like(rewards)
next_values[not_done] = self.value_net(next_states[not_done]).squeeze(1)
values = self.value_net(states).squeeze(1)
generalized_advantages = torch.zeros_like(rewards)
for i in range(len(generalized_advantages)):
weights = torch.ones_like(rewards[i:])
if i != len(generalized_advantages) - 1:
weights[1:] = torch.cumprod(torch.tensor(self.config['discount_rate'] * tradeoff).repeat(len(rewards) - i - 1), dim = 0)
generalized_advantages[i] = (weights * (rewards[i:] + self.config['discount_rate'] * next_values[i:] - values[i:])).sum()
return generalized_advantages
return discounted_rewards
def learn(self): def learn(self):
episode_batch = self.memory.recall() episode_batch = self.memory.recall()
@ -35,7 +50,9 @@ class A2CSingleAgent:
## Value Loss ## Value Loss
# In A2C, the value loss is the difference between the discounted reward and the value from the first state # In A2C, the value loss is the difference between the discounted reward and the value from the first state
# The value of the first state is supposed to tell us the expected reward from the current policy of the whole episode # The value of the first state is supposed to tell us the expected reward from the current policy of the whole episode
value_loss = F.mse_loss(self._discount_rewards(reward_batch).sum(), self.value_net(state_batch[0])) discounted_reward = self._discount_rewards(reward_batch)
observed_value = discounted_reward.sum()
value_loss = F.mse_loss(observed_value, self.value_net(state_batch[0]))
self.value_net.zero_grad() self.value_net.zero_grad()
value_loss.backward() value_loss.backward()
self.value_net.step() self.value_net.step()
@ -50,6 +67,10 @@ class A2CSingleAgent:
advantages = (reward_batch.unsqueeze(1) + self.config['discount_rate'] * next_state_values) - state_values advantages = (reward_batch.unsqueeze(1) + self.config['discount_rate'] * next_state_values) - state_values
advantages = advantages.squeeze(1) advantages = advantages.squeeze(1)
# advantages = self._generalized_advantage_estimation(state_batch, reward_batch, next_state_batch, not_done_batch)
# Scale for more stable learning
advantages = advantages / (advantages.std() + np.finfo('float').eps)
policy_loss = (-log_prob_batch * advantages).sum() policy_loss = (-log_prob_batch * advantages).sum()
if self.logger is not None: if self.logger is not None:

View file

@ -18,13 +18,8 @@ class PPOAgent:
self.logger = logger self.logger = logger
def _discount_rewards(self, rewards): def _discount_rewards(self, rewards):
discounted_rewards = torch.zeros_like(rewards) gammas = torch.cumprod(torch.tensor(self.config['discount_rate']).repeat(len(rewards)), dim = 0)
running_add = 0 return gammas * rewards
for t in reversed(range(len(rewards))):
running_add = running_add * self.config['discount_rate'] + rewards[t]
discounted_rewards[t] = running_add
return discounted_rewards
def learn(self): def learn(self):