diff --git a/rltorch/agents/DQfDAgent.py b/rltorch/agents/DQfDAgent.py index 99af560..43f2cc9 100644 --- a/rltorch/agents/DQfDAgent.py +++ b/rltorch/agents/DQfDAgent.py @@ -32,18 +32,29 @@ class DQfDAgent: batch_size = self.config['batch_size'] steps = None - weight_importance = self.config['prioritized_replay_weight_importance'] - # If it's a scheduler then get the next value by calling next, otherwise just use it's value - beta = next(weight_importance) if isinstance(weight_importance, collections.Iterable) else weight_importance - - # Check to see if we are doing N-Step DQN - if steps is not None: - minibatch = self.memory.sample_n_steps(batch_size, steps, beta) - else: - minibatch = self.memory.sample(batch_size, beta = beta) + if isinstance(self.memory, M.DQfDMemory): + weight_importance = self.config['prioritized_replay_weight_importance'] + # If it's a scheduler then get the next value by calling next, otherwise just use it's value + beta = next(weight_importance) if isinstance(weight_importance, collections.Iterable) else weight_importance - # Process batch - state_batch, action_batch, reward_batch, next_state_batch, not_done_batch, importance_weights, batch_indexes = M.zip_batch(minibatch, priority = True) + # Check to see if we are doing N-Step DQN + if steps is not None: + minibatch = self.memory.sample_n_steps(batch_size, steps, beta) + else: + minibatch = self.memory.sample(batch_size, beta = beta) + + # Process batch + state_batch, action_batch, reward_batch, next_state_batch, not_done_batch, importance_weights, batch_indexes = M.zip_batch(minibatch, priority = True) + + else: + # Check to see if we're doing N-Step DQN + if steps is not None: + minibatch = self.memory.sample_n_steps(batch_size, steps) + else: + minibatch = self.memory.sample(batch_size) + + # Process batch + state_batch, action_batch, reward_batch, next_state_batch, not_done_batch, batch_indexes = M.zip_batch(minibatch, want_indices = True) batch_index_tensors = torch.tensor(batch_indexes) demo_mask = batch_index_tensors < self.memory.demo_position @@ -75,7 +86,7 @@ class DQfDAgent: best_next_state_value = torch.zeros(batch_size, device = self.net.device) best_next_state_value[not_done_batch] = next_state_values[not_done_batch].gather(1, next_best_action.view((not_done_size, 1))).squeeze(1) - expected_values = (reward_batch + (batch_size * best_next_state_value)).unsqueeze(1) + expected_values = (reward_batch + (self.config['discount_rate'] * best_next_state_value)).unsqueeze(1) # N-Step DQN Loss # num_steps capture how many steps actually exist before the end of episode @@ -137,17 +148,26 @@ class DQfDAgent: # Since dqn_loss and demo_loss are different sizes, the reduction has to happen before they are combined - dqn_loss = (torch.as_tensor(importance_weights, device = self.net.device) * F.mse_loss(obtained_values, expected_values, reduction = 'none').squeeze(1)).mean() + if isinstance(self.memory, M.DQfDMemory): + dqn_loss = (torch.as_tensor(importance_weights, device = self.net.device) * F.mse_loss(obtained_values, expected_values, reduction = 'none').squeeze(1)).mean() + else: + dqn_loss = F.mse_loss(obtained_values, expected_values) if steps != None: - dqn_n_step_loss = (torch.as_tensor(importance_weights[::steps], device = self.net.device) * F.mse_loss(observed_n_step_values, expected_n_step_values, reduction = 'none')).mean() + if isinstance(self.memory, M.DQfDMemory): + dqn_n_step_loss = (torch.as_tensor(importance_weights[::steps], device = self.net.device) * F.mse_loss(observed_n_step_values, expected_n_step_values, reduction = 'none')).mean() + else: + dqn_n_step_loss = F.mse_loss(observed_n_step_values, expected_n_step_values, reduction = 'none').mean() else: dqn_n_step_loss = torch.tensor(0, device = self.net.device) if demo_mask.sum() > 0: - demo_loss = (torch.as_tensor(importance_weights, device = self.net.device)[demo_mask] * F.mse_loss((state_values[demo_mask] + l).max(1)[0].unsqueeze(1), expert_value, reduction = 'none').squeeze(1)).mean() + if isinstance(self.memory, M.DQfDMemory): + demo_loss = (torch.as_tensor(importance_weights, device = self.net.device)[demo_mask] * F.mse_loss((state_values[demo_mask] + l).max(1)[0].unsqueeze(1), expert_value, reduction = 'none').squeeze(1)).mean() + else: + demo_loss = F.mse_loss((state_values[demo_mask] + l).max(1)[0].unsqueeze(1), expert_value, reduction = 'none').squeeze(1).mean() else: - demo_loss = 0 + demo_loss = 0. loss = td_importance * dqn_loss + td_importance * dqn_n_step_loss + demo_importance * demo_loss if self.logger is not None: @@ -165,11 +185,11 @@ class DQfDAgent: self.target_net.sync() # If we're sampling by TD error, readjust the weights of the experiences - # TODO: Can probably adjust demonstration priority here - td_error = (obtained_values - expected_values).detach().abs() - td_error[demo_mask] = td_error[demo_mask] + self.config['demo_prio_bonus'] - observed_mask = batch_index_tensors >= self.memory.demo_position - td_error[observed_mask] = td_error[observed_mask] + self.config['observed_prio_bonus'] - self.memory.update_priorities(batch_indexes, td_error) + if isinstance(self.memory, M.DQfDMemory): + td_error = (obtained_values - expected_values).detach().abs() + td_error[demo_mask] = td_error[demo_mask] + self.config['demo_prio_bonus'] + observed_mask = batch_index_tensors >= self.memory.demo_position + td_error[observed_mask] = td_error[observed_mask] + self.config['observed_prio_bonus'] + self.memory.update_priorities(batch_indexes, td_error) diff --git a/rltorch/agents/QEPAgent.py b/rltorch/agents/QEPAgent.py index a1ae0dd..9dd0fd9 100644 --- a/rltorch/agents/QEPAgent.py +++ b/rltorch/agents/QEPAgent.py @@ -11,7 +11,7 @@ import torch.nn.functional as F # Maximizes the policy with respect to the Q-Value function. # Since function is non-differentiabile, depends on the Evolutionary Strategy algorithm class QEPAgent: - def __init__(self, policy_net, value_net, memory, config, target_value_net = None, logger = None, entropy_importance = 0): + def __init__(self, policy_net, value_net, memory, config, target_value_net = None, logger = None, entropy_importance = 0, policy_skip = 4, after_value_train = None): self.policy_net = policy_net assert isinstance(self.policy_net, rltorch.network.ESNetwork) or isinstance(self.policy_net, rltorch.network.ESNetworkMP) self.policy_net.fitness = self.fitness @@ -20,8 +20,9 @@ class QEPAgent: self.memory = memory self.config = deepcopy(config) self.logger = logger - self.policy_skip = 4 + self.policy_skip = policy_skip self.entropy_importance = entropy_importance + self.after_value_train = after_value_train def save(self, file_location): torch.save({ @@ -41,27 +42,29 @@ class QEPAgent: batch_size = len(state_batch) with torch.no_grad(): action_probabilities = policy_net(state_batch) + action_size = action_probabilities.shape[1] distributions = list(map(Categorical, action_probabilities)) + actions = torch.stack([d.sample() for d in distributions]) with torch.no_grad(): state_values = value_net(state_batch) - + # Weird hacky solution where in multiprocess, it sometimes spits out nans # So have it try again while torch.isnan(state_values).any(): + print("NAN DETECTED") with torch.no_grad(): state_values = value_net(state_batch) - obtained_values = state_values.gather(1, actions.view(len(state_batch), 1)).squeeze(1) - # return -obtained_values.mean().item() - entropy_importance = 0 # Entropy accounting for 1% of loss seems to work well + obtained_values = state_values.gather(1, actions.view(batch_size, 1)).squeeze(1) + entropy_importance = next(self.entropy_importance) if isinstance(self.entropy_importance, collections.Iterable) else self.entropy_importance value_importance = 1 - entropy_importance # entropy_loss = (action_probabilities * torch.log2(action_probabilities)).sum(1) # Standard entropy loss from information theory - entropy_loss = (action_probabilities - torch.tensor(1 / action_size, device = state_batch.device).repeat(len(state_batch), action_size)).abs().sum(1) + entropy_loss = (action_probabilities - torch.tensor(1 / action_size, device = state_batch.device).repeat(batch_size, action_size)).abs().sum(1) return (entropy_importance * entropy_loss - value_importance * obtained_values).mean().item() @@ -121,6 +124,9 @@ class QEPAgent: self.value_net.clamp_gradients() self.value_net.step() + if callable(self.after_value_train): + self.after_value_train() + if self.target_value_net is not None: if 'target_sync_tau' in self.config: self.target_value_net.partial_sync(self.config['target_sync_tau']) @@ -135,8 +141,7 @@ class QEPAgent: if self.policy_skip > 0: self.policy_skip -= 1 return - self.policy_skip = 4 - + self.policy_skip = self.config['policy_skip'] if self.target_value_net is not None: self.policy_net.calc_gradients(self.target_value_net, state_batch) else: diff --git a/rltorch/memory/DQfDMemory.py b/rltorch/memory/DQfDMemory.py index af09e68..3a8341b 100644 --- a/rltorch/memory/DQfDMemory.py +++ b/rltorch/memory/DQfDMemory.py @@ -17,8 +17,8 @@ class DQfDMemory(PrioritizedReplayMemory): last_position = self.position # Get position before super classes change it super().append(*args, **kwargs) # Don't overwrite demonstration data - new_position = ((last_position + 1) % (self.capacity - self.demo_position + 1)) - self.position = new_position if new_position > self.demo_position else self.demo_position + new_position + new_position = ((last_position - self.demo_position + 1) % (self.capacity - self.demo_position)) + self.position = self.demo_position + new_position def append_demonstration(self, *args): demonstrations = self.memory[:self.demo_position] diff --git a/rltorch/memory/ReplayMemory.py b/rltorch/memory/ReplayMemory.py index aa32ab7..a11229b 100644 --- a/rltorch/memory/ReplayMemory.py +++ b/rltorch/memory/ReplayMemory.py @@ -64,9 +64,11 @@ class ReplayMemory(object): def __reversed__(self): return reversed(self.memory) -def zip_batch(minibatch, priority = False): +def zip_batch(minibatch, priority = False, want_indices = False): if priority: state_batch, action_batch, reward_batch, next_state_batch, done_batch, weights, indexes = zip(*minibatch) + elif want_indices: + state_batch, action_batch, reward_batch, next_state_batch, done_batch, indexes = zip(*minibatch) else: state_batch, action_batch, reward_batch, next_state_batch, done_batch = zip(*minibatch) @@ -78,5 +80,7 @@ def zip_batch(minibatch, priority = False): if priority: return state_batch, action_batch, reward_batch, next_state_batch, not_done_batch, weights, indexes + elif want_indices: + return state_batch, action_batch, reward_batch, next_state_batch, not_done_batch, indexes else: return state_batch, action_batch, reward_batch, next_state_batch, not_done_batch \ No newline at end of file diff --git a/rltorch/memory/__init__.py b/rltorch/memory/__init__.py index eb9932c..05312d9 100644 --- a/rltorch/memory/__init__.py +++ b/rltorch/memory/__init__.py @@ -1,4 +1,5 @@ from .EpisodeMemory import * from .ReplayMemory import * from .PrioritizedReplayMemory import * -from .DQfDMemory import * \ No newline at end of file +from .DQfDMemory import * +from .iDQfDMemory import * \ No newline at end of file