Added network documentation

This commit is contained in:
Brandon Rozek 2020-03-20 20:16:29 -04:00
parent 5e7de5bed7
commit 4c6dc0a2ea
5 changed files with 134 additions and 12 deletions

View file

@ -1,4 +1,10 @@
Neural Networks Neural Networks
=============== ===============
.. automodule:: rltorch.network .. autoclass:: rltorch.network.Network
:members:
.. autoclass:: rltorch.network.TargetNetwork
:members:
.. autoclass:: rltorch.network.ESNetwork
:members:
.. autoclass:: rltorch.network.NoisyLinear
:members: :members:

View file

@ -7,9 +7,36 @@ from copy import deepcopy
# What if we want to sometimes do gradient descent as well? # What if we want to sometimes do gradient descent as well?
class ESNetwork(Network): class ESNetwork(Network):
""" """
Network that functions from the paper Evolutionary Strategies (https://arxiv.org/abs/1703.03864) Uses evolutionary tecniques to optimize a neural network.
fitness_fun := model, *args -> fitness_value (float)
We wish to find a model that maximizes the fitness function Notes
-----
Derived from the paper
Evolutionary Strategies
(https://arxiv.org/abs/1703.03864)
Parameters
----------
model : nn.Module
A PyTorch nn.Module.
optimizer
A PyTorch opimtizer from torch.optim.
population_size : int
The number of networks to evaluate each iteration.
fitness_fn : function
Function that evaluates a network and returns a higher
number for better performing networks.
sigma : number
The standard deviation of the guassian noise added to
the parameters when creating the population.
config : dict
A dictionary of configuration items.
device
A device to send the weights to.
logger
Keeps track of historical weights
name
For use in logger to differentiate in analysis.
""" """
def __init__(self, model, optimizer, population_size, fitness_fn, config, sigma = 0.05, device = None, logger = None, name = ""): def __init__(self, model, optimizer, population_size, fitness_fn, config, sigma = 0.05, device = None, logger = None, name = ""):
super(ESNetwork, self).__init__(model, optimizer, config, device, logger, name) super(ESNetwork, self).__init__(model, optimizer, config, device, logger, name)
@ -18,9 +45,15 @@ class ESNetwork(Network):
self.sigma = sigma self.sigma = sigma
assert self.sigma > 0 assert self.sigma > 0
# We're not going to be calculating gradients in the traditional way
# So there's no need to waste computation time keeping track
def __call__(self, *args): def __call__(self, *args):
"""
Notes
-----
Since gradients aren't going to be computed in the
traditional fashion, there is no need to keep
track of the computations performed on the
tensors.
"""
with torch.no_grad(): with torch.no_grad():
result = self.model(*args) result = self.model(*args)
return result return result
@ -48,6 +81,14 @@ class ESNetwork(Network):
return candidate_solutions return candidate_solutions
def calc_gradients(self, *args): def calc_gradients(self, *args):
"""
Calculate gradients by shifting parameters
towards the networks with the highest fitness value.
This is calculated by evaluating the fitness of multiple
networks according to the fitness function specified in
the class.
"""
## Generate Noise ## Generate Noise
white_noise_dict, noise_dict = self._generate_noise_dicts() white_noise_dict, noise_dict = self._generate_noise_dicts()

View file

@ -1,6 +1,21 @@
class Network: class Network:
""" """
Wrapper around model which provides copy of it instead of trained weights Wrapper around model and optimizer in PyTorch to abstract away common use cases.
Parameters
----------
model : nn.Module
A PyTorch nn.Module.
optimizer
A PyTorch opimtizer from torch.optim.
config : dict
A dictionary of configuration items.
device
A device to send the weights to.
logger
Keeps track of historical weights
name
For use in logger to differentiate in analysis.
""" """
def __init__(self, model, optimizer, config, device = None, logger = None, name = ""): def __init__(self, model, optimizer, config, device = None, logger = None, name = ""):
self.model = model self.model = model
@ -18,14 +33,29 @@ class Network:
return self.model(*args) return self.model(*args)
def clamp_gradients(self, x = 1): def clamp_gradients(self, x = 1):
"""
Forcing gradients to stay within a certain interval
by setting it to the bound if it goes over it.
Parameters
----------
x : number > 0
Sets the interval to be [-x, x]
"""
assert x > 0 assert x > 0
for param in self.model.parameters(): for param in self.model.parameters():
param.grad.data.clamp_(-x, x) param.grad.data.clamp_(-x, x)
def zero_grad(self): def zero_grad(self):
"""
Clears out gradients held in the model.
"""
self.model.zero_grad() self.model.zero_grad()
def step(self): def step(self):
"""
Run a step of the optimizer on `model`.
"""
self.optimizer.step() self.optimizer.step()
def log_named_parameters(self): def log_named_parameters(self):

View file

@ -6,6 +6,24 @@ import math
# This class utilizes this property of the normal distribution # This class utilizes this property of the normal distribution
# N(mu, sigma) = mu + sigma * N(0, 1) # N(mu, sigma) = mu + sigma * N(0, 1)
class NoisyLinear(nn.Linear): class NoisyLinear(nn.Linear):
"""
Draws the parameters of nn.Linear from a normal distribution.
The parameters of the normal distribution are registered as
learnable parameters in the neural network.
Parameters
----------
in_features
Size of each input sample.
out_features
Size of each output sample.
sigma_init
The starting standard deviation of guassian noise.
bias
If set to False, the layer will not
learn an additive bias.
Default: True
"""
def __init__(self, in_features, out_features, sigma_init = 0.017, bias = True): def __init__(self, in_features, out_features, sigma_init = 0.017, bias = True):
super(NoisyLinear, self).__init__(in_features, out_features, bias = bias) super(NoisyLinear, self).__init__(in_features, out_features, bias = bias)
# One of the parameters the network is going to tune is the # One of the parameters the network is going to tune is the
@ -27,6 +45,15 @@ class NoisyLinear(nn.Linear):
nn.init.uniform_(self.bias, -std, std) nn.init.uniform_(self.bias, -std, std)
def forward(self, x): def forward(self, x):
r"""
Calculates the output :math:`y` through the following:
:math:`sigma \sim N(mu_1, std_1)`
:math:`bias \sim N(mu_2, std_2)`
:math:`y = sigma \cdot x + bias`
"""
# Fill s_normal_weight with values from the standard normal distribution # Fill s_normal_weight with values from the standard normal distribution
self.s_normal_weight.normal_() self.s_normal_weight.normal_()
weight_noise = self.sigma_weight * self.s_normal_weight.clone().requires_grad_() weight_noise = self.sigma_weight * self.s_normal_weight.clone().requires_grad_()

View file

@ -1,25 +1,43 @@
from copy import deepcopy from copy import deepcopy
# Derived from ptan library
class TargetNetwork: class TargetNetwork:
""" """
Wrapper around model which provides copy of it instead of trained weights Creates a clone of a network with syncing capabilities.
Parameters
----------
network
The network to clone.
device
The device to put the cloned parameters in.
""" """
def __init__(self, network, device = None): def __init__(self, network, device = None):
self.model = network.model self.model = network.model
self.target_model = deepcopy(network.model) self.target_model = deepcopy(network.model)
if network.device is not None: if device is not None:
self.target_model = self.target_model.to(device)
elif network.device is not None:
self.target_model = self.target_model.to(network.device) self.target_model = self.target_model.to(network.device)
def __call__(self, *args): def __call__(self, *args):
return self.model(*args) return self.model(*args)
def sync(self): def sync(self):
"""
Perform a full state sync with the originating model.
"""
self.target_model.load_state_dict(self.model.state_dict()) self.target_model.load_state_dict(self.model.state_dict())
def partial_sync(self, tau): def partial_sync(self, tau):
""" """
Blend params of target net with params from the model Partially move closer to the parameters of the originating
:param tau: model by updating parameters to be a mix of the
originating and the clone models.
Parameters
----------
tau : number
A number between 0-1 which indicates the proportion of the originator and clone in the new clone.
""" """
assert isinstance(tau, float) assert isinstance(tau, float)
assert 0.0 < tau <= 1.0 assert 0.0 < tau <= 1.0