Cleaned up scripts, added more comments

This commit is contained in:
Brandon Rozek 2019-03-04 17:09:46 -05:00
parent e42f5bba1b
commit a59f84b446
11 changed files with 103 additions and 436 deletions

View file

@ -1,12 +1,8 @@
from copy import deepcopy
import numpy as np
import torch
import torch.nn.functional as F
from torch.distributions import Categorical
import rltorch
import rltorch.memory as M
import collections
import random
class A2CSingleAgent:
def __init__(self, policy_net, value_net, memory, config, logger = None):
@ -25,11 +21,11 @@ class A2CSingleAgent:
return discounted_rewards
def learn(self):
episode_batch = self.memory.recall()
state_batch, action_batch, reward_batch, next_state_batch, done_batch, log_prob_batch = zip(*episode_batch)
# Send batches to the appropriate device
state_batch = torch.cat(state_batch).to(self.value_net.device)
reward_batch = torch.tensor(reward_batch).to(self.value_net.device)
not_done_batch = ~torch.tensor(done_batch).to(self.value_net.device)
@ -37,12 +33,16 @@ class A2CSingleAgent:
log_prob_batch = torch.cat(log_prob_batch).to(self.value_net.device)
## Value Loss
# In A2C, the value loss is the difference between the discounted reward and the value from the first state
# The value of the first state is supposed to tell us the expected reward from the current policy of the whole episode
value_loss = F.mse_loss(self._discount_rewards(reward_batch).sum(), self.value_net(state_batch[0]))
self.value_net.zero_grad()
value_loss.backward()
self.value_net.step()
## Policy Loss
# Increase probabilities of advantageous states
# and decrease the probabilities of non-advantageous ones
with torch.no_grad():
state_values = self.value_net(state_batch)
next_state_values = torch.zeros_like(state_values)
@ -61,8 +61,7 @@ class A2CSingleAgent:
policy_loss.backward()
self.policy_net.step()
# Memory is irrelevant for future training
# Memory under the old policy is not needed for future training
self.memory.clear()

View file

@ -55,6 +55,7 @@ class DQNAgent:
expected_values = (reward_batch + (self.config['discount_rate'] * best_next_state_value)).unsqueeze(1)
# If we're sampling by TD error, multiply loss by a importance weight which helps decrease overfitting
if (isinstance(self.memory, M.PrioritizedReplayMemory)):
loss = (torch.as_tensor(importance_weights, device = self.net.device) * ((obtained_values - expected_values)**2).squeeze(1)).mean()
else:
@ -74,6 +75,7 @@ class DQNAgent:
else:
self.target_net.sync()
# If we're sampling by TD error, readjust the weights of the experiences
if (isinstance(self.memory, M.PrioritizedReplayMemory)):
td_error = (obtained_values - expected_values).detach().abs()
self.memory.update_priorities(batch_indexes, td_error)

View file

@ -31,6 +31,7 @@ class PPOAgent:
episode_batch = self.memory.recall()
state_batch, action_batch, reward_batch, next_state_batch, done_batch, log_prob_batch = zip(*episode_batch)
# Send batches to the appropriate device
state_batch = torch.cat(state_batch).to(self.value_net.device)
action_batch = torch.tensor(action_batch).to(self.value_net.device)
reward_batch = torch.tensor(reward_batch).to(self.value_net.device)
@ -39,12 +40,16 @@ class PPOAgent:
log_prob_batch = torch.cat(log_prob_batch).to(self.value_net.device)
## Value Loss
# In PPO, the value loss is the difference between the discounted reward and the value from the first state
# The value of the first state is supposed to tell us the expected reward from the current policy of the whole episode
value_loss = F.mse_loss(self._discount_rewards(reward_batch).sum(), self.value_net(state_batch[0]))
self.value_net.zero_grad()
value_loss.backward()
self.value_net.step()
## Policy Loss
# Increase probabilities of advantageous states
# and decrease the probabilities of non-advantageous ones
with torch.no_grad():
state_values = self.value_net(state_batch)
next_state_values = torch.zeros_like(state_values)
@ -56,6 +61,7 @@ class PPOAgent:
distributions = list(map(Categorical, action_probabilities))
old_log_probs = torch.stack(list(map(lambda distribution, action: distribution.log_prob(action), distributions, action_batch)))
# For PPO we want to stay within a certain ratio of the old policy
policy_ratio = torch.exp(log_prob_batch - old_log_probs) # Equivalent to (log_prob / old_log_prob)
policy_loss1 = policy_ratio * advantages
policy_loss2 = policy_ratio.clamp(min = 0.8, max = 1.2) * advantages # From original paper
@ -72,7 +78,7 @@ class PPOAgent:
self.policy_net.step()
# Memory is irrelevant for future training
# Memory under the old policy is not needed for future training
self.memory.clear()

View file

@ -3,7 +3,8 @@ import torch
from .Network import Network
from copy import deepcopy
# [TODO] See if you need to move network to device
# [TODO] Should we torch.no_grad the __call__?
# What if we want to sometimes do gradient descent as well?
class ESNetwork(Network):
"""
Network that functions from the paper Evolutionary Strategies (https://arxiv.org/abs/1703.03864)