diff --git a/rltorch/__init__.py b/rltorch/__init__.py index d2e8f73..e58faa9 100644 --- a/rltorch/__init__.py +++ b/rltorch/__init__.py @@ -4,5 +4,6 @@ from . import env from . import memory from . import network from . import mp +from . import scheduler from .seed import * from . import log \ No newline at end of file diff --git a/rltorch/agents/DQNAgent.py b/rltorch/agents/DQNAgent.py index 08f96b0..6d7ab53 100644 --- a/rltorch/agents/DQNAgent.py +++ b/rltorch/agents/DQNAgent.py @@ -1,7 +1,9 @@ +import collections import rltorch.memory as M import torch import torch.nn.functional as F from copy import deepcopy +import numpy as np class DQNAgent: def __init__(self, net , memory, config, target_net = None, logger = None): @@ -14,9 +16,16 @@ class DQNAgent: def learn(self): if len(self.memory) < self.config['batch_size']: return - - minibatch = self.memory.sample(self.config['batch_size']) - state_batch, action_batch, reward_batch, next_state_batch, not_done_batch = M.zip_batch(minibatch) + + if (isinstance(self.memory, M.PrioritizedReplayMemory)): + weight_importance = self.config['prioritized_replay_weight_importance'] + # If it's a scheduler then get the next value by calling next, otherwise just use it's value + beta = next(weight_importance) if isinstance(weight_importance, collections.Iterable) else weight_importance + minibatch = self.memory.sample(self.config['batch_size'], beta = beta) + state_batch, action_batch, reward_batch, next_state_batch, not_done_batch, importance_weights, batch_indexes = M.zip_batch(minibatch, priority = True) + else: + minibatch = self.memory.sample(self.config['batch_size']) + state_batch, action_batch, reward_batch, next_state_batch, not_done_batch = M.zip_batch(minibatch) # Send to their appropriate devices state_batch = state_batch.to(self.net.device) @@ -44,7 +53,10 @@ class DQNAgent: expected_values = (reward_batch + (self.config['discount_rate'] * best_next_state_value)).unsqueeze(1) - loss = F.mse_loss(obtained_values, expected_values) + if (isinstance(self.memory, M.PrioritizedReplayMemory)): + loss = (torch.as_tensor(importance_weights) * (obtained_values - expected_values)**2).mean() + else: + loss = F.mse_loss(obtained_values, expected_values) if self.logger is not None: self.logger.append("Loss", loss.item()) @@ -59,3 +71,9 @@ class DQNAgent: self.target_net.partial_sync(self.config['target_sync_tau']) else: self.target_net.sync() + + if (isinstance(self.memory, M.PrioritizedReplayMemory)): + td_error = (obtained_values - expected_values).detach().abs() + self.memory.update_priorities(batch_indexes, td_error) + + diff --git a/rltorch/memory/PrioritizedReplayMemory.py b/rltorch/memory/PrioritizedReplayMemory.py new file mode 100644 index 0000000..a3f860b --- /dev/null +++ b/rltorch/memory/PrioritizedReplayMemory.py @@ -0,0 +1,255 @@ +# From OpenAI Baselines https://github.com/openai/baselines/blob/master/baselines/deepq/replay_buffer.py + +from .ReplayMemory import ReplayMemory +import operator +import random +import numpy as np + +class SegmentTree(object): + def __init__(self, capacity, operation, neutral_element): + """Build a Segment Tree data structure. + https://en.wikipedia.org/wiki/Segment_tree + Can be used as regular array, but with two + important differences: + a) setting item's value is slightly slower. + It is O(lg capacity) instead of O(1). + b) user has access to an efficient ( O(log segment size) ) + `reduce` operation which reduces `operation` over + a contiguous subsequence of items in the array. + Paramters + --------- + capacity: int + Total size of the array - must be a power of two. + operation: lambda obj, obj -> obj + and operation for combining elements (eg. sum, max) + must form a mathematical group together with the set of + possible values for array elements (i.e. be associative) + neutral_element: obj + neutral element for the operation above. eg. float('-inf') + for max and 0 for sum. + """ + assert capacity > 0 and capacity & (capacity - 1) == 0, "capacity must be positive and a power of 2." + self._capacity = capacity + self._value = [neutral_element for _ in range(2 * capacity)] + self._operation = operation + + def _reduce_helper(self, start, end, node, node_start, node_end): + if start == node_start and end == node_end: + return self._value[node] + mid = (node_start + node_end) // 2 + if end <= mid: + return self._reduce_helper(start, end, 2 * node, node_start, mid) + else: + if mid + 1 <= start: + return self._reduce_helper(start, end, 2 * node + 1, mid + 1, node_end) + else: + return self._operation( + self._reduce_helper(start, mid, 2 * node, node_start, mid), + self._reduce_helper(mid + 1, end, 2 * node + 1, mid + 1, node_end) + ) + + def reduce(self, start=0, end=None): + """Returns result of applying `self.operation` + to a contiguous subsequence of the array. + self.operation(arr[start], operation(arr[start+1], operation(... arr[end]))) + Parameters + ---------- + start: int + beginning of the subsequence + end: int + end of the subsequences + Returns + ------- + reduced: obj + result of reducing self.operation over the specified range of array elements. + """ + if end is None: + end = self._capacity + if end < 0: + end += self._capacity + end -= 1 + return self._reduce_helper(start, end, 1, 0, self._capacity - 1) + + def __setitem__(self, idx, val): + # index of the leaf + idx += self._capacity + self._value[idx] = val + idx //= 2 + while idx >= 1: + self._value[idx] = self._operation( + self._value[2 * idx], + self._value[2 * idx + 1] + ) + idx //= 2 + + def __getitem__(self, idx): + assert 0 <= idx < self._capacity + return self._value[self._capacity + idx] + + +class SumSegmentTree(SegmentTree): + def __init__(self, capacity): + super(SumSegmentTree, self).__init__( + capacity=capacity, + operation=operator.add, + neutral_element=0.0 + ) + + def sum(self, start=0, end=None): + """Returns arr[start] + ... + arr[end]""" + return super(SumSegmentTree, self).reduce(start, end) + + def find_prefixsum_idx(self, prefixsum): + """Find the highest index `i` in the array such that + sum(arr[0] + arr[1] + ... + arr[i - i]) <= prefixsum + if array values are probabilities, this function + allows to sample indexes according to the discrete + probability efficiently. + Parameters + ---------- + perfixsum: float + upperbound on the sum of array prefix + Returns + ------- + idx: int + highest index satisfying the prefixsum constraint + """ + assert 0 <= prefixsum <= self.sum() + 1e-5 + idx = 1 + while idx < self._capacity: # while non-leaf + if self._value[2 * idx] > prefixsum: + idx = 2 * idx + else: + prefixsum -= self._value[2 * idx] + idx = 2 * idx + 1 + return idx - self._capacity + + +class MinSegmentTree(SegmentTree): + def __init__(self, capacity): + super(MinSegmentTree, self).__init__( + capacity=capacity, + operation=min, + neutral_element=float('inf') + ) + + def min(self, start=0, end=None): + """Returns min(arr[start], ..., arr[end])""" + return super(MinSegmentTree, self).reduce(start, end) + +class PrioritizedReplayMemory(ReplayMemory): + def __init__(self, capacity, alpha): + """Create Prioritized Replay buffer. + Parameters + ---------- + capacity: int + Max number of transitions to store in the buffer. When the buffer + overflows the old memories are dropped. + alpha: float + how much prioritization is used + (0 - no prioritization, 1 - full prioritization) + See Also + -------- + ReplayBuffer.__init__ + """ + super(PrioritizedReplayMemory, self).__init__(capacity) + assert alpha >= 0 + self._alpha = alpha + + it_capacity = 1 + while it_capacity < capacity: + it_capacity *= 2 + + self._it_sum = SumSegmentTree(it_capacity) + self._it_min = MinSegmentTree(it_capacity) + self._max_priority = 1.0 + + def append(self, *args, **kwargs): + """See ReplayBuffer.store_effect""" + idx = self.position + super().append(*args, **kwargs) + self._it_sum[idx] = self._max_priority ** self._alpha + self._it_min[idx] = self._max_priority ** self._alpha + + def _sample_proportional(self, batch_size): + res = [] + p_total = self._it_sum.sum(0, len(self.memory) - 1) + every_range_len = p_total / batch_size + for i in range(batch_size): + mass = random.random() * every_range_len + i * every_range_len + idx = self._it_sum.find_prefixsum_idx(mass) + res.append(idx) + return res + + def sample(self, batch_size, beta): + """Sample a batch of experiences. + compared to ReplayBuffer.sample + it also returns importance weights and idxes + of sampled experiences. + Parameters + ---------- + batch_size: int + How many transitions to sample. + beta: float + To what degree to use importance weights + (0 - no corrections, 1 - full correction) + Returns + ------- + obs_batch: np.array + batch of observations + act_batch: np.array + batch of actions executed given obs_batch + rew_batch: np.array + rewards received as results of executing act_batch + next_obs_batch: np.array + next set of observations seen after executing act_batch + done_mask: np.array + done_mask[i] = 1 if executing act_batch[i] resulted in + the end of an episode and 0 otherwise. + weights: np.array + Array of shape (batch_size,) and dtype np.float32 + denoting importance weight of each sampled transition + idxes: np.array + Array of shape (batch_size,) and dtype np.int32 + idexes in buffer of sampled experiences + """ + assert beta > 0 + + idxes = self._sample_proportional(batch_size) + + weights = [] + p_min = self._it_min.min() / self._it_sum.sum() + max_weight = (p_min * len(self.memory)) ** (-beta) + + for idx in idxes: + p_sample = self._it_sum[idx] / self._it_sum.sum() + weight = (p_sample * len(self.memory)) ** (-beta) + weights.append(weight / max_weight) + weights = np.array(weights) + encoded_sample = tuple(zip(*self._encode_sample(idxes))) + batch = list(zip(*encoded_sample, weights, idxes)) + return batch + + def update_priorities(self, idxes, priorities): + """Update priorities of sampled transitions. + sets priority of transition at index idxes[i] in buffer + to priorities[i]. + Parameters + ---------- + idxes: [int] + List of idxes of sampled transitions + priorities: [float] + List of updated priorities corresponding to + transitions at the sampled idxes denoted by + variable `idxes`. + """ + assert len(idxes) == len(priorities) + priorities += np.finfo('float').eps + for idx, priority in zip(idxes, priorities): + assert priority > 0 + assert 0 <= idx < len(self.memory) + self._it_sum[idx] = priority ** self._alpha + self._it_min[idx] = priority ** self._alpha + + self._max_priority = max(self._max_priority, priority) + diff --git a/rltorch/memory/ReplayMemory.py b/rltorch/memory/ReplayMemory.py index f9d6b2f..367b9c9 100644 --- a/rltorch/memory/ReplayMemory.py +++ b/rltorch/memory/ReplayMemory.py @@ -1,4 +1,4 @@ -from random import sample +import random from collections import namedtuple import torch Transition = namedtuple('Transition', @@ -22,8 +22,22 @@ class ReplayMemory(object): self.memory.clear() self.position = 0 + def _encode_sample(self, indexes): + states, actions, rewards, next_states, dones = [], [], [], [], [] + for i in indexes: + observation = self.memory[i] + state, action, reward, next_state, done = observation + states.append(state) + actions.append(action) + rewards.append(reward) + next_states.append(next_state) + dones.append(done) + batch = list(zip(states, actions, rewards, next_states, dones)) + return batch + + def sample(self, batch_size): - return sample(self.memory, batch_size) + return random.sample(self.memory, batch_size) def __len__(self): return len(self.memory) @@ -43,8 +57,11 @@ class ReplayMemory(object): def __reversed__(self): return reversed(self.memory) -def zip_batch(minibatch): - state_batch, action_batch, reward_batch, next_state_batch, done_batch = zip(*minibatch) +def zip_batch(minibatch, priority = False): + if priority: + state_batch, action_batch, reward_batch, next_state_batch, done_batch, weights, indexes = zip(*minibatch) + else: + state_batch, action_batch, reward_batch, next_state_batch, done_batch = zip(*minibatch) state_batch = torch.cat(state_batch) action_batch = torch.tensor(action_batch) @@ -52,4 +69,7 @@ def zip_batch(minibatch): not_done_batch = ~torch.tensor(done_batch) next_state_batch = torch.cat(next_state_batch)[not_done_batch] - return state_batch, action_batch, reward_batch, next_state_batch, not_done_batch \ No newline at end of file + if priority: + return state_batch, action_batch, reward_batch, next_state_batch, not_done_batch, weights, indexes + else: + return state_batch, action_batch, reward_batch, next_state_batch, not_done_batch \ No newline at end of file diff --git a/rltorch/memory/__init__.py b/rltorch/memory/__init__.py index d4f414f..ca2a917 100644 --- a/rltorch/memory/__init__.py +++ b/rltorch/memory/__init__.py @@ -1 +1,2 @@ from .ReplayMemory import * +from .PrioritizedReplayMemory import * diff --git a/rltorch/network/Network.py b/rltorch/network/Network.py index 06603ac..ea523e1 100644 --- a/rltorch/network/Network.py +++ b/rltorch/network/Network.py @@ -4,7 +4,10 @@ class Network: """ def __init__(self, model, optimizer, config, device = None, logger = None, name = ""): self.model = model - self.optimizer = optimizer(model.parameters(), lr = config['learning_rate'], weight_decay = config['weight_decay']) + if 'weight_decay' in config: + self.optimizer = optimizer(model.parameters(), lr = config['learning_rate'], weight_decay = config['weight_decay']) + else: + self.optimizer = optimizer(model.parameters(), lr = config['learning_rate']) self.logger = logger self.name = name self.device = device @@ -14,9 +17,10 @@ class Network: def __call__(self, *args): return self.model(*args) - def clamp_gradients(self): + def clamp_gradients(self, x = 1): + assert x > 0 for param in self.model.parameters(): - param.grad.data.clamp_(-1, 1) + param.grad.data.clamp_(-x, x) def zero_grad(self): self.model.zero_grad() diff --git a/rltorch/scheduler/ExponentialScheduler.py b/rltorch/scheduler/ExponentialScheduler.py new file mode 100644 index 0000000..ca8d162 --- /dev/null +++ b/rltorch/scheduler/ExponentialScheduler.py @@ -0,0 +1,12 @@ +from .Scheduler import Scheduler +class ExponentialScheduler(Scheduler): + def __init__(self, initial_value, end_value, iterations): + super(ExponentialScheduler, self).__init__(initial_value, end_value, iterations) + self.base = (end_value / initial_value) ** (1.0 / iterations) + def __next__(self): + if self.current_iteration < self.max_iterations: + self.current_iteration += 1 + return self.initial_value * (self.base ** (self.current_iteration - 1)) + else: + return self.end_value + diff --git a/rltorch/scheduler/LinearScheduler.py b/rltorch/scheduler/LinearScheduler.py new file mode 100644 index 0000000..984f3eb --- /dev/null +++ b/rltorch/scheduler/LinearScheduler.py @@ -0,0 +1,12 @@ +from .Scheduler import Scheduler +class LinearScheduler(Scheduler): + def __init__(self, initial_value, end_value, iterations): + super(LinearScheduler, self).__init__(initial_value, end_value, iterations) + self.slope = (end_value - initial_value) / iterations + def __next__(self): + if self.current_iteration < self.max_iterations: + self.current_iteration += 1 + return self.slope * (self.current_iteration - 1) + self.initial_value + else: + return self.end_value + \ No newline at end of file diff --git a/rltorch/scheduler/Scheduler.py b/rltorch/scheduler/Scheduler.py new file mode 100644 index 0000000..0314907 --- /dev/null +++ b/rltorch/scheduler/Scheduler.py @@ -0,0 +1,10 @@ +class Scheduler(): + def __init__(self, initial_value, end_value, iterations): + self.initial_value = initial_value + self.end_value = end_value + self.max_iterations = iterations + self.current_iteration = 0 + def __iter__(self): + return self + def __next__(self): + raise NotImplementedError("Scheduler does not have it's function to create a value implemented") \ No newline at end of file diff --git a/rltorch/scheduler/__init__.py b/rltorch/scheduler/__init__.py new file mode 100644 index 0000000..1d7d9b2 --- /dev/null +++ b/rltorch/scheduler/__init__.py @@ -0,0 +1,3 @@ +from .Scheduler import Scheduler +from .LinearScheduler import LinearScheduler +from .ExponentialScheduler import ExponentialScheduler