Cleaned up scripts, added more comments
This commit is contained in:
parent
e42f5bba1b
commit
a59f84b446
11 changed files with 103 additions and 436 deletions
|
|
@ -1,5 +1,4 @@
|
|||
import gym
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
|
@ -9,9 +8,11 @@ import rltorch.memory as M
|
|||
import rltorch.env as E
|
||||
from rltorch.action_selector import StochasticSelector
|
||||
from tensorboardX import SummaryWriter
|
||||
import torch.multiprocessing as mp
|
||||
from copy import deepcopy
|
||||
|
||||
#
|
||||
## Networks
|
||||
#
|
||||
class Value(nn.Module):
|
||||
def __init__(self, state_size, action_size):
|
||||
super(Value, self).__init__()
|
||||
|
|
@ -39,7 +40,6 @@ class Value(nn.Module):
|
|||
advantage = self.advantage(advantage)
|
||||
|
||||
x = state_value + advantage - advantage.mean()
|
||||
|
||||
return x
|
||||
|
||||
|
||||
|
|
@ -63,6 +63,9 @@ class Policy(nn.Module):
|
|||
x = F.softmax(self.action_prob(x), dim = 1)
|
||||
return x
|
||||
|
||||
#
|
||||
## Configuration
|
||||
#
|
||||
config = {}
|
||||
config['seed'] = 901
|
||||
config['environment_name'] = 'Acrobot-v1'
|
||||
|
|
@ -88,7 +91,9 @@ config['prioritized_replay_sampling_priority'] = 0.6
|
|||
config['prioritized_replay_weight_importance'] = rltorch.scheduler.ExponentialScheduler(initial_value = 0.4, end_value = 1, iterations = 5000)
|
||||
|
||||
|
||||
|
||||
#
|
||||
## Training Loop
|
||||
#
|
||||
def train(runner, agent, config, logger = None, logwriter = None):
|
||||
finished = False
|
||||
last_episode_num = 1
|
||||
|
|
@ -103,6 +108,7 @@ def train(runner, agent, config, logger = None, logwriter = None):
|
|||
logwriter.write(logger)
|
||||
finished = runner.episode_num > config['total_training_episodes']
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Setting up the environment
|
||||
rltorch.set_seed(config['seed'])
|
||||
|
|
@ -116,7 +122,6 @@ if __name__ == "__main__":
|
|||
|
||||
# Logging
|
||||
logger = rltorch.log.Logger()
|
||||
# logwriter = rltorch.log.LogWriter(logger, SummaryWriter())
|
||||
logwriter = rltorch.log.LogWriter(SummaryWriter())
|
||||
|
||||
# Setting up the networks
|
||||
|
|
@ -127,13 +132,11 @@ if __name__ == "__main__":
|
|||
torch.optim.Adam, 500, None, config2, sigma = 0.1, device = device, name = "ES", logger = logger)
|
||||
value_net = rn.Network(Value(state_size, action_size),
|
||||
torch.optim.Adam, config, device = device, name = "DQN", logger = logger)
|
||||
|
||||
target_net = rn.TargetNetwork(value_net, device = device)
|
||||
value_net.model.share_memory()
|
||||
target_net.model.share_memory()
|
||||
|
||||
# Actor takes a net and uses it to produce actions from given states
|
||||
actor = StochasticSelector(policy_net, action_size, device = device)
|
||||
|
||||
# Memory stores experiences for later training
|
||||
memory = M.PrioritizedReplayMemory(capacity = config['memory_size'], alpha = config['prioritized_replay_sampling_priority'])
|
||||
|
||||
|
|
@ -141,11 +144,9 @@ if __name__ == "__main__":
|
|||
runner = rltorch.env.EnvironmentRunSync(env, actor, config, name = "Training", memory = memory, logwriter = logwriter)
|
||||
|
||||
# Agent is what performs the training
|
||||
# agent = TestAgent(policy_net, value_net, memory, config, target_value_net = target_net, logger = logger)
|
||||
agent = rltorch.agents.QEPAgent(policy_net, value_net, memory, config, target_value_net = target_net, logger = logger)
|
||||
|
||||
print("Training...")
|
||||
|
||||
train(runner, agent, config, logger = logger, logwriter = logwriter)
|
||||
|
||||
# For profiling...
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue