Cleaned up scripts, added more comments
This commit is contained in:
parent
e42f5bba1b
commit
a59f84b446
11 changed files with 103 additions and 436 deletions
|
@ -1,12 +1,8 @@
|
|||
from copy import deepcopy
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.distributions import Categorical
|
||||
import rltorch
|
||||
import rltorch.memory as M
|
||||
import collections
|
||||
import random
|
||||
|
||||
class A2CSingleAgent:
|
||||
def __init__(self, policy_net, value_net, memory, config, logger = None):
|
||||
|
@ -25,11 +21,11 @@ class A2CSingleAgent:
|
|||
|
||||
return discounted_rewards
|
||||
|
||||
|
||||
def learn(self):
|
||||
episode_batch = self.memory.recall()
|
||||
state_batch, action_batch, reward_batch, next_state_batch, done_batch, log_prob_batch = zip(*episode_batch)
|
||||
|
||||
# Send batches to the appropriate device
|
||||
state_batch = torch.cat(state_batch).to(self.value_net.device)
|
||||
reward_batch = torch.tensor(reward_batch).to(self.value_net.device)
|
||||
not_done_batch = ~torch.tensor(done_batch).to(self.value_net.device)
|
||||
|
@ -37,12 +33,16 @@ class A2CSingleAgent:
|
|||
log_prob_batch = torch.cat(log_prob_batch).to(self.value_net.device)
|
||||
|
||||
## Value Loss
|
||||
# In A2C, the value loss is the difference between the discounted reward and the value from the first state
|
||||
# The value of the first state is supposed to tell us the expected reward from the current policy of the whole episode
|
||||
value_loss = F.mse_loss(self._discount_rewards(reward_batch).sum(), self.value_net(state_batch[0]))
|
||||
self.value_net.zero_grad()
|
||||
value_loss.backward()
|
||||
self.value_net.step()
|
||||
|
||||
## Policy Loss
|
||||
# Increase probabilities of advantageous states
|
||||
# and decrease the probabilities of non-advantageous ones
|
||||
with torch.no_grad():
|
||||
state_values = self.value_net(state_batch)
|
||||
next_state_values = torch.zeros_like(state_values)
|
||||
|
@ -61,8 +61,7 @@ class A2CSingleAgent:
|
|||
policy_loss.backward()
|
||||
self.policy_net.step()
|
||||
|
||||
|
||||
# Memory is irrelevant for future training
|
||||
# Memory under the old policy is not needed for future training
|
||||
self.memory.clear()
|
||||
|
||||
|
||||
|
|
|
@ -55,6 +55,7 @@ class DQNAgent:
|
|||
|
||||
expected_values = (reward_batch + (self.config['discount_rate'] * best_next_state_value)).unsqueeze(1)
|
||||
|
||||
# If we're sampling by TD error, multiply loss by a importance weight which helps decrease overfitting
|
||||
if (isinstance(self.memory, M.PrioritizedReplayMemory)):
|
||||
loss = (torch.as_tensor(importance_weights, device = self.net.device) * ((obtained_values - expected_values)**2).squeeze(1)).mean()
|
||||
else:
|
||||
|
@ -74,6 +75,7 @@ class DQNAgent:
|
|||
else:
|
||||
self.target_net.sync()
|
||||
|
||||
# If we're sampling by TD error, readjust the weights of the experiences
|
||||
if (isinstance(self.memory, M.PrioritizedReplayMemory)):
|
||||
td_error = (obtained_values - expected_values).detach().abs()
|
||||
self.memory.update_priorities(batch_indexes, td_error)
|
||||
|
|
|
@ -31,6 +31,7 @@ class PPOAgent:
|
|||
episode_batch = self.memory.recall()
|
||||
state_batch, action_batch, reward_batch, next_state_batch, done_batch, log_prob_batch = zip(*episode_batch)
|
||||
|
||||
# Send batches to the appropriate device
|
||||
state_batch = torch.cat(state_batch).to(self.value_net.device)
|
||||
action_batch = torch.tensor(action_batch).to(self.value_net.device)
|
||||
reward_batch = torch.tensor(reward_batch).to(self.value_net.device)
|
||||
|
@ -39,12 +40,16 @@ class PPOAgent:
|
|||
log_prob_batch = torch.cat(log_prob_batch).to(self.value_net.device)
|
||||
|
||||
## Value Loss
|
||||
# In PPO, the value loss is the difference between the discounted reward and the value from the first state
|
||||
# The value of the first state is supposed to tell us the expected reward from the current policy of the whole episode
|
||||
value_loss = F.mse_loss(self._discount_rewards(reward_batch).sum(), self.value_net(state_batch[0]))
|
||||
self.value_net.zero_grad()
|
||||
value_loss.backward()
|
||||
self.value_net.step()
|
||||
|
||||
## Policy Loss
|
||||
# Increase probabilities of advantageous states
|
||||
# and decrease the probabilities of non-advantageous ones
|
||||
with torch.no_grad():
|
||||
state_values = self.value_net(state_batch)
|
||||
next_state_values = torch.zeros_like(state_values)
|
||||
|
@ -56,6 +61,7 @@ class PPOAgent:
|
|||
distributions = list(map(Categorical, action_probabilities))
|
||||
old_log_probs = torch.stack(list(map(lambda distribution, action: distribution.log_prob(action), distributions, action_batch)))
|
||||
|
||||
# For PPO we want to stay within a certain ratio of the old policy
|
||||
policy_ratio = torch.exp(log_prob_batch - old_log_probs) # Equivalent to (log_prob / old_log_prob)
|
||||
policy_loss1 = policy_ratio * advantages
|
||||
policy_loss2 = policy_ratio.clamp(min = 0.8, max = 1.2) * advantages # From original paper
|
||||
|
@ -72,7 +78,7 @@ class PPOAgent:
|
|||
self.policy_net.step()
|
||||
|
||||
|
||||
# Memory is irrelevant for future training
|
||||
# Memory under the old policy is not needed for future training
|
||||
self.memory.clear()
|
||||
|
||||
|
||||
|
|
|
@ -3,7 +3,8 @@ import torch
|
|||
from .Network import Network
|
||||
from copy import deepcopy
|
||||
|
||||
# [TODO] See if you need to move network to device
|
||||
# [TODO] Should we torch.no_grad the __call__?
|
||||
# What if we want to sometimes do gradient descent as well?
|
||||
class ESNetwork(Network):
|
||||
"""
|
||||
Network that functions from the paper Evolutionary Strategies (https://arxiv.org/abs/1703.03864)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue