Sends importance weights into correct device for prioiritized replay
This commit is contained in:
		
							parent
							
								
									013d40a4f9
								
							
						
					
					
						commit
						04e54cddc2
					
				
					 1 changed files with 1 additions and 1 deletions
				
			
		| 
						 | 
				
			
			@ -54,7 +54,7 @@ class DQNAgent:
 | 
			
		|||
        expected_values = (reward_batch + (self.config['discount_rate'] * best_next_state_value)).unsqueeze(1)
 | 
			
		||||
 | 
			
		||||
        if (isinstance(self.memory, M.PrioritizedReplayMemory)):
 | 
			
		||||
            loss = (torch.as_tensor(importance_weights) * (obtained_values - expected_values)**2).mean()
 | 
			
		||||
            loss = (torch.as_tensor(importance_weights, device = self.net.device) * (obtained_values - expected_values)**2).mean()
 | 
			
		||||
        else:
 | 
			
		||||
            loss = F.mse_loss(obtained_values, expected_values)
 | 
			
		||||
        
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue