From 26084d4c7c4e191e6bee186e173ddee6139aed6d Mon Sep 17 00:00:00 2001 From: Brandon Rozek Date: Tue, 19 Feb 2019 20:54:30 -0500 Subject: [PATCH] Added PPOAgent and A2CAgent to the agents submodule. Also made some small changes to how memories are queried --- rltorch/action_selector/StochasticSelector.py | 4 +- rltorch/agents/A2CSingleAgent.py | 73 +++++ rltorch/agents/DQNAgent.py | 16 +- rltorch/agents/PPOAgent.py | 83 ++++++ rltorch/agents/REINFORCEAgent.py | 2 +- rltorch/agents/_A2CSingleAgent.py | 260 ++++++++++++++++++ rltorch/agents/__init__.py | 2 + rltorch/memory/ReplayMemory.py | 2 +- 8 files changed, 430 insertions(+), 12 deletions(-) create mode 100644 rltorch/agents/A2CSingleAgent.py create mode 100644 rltorch/agents/PPOAgent.py create mode 100644 rltorch/agents/_A2CSingleAgent.py diff --git a/rltorch/action_selector/StochasticSelector.py b/rltorch/action_selector/StochasticSelector.py index e9b7019..94d524f 100644 --- a/rltorch/action_selector/StochasticSelector.py +++ b/rltorch/action_selector/StochasticSelector.py @@ -10,8 +10,6 @@ class StochasticSelector(ArgMaxSelector): self.model = model self.action_size = action_size self.device = device - if not isinstance(memory, rltorch.memory.EpisodeMemory): - raise ValueError("Memory must be of instance EpisodeMemory") self.memory = memory def best_act(self, state, log_prob = True): if self.device is not None: @@ -19,6 +17,6 @@ class StochasticSelector(ArgMaxSelector): action_probabilities = self.model(state) distribution = Categorical(action_probabilities) action = distribution.sample() - if log_prob: + if log_prob and isinstance(self.memory, rltorch.memory.EpisodeMemory): self.memory.append_log_probs(distribution.log_prob(action)) return action.item() \ No newline at end of file diff --git a/rltorch/agents/A2CSingleAgent.py b/rltorch/agents/A2CSingleAgent.py new file mode 100644 index 0000000..8cb8047 --- /dev/null +++ b/rltorch/agents/A2CSingleAgent.py @@ -0,0 +1,73 @@ +# Deprecated since the idea of the idea shouldn't work without having some sort of "mental model" of the environment + +from copy import deepcopy +import numpy as np +import torch +import torch.nn.functional as F +from torch.distributions import Categorical +import rltorch +import rltorch.memory as M +import collections +import random + +class A2CSingleAgent: + def __init__(self, policy_net, value_net, memory, config, logger = None): + self.policy_net = policy_net + self.value_net = value_net + self.memory = memory + self.config = deepcopy(config) + self.logger = logger + + def _discount_rewards(self, rewards): + discounted_rewards = torch.zeros_like(rewards) + running_add = 0 + for t in reversed(range(len(rewards))): + running_add = running_add * self.config['discount_rate'] + rewards[t] + discounted_rewards[t] = running_add + + return discounted_rewards + + + def learn(self): + if len(self.memory) < self.config['batch_size']: + return + + episode_batch = self.memory.recall() + state_batch, action_batch, reward_batch, next_state_batch, done_batch, log_prob_batch = zip(*episode_batch) + + state_batch = torch.cat(state_batch).to(self.value_net.device) + reward_batch = torch.tensor(reward_batch).to(self.value_net.device) + not_done_batch = ~torch.tensor(done_batch).to(self.value_net.device) + next_state_batch = torch.cat(next_state_batch).to(self.value_net.device) + log_prob_batch = torch.cat(log_prob_batch).to(self.value_net.device) + + ## Value Loss + value_loss = F.mse_loss(self._discount_rewards(reward_batch), self.value_net(state_batch[0])) + self.value_net.zero_grad() + value_loss.backward() + self.value_net.step() + + ## Policy Loss + with torch.no_grad(): + state_values = self.value_net(state_batch) + next_state_values = torch.zeros_like(state_values) + next_state_values[not_done_batch] = self.value_net(next_state_batch[not_done_batch]) + advantages = (reward_batch.unsqueeze(1) + self.config['discount_rate'] * next_state_values) - state_values + advantages = advantages.squeeze(1) + + policy_loss = (-log_prob_batch * advantages).sum() + + if self.logger is not None: + self.logger.append("Loss/Policy", policy_loss.item()) + self.logger.append("Loss/Value", value_loss.item()) + + + self.policy_net.zero_grad() + policy_loss.backward() + self.policy_net.step() + + + # Memory is irrelevant for future training + self.memory.clear() + + diff --git a/rltorch/agents/DQNAgent.py b/rltorch/agents/DQNAgent.py index 8e665ab..2e3d407 100644 --- a/rltorch/agents/DQNAgent.py +++ b/rltorch/agents/DQNAgent.py @@ -34,27 +34,29 @@ class DQNAgent: next_state_batch = next_state_batch.to(self.net.device) not_done_batch = not_done_batch.to(self.net.device) - obtained_values = self.net(state_batch).gather(1, action_batch.view(self.config['batch_size'], 1)) + state_values = self.net(state_batch) + obtained_values = state_values.gather(1, action_batch.view(self.config['batch_size'], 1)) with torch.no_grad(): # Use the target net to produce action values for the next state # and the regular net to select the action # That way we decouple the value and action selecting processes (DOUBLE DQN) not_done_size = not_done_batch.sum() + next_state_values = torch.zeros_like(state_values, device = self.net.device) if self.target_net is not None: - next_state_values = self.target_net(next_state_batch) - next_best_action = self.net(next_state_batch).argmax(1) + next_state_values[not_done_batch] = self.target_net(next_state_batch[not_done_batch]) + next_best_action = self.net(next_state_batch[not_done_batch]).argmax(1) else: - next_state_values = self.net(next_state_batch) - next_best_action = next_state_values.argmax(1) + next_state_values[not_done_batch] = self.net(next_state_batch[not_done_batch]) + next_best_action = next_state_values[not_done_batch].argmax(1) best_next_state_value = torch.zeros(self.config['batch_size'], device = self.net.device) - best_next_state_value[not_done_batch] = next_state_values.gather(1, next_best_action.view((not_done_size, 1))).squeeze(1) + best_next_state_value[not_done_batch] = next_state_values[not_done_batch].gather(1, next_best_action.view((not_done_size, 1))).squeeze(1) expected_values = (reward_batch + (self.config['discount_rate'] * best_next_state_value)).unsqueeze(1) if (isinstance(self.memory, M.PrioritizedReplayMemory)): - loss = (torch.as_tensor(importance_weights, device = self.net.device) * (obtained_values - expected_values)**2).mean() + loss = (torch.as_tensor(importance_weights, device = self.net.device) * ((obtained_values - expected_values)**2).squeeze(1)).mean() else: loss = F.mse_loss(obtained_values, expected_values) diff --git a/rltorch/agents/PPOAgent.py b/rltorch/agents/PPOAgent.py new file mode 100644 index 0000000..fed5840 --- /dev/null +++ b/rltorch/agents/PPOAgent.py @@ -0,0 +1,83 @@ +# Deprecated since the idea of the idea shouldn't work without having some sort of "mental model" of the environment + +from copy import deepcopy +import numpy as np +import torch +import torch.nn.functional as F +from torch.distributions import Categorical +import rltorch +import rltorch.memory as M +import collections +import random + +class PPOAgent: + def __init__(self, policy_net, value_net, memory, config, logger = None): + self.policy_net = policy_net + self.old_policy_net = rltorch.network.TargetNetwork(policy_net) + self.value_net = value_net + self.memory = memory + self.config = deepcopy(config) + self.logger = logger + + def _discount_rewards(self, rewards): + discounted_rewards = torch.zeros_like(rewards) + running_add = 0 + for t in reversed(range(len(rewards))): + running_add = running_add * self.config['discount_rate'] + rewards[t] + discounted_rewards[t] = running_add + + return discounted_rewards + + + def learn(self): + if len(self.memory) < self.config['batch_size']: + return + + episode_batch = self.memory.recall() + state_batch, action_batch, reward_batch, next_state_batch, done_batch, log_prob_batch = zip(*episode_batch) + + state_batch = torch.cat(state_batch).to(self.value_net.device) + action_batch = torch.tensor(action_batch).to(self.value_net.device) + reward_batch = torch.tensor(reward_batch).to(self.value_net.device) + not_done_batch = ~torch.tensor(done_batch).to(self.value_net.device) + next_state_batch = torch.cat(next_state_batch).to(self.value_net.device) + log_prob_batch = torch.cat(log_prob_batch).to(self.value_net.device) + + ## Value Loss + value_loss = F.mse_loss(self._discount_rewards(reward_batch), self.value_net(state_batch[0])) + self.value_net.zero_grad() + value_loss.backward() + self.value_net.step() + + ## Policy Loss + with torch.no_grad(): + state_values = self.value_net(state_batch) + next_state_values = torch.zeros_like(state_values) + next_state_values[not_done_batch] = self.value_net(next_state_batch[not_done_batch]) + advantages = (reward_batch.unsqueeze(1) + self.config['discount_rate'] * next_state_values) - state_values + advantages = advantages.squeeze(1) + + action_probabilities = self.old_policy_net(state_batch).detach() + distributions = list(map(Categorical, action_probabilities)) + old_log_probs = torch.stack(list(map(lambda distribution, action: distribution.log_prob(action), distributions, action_batch))) + + policy_ratio = torch.exp(log_prob_batch - old_log_probs) # Equivalent to (log_prob / old_log_prob) + policy_loss1 = policy_ratio * advantages + policy_loss2 = policy_ratio.clamp(min = 0.8, max = 1.2) * advantages # From original paper + policy_loss = -torch.min(policy_loss1, policy_loss2).sum() + + if self.logger is not None: + self.logger.append("Loss/Policy", policy_loss.item()) + self.logger.append("Loss/Value", value_loss.item()) + + + self.old_policy_net.sync() + self.policy_net.zero_grad() + policy_loss.backward() + self.policy_net.step() + + + # Memory is irrelevant for future training + self.memory.clear() + + diff --git a/rltorch/agents/REINFORCEAgent.py b/rltorch/agents/REINFORCEAgent.py index a9d034d..5948fef 100644 --- a/rltorch/agents/REINFORCEAgent.py +++ b/rltorch/agents/REINFORCEAgent.py @@ -31,7 +31,7 @@ class REINFORCEAgent: discount_reward_batch = self._discount_rewards(torch.tensor(reward_batch)) log_prob_batch = torch.cat(log_prob_batch) - policy_loss = (-1 * log_prob_batch * discount_reward_batch).sum() + policy_loss = (-log_prob_batch * discount_reward_batch).sum() if self.logger is not None: self.logger.append("Loss", policy_loss.item()) diff --git a/rltorch/agents/_A2CSingleAgent.py b/rltorch/agents/_A2CSingleAgent.py new file mode 100644 index 0000000..b01733f --- /dev/null +++ b/rltorch/agents/_A2CSingleAgent.py @@ -0,0 +1,260 @@ +# Deprecated since the idea of the idea shouldn't work without having some sort of "mental model" of the environment + +from copy import deepcopy +import numpy as np +import torch +import torch.nn.functional as F +from torch.distributions import Categorical +import rltorch +import rltorch.memory as M +import collections +import random + +class A2CSingleAgent: + def __init__(self, policy_net, value_net, memory, config, target_value_net = None, logger = None): + self.policy_net = policy_net + self.value_net = value_net + self.memory = memory + self.config = deepcopy(config) + self.target_value_net = target_value_net + self.logger = logger + + def learn_value(self): + if (isinstance(self.memory, M.PrioritizedReplayMemory)): + weight_importance = self.config['prioritized_replay_weight_importance'] + # If it's a scheduler then get the next value by calling next, otherwise just use it's value + beta = next(weight_importance) if isinstance(weight_importance, collections.Iterable) else weight_importance + minibatch = self.memory.sample(self.config['batch_size'], beta = beta) + state_batch, action_batch, reward_batch, next_state_batch, not_done_batch, importance_weights, batch_indexes = M.zip_batch(minibatch, priority = True) + else: + minibatch = self.memory.sample(self.config['batch_size']) + state_batch, action_batch, reward_batch, next_state_batch, not_done_batch = M.zip_batch(minibatch) + + # Send to their appropriate devices + state_batch = state_batch.to(self.value_net.device) + action_batch = action_batch.to(self.value_net.device) + reward_batch = reward_batch.to(self.value_net.device) + next_state_batch = next_state_batch.to(self.value_net.device) + not_done_batch = not_done_batch.to(self.value_net.device) + + + ## Value Loss + state_values = self.value_net(state_batch) + obtained_values = state_values.gather(1, action_batch.view(self.config['batch_size'], 1)) + with torch.no_grad(): + # Use the target net to produce action values for the next state + # and the regular net to select the action + # That way we decouple the value and action selecting processes (DOUBLE DQN) + not_done_size = not_done_batch.sum() + next_state_values = torch.zeros_like(state_values) + if self.target_value_net is not None: + next_state_values[not_done_batch] = self.target_value_net(next_state_batch[not_done_batch]) + next_best_action = self.value_net(next_state_batch).argmax(1) + else: + next_state_values[not_done_batch] = self.value_net(next_state_batch[not_done_batch]) + next_best_action = next_state_values.argmax(1) + + best_next_state_value = torch.zeros(self.config['batch_size'], device = self.value_net.device) + # best_next_state_value[not_done_batch] = next_state_values.gather(1, next_best_action.view((not_done_size, 1))).squeeze(1) + best_next_state_value[not_done_batch] = next_state_values[not_done_batch].gather(1, next_best_action[not_done_batch].view((not_done_size, 1))).squeeze(1) + + expected_values = (reward_batch + (self.config['discount_rate'] * best_next_state_value)).unsqueeze(1) + + if (isinstance(self.memory, M.PrioritizedReplayMemory)): + importance_weights = torch.as_tensor(importance_weights, device = self.value_net.device) + value_loss = (importance_weights * ((obtained_values - expected_values)**2).squeeze(1)).mean() + else: + value_loss = F.mse_loss(obtained_values, expected_values) + + if (isinstance(self.memory, M.PrioritizedReplayMemory)): + td_error = (obtained_values - expected_values).detach().abs() + self.memory.update_priorities(batch_indexes, td_error) + + self.value_net.zero_grad() + value_loss.backward() + self.value_net.step() + + if self.target_value_net is not None: + if 'target_sync_tau' in self.config: + self.target_value_net.partial_sync(self.config['target_sync_tau']) + else: + self.target_value_net.sync() + + if self.logger is not None: + self.logger.append("Loss/Value", value_loss.item()) + + + def learn_policy(self): + starting_index = random.randint(0, len(self.memory) - self.config['batch_size']) + state_batch, action_batch, reward_batch, next_state_batch, not_done_batch = M.zip_batch(self.memory[starting_index:(starting_index + self.config['batch_size'])]) + + state_batch = state_batch.to(self.policy_net.device) + action_batch = action_batch.to(self.policy_net.device) + reward_batch = reward_batch.to(self.policy_net.device) + next_state_batch = next_state_batch.to(self.policy_net.device) + not_done_batch = not_done_batch.to(self.policy_net.device) + + # Find when episode ends and filter out the Transitions after + episode_ends = (~not_done_batch).nonzero().squeeze(1) + start_idx = 0 + end_idx = self.config['batch_size'] + if len(episode_ends) > 0: + if (episode_ends[0] == 0).item(): + if len(episode_ends) > 1: + start_idx = 1 + end_idx = episode_ends[1] + 1 + else: + start_idx = 1 + else: + end_idx = episode_ends[0] + 1 + batch_size = end_idx - start_idx + + # Now filter... + state_batch = state_batch[start_idx:end_idx] + action_batch = action_batch[start_idx:end_idx] + reward_batch = reward_batch[start_idx:end_idx] + next_state_batch = next_state_batch[start_idx:end_idx] + not_done_batch = not_done_batch[start_idx:end_idx] + + + with torch.no_grad(): + if self.target_value_net is not None: + state_values = self.target_value_net(state_batch) + next_state_values = torch.zeros_like(state_values, device = self.value_net.device) + next_state_values[not_done_batch] = self.target_value_net(next_state_batch[not_done_batch]) + else: + state_values = self.value_net(state_batch) + next_state_values = torch.zeros_like(state_values, device = self.value_net.device) + next_state_values[not_done_batch] = self.value_net(next_state_batch[not_done_batch]) + + obtained_values = state_values.gather(1, action_batch.view(batch_size, 1)) + approx_state_action_values = reward_batch.unsqueeze(1) + self.config['discount_rate'] * next_state_values + advantage = (obtained_values - approx_state_action_values.mean(1).unsqueeze(1)) + # Scale and squeeze the dimension + advantage = advantage.squeeze(1) + # advantage = (advantage / (state_values.std() + np.finfo('float').eps)).squeeze(1) + action_probabilities = self.policy_net(state_batch) + distributions = list(map(Categorical, action_probabilities)) + log_probs = torch.stack(list(map(lambda distribution, action: distribution.log_prob(action), distributions, action_batch))) + policy_loss = (-log_probs * advantage).mean() + + self.policy_net.zero_grad() + policy_loss.backward() + self.policy_net.step() + + if self.logger is not None: + self.logger.append("Loss/Policy", policy_loss.item()) + + def learn(self): + if len(self.memory) < self.config['batch_size']: + return + self.learn_value() + self.learn_policy() + + + + + + + + # def learn(self): + # if len(self.memory) < self.config['batch_size']: + # return + + # if (isinstance(self.memory, M.PrioritizedReplayMemory)): + # weight_importance = self.config['prioritized_replay_weight_importance'] + # # If it's a scheduler then get the next value by calling next, otherwise just use it's value + # beta = next(weight_importance) if isinstance(weight_importance, collections.Iterable) else weight_importance + # minibatch = self.memory.sample(self.config['batch_size'], beta = beta) + # state_batch, action_batch, reward_batch, next_state_batch, not_done_batch, importance_weights, batch_indexes = M.zip_batch(minibatch, priority = True) + # else: + # minibatch = self.memory.sample(self.config['batch_size']) + # state_batch, action_batch, reward_batch, next_state_batch, not_done_batch = M.zip_batch(minibatch) + + # # Send to their appropriate devices + # # [TODO] Notice how we're sending it to the value_net's device, what if policy_net was on a different device? + # state_batch = state_batch.to(self.value_net.device) + # action_batch = action_batch.to(self.value_net.device) + # reward_batch = reward_batch.to(self.value_net.device) + # next_state_batch = next_state_batch.to(self.value_net.device) + # not_done_batch = not_done_batch.to(self.value_net.device) + + + # ## Value Loss + + # obtained_values = self.value_net(state_batch).gather(1, action_batch.view(self.config['batch_size'], 1)) + + # with torch.no_grad(): + # # Use the target net to produce action values for the next state + # # and the regular net to select the action + # # That way we decouple the value and action selecting processes (DOUBLE DQN) + # not_done_size = not_done_batch.sum() + # if self.target_value_net is not None: + # next_state_values = self.target_value_net(next_state_batch) + # next_best_action = self.value_net(next_state_batch).argmax(1) + # else: + # next_state_values = self.value_net(next_state_batch) + # next_best_action = next_state_values.argmax(1) + + # best_next_state_value = torch.zeros(self.config['batch_size'], device = self.value_net.device) + # best_next_state_value[not_done_batch] = next_state_values.gather(1, next_best_action.view((not_done_size, 1))).squeeze(1) + + # expected_values = (reward_batch + (self.config['discount_rate'] * best_next_state_value)).unsqueeze(1) + + # if (isinstance(self.memory, M.PrioritizedReplayMemory)): + # importance_weights = torch.as_tensor(importance_weights, device = self.value_net.device) + # value_loss = (importance_weights * ((obtained_values - expected_values)**2).squeeze(1)).mean() + # else: + # value_loss = F.mse_loss(obtained_values, expected_values) + + # self.value_net.zero_grad() + # value_loss.backward() + # self.value_net.step() + + # if self.target_value_net is not None: + # if 'target_sync_tau' in self.config: + # self.target_value_net.partial_sync(self.config['target_sync_tau']) + # else: + # self.target_value_net.sync() + + # if (isinstance(self.memory, M.PrioritizedReplayMemory)): + # td_error = (obtained_values - expected_values).detach().abs() + # self.memory.update_priorities(batch_indexes, td_error) + + # if self.logger is not None: + # self.logger.append("ValueLoss", value_loss.item()) + + # ## Policy Loss + # with torch.no_grad(): + # state_values = self.value_net(state_batch) + # if self.target_value_net is not None: + # next_state_values = self.target_value_net(next_state_batch) + # else: + # next_state_values = self.value_net(next_state_batch) + + # state_action_values = state_values.gather(1, action_batch.view(self.config['batch_size'], 1)) + # average_next_state_values = torch.zeros(self.config['batch_size'], device = self.value_net.device) + # average_next_state_values[not_done_batch] = next_state_values.mean(1) + + # advantage = (state_action_values - (reward_batch + self.config['discount_rate'] * average_next_state_values).unsqueeze(1)) + # # Scale and squeeze the dimension + # advantage = advantage.squeeze(1) + # # advantage = (advantage / (state_values.std() + np.finfo('float').eps)).squeeze(1) + # action_probabilities = self.policy_net(state_batch) + # distributions = list(map(Categorical, action_probabilities)) + # log_probs = torch.stack(list(map(lambda distribution, action: distribution.log_prob(action), distributions, action_batch))) + # if (isinstance(self.memory, M.PrioritizedReplayMemory)): + # policy_loss = (importance_weights * -log_probs * advantage).sum() + # else: + # policy_loss = (-log_probs * advantage).sum() + + # self.policy_net.zero_grad() + # policy_loss.backward() + # self.policy_net.step() + + # if self.logger is not None: + # self.logger.append("PolicyLoss", policy_loss.item()) + + + + diff --git a/rltorch/agents/__init__.py b/rltorch/agents/__init__.py index 6be341c..df93574 100644 --- a/rltorch/agents/__init__.py +++ b/rltorch/agents/__init__.py @@ -1,2 +1,4 @@ +from .A2CSingleAgent import * from .DQNAgent import * +from .PPOAgent import * from .REINFORCEAgent import * \ No newline at end of file diff --git a/rltorch/memory/ReplayMemory.py b/rltorch/memory/ReplayMemory.py index 367b9c9..89e6cd8 100644 --- a/rltorch/memory/ReplayMemory.py +++ b/rltorch/memory/ReplayMemory.py @@ -67,7 +67,7 @@ def zip_batch(minibatch, priority = False): action_batch = torch.tensor(action_batch) reward_batch = torch.tensor(reward_batch) not_done_batch = ~torch.tensor(done_batch) - next_state_batch = torch.cat(next_state_batch)[not_done_batch] + next_state_batch = torch.cat(next_state_batch) if priority: return state_batch, action_batch, reward_batch, next_state_batch, not_done_batch, weights, indexes