Corrected A2C and PPO to train at the end of an episode
This commit is contained in:
parent
1958fc7c7e
commit
e42f5bba1b
5 changed files with 48 additions and 28 deletions
|
@ -94,15 +94,12 @@ config['disable_cuda'] = False
|
|||
|
||||
def train(runner, agent, config, logger = None, logwriter = None):
|
||||
finished = False
|
||||
last_episode_num = 1
|
||||
while not finished:
|
||||
runner.run(config['replay_skip'] + 1)
|
||||
runner.run()
|
||||
agent.learn()
|
||||
if logwriter is not None:
|
||||
if last_episode_num < runner.episode_num:
|
||||
last_episode_num = runner.episode_num
|
||||
agent.value_net.log_named_parameters()
|
||||
agent.policy_net.log_named_parameters()
|
||||
agent.value_net.log_named_parameters()
|
||||
agent.policy_net.log_named_parameters()
|
||||
logwriter.write(logger)
|
||||
finished = runner.episode_num > config['total_training_episodes']
|
||||
|
||||
|
@ -141,8 +138,8 @@ if __name__ == "__main__":
|
|||
# agent = rltorch.agents.REINFORCEAgent(net, memory, config, target_net = target_net, logger = logger)
|
||||
agent = rltorch.agents.A2CSingleAgent(policy_net, value_net, memory, config, logger = logger)
|
||||
|
||||
# Runner performs a certain number of steps in the environment
|
||||
runner = rltorch.env.EnvironmentRunSync(env, actor, config, name = "Training", memory = memory, logwriter = logwriter)
|
||||
# Runner performs one episode in the environment
|
||||
runner = rltorch.env.EnvironmentEpisodeSync(env, actor, config, name = "Training", memory = memory, logwriter = logwriter)
|
||||
|
||||
print("Training...")
|
||||
train(runner, agent, config, logger = logger, logwriter = logwriter)
|
||||
|
|
|
@ -94,21 +94,16 @@ config['disable_cuda'] = False
|
|||
|
||||
def train(runner, agent, config, logger = None, logwriter = None):
|
||||
finished = False
|
||||
last_episode_num = 1
|
||||
while not finished:
|
||||
runner.run(config['replay_skip'] + 1)
|
||||
runner.run()
|
||||
agent.learn()
|
||||
if logwriter is not None:
|
||||
if last_episode_num < runner.episode_num:
|
||||
last_episode_num = runner.episode_num
|
||||
agent.value_net.log_named_parameters()
|
||||
agent.policy_net.log_named_parameters()
|
||||
agent.value_net.log_named_parameters()
|
||||
agent.policy_net.log_named_parameters()
|
||||
logwriter.write(logger)
|
||||
finished = runner.episode_num > config['total_training_episodes']
|
||||
|
||||
if __name__ == "__main__":
|
||||
torch.multiprocessing.set_sharing_strategy('file_system') # To not hit file descriptor memory limit
|
||||
|
||||
# Setting up the environment
|
||||
rltorch.set_seed(config['seed'])
|
||||
print("Setting up environment...", end = " ")
|
||||
|
@ -142,7 +137,7 @@ if __name__ == "__main__":
|
|||
agent = rltorch.agents.PPOAgent(policy_net, value_net, memory, config, logger = logger)
|
||||
|
||||
# Runner performs a certain number of steps in the environment
|
||||
runner = rltorch.env.EnvironmentRunSync(env, actor, config, name = "Training", memory = memory, logwriter = logwriter)
|
||||
runner = rltorch.env.EnvironmentEpisodeSync(env, actor, config, name = "Training", memory = memory, logwriter = logwriter)
|
||||
|
||||
print("Training...")
|
||||
train(runner, agent, config, logger = logger, logwriter = logwriter)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue