Added a way to cap the number of demonstrations that are kept in the buffer
This commit is contained in:
parent
038d406d0f
commit
23532fc372
1 changed files with 5 additions and 2 deletions
|
@ -7,9 +7,11 @@ Transition = namedtuple('Transition',
|
||||||
|
|
||||||
|
|
||||||
class DQfDMemory(PrioritizedReplayMemory):
|
class DQfDMemory(PrioritizedReplayMemory):
|
||||||
def __init__(self, capacity, alpha):
|
def __init__(self, capacity, alpha, max_demo = -1):
|
||||||
|
assert max_demo <= capacity
|
||||||
super().__init__(capacity, alpha)
|
super().__init__(capacity, alpha)
|
||||||
self.demo_position = 0
|
self.demo_position = 0
|
||||||
|
self.max_demo = max_demo # -1 means no maximum number of demonstrations
|
||||||
|
|
||||||
def append(self, *args, **kwargs):
|
def append(self, *args, **kwargs):
|
||||||
last_position = self.position # Get position before super classes change it
|
last_position = self.position # Get position before super classes change it
|
||||||
|
@ -21,7 +23,8 @@ class DQfDMemory(PrioritizedReplayMemory):
|
||||||
def append_demonstration(self, *args):
|
def append_demonstration(self, *args):
|
||||||
demonstrations = self.memory[:self.demo_position]
|
demonstrations = self.memory[:self.demo_position]
|
||||||
obtained_transitions = self.memory[self.demo_position:]
|
obtained_transitions = self.memory[self.demo_position:]
|
||||||
if len(demonstrations) + 1 > self.capacity:
|
max_demo = self.max_demo if self.max_demo > -1 else self.capacity
|
||||||
|
if len(demonstrations) + 1 > max_demo:
|
||||||
self.memory.pop(0)
|
self.memory.pop(0)
|
||||||
self.memory.append(Transition(*args))
|
self.memory.append(Transition(*args))
|
||||||
else:
|
else:
|
||||||
|
|
Loading…
Reference in a new issue