DQfD memory was adjusted to actually update the weights in the priority trees, fixing a bug in the sampling
This commit is contained in:
		
							parent
							
								
									23532fc372
								
							
						
					
					
						commit
						3217c76a79
					
				
					 2 changed files with 9 additions and 5 deletions
				
			
		|  | @ -27,10 +27,14 @@ class DQfDMemory(PrioritizedReplayMemory): | |||
|         if len(demonstrations) + 1 > max_demo: | ||||
|             self.memory.pop(0) | ||||
|             self.memory.append(Transition(*args)) | ||||
|             self._it_sum[len(self.memory) - 1] = self._max_priority ** self._alpha | ||||
|             self._it_min[len(self.memory) - 1] = self._max_priority ** self._alpha | ||||
|         else: | ||||
|             if len(demonstrations) + len(obtained_transitions) + 1 > self.capacity: | ||||
|                 obtained_transitions = obtained_transitions[1:] | ||||
|             self.memory = demonstrations + [Transition(*args)] + obtained_transitions | ||||
|             self._it_sum[len(demonstrations)] = self._max_priority ** self._alpha | ||||
|             self._it_min[len(demonstrations)] = self._max_priority ** self._alpha | ||||
|             self.demo_position += 1 | ||||
|             self.position += 1 | ||||
|      | ||||
|  | @ -56,13 +60,13 @@ class DQfDMemory(PrioritizedReplayMemory): | |||
|         weights = [] | ||||
|         p_min = self._it_min.min() / self._it_sum.sum() | ||||
|         max_weight = (p_min * len(self.memory)) ** (-beta) | ||||
|         for idx in idxes: | ||||
|         for idx in step_idxes: | ||||
|             p_sample = self._it_sum[idx] / self._it_sum.sum() | ||||
|             weight = (p_sample * len(self.memory)) ** (-beta) | ||||
|             weights += [(weight / max_weight) for i in range(steps)] | ||||
|             weights.append(weight / max_weight) | ||||
|         weights = np.array(weights) | ||||
|          | ||||
|         # Combine all the data together into a batch | ||||
|         encoded_sample = tuple(zip(*self._encode_sample(step_idxes))) | ||||
|         batch = list(zip(*encoded_sample, weights, step_idxes)) | ||||
|         return batch | ||||
|         return batch | ||||
|  |  | |||
|  | @ -249,10 +249,10 @@ class PrioritizedReplayMemory(ReplayMemory): | |||
|         weights = [] | ||||
|         p_min = self._it_min.min() / self._it_sum.sum() | ||||
|         max_weight = (p_min * len(self.memory)) ** (-beta) | ||||
|         for idx in idxes: | ||||
|         for idx in step_idxes: | ||||
|             p_sample = self._it_sum[idx] / self._it_sum.sum() | ||||
|             weight = (p_sample * len(self.memory)) ** (-beta) | ||||
|             weights += [(weight / max_weight) for i in range(steps)] | ||||
|             weights.append(weight / max_weight) | ||||
|         weights = np.array(weights) | ||||
|          | ||||
|         # Combine all the data together into a batch | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue