Made sure the reward_batch is float across different agents
This commit is contained in:
		
							parent
							
								
									cdfd3ab6b9
								
							
						
					
					
						commit
						b2f5220585
					
				
					 4 changed files with 5 additions and 5 deletions
				
			
		| 
						 | 
				
			
			@ -44,7 +44,7 @@ class A2CSingleAgent:
 | 
			
		|||
 | 
			
		||||
    # Send batches to the appropriate device
 | 
			
		||||
    state_batch = torch.cat(state_batch).to(self.value_net.device)
 | 
			
		||||
    reward_batch = torch.tensor(reward_batch).to(self.value_net.device)
 | 
			
		||||
    reward_batch = torch.tensor(reward_batch).to(self.value_net.device).float()
 | 
			
		||||
    not_done_batch = ~torch.tensor(done_batch).to(self.value_net.device)
 | 
			
		||||
    next_state_batch = torch.cat(next_state_batch).to(self.value_net.device)
 | 
			
		||||
    log_prob_batch = torch.cat(log_prob_batch).to(self.value_net.device)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -30,7 +30,7 @@ class DQNAgent:
 | 
			
		|||
        # Send to their appropriate devices
 | 
			
		||||
        state_batch = state_batch.to(self.net.device)
 | 
			
		||||
        action_batch = action_batch.to(self.net.device)
 | 
			
		||||
        reward_batch = reward_batch.to(self.net.device)
 | 
			
		||||
        reward_batch = reward_batch.to(self.net.device).float()
 | 
			
		||||
        next_state_batch = next_state_batch.to(self.net.device)
 | 
			
		||||
        not_done_batch = not_done_batch.to(self.net.device)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -31,7 +31,7 @@ class PPOAgent:
 | 
			
		|||
    # Send batches to the appropriate device
 | 
			
		||||
    state_batch = torch.cat(state_batch).to(self.value_net.device)
 | 
			
		||||
    action_batch = torch.tensor(action_batch).to(self.value_net.device)
 | 
			
		||||
    reward_batch = torch.tensor(reward_batch).to(self.value_net.device)
 | 
			
		||||
    reward_batch = torch.tensor(reward_batch).to(self.value_net.device).float()
 | 
			
		||||
    not_done_batch = ~torch.tensor(done_batch).to(self.value_net.device)
 | 
			
		||||
    next_state_batch = torch.cat(next_state_batch).to(self.value_net.device)
 | 
			
		||||
    log_prob_batch = torch.cat(log_prob_batch).to(self.value_net.device)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -59,7 +59,7 @@ class QEPAgent:
 | 
			
		|||
        # Send to their appropriate devices
 | 
			
		||||
        state_batch = state_batch.to(self.value_net.device)
 | 
			
		||||
        action_batch = action_batch.to(self.value_net.device)
 | 
			
		||||
        reward_batch = reward_batch.to(self.value_net.device)
 | 
			
		||||
        reward_batch = reward_batch.to(self.value_net.device).float()
 | 
			
		||||
        next_state_batch = next_state_batch.to(self.value_net.device)
 | 
			
		||||
        not_done_batch = not_done_batch.to(self.value_net.device)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -82,7 +82,7 @@ class QEPAgent:
 | 
			
		|||
            best_next_state_value = torch.zeros(self.config['batch_size'], device = self.value_net.device)
 | 
			
		||||
            best_next_state_value[not_done_batch] = next_state_values[not_done_batch].gather(1, next_best_action.view((not_done_size, 1))).squeeze(1)
 | 
			
		||||
            
 | 
			
		||||
        expected_values = (reward_batch.float() + (self.config['discount_rate'] * best_next_state_value)).unsqueeze(1)
 | 
			
		||||
        expected_values = (reward_batch + (self.config['discount_rate'] * best_next_state_value)).unsqueeze(1)
 | 
			
		||||
 | 
			
		||||
        if (isinstance(self.memory, M.PrioritizedReplayMemory)):
 | 
			
		||||
            value_loss = (torch.as_tensor(importance_weights, device = self.value_net.device) * ((obtained_values - expected_values)**2).squeeze(1)).mean()
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue