commit 758d819ad12ad30571306d9a6f4a42f50c000c9e Author: Brandon Rozek Date: Tue Sep 3 07:16:26 2019 -0400 Initial Commit diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..5f46325 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +__pycache__/ +playlogs/ diff --git a/play.py b/play.py new file mode 100644 index 0000000..ab64eae --- /dev/null +++ b/play.py @@ -0,0 +1,174 @@ +import gym +import pygame +import sys +import time +import matplotlib +try: + matplotlib.use('GTK3Agg') + import matplotlib.pyplot as plt +except Exception: + pass + + +import pyglet.window as pw + +from collections import deque +from pygame.locals import HWSURFACE, DOUBLEBUF, RESIZABLE, VIDEORESIZE +from threading import Thread + +def display_arr(screen, arr, video_size, transpose): + arr_min, arr_max = arr.min(), arr.max() + arr = 255.0 * (arr - arr_min) / (arr_max - arr_min) + pyg_img = pygame.surfarray.make_surface(arr.swapaxes(0, 1) if transpose else arr) + pyg_img = pygame.transform.scale(pyg_img, video_size) + screen.blit(pyg_img, (0,0)) + +def play(env, transpose=True, fps=30, zoom=None, callback=None, keys_to_action=None): + """Allows one to play the game using keyboard. + To simply play the game use: + play(gym.make("Pong-v3")) + Above code works also if env is wrapped, so it's particularly useful in + verifying that the frame-level preprocessing does not render the game + unplayable. + If you wish to plot real time statistics as you play, you can use + gym.utils.play.PlayPlot. Here's a sample code for plotting the reward + for last 5 second of gameplay. + def callback(obs_t, obs_tp1, rew, done, info): + return [rew,] + env_plotter = EnvPlotter(callback, 30 * 5, ["reward"]) + env = gym.make("Pong-v3") + play(env, callback=env_plotter.callback) + Arguments + --------- + env: gym.Env + Environment to use for playing. + transpose: bool + If True the output of observation is transposed. + Defaults to true. + fps: int + Maximum number of steps of the environment to execute every second. + Defaults to 30. + zoom: float + Make screen edge this many times bigger + callback: lambda or None + Callback if a callback is provided it will be executed after + every step. It takes the following input: + obs_t: observation before performing action + obs_tp1: observation after performing action + action: action that was executed + rew: reward that was received + done: whether the environment is done or not + info: debug info + keys_to_action: dict: tuple(int) -> int or None + Mapping from keys pressed to action performed. + For example if pressed 'w' and space at the same time is supposed + to trigger action number 2 then key_to_action dict would look like this: + { + # ... + sorted(ord('w'), ord(' ')) -> 2 + # ... + } + If None, default key_to_action mapping for that env is used, if provided. + """ + + obs_s = env.observation_space + assert type(obs_s) == gym.spaces.box.Box + assert len(obs_s.shape) == 2 or (len(obs_s.shape) == 3 and obs_s.shape[2] in [1,3]) + + if keys_to_action is None: + if hasattr(env, 'get_keys_to_action'): + keys_to_action = env.get_keys_to_action() + elif hasattr(env.unwrapped, 'get_keys_to_action'): + keys_to_action = env.unwrapped.get_keys_to_action() + else: + assert False, env.spec.id + " does not have explicit key to action mapping, " + \ + "please specify one manually" + relevant_keys = set(sum(map(list, keys_to_action.keys()),[])) + + if transpose: + video_size = env.observation_space.shape[1], env.observation_space.shape[0] + else: + video_size = env.observation_space.shape[0], env.observation_space.shape[1] + + if zoom is not None: + video_size = int(video_size[0] * zoom), int(video_size[1] * zoom) + + pressed_keys = [] + running = True + env_done = True + + screen = pygame.display.set_mode(video_size) + clock = pygame.time.Clock() + + + while running: + if env_done: + env_done = False + obs = env.reset() + else: + action = keys_to_action.get(tuple(sorted(pressed_keys)), 0) + prev_obs = obs + obs, rew, env_done, info = env.step(action) + if callback is not None: + callback(prev_obs, obs, action, rew, env_done, info) + if obs is not None: + if len(obs.shape) == 2: + obs = obs[:, :, None] + if obs.shape[2] == 1: + obs = obs.repeat(3, axis=2) + display_arr(screen, obs, transpose=transpose, video_size=video_size) + + # process pygame events + for event in pygame.event.get(): + # test events, set key states + if event.type == pygame.KEYDOWN: + if event.key in relevant_keys: + pressed_keys.append(event.key) + elif event.key == 27: + running = False + elif event.type == pygame.KEYUP: + if event.key in relevant_keys: + pressed_keys.remove(event.key) + elif event.type == pygame.QUIT: + running = False + elif event.type == VIDEORESIZE: + video_size = event.size + screen = pygame.display.set_mode(video_size) + print(video_size) + + pygame.display.flip() + clock.tick(fps) + pygame.quit() + +class PlayPlot(object): + def __init__(self, callback, horizon_timesteps, plot_names): + self.data_callback = callback + self.horizon_timesteps = horizon_timesteps + self.plot_names = plot_names + + num_plots = len(self.plot_names) + self.fig, self.ax = plt.subplots(num_plots) + if num_plots == 1: + self.ax = [self.ax] + for axis, name in zip(self.ax, plot_names): + axis.set_title(name) + self.t = 0 + self.cur_plot = [None for _ in range(num_plots)] + self.data = [deque(maxlen=horizon_timesteps) for _ in range(num_plots)] + + def callback(self, obs_t, obs_tp1, action, rew, done, info): + points = self.data_callback(obs_t, obs_tp1, action, rew, done, info) + for point, data_series in zip(points, self.data): + data_series.append(point) + self.t += 1 + + xmin, xmax = max(0, self.t - self.horizon_timesteps), self.t + + for i, plot in enumerate(self.cur_plot): + if plot is not None: + plot.remove() + self.cur_plot[i] = self.ax[i].scatter(range(xmin, xmax), list(self.data[i])) + self.ax[i].set_xlim(xmin, xmax) + plt.pause(0.000001) + + diff --git a/play_env.py b/play_env.py new file mode 100644 index 0000000..9b99d4d --- /dev/null +++ b/play_env.py @@ -0,0 +1,94 @@ +import play +import gym +from collections import namedtuple +from datetime import datetime +import pickle +import threading +from time import sleep +import argparse +import sys +import numpy as np + +Transition = namedtuple('Transition', + ('state', 'action', 'reward', 'next_state', 'done')) + +class PlayClass(threading.Thread): + def __init__(self, env, fps = 60): + super(PlayClass, self).__init__() + self.env = env + self.fps = fps + + def run(self): + play.play(self.env, fps = self.fps, zoom = 4) + +class Record(gym.Wrapper): + def __init__(self, env, memory, args, skipframes = 3): + gym.Wrapper.__init__(self, env) + self.memory_lock = threading.Lock() + self.memory = memory + self.args = args + self.skipframes = skipframes + self.current_i = skipframes + + def reset(self): + return self.env.reset() + + def step(self, action): + self.memory_lock.acquire() + state = self.env.env._get_obs() + next_state, reward, done, info = self.env.step(action) + if self.current_i <= 0: + self.memory.append(Transition(state, action, reward, next_state, done)) + self.current_i = self.skipframes + else: self.current_i -= 1 + self.memory_lock.release() + return next_state, reward, done, info + + def log_transitions(self): + self.memory_lock.acquire() + if len(self.memory) > 0: + basename = self.args['logdir'] + "/{}.{}".format(self.args['environment_name'], datetime.now().strftime("%Y-%m-%d-%H-%M-%s")) + print("Base Filename: ", basename) + state, action, reward, next_state, done = zip(*self.memory) + np.save(basename + "-state.npy", np.array(state), allow_pickle = False) + np.save(basename + "-action.npy", np.array(action), allow_pickle = False) + np.save(basename + "-reward.npy", np.array(reward), allow_pickle = False) + np.save(basename + "-nextstate.npy", np.array(next_state), allow_pickle = False) + np.save(basename + "-done.npy", np.array(done), allow_pickle = False) + self.memory.clear() + self.memory_lock.release() + + +## Parsing arguments +parser = argparse.ArgumentParser(description="Play and log the environment") +parser.add_argument("--environment_name", type=str, help="The environment name in OpenAI gym to play.") +parser.add_argument("--logdir", type=str, help="Directory to log video and (state, action, reward, next_state, done) in.") +parser.add_argument("--skip", type=int, help="Number of frames to skip logging.") +parser.add_argument("--fps", type=int, help="Number of frames per second") +args = vars(parser.parse_args()) + +if args['environment_name'] is None or args['logdir'] is None: + parser.print_help() + sys.exit(1) + +if args['skip'] is None: + args['skip'] = 3 + +if args['fps'] is None: + args['fps'] = 30 + +## Starting the game +memory = [] +env = Record(gym.make(args['environment_name']), memory, args, skipframes = args['skip']) +env = gym.wrappers.Monitor(env, args['logdir'], force=True) +playThread = PlayClass(env, args['fps']) +playThread.start() + +## Logging portion +while playThread.is_alive(): + playThread.join(60) + print("Logging....", end = " ") + env.log_transitions() + +# Save what's remaining after process died +env.log_transitions() \ No newline at end of file diff --git a/play_pong.sh b/play_pong.sh new file mode 100755 index 0000000..b7591ce --- /dev/null +++ b/play_pong.sh @@ -0,0 +1,2 @@ +#!/bin/sh +python play_env.py --environment_name=PongNoFrameskip-v4 --logdir=playlogs