DQfD memory was adjusted to actually update the weights in the priority trees, fixing a bug in the sampling

This commit is contained in:
Brandon Rozek 2019-11-17 19:50:49 -05:00
parent 23532fc372
commit 3217c76a79
2 changed files with 9 additions and 5 deletions

View file

@ -27,10 +27,14 @@ class DQfDMemory(PrioritizedReplayMemory):
if len(demonstrations) + 1 > max_demo:
self.memory.pop(0)
self.memory.append(Transition(*args))
self._it_sum[len(self.memory) - 1] = self._max_priority ** self._alpha
self._it_min[len(self.memory) - 1] = self._max_priority ** self._alpha
else:
if len(demonstrations) + len(obtained_transitions) + 1 > self.capacity:
obtained_transitions = obtained_transitions[1:]
self.memory = demonstrations + [Transition(*args)] + obtained_transitions
self._it_sum[len(demonstrations)] = self._max_priority ** self._alpha
self._it_min[len(demonstrations)] = self._max_priority ** self._alpha
self.demo_position += 1
self.position += 1
@ -56,13 +60,13 @@ class DQfDMemory(PrioritizedReplayMemory):
weights = []
p_min = self._it_min.min() / self._it_sum.sum()
max_weight = (p_min * len(self.memory)) ** (-beta)
for idx in idxes:
for idx in step_idxes:
p_sample = self._it_sum[idx] / self._it_sum.sum()
weight = (p_sample * len(self.memory)) ** (-beta)
weights += [(weight / max_weight) for i in range(steps)]
weights.append(weight / max_weight)
weights = np.array(weights)
# Combine all the data together into a batch
encoded_sample = tuple(zip(*self._encode_sample(step_idxes)))
batch = list(zip(*encoded_sample, weights, step_idxes))
return batch
return batch

View file

@ -249,10 +249,10 @@ class PrioritizedReplayMemory(ReplayMemory):
weights = []
p_min = self._it_min.min() / self._it_sum.sum()
max_weight = (p_min * len(self.memory)) ** (-beta)
for idx in idxes:
for idx in step_idxes:
p_sample = self._it_sum[idx] / self._it_sum.sum()
weight = (p_sample * len(self.memory)) ** (-beta)
weights += [(weight / max_weight) for i in range(steps)]
weights.append(weight / max_weight)
weights = np.array(weights)
# Combine all the data together into a batch