Corrected A2C and PPO to train at the end of an episode
This commit is contained in:
parent
1958fc7c7e
commit
e42f5bba1b
5 changed files with 48 additions and 28 deletions
|
@ -94,15 +94,12 @@ config['disable_cuda'] = False
|
||||||
|
|
||||||
def train(runner, agent, config, logger = None, logwriter = None):
|
def train(runner, agent, config, logger = None, logwriter = None):
|
||||||
finished = False
|
finished = False
|
||||||
last_episode_num = 1
|
|
||||||
while not finished:
|
while not finished:
|
||||||
runner.run(config['replay_skip'] + 1)
|
runner.run()
|
||||||
agent.learn()
|
agent.learn()
|
||||||
if logwriter is not None:
|
if logwriter is not None:
|
||||||
if last_episode_num < runner.episode_num:
|
agent.value_net.log_named_parameters()
|
||||||
last_episode_num = runner.episode_num
|
agent.policy_net.log_named_parameters()
|
||||||
agent.value_net.log_named_parameters()
|
|
||||||
agent.policy_net.log_named_parameters()
|
|
||||||
logwriter.write(logger)
|
logwriter.write(logger)
|
||||||
finished = runner.episode_num > config['total_training_episodes']
|
finished = runner.episode_num > config['total_training_episodes']
|
||||||
|
|
||||||
|
@ -141,8 +138,8 @@ if __name__ == "__main__":
|
||||||
# agent = rltorch.agents.REINFORCEAgent(net, memory, config, target_net = target_net, logger = logger)
|
# agent = rltorch.agents.REINFORCEAgent(net, memory, config, target_net = target_net, logger = logger)
|
||||||
agent = rltorch.agents.A2CSingleAgent(policy_net, value_net, memory, config, logger = logger)
|
agent = rltorch.agents.A2CSingleAgent(policy_net, value_net, memory, config, logger = logger)
|
||||||
|
|
||||||
# Runner performs a certain number of steps in the environment
|
# Runner performs one episode in the environment
|
||||||
runner = rltorch.env.EnvironmentRunSync(env, actor, config, name = "Training", memory = memory, logwriter = logwriter)
|
runner = rltorch.env.EnvironmentEpisodeSync(env, actor, config, name = "Training", memory = memory, logwriter = logwriter)
|
||||||
|
|
||||||
print("Training...")
|
print("Training...")
|
||||||
train(runner, agent, config, logger = logger, logwriter = logwriter)
|
train(runner, agent, config, logger = logger, logwriter = logwriter)
|
||||||
|
|
|
@ -94,21 +94,16 @@ config['disable_cuda'] = False
|
||||||
|
|
||||||
def train(runner, agent, config, logger = None, logwriter = None):
|
def train(runner, agent, config, logger = None, logwriter = None):
|
||||||
finished = False
|
finished = False
|
||||||
last_episode_num = 1
|
|
||||||
while not finished:
|
while not finished:
|
||||||
runner.run(config['replay_skip'] + 1)
|
runner.run()
|
||||||
agent.learn()
|
agent.learn()
|
||||||
if logwriter is not None:
|
if logwriter is not None:
|
||||||
if last_episode_num < runner.episode_num:
|
agent.value_net.log_named_parameters()
|
||||||
last_episode_num = runner.episode_num
|
agent.policy_net.log_named_parameters()
|
||||||
agent.value_net.log_named_parameters()
|
|
||||||
agent.policy_net.log_named_parameters()
|
|
||||||
logwriter.write(logger)
|
logwriter.write(logger)
|
||||||
finished = runner.episode_num > config['total_training_episodes']
|
finished = runner.episode_num > config['total_training_episodes']
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
torch.multiprocessing.set_sharing_strategy('file_system') # To not hit file descriptor memory limit
|
|
||||||
|
|
||||||
# Setting up the environment
|
# Setting up the environment
|
||||||
rltorch.set_seed(config['seed'])
|
rltorch.set_seed(config['seed'])
|
||||||
print("Setting up environment...", end = " ")
|
print("Setting up environment...", end = " ")
|
||||||
|
@ -142,7 +137,7 @@ if __name__ == "__main__":
|
||||||
agent = rltorch.agents.PPOAgent(policy_net, value_net, memory, config, logger = logger)
|
agent = rltorch.agents.PPOAgent(policy_net, value_net, memory, config, logger = logger)
|
||||||
|
|
||||||
# Runner performs a certain number of steps in the environment
|
# Runner performs a certain number of steps in the environment
|
||||||
runner = rltorch.env.EnvironmentRunSync(env, actor, config, name = "Training", memory = memory, logwriter = logwriter)
|
runner = rltorch.env.EnvironmentEpisodeSync(env, actor, config, name = "Training", memory = memory, logwriter = logwriter)
|
||||||
|
|
||||||
print("Training...")
|
print("Training...")
|
||||||
train(runner, agent, config, logger = logger, logwriter = logwriter)
|
train(runner, agent, config, logger = logger, logwriter = logwriter)
|
||||||
|
|
|
@ -27,9 +27,6 @@ class A2CSingleAgent:
|
||||||
|
|
||||||
|
|
||||||
def learn(self):
|
def learn(self):
|
||||||
if len(self.memory) < self.config['batch_size']:
|
|
||||||
return
|
|
||||||
|
|
||||||
episode_batch = self.memory.recall()
|
episode_batch = self.memory.recall()
|
||||||
state_batch, action_batch, reward_batch, next_state_batch, done_batch, log_prob_batch = zip(*episode_batch)
|
state_batch, action_batch, reward_batch, next_state_batch, done_batch, log_prob_batch = zip(*episode_batch)
|
||||||
|
|
||||||
|
@ -40,7 +37,7 @@ class A2CSingleAgent:
|
||||||
log_prob_batch = torch.cat(log_prob_batch).to(self.value_net.device)
|
log_prob_batch = torch.cat(log_prob_batch).to(self.value_net.device)
|
||||||
|
|
||||||
## Value Loss
|
## Value Loss
|
||||||
value_loss = F.mse_loss(self._discount_rewards(reward_batch), self.value_net(state_batch[0]))
|
value_loss = F.mse_loss(self._discount_rewards(reward_batch).sum(), self.value_net(state_batch[0]))
|
||||||
self.value_net.zero_grad()
|
self.value_net.zero_grad()
|
||||||
value_loss.backward()
|
value_loss.backward()
|
||||||
self.value_net.step()
|
self.value_net.step()
|
||||||
|
|
|
@ -1,5 +1,3 @@
|
||||||
# Deprecated since the idea of the idea shouldn't work without having some sort of "mental model" of the environment
|
|
||||||
|
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
@ -30,9 +28,6 @@ class PPOAgent:
|
||||||
|
|
||||||
|
|
||||||
def learn(self):
|
def learn(self):
|
||||||
if len(self.memory) < self.config['batch_size']:
|
|
||||||
return
|
|
||||||
|
|
||||||
episode_batch = self.memory.recall()
|
episode_batch = self.memory.recall()
|
||||||
state_batch, action_batch, reward_batch, next_state_batch, done_batch, log_prob_batch = zip(*episode_batch)
|
state_batch, action_batch, reward_batch, next_state_batch, done_batch, log_prob_batch = zip(*episode_batch)
|
||||||
|
|
||||||
|
@ -44,7 +39,7 @@ class PPOAgent:
|
||||||
log_prob_batch = torch.cat(log_prob_batch).to(self.value_net.device)
|
log_prob_batch = torch.cat(log_prob_batch).to(self.value_net.device)
|
||||||
|
|
||||||
## Value Loss
|
## Value Loss
|
||||||
value_loss = F.mse_loss(self._discount_rewards(reward_batch), self.value_net(state_batch[0]))
|
value_loss = F.mse_loss(self._discount_rewards(reward_batch).sum(), self.value_net(state_batch[0]))
|
||||||
self.value_net.zero_grad()
|
self.value_net.zero_grad()
|
||||||
value_loss.backward()
|
value_loss.backward()
|
||||||
self.value_net.step()
|
self.value_net.step()
|
||||||
|
|
38
rltorch/env/simulate.py
vendored
38
rltorch/env/simulate.py
vendored
|
@ -62,4 +62,40 @@ class EnvironmentRunSync():
|
||||||
if self.logwriter is not None:
|
if self.logwriter is not None:
|
||||||
self.logwriter.write(logger)
|
self.logwriter.write(logger)
|
||||||
|
|
||||||
self.last_state = state
|
self.last_state = state
|
||||||
|
|
||||||
|
|
||||||
|
class EnvironmentEpisodeSync():
|
||||||
|
def __init__(self, env, actor, config, memory = None, logwriter = None, name = ""):
|
||||||
|
self.env = env
|
||||||
|
self.name = name
|
||||||
|
self.actor = actor
|
||||||
|
self.config = deepcopy(config)
|
||||||
|
self.logwriter = logwriter
|
||||||
|
self.memory = memory
|
||||||
|
self.episode_num = 1
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
state = self.env.reset()
|
||||||
|
done = False
|
||||||
|
episodeReward = 0
|
||||||
|
logger = rltorch.log.Logger() if self.logwriter is not None else None
|
||||||
|
while not done:
|
||||||
|
action = self.actor.act(state)
|
||||||
|
next_state, reward, done, _ = self.env.step(action)
|
||||||
|
|
||||||
|
episodeReward += reward
|
||||||
|
if self.memory is not None:
|
||||||
|
self.memory.append(state, action, reward, next_state, done)
|
||||||
|
|
||||||
|
state = next_state
|
||||||
|
|
||||||
|
if self.episode_num % self.config['print_stat_n_eps'] == 0:
|
||||||
|
print("episode: {}/{}, score: {}"
|
||||||
|
.format(self.episode_num, self.config['total_training_episodes'], episodeReward))
|
||||||
|
|
||||||
|
if self.logwriter is not None:
|
||||||
|
logger.append(self.name + '/EpisodeReward', episodeReward)
|
||||||
|
self.logwriter.write(logger)
|
||||||
|
|
||||||
|
self.episode_num += 1
|
Loading…
Reference in a new issue