Corrected gamma multiplication

This commit is contained in:
Brandon Rozek 2019-03-04 22:04:13 -05:00
parent 190eb1f0c4
commit 8683b75ad9
3 changed files with 10 additions and 4 deletions

View file

@ -14,9 +14,11 @@ class A2CSingleAgent:
self.logger = logger
def _discount_rewards(self, rewards):
gammas = torch.cumprod(torch.tensor(self.config['discount_rate']).repeat(len(rewards)), dim = 0)
gammas = torch.ones_like(rewards)
if len(rewards) > 1:
gammas[1:] = torch.cumprod(torch.tensor(self.config['discount_rate']).repeat(len(rewards) - 1), dim = 0)
return gammas * rewards
# This function is currently not used since the performance gains hasn't been shown
# May be due to a faulty implementation, need to investigate more..
def _generalized_advantage_estimation(self, states, rewards, next_states, not_done):

View file

@ -18,7 +18,9 @@ class PPOAgent:
self.logger = logger
def _discount_rewards(self, rewards):
gammas = torch.cumprod(torch.tensor(self.config['discount_rate']).repeat(len(rewards)), dim = 0)
gammas = torch.ones_like(rewards)
if len(rewards) > 1:
gammas[1:] = torch.cumprod(torch.tensor(self.config['discount_rate']).repeat(len(rewards) - 1), dim = 0)
return gammas * rewards

View file

@ -21,7 +21,9 @@ class REINFORCEAgent:
shaped_rewards = torch.zeros_like(rewards)
baseline = rewards.mean()
for i in range(len(rewards)):
gammas = torch.cumprod(torch.tensor(self.config['discount_rate']).repeat(len(rewards) - i), dim = 0)
gammas = torch.ones_like(rewards[i:])
if i != len(rewards) - 1:
gammas[1:] = torch.cumprod(torch.tensor(self.config['discount_rate']).repeat(len(rewards) - i - 1), dim = 0)
advantages = rewards[i:] - baseline
shaped_rewards[i] = (gammas * advantages).sum()
return shaped_rewards