Added PPOAgent and A2CAgent to the agents submodule.
Also made some small changes to how memories are queried
This commit is contained in:
		
							parent
							
								
									21b820b401
								
							
						
					
					
						commit
						26084d4c7c
					
				
					 8 changed files with 430 additions and 12 deletions
				
			
		| 
						 | 
				
			
			@ -10,8 +10,6 @@ class StochasticSelector(ArgMaxSelector):
 | 
			
		|||
        self.model = model
 | 
			
		||||
        self.action_size = action_size
 | 
			
		||||
        self.device = device
 | 
			
		||||
        if not isinstance(memory, rltorch.memory.EpisodeMemory):
 | 
			
		||||
            raise ValueError("Memory must be of instance EpisodeMemory")
 | 
			
		||||
        self.memory = memory
 | 
			
		||||
    def best_act(self, state, log_prob = True):
 | 
			
		||||
        if self.device is not None:
 | 
			
		||||
| 
						 | 
				
			
			@ -19,6 +17,6 @@ class StochasticSelector(ArgMaxSelector):
 | 
			
		|||
        action_probabilities = self.model(state)
 | 
			
		||||
        distribution = Categorical(action_probabilities)
 | 
			
		||||
        action = distribution.sample()
 | 
			
		||||
        if log_prob:
 | 
			
		||||
        if log_prob and isinstance(self.memory, rltorch.memory.EpisodeMemory):
 | 
			
		||||
            self.memory.append_log_probs(distribution.log_prob(action))
 | 
			
		||||
        return action.item()
 | 
			
		||||
							
								
								
									
										73
									
								
								rltorch/agents/A2CSingleAgent.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										73
									
								
								rltorch/agents/A2CSingleAgent.py
									
										
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,73 @@
 | 
			
		|||
# Deprecated since the idea of the idea shouldn't work without having some sort of "mental model" of the environment
 | 
			
		||||
 | 
			
		||||
from copy import deepcopy
 | 
			
		||||
import numpy as np
 | 
			
		||||
import torch
 | 
			
		||||
import torch.nn.functional as F
 | 
			
		||||
from torch.distributions import Categorical
 | 
			
		||||
import rltorch
 | 
			
		||||
import rltorch.memory as M
 | 
			
		||||
import collections
 | 
			
		||||
import random
 | 
			
		||||
 | 
			
		||||
class A2CSingleAgent:
 | 
			
		||||
  def __init__(self, policy_net, value_net, memory, config, logger = None):
 | 
			
		||||
    self.policy_net = policy_net
 | 
			
		||||
    self.value_net = value_net
 | 
			
		||||
    self.memory = memory
 | 
			
		||||
    self.config = deepcopy(config)
 | 
			
		||||
    self.logger = logger
 | 
			
		||||
 | 
			
		||||
  def _discount_rewards(self, rewards):
 | 
			
		||||
    discounted_rewards = torch.zeros_like(rewards)
 | 
			
		||||
    running_add = 0
 | 
			
		||||
    for t in reversed(range(len(rewards))):
 | 
			
		||||
      running_add = running_add * self.config['discount_rate'] + rewards[t]
 | 
			
		||||
      discounted_rewards[t] = running_add
 | 
			
		||||
 | 
			
		||||
    return discounted_rewards
 | 
			
		||||
  
 | 
			
		||||
  
 | 
			
		||||
  def learn(self):
 | 
			
		||||
    if len(self.memory) < self.config['batch_size']:
 | 
			
		||||
      return
 | 
			
		||||
 | 
			
		||||
    episode_batch = self.memory.recall()
 | 
			
		||||
    state_batch, action_batch, reward_batch, next_state_batch, done_batch, log_prob_batch = zip(*episode_batch)  
 | 
			
		||||
 | 
			
		||||
    state_batch = torch.cat(state_batch).to(self.value_net.device)
 | 
			
		||||
    reward_batch = torch.tensor(reward_batch).to(self.value_net.device)
 | 
			
		||||
    not_done_batch = ~torch.tensor(done_batch).to(self.value_net.device)
 | 
			
		||||
    next_state_batch = torch.cat(next_state_batch).to(self.value_net.device)
 | 
			
		||||
    log_prob_batch = torch.cat(log_prob_batch).to(self.value_net.device)
 | 
			
		||||
 | 
			
		||||
    ## Value Loss
 | 
			
		||||
    value_loss = F.mse_loss(self._discount_rewards(reward_batch), self.value_net(state_batch[0]))
 | 
			
		||||
    self.value_net.zero_grad()
 | 
			
		||||
    value_loss.backward()
 | 
			
		||||
    self.value_net.step()
 | 
			
		||||
 | 
			
		||||
    ## Policy Loss
 | 
			
		||||
    with torch.no_grad():
 | 
			
		||||
      state_values = self.value_net(state_batch)
 | 
			
		||||
      next_state_values = torch.zeros_like(state_values) 
 | 
			
		||||
      next_state_values[not_done_batch] = self.value_net(next_state_batch[not_done_batch])
 | 
			
		||||
    advantages = (reward_batch.unsqueeze(1) + self.config['discount_rate'] * next_state_values) - state_values
 | 
			
		||||
    advantages = advantages.squeeze(1)
 | 
			
		||||
 | 
			
		||||
    policy_loss = (-log_prob_batch * advantages).sum()
 | 
			
		||||
    
 | 
			
		||||
    if self.logger is not None:
 | 
			
		||||
      self.logger.append("Loss/Policy", policy_loss.item())
 | 
			
		||||
      self.logger.append("Loss/Value", value_loss.item())
 | 
			
		||||
 | 
			
		||||
    
 | 
			
		||||
    self.policy_net.zero_grad()
 | 
			
		||||
    policy_loss.backward()
 | 
			
		||||
    self.policy_net.step()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    # Memory is irrelevant for future training
 | 
			
		||||
    self.memory.clear()
 | 
			
		||||
    
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -34,27 +34,29 @@ class DQNAgent:
 | 
			
		|||
        next_state_batch = next_state_batch.to(self.net.device)
 | 
			
		||||
        not_done_batch = not_done_batch.to(self.net.device)
 | 
			
		||||
 | 
			
		||||
        obtained_values = self.net(state_batch).gather(1, action_batch.view(self.config['batch_size'], 1))
 | 
			
		||||
        state_values = self.net(state_batch)
 | 
			
		||||
        obtained_values = state_values.gather(1, action_batch.view(self.config['batch_size'], 1))
 | 
			
		||||
 | 
			
		||||
        with torch.no_grad():
 | 
			
		||||
            # Use the target net to produce action values for the next state
 | 
			
		||||
            # and the regular net to select the action
 | 
			
		||||
            # That way we decouple the value and action selecting processes (DOUBLE DQN)
 | 
			
		||||
            not_done_size = not_done_batch.sum()
 | 
			
		||||
            next_state_values = torch.zeros_like(state_values, device = self.net.device)
 | 
			
		||||
            if self.target_net is not None:
 | 
			
		||||
                next_state_values = self.target_net(next_state_batch)
 | 
			
		||||
                next_best_action = self.net(next_state_batch).argmax(1)
 | 
			
		||||
                next_state_values[not_done_batch] = self.target_net(next_state_batch[not_done_batch])
 | 
			
		||||
                next_best_action = self.net(next_state_batch[not_done_batch]).argmax(1)
 | 
			
		||||
            else:
 | 
			
		||||
                next_state_values = self.net(next_state_batch)
 | 
			
		||||
                next_best_action = next_state_values.argmax(1)
 | 
			
		||||
                next_state_values[not_done_batch] = self.net(next_state_batch[not_done_batch])
 | 
			
		||||
                next_best_action = next_state_values[not_done_batch].argmax(1)
 | 
			
		||||
 | 
			
		||||
            best_next_state_value = torch.zeros(self.config['batch_size'], device = self.net.device)
 | 
			
		||||
            best_next_state_value[not_done_batch] = next_state_values.gather(1, next_best_action.view((not_done_size, 1))).squeeze(1)
 | 
			
		||||
            best_next_state_value[not_done_batch] = next_state_values[not_done_batch].gather(1, next_best_action.view((not_done_size, 1))).squeeze(1)
 | 
			
		||||
            
 | 
			
		||||
        expected_values = (reward_batch + (self.config['discount_rate'] * best_next_state_value)).unsqueeze(1)
 | 
			
		||||
 | 
			
		||||
        if (isinstance(self.memory, M.PrioritizedReplayMemory)):
 | 
			
		||||
            loss = (torch.as_tensor(importance_weights, device = self.net.device) * (obtained_values - expected_values)**2).mean()
 | 
			
		||||
            loss = (torch.as_tensor(importance_weights, device = self.net.device) * ((obtained_values - expected_values)**2).squeeze(1)).mean()
 | 
			
		||||
        else:
 | 
			
		||||
            loss = F.mse_loss(obtained_values, expected_values)
 | 
			
		||||
        
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										83
									
								
								rltorch/agents/PPOAgent.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										83
									
								
								rltorch/agents/PPOAgent.py
									
										
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,83 @@
 | 
			
		|||
# Deprecated since the idea of the idea shouldn't work without having some sort of "mental model" of the environment
 | 
			
		||||
 | 
			
		||||
from copy import deepcopy
 | 
			
		||||
import numpy as np
 | 
			
		||||
import torch
 | 
			
		||||
import torch.nn.functional as F
 | 
			
		||||
from torch.distributions import Categorical
 | 
			
		||||
import rltorch
 | 
			
		||||
import rltorch.memory as M
 | 
			
		||||
import collections
 | 
			
		||||
import random
 | 
			
		||||
 | 
			
		||||
class PPOAgent:
 | 
			
		||||
  def __init__(self, policy_net, value_net, memory, config, logger = None):
 | 
			
		||||
    self.policy_net = policy_net
 | 
			
		||||
    self.old_policy_net = rltorch.network.TargetNetwork(policy_net)
 | 
			
		||||
    self.value_net = value_net
 | 
			
		||||
    self.memory = memory
 | 
			
		||||
    self.config = deepcopy(config)
 | 
			
		||||
    self.logger = logger
 | 
			
		||||
 | 
			
		||||
  def _discount_rewards(self, rewards):
 | 
			
		||||
    discounted_rewards = torch.zeros_like(rewards)
 | 
			
		||||
    running_add = 0
 | 
			
		||||
    for t in reversed(range(len(rewards))):
 | 
			
		||||
      running_add = running_add * self.config['discount_rate'] + rewards[t]
 | 
			
		||||
      discounted_rewards[t] = running_add
 | 
			
		||||
 | 
			
		||||
    return discounted_rewards
 | 
			
		||||
  
 | 
			
		||||
  
 | 
			
		||||
  def learn(self):
 | 
			
		||||
    if len(self.memory) < self.config['batch_size']:
 | 
			
		||||
      return
 | 
			
		||||
 | 
			
		||||
    episode_batch = self.memory.recall()
 | 
			
		||||
    state_batch, action_batch, reward_batch, next_state_batch, done_batch, log_prob_batch = zip(*episode_batch)  
 | 
			
		||||
 | 
			
		||||
    state_batch = torch.cat(state_batch).to(self.value_net.device)
 | 
			
		||||
    action_batch = torch.tensor(action_batch).to(self.value_net.device)
 | 
			
		||||
    reward_batch = torch.tensor(reward_batch).to(self.value_net.device)
 | 
			
		||||
    not_done_batch = ~torch.tensor(done_batch).to(self.value_net.device)
 | 
			
		||||
    next_state_batch = torch.cat(next_state_batch).to(self.value_net.device)
 | 
			
		||||
    log_prob_batch = torch.cat(log_prob_batch).to(self.value_net.device)
 | 
			
		||||
 | 
			
		||||
    ## Value Loss
 | 
			
		||||
    value_loss = F.mse_loss(self._discount_rewards(reward_batch), self.value_net(state_batch[0]))
 | 
			
		||||
    self.value_net.zero_grad()
 | 
			
		||||
    value_loss.backward()
 | 
			
		||||
    self.value_net.step()
 | 
			
		||||
 | 
			
		||||
    ## Policy Loss
 | 
			
		||||
    with torch.no_grad():
 | 
			
		||||
      state_values = self.value_net(state_batch)
 | 
			
		||||
      next_state_values = torch.zeros_like(state_values) 
 | 
			
		||||
      next_state_values[not_done_batch] = self.value_net(next_state_batch[not_done_batch])
 | 
			
		||||
    advantages = (reward_batch.unsqueeze(1) + self.config['discount_rate'] * next_state_values) - state_values
 | 
			
		||||
    advantages = advantages.squeeze(1)
 | 
			
		||||
 | 
			
		||||
    action_probabilities = self.old_policy_net(state_batch).detach()
 | 
			
		||||
    distributions = list(map(Categorical, action_probabilities))
 | 
			
		||||
    old_log_probs = torch.stack(list(map(lambda distribution, action: distribution.log_prob(action), distributions, action_batch)))
 | 
			
		||||
 | 
			
		||||
    policy_ratio = torch.exp(log_prob_batch - old_log_probs) # Equivalent to (log_prob / old_log_prob)
 | 
			
		||||
    policy_loss1 = policy_ratio * advantages 
 | 
			
		||||
    policy_loss2 = policy_ratio.clamp(min = 0.8, max = 1.2) * advantages # From original paper
 | 
			
		||||
    policy_loss = -torch.min(policy_loss1, policy_loss2).sum()
 | 
			
		||||
    
 | 
			
		||||
    if self.logger is not None:
 | 
			
		||||
      self.logger.append("Loss/Policy", policy_loss.item())
 | 
			
		||||
      self.logger.append("Loss/Value", value_loss.item())
 | 
			
		||||
 | 
			
		||||
    
 | 
			
		||||
    self.old_policy_net.sync()
 | 
			
		||||
    self.policy_net.zero_grad()
 | 
			
		||||
    policy_loss.backward()
 | 
			
		||||
    self.policy_net.step()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    # Memory is irrelevant for future training
 | 
			
		||||
    self.memory.clear()
 | 
			
		||||
    
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -31,7 +31,7 @@ class REINFORCEAgent:
 | 
			
		|||
    discount_reward_batch = self._discount_rewards(torch.tensor(reward_batch))
 | 
			
		||||
    log_prob_batch = torch.cat(log_prob_batch)
 | 
			
		||||
 | 
			
		||||
    policy_loss = (-1 * log_prob_batch * discount_reward_batch).sum()
 | 
			
		||||
    policy_loss = (-log_prob_batch * discount_reward_batch).sum()
 | 
			
		||||
    
 | 
			
		||||
    if self.logger is not None:
 | 
			
		||||
            self.logger.append("Loss", policy_loss.item())
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										260
									
								
								rltorch/agents/_A2CSingleAgent.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										260
									
								
								rltorch/agents/_A2CSingleAgent.py
									
										
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,260 @@
 | 
			
		|||
# Deprecated since the idea of the idea shouldn't work without having some sort of "mental model" of the environment
 | 
			
		||||
 | 
			
		||||
from copy import deepcopy
 | 
			
		||||
import numpy as np
 | 
			
		||||
import torch
 | 
			
		||||
import torch.nn.functional as F
 | 
			
		||||
from torch.distributions import Categorical
 | 
			
		||||
import rltorch
 | 
			
		||||
import rltorch.memory as M
 | 
			
		||||
import collections
 | 
			
		||||
import random
 | 
			
		||||
 | 
			
		||||
class A2CSingleAgent:
 | 
			
		||||
  def __init__(self, policy_net, value_net, memory, config, target_value_net = None, logger = None):
 | 
			
		||||
    self.policy_net = policy_net
 | 
			
		||||
    self.value_net = value_net
 | 
			
		||||
    self.memory = memory
 | 
			
		||||
    self.config = deepcopy(config)
 | 
			
		||||
    self.target_value_net = target_value_net
 | 
			
		||||
    self.logger = logger
 | 
			
		||||
 | 
			
		||||
  def learn_value(self):
 | 
			
		||||
    if (isinstance(self.memory, M.PrioritizedReplayMemory)):
 | 
			
		||||
      weight_importance = self.config['prioritized_replay_weight_importance']
 | 
			
		||||
      # If it's a scheduler then get the next value by calling next, otherwise just use it's value
 | 
			
		||||
      beta = next(weight_importance) if isinstance(weight_importance, collections.Iterable) else weight_importance
 | 
			
		||||
      minibatch = self.memory.sample(self.config['batch_size'], beta = beta)
 | 
			
		||||
      state_batch, action_batch, reward_batch, next_state_batch, not_done_batch, importance_weights, batch_indexes = M.zip_batch(minibatch, priority = True)
 | 
			
		||||
    else:
 | 
			
		||||
      minibatch = self.memory.sample(self.config['batch_size'])
 | 
			
		||||
      state_batch, action_batch, reward_batch, next_state_batch, not_done_batch = M.zip_batch(minibatch)
 | 
			
		||||
    
 | 
			
		||||
    # Send to their appropriate devices 
 | 
			
		||||
    state_batch = state_batch.to(self.value_net.device)
 | 
			
		||||
    action_batch = action_batch.to(self.value_net.device)
 | 
			
		||||
    reward_batch = reward_batch.to(self.value_net.device)
 | 
			
		||||
    next_state_batch = next_state_batch.to(self.value_net.device)
 | 
			
		||||
    not_done_batch = not_done_batch.to(self.value_net.device)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    ## Value Loss
 | 
			
		||||
    state_values = self.value_net(state_batch)
 | 
			
		||||
    obtained_values = state_values.gather(1, action_batch.view(self.config['batch_size'], 1))
 | 
			
		||||
    with torch.no_grad():
 | 
			
		||||
      # Use the target net to produce action values for the next state
 | 
			
		||||
      # and the regular net to select the action
 | 
			
		||||
      # That way we decouple the value and action selecting processes (DOUBLE DQN)
 | 
			
		||||
      not_done_size = not_done_batch.sum()
 | 
			
		||||
      next_state_values = torch.zeros_like(state_values)
 | 
			
		||||
      if self.target_value_net is not None:
 | 
			
		||||
        next_state_values[not_done_batch] = self.target_value_net(next_state_batch[not_done_batch])
 | 
			
		||||
        next_best_action = self.value_net(next_state_batch).argmax(1)
 | 
			
		||||
      else:
 | 
			
		||||
        next_state_values[not_done_batch] = self.value_net(next_state_batch[not_done_batch])
 | 
			
		||||
        next_best_action = next_state_values.argmax(1)
 | 
			
		||||
 | 
			
		||||
      best_next_state_value = torch.zeros(self.config['batch_size'], device = self.value_net.device)
 | 
			
		||||
      # best_next_state_value[not_done_batch] = next_state_values.gather(1, next_best_action.view((not_done_size, 1))).squeeze(1)
 | 
			
		||||
      best_next_state_value[not_done_batch] = next_state_values[not_done_batch].gather(1, next_best_action[not_done_batch].view((not_done_size, 1))).squeeze(1)
 | 
			
		||||
            
 | 
			
		||||
    expected_values = (reward_batch + (self.config['discount_rate'] * best_next_state_value)).unsqueeze(1)
 | 
			
		||||
 | 
			
		||||
    if (isinstance(self.memory, M.PrioritizedReplayMemory)):
 | 
			
		||||
      importance_weights = torch.as_tensor(importance_weights, device = self.value_net.device) 
 | 
			
		||||
      value_loss = (importance_weights * ((obtained_values - expected_values)**2).squeeze(1)).mean()
 | 
			
		||||
    else:
 | 
			
		||||
      value_loss = F.mse_loss(obtained_values, expected_values)
 | 
			
		||||
 | 
			
		||||
    if (isinstance(self.memory, M.PrioritizedReplayMemory)):
 | 
			
		||||
      td_error = (obtained_values - expected_values).detach().abs()
 | 
			
		||||
      self.memory.update_priorities(batch_indexes, td_error)
 | 
			
		||||
 | 
			
		||||
    self.value_net.zero_grad()
 | 
			
		||||
    value_loss.backward()
 | 
			
		||||
    self.value_net.step()
 | 
			
		||||
    
 | 
			
		||||
    if self.target_value_net is not None:
 | 
			
		||||
      if 'target_sync_tau' in self.config:
 | 
			
		||||
        self.target_value_net.partial_sync(self.config['target_sync_tau'])
 | 
			
		||||
      else:
 | 
			
		||||
        self.target_value_net.sync()
 | 
			
		||||
 | 
			
		||||
    if self.logger is not None:
 | 
			
		||||
      self.logger.append("Loss/Value", value_loss.item())
 | 
			
		||||
    
 | 
			
		||||
 | 
			
		||||
  def learn_policy(self):
 | 
			
		||||
    starting_index = random.randint(0, len(self.memory) - self.config['batch_size'])
 | 
			
		||||
    state_batch, action_batch, reward_batch, next_state_batch, not_done_batch = M.zip_batch(self.memory[starting_index:(starting_index + self.config['batch_size'])])
 | 
			
		||||
    
 | 
			
		||||
    state_batch = state_batch.to(self.policy_net.device)
 | 
			
		||||
    action_batch = action_batch.to(self.policy_net.device)
 | 
			
		||||
    reward_batch = reward_batch.to(self.policy_net.device)
 | 
			
		||||
    next_state_batch = next_state_batch.to(self.policy_net.device)
 | 
			
		||||
    not_done_batch = not_done_batch.to(self.policy_net.device)
 | 
			
		||||
 | 
			
		||||
    # Find when episode ends and filter out the Transitions after
 | 
			
		||||
    episode_ends = (~not_done_batch).nonzero().squeeze(1)
 | 
			
		||||
    start_idx = 0
 | 
			
		||||
    end_idx = self.config['batch_size']
 | 
			
		||||
    if len(episode_ends) > 0:
 | 
			
		||||
      if (episode_ends[0] == 0).item():
 | 
			
		||||
        if len(episode_ends) > 1:
 | 
			
		||||
          start_idx = 1
 | 
			
		||||
          end_idx = episode_ends[1] + 1
 | 
			
		||||
        else:
 | 
			
		||||
          start_idx = 1
 | 
			
		||||
      else:
 | 
			
		||||
        end_idx = episode_ends[0] + 1
 | 
			
		||||
    batch_size = end_idx - start_idx
 | 
			
		||||
 | 
			
		||||
    # Now filter...
 | 
			
		||||
    state_batch = state_batch[start_idx:end_idx]
 | 
			
		||||
    action_batch = action_batch[start_idx:end_idx]
 | 
			
		||||
    reward_batch = reward_batch[start_idx:end_idx]
 | 
			
		||||
    next_state_batch = next_state_batch[start_idx:end_idx]
 | 
			
		||||
    not_done_batch = not_done_batch[start_idx:end_idx]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    with torch.no_grad():
 | 
			
		||||
      if self.target_value_net is not None:
 | 
			
		||||
        state_values = self.target_value_net(state_batch)
 | 
			
		||||
        next_state_values = torch.zeros_like(state_values, device = self.value_net.device) 
 | 
			
		||||
        next_state_values[not_done_batch] = self.target_value_net(next_state_batch[not_done_batch])
 | 
			
		||||
      else:
 | 
			
		||||
        state_values = self.value_net(state_batch)
 | 
			
		||||
        next_state_values = torch.zeros_like(state_values, device = self.value_net.device) 
 | 
			
		||||
        next_state_values[not_done_batch] = self.value_net(next_state_batch[not_done_batch])
 | 
			
		||||
      
 | 
			
		||||
    obtained_values = state_values.gather(1, action_batch.view(batch_size, 1))
 | 
			
		||||
    approx_state_action_values = reward_batch.unsqueeze(1) + self.config['discount_rate'] * next_state_values
 | 
			
		||||
    advantage = (obtained_values - approx_state_action_values.mean(1).unsqueeze(1)) 
 | 
			
		||||
    # Scale and squeeze the dimension
 | 
			
		||||
    advantage = advantage.squeeze(1)
 | 
			
		||||
    # advantage = (advantage / (state_values.std() + np.finfo('float').eps)).squeeze(1)
 | 
			
		||||
    action_probabilities = self.policy_net(state_batch)
 | 
			
		||||
    distributions = list(map(Categorical, action_probabilities))
 | 
			
		||||
    log_probs = torch.stack(list(map(lambda distribution, action: distribution.log_prob(action), distributions, action_batch)))
 | 
			
		||||
    policy_loss = (-log_probs * advantage).mean()
 | 
			
		||||
 | 
			
		||||
    self.policy_net.zero_grad()
 | 
			
		||||
    policy_loss.backward()
 | 
			
		||||
    self.policy_net.step()
 | 
			
		||||
 | 
			
		||||
    if self.logger is not None:
 | 
			
		||||
      self.logger.append("Loss/Policy", policy_loss.item())
 | 
			
		||||
  
 | 
			
		||||
  def learn(self):
 | 
			
		||||
    if len(self.memory) < self.config['batch_size']:
 | 
			
		||||
      return
 | 
			
		||||
    self.learn_value()
 | 
			
		||||
    self.learn_policy()
 | 
			
		||||
    
 | 
			
		||||
 | 
			
		||||
    
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
  # def learn(self):
 | 
			
		||||
  #   if len(self.memory) < self.config['batch_size']:
 | 
			
		||||
  #     return
 | 
			
		||||
    
 | 
			
		||||
  #   if (isinstance(self.memory, M.PrioritizedReplayMemory)):
 | 
			
		||||
  #     weight_importance = self.config['prioritized_replay_weight_importance']
 | 
			
		||||
  #     # If it's a scheduler then get the next value by calling next, otherwise just use it's value
 | 
			
		||||
  #     beta = next(weight_importance) if isinstance(weight_importance, collections.Iterable) else weight_importance
 | 
			
		||||
  #     minibatch = self.memory.sample(self.config['batch_size'], beta = beta)
 | 
			
		||||
  #     state_batch, action_batch, reward_batch, next_state_batch, not_done_batch, importance_weights, batch_indexes = M.zip_batch(minibatch, priority = True)
 | 
			
		||||
  #   else:
 | 
			
		||||
  #     minibatch = self.memory.sample(self.config['batch_size'])
 | 
			
		||||
  #     state_batch, action_batch, reward_batch, next_state_batch, not_done_batch = M.zip_batch(minibatch)
 | 
			
		||||
    
 | 
			
		||||
  #   # Send to their appropriate devices 
 | 
			
		||||
  #   # [TODO] Notice how we're sending it to the value_net's device, what if policy_net was on a different device?
 | 
			
		||||
  #   state_batch = state_batch.to(self.value_net.device)
 | 
			
		||||
  #   action_batch = action_batch.to(self.value_net.device)
 | 
			
		||||
  #   reward_batch = reward_batch.to(self.value_net.device)
 | 
			
		||||
  #   next_state_batch = next_state_batch.to(self.value_net.device)
 | 
			
		||||
  #   not_done_batch = not_done_batch.to(self.value_net.device)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
  #   ## Value Loss
 | 
			
		||||
 | 
			
		||||
  #   obtained_values = self.value_net(state_batch).gather(1, action_batch.view(self.config['batch_size'], 1))
 | 
			
		||||
 | 
			
		||||
  #   with torch.no_grad():
 | 
			
		||||
  #     # Use the target net to produce action values for the next state
 | 
			
		||||
  #     # and the regular net to select the action
 | 
			
		||||
  #     # That way we decouple the value and action selecting processes (DOUBLE DQN)
 | 
			
		||||
  #     not_done_size = not_done_batch.sum()
 | 
			
		||||
  #     if self.target_value_net is not None:
 | 
			
		||||
  #       next_state_values = self.target_value_net(next_state_batch)
 | 
			
		||||
  #       next_best_action = self.value_net(next_state_batch).argmax(1)
 | 
			
		||||
  #     else:
 | 
			
		||||
  #       next_state_values = self.value_net(next_state_batch)
 | 
			
		||||
  #       next_best_action = next_state_values.argmax(1)
 | 
			
		||||
 | 
			
		||||
  #     best_next_state_value = torch.zeros(self.config['batch_size'], device = self.value_net.device)
 | 
			
		||||
  #     best_next_state_value[not_done_batch] = next_state_values.gather(1, next_best_action.view((not_done_size, 1))).squeeze(1)
 | 
			
		||||
            
 | 
			
		||||
  #   expected_values = (reward_batch + (self.config['discount_rate'] * best_next_state_value)).unsqueeze(1)
 | 
			
		||||
 | 
			
		||||
  #   if (isinstance(self.memory, M.PrioritizedReplayMemory)):
 | 
			
		||||
  #     importance_weights = torch.as_tensor(importance_weights, device = self.value_net.device) 
 | 
			
		||||
  #     value_loss = (importance_weights * ((obtained_values - expected_values)**2).squeeze(1)).mean()
 | 
			
		||||
  #   else:
 | 
			
		||||
  #     value_loss = F.mse_loss(obtained_values, expected_values)
 | 
			
		||||
    
 | 
			
		||||
  #   self.value_net.zero_grad()
 | 
			
		||||
  #   value_loss.backward()
 | 
			
		||||
  #   self.value_net.step()
 | 
			
		||||
 | 
			
		||||
  #   if self.target_value_net is not None:
 | 
			
		||||
  #     if 'target_sync_tau' in self.config:
 | 
			
		||||
  #       self.target_value_net.partial_sync(self.config['target_sync_tau'])
 | 
			
		||||
  #     else:
 | 
			
		||||
  #       self.target_value_net.sync()
 | 
			
		||||
 | 
			
		||||
  #   if (isinstance(self.memory, M.PrioritizedReplayMemory)):
 | 
			
		||||
  #     td_error = (obtained_values - expected_values).detach().abs()
 | 
			
		||||
  #     self.memory.update_priorities(batch_indexes, td_error)
 | 
			
		||||
 | 
			
		||||
  #   if self.logger is not None:
 | 
			
		||||
  #     self.logger.append("ValueLoss", value_loss.item())
 | 
			
		||||
 | 
			
		||||
  #   ## Policy Loss
 | 
			
		||||
  #   with torch.no_grad():
 | 
			
		||||
  #     state_values = self.value_net(state_batch)
 | 
			
		||||
  #     if self.target_value_net is not None:
 | 
			
		||||
  #       next_state_values = self.target_value_net(next_state_batch)
 | 
			
		||||
  #     else:
 | 
			
		||||
  #       next_state_values = self.value_net(next_state_batch)
 | 
			
		||||
      
 | 
			
		||||
  #   state_action_values = state_values.gather(1, action_batch.view(self.config['batch_size'], 1))
 | 
			
		||||
  #   average_next_state_values = torch.zeros(self.config['batch_size'], device = self.value_net.device) 
 | 
			
		||||
  #   average_next_state_values[not_done_batch] = next_state_values.mean(1)
 | 
			
		||||
 | 
			
		||||
  #   advantage = (state_action_values - (reward_batch + self.config['discount_rate'] * average_next_state_values).unsqueeze(1)) 
 | 
			
		||||
  #   # Scale and squeeze the dimension
 | 
			
		||||
  #   advantage = advantage.squeeze(1)
 | 
			
		||||
  #   # advantage = (advantage / (state_values.std() + np.finfo('float').eps)).squeeze(1)
 | 
			
		||||
  #   action_probabilities = self.policy_net(state_batch)
 | 
			
		||||
  #   distributions = list(map(Categorical, action_probabilities))
 | 
			
		||||
  #   log_probs = torch.stack(list(map(lambda distribution, action: distribution.log_prob(action), distributions, action_batch)))
 | 
			
		||||
  #   if (isinstance(self.memory, M.PrioritizedReplayMemory)):
 | 
			
		||||
  #     policy_loss = (importance_weights * -log_probs * advantage).sum()
 | 
			
		||||
  #   else:
 | 
			
		||||
  #     policy_loss = (-log_probs * advantage).sum()
 | 
			
		||||
 | 
			
		||||
  #   self.policy_net.zero_grad()
 | 
			
		||||
  #   policy_loss.backward()
 | 
			
		||||
  #   self.policy_net.step()
 | 
			
		||||
 | 
			
		||||
  #   if self.logger is not None:
 | 
			
		||||
  #     self.logger.append("PolicyLoss", policy_loss.item())
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -1,2 +1,4 @@
 | 
			
		|||
from .A2CSingleAgent import *
 | 
			
		||||
from .DQNAgent import *
 | 
			
		||||
from .PPOAgent import *
 | 
			
		||||
from .REINFORCEAgent import *
 | 
			
		||||
| 
						 | 
				
			
			@ -67,7 +67,7 @@ def zip_batch(minibatch, priority = False):
 | 
			
		|||
    action_batch = torch.tensor(action_batch)
 | 
			
		||||
    reward_batch = torch.tensor(reward_batch)
 | 
			
		||||
    not_done_batch = ~torch.tensor(done_batch)
 | 
			
		||||
    next_state_batch = torch.cat(next_state_batch)[not_done_batch]
 | 
			
		||||
    next_state_batch = torch.cat(next_state_batch)
 | 
			
		||||
 | 
			
		||||
    if priority:
 | 
			
		||||
        return state_batch, action_batch, reward_batch, next_state_batch, not_done_batch, weights, indexes
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue