Some work on multiprocessing evolutionary strategies from last semester
This commit is contained in:
		
							parent
							
								
									6d3a78cd20
								
							
						
					
					
						commit
						da83f1470c
					
				
					 1 changed files with 13 additions and 5 deletions
				
			
		| 
						 | 
				
			
			@ -12,7 +12,7 @@ import rltorch.memory as M
 | 
			
		|||
class QEPAgent:
 | 
			
		||||
    def __init__(self, policy_net, value_net, memory, config, target_value_net = None, logger = None):
 | 
			
		||||
        self.policy_net = policy_net
 | 
			
		||||
        assert isinstance(self.policy_net, rltorch.network.ESNetwork)
 | 
			
		||||
        assert isinstance(self.policy_net, rltorch.network.ESNetwork) or isinstance(self.policy_net, rltorch.network.ESNetworkMP)
 | 
			
		||||
        self.policy_net.fitness = self.fitness
 | 
			
		||||
        self.value_net = value_net
 | 
			
		||||
        self.target_value_net = target_value_net
 | 
			
		||||
| 
						 | 
				
			
			@ -22,6 +22,7 @@ class QEPAgent:
 | 
			
		|||
        self.policy_skip = 4
 | 
			
		||||
 | 
			
		||||
    def fitness(self, policy_net, value_net, state_batch):
 | 
			
		||||
        # print("Worker started")
 | 
			
		||||
        batch_size = len(state_batch)
 | 
			
		||||
        action_probabilities = policy_net(state_batch)
 | 
			
		||||
        action_size = action_probabilities.shape[1]
 | 
			
		||||
| 
						 | 
				
			
			@ -30,15 +31,21 @@ class QEPAgent:
 | 
			
		|||
      
 | 
			
		||||
        with torch.no_grad():
 | 
			
		||||
            state_values = value_net(state_batch)
 | 
			
		||||
 | 
			
		||||
        # Weird hacky solution where in multiprocess, it sometimes spits out nans
 | 
			
		||||
        # So have it try again
 | 
			
		||||
        while torch.isnan(state_values).any():
 | 
			
		||||
            with torch.no_grad():
 | 
			
		||||
                state_values = value_net(state_batch)
 | 
			
		||||
 | 
			
		||||
        obtained_values = state_values.gather(1, actions.view(len(state_batch), 1)).squeeze(1)
 | 
			
		||||
        # return -obtained_values.mean().item()
 | 
			
		||||
        
 | 
			
		||||
        entropy_importance = 0.01 # Entropy accounting for 1% of loss seems to work well
 | 
			
		||||
        entropy_importance = 0 # Entropy accounting for 1% of loss seems to work well
 | 
			
		||||
        value_importance = 1 - entropy_importance
 | 
			
		||||
        
 | 
			
		||||
        # entropy_loss = (action_probabilities * torch.log2(action_probabilities)).sum(1) # Standard entropy loss from information theory
 | 
			
		||||
        entropy_loss = (action_probabilities - torch.tensor(1 / action_size).repeat(len(state_batch), action_size)).abs().sum(1)
 | 
			
		||||
        
 | 
			
		||||
        # print("END WORKER")
 | 
			
		||||
        return (entropy_importance * entropy_loss - value_importance * obtained_values).mean().item()
 | 
			
		||||
        
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -112,10 +119,11 @@ class QEPAgent:
 | 
			
		|||
          self.policy_skip -= 1
 | 
			
		||||
          return
 | 
			
		||||
        self.policy_skip = 4
 | 
			
		||||
 | 
			
		||||
        if self.target_value_net is not None:
 | 
			
		||||
          self.policy_net.calc_gradients(self.target_value_net, state_batch)
 | 
			
		||||
        else:
 | 
			
		||||
          self.policy_net.calc_gradients(self.value_net, state_batch)
 | 
			
		||||
        # self.policy_net.clamp_gradients()
 | 
			
		||||
 | 
			
		||||
        self.policy_net.step()
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue