If memory or logger does not exist, then don't create those shared memory structures
This commit is contained in:
		
							parent
							
								
									460d4c05c1
								
							
						
					
					
						commit
						19a859a4f6
					
				
					 1 changed files with 7 additions and 5 deletions
				
			
		| 
						 | 
				
			
			@ -9,7 +9,7 @@ def envrun(actor, env, episode_num, config, runcondition, iterations = 1, memory
 | 
			
		|||
  # Wait for signal to start running through the environment
 | 
			
		||||
  while runcondition.wait():
 | 
			
		||||
    # Start a logger to log the rewards
 | 
			
		||||
    logger = rltorch.log.Logger()
 | 
			
		||||
    logger = rltorch.log.Logger() if logqueue is not None else None
 | 
			
		||||
    for _ in range(iterations):
 | 
			
		||||
      action = actor.act(state)
 | 
			
		||||
      next_state, reward, done, _ = env.step(action)
 | 
			
		||||
| 
						 | 
				
			
			@ -33,7 +33,8 @@ def envrun(actor, env, episode_num, config, runcondition, iterations = 1, memory
 | 
			
		|||
        with episode_num.get_lock():
 | 
			
		||||
          episode_num.value +=  1
 | 
			
		||||
          
 | 
			
		||||
    logqueue.put(logger)
 | 
			
		||||
    if logqueue is not None:
 | 
			
		||||
      logqueue.put(logger)
 | 
			
		||||
  
 | 
			
		||||
class EnvironmentRun():
 | 
			
		||||
  def __init__(self, env, actor, config, memory = None, logwriter = None, name = ""):
 | 
			
		||||
| 
						 | 
				
			
			@ -42,8 +43,8 @@ class EnvironmentRun():
 | 
			
		|||
    self.memory = memory
 | 
			
		||||
    self.episode_num = mp.Value(c_uint)
 | 
			
		||||
    self.runcondition = mp.Event()
 | 
			
		||||
    self.memory_queue = mp.Queue(maxsize = config['replay_skip'] + 1)
 | 
			
		||||
    self.logqueue = mp.Queue(maxsize = 1)
 | 
			
		||||
    self.memory_queue = mp.Queue(maxsize = config['replay_skip'] + 1) if memory is not None else None
 | 
			
		||||
    self.logqueue = mp.Queue(maxsize = 1) if logwriter is not None else None
 | 
			
		||||
    with self.episode_num.get_lock():
 | 
			
		||||
      self.episode_num.value = 1
 | 
			
		||||
    self.runner = mp.Process(target=envrun, 
 | 
			
		||||
| 
						 | 
				
			
			@ -66,7 +67,8 @@ class EnvironmentRun():
 | 
			
		|||
        self.memory.append(*self.memory_queue.get())
 | 
			
		||||
 | 
			
		||||
  def _get_reward_logger(self):
 | 
			
		||||
    return self.logqueue.get()
 | 
			
		||||
    if self.logqueue is not None:
 | 
			
		||||
      return self.logqueue.get()
 | 
			
		||||
 | 
			
		||||
  def terminate(self):
 | 
			
		||||
    self.runner.terminate()
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue