Updated GymInteract to introduce a form of hidden training between showing the human play the game and the computer
This commit is contained in:
parent
1bf2c15542
commit
b7aa4a4ec6
2 changed files with 71 additions and 33 deletions
67
play.py
67
play.py
|
@ -3,6 +3,7 @@ import pygame
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
import matplotlib
|
import matplotlib
|
||||||
|
import rltorch.memory as M
|
||||||
try:
|
try:
|
||||||
matplotlib.use('GTK3Agg')
|
matplotlib.use('GTK3Agg')
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
@ -17,7 +18,7 @@ from pygame.locals import HWSURFACE, DOUBLEBUF, RESIZABLE, VIDEORESIZE
|
||||||
from threading import Thread, Event, Timer
|
from threading import Thread, Event, Timer
|
||||||
|
|
||||||
class Play:
|
class Play:
|
||||||
def __init__(self, env, action_selector, memory, agent, transpose = True, fps = 30, zoom = None, keys_to_action = None):
|
def __init__(self, env, action_selector, memory, agent, sneaky_env, transpose = True, fps = 30, zoom = None, keys_to_action = None):
|
||||||
self.env = env
|
self.env = env
|
||||||
self.action_selector = action_selector
|
self.action_selector = action_selector
|
||||||
self.transpose = transpose
|
self.transpose = transpose
|
||||||
|
@ -34,7 +35,7 @@ class Play:
|
||||||
self.paused = False
|
self.paused = False
|
||||||
self.memory = memory
|
self.memory = memory
|
||||||
self.agent = agent
|
self.agent = agent
|
||||||
print("FPS ", 30)
|
self.sneaky_env = sneaky_env
|
||||||
|
|
||||||
def _display_arr(self, obs, screen, arr, video_size):
|
def _display_arr(self, obs, screen, arr, video_size):
|
||||||
if obs is not None:
|
if obs is not None:
|
||||||
|
@ -120,7 +121,7 @@ class Play:
|
||||||
self.relevant_keys = set(sum(map(list, self.keys_to_action.keys()),[]))
|
self.relevant_keys = set(sum(map(list, self.keys_to_action.keys()),[]))
|
||||||
|
|
||||||
def _increment_state(self):
|
def _increment_state(self):
|
||||||
self.state = (self.state + 1) % 4
|
self.state = (self.state + 1) % 5
|
||||||
|
|
||||||
def pause(self, text = ""):
|
def pause(self, text = ""):
|
||||||
self.paused = True
|
self.paused = True
|
||||||
|
@ -145,6 +146,31 @@ class Play:
|
||||||
pygame.display.flip()
|
pygame.display.flip()
|
||||||
self.clock.tick(self.fps)
|
self.clock.tick(self.fps)
|
||||||
|
|
||||||
|
|
||||||
|
def sneaky_train(self):
|
||||||
|
# Backup memory
|
||||||
|
backup_memory = self.memory
|
||||||
|
self.memory = M.ReplayMemory(capacity = 2000) # Another configurable parameter
|
||||||
|
EPISODES = 30 # Make this configurable
|
||||||
|
replay_skip = 4 # Make this configurable
|
||||||
|
for _ in range(EPISODES):
|
||||||
|
prev_obs = self.sneaky_env.reset()
|
||||||
|
done = False
|
||||||
|
step = 0
|
||||||
|
while not done:
|
||||||
|
action = self.action_selector.act(prev_obs)
|
||||||
|
obs, reward, done, _ = self.sneaky_env.step(action)
|
||||||
|
self.memory.append(prev_obs, action, reward, obs, done)
|
||||||
|
prev_obs = obs
|
||||||
|
step += 1
|
||||||
|
if step % replay_skip == 0:
|
||||||
|
self.agent.learn()
|
||||||
|
self.memory = backup_memory
|
||||||
|
# It would be cool instead of throwing away all this new data, we keep just a sample of it
|
||||||
|
# Not sure if i want all of it because then it'll drown out the expert demonstration data
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def start(self):
|
def start(self):
|
||||||
"""Allows one to play the game using keyboard.
|
"""Allows one to play the game using keyboard.
|
||||||
To simply play the game use:
|
To simply play the game use:
|
||||||
|
@ -202,8 +228,12 @@ class Play:
|
||||||
self.clock = pygame.time.Clock()
|
self.clock = pygame.time.Clock()
|
||||||
|
|
||||||
# States
|
# States
|
||||||
COMPUTER_PLAY = 0
|
HUMAN_PLAY = 0
|
||||||
HUMAN_PLAY = 2
|
SNEAKY_COMPUTER_PLAY = 1
|
||||||
|
TRANSITION = 2
|
||||||
|
COMPUTER_PLAY = 3
|
||||||
|
TRANSITION2 = 4
|
||||||
|
|
||||||
|
|
||||||
env_done = True
|
env_done = True
|
||||||
prev_obs = None
|
prev_obs = None
|
||||||
|
@ -214,28 +244,31 @@ class Play:
|
||||||
if env_done:
|
if env_done:
|
||||||
obs = self.env.reset()
|
obs = self.env.reset()
|
||||||
env_done = False
|
env_done = False
|
||||||
|
if self.state is HUMAN_PLAY:
|
||||||
if self.state == 0:
|
|
||||||
prev_obs, action, reward, obs, env_done = self._computer_play(obs)
|
|
||||||
elif self.state == 1:
|
|
||||||
self.pause("Your Turn! Press <Space> to Start")
|
|
||||||
elif self.state == 2:
|
|
||||||
prev_obs, action, reward, obs, env_done = self._human_play(obs)
|
prev_obs, action, reward, obs, env_done = self._human_play(obs)
|
||||||
elif self.state == 3:
|
elif self.state is SNEAKY_COMPUTER_PLAY:
|
||||||
|
myfont = pygame.font.SysFont('Comic Sans MS', 50)
|
||||||
|
textsurface = myfont.render("Training....", False, (0, 0, 0))
|
||||||
|
self.screen.blit(textsurface,(0,0))
|
||||||
|
self.sneaky_train()
|
||||||
|
self._increment_state()
|
||||||
|
elif self.state is TRANSITION:
|
||||||
self.pause("Computers Turn! Press <Space> to Start")
|
self.pause("Computers Turn! Press <Space> to Start")
|
||||||
|
elif self.state is COMPUTER_PLAY:
|
||||||
|
prev_obs, action, reward, obs, env_done = self._computer_play(obs)
|
||||||
|
elif self.state is TRANSITION2:
|
||||||
|
self.pause("Your Turn! Press <Space> to Start")
|
||||||
|
|
||||||
if self.state is COMPUTER_PLAY or self.state is HUMAN_PLAY:
|
if self.state is COMPUTER_PLAY or self.state is HUMAN_PLAY:
|
||||||
self.memory.append(prev_obs, action, reward, obs, env_done)
|
self.memory.append(prev_obs, action, reward, obs, env_done)
|
||||||
|
|
||||||
if not self.paused:
|
|
||||||
i += 1
|
i += 1
|
||||||
if i % (self.fps * 30) == 0: # Every 30 seconds...
|
# Every 30 seconds...
|
||||||
print("TRAINING...")
|
if i % (self.fps * 30) == 0:
|
||||||
|
print("Training...")
|
||||||
self.agent.learn()
|
self.agent.learn()
|
||||||
print("PAUSING...")
|
print("PAUSING...")
|
||||||
self._increment_state()
|
self._increment_state()
|
||||||
i = 0
|
i = 0
|
||||||
|
|
||||||
|
|
||||||
pygame.quit()
|
pygame.quit()
|
||||||
|
|
||||||
|
|
37
play_env.py
37
play_env.py
|
@ -17,11 +17,9 @@ import argparse
|
||||||
import sys
|
import sys
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
#
|
||||||
## CURRRENT ISSUE: MaxSkipEnv applies to the human player as well, which makes for an awkward gaming experience
|
## Networks
|
||||||
# What are your thoughts? Training is different if expert isn't forced with the same constraint
|
#
|
||||||
# At some point I need to introduce learning
|
|
||||||
|
|
||||||
class Value(nn.Module):
|
class Value(nn.Module):
|
||||||
def __init__(self, state_size, action_size):
|
def __init__(self, state_size, action_size):
|
||||||
super(Value, self).__init__()
|
super(Value, self).__init__()
|
||||||
|
@ -69,16 +67,18 @@ class Value(nn.Module):
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
#
|
||||||
|
## Play Related Classes
|
||||||
|
#
|
||||||
Transition = namedtuple('Transition',
|
Transition = namedtuple('Transition',
|
||||||
('state', 'action', 'reward', 'next_state', 'done'))
|
('state', 'action', 'reward', 'next_state', 'done'))
|
||||||
|
|
||||||
class PlayClass(threading.Thread):
|
class PlayClass(threading.Thread):
|
||||||
def __init__(self, env, action_selector, memory, agent, fps = 60):
|
def __init__(self, env, action_selector, memory, agent, sneaky_env, fps = 60):
|
||||||
super(PlayClass, self).__init__()
|
super(PlayClass, self).__init__()
|
||||||
self.env = env
|
self.env = env
|
||||||
self.fps = fps
|
self.fps = fps
|
||||||
self.play = play.Play(self.env, action_selector, memory, agent, fps = fps, zoom = 4)
|
self.play = play.Play(self.env, action_selector, memory, agent, sneaky_env, fps = fps, zoom = 4)
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
self.play.start()
|
self.play.start()
|
||||||
|
@ -162,19 +162,15 @@ if args['skip'] is None:
|
||||||
if args['fps'] is None:
|
if args['fps'] is None:
|
||||||
args['fps'] = 30
|
args['fps'] = 30
|
||||||
|
|
||||||
## Starting the game
|
def wrap_preprocessing(env):
|
||||||
memory = []
|
return E.ClippedRewardsWrapper(
|
||||||
env = Record(gym.make(args['environment_name']), memory, args, skipframes = args['skip'])
|
|
||||||
record_env = env
|
|
||||||
env = gym.wrappers.Monitor(env, args['logdir'], force=True)
|
|
||||||
env = E.ClippedRewardsWrapper(
|
|
||||||
E.FrameStack(
|
E.FrameStack(
|
||||||
E.TorchWrap(
|
E.TorchWrap(
|
||||||
E.ProcessFrame84(
|
E.ProcessFrame84(
|
||||||
E.FireResetEnv(
|
E.FireResetEnv(
|
||||||
# E.MaxAndSkipEnv(
|
# E.MaxAndSkipEnv(
|
||||||
E.NoopResetEnv(
|
E.NoopResetEnv(
|
||||||
E.EpisodicLifeEnv(gym.make(config['environment_name']))
|
E.EpisodicLifeEnv(env)
|
||||||
, noop_max = 30)
|
, noop_max = 30)
|
||||||
# , skip=4)
|
# , skip=4)
|
||||||
)
|
)
|
||||||
|
@ -183,6 +179,15 @@ env = E.ClippedRewardsWrapper(
|
||||||
4)
|
4)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
## Starting the game
|
||||||
|
memory = []
|
||||||
|
env = Record(gym.make(args['environment_name']), memory, args, skipframes = args['skip'])
|
||||||
|
record_env = env
|
||||||
|
env = gym.wrappers.Monitor(env, args['logdir'], force=True)
|
||||||
|
env = wrap_preprocessing(env)
|
||||||
|
|
||||||
|
sneaky_env = wrap_preprocessing(gym.make(args['environment_name']))
|
||||||
|
|
||||||
rltorch.set_seed(config['seed'])
|
rltorch.set_seed(config['seed'])
|
||||||
|
|
||||||
device = torch.device("cuda:0" if torch.cuda.is_available() and not config['disable_cuda'] else "cpu")
|
device = torch.device("cuda:0" if torch.cuda.is_available() and not config['disable_cuda'] else "cpu")
|
||||||
|
@ -199,7 +204,7 @@ agent = rltorch.agents.DQNAgent(net, memory, config, target_net = target_net)
|
||||||
|
|
||||||
env.seed(config['seed'])
|
env.seed(config['seed'])
|
||||||
|
|
||||||
playThread = PlayClass(env, actor, memory, agent, args['fps'])
|
playThread = PlayClass(env, actor, memory, agent, sneaky_env, fps = args['fps'])
|
||||||
playThread.start()
|
playThread.start()
|
||||||
|
|
||||||
## Logging portion
|
## Logging portion
|
||||||
|
|
Loading…
Reference in a new issue