Made sure the reward_batch is float across different agents
This commit is contained in:
parent
cdfd3ab6b9
commit
b2f5220585
4 changed files with 5 additions and 5 deletions
|
@ -44,7 +44,7 @@ class A2CSingleAgent:
|
||||||
|
|
||||||
# Send batches to the appropriate device
|
# Send batches to the appropriate device
|
||||||
state_batch = torch.cat(state_batch).to(self.value_net.device)
|
state_batch = torch.cat(state_batch).to(self.value_net.device)
|
||||||
reward_batch = torch.tensor(reward_batch).to(self.value_net.device)
|
reward_batch = torch.tensor(reward_batch).to(self.value_net.device).float()
|
||||||
not_done_batch = ~torch.tensor(done_batch).to(self.value_net.device)
|
not_done_batch = ~torch.tensor(done_batch).to(self.value_net.device)
|
||||||
next_state_batch = torch.cat(next_state_batch).to(self.value_net.device)
|
next_state_batch = torch.cat(next_state_batch).to(self.value_net.device)
|
||||||
log_prob_batch = torch.cat(log_prob_batch).to(self.value_net.device)
|
log_prob_batch = torch.cat(log_prob_batch).to(self.value_net.device)
|
||||||
|
|
|
@ -30,7 +30,7 @@ class DQNAgent:
|
||||||
# Send to their appropriate devices
|
# Send to their appropriate devices
|
||||||
state_batch = state_batch.to(self.net.device)
|
state_batch = state_batch.to(self.net.device)
|
||||||
action_batch = action_batch.to(self.net.device)
|
action_batch = action_batch.to(self.net.device)
|
||||||
reward_batch = reward_batch.to(self.net.device)
|
reward_batch = reward_batch.to(self.net.device).float()
|
||||||
next_state_batch = next_state_batch.to(self.net.device)
|
next_state_batch = next_state_batch.to(self.net.device)
|
||||||
not_done_batch = not_done_batch.to(self.net.device)
|
not_done_batch = not_done_batch.to(self.net.device)
|
||||||
|
|
||||||
|
|
|
@ -31,7 +31,7 @@ class PPOAgent:
|
||||||
# Send batches to the appropriate device
|
# Send batches to the appropriate device
|
||||||
state_batch = torch.cat(state_batch).to(self.value_net.device)
|
state_batch = torch.cat(state_batch).to(self.value_net.device)
|
||||||
action_batch = torch.tensor(action_batch).to(self.value_net.device)
|
action_batch = torch.tensor(action_batch).to(self.value_net.device)
|
||||||
reward_batch = torch.tensor(reward_batch).to(self.value_net.device)
|
reward_batch = torch.tensor(reward_batch).to(self.value_net.device).float()
|
||||||
not_done_batch = ~torch.tensor(done_batch).to(self.value_net.device)
|
not_done_batch = ~torch.tensor(done_batch).to(self.value_net.device)
|
||||||
next_state_batch = torch.cat(next_state_batch).to(self.value_net.device)
|
next_state_batch = torch.cat(next_state_batch).to(self.value_net.device)
|
||||||
log_prob_batch = torch.cat(log_prob_batch).to(self.value_net.device)
|
log_prob_batch = torch.cat(log_prob_batch).to(self.value_net.device)
|
||||||
|
|
|
@ -59,7 +59,7 @@ class QEPAgent:
|
||||||
# Send to their appropriate devices
|
# Send to their appropriate devices
|
||||||
state_batch = state_batch.to(self.value_net.device)
|
state_batch = state_batch.to(self.value_net.device)
|
||||||
action_batch = action_batch.to(self.value_net.device)
|
action_batch = action_batch.to(self.value_net.device)
|
||||||
reward_batch = reward_batch.to(self.value_net.device)
|
reward_batch = reward_batch.to(self.value_net.device).float()
|
||||||
next_state_batch = next_state_batch.to(self.value_net.device)
|
next_state_batch = next_state_batch.to(self.value_net.device)
|
||||||
not_done_batch = not_done_batch.to(self.value_net.device)
|
not_done_batch = not_done_batch.to(self.value_net.device)
|
||||||
|
|
||||||
|
@ -82,7 +82,7 @@ class QEPAgent:
|
||||||
best_next_state_value = torch.zeros(self.config['batch_size'], device = self.value_net.device)
|
best_next_state_value = torch.zeros(self.config['batch_size'], device = self.value_net.device)
|
||||||
best_next_state_value[not_done_batch] = next_state_values[not_done_batch].gather(1, next_best_action.view((not_done_size, 1))).squeeze(1)
|
best_next_state_value[not_done_batch] = next_state_values[not_done_batch].gather(1, next_best_action.view((not_done_size, 1))).squeeze(1)
|
||||||
|
|
||||||
expected_values = (reward_batch.float() + (self.config['discount_rate'] * best_next_state_value)).unsqueeze(1)
|
expected_values = (reward_batch + (self.config['discount_rate'] * best_next_state_value)).unsqueeze(1)
|
||||||
|
|
||||||
if (isinstance(self.memory, M.PrioritizedReplayMemory)):
|
if (isinstance(self.memory, M.PrioritizedReplayMemory)):
|
||||||
value_loss = (torch.as_tensor(importance_weights, device = self.value_net.device) * ((obtained_values - expected_values)**2).squeeze(1)).mean()
|
value_loss = (torch.as_tensor(importance_weights, device = self.value_net.device) * ((obtained_values - expected_values)**2).squeeze(1)).mean()
|
||||||
|
|
Loading…
Reference in a new issue