From 038d406d0fff6546115bf9d1fce8c64888af1457 Mon Sep 17 00:00:00 2001 From: Brandon Rozek Date: Wed, 13 Nov 2019 22:56:27 -0500 Subject: [PATCH] Fixed errors with n-step returns --- rltorch/agents/DQfDAgent.py | 83 ++++++++++++++--------- rltorch/memory/DQfDMemory.py | 42 +++++++++++- rltorch/memory/PrioritizedReplayMemory.py | 20 +++--- 3 files changed, 98 insertions(+), 47 deletions(-) diff --git a/rltorch/agents/DQfDAgent.py b/rltorch/agents/DQfDAgent.py index 0f742b1..99af560 100644 --- a/rltorch/agents/DQfDAgent.py +++ b/rltorch/agents/DQfDAgent.py @@ -25,16 +25,22 @@ class DQfDAgent: if len(self.memory) < self.config['batch_size']: return + if 'n_step' in self.config: + batch_size = (self.config['batch_size'] // self.config['n_step']) * self.config['n_step'] + steps = self.config['n_step'] + else: + batch_size = self.config['batch_size'] + steps = None + weight_importance = self.config['prioritized_replay_weight_importance'] # If it's a scheduler then get the next value by calling next, otherwise just use it's value beta = next(weight_importance) if isinstance(weight_importance, collections.Iterable) else weight_importance # Check to see if we are doing N-Step DQN - steps = self.config['n_step'] if 'n_step' in self.config else None if steps is not None: - minibatch = self.memory.sample_n_steps(self.config['batch_size'], steps, beta) + minibatch = self.memory.sample_n_steps(batch_size, steps, beta) else: - minibatch = self.memory.sample(self.config['batch_size'], beta = beta) + minibatch = self.memory.sample(batch_size, beta = beta) # Process batch state_batch, action_batch, reward_batch, next_state_batch, not_done_batch, importance_weights, batch_indexes = M.zip_batch(minibatch, priority = True) @@ -50,7 +56,7 @@ class DQfDAgent: not_done_batch = not_done_batch.to(self.net.device) state_values = self.net(state_batch) - obtained_values = state_values.gather(1, action_batch.view(self.config['batch_size'], 1)) + obtained_values = state_values.gather(1, action_batch.view(batch_size, 1)) # DQN Loss with torch.no_grad(): @@ -66,38 +72,44 @@ class DQfDAgent: next_state_values[not_done_batch] = self.net(next_state_batch[not_done_batch]) next_best_action = next_state_values[not_done_batch].argmax(1) - best_next_state_value = torch.zeros(self.config['batch_size'], device = self.net.device) + best_next_state_value = torch.zeros(batch_size, device = self.net.device) best_next_state_value[not_done_batch] = next_state_values[not_done_batch].gather(1, next_best_action.view((not_done_size, 1))).squeeze(1) - expected_values = (reward_batch + (self.config['discount_rate'] * best_next_state_value)).unsqueeze(1) + expected_values = (reward_batch + (batch_size * best_next_state_value)).unsqueeze(1) # N-Step DQN Loss - expected_n_step_values = [] - with torch.no_grad(): + # num_steps capture how many steps actually exist before the end of episode + if steps != None: + expected_n_step_values = [] + with torch.no_grad(): + for i in range(0, len(state_batch), steps): + num_steps = not_done_batch[i:(i + steps)].sum() + if num_steps < 2: + continue # No point processing this + # Get the estimated value at the last state in a sequence + if self.target_net is not None: + expected_nth_values = self.target_net(state_batch[i + num_steps - 1].unsqueeze(0)).squeeze(0) + best_nth_action = self.net(state_batch[i + num_steps - 1].unsqueeze(0)).squeeze(0).argmax(0) + else: + expected_nth_values = self.net(state_batch[i + num_steps - 1].unsqueeze(0)).squeeze(0) + best_nth_action = expected_nth_values.argmax(0) + best_expected_nth_value = expected_nth_values[best_nth_action] + # Calculate the value leading up to it by taking the rewards and multiplying it by the discount rate + received_n_value = 0 + for j in range(num_steps): + received_n_value += self.config['discount_rate']**j * reward_batch[j] + # Q(s, a) = r_0 + lambda_1 * r_1 + lambda_2^2 * r_2 + ... + lambda_{steps}^{steps} * max_{a}(Q(s + steps, a)) + expected_n_step_values.append(received_n_value + self.config['discount_rate']**num_steps * best_expected_nth_value) + expected_n_step_values = torch.stack(expected_n_step_values) + # Gather the value the current network thinks it should be + observed_n_step_values = [] for i in range(0, len(state_batch), steps): - # Get the estimated value at the last state in a sequence - if self.target_net is not None: - expected_nth_values = self.target_net(state_batch[i + steps]) - best_nth_action = self.net(state_batch[i + steps]).argmax(1) - else: - expected_nth_values = self.net(state_batch[i + steps]) - best_nth_action = expected_nth_values.argmax(1) - best_expected_nth_value = expected_nth_values[best_nth_action].squeeze(1) - # Calculate the value leading up to it by taking the rewards and multiplying it by the discount rate - received_n_value = 0 - for j in range(steps): - received_n_value += self.config['discount_rate']**j * reward_batch[j] - # Q(s, a) = r_0 + lambda_1 * r_1 + lambda_2^2 * r_2 + ... + lambda_{steps}^{steps} * max_{a}(Q(s + steps, a)) - expected_n_step_values.append(received_n_value + self.config['discount_rate']**steps * best_expected_nth_value) - expected_n_step_values = torch.stack(expected_n_step_values) - # Gather the value the current network thinks it should be - observed_n_step_values = [] - for i in range(0, len(state_batch), steps): - observed_nth_value = self.net(state_batch[i])[action_batch[i]] - observed_n_step_values.append(observed_nth_value) - observed_n_step_values = torch.stack(observed_n_step_values) - - + num_steps = not_done_batch[i:(i + steps)].sum() + if num_steps < 2: + continue # No point processing this + observed_nth_value = self.net(state_batch[i].unsqueeze(0)).squeeze(0)[action_batch[i]] + observed_n_step_values.append(observed_nth_value) + observed_n_step_values = torch.stack(observed_n_step_values) # Demonstration loss if demo_mask.sum() > 0: @@ -126,12 +138,17 @@ class DQfDAgent: # Since dqn_loss and demo_loss are different sizes, the reduction has to happen before they are combined dqn_loss = (torch.as_tensor(importance_weights, device = self.net.device) * F.mse_loss(obtained_values, expected_values, reduction = 'none').squeeze(1)).mean() - dqn_n_step_loss = (torch.as_tensor(importance_weights, device = self.net.device) * F.mse_loss(observed_n_step_values, expected_n_step_values, reduction = 'none').squeeze(1)).mean() + + if steps != None: + dqn_n_step_loss = (torch.as_tensor(importance_weights[::steps], device = self.net.device) * F.mse_loss(observed_n_step_values, expected_n_step_values, reduction = 'none')).mean() + else: + dqn_n_step_loss = torch.tensor(0, device = self.net.device) + if demo_mask.sum() > 0: demo_loss = (torch.as_tensor(importance_weights, device = self.net.device)[demo_mask] * F.mse_loss((state_values[demo_mask] + l).max(1)[0].unsqueeze(1), expert_value, reduction = 'none').squeeze(1)).mean() else: demo_loss = 0 - loss = td_importance * dqn_loss + demo_importance * demo_loss + loss = td_importance * dqn_loss + td_importance * dqn_n_step_loss + demo_importance * demo_loss if self.logger is not None: self.logger.append("Loss", loss.item()) diff --git a/rltorch/memory/DQfDMemory.py b/rltorch/memory/DQfDMemory.py index 77e6014..d4ed582 100644 --- a/rltorch/memory/DQfDMemory.py +++ b/rltorch/memory/DQfDMemory.py @@ -1,5 +1,6 @@ from .PrioritizedReplayMemory import PrioritizedReplayMemory from collections import namedtuple +import numpy as np Transition = namedtuple('Transition', ('state', 'action', 'reward', 'next_state', 'done')) @@ -9,12 +10,13 @@ class DQfDMemory(PrioritizedReplayMemory): def __init__(self, capacity, alpha): super().__init__(capacity, alpha) self.demo_position = 0 - self.obtained_transitions_length = 0 def append(self, *args, **kwargs): + last_position = self.position # Get position before super classes change it super().append(*args, **kwargs) # Don't overwrite demonstration data - self.position = self.demo_position + ((self.position + 1) % (len(self.memory) - self.demo_position)) + new_position = ((last_position + 1) % (self.capacity - self.demo_position + 1)) + self.position = new_position if new_position > self.demo_position else self.demo_position + new_position def append_demonstration(self, *args): demonstrations = self.memory[:self.demo_position] @@ -24,6 +26,40 @@ class DQfDMemory(PrioritizedReplayMemory): self.memory.append(Transition(*args)) else: if len(demonstrations) + len(obtained_transitions) + 1 > self.capacity: - obtained_transitions = obtained_transitions[:(self.capacity - len(demonstrations) - 1)] + obtained_transitions = obtained_transitions[1:] self.memory = demonstrations + [Transition(*args)] + obtained_transitions self.demo_position += 1 + self.position += 1 + + def sample_n_steps(self, batch_size, steps, beta): + assert beta > 0 + + sample_size = batch_size // steps + + # Sample indexes and get n-steps after that + idxes = self._sample_proportional(sample_size) + step_idxes = [] + for i in idxes: + # If the interval of experiences fall between demonstration and obtained, move it over to the demonstration half + if i < self.demo_position and i + steps > self.demo_position: + diff = i + steps - self.demo_position + step_idxes += range(i - diff, i + steps - diff) + elif i > steps: + step_idxes += range(i - steps, i) + else: + step_idxes += range(i, i + steps) + + # Calculate appropriate weights and assign it to the values of the same sequence + weights = [] + p_min = self._it_min.min() / self._it_sum.sum() + max_weight = (p_min * len(self.memory)) ** (-beta) + for idx in idxes: + p_sample = self._it_sum[idx] / self._it_sum.sum() + weight = (p_sample * len(self.memory)) ** (-beta) + weights += [(weight / max_weight) for i in range(steps)] + weights = np.array(weights) + + # Combine all the data together into a batch + encoded_sample = tuple(zip(*self._encode_sample(step_idxes))) + batch = list(zip(*encoded_sample, weights, step_idxes)) + return batch \ No newline at end of file diff --git a/rltorch/memory/PrioritizedReplayMemory.py b/rltorch/memory/PrioritizedReplayMemory.py index c9aedca..da3c767 100644 --- a/rltorch/memory/PrioritizedReplayMemory.py +++ b/rltorch/memory/PrioritizedReplayMemory.py @@ -234,32 +234,30 @@ class PrioritizedReplayMemory(ReplayMemory): def sample_n_steps(self, batch_size, steps, beta): assert beta > 0 - memory = self.memory - self.memory = self.memory[:-steps] sample_size = batch_size // steps - + # Sample indexes and get n-steps after that idxes = self._sample_proportional(sample_size) step_idxes = [] for i in idxes: - step_idxes += range(i, i + steps) + if i > steps: + step_idxes += range(i - steps, i) + else: + step_idxes += range(i, i + steps) - # Calculate appropriate weights + # Calculate appropriate weights and assign it to the values of the same sequence weights = [] p_min = self._it_min.min() / self._it_sum.sum() max_weight = (p_min * len(self.memory)) ** (-beta) - for idx in step_idxes: + for idx in idxes: p_sample = self._it_sum[idx] / self._it_sum.sum() weight = (p_sample * len(self.memory)) ** (-beta) - weights.append(weight / max_weight) + weights += [(weight / max_weight) for i in range(steps)] weights = np.array(weights) # Combine all the data together into a batch encoded_sample = tuple(zip(*self._encode_sample(step_idxes))) - batch = list(zip(*encoded_sample, weights, idxes)) - - # Restore memory and return batch - self.memory = memory + batch = list(zip(*encoded_sample, weights, step_idxes)) return batch @jit(forceobj = True)