Fixed another mismatched device error

This commit is contained in:
Brandon Rozek 2019-02-03 00:53:13 -05:00
parent e62385b574
commit 0a6f1e73f3

View file

@ -39,7 +39,7 @@ class DQNAgent:
next_state_values = self.net(next_state_batch)
next_best_action = next_state_values.argmax(1)
best_next_state_value = torch.zeros(self.config['batch_size'])
best_next_state_value = torch.zeros(self.config['batch_size'], device = self.net.device)
best_next_state_value[not_done_batch] = next_state_values.gather(1, next_best_action.view((not_done_size, 1))).squeeze(1)
expected_values = (reward_batch + (self.config['discount_rate'] * best_next_state_value)).unsqueeze(1)