Fixed errors with n-step returns

This commit is contained in:
Brandon Rozek 2019-11-13 22:56:27 -05:00
parent ed62e148d5
commit 038d406d0f
3 changed files with 98 additions and 47 deletions

View file

@ -25,16 +25,22 @@ class DQfDAgent:
if len(self.memory) < self.config['batch_size']: if len(self.memory) < self.config['batch_size']:
return return
if 'n_step' in self.config:
batch_size = (self.config['batch_size'] // self.config['n_step']) * self.config['n_step']
steps = self.config['n_step']
else:
batch_size = self.config['batch_size']
steps = None
weight_importance = self.config['prioritized_replay_weight_importance'] weight_importance = self.config['prioritized_replay_weight_importance']
# If it's a scheduler then get the next value by calling next, otherwise just use it's value # If it's a scheduler then get the next value by calling next, otherwise just use it's value
beta = next(weight_importance) if isinstance(weight_importance, collections.Iterable) else weight_importance beta = next(weight_importance) if isinstance(weight_importance, collections.Iterable) else weight_importance
# Check to see if we are doing N-Step DQN # Check to see if we are doing N-Step DQN
steps = self.config['n_step'] if 'n_step' in self.config else None
if steps is not None: if steps is not None:
minibatch = self.memory.sample_n_steps(self.config['batch_size'], steps, beta) minibatch = self.memory.sample_n_steps(batch_size, steps, beta)
else: else:
minibatch = self.memory.sample(self.config['batch_size'], beta = beta) minibatch = self.memory.sample(batch_size, beta = beta)
# Process batch # Process batch
state_batch, action_batch, reward_batch, next_state_batch, not_done_batch, importance_weights, batch_indexes = M.zip_batch(minibatch, priority = True) state_batch, action_batch, reward_batch, next_state_batch, not_done_batch, importance_weights, batch_indexes = M.zip_batch(minibatch, priority = True)
@ -50,7 +56,7 @@ class DQfDAgent:
not_done_batch = not_done_batch.to(self.net.device) not_done_batch = not_done_batch.to(self.net.device)
state_values = self.net(state_batch) state_values = self.net(state_batch)
obtained_values = state_values.gather(1, action_batch.view(self.config['batch_size'], 1)) obtained_values = state_values.gather(1, action_batch.view(batch_size, 1))
# DQN Loss # DQN Loss
with torch.no_grad(): with torch.no_grad():
@ -66,38 +72,44 @@ class DQfDAgent:
next_state_values[not_done_batch] = self.net(next_state_batch[not_done_batch]) next_state_values[not_done_batch] = self.net(next_state_batch[not_done_batch])
next_best_action = next_state_values[not_done_batch].argmax(1) next_best_action = next_state_values[not_done_batch].argmax(1)
best_next_state_value = torch.zeros(self.config['batch_size'], device = self.net.device) best_next_state_value = torch.zeros(batch_size, device = self.net.device)
best_next_state_value[not_done_batch] = next_state_values[not_done_batch].gather(1, next_best_action.view((not_done_size, 1))).squeeze(1) best_next_state_value[not_done_batch] = next_state_values[not_done_batch].gather(1, next_best_action.view((not_done_size, 1))).squeeze(1)
expected_values = (reward_batch + (self.config['discount_rate'] * best_next_state_value)).unsqueeze(1) expected_values = (reward_batch + (batch_size * best_next_state_value)).unsqueeze(1)
# N-Step DQN Loss # N-Step DQN Loss
expected_n_step_values = [] # num_steps capture how many steps actually exist before the end of episode
with torch.no_grad(): if steps != None:
expected_n_step_values = []
with torch.no_grad():
for i in range(0, len(state_batch), steps):
num_steps = not_done_batch[i:(i + steps)].sum()
if num_steps < 2:
continue # No point processing this
# Get the estimated value at the last state in a sequence
if self.target_net is not None:
expected_nth_values = self.target_net(state_batch[i + num_steps - 1].unsqueeze(0)).squeeze(0)
best_nth_action = self.net(state_batch[i + num_steps - 1].unsqueeze(0)).squeeze(0).argmax(0)
else:
expected_nth_values = self.net(state_batch[i + num_steps - 1].unsqueeze(0)).squeeze(0)
best_nth_action = expected_nth_values.argmax(0)
best_expected_nth_value = expected_nth_values[best_nth_action]
# Calculate the value leading up to it by taking the rewards and multiplying it by the discount rate
received_n_value = 0
for j in range(num_steps):
received_n_value += self.config['discount_rate']**j * reward_batch[j]
# Q(s, a) = r_0 + lambda_1 * r_1 + lambda_2^2 * r_2 + ... + lambda_{steps}^{steps} * max_{a}(Q(s + steps, a))
expected_n_step_values.append(received_n_value + self.config['discount_rate']**num_steps * best_expected_nth_value)
expected_n_step_values = torch.stack(expected_n_step_values)
# Gather the value the current network thinks it should be
observed_n_step_values = []
for i in range(0, len(state_batch), steps): for i in range(0, len(state_batch), steps):
# Get the estimated value at the last state in a sequence num_steps = not_done_batch[i:(i + steps)].sum()
if self.target_net is not None: if num_steps < 2:
expected_nth_values = self.target_net(state_batch[i + steps]) continue # No point processing this
best_nth_action = self.net(state_batch[i + steps]).argmax(1) observed_nth_value = self.net(state_batch[i].unsqueeze(0)).squeeze(0)[action_batch[i]]
else: observed_n_step_values.append(observed_nth_value)
expected_nth_values = self.net(state_batch[i + steps]) observed_n_step_values = torch.stack(observed_n_step_values)
best_nth_action = expected_nth_values.argmax(1)
best_expected_nth_value = expected_nth_values[best_nth_action].squeeze(1)
# Calculate the value leading up to it by taking the rewards and multiplying it by the discount rate
received_n_value = 0
for j in range(steps):
received_n_value += self.config['discount_rate']**j * reward_batch[j]
# Q(s, a) = r_0 + lambda_1 * r_1 + lambda_2^2 * r_2 + ... + lambda_{steps}^{steps} * max_{a}(Q(s + steps, a))
expected_n_step_values.append(received_n_value + self.config['discount_rate']**steps * best_expected_nth_value)
expected_n_step_values = torch.stack(expected_n_step_values)
# Gather the value the current network thinks it should be
observed_n_step_values = []
for i in range(0, len(state_batch), steps):
observed_nth_value = self.net(state_batch[i])[action_batch[i]]
observed_n_step_values.append(observed_nth_value)
observed_n_step_values = torch.stack(observed_n_step_values)
# Demonstration loss # Demonstration loss
if demo_mask.sum() > 0: if demo_mask.sum() > 0:
@ -126,12 +138,17 @@ class DQfDAgent:
# Since dqn_loss and demo_loss are different sizes, the reduction has to happen before they are combined # Since dqn_loss and demo_loss are different sizes, the reduction has to happen before they are combined
dqn_loss = (torch.as_tensor(importance_weights, device = self.net.device) * F.mse_loss(obtained_values, expected_values, reduction = 'none').squeeze(1)).mean() dqn_loss = (torch.as_tensor(importance_weights, device = self.net.device) * F.mse_loss(obtained_values, expected_values, reduction = 'none').squeeze(1)).mean()
dqn_n_step_loss = (torch.as_tensor(importance_weights, device = self.net.device) * F.mse_loss(observed_n_step_values, expected_n_step_values, reduction = 'none').squeeze(1)).mean()
if steps != None:
dqn_n_step_loss = (torch.as_tensor(importance_weights[::steps], device = self.net.device) * F.mse_loss(observed_n_step_values, expected_n_step_values, reduction = 'none')).mean()
else:
dqn_n_step_loss = torch.tensor(0, device = self.net.device)
if demo_mask.sum() > 0: if demo_mask.sum() > 0:
demo_loss = (torch.as_tensor(importance_weights, device = self.net.device)[demo_mask] * F.mse_loss((state_values[demo_mask] + l).max(1)[0].unsqueeze(1), expert_value, reduction = 'none').squeeze(1)).mean() demo_loss = (torch.as_tensor(importance_weights, device = self.net.device)[demo_mask] * F.mse_loss((state_values[demo_mask] + l).max(1)[0].unsqueeze(1), expert_value, reduction = 'none').squeeze(1)).mean()
else: else:
demo_loss = 0 demo_loss = 0
loss = td_importance * dqn_loss + demo_importance * demo_loss loss = td_importance * dqn_loss + td_importance * dqn_n_step_loss + demo_importance * demo_loss
if self.logger is not None: if self.logger is not None:
self.logger.append("Loss", loss.item()) self.logger.append("Loss", loss.item())

View file

@ -1,5 +1,6 @@
from .PrioritizedReplayMemory import PrioritizedReplayMemory from .PrioritizedReplayMemory import PrioritizedReplayMemory
from collections import namedtuple from collections import namedtuple
import numpy as np
Transition = namedtuple('Transition', Transition = namedtuple('Transition',
('state', 'action', 'reward', 'next_state', 'done')) ('state', 'action', 'reward', 'next_state', 'done'))
@ -9,12 +10,13 @@ class DQfDMemory(PrioritizedReplayMemory):
def __init__(self, capacity, alpha): def __init__(self, capacity, alpha):
super().__init__(capacity, alpha) super().__init__(capacity, alpha)
self.demo_position = 0 self.demo_position = 0
self.obtained_transitions_length = 0
def append(self, *args, **kwargs): def append(self, *args, **kwargs):
last_position = self.position # Get position before super classes change it
super().append(*args, **kwargs) super().append(*args, **kwargs)
# Don't overwrite demonstration data # Don't overwrite demonstration data
self.position = self.demo_position + ((self.position + 1) % (len(self.memory) - self.demo_position)) new_position = ((last_position + 1) % (self.capacity - self.demo_position + 1))
self.position = new_position if new_position > self.demo_position else self.demo_position + new_position
def append_demonstration(self, *args): def append_demonstration(self, *args):
demonstrations = self.memory[:self.demo_position] demonstrations = self.memory[:self.demo_position]
@ -24,6 +26,40 @@ class DQfDMemory(PrioritizedReplayMemory):
self.memory.append(Transition(*args)) self.memory.append(Transition(*args))
else: else:
if len(demonstrations) + len(obtained_transitions) + 1 > self.capacity: if len(demonstrations) + len(obtained_transitions) + 1 > self.capacity:
obtained_transitions = obtained_transitions[:(self.capacity - len(demonstrations) - 1)] obtained_transitions = obtained_transitions[1:]
self.memory = demonstrations + [Transition(*args)] + obtained_transitions self.memory = demonstrations + [Transition(*args)] + obtained_transitions
self.demo_position += 1 self.demo_position += 1
self.position += 1
def sample_n_steps(self, batch_size, steps, beta):
assert beta > 0
sample_size = batch_size // steps
# Sample indexes and get n-steps after that
idxes = self._sample_proportional(sample_size)
step_idxes = []
for i in idxes:
# If the interval of experiences fall between demonstration and obtained, move it over to the demonstration half
if i < self.demo_position and i + steps > self.demo_position:
diff = i + steps - self.demo_position
step_idxes += range(i - diff, i + steps - diff)
elif i > steps:
step_idxes += range(i - steps, i)
else:
step_idxes += range(i, i + steps)
# Calculate appropriate weights and assign it to the values of the same sequence
weights = []
p_min = self._it_min.min() / self._it_sum.sum()
max_weight = (p_min * len(self.memory)) ** (-beta)
for idx in idxes:
p_sample = self._it_sum[idx] / self._it_sum.sum()
weight = (p_sample * len(self.memory)) ** (-beta)
weights += [(weight / max_weight) for i in range(steps)]
weights = np.array(weights)
# Combine all the data together into a batch
encoded_sample = tuple(zip(*self._encode_sample(step_idxes)))
batch = list(zip(*encoded_sample, weights, step_idxes))
return batch

View file

@ -234,32 +234,30 @@ class PrioritizedReplayMemory(ReplayMemory):
def sample_n_steps(self, batch_size, steps, beta): def sample_n_steps(self, batch_size, steps, beta):
assert beta > 0 assert beta > 0
memory = self.memory
self.memory = self.memory[:-steps]
sample_size = batch_size // steps sample_size = batch_size // steps
# Sample indexes and get n-steps after that # Sample indexes and get n-steps after that
idxes = self._sample_proportional(sample_size) idxes = self._sample_proportional(sample_size)
step_idxes = [] step_idxes = []
for i in idxes: for i in idxes:
step_idxes += range(i, i + steps) if i > steps:
step_idxes += range(i - steps, i)
else:
step_idxes += range(i, i + steps)
# Calculate appropriate weights # Calculate appropriate weights and assign it to the values of the same sequence
weights = [] weights = []
p_min = self._it_min.min() / self._it_sum.sum() p_min = self._it_min.min() / self._it_sum.sum()
max_weight = (p_min * len(self.memory)) ** (-beta) max_weight = (p_min * len(self.memory)) ** (-beta)
for idx in step_idxes: for idx in idxes:
p_sample = self._it_sum[idx] / self._it_sum.sum() p_sample = self._it_sum[idx] / self._it_sum.sum()
weight = (p_sample * len(self.memory)) ** (-beta) weight = (p_sample * len(self.memory)) ** (-beta)
weights.append(weight / max_weight) weights += [(weight / max_weight) for i in range(steps)]
weights = np.array(weights) weights = np.array(weights)
# Combine all the data together into a batch # Combine all the data together into a batch
encoded_sample = tuple(zip(*self._encode_sample(step_idxes))) encoded_sample = tuple(zip(*self._encode_sample(step_idxes)))
batch = list(zip(*encoded_sample, weights, idxes)) batch = list(zip(*encoded_sample, weights, step_idxes))
# Restore memory and return batch
self.memory = memory
return batch return batch
@jit(forceobj = True) @jit(forceobj = True)