From d78892e62c358512ea52aeea83d5b44d78bee64b Mon Sep 17 00:00:00 2001 From: Brandon Rozek Date: Wed, 23 Oct 2019 21:53:20 -0400 Subject: [PATCH] SneakyTrain uses separate replay buffer Scripts were cleaned up considerably and comments were added --- play.py | 209 ++++++++++++++++++++++++---------------------------- play_env.py | 152 ++++++++++++++++++++++---------------- 2 files changed, 186 insertions(+), 175 deletions(-) diff --git a/play.py b/play.py index a26e254..48b7d52 100644 --- a/play.py +++ b/play.py @@ -1,41 +1,33 @@ -import gym +from gym.spaces.box import Box import pygame -import sys -import time -import matplotlib -import rltorch.memory as M -try: - matplotlib.use('GTK3Agg') - import matplotlib.pyplot as plt -except Exception: - pass - - -import pyglet.window as pw - -from collections import deque -from pygame.locals import HWSURFACE, DOUBLEBUF, RESIZABLE, VIDEORESIZE -from threading import Thread, Event, Timer +from pygame.locals import VIDEORESIZE +from rltorch.memory import ReplayMemory class Play: - def __init__(self, env, action_selector, memory, agent, sneaky_env, transpose = True, fps = 30, zoom = None, keys_to_action = None): + def __init__(self, env, action_selector, memory, memory_lock, agent, sneaky_env, config): self.env = env self.action_selector = action_selector - self.transpose = transpose - self.fps = fps - self.zoom = zoom - self.keys_to_action = None + self.memory = memory + self.memory_lock = memory_lock + self.agent = agent + self.sneaky_env = sneaky_env + # Get relevant parameters from config or set sane defaults + self.transpose = config['transpose'] if 'transpose' in config else True + self.fps = config['fps'] if 'fps' in config else 30 + self.zoom = config['zoom'] if 'zoom' in config else 1 + self.keys_to_action = config['keys_to_action'] if 'keys_to_action' in config else None + self.seconds_play_per_state = config['seconds_play_per_state'] if 'seconds_play_per_state' in config else 30 + self.num_sneaky_episodes = config['num_sneaky_episodes'] if 'num_sneaky_episodes' in config else 10 + self.memory_size = config['memory_size'] if 'memory_size' in config else 10**4 + self.replay_skip = config['replay_skip'] if 'replay_skip' in config else 0 + # Initial values... self.video_size = (0, 0) self.pressed_keys = [] self.screen = None self.relevant_keys = set() self.running = True - self.switch = Event() self.state = 0 - self.paused = False - self.memory = memory - self.agent = agent - self.sneaky_env = sneaky_env + self.clock = pygame.time.Clock() def _display_arr(self, obs, screen, arr, video_size): if obs is not None: @@ -48,6 +40,21 @@ class Play: pyg_img = pygame.surfarray.make_surface(arr.swapaxes(0, 1) if self.transpose else arr) pyg_img = pygame.transform.scale(pyg_img, video_size) screen.blit(pyg_img, (0,0)) + + def _process_common_pygame_events(self, event): + if event.type == pygame.QUIT: + self.running = False + elif event.type == VIDEORESIZE: + self.video_size = event.size + self.screen = pygame.display.set_mode(self.video_size) + elif event.type == pygame.KEYDOWN and event.key == pygame.K_ESCAPE: + self.running = False + else: + # No event was matched here + return False + # One of the events above matched + return True + def _human_play(self, obs): action = self.keys_to_action.get(tuple(sorted(self.pressed_keys)), 0) @@ -57,20 +64,14 @@ class Play: # process pygame events for event in pygame.event.get(): - # test events, set key states - if event.type == pygame.KEYDOWN: + if self._process_common_pygame_events(event): + continue + elif event.type == pygame.KEYDOWN: if event.key in self.relevant_keys: self.pressed_keys.append(event.key) - elif event.key == pygame.K_ESCAPE: - self.running = False elif event.type == pygame.KEYUP: if event.key in self.relevant_keys: self.pressed_keys.remove(event.key) - elif event.type == pygame.QUIT: - self.running = False - elif event.type == VIDEORESIZE: - self.video_size = event.size - self.screen = pygame.display.set_mode(self.video_size) pygame.display.flip() self.clock.tick(self.fps) @@ -84,13 +85,7 @@ class Play: # process pygame events for event in pygame.event.get(): - if event.type == pygame.QUIT: - self.running = False - elif event.type == VIDEORESIZE: - self.video_size = event.size - self.screen = pygame.display.set_mode(self.video_size) - elif event.type == pygame.KEYDOWN and event.key == pygame.K_ESCAPE: - self.running = False + self._process_common_pygame_events(event) pygame.display.flip() self.clock.tick(self.fps) @@ -107,7 +102,7 @@ class Play: self.video_size = video_size self.screen = pygame.display.set_mode(self.video_size) - pygame.font.init() # For later text + pygame.font.init() def _setup_keys(self): if self.keys_to_action is None: @@ -124,48 +119,59 @@ class Play: self.state = (self.state + 1) % 5 def pause(self, text = ""): - self.paused = True myfont = pygame.font.SysFont('Comic Sans MS', 50) textsurface = myfont.render(text, False, (0, 0, 0)) self.screen.blit(textsurface,(0,0)) + + # Process pygame events for event in pygame.event.get(): - if event.type == pygame.QUIT: - self.running = False - elif event.type == VIDEORESIZE: - self.video_size = event.size - self.screen = pygame.display.set_mode(self.video_size) + if self._process_common_pygame_events(event): + continue elif event.type == pygame.KEYDOWN: if event.key == pygame.K_SPACE: self.pressed_keys.append(event.key) - elif event.key == pygame.K_ESCAPE: - self.running = False elif event.type == pygame.KEYUP and event.key == pygame.K_SPACE: self.pressed_keys.remove(event.key) self._increment_state() - self.paused = False + pygame.display.flip() self.clock.tick(self.fps) - def sneaky_train(self): + self.memory_lock.acquire() + # Backup memory backup_memory = self.memory - self.memory = M.ReplayMemory(capacity = 2000) # Another configurable parameter - EPISODES = 30 # Make this configurable - replay_skip = 4 # Make this configurable - for _ in range(EPISODES): + self.memory = ReplayMemory(capacity = self.memory_size) + + # Do a standard RL algorithm process for a certain number of episodes + for i in range(self.num_sneaky_episodes): + print("Episode: %d / %d, Reward: " % (i + 1, self.num_sneaky_episodes), end = "") + + # Reset all episode releated variables prev_obs = self.sneaky_env.reset() done = False step = 0 + total_reward = 0 + while not done: action = self.action_selector.act(prev_obs) obs, reward, done, _ = self.sneaky_env.step(action) + total_reward += reward self.memory.append(prev_obs, action, reward, obs, done) prev_obs = obs step += 1 - if step % replay_skip == 0: + if step % self.replay_skip == 0: self.agent.learn() + + # Finish the previous print with the total reward obtained during the episode + print(total_reward) + + # Reset the memory back to the human demonstration / shown computer data self.memory = backup_memory + self.memory_lock.release() + + # Thoughts: # It would be cool instead of throwing away all this new data, we keep just a sample of it # Not sure if i want all of it because then it'll drown out the expert demonstration data @@ -178,55 +184,14 @@ class Play: Above code works also if env is wrapped, so it's particularly useful in verifying that the frame-level preprocessing does not render the game unplayable. - If you wish to plot real time statistics as you play, you can use - gym.utils.play.PlayPlot. Here's a sample code for plotting the reward - for last 5 second of gameplay. - def callback(obs_t, obs_tp1, rew, done, info): - return [rew,] - env_plotter = EnvPlotter(callback, 30 * 5, ["reward"]) - env = gym.make("Pong-v3") - play(env, callback=env_plotter.callback) - Arguments - --------- - env: gym.Env - Environment to use for playing. - transpose: bool - If True the output of observation is transposed. - Defaults to true. - fps: int - Maximum number of steps of the environment to execute every second. - Defaults to 30. - zoom: float - Make screen edge this many times bigger - callback: lambda or None - Callback if a callback is provided it will be executed after - every step. It takes the following input: - obs_t: observation before performing action - obs_tp1: observation after performing action - action: action that was executed - rew: reward that was received - done: whether the environment is done or not - info: debug info - keys_to_action: dict: tuple(int) -> int or None - Mapping from keys pressed to action performed. - For example if pressed 'w' and space at the same time is supposed - to trigger action number 2 then key_to_action dict would look like this: - { - # ... - sorted(ord('w'), ord(' ')) -> 2 - # ... - } - If None, default key_to_action mapping for that env is used, if provided. """ obs_s = self.env.unwrapped.observation_space - assert type(obs_s) == gym.spaces.box.Box + assert type(obs_s) == Box assert len(obs_s.shape) == 2 or (len(obs_s.shape) == 3 and obs_s.shape[2] in [1,3]) self._setup_keys() self._setup_video() - self.clock = pygame.time.Clock() - # States HUMAN_PLAY = 0 SNEAKY_COMPUTER_PLAY = 1 @@ -234,41 +199,61 @@ class Play: COMPUTER_PLAY = 3 TRANSITION2 = 4 - env_done = True - prev_obs = None obs = None - reward = 0 i = 0 while self.running: + # If the environment is done after a turn, reset it so we can keep going if env_done: obs = self.env.reset() env_done = False + + if self.state is HUMAN_PLAY: - prev_obs, action, reward, obs, env_done = self._human_play(obs) + _, _, _, obs, env_done = self._human_play(obs) + + # The computer will train for a few episodes without showing to the user. + # Mainly to speed up the learning process a bit elif self.state is SNEAKY_COMPUTER_PLAY: + print("Sneaky Computer Time") + + # Display "Training..." text to user myfont = pygame.font.SysFont('Comic Sans MS', 50) textsurface = myfont.render("Training....", False, (0, 0, 0)) self.screen.blit(textsurface,(0,0)) + pygame.display.flip() + + # Have the agent play a few rounds without showing to the user self.sneaky_train() + + # To take away training text + self._display_arr(obs, self.screen, self.env.unwrapped._get_obs(), video_size=self.video_size) + pygame.display.flip() + + # Go to the next step immediately self._increment_state() + elif self.state is TRANSITION: self.pause("Computers Turn! Press to Start") + elif self.state is COMPUTER_PLAY: - prev_obs, action, reward, obs, env_done = self._computer_play(obs) + _, _, _, obs, env_done = self._computer_play(obs) + elif self.state is TRANSITION2: self.pause("Your Turn! Press to Start") + # Increment the timer if it's the human or shown computer's turn if self.state is COMPUTER_PLAY or self.state is HUMAN_PLAY: - self.memory.append(prev_obs, action, reward, obs, env_done) i += 1 - # Every 30 seconds... - if i % (self.fps * 30) == 0: - print("Training...") + # Perform a quick learning process and increment the state after a certain time period has passed + if i % (self.fps * self.seconds_play_per_state) == 0: + self.memory_lock.acquire() + print("Number of transitions in buffer: ", len(self.memory)) self.agent.learn() - print("PAUSING...") + self.memory_lock.release() self._increment_state() i = 0 - + + # Stop the pygame environment when done pygame.quit() diff --git a/play_env.py b/play_env.py index 7e53911..815115f 100644 --- a/play_env.py +++ b/play_env.py @@ -1,21 +1,31 @@ -import play -import rltorch -import rltorch.memory as M -import torch -import gym + +# Import Python Standard Libraries +from threading import Thread, Lock +from argparse import ArgumentParser from collections import namedtuple from datetime import datetime + +# Import Pytorch related packages for NNs +from numpy import array as np_array +from numpy import save as np_save +import torch +from torch.optim import Adam +import torch.nn as nn +import torch.nn.functional as F + +# Import my custom RL library +import rltorch +from rltorch.memory import PrioritizedReplayMemory from rltorch.action_selector import EpsilonGreedySelector import rltorch.env as E import rltorch.network as rn -import torch.nn as nn -import torch.nn.functional as F -import pickle -import threading -from time import sleep -import argparse -import sys -import numpy as np + +# Import OpenAI gym and related packages +from gym import make as makeEnv +from gym import Wrapper as GymWrapper +from gym.wrappers import Monitor as GymMonitor +import play + # ## Networks @@ -73,56 +83,56 @@ class Value(nn.Module): Transition = namedtuple('Transition', ('state', 'action', 'reward', 'next_state', 'done')) -class PlayClass(threading.Thread): - def __init__(self, env, action_selector, memory, agent, sneaky_env, fps = 60): +class PlayClass(Thread): + def __init__(self, env, action_selector, memory, memory_lock, agent, sneaky_env, config): super(PlayClass, self).__init__() - self.env = env - self.fps = fps - self.play = play.Play(self.env, action_selector, memory, agent, sneaky_env, fps = fps, zoom = 4) + self.play = play.Play(env, action_selector, memory, memory_lock, agent, sneaky_env, config) def run(self): self.play.start() -class Record(gym.Wrapper): - def __init__(self, env, memory, args, skipframes = 3): - gym.Wrapper.__init__(self, env) - self.memory_lock = threading.Lock() +class Record(GymWrapper): + def __init__(self, env, memory, memory_lock, args): + GymWrapper.__init__(self, env) + self.memory_lock = memory_lock self.memory = memory - self.args = args - self.skipframes = skipframes - self.current_i = skipframes + self.skipframes = args['skip'] + self.environment_name = args['environment_name'] + self.logdir = args['logdir'] + self.current_i = 0 def reset(self): return self.env.reset() def step(self, action): - self.memory_lock.acquire() state = self.env.env._get_obs() next_state, reward, done, info = self.env.step(action) - if self.current_i <= 0: - self.memory.append(Transition(state, action, reward, next_state, done)) - self.current_i = self.skipframes - else: self.current_i -= 1 - self.memory_lock.release() + self.current_i += 1 + # Don't add to memory until a certain number of frames is reached + if self.current_i % self.skipframes == 0: + self.memory_lock.acquire() + self.memory.append(state, action, reward, next_state, done) + self.memory_lock.release() + self.current_i = 0 return next_state, reward, done, info def log_transitions(self): self.memory_lock.acquire() if len(self.memory) > 0: - basename = self.args['logdir'] + "/{}.{}".format(self.args['environment_name'], datetime.now().strftime("%Y-%m-%d-%H-%M-%s")) + basename = self.logdir + "/{}.{}".format(self.environment_name, datetime.now().strftime("%Y-%m-%d-%H-%M-%s")) print("Base Filename: ", basename) state, action, reward, next_state, done = zip(*self.memory) - np.save(basename + "-state.npy", np.array(state), allow_pickle = False) - np.save(basename + "-action.npy", np.array(action), allow_pickle = False) - np.save(basename + "-reward.npy", np.array(reward), allow_pickle = False) - np.save(basename + "-nextstate.npy", np.array(next_state), allow_pickle = False) - np.save(basename + "-done.npy", np.array(done), allow_pickle = False) + np_save(basename + "-state.npy", np_array(state), allow_pickle = False) + np_save(basename + "-action.npy", np_array(action), allow_pickle = False) + np_save(basename + "-reward.npy", np_array(reward), allow_pickle = False) + np_save(basename + "-nextstate.npy", np_array(next_state), allow_pickle = False) + np_save(basename + "-done.npy", np_array(done), allow_pickle = False) self.memory.clear() self.memory_lock.release() ## Parsing arguments -parser = argparse.ArgumentParser(description="Play and log the environment") +parser = ArgumentParser(description="Play and log the environment") parser.add_argument("--environment_name", type=str, help="The environment name in OpenAI gym to play.") parser.add_argument("--logdir", type=str, help="Directory to log video and (state, action, reward, next_state, done) in.") parser.add_argument("--skip", type=int, help="Number of frames to skip logging.") @@ -130,14 +140,20 @@ parser.add_argument("--fps", type=int, help="Number of frames per second") parser.add_argument("--model", type=str, help = "The path location of the PyTorch model") args = vars(parser.parse_args()) +## Main configuration for script config = {} config['seed'] = 901 +config['seconds_play_per_state'] = 60 +config['zoom'] = 4 config['environment_name'] = 'PongNoFrameskip-v4' config['learning_rate'] = 1e-4 config['target_sync_tau'] = 1e-3 config['discount_rate'] = 0.99 config['exploration_rate'] = rltorch.scheduler.ExponentialScheduler(initial_value = 1, end_value = 0.1, iterations = 10**5) -config['batch_size'] = 480 +# Number of episodes for the computer to train the agent without the human seeing +config['num_sneaky_episodes'] = 20 +config['replay_skip'] = 14 +config['batch_size'] = 32 * (config['replay_skip'] + 1) config['disable_cuda'] = False config['memory_size'] = 10**4 # Prioritized vs Random Sampling @@ -151,63 +167,73 @@ config['prioritized_replay_sampling_priority'] = 0.6 config['prioritized_replay_weight_importance'] = rltorch.scheduler.ExponentialScheduler(initial_value = 0.4, end_value = 1, iterations = 10**5) - +# Environment name and log directory is vital so show help message and exit if not provided if args['environment_name'] is None or args['logdir'] is None: parser.print_help() - sys.exit(1) + exit(1) +# Number of frames to skip when recording and fps can have sane defaults if args['skip'] is None: args['skip'] = 3 - if args['fps'] is None: args['fps'] = 30 -def wrap_preprocessing(env): + +def wrap_preprocessing(env, MaxAndSkipEnv = False): + env = E.NoopResetEnv( + E.EpisodicLifeEnv(env), + noop_max = 30 + ) + if MaxAndSkipEnv: + env = E.MaxAndSkipEnv(env, skip = 4) return E.ClippedRewardsWrapper( E.FrameStack( E.TorchWrap( E.ProcessFrame84( - E.FireResetEnv( - # E.MaxAndSkipEnv( - E.NoopResetEnv( - E.EpisodicLifeEnv(env) - , noop_max = 30) - # , skip=4) - ) + E.FireResetEnv(env) ) - ), - 4) + ) + , 4) ) -## Starting the game -memory = [] -env = Record(gym.make(args['environment_name']), memory, args, skipframes = args['skip']) + +## Set up environment to be recorded and preprocessed +memory = PrioritizedReplayMemory(capacity = config['memory_size'], alpha = config['prioritized_replay_sampling_priority']) +memory_lock = Lock() +env = Record(makeEnv(args['environment_name']), memory, memory_lock, args) +# Bind record_env to current env so that we can reference log_transitions easier later record_env = env -env = gym.wrappers.Monitor(env, args['logdir'], force=True) +# Use native gym monitor to get video recording +env = GymMonitor(env, args['logdir'], force=True) +# Preprocess enviornment env = wrap_preprocessing(env) -sneaky_env = wrap_preprocessing(gym.make(args['environment_name'])) +# Use a different environment for when the computer trains on the side so that the current game state isn't manipuated +# Also use MaxEnvSkip to speed up processing +sneaky_env = wrap_preprocessing(makeEnv(args['environment_name']), MaxAndSkipEnv = True) +# Set seeds rltorch.set_seed(config['seed']) +env.seed(config['seed']) device = torch.device("cuda:0" if torch.cuda.is_available() and not config['disable_cuda'] else "cpu") state_size = env.observation_space.shape[0] action_size = env.action_space.n +# Set up the networks net = rn.Network(Value(state_size, action_size), - torch.optim.Adam, config, device = device) + Adam, config, device = device) target_net = rn.TargetNetwork(net, device = device) +# Relevant components from RLTorch actor = EpsilonGreedySelector(net, action_size, device = device, epsilon = config['exploration_rate']) -memory = M.PrioritizedReplayMemory(capacity = config['memory_size'], alpha = config['prioritized_replay_sampling_priority']) agent = rltorch.agents.DQNAgent(net, memory, config, target_net = target_net) -env.seed(config['seed']) - -playThread = PlayClass(env, actor, memory, agent, sneaky_env, fps = args['fps']) +# Pass all this information into the thread that will handle the game play and start +playThread = PlayClass(env, actor, memory, memory_lock, agent, sneaky_env, config) playThread.start() -## Logging portion +# While the play thread is running, we'll periodically log transitions we've encountered while playThread.is_alive(): playThread.join(60) print("Logging....", end = " ")