Sends importance weights into correct device for prioiritized replay
This commit is contained in:
parent
013d40a4f9
commit
04e54cddc2
1 changed files with 1 additions and 1 deletions
|
@ -54,7 +54,7 @@ class DQNAgent:
|
|||
expected_values = (reward_batch + (self.config['discount_rate'] * best_next_state_value)).unsqueeze(1)
|
||||
|
||||
if (isinstance(self.memory, M.PrioritizedReplayMemory)):
|
||||
loss = (torch.as_tensor(importance_weights) * (obtained_values - expected_values)**2).mean()
|
||||
loss = (torch.as_tensor(importance_weights, device = self.net.device) * (obtained_values - expected_values)**2).mean()
|
||||
else:
|
||||
loss = F.mse_loss(obtained_values, expected_values)
|
||||
|
||||
|
|
Loading…
Reference in a new issue