Better handling of demonstration data
This commit is contained in:
parent
3217c76a79
commit
838062813a
5 changed files with 65 additions and 35 deletions
|
@ -32,6 +32,7 @@ class DQfDAgent:
|
||||||
batch_size = self.config['batch_size']
|
batch_size = self.config['batch_size']
|
||||||
steps = None
|
steps = None
|
||||||
|
|
||||||
|
if isinstance(self.memory, M.DQfDMemory):
|
||||||
weight_importance = self.config['prioritized_replay_weight_importance']
|
weight_importance = self.config['prioritized_replay_weight_importance']
|
||||||
# If it's a scheduler then get the next value by calling next, otherwise just use it's value
|
# If it's a scheduler then get the next value by calling next, otherwise just use it's value
|
||||||
beta = next(weight_importance) if isinstance(weight_importance, collections.Iterable) else weight_importance
|
beta = next(weight_importance) if isinstance(weight_importance, collections.Iterable) else weight_importance
|
||||||
|
@ -45,6 +46,16 @@ class DQfDAgent:
|
||||||
# Process batch
|
# Process batch
|
||||||
state_batch, action_batch, reward_batch, next_state_batch, not_done_batch, importance_weights, batch_indexes = M.zip_batch(minibatch, priority = True)
|
state_batch, action_batch, reward_batch, next_state_batch, not_done_batch, importance_weights, batch_indexes = M.zip_batch(minibatch, priority = True)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Check to see if we're doing N-Step DQN
|
||||||
|
if steps is not None:
|
||||||
|
minibatch = self.memory.sample_n_steps(batch_size, steps)
|
||||||
|
else:
|
||||||
|
minibatch = self.memory.sample(batch_size)
|
||||||
|
|
||||||
|
# Process batch
|
||||||
|
state_batch, action_batch, reward_batch, next_state_batch, not_done_batch, batch_indexes = M.zip_batch(minibatch, want_indices = True)
|
||||||
|
|
||||||
batch_index_tensors = torch.tensor(batch_indexes)
|
batch_index_tensors = torch.tensor(batch_indexes)
|
||||||
demo_mask = batch_index_tensors < self.memory.demo_position
|
demo_mask = batch_index_tensors < self.memory.demo_position
|
||||||
|
|
||||||
|
@ -75,7 +86,7 @@ class DQfDAgent:
|
||||||
best_next_state_value = torch.zeros(batch_size, device = self.net.device)
|
best_next_state_value = torch.zeros(batch_size, device = self.net.device)
|
||||||
best_next_state_value[not_done_batch] = next_state_values[not_done_batch].gather(1, next_best_action.view((not_done_size, 1))).squeeze(1)
|
best_next_state_value[not_done_batch] = next_state_values[not_done_batch].gather(1, next_best_action.view((not_done_size, 1))).squeeze(1)
|
||||||
|
|
||||||
expected_values = (reward_batch + (batch_size * best_next_state_value)).unsqueeze(1)
|
expected_values = (reward_batch + (self.config['discount_rate'] * best_next_state_value)).unsqueeze(1)
|
||||||
|
|
||||||
# N-Step DQN Loss
|
# N-Step DQN Loss
|
||||||
# num_steps capture how many steps actually exist before the end of episode
|
# num_steps capture how many steps actually exist before the end of episode
|
||||||
|
@ -137,17 +148,26 @@ class DQfDAgent:
|
||||||
|
|
||||||
|
|
||||||
# Since dqn_loss and demo_loss are different sizes, the reduction has to happen before they are combined
|
# Since dqn_loss and demo_loss are different sizes, the reduction has to happen before they are combined
|
||||||
|
if isinstance(self.memory, M.DQfDMemory):
|
||||||
dqn_loss = (torch.as_tensor(importance_weights, device = self.net.device) * F.mse_loss(obtained_values, expected_values, reduction = 'none').squeeze(1)).mean()
|
dqn_loss = (torch.as_tensor(importance_weights, device = self.net.device) * F.mse_loss(obtained_values, expected_values, reduction = 'none').squeeze(1)).mean()
|
||||||
|
else:
|
||||||
|
dqn_loss = F.mse_loss(obtained_values, expected_values)
|
||||||
|
|
||||||
if steps != None:
|
if steps != None:
|
||||||
|
if isinstance(self.memory, M.DQfDMemory):
|
||||||
dqn_n_step_loss = (torch.as_tensor(importance_weights[::steps], device = self.net.device) * F.mse_loss(observed_n_step_values, expected_n_step_values, reduction = 'none')).mean()
|
dqn_n_step_loss = (torch.as_tensor(importance_weights[::steps], device = self.net.device) * F.mse_loss(observed_n_step_values, expected_n_step_values, reduction = 'none')).mean()
|
||||||
|
else:
|
||||||
|
dqn_n_step_loss = F.mse_loss(observed_n_step_values, expected_n_step_values, reduction = 'none').mean()
|
||||||
else:
|
else:
|
||||||
dqn_n_step_loss = torch.tensor(0, device = self.net.device)
|
dqn_n_step_loss = torch.tensor(0, device = self.net.device)
|
||||||
|
|
||||||
if demo_mask.sum() > 0:
|
if demo_mask.sum() > 0:
|
||||||
|
if isinstance(self.memory, M.DQfDMemory):
|
||||||
demo_loss = (torch.as_tensor(importance_weights, device = self.net.device)[demo_mask] * F.mse_loss((state_values[demo_mask] + l).max(1)[0].unsqueeze(1), expert_value, reduction = 'none').squeeze(1)).mean()
|
demo_loss = (torch.as_tensor(importance_weights, device = self.net.device)[demo_mask] * F.mse_loss((state_values[demo_mask] + l).max(1)[0].unsqueeze(1), expert_value, reduction = 'none').squeeze(1)).mean()
|
||||||
else:
|
else:
|
||||||
demo_loss = 0
|
demo_loss = F.mse_loss((state_values[demo_mask] + l).max(1)[0].unsqueeze(1), expert_value, reduction = 'none').squeeze(1).mean()
|
||||||
|
else:
|
||||||
|
demo_loss = 0.
|
||||||
loss = td_importance * dqn_loss + td_importance * dqn_n_step_loss + demo_importance * demo_loss
|
loss = td_importance * dqn_loss + td_importance * dqn_n_step_loss + demo_importance * demo_loss
|
||||||
|
|
||||||
if self.logger is not None:
|
if self.logger is not None:
|
||||||
|
@ -165,7 +185,7 @@ class DQfDAgent:
|
||||||
self.target_net.sync()
|
self.target_net.sync()
|
||||||
|
|
||||||
# If we're sampling by TD error, readjust the weights of the experiences
|
# If we're sampling by TD error, readjust the weights of the experiences
|
||||||
# TODO: Can probably adjust demonstration priority here
|
if isinstance(self.memory, M.DQfDMemory):
|
||||||
td_error = (obtained_values - expected_values).detach().abs()
|
td_error = (obtained_values - expected_values).detach().abs()
|
||||||
td_error[demo_mask] = td_error[demo_mask] + self.config['demo_prio_bonus']
|
td_error[demo_mask] = td_error[demo_mask] + self.config['demo_prio_bonus']
|
||||||
observed_mask = batch_index_tensors >= self.memory.demo_position
|
observed_mask = batch_index_tensors >= self.memory.demo_position
|
||||||
|
|
|
@ -11,7 +11,7 @@ import torch.nn.functional as F
|
||||||
# Maximizes the policy with respect to the Q-Value function.
|
# Maximizes the policy with respect to the Q-Value function.
|
||||||
# Since function is non-differentiabile, depends on the Evolutionary Strategy algorithm
|
# Since function is non-differentiabile, depends on the Evolutionary Strategy algorithm
|
||||||
class QEPAgent:
|
class QEPAgent:
|
||||||
def __init__(self, policy_net, value_net, memory, config, target_value_net = None, logger = None, entropy_importance = 0):
|
def __init__(self, policy_net, value_net, memory, config, target_value_net = None, logger = None, entropy_importance = 0, policy_skip = 4, after_value_train = None):
|
||||||
self.policy_net = policy_net
|
self.policy_net = policy_net
|
||||||
assert isinstance(self.policy_net, rltorch.network.ESNetwork) or isinstance(self.policy_net, rltorch.network.ESNetworkMP)
|
assert isinstance(self.policy_net, rltorch.network.ESNetwork) or isinstance(self.policy_net, rltorch.network.ESNetworkMP)
|
||||||
self.policy_net.fitness = self.fitness
|
self.policy_net.fitness = self.fitness
|
||||||
|
@ -20,8 +20,9 @@ class QEPAgent:
|
||||||
self.memory = memory
|
self.memory = memory
|
||||||
self.config = deepcopy(config)
|
self.config = deepcopy(config)
|
||||||
self.logger = logger
|
self.logger = logger
|
||||||
self.policy_skip = 4
|
self.policy_skip = policy_skip
|
||||||
self.entropy_importance = entropy_importance
|
self.entropy_importance = entropy_importance
|
||||||
|
self.after_value_train = after_value_train
|
||||||
|
|
||||||
def save(self, file_location):
|
def save(self, file_location):
|
||||||
torch.save({
|
torch.save({
|
||||||
|
@ -41,8 +42,10 @@ class QEPAgent:
|
||||||
batch_size = len(state_batch)
|
batch_size = len(state_batch)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
action_probabilities = policy_net(state_batch)
|
action_probabilities = policy_net(state_batch)
|
||||||
|
|
||||||
action_size = action_probabilities.shape[1]
|
action_size = action_probabilities.shape[1]
|
||||||
distributions = list(map(Categorical, action_probabilities))
|
distributions = list(map(Categorical, action_probabilities))
|
||||||
|
|
||||||
actions = torch.stack([d.sample() for d in distributions])
|
actions = torch.stack([d.sample() for d in distributions])
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
@ -51,17 +54,17 @@ class QEPAgent:
|
||||||
# Weird hacky solution where in multiprocess, it sometimes spits out nans
|
# Weird hacky solution where in multiprocess, it sometimes spits out nans
|
||||||
# So have it try again
|
# So have it try again
|
||||||
while torch.isnan(state_values).any():
|
while torch.isnan(state_values).any():
|
||||||
|
print("NAN DETECTED")
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
state_values = value_net(state_batch)
|
state_values = value_net(state_batch)
|
||||||
|
|
||||||
obtained_values = state_values.gather(1, actions.view(len(state_batch), 1)).squeeze(1)
|
obtained_values = state_values.gather(1, actions.view(batch_size, 1)).squeeze(1)
|
||||||
# return -obtained_values.mean().item()
|
|
||||||
entropy_importance = 0 # Entropy accounting for 1% of loss seems to work well
|
|
||||||
entropy_importance = next(self.entropy_importance) if isinstance(self.entropy_importance, collections.Iterable) else self.entropy_importance
|
entropy_importance = next(self.entropy_importance) if isinstance(self.entropy_importance, collections.Iterable) else self.entropy_importance
|
||||||
value_importance = 1 - entropy_importance
|
value_importance = 1 - entropy_importance
|
||||||
|
|
||||||
# entropy_loss = (action_probabilities * torch.log2(action_probabilities)).sum(1) # Standard entropy loss from information theory
|
# entropy_loss = (action_probabilities * torch.log2(action_probabilities)).sum(1) # Standard entropy loss from information theory
|
||||||
entropy_loss = (action_probabilities - torch.tensor(1 / action_size, device = state_batch.device).repeat(len(state_batch), action_size)).abs().sum(1)
|
entropy_loss = (action_probabilities - torch.tensor(1 / action_size, device = state_batch.device).repeat(batch_size, action_size)).abs().sum(1)
|
||||||
|
|
||||||
return (entropy_importance * entropy_loss - value_importance * obtained_values).mean().item()
|
return (entropy_importance * entropy_loss - value_importance * obtained_values).mean().item()
|
||||||
|
|
||||||
|
@ -121,6 +124,9 @@ class QEPAgent:
|
||||||
self.value_net.clamp_gradients()
|
self.value_net.clamp_gradients()
|
||||||
self.value_net.step()
|
self.value_net.step()
|
||||||
|
|
||||||
|
if callable(self.after_value_train):
|
||||||
|
self.after_value_train()
|
||||||
|
|
||||||
if self.target_value_net is not None:
|
if self.target_value_net is not None:
|
||||||
if 'target_sync_tau' in self.config:
|
if 'target_sync_tau' in self.config:
|
||||||
self.target_value_net.partial_sync(self.config['target_sync_tau'])
|
self.target_value_net.partial_sync(self.config['target_sync_tau'])
|
||||||
|
@ -135,8 +141,7 @@ class QEPAgent:
|
||||||
if self.policy_skip > 0:
|
if self.policy_skip > 0:
|
||||||
self.policy_skip -= 1
|
self.policy_skip -= 1
|
||||||
return
|
return
|
||||||
self.policy_skip = 4
|
self.policy_skip = self.config['policy_skip']
|
||||||
|
|
||||||
if self.target_value_net is not None:
|
if self.target_value_net is not None:
|
||||||
self.policy_net.calc_gradients(self.target_value_net, state_batch)
|
self.policy_net.calc_gradients(self.target_value_net, state_batch)
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -17,8 +17,8 @@ class DQfDMemory(PrioritizedReplayMemory):
|
||||||
last_position = self.position # Get position before super classes change it
|
last_position = self.position # Get position before super classes change it
|
||||||
super().append(*args, **kwargs)
|
super().append(*args, **kwargs)
|
||||||
# Don't overwrite demonstration data
|
# Don't overwrite demonstration data
|
||||||
new_position = ((last_position + 1) % (self.capacity - self.demo_position + 1))
|
new_position = ((last_position - self.demo_position + 1) % (self.capacity - self.demo_position))
|
||||||
self.position = new_position if new_position > self.demo_position else self.demo_position + new_position
|
self.position = self.demo_position + new_position
|
||||||
|
|
||||||
def append_demonstration(self, *args):
|
def append_demonstration(self, *args):
|
||||||
demonstrations = self.memory[:self.demo_position]
|
demonstrations = self.memory[:self.demo_position]
|
||||||
|
|
|
@ -64,9 +64,11 @@ class ReplayMemory(object):
|
||||||
def __reversed__(self):
|
def __reversed__(self):
|
||||||
return reversed(self.memory)
|
return reversed(self.memory)
|
||||||
|
|
||||||
def zip_batch(minibatch, priority = False):
|
def zip_batch(minibatch, priority = False, want_indices = False):
|
||||||
if priority:
|
if priority:
|
||||||
state_batch, action_batch, reward_batch, next_state_batch, done_batch, weights, indexes = zip(*minibatch)
|
state_batch, action_batch, reward_batch, next_state_batch, done_batch, weights, indexes = zip(*minibatch)
|
||||||
|
elif want_indices:
|
||||||
|
state_batch, action_batch, reward_batch, next_state_batch, done_batch, indexes = zip(*minibatch)
|
||||||
else:
|
else:
|
||||||
state_batch, action_batch, reward_batch, next_state_batch, done_batch = zip(*minibatch)
|
state_batch, action_batch, reward_batch, next_state_batch, done_batch = zip(*minibatch)
|
||||||
|
|
||||||
|
@ -78,5 +80,7 @@ def zip_batch(minibatch, priority = False):
|
||||||
|
|
||||||
if priority:
|
if priority:
|
||||||
return state_batch, action_batch, reward_batch, next_state_batch, not_done_batch, weights, indexes
|
return state_batch, action_batch, reward_batch, next_state_batch, not_done_batch, weights, indexes
|
||||||
|
elif want_indices:
|
||||||
|
return state_batch, action_batch, reward_batch, next_state_batch, not_done_batch, indexes
|
||||||
else:
|
else:
|
||||||
return state_batch, action_batch, reward_batch, next_state_batch, not_done_batch
|
return state_batch, action_batch, reward_batch, next_state_batch, not_done_batch
|
|
@ -2,3 +2,4 @@ from .EpisodeMemory import *
|
||||||
from .ReplayMemory import *
|
from .ReplayMemory import *
|
||||||
from .PrioritizedReplayMemory import *
|
from .PrioritizedReplayMemory import *
|
||||||
from .DQfDMemory import *
|
from .DQfDMemory import *
|
||||||
|
from .iDQfDMemory import *
|
Loading…
Reference in a new issue