Initial Commit
This commit is contained in:
commit
758d819ad1
4 changed files with 272 additions and 0 deletions
2
.gitignore
vendored
Normal file
2
.gitignore
vendored
Normal file
|
@ -0,0 +1,2 @@
|
||||||
|
__pycache__/
|
||||||
|
playlogs/
|
174
play.py
Normal file
174
play.py
Normal file
|
@ -0,0 +1,174 @@
|
||||||
|
import gym
|
||||||
|
import pygame
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
import matplotlib
|
||||||
|
try:
|
||||||
|
matplotlib.use('GTK3Agg')
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
import pyglet.window as pw
|
||||||
|
|
||||||
|
from collections import deque
|
||||||
|
from pygame.locals import HWSURFACE, DOUBLEBUF, RESIZABLE, VIDEORESIZE
|
||||||
|
from threading import Thread
|
||||||
|
|
||||||
|
def display_arr(screen, arr, video_size, transpose):
|
||||||
|
arr_min, arr_max = arr.min(), arr.max()
|
||||||
|
arr = 255.0 * (arr - arr_min) / (arr_max - arr_min)
|
||||||
|
pyg_img = pygame.surfarray.make_surface(arr.swapaxes(0, 1) if transpose else arr)
|
||||||
|
pyg_img = pygame.transform.scale(pyg_img, video_size)
|
||||||
|
screen.blit(pyg_img, (0,0))
|
||||||
|
|
||||||
|
def play(env, transpose=True, fps=30, zoom=None, callback=None, keys_to_action=None):
|
||||||
|
"""Allows one to play the game using keyboard.
|
||||||
|
To simply play the game use:
|
||||||
|
play(gym.make("Pong-v3"))
|
||||||
|
Above code works also if env is wrapped, so it's particularly useful in
|
||||||
|
verifying that the frame-level preprocessing does not render the game
|
||||||
|
unplayable.
|
||||||
|
If you wish to plot real time statistics as you play, you can use
|
||||||
|
gym.utils.play.PlayPlot. Here's a sample code for plotting the reward
|
||||||
|
for last 5 second of gameplay.
|
||||||
|
def callback(obs_t, obs_tp1, rew, done, info):
|
||||||
|
return [rew,]
|
||||||
|
env_plotter = EnvPlotter(callback, 30 * 5, ["reward"])
|
||||||
|
env = gym.make("Pong-v3")
|
||||||
|
play(env, callback=env_plotter.callback)
|
||||||
|
Arguments
|
||||||
|
---------
|
||||||
|
env: gym.Env
|
||||||
|
Environment to use for playing.
|
||||||
|
transpose: bool
|
||||||
|
If True the output of observation is transposed.
|
||||||
|
Defaults to true.
|
||||||
|
fps: int
|
||||||
|
Maximum number of steps of the environment to execute every second.
|
||||||
|
Defaults to 30.
|
||||||
|
zoom: float
|
||||||
|
Make screen edge this many times bigger
|
||||||
|
callback: lambda or None
|
||||||
|
Callback if a callback is provided it will be executed after
|
||||||
|
every step. It takes the following input:
|
||||||
|
obs_t: observation before performing action
|
||||||
|
obs_tp1: observation after performing action
|
||||||
|
action: action that was executed
|
||||||
|
rew: reward that was received
|
||||||
|
done: whether the environment is done or not
|
||||||
|
info: debug info
|
||||||
|
keys_to_action: dict: tuple(int) -> int or None
|
||||||
|
Mapping from keys pressed to action performed.
|
||||||
|
For example if pressed 'w' and space at the same time is supposed
|
||||||
|
to trigger action number 2 then key_to_action dict would look like this:
|
||||||
|
{
|
||||||
|
# ...
|
||||||
|
sorted(ord('w'), ord(' ')) -> 2
|
||||||
|
# ...
|
||||||
|
}
|
||||||
|
If None, default key_to_action mapping for that env is used, if provided.
|
||||||
|
"""
|
||||||
|
|
||||||
|
obs_s = env.observation_space
|
||||||
|
assert type(obs_s) == gym.spaces.box.Box
|
||||||
|
assert len(obs_s.shape) == 2 or (len(obs_s.shape) == 3 and obs_s.shape[2] in [1,3])
|
||||||
|
|
||||||
|
if keys_to_action is None:
|
||||||
|
if hasattr(env, 'get_keys_to_action'):
|
||||||
|
keys_to_action = env.get_keys_to_action()
|
||||||
|
elif hasattr(env.unwrapped, 'get_keys_to_action'):
|
||||||
|
keys_to_action = env.unwrapped.get_keys_to_action()
|
||||||
|
else:
|
||||||
|
assert False, env.spec.id + " does not have explicit key to action mapping, " + \
|
||||||
|
"please specify one manually"
|
||||||
|
relevant_keys = set(sum(map(list, keys_to_action.keys()),[]))
|
||||||
|
|
||||||
|
if transpose:
|
||||||
|
video_size = env.observation_space.shape[1], env.observation_space.shape[0]
|
||||||
|
else:
|
||||||
|
video_size = env.observation_space.shape[0], env.observation_space.shape[1]
|
||||||
|
|
||||||
|
if zoom is not None:
|
||||||
|
video_size = int(video_size[0] * zoom), int(video_size[1] * zoom)
|
||||||
|
|
||||||
|
pressed_keys = []
|
||||||
|
running = True
|
||||||
|
env_done = True
|
||||||
|
|
||||||
|
screen = pygame.display.set_mode(video_size)
|
||||||
|
clock = pygame.time.Clock()
|
||||||
|
|
||||||
|
|
||||||
|
while running:
|
||||||
|
if env_done:
|
||||||
|
env_done = False
|
||||||
|
obs = env.reset()
|
||||||
|
else:
|
||||||
|
action = keys_to_action.get(tuple(sorted(pressed_keys)), 0)
|
||||||
|
prev_obs = obs
|
||||||
|
obs, rew, env_done, info = env.step(action)
|
||||||
|
if callback is not None:
|
||||||
|
callback(prev_obs, obs, action, rew, env_done, info)
|
||||||
|
if obs is not None:
|
||||||
|
if len(obs.shape) == 2:
|
||||||
|
obs = obs[:, :, None]
|
||||||
|
if obs.shape[2] == 1:
|
||||||
|
obs = obs.repeat(3, axis=2)
|
||||||
|
display_arr(screen, obs, transpose=transpose, video_size=video_size)
|
||||||
|
|
||||||
|
# process pygame events
|
||||||
|
for event in pygame.event.get():
|
||||||
|
# test events, set key states
|
||||||
|
if event.type == pygame.KEYDOWN:
|
||||||
|
if event.key in relevant_keys:
|
||||||
|
pressed_keys.append(event.key)
|
||||||
|
elif event.key == 27:
|
||||||
|
running = False
|
||||||
|
elif event.type == pygame.KEYUP:
|
||||||
|
if event.key in relevant_keys:
|
||||||
|
pressed_keys.remove(event.key)
|
||||||
|
elif event.type == pygame.QUIT:
|
||||||
|
running = False
|
||||||
|
elif event.type == VIDEORESIZE:
|
||||||
|
video_size = event.size
|
||||||
|
screen = pygame.display.set_mode(video_size)
|
||||||
|
print(video_size)
|
||||||
|
|
||||||
|
pygame.display.flip()
|
||||||
|
clock.tick(fps)
|
||||||
|
pygame.quit()
|
||||||
|
|
||||||
|
class PlayPlot(object):
|
||||||
|
def __init__(self, callback, horizon_timesteps, plot_names):
|
||||||
|
self.data_callback = callback
|
||||||
|
self.horizon_timesteps = horizon_timesteps
|
||||||
|
self.plot_names = plot_names
|
||||||
|
|
||||||
|
num_plots = len(self.plot_names)
|
||||||
|
self.fig, self.ax = plt.subplots(num_plots)
|
||||||
|
if num_plots == 1:
|
||||||
|
self.ax = [self.ax]
|
||||||
|
for axis, name in zip(self.ax, plot_names):
|
||||||
|
axis.set_title(name)
|
||||||
|
self.t = 0
|
||||||
|
self.cur_plot = [None for _ in range(num_plots)]
|
||||||
|
self.data = [deque(maxlen=horizon_timesteps) for _ in range(num_plots)]
|
||||||
|
|
||||||
|
def callback(self, obs_t, obs_tp1, action, rew, done, info):
|
||||||
|
points = self.data_callback(obs_t, obs_tp1, action, rew, done, info)
|
||||||
|
for point, data_series in zip(points, self.data):
|
||||||
|
data_series.append(point)
|
||||||
|
self.t += 1
|
||||||
|
|
||||||
|
xmin, xmax = max(0, self.t - self.horizon_timesteps), self.t
|
||||||
|
|
||||||
|
for i, plot in enumerate(self.cur_plot):
|
||||||
|
if plot is not None:
|
||||||
|
plot.remove()
|
||||||
|
self.cur_plot[i] = self.ax[i].scatter(range(xmin, xmax), list(self.data[i]))
|
||||||
|
self.ax[i].set_xlim(xmin, xmax)
|
||||||
|
plt.pause(0.000001)
|
||||||
|
|
||||||
|
|
94
play_env.py
Normal file
94
play_env.py
Normal file
|
@ -0,0 +1,94 @@
|
||||||
|
import play
|
||||||
|
import gym
|
||||||
|
from collections import namedtuple
|
||||||
|
from datetime import datetime
|
||||||
|
import pickle
|
||||||
|
import threading
|
||||||
|
from time import sleep
|
||||||
|
import argparse
|
||||||
|
import sys
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
Transition = namedtuple('Transition',
|
||||||
|
('state', 'action', 'reward', 'next_state', 'done'))
|
||||||
|
|
||||||
|
class PlayClass(threading.Thread):
|
||||||
|
def __init__(self, env, fps = 60):
|
||||||
|
super(PlayClass, self).__init__()
|
||||||
|
self.env = env
|
||||||
|
self.fps = fps
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
play.play(self.env, fps = self.fps, zoom = 4)
|
||||||
|
|
||||||
|
class Record(gym.Wrapper):
|
||||||
|
def __init__(self, env, memory, args, skipframes = 3):
|
||||||
|
gym.Wrapper.__init__(self, env)
|
||||||
|
self.memory_lock = threading.Lock()
|
||||||
|
self.memory = memory
|
||||||
|
self.args = args
|
||||||
|
self.skipframes = skipframes
|
||||||
|
self.current_i = skipframes
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
return self.env.reset()
|
||||||
|
|
||||||
|
def step(self, action):
|
||||||
|
self.memory_lock.acquire()
|
||||||
|
state = self.env.env._get_obs()
|
||||||
|
next_state, reward, done, info = self.env.step(action)
|
||||||
|
if self.current_i <= 0:
|
||||||
|
self.memory.append(Transition(state, action, reward, next_state, done))
|
||||||
|
self.current_i = self.skipframes
|
||||||
|
else: self.current_i -= 1
|
||||||
|
self.memory_lock.release()
|
||||||
|
return next_state, reward, done, info
|
||||||
|
|
||||||
|
def log_transitions(self):
|
||||||
|
self.memory_lock.acquire()
|
||||||
|
if len(self.memory) > 0:
|
||||||
|
basename = self.args['logdir'] + "/{}.{}".format(self.args['environment_name'], datetime.now().strftime("%Y-%m-%d-%H-%M-%s"))
|
||||||
|
print("Base Filename: ", basename)
|
||||||
|
state, action, reward, next_state, done = zip(*self.memory)
|
||||||
|
np.save(basename + "-state.npy", np.array(state), allow_pickle = False)
|
||||||
|
np.save(basename + "-action.npy", np.array(action), allow_pickle = False)
|
||||||
|
np.save(basename + "-reward.npy", np.array(reward), allow_pickle = False)
|
||||||
|
np.save(basename + "-nextstate.npy", np.array(next_state), allow_pickle = False)
|
||||||
|
np.save(basename + "-done.npy", np.array(done), allow_pickle = False)
|
||||||
|
self.memory.clear()
|
||||||
|
self.memory_lock.release()
|
||||||
|
|
||||||
|
|
||||||
|
## Parsing arguments
|
||||||
|
parser = argparse.ArgumentParser(description="Play and log the environment")
|
||||||
|
parser.add_argument("--environment_name", type=str, help="The environment name in OpenAI gym to play.")
|
||||||
|
parser.add_argument("--logdir", type=str, help="Directory to log video and (state, action, reward, next_state, done) in.")
|
||||||
|
parser.add_argument("--skip", type=int, help="Number of frames to skip logging.")
|
||||||
|
parser.add_argument("--fps", type=int, help="Number of frames per second")
|
||||||
|
args = vars(parser.parse_args())
|
||||||
|
|
||||||
|
if args['environment_name'] is None or args['logdir'] is None:
|
||||||
|
parser.print_help()
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
if args['skip'] is None:
|
||||||
|
args['skip'] = 3
|
||||||
|
|
||||||
|
if args['fps'] is None:
|
||||||
|
args['fps'] = 30
|
||||||
|
|
||||||
|
## Starting the game
|
||||||
|
memory = []
|
||||||
|
env = Record(gym.make(args['environment_name']), memory, args, skipframes = args['skip'])
|
||||||
|
env = gym.wrappers.Monitor(env, args['logdir'], force=True)
|
||||||
|
playThread = PlayClass(env, args['fps'])
|
||||||
|
playThread.start()
|
||||||
|
|
||||||
|
## Logging portion
|
||||||
|
while playThread.is_alive():
|
||||||
|
playThread.join(60)
|
||||||
|
print("Logging....", end = " ")
|
||||||
|
env.log_transitions()
|
||||||
|
|
||||||
|
# Save what's remaining after process died
|
||||||
|
env.log_transitions()
|
2
play_pong.sh
Executable file
2
play_pong.sh
Executable file
|
@ -0,0 +1,2 @@
|
||||||
|
#!/bin/sh
|
||||||
|
python play_env.py --environment_name=PongNoFrameskip-v4 --logdir=playlogs
|
Loading…
Reference in a new issue