Fixed another mismatched device error
This commit is contained in:
parent
e62385b574
commit
0a6f1e73f3
1 changed files with 1 additions and 1 deletions
|
@ -39,7 +39,7 @@ class DQNAgent:
|
|||
next_state_values = self.net(next_state_batch)
|
||||
next_best_action = next_state_values.argmax(1)
|
||||
|
||||
best_next_state_value = torch.zeros(self.config['batch_size'])
|
||||
best_next_state_value = torch.zeros(self.config['batch_size'], device = self.net.device)
|
||||
best_next_state_value[not_done_batch] = next_state_values.gather(1, next_best_action.view((not_done_size, 1))).squeeze(1)
|
||||
|
||||
expected_values = (reward_batch + (self.config['discount_rate'] * best_next_state_value)).unsqueeze(1)
|
||||
|
|
Loading…
Reference in a new issue