Sends importance weights into correct device for prioiritized replay
This commit is contained in:
parent
013d40a4f9
commit
04e54cddc2
1 changed files with 1 additions and 1 deletions
|
@ -54,7 +54,7 @@ class DQNAgent:
|
||||||
expected_values = (reward_batch + (self.config['discount_rate'] * best_next_state_value)).unsqueeze(1)
|
expected_values = (reward_batch + (self.config['discount_rate'] * best_next_state_value)).unsqueeze(1)
|
||||||
|
|
||||||
if (isinstance(self.memory, M.PrioritizedReplayMemory)):
|
if (isinstance(self.memory, M.PrioritizedReplayMemory)):
|
||||||
loss = (torch.as_tensor(importance_weights) * (obtained_values - expected_values)**2).mean()
|
loss = (torch.as_tensor(importance_weights, device = self.net.device) * (obtained_values - expected_values)**2).mean()
|
||||||
else:
|
else:
|
||||||
loss = F.mse_loss(obtained_values, expected_values)
|
loss = F.mse_loss(obtained_values, expected_values)
|
||||||
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue