DQfD memory was adjusted to actually update the weights in the priority trees, fixing a bug in the sampling
This commit is contained in:
parent
23532fc372
commit
3217c76a79
2 changed files with 9 additions and 5 deletions
|
@ -27,10 +27,14 @@ class DQfDMemory(PrioritizedReplayMemory):
|
|||
if len(demonstrations) + 1 > max_demo:
|
||||
self.memory.pop(0)
|
||||
self.memory.append(Transition(*args))
|
||||
self._it_sum[len(self.memory) - 1] = self._max_priority ** self._alpha
|
||||
self._it_min[len(self.memory) - 1] = self._max_priority ** self._alpha
|
||||
else:
|
||||
if len(demonstrations) + len(obtained_transitions) + 1 > self.capacity:
|
||||
obtained_transitions = obtained_transitions[1:]
|
||||
self.memory = demonstrations + [Transition(*args)] + obtained_transitions
|
||||
self._it_sum[len(demonstrations)] = self._max_priority ** self._alpha
|
||||
self._it_min[len(demonstrations)] = self._max_priority ** self._alpha
|
||||
self.demo_position += 1
|
||||
self.position += 1
|
||||
|
||||
|
@ -56,10 +60,10 @@ class DQfDMemory(PrioritizedReplayMemory):
|
|||
weights = []
|
||||
p_min = self._it_min.min() / self._it_sum.sum()
|
||||
max_weight = (p_min * len(self.memory)) ** (-beta)
|
||||
for idx in idxes:
|
||||
for idx in step_idxes:
|
||||
p_sample = self._it_sum[idx] / self._it_sum.sum()
|
||||
weight = (p_sample * len(self.memory)) ** (-beta)
|
||||
weights += [(weight / max_weight) for i in range(steps)]
|
||||
weights.append(weight / max_weight)
|
||||
weights = np.array(weights)
|
||||
|
||||
# Combine all the data together into a batch
|
||||
|
|
|
@ -249,10 +249,10 @@ class PrioritizedReplayMemory(ReplayMemory):
|
|||
weights = []
|
||||
p_min = self._it_min.min() / self._it_sum.sum()
|
||||
max_weight = (p_min * len(self.memory)) ** (-beta)
|
||||
for idx in idxes:
|
||||
for idx in step_idxes:
|
||||
p_sample = self._it_sum[idx] / self._it_sum.sum()
|
||||
weight = (p_sample * len(self.memory)) ** (-beta)
|
||||
weights += [(weight / max_weight) for i in range(steps)]
|
||||
weights.append(weight / max_weight)
|
||||
weights = np.array(weights)
|
||||
|
||||
# Combine all the data together into a batch
|
||||
|
|
Loading…
Reference in a new issue