Implemented REINFORCE into the library
This commit is contained in:
		
							parent
							
								
									14ba64d525
								
							
						
					
					
						commit
						21b820b401
					
				
					 7 changed files with 250 additions and 2 deletions
				
			
		
							
								
								
									
										126
									
								
								examples/acrobot_reinforce.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										126
									
								
								examples/acrobot_reinforce.py
									
										
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,126 @@
 | 
			
		|||
import gym
 | 
			
		||||
import numpy as np
 | 
			
		||||
import torch
 | 
			
		||||
import torch.nn as nn
 | 
			
		||||
import torch.nn.functional as F
 | 
			
		||||
import rltorch
 | 
			
		||||
import rltorch.network as rn
 | 
			
		||||
import rltorch.memory as M
 | 
			
		||||
import rltorch.env as E
 | 
			
		||||
from rltorch.action_selector import StochasticSelector
 | 
			
		||||
from tensorboardX import SummaryWriter
 | 
			
		||||
import torch.multiprocessing as mp
 | 
			
		||||
import signal
 | 
			
		||||
from copy import deepcopy
 | 
			
		||||
 | 
			
		||||
class Value(nn.Module):
 | 
			
		||||
  def __init__(self, state_size, action_size):
 | 
			
		||||
    super(Value, self).__init__()
 | 
			
		||||
    self.state_size = state_size
 | 
			
		||||
    self.action_size = action_size
 | 
			
		||||
 | 
			
		||||
    self.fc1 = rn.NoisyLinear(state_size, 64)
 | 
			
		||||
    self.fc_norm = nn.LayerNorm(64)
 | 
			
		||||
    
 | 
			
		||||
    self.value_fc = rn.NoisyLinear(64, 64)
 | 
			
		||||
    self.value_fc_norm = nn.LayerNorm(64)
 | 
			
		||||
    self.value = rn.NoisyLinear(64, 1)
 | 
			
		||||
    
 | 
			
		||||
    self.advantage_fc = rn.NoisyLinear(64, 64)
 | 
			
		||||
    self.advantage_fc_norm = nn.LayerNorm(64)
 | 
			
		||||
    self.advantage = rn.NoisyLinear(64, action_size)
 | 
			
		||||
 | 
			
		||||
  def forward(self, x):
 | 
			
		||||
    x = F.relu(self.fc_norm(self.fc1(x)))
 | 
			
		||||
    
 | 
			
		||||
    state_value = F.relu(self.value_fc_norm(self.value_fc(x)))
 | 
			
		||||
    state_value = self.value(state_value)
 | 
			
		||||
    
 | 
			
		||||
    advantage = F.relu(self.advantage_fc_norm(self.advantage_fc(x)))
 | 
			
		||||
    advantage = self.advantage(advantage)
 | 
			
		||||
    
 | 
			
		||||
    x = F.softmax(state_value + advantage - advantage.mean(), dim = 1)
 | 
			
		||||
    
 | 
			
		||||
    return x
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
config = {}
 | 
			
		||||
config['seed'] = 901
 | 
			
		||||
config['environment_name'] = 'Acrobot-v1'
 | 
			
		||||
config['memory_size'] = 2000
 | 
			
		||||
config['total_training_episodes'] = 100
 | 
			
		||||
config['total_evaluation_episodes'] = 10
 | 
			
		||||
config['batch_size'] = 32
 | 
			
		||||
config['learning_rate'] = 1e-3
 | 
			
		||||
config['target_sync_tau'] = 1e-1
 | 
			
		||||
config['discount_rate'] = 0.99
 | 
			
		||||
config['replay_skip'] = 0
 | 
			
		||||
# How many episodes between printing out the episode stats
 | 
			
		||||
config['print_stat_n_eps'] = 1
 | 
			
		||||
config['disable_cuda'] = False
 | 
			
		||||
 | 
			
		||||
def train(env, agent, actor, memory, config, logger = None, logwriter = None):
 | 
			
		||||
    finished = False
 | 
			
		||||
    episode_num = 1
 | 
			
		||||
    while not finished:
 | 
			
		||||
        rltorch.env.simulateEnvEps(env, actor, config, memory = memory, logger = logger, name = "Training")
 | 
			
		||||
        episode_num += 1
 | 
			
		||||
        agent.learn()
 | 
			
		||||
        # When the episode number changes, log network paramters
 | 
			
		||||
        if logwriter is not None:
 | 
			
		||||
          agent.net.log_named_parameters()
 | 
			
		||||
          logwriter.write(logger)
 | 
			
		||||
        finished = episode_num > config['total_training_episodes']
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
  torch.multiprocessing.set_sharing_strategy('file_system') # To not hit file descriptor memory limit
 | 
			
		||||
 | 
			
		||||
  # Setting up the environment
 | 
			
		||||
  rltorch.set_seed(config['seed'])
 | 
			
		||||
  print("Setting up environment...", end = " ")
 | 
			
		||||
  env = E.TorchWrap(gym.make(config['environment_name']))
 | 
			
		||||
  env.seed(config['seed'])
 | 
			
		||||
  print("Done.")
 | 
			
		||||
      
 | 
			
		||||
  state_size = env.observation_space.shape[0]
 | 
			
		||||
  action_size = env.action_space.n
 | 
			
		||||
 | 
			
		||||
  # Logging
 | 
			
		||||
  logger = rltorch.log.Logger()
 | 
			
		||||
  logwriter = rltorch.log.LogWriter(SummaryWriter())
 | 
			
		||||
 | 
			
		||||
  # Setting up the networks
 | 
			
		||||
  device = torch.device("cuda:0" if torch.cuda.is_available() and not config['disable_cuda'] else "cpu")
 | 
			
		||||
  net = rn.Network(Value(state_size, action_size), 
 | 
			
		||||
                      torch.optim.Adam, config, device = device, name = "DQN")
 | 
			
		||||
  target_net = rn.TargetNetwork(net, device = device)
 | 
			
		||||
  net.model.share_memory()
 | 
			
		||||
  target_net.model.share_memory()
 | 
			
		||||
 | 
			
		||||
  # Memory stores experiences for later training
 | 
			
		||||
  memory = M.EpisodeMemory()
 | 
			
		||||
 | 
			
		||||
  # Actor takes a net and uses it to produce actions from given states
 | 
			
		||||
  actor = StochasticSelector(net, action_size, memory, device = device)
 | 
			
		||||
 | 
			
		||||
  # Agent is what performs the training
 | 
			
		||||
  agent = rltorch.agents.REINFORCEAgent(net, memory, config, target_net = target_net, logger = logger)
 | 
			
		||||
    
 | 
			
		||||
  print("Training...")
 | 
			
		||||
 | 
			
		||||
  train(env, agent, actor, memory, config, logger = logger, logwriter = logwriter) 
 | 
			
		||||
 | 
			
		||||
  # For profiling...
 | 
			
		||||
  # import cProfile
 | 
			
		||||
  # cProfile.run('train(runner, agent, config, logger = logger, logwriter = logwriter )')
 | 
			
		||||
  # python -m torch.utils.bottleneck /path/to/source/script.py [args] is also a good solution...
 | 
			
		||||
 | 
			
		||||
  print("Training Finished.")
 | 
			
		||||
 | 
			
		||||
  print("Evaluating...")
 | 
			
		||||
  rltorch.env.simulateEnvEps(env, actor, config, total_episodes = config['total_evaluation_episodes'], logger = logger, name = "Evaluation")
 | 
			
		||||
  print("Evaulations Done.")
 | 
			
		||||
 | 
			
		||||
  logwriter.close() # We don't need to write anything out to disk anymore
 | 
			
		||||
							
								
								
									
										24
									
								
								rltorch/action_selector/StochasticSelector.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										24
									
								
								rltorch/action_selector/StochasticSelector.py
									
										
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,24 @@
 | 
			
		|||
from random import randrange
 | 
			
		||||
import torch
 | 
			
		||||
from torch.distributions import Categorical
 | 
			
		||||
import rltorch
 | 
			
		||||
from rltorch.action_selector import ArgMaxSelector
 | 
			
		||||
 | 
			
		||||
class StochasticSelector(ArgMaxSelector):
 | 
			
		||||
    def __init__(self, model, action_size, memory, device = None):
 | 
			
		||||
        super(StochasticSelector, self).__init__(model, action_size, device = device)
 | 
			
		||||
        self.model = model
 | 
			
		||||
        self.action_size = action_size
 | 
			
		||||
        self.device = device
 | 
			
		||||
        if not isinstance(memory, rltorch.memory.EpisodeMemory):
 | 
			
		||||
            raise ValueError("Memory must be of instance EpisodeMemory")
 | 
			
		||||
        self.memory = memory
 | 
			
		||||
    def best_act(self, state, log_prob = True):
 | 
			
		||||
        if self.device is not None:
 | 
			
		||||
            state = state.to(self.device)
 | 
			
		||||
        action_probabilities = self.model(state)
 | 
			
		||||
        distribution = Categorical(action_probabilities)
 | 
			
		||||
        action = distribution.sample()
 | 
			
		||||
        if log_prob:
 | 
			
		||||
            self.memory.append_log_probs(distribution.log_prob(action))
 | 
			
		||||
        return action.item()
 | 
			
		||||
| 
						 | 
				
			
			@ -1,3 +1,4 @@
 | 
			
		|||
from .ArgMaxSelector import * 
 | 
			
		||||
from .EpsilonGreedySelector import * 
 | 
			
		||||
from .RandomSelector import * 
 | 
			
		||||
from .RandomSelector import * 
 | 
			
		||||
from .StochasticSelector import * 
 | 
			
		||||
							
								
								
									
										51
									
								
								rltorch/agents/REINFORCEAgent.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										51
									
								
								rltorch/agents/REINFORCEAgent.py
									
										
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,51 @@
 | 
			
		|||
import rltorch
 | 
			
		||||
from copy import deepcopy
 | 
			
		||||
import torch
 | 
			
		||||
import numpy as np
 | 
			
		||||
 | 
			
		||||
class REINFORCEAgent:
 | 
			
		||||
  def __init__(self, net , memory, config, target_net = None, logger = None):
 | 
			
		||||
    self.net = net
 | 
			
		||||
    if not isinstance(memory, rltorch.memory.EpisodeMemory):
 | 
			
		||||
      raise ValueError("Memory must be of instance EpisodeMemory")
 | 
			
		||||
    self.memory = memory
 | 
			
		||||
    self.config = deepcopy(config)
 | 
			
		||||
    self.target_net = target_net
 | 
			
		||||
    self.logger = logger
 | 
			
		||||
 | 
			
		||||
  def _discount_rewards(self, rewards):
 | 
			
		||||
    discounted_rewards = torch.zeros_like(rewards)
 | 
			
		||||
    running_add = 0
 | 
			
		||||
    for t in reversed(range(len(rewards))):
 | 
			
		||||
      running_add = running_add * self.config['discount_rate'] + rewards[t]
 | 
			
		||||
      discounted_rewards[t] = running_add
 | 
			
		||||
 | 
			
		||||
    # Normalize rewards
 | 
			
		||||
    discounted_rewards = (discounted_rewards - discounted_rewards.mean()) / (discounted_rewards.std() + np.finfo('float').eps)
 | 
			
		||||
    return discounted_rewards
 | 
			
		||||
  
 | 
			
		||||
  def learn(self):
 | 
			
		||||
    episode_batch = self.memory.recall()
 | 
			
		||||
    state_batch, action_batch, reward_batch, next_state_batch, done_batch, log_prob_batch = zip(*episode_batch)
 | 
			
		||||
 | 
			
		||||
    discount_reward_batch = self._discount_rewards(torch.tensor(reward_batch))
 | 
			
		||||
    log_prob_batch = torch.cat(log_prob_batch)
 | 
			
		||||
 | 
			
		||||
    policy_loss = (-1 * log_prob_batch * discount_reward_batch).sum()
 | 
			
		||||
    
 | 
			
		||||
    if self.logger is not None:
 | 
			
		||||
            self.logger.append("Loss", policy_loss.item())
 | 
			
		||||
 | 
			
		||||
    self.net.zero_grad()
 | 
			
		||||
    policy_loss.backward()
 | 
			
		||||
    self.net.clamp_gradients()
 | 
			
		||||
    self.net.step()
 | 
			
		||||
 | 
			
		||||
    if self.target_net is not None:
 | 
			
		||||
      if 'target_sync_tau' in self.config:
 | 
			
		||||
        self.target_net.partial_sync(self.config['target_sync_tau'])
 | 
			
		||||
      else:
 | 
			
		||||
        self.target_net.sync()
 | 
			
		||||
 | 
			
		||||
    # Memory is irrelevant for future training
 | 
			
		||||
    self.memory.clear()
 | 
			
		||||
| 
						 | 
				
			
			@ -1 +1,2 @@
 | 
			
		|||
from .DQNAgent import *
 | 
			
		||||
from .DQNAgent import *
 | 
			
		||||
from .REINFORCEAgent import *
 | 
			
		||||
							
								
								
									
										44
									
								
								rltorch/memory/EpisodeMemory.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										44
									
								
								rltorch/memory/EpisodeMemory.py
									
										
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,44 @@
 | 
			
		|||
import random
 | 
			
		||||
from collections import namedtuple
 | 
			
		||||
import torch
 | 
			
		||||
Transition = namedtuple('Transition',
 | 
			
		||||
    ('state', 'action', 'reward', 'next_state', 'done'))
 | 
			
		||||
 | 
			
		||||
class EpisodeMemory(object):
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
        self.memory = []
 | 
			
		||||
        self.log_probs = []
 | 
			
		||||
 | 
			
		||||
    def append(self, *args):
 | 
			
		||||
        """Saves a transition."""
 | 
			
		||||
        self.memory.append(Transition(*args))
 | 
			
		||||
    
 | 
			
		||||
    def append_log_probs(self, logprob):
 | 
			
		||||
        self.log_probs.append(logprob)
 | 
			
		||||
 | 
			
		||||
    def clear(self):
 | 
			
		||||
        self.memory.clear()
 | 
			
		||||
        self.log_probs.clear()
 | 
			
		||||
 | 
			
		||||
    def recall(self):
 | 
			
		||||
        if len(self.memory) != len(self.log_probs):
 | 
			
		||||
            raise ValueError("Memory and recorded log probabilities must be the same length.")
 | 
			
		||||
        return list(zip(*tuple(zip(*self.memory)), self.log_probs))
 | 
			
		||||
 | 
			
		||||
    def __len__(self):
 | 
			
		||||
        return len(self.memory)
 | 
			
		||||
 | 
			
		||||
    def __iter__(self):
 | 
			
		||||
        return iter(self.memory)
 | 
			
		||||
 | 
			
		||||
    def __contains__(self, value):
 | 
			
		||||
        return value in self.memory
 | 
			
		||||
 | 
			
		||||
    def __getitem__(self, index):
 | 
			
		||||
        return self.memory[index]
 | 
			
		||||
 | 
			
		||||
    def __setitem__(self, index, value):
 | 
			
		||||
        self.memory[index] = value
 | 
			
		||||
 | 
			
		||||
    def __reversed__(self):
 | 
			
		||||
        return reversed(self.memory)
 | 
			
		||||
| 
						 | 
				
			
			@ -1,2 +1,3 @@
 | 
			
		|||
from .EpisodeMemory import *
 | 
			
		||||
from .ReplayMemory import * 
 | 
			
		||||
from .PrioritizedReplayMemory import *
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue