Added parallel version of ES
This commit is contained in:
		
							parent
							
								
									9ad63a6921
								
							
						
					
					
						commit
						6d3a78cd20
					
				
					 3 changed files with 91 additions and 0 deletions
				
			
		| 
						 | 
				
			
			@ -16,6 +16,7 @@ class ESNetwork(Network):
 | 
			
		|||
        self.population_size = population_size
 | 
			
		||||
        self.fitness = fitness_fn
 | 
			
		||||
        self.sigma = sigma
 | 
			
		||||
        assert self.sigma > 0
 | 
			
		||||
 | 
			
		||||
    # We're not going to be calculating gradients in the traditional way
 | 
			
		||||
    # So there's no need to waste computation time keeping track
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										89
									
								
								rltorch/network/ESNetworkMP.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										89
									
								
								rltorch/network/ESNetworkMP.py
									
										
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,89 @@
 | 
			
		|||
import numpy as np
 | 
			
		||||
import torch
 | 
			
		||||
from .Network import Network
 | 
			
		||||
from copy import deepcopy
 | 
			
		||||
import torch.multiprocessing as mp
 | 
			
		||||
import functools
 | 
			
		||||
 | 
			
		||||
class fn_copy:
 | 
			
		||||
    def __init__(self, fn, args):
 | 
			
		||||
        self.fn = fn
 | 
			
		||||
        self.args = args
 | 
			
		||||
    def __call__(self, x):
 | 
			
		||||
        return self.fn(x, *(self.args))
 | 
			
		||||
 | 
			
		||||
# [TODO] Should we torch.no_grad the __call__?
 | 
			
		||||
# What if we want to sometimes do gradient descent as well?
 | 
			
		||||
class ESNetworkMP(Network):
 | 
			
		||||
    """
 | 
			
		||||
    Network that functions from the paper Evolutionary Strategies (https://arxiv.org/abs/1703.03864)
 | 
			
		||||
    fitness_fun := model, *args -> fitness_value (float)
 | 
			
		||||
    We wish to find a model that maximizes the fitness function
 | 
			
		||||
    """
 | 
			
		||||
    def __init__(self, model, optimizer, population_size, fitness_fn, config, sigma = 0.05, device = None, logger = None, name = ""):
 | 
			
		||||
        super(ESNetworkMP, self).__init__(model, optimizer, config, device, logger, name)
 | 
			
		||||
        self.population_size = population_size
 | 
			
		||||
        self.fitness = fitness_fn
 | 
			
		||||
        self.sigma = sigma
 | 
			
		||||
        assert self.sigma > 0
 | 
			
		||||
        self.pool = mp.Pool(processes=3) #[TODO] Probably should make number of processes a config variable
 | 
			
		||||
 | 
			
		||||
    # We're not going to be calculating gradients in the traditional way
 | 
			
		||||
    # So there's no need to waste computation time keeping track
 | 
			
		||||
    def __call__(self, *args):
 | 
			
		||||
        with torch.no_grad():
 | 
			
		||||
            result = self.model(*args)
 | 
			
		||||
        return result
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    def _generate_noise_dicts(self):
 | 
			
		||||
        model_dict = self.model.state_dict()
 | 
			
		||||
        white_noise_dict = {}
 | 
			
		||||
        noise_dict = {}
 | 
			
		||||
        for key in model_dict.keys():
 | 
			
		||||
            white_noise_dict[key] = torch.randn(self.population_size, *model_dict[key].shape, device = self.device)
 | 
			
		||||
            noise_dict[key] = self.sigma * white_noise_dict[key]
 | 
			
		||||
        return white_noise_dict, noise_dict
 | 
			
		||||
 | 
			
		||||
    def _generate_candidate_solutions(self, noise_dict):
 | 
			
		||||
        model_dict = self.model.state_dict()
 | 
			
		||||
        candidate_solutions = []
 | 
			
		||||
        for i in range(self.population_size):
 | 
			
		||||
            candidate_statedict = {}
 | 
			
		||||
            for key in model_dict.keys():
 | 
			
		||||
                candidate_statedict[key] = model_dict[key] + noise_dict[key][i]
 | 
			
		||||
            candidate = deepcopy(self.model)
 | 
			
		||||
            candidate.load_state_dict(candidate_statedict)
 | 
			
		||||
            candidate_solutions.append(candidate)
 | 
			
		||||
        return candidate_solutions
 | 
			
		||||
 | 
			
		||||
    
 | 
			
		||||
    def calc_gradients(self, *args):
 | 
			
		||||
        ## Generate Noise
 | 
			
		||||
        white_noise_dict, noise_dict = self._generate_noise_dicts()
 | 
			
		||||
        
 | 
			
		||||
        ## Generate candidate solutions
 | 
			
		||||
        candidate_solutions = self._generate_candidate_solutions(noise_dict)
 | 
			
		||||
        
 | 
			
		||||
        ## Calculate fitness then mean shift, scale
 | 
			
		||||
        fitness_values = torch.tensor(list(self.pool.map(fn_copy(self.fitness, args), candidate_solutions)), device = self.device)
 | 
			
		||||
 | 
			
		||||
        if self.logger is not None:
 | 
			
		||||
            self.logger.append(self.name + "/" + "fitness_value", fitness_values.mean().item())
 | 
			
		||||
        fitness_values = (fitness_values - fitness_values.mean()) / (fitness_values.std() + np.finfo('float').eps)
 | 
			
		||||
 | 
			
		||||
        ## Insert adjustments into gradients slot
 | 
			
		||||
        self.zero_grad()
 | 
			
		||||
        for name, param in self.model.named_parameters():
 | 
			
		||||
            if param.requires_grad:
 | 
			
		||||
                noise_dim_n = len(white_noise_dict[name].shape)
 | 
			
		||||
                dim = np.repeat(1, noise_dim_n - 1).tolist() if noise_dim_n > 0 else []
 | 
			
		||||
                param.grad = (white_noise_dict[name] * fitness_values.float().reshape(self.population_size, *dim)).mean(0) / self.sigma
 | 
			
		||||
        
 | 
			
		||||
        del white_noise_dict, noise_dict, candidate_solutions
 | 
			
		||||
    
 | 
			
		||||
    # To address error that you can't pickle pool objects...
 | 
			
		||||
    def __getstate__(self):
 | 
			
		||||
        self_dict = self.__dict__.copy()
 | 
			
		||||
        del self_dict['pool']
 | 
			
		||||
        return self_dict
 | 
			
		||||
| 
						 | 
				
			
			@ -1,4 +1,5 @@
 | 
			
		|||
from .ESNetwork import *
 | 
			
		||||
from .ESNetworkMP import *
 | 
			
		||||
from .Network import *
 | 
			
		||||
from .NoisyLinear import *
 | 
			
		||||
from .TargetNetwork import *
 | 
			
		||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue