diff --git a/rltorch/memory/DQfDMemory.py b/rltorch/memory/DQfDMemory.py index d4ed582..f32fb8f 100644 --- a/rltorch/memory/DQfDMemory.py +++ b/rltorch/memory/DQfDMemory.py @@ -7,9 +7,11 @@ Transition = namedtuple('Transition', class DQfDMemory(PrioritizedReplayMemory): - def __init__(self, capacity, alpha): + def __init__(self, capacity, alpha, max_demo = -1): + assert max_demo <= capacity super().__init__(capacity, alpha) self.demo_position = 0 + self.max_demo = max_demo # -1 means no maximum number of demonstrations def append(self, *args, **kwargs): last_position = self.position # Get position before super classes change it @@ -21,7 +23,8 @@ class DQfDMemory(PrioritizedReplayMemory): def append_demonstration(self, *args): demonstrations = self.memory[:self.demo_position] obtained_transitions = self.memory[self.demo_position:] - if len(demonstrations) + 1 > self.capacity: + max_demo = self.max_demo if self.max_demo > -1 else self.capacity + if len(demonstrations) + 1 > max_demo: self.memory.pop(0) self.memory.append(Transition(*args)) else: