Implemented Schedulers and Prioritized Replay
This commit is contained in:
parent
8c78f47c0c
commit
013d40a4f9
10 changed files with 348 additions and 12 deletions
|
@ -4,5 +4,6 @@ from . import env
|
||||||
from . import memory
|
from . import memory
|
||||||
from . import network
|
from . import network
|
||||||
from . import mp
|
from . import mp
|
||||||
|
from . import scheduler
|
||||||
from .seed import *
|
from .seed import *
|
||||||
from . import log
|
from . import log
|
|
@ -1,7 +1,9 @@
|
||||||
|
import collections
|
||||||
import rltorch.memory as M
|
import rltorch.memory as M
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
class DQNAgent:
|
class DQNAgent:
|
||||||
def __init__(self, net , memory, config, target_net = None, logger = None):
|
def __init__(self, net , memory, config, target_net = None, logger = None):
|
||||||
|
@ -15,6 +17,13 @@ class DQNAgent:
|
||||||
if len(self.memory) < self.config['batch_size']:
|
if len(self.memory) < self.config['batch_size']:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
if (isinstance(self.memory, M.PrioritizedReplayMemory)):
|
||||||
|
weight_importance = self.config['prioritized_replay_weight_importance']
|
||||||
|
# If it's a scheduler then get the next value by calling next, otherwise just use it's value
|
||||||
|
beta = next(weight_importance) if isinstance(weight_importance, collections.Iterable) else weight_importance
|
||||||
|
minibatch = self.memory.sample(self.config['batch_size'], beta = beta)
|
||||||
|
state_batch, action_batch, reward_batch, next_state_batch, not_done_batch, importance_weights, batch_indexes = M.zip_batch(minibatch, priority = True)
|
||||||
|
else:
|
||||||
minibatch = self.memory.sample(self.config['batch_size'])
|
minibatch = self.memory.sample(self.config['batch_size'])
|
||||||
state_batch, action_batch, reward_batch, next_state_batch, not_done_batch = M.zip_batch(minibatch)
|
state_batch, action_batch, reward_batch, next_state_batch, not_done_batch = M.zip_batch(minibatch)
|
||||||
|
|
||||||
|
@ -44,6 +53,9 @@ class DQNAgent:
|
||||||
|
|
||||||
expected_values = (reward_batch + (self.config['discount_rate'] * best_next_state_value)).unsqueeze(1)
|
expected_values = (reward_batch + (self.config['discount_rate'] * best_next_state_value)).unsqueeze(1)
|
||||||
|
|
||||||
|
if (isinstance(self.memory, M.PrioritizedReplayMemory)):
|
||||||
|
loss = (torch.as_tensor(importance_weights) * (obtained_values - expected_values)**2).mean()
|
||||||
|
else:
|
||||||
loss = F.mse_loss(obtained_values, expected_values)
|
loss = F.mse_loss(obtained_values, expected_values)
|
||||||
|
|
||||||
if self.logger is not None:
|
if self.logger is not None:
|
||||||
|
@ -59,3 +71,9 @@ class DQNAgent:
|
||||||
self.target_net.partial_sync(self.config['target_sync_tau'])
|
self.target_net.partial_sync(self.config['target_sync_tau'])
|
||||||
else:
|
else:
|
||||||
self.target_net.sync()
|
self.target_net.sync()
|
||||||
|
|
||||||
|
if (isinstance(self.memory, M.PrioritizedReplayMemory)):
|
||||||
|
td_error = (obtained_values - expected_values).detach().abs()
|
||||||
|
self.memory.update_priorities(batch_indexes, td_error)
|
||||||
|
|
||||||
|
|
||||||
|
|
255
rltorch/memory/PrioritizedReplayMemory.py
Normal file
255
rltorch/memory/PrioritizedReplayMemory.py
Normal file
|
@ -0,0 +1,255 @@
|
||||||
|
# From OpenAI Baselines https://github.com/openai/baselines/blob/master/baselines/deepq/replay_buffer.py
|
||||||
|
|
||||||
|
from .ReplayMemory import ReplayMemory
|
||||||
|
import operator
|
||||||
|
import random
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
class SegmentTree(object):
|
||||||
|
def __init__(self, capacity, operation, neutral_element):
|
||||||
|
"""Build a Segment Tree data structure.
|
||||||
|
https://en.wikipedia.org/wiki/Segment_tree
|
||||||
|
Can be used as regular array, but with two
|
||||||
|
important differences:
|
||||||
|
a) setting item's value is slightly slower.
|
||||||
|
It is O(lg capacity) instead of O(1).
|
||||||
|
b) user has access to an efficient ( O(log segment size) )
|
||||||
|
`reduce` operation which reduces `operation` over
|
||||||
|
a contiguous subsequence of items in the array.
|
||||||
|
Paramters
|
||||||
|
---------
|
||||||
|
capacity: int
|
||||||
|
Total size of the array - must be a power of two.
|
||||||
|
operation: lambda obj, obj -> obj
|
||||||
|
and operation for combining elements (eg. sum, max)
|
||||||
|
must form a mathematical group together with the set of
|
||||||
|
possible values for array elements (i.e. be associative)
|
||||||
|
neutral_element: obj
|
||||||
|
neutral element for the operation above. eg. float('-inf')
|
||||||
|
for max and 0 for sum.
|
||||||
|
"""
|
||||||
|
assert capacity > 0 and capacity & (capacity - 1) == 0, "capacity must be positive and a power of 2."
|
||||||
|
self._capacity = capacity
|
||||||
|
self._value = [neutral_element for _ in range(2 * capacity)]
|
||||||
|
self._operation = operation
|
||||||
|
|
||||||
|
def _reduce_helper(self, start, end, node, node_start, node_end):
|
||||||
|
if start == node_start and end == node_end:
|
||||||
|
return self._value[node]
|
||||||
|
mid = (node_start + node_end) // 2
|
||||||
|
if end <= mid:
|
||||||
|
return self._reduce_helper(start, end, 2 * node, node_start, mid)
|
||||||
|
else:
|
||||||
|
if mid + 1 <= start:
|
||||||
|
return self._reduce_helper(start, end, 2 * node + 1, mid + 1, node_end)
|
||||||
|
else:
|
||||||
|
return self._operation(
|
||||||
|
self._reduce_helper(start, mid, 2 * node, node_start, mid),
|
||||||
|
self._reduce_helper(mid + 1, end, 2 * node + 1, mid + 1, node_end)
|
||||||
|
)
|
||||||
|
|
||||||
|
def reduce(self, start=0, end=None):
|
||||||
|
"""Returns result of applying `self.operation`
|
||||||
|
to a contiguous subsequence of the array.
|
||||||
|
self.operation(arr[start], operation(arr[start+1], operation(... arr[end])))
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
start: int
|
||||||
|
beginning of the subsequence
|
||||||
|
end: int
|
||||||
|
end of the subsequences
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
reduced: obj
|
||||||
|
result of reducing self.operation over the specified range of array elements.
|
||||||
|
"""
|
||||||
|
if end is None:
|
||||||
|
end = self._capacity
|
||||||
|
if end < 0:
|
||||||
|
end += self._capacity
|
||||||
|
end -= 1
|
||||||
|
return self._reduce_helper(start, end, 1, 0, self._capacity - 1)
|
||||||
|
|
||||||
|
def __setitem__(self, idx, val):
|
||||||
|
# index of the leaf
|
||||||
|
idx += self._capacity
|
||||||
|
self._value[idx] = val
|
||||||
|
idx //= 2
|
||||||
|
while idx >= 1:
|
||||||
|
self._value[idx] = self._operation(
|
||||||
|
self._value[2 * idx],
|
||||||
|
self._value[2 * idx + 1]
|
||||||
|
)
|
||||||
|
idx //= 2
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
assert 0 <= idx < self._capacity
|
||||||
|
return self._value[self._capacity + idx]
|
||||||
|
|
||||||
|
|
||||||
|
class SumSegmentTree(SegmentTree):
|
||||||
|
def __init__(self, capacity):
|
||||||
|
super(SumSegmentTree, self).__init__(
|
||||||
|
capacity=capacity,
|
||||||
|
operation=operator.add,
|
||||||
|
neutral_element=0.0
|
||||||
|
)
|
||||||
|
|
||||||
|
def sum(self, start=0, end=None):
|
||||||
|
"""Returns arr[start] + ... + arr[end]"""
|
||||||
|
return super(SumSegmentTree, self).reduce(start, end)
|
||||||
|
|
||||||
|
def find_prefixsum_idx(self, prefixsum):
|
||||||
|
"""Find the highest index `i` in the array such that
|
||||||
|
sum(arr[0] + arr[1] + ... + arr[i - i]) <= prefixsum
|
||||||
|
if array values are probabilities, this function
|
||||||
|
allows to sample indexes according to the discrete
|
||||||
|
probability efficiently.
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
perfixsum: float
|
||||||
|
upperbound on the sum of array prefix
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
idx: int
|
||||||
|
highest index satisfying the prefixsum constraint
|
||||||
|
"""
|
||||||
|
assert 0 <= prefixsum <= self.sum() + 1e-5
|
||||||
|
idx = 1
|
||||||
|
while idx < self._capacity: # while non-leaf
|
||||||
|
if self._value[2 * idx] > prefixsum:
|
||||||
|
idx = 2 * idx
|
||||||
|
else:
|
||||||
|
prefixsum -= self._value[2 * idx]
|
||||||
|
idx = 2 * idx + 1
|
||||||
|
return idx - self._capacity
|
||||||
|
|
||||||
|
|
||||||
|
class MinSegmentTree(SegmentTree):
|
||||||
|
def __init__(self, capacity):
|
||||||
|
super(MinSegmentTree, self).__init__(
|
||||||
|
capacity=capacity,
|
||||||
|
operation=min,
|
||||||
|
neutral_element=float('inf')
|
||||||
|
)
|
||||||
|
|
||||||
|
def min(self, start=0, end=None):
|
||||||
|
"""Returns min(arr[start], ..., arr[end])"""
|
||||||
|
return super(MinSegmentTree, self).reduce(start, end)
|
||||||
|
|
||||||
|
class PrioritizedReplayMemory(ReplayMemory):
|
||||||
|
def __init__(self, capacity, alpha):
|
||||||
|
"""Create Prioritized Replay buffer.
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
capacity: int
|
||||||
|
Max number of transitions to store in the buffer. When the buffer
|
||||||
|
overflows the old memories are dropped.
|
||||||
|
alpha: float
|
||||||
|
how much prioritization is used
|
||||||
|
(0 - no prioritization, 1 - full prioritization)
|
||||||
|
See Also
|
||||||
|
--------
|
||||||
|
ReplayBuffer.__init__
|
||||||
|
"""
|
||||||
|
super(PrioritizedReplayMemory, self).__init__(capacity)
|
||||||
|
assert alpha >= 0
|
||||||
|
self._alpha = alpha
|
||||||
|
|
||||||
|
it_capacity = 1
|
||||||
|
while it_capacity < capacity:
|
||||||
|
it_capacity *= 2
|
||||||
|
|
||||||
|
self._it_sum = SumSegmentTree(it_capacity)
|
||||||
|
self._it_min = MinSegmentTree(it_capacity)
|
||||||
|
self._max_priority = 1.0
|
||||||
|
|
||||||
|
def append(self, *args, **kwargs):
|
||||||
|
"""See ReplayBuffer.store_effect"""
|
||||||
|
idx = self.position
|
||||||
|
super().append(*args, **kwargs)
|
||||||
|
self._it_sum[idx] = self._max_priority ** self._alpha
|
||||||
|
self._it_min[idx] = self._max_priority ** self._alpha
|
||||||
|
|
||||||
|
def _sample_proportional(self, batch_size):
|
||||||
|
res = []
|
||||||
|
p_total = self._it_sum.sum(0, len(self.memory) - 1)
|
||||||
|
every_range_len = p_total / batch_size
|
||||||
|
for i in range(batch_size):
|
||||||
|
mass = random.random() * every_range_len + i * every_range_len
|
||||||
|
idx = self._it_sum.find_prefixsum_idx(mass)
|
||||||
|
res.append(idx)
|
||||||
|
return res
|
||||||
|
|
||||||
|
def sample(self, batch_size, beta):
|
||||||
|
"""Sample a batch of experiences.
|
||||||
|
compared to ReplayBuffer.sample
|
||||||
|
it also returns importance weights and idxes
|
||||||
|
of sampled experiences.
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
batch_size: int
|
||||||
|
How many transitions to sample.
|
||||||
|
beta: float
|
||||||
|
To what degree to use importance weights
|
||||||
|
(0 - no corrections, 1 - full correction)
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
obs_batch: np.array
|
||||||
|
batch of observations
|
||||||
|
act_batch: np.array
|
||||||
|
batch of actions executed given obs_batch
|
||||||
|
rew_batch: np.array
|
||||||
|
rewards received as results of executing act_batch
|
||||||
|
next_obs_batch: np.array
|
||||||
|
next set of observations seen after executing act_batch
|
||||||
|
done_mask: np.array
|
||||||
|
done_mask[i] = 1 if executing act_batch[i] resulted in
|
||||||
|
the end of an episode and 0 otherwise.
|
||||||
|
weights: np.array
|
||||||
|
Array of shape (batch_size,) and dtype np.float32
|
||||||
|
denoting importance weight of each sampled transition
|
||||||
|
idxes: np.array
|
||||||
|
Array of shape (batch_size,) and dtype np.int32
|
||||||
|
idexes in buffer of sampled experiences
|
||||||
|
"""
|
||||||
|
assert beta > 0
|
||||||
|
|
||||||
|
idxes = self._sample_proportional(batch_size)
|
||||||
|
|
||||||
|
weights = []
|
||||||
|
p_min = self._it_min.min() / self._it_sum.sum()
|
||||||
|
max_weight = (p_min * len(self.memory)) ** (-beta)
|
||||||
|
|
||||||
|
for idx in idxes:
|
||||||
|
p_sample = self._it_sum[idx] / self._it_sum.sum()
|
||||||
|
weight = (p_sample * len(self.memory)) ** (-beta)
|
||||||
|
weights.append(weight / max_weight)
|
||||||
|
weights = np.array(weights)
|
||||||
|
encoded_sample = tuple(zip(*self._encode_sample(idxes)))
|
||||||
|
batch = list(zip(*encoded_sample, weights, idxes))
|
||||||
|
return batch
|
||||||
|
|
||||||
|
def update_priorities(self, idxes, priorities):
|
||||||
|
"""Update priorities of sampled transitions.
|
||||||
|
sets priority of transition at index idxes[i] in buffer
|
||||||
|
to priorities[i].
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
idxes: [int]
|
||||||
|
List of idxes of sampled transitions
|
||||||
|
priorities: [float]
|
||||||
|
List of updated priorities corresponding to
|
||||||
|
transitions at the sampled idxes denoted by
|
||||||
|
variable `idxes`.
|
||||||
|
"""
|
||||||
|
assert len(idxes) == len(priorities)
|
||||||
|
priorities += np.finfo('float').eps
|
||||||
|
for idx, priority in zip(idxes, priorities):
|
||||||
|
assert priority > 0
|
||||||
|
assert 0 <= idx < len(self.memory)
|
||||||
|
self._it_sum[idx] = priority ** self._alpha
|
||||||
|
self._it_min[idx] = priority ** self._alpha
|
||||||
|
|
||||||
|
self._max_priority = max(self._max_priority, priority)
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
from random import sample
|
import random
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
import torch
|
import torch
|
||||||
Transition = namedtuple('Transition',
|
Transition = namedtuple('Transition',
|
||||||
|
@ -22,8 +22,22 @@ class ReplayMemory(object):
|
||||||
self.memory.clear()
|
self.memory.clear()
|
||||||
self.position = 0
|
self.position = 0
|
||||||
|
|
||||||
|
def _encode_sample(self, indexes):
|
||||||
|
states, actions, rewards, next_states, dones = [], [], [], [], []
|
||||||
|
for i in indexes:
|
||||||
|
observation = self.memory[i]
|
||||||
|
state, action, reward, next_state, done = observation
|
||||||
|
states.append(state)
|
||||||
|
actions.append(action)
|
||||||
|
rewards.append(reward)
|
||||||
|
next_states.append(next_state)
|
||||||
|
dones.append(done)
|
||||||
|
batch = list(zip(states, actions, rewards, next_states, dones))
|
||||||
|
return batch
|
||||||
|
|
||||||
|
|
||||||
def sample(self, batch_size):
|
def sample(self, batch_size):
|
||||||
return sample(self.memory, batch_size)
|
return random.sample(self.memory, batch_size)
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.memory)
|
return len(self.memory)
|
||||||
|
@ -43,7 +57,10 @@ class ReplayMemory(object):
|
||||||
def __reversed__(self):
|
def __reversed__(self):
|
||||||
return reversed(self.memory)
|
return reversed(self.memory)
|
||||||
|
|
||||||
def zip_batch(minibatch):
|
def zip_batch(minibatch, priority = False):
|
||||||
|
if priority:
|
||||||
|
state_batch, action_batch, reward_batch, next_state_batch, done_batch, weights, indexes = zip(*minibatch)
|
||||||
|
else:
|
||||||
state_batch, action_batch, reward_batch, next_state_batch, done_batch = zip(*minibatch)
|
state_batch, action_batch, reward_batch, next_state_batch, done_batch = zip(*minibatch)
|
||||||
|
|
||||||
state_batch = torch.cat(state_batch)
|
state_batch = torch.cat(state_batch)
|
||||||
|
@ -52,4 +69,7 @@ def zip_batch(minibatch):
|
||||||
not_done_batch = ~torch.tensor(done_batch)
|
not_done_batch = ~torch.tensor(done_batch)
|
||||||
next_state_batch = torch.cat(next_state_batch)[not_done_batch]
|
next_state_batch = torch.cat(next_state_batch)[not_done_batch]
|
||||||
|
|
||||||
|
if priority:
|
||||||
|
return state_batch, action_batch, reward_batch, next_state_batch, not_done_batch, weights, indexes
|
||||||
|
else:
|
||||||
return state_batch, action_batch, reward_batch, next_state_batch, not_done_batch
|
return state_batch, action_batch, reward_batch, next_state_batch, not_done_batch
|
|
@ -1 +1,2 @@
|
||||||
from .ReplayMemory import *
|
from .ReplayMemory import *
|
||||||
|
from .PrioritizedReplayMemory import *
|
||||||
|
|
|
@ -4,7 +4,10 @@ class Network:
|
||||||
"""
|
"""
|
||||||
def __init__(self, model, optimizer, config, device = None, logger = None, name = ""):
|
def __init__(self, model, optimizer, config, device = None, logger = None, name = ""):
|
||||||
self.model = model
|
self.model = model
|
||||||
|
if 'weight_decay' in config:
|
||||||
self.optimizer = optimizer(model.parameters(), lr = config['learning_rate'], weight_decay = config['weight_decay'])
|
self.optimizer = optimizer(model.parameters(), lr = config['learning_rate'], weight_decay = config['weight_decay'])
|
||||||
|
else:
|
||||||
|
self.optimizer = optimizer(model.parameters(), lr = config['learning_rate'])
|
||||||
self.logger = logger
|
self.logger = logger
|
||||||
self.name = name
|
self.name = name
|
||||||
self.device = device
|
self.device = device
|
||||||
|
@ -14,9 +17,10 @@ class Network:
|
||||||
def __call__(self, *args):
|
def __call__(self, *args):
|
||||||
return self.model(*args)
|
return self.model(*args)
|
||||||
|
|
||||||
def clamp_gradients(self):
|
def clamp_gradients(self, x = 1):
|
||||||
|
assert x > 0
|
||||||
for param in self.model.parameters():
|
for param in self.model.parameters():
|
||||||
param.grad.data.clamp_(-1, 1)
|
param.grad.data.clamp_(-x, x)
|
||||||
|
|
||||||
def zero_grad(self):
|
def zero_grad(self):
|
||||||
self.model.zero_grad()
|
self.model.zero_grad()
|
||||||
|
|
12
rltorch/scheduler/ExponentialScheduler.py
Normal file
12
rltorch/scheduler/ExponentialScheduler.py
Normal file
|
@ -0,0 +1,12 @@
|
||||||
|
from .Scheduler import Scheduler
|
||||||
|
class ExponentialScheduler(Scheduler):
|
||||||
|
def __init__(self, initial_value, end_value, iterations):
|
||||||
|
super(ExponentialScheduler, self).__init__(initial_value, end_value, iterations)
|
||||||
|
self.base = (end_value / initial_value) ** (1.0 / iterations)
|
||||||
|
def __next__(self):
|
||||||
|
if self.current_iteration < self.max_iterations:
|
||||||
|
self.current_iteration += 1
|
||||||
|
return self.initial_value * (self.base ** (self.current_iteration - 1))
|
||||||
|
else:
|
||||||
|
return self.end_value
|
||||||
|
|
12
rltorch/scheduler/LinearScheduler.py
Normal file
12
rltorch/scheduler/LinearScheduler.py
Normal file
|
@ -0,0 +1,12 @@
|
||||||
|
from .Scheduler import Scheduler
|
||||||
|
class LinearScheduler(Scheduler):
|
||||||
|
def __init__(self, initial_value, end_value, iterations):
|
||||||
|
super(LinearScheduler, self).__init__(initial_value, end_value, iterations)
|
||||||
|
self.slope = (end_value - initial_value) / iterations
|
||||||
|
def __next__(self):
|
||||||
|
if self.current_iteration < self.max_iterations:
|
||||||
|
self.current_iteration += 1
|
||||||
|
return self.slope * (self.current_iteration - 1) + self.initial_value
|
||||||
|
else:
|
||||||
|
return self.end_value
|
||||||
|
|
10
rltorch/scheduler/Scheduler.py
Normal file
10
rltorch/scheduler/Scheduler.py
Normal file
|
@ -0,0 +1,10 @@
|
||||||
|
class Scheduler():
|
||||||
|
def __init__(self, initial_value, end_value, iterations):
|
||||||
|
self.initial_value = initial_value
|
||||||
|
self.end_value = end_value
|
||||||
|
self.max_iterations = iterations
|
||||||
|
self.current_iteration = 0
|
||||||
|
def __iter__(self):
|
||||||
|
return self
|
||||||
|
def __next__(self):
|
||||||
|
raise NotImplementedError("Scheduler does not have it's function to create a value implemented")
|
3
rltorch/scheduler/__init__.py
Normal file
3
rltorch/scheduler/__init__.py
Normal file
|
@ -0,0 +1,3 @@
|
||||||
|
from .Scheduler import Scheduler
|
||||||
|
from .LinearScheduler import LinearScheduler
|
||||||
|
from .ExponentialScheduler import ExponentialScheduler
|
Loading…
Reference in a new issue