Sends importance weights into correct device for prioiritized replay

This commit is contained in:
Brandon Rozek 2019-02-10 23:16:44 -05:00
parent 013d40a4f9
commit 04e54cddc2

View file

@ -54,7 +54,7 @@ class DQNAgent:
expected_values = (reward_batch + (self.config['discount_rate'] * best_next_state_value)).unsqueeze(1)
if (isinstance(self.memory, M.PrioritizedReplayMemory)):
loss = (torch.as_tensor(importance_weights) * (obtained_values - expected_values)**2).mean()
loss = (torch.as_tensor(importance_weights, device = self.net.device) * (obtained_values - expected_values)**2).mean()
else:
loss = F.mse_loss(obtained_values, expected_values)