Fixed multiprocessing with CUDA. Added entropy importance as a config option.
This commit is contained in:
		
							parent
							
								
									9d32a9edd1
								
							
						
					
					
						commit
						a99ca66b4f
					
				
					 2 changed files with 9 additions and 3 deletions
				
			
		| 
						 | 
				
			
			@ -5,12 +5,13 @@ import torch
 | 
			
		|||
from torch.distributions import Categorical
 | 
			
		||||
import rltorch
 | 
			
		||||
import rltorch.memory as M
 | 
			
		||||
import torch.nn.functional as F
 | 
			
		||||
 | 
			
		||||
# Q-Evolutionary Policy Agent
 | 
			
		||||
# Maximizes the policy with respect to the Q-Value function.
 | 
			
		||||
# Since function is non-differentiabile, depends on the Evolutionary Strategy algorithm
 | 
			
		||||
class QEPAgent:
 | 
			
		||||
    def __init__(self, policy_net, value_net, memory, config, target_value_net = None, logger = None):
 | 
			
		||||
    def __init__(self, policy_net, value_net, memory, config, target_value_net = None, logger = None, entropy_importance = 0):
 | 
			
		||||
        self.policy_net = policy_net
 | 
			
		||||
        assert isinstance(self.policy_net, rltorch.network.ESNetwork) or isinstance(self.policy_net, rltorch.network.ESNetworkMP)
 | 
			
		||||
        self.policy_net.fitness = self.fitness
 | 
			
		||||
| 
						 | 
				
			
			@ -20,6 +21,7 @@ class QEPAgent:
 | 
			
		|||
        self.config = deepcopy(config)
 | 
			
		||||
        self.logger = logger
 | 
			
		||||
        self.policy_skip = 4
 | 
			
		||||
        self.entropy_importance = entropy_importance
 | 
			
		||||
    
 | 
			
		||||
    def save(self, file_location):
 | 
			
		||||
        torch.save({
 | 
			
		||||
| 
						 | 
				
			
			@ -36,7 +38,8 @@ class QEPAgent:
 | 
			
		|||
 | 
			
		||||
    def fitness(self, policy_net, value_net, state_batch):
 | 
			
		||||
        batch_size = len(state_batch)
 | 
			
		||||
        action_probabilities = policy_net(state_batch)
 | 
			
		||||
        with torch.no_grad():
 | 
			
		||||
            action_probabilities = policy_net(state_batch)
 | 
			
		||||
        action_size = action_probabilities.shape[1]
 | 
			
		||||
        distributions = list(map(Categorical, action_probabilities))
 | 
			
		||||
        actions = torch.stack([d.sample() for d in distributions])
 | 
			
		||||
| 
						 | 
				
			
			@ -53,6 +56,7 @@ class QEPAgent:
 | 
			
		|||
        obtained_values = state_values.gather(1, actions.view(len(state_batch), 1)).squeeze(1)
 | 
			
		||||
        # return -obtained_values.mean().item()
 | 
			
		||||
        entropy_importance = 0 # Entropy accounting for 1% of loss seems to work well
 | 
			
		||||
        entropy_importance = next(self.entropy_importance) if isinstance(self.entropy_importance, collections.Iterable) else self.entropy_importance
 | 
			
		||||
        value_importance = 1 - entropy_importance
 | 
			
		||||
        
 | 
			
		||||
        # entropy_loss = (action_probabilities * torch.log2(action_probabilities)).sum(1) # Standard entropy loss from information theory
 | 
			
		||||
| 
						 | 
				
			
			@ -131,6 +135,7 @@ class QEPAgent:
 | 
			
		|||
          self.policy_skip -= 1
 | 
			
		||||
          return
 | 
			
		||||
        self.policy_skip = 4
 | 
			
		||||
 | 
			
		||||
        if self.target_value_net is not None:
 | 
			
		||||
          self.policy_net.calc_gradients(self.target_value_net, state_batch)
 | 
			
		||||
        else:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -26,7 +26,8 @@ class ESNetworkMP(Network):
 | 
			
		|||
        self.fitness = fitness_fn
 | 
			
		||||
        self.sigma = sigma
 | 
			
		||||
        assert self.sigma > 0
 | 
			
		||||
        self.pool = mp.Pool(processes=3) #[TODO] Probably should make number of processes a config variable
 | 
			
		||||
        mp_ctx = mp.get_context("spawn")
 | 
			
		||||
        self.pool = mp_ctx.Pool(processes=2) #[TODO] Probably should make number of processes a config variable
 | 
			
		||||
 | 
			
		||||
    # We're not going to be calculating gradients in the traditional way
 | 
			
		||||
    # So there's no need to waste computation time keeping track
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue