diff --git a/rltorch/network/ESNetwork.py b/rltorch/network/ESNetwork.py index c5ae4b7..b3f372c 100644 --- a/rltorch/network/ESNetwork.py +++ b/rltorch/network/ESNetwork.py @@ -3,6 +3,7 @@ import torch from .Network import Network from copy import deepcopy +# [TODO] See if you need to move network to device class ESNetwork(Network): """ Network that functions from the paper Evolutionary Strategies (https://arxiv.org/abs/1703.03864) @@ -52,7 +53,7 @@ class ESNetwork(Network): candidate_solutions = self._generate_candidate_solutions(noise_dict) ## Calculate fitness then mean shift, scale - fitness_values = torch.tensor([self.fitness(x, *args) for x in candidate_solutions]) + fitness_values = torch.tensor([self.fitness(x, *args) for x in candidate_solutions], device = self.device) if self.logger is not None: self.logger.append(self.name + "/" + "fitness_value", fitness_values.mean().item()) fitness_values = (fitness_values - fitness_values.mean()) / (fitness_values.std() + np.finfo('float').eps)