Corrected for numba deprecation
Enable the ability to render out scenes to play back data
This commit is contained in:
parent
a99ca66b4f
commit
559efa38b0
3 changed files with 20 additions and 12 deletions
|
@ -34,7 +34,8 @@ class QEPAgent:
|
||||||
self.value_net.model.to(self.value_net.device)
|
self.value_net.model.to(self.value_net.device)
|
||||||
self.policy_net.model.state_dict(checkpoint['policy'])
|
self.policy_net.model.state_dict(checkpoint['policy'])
|
||||||
self.policy_net.model.to(self.policy_net.device)
|
self.policy_net.model.to(self.policy_net.device)
|
||||||
self.target_net.sync()
|
if self.target_value_net is not None:
|
||||||
|
self.target_net.sync()
|
||||||
|
|
||||||
def fitness(self, policy_net, value_net, state_batch):
|
def fitness(self, policy_net, value_net, state_batch):
|
||||||
batch_size = len(state_batch)
|
batch_size = len(state_batch)
|
||||||
|
|
11
rltorch/env/simulate.py
vendored
11
rltorch/env/simulate.py
vendored
|
@ -1,7 +1,8 @@
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
import rltorch
|
import rltorch
|
||||||
|
import time
|
||||||
|
|
||||||
def simulateEnvEps(env, actor, config, total_episodes = 1, memory = None, logger = None, name = ""):
|
def simulateEnvEps(env, actor, config, total_episodes = 1, memory = None, logger = None, name = "", render = False):
|
||||||
for episode in range(total_episodes):
|
for episode in range(total_episodes):
|
||||||
state = env.reset()
|
state = env.reset()
|
||||||
done = False
|
done = False
|
||||||
|
@ -9,6 +10,9 @@ def simulateEnvEps(env, actor, config, total_episodes = 1, memory = None, logger
|
||||||
while not done:
|
while not done:
|
||||||
action = actor.act(state)
|
action = actor.act(state)
|
||||||
next_state, reward, done, _ = env.step(action)
|
next_state, reward, done, _ = env.step(action)
|
||||||
|
if render:
|
||||||
|
env.render()
|
||||||
|
time.sleep(0.01)
|
||||||
|
|
||||||
episode_reward = episode_reward + reward
|
episode_reward = episode_reward + reward
|
||||||
if memory is not None:
|
if memory is not None:
|
||||||
|
@ -24,7 +28,7 @@ def simulateEnvEps(env, actor, config, total_episodes = 1, memory = None, logger
|
||||||
|
|
||||||
|
|
||||||
class EnvironmentRunSync():
|
class EnvironmentRunSync():
|
||||||
def __init__(self, env, actor, config, memory = None, logwriter = None, name = ""):
|
def __init__(self, env, actor, config, memory = None, logwriter = None, name = "", render = False):
|
||||||
self.env = env
|
self.env = env
|
||||||
self.name = name
|
self.name = name
|
||||||
self.actor = actor
|
self.actor = actor
|
||||||
|
@ -34,6 +38,7 @@ class EnvironmentRunSync():
|
||||||
self.episode_num = 1
|
self.episode_num = 1
|
||||||
self.episode_reward = 0
|
self.episode_reward = 0
|
||||||
self.last_state = env.reset()
|
self.last_state = env.reset()
|
||||||
|
self.render = render
|
||||||
|
|
||||||
def run(self, iterations):
|
def run(self, iterations):
|
||||||
state = self.last_state
|
state = self.last_state
|
||||||
|
@ -41,6 +46,8 @@ class EnvironmentRunSync():
|
||||||
for _ in range(iterations):
|
for _ in range(iterations):
|
||||||
action = self.actor.act(state)
|
action = self.actor.act(state)
|
||||||
next_state, reward, done, _ = self.env.step(action)
|
next_state, reward, done, _ = self.env.step(action)
|
||||||
|
if self.render:
|
||||||
|
self.env.render()
|
||||||
|
|
||||||
self.episode_reward += reward
|
self.episode_reward += reward
|
||||||
if self.memory is not None:
|
if self.memory is not None:
|
||||||
|
|
|
@ -34,7 +34,7 @@ class SegmentTree(object):
|
||||||
self._value = [neutral_element for _ in range(2 * capacity)]
|
self._value = [neutral_element for _ in range(2 * capacity)]
|
||||||
self._operation = operation
|
self._operation = operation
|
||||||
|
|
||||||
@jit
|
@jit(forceobj = True)
|
||||||
def _reduce_helper(self, start, end, node, node_start, node_end):
|
def _reduce_helper(self, start, end, node, node_start, node_end):
|
||||||
if start == node_start and end == node_end:
|
if start == node_start and end == node_end:
|
||||||
return self._value[node]
|
return self._value[node]
|
||||||
|
@ -50,7 +50,7 @@ class SegmentTree(object):
|
||||||
self._reduce_helper(mid + 1, end, 2 * node + 1, mid + 1, node_end)
|
self._reduce_helper(mid + 1, end, 2 * node + 1, mid + 1, node_end)
|
||||||
)
|
)
|
||||||
|
|
||||||
@jit
|
@jit(forceobj = True)
|
||||||
def reduce(self, start=0, end=None):
|
def reduce(self, start=0, end=None):
|
||||||
"""Returns result of applying `self.operation`
|
"""Returns result of applying `self.operation`
|
||||||
to a contiguous subsequence of the array.
|
to a contiguous subsequence of the array.
|
||||||
|
@ -73,7 +73,7 @@ class SegmentTree(object):
|
||||||
end -= 1
|
end -= 1
|
||||||
return self._reduce_helper(start, end, 1, 0, self._capacity - 1)
|
return self._reduce_helper(start, end, 1, 0, self._capacity - 1)
|
||||||
|
|
||||||
@jit
|
@jit(forceobj = True)
|
||||||
def __setitem__(self, idx, val):
|
def __setitem__(self, idx, val):
|
||||||
# index of the leaf
|
# index of the leaf
|
||||||
idx += self._capacity
|
idx += self._capacity
|
||||||
|
@ -86,7 +86,7 @@ class SegmentTree(object):
|
||||||
)
|
)
|
||||||
idx //= 2
|
idx //= 2
|
||||||
|
|
||||||
@jit
|
@jit(forceobj = True)
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
assert 0 <= idx < self._capacity
|
assert 0 <= idx < self._capacity
|
||||||
return self._value[self._capacity + idx]
|
return self._value[self._capacity + idx]
|
||||||
|
@ -100,12 +100,12 @@ class SumSegmentTree(SegmentTree):
|
||||||
neutral_element=0.0
|
neutral_element=0.0
|
||||||
)
|
)
|
||||||
|
|
||||||
@jit
|
@jit(forceobj = True)
|
||||||
def sum(self, start=0, end=None):
|
def sum(self, start=0, end=None):
|
||||||
"""Returns arr[start] + ... + arr[end]"""
|
"""Returns arr[start] + ... + arr[end]"""
|
||||||
return super(SumSegmentTree, self).reduce(start, end)
|
return super(SumSegmentTree, self).reduce(start, end)
|
||||||
|
|
||||||
@jit
|
@jit(forceobj = True)
|
||||||
def find_prefixsum_idx(self, prefixsum):
|
def find_prefixsum_idx(self, prefixsum):
|
||||||
"""Find the highest index `i` in the array such that
|
"""Find the highest index `i` in the array such that
|
||||||
sum(arr[0] + arr[1] + ... + arr[i - i]) <= prefixsum
|
sum(arr[0] + arr[1] + ... + arr[i - i]) <= prefixsum
|
||||||
|
@ -140,7 +140,7 @@ class MinSegmentTree(SegmentTree):
|
||||||
neutral_element=float('inf')
|
neutral_element=float('inf')
|
||||||
)
|
)
|
||||||
|
|
||||||
@jit
|
@jit(forceobj = True)
|
||||||
def min(self, start=0, end=None):
|
def min(self, start=0, end=None):
|
||||||
"""Returns min(arr[start], ..., arr[end])"""
|
"""Returns min(arr[start], ..., arr[end])"""
|
||||||
return super(MinSegmentTree, self).reduce(start, end)
|
return super(MinSegmentTree, self).reduce(start, end)
|
||||||
|
@ -179,7 +179,7 @@ class PrioritizedReplayMemory(ReplayMemory):
|
||||||
self._it_sum[idx] = self._max_priority ** self._alpha
|
self._it_sum[idx] = self._max_priority ** self._alpha
|
||||||
self._it_min[idx] = self._max_priority ** self._alpha
|
self._it_min[idx] = self._max_priority ** self._alpha
|
||||||
|
|
||||||
@jit
|
@jit(forceobj = True)
|
||||||
def _sample_proportional(self, batch_size):
|
def _sample_proportional(self, batch_size):
|
||||||
res = []
|
res = []
|
||||||
p_total = self._it_sum.sum(0, len(self.memory) - 1)
|
p_total = self._it_sum.sum(0, len(self.memory) - 1)
|
||||||
|
@ -239,7 +239,7 @@ class PrioritizedReplayMemory(ReplayMemory):
|
||||||
batch = list(zip(*encoded_sample, weights, idxes))
|
batch = list(zip(*encoded_sample, weights, idxes))
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
@jit
|
@jit(forceobj = True)
|
||||||
def update_priorities(self, idxes, priorities):
|
def update_priorities(self, idxes, priorities):
|
||||||
"""Update priorities of sampled transitions.
|
"""Update priorities of sampled transitions.
|
||||||
sets priority of transition at index idxes[i] in buffer
|
sets priority of transition at index idxes[i] in buffer
|
||||||
|
|
Loading…
Reference in a new issue