Made sure the reward_batch is float across different agents
This commit is contained in:
parent
cdfd3ab6b9
commit
b2f5220585
4 changed files with 5 additions and 5 deletions
|
@ -44,7 +44,7 @@ class A2CSingleAgent:
|
|||
|
||||
# Send batches to the appropriate device
|
||||
state_batch = torch.cat(state_batch).to(self.value_net.device)
|
||||
reward_batch = torch.tensor(reward_batch).to(self.value_net.device)
|
||||
reward_batch = torch.tensor(reward_batch).to(self.value_net.device).float()
|
||||
not_done_batch = ~torch.tensor(done_batch).to(self.value_net.device)
|
||||
next_state_batch = torch.cat(next_state_batch).to(self.value_net.device)
|
||||
log_prob_batch = torch.cat(log_prob_batch).to(self.value_net.device)
|
||||
|
|
|
@ -30,7 +30,7 @@ class DQNAgent:
|
|||
# Send to their appropriate devices
|
||||
state_batch = state_batch.to(self.net.device)
|
||||
action_batch = action_batch.to(self.net.device)
|
||||
reward_batch = reward_batch.to(self.net.device)
|
||||
reward_batch = reward_batch.to(self.net.device).float()
|
||||
next_state_batch = next_state_batch.to(self.net.device)
|
||||
not_done_batch = not_done_batch.to(self.net.device)
|
||||
|
||||
|
|
|
@ -31,7 +31,7 @@ class PPOAgent:
|
|||
# Send batches to the appropriate device
|
||||
state_batch = torch.cat(state_batch).to(self.value_net.device)
|
||||
action_batch = torch.tensor(action_batch).to(self.value_net.device)
|
||||
reward_batch = torch.tensor(reward_batch).to(self.value_net.device)
|
||||
reward_batch = torch.tensor(reward_batch).to(self.value_net.device).float()
|
||||
not_done_batch = ~torch.tensor(done_batch).to(self.value_net.device)
|
||||
next_state_batch = torch.cat(next_state_batch).to(self.value_net.device)
|
||||
log_prob_batch = torch.cat(log_prob_batch).to(self.value_net.device)
|
||||
|
|
|
@ -59,7 +59,7 @@ class QEPAgent:
|
|||
# Send to their appropriate devices
|
||||
state_batch = state_batch.to(self.value_net.device)
|
||||
action_batch = action_batch.to(self.value_net.device)
|
||||
reward_batch = reward_batch.to(self.value_net.device)
|
||||
reward_batch = reward_batch.to(self.value_net.device).float()
|
||||
next_state_batch = next_state_batch.to(self.value_net.device)
|
||||
not_done_batch = not_done_batch.to(self.value_net.device)
|
||||
|
||||
|
@ -82,7 +82,7 @@ class QEPAgent:
|
|||
best_next_state_value = torch.zeros(self.config['batch_size'], device = self.value_net.device)
|
||||
best_next_state_value[not_done_batch] = next_state_values[not_done_batch].gather(1, next_best_action.view((not_done_size, 1))).squeeze(1)
|
||||
|
||||
expected_values = (reward_batch.float() + (self.config['discount_rate'] * best_next_state_value)).unsqueeze(1)
|
||||
expected_values = (reward_batch + (self.config['discount_rate'] * best_next_state_value)).unsqueeze(1)
|
||||
|
||||
if (isinstance(self.memory, M.PrioritizedReplayMemory)):
|
||||
value_loss = (torch.as_tensor(importance_weights, device = self.value_net.device) * ((obtained_values - expected_values)**2).squeeze(1)).mean()
|
||||
|
|
Loading…
Reference in a new issue