From 17391c7467df210ef41a6f406c28e34f6912b211 Mon Sep 17 00:00:00 2001 From: Brandon Rozek Date: Thu, 31 Oct 2019 20:54:52 -0400 Subject: [PATCH] First draft of Deep Q Learning From Demonstrations --- rltorch/agents/DQfDAgent.py | 114 ++++++++++++++++++++++++++++++++++++ rltorch/agents/__init__.py | 1 + 2 files changed, 115 insertions(+) create mode 100644 rltorch/agents/DQfDAgent.py diff --git a/rltorch/agents/DQfDAgent.py b/rltorch/agents/DQfDAgent.py new file mode 100644 index 0000000..45a84fc --- /dev/null +++ b/rltorch/agents/DQfDAgent.py @@ -0,0 +1,114 @@ +import collections +import rltorch.memory as M +import torch +import torch.nn.functional as F +from copy import deepcopy +import numpy as np +from pathlib import Path +from rltorch.action_selector import ArgMaxSelector + +class DQfDAgent: + def __init__(self, net, imitation_net, memory, config, target_net = None, logger = None): + self.net = net + self.imitation_net = imitation_net + self.imitator = ArgMaxSelector(imitation_net, self.imitation_net.model.action_size, device = imitation_net.device) + self.target_net = target_net + self.memory = memory + self.config = deepcopy(config) + self.logger = logger + def save(self, file_location): + torch.save(self.net.model.state_dict(), file_location) + def load(self, file_location): + self.net.model.state_dict(torch.load(file_location)) + self.net.model.to(self.net.device) + self.target_net.sync() + + def learn(self, logger = None): + if len(self.memory) < self.config['batch_size']: + return + + if (isinstance(self.memory, M.PrioritizedReplayMemory)): + weight_importance = self.config['prioritized_replay_weight_importance'] + # If it's a scheduler then get the next value by calling next, otherwise just use it's value + beta = next(weight_importance) if isinstance(weight_importance, collections.Iterable) else weight_importance + minibatch = self.memory.sample(self.config['batch_size'], beta = beta) + state_batch, action_batch, reward_batch, next_state_batch, not_done_batch, importance_weights, batch_indexes = M.zip_batch(minibatch, priority = True) + else: + minibatch = self.memory.sample(self.config['batch_size']) + state_batch, action_batch, reward_batch, next_state_batch, not_done_batch = M.zip_batch(minibatch) + + # Send to their appropriate devices + state_batch = state_batch.to(self.net.device) + action_batch = action_batch.to(self.net.device) + reward_batch = reward_batch.to(self.net.device).float() + next_state_batch = next_state_batch.to(self.net.device) + not_done_batch = not_done_batch.to(self.net.device) + + state_values = self.net(state_batch) + obtained_values = state_values.gather(1, action_batch.view(self.config['batch_size'], 1)) + + with torch.no_grad(): + # Use the target net to produce action values for the next state + # and the regular net to select the action + # That way we decouple the value and action selecting processes (DOUBLE DQN) + not_done_size = not_done_batch.sum() + next_state_values = torch.zeros_like(state_values, device = self.net.device) + if self.target_net is not None: + next_state_values[not_done_batch] = self.target_net(next_state_batch[not_done_batch]) + next_best_action = self.net(next_state_batch[not_done_batch]).argmax(1) + else: + next_state_values[not_done_batch] = self.net(next_state_batch[not_done_batch]) + next_best_action = next_state_values[not_done_batch].argmax(1) + + best_next_state_value = torch.zeros(self.config['batch_size'], device = self.net.device) + best_next_state_value[not_done_batch] = next_state_values[not_done_batch].gather(1, next_best_action.view((not_done_size, 1))).squeeze(1) + + expected_values = (reward_batch + (self.config['discount_rate'] * best_next_state_value)).unsqueeze(1) + + # If we're sampling by TD error, multiply loss by a importance weight which helps decrease overfitting + if (isinstance(self.memory, M.PrioritizedReplayMemory)): + # dqn_loss = (torch.as_tensor(importance_weights, device = self.net.device) * F.smooth_l1_loss(obtained_values, expected_values, reduction = 'none').squeeze(1)).mean() + dqn_loss = (torch.as_tensor(importance_weights, device = self.net.device) * ((obtained_values - expected_values)**2).squeeze(1)).mean() + else: + # dqn_loss = F.smooth_l1_loss(obtained_values, expected_values) + dqn_loss = F.mse_loss(obtained_values, expected_values) + + # Demonstration loss + l = torch.ones_like(state_values) + expert_actions = self.imitation_net(state_batch).argmax(1) + # l(s, a) is zero for every action the expert doesn't take + for i,a in zip(range(len(state_values)), expert_actions): + l[i].fill_(0.8) # According to paper + l[i, a] = 0 + if self.target_net is not None: + expert_value = self.target_net(state_batch) + else: + expert_value = self.net(state_batch) + expert_value = expert_value.gather(1, expert_actions.view((self.config['batch_size'], 1))).squeeze(1) + if isinstance(self.config['dqfd_demo_loss_weight'], collections.Iterable): + demo_importance = next(self.config['dqfd_demo_loss_weight']) + else: + demo_importance = self.config['dqfd_demo_loss_weight'] + demo_loss = F.mse_loss((state_values + l).max(1)[0], expert_value) + loss = ((1 - demo_importance) * dqn_loss + demo_importance * demo_loss) / (dqn_loss + demo_loss) + + if self.logger is not None: + self.logger.append("Loss", loss.item()) + + self.net.zero_grad() + loss.backward() + self.net.clamp_gradients() + self.net.step() + + if self.target_net is not None: + if 'target_sync_tau' in self.config: + self.target_net.partial_sync(self.config['target_sync_tau']) + else: + self.target_net.sync() + + # If we're sampling by TD error, readjust the weights of the experiences + if (isinstance(self.memory, M.PrioritizedReplayMemory)): + td_error = (obtained_values - expected_values).detach().abs() + self.memory.update_priorities(batch_indexes, td_error) + + diff --git a/rltorch/agents/__init__.py b/rltorch/agents/__init__.py index 245932b..2d12612 100644 --- a/rltorch/agents/__init__.py +++ b/rltorch/agents/__init__.py @@ -1,5 +1,6 @@ from .A2CSingleAgent import * from .DQNAgent import * +from .DQfDAgent import * from .PPOAgent import * from .QEPAgent import * from .REINFORCEAgent import * \ No newline at end of file