Initial implementation of n-step loss
This commit is contained in:
		
							parent
							
								
									07c90a09f9
								
							
						
					
					
						commit
						ed62e148d5
					
				
					 4 changed files with 111 additions and 33 deletions
				
			
		| 
						 | 
				
			
			@ -28,10 +28,19 @@ class DQfDAgent:
 | 
			
		|||
        weight_importance = self.config['prioritized_replay_weight_importance']
 | 
			
		||||
        # If it's a scheduler then get the next value by calling next, otherwise just use it's value
 | 
			
		||||
        beta = next(weight_importance) if isinstance(weight_importance, collections.Iterable) else weight_importance
 | 
			
		||||
        minibatch = self.memory.sample(self.config['batch_size'], beta = beta)
 | 
			
		||||
        state_batch, action_batch, reward_batch, next_state_batch, not_done_batch, importance_weights, batch_indexes = M.zip_batch(minibatch, priority = True)
 | 
			
		||||
        
 | 
			
		||||
        demo_indexes = batch_indexes < self.memory.demo_position
 | 
			
		||||
        # Check to see if we are doing N-Step DQN
 | 
			
		||||
        steps = self.config['n_step'] if 'n_step' in self.config else None
 | 
			
		||||
        if steps is not None:
 | 
			
		||||
            minibatch = self.memory.sample_n_steps(self.config['batch_size'], steps, beta)
 | 
			
		||||
        else:
 | 
			
		||||
            minibatch = self.memory.sample(self.config['batch_size'], beta = beta)
 | 
			
		||||
 | 
			
		||||
        # Process batch
 | 
			
		||||
        state_batch, action_batch, reward_batch, next_state_batch, not_done_batch, importance_weights, batch_indexes = M.zip_batch(minibatch, priority = True)
 | 
			
		||||
 | 
			
		||||
        batch_index_tensors = torch.tensor(batch_indexes)
 | 
			
		||||
        demo_mask = batch_index_tensors < self.memory.demo_position
 | 
			
		||||
        
 | 
			
		||||
        # Send to their appropriate devices
 | 
			
		||||
        state_batch = state_batch.to(self.net.device)
 | 
			
		||||
| 
						 | 
				
			
			@ -43,6 +52,7 @@ class DQfDAgent:
 | 
			
		|||
        state_values = self.net(state_batch)
 | 
			
		||||
        obtained_values = state_values.gather(1, action_batch.view(self.config['batch_size'], 1))
 | 
			
		||||
 | 
			
		||||
        # DQN Loss
 | 
			
		||||
        with torch.no_grad():
 | 
			
		||||
            # Use the target net to produce action values for the next state
 | 
			
		||||
            # and the regular net to select the action
 | 
			
		||||
| 
						 | 
				
			
			@ -61,19 +71,48 @@ class DQfDAgent:
 | 
			
		|||
            
 | 
			
		||||
        expected_values = (reward_batch + (self.config['discount_rate'] * best_next_state_value)).unsqueeze(1)
 | 
			
		||||
 | 
			
		||||
        # Demonstration loss
 | 
			
		||||
        l = torch.ones_like(state_values[demo_indexes])
 | 
			
		||||
        expert_actions = action_batch[demo_indexes]
 | 
			
		||||
        # l(s, a) is zero for every action the expert doesn't take
 | 
			
		||||
        for i,a in zip(range(len(l)), expert_actions):
 | 
			
		||||
            l[i].fill_(0.8) # According to paper
 | 
			
		||||
            l[i, a] = 0
 | 
			
		||||
        if self.target_net is not None:
 | 
			
		||||
            expert_value = self.target_net(state_batch[demo_indexes])
 | 
			
		||||
        else:
 | 
			
		||||
            expert_value = state_values[demo_indexes]
 | 
			
		||||
        expert_value = expert_value.gather(1, expert_actions.view((self.config['batch_size'], 1))).squeeze(1)
 | 
			
		||||
        # N-Step DQN Loss
 | 
			
		||||
        expected_n_step_values = []
 | 
			
		||||
        with torch.no_grad():
 | 
			
		||||
            for i in range(0, len(state_batch), steps):
 | 
			
		||||
                # Get the estimated value at the last state in a sequence
 | 
			
		||||
                if self.target_net is not None:
 | 
			
		||||
                    expected_nth_values = self.target_net(state_batch[i + steps])
 | 
			
		||||
                    best_nth_action = self.net(state_batch[i + steps]).argmax(1)
 | 
			
		||||
                else:
 | 
			
		||||
                    expected_nth_values = self.net(state_batch[i + steps])
 | 
			
		||||
                    best_nth_action = expected_nth_values.argmax(1)
 | 
			
		||||
                best_expected_nth_value = expected_nth_values[best_nth_action].squeeze(1)
 | 
			
		||||
                # Calculate the value leading up to it by taking the rewards and multiplying it by the discount rate
 | 
			
		||||
                received_n_value = 0
 | 
			
		||||
                for j in range(steps):
 | 
			
		||||
                    received_n_value += self.config['discount_rate']**j * reward_batch[j]
 | 
			
		||||
                # Q(s, a) = r_0 + lambda_1 * r_1 + lambda_2^2 * r_2 + ... + lambda_{steps}^{steps} * max_{a}(Q(s + steps, a))
 | 
			
		||||
                expected_n_step_values.append(received_n_value + self.config['discount_rate']**steps * best_expected_nth_value)
 | 
			
		||||
            expected_n_step_values = torch.stack(expected_n_step_values)
 | 
			
		||||
        # Gather the value the current network thinks it should be
 | 
			
		||||
        observed_n_step_values = []
 | 
			
		||||
        for i in range(0, len(state_batch), steps):
 | 
			
		||||
            observed_nth_value = self.net(state_batch[i])[action_batch[i]]
 | 
			
		||||
            observed_n_step_values.append(observed_nth_value)
 | 
			
		||||
        observed_n_step_values = torch.stack(observed_n_step_values)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
        # Demonstration loss
 | 
			
		||||
        if demo_mask.sum() > 0:
 | 
			
		||||
            l = torch.ones_like(state_values[demo_mask])
 | 
			
		||||
            expert_actions = action_batch[demo_mask]
 | 
			
		||||
            # l(s, a) is zero for every action the expert doesn't take
 | 
			
		||||
            for i,a in zip(range(len(l)), expert_actions):
 | 
			
		||||
                l[i].fill_(0.8) # According to paper
 | 
			
		||||
                l[i, a] = 0
 | 
			
		||||
            if self.target_net is not None:
 | 
			
		||||
                expert_value = self.target_net(state_batch[demo_mask])
 | 
			
		||||
            else:
 | 
			
		||||
                expert_value = state_values[demo_mask]
 | 
			
		||||
            expert_value = expert_value.gather(1, expert_actions.view(demo_mask.sum(), 1))
 | 
			
		||||
        
 | 
			
		||||
        # Iterate through hyperparamters
 | 
			
		||||
        if isinstance(self.config['dqfd_demo_loss_weight'], collections.Iterable):
 | 
			
		||||
            demo_importance = next(self.config['dqfd_demo_loss_weight'])
 | 
			
		||||
| 
						 | 
				
			
			@ -84,9 +123,14 @@ class DQfDAgent:
 | 
			
		|||
        else: 
 | 
			
		||||
            td_importance = self.config['dqfd_td_loss_weight']
 | 
			
		||||
        
 | 
			
		||||
        
 | 
			
		||||
        # Since dqn_loss and demo_loss are different sizes, the reduction has to happen before they are combined
 | 
			
		||||
        dqn_loss = (torch.as_tensor(importance_weights, device = self.net.device) * F.mse_loss(obtained_values, expected_values, reduction = 'none')).mean()
 | 
			
		||||
        demo_loss = (torch.as_tensor(importance_weights[demo_indexes], device = self.net.device) * F.mse_loss((state_values[demo_indexes] + l).max(1)[0], expert_value, reduction = 'none')).mean()
 | 
			
		||||
        dqn_loss = (torch.as_tensor(importance_weights, device = self.net.device) * F.mse_loss(obtained_values, expected_values, reduction = 'none').squeeze(1)).mean()
 | 
			
		||||
        dqn_n_step_loss =  (torch.as_tensor(importance_weights, device = self.net.device) * F.mse_loss(observed_n_step_values, expected_n_step_values, reduction = 'none').squeeze(1)).mean()
 | 
			
		||||
        if demo_mask.sum() > 0:
 | 
			
		||||
            demo_loss = (torch.as_tensor(importance_weights, device = self.net.device)[demo_mask] * F.mse_loss((state_values[demo_mask] + l).max(1)[0].unsqueeze(1), expert_value, reduction = 'none').squeeze(1)).mean()
 | 
			
		||||
        else:
 | 
			
		||||
            demo_loss = 0
 | 
			
		||||
        loss = td_importance * dqn_loss + demo_importance * demo_loss
 | 
			
		||||
        
 | 
			
		||||
        if self.logger is not None:
 | 
			
		||||
| 
						 | 
				
			
			@ -106,6 +150,9 @@ class DQfDAgent:
 | 
			
		|||
        # If we're sampling by TD error, readjust the weights of the experiences
 | 
			
		||||
        # TODO: Can probably adjust demonstration priority here
 | 
			
		||||
        td_error = (obtained_values - expected_values).detach().abs()
 | 
			
		||||
        td_error[demo_mask] = td_error[demo_mask] + self.config['demo_prio_bonus']
 | 
			
		||||
        observed_mask = batch_index_tensors >= self.memory.demo_position
 | 
			
		||||
        td_error[observed_mask] = td_error[observed_mask] + self.config['observed_prio_bonus']
 | 
			
		||||
        self.memory.update_priorities(batch_indexes, td_error)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,4 +1,5 @@
 | 
			
		|||
from .PrioritizedReplayMemory import PrioritizedReplayMemory
 | 
			
		||||
from collections import namedtuple
 | 
			
		||||
 | 
			
		||||
Transition = namedtuple('Transition',
 | 
			
		||||
    ('state', 'action', 'reward', 'next_state', 'done'))
 | 
			
		||||
| 
						 | 
				
			
			@ -11,9 +12,9 @@ class DQfDMemory(PrioritizedReplayMemory):
 | 
			
		|||
        self.obtained_transitions_length = 0
 | 
			
		||||
    
 | 
			
		||||
    def append(self, *args, **kwargs):
 | 
			
		||||
        super().append(self, *args, **kwargs)
 | 
			
		||||
        super().append(*args, **kwargs)
 | 
			
		||||
        # Don't overwrite demonstration data
 | 
			
		||||
        self.position = self.demo_position + ((self.position + 1) % (self.capacity - self.demo_position))
 | 
			
		||||
        self.position = self.demo_position + ((self.position + 1) % (len(self.memory) - self.demo_position))
 | 
			
		||||
    
 | 
			
		||||
    def append_demonstration(self, *args):
 | 
			
		||||
        demonstrations = self.memory[:self.demo_position]
 | 
			
		||||
| 
						 | 
				
			
			@ -25,4 +26,4 @@ class DQfDMemory(PrioritizedReplayMemory):
 | 
			
		|||
            if len(demonstrations) + len(obtained_transitions) + 1 > self.capacity:
 | 
			
		||||
                obtained_transitions = obtained_transitions[:(self.capacity - len(demonstrations) - 1)]
 | 
			
		||||
            self.memory = demonstrations + [Transition(*args)] + obtained_transitions
 | 
			
		||||
            self.demo_position += 1
 | 
			
		||||
            self.demo_position += 1
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -105,7 +105,7 @@ class SumSegmentTree(SegmentTree):
 | 
			
		|||
        """Returns arr[start] + ... + arr[end]"""
 | 
			
		||||
        return super(SumSegmentTree, self).reduce(start, end)
 | 
			
		||||
 | 
			
		||||
    @jit(forceobj = True)
 | 
			
		||||
    @jit(forceobj = True, parallel = True)
 | 
			
		||||
    def find_prefixsum_idx(self, prefixsum):
 | 
			
		||||
        """Find the highest index `i` in the array such that
 | 
			
		||||
            sum(arr[0] + arr[1] + ... + arr[i - i]) <= prefixsum
 | 
			
		||||
| 
						 | 
				
			
			@ -204,17 +204,6 @@ class PrioritizedReplayMemory(ReplayMemory):
 | 
			
		|||
            (0 - no corrections, 1 - full correction)
 | 
			
		||||
        Returns
 | 
			
		||||
        -------
 | 
			
		||||
        obs_batch: np.array
 | 
			
		||||
            batch of observations
 | 
			
		||||
        act_batch: np.array
 | 
			
		||||
            batch of actions executed given obs_batch
 | 
			
		||||
        rew_batch: np.array
 | 
			
		||||
            rewards received as results of executing act_batch
 | 
			
		||||
        next_obs_batch: np.array
 | 
			
		||||
            next set of observations seen after executing act_batch
 | 
			
		||||
        done_mask: np.array
 | 
			
		||||
            done_mask[i] = 1 if executing act_batch[i] resulted in
 | 
			
		||||
            the end of an episode and 0 otherwise.
 | 
			
		||||
        weights: np.array
 | 
			
		||||
            Array of shape (batch_size,) and dtype np.float32
 | 
			
		||||
            denoting importance weight of each sampled transition
 | 
			
		||||
| 
						 | 
				
			
			@ -224,21 +213,55 @@ class PrioritizedReplayMemory(ReplayMemory):
 | 
			
		|||
        """
 | 
			
		||||
        assert beta > 0
 | 
			
		||||
 | 
			
		||||
        # Sample indexes
 | 
			
		||||
        idxes = self._sample_proportional(batch_size)
 | 
			
		||||
 | 
			
		||||
        # Calculate weights
 | 
			
		||||
        weights = []
 | 
			
		||||
        p_min = self._it_min.min() / self._it_sum.sum()
 | 
			
		||||
        max_weight = (p_min * len(self.memory)) ** (-beta)
 | 
			
		||||
 | 
			
		||||
        for idx in idxes:
 | 
			
		||||
            p_sample = self._it_sum[idx] / self._it_sum.sum()
 | 
			
		||||
            weight = (p_sample * len(self.memory)) ** (-beta)
 | 
			
		||||
            weights.append(weight / max_weight)
 | 
			
		||||
        weights = np.array(weights)
 | 
			
		||||
 | 
			
		||||
        # Combine all data into a batch
 | 
			
		||||
        encoded_sample = tuple(zip(*self._encode_sample(idxes)))
 | 
			
		||||
        batch = list(zip(*encoded_sample, weights, idxes))
 | 
			
		||||
        return batch
 | 
			
		||||
 | 
			
		||||
    def sample_n_steps(self, batch_size, steps, beta):
 | 
			
		||||
        assert beta > 0
 | 
			
		||||
 | 
			
		||||
        memory = self.memory
 | 
			
		||||
        self.memory = self.memory[:-steps]
 | 
			
		||||
        sample_size = batch_size // steps
 | 
			
		||||
 | 
			
		||||
        # Sample indexes and get n-steps after that
 | 
			
		||||
        idxes = self._sample_proportional(sample_size)
 | 
			
		||||
        step_idxes = []
 | 
			
		||||
        for i in idxes:
 | 
			
		||||
            step_idxes += range(i, i + steps)
 | 
			
		||||
 | 
			
		||||
        # Calculate appropriate weights
 | 
			
		||||
        weights = []
 | 
			
		||||
        p_min = self._it_min.min() / self._it_sum.sum()
 | 
			
		||||
        max_weight = (p_min * len(self.memory)) ** (-beta)
 | 
			
		||||
        for idx in step_idxes:
 | 
			
		||||
            p_sample = self._it_sum[idx] / self._it_sum.sum()
 | 
			
		||||
            weight = (p_sample * len(self.memory)) ** (-beta)
 | 
			
		||||
            weights.append(weight / max_weight)
 | 
			
		||||
        weights = np.array(weights)
 | 
			
		||||
        
 | 
			
		||||
        # Combine all the data together into a batch
 | 
			
		||||
        encoded_sample = tuple(zip(*self._encode_sample(step_idxes)))
 | 
			
		||||
        batch = list(zip(*encoded_sample, weights, idxes))
 | 
			
		||||
        
 | 
			
		||||
        # Restore memory and return batch
 | 
			
		||||
        self.memory = memory
 | 
			
		||||
        return batch
 | 
			
		||||
    
 | 
			
		||||
    @jit(forceobj = True)
 | 
			
		||||
    def update_priorities(self, idxes, priorities):
 | 
			
		||||
        """Update priorities of sampled transitions.
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -38,6 +38,13 @@ class ReplayMemory(object):
 | 
			
		|||
 | 
			
		||||
    def sample(self, batch_size):
 | 
			
		||||
        return random.sample(self.memory, batch_size)
 | 
			
		||||
    
 | 
			
		||||
    def sample_n_steps(self, batch_size, steps):
 | 
			
		||||
        idxes = random.sample(range(len(self.memory) - steps), batch_size // steps)
 | 
			
		||||
        step_idxes = []
 | 
			
		||||
        for i in idxes:
 | 
			
		||||
            step_idxes += range(i, i + steps)
 | 
			
		||||
        return self._encode_sample(step_idxes)
 | 
			
		||||
 | 
			
		||||
    def __len__(self):
 | 
			
		||||
        return len(self.memory)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue