Implemented components necessary for Deep Q Learning from Demonstrations
This commit is contained in:
parent
17391c7467
commit
ad75539776
3 changed files with 53 additions and 31 deletions
|
@ -8,10 +8,8 @@ from pathlib import Path
|
|||
from rltorch.action_selector import ArgMaxSelector
|
||||
|
||||
class DQfDAgent:
|
||||
def __init__(self, net, imitation_net, memory, config, target_net = None, logger = None):
|
||||
def __init__(self, net, memory, config, target_net = None, logger = None):
|
||||
self.net = net
|
||||
self.imitation_net = imitation_net
|
||||
self.imitator = ArgMaxSelector(imitation_net, self.imitation_net.model.action_size, device = imitation_net.device)
|
||||
self.target_net = target_net
|
||||
self.memory = memory
|
||||
self.config = deepcopy(config)
|
||||
|
@ -27,15 +25,13 @@ class DQfDAgent:
|
|||
if len(self.memory) < self.config['batch_size']:
|
||||
return
|
||||
|
||||
if (isinstance(self.memory, M.PrioritizedReplayMemory)):
|
||||
weight_importance = self.config['prioritized_replay_weight_importance']
|
||||
# If it's a scheduler then get the next value by calling next, otherwise just use it's value
|
||||
beta = next(weight_importance) if isinstance(weight_importance, collections.Iterable) else weight_importance
|
||||
minibatch = self.memory.sample(self.config['batch_size'], beta = beta)
|
||||
state_batch, action_batch, reward_batch, next_state_batch, not_done_batch, importance_weights, batch_indexes = M.zip_batch(minibatch, priority = True)
|
||||
else:
|
||||
minibatch = self.memory.sample(self.config['batch_size'])
|
||||
state_batch, action_batch, reward_batch, next_state_batch, not_done_batch = M.zip_batch(minibatch)
|
||||
|
||||
demo_indexes = batch_indexes < self.memory.demo_position
|
||||
|
||||
# Send to their appropriate devices
|
||||
state_batch = state_batch.to(self.net.device)
|
||||
|
@ -65,32 +61,33 @@ class DQfDAgent:
|
|||
|
||||
expected_values = (reward_batch + (self.config['discount_rate'] * best_next_state_value)).unsqueeze(1)
|
||||
|
||||
# If we're sampling by TD error, multiply loss by a importance weight which helps decrease overfitting
|
||||
if (isinstance(self.memory, M.PrioritizedReplayMemory)):
|
||||
# dqn_loss = (torch.as_tensor(importance_weights, device = self.net.device) * F.smooth_l1_loss(obtained_values, expected_values, reduction = 'none').squeeze(1)).mean()
|
||||
dqn_loss = (torch.as_tensor(importance_weights, device = self.net.device) * ((obtained_values - expected_values)**2).squeeze(1)).mean()
|
||||
else:
|
||||
# dqn_loss = F.smooth_l1_loss(obtained_values, expected_values)
|
||||
dqn_loss = F.mse_loss(obtained_values, expected_values)
|
||||
|
||||
# Demonstration loss
|
||||
l = torch.ones_like(state_values)
|
||||
expert_actions = self.imitation_net(state_batch).argmax(1)
|
||||
l = torch.ones_like(state_values[demo_indexes])
|
||||
expert_actions = action_batch[demo_indexes]
|
||||
# l(s, a) is zero for every action the expert doesn't take
|
||||
for i,a in zip(range(len(state_values)), expert_actions):
|
||||
for i,a in zip(range(len(l)), expert_actions):
|
||||
l[i].fill_(0.8) # According to paper
|
||||
l[i, a] = 0
|
||||
if self.target_net is not None:
|
||||
expert_value = self.target_net(state_batch)
|
||||
expert_value = self.target_net(state_batch[demo_indexes])
|
||||
else:
|
||||
expert_value = self.net(state_batch)
|
||||
expert_value = state_values[demo_indexes]
|
||||
expert_value = expert_value.gather(1, expert_actions.view((self.config['batch_size'], 1))).squeeze(1)
|
||||
|
||||
# Iterate through hyperparamters
|
||||
if isinstance(self.config['dqfd_demo_loss_weight'], collections.Iterable):
|
||||
demo_importance = next(self.config['dqfd_demo_loss_weight'])
|
||||
else:
|
||||
demo_importance = self.config['dqfd_demo_loss_weight']
|
||||
demo_loss = F.mse_loss((state_values + l).max(1)[0], expert_value)
|
||||
loss = ((1 - demo_importance) * dqn_loss + demo_importance * demo_loss) / (dqn_loss + demo_loss)
|
||||
if isinstance(self.config['dqfd_td_loss_weight'], collections.Iterable):
|
||||
td_importance = next(self.config['dqfd_td_loss_weight'])
|
||||
else:
|
||||
td_importance = self.config['dqfd_td_loss_weight']
|
||||
|
||||
# Since dqn_loss and demo_loss are different sizes, the reduction has to happen before they are combined
|
||||
dqn_loss = (torch.as_tensor(importance_weights, device = self.net.device) * F.mse_loss(obtained_values, expected_values, reduction = 'none')).mean()
|
||||
demo_loss = (torch.as_tensor(importance_weights[demo_indexes], device = self.net.device) * F.mse_loss((state_values[demo_indexes] + l).max(1)[0], expert_value, reduction = 'none')).mean()
|
||||
loss = td_importance * dqn_loss + demo_importance * demo_loss
|
||||
|
||||
if self.logger is not None:
|
||||
self.logger.append("Loss", loss.item())
|
||||
|
@ -107,7 +104,7 @@ class DQfDAgent:
|
|||
self.target_net.sync()
|
||||
|
||||
# If we're sampling by TD error, readjust the weights of the experiences
|
||||
if (isinstance(self.memory, M.PrioritizedReplayMemory)):
|
||||
# TODO: Can probably adjust demonstration priority here
|
||||
td_error = (obtained_values - expected_values).detach().abs()
|
||||
self.memory.update_priorities(batch_indexes, td_error)
|
||||
|
||||
|
|
24
rltorch/memory/DQfDMemory.py
Normal file
24
rltorch/memory/DQfDMemory.py
Normal file
|
@ -0,0 +1,24 @@
|
|||
from .PrioritizedReplayMemory import PrioritizedReplayMemory, Transition
|
||||
|
||||
class DQfDMemory(PrioritizedReplayMemory):
|
||||
def __init__(self, capacity, alpha):
|
||||
super().__init__(capacity, alpha)
|
||||
self.demo_position = 0
|
||||
self.obtained_transitions_length = 0
|
||||
|
||||
def append(self, *args, **kwargs):
|
||||
super().append(self, *args, **kwargs)
|
||||
# Don't overwrite demonstration data
|
||||
self.position = self.demo_position + ((self.position + 1) % (self.capacity - self.demo_position))
|
||||
|
||||
def append_demonstration(self, *args):
|
||||
demonstrations = self.memory[:self.demo_position]
|
||||
obtained_transitions = self.memory[self.demo_position:]
|
||||
if len(demonstrations) + 1 > self.capacity:
|
||||
self.memory.pop(0)
|
||||
self.memory.append(Transition(*args))
|
||||
else:
|
||||
if len(demonstrations) + len(obtained_transitions) + 1 > self.capacity:
|
||||
obtained_transitions = obtained_transitions[:(self.capacity - len(demonstrations) - 1)]
|
||||
self.memory = demonstrations + [Transition(*args)] + obtained_transitions
|
||||
self.demo_position += 1
|
|
@ -1,3 +1,4 @@
|
|||
from .EpisodeMemory import *
|
||||
from .ReplayMemory import *
|
||||
from .PrioritizedReplayMemory import *
|
||||
from .DQfDMemory import *
|
Loading…
Reference in a new issue