Corrected gamma multiplication
This commit is contained in:
parent
190eb1f0c4
commit
8683b75ad9
3 changed files with 10 additions and 4 deletions
|
@ -14,7 +14,9 @@ class A2CSingleAgent:
|
||||||
self.logger = logger
|
self.logger = logger
|
||||||
|
|
||||||
def _discount_rewards(self, rewards):
|
def _discount_rewards(self, rewards):
|
||||||
gammas = torch.cumprod(torch.tensor(self.config['discount_rate']).repeat(len(rewards)), dim = 0)
|
gammas = torch.ones_like(rewards)
|
||||||
|
if len(rewards) > 1:
|
||||||
|
gammas[1:] = torch.cumprod(torch.tensor(self.config['discount_rate']).repeat(len(rewards) - 1), dim = 0)
|
||||||
return gammas * rewards
|
return gammas * rewards
|
||||||
|
|
||||||
# This function is currently not used since the performance gains hasn't been shown
|
# This function is currently not used since the performance gains hasn't been shown
|
||||||
|
|
|
@ -18,7 +18,9 @@ class PPOAgent:
|
||||||
self.logger = logger
|
self.logger = logger
|
||||||
|
|
||||||
def _discount_rewards(self, rewards):
|
def _discount_rewards(self, rewards):
|
||||||
gammas = torch.cumprod(torch.tensor(self.config['discount_rate']).repeat(len(rewards)), dim = 0)
|
gammas = torch.ones_like(rewards)
|
||||||
|
if len(rewards) > 1:
|
||||||
|
gammas[1:] = torch.cumprod(torch.tensor(self.config['discount_rate']).repeat(len(rewards) - 1), dim = 0)
|
||||||
return gammas * rewards
|
return gammas * rewards
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -21,7 +21,9 @@ class REINFORCEAgent:
|
||||||
shaped_rewards = torch.zeros_like(rewards)
|
shaped_rewards = torch.zeros_like(rewards)
|
||||||
baseline = rewards.mean()
|
baseline = rewards.mean()
|
||||||
for i in range(len(rewards)):
|
for i in range(len(rewards)):
|
||||||
gammas = torch.cumprod(torch.tensor(self.config['discount_rate']).repeat(len(rewards) - i), dim = 0)
|
gammas = torch.ones_like(rewards[i:])
|
||||||
|
if i != len(rewards) - 1:
|
||||||
|
gammas[1:] = torch.cumprod(torch.tensor(self.config['discount_rate']).repeat(len(rewards) - i - 1), dim = 0)
|
||||||
advantages = rewards[i:] - baseline
|
advantages = rewards[i:] - baseline
|
||||||
shaped_rewards[i] = (gammas * advantages).sum()
|
shaped_rewards[i] = (gammas * advantages).sum()
|
||||||
return shaped_rewards
|
return shaped_rewards
|
||||||
|
|
Loading…
Reference in a new issue