diff --git a/rltorch/agents/DQfDAgent.py b/rltorch/agents/DQfDAgent.py index f81a5ef..0f742b1 100644 --- a/rltorch/agents/DQfDAgent.py +++ b/rltorch/agents/DQfDAgent.py @@ -28,10 +28,19 @@ class DQfDAgent: weight_importance = self.config['prioritized_replay_weight_importance'] # If it's a scheduler then get the next value by calling next, otherwise just use it's value beta = next(weight_importance) if isinstance(weight_importance, collections.Iterable) else weight_importance - minibatch = self.memory.sample(self.config['batch_size'], beta = beta) - state_batch, action_batch, reward_batch, next_state_batch, not_done_batch, importance_weights, batch_indexes = M.zip_batch(minibatch, priority = True) - demo_indexes = batch_indexes < self.memory.demo_position + # Check to see if we are doing N-Step DQN + steps = self.config['n_step'] if 'n_step' in self.config else None + if steps is not None: + minibatch = self.memory.sample_n_steps(self.config['batch_size'], steps, beta) + else: + minibatch = self.memory.sample(self.config['batch_size'], beta = beta) + + # Process batch + state_batch, action_batch, reward_batch, next_state_batch, not_done_batch, importance_weights, batch_indexes = M.zip_batch(minibatch, priority = True) + + batch_index_tensors = torch.tensor(batch_indexes) + demo_mask = batch_index_tensors < self.memory.demo_position # Send to their appropriate devices state_batch = state_batch.to(self.net.device) @@ -43,6 +52,7 @@ class DQfDAgent: state_values = self.net(state_batch) obtained_values = state_values.gather(1, action_batch.view(self.config['batch_size'], 1)) + # DQN Loss with torch.no_grad(): # Use the target net to produce action values for the next state # and the regular net to select the action @@ -61,19 +71,48 @@ class DQfDAgent: expected_values = (reward_batch + (self.config['discount_rate'] * best_next_state_value)).unsqueeze(1) - # Demonstration loss - l = torch.ones_like(state_values[demo_indexes]) - expert_actions = action_batch[demo_indexes] - # l(s, a) is zero for every action the expert doesn't take - for i,a in zip(range(len(l)), expert_actions): - l[i].fill_(0.8) # According to paper - l[i, a] = 0 - if self.target_net is not None: - expert_value = self.target_net(state_batch[demo_indexes]) - else: - expert_value = state_values[demo_indexes] - expert_value = expert_value.gather(1, expert_actions.view((self.config['batch_size'], 1))).squeeze(1) + # N-Step DQN Loss + expected_n_step_values = [] + with torch.no_grad(): + for i in range(0, len(state_batch), steps): + # Get the estimated value at the last state in a sequence + if self.target_net is not None: + expected_nth_values = self.target_net(state_batch[i + steps]) + best_nth_action = self.net(state_batch[i + steps]).argmax(1) + else: + expected_nth_values = self.net(state_batch[i + steps]) + best_nth_action = expected_nth_values.argmax(1) + best_expected_nth_value = expected_nth_values[best_nth_action].squeeze(1) + # Calculate the value leading up to it by taking the rewards and multiplying it by the discount rate + received_n_value = 0 + for j in range(steps): + received_n_value += self.config['discount_rate']**j * reward_batch[j] + # Q(s, a) = r_0 + lambda_1 * r_1 + lambda_2^2 * r_2 + ... + lambda_{steps}^{steps} * max_{a}(Q(s + steps, a)) + expected_n_step_values.append(received_n_value + self.config['discount_rate']**steps * best_expected_nth_value) + expected_n_step_values = torch.stack(expected_n_step_values) + # Gather the value the current network thinks it should be + observed_n_step_values = [] + for i in range(0, len(state_batch), steps): + observed_nth_value = self.net(state_batch[i])[action_batch[i]] + observed_n_step_values.append(observed_nth_value) + observed_n_step_values = torch.stack(observed_n_step_values) + + + # Demonstration loss + if demo_mask.sum() > 0: + l = torch.ones_like(state_values[demo_mask]) + expert_actions = action_batch[demo_mask] + # l(s, a) is zero for every action the expert doesn't take + for i,a in zip(range(len(l)), expert_actions): + l[i].fill_(0.8) # According to paper + l[i, a] = 0 + if self.target_net is not None: + expert_value = self.target_net(state_batch[demo_mask]) + else: + expert_value = state_values[demo_mask] + expert_value = expert_value.gather(1, expert_actions.view(demo_mask.sum(), 1)) + # Iterate through hyperparamters if isinstance(self.config['dqfd_demo_loss_weight'], collections.Iterable): demo_importance = next(self.config['dqfd_demo_loss_weight']) @@ -84,9 +123,14 @@ class DQfDAgent: else: td_importance = self.config['dqfd_td_loss_weight'] + # Since dqn_loss and demo_loss are different sizes, the reduction has to happen before they are combined - dqn_loss = (torch.as_tensor(importance_weights, device = self.net.device) * F.mse_loss(obtained_values, expected_values, reduction = 'none')).mean() - demo_loss = (torch.as_tensor(importance_weights[demo_indexes], device = self.net.device) * F.mse_loss((state_values[demo_indexes] + l).max(1)[0], expert_value, reduction = 'none')).mean() + dqn_loss = (torch.as_tensor(importance_weights, device = self.net.device) * F.mse_loss(obtained_values, expected_values, reduction = 'none').squeeze(1)).mean() + dqn_n_step_loss = (torch.as_tensor(importance_weights, device = self.net.device) * F.mse_loss(observed_n_step_values, expected_n_step_values, reduction = 'none').squeeze(1)).mean() + if demo_mask.sum() > 0: + demo_loss = (torch.as_tensor(importance_weights, device = self.net.device)[demo_mask] * F.mse_loss((state_values[demo_mask] + l).max(1)[0].unsqueeze(1), expert_value, reduction = 'none').squeeze(1)).mean() + else: + demo_loss = 0 loss = td_importance * dqn_loss + demo_importance * demo_loss if self.logger is not None: @@ -106,6 +150,9 @@ class DQfDAgent: # If we're sampling by TD error, readjust the weights of the experiences # TODO: Can probably adjust demonstration priority here td_error = (obtained_values - expected_values).detach().abs() + td_error[demo_mask] = td_error[demo_mask] + self.config['demo_prio_bonus'] + observed_mask = batch_index_tensors >= self.memory.demo_position + td_error[observed_mask] = td_error[observed_mask] + self.config['observed_prio_bonus'] self.memory.update_priorities(batch_indexes, td_error) diff --git a/rltorch/memory/DQfDMemory.py b/rltorch/memory/DQfDMemory.py index ba957ed..77e6014 100644 --- a/rltorch/memory/DQfDMemory.py +++ b/rltorch/memory/DQfDMemory.py @@ -1,4 +1,5 @@ from .PrioritizedReplayMemory import PrioritizedReplayMemory +from collections import namedtuple Transition = namedtuple('Transition', ('state', 'action', 'reward', 'next_state', 'done')) @@ -11,9 +12,9 @@ class DQfDMemory(PrioritizedReplayMemory): self.obtained_transitions_length = 0 def append(self, *args, **kwargs): - super().append(self, *args, **kwargs) + super().append(*args, **kwargs) # Don't overwrite demonstration data - self.position = self.demo_position + ((self.position + 1) % (self.capacity - self.demo_position)) + self.position = self.demo_position + ((self.position + 1) % (len(self.memory) - self.demo_position)) def append_demonstration(self, *args): demonstrations = self.memory[:self.demo_position] @@ -25,4 +26,4 @@ class DQfDMemory(PrioritizedReplayMemory): if len(demonstrations) + len(obtained_transitions) + 1 > self.capacity: obtained_transitions = obtained_transitions[:(self.capacity - len(demonstrations) - 1)] self.memory = demonstrations + [Transition(*args)] + obtained_transitions - self.demo_position += 1 \ No newline at end of file + self.demo_position += 1 diff --git a/rltorch/memory/PrioritizedReplayMemory.py b/rltorch/memory/PrioritizedReplayMemory.py index 00b1d6e..c9aedca 100644 --- a/rltorch/memory/PrioritizedReplayMemory.py +++ b/rltorch/memory/PrioritizedReplayMemory.py @@ -105,7 +105,7 @@ class SumSegmentTree(SegmentTree): """Returns arr[start] + ... + arr[end]""" return super(SumSegmentTree, self).reduce(start, end) - @jit(forceobj = True) + @jit(forceobj = True, parallel = True) def find_prefixsum_idx(self, prefixsum): """Find the highest index `i` in the array such that sum(arr[0] + arr[1] + ... + arr[i - i]) <= prefixsum @@ -204,17 +204,6 @@ class PrioritizedReplayMemory(ReplayMemory): (0 - no corrections, 1 - full correction) Returns ------- - obs_batch: np.array - batch of observations - act_batch: np.array - batch of actions executed given obs_batch - rew_batch: np.array - rewards received as results of executing act_batch - next_obs_batch: np.array - next set of observations seen after executing act_batch - done_mask: np.array - done_mask[i] = 1 if executing act_batch[i] resulted in - the end of an episode and 0 otherwise. weights: np.array Array of shape (batch_size,) and dtype np.float32 denoting importance weight of each sampled transition @@ -224,21 +213,55 @@ class PrioritizedReplayMemory(ReplayMemory): """ assert beta > 0 + # Sample indexes idxes = self._sample_proportional(batch_size) + # Calculate weights weights = [] p_min = self._it_min.min() / self._it_sum.sum() max_weight = (p_min * len(self.memory)) ** (-beta) - for idx in idxes: p_sample = self._it_sum[idx] / self._it_sum.sum() weight = (p_sample * len(self.memory)) ** (-beta) weights.append(weight / max_weight) weights = np.array(weights) + + # Combine all data into a batch encoded_sample = tuple(zip(*self._encode_sample(idxes))) batch = list(zip(*encoded_sample, weights, idxes)) return batch + def sample_n_steps(self, batch_size, steps, beta): + assert beta > 0 + + memory = self.memory + self.memory = self.memory[:-steps] + sample_size = batch_size // steps + + # Sample indexes and get n-steps after that + idxes = self._sample_proportional(sample_size) + step_idxes = [] + for i in idxes: + step_idxes += range(i, i + steps) + + # Calculate appropriate weights + weights = [] + p_min = self._it_min.min() / self._it_sum.sum() + max_weight = (p_min * len(self.memory)) ** (-beta) + for idx in step_idxes: + p_sample = self._it_sum[idx] / self._it_sum.sum() + weight = (p_sample * len(self.memory)) ** (-beta) + weights.append(weight / max_weight) + weights = np.array(weights) + + # Combine all the data together into a batch + encoded_sample = tuple(zip(*self._encode_sample(step_idxes))) + batch = list(zip(*encoded_sample, weights, idxes)) + + # Restore memory and return batch + self.memory = memory + return batch + @jit(forceobj = True) def update_priorities(self, idxes, priorities): """Update priorities of sampled transitions. diff --git a/rltorch/memory/ReplayMemory.py b/rltorch/memory/ReplayMemory.py index 89e6cd8..aa32ab7 100644 --- a/rltorch/memory/ReplayMemory.py +++ b/rltorch/memory/ReplayMemory.py @@ -38,6 +38,13 @@ class ReplayMemory(object): def sample(self, batch_size): return random.sample(self.memory, batch_size) + + def sample_n_steps(self, batch_size, steps): + idxes = random.sample(range(len(self.memory) - steps), batch_size // steps) + step_idxes = [] + for i in idxes: + step_idxes += range(i, i + steps) + return self._encode_sample(step_idxes) def __len__(self): return len(self.memory)