If memory or logger does not exist, then don't create those shared memory structures
This commit is contained in:
parent
460d4c05c1
commit
19a859a4f6
1 changed files with 7 additions and 5 deletions
|
@ -9,7 +9,7 @@ def envrun(actor, env, episode_num, config, runcondition, iterations = 1, memory
|
||||||
# Wait for signal to start running through the environment
|
# Wait for signal to start running through the environment
|
||||||
while runcondition.wait():
|
while runcondition.wait():
|
||||||
# Start a logger to log the rewards
|
# Start a logger to log the rewards
|
||||||
logger = rltorch.log.Logger()
|
logger = rltorch.log.Logger() if logqueue is not None else None
|
||||||
for _ in range(iterations):
|
for _ in range(iterations):
|
||||||
action = actor.act(state)
|
action = actor.act(state)
|
||||||
next_state, reward, done, _ = env.step(action)
|
next_state, reward, done, _ = env.step(action)
|
||||||
|
@ -33,7 +33,8 @@ def envrun(actor, env, episode_num, config, runcondition, iterations = 1, memory
|
||||||
with episode_num.get_lock():
|
with episode_num.get_lock():
|
||||||
episode_num.value += 1
|
episode_num.value += 1
|
||||||
|
|
||||||
logqueue.put(logger)
|
if logqueue is not None:
|
||||||
|
logqueue.put(logger)
|
||||||
|
|
||||||
class EnvironmentRun():
|
class EnvironmentRun():
|
||||||
def __init__(self, env, actor, config, memory = None, logwriter = None, name = ""):
|
def __init__(self, env, actor, config, memory = None, logwriter = None, name = ""):
|
||||||
|
@ -42,8 +43,8 @@ class EnvironmentRun():
|
||||||
self.memory = memory
|
self.memory = memory
|
||||||
self.episode_num = mp.Value(c_uint)
|
self.episode_num = mp.Value(c_uint)
|
||||||
self.runcondition = mp.Event()
|
self.runcondition = mp.Event()
|
||||||
self.memory_queue = mp.Queue(maxsize = config['replay_skip'] + 1)
|
self.memory_queue = mp.Queue(maxsize = config['replay_skip'] + 1) if memory is not None else None
|
||||||
self.logqueue = mp.Queue(maxsize = 1)
|
self.logqueue = mp.Queue(maxsize = 1) if logwriter is not None else None
|
||||||
with self.episode_num.get_lock():
|
with self.episode_num.get_lock():
|
||||||
self.episode_num.value = 1
|
self.episode_num.value = 1
|
||||||
self.runner = mp.Process(target=envrun,
|
self.runner = mp.Process(target=envrun,
|
||||||
|
@ -66,7 +67,8 @@ class EnvironmentRun():
|
||||||
self.memory.append(*self.memory_queue.get())
|
self.memory.append(*self.memory_queue.get())
|
||||||
|
|
||||||
def _get_reward_logger(self):
|
def _get_reward_logger(self):
|
||||||
return self.logqueue.get()
|
if self.logqueue is not None:
|
||||||
|
return self.logqueue.get()
|
||||||
|
|
||||||
def terminate(self):
|
def terminate(self):
|
||||||
self.runner.terminate()
|
self.runner.terminate()
|
||||||
|
|
Loading…
Reference in a new issue