Merge branch 'master' of github.com:Brandon-Rozek/rltorch
This commit is contained in:
		
						commit
						a667b3734b
					
				
					 29 changed files with 536 additions and 78 deletions
				
			
		
							
								
								
									
										6
									
								
								.gitignore
									
										
									
									
										vendored
									
									
								
							
							
						
						
									
										6
									
								
								.gitignore
									
										
									
									
										vendored
									
									
								
							| 
						 | 
					@ -2,3 +2,9 @@ __pycache__/
 | 
				
			||||||
*.py[cod]
 | 
					*.py[cod]
 | 
				
			||||||
rlenv/
 | 
					rlenv/
 | 
				
			||||||
runs/
 | 
					runs/
 | 
				
			||||||
 | 
					*.tox
 | 
				
			||||||
 | 
					*.coverage
 | 
				
			||||||
 | 
					.vscode/
 | 
				
			||||||
 | 
					docs/build
 | 
				
			||||||
 | 
					.mypy_cache/
 | 
				
			||||||
 | 
					*egg-info*
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
							
								
								
									
										20
									
								
								docs/Makefile
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										20
									
								
								docs/Makefile
									
										
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,20 @@
 | 
				
			||||||
 | 
					# Minimal makefile for Sphinx documentation
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# You can set these variables from the command line, and also
 | 
				
			||||||
 | 
					# from the environment for the first two.
 | 
				
			||||||
 | 
					SPHINXOPTS    ?=
 | 
				
			||||||
 | 
					SPHINXBUILD   ?= sphinx-build
 | 
				
			||||||
 | 
					SOURCEDIR     = source
 | 
				
			||||||
 | 
					BUILDDIR      = build
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Put it first so that "make" without argument is like "make help".
 | 
				
			||||||
 | 
					help:
 | 
				
			||||||
 | 
						@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					.PHONY: help Makefile
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Catch-all target: route all unknown targets to Sphinx using the new
 | 
				
			||||||
 | 
					# "make mode" option.  $(O) is meant as a shortcut for $(SPHINXOPTS).
 | 
				
			||||||
 | 
					%: Makefile
 | 
				
			||||||
 | 
						@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
 | 
				
			||||||
							
								
								
									
										35
									
								
								docs/make.bat
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										35
									
								
								docs/make.bat
									
										
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,35 @@
 | 
				
			||||||
 | 
					@ECHO OFF
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					pushd %~dp0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					REM Command file for Sphinx documentation
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					if "%SPHINXBUILD%" == "" (
 | 
				
			||||||
 | 
						set SPHINXBUILD=sphinx-build
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					set SOURCEDIR=source
 | 
				
			||||||
 | 
					set BUILDDIR=build
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					if "%1" == "" goto help
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					%SPHINXBUILD% >NUL 2>NUL
 | 
				
			||||||
 | 
					if errorlevel 9009 (
 | 
				
			||||||
 | 
						echo.
 | 
				
			||||||
 | 
						echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
 | 
				
			||||||
 | 
						echo.installed, then set the SPHINXBUILD environment variable to point
 | 
				
			||||||
 | 
						echo.to the full path of the 'sphinx-build' executable. Alternatively you
 | 
				
			||||||
 | 
						echo.may add the Sphinx directory to PATH.
 | 
				
			||||||
 | 
						echo.
 | 
				
			||||||
 | 
						echo.If you don't have Sphinx installed, grab it from
 | 
				
			||||||
 | 
						echo.http://sphinx-doc.org/
 | 
				
			||||||
 | 
						exit /b 1
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
 | 
				
			||||||
 | 
					goto end
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					:help
 | 
				
			||||||
 | 
					%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					:end
 | 
				
			||||||
 | 
					popd
 | 
				
			||||||
							
								
								
									
										4
									
								
								docs/source/action_selector.rst
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										4
									
								
								docs/source/action_selector.rst
									
										
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,4 @@
 | 
				
			||||||
 | 
					Action Selector
 | 
				
			||||||
 | 
					===============
 | 
				
			||||||
 | 
					.. automodule:: rltorch.action_selector
 | 
				
			||||||
 | 
					   :members:
 | 
				
			||||||
							
								
								
									
										4
									
								
								docs/source/agents.rst
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										4
									
								
								docs/source/agents.rst
									
										
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,4 @@
 | 
				
			||||||
 | 
					Agents
 | 
				
			||||||
 | 
					======
 | 
				
			||||||
 | 
					.. automodule:: rltorch.agents
 | 
				
			||||||
 | 
					   :members:
 | 
				
			||||||
							
								
								
									
										58
									
								
								docs/source/conf.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										58
									
								
								docs/source/conf.py
									
										
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,58 @@
 | 
				
			||||||
 | 
					# Configuration file for the Sphinx documentation builder.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# This file only contains a selection of the most common options. For a full
 | 
				
			||||||
 | 
					# list see the documentation:
 | 
				
			||||||
 | 
					# https://www.sphinx-doc.org/en/master/usage/configuration.html
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# -- Project information -----------------------------------------------------
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					project = 'RLTorch'
 | 
				
			||||||
 | 
					copyright = '2020, Brandon Rozek'
 | 
				
			||||||
 | 
					author = 'Brandon Rozek'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# The full version, including alpha/beta/rc tags
 | 
				
			||||||
 | 
					release = '0.1.0'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# -- General configuration ---------------------------------------------------
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Add any Sphinx extension module names here, as strings. They can be
 | 
				
			||||||
 | 
					# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
 | 
				
			||||||
 | 
					# ones.
 | 
				
			||||||
 | 
					extensions = [
 | 
				
			||||||
 | 
					    "sphinx.ext.autodoc",
 | 
				
			||||||
 | 
					    'sphinx.ext.autosummary',
 | 
				
			||||||
 | 
					    'sphinx.ext.napoleon',
 | 
				
			||||||
 | 
					    "sphinx.ext.viewcode",
 | 
				
			||||||
 | 
					    "sphinx.ext.mathjax",
 | 
				
			||||||
 | 
					]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Add any paths that contain templates here, relative to this directory.
 | 
				
			||||||
 | 
					templates_path = ['_templates']
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# List of patterns, relative to source directory, that match files and
 | 
				
			||||||
 | 
					# directories to ignore when looking for source files.
 | 
				
			||||||
 | 
					# This pattern also affects html_static_path and html_extra_path.
 | 
				
			||||||
 | 
					exclude_patterns = []
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# -- Options for HTML output -------------------------------------------------
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# The theme to use for HTML and HTML Help pages.  See the documentation for
 | 
				
			||||||
 | 
					# a list of builtin themes.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					html_theme = 'alabaster'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					html_sidebars = {
 | 
				
			||||||
 | 
					    '**': [
 | 
				
			||||||
 | 
					        'about.html',
 | 
				
			||||||
 | 
					        'navigation.html',
 | 
				
			||||||
 | 
					        'searchbox.html',
 | 
				
			||||||
 | 
					    ]
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Add any paths that contain custom static files (such as style sheets) here,
 | 
				
			||||||
 | 
					# relative to this directory. They are copied after the builtin static files,
 | 
				
			||||||
 | 
					# so a file named "default.css" will overwrite the builtin "default.css".
 | 
				
			||||||
 | 
					html_static_path = ['_static']
 | 
				
			||||||
							
								
								
									
										5
									
								
								docs/source/env.rst
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										5
									
								
								docs/source/env.rst
									
										
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,5 @@
 | 
				
			||||||
 | 
					Environment Utilities
 | 
				
			||||||
 | 
					=====================
 | 
				
			||||||
 | 
					.. automodule:: rltorch.env
 | 
				
			||||||
 | 
					   :members:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
							
								
								
									
										15
									
								
								docs/source/index.rst
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										15
									
								
								docs/source/index.rst
									
										
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,15 @@
 | 
				
			||||||
 | 
					Welcome to RLTorch's documentation!
 | 
				
			||||||
 | 
					===================================
 | 
				
			||||||
 | 
					.. toctree::
 | 
				
			||||||
 | 
					   :maxdepth: 2
 | 
				
			||||||
 | 
					   :caption: Contents:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					   action_selector
 | 
				
			||||||
 | 
					   agents
 | 
				
			||||||
 | 
					   env
 | 
				
			||||||
 | 
					   memory
 | 
				
			||||||
 | 
					   mp
 | 
				
			||||||
 | 
					   network
 | 
				
			||||||
 | 
					   scheduler
 | 
				
			||||||
 | 
					   log
 | 
				
			||||||
 | 
					   seed
 | 
				
			||||||
							
								
								
									
										4
									
								
								docs/source/log.rst
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										4
									
								
								docs/source/log.rst
									
										
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,4 @@
 | 
				
			||||||
 | 
					Logging
 | 
				
			||||||
 | 
					=======
 | 
				
			||||||
 | 
					.. automodule:: rltorch.log
 | 
				
			||||||
 | 
					   :members:
 | 
				
			||||||
							
								
								
									
										8
									
								
								docs/source/memory.rst
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										8
									
								
								docs/source/memory.rst
									
										
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,8 @@
 | 
				
			||||||
 | 
					Memory Structures
 | 
				
			||||||
 | 
					=================
 | 
				
			||||||
 | 
					.. autoclass:: rltorch.memory.ReplayMemory
 | 
				
			||||||
 | 
					   :members:
 | 
				
			||||||
 | 
					.. autoclass:: rltorch.memory.PrioritizedReplayMemory
 | 
				
			||||||
 | 
					   :members:
 | 
				
			||||||
 | 
					.. autoclass:: rltorch.memory.EpisodeMemory
 | 
				
			||||||
 | 
					   :members:
 | 
				
			||||||
							
								
								
									
										4
									
								
								docs/source/mp.rst
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										4
									
								
								docs/source/mp.rst
									
										
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,4 @@
 | 
				
			||||||
 | 
					Multiprocessing
 | 
				
			||||||
 | 
					===============
 | 
				
			||||||
 | 
					.. automodule:: rltorch.mp
 | 
				
			||||||
 | 
					   :members:
 | 
				
			||||||
							
								
								
									
										10
									
								
								docs/source/network.rst
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										10
									
								
								docs/source/network.rst
									
										
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,10 @@
 | 
				
			||||||
 | 
					Neural Networks
 | 
				
			||||||
 | 
					===============
 | 
				
			||||||
 | 
					.. autoclass:: rltorch.network.Network
 | 
				
			||||||
 | 
					   :members:
 | 
				
			||||||
 | 
					.. autoclass:: rltorch.network.TargetNetwork
 | 
				
			||||||
 | 
					   :members:
 | 
				
			||||||
 | 
					.. autoclass:: rltorch.network.ESNetwork
 | 
				
			||||||
 | 
					   :members:
 | 
				
			||||||
 | 
					.. autoclass:: rltorch.network.NoisyLinear
 | 
				
			||||||
 | 
					   :members:
 | 
				
			||||||
							
								
								
									
										6
									
								
								docs/source/scheduler.rst
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										6
									
								
								docs/source/scheduler.rst
									
										
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,6 @@
 | 
				
			||||||
 | 
					Hyperparameter Scheduling
 | 
				
			||||||
 | 
					=========================
 | 
				
			||||||
 | 
					.. autoclass:: rltorch.scheduler.LinearScheduler
 | 
				
			||||||
 | 
					   :members:
 | 
				
			||||||
 | 
					.. autoclass:: rltorch.scheduler.ExponentialScheduler
 | 
				
			||||||
 | 
					   :members:
 | 
				
			||||||
							
								
								
									
										4
									
								
								docs/source/seed.rst
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										4
									
								
								docs/source/seed.rst
									
										
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,4 @@
 | 
				
			||||||
 | 
					Seeding
 | 
				
			||||||
 | 
					=======
 | 
				
			||||||
 | 
					.. automodule:: rltorch.seed
 | 
				
			||||||
 | 
					   :members:
 | 
				
			||||||
| 
						 | 
					@ -1,32 +0,0 @@
 | 
				
			||||||
absl-py==0.7.0
 | 
					 | 
				
			||||||
astor==0.7.1
 | 
					 | 
				
			||||||
atari-py==0.1.7
 | 
					 | 
				
			||||||
certifi==2018.11.29
 | 
					 | 
				
			||||||
chardet==3.0.4
 | 
					 | 
				
			||||||
future==0.17.1
 | 
					 | 
				
			||||||
gast==0.2.2
 | 
					 | 
				
			||||||
grpcio==1.18.0
 | 
					 | 
				
			||||||
gym==0.10.11
 | 
					 | 
				
			||||||
h5py==2.9.0
 | 
					 | 
				
			||||||
idna==2.8
 | 
					 | 
				
			||||||
Keras-Applications==1.0.7
 | 
					 | 
				
			||||||
Keras-Preprocessing==1.0.8
 | 
					 | 
				
			||||||
Markdown==3.0.1
 | 
					 | 
				
			||||||
numpy==1.16.0
 | 
					 | 
				
			||||||
opencv-python==4.0.0.21
 | 
					 | 
				
			||||||
Pillow==5.4.1
 | 
					 | 
				
			||||||
pkg-resources==0.0.0
 | 
					 | 
				
			||||||
protobuf==3.6.1
 | 
					 | 
				
			||||||
pyglet==1.3.2
 | 
					 | 
				
			||||||
PyOpenGL==3.1.0
 | 
					 | 
				
			||||||
requests==2.21.0
 | 
					 | 
				
			||||||
scipy==1.2.0
 | 
					 | 
				
			||||||
six==1.12.0
 | 
					 | 
				
			||||||
tensorboard==1.12.2
 | 
					 | 
				
			||||||
tensorboardX==1.6
 | 
					 | 
				
			||||||
tensorflow==1.12.0
 | 
					 | 
				
			||||||
termcolor==1.1.0
 | 
					 | 
				
			||||||
torch==1.0.0
 | 
					 | 
				
			||||||
urllib3==1.24.1
 | 
					 | 
				
			||||||
Werkzeug==0.14.1
 | 
					 | 
				
			||||||
numba==0.42.1
 | 
					 | 
				
			||||||
| 
						 | 
					@ -3,6 +3,13 @@ import numpy as np
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class Logger:
 | 
					class Logger:
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    Keeps track of lists of items seperated by tags.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Notes
 | 
				
			||||||
 | 
					    -----
 | 
				
			||||||
 | 
					    Logger is a dictionary of lists.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
    def __init__(self):
 | 
					    def __init__(self):
 | 
				
			||||||
        self.log = {}
 | 
					        self.log = {}
 | 
				
			||||||
    def append(self, tag, value):
 | 
					    def append(self, tag, value):
 | 
				
			||||||
| 
						 | 
					@ -26,26 +33,22 @@ class Logger:
 | 
				
			||||||
    def __reversed__(self):
 | 
					    def __reversed__(self):
 | 
				
			||||||
        return reversed(self.log)
 | 
					        return reversed(self.log)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Workaround since we can't use SummaryWriter in a different process
 | 
					 | 
				
			||||||
# class LogWriter:
 | 
					 | 
				
			||||||
#     def __init__(self, logger, writer):
 | 
					 | 
				
			||||||
#         self.logger = logger
 | 
					 | 
				
			||||||
#         self.writer = writer
 | 
					 | 
				
			||||||
#         self.steps = Counter()
 | 
					 | 
				
			||||||
#     def write(self):
 | 
					 | 
				
			||||||
#         for key in self.logger.keys():
 | 
					 | 
				
			||||||
#             for value in self.logger[key]:
 | 
					 | 
				
			||||||
#                 self.steps[key] += 1
 | 
					 | 
				
			||||||
#                 if isinstance(value, int) or isinstance(value, float):
 | 
					 | 
				
			||||||
#                     self.writer.add_scalar(key, value, self.steps[key])
 | 
					 | 
				
			||||||
#                 if isinstance(value, np.ndarray) or isinstance(value, torch.Tensor):
 | 
					 | 
				
			||||||
#                     self.writer.add_histogram(key, value, self.steps[key])
 | 
					 | 
				
			||||||
#         self.logger.log = {}
 | 
					 | 
				
			||||||
#     def close(self):
 | 
					 | 
				
			||||||
#         self.writer.close()
 | 
					 | 
				
			||||||
    
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class LogWriter:
 | 
					class LogWriter:
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    Takes a logger and writes it to a writter. 
 | 
				
			||||||
 | 
					    While keeping track of the number of times it 
 | 
				
			||||||
 | 
					    a certain tag.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Notes
 | 
				
			||||||
 | 
					    -----
 | 
				
			||||||
 | 
					    Used to keep track of scalars and histograms in
 | 
				
			||||||
 | 
					    Tensorboard.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Parameters
 | 
				
			||||||
 | 
					    ----------
 | 
				
			||||||
 | 
					    writer
 | 
				
			||||||
 | 
					      The tensorboard writer.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
    def __init__(self, writer):
 | 
					    def __init__(self, writer):
 | 
				
			||||||
        self.writer = writer
 | 
					        self.writer = writer
 | 
				
			||||||
        self.steps = Counter()
 | 
					        self.steps = Counter()
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -5,22 +5,43 @@ Transition = namedtuple('Transition',
 | 
				
			||||||
    ('state', 'action', 'reward', 'next_state', 'done'))
 | 
					    ('state', 'action', 'reward', 'next_state', 'done'))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class EpisodeMemory(object):
 | 
					class EpisodeMemory(object):
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    Memory structure that stores an entire episode and
 | 
				
			||||||
 | 
					    the observation's associated log-based probabilities.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
    def __init__(self):
 | 
					    def __init__(self):
 | 
				
			||||||
        self.memory = []
 | 
					        self.memory = []
 | 
				
			||||||
        self.log_probs = []
 | 
					        self.log_probs = []
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def append(self, *args):
 | 
					    def append(self, *args):
 | 
				
			||||||
        """Saves a transition."""
 | 
					        """
 | 
				
			||||||
 | 
					        Adds a transition to the memory.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Parameters
 | 
				
			||||||
 | 
					        ----------
 | 
				
			||||||
 | 
					          *args
 | 
				
			||||||
 | 
					             The state, action, reward, next_state, done tuple
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
        self.memory.append(Transition(*args))
 | 
					        self.memory.append(Transition(*args))
 | 
				
			||||||
    
 | 
					    
 | 
				
			||||||
    def append_log_probs(self, logprob):
 | 
					    def append_log_probs(self, logprob):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Adds a log-based probability to the observation.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
        self.log_probs.append(logprob)
 | 
					        self.log_probs.append(logprob)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def clear(self):
 | 
					    def clear(self):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Clears the transitions and log-based probabilities.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
        self.memory.clear()
 | 
					        self.memory.clear()
 | 
				
			||||||
        self.log_probs.clear()
 | 
					        self.log_probs.clear()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def recall(self):
 | 
					    def recall(self):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Return a list of the transitions with their 
 | 
				
			||||||
 | 
					        associated log-based probabilities.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
        if len(self.memory) != len(self.log_probs):
 | 
					        if len(self.memory) != len(self.log_probs):
 | 
				
			||||||
            raise ValueError("Memory and recorded log probabilities must be the same length.")
 | 
					            raise ValueError("Memory and recorded log probabilities must be the same length.")
 | 
				
			||||||
        return list(zip(*tuple(zip(*self.memory)), self.log_probs))
 | 
					        return list(zip(*tuple(zip(*self.memory)), self.log_probs))
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -147,7 +147,9 @@ class MinSegmentTree(SegmentTree):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class PrioritizedReplayMemory(ReplayMemory):
 | 
					class PrioritizedReplayMemory(ReplayMemory):
 | 
				
			||||||
    def __init__(self, capacity, alpha):
 | 
					    def __init__(self, capacity, alpha):
 | 
				
			||||||
        """Create Prioritized Replay buffer.
 | 
					        """
 | 
				
			||||||
 | 
					        Create Prioritized Replay buffer.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        Parameters
 | 
					        Parameters
 | 
				
			||||||
        ----------
 | 
					        ----------
 | 
				
			||||||
        capacity: int
 | 
					        capacity: int
 | 
				
			||||||
| 
						 | 
					@ -156,9 +158,6 @@ class PrioritizedReplayMemory(ReplayMemory):
 | 
				
			||||||
        alpha: float
 | 
					        alpha: float
 | 
				
			||||||
            how much prioritization is used
 | 
					            how much prioritization is used
 | 
				
			||||||
            (0 - no prioritization, 1 - full prioritization)
 | 
					            (0 - no prioritization, 1 - full prioritization)
 | 
				
			||||||
        See Also
 | 
					 | 
				
			||||||
        --------
 | 
					 | 
				
			||||||
        ReplayBuffer.__init__
 | 
					 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        super(PrioritizedReplayMemory, self).__init__(capacity)
 | 
					        super(PrioritizedReplayMemory, self).__init__(capacity)
 | 
				
			||||||
        assert alpha >= 0
 | 
					        assert alpha >= 0
 | 
				
			||||||
| 
						 | 
					@ -173,7 +172,14 @@ class PrioritizedReplayMemory(ReplayMemory):
 | 
				
			||||||
        self._max_priority = 1.0
 | 
					        self._max_priority = 1.0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def append(self, *args, **kwargs):
 | 
					    def append(self, *args, **kwargs):
 | 
				
			||||||
        """See ReplayBuffer.store_effect"""
 | 
					        """
 | 
				
			||||||
 | 
					        Adds a transition to the buffer and add an initial prioritization.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Parameters
 | 
				
			||||||
 | 
					        ----------
 | 
				
			||||||
 | 
					          *args
 | 
				
			||||||
 | 
					             The state, action, reward, next_state, done tuple
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
        idx = self.position
 | 
					        idx = self.position
 | 
				
			||||||
        super().append(*args, **kwargs)
 | 
					        super().append(*args, **kwargs)
 | 
				
			||||||
        self._it_sum[idx] = self._max_priority ** self._alpha
 | 
					        self._it_sum[idx] = self._max_priority ** self._alpha
 | 
				
			||||||
| 
						 | 
					@ -191,10 +197,11 @@ class PrioritizedReplayMemory(ReplayMemory):
 | 
				
			||||||
        return res
 | 
					        return res
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def sample(self, batch_size, beta):
 | 
					    def sample(self, batch_size, beta):
 | 
				
			||||||
        """Sample a batch of experiences.
 | 
					        """
 | 
				
			||||||
        compared to ReplayBuffer.sample
 | 
					        Sample a batch of experiences.
 | 
				
			||||||
        it also returns importance weights and idxes
 | 
					        while returning importance weights and idxes
 | 
				
			||||||
        of sampled experiences.
 | 
					        of sampled experiences.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        Parameters
 | 
					        Parameters
 | 
				
			||||||
        ----------
 | 
					        ----------
 | 
				
			||||||
        batch_size: int
 | 
					        batch_size: int
 | 
				
			||||||
| 
						 | 
					@ -202,6 +209,7 @@ class PrioritizedReplayMemory(ReplayMemory):
 | 
				
			||||||
        beta: float
 | 
					        beta: float
 | 
				
			||||||
            To what degree to use importance weights
 | 
					            To what degree to use importance weights
 | 
				
			||||||
            (0 - no corrections, 1 - full correction)
 | 
					            (0 - no corrections, 1 - full correction)
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
        Returns
 | 
					        Returns
 | 
				
			||||||
        -------
 | 
					        -------
 | 
				
			||||||
        weights: np.array
 | 
					        weights: np.array
 | 
				
			||||||
| 
						 | 
					@ -232,6 +240,32 @@ class PrioritizedReplayMemory(ReplayMemory):
 | 
				
			||||||
        return batch
 | 
					        return batch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def sample_n_steps(self, batch_size, steps, beta):
 | 
					    def sample_n_steps(self, batch_size, steps, beta):
 | 
				
			||||||
 | 
					        r"""
 | 
				
			||||||
 | 
					        Sample a batch of sequential experiences.
 | 
				
			||||||
 | 
					        while returning importance weights and idxes
 | 
				
			||||||
 | 
					        of sampled experiences.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Parameters
 | 
				
			||||||
 | 
					        ----------
 | 
				
			||||||
 | 
					        batch_size: int
 | 
				
			||||||
 | 
					            How many transitions to sample.
 | 
				
			||||||
 | 
					        beta: float
 | 
				
			||||||
 | 
					            To what degree to use importance weights
 | 
				
			||||||
 | 
					            (0 - no corrections, 1 - full correction)
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        Notes
 | 
				
			||||||
 | 
					        -----
 | 
				
			||||||
 | 
					        The number of batches sampled is :math:`\lfloor\frac{batch\_size}{steps}\rfloor`.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Returns
 | 
				
			||||||
 | 
					        -------
 | 
				
			||||||
 | 
					        weights: np.array
 | 
				
			||||||
 | 
					            Array of shape (batch_size,) and dtype np.float32
 | 
				
			||||||
 | 
					            denoting importance weight of each sampled transition
 | 
				
			||||||
 | 
					        idxes: np.array
 | 
				
			||||||
 | 
					            Array of shape (batch_size,) and dtype np.int32
 | 
				
			||||||
 | 
					            idexes in buffer of sampled experiences
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
        assert beta > 0
 | 
					        assert beta > 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        sample_size = batch_size // steps
 | 
					        sample_size = batch_size // steps
 | 
				
			||||||
| 
						 | 
					@ -262,9 +296,11 @@ class PrioritizedReplayMemory(ReplayMemory):
 | 
				
			||||||
    
 | 
					    
 | 
				
			||||||
    @jit(forceobj = True)
 | 
					    @jit(forceobj = True)
 | 
				
			||||||
    def update_priorities(self, idxes, priorities):
 | 
					    def update_priorities(self, idxes, priorities):
 | 
				
			||||||
        """Update priorities of sampled transitions.
 | 
					        """
 | 
				
			||||||
 | 
					        Update priorities of sampled transitions.
 | 
				
			||||||
        sets priority of transition at index idxes[i] in buffer
 | 
					        sets priority of transition at index idxes[i] in buffer
 | 
				
			||||||
        to priorities[i].
 | 
					        to priorities[i].
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
        Parameters
 | 
					        Parameters
 | 
				
			||||||
        ----------
 | 
					        ----------
 | 
				
			||||||
        idxes: [int]
 | 
					        idxes: [int]
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -4,21 +4,38 @@ import torch
 | 
				
			||||||
Transition = namedtuple('Transition',
 | 
					Transition = namedtuple('Transition',
 | 
				
			||||||
    ('state', 'action', 'reward', 'next_state', 'done'))
 | 
					    ('state', 'action', 'reward', 'next_state', 'done'))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Implements a Ring Buffer
 | 
					 | 
				
			||||||
class ReplayMemory(object):
 | 
					class ReplayMemory(object):
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    Creates a ring buffer of a fixed size.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Parameters
 | 
				
			||||||
 | 
					    ----------
 | 
				
			||||||
 | 
					    capacity : int
 | 
				
			||||||
 | 
					      The maximum size of the buffer
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
    def __init__(self, capacity):
 | 
					    def __init__(self, capacity):
 | 
				
			||||||
        self.capacity = capacity
 | 
					        self.capacity = capacity
 | 
				
			||||||
        self.memory = []
 | 
					        self.memory = []
 | 
				
			||||||
        self.position = 0
 | 
					        self.position = 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def append(self, *args):
 | 
					    def append(self, *args):
 | 
				
			||||||
        """Saves a transition."""
 | 
					        """
 | 
				
			||||||
 | 
					        Adds a transition to the buffer.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Parameters
 | 
				
			||||||
 | 
					        ----------
 | 
				
			||||||
 | 
					        *args
 | 
				
			||||||
 | 
					          The state, action, reward, next_state, done tuple
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
        if len(self.memory) < self.capacity:
 | 
					        if len(self.memory) < self.capacity:
 | 
				
			||||||
            self.memory.append(None)
 | 
					            self.memory.append(None)
 | 
				
			||||||
        self.memory[self.position] = Transition(*args)
 | 
					        self.memory[self.position] = Transition(*args)
 | 
				
			||||||
        self.position = (self.position + 1) % self.capacity
 | 
					        self.position = (self.position + 1) % self.capacity
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def clear(self):
 | 
					    def clear(self):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Clears the buffer.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
        self.memory.clear()
 | 
					        self.memory.clear()
 | 
				
			||||||
        self.position = 0
 | 
					        self.position = 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -37,10 +54,35 @@ class ReplayMemory(object):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def sample(self, batch_size):
 | 
					    def sample(self, batch_size):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Returns a random sample from the buffer.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Parameters
 | 
				
			||||||
 | 
					        ----------
 | 
				
			||||||
 | 
					        batch_size : int
 | 
				
			||||||
 | 
					          The number of observations to sample.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
        return random.sample(self.memory, batch_size)
 | 
					        return random.sample(self.memory, batch_size)
 | 
				
			||||||
    
 | 
					    
 | 
				
			||||||
    def sample_n_steps(self, batch_size, steps):
 | 
					    def sample_n_steps(self, batch_size, steps):
 | 
				
			||||||
        idxes = random.sample(range(len(self.memory) - steps), batch_size // steps)
 | 
					        r"""
 | 
				
			||||||
 | 
					        Returns a random sample of sequential batches of size steps.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Notes
 | 
				
			||||||
 | 
					        -----
 | 
				
			||||||
 | 
					        The number of batches sampled is :math:`\lfloor\frac{batch\_size}{steps}\rfloor`.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Parameters
 | 
				
			||||||
 | 
					        ----------
 | 
				
			||||||
 | 
					        batch_size : int
 | 
				
			||||||
 | 
					          The total number of observations to sample.
 | 
				
			||||||
 | 
					        steps : int
 | 
				
			||||||
 | 
					          The number of observations after the one selected to sample.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        idxes = random.sample(
 | 
				
			||||||
 | 
					            range(len(self.memory) - steps), 
 | 
				
			||||||
 | 
					            batch_size // steps
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
        step_idxes = []
 | 
					        step_idxes = []
 | 
				
			||||||
        for i in idxes:
 | 
					        for i in idxes:
 | 
				
			||||||
            step_idxes += range(i, i + steps)
 | 
					            step_idxes += range(i, i + steps)
 | 
				
			||||||
| 
						 | 
					@ -56,10 +98,10 @@ class ReplayMemory(object):
 | 
				
			||||||
        return value in self.memory
 | 
					        return value in self.memory
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __getitem__(self, index):
 | 
					    def __getitem__(self, index):
 | 
				
			||||||
        return self.memory[index]
 | 
					        return self.memory[index % self.capacity]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __setitem__(self, index, value):
 | 
					    def __setitem__(self, index, value):
 | 
				
			||||||
        self.memory[index] = value
 | 
					        self.memory[index % self.capacity] = value
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __reversed__(self):
 | 
					    def __reversed__(self):
 | 
				
			||||||
        return reversed(self.memory)
 | 
					        return reversed(self.memory)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -7,9 +7,36 @@ from copy import deepcopy
 | 
				
			||||||
# What if we want to sometimes do gradient descent as well?
 | 
					# What if we want to sometimes do gradient descent as well?
 | 
				
			||||||
class ESNetwork(Network):
 | 
					class ESNetwork(Network):
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    Network that functions from the paper Evolutionary Strategies (https://arxiv.org/abs/1703.03864)
 | 
					    Uses evolutionary tecniques to optimize a neural network.
 | 
				
			||||||
    fitness_fun := model, *args -> fitness_value (float)
 | 
					
 | 
				
			||||||
    We wish to find a model that maximizes the fitness function
 | 
					    Notes
 | 
				
			||||||
 | 
					    -----
 | 
				
			||||||
 | 
					    Derived from the paper 
 | 
				
			||||||
 | 
					    Evolutionary Strategies 
 | 
				
			||||||
 | 
					    (https://arxiv.org/abs/1703.03864)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Parameters
 | 
				
			||||||
 | 
					    ----------
 | 
				
			||||||
 | 
					    model : nn.Module
 | 
				
			||||||
 | 
					      A PyTorch nn.Module.
 | 
				
			||||||
 | 
					    optimizer
 | 
				
			||||||
 | 
					      A PyTorch opimtizer from torch.optim.
 | 
				
			||||||
 | 
					    population_size : int
 | 
				
			||||||
 | 
					      The number of networks to evaluate each iteration.
 | 
				
			||||||
 | 
					    fitness_fn : function
 | 
				
			||||||
 | 
					      Function that evaluates a network and returns a higher
 | 
				
			||||||
 | 
					      number for better performing networks.
 | 
				
			||||||
 | 
					    sigma : number
 | 
				
			||||||
 | 
					      The standard deviation of the guassian noise added to
 | 
				
			||||||
 | 
					      the parameters when creating the population.
 | 
				
			||||||
 | 
					    config : dict
 | 
				
			||||||
 | 
					      A dictionary of configuration items.
 | 
				
			||||||
 | 
					    device
 | 
				
			||||||
 | 
					      A device to send the weights to.
 | 
				
			||||||
 | 
					    logger
 | 
				
			||||||
 | 
					      Keeps track of historical weights
 | 
				
			||||||
 | 
					    name
 | 
				
			||||||
 | 
					      For use in logger to differentiate in analysis.
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    def __init__(self, model, optimizer, population_size, fitness_fn, config, sigma = 0.05, device = None, logger = None, name = ""):
 | 
					    def __init__(self, model, optimizer, population_size, fitness_fn, config, sigma = 0.05, device = None, logger = None, name = ""):
 | 
				
			||||||
        super(ESNetwork, self).__init__(model, optimizer, config, device, logger, name)
 | 
					        super(ESNetwork, self).__init__(model, optimizer, config, device, logger, name)
 | 
				
			||||||
| 
						 | 
					@ -18,9 +45,15 @@ class ESNetwork(Network):
 | 
				
			||||||
        self.sigma = sigma
 | 
					        self.sigma = sigma
 | 
				
			||||||
        assert self.sigma > 0
 | 
					        assert self.sigma > 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # We're not going to be calculating gradients in the traditional way
 | 
					 | 
				
			||||||
    # So there's no need to waste computation time keeping track
 | 
					 | 
				
			||||||
    def __call__(self, *args):
 | 
					    def __call__(self, *args):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Notes
 | 
				
			||||||
 | 
					        -----
 | 
				
			||||||
 | 
					        Since gradients aren't going to be computed in the 
 | 
				
			||||||
 | 
					        traditional fashion, there is no need to keep
 | 
				
			||||||
 | 
					        track of the computations performed on the
 | 
				
			||||||
 | 
					        tensors.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
        with torch.no_grad():
 | 
					        with torch.no_grad():
 | 
				
			||||||
            result = self.model(*args)
 | 
					            result = self.model(*args)
 | 
				
			||||||
        return result
 | 
					        return result
 | 
				
			||||||
| 
						 | 
					@ -48,6 +81,14 @@ class ESNetwork(Network):
 | 
				
			||||||
        return candidate_solutions
 | 
					        return candidate_solutions
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def calc_gradients(self, *args):
 | 
					    def calc_gradients(self, *args):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Calculate gradients by shifting parameters
 | 
				
			||||||
 | 
					        towards the networks with the highest fitness value.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        This is calculated by evaluating the fitness of multiple
 | 
				
			||||||
 | 
					        networks according to the fitness function specified in
 | 
				
			||||||
 | 
					        the class. 
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
        ## Generate Noise
 | 
					        ## Generate Noise
 | 
				
			||||||
        white_noise_dict, noise_dict = self._generate_noise_dicts()
 | 
					        white_noise_dict, noise_dict = self._generate_noise_dicts()
 | 
				
			||||||
        
 | 
					        
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,6 +1,21 @@
 | 
				
			||||||
class Network:
 | 
					class Network:
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    Wrapper around model which provides copy of it instead of trained weights
 | 
					    Wrapper around model and optimizer in PyTorch to abstract away common use cases.
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    Parameters
 | 
				
			||||||
 | 
					    ----------
 | 
				
			||||||
 | 
					    model : nn.Module
 | 
				
			||||||
 | 
					      A PyTorch nn.Module.
 | 
				
			||||||
 | 
					    optimizer
 | 
				
			||||||
 | 
					      A PyTorch opimtizer from torch.optim.
 | 
				
			||||||
 | 
					    config : dict
 | 
				
			||||||
 | 
					      A dictionary of configuration items.
 | 
				
			||||||
 | 
					    device
 | 
				
			||||||
 | 
					      A device to send the weights to.
 | 
				
			||||||
 | 
					    logger
 | 
				
			||||||
 | 
					      Keeps track of historical weights
 | 
				
			||||||
 | 
					    name
 | 
				
			||||||
 | 
					      For use in logger to differentiate in analysis.
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    def __init__(self, model, optimizer, config, device = None, logger = None, name = ""):
 | 
					    def __init__(self, model, optimizer, config, device = None, logger = None, name = ""):
 | 
				
			||||||
        self.model = model
 | 
					        self.model = model
 | 
				
			||||||
| 
						 | 
					@ -18,14 +33,29 @@ class Network:
 | 
				
			||||||
        return self.model(*args)
 | 
					        return self.model(*args)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def clamp_gradients(self, x = 1):
 | 
					    def clamp_gradients(self, x = 1):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Forcing gradients to stay within a certain interval
 | 
				
			||||||
 | 
					        by setting it to the bound if it goes over it.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Parameters
 | 
				
			||||||
 | 
					        ----------
 | 
				
			||||||
 | 
					        x : number > 0
 | 
				
			||||||
 | 
					          Sets the interval to be [-x, x]
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
        assert x > 0
 | 
					        assert x > 0
 | 
				
			||||||
        for param in self.model.parameters():
 | 
					        for param in self.model.parameters():
 | 
				
			||||||
            param.grad.data.clamp_(-x, x)
 | 
					            param.grad.data.clamp_(-x, x)
 | 
				
			||||||
    
 | 
					    
 | 
				
			||||||
    def zero_grad(self):
 | 
					    def zero_grad(self):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Clears out gradients held in the model.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
        self.model.zero_grad()
 | 
					        self.model.zero_grad()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def step(self):
 | 
					    def step(self):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Run a step of the optimizer on `model`.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
        self.optimizer.step()
 | 
					        self.optimizer.step()
 | 
				
			||||||
    
 | 
					    
 | 
				
			||||||
    def log_named_parameters(self):
 | 
					    def log_named_parameters(self):
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -6,6 +6,24 @@ import math
 | 
				
			||||||
# This class utilizes this property of the normal distribution
 | 
					# This class utilizes this property of the normal distribution
 | 
				
			||||||
# N(mu, sigma) = mu + sigma * N(0, 1)
 | 
					# N(mu, sigma) = mu + sigma * N(0, 1)
 | 
				
			||||||
class NoisyLinear(nn.Linear):
 | 
					class NoisyLinear(nn.Linear):
 | 
				
			||||||
 | 
					  """
 | 
				
			||||||
 | 
					  Draws the parameters of nn.Linear from a normal distribution.
 | 
				
			||||||
 | 
					  The parameters of the normal distribution are registered as 
 | 
				
			||||||
 | 
					  learnable parameters in the neural network.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  Parameters
 | 
				
			||||||
 | 
					  ----------
 | 
				
			||||||
 | 
					  in_features
 | 
				
			||||||
 | 
					    Size of each input sample.
 | 
				
			||||||
 | 
					  out_features
 | 
				
			||||||
 | 
					    Size of each output sample.
 | 
				
			||||||
 | 
					  sigma_init
 | 
				
			||||||
 | 
					    The starting standard deviation of guassian noise.
 | 
				
			||||||
 | 
					  bias
 | 
				
			||||||
 | 
					     If set to False, the layer will not 
 | 
				
			||||||
 | 
					     learn an additive bias.
 | 
				
			||||||
 | 
					     Default: True
 | 
				
			||||||
 | 
					  """
 | 
				
			||||||
  def __init__(self, in_features, out_features, sigma_init = 0.017, bias = True):
 | 
					  def __init__(self, in_features, out_features, sigma_init = 0.017, bias = True):
 | 
				
			||||||
    super(NoisyLinear, self).__init__(in_features, out_features, bias = bias)
 | 
					    super(NoisyLinear, self).__init__(in_features, out_features, bias = bias)
 | 
				
			||||||
    # One of the parameters the network is going to tune is the 
 | 
					    # One of the parameters the network is going to tune is the 
 | 
				
			||||||
| 
						 | 
					@ -27,6 +45,15 @@ class NoisyLinear(nn.Linear):
 | 
				
			||||||
    nn.init.uniform_(self.bias, -std, std)
 | 
					    nn.init.uniform_(self.bias, -std, std)
 | 
				
			||||||
  
 | 
					  
 | 
				
			||||||
  def forward(self, x):
 | 
					  def forward(self, x):
 | 
				
			||||||
 | 
					    r"""
 | 
				
			||||||
 | 
					    Calculates the output :math:`y` through the following:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    :math:`sigma \sim N(mu_1, std_1)`
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    :math:`bias \sim N(mu_2, std_2)`
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    :math:`y = sigma \cdot x + bias`
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
    # Fill s_normal_weight with values from the standard normal distribution
 | 
					    # Fill s_normal_weight with values from the standard normal distribution
 | 
				
			||||||
    self.s_normal_weight.normal_()
 | 
					    self.s_normal_weight.normal_()
 | 
				
			||||||
    weight_noise = self.sigma_weight * self.s_normal_weight.clone().requires_grad_()
 | 
					    weight_noise = self.sigma_weight * self.s_normal_weight.clone().requires_grad_()
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,25 +1,43 @@
 | 
				
			||||||
from copy import deepcopy
 | 
					from copy import deepcopy
 | 
				
			||||||
# Derived from ptan library
 | 
					
 | 
				
			||||||
class TargetNetwork:
 | 
					class TargetNetwork:
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    Wrapper around model which provides copy of it instead of trained weights
 | 
					    Creates a clone of a network with syncing capabilities.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Parameters
 | 
				
			||||||
 | 
					    ----------
 | 
				
			||||||
 | 
					    network
 | 
				
			||||||
 | 
					      The network to clone.
 | 
				
			||||||
 | 
					    device
 | 
				
			||||||
 | 
					      The device to put the cloned parameters in.
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    def __init__(self, network, device = None):
 | 
					    def __init__(self, network, device = None):
 | 
				
			||||||
        self.model = network.model
 | 
					        self.model = network.model
 | 
				
			||||||
        self.target_model = deepcopy(network.model)
 | 
					        self.target_model = deepcopy(network.model)
 | 
				
			||||||
        if network.device is not None:
 | 
					        if device is not None:
 | 
				
			||||||
 | 
					            self.target_model = self.target_model.to(device)
 | 
				
			||||||
 | 
					        elif network.device is not None:
 | 
				
			||||||
            self.target_model = self.target_model.to(network.device)
 | 
					            self.target_model = self.target_model.to(network.device)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __call__(self, *args):
 | 
					    def __call__(self, *args):
 | 
				
			||||||
        return self.model(*args)
 | 
					        return self.model(*args)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def sync(self):
 | 
					    def sync(self):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Perform a full state sync with the originating model.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
        self.target_model.load_state_dict(self.model.state_dict())
 | 
					        self.target_model.load_state_dict(self.model.state_dict())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def partial_sync(self, tau):
 | 
					    def partial_sync(self, tau):
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        Blend params of target net with params from the model
 | 
					        Partially move closer to the parameters of the originating
 | 
				
			||||||
        :param tau:
 | 
					        model by updating parameters to be a mix of the
 | 
				
			||||||
 | 
					        originating and the clone models.
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        Parameters
 | 
				
			||||||
 | 
					        ----------
 | 
				
			||||||
 | 
					        tau : number
 | 
				
			||||||
 | 
					          A number between 0-1 which indicates the proportion of the originator and clone in the new clone.
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        assert isinstance(tau, float)
 | 
					        assert isinstance(tau, float)
 | 
				
			||||||
        assert 0.0 < tau <= 1.0
 | 
					        assert 0.0 < tau <= 1.0
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,5 +1,32 @@
 | 
				
			||||||
from .Scheduler import Scheduler
 | 
					from .Scheduler import Scheduler
 | 
				
			||||||
class ExponentialScheduler(Scheduler):
 | 
					class ExponentialScheduler(Scheduler):
 | 
				
			||||||
 | 
					    r"""
 | 
				
			||||||
 | 
					    A exponential scheduler that given a certain number
 | 
				
			||||||
 | 
					    of iterations, spaces the values between
 | 
				
			||||||
 | 
					    a start and an end point in an exponential order.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Notes
 | 
				
			||||||
 | 
					    -----
 | 
				
			||||||
 | 
					    The forumula used to produce the value :math:`y` is based on the number of
 | 
				
			||||||
 | 
					    times you call `next`. (denoted as :math:`i`)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    :math:`y(1) = initial\_value`
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    :math:`base = \sqrt[iterations]{\frac{end\_value}{initial\_value}}`
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    :math:`y(i) = y(1) \cdot base^{i - 1}`
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Another property is that :math:`y(iterations) = end\_value`.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Parameters
 | 
				
			||||||
 | 
					    ----------
 | 
				
			||||||
 | 
					    initial_value : number
 | 
				
			||||||
 | 
					      The first value returned in the schedule.
 | 
				
			||||||
 | 
					    end_value: number
 | 
				
			||||||
 | 
					      The value returned when the maximum number of iterations are reached
 | 
				
			||||||
 | 
					    iterations: int
 | 
				
			||||||
 | 
					      The total number of iterations
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
    def __init__(self, initial_value, end_value, iterations):
 | 
					    def __init__(self, initial_value, end_value, iterations):
 | 
				
			||||||
        super(ExponentialScheduler, self).__init__(initial_value, end_value, iterations)
 | 
					        super(ExponentialScheduler, self).__init__(initial_value, end_value, iterations)
 | 
				
			||||||
        self.base = (end_value / initial_value) ** (1.0 / iterations)
 | 
					        self.base = (end_value / initial_value) ** (1.0 / iterations)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,5 +1,29 @@
 | 
				
			||||||
from .Scheduler import Scheduler
 | 
					from .Scheduler import Scheduler
 | 
				
			||||||
class LinearScheduler(Scheduler):
 | 
					class LinearScheduler(Scheduler):
 | 
				
			||||||
 | 
					    r"""
 | 
				
			||||||
 | 
					    A linear scheduler that given a certain number
 | 
				
			||||||
 | 
					    of iterations, equally spaces the values between
 | 
				
			||||||
 | 
					    a start and an end point.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Notes
 | 
				
			||||||
 | 
					    -----
 | 
				
			||||||
 | 
					    The forumula used to produce the value :math:`y` is based on the number of
 | 
				
			||||||
 | 
					    times you call `next`. (denoted as :math:`i`)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    :math:`y(1) = initial\_value`
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    :math:`y(i) = slope(i - 1) + y(1)`
 | 
				
			||||||
 | 
					    where :math:`slope = \frac{end\_value - initial\_value}{iterations}`.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Parameters
 | 
				
			||||||
 | 
					    ----------
 | 
				
			||||||
 | 
					    initial_value : number
 | 
				
			||||||
 | 
					      The first value returned in the schedule.
 | 
				
			||||||
 | 
					    end_value: number
 | 
				
			||||||
 | 
					      The value returned when the maximum number of iterations are reached
 | 
				
			||||||
 | 
					    iterations: int
 | 
				
			||||||
 | 
					      The total number of iterations
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
    def __init__(self, initial_value, end_value, iterations):
 | 
					    def __init__(self, initial_value, end_value, iterations):
 | 
				
			||||||
        super(LinearScheduler, self).__init__(initial_value, end_value, iterations)
 | 
					        super(LinearScheduler, self).__init__(initial_value, end_value, iterations)
 | 
				
			||||||
        self.slope = (end_value - initial_value) / iterations
 | 
					        self.slope = (end_value - initial_value) / iterations
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -4,6 +4,14 @@ import random
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def set_seed(SEED):
 | 
					def set_seed(SEED):
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    Set the seed for repeatability purposes.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Parameters
 | 
				
			||||||
 | 
					    ----------
 | 
				
			||||||
 | 
					    SEED : int
 | 
				
			||||||
 | 
					      The seed to set numpy, random, and torch to.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
    # Set `PYTHONHASHSEED` environment variable at a fixed value
 | 
					    # Set `PYTHONHASHSEED` environment variable at a fixed value
 | 
				
			||||||
    environ['PYTHONHASHSEED'] = str(SEED)
 | 
					    environ['PYTHONHASHSEED'] = str(SEED)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
							
								
								
									
										9
									
								
								setup.py
									
										
									
									
									
								
							
							
						
						
									
										9
									
								
								setup.py
									
										
									
									
									
								
							| 
						 | 
					@ -12,4 +12,11 @@ setuptools.setup(
 | 
				
			||||||
    description="Reinforcement Learning Framework for PyTorch",
 | 
					    description="Reinforcement Learning Framework for PyTorch",
 | 
				
			||||||
    version="0.1",
 | 
					    version="0.1",
 | 
				
			||||||
    packages=setuptools.find_packages(),
 | 
					    packages=setuptools.find_packages(),
 | 
				
			||||||
)
 | 
					    install_requires=[
 | 
				
			||||||
 | 
					        "numpy~=1.16.0",
 | 
				
			||||||
 | 
					        "opencv-python~=4.2.0.32",
 | 
				
			||||||
 | 
					        "gym~=0.10.11",
 | 
				
			||||||
 | 
					        "torch~=1.4.0",
 | 
				
			||||||
 | 
					        "numba~=0.48.0"
 | 
				
			||||||
 | 
					    ]
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
							
								
								
									
										6
									
								
								tests/test.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										6
									
								
								tests/test.py
									
										
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,6 @@
 | 
				
			||||||
 | 
					import rltorch
 | 
				
			||||||
 | 
					import unittest
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class Test(unittest.TestCase):
 | 
				
			||||||
 | 
					    def test(self):
 | 
				
			||||||
 | 
					        pass
 | 
				
			||||||
							
								
								
									
										17
									
								
								tox.ini
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										17
									
								
								tox.ini
									
										
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,17 @@
 | 
				
			||||||
 | 
					[tox]
 | 
				
			||||||
 | 
					envlist = 
 | 
				
			||||||
 | 
					    py36
 | 
				
			||||||
 | 
					    py37
 | 
				
			||||||
 | 
					    py38
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					[testenv]
 | 
				
			||||||
 | 
					deps = coverage
 | 
				
			||||||
 | 
					commands = 
 | 
				
			||||||
 | 
					    coverage run --source=tests,rltorch -m unittest discover tests
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					[testenv:py38]
 | 
				
			||||||
 | 
					commands =
 | 
				
			||||||
 | 
					    coverage run --source=tests,rltorch -m unittest discover tests
 | 
				
			||||||
 | 
					    coverage report -m
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue