Work towards simplifying ReplayMemory
This commit is contained in:
parent
c6172f309d
commit
cb87105305
1 changed files with 76 additions and 0 deletions
76
rltorch/memory/SimplifiedMemory.py
Normal file
76
rltorch/memory/SimplifiedMemory.py
Normal file
|
@ -0,0 +1,76 @@
|
|||
from random import sample
|
||||
from collections import deque
|
||||
|
||||
class ReplayMemory:
|
||||
"""
|
||||
Creates a queue of a fixed size.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
capacity : int
|
||||
The maximum size of the buffer
|
||||
"""
|
||||
def __init__(self, capacity):
|
||||
self.capacity = capacity
|
||||
self.memory = deque(maxlen=capacity)
|
||||
|
||||
def append(self, **kwargs):
|
||||
"""
|
||||
Adds a transition to the buffer.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
**kwargs
|
||||
The state, action, reward, next_state, done tuple
|
||||
"""
|
||||
self.memory.append(kwargs)
|
||||
|
||||
def clear(self):
|
||||
"""
|
||||
Clears the buffer.
|
||||
"""
|
||||
self.memory.clear()
|
||||
|
||||
def _encode_sample(self, indices):
|
||||
batch = list()
|
||||
for i in indices:
|
||||
batch.append(self.memory[i])
|
||||
return batch
|
||||
|
||||
def sample(self, batch_size):
|
||||
"""
|
||||
Returns a random sample from the buffer.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
batch_size : int
|
||||
The number of observations to sample.
|
||||
"""
|
||||
return sample(self.memory, batch_size)
|
||||
|
||||
def sample_n_steps(self, batch_size, steps):
|
||||
r"""
|
||||
Returns a random sample of sequential batches of size steps.
|
||||
|
||||
Notes
|
||||
-----
|
||||
The number of batches sampled is :math:`\lfloor\frac{batch\_size}{steps}\rfloor`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
batch_size : int
|
||||
The total number of observations to sample.
|
||||
steps : int
|
||||
The number of observations after the one selected to sample.
|
||||
"""
|
||||
idxes = sample(
|
||||
range(len(self.memory) - steps),
|
||||
batch_size // steps
|
||||
)
|
||||
step_idxes = []
|
||||
for i in idxes:
|
||||
step_idxes += range(i, i + steps)
|
||||
return self._encode_sample(step_idxes)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.memory)
|
Loading…
Reference in a new issue