Corrected A2C and PPO to train at the end of an episode
This commit is contained in:
		
							parent
							
								
									1958fc7c7e
								
							
						
					
					
						commit
						e42f5bba1b
					
				
					 5 changed files with 48 additions and 28 deletions
				
			
		| 
						 | 
				
			
			@ -94,15 +94,12 @@ config['disable_cuda'] = False
 | 
			
		|||
 | 
			
		||||
def train(runner, agent, config, logger = None, logwriter = None):
 | 
			
		||||
    finished = False
 | 
			
		||||
    last_episode_num = 1
 | 
			
		||||
    while not finished:
 | 
			
		||||
        runner.run(config['replay_skip'] + 1)
 | 
			
		||||
        runner.run()
 | 
			
		||||
        agent.learn()
 | 
			
		||||
        if logwriter is not None:
 | 
			
		||||
          if last_episode_num < runner.episode_num:
 | 
			
		||||
            last_episode_num = runner.episode_num
 | 
			
		||||
            agent.value_net.log_named_parameters()
 | 
			
		||||
            agent.policy_net.log_named_parameters()
 | 
			
		||||
          agent.value_net.log_named_parameters()
 | 
			
		||||
          agent.policy_net.log_named_parameters()
 | 
			
		||||
          logwriter.write(logger)
 | 
			
		||||
        finished = runner.episode_num > config['total_training_episodes']
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -141,8 +138,8 @@ if __name__ == "__main__":
 | 
			
		|||
  # agent = rltorch.agents.REINFORCEAgent(net, memory, config, target_net = target_net, logger = logger)
 | 
			
		||||
  agent = rltorch.agents.A2CSingleAgent(policy_net, value_net, memory, config, logger = logger)
 | 
			
		||||
 | 
			
		||||
  # Runner performs a certain number of steps in the environment
 | 
			
		||||
  runner = rltorch.env.EnvironmentRunSync(env, actor, config, name = "Training", memory = memory, logwriter = logwriter)
 | 
			
		||||
  # Runner performs one episode in the environment
 | 
			
		||||
  runner = rltorch.env.EnvironmentEpisodeSync(env, actor, config, name = "Training", memory = memory, logwriter = logwriter)
 | 
			
		||||
    
 | 
			
		||||
  print("Training...")
 | 
			
		||||
  train(runner, agent, config, logger = logger, logwriter = logwriter) 
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -94,21 +94,16 @@ config['disable_cuda'] = False
 | 
			
		|||
 | 
			
		||||
def train(runner, agent, config, logger = None, logwriter = None):
 | 
			
		||||
    finished = False
 | 
			
		||||
    last_episode_num = 1
 | 
			
		||||
    while not finished:
 | 
			
		||||
        runner.run(config['replay_skip'] + 1)
 | 
			
		||||
        runner.run()
 | 
			
		||||
        agent.learn()
 | 
			
		||||
        if logwriter is not None:
 | 
			
		||||
          if last_episode_num < runner.episode_num:
 | 
			
		||||
            last_episode_num = runner.episode_num
 | 
			
		||||
            agent.value_net.log_named_parameters()
 | 
			
		||||
            agent.policy_net.log_named_parameters()
 | 
			
		||||
          agent.value_net.log_named_parameters()
 | 
			
		||||
          agent.policy_net.log_named_parameters()
 | 
			
		||||
          logwriter.write(logger)
 | 
			
		||||
        finished = runner.episode_num > config['total_training_episodes']
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
  torch.multiprocessing.set_sharing_strategy('file_system') # To not hit file descriptor memory limit
 | 
			
		||||
 | 
			
		||||
  # Setting up the environment
 | 
			
		||||
  rltorch.set_seed(config['seed'])
 | 
			
		||||
  print("Setting up environment...", end = " ")
 | 
			
		||||
| 
						 | 
				
			
			@ -142,7 +137,7 @@ if __name__ == "__main__":
 | 
			
		|||
  agent = rltorch.agents.PPOAgent(policy_net, value_net, memory, config, logger = logger)
 | 
			
		||||
 | 
			
		||||
  # Runner performs a certain number of steps in the environment
 | 
			
		||||
  runner = rltorch.env.EnvironmentRunSync(env, actor, config, name = "Training", memory = memory, logwriter = logwriter)
 | 
			
		||||
  runner = rltorch.env.EnvironmentEpisodeSync(env, actor, config, name = "Training", memory = memory, logwriter = logwriter)
 | 
			
		||||
    
 | 
			
		||||
  print("Training...")
 | 
			
		||||
  train(runner, agent, config, logger = logger, logwriter = logwriter) 
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -27,9 +27,6 @@ class A2CSingleAgent:
 | 
			
		|||
  
 | 
			
		||||
  
 | 
			
		||||
  def learn(self):
 | 
			
		||||
    if len(self.memory) < self.config['batch_size']:
 | 
			
		||||
      return
 | 
			
		||||
 | 
			
		||||
    episode_batch = self.memory.recall()
 | 
			
		||||
    state_batch, action_batch, reward_batch, next_state_batch, done_batch, log_prob_batch = zip(*episode_batch)  
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -40,7 +37,7 @@ class A2CSingleAgent:
 | 
			
		|||
    log_prob_batch = torch.cat(log_prob_batch).to(self.value_net.device)
 | 
			
		||||
 | 
			
		||||
    ## Value Loss
 | 
			
		||||
    value_loss = F.mse_loss(self._discount_rewards(reward_batch), self.value_net(state_batch[0]))
 | 
			
		||||
    value_loss = F.mse_loss(self._discount_rewards(reward_batch).sum(), self.value_net(state_batch[0]))
 | 
			
		||||
    self.value_net.zero_grad()
 | 
			
		||||
    value_loss.backward()
 | 
			
		||||
    self.value_net.step()
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,5 +1,3 @@
 | 
			
		|||
# Deprecated since the idea of the idea shouldn't work without having some sort of "mental model" of the environment
 | 
			
		||||
 | 
			
		||||
from copy import deepcopy
 | 
			
		||||
import numpy as np
 | 
			
		||||
import torch
 | 
			
		||||
| 
						 | 
				
			
			@ -30,9 +28,6 @@ class PPOAgent:
 | 
			
		|||
  
 | 
			
		||||
  
 | 
			
		||||
  def learn(self):
 | 
			
		||||
    if len(self.memory) < self.config['batch_size']:
 | 
			
		||||
      return
 | 
			
		||||
 | 
			
		||||
    episode_batch = self.memory.recall()
 | 
			
		||||
    state_batch, action_batch, reward_batch, next_state_batch, done_batch, log_prob_batch = zip(*episode_batch)  
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -44,7 +39,7 @@ class PPOAgent:
 | 
			
		|||
    log_prob_batch = torch.cat(log_prob_batch).to(self.value_net.device)
 | 
			
		||||
 | 
			
		||||
    ## Value Loss
 | 
			
		||||
    value_loss = F.mse_loss(self._discount_rewards(reward_batch), self.value_net(state_batch[0]))
 | 
			
		||||
    value_loss = F.mse_loss(self._discount_rewards(reward_batch).sum(), self.value_net(state_batch[0]))
 | 
			
		||||
    self.value_net.zero_grad()
 | 
			
		||||
    value_loss.backward()
 | 
			
		||||
    self.value_net.step()
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										38
									
								
								rltorch/env/simulate.py
									
										
									
									
										vendored
									
									
								
							
							
						
						
									
										38
									
								
								rltorch/env/simulate.py
									
										
									
									
										vendored
									
									
								
							| 
						 | 
				
			
			@ -62,4 +62,40 @@ class EnvironmentRunSync():
 | 
			
		|||
    if self.logwriter is not None:
 | 
			
		||||
      self.logwriter.write(logger)
 | 
			
		||||
    
 | 
			
		||||
    self.last_state = state
 | 
			
		||||
    self.last_state = state
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class EnvironmentEpisodeSync():
 | 
			
		||||
  def __init__(self, env, actor, config, memory = None, logwriter = None, name = ""):
 | 
			
		||||
    self.env = env
 | 
			
		||||
    self.name = name
 | 
			
		||||
    self.actor = actor
 | 
			
		||||
    self.config = deepcopy(config)
 | 
			
		||||
    self.logwriter = logwriter
 | 
			
		||||
    self.memory = memory
 | 
			
		||||
    self.episode_num = 1
 | 
			
		||||
 | 
			
		||||
  def run(self):
 | 
			
		||||
    state = self.env.reset()
 | 
			
		||||
    done = False
 | 
			
		||||
    episodeReward = 0
 | 
			
		||||
    logger = rltorch.log.Logger() if self.logwriter is not None else None
 | 
			
		||||
    while not done:
 | 
			
		||||
      action = self.actor.act(state)
 | 
			
		||||
      next_state, reward, done, _ = self.env.step(action)
 | 
			
		||||
       
 | 
			
		||||
      episodeReward += reward
 | 
			
		||||
      if self.memory is not None:
 | 
			
		||||
        self.memory.append(state, action, reward, next_state, done)
 | 
			
		||||
       
 | 
			
		||||
      state = next_state
 | 
			
		||||
 | 
			
		||||
    if self.episode_num % self.config['print_stat_n_eps'] == 0:
 | 
			
		||||
      print("episode: {}/{}, score: {}"
 | 
			
		||||
        .format(self.episode_num, self.config['total_training_episodes'], episodeReward))
 | 
			
		||||
          
 | 
			
		||||
    if self.logwriter is not None:
 | 
			
		||||
      logger.append(self.name + '/EpisodeReward', episodeReward)
 | 
			
		||||
      self.logwriter.write(logger)
 | 
			
		||||
    
 | 
			
		||||
    self.episode_num +=  1
 | 
			
		||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue