Added numba as a dependency and decorated the Prioiritzed Replay function
This commit is contained in:
parent
19a859a4f6
commit
2caf869fd6
3 changed files with 14 additions and 1 deletions
|
@ -29,3 +29,4 @@ termcolor==1.1.0
|
|||
torch==1.0.0
|
||||
urllib3==1.24.1
|
||||
Werkzeug==0.14.1
|
||||
numba==0.42.1
|
||||
|
|
|
@ -4,6 +4,7 @@ from .ReplayMemory import ReplayMemory
|
|||
import operator
|
||||
import random
|
||||
import numpy as np
|
||||
from numba import jit
|
||||
|
||||
class SegmentTree(object):
|
||||
def __init__(self, capacity, operation, neutral_element):
|
||||
|
@ -33,6 +34,7 @@ class SegmentTree(object):
|
|||
self._value = [neutral_element for _ in range(2 * capacity)]
|
||||
self._operation = operation
|
||||
|
||||
@jit
|
||||
def _reduce_helper(self, start, end, node, node_start, node_end):
|
||||
if start == node_start and end == node_end:
|
||||
return self._value[node]
|
||||
|
@ -48,6 +50,7 @@ class SegmentTree(object):
|
|||
self._reduce_helper(mid + 1, end, 2 * node + 1, mid + 1, node_end)
|
||||
)
|
||||
|
||||
@jit
|
||||
def reduce(self, start=0, end=None):
|
||||
"""Returns result of applying `self.operation`
|
||||
to a contiguous subsequence of the array.
|
||||
|
@ -70,6 +73,7 @@ class SegmentTree(object):
|
|||
end -= 1
|
||||
return self._reduce_helper(start, end, 1, 0, self._capacity - 1)
|
||||
|
||||
@jit
|
||||
def __setitem__(self, idx, val):
|
||||
# index of the leaf
|
||||
idx += self._capacity
|
||||
|
@ -82,6 +86,7 @@ class SegmentTree(object):
|
|||
)
|
||||
idx //= 2
|
||||
|
||||
@jit
|
||||
def __getitem__(self, idx):
|
||||
assert 0 <= idx < self._capacity
|
||||
return self._value[self._capacity + idx]
|
||||
|
@ -95,10 +100,12 @@ class SumSegmentTree(SegmentTree):
|
|||
neutral_element=0.0
|
||||
)
|
||||
|
||||
@jit
|
||||
def sum(self, start=0, end=None):
|
||||
"""Returns arr[start] + ... + arr[end]"""
|
||||
return super(SumSegmentTree, self).reduce(start, end)
|
||||
|
||||
@jit
|
||||
def find_prefixsum_idx(self, prefixsum):
|
||||
"""Find the highest index `i` in the array such that
|
||||
sum(arr[0] + arr[1] + ... + arr[i - i]) <= prefixsum
|
||||
|
@ -133,6 +140,7 @@ class MinSegmentTree(SegmentTree):
|
|||
neutral_element=float('inf')
|
||||
)
|
||||
|
||||
@jit
|
||||
def min(self, start=0, end=None):
|
||||
"""Returns min(arr[start], ..., arr[end])"""
|
||||
return super(MinSegmentTree, self).reduce(start, end)
|
||||
|
@ -171,6 +179,7 @@ class PrioritizedReplayMemory(ReplayMemory):
|
|||
self._it_sum[idx] = self._max_priority ** self._alpha
|
||||
self._it_min[idx] = self._max_priority ** self._alpha
|
||||
|
||||
@jit
|
||||
def _sample_proportional(self, batch_size):
|
||||
res = []
|
||||
p_total = self._it_sum.sum(0, len(self.memory) - 1)
|
||||
|
@ -230,6 +239,7 @@ class PrioritizedReplayMemory(ReplayMemory):
|
|||
batch = list(zip(*encoded_sample, weights, idxes))
|
||||
return batch
|
||||
|
||||
@jit
|
||||
def update_priorities(self, idxes, priorities):
|
||||
"""Update priorities of sampled transitions.
|
||||
sets priority of transition at index idxes[i] in buffer
|
||||
|
|
|
@ -15,6 +15,8 @@ class ReplayMemory(object):
|
|||
"""Saves a transition."""
|
||||
if len(self.memory) < self.capacity:
|
||||
self.memory.append(None)
|
||||
if self.memory[self.position] is not None:
|
||||
del self.memory[self.position]
|
||||
self.memory[self.position] = Transition(*args)
|
||||
self.position = (self.position + 1) % self.capacity
|
||||
|
||||
|
|
Loading…
Reference in a new issue