If memory or logger does not exist, then don't create those shared memory structures

This commit is contained in:
Brandon Rozek 2019-02-14 21:06:44 -05:00
parent 460d4c05c1
commit 19a859a4f6

View file

@ -9,7 +9,7 @@ def envrun(actor, env, episode_num, config, runcondition, iterations = 1, memory
# Wait for signal to start running through the environment
while runcondition.wait():
# Start a logger to log the rewards
logger = rltorch.log.Logger()
logger = rltorch.log.Logger() if logqueue is not None else None
for _ in range(iterations):
action = actor.act(state)
next_state, reward, done, _ = env.step(action)
@ -33,7 +33,8 @@ def envrun(actor, env, episode_num, config, runcondition, iterations = 1, memory
with episode_num.get_lock():
episode_num.value += 1
logqueue.put(logger)
if logqueue is not None:
logqueue.put(logger)
class EnvironmentRun():
def __init__(self, env, actor, config, memory = None, logwriter = None, name = ""):
@ -42,8 +43,8 @@ class EnvironmentRun():
self.memory = memory
self.episode_num = mp.Value(c_uint)
self.runcondition = mp.Event()
self.memory_queue = mp.Queue(maxsize = config['replay_skip'] + 1)
self.logqueue = mp.Queue(maxsize = 1)
self.memory_queue = mp.Queue(maxsize = config['replay_skip'] + 1) if memory is not None else None
self.logqueue = mp.Queue(maxsize = 1) if logwriter is not None else None
with self.episode_num.get_lock():
self.episode_num.value = 1
self.runner = mp.Process(target=envrun,
@ -66,7 +67,8 @@ class EnvironmentRun():
self.memory.append(*self.memory_queue.get())
def _get_reward_logger(self):
return self.logqueue.get()
if self.logqueue is not None:
return self.logqueue.get()
def terminate(self):
self.runner.terminate()