Back and forth between computer play and human play while training an agent
This commit is contained in:
commit
1bf2c15542
4 changed files with 457 additions and 0 deletions
2
.gitignore
vendored
Normal file
2
.gitignore
vendored
Normal file
|
@ -0,0 +1,2 @@
|
|||
__pycache__/
|
||||
playlogs/
|
241
play.py
Normal file
241
play.py
Normal file
|
@ -0,0 +1,241 @@
|
|||
import gym
|
||||
import pygame
|
||||
import sys
|
||||
import time
|
||||
import matplotlib
|
||||
try:
|
||||
matplotlib.use('GTK3Agg')
|
||||
import matplotlib.pyplot as plt
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
import pyglet.window as pw
|
||||
|
||||
from collections import deque
|
||||
from pygame.locals import HWSURFACE, DOUBLEBUF, RESIZABLE, VIDEORESIZE
|
||||
from threading import Thread, Event, Timer
|
||||
|
||||
class Play:
|
||||
def __init__(self, env, action_selector, memory, agent, transpose = True, fps = 30, zoom = None, keys_to_action = None):
|
||||
self.env = env
|
||||
self.action_selector = action_selector
|
||||
self.transpose = transpose
|
||||
self.fps = fps
|
||||
self.zoom = zoom
|
||||
self.keys_to_action = None
|
||||
self.video_size = (0, 0)
|
||||
self.pressed_keys = []
|
||||
self.screen = None
|
||||
self.relevant_keys = set()
|
||||
self.running = True
|
||||
self.switch = Event()
|
||||
self.state = 0
|
||||
self.paused = False
|
||||
self.memory = memory
|
||||
self.agent = agent
|
||||
print("FPS ", 30)
|
||||
|
||||
def _display_arr(self, obs, screen, arr, video_size):
|
||||
if obs is not None:
|
||||
if len(obs.shape) == 2:
|
||||
obs = obs[:, :, None]
|
||||
if obs.shape[2] == 1:
|
||||
obs = obs.repeat(3, axis=2)
|
||||
arr_min, arr_max = arr.min(), arr.max()
|
||||
arr = 255.0 * (arr - arr_min) / (arr_max - arr_min)
|
||||
pyg_img = pygame.surfarray.make_surface(arr.swapaxes(0, 1) if self.transpose else arr)
|
||||
pyg_img = pygame.transform.scale(pyg_img, video_size)
|
||||
screen.blit(pyg_img, (0,0))
|
||||
|
||||
def _human_play(self, obs):
|
||||
action = self.keys_to_action.get(tuple(sorted(self.pressed_keys)), 0)
|
||||
prev_obs = obs
|
||||
obs, reward, env_done, _ = self.env.step(action)
|
||||
self._display_arr(obs, self.screen, self.env.unwrapped._get_obs(), video_size=self.video_size)
|
||||
|
||||
# process pygame events
|
||||
for event in pygame.event.get():
|
||||
# test events, set key states
|
||||
if event.type == pygame.KEYDOWN:
|
||||
if event.key in self.relevant_keys:
|
||||
self.pressed_keys.append(event.key)
|
||||
elif event.key == pygame.K_ESCAPE:
|
||||
self.running = False
|
||||
elif event.type == pygame.KEYUP:
|
||||
if event.key in self.relevant_keys:
|
||||
self.pressed_keys.remove(event.key)
|
||||
elif event.type == pygame.QUIT:
|
||||
self.running = False
|
||||
elif event.type == VIDEORESIZE:
|
||||
self.video_size = event.size
|
||||
self.screen = pygame.display.set_mode(self.video_size)
|
||||
|
||||
pygame.display.flip()
|
||||
self.clock.tick(self.fps)
|
||||
return prev_obs, action, reward, obs, env_done
|
||||
|
||||
def _computer_play(self, obs):
|
||||
prev_obs = obs
|
||||
action = self.action_selector.act(obs)
|
||||
obs, reward, env_done, _ = self.env.step(action)
|
||||
self._display_arr(obs, self.screen, self.env.unwrapped._get_obs(), video_size=self.video_size)
|
||||
|
||||
# process pygame events
|
||||
for event in pygame.event.get():
|
||||
if event.type == pygame.QUIT:
|
||||
self.running = False
|
||||
elif event.type == VIDEORESIZE:
|
||||
self.video_size = event.size
|
||||
self.screen = pygame.display.set_mode(self.video_size)
|
||||
elif event.type == pygame.KEYDOWN and event.key == pygame.K_ESCAPE:
|
||||
self.running = False
|
||||
|
||||
pygame.display.flip()
|
||||
self.clock.tick(self.fps)
|
||||
return prev_obs, action, reward, obs, env_done
|
||||
|
||||
def _setup_video(self):
|
||||
if self.transpose:
|
||||
video_size = self.env.unwrapped.observation_space.shape[1], self.env.unwrapped.observation_space.shape[0]
|
||||
else:
|
||||
video_size = self.env.unwrapped.observation_space.shape[0], self.env.unwrapped.observation_space.shape[1]
|
||||
|
||||
if self.zoom is not None:
|
||||
video_size = int(video_size[0] * self.zoom), int(video_size[1] * self.zoom)
|
||||
|
||||
self.video_size = video_size
|
||||
self.screen = pygame.display.set_mode(self.video_size)
|
||||
pygame.font.init() # For later text
|
||||
|
||||
def _setup_keys(self):
|
||||
if self.keys_to_action is None:
|
||||
if hasattr(self.env, 'get_keys_to_action'):
|
||||
self.keys_to_action = self.env.get_keys_to_action()
|
||||
elif hasattr(self.env.unwrapped, 'get_keys_to_action'):
|
||||
self.keys_to_action = self.env.unwrapped.get_keys_to_action()
|
||||
else:
|
||||
assert False, self.env.spec.id + " does not have explicit key to action mapping, " + \
|
||||
"please specify one manually"
|
||||
self.relevant_keys = set(sum(map(list, self.keys_to_action.keys()),[]))
|
||||
|
||||
def _increment_state(self):
|
||||
self.state = (self.state + 1) % 4
|
||||
|
||||
def pause(self, text = ""):
|
||||
self.paused = True
|
||||
myfont = pygame.font.SysFont('Comic Sans MS', 50)
|
||||
textsurface = myfont.render(text, False, (0, 0, 0))
|
||||
self.screen.blit(textsurface,(0,0))
|
||||
for event in pygame.event.get():
|
||||
if event.type == pygame.QUIT:
|
||||
self.running = False
|
||||
elif event.type == VIDEORESIZE:
|
||||
self.video_size = event.size
|
||||
self.screen = pygame.display.set_mode(self.video_size)
|
||||
elif event.type == pygame.KEYDOWN:
|
||||
if event.key == pygame.K_SPACE:
|
||||
self.pressed_keys.append(event.key)
|
||||
elif event.key == pygame.K_ESCAPE:
|
||||
self.running = False
|
||||
elif event.type == pygame.KEYUP and event.key == pygame.K_SPACE:
|
||||
self.pressed_keys.remove(event.key)
|
||||
self._increment_state()
|
||||
self.paused = False
|
||||
pygame.display.flip()
|
||||
self.clock.tick(self.fps)
|
||||
|
||||
def start(self):
|
||||
"""Allows one to play the game using keyboard.
|
||||
To simply play the game use:
|
||||
play(gym.make("Pong-v3"))
|
||||
Above code works also if env is wrapped, so it's particularly useful in
|
||||
verifying that the frame-level preprocessing does not render the game
|
||||
unplayable.
|
||||
If you wish to plot real time statistics as you play, you can use
|
||||
gym.utils.play.PlayPlot. Here's a sample code for plotting the reward
|
||||
for last 5 second of gameplay.
|
||||
def callback(obs_t, obs_tp1, rew, done, info):
|
||||
return [rew,]
|
||||
env_plotter = EnvPlotter(callback, 30 * 5, ["reward"])
|
||||
env = gym.make("Pong-v3")
|
||||
play(env, callback=env_plotter.callback)
|
||||
Arguments
|
||||
---------
|
||||
env: gym.Env
|
||||
Environment to use for playing.
|
||||
transpose: bool
|
||||
If True the output of observation is transposed.
|
||||
Defaults to true.
|
||||
fps: int
|
||||
Maximum number of steps of the environment to execute every second.
|
||||
Defaults to 30.
|
||||
zoom: float
|
||||
Make screen edge this many times bigger
|
||||
callback: lambda or None
|
||||
Callback if a callback is provided it will be executed after
|
||||
every step. It takes the following input:
|
||||
obs_t: observation before performing action
|
||||
obs_tp1: observation after performing action
|
||||
action: action that was executed
|
||||
rew: reward that was received
|
||||
done: whether the environment is done or not
|
||||
info: debug info
|
||||
keys_to_action: dict: tuple(int) -> int or None
|
||||
Mapping from keys pressed to action performed.
|
||||
For example if pressed 'w' and space at the same time is supposed
|
||||
to trigger action number 2 then key_to_action dict would look like this:
|
||||
{
|
||||
# ...
|
||||
sorted(ord('w'), ord(' ')) -> 2
|
||||
# ...
|
||||
}
|
||||
If None, default key_to_action mapping for that env is used, if provided.
|
||||
"""
|
||||
obs_s = self.env.unwrapped.observation_space
|
||||
assert type(obs_s) == gym.spaces.box.Box
|
||||
assert len(obs_s.shape) == 2 or (len(obs_s.shape) == 3 and obs_s.shape[2] in [1,3])
|
||||
|
||||
self._setup_keys()
|
||||
self._setup_video()
|
||||
|
||||
self.clock = pygame.time.Clock()
|
||||
|
||||
# States
|
||||
COMPUTER_PLAY = 0
|
||||
HUMAN_PLAY = 2
|
||||
|
||||
env_done = True
|
||||
prev_obs = None
|
||||
obs = None
|
||||
reward = 0
|
||||
i = 0
|
||||
while self.running:
|
||||
if env_done:
|
||||
obs = self.env.reset()
|
||||
env_done = False
|
||||
|
||||
if self.state == 0:
|
||||
prev_obs, action, reward, obs, env_done = self._computer_play(obs)
|
||||
elif self.state == 1:
|
||||
self.pause("Your Turn! Press <Space> to Start")
|
||||
elif self.state == 2:
|
||||
prev_obs, action, reward, obs, env_done = self._human_play(obs)
|
||||
elif self.state == 3:
|
||||
self.pause("Computers Turn! Press <Space> to Start")
|
||||
|
||||
if self.state is COMPUTER_PLAY or self.state is HUMAN_PLAY:
|
||||
self.memory.append(prev_obs, action, reward, obs, env_done)
|
||||
|
||||
if not self.paused:
|
||||
i += 1
|
||||
if i % (self.fps * 30) == 0: # Every 30 seconds...
|
||||
print("TRAINING...")
|
||||
self.agent.learn()
|
||||
print("PAUSING...")
|
||||
self._increment_state()
|
||||
i = 0
|
||||
|
||||
|
||||
pygame.quit()
|
||||
|
212
play_env.py
Normal file
212
play_env.py
Normal file
|
@ -0,0 +1,212 @@
|
|||
import play
|
||||
import rltorch
|
||||
import rltorch.memory as M
|
||||
import torch
|
||||
import gym
|
||||
from collections import namedtuple
|
||||
from datetime import datetime
|
||||
from rltorch.action_selector import EpsilonGreedySelector
|
||||
import rltorch.env as E
|
||||
import rltorch.network as rn
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import pickle
|
||||
import threading
|
||||
from time import sleep
|
||||
import argparse
|
||||
import sys
|
||||
import numpy as np
|
||||
|
||||
|
||||
## CURRRENT ISSUE: MaxSkipEnv applies to the human player as well, which makes for an awkward gaming experience
|
||||
# What are your thoughts? Training is different if expert isn't forced with the same constraint
|
||||
# At some point I need to introduce learning
|
||||
|
||||
class Value(nn.Module):
|
||||
def __init__(self, state_size, action_size):
|
||||
super(Value, self).__init__()
|
||||
self.state_size = state_size
|
||||
self.action_size = action_size
|
||||
|
||||
self.conv1 = nn.Conv2d(4, 32, kernel_size = (8, 8), stride = (4, 4))
|
||||
self.conv2 = nn.Conv2d(32, 64, kernel_size = (4, 4), stride = (2, 2))
|
||||
self.conv3 = nn.Conv2d(64, 64, kernel_size = (3, 3), stride = (1, 1))
|
||||
|
||||
self.fc1 = nn.Linear(3136, 512)
|
||||
self.fc1_norm = nn.LayerNorm(512)
|
||||
|
||||
self.value_fc = rn.NoisyLinear(512, 512)
|
||||
self.value_fc_norm = nn.LayerNorm(512)
|
||||
self.value = nn.Linear(512, 1)
|
||||
|
||||
self.advantage_fc = rn.NoisyLinear(512, 512)
|
||||
self.advantage_fc_norm = nn.LayerNorm(512)
|
||||
self.advantage = nn.Linear(512, action_size)
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
x = x.float() / 256
|
||||
x = F.relu(self.conv1(x))
|
||||
x = F.relu(self.conv2(x))
|
||||
x = F.relu(self.conv3(x))
|
||||
|
||||
# Makes batch_size dimension again
|
||||
x = x.view(-1, 3136)
|
||||
x = F.relu(self.fc1_norm(self.fc1(x)))
|
||||
|
||||
state_value = F.relu(self.value_fc_norm(self.value_fc(x)))
|
||||
state_value = self.value(state_value)
|
||||
|
||||
advantage = F.relu(self.advantage_fc_norm(self.advantage_fc(x)))
|
||||
advantage = self.advantage(advantage)
|
||||
|
||||
x = state_value + advantage - advantage.mean()
|
||||
|
||||
# For debugging purposes...
|
||||
if torch.isnan(x).any().item():
|
||||
print("WARNING NAN IN MODEL DETECTED")
|
||||
|
||||
return x
|
||||
|
||||
|
||||
|
||||
Transition = namedtuple('Transition',
|
||||
('state', 'action', 'reward', 'next_state', 'done'))
|
||||
|
||||
class PlayClass(threading.Thread):
|
||||
def __init__(self, env, action_selector, memory, agent, fps = 60):
|
||||
super(PlayClass, self).__init__()
|
||||
self.env = env
|
||||
self.fps = fps
|
||||
self.play = play.Play(self.env, action_selector, memory, agent, fps = fps, zoom = 4)
|
||||
|
||||
def run(self):
|
||||
self.play.start()
|
||||
|
||||
class Record(gym.Wrapper):
|
||||
def __init__(self, env, memory, args, skipframes = 3):
|
||||
gym.Wrapper.__init__(self, env)
|
||||
self.memory_lock = threading.Lock()
|
||||
self.memory = memory
|
||||
self.args = args
|
||||
self.skipframes = skipframes
|
||||
self.current_i = skipframes
|
||||
|
||||
def reset(self):
|
||||
return self.env.reset()
|
||||
|
||||
def step(self, action):
|
||||
self.memory_lock.acquire()
|
||||
state = self.env.env._get_obs()
|
||||
next_state, reward, done, info = self.env.step(action)
|
||||
if self.current_i <= 0:
|
||||
self.memory.append(Transition(state, action, reward, next_state, done))
|
||||
self.current_i = self.skipframes
|
||||
else: self.current_i -= 1
|
||||
self.memory_lock.release()
|
||||
return next_state, reward, done, info
|
||||
|
||||
def log_transitions(self):
|
||||
self.memory_lock.acquire()
|
||||
if len(self.memory) > 0:
|
||||
basename = self.args['logdir'] + "/{}.{}".format(self.args['environment_name'], datetime.now().strftime("%Y-%m-%d-%H-%M-%s"))
|
||||
print("Base Filename: ", basename)
|
||||
state, action, reward, next_state, done = zip(*self.memory)
|
||||
np.save(basename + "-state.npy", np.array(state), allow_pickle = False)
|
||||
np.save(basename + "-action.npy", np.array(action), allow_pickle = False)
|
||||
np.save(basename + "-reward.npy", np.array(reward), allow_pickle = False)
|
||||
np.save(basename + "-nextstate.npy", np.array(next_state), allow_pickle = False)
|
||||
np.save(basename + "-done.npy", np.array(done), allow_pickle = False)
|
||||
self.memory.clear()
|
||||
self.memory_lock.release()
|
||||
|
||||
|
||||
## Parsing arguments
|
||||
parser = argparse.ArgumentParser(description="Play and log the environment")
|
||||
parser.add_argument("--environment_name", type=str, help="The environment name in OpenAI gym to play.")
|
||||
parser.add_argument("--logdir", type=str, help="Directory to log video and (state, action, reward, next_state, done) in.")
|
||||
parser.add_argument("--skip", type=int, help="Number of frames to skip logging.")
|
||||
parser.add_argument("--fps", type=int, help="Number of frames per second")
|
||||
parser.add_argument("--model", type=str, help = "The path location of the PyTorch model")
|
||||
args = vars(parser.parse_args())
|
||||
|
||||
config = {}
|
||||
config['seed'] = 901
|
||||
config['environment_name'] = 'PongNoFrameskip-v4'
|
||||
config['learning_rate'] = 1e-4
|
||||
config['target_sync_tau'] = 1e-3
|
||||
config['discount_rate'] = 0.99
|
||||
config['exploration_rate'] = rltorch.scheduler.ExponentialScheduler(initial_value = 1, end_value = 0.1, iterations = 10**5)
|
||||
config['batch_size'] = 480
|
||||
config['disable_cuda'] = False
|
||||
config['memory_size'] = 10**4
|
||||
# Prioritized vs Random Sampling
|
||||
# 0 - Random sampling
|
||||
# 1 - Only the highest prioirities
|
||||
config['prioritized_replay_sampling_priority'] = 0.6
|
||||
# How important are the weights for the loss?
|
||||
# 0 - Treat all losses equally
|
||||
# 1 - Lower the importance of high losses
|
||||
# Should ideally start from 0 and move your way to 1 to prevent overfitting
|
||||
config['prioritized_replay_weight_importance'] = rltorch.scheduler.ExponentialScheduler(initial_value = 0.4, end_value = 1, iterations = 10**5)
|
||||
|
||||
|
||||
|
||||
if args['environment_name'] is None or args['logdir'] is None:
|
||||
parser.print_help()
|
||||
sys.exit(1)
|
||||
|
||||
if args['skip'] is None:
|
||||
args['skip'] = 3
|
||||
|
||||
if args['fps'] is None:
|
||||
args['fps'] = 30
|
||||
|
||||
## Starting the game
|
||||
memory = []
|
||||
env = Record(gym.make(args['environment_name']), memory, args, skipframes = args['skip'])
|
||||
record_env = env
|
||||
env = gym.wrappers.Monitor(env, args['logdir'], force=True)
|
||||
env = E.ClippedRewardsWrapper(
|
||||
E.FrameStack(
|
||||
E.TorchWrap(
|
||||
E.ProcessFrame84(
|
||||
E.FireResetEnv(
|
||||
# E.MaxAndSkipEnv(
|
||||
E.NoopResetEnv(
|
||||
E.EpisodicLifeEnv(gym.make(config['environment_name']))
|
||||
, noop_max = 30)
|
||||
# , skip=4)
|
||||
)
|
||||
)
|
||||
),
|
||||
4)
|
||||
)
|
||||
|
||||
rltorch.set_seed(config['seed'])
|
||||
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() and not config['disable_cuda'] else "cpu")
|
||||
state_size = env.observation_space.shape[0]
|
||||
action_size = env.action_space.n
|
||||
|
||||
net = rn.Network(Value(state_size, action_size),
|
||||
torch.optim.Adam, config, device = device)
|
||||
target_net = rn.TargetNetwork(net, device = device)
|
||||
|
||||
actor = EpsilonGreedySelector(net, action_size, device = device, epsilon = config['exploration_rate'])
|
||||
memory = M.PrioritizedReplayMemory(capacity = config['memory_size'], alpha = config['prioritized_replay_sampling_priority'])
|
||||
agent = rltorch.agents.DQNAgent(net, memory, config, target_net = target_net)
|
||||
|
||||
env.seed(config['seed'])
|
||||
|
||||
playThread = PlayClass(env, actor, memory, agent, args['fps'])
|
||||
playThread.start()
|
||||
|
||||
## Logging portion
|
||||
while playThread.is_alive():
|
||||
playThread.join(60)
|
||||
print("Logging....", end = " ")
|
||||
record_env.log_transitions()
|
||||
|
||||
# Save what's remaining after process died
|
||||
record_env.log_transitions()
|
2
play_pong.sh
Executable file
2
play_pong.sh
Executable file
|
@ -0,0 +1,2 @@
|
|||
#!/bin/sh
|
||||
python play_env.py --environment_name=PongNoFrameskip-v4 --logdir=playlogs
|
Loading…
Reference in a new issue