If memory or logger does not exist, then don't create those shared memory structures

This commit is contained in:
Brandon Rozek 2019-02-14 21:06:44 -05:00
parent 460d4c05c1
commit 19a859a4f6

View file

@ -9,7 +9,7 @@ def envrun(actor, env, episode_num, config, runcondition, iterations = 1, memory
# Wait for signal to start running through the environment # Wait for signal to start running through the environment
while runcondition.wait(): while runcondition.wait():
# Start a logger to log the rewards # Start a logger to log the rewards
logger = rltorch.log.Logger() logger = rltorch.log.Logger() if logqueue is not None else None
for _ in range(iterations): for _ in range(iterations):
action = actor.act(state) action = actor.act(state)
next_state, reward, done, _ = env.step(action) next_state, reward, done, _ = env.step(action)
@ -33,7 +33,8 @@ def envrun(actor, env, episode_num, config, runcondition, iterations = 1, memory
with episode_num.get_lock(): with episode_num.get_lock():
episode_num.value += 1 episode_num.value += 1
logqueue.put(logger) if logqueue is not None:
logqueue.put(logger)
class EnvironmentRun(): class EnvironmentRun():
def __init__(self, env, actor, config, memory = None, logwriter = None, name = ""): def __init__(self, env, actor, config, memory = None, logwriter = None, name = ""):
@ -42,8 +43,8 @@ class EnvironmentRun():
self.memory = memory self.memory = memory
self.episode_num = mp.Value(c_uint) self.episode_num = mp.Value(c_uint)
self.runcondition = mp.Event() self.runcondition = mp.Event()
self.memory_queue = mp.Queue(maxsize = config['replay_skip'] + 1) self.memory_queue = mp.Queue(maxsize = config['replay_skip'] + 1) if memory is not None else None
self.logqueue = mp.Queue(maxsize = 1) self.logqueue = mp.Queue(maxsize = 1) if logwriter is not None else None
with self.episode_num.get_lock(): with self.episode_num.get_lock():
self.episode_num.value = 1 self.episode_num.value = 1
self.runner = mp.Process(target=envrun, self.runner = mp.Process(target=envrun,
@ -66,7 +67,8 @@ class EnvironmentRun():
self.memory.append(*self.memory_queue.get()) self.memory.append(*self.memory_queue.get())
def _get_reward_logger(self): def _get_reward_logger(self):
return self.logqueue.get() if self.logqueue is not None:
return self.logqueue.get()
def terminate(self): def terminate(self):
self.runner.terminate() self.runner.terminate()