diff --git a/rltorch/agents/DQfDAgent.py b/rltorch/agents/DQfDAgent.py index 45a84fc..f81a5ef 100644 --- a/rltorch/agents/DQfDAgent.py +++ b/rltorch/agents/DQfDAgent.py @@ -8,10 +8,8 @@ from pathlib import Path from rltorch.action_selector import ArgMaxSelector class DQfDAgent: - def __init__(self, net, imitation_net, memory, config, target_net = None, logger = None): + def __init__(self, net, memory, config, target_net = None, logger = None): self.net = net - self.imitation_net = imitation_net - self.imitator = ArgMaxSelector(imitation_net, self.imitation_net.model.action_size, device = imitation_net.device) self.target_net = target_net self.memory = memory self.config = deepcopy(config) @@ -27,15 +25,13 @@ class DQfDAgent: if len(self.memory) < self.config['batch_size']: return - if (isinstance(self.memory, M.PrioritizedReplayMemory)): - weight_importance = self.config['prioritized_replay_weight_importance'] - # If it's a scheduler then get the next value by calling next, otherwise just use it's value - beta = next(weight_importance) if isinstance(weight_importance, collections.Iterable) else weight_importance - minibatch = self.memory.sample(self.config['batch_size'], beta = beta) - state_batch, action_batch, reward_batch, next_state_batch, not_done_batch, importance_weights, batch_indexes = M.zip_batch(minibatch, priority = True) - else: - minibatch = self.memory.sample(self.config['batch_size']) - state_batch, action_batch, reward_batch, next_state_batch, not_done_batch = M.zip_batch(minibatch) + weight_importance = self.config['prioritized_replay_weight_importance'] + # If it's a scheduler then get the next value by calling next, otherwise just use it's value + beta = next(weight_importance) if isinstance(weight_importance, collections.Iterable) else weight_importance + minibatch = self.memory.sample(self.config['batch_size'], beta = beta) + state_batch, action_batch, reward_batch, next_state_batch, not_done_batch, importance_weights, batch_indexes = M.zip_batch(minibatch, priority = True) + + demo_indexes = batch_indexes < self.memory.demo_position # Send to their appropriate devices state_batch = state_batch.to(self.net.device) @@ -65,32 +61,33 @@ class DQfDAgent: expected_values = (reward_batch + (self.config['discount_rate'] * best_next_state_value)).unsqueeze(1) - # If we're sampling by TD error, multiply loss by a importance weight which helps decrease overfitting - if (isinstance(self.memory, M.PrioritizedReplayMemory)): - # dqn_loss = (torch.as_tensor(importance_weights, device = self.net.device) * F.smooth_l1_loss(obtained_values, expected_values, reduction = 'none').squeeze(1)).mean() - dqn_loss = (torch.as_tensor(importance_weights, device = self.net.device) * ((obtained_values - expected_values)**2).squeeze(1)).mean() - else: - # dqn_loss = F.smooth_l1_loss(obtained_values, expected_values) - dqn_loss = F.mse_loss(obtained_values, expected_values) - # Demonstration loss - l = torch.ones_like(state_values) - expert_actions = self.imitation_net(state_batch).argmax(1) + l = torch.ones_like(state_values[demo_indexes]) + expert_actions = action_batch[demo_indexes] # l(s, a) is zero for every action the expert doesn't take - for i,a in zip(range(len(state_values)), expert_actions): + for i,a in zip(range(len(l)), expert_actions): l[i].fill_(0.8) # According to paper l[i, a] = 0 if self.target_net is not None: - expert_value = self.target_net(state_batch) + expert_value = self.target_net(state_batch[demo_indexes]) else: - expert_value = self.net(state_batch) + expert_value = state_values[demo_indexes] expert_value = expert_value.gather(1, expert_actions.view((self.config['batch_size'], 1))).squeeze(1) + + # Iterate through hyperparamters if isinstance(self.config['dqfd_demo_loss_weight'], collections.Iterable): demo_importance = next(self.config['dqfd_demo_loss_weight']) else: demo_importance = self.config['dqfd_demo_loss_weight'] - demo_loss = F.mse_loss((state_values + l).max(1)[0], expert_value) - loss = ((1 - demo_importance) * dqn_loss + demo_importance * demo_loss) / (dqn_loss + demo_loss) + if isinstance(self.config['dqfd_td_loss_weight'], collections.Iterable): + td_importance = next(self.config['dqfd_td_loss_weight']) + else: + td_importance = self.config['dqfd_td_loss_weight'] + + # Since dqn_loss and demo_loss are different sizes, the reduction has to happen before they are combined + dqn_loss = (torch.as_tensor(importance_weights, device = self.net.device) * F.mse_loss(obtained_values, expected_values, reduction = 'none')).mean() + demo_loss = (torch.as_tensor(importance_weights[demo_indexes], device = self.net.device) * F.mse_loss((state_values[demo_indexes] + l).max(1)[0], expert_value, reduction = 'none')).mean() + loss = td_importance * dqn_loss + demo_importance * demo_loss if self.logger is not None: self.logger.append("Loss", loss.item()) @@ -105,10 +102,10 @@ class DQfDAgent: self.target_net.partial_sync(self.config['target_sync_tau']) else: self.target_net.sync() - + # If we're sampling by TD error, readjust the weights of the experiences - if (isinstance(self.memory, M.PrioritizedReplayMemory)): - td_error = (obtained_values - expected_values).detach().abs() - self.memory.update_priorities(batch_indexes, td_error) + # TODO: Can probably adjust demonstration priority here + td_error = (obtained_values - expected_values).detach().abs() + self.memory.update_priorities(batch_indexes, td_error) diff --git a/rltorch/memory/DQfDMemory.py b/rltorch/memory/DQfDMemory.py new file mode 100644 index 0000000..73ce175 --- /dev/null +++ b/rltorch/memory/DQfDMemory.py @@ -0,0 +1,24 @@ +from .PrioritizedReplayMemory import PrioritizedReplayMemory, Transition + +class DQfDMemory(PrioritizedReplayMemory): + def __init__(self, capacity, alpha): + super().__init__(capacity, alpha) + self.demo_position = 0 + self.obtained_transitions_length = 0 + + def append(self, *args, **kwargs): + super().append(self, *args, **kwargs) + # Don't overwrite demonstration data + self.position = self.demo_position + ((self.position + 1) % (self.capacity - self.demo_position)) + + def append_demonstration(self, *args): + demonstrations = self.memory[:self.demo_position] + obtained_transitions = self.memory[self.demo_position:] + if len(demonstrations) + 1 > self.capacity: + self.memory.pop(0) + self.memory.append(Transition(*args)) + else: + if len(demonstrations) + len(obtained_transitions) + 1 > self.capacity: + obtained_transitions = obtained_transitions[:(self.capacity - len(demonstrations) - 1)] + self.memory = demonstrations + [Transition(*args)] + obtained_transitions + self.demo_position += 1 \ No newline at end of file diff --git a/rltorch/memory/__init__.py b/rltorch/memory/__init__.py index 17b803f..eb9932c 100644 --- a/rltorch/memory/__init__.py +++ b/rltorch/memory/__init__.py @@ -1,3 +1,4 @@ from .EpisodeMemory import * from .ReplayMemory import * from .PrioritizedReplayMemory import * +from .DQfDMemory import * \ No newline at end of file