58 lines
2.2 KiB
Python
58 lines
2.2 KiB
Python
import rltorch
|
|
from copy import deepcopy
|
|
import torch
|
|
import numpy as np
|
|
|
|
class REINFORCEAgent:
|
|
def __init__(self, net , memory, config, target_net = None, logger = None):
|
|
self.net = net
|
|
if not isinstance(memory, rltorch.memory.EpisodeMemory):
|
|
raise ValueError("Memory must be of instance EpisodeMemory")
|
|
self.memory = memory
|
|
self.config = deepcopy(config)
|
|
self.target_net = target_net
|
|
self.logger = logger
|
|
|
|
# Shaped rewards implements three improvements to REINFORCE
|
|
# 1) Discounted rewards, future rewards matter less than current
|
|
# 2) Baselines: We use the mean reward to see if the current reward is advantageous or not
|
|
# 3) Causality: Your current actions do not affect your past. Only the present and future.
|
|
def _shape_rewards(self, rewards):
|
|
shaped_rewards = torch.zeros_like(rewards)
|
|
baseline = rewards.mean()
|
|
for i in range(len(rewards)):
|
|
gammas = torch.cumprod(torch.tensor(self.config['discount_rate']).repeat(len(rewards) - i), dim = 0)
|
|
advantages = rewards[i:] - baseline
|
|
shaped_rewards[i] = (gammas * advantages).sum()
|
|
return shaped_rewards
|
|
|
|
def learn(self):
|
|
episode_batch = self.memory.recall()
|
|
state_batch, action_batch, reward_batch, next_state_batch, done_batch, log_prob_batch = zip(*episode_batch)
|
|
|
|
# Caluclate discounted rewards to place more importance to recent rewards
|
|
shaped_reward_batch = self._shape_rewards(torch.tensor(reward_batch))
|
|
|
|
# Scale discounted rewards to have variance 1 (stabalizes training)
|
|
shaped_reward_batch = shaped_reward_batch / (shaped_reward_batch.std() + np.finfo('float').eps)
|
|
|
|
log_prob_batch = torch.cat(log_prob_batch)
|
|
|
|
policy_loss = (-log_prob_batch * shaped_reward_batch).sum()
|
|
|
|
if self.logger is not None:
|
|
self.logger.append("Loss", policy_loss.item())
|
|
|
|
self.net.zero_grad()
|
|
policy_loss.backward()
|
|
self.net.clamp_gradients()
|
|
self.net.step()
|
|
|
|
if self.target_net is not None:
|
|
if 'target_sync_tau' in self.config:
|
|
self.target_net.partial_sync(self.config['target_sync_tau'])
|
|
else:
|
|
self.target_net.sync()
|
|
|
|
# Memory under the old policy is not needed for future training
|
|
self.memory.clear()
|