Corrected device when constructing fitness tensor
This commit is contained in:
parent
9740c40527
commit
1958fc7c7e
1 changed files with 2 additions and 1 deletions
|
@ -3,6 +3,7 @@ import torch
|
||||||
from .Network import Network
|
from .Network import Network
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
|
||||||
|
# [TODO] See if you need to move network to device
|
||||||
class ESNetwork(Network):
|
class ESNetwork(Network):
|
||||||
"""
|
"""
|
||||||
Network that functions from the paper Evolutionary Strategies (https://arxiv.org/abs/1703.03864)
|
Network that functions from the paper Evolutionary Strategies (https://arxiv.org/abs/1703.03864)
|
||||||
|
@ -52,7 +53,7 @@ class ESNetwork(Network):
|
||||||
candidate_solutions = self._generate_candidate_solutions(noise_dict)
|
candidate_solutions = self._generate_candidate_solutions(noise_dict)
|
||||||
|
|
||||||
## Calculate fitness then mean shift, scale
|
## Calculate fitness then mean shift, scale
|
||||||
fitness_values = torch.tensor([self.fitness(x, *args) for x in candidate_solutions])
|
fitness_values = torch.tensor([self.fitness(x, *args) for x in candidate_solutions], device = self.device)
|
||||||
if self.logger is not None:
|
if self.logger is not None:
|
||||||
self.logger.append(self.name + "/" + "fitness_value", fitness_values.mean().item())
|
self.logger.append(self.name + "/" + "fitness_value", fitness_values.mean().item())
|
||||||
fitness_values = (fitness_values - fitness_values.mean()) / (fitness_values.std() + np.finfo('float').eps)
|
fitness_values = (fitness_values - fitness_values.mean()) / (fitness_values.std() + np.finfo('float').eps)
|
||||||
|
|
Loading…
Reference in a new issue