Corrected device when constructing fitness tensor
This commit is contained in:
		
							parent
							
								
									9740c40527
								
							
						
					
					
						commit
						1958fc7c7e
					
				
					 1 changed files with 2 additions and 1 deletions
				
			
		| 
						 | 
				
			
			@ -3,6 +3,7 @@ import torch
 | 
			
		|||
from .Network import Network
 | 
			
		||||
from copy import deepcopy
 | 
			
		||||
 | 
			
		||||
# [TODO] See if you need to move network to device
 | 
			
		||||
class ESNetwork(Network):
 | 
			
		||||
    """
 | 
			
		||||
    Network that functions from the paper Evolutionary Strategies (https://arxiv.org/abs/1703.03864)
 | 
			
		||||
| 
						 | 
				
			
			@ -52,7 +53,7 @@ class ESNetwork(Network):
 | 
			
		|||
        candidate_solutions = self._generate_candidate_solutions(noise_dict)
 | 
			
		||||
        
 | 
			
		||||
        ## Calculate fitness then mean shift, scale
 | 
			
		||||
        fitness_values = torch.tensor([self.fitness(x, *args) for x in candidate_solutions])
 | 
			
		||||
        fitness_values = torch.tensor([self.fitness(x, *args) for x in candidate_solutions], device = self.device)
 | 
			
		||||
        if self.logger is not None:
 | 
			
		||||
            self.logger.append(self.name + "/" + "fitness_value", fitness_values.mean().item())
 | 
			
		||||
        fitness_values = (fitness_values - fitness_values.mean()) / (fitness_values.std() + np.finfo('float').eps)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue