SneakyTrain uses separate replay buffer

Scripts were cleaned up considerably and comments were added
This commit is contained in:
Brandon Rozek 2019-10-23 21:53:20 -04:00
parent b7aa4a4ec6
commit d78892e62c
2 changed files with 186 additions and 175 deletions

207
play.py
View file

@ -1,41 +1,33 @@
import gym
from gym.spaces.box import Box
import pygame
import sys
import time
import matplotlib
import rltorch.memory as M
try:
matplotlib.use('GTK3Agg')
import matplotlib.pyplot as plt
except Exception:
pass
import pyglet.window as pw
from collections import deque
from pygame.locals import HWSURFACE, DOUBLEBUF, RESIZABLE, VIDEORESIZE
from threading import Thread, Event, Timer
from pygame.locals import VIDEORESIZE
from rltorch.memory import ReplayMemory
class Play:
def __init__(self, env, action_selector, memory, agent, sneaky_env, transpose = True, fps = 30, zoom = None, keys_to_action = None):
def __init__(self, env, action_selector, memory, memory_lock, agent, sneaky_env, config):
self.env = env
self.action_selector = action_selector
self.transpose = transpose
self.fps = fps
self.zoom = zoom
self.keys_to_action = None
self.memory = memory
self.memory_lock = memory_lock
self.agent = agent
self.sneaky_env = sneaky_env
# Get relevant parameters from config or set sane defaults
self.transpose = config['transpose'] if 'transpose' in config else True
self.fps = config['fps'] if 'fps' in config else 30
self.zoom = config['zoom'] if 'zoom' in config else 1
self.keys_to_action = config['keys_to_action'] if 'keys_to_action' in config else None
self.seconds_play_per_state = config['seconds_play_per_state'] if 'seconds_play_per_state' in config else 30
self.num_sneaky_episodes = config['num_sneaky_episodes'] if 'num_sneaky_episodes' in config else 10
self.memory_size = config['memory_size'] if 'memory_size' in config else 10**4
self.replay_skip = config['replay_skip'] if 'replay_skip' in config else 0
# Initial values...
self.video_size = (0, 0)
self.pressed_keys = []
self.screen = None
self.relevant_keys = set()
self.running = True
self.switch = Event()
self.state = 0
self.paused = False
self.memory = memory
self.agent = agent
self.sneaky_env = sneaky_env
self.clock = pygame.time.Clock()
def _display_arr(self, obs, screen, arr, video_size):
if obs is not None:
@ -49,6 +41,21 @@ class Play:
pyg_img = pygame.transform.scale(pyg_img, video_size)
screen.blit(pyg_img, (0,0))
def _process_common_pygame_events(self, event):
if event.type == pygame.QUIT:
self.running = False
elif event.type == VIDEORESIZE:
self.video_size = event.size
self.screen = pygame.display.set_mode(self.video_size)
elif event.type == pygame.KEYDOWN and event.key == pygame.K_ESCAPE:
self.running = False
else:
# No event was matched here
return False
# One of the events above matched
return True
def _human_play(self, obs):
action = self.keys_to_action.get(tuple(sorted(self.pressed_keys)), 0)
prev_obs = obs
@ -57,20 +64,14 @@ class Play:
# process pygame events
for event in pygame.event.get():
# test events, set key states
if event.type == pygame.KEYDOWN:
if self._process_common_pygame_events(event):
continue
elif event.type == pygame.KEYDOWN:
if event.key in self.relevant_keys:
self.pressed_keys.append(event.key)
elif event.key == pygame.K_ESCAPE:
self.running = False
elif event.type == pygame.KEYUP:
if event.key in self.relevant_keys:
self.pressed_keys.remove(event.key)
elif event.type == pygame.QUIT:
self.running = False
elif event.type == VIDEORESIZE:
self.video_size = event.size
self.screen = pygame.display.set_mode(self.video_size)
pygame.display.flip()
self.clock.tick(self.fps)
@ -84,13 +85,7 @@ class Play:
# process pygame events
for event in pygame.event.get():
if event.type == pygame.QUIT:
self.running = False
elif event.type == VIDEORESIZE:
self.video_size = event.size
self.screen = pygame.display.set_mode(self.video_size)
elif event.type == pygame.KEYDOWN and event.key == pygame.K_ESCAPE:
self.running = False
self._process_common_pygame_events(event)
pygame.display.flip()
self.clock.tick(self.fps)
@ -107,7 +102,7 @@ class Play:
self.video_size = video_size
self.screen = pygame.display.set_mode(self.video_size)
pygame.font.init() # For later text
pygame.font.init()
def _setup_keys(self):
if self.keys_to_action is None:
@ -124,48 +119,59 @@ class Play:
self.state = (self.state + 1) % 5
def pause(self, text = ""):
self.paused = True
myfont = pygame.font.SysFont('Comic Sans MS', 50)
textsurface = myfont.render(text, False, (0, 0, 0))
self.screen.blit(textsurface,(0,0))
# Process pygame events
for event in pygame.event.get():
if event.type == pygame.QUIT:
self.running = False
elif event.type == VIDEORESIZE:
self.video_size = event.size
self.screen = pygame.display.set_mode(self.video_size)
if self._process_common_pygame_events(event):
continue
elif event.type == pygame.KEYDOWN:
if event.key == pygame.K_SPACE:
self.pressed_keys.append(event.key)
elif event.key == pygame.K_ESCAPE:
self.running = False
elif event.type == pygame.KEYUP and event.key == pygame.K_SPACE:
self.pressed_keys.remove(event.key)
self._increment_state()
self.paused = False
pygame.display.flip()
self.clock.tick(self.fps)
def sneaky_train(self):
self.memory_lock.acquire()
# Backup memory
backup_memory = self.memory
self.memory = M.ReplayMemory(capacity = 2000) # Another configurable parameter
EPISODES = 30 # Make this configurable
replay_skip = 4 # Make this configurable
for _ in range(EPISODES):
self.memory = ReplayMemory(capacity = self.memory_size)
# Do a standard RL algorithm process for a certain number of episodes
for i in range(self.num_sneaky_episodes):
print("Episode: %d / %d, Reward: " % (i + 1, self.num_sneaky_episodes), end = "")
# Reset all episode releated variables
prev_obs = self.sneaky_env.reset()
done = False
step = 0
total_reward = 0
while not done:
action = self.action_selector.act(prev_obs)
obs, reward, done, _ = self.sneaky_env.step(action)
total_reward += reward
self.memory.append(prev_obs, action, reward, obs, done)
prev_obs = obs
step += 1
if step % replay_skip == 0:
if step % self.replay_skip == 0:
self.agent.learn()
# Finish the previous print with the total reward obtained during the episode
print(total_reward)
# Reset the memory back to the human demonstration / shown computer data
self.memory = backup_memory
self.memory_lock.release()
# Thoughts:
# It would be cool instead of throwing away all this new data, we keep just a sample of it
# Not sure if i want all of it because then it'll drown out the expert demonstration data
@ -178,55 +184,14 @@ class Play:
Above code works also if env is wrapped, so it's particularly useful in
verifying that the frame-level preprocessing does not render the game
unplayable.
If you wish to plot real time statistics as you play, you can use
gym.utils.play.PlayPlot. Here's a sample code for plotting the reward
for last 5 second of gameplay.
def callback(obs_t, obs_tp1, rew, done, info):
return [rew,]
env_plotter = EnvPlotter(callback, 30 * 5, ["reward"])
env = gym.make("Pong-v3")
play(env, callback=env_plotter.callback)
Arguments
---------
env: gym.Env
Environment to use for playing.
transpose: bool
If True the output of observation is transposed.
Defaults to true.
fps: int
Maximum number of steps of the environment to execute every second.
Defaults to 30.
zoom: float
Make screen edge this many times bigger
callback: lambda or None
Callback if a callback is provided it will be executed after
every step. It takes the following input:
obs_t: observation before performing action
obs_tp1: observation after performing action
action: action that was executed
rew: reward that was received
done: whether the environment is done or not
info: debug info
keys_to_action: dict: tuple(int) -> int or None
Mapping from keys pressed to action performed.
For example if pressed 'w' and space at the same time is supposed
to trigger action number 2 then key_to_action dict would look like this:
{
# ...
sorted(ord('w'), ord(' ')) -> 2
# ...
}
If None, default key_to_action mapping for that env is used, if provided.
"""
obs_s = self.env.unwrapped.observation_space
assert type(obs_s) == gym.spaces.box.Box
assert type(obs_s) == Box
assert len(obs_s.shape) == 2 or (len(obs_s.shape) == 3 and obs_s.shape[2] in [1,3])
self._setup_keys()
self._setup_video()
self.clock = pygame.time.Clock()
# States
HUMAN_PLAY = 0
SNEAKY_COMPUTER_PLAY = 1
@ -234,41 +199,61 @@ class Play:
COMPUTER_PLAY = 3
TRANSITION2 = 4
env_done = True
prev_obs = None
obs = None
reward = 0
i = 0
while self.running:
# If the environment is done after a turn, reset it so we can keep going
if env_done:
obs = self.env.reset()
env_done = False
if self.state is HUMAN_PLAY:
prev_obs, action, reward, obs, env_done = self._human_play(obs)
_, _, _, obs, env_done = self._human_play(obs)
# The computer will train for a few episodes without showing to the user.
# Mainly to speed up the learning process a bit
elif self.state is SNEAKY_COMPUTER_PLAY:
print("Sneaky Computer Time")
# Display "Training..." text to user
myfont = pygame.font.SysFont('Comic Sans MS', 50)
textsurface = myfont.render("Training....", False, (0, 0, 0))
self.screen.blit(textsurface,(0,0))
pygame.display.flip()
# Have the agent play a few rounds without showing to the user
self.sneaky_train()
# To take away training text
self._display_arr(obs, self.screen, self.env.unwrapped._get_obs(), video_size=self.video_size)
pygame.display.flip()
# Go to the next step immediately
self._increment_state()
elif self.state is TRANSITION:
self.pause("Computers Turn! Press <Space> to Start")
elif self.state is COMPUTER_PLAY:
prev_obs, action, reward, obs, env_done = self._computer_play(obs)
_, _, _, obs, env_done = self._computer_play(obs)
elif self.state is TRANSITION2:
self.pause("Your Turn! Press <Space> to Start")
# Increment the timer if it's the human or shown computer's turn
if self.state is COMPUTER_PLAY or self.state is HUMAN_PLAY:
self.memory.append(prev_obs, action, reward, obs, env_done)
i += 1
# Every 30 seconds...
if i % (self.fps * 30) == 0:
print("Training...")
# Perform a quick learning process and increment the state after a certain time period has passed
if i % (self.fps * self.seconds_play_per_state) == 0:
self.memory_lock.acquire()
print("Number of transitions in buffer: ", len(self.memory))
self.agent.learn()
print("PAUSING...")
self.memory_lock.release()
self._increment_state()
i = 0
# Stop the pygame environment when done
pygame.quit()

View file

@ -1,21 +1,31 @@
import play
import rltorch
import rltorch.memory as M
import torch
import gym
# Import Python Standard Libraries
from threading import Thread, Lock
from argparse import ArgumentParser
from collections import namedtuple
from datetime import datetime
# Import Pytorch related packages for NNs
from numpy import array as np_array
from numpy import save as np_save
import torch
from torch.optim import Adam
import torch.nn as nn
import torch.nn.functional as F
# Import my custom RL library
import rltorch
from rltorch.memory import PrioritizedReplayMemory
from rltorch.action_selector import EpsilonGreedySelector
import rltorch.env as E
import rltorch.network as rn
import torch.nn as nn
import torch.nn.functional as F
import pickle
import threading
from time import sleep
import argparse
import sys
import numpy as np
# Import OpenAI gym and related packages
from gym import make as makeEnv
from gym import Wrapper as GymWrapper
from gym.wrappers import Monitor as GymMonitor
import play
#
## Networks
@ -73,56 +83,56 @@ class Value(nn.Module):
Transition = namedtuple('Transition',
('state', 'action', 'reward', 'next_state', 'done'))
class PlayClass(threading.Thread):
def __init__(self, env, action_selector, memory, agent, sneaky_env, fps = 60):
class PlayClass(Thread):
def __init__(self, env, action_selector, memory, memory_lock, agent, sneaky_env, config):
super(PlayClass, self).__init__()
self.env = env
self.fps = fps
self.play = play.Play(self.env, action_selector, memory, agent, sneaky_env, fps = fps, zoom = 4)
self.play = play.Play(env, action_selector, memory, memory_lock, agent, sneaky_env, config)
def run(self):
self.play.start()
class Record(gym.Wrapper):
def __init__(self, env, memory, args, skipframes = 3):
gym.Wrapper.__init__(self, env)
self.memory_lock = threading.Lock()
class Record(GymWrapper):
def __init__(self, env, memory, memory_lock, args):
GymWrapper.__init__(self, env)
self.memory_lock = memory_lock
self.memory = memory
self.args = args
self.skipframes = skipframes
self.current_i = skipframes
self.skipframes = args['skip']
self.environment_name = args['environment_name']
self.logdir = args['logdir']
self.current_i = 0
def reset(self):
return self.env.reset()
def step(self, action):
self.memory_lock.acquire()
state = self.env.env._get_obs()
next_state, reward, done, info = self.env.step(action)
if self.current_i <= 0:
self.memory.append(Transition(state, action, reward, next_state, done))
self.current_i = self.skipframes
else: self.current_i -= 1
self.memory_lock.release()
self.current_i += 1
# Don't add to memory until a certain number of frames is reached
if self.current_i % self.skipframes == 0:
self.memory_lock.acquire()
self.memory.append(state, action, reward, next_state, done)
self.memory_lock.release()
self.current_i = 0
return next_state, reward, done, info
def log_transitions(self):
self.memory_lock.acquire()
if len(self.memory) > 0:
basename = self.args['logdir'] + "/{}.{}".format(self.args['environment_name'], datetime.now().strftime("%Y-%m-%d-%H-%M-%s"))
basename = self.logdir + "/{}.{}".format(self.environment_name, datetime.now().strftime("%Y-%m-%d-%H-%M-%s"))
print("Base Filename: ", basename)
state, action, reward, next_state, done = zip(*self.memory)
np.save(basename + "-state.npy", np.array(state), allow_pickle = False)
np.save(basename + "-action.npy", np.array(action), allow_pickle = False)
np.save(basename + "-reward.npy", np.array(reward), allow_pickle = False)
np.save(basename + "-nextstate.npy", np.array(next_state), allow_pickle = False)
np.save(basename + "-done.npy", np.array(done), allow_pickle = False)
np_save(basename + "-state.npy", np_array(state), allow_pickle = False)
np_save(basename + "-action.npy", np_array(action), allow_pickle = False)
np_save(basename + "-reward.npy", np_array(reward), allow_pickle = False)
np_save(basename + "-nextstate.npy", np_array(next_state), allow_pickle = False)
np_save(basename + "-done.npy", np_array(done), allow_pickle = False)
self.memory.clear()
self.memory_lock.release()
## Parsing arguments
parser = argparse.ArgumentParser(description="Play and log the environment")
parser = ArgumentParser(description="Play and log the environment")
parser.add_argument("--environment_name", type=str, help="The environment name in OpenAI gym to play.")
parser.add_argument("--logdir", type=str, help="Directory to log video and (state, action, reward, next_state, done) in.")
parser.add_argument("--skip", type=int, help="Number of frames to skip logging.")
@ -130,14 +140,20 @@ parser.add_argument("--fps", type=int, help="Number of frames per second")
parser.add_argument("--model", type=str, help = "The path location of the PyTorch model")
args = vars(parser.parse_args())
## Main configuration for script
config = {}
config['seed'] = 901
config['seconds_play_per_state'] = 60
config['zoom'] = 4
config['environment_name'] = 'PongNoFrameskip-v4'
config['learning_rate'] = 1e-4
config['target_sync_tau'] = 1e-3
config['discount_rate'] = 0.99
config['exploration_rate'] = rltorch.scheduler.ExponentialScheduler(initial_value = 1, end_value = 0.1, iterations = 10**5)
config['batch_size'] = 480
# Number of episodes for the computer to train the agent without the human seeing
config['num_sneaky_episodes'] = 20
config['replay_skip'] = 14
config['batch_size'] = 32 * (config['replay_skip'] + 1)
config['disable_cuda'] = False
config['memory_size'] = 10**4
# Prioritized vs Random Sampling
@ -151,63 +167,73 @@ config['prioritized_replay_sampling_priority'] = 0.6
config['prioritized_replay_weight_importance'] = rltorch.scheduler.ExponentialScheduler(initial_value = 0.4, end_value = 1, iterations = 10**5)
# Environment name and log directory is vital so show help message and exit if not provided
if args['environment_name'] is None or args['logdir'] is None:
parser.print_help()
sys.exit(1)
exit(1)
# Number of frames to skip when recording and fps can have sane defaults
if args['skip'] is None:
args['skip'] = 3
if args['fps'] is None:
args['fps'] = 30
def wrap_preprocessing(env):
def wrap_preprocessing(env, MaxAndSkipEnv = False):
env = E.NoopResetEnv(
E.EpisodicLifeEnv(env),
noop_max = 30
)
if MaxAndSkipEnv:
env = E.MaxAndSkipEnv(env, skip = 4)
return E.ClippedRewardsWrapper(
E.FrameStack(
E.TorchWrap(
E.ProcessFrame84(
E.FireResetEnv(
# E.MaxAndSkipEnv(
E.NoopResetEnv(
E.EpisodicLifeEnv(env)
, noop_max = 30)
# , skip=4)
)
E.FireResetEnv(env)
)
),
4)
)
, 4)
)
## Starting the game
memory = []
env = Record(gym.make(args['environment_name']), memory, args, skipframes = args['skip'])
## Set up environment to be recorded and preprocessed
memory = PrioritizedReplayMemory(capacity = config['memory_size'], alpha = config['prioritized_replay_sampling_priority'])
memory_lock = Lock()
env = Record(makeEnv(args['environment_name']), memory, memory_lock, args)
# Bind record_env to current env so that we can reference log_transitions easier later
record_env = env
env = gym.wrappers.Monitor(env, args['logdir'], force=True)
# Use native gym monitor to get video recording
env = GymMonitor(env, args['logdir'], force=True)
# Preprocess enviornment
env = wrap_preprocessing(env)
sneaky_env = wrap_preprocessing(gym.make(args['environment_name']))
# Use a different environment for when the computer trains on the side so that the current game state isn't manipuated
# Also use MaxEnvSkip to speed up processing
sneaky_env = wrap_preprocessing(makeEnv(args['environment_name']), MaxAndSkipEnv = True)
# Set seeds
rltorch.set_seed(config['seed'])
env.seed(config['seed'])
device = torch.device("cuda:0" if torch.cuda.is_available() and not config['disable_cuda'] else "cpu")
state_size = env.observation_space.shape[0]
action_size = env.action_space.n
# Set up the networks
net = rn.Network(Value(state_size, action_size),
torch.optim.Adam, config, device = device)
Adam, config, device = device)
target_net = rn.TargetNetwork(net, device = device)
# Relevant components from RLTorch
actor = EpsilonGreedySelector(net, action_size, device = device, epsilon = config['exploration_rate'])
memory = M.PrioritizedReplayMemory(capacity = config['memory_size'], alpha = config['prioritized_replay_sampling_priority'])
agent = rltorch.agents.DQNAgent(net, memory, config, target_net = target_net)
env.seed(config['seed'])
playThread = PlayClass(env, actor, memory, agent, sneaky_env, fps = args['fps'])
# Pass all this information into the thread that will handle the game play and start
playThread = PlayClass(env, actor, memory, memory_lock, agent, sneaky_env, config)
playThread.start()
## Logging portion
# While the play thread is running, we'll periodically log transitions we've encountered
while playThread.is_alive():
playThread.join(60)
print("Logging....", end = " ")