Added PPOAgent and A2CAgent to the agents submodule.
Also made some small changes to how memories are queried
This commit is contained in:
parent
21b820b401
commit
26084d4c7c
8 changed files with 430 additions and 12 deletions
|
@ -10,8 +10,6 @@ class StochasticSelector(ArgMaxSelector):
|
||||||
self.model = model
|
self.model = model
|
||||||
self.action_size = action_size
|
self.action_size = action_size
|
||||||
self.device = device
|
self.device = device
|
||||||
if not isinstance(memory, rltorch.memory.EpisodeMemory):
|
|
||||||
raise ValueError("Memory must be of instance EpisodeMemory")
|
|
||||||
self.memory = memory
|
self.memory = memory
|
||||||
def best_act(self, state, log_prob = True):
|
def best_act(self, state, log_prob = True):
|
||||||
if self.device is not None:
|
if self.device is not None:
|
||||||
|
@ -19,6 +17,6 @@ class StochasticSelector(ArgMaxSelector):
|
||||||
action_probabilities = self.model(state)
|
action_probabilities = self.model(state)
|
||||||
distribution = Categorical(action_probabilities)
|
distribution = Categorical(action_probabilities)
|
||||||
action = distribution.sample()
|
action = distribution.sample()
|
||||||
if log_prob:
|
if log_prob and isinstance(self.memory, rltorch.memory.EpisodeMemory):
|
||||||
self.memory.append_log_probs(distribution.log_prob(action))
|
self.memory.append_log_probs(distribution.log_prob(action))
|
||||||
return action.item()
|
return action.item()
|
73
rltorch/agents/A2CSingleAgent.py
Normal file
73
rltorch/agents/A2CSingleAgent.py
Normal file
|
@ -0,0 +1,73 @@
|
||||||
|
# Deprecated since the idea of the idea shouldn't work without having some sort of "mental model" of the environment
|
||||||
|
|
||||||
|
from copy import deepcopy
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch.distributions import Categorical
|
||||||
|
import rltorch
|
||||||
|
import rltorch.memory as M
|
||||||
|
import collections
|
||||||
|
import random
|
||||||
|
|
||||||
|
class A2CSingleAgent:
|
||||||
|
def __init__(self, policy_net, value_net, memory, config, logger = None):
|
||||||
|
self.policy_net = policy_net
|
||||||
|
self.value_net = value_net
|
||||||
|
self.memory = memory
|
||||||
|
self.config = deepcopy(config)
|
||||||
|
self.logger = logger
|
||||||
|
|
||||||
|
def _discount_rewards(self, rewards):
|
||||||
|
discounted_rewards = torch.zeros_like(rewards)
|
||||||
|
running_add = 0
|
||||||
|
for t in reversed(range(len(rewards))):
|
||||||
|
running_add = running_add * self.config['discount_rate'] + rewards[t]
|
||||||
|
discounted_rewards[t] = running_add
|
||||||
|
|
||||||
|
return discounted_rewards
|
||||||
|
|
||||||
|
|
||||||
|
def learn(self):
|
||||||
|
if len(self.memory) < self.config['batch_size']:
|
||||||
|
return
|
||||||
|
|
||||||
|
episode_batch = self.memory.recall()
|
||||||
|
state_batch, action_batch, reward_batch, next_state_batch, done_batch, log_prob_batch = zip(*episode_batch)
|
||||||
|
|
||||||
|
state_batch = torch.cat(state_batch).to(self.value_net.device)
|
||||||
|
reward_batch = torch.tensor(reward_batch).to(self.value_net.device)
|
||||||
|
not_done_batch = ~torch.tensor(done_batch).to(self.value_net.device)
|
||||||
|
next_state_batch = torch.cat(next_state_batch).to(self.value_net.device)
|
||||||
|
log_prob_batch = torch.cat(log_prob_batch).to(self.value_net.device)
|
||||||
|
|
||||||
|
## Value Loss
|
||||||
|
value_loss = F.mse_loss(self._discount_rewards(reward_batch), self.value_net(state_batch[0]))
|
||||||
|
self.value_net.zero_grad()
|
||||||
|
value_loss.backward()
|
||||||
|
self.value_net.step()
|
||||||
|
|
||||||
|
## Policy Loss
|
||||||
|
with torch.no_grad():
|
||||||
|
state_values = self.value_net(state_batch)
|
||||||
|
next_state_values = torch.zeros_like(state_values)
|
||||||
|
next_state_values[not_done_batch] = self.value_net(next_state_batch[not_done_batch])
|
||||||
|
advantages = (reward_batch.unsqueeze(1) + self.config['discount_rate'] * next_state_values) - state_values
|
||||||
|
advantages = advantages.squeeze(1)
|
||||||
|
|
||||||
|
policy_loss = (-log_prob_batch * advantages).sum()
|
||||||
|
|
||||||
|
if self.logger is not None:
|
||||||
|
self.logger.append("Loss/Policy", policy_loss.item())
|
||||||
|
self.logger.append("Loss/Value", value_loss.item())
|
||||||
|
|
||||||
|
|
||||||
|
self.policy_net.zero_grad()
|
||||||
|
policy_loss.backward()
|
||||||
|
self.policy_net.step()
|
||||||
|
|
||||||
|
|
||||||
|
# Memory is irrelevant for future training
|
||||||
|
self.memory.clear()
|
||||||
|
|
||||||
|
|
|
@ -34,27 +34,29 @@ class DQNAgent:
|
||||||
next_state_batch = next_state_batch.to(self.net.device)
|
next_state_batch = next_state_batch.to(self.net.device)
|
||||||
not_done_batch = not_done_batch.to(self.net.device)
|
not_done_batch = not_done_batch.to(self.net.device)
|
||||||
|
|
||||||
obtained_values = self.net(state_batch).gather(1, action_batch.view(self.config['batch_size'], 1))
|
state_values = self.net(state_batch)
|
||||||
|
obtained_values = state_values.gather(1, action_batch.view(self.config['batch_size'], 1))
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
# Use the target net to produce action values for the next state
|
# Use the target net to produce action values for the next state
|
||||||
# and the regular net to select the action
|
# and the regular net to select the action
|
||||||
# That way we decouple the value and action selecting processes (DOUBLE DQN)
|
# That way we decouple the value and action selecting processes (DOUBLE DQN)
|
||||||
not_done_size = not_done_batch.sum()
|
not_done_size = not_done_batch.sum()
|
||||||
|
next_state_values = torch.zeros_like(state_values, device = self.net.device)
|
||||||
if self.target_net is not None:
|
if self.target_net is not None:
|
||||||
next_state_values = self.target_net(next_state_batch)
|
next_state_values[not_done_batch] = self.target_net(next_state_batch[not_done_batch])
|
||||||
next_best_action = self.net(next_state_batch).argmax(1)
|
next_best_action = self.net(next_state_batch[not_done_batch]).argmax(1)
|
||||||
else:
|
else:
|
||||||
next_state_values = self.net(next_state_batch)
|
next_state_values[not_done_batch] = self.net(next_state_batch[not_done_batch])
|
||||||
next_best_action = next_state_values.argmax(1)
|
next_best_action = next_state_values[not_done_batch].argmax(1)
|
||||||
|
|
||||||
best_next_state_value = torch.zeros(self.config['batch_size'], device = self.net.device)
|
best_next_state_value = torch.zeros(self.config['batch_size'], device = self.net.device)
|
||||||
best_next_state_value[not_done_batch] = next_state_values.gather(1, next_best_action.view((not_done_size, 1))).squeeze(1)
|
best_next_state_value[not_done_batch] = next_state_values[not_done_batch].gather(1, next_best_action.view((not_done_size, 1))).squeeze(1)
|
||||||
|
|
||||||
expected_values = (reward_batch + (self.config['discount_rate'] * best_next_state_value)).unsqueeze(1)
|
expected_values = (reward_batch + (self.config['discount_rate'] * best_next_state_value)).unsqueeze(1)
|
||||||
|
|
||||||
if (isinstance(self.memory, M.PrioritizedReplayMemory)):
|
if (isinstance(self.memory, M.PrioritizedReplayMemory)):
|
||||||
loss = (torch.as_tensor(importance_weights, device = self.net.device) * (obtained_values - expected_values)**2).mean()
|
loss = (torch.as_tensor(importance_weights, device = self.net.device) * ((obtained_values - expected_values)**2).squeeze(1)).mean()
|
||||||
else:
|
else:
|
||||||
loss = F.mse_loss(obtained_values, expected_values)
|
loss = F.mse_loss(obtained_values, expected_values)
|
||||||
|
|
||||||
|
|
83
rltorch/agents/PPOAgent.py
Normal file
83
rltorch/agents/PPOAgent.py
Normal file
|
@ -0,0 +1,83 @@
|
||||||
|
# Deprecated since the idea of the idea shouldn't work without having some sort of "mental model" of the environment
|
||||||
|
|
||||||
|
from copy import deepcopy
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch.distributions import Categorical
|
||||||
|
import rltorch
|
||||||
|
import rltorch.memory as M
|
||||||
|
import collections
|
||||||
|
import random
|
||||||
|
|
||||||
|
class PPOAgent:
|
||||||
|
def __init__(self, policy_net, value_net, memory, config, logger = None):
|
||||||
|
self.policy_net = policy_net
|
||||||
|
self.old_policy_net = rltorch.network.TargetNetwork(policy_net)
|
||||||
|
self.value_net = value_net
|
||||||
|
self.memory = memory
|
||||||
|
self.config = deepcopy(config)
|
||||||
|
self.logger = logger
|
||||||
|
|
||||||
|
def _discount_rewards(self, rewards):
|
||||||
|
discounted_rewards = torch.zeros_like(rewards)
|
||||||
|
running_add = 0
|
||||||
|
for t in reversed(range(len(rewards))):
|
||||||
|
running_add = running_add * self.config['discount_rate'] + rewards[t]
|
||||||
|
discounted_rewards[t] = running_add
|
||||||
|
|
||||||
|
return discounted_rewards
|
||||||
|
|
||||||
|
|
||||||
|
def learn(self):
|
||||||
|
if len(self.memory) < self.config['batch_size']:
|
||||||
|
return
|
||||||
|
|
||||||
|
episode_batch = self.memory.recall()
|
||||||
|
state_batch, action_batch, reward_batch, next_state_batch, done_batch, log_prob_batch = zip(*episode_batch)
|
||||||
|
|
||||||
|
state_batch = torch.cat(state_batch).to(self.value_net.device)
|
||||||
|
action_batch = torch.tensor(action_batch).to(self.value_net.device)
|
||||||
|
reward_batch = torch.tensor(reward_batch).to(self.value_net.device)
|
||||||
|
not_done_batch = ~torch.tensor(done_batch).to(self.value_net.device)
|
||||||
|
next_state_batch = torch.cat(next_state_batch).to(self.value_net.device)
|
||||||
|
log_prob_batch = torch.cat(log_prob_batch).to(self.value_net.device)
|
||||||
|
|
||||||
|
## Value Loss
|
||||||
|
value_loss = F.mse_loss(self._discount_rewards(reward_batch), self.value_net(state_batch[0]))
|
||||||
|
self.value_net.zero_grad()
|
||||||
|
value_loss.backward()
|
||||||
|
self.value_net.step()
|
||||||
|
|
||||||
|
## Policy Loss
|
||||||
|
with torch.no_grad():
|
||||||
|
state_values = self.value_net(state_batch)
|
||||||
|
next_state_values = torch.zeros_like(state_values)
|
||||||
|
next_state_values[not_done_batch] = self.value_net(next_state_batch[not_done_batch])
|
||||||
|
advantages = (reward_batch.unsqueeze(1) + self.config['discount_rate'] * next_state_values) - state_values
|
||||||
|
advantages = advantages.squeeze(1)
|
||||||
|
|
||||||
|
action_probabilities = self.old_policy_net(state_batch).detach()
|
||||||
|
distributions = list(map(Categorical, action_probabilities))
|
||||||
|
old_log_probs = torch.stack(list(map(lambda distribution, action: distribution.log_prob(action), distributions, action_batch)))
|
||||||
|
|
||||||
|
policy_ratio = torch.exp(log_prob_batch - old_log_probs) # Equivalent to (log_prob / old_log_prob)
|
||||||
|
policy_loss1 = policy_ratio * advantages
|
||||||
|
policy_loss2 = policy_ratio.clamp(min = 0.8, max = 1.2) * advantages # From original paper
|
||||||
|
policy_loss = -torch.min(policy_loss1, policy_loss2).sum()
|
||||||
|
|
||||||
|
if self.logger is not None:
|
||||||
|
self.logger.append("Loss/Policy", policy_loss.item())
|
||||||
|
self.logger.append("Loss/Value", value_loss.item())
|
||||||
|
|
||||||
|
|
||||||
|
self.old_policy_net.sync()
|
||||||
|
self.policy_net.zero_grad()
|
||||||
|
policy_loss.backward()
|
||||||
|
self.policy_net.step()
|
||||||
|
|
||||||
|
|
||||||
|
# Memory is irrelevant for future training
|
||||||
|
self.memory.clear()
|
||||||
|
|
||||||
|
|
|
@ -31,7 +31,7 @@ class REINFORCEAgent:
|
||||||
discount_reward_batch = self._discount_rewards(torch.tensor(reward_batch))
|
discount_reward_batch = self._discount_rewards(torch.tensor(reward_batch))
|
||||||
log_prob_batch = torch.cat(log_prob_batch)
|
log_prob_batch = torch.cat(log_prob_batch)
|
||||||
|
|
||||||
policy_loss = (-1 * log_prob_batch * discount_reward_batch).sum()
|
policy_loss = (-log_prob_batch * discount_reward_batch).sum()
|
||||||
|
|
||||||
if self.logger is not None:
|
if self.logger is not None:
|
||||||
self.logger.append("Loss", policy_loss.item())
|
self.logger.append("Loss", policy_loss.item())
|
||||||
|
|
260
rltorch/agents/_A2CSingleAgent.py
Normal file
260
rltorch/agents/_A2CSingleAgent.py
Normal file
|
@ -0,0 +1,260 @@
|
||||||
|
# Deprecated since the idea of the idea shouldn't work without having some sort of "mental model" of the environment
|
||||||
|
|
||||||
|
from copy import deepcopy
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch.distributions import Categorical
|
||||||
|
import rltorch
|
||||||
|
import rltorch.memory as M
|
||||||
|
import collections
|
||||||
|
import random
|
||||||
|
|
||||||
|
class A2CSingleAgent:
|
||||||
|
def __init__(self, policy_net, value_net, memory, config, target_value_net = None, logger = None):
|
||||||
|
self.policy_net = policy_net
|
||||||
|
self.value_net = value_net
|
||||||
|
self.memory = memory
|
||||||
|
self.config = deepcopy(config)
|
||||||
|
self.target_value_net = target_value_net
|
||||||
|
self.logger = logger
|
||||||
|
|
||||||
|
def learn_value(self):
|
||||||
|
if (isinstance(self.memory, M.PrioritizedReplayMemory)):
|
||||||
|
weight_importance = self.config['prioritized_replay_weight_importance']
|
||||||
|
# If it's a scheduler then get the next value by calling next, otherwise just use it's value
|
||||||
|
beta = next(weight_importance) if isinstance(weight_importance, collections.Iterable) else weight_importance
|
||||||
|
minibatch = self.memory.sample(self.config['batch_size'], beta = beta)
|
||||||
|
state_batch, action_batch, reward_batch, next_state_batch, not_done_batch, importance_weights, batch_indexes = M.zip_batch(minibatch, priority = True)
|
||||||
|
else:
|
||||||
|
minibatch = self.memory.sample(self.config['batch_size'])
|
||||||
|
state_batch, action_batch, reward_batch, next_state_batch, not_done_batch = M.zip_batch(minibatch)
|
||||||
|
|
||||||
|
# Send to their appropriate devices
|
||||||
|
state_batch = state_batch.to(self.value_net.device)
|
||||||
|
action_batch = action_batch.to(self.value_net.device)
|
||||||
|
reward_batch = reward_batch.to(self.value_net.device)
|
||||||
|
next_state_batch = next_state_batch.to(self.value_net.device)
|
||||||
|
not_done_batch = not_done_batch.to(self.value_net.device)
|
||||||
|
|
||||||
|
|
||||||
|
## Value Loss
|
||||||
|
state_values = self.value_net(state_batch)
|
||||||
|
obtained_values = state_values.gather(1, action_batch.view(self.config['batch_size'], 1))
|
||||||
|
with torch.no_grad():
|
||||||
|
# Use the target net to produce action values for the next state
|
||||||
|
# and the regular net to select the action
|
||||||
|
# That way we decouple the value and action selecting processes (DOUBLE DQN)
|
||||||
|
not_done_size = not_done_batch.sum()
|
||||||
|
next_state_values = torch.zeros_like(state_values)
|
||||||
|
if self.target_value_net is not None:
|
||||||
|
next_state_values[not_done_batch] = self.target_value_net(next_state_batch[not_done_batch])
|
||||||
|
next_best_action = self.value_net(next_state_batch).argmax(1)
|
||||||
|
else:
|
||||||
|
next_state_values[not_done_batch] = self.value_net(next_state_batch[not_done_batch])
|
||||||
|
next_best_action = next_state_values.argmax(1)
|
||||||
|
|
||||||
|
best_next_state_value = torch.zeros(self.config['batch_size'], device = self.value_net.device)
|
||||||
|
# best_next_state_value[not_done_batch] = next_state_values.gather(1, next_best_action.view((not_done_size, 1))).squeeze(1)
|
||||||
|
best_next_state_value[not_done_batch] = next_state_values[not_done_batch].gather(1, next_best_action[not_done_batch].view((not_done_size, 1))).squeeze(1)
|
||||||
|
|
||||||
|
expected_values = (reward_batch + (self.config['discount_rate'] * best_next_state_value)).unsqueeze(1)
|
||||||
|
|
||||||
|
if (isinstance(self.memory, M.PrioritizedReplayMemory)):
|
||||||
|
importance_weights = torch.as_tensor(importance_weights, device = self.value_net.device)
|
||||||
|
value_loss = (importance_weights * ((obtained_values - expected_values)**2).squeeze(1)).mean()
|
||||||
|
else:
|
||||||
|
value_loss = F.mse_loss(obtained_values, expected_values)
|
||||||
|
|
||||||
|
if (isinstance(self.memory, M.PrioritizedReplayMemory)):
|
||||||
|
td_error = (obtained_values - expected_values).detach().abs()
|
||||||
|
self.memory.update_priorities(batch_indexes, td_error)
|
||||||
|
|
||||||
|
self.value_net.zero_grad()
|
||||||
|
value_loss.backward()
|
||||||
|
self.value_net.step()
|
||||||
|
|
||||||
|
if self.target_value_net is not None:
|
||||||
|
if 'target_sync_tau' in self.config:
|
||||||
|
self.target_value_net.partial_sync(self.config['target_sync_tau'])
|
||||||
|
else:
|
||||||
|
self.target_value_net.sync()
|
||||||
|
|
||||||
|
if self.logger is not None:
|
||||||
|
self.logger.append("Loss/Value", value_loss.item())
|
||||||
|
|
||||||
|
|
||||||
|
def learn_policy(self):
|
||||||
|
starting_index = random.randint(0, len(self.memory) - self.config['batch_size'])
|
||||||
|
state_batch, action_batch, reward_batch, next_state_batch, not_done_batch = M.zip_batch(self.memory[starting_index:(starting_index + self.config['batch_size'])])
|
||||||
|
|
||||||
|
state_batch = state_batch.to(self.policy_net.device)
|
||||||
|
action_batch = action_batch.to(self.policy_net.device)
|
||||||
|
reward_batch = reward_batch.to(self.policy_net.device)
|
||||||
|
next_state_batch = next_state_batch.to(self.policy_net.device)
|
||||||
|
not_done_batch = not_done_batch.to(self.policy_net.device)
|
||||||
|
|
||||||
|
# Find when episode ends and filter out the Transitions after
|
||||||
|
episode_ends = (~not_done_batch).nonzero().squeeze(1)
|
||||||
|
start_idx = 0
|
||||||
|
end_idx = self.config['batch_size']
|
||||||
|
if len(episode_ends) > 0:
|
||||||
|
if (episode_ends[0] == 0).item():
|
||||||
|
if len(episode_ends) > 1:
|
||||||
|
start_idx = 1
|
||||||
|
end_idx = episode_ends[1] + 1
|
||||||
|
else:
|
||||||
|
start_idx = 1
|
||||||
|
else:
|
||||||
|
end_idx = episode_ends[0] + 1
|
||||||
|
batch_size = end_idx - start_idx
|
||||||
|
|
||||||
|
# Now filter...
|
||||||
|
state_batch = state_batch[start_idx:end_idx]
|
||||||
|
action_batch = action_batch[start_idx:end_idx]
|
||||||
|
reward_batch = reward_batch[start_idx:end_idx]
|
||||||
|
next_state_batch = next_state_batch[start_idx:end_idx]
|
||||||
|
not_done_batch = not_done_batch[start_idx:end_idx]
|
||||||
|
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
if self.target_value_net is not None:
|
||||||
|
state_values = self.target_value_net(state_batch)
|
||||||
|
next_state_values = torch.zeros_like(state_values, device = self.value_net.device)
|
||||||
|
next_state_values[not_done_batch] = self.target_value_net(next_state_batch[not_done_batch])
|
||||||
|
else:
|
||||||
|
state_values = self.value_net(state_batch)
|
||||||
|
next_state_values = torch.zeros_like(state_values, device = self.value_net.device)
|
||||||
|
next_state_values[not_done_batch] = self.value_net(next_state_batch[not_done_batch])
|
||||||
|
|
||||||
|
obtained_values = state_values.gather(1, action_batch.view(batch_size, 1))
|
||||||
|
approx_state_action_values = reward_batch.unsqueeze(1) + self.config['discount_rate'] * next_state_values
|
||||||
|
advantage = (obtained_values - approx_state_action_values.mean(1).unsqueeze(1))
|
||||||
|
# Scale and squeeze the dimension
|
||||||
|
advantage = advantage.squeeze(1)
|
||||||
|
# advantage = (advantage / (state_values.std() + np.finfo('float').eps)).squeeze(1)
|
||||||
|
action_probabilities = self.policy_net(state_batch)
|
||||||
|
distributions = list(map(Categorical, action_probabilities))
|
||||||
|
log_probs = torch.stack(list(map(lambda distribution, action: distribution.log_prob(action), distributions, action_batch)))
|
||||||
|
policy_loss = (-log_probs * advantage).mean()
|
||||||
|
|
||||||
|
self.policy_net.zero_grad()
|
||||||
|
policy_loss.backward()
|
||||||
|
self.policy_net.step()
|
||||||
|
|
||||||
|
if self.logger is not None:
|
||||||
|
self.logger.append("Loss/Policy", policy_loss.item())
|
||||||
|
|
||||||
|
def learn(self):
|
||||||
|
if len(self.memory) < self.config['batch_size']:
|
||||||
|
return
|
||||||
|
self.learn_value()
|
||||||
|
self.learn_policy()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# def learn(self):
|
||||||
|
# if len(self.memory) < self.config['batch_size']:
|
||||||
|
# return
|
||||||
|
|
||||||
|
# if (isinstance(self.memory, M.PrioritizedReplayMemory)):
|
||||||
|
# weight_importance = self.config['prioritized_replay_weight_importance']
|
||||||
|
# # If it's a scheduler then get the next value by calling next, otherwise just use it's value
|
||||||
|
# beta = next(weight_importance) if isinstance(weight_importance, collections.Iterable) else weight_importance
|
||||||
|
# minibatch = self.memory.sample(self.config['batch_size'], beta = beta)
|
||||||
|
# state_batch, action_batch, reward_batch, next_state_batch, not_done_batch, importance_weights, batch_indexes = M.zip_batch(minibatch, priority = True)
|
||||||
|
# else:
|
||||||
|
# minibatch = self.memory.sample(self.config['batch_size'])
|
||||||
|
# state_batch, action_batch, reward_batch, next_state_batch, not_done_batch = M.zip_batch(minibatch)
|
||||||
|
|
||||||
|
# # Send to their appropriate devices
|
||||||
|
# # [TODO] Notice how we're sending it to the value_net's device, what if policy_net was on a different device?
|
||||||
|
# state_batch = state_batch.to(self.value_net.device)
|
||||||
|
# action_batch = action_batch.to(self.value_net.device)
|
||||||
|
# reward_batch = reward_batch.to(self.value_net.device)
|
||||||
|
# next_state_batch = next_state_batch.to(self.value_net.device)
|
||||||
|
# not_done_batch = not_done_batch.to(self.value_net.device)
|
||||||
|
|
||||||
|
|
||||||
|
# ## Value Loss
|
||||||
|
|
||||||
|
# obtained_values = self.value_net(state_batch).gather(1, action_batch.view(self.config['batch_size'], 1))
|
||||||
|
|
||||||
|
# with torch.no_grad():
|
||||||
|
# # Use the target net to produce action values for the next state
|
||||||
|
# # and the regular net to select the action
|
||||||
|
# # That way we decouple the value and action selecting processes (DOUBLE DQN)
|
||||||
|
# not_done_size = not_done_batch.sum()
|
||||||
|
# if self.target_value_net is not None:
|
||||||
|
# next_state_values = self.target_value_net(next_state_batch)
|
||||||
|
# next_best_action = self.value_net(next_state_batch).argmax(1)
|
||||||
|
# else:
|
||||||
|
# next_state_values = self.value_net(next_state_batch)
|
||||||
|
# next_best_action = next_state_values.argmax(1)
|
||||||
|
|
||||||
|
# best_next_state_value = torch.zeros(self.config['batch_size'], device = self.value_net.device)
|
||||||
|
# best_next_state_value[not_done_batch] = next_state_values.gather(1, next_best_action.view((not_done_size, 1))).squeeze(1)
|
||||||
|
|
||||||
|
# expected_values = (reward_batch + (self.config['discount_rate'] * best_next_state_value)).unsqueeze(1)
|
||||||
|
|
||||||
|
# if (isinstance(self.memory, M.PrioritizedReplayMemory)):
|
||||||
|
# importance_weights = torch.as_tensor(importance_weights, device = self.value_net.device)
|
||||||
|
# value_loss = (importance_weights * ((obtained_values - expected_values)**2).squeeze(1)).mean()
|
||||||
|
# else:
|
||||||
|
# value_loss = F.mse_loss(obtained_values, expected_values)
|
||||||
|
|
||||||
|
# self.value_net.zero_grad()
|
||||||
|
# value_loss.backward()
|
||||||
|
# self.value_net.step()
|
||||||
|
|
||||||
|
# if self.target_value_net is not None:
|
||||||
|
# if 'target_sync_tau' in self.config:
|
||||||
|
# self.target_value_net.partial_sync(self.config['target_sync_tau'])
|
||||||
|
# else:
|
||||||
|
# self.target_value_net.sync()
|
||||||
|
|
||||||
|
# if (isinstance(self.memory, M.PrioritizedReplayMemory)):
|
||||||
|
# td_error = (obtained_values - expected_values).detach().abs()
|
||||||
|
# self.memory.update_priorities(batch_indexes, td_error)
|
||||||
|
|
||||||
|
# if self.logger is not None:
|
||||||
|
# self.logger.append("ValueLoss", value_loss.item())
|
||||||
|
|
||||||
|
# ## Policy Loss
|
||||||
|
# with torch.no_grad():
|
||||||
|
# state_values = self.value_net(state_batch)
|
||||||
|
# if self.target_value_net is not None:
|
||||||
|
# next_state_values = self.target_value_net(next_state_batch)
|
||||||
|
# else:
|
||||||
|
# next_state_values = self.value_net(next_state_batch)
|
||||||
|
|
||||||
|
# state_action_values = state_values.gather(1, action_batch.view(self.config['batch_size'], 1))
|
||||||
|
# average_next_state_values = torch.zeros(self.config['batch_size'], device = self.value_net.device)
|
||||||
|
# average_next_state_values[not_done_batch] = next_state_values.mean(1)
|
||||||
|
|
||||||
|
# advantage = (state_action_values - (reward_batch + self.config['discount_rate'] * average_next_state_values).unsqueeze(1))
|
||||||
|
# # Scale and squeeze the dimension
|
||||||
|
# advantage = advantage.squeeze(1)
|
||||||
|
# # advantage = (advantage / (state_values.std() + np.finfo('float').eps)).squeeze(1)
|
||||||
|
# action_probabilities = self.policy_net(state_batch)
|
||||||
|
# distributions = list(map(Categorical, action_probabilities))
|
||||||
|
# log_probs = torch.stack(list(map(lambda distribution, action: distribution.log_prob(action), distributions, action_batch)))
|
||||||
|
# if (isinstance(self.memory, M.PrioritizedReplayMemory)):
|
||||||
|
# policy_loss = (importance_weights * -log_probs * advantage).sum()
|
||||||
|
# else:
|
||||||
|
# policy_loss = (-log_probs * advantage).sum()
|
||||||
|
|
||||||
|
# self.policy_net.zero_grad()
|
||||||
|
# policy_loss.backward()
|
||||||
|
# self.policy_net.step()
|
||||||
|
|
||||||
|
# if self.logger is not None:
|
||||||
|
# self.logger.append("PolicyLoss", policy_loss.item())
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,2 +1,4 @@
|
||||||
|
from .A2CSingleAgent import *
|
||||||
from .DQNAgent import *
|
from .DQNAgent import *
|
||||||
|
from .PPOAgent import *
|
||||||
from .REINFORCEAgent import *
|
from .REINFORCEAgent import *
|
|
@ -67,7 +67,7 @@ def zip_batch(minibatch, priority = False):
|
||||||
action_batch = torch.tensor(action_batch)
|
action_batch = torch.tensor(action_batch)
|
||||||
reward_batch = torch.tensor(reward_batch)
|
reward_batch = torch.tensor(reward_batch)
|
||||||
not_done_batch = ~torch.tensor(done_batch)
|
not_done_batch = ~torch.tensor(done_batch)
|
||||||
next_state_batch = torch.cat(next_state_batch)[not_done_batch]
|
next_state_batch = torch.cat(next_state_batch)
|
||||||
|
|
||||||
if priority:
|
if priority:
|
||||||
return state_batch, action_batch, reward_batch, next_state_batch, not_done_batch, weights, indexes
|
return state_batch, action_batch, reward_batch, next_state_batch, not_done_batch, weights, indexes
|
||||||
|
|
Loading…
Add table
Reference in a new issue