From b2f522058584e7d14240f6d56c6e5d4e50b34e20 Mon Sep 17 00:00:00 2001 From: Brandon Rozek Date: Thu, 14 Mar 2019 10:43:14 -0400 Subject: [PATCH] Made sure the reward_batch is float across different agents --- rltorch/agents/A2CSingleAgent.py | 2 +- rltorch/agents/DQNAgent.py | 2 +- rltorch/agents/PPOAgent.py | 2 +- rltorch/agents/QEPAgent.py | 4 ++-- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/rltorch/agents/A2CSingleAgent.py b/rltorch/agents/A2CSingleAgent.py index c7f367e..e7316ec 100644 --- a/rltorch/agents/A2CSingleAgent.py +++ b/rltorch/agents/A2CSingleAgent.py @@ -44,7 +44,7 @@ class A2CSingleAgent: # Send batches to the appropriate device state_batch = torch.cat(state_batch).to(self.value_net.device) - reward_batch = torch.tensor(reward_batch).to(self.value_net.device) + reward_batch = torch.tensor(reward_batch).to(self.value_net.device).float() not_done_batch = ~torch.tensor(done_batch).to(self.value_net.device) next_state_batch = torch.cat(next_state_batch).to(self.value_net.device) log_prob_batch = torch.cat(log_prob_batch).to(self.value_net.device) diff --git a/rltorch/agents/DQNAgent.py b/rltorch/agents/DQNAgent.py index a73391f..3f20b52 100644 --- a/rltorch/agents/DQNAgent.py +++ b/rltorch/agents/DQNAgent.py @@ -30,7 +30,7 @@ class DQNAgent: # Send to their appropriate devices state_batch = state_batch.to(self.net.device) action_batch = action_batch.to(self.net.device) - reward_batch = reward_batch.to(self.net.device) + reward_batch = reward_batch.to(self.net.device).float() next_state_batch = next_state_batch.to(self.net.device) not_done_batch = not_done_batch.to(self.net.device) diff --git a/rltorch/agents/PPOAgent.py b/rltorch/agents/PPOAgent.py index 0a3ded4..8f6b78e 100644 --- a/rltorch/agents/PPOAgent.py +++ b/rltorch/agents/PPOAgent.py @@ -31,7 +31,7 @@ class PPOAgent: # Send batches to the appropriate device state_batch = torch.cat(state_batch).to(self.value_net.device) action_batch = torch.tensor(action_batch).to(self.value_net.device) - reward_batch = torch.tensor(reward_batch).to(self.value_net.device) + reward_batch = torch.tensor(reward_batch).to(self.value_net.device).float() not_done_batch = ~torch.tensor(done_batch).to(self.value_net.device) next_state_batch = torch.cat(next_state_batch).to(self.value_net.device) log_prob_batch = torch.cat(log_prob_batch).to(self.value_net.device) diff --git a/rltorch/agents/QEPAgent.py b/rltorch/agents/QEPAgent.py index d636cd2..6040e28 100644 --- a/rltorch/agents/QEPAgent.py +++ b/rltorch/agents/QEPAgent.py @@ -59,7 +59,7 @@ class QEPAgent: # Send to their appropriate devices state_batch = state_batch.to(self.value_net.device) action_batch = action_batch.to(self.value_net.device) - reward_batch = reward_batch.to(self.value_net.device) + reward_batch = reward_batch.to(self.value_net.device).float() next_state_batch = next_state_batch.to(self.value_net.device) not_done_batch = not_done_batch.to(self.value_net.device) @@ -82,7 +82,7 @@ class QEPAgent: best_next_state_value = torch.zeros(self.config['batch_size'], device = self.value_net.device) best_next_state_value[not_done_batch] = next_state_values[not_done_batch].gather(1, next_best_action.view((not_done_size, 1))).squeeze(1) - expected_values = (reward_batch.float() + (self.config['discount_rate'] * best_next_state_value)).unsqueeze(1) + expected_values = (reward_batch + (self.config['discount_rate'] * best_next_state_value)).unsqueeze(1) if (isinstance(self.memory, M.PrioritizedReplayMemory)): value_loss = (torch.as_tensor(importance_weights, device = self.value_net.device) * ((obtained_values - expected_values)**2).squeeze(1)).mean()