Made sure the reward_batch is float across different agents

This commit is contained in:
Brandon Rozek 2019-03-14 10:43:14 -04:00
parent cdfd3ab6b9
commit b2f5220585
4 changed files with 5 additions and 5 deletions

View file

@ -44,7 +44,7 @@ class A2CSingleAgent:
# Send batches to the appropriate device # Send batches to the appropriate device
state_batch = torch.cat(state_batch).to(self.value_net.device) state_batch = torch.cat(state_batch).to(self.value_net.device)
reward_batch = torch.tensor(reward_batch).to(self.value_net.device) reward_batch = torch.tensor(reward_batch).to(self.value_net.device).float()
not_done_batch = ~torch.tensor(done_batch).to(self.value_net.device) not_done_batch = ~torch.tensor(done_batch).to(self.value_net.device)
next_state_batch = torch.cat(next_state_batch).to(self.value_net.device) next_state_batch = torch.cat(next_state_batch).to(self.value_net.device)
log_prob_batch = torch.cat(log_prob_batch).to(self.value_net.device) log_prob_batch = torch.cat(log_prob_batch).to(self.value_net.device)

View file

@ -30,7 +30,7 @@ class DQNAgent:
# Send to their appropriate devices # Send to their appropriate devices
state_batch = state_batch.to(self.net.device) state_batch = state_batch.to(self.net.device)
action_batch = action_batch.to(self.net.device) action_batch = action_batch.to(self.net.device)
reward_batch = reward_batch.to(self.net.device) reward_batch = reward_batch.to(self.net.device).float()
next_state_batch = next_state_batch.to(self.net.device) next_state_batch = next_state_batch.to(self.net.device)
not_done_batch = not_done_batch.to(self.net.device) not_done_batch = not_done_batch.to(self.net.device)

View file

@ -31,7 +31,7 @@ class PPOAgent:
# Send batches to the appropriate device # Send batches to the appropriate device
state_batch = torch.cat(state_batch).to(self.value_net.device) state_batch = torch.cat(state_batch).to(self.value_net.device)
action_batch = torch.tensor(action_batch).to(self.value_net.device) action_batch = torch.tensor(action_batch).to(self.value_net.device)
reward_batch = torch.tensor(reward_batch).to(self.value_net.device) reward_batch = torch.tensor(reward_batch).to(self.value_net.device).float()
not_done_batch = ~torch.tensor(done_batch).to(self.value_net.device) not_done_batch = ~torch.tensor(done_batch).to(self.value_net.device)
next_state_batch = torch.cat(next_state_batch).to(self.value_net.device) next_state_batch = torch.cat(next_state_batch).to(self.value_net.device)
log_prob_batch = torch.cat(log_prob_batch).to(self.value_net.device) log_prob_batch = torch.cat(log_prob_batch).to(self.value_net.device)

View file

@ -59,7 +59,7 @@ class QEPAgent:
# Send to their appropriate devices # Send to their appropriate devices
state_batch = state_batch.to(self.value_net.device) state_batch = state_batch.to(self.value_net.device)
action_batch = action_batch.to(self.value_net.device) action_batch = action_batch.to(self.value_net.device)
reward_batch = reward_batch.to(self.value_net.device) reward_batch = reward_batch.to(self.value_net.device).float()
next_state_batch = next_state_batch.to(self.value_net.device) next_state_batch = next_state_batch.to(self.value_net.device)
not_done_batch = not_done_batch.to(self.value_net.device) not_done_batch = not_done_batch.to(self.value_net.device)
@ -82,7 +82,7 @@ class QEPAgent:
best_next_state_value = torch.zeros(self.config['batch_size'], device = self.value_net.device) best_next_state_value = torch.zeros(self.config['batch_size'], device = self.value_net.device)
best_next_state_value[not_done_batch] = next_state_values[not_done_batch].gather(1, next_best_action.view((not_done_size, 1))).squeeze(1) best_next_state_value[not_done_batch] = next_state_values[not_done_batch].gather(1, next_best_action.view((not_done_size, 1))).squeeze(1)
expected_values = (reward_batch.float() + (self.config['discount_rate'] * best_next_state_value)).unsqueeze(1) expected_values = (reward_batch + (self.config['discount_rate'] * best_next_state_value)).unsqueeze(1)
if (isinstance(self.memory, M.PrioritizedReplayMemory)): if (isinstance(self.memory, M.PrioritizedReplayMemory)):
value_loss = (torch.as_tensor(importance_weights, device = self.value_net.device) * ((obtained_values - expected_values)**2).squeeze(1)).mean() value_loss = (torch.as_tensor(importance_weights, device = self.value_net.device) * ((obtained_values - expected_values)**2).squeeze(1)).mean()