Initial implementation of n-step loss
This commit is contained in:
parent
07c90a09f9
commit
ed62e148d5
4 changed files with 111 additions and 33 deletions
|
@ -28,10 +28,19 @@ class DQfDAgent:
|
|||
weight_importance = self.config['prioritized_replay_weight_importance']
|
||||
# If it's a scheduler then get the next value by calling next, otherwise just use it's value
|
||||
beta = next(weight_importance) if isinstance(weight_importance, collections.Iterable) else weight_importance
|
||||
minibatch = self.memory.sample(self.config['batch_size'], beta = beta)
|
||||
state_batch, action_batch, reward_batch, next_state_batch, not_done_batch, importance_weights, batch_indexes = M.zip_batch(minibatch, priority = True)
|
||||
|
||||
demo_indexes = batch_indexes < self.memory.demo_position
|
||||
# Check to see if we are doing N-Step DQN
|
||||
steps = self.config['n_step'] if 'n_step' in self.config else None
|
||||
if steps is not None:
|
||||
minibatch = self.memory.sample_n_steps(self.config['batch_size'], steps, beta)
|
||||
else:
|
||||
minibatch = self.memory.sample(self.config['batch_size'], beta = beta)
|
||||
|
||||
# Process batch
|
||||
state_batch, action_batch, reward_batch, next_state_batch, not_done_batch, importance_weights, batch_indexes = M.zip_batch(minibatch, priority = True)
|
||||
|
||||
batch_index_tensors = torch.tensor(batch_indexes)
|
||||
demo_mask = batch_index_tensors < self.memory.demo_position
|
||||
|
||||
# Send to their appropriate devices
|
||||
state_batch = state_batch.to(self.net.device)
|
||||
|
@ -43,6 +52,7 @@ class DQfDAgent:
|
|||
state_values = self.net(state_batch)
|
||||
obtained_values = state_values.gather(1, action_batch.view(self.config['batch_size'], 1))
|
||||
|
||||
# DQN Loss
|
||||
with torch.no_grad():
|
||||
# Use the target net to produce action values for the next state
|
||||
# and the regular net to select the action
|
||||
|
@ -61,19 +71,48 @@ class DQfDAgent:
|
|||
|
||||
expected_values = (reward_batch + (self.config['discount_rate'] * best_next_state_value)).unsqueeze(1)
|
||||
|
||||
# Demonstration loss
|
||||
l = torch.ones_like(state_values[demo_indexes])
|
||||
expert_actions = action_batch[demo_indexes]
|
||||
# l(s, a) is zero for every action the expert doesn't take
|
||||
for i,a in zip(range(len(l)), expert_actions):
|
||||
l[i].fill_(0.8) # According to paper
|
||||
l[i, a] = 0
|
||||
if self.target_net is not None:
|
||||
expert_value = self.target_net(state_batch[demo_indexes])
|
||||
else:
|
||||
expert_value = state_values[demo_indexes]
|
||||
expert_value = expert_value.gather(1, expert_actions.view((self.config['batch_size'], 1))).squeeze(1)
|
||||
# N-Step DQN Loss
|
||||
expected_n_step_values = []
|
||||
with torch.no_grad():
|
||||
for i in range(0, len(state_batch), steps):
|
||||
# Get the estimated value at the last state in a sequence
|
||||
if self.target_net is not None:
|
||||
expected_nth_values = self.target_net(state_batch[i + steps])
|
||||
best_nth_action = self.net(state_batch[i + steps]).argmax(1)
|
||||
else:
|
||||
expected_nth_values = self.net(state_batch[i + steps])
|
||||
best_nth_action = expected_nth_values.argmax(1)
|
||||
best_expected_nth_value = expected_nth_values[best_nth_action].squeeze(1)
|
||||
# Calculate the value leading up to it by taking the rewards and multiplying it by the discount rate
|
||||
received_n_value = 0
|
||||
for j in range(steps):
|
||||
received_n_value += self.config['discount_rate']**j * reward_batch[j]
|
||||
# Q(s, a) = r_0 + lambda_1 * r_1 + lambda_2^2 * r_2 + ... + lambda_{steps}^{steps} * max_{a}(Q(s + steps, a))
|
||||
expected_n_step_values.append(received_n_value + self.config['discount_rate']**steps * best_expected_nth_value)
|
||||
expected_n_step_values = torch.stack(expected_n_step_values)
|
||||
# Gather the value the current network thinks it should be
|
||||
observed_n_step_values = []
|
||||
for i in range(0, len(state_batch), steps):
|
||||
observed_nth_value = self.net(state_batch[i])[action_batch[i]]
|
||||
observed_n_step_values.append(observed_nth_value)
|
||||
observed_n_step_values = torch.stack(observed_n_step_values)
|
||||
|
||||
|
||||
|
||||
# Demonstration loss
|
||||
if demo_mask.sum() > 0:
|
||||
l = torch.ones_like(state_values[demo_mask])
|
||||
expert_actions = action_batch[demo_mask]
|
||||
# l(s, a) is zero for every action the expert doesn't take
|
||||
for i,a in zip(range(len(l)), expert_actions):
|
||||
l[i].fill_(0.8) # According to paper
|
||||
l[i, a] = 0
|
||||
if self.target_net is not None:
|
||||
expert_value = self.target_net(state_batch[demo_mask])
|
||||
else:
|
||||
expert_value = state_values[demo_mask]
|
||||
expert_value = expert_value.gather(1, expert_actions.view(demo_mask.sum(), 1))
|
||||
|
||||
# Iterate through hyperparamters
|
||||
if isinstance(self.config['dqfd_demo_loss_weight'], collections.Iterable):
|
||||
demo_importance = next(self.config['dqfd_demo_loss_weight'])
|
||||
|
@ -84,9 +123,14 @@ class DQfDAgent:
|
|||
else:
|
||||
td_importance = self.config['dqfd_td_loss_weight']
|
||||
|
||||
|
||||
# Since dqn_loss and demo_loss are different sizes, the reduction has to happen before they are combined
|
||||
dqn_loss = (torch.as_tensor(importance_weights, device = self.net.device) * F.mse_loss(obtained_values, expected_values, reduction = 'none')).mean()
|
||||
demo_loss = (torch.as_tensor(importance_weights[demo_indexes], device = self.net.device) * F.mse_loss((state_values[demo_indexes] + l).max(1)[0], expert_value, reduction = 'none')).mean()
|
||||
dqn_loss = (torch.as_tensor(importance_weights, device = self.net.device) * F.mse_loss(obtained_values, expected_values, reduction = 'none').squeeze(1)).mean()
|
||||
dqn_n_step_loss = (torch.as_tensor(importance_weights, device = self.net.device) * F.mse_loss(observed_n_step_values, expected_n_step_values, reduction = 'none').squeeze(1)).mean()
|
||||
if demo_mask.sum() > 0:
|
||||
demo_loss = (torch.as_tensor(importance_weights, device = self.net.device)[demo_mask] * F.mse_loss((state_values[demo_mask] + l).max(1)[0].unsqueeze(1), expert_value, reduction = 'none').squeeze(1)).mean()
|
||||
else:
|
||||
demo_loss = 0
|
||||
loss = td_importance * dqn_loss + demo_importance * demo_loss
|
||||
|
||||
if self.logger is not None:
|
||||
|
@ -106,6 +150,9 @@ class DQfDAgent:
|
|||
# If we're sampling by TD error, readjust the weights of the experiences
|
||||
# TODO: Can probably adjust demonstration priority here
|
||||
td_error = (obtained_values - expected_values).detach().abs()
|
||||
td_error[demo_mask] = td_error[demo_mask] + self.config['demo_prio_bonus']
|
||||
observed_mask = batch_index_tensors >= self.memory.demo_position
|
||||
td_error[observed_mask] = td_error[observed_mask] + self.config['observed_prio_bonus']
|
||||
self.memory.update_priorities(batch_indexes, td_error)
|
||||
|
||||
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
from .PrioritizedReplayMemory import PrioritizedReplayMemory
|
||||
from collections import namedtuple
|
||||
|
||||
Transition = namedtuple('Transition',
|
||||
('state', 'action', 'reward', 'next_state', 'done'))
|
||||
|
@ -11,9 +12,9 @@ class DQfDMemory(PrioritizedReplayMemory):
|
|||
self.obtained_transitions_length = 0
|
||||
|
||||
def append(self, *args, **kwargs):
|
||||
super().append(self, *args, **kwargs)
|
||||
super().append(*args, **kwargs)
|
||||
# Don't overwrite demonstration data
|
||||
self.position = self.demo_position + ((self.position + 1) % (self.capacity - self.demo_position))
|
||||
self.position = self.demo_position + ((self.position + 1) % (len(self.memory) - self.demo_position))
|
||||
|
||||
def append_demonstration(self, *args):
|
||||
demonstrations = self.memory[:self.demo_position]
|
||||
|
@ -25,4 +26,4 @@ class DQfDMemory(PrioritizedReplayMemory):
|
|||
if len(demonstrations) + len(obtained_transitions) + 1 > self.capacity:
|
||||
obtained_transitions = obtained_transitions[:(self.capacity - len(demonstrations) - 1)]
|
||||
self.memory = demonstrations + [Transition(*args)] + obtained_transitions
|
||||
self.demo_position += 1
|
||||
self.demo_position += 1
|
||||
|
|
|
@ -105,7 +105,7 @@ class SumSegmentTree(SegmentTree):
|
|||
"""Returns arr[start] + ... + arr[end]"""
|
||||
return super(SumSegmentTree, self).reduce(start, end)
|
||||
|
||||
@jit(forceobj = True)
|
||||
@jit(forceobj = True, parallel = True)
|
||||
def find_prefixsum_idx(self, prefixsum):
|
||||
"""Find the highest index `i` in the array such that
|
||||
sum(arr[0] + arr[1] + ... + arr[i - i]) <= prefixsum
|
||||
|
@ -204,17 +204,6 @@ class PrioritizedReplayMemory(ReplayMemory):
|
|||
(0 - no corrections, 1 - full correction)
|
||||
Returns
|
||||
-------
|
||||
obs_batch: np.array
|
||||
batch of observations
|
||||
act_batch: np.array
|
||||
batch of actions executed given obs_batch
|
||||
rew_batch: np.array
|
||||
rewards received as results of executing act_batch
|
||||
next_obs_batch: np.array
|
||||
next set of observations seen after executing act_batch
|
||||
done_mask: np.array
|
||||
done_mask[i] = 1 if executing act_batch[i] resulted in
|
||||
the end of an episode and 0 otherwise.
|
||||
weights: np.array
|
||||
Array of shape (batch_size,) and dtype np.float32
|
||||
denoting importance weight of each sampled transition
|
||||
|
@ -224,21 +213,55 @@ class PrioritizedReplayMemory(ReplayMemory):
|
|||
"""
|
||||
assert beta > 0
|
||||
|
||||
# Sample indexes
|
||||
idxes = self._sample_proportional(batch_size)
|
||||
|
||||
# Calculate weights
|
||||
weights = []
|
||||
p_min = self._it_min.min() / self._it_sum.sum()
|
||||
max_weight = (p_min * len(self.memory)) ** (-beta)
|
||||
|
||||
for idx in idxes:
|
||||
p_sample = self._it_sum[idx] / self._it_sum.sum()
|
||||
weight = (p_sample * len(self.memory)) ** (-beta)
|
||||
weights.append(weight / max_weight)
|
||||
weights = np.array(weights)
|
||||
|
||||
# Combine all data into a batch
|
||||
encoded_sample = tuple(zip(*self._encode_sample(idxes)))
|
||||
batch = list(zip(*encoded_sample, weights, idxes))
|
||||
return batch
|
||||
|
||||
def sample_n_steps(self, batch_size, steps, beta):
|
||||
assert beta > 0
|
||||
|
||||
memory = self.memory
|
||||
self.memory = self.memory[:-steps]
|
||||
sample_size = batch_size // steps
|
||||
|
||||
# Sample indexes and get n-steps after that
|
||||
idxes = self._sample_proportional(sample_size)
|
||||
step_idxes = []
|
||||
for i in idxes:
|
||||
step_idxes += range(i, i + steps)
|
||||
|
||||
# Calculate appropriate weights
|
||||
weights = []
|
||||
p_min = self._it_min.min() / self._it_sum.sum()
|
||||
max_weight = (p_min * len(self.memory)) ** (-beta)
|
||||
for idx in step_idxes:
|
||||
p_sample = self._it_sum[idx] / self._it_sum.sum()
|
||||
weight = (p_sample * len(self.memory)) ** (-beta)
|
||||
weights.append(weight / max_weight)
|
||||
weights = np.array(weights)
|
||||
|
||||
# Combine all the data together into a batch
|
||||
encoded_sample = tuple(zip(*self._encode_sample(step_idxes)))
|
||||
batch = list(zip(*encoded_sample, weights, idxes))
|
||||
|
||||
# Restore memory and return batch
|
||||
self.memory = memory
|
||||
return batch
|
||||
|
||||
@jit(forceobj = True)
|
||||
def update_priorities(self, idxes, priorities):
|
||||
"""Update priorities of sampled transitions.
|
||||
|
|
|
@ -38,6 +38,13 @@ class ReplayMemory(object):
|
|||
|
||||
def sample(self, batch_size):
|
||||
return random.sample(self.memory, batch_size)
|
||||
|
||||
def sample_n_steps(self, batch_size, steps):
|
||||
idxes = random.sample(range(len(self.memory) - steps), batch_size // steps)
|
||||
step_idxes = []
|
||||
for i in idxes:
|
||||
step_idxes += range(i, i + steps)
|
||||
return self._encode_sample(step_idxes)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.memory)
|
||||
|
|
Loading…
Reference in a new issue