SneakyTrain uses separate replay buffer
Scripts were cleaned up considerably and comments were added
This commit is contained in:
parent
b7aa4a4ec6
commit
d78892e62c
2 changed files with 186 additions and 175 deletions
209
play.py
209
play.py
|
@ -1,41 +1,33 @@
|
||||||
import gym
|
from gym.spaces.box import Box
|
||||||
import pygame
|
import pygame
|
||||||
import sys
|
from pygame.locals import VIDEORESIZE
|
||||||
import time
|
from rltorch.memory import ReplayMemory
|
||||||
import matplotlib
|
|
||||||
import rltorch.memory as M
|
|
||||||
try:
|
|
||||||
matplotlib.use('GTK3Agg')
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
import pyglet.window as pw
|
|
||||||
|
|
||||||
from collections import deque
|
|
||||||
from pygame.locals import HWSURFACE, DOUBLEBUF, RESIZABLE, VIDEORESIZE
|
|
||||||
from threading import Thread, Event, Timer
|
|
||||||
|
|
||||||
class Play:
|
class Play:
|
||||||
def __init__(self, env, action_selector, memory, agent, sneaky_env, transpose = True, fps = 30, zoom = None, keys_to_action = None):
|
def __init__(self, env, action_selector, memory, memory_lock, agent, sneaky_env, config):
|
||||||
self.env = env
|
self.env = env
|
||||||
self.action_selector = action_selector
|
self.action_selector = action_selector
|
||||||
self.transpose = transpose
|
self.memory = memory
|
||||||
self.fps = fps
|
self.memory_lock = memory_lock
|
||||||
self.zoom = zoom
|
self.agent = agent
|
||||||
self.keys_to_action = None
|
self.sneaky_env = sneaky_env
|
||||||
|
# Get relevant parameters from config or set sane defaults
|
||||||
|
self.transpose = config['transpose'] if 'transpose' in config else True
|
||||||
|
self.fps = config['fps'] if 'fps' in config else 30
|
||||||
|
self.zoom = config['zoom'] if 'zoom' in config else 1
|
||||||
|
self.keys_to_action = config['keys_to_action'] if 'keys_to_action' in config else None
|
||||||
|
self.seconds_play_per_state = config['seconds_play_per_state'] if 'seconds_play_per_state' in config else 30
|
||||||
|
self.num_sneaky_episodes = config['num_sneaky_episodes'] if 'num_sneaky_episodes' in config else 10
|
||||||
|
self.memory_size = config['memory_size'] if 'memory_size' in config else 10**4
|
||||||
|
self.replay_skip = config['replay_skip'] if 'replay_skip' in config else 0
|
||||||
|
# Initial values...
|
||||||
self.video_size = (0, 0)
|
self.video_size = (0, 0)
|
||||||
self.pressed_keys = []
|
self.pressed_keys = []
|
||||||
self.screen = None
|
self.screen = None
|
||||||
self.relevant_keys = set()
|
self.relevant_keys = set()
|
||||||
self.running = True
|
self.running = True
|
||||||
self.switch = Event()
|
|
||||||
self.state = 0
|
self.state = 0
|
||||||
self.paused = False
|
self.clock = pygame.time.Clock()
|
||||||
self.memory = memory
|
|
||||||
self.agent = agent
|
|
||||||
self.sneaky_env = sneaky_env
|
|
||||||
|
|
||||||
def _display_arr(self, obs, screen, arr, video_size):
|
def _display_arr(self, obs, screen, arr, video_size):
|
||||||
if obs is not None:
|
if obs is not None:
|
||||||
|
@ -48,6 +40,21 @@ class Play:
|
||||||
pyg_img = pygame.surfarray.make_surface(arr.swapaxes(0, 1) if self.transpose else arr)
|
pyg_img = pygame.surfarray.make_surface(arr.swapaxes(0, 1) if self.transpose else arr)
|
||||||
pyg_img = pygame.transform.scale(pyg_img, video_size)
|
pyg_img = pygame.transform.scale(pyg_img, video_size)
|
||||||
screen.blit(pyg_img, (0,0))
|
screen.blit(pyg_img, (0,0))
|
||||||
|
|
||||||
|
def _process_common_pygame_events(self, event):
|
||||||
|
if event.type == pygame.QUIT:
|
||||||
|
self.running = False
|
||||||
|
elif event.type == VIDEORESIZE:
|
||||||
|
self.video_size = event.size
|
||||||
|
self.screen = pygame.display.set_mode(self.video_size)
|
||||||
|
elif event.type == pygame.KEYDOWN and event.key == pygame.K_ESCAPE:
|
||||||
|
self.running = False
|
||||||
|
else:
|
||||||
|
# No event was matched here
|
||||||
|
return False
|
||||||
|
# One of the events above matched
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
def _human_play(self, obs):
|
def _human_play(self, obs):
|
||||||
action = self.keys_to_action.get(tuple(sorted(self.pressed_keys)), 0)
|
action = self.keys_to_action.get(tuple(sorted(self.pressed_keys)), 0)
|
||||||
|
@ -57,20 +64,14 @@ class Play:
|
||||||
|
|
||||||
# process pygame events
|
# process pygame events
|
||||||
for event in pygame.event.get():
|
for event in pygame.event.get():
|
||||||
# test events, set key states
|
if self._process_common_pygame_events(event):
|
||||||
if event.type == pygame.KEYDOWN:
|
continue
|
||||||
|
elif event.type == pygame.KEYDOWN:
|
||||||
if event.key in self.relevant_keys:
|
if event.key in self.relevant_keys:
|
||||||
self.pressed_keys.append(event.key)
|
self.pressed_keys.append(event.key)
|
||||||
elif event.key == pygame.K_ESCAPE:
|
|
||||||
self.running = False
|
|
||||||
elif event.type == pygame.KEYUP:
|
elif event.type == pygame.KEYUP:
|
||||||
if event.key in self.relevant_keys:
|
if event.key in self.relevant_keys:
|
||||||
self.pressed_keys.remove(event.key)
|
self.pressed_keys.remove(event.key)
|
||||||
elif event.type == pygame.QUIT:
|
|
||||||
self.running = False
|
|
||||||
elif event.type == VIDEORESIZE:
|
|
||||||
self.video_size = event.size
|
|
||||||
self.screen = pygame.display.set_mode(self.video_size)
|
|
||||||
|
|
||||||
pygame.display.flip()
|
pygame.display.flip()
|
||||||
self.clock.tick(self.fps)
|
self.clock.tick(self.fps)
|
||||||
|
@ -84,13 +85,7 @@ class Play:
|
||||||
|
|
||||||
# process pygame events
|
# process pygame events
|
||||||
for event in pygame.event.get():
|
for event in pygame.event.get():
|
||||||
if event.type == pygame.QUIT:
|
self._process_common_pygame_events(event)
|
||||||
self.running = False
|
|
||||||
elif event.type == VIDEORESIZE:
|
|
||||||
self.video_size = event.size
|
|
||||||
self.screen = pygame.display.set_mode(self.video_size)
|
|
||||||
elif event.type == pygame.KEYDOWN and event.key == pygame.K_ESCAPE:
|
|
||||||
self.running = False
|
|
||||||
|
|
||||||
pygame.display.flip()
|
pygame.display.flip()
|
||||||
self.clock.tick(self.fps)
|
self.clock.tick(self.fps)
|
||||||
|
@ -107,7 +102,7 @@ class Play:
|
||||||
|
|
||||||
self.video_size = video_size
|
self.video_size = video_size
|
||||||
self.screen = pygame.display.set_mode(self.video_size)
|
self.screen = pygame.display.set_mode(self.video_size)
|
||||||
pygame.font.init() # For later text
|
pygame.font.init()
|
||||||
|
|
||||||
def _setup_keys(self):
|
def _setup_keys(self):
|
||||||
if self.keys_to_action is None:
|
if self.keys_to_action is None:
|
||||||
|
@ -124,48 +119,59 @@ class Play:
|
||||||
self.state = (self.state + 1) % 5
|
self.state = (self.state + 1) % 5
|
||||||
|
|
||||||
def pause(self, text = ""):
|
def pause(self, text = ""):
|
||||||
self.paused = True
|
|
||||||
myfont = pygame.font.SysFont('Comic Sans MS', 50)
|
myfont = pygame.font.SysFont('Comic Sans MS', 50)
|
||||||
textsurface = myfont.render(text, False, (0, 0, 0))
|
textsurface = myfont.render(text, False, (0, 0, 0))
|
||||||
self.screen.blit(textsurface,(0,0))
|
self.screen.blit(textsurface,(0,0))
|
||||||
|
|
||||||
|
# Process pygame events
|
||||||
for event in pygame.event.get():
|
for event in pygame.event.get():
|
||||||
if event.type == pygame.QUIT:
|
if self._process_common_pygame_events(event):
|
||||||
self.running = False
|
continue
|
||||||
elif event.type == VIDEORESIZE:
|
|
||||||
self.video_size = event.size
|
|
||||||
self.screen = pygame.display.set_mode(self.video_size)
|
|
||||||
elif event.type == pygame.KEYDOWN:
|
elif event.type == pygame.KEYDOWN:
|
||||||
if event.key == pygame.K_SPACE:
|
if event.key == pygame.K_SPACE:
|
||||||
self.pressed_keys.append(event.key)
|
self.pressed_keys.append(event.key)
|
||||||
elif event.key == pygame.K_ESCAPE:
|
|
||||||
self.running = False
|
|
||||||
elif event.type == pygame.KEYUP and event.key == pygame.K_SPACE:
|
elif event.type == pygame.KEYUP and event.key == pygame.K_SPACE:
|
||||||
self.pressed_keys.remove(event.key)
|
self.pressed_keys.remove(event.key)
|
||||||
self._increment_state()
|
self._increment_state()
|
||||||
self.paused = False
|
|
||||||
pygame.display.flip()
|
pygame.display.flip()
|
||||||
self.clock.tick(self.fps)
|
self.clock.tick(self.fps)
|
||||||
|
|
||||||
|
|
||||||
def sneaky_train(self):
|
def sneaky_train(self):
|
||||||
|
self.memory_lock.acquire()
|
||||||
|
|
||||||
# Backup memory
|
# Backup memory
|
||||||
backup_memory = self.memory
|
backup_memory = self.memory
|
||||||
self.memory = M.ReplayMemory(capacity = 2000) # Another configurable parameter
|
self.memory = ReplayMemory(capacity = self.memory_size)
|
||||||
EPISODES = 30 # Make this configurable
|
|
||||||
replay_skip = 4 # Make this configurable
|
# Do a standard RL algorithm process for a certain number of episodes
|
||||||
for _ in range(EPISODES):
|
for i in range(self.num_sneaky_episodes):
|
||||||
|
print("Episode: %d / %d, Reward: " % (i + 1, self.num_sneaky_episodes), end = "")
|
||||||
|
|
||||||
|
# Reset all episode releated variables
|
||||||
prev_obs = self.sneaky_env.reset()
|
prev_obs = self.sneaky_env.reset()
|
||||||
done = False
|
done = False
|
||||||
step = 0
|
step = 0
|
||||||
|
total_reward = 0
|
||||||
|
|
||||||
while not done:
|
while not done:
|
||||||
action = self.action_selector.act(prev_obs)
|
action = self.action_selector.act(prev_obs)
|
||||||
obs, reward, done, _ = self.sneaky_env.step(action)
|
obs, reward, done, _ = self.sneaky_env.step(action)
|
||||||
|
total_reward += reward
|
||||||
self.memory.append(prev_obs, action, reward, obs, done)
|
self.memory.append(prev_obs, action, reward, obs, done)
|
||||||
prev_obs = obs
|
prev_obs = obs
|
||||||
step += 1
|
step += 1
|
||||||
if step % replay_skip == 0:
|
if step % self.replay_skip == 0:
|
||||||
self.agent.learn()
|
self.agent.learn()
|
||||||
|
|
||||||
|
# Finish the previous print with the total reward obtained during the episode
|
||||||
|
print(total_reward)
|
||||||
|
|
||||||
|
# Reset the memory back to the human demonstration / shown computer data
|
||||||
self.memory = backup_memory
|
self.memory = backup_memory
|
||||||
|
self.memory_lock.release()
|
||||||
|
|
||||||
|
# Thoughts:
|
||||||
# It would be cool instead of throwing away all this new data, we keep just a sample of it
|
# It would be cool instead of throwing away all this new data, we keep just a sample of it
|
||||||
# Not sure if i want all of it because then it'll drown out the expert demonstration data
|
# Not sure if i want all of it because then it'll drown out the expert demonstration data
|
||||||
|
|
||||||
|
@ -178,55 +184,14 @@ class Play:
|
||||||
Above code works also if env is wrapped, so it's particularly useful in
|
Above code works also if env is wrapped, so it's particularly useful in
|
||||||
verifying that the frame-level preprocessing does not render the game
|
verifying that the frame-level preprocessing does not render the game
|
||||||
unplayable.
|
unplayable.
|
||||||
If you wish to plot real time statistics as you play, you can use
|
|
||||||
gym.utils.play.PlayPlot. Here's a sample code for plotting the reward
|
|
||||||
for last 5 second of gameplay.
|
|
||||||
def callback(obs_t, obs_tp1, rew, done, info):
|
|
||||||
return [rew,]
|
|
||||||
env_plotter = EnvPlotter(callback, 30 * 5, ["reward"])
|
|
||||||
env = gym.make("Pong-v3")
|
|
||||||
play(env, callback=env_plotter.callback)
|
|
||||||
Arguments
|
|
||||||
---------
|
|
||||||
env: gym.Env
|
|
||||||
Environment to use for playing.
|
|
||||||
transpose: bool
|
|
||||||
If True the output of observation is transposed.
|
|
||||||
Defaults to true.
|
|
||||||
fps: int
|
|
||||||
Maximum number of steps of the environment to execute every second.
|
|
||||||
Defaults to 30.
|
|
||||||
zoom: float
|
|
||||||
Make screen edge this many times bigger
|
|
||||||
callback: lambda or None
|
|
||||||
Callback if a callback is provided it will be executed after
|
|
||||||
every step. It takes the following input:
|
|
||||||
obs_t: observation before performing action
|
|
||||||
obs_tp1: observation after performing action
|
|
||||||
action: action that was executed
|
|
||||||
rew: reward that was received
|
|
||||||
done: whether the environment is done or not
|
|
||||||
info: debug info
|
|
||||||
keys_to_action: dict: tuple(int) -> int or None
|
|
||||||
Mapping from keys pressed to action performed.
|
|
||||||
For example if pressed 'w' and space at the same time is supposed
|
|
||||||
to trigger action number 2 then key_to_action dict would look like this:
|
|
||||||
{
|
|
||||||
# ...
|
|
||||||
sorted(ord('w'), ord(' ')) -> 2
|
|
||||||
# ...
|
|
||||||
}
|
|
||||||
If None, default key_to_action mapping for that env is used, if provided.
|
|
||||||
"""
|
"""
|
||||||
obs_s = self.env.unwrapped.observation_space
|
obs_s = self.env.unwrapped.observation_space
|
||||||
assert type(obs_s) == gym.spaces.box.Box
|
assert type(obs_s) == Box
|
||||||
assert len(obs_s.shape) == 2 or (len(obs_s.shape) == 3 and obs_s.shape[2] in [1,3])
|
assert len(obs_s.shape) == 2 or (len(obs_s.shape) == 3 and obs_s.shape[2] in [1,3])
|
||||||
|
|
||||||
self._setup_keys()
|
self._setup_keys()
|
||||||
self._setup_video()
|
self._setup_video()
|
||||||
|
|
||||||
self.clock = pygame.time.Clock()
|
|
||||||
|
|
||||||
# States
|
# States
|
||||||
HUMAN_PLAY = 0
|
HUMAN_PLAY = 0
|
||||||
SNEAKY_COMPUTER_PLAY = 1
|
SNEAKY_COMPUTER_PLAY = 1
|
||||||
|
@ -234,41 +199,61 @@ class Play:
|
||||||
COMPUTER_PLAY = 3
|
COMPUTER_PLAY = 3
|
||||||
TRANSITION2 = 4
|
TRANSITION2 = 4
|
||||||
|
|
||||||
|
|
||||||
env_done = True
|
env_done = True
|
||||||
prev_obs = None
|
|
||||||
obs = None
|
obs = None
|
||||||
reward = 0
|
|
||||||
i = 0
|
i = 0
|
||||||
while self.running:
|
while self.running:
|
||||||
|
# If the environment is done after a turn, reset it so we can keep going
|
||||||
if env_done:
|
if env_done:
|
||||||
obs = self.env.reset()
|
obs = self.env.reset()
|
||||||
env_done = False
|
env_done = False
|
||||||
|
|
||||||
|
|
||||||
if self.state is HUMAN_PLAY:
|
if self.state is HUMAN_PLAY:
|
||||||
prev_obs, action, reward, obs, env_done = self._human_play(obs)
|
_, _, _, obs, env_done = self._human_play(obs)
|
||||||
|
|
||||||
|
# The computer will train for a few episodes without showing to the user.
|
||||||
|
# Mainly to speed up the learning process a bit
|
||||||
elif self.state is SNEAKY_COMPUTER_PLAY:
|
elif self.state is SNEAKY_COMPUTER_PLAY:
|
||||||
|
print("Sneaky Computer Time")
|
||||||
|
|
||||||
|
# Display "Training..." text to user
|
||||||
myfont = pygame.font.SysFont('Comic Sans MS', 50)
|
myfont = pygame.font.SysFont('Comic Sans MS', 50)
|
||||||
textsurface = myfont.render("Training....", False, (0, 0, 0))
|
textsurface = myfont.render("Training....", False, (0, 0, 0))
|
||||||
self.screen.blit(textsurface,(0,0))
|
self.screen.blit(textsurface,(0,0))
|
||||||
|
pygame.display.flip()
|
||||||
|
|
||||||
|
# Have the agent play a few rounds without showing to the user
|
||||||
self.sneaky_train()
|
self.sneaky_train()
|
||||||
|
|
||||||
|
# To take away training text
|
||||||
|
self._display_arr(obs, self.screen, self.env.unwrapped._get_obs(), video_size=self.video_size)
|
||||||
|
pygame.display.flip()
|
||||||
|
|
||||||
|
# Go to the next step immediately
|
||||||
self._increment_state()
|
self._increment_state()
|
||||||
|
|
||||||
elif self.state is TRANSITION:
|
elif self.state is TRANSITION:
|
||||||
self.pause("Computers Turn! Press <Space> to Start")
|
self.pause("Computers Turn! Press <Space> to Start")
|
||||||
|
|
||||||
elif self.state is COMPUTER_PLAY:
|
elif self.state is COMPUTER_PLAY:
|
||||||
prev_obs, action, reward, obs, env_done = self._computer_play(obs)
|
_, _, _, obs, env_done = self._computer_play(obs)
|
||||||
|
|
||||||
elif self.state is TRANSITION2:
|
elif self.state is TRANSITION2:
|
||||||
self.pause("Your Turn! Press <Space> to Start")
|
self.pause("Your Turn! Press <Space> to Start")
|
||||||
|
|
||||||
|
# Increment the timer if it's the human or shown computer's turn
|
||||||
if self.state is COMPUTER_PLAY or self.state is HUMAN_PLAY:
|
if self.state is COMPUTER_PLAY or self.state is HUMAN_PLAY:
|
||||||
self.memory.append(prev_obs, action, reward, obs, env_done)
|
|
||||||
i += 1
|
i += 1
|
||||||
# Every 30 seconds...
|
# Perform a quick learning process and increment the state after a certain time period has passed
|
||||||
if i % (self.fps * 30) == 0:
|
if i % (self.fps * self.seconds_play_per_state) == 0:
|
||||||
print("Training...")
|
self.memory_lock.acquire()
|
||||||
|
print("Number of transitions in buffer: ", len(self.memory))
|
||||||
self.agent.learn()
|
self.agent.learn()
|
||||||
print("PAUSING...")
|
self.memory_lock.release()
|
||||||
self._increment_state()
|
self._increment_state()
|
||||||
i = 0
|
i = 0
|
||||||
|
|
||||||
|
# Stop the pygame environment when done
|
||||||
pygame.quit()
|
pygame.quit()
|
||||||
|
|
||||||
|
|
152
play_env.py
152
play_env.py
|
@ -1,21 +1,31 @@
|
||||||
import play
|
|
||||||
import rltorch
|
# Import Python Standard Libraries
|
||||||
import rltorch.memory as M
|
from threading import Thread, Lock
|
||||||
import torch
|
from argparse import ArgumentParser
|
||||||
import gym
|
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
|
# Import Pytorch related packages for NNs
|
||||||
|
from numpy import array as np_array
|
||||||
|
from numpy import save as np_save
|
||||||
|
import torch
|
||||||
|
from torch.optim import Adam
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
# Import my custom RL library
|
||||||
|
import rltorch
|
||||||
|
from rltorch.memory import PrioritizedReplayMemory
|
||||||
from rltorch.action_selector import EpsilonGreedySelector
|
from rltorch.action_selector import EpsilonGreedySelector
|
||||||
import rltorch.env as E
|
import rltorch.env as E
|
||||||
import rltorch.network as rn
|
import rltorch.network as rn
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
# Import OpenAI gym and related packages
|
||||||
import pickle
|
from gym import make as makeEnv
|
||||||
import threading
|
from gym import Wrapper as GymWrapper
|
||||||
from time import sleep
|
from gym.wrappers import Monitor as GymMonitor
|
||||||
import argparse
|
import play
|
||||||
import sys
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
#
|
#
|
||||||
## Networks
|
## Networks
|
||||||
|
@ -73,56 +83,56 @@ class Value(nn.Module):
|
||||||
Transition = namedtuple('Transition',
|
Transition = namedtuple('Transition',
|
||||||
('state', 'action', 'reward', 'next_state', 'done'))
|
('state', 'action', 'reward', 'next_state', 'done'))
|
||||||
|
|
||||||
class PlayClass(threading.Thread):
|
class PlayClass(Thread):
|
||||||
def __init__(self, env, action_selector, memory, agent, sneaky_env, fps = 60):
|
def __init__(self, env, action_selector, memory, memory_lock, agent, sneaky_env, config):
|
||||||
super(PlayClass, self).__init__()
|
super(PlayClass, self).__init__()
|
||||||
self.env = env
|
self.play = play.Play(env, action_selector, memory, memory_lock, agent, sneaky_env, config)
|
||||||
self.fps = fps
|
|
||||||
self.play = play.Play(self.env, action_selector, memory, agent, sneaky_env, fps = fps, zoom = 4)
|
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
self.play.start()
|
self.play.start()
|
||||||
|
|
||||||
class Record(gym.Wrapper):
|
class Record(GymWrapper):
|
||||||
def __init__(self, env, memory, args, skipframes = 3):
|
def __init__(self, env, memory, memory_lock, args):
|
||||||
gym.Wrapper.__init__(self, env)
|
GymWrapper.__init__(self, env)
|
||||||
self.memory_lock = threading.Lock()
|
self.memory_lock = memory_lock
|
||||||
self.memory = memory
|
self.memory = memory
|
||||||
self.args = args
|
self.skipframes = args['skip']
|
||||||
self.skipframes = skipframes
|
self.environment_name = args['environment_name']
|
||||||
self.current_i = skipframes
|
self.logdir = args['logdir']
|
||||||
|
self.current_i = 0
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
return self.env.reset()
|
return self.env.reset()
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
self.memory_lock.acquire()
|
|
||||||
state = self.env.env._get_obs()
|
state = self.env.env._get_obs()
|
||||||
next_state, reward, done, info = self.env.step(action)
|
next_state, reward, done, info = self.env.step(action)
|
||||||
if self.current_i <= 0:
|
self.current_i += 1
|
||||||
self.memory.append(Transition(state, action, reward, next_state, done))
|
# Don't add to memory until a certain number of frames is reached
|
||||||
self.current_i = self.skipframes
|
if self.current_i % self.skipframes == 0:
|
||||||
else: self.current_i -= 1
|
self.memory_lock.acquire()
|
||||||
self.memory_lock.release()
|
self.memory.append(state, action, reward, next_state, done)
|
||||||
|
self.memory_lock.release()
|
||||||
|
self.current_i = 0
|
||||||
return next_state, reward, done, info
|
return next_state, reward, done, info
|
||||||
|
|
||||||
def log_transitions(self):
|
def log_transitions(self):
|
||||||
self.memory_lock.acquire()
|
self.memory_lock.acquire()
|
||||||
if len(self.memory) > 0:
|
if len(self.memory) > 0:
|
||||||
basename = self.args['logdir'] + "/{}.{}".format(self.args['environment_name'], datetime.now().strftime("%Y-%m-%d-%H-%M-%s"))
|
basename = self.logdir + "/{}.{}".format(self.environment_name, datetime.now().strftime("%Y-%m-%d-%H-%M-%s"))
|
||||||
print("Base Filename: ", basename)
|
print("Base Filename: ", basename)
|
||||||
state, action, reward, next_state, done = zip(*self.memory)
|
state, action, reward, next_state, done = zip(*self.memory)
|
||||||
np.save(basename + "-state.npy", np.array(state), allow_pickle = False)
|
np_save(basename + "-state.npy", np_array(state), allow_pickle = False)
|
||||||
np.save(basename + "-action.npy", np.array(action), allow_pickle = False)
|
np_save(basename + "-action.npy", np_array(action), allow_pickle = False)
|
||||||
np.save(basename + "-reward.npy", np.array(reward), allow_pickle = False)
|
np_save(basename + "-reward.npy", np_array(reward), allow_pickle = False)
|
||||||
np.save(basename + "-nextstate.npy", np.array(next_state), allow_pickle = False)
|
np_save(basename + "-nextstate.npy", np_array(next_state), allow_pickle = False)
|
||||||
np.save(basename + "-done.npy", np.array(done), allow_pickle = False)
|
np_save(basename + "-done.npy", np_array(done), allow_pickle = False)
|
||||||
self.memory.clear()
|
self.memory.clear()
|
||||||
self.memory_lock.release()
|
self.memory_lock.release()
|
||||||
|
|
||||||
|
|
||||||
## Parsing arguments
|
## Parsing arguments
|
||||||
parser = argparse.ArgumentParser(description="Play and log the environment")
|
parser = ArgumentParser(description="Play and log the environment")
|
||||||
parser.add_argument("--environment_name", type=str, help="The environment name in OpenAI gym to play.")
|
parser.add_argument("--environment_name", type=str, help="The environment name in OpenAI gym to play.")
|
||||||
parser.add_argument("--logdir", type=str, help="Directory to log video and (state, action, reward, next_state, done) in.")
|
parser.add_argument("--logdir", type=str, help="Directory to log video and (state, action, reward, next_state, done) in.")
|
||||||
parser.add_argument("--skip", type=int, help="Number of frames to skip logging.")
|
parser.add_argument("--skip", type=int, help="Number of frames to skip logging.")
|
||||||
|
@ -130,14 +140,20 @@ parser.add_argument("--fps", type=int, help="Number of frames per second")
|
||||||
parser.add_argument("--model", type=str, help = "The path location of the PyTorch model")
|
parser.add_argument("--model", type=str, help = "The path location of the PyTorch model")
|
||||||
args = vars(parser.parse_args())
|
args = vars(parser.parse_args())
|
||||||
|
|
||||||
|
## Main configuration for script
|
||||||
config = {}
|
config = {}
|
||||||
config['seed'] = 901
|
config['seed'] = 901
|
||||||
|
config['seconds_play_per_state'] = 60
|
||||||
|
config['zoom'] = 4
|
||||||
config['environment_name'] = 'PongNoFrameskip-v4'
|
config['environment_name'] = 'PongNoFrameskip-v4'
|
||||||
config['learning_rate'] = 1e-4
|
config['learning_rate'] = 1e-4
|
||||||
config['target_sync_tau'] = 1e-3
|
config['target_sync_tau'] = 1e-3
|
||||||
config['discount_rate'] = 0.99
|
config['discount_rate'] = 0.99
|
||||||
config['exploration_rate'] = rltorch.scheduler.ExponentialScheduler(initial_value = 1, end_value = 0.1, iterations = 10**5)
|
config['exploration_rate'] = rltorch.scheduler.ExponentialScheduler(initial_value = 1, end_value = 0.1, iterations = 10**5)
|
||||||
config['batch_size'] = 480
|
# Number of episodes for the computer to train the agent without the human seeing
|
||||||
|
config['num_sneaky_episodes'] = 20
|
||||||
|
config['replay_skip'] = 14
|
||||||
|
config['batch_size'] = 32 * (config['replay_skip'] + 1)
|
||||||
config['disable_cuda'] = False
|
config['disable_cuda'] = False
|
||||||
config['memory_size'] = 10**4
|
config['memory_size'] = 10**4
|
||||||
# Prioritized vs Random Sampling
|
# Prioritized vs Random Sampling
|
||||||
|
@ -151,63 +167,73 @@ config['prioritized_replay_sampling_priority'] = 0.6
|
||||||
config['prioritized_replay_weight_importance'] = rltorch.scheduler.ExponentialScheduler(initial_value = 0.4, end_value = 1, iterations = 10**5)
|
config['prioritized_replay_weight_importance'] = rltorch.scheduler.ExponentialScheduler(initial_value = 0.4, end_value = 1, iterations = 10**5)
|
||||||
|
|
||||||
|
|
||||||
|
# Environment name and log directory is vital so show help message and exit if not provided
|
||||||
if args['environment_name'] is None or args['logdir'] is None:
|
if args['environment_name'] is None or args['logdir'] is None:
|
||||||
parser.print_help()
|
parser.print_help()
|
||||||
sys.exit(1)
|
exit(1)
|
||||||
|
|
||||||
|
# Number of frames to skip when recording and fps can have sane defaults
|
||||||
if args['skip'] is None:
|
if args['skip'] is None:
|
||||||
args['skip'] = 3
|
args['skip'] = 3
|
||||||
|
|
||||||
if args['fps'] is None:
|
if args['fps'] is None:
|
||||||
args['fps'] = 30
|
args['fps'] = 30
|
||||||
|
|
||||||
def wrap_preprocessing(env):
|
|
||||||
|
def wrap_preprocessing(env, MaxAndSkipEnv = False):
|
||||||
|
env = E.NoopResetEnv(
|
||||||
|
E.EpisodicLifeEnv(env),
|
||||||
|
noop_max = 30
|
||||||
|
)
|
||||||
|
if MaxAndSkipEnv:
|
||||||
|
env = E.MaxAndSkipEnv(env, skip = 4)
|
||||||
return E.ClippedRewardsWrapper(
|
return E.ClippedRewardsWrapper(
|
||||||
E.FrameStack(
|
E.FrameStack(
|
||||||
E.TorchWrap(
|
E.TorchWrap(
|
||||||
E.ProcessFrame84(
|
E.ProcessFrame84(
|
||||||
E.FireResetEnv(
|
E.FireResetEnv(env)
|
||||||
# E.MaxAndSkipEnv(
|
|
||||||
E.NoopResetEnv(
|
|
||||||
E.EpisodicLifeEnv(env)
|
|
||||||
, noop_max = 30)
|
|
||||||
# , skip=4)
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
),
|
)
|
||||||
4)
|
, 4)
|
||||||
)
|
)
|
||||||
|
|
||||||
## Starting the game
|
|
||||||
memory = []
|
## Set up environment to be recorded and preprocessed
|
||||||
env = Record(gym.make(args['environment_name']), memory, args, skipframes = args['skip'])
|
memory = PrioritizedReplayMemory(capacity = config['memory_size'], alpha = config['prioritized_replay_sampling_priority'])
|
||||||
|
memory_lock = Lock()
|
||||||
|
env = Record(makeEnv(args['environment_name']), memory, memory_lock, args)
|
||||||
|
# Bind record_env to current env so that we can reference log_transitions easier later
|
||||||
record_env = env
|
record_env = env
|
||||||
env = gym.wrappers.Monitor(env, args['logdir'], force=True)
|
# Use native gym monitor to get video recording
|
||||||
|
env = GymMonitor(env, args['logdir'], force=True)
|
||||||
|
# Preprocess enviornment
|
||||||
env = wrap_preprocessing(env)
|
env = wrap_preprocessing(env)
|
||||||
|
|
||||||
sneaky_env = wrap_preprocessing(gym.make(args['environment_name']))
|
# Use a different environment for when the computer trains on the side so that the current game state isn't manipuated
|
||||||
|
# Also use MaxEnvSkip to speed up processing
|
||||||
|
sneaky_env = wrap_preprocessing(makeEnv(args['environment_name']), MaxAndSkipEnv = True)
|
||||||
|
|
||||||
|
# Set seeds
|
||||||
rltorch.set_seed(config['seed'])
|
rltorch.set_seed(config['seed'])
|
||||||
|
env.seed(config['seed'])
|
||||||
|
|
||||||
device = torch.device("cuda:0" if torch.cuda.is_available() and not config['disable_cuda'] else "cpu")
|
device = torch.device("cuda:0" if torch.cuda.is_available() and not config['disable_cuda'] else "cpu")
|
||||||
state_size = env.observation_space.shape[0]
|
state_size = env.observation_space.shape[0]
|
||||||
action_size = env.action_space.n
|
action_size = env.action_space.n
|
||||||
|
|
||||||
|
# Set up the networks
|
||||||
net = rn.Network(Value(state_size, action_size),
|
net = rn.Network(Value(state_size, action_size),
|
||||||
torch.optim.Adam, config, device = device)
|
Adam, config, device = device)
|
||||||
target_net = rn.TargetNetwork(net, device = device)
|
target_net = rn.TargetNetwork(net, device = device)
|
||||||
|
|
||||||
|
# Relevant components from RLTorch
|
||||||
actor = EpsilonGreedySelector(net, action_size, device = device, epsilon = config['exploration_rate'])
|
actor = EpsilonGreedySelector(net, action_size, device = device, epsilon = config['exploration_rate'])
|
||||||
memory = M.PrioritizedReplayMemory(capacity = config['memory_size'], alpha = config['prioritized_replay_sampling_priority'])
|
|
||||||
agent = rltorch.agents.DQNAgent(net, memory, config, target_net = target_net)
|
agent = rltorch.agents.DQNAgent(net, memory, config, target_net = target_net)
|
||||||
|
|
||||||
env.seed(config['seed'])
|
# Pass all this information into the thread that will handle the game play and start
|
||||||
|
playThread = PlayClass(env, actor, memory, memory_lock, agent, sneaky_env, config)
|
||||||
playThread = PlayClass(env, actor, memory, agent, sneaky_env, fps = args['fps'])
|
|
||||||
playThread.start()
|
playThread.start()
|
||||||
|
|
||||||
## Logging portion
|
# While the play thread is running, we'll periodically log transitions we've encountered
|
||||||
while playThread.is_alive():
|
while playThread.is_alive():
|
||||||
playThread.join(60)
|
playThread.join(60)
|
||||||
print("Logging....", end = " ")
|
print("Logging....", end = " ")
|
||||||
|
|
Loading…
Reference in a new issue