Fixed errors with n-step returns
This commit is contained in:
parent
ed62e148d5
commit
038d406d0f
3 changed files with 98 additions and 47 deletions
|
@ -25,16 +25,22 @@ class DQfDAgent:
|
||||||
if len(self.memory) < self.config['batch_size']:
|
if len(self.memory) < self.config['batch_size']:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
if 'n_step' in self.config:
|
||||||
|
batch_size = (self.config['batch_size'] // self.config['n_step']) * self.config['n_step']
|
||||||
|
steps = self.config['n_step']
|
||||||
|
else:
|
||||||
|
batch_size = self.config['batch_size']
|
||||||
|
steps = None
|
||||||
|
|
||||||
weight_importance = self.config['prioritized_replay_weight_importance']
|
weight_importance = self.config['prioritized_replay_weight_importance']
|
||||||
# If it's a scheduler then get the next value by calling next, otherwise just use it's value
|
# If it's a scheduler then get the next value by calling next, otherwise just use it's value
|
||||||
beta = next(weight_importance) if isinstance(weight_importance, collections.Iterable) else weight_importance
|
beta = next(weight_importance) if isinstance(weight_importance, collections.Iterable) else weight_importance
|
||||||
|
|
||||||
# Check to see if we are doing N-Step DQN
|
# Check to see if we are doing N-Step DQN
|
||||||
steps = self.config['n_step'] if 'n_step' in self.config else None
|
|
||||||
if steps is not None:
|
if steps is not None:
|
||||||
minibatch = self.memory.sample_n_steps(self.config['batch_size'], steps, beta)
|
minibatch = self.memory.sample_n_steps(batch_size, steps, beta)
|
||||||
else:
|
else:
|
||||||
minibatch = self.memory.sample(self.config['batch_size'], beta = beta)
|
minibatch = self.memory.sample(batch_size, beta = beta)
|
||||||
|
|
||||||
# Process batch
|
# Process batch
|
||||||
state_batch, action_batch, reward_batch, next_state_batch, not_done_batch, importance_weights, batch_indexes = M.zip_batch(minibatch, priority = True)
|
state_batch, action_batch, reward_batch, next_state_batch, not_done_batch, importance_weights, batch_indexes = M.zip_batch(minibatch, priority = True)
|
||||||
|
@ -50,7 +56,7 @@ class DQfDAgent:
|
||||||
not_done_batch = not_done_batch.to(self.net.device)
|
not_done_batch = not_done_batch.to(self.net.device)
|
||||||
|
|
||||||
state_values = self.net(state_batch)
|
state_values = self.net(state_batch)
|
||||||
obtained_values = state_values.gather(1, action_batch.view(self.config['batch_size'], 1))
|
obtained_values = state_values.gather(1, action_batch.view(batch_size, 1))
|
||||||
|
|
||||||
# DQN Loss
|
# DQN Loss
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
@ -66,38 +72,44 @@ class DQfDAgent:
|
||||||
next_state_values[not_done_batch] = self.net(next_state_batch[not_done_batch])
|
next_state_values[not_done_batch] = self.net(next_state_batch[not_done_batch])
|
||||||
next_best_action = next_state_values[not_done_batch].argmax(1)
|
next_best_action = next_state_values[not_done_batch].argmax(1)
|
||||||
|
|
||||||
best_next_state_value = torch.zeros(self.config['batch_size'], device = self.net.device)
|
best_next_state_value = torch.zeros(batch_size, device = self.net.device)
|
||||||
best_next_state_value[not_done_batch] = next_state_values[not_done_batch].gather(1, next_best_action.view((not_done_size, 1))).squeeze(1)
|
best_next_state_value[not_done_batch] = next_state_values[not_done_batch].gather(1, next_best_action.view((not_done_size, 1))).squeeze(1)
|
||||||
|
|
||||||
expected_values = (reward_batch + (self.config['discount_rate'] * best_next_state_value)).unsqueeze(1)
|
expected_values = (reward_batch + (batch_size * best_next_state_value)).unsqueeze(1)
|
||||||
|
|
||||||
# N-Step DQN Loss
|
# N-Step DQN Loss
|
||||||
expected_n_step_values = []
|
# num_steps capture how many steps actually exist before the end of episode
|
||||||
with torch.no_grad():
|
if steps != None:
|
||||||
|
expected_n_step_values = []
|
||||||
|
with torch.no_grad():
|
||||||
|
for i in range(0, len(state_batch), steps):
|
||||||
|
num_steps = not_done_batch[i:(i + steps)].sum()
|
||||||
|
if num_steps < 2:
|
||||||
|
continue # No point processing this
|
||||||
|
# Get the estimated value at the last state in a sequence
|
||||||
|
if self.target_net is not None:
|
||||||
|
expected_nth_values = self.target_net(state_batch[i + num_steps - 1].unsqueeze(0)).squeeze(0)
|
||||||
|
best_nth_action = self.net(state_batch[i + num_steps - 1].unsqueeze(0)).squeeze(0).argmax(0)
|
||||||
|
else:
|
||||||
|
expected_nth_values = self.net(state_batch[i + num_steps - 1].unsqueeze(0)).squeeze(0)
|
||||||
|
best_nth_action = expected_nth_values.argmax(0)
|
||||||
|
best_expected_nth_value = expected_nth_values[best_nth_action]
|
||||||
|
# Calculate the value leading up to it by taking the rewards and multiplying it by the discount rate
|
||||||
|
received_n_value = 0
|
||||||
|
for j in range(num_steps):
|
||||||
|
received_n_value += self.config['discount_rate']**j * reward_batch[j]
|
||||||
|
# Q(s, a) = r_0 + lambda_1 * r_1 + lambda_2^2 * r_2 + ... + lambda_{steps}^{steps} * max_{a}(Q(s + steps, a))
|
||||||
|
expected_n_step_values.append(received_n_value + self.config['discount_rate']**num_steps * best_expected_nth_value)
|
||||||
|
expected_n_step_values = torch.stack(expected_n_step_values)
|
||||||
|
# Gather the value the current network thinks it should be
|
||||||
|
observed_n_step_values = []
|
||||||
for i in range(0, len(state_batch), steps):
|
for i in range(0, len(state_batch), steps):
|
||||||
# Get the estimated value at the last state in a sequence
|
num_steps = not_done_batch[i:(i + steps)].sum()
|
||||||
if self.target_net is not None:
|
if num_steps < 2:
|
||||||
expected_nth_values = self.target_net(state_batch[i + steps])
|
continue # No point processing this
|
||||||
best_nth_action = self.net(state_batch[i + steps]).argmax(1)
|
observed_nth_value = self.net(state_batch[i].unsqueeze(0)).squeeze(0)[action_batch[i]]
|
||||||
else:
|
observed_n_step_values.append(observed_nth_value)
|
||||||
expected_nth_values = self.net(state_batch[i + steps])
|
observed_n_step_values = torch.stack(observed_n_step_values)
|
||||||
best_nth_action = expected_nth_values.argmax(1)
|
|
||||||
best_expected_nth_value = expected_nth_values[best_nth_action].squeeze(1)
|
|
||||||
# Calculate the value leading up to it by taking the rewards and multiplying it by the discount rate
|
|
||||||
received_n_value = 0
|
|
||||||
for j in range(steps):
|
|
||||||
received_n_value += self.config['discount_rate']**j * reward_batch[j]
|
|
||||||
# Q(s, a) = r_0 + lambda_1 * r_1 + lambda_2^2 * r_2 + ... + lambda_{steps}^{steps} * max_{a}(Q(s + steps, a))
|
|
||||||
expected_n_step_values.append(received_n_value + self.config['discount_rate']**steps * best_expected_nth_value)
|
|
||||||
expected_n_step_values = torch.stack(expected_n_step_values)
|
|
||||||
# Gather the value the current network thinks it should be
|
|
||||||
observed_n_step_values = []
|
|
||||||
for i in range(0, len(state_batch), steps):
|
|
||||||
observed_nth_value = self.net(state_batch[i])[action_batch[i]]
|
|
||||||
observed_n_step_values.append(observed_nth_value)
|
|
||||||
observed_n_step_values = torch.stack(observed_n_step_values)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Demonstration loss
|
# Demonstration loss
|
||||||
if demo_mask.sum() > 0:
|
if demo_mask.sum() > 0:
|
||||||
|
@ -126,12 +138,17 @@ class DQfDAgent:
|
||||||
|
|
||||||
# Since dqn_loss and demo_loss are different sizes, the reduction has to happen before they are combined
|
# Since dqn_loss and demo_loss are different sizes, the reduction has to happen before they are combined
|
||||||
dqn_loss = (torch.as_tensor(importance_weights, device = self.net.device) * F.mse_loss(obtained_values, expected_values, reduction = 'none').squeeze(1)).mean()
|
dqn_loss = (torch.as_tensor(importance_weights, device = self.net.device) * F.mse_loss(obtained_values, expected_values, reduction = 'none').squeeze(1)).mean()
|
||||||
dqn_n_step_loss = (torch.as_tensor(importance_weights, device = self.net.device) * F.mse_loss(observed_n_step_values, expected_n_step_values, reduction = 'none').squeeze(1)).mean()
|
|
||||||
|
if steps != None:
|
||||||
|
dqn_n_step_loss = (torch.as_tensor(importance_weights[::steps], device = self.net.device) * F.mse_loss(observed_n_step_values, expected_n_step_values, reduction = 'none')).mean()
|
||||||
|
else:
|
||||||
|
dqn_n_step_loss = torch.tensor(0, device = self.net.device)
|
||||||
|
|
||||||
if demo_mask.sum() > 0:
|
if demo_mask.sum() > 0:
|
||||||
demo_loss = (torch.as_tensor(importance_weights, device = self.net.device)[demo_mask] * F.mse_loss((state_values[demo_mask] + l).max(1)[0].unsqueeze(1), expert_value, reduction = 'none').squeeze(1)).mean()
|
demo_loss = (torch.as_tensor(importance_weights, device = self.net.device)[demo_mask] * F.mse_loss((state_values[demo_mask] + l).max(1)[0].unsqueeze(1), expert_value, reduction = 'none').squeeze(1)).mean()
|
||||||
else:
|
else:
|
||||||
demo_loss = 0
|
demo_loss = 0
|
||||||
loss = td_importance * dqn_loss + demo_importance * demo_loss
|
loss = td_importance * dqn_loss + td_importance * dqn_n_step_loss + demo_importance * demo_loss
|
||||||
|
|
||||||
if self.logger is not None:
|
if self.logger is not None:
|
||||||
self.logger.append("Loss", loss.item())
|
self.logger.append("Loss", loss.item())
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
from .PrioritizedReplayMemory import PrioritizedReplayMemory
|
from .PrioritizedReplayMemory import PrioritizedReplayMemory
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
Transition = namedtuple('Transition',
|
Transition = namedtuple('Transition',
|
||||||
('state', 'action', 'reward', 'next_state', 'done'))
|
('state', 'action', 'reward', 'next_state', 'done'))
|
||||||
|
@ -9,12 +10,13 @@ class DQfDMemory(PrioritizedReplayMemory):
|
||||||
def __init__(self, capacity, alpha):
|
def __init__(self, capacity, alpha):
|
||||||
super().__init__(capacity, alpha)
|
super().__init__(capacity, alpha)
|
||||||
self.demo_position = 0
|
self.demo_position = 0
|
||||||
self.obtained_transitions_length = 0
|
|
||||||
|
|
||||||
def append(self, *args, **kwargs):
|
def append(self, *args, **kwargs):
|
||||||
|
last_position = self.position # Get position before super classes change it
|
||||||
super().append(*args, **kwargs)
|
super().append(*args, **kwargs)
|
||||||
# Don't overwrite demonstration data
|
# Don't overwrite demonstration data
|
||||||
self.position = self.demo_position + ((self.position + 1) % (len(self.memory) - self.demo_position))
|
new_position = ((last_position + 1) % (self.capacity - self.demo_position + 1))
|
||||||
|
self.position = new_position if new_position > self.demo_position else self.demo_position + new_position
|
||||||
|
|
||||||
def append_demonstration(self, *args):
|
def append_demonstration(self, *args):
|
||||||
demonstrations = self.memory[:self.demo_position]
|
demonstrations = self.memory[:self.demo_position]
|
||||||
|
@ -24,6 +26,40 @@ class DQfDMemory(PrioritizedReplayMemory):
|
||||||
self.memory.append(Transition(*args))
|
self.memory.append(Transition(*args))
|
||||||
else:
|
else:
|
||||||
if len(demonstrations) + len(obtained_transitions) + 1 > self.capacity:
|
if len(demonstrations) + len(obtained_transitions) + 1 > self.capacity:
|
||||||
obtained_transitions = obtained_transitions[:(self.capacity - len(demonstrations) - 1)]
|
obtained_transitions = obtained_transitions[1:]
|
||||||
self.memory = demonstrations + [Transition(*args)] + obtained_transitions
|
self.memory = demonstrations + [Transition(*args)] + obtained_transitions
|
||||||
self.demo_position += 1
|
self.demo_position += 1
|
||||||
|
self.position += 1
|
||||||
|
|
||||||
|
def sample_n_steps(self, batch_size, steps, beta):
|
||||||
|
assert beta > 0
|
||||||
|
|
||||||
|
sample_size = batch_size // steps
|
||||||
|
|
||||||
|
# Sample indexes and get n-steps after that
|
||||||
|
idxes = self._sample_proportional(sample_size)
|
||||||
|
step_idxes = []
|
||||||
|
for i in idxes:
|
||||||
|
# If the interval of experiences fall between demonstration and obtained, move it over to the demonstration half
|
||||||
|
if i < self.demo_position and i + steps > self.demo_position:
|
||||||
|
diff = i + steps - self.demo_position
|
||||||
|
step_idxes += range(i - diff, i + steps - diff)
|
||||||
|
elif i > steps:
|
||||||
|
step_idxes += range(i - steps, i)
|
||||||
|
else:
|
||||||
|
step_idxes += range(i, i + steps)
|
||||||
|
|
||||||
|
# Calculate appropriate weights and assign it to the values of the same sequence
|
||||||
|
weights = []
|
||||||
|
p_min = self._it_min.min() / self._it_sum.sum()
|
||||||
|
max_weight = (p_min * len(self.memory)) ** (-beta)
|
||||||
|
for idx in idxes:
|
||||||
|
p_sample = self._it_sum[idx] / self._it_sum.sum()
|
||||||
|
weight = (p_sample * len(self.memory)) ** (-beta)
|
||||||
|
weights += [(weight / max_weight) for i in range(steps)]
|
||||||
|
weights = np.array(weights)
|
||||||
|
|
||||||
|
# Combine all the data together into a batch
|
||||||
|
encoded_sample = tuple(zip(*self._encode_sample(step_idxes)))
|
||||||
|
batch = list(zip(*encoded_sample, weights, step_idxes))
|
||||||
|
return batch
|
|
@ -234,32 +234,30 @@ class PrioritizedReplayMemory(ReplayMemory):
|
||||||
def sample_n_steps(self, batch_size, steps, beta):
|
def sample_n_steps(self, batch_size, steps, beta):
|
||||||
assert beta > 0
|
assert beta > 0
|
||||||
|
|
||||||
memory = self.memory
|
|
||||||
self.memory = self.memory[:-steps]
|
|
||||||
sample_size = batch_size // steps
|
sample_size = batch_size // steps
|
||||||
|
|
||||||
# Sample indexes and get n-steps after that
|
# Sample indexes and get n-steps after that
|
||||||
idxes = self._sample_proportional(sample_size)
|
idxes = self._sample_proportional(sample_size)
|
||||||
step_idxes = []
|
step_idxes = []
|
||||||
for i in idxes:
|
for i in idxes:
|
||||||
step_idxes += range(i, i + steps)
|
if i > steps:
|
||||||
|
step_idxes += range(i - steps, i)
|
||||||
|
else:
|
||||||
|
step_idxes += range(i, i + steps)
|
||||||
|
|
||||||
# Calculate appropriate weights
|
# Calculate appropriate weights and assign it to the values of the same sequence
|
||||||
weights = []
|
weights = []
|
||||||
p_min = self._it_min.min() / self._it_sum.sum()
|
p_min = self._it_min.min() / self._it_sum.sum()
|
||||||
max_weight = (p_min * len(self.memory)) ** (-beta)
|
max_weight = (p_min * len(self.memory)) ** (-beta)
|
||||||
for idx in step_idxes:
|
for idx in idxes:
|
||||||
p_sample = self._it_sum[idx] / self._it_sum.sum()
|
p_sample = self._it_sum[idx] / self._it_sum.sum()
|
||||||
weight = (p_sample * len(self.memory)) ** (-beta)
|
weight = (p_sample * len(self.memory)) ** (-beta)
|
||||||
weights.append(weight / max_weight)
|
weights += [(weight / max_weight) for i in range(steps)]
|
||||||
weights = np.array(weights)
|
weights = np.array(weights)
|
||||||
|
|
||||||
# Combine all the data together into a batch
|
# Combine all the data together into a batch
|
||||||
encoded_sample = tuple(zip(*self._encode_sample(step_idxes)))
|
encoded_sample = tuple(zip(*self._encode_sample(step_idxes)))
|
||||||
batch = list(zip(*encoded_sample, weights, idxes))
|
batch = list(zip(*encoded_sample, weights, step_idxes))
|
||||||
|
|
||||||
# Restore memory and return batch
|
|
||||||
self.memory = memory
|
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
@jit(forceobj = True)
|
@jit(forceobj = True)
|
||||||
|
|
Loading…
Reference in a new issue