From 3217c76a79a295fb76e9b408d66f36e58cefeb0a Mon Sep 17 00:00:00 2001 From: Brandon Rozek Date: Sun, 17 Nov 2019 19:50:49 -0500 Subject: [PATCH] DQfD memory was adjusted to actually update the weights in the priority trees, fixing a bug in the sampling --- rltorch/memory/DQfDMemory.py | 10 +++++++--- rltorch/memory/PrioritizedReplayMemory.py | 4 ++-- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/rltorch/memory/DQfDMemory.py b/rltorch/memory/DQfDMemory.py index f32fb8f..af09e68 100644 --- a/rltorch/memory/DQfDMemory.py +++ b/rltorch/memory/DQfDMemory.py @@ -27,10 +27,14 @@ class DQfDMemory(PrioritizedReplayMemory): if len(demonstrations) + 1 > max_demo: self.memory.pop(0) self.memory.append(Transition(*args)) + self._it_sum[len(self.memory) - 1] = self._max_priority ** self._alpha + self._it_min[len(self.memory) - 1] = self._max_priority ** self._alpha else: if len(demonstrations) + len(obtained_transitions) + 1 > self.capacity: obtained_transitions = obtained_transitions[1:] self.memory = demonstrations + [Transition(*args)] + obtained_transitions + self._it_sum[len(demonstrations)] = self._max_priority ** self._alpha + self._it_min[len(demonstrations)] = self._max_priority ** self._alpha self.demo_position += 1 self.position += 1 @@ -56,13 +60,13 @@ class DQfDMemory(PrioritizedReplayMemory): weights = [] p_min = self._it_min.min() / self._it_sum.sum() max_weight = (p_min * len(self.memory)) ** (-beta) - for idx in idxes: + for idx in step_idxes: p_sample = self._it_sum[idx] / self._it_sum.sum() weight = (p_sample * len(self.memory)) ** (-beta) - weights += [(weight / max_weight) for i in range(steps)] + weights.append(weight / max_weight) weights = np.array(weights) # Combine all the data together into a batch encoded_sample = tuple(zip(*self._encode_sample(step_idxes))) batch = list(zip(*encoded_sample, weights, step_idxes)) - return batch \ No newline at end of file + return batch diff --git a/rltorch/memory/PrioritizedReplayMemory.py b/rltorch/memory/PrioritizedReplayMemory.py index da3c767..58843e3 100644 --- a/rltorch/memory/PrioritizedReplayMemory.py +++ b/rltorch/memory/PrioritizedReplayMemory.py @@ -249,10 +249,10 @@ class PrioritizedReplayMemory(ReplayMemory): weights = [] p_min = self._it_min.min() / self._it_sum.sum() max_weight = (p_min * len(self.memory)) ** (-beta) - for idx in idxes: + for idx in step_idxes: p_sample = self._it_sum[idx] / self._it_sum.sum() weight = (p_sample * len(self.memory)) ** (-beta) - weights += [(weight / max_weight) for i in range(steps)] + weights.append(weight / max_weight) weights = np.array(weights) # Combine all the data together into a batch