Created documentation for memory module

This commit is contained in:
Brandon Rozek 2020-03-20 19:31:09 -04:00
parent 711c2e8dd1
commit 1cad98fcf9
4 changed files with 119 additions and 16 deletions

View file

@ -1,4 +1,8 @@
Memory Structures Memory Structures
================= =================
.. automodule:: rltorch.memory .. autoclass:: rltorch.memory.ReplayMemory
:members:
.. autoclass:: rltorch.memory.PrioritizedReplayMemory
:members:
.. autoclass:: rltorch.memory.EpisodeMemory
:members: :members:

View file

@ -5,22 +5,43 @@ Transition = namedtuple('Transition',
('state', 'action', 'reward', 'next_state', 'done')) ('state', 'action', 'reward', 'next_state', 'done'))
class EpisodeMemory(object): class EpisodeMemory(object):
"""
Memory structure that stores an entire episode and
the observation's associated log-based probabilities.
"""
def __init__(self): def __init__(self):
self.memory = [] self.memory = []
self.log_probs = [] self.log_probs = []
def append(self, *args): def append(self, *args):
"""Saves a transition.""" """
Adds a transition to the memory.
Parameters
----------
*args
The state, action, reward, next_state, done tuple
"""
self.memory.append(Transition(*args)) self.memory.append(Transition(*args))
def append_log_probs(self, logprob): def append_log_probs(self, logprob):
"""
Adds a log-based probability to the observation.
"""
self.log_probs.append(logprob) self.log_probs.append(logprob)
def clear(self): def clear(self):
"""
Clears the transitions and log-based probabilities.
"""
self.memory.clear() self.memory.clear()
self.log_probs.clear() self.log_probs.clear()
def recall(self): def recall(self):
"""
Return a list of the transitions with their
associated log-based probabilities.
"""
if len(self.memory) != len(self.log_probs): if len(self.memory) != len(self.log_probs):
raise ValueError("Memory and recorded log probabilities must be the same length.") raise ValueError("Memory and recorded log probabilities must be the same length.")
return list(zip(*tuple(zip(*self.memory)), self.log_probs)) return list(zip(*tuple(zip(*self.memory)), self.log_probs))

View file

@ -147,7 +147,9 @@ class MinSegmentTree(SegmentTree):
class PrioritizedReplayMemory(ReplayMemory): class PrioritizedReplayMemory(ReplayMemory):
def __init__(self, capacity, alpha): def __init__(self, capacity, alpha):
"""Create Prioritized Replay buffer. """
Create Prioritized Replay buffer.
Parameters Parameters
---------- ----------
capacity: int capacity: int
@ -156,9 +158,6 @@ class PrioritizedReplayMemory(ReplayMemory):
alpha: float alpha: float
how much prioritization is used how much prioritization is used
(0 - no prioritization, 1 - full prioritization) (0 - no prioritization, 1 - full prioritization)
See Also
--------
ReplayBuffer.__init__
""" """
super(PrioritizedReplayMemory, self).__init__(capacity) super(PrioritizedReplayMemory, self).__init__(capacity)
assert alpha >= 0 assert alpha >= 0
@ -173,7 +172,14 @@ class PrioritizedReplayMemory(ReplayMemory):
self._max_priority = 1.0 self._max_priority = 1.0
def append(self, *args, **kwargs): def append(self, *args, **kwargs):
"""See ReplayBuffer.store_effect""" """
Adds a transition to the buffer and add an initial prioritization.
Parameters
----------
*args
The state, action, reward, next_state, done tuple
"""
idx = self.position idx = self.position
super().append(*args, **kwargs) super().append(*args, **kwargs)
self._it_sum[idx] = self._max_priority ** self._alpha self._it_sum[idx] = self._max_priority ** self._alpha
@ -191,10 +197,11 @@ class PrioritizedReplayMemory(ReplayMemory):
return res return res
def sample(self, batch_size, beta): def sample(self, batch_size, beta):
"""Sample a batch of experiences. """
compared to ReplayBuffer.sample Sample a batch of experiences.
it also returns importance weights and idxes while returning importance weights and idxes
of sampled experiences. of sampled experiences.
Parameters Parameters
---------- ----------
batch_size: int batch_size: int
@ -202,6 +209,7 @@ class PrioritizedReplayMemory(ReplayMemory):
beta: float beta: float
To what degree to use importance weights To what degree to use importance weights
(0 - no corrections, 1 - full correction) (0 - no corrections, 1 - full correction)
Returns Returns
------- -------
weights: np.array weights: np.array
@ -232,6 +240,32 @@ class PrioritizedReplayMemory(ReplayMemory):
return batch return batch
def sample_n_steps(self, batch_size, steps, beta): def sample_n_steps(self, batch_size, steps, beta):
r"""
Sample a batch of sequential experiences.
while returning importance weights and idxes
of sampled experiences.
Parameters
----------
batch_size: int
How many transitions to sample.
beta: float
To what degree to use importance weights
(0 - no corrections, 1 - full correction)
Notes
-----
The number of batches sampled is :math:`\lfloor\frac{batch\_size}{steps}\rfloor`.
Returns
-------
weights: np.array
Array of shape (batch_size,) and dtype np.float32
denoting importance weight of each sampled transition
idxes: np.array
Array of shape (batch_size,) and dtype np.int32
idexes in buffer of sampled experiences
"""
assert beta > 0 assert beta > 0
sample_size = batch_size // steps sample_size = batch_size // steps
@ -262,9 +296,11 @@ class PrioritizedReplayMemory(ReplayMemory):
@jit(forceobj = True) @jit(forceobj = True)
def update_priorities(self, idxes, priorities): def update_priorities(self, idxes, priorities):
"""Update priorities of sampled transitions. """
Update priorities of sampled transitions.
sets priority of transition at index idxes[i] in buffer sets priority of transition at index idxes[i] in buffer
to priorities[i]. to priorities[i].
Parameters Parameters
---------- ----------
idxes: [int] idxes: [int]

View file

@ -4,21 +4,38 @@ import torch
Transition = namedtuple('Transition', Transition = namedtuple('Transition',
('state', 'action', 'reward', 'next_state', 'done')) ('state', 'action', 'reward', 'next_state', 'done'))
# Implements a Ring Buffer
class ReplayMemory(object): class ReplayMemory(object):
"""
Creates a ring buffer of a fixed size.
Parameters
----------
capacity : int
The maximum size of the buffer
"""
def __init__(self, capacity): def __init__(self, capacity):
self.capacity = capacity self.capacity = capacity
self.memory = [] self.memory = []
self.position = 0 self.position = 0
def append(self, *args): def append(self, *args):
"""Saves a transition.""" """
Adds a transition to the buffer.
Parameters
----------
*args
The state, action, reward, next_state, done tuple
"""
if len(self.memory) < self.capacity: if len(self.memory) < self.capacity:
self.memory.append(None) self.memory.append(None)
self.memory[self.position] = Transition(*args) self.memory[self.position] = Transition(*args)
self.position = (self.position + 1) % self.capacity self.position = (self.position + 1) % self.capacity
def clear(self): def clear(self):
"""
Clears the buffer.
"""
self.memory.clear() self.memory.clear()
self.position = 0 self.position = 0
@ -37,10 +54,35 @@ class ReplayMemory(object):
def sample(self, batch_size): def sample(self, batch_size):
"""
Returns a random sample from the buffer.
Parameters
----------
batch_size : int
The number of observations to sample.
"""
return random.sample(self.memory, batch_size) return random.sample(self.memory, batch_size)
def sample_n_steps(self, batch_size, steps): def sample_n_steps(self, batch_size, steps):
idxes = random.sample(range(len(self.memory) - steps), batch_size // steps) r"""
Returns a random sample of sequential batches of size steps.
Notes
-----
The number of batches sampled is :math:`\lfloor\frac{batch\_size}{steps}\rfloor`.
Parameters
----------
batch_size : int
The total number of observations to sample.
steps : int
The number of observations after the one selected to sample.
"""
idxes = random.sample(
range(len(self.memory) - steps),
batch_size // steps
)
step_idxes = [] step_idxes = []
for i in idxes: for i in idxes:
step_idxes += range(i, i + steps) step_idxes += range(i, i + steps)
@ -56,10 +98,10 @@ class ReplayMemory(object):
return value in self.memory return value in self.memory
def __getitem__(self, index): def __getitem__(self, index):
return self.memory[index] return self.memory[index % self.capacity]
def __setitem__(self, index, value): def __setitem__(self, index, value):
self.memory[index] = value self.memory[index % self.capacity] = value
def __reversed__(self): def __reversed__(self):
return reversed(self.memory) return reversed(self.memory)