Added a way to cap the number of demonstrations that are kept in the buffer

This commit is contained in:
Brandon Rozek 2019-11-17 18:29:12 -05:00
parent 038d406d0f
commit 23532fc372

View file

@ -7,9 +7,11 @@ Transition = namedtuple('Transition',
class DQfDMemory(PrioritizedReplayMemory): class DQfDMemory(PrioritizedReplayMemory):
def __init__(self, capacity, alpha): def __init__(self, capacity, alpha, max_demo = -1):
assert max_demo <= capacity
super().__init__(capacity, alpha) super().__init__(capacity, alpha)
self.demo_position = 0 self.demo_position = 0
self.max_demo = max_demo # -1 means no maximum number of demonstrations
def append(self, *args, **kwargs): def append(self, *args, **kwargs):
last_position = self.position # Get position before super classes change it last_position = self.position # Get position before super classes change it
@ -21,7 +23,8 @@ class DQfDMemory(PrioritizedReplayMemory):
def append_demonstration(self, *args): def append_demonstration(self, *args):
demonstrations = self.memory[:self.demo_position] demonstrations = self.memory[:self.demo_position]
obtained_transitions = self.memory[self.demo_position:] obtained_transitions = self.memory[self.demo_position:]
if len(demonstrations) + 1 > self.capacity: max_demo = self.max_demo if self.max_demo > -1 else self.capacity
if len(demonstrations) + 1 > max_demo:
self.memory.pop(0) self.memory.pop(0)
self.memory.append(Transition(*args)) self.memory.append(Transition(*args))
else: else: