Created documentation for memory module
This commit is contained in:
parent
711c2e8dd1
commit
1cad98fcf9
4 changed files with 119 additions and 16 deletions
|
@ -1,4 +1,8 @@
|
||||||
Memory Structures
|
Memory Structures
|
||||||
=================
|
=================
|
||||||
.. automodule:: rltorch.memory
|
.. autoclass:: rltorch.memory.ReplayMemory
|
||||||
|
:members:
|
||||||
|
.. autoclass:: rltorch.memory.PrioritizedReplayMemory
|
||||||
|
:members:
|
||||||
|
.. autoclass:: rltorch.memory.EpisodeMemory
|
||||||
:members:
|
:members:
|
||||||
|
|
|
@ -5,22 +5,43 @@ Transition = namedtuple('Transition',
|
||||||
('state', 'action', 'reward', 'next_state', 'done'))
|
('state', 'action', 'reward', 'next_state', 'done'))
|
||||||
|
|
||||||
class EpisodeMemory(object):
|
class EpisodeMemory(object):
|
||||||
|
"""
|
||||||
|
Memory structure that stores an entire episode and
|
||||||
|
the observation's associated log-based probabilities.
|
||||||
|
"""
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.memory = []
|
self.memory = []
|
||||||
self.log_probs = []
|
self.log_probs = []
|
||||||
|
|
||||||
def append(self, *args):
|
def append(self, *args):
|
||||||
"""Saves a transition."""
|
"""
|
||||||
|
Adds a transition to the memory.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
*args
|
||||||
|
The state, action, reward, next_state, done tuple
|
||||||
|
"""
|
||||||
self.memory.append(Transition(*args))
|
self.memory.append(Transition(*args))
|
||||||
|
|
||||||
def append_log_probs(self, logprob):
|
def append_log_probs(self, logprob):
|
||||||
|
"""
|
||||||
|
Adds a log-based probability to the observation.
|
||||||
|
"""
|
||||||
self.log_probs.append(logprob)
|
self.log_probs.append(logprob)
|
||||||
|
|
||||||
def clear(self):
|
def clear(self):
|
||||||
|
"""
|
||||||
|
Clears the transitions and log-based probabilities.
|
||||||
|
"""
|
||||||
self.memory.clear()
|
self.memory.clear()
|
||||||
self.log_probs.clear()
|
self.log_probs.clear()
|
||||||
|
|
||||||
def recall(self):
|
def recall(self):
|
||||||
|
"""
|
||||||
|
Return a list of the transitions with their
|
||||||
|
associated log-based probabilities.
|
||||||
|
"""
|
||||||
if len(self.memory) != len(self.log_probs):
|
if len(self.memory) != len(self.log_probs):
|
||||||
raise ValueError("Memory and recorded log probabilities must be the same length.")
|
raise ValueError("Memory and recorded log probabilities must be the same length.")
|
||||||
return list(zip(*tuple(zip(*self.memory)), self.log_probs))
|
return list(zip(*tuple(zip(*self.memory)), self.log_probs))
|
||||||
|
|
|
@ -147,7 +147,9 @@ class MinSegmentTree(SegmentTree):
|
||||||
|
|
||||||
class PrioritizedReplayMemory(ReplayMemory):
|
class PrioritizedReplayMemory(ReplayMemory):
|
||||||
def __init__(self, capacity, alpha):
|
def __init__(self, capacity, alpha):
|
||||||
"""Create Prioritized Replay buffer.
|
"""
|
||||||
|
Create Prioritized Replay buffer.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
capacity: int
|
capacity: int
|
||||||
|
@ -156,9 +158,6 @@ class PrioritizedReplayMemory(ReplayMemory):
|
||||||
alpha: float
|
alpha: float
|
||||||
how much prioritization is used
|
how much prioritization is used
|
||||||
(0 - no prioritization, 1 - full prioritization)
|
(0 - no prioritization, 1 - full prioritization)
|
||||||
See Also
|
|
||||||
--------
|
|
||||||
ReplayBuffer.__init__
|
|
||||||
"""
|
"""
|
||||||
super(PrioritizedReplayMemory, self).__init__(capacity)
|
super(PrioritizedReplayMemory, self).__init__(capacity)
|
||||||
assert alpha >= 0
|
assert alpha >= 0
|
||||||
|
@ -173,7 +172,14 @@ class PrioritizedReplayMemory(ReplayMemory):
|
||||||
self._max_priority = 1.0
|
self._max_priority = 1.0
|
||||||
|
|
||||||
def append(self, *args, **kwargs):
|
def append(self, *args, **kwargs):
|
||||||
"""See ReplayBuffer.store_effect"""
|
"""
|
||||||
|
Adds a transition to the buffer and add an initial prioritization.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
*args
|
||||||
|
The state, action, reward, next_state, done tuple
|
||||||
|
"""
|
||||||
idx = self.position
|
idx = self.position
|
||||||
super().append(*args, **kwargs)
|
super().append(*args, **kwargs)
|
||||||
self._it_sum[idx] = self._max_priority ** self._alpha
|
self._it_sum[idx] = self._max_priority ** self._alpha
|
||||||
|
@ -191,10 +197,11 @@ class PrioritizedReplayMemory(ReplayMemory):
|
||||||
return res
|
return res
|
||||||
|
|
||||||
def sample(self, batch_size, beta):
|
def sample(self, batch_size, beta):
|
||||||
"""Sample a batch of experiences.
|
"""
|
||||||
compared to ReplayBuffer.sample
|
Sample a batch of experiences.
|
||||||
it also returns importance weights and idxes
|
while returning importance weights and idxes
|
||||||
of sampled experiences.
|
of sampled experiences.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
batch_size: int
|
batch_size: int
|
||||||
|
@ -202,6 +209,7 @@ class PrioritizedReplayMemory(ReplayMemory):
|
||||||
beta: float
|
beta: float
|
||||||
To what degree to use importance weights
|
To what degree to use importance weights
|
||||||
(0 - no corrections, 1 - full correction)
|
(0 - no corrections, 1 - full correction)
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
weights: np.array
|
weights: np.array
|
||||||
|
@ -232,6 +240,32 @@ class PrioritizedReplayMemory(ReplayMemory):
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
def sample_n_steps(self, batch_size, steps, beta):
|
def sample_n_steps(self, batch_size, steps, beta):
|
||||||
|
r"""
|
||||||
|
Sample a batch of sequential experiences.
|
||||||
|
while returning importance weights and idxes
|
||||||
|
of sampled experiences.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
batch_size: int
|
||||||
|
How many transitions to sample.
|
||||||
|
beta: float
|
||||||
|
To what degree to use importance weights
|
||||||
|
(0 - no corrections, 1 - full correction)
|
||||||
|
|
||||||
|
Notes
|
||||||
|
-----
|
||||||
|
The number of batches sampled is :math:`\lfloor\frac{batch\_size}{steps}\rfloor`.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
weights: np.array
|
||||||
|
Array of shape (batch_size,) and dtype np.float32
|
||||||
|
denoting importance weight of each sampled transition
|
||||||
|
idxes: np.array
|
||||||
|
Array of shape (batch_size,) and dtype np.int32
|
||||||
|
idexes in buffer of sampled experiences
|
||||||
|
"""
|
||||||
assert beta > 0
|
assert beta > 0
|
||||||
|
|
||||||
sample_size = batch_size // steps
|
sample_size = batch_size // steps
|
||||||
|
@ -262,9 +296,11 @@ class PrioritizedReplayMemory(ReplayMemory):
|
||||||
|
|
||||||
@jit(forceobj = True)
|
@jit(forceobj = True)
|
||||||
def update_priorities(self, idxes, priorities):
|
def update_priorities(self, idxes, priorities):
|
||||||
"""Update priorities of sampled transitions.
|
"""
|
||||||
|
Update priorities of sampled transitions.
|
||||||
sets priority of transition at index idxes[i] in buffer
|
sets priority of transition at index idxes[i] in buffer
|
||||||
to priorities[i].
|
to priorities[i].
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
idxes: [int]
|
idxes: [int]
|
||||||
|
|
|
@ -4,21 +4,38 @@ import torch
|
||||||
Transition = namedtuple('Transition',
|
Transition = namedtuple('Transition',
|
||||||
('state', 'action', 'reward', 'next_state', 'done'))
|
('state', 'action', 'reward', 'next_state', 'done'))
|
||||||
|
|
||||||
# Implements a Ring Buffer
|
|
||||||
class ReplayMemory(object):
|
class ReplayMemory(object):
|
||||||
|
"""
|
||||||
|
Creates a ring buffer of a fixed size.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
capacity : int
|
||||||
|
The maximum size of the buffer
|
||||||
|
"""
|
||||||
def __init__(self, capacity):
|
def __init__(self, capacity):
|
||||||
self.capacity = capacity
|
self.capacity = capacity
|
||||||
self.memory = []
|
self.memory = []
|
||||||
self.position = 0
|
self.position = 0
|
||||||
|
|
||||||
def append(self, *args):
|
def append(self, *args):
|
||||||
"""Saves a transition."""
|
"""
|
||||||
|
Adds a transition to the buffer.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
*args
|
||||||
|
The state, action, reward, next_state, done tuple
|
||||||
|
"""
|
||||||
if len(self.memory) < self.capacity:
|
if len(self.memory) < self.capacity:
|
||||||
self.memory.append(None)
|
self.memory.append(None)
|
||||||
self.memory[self.position] = Transition(*args)
|
self.memory[self.position] = Transition(*args)
|
||||||
self.position = (self.position + 1) % self.capacity
|
self.position = (self.position + 1) % self.capacity
|
||||||
|
|
||||||
def clear(self):
|
def clear(self):
|
||||||
|
"""
|
||||||
|
Clears the buffer.
|
||||||
|
"""
|
||||||
self.memory.clear()
|
self.memory.clear()
|
||||||
self.position = 0
|
self.position = 0
|
||||||
|
|
||||||
|
@ -37,10 +54,35 @@ class ReplayMemory(object):
|
||||||
|
|
||||||
|
|
||||||
def sample(self, batch_size):
|
def sample(self, batch_size):
|
||||||
|
"""
|
||||||
|
Returns a random sample from the buffer.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
batch_size : int
|
||||||
|
The number of observations to sample.
|
||||||
|
"""
|
||||||
return random.sample(self.memory, batch_size)
|
return random.sample(self.memory, batch_size)
|
||||||
|
|
||||||
def sample_n_steps(self, batch_size, steps):
|
def sample_n_steps(self, batch_size, steps):
|
||||||
idxes = random.sample(range(len(self.memory) - steps), batch_size // steps)
|
r"""
|
||||||
|
Returns a random sample of sequential batches of size steps.
|
||||||
|
|
||||||
|
Notes
|
||||||
|
-----
|
||||||
|
The number of batches sampled is :math:`\lfloor\frac{batch\_size}{steps}\rfloor`.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
batch_size : int
|
||||||
|
The total number of observations to sample.
|
||||||
|
steps : int
|
||||||
|
The number of observations after the one selected to sample.
|
||||||
|
"""
|
||||||
|
idxes = random.sample(
|
||||||
|
range(len(self.memory) - steps),
|
||||||
|
batch_size // steps
|
||||||
|
)
|
||||||
step_idxes = []
|
step_idxes = []
|
||||||
for i in idxes:
|
for i in idxes:
|
||||||
step_idxes += range(i, i + steps)
|
step_idxes += range(i, i + steps)
|
||||||
|
@ -56,10 +98,10 @@ class ReplayMemory(object):
|
||||||
return value in self.memory
|
return value in self.memory
|
||||||
|
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
return self.memory[index]
|
return self.memory[index % self.capacity]
|
||||||
|
|
||||||
def __setitem__(self, index, value):
|
def __setitem__(self, index, value):
|
||||||
self.memory[index] = value
|
self.memory[index % self.capacity] = value
|
||||||
|
|
||||||
def __reversed__(self):
|
def __reversed__(self):
|
||||||
return reversed(self.memory)
|
return reversed(self.memory)
|
||||||
|
|
Loading…
Reference in a new issue