From 559efa38b08e956ff4c7f0417a451a635240487b Mon Sep 17 00:00:00 2001 From: Brandon Rozek Date: Thu, 19 Sep 2019 07:57:39 -0400 Subject: [PATCH] Corrected for numba deprecation Enable the ability to render out scenes to play back data --- rltorch/agents/QEPAgent.py | 3 ++- rltorch/env/simulate.py | 11 +++++++++-- rltorch/memory/PrioritizedReplayMemory.py | 18 +++++++++--------- 3 files changed, 20 insertions(+), 12 deletions(-) diff --git a/rltorch/agents/QEPAgent.py b/rltorch/agents/QEPAgent.py index de203b0..a1ae0dd 100644 --- a/rltorch/agents/QEPAgent.py +++ b/rltorch/agents/QEPAgent.py @@ -34,7 +34,8 @@ class QEPAgent: self.value_net.model.to(self.value_net.device) self.policy_net.model.state_dict(checkpoint['policy']) self.policy_net.model.to(self.policy_net.device) - self.target_net.sync() + if self.target_value_net is not None: + self.target_net.sync() def fitness(self, policy_net, value_net, state_batch): batch_size = len(state_batch) diff --git a/rltorch/env/simulate.py b/rltorch/env/simulate.py index a792ccb..ec544cb 100644 --- a/rltorch/env/simulate.py +++ b/rltorch/env/simulate.py @@ -1,7 +1,8 @@ from copy import deepcopy import rltorch +import time -def simulateEnvEps(env, actor, config, total_episodes = 1, memory = None, logger = None, name = ""): +def simulateEnvEps(env, actor, config, total_episodes = 1, memory = None, logger = None, name = "", render = False): for episode in range(total_episodes): state = env.reset() done = False @@ -9,6 +10,9 @@ def simulateEnvEps(env, actor, config, total_episodes = 1, memory = None, logger while not done: action = actor.act(state) next_state, reward, done, _ = env.step(action) + if render: + env.render() + time.sleep(0.01) episode_reward = episode_reward + reward if memory is not None: @@ -24,7 +28,7 @@ def simulateEnvEps(env, actor, config, total_episodes = 1, memory = None, logger class EnvironmentRunSync(): - def __init__(self, env, actor, config, memory = None, logwriter = None, name = ""): + def __init__(self, env, actor, config, memory = None, logwriter = None, name = "", render = False): self.env = env self.name = name self.actor = actor @@ -34,6 +38,7 @@ class EnvironmentRunSync(): self.episode_num = 1 self.episode_reward = 0 self.last_state = env.reset() + self.render = render def run(self, iterations): state = self.last_state @@ -41,6 +46,8 @@ class EnvironmentRunSync(): for _ in range(iterations): action = self.actor.act(state) next_state, reward, done, _ = self.env.step(action) + if self.render: + self.env.render() self.episode_reward += reward if self.memory is not None: diff --git a/rltorch/memory/PrioritizedReplayMemory.py b/rltorch/memory/PrioritizedReplayMemory.py index efa477e..00b1d6e 100644 --- a/rltorch/memory/PrioritizedReplayMemory.py +++ b/rltorch/memory/PrioritizedReplayMemory.py @@ -34,7 +34,7 @@ class SegmentTree(object): self._value = [neutral_element for _ in range(2 * capacity)] self._operation = operation - @jit + @jit(forceobj = True) def _reduce_helper(self, start, end, node, node_start, node_end): if start == node_start and end == node_end: return self._value[node] @@ -50,7 +50,7 @@ class SegmentTree(object): self._reduce_helper(mid + 1, end, 2 * node + 1, mid + 1, node_end) ) - @jit + @jit(forceobj = True) def reduce(self, start=0, end=None): """Returns result of applying `self.operation` to a contiguous subsequence of the array. @@ -73,7 +73,7 @@ class SegmentTree(object): end -= 1 return self._reduce_helper(start, end, 1, 0, self._capacity - 1) - @jit + @jit(forceobj = True) def __setitem__(self, idx, val): # index of the leaf idx += self._capacity @@ -86,7 +86,7 @@ class SegmentTree(object): ) idx //= 2 - @jit + @jit(forceobj = True) def __getitem__(self, idx): assert 0 <= idx < self._capacity return self._value[self._capacity + idx] @@ -100,12 +100,12 @@ class SumSegmentTree(SegmentTree): neutral_element=0.0 ) - @jit + @jit(forceobj = True) def sum(self, start=0, end=None): """Returns arr[start] + ... + arr[end]""" return super(SumSegmentTree, self).reduce(start, end) - @jit + @jit(forceobj = True) def find_prefixsum_idx(self, prefixsum): """Find the highest index `i` in the array such that sum(arr[0] + arr[1] + ... + arr[i - i]) <= prefixsum @@ -140,7 +140,7 @@ class MinSegmentTree(SegmentTree): neutral_element=float('inf') ) - @jit + @jit(forceobj = True) def min(self, start=0, end=None): """Returns min(arr[start], ..., arr[end])""" return super(MinSegmentTree, self).reduce(start, end) @@ -179,7 +179,7 @@ class PrioritizedReplayMemory(ReplayMemory): self._it_sum[idx] = self._max_priority ** self._alpha self._it_min[idx] = self._max_priority ** self._alpha - @jit + @jit(forceobj = True) def _sample_proportional(self, batch_size): res = [] p_total = self._it_sum.sum(0, len(self.memory) - 1) @@ -239,7 +239,7 @@ class PrioritizedReplayMemory(ReplayMemory): batch = list(zip(*encoded_sample, weights, idxes)) return batch - @jit + @jit(forceobj = True) def update_priorities(self, idxes, priorities): """Update priorities of sampled transitions. sets priority of transition at index idxes[i] in buffer