From 03455accc8fadd7f349b747e4243905224ab11a6 Mon Sep 17 00:00:00 2001 From: Brandon Rozek Date: Sun, 3 Feb 2019 00:49:47 -0500 Subject: [PATCH] Attempting to see if this fixes the mismatched devices error --- rltorch/agents/DQNAgent.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/rltorch/agents/DQNAgent.py b/rltorch/agents/DQNAgent.py index cd57913..e053324 100644 --- a/rltorch/agents/DQNAgent.py +++ b/rltorch/agents/DQNAgent.py @@ -17,6 +17,13 @@ class DQNAgent: minibatch = self.memory.sample(self.config['batch_size']) state_batch, action_batch, reward_batch, next_state_batch, not_done_batch = M.zip_batch(minibatch) + + # Send to their appropriate devices + state_batch = state_batch.to(net.device) + action_batch = action_batch.to(net.device) + reward_batch = reward_batch.to(net.device) + next_state_batch = next_state_batch.to(net.device) + not_done_batch = not_done_batch.to(net.device) obtained_values = self.net(state_batch).gather(1, action_batch.view(self.config['batch_size'], 1))