From 2caf869fd6cc25bcaca19e095f24f553b25ec689 Mon Sep 17 00:00:00 2001 From: Brandon Rozek Date: Thu, 14 Feb 2019 21:42:31 -0500 Subject: [PATCH] Added numba as a dependency and decorated the Prioiritzed Replay function --- requirements.txt | 3 ++- rltorch/memory/PrioritizedReplayMemory.py | 10 ++++++++++ rltorch/memory/ReplayMemory.py | 2 ++ 3 files changed, 14 insertions(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index fb01538..9f97f26 100644 --- a/requirements.txt +++ b/requirements.txt @@ -28,4 +28,5 @@ tensorflow==1.12.0 termcolor==1.1.0 torch==1.0.0 urllib3==1.24.1 -Werkzeug==0.14.1 \ No newline at end of file +Werkzeug==0.14.1 +numba==0.42.1 diff --git a/rltorch/memory/PrioritizedReplayMemory.py b/rltorch/memory/PrioritizedReplayMemory.py index 4f738bf..efa477e 100644 --- a/rltorch/memory/PrioritizedReplayMemory.py +++ b/rltorch/memory/PrioritizedReplayMemory.py @@ -4,6 +4,7 @@ from .ReplayMemory import ReplayMemory import operator import random import numpy as np +from numba import jit class SegmentTree(object): def __init__(self, capacity, operation, neutral_element): @@ -33,6 +34,7 @@ class SegmentTree(object): self._value = [neutral_element for _ in range(2 * capacity)] self._operation = operation + @jit def _reduce_helper(self, start, end, node, node_start, node_end): if start == node_start and end == node_end: return self._value[node] @@ -48,6 +50,7 @@ class SegmentTree(object): self._reduce_helper(mid + 1, end, 2 * node + 1, mid + 1, node_end) ) + @jit def reduce(self, start=0, end=None): """Returns result of applying `self.operation` to a contiguous subsequence of the array. @@ -70,6 +73,7 @@ class SegmentTree(object): end -= 1 return self._reduce_helper(start, end, 1, 0, self._capacity - 1) + @jit def __setitem__(self, idx, val): # index of the leaf idx += self._capacity @@ -82,6 +86,7 @@ class SegmentTree(object): ) idx //= 2 + @jit def __getitem__(self, idx): assert 0 <= idx < self._capacity return self._value[self._capacity + idx] @@ -95,10 +100,12 @@ class SumSegmentTree(SegmentTree): neutral_element=0.0 ) + @jit def sum(self, start=0, end=None): """Returns arr[start] + ... + arr[end]""" return super(SumSegmentTree, self).reduce(start, end) + @jit def find_prefixsum_idx(self, prefixsum): """Find the highest index `i` in the array such that sum(arr[0] + arr[1] + ... + arr[i - i]) <= prefixsum @@ -133,6 +140,7 @@ class MinSegmentTree(SegmentTree): neutral_element=float('inf') ) + @jit def min(self, start=0, end=None): """Returns min(arr[start], ..., arr[end])""" return super(MinSegmentTree, self).reduce(start, end) @@ -171,6 +179,7 @@ class PrioritizedReplayMemory(ReplayMemory): self._it_sum[idx] = self._max_priority ** self._alpha self._it_min[idx] = self._max_priority ** self._alpha + @jit def _sample_proportional(self, batch_size): res = [] p_total = self._it_sum.sum(0, len(self.memory) - 1) @@ -230,6 +239,7 @@ class PrioritizedReplayMemory(ReplayMemory): batch = list(zip(*encoded_sample, weights, idxes)) return batch + @jit def update_priorities(self, idxes, priorities): """Update priorities of sampled transitions. sets priority of transition at index idxes[i] in buffer diff --git a/rltorch/memory/ReplayMemory.py b/rltorch/memory/ReplayMemory.py index 367b9c9..34e3571 100644 --- a/rltorch/memory/ReplayMemory.py +++ b/rltorch/memory/ReplayMemory.py @@ -15,6 +15,8 @@ class ReplayMemory(object): """Saves a transition.""" if len(self.memory) < self.capacity: self.memory.append(None) + if self.memory[self.position] is not None: + del self.memory[self.position] self.memory[self.position] = Transition(*args) self.position = (self.position + 1) % self.capacity