SneakyTrain uses separate replay buffer

Scripts were cleaned up considerably and comments were added
This commit is contained in:
Brandon Rozek 2019-10-23 21:53:20 -04:00
parent b7aa4a4ec6
commit d78892e62c
2 changed files with 186 additions and 175 deletions

207
play.py
View file

@ -1,41 +1,33 @@
import gym from gym.spaces.box import Box
import pygame import pygame
import sys from pygame.locals import VIDEORESIZE
import time from rltorch.memory import ReplayMemory
import matplotlib
import rltorch.memory as M
try:
matplotlib.use('GTK3Agg')
import matplotlib.pyplot as plt
except Exception:
pass
import pyglet.window as pw
from collections import deque
from pygame.locals import HWSURFACE, DOUBLEBUF, RESIZABLE, VIDEORESIZE
from threading import Thread, Event, Timer
class Play: class Play:
def __init__(self, env, action_selector, memory, agent, sneaky_env, transpose = True, fps = 30, zoom = None, keys_to_action = None): def __init__(self, env, action_selector, memory, memory_lock, agent, sneaky_env, config):
self.env = env self.env = env
self.action_selector = action_selector self.action_selector = action_selector
self.transpose = transpose self.memory = memory
self.fps = fps self.memory_lock = memory_lock
self.zoom = zoom self.agent = agent
self.keys_to_action = None self.sneaky_env = sneaky_env
# Get relevant parameters from config or set sane defaults
self.transpose = config['transpose'] if 'transpose' in config else True
self.fps = config['fps'] if 'fps' in config else 30
self.zoom = config['zoom'] if 'zoom' in config else 1
self.keys_to_action = config['keys_to_action'] if 'keys_to_action' in config else None
self.seconds_play_per_state = config['seconds_play_per_state'] if 'seconds_play_per_state' in config else 30
self.num_sneaky_episodes = config['num_sneaky_episodes'] if 'num_sneaky_episodes' in config else 10
self.memory_size = config['memory_size'] if 'memory_size' in config else 10**4
self.replay_skip = config['replay_skip'] if 'replay_skip' in config else 0
# Initial values...
self.video_size = (0, 0) self.video_size = (0, 0)
self.pressed_keys = [] self.pressed_keys = []
self.screen = None self.screen = None
self.relevant_keys = set() self.relevant_keys = set()
self.running = True self.running = True
self.switch = Event()
self.state = 0 self.state = 0
self.paused = False self.clock = pygame.time.Clock()
self.memory = memory
self.agent = agent
self.sneaky_env = sneaky_env
def _display_arr(self, obs, screen, arr, video_size): def _display_arr(self, obs, screen, arr, video_size):
if obs is not None: if obs is not None:
@ -49,6 +41,21 @@ class Play:
pyg_img = pygame.transform.scale(pyg_img, video_size) pyg_img = pygame.transform.scale(pyg_img, video_size)
screen.blit(pyg_img, (0,0)) screen.blit(pyg_img, (0,0))
def _process_common_pygame_events(self, event):
if event.type == pygame.QUIT:
self.running = False
elif event.type == VIDEORESIZE:
self.video_size = event.size
self.screen = pygame.display.set_mode(self.video_size)
elif event.type == pygame.KEYDOWN and event.key == pygame.K_ESCAPE:
self.running = False
else:
# No event was matched here
return False
# One of the events above matched
return True
def _human_play(self, obs): def _human_play(self, obs):
action = self.keys_to_action.get(tuple(sorted(self.pressed_keys)), 0) action = self.keys_to_action.get(tuple(sorted(self.pressed_keys)), 0)
prev_obs = obs prev_obs = obs
@ -57,20 +64,14 @@ class Play:
# process pygame events # process pygame events
for event in pygame.event.get(): for event in pygame.event.get():
# test events, set key states if self._process_common_pygame_events(event):
if event.type == pygame.KEYDOWN: continue
elif event.type == pygame.KEYDOWN:
if event.key in self.relevant_keys: if event.key in self.relevant_keys:
self.pressed_keys.append(event.key) self.pressed_keys.append(event.key)
elif event.key == pygame.K_ESCAPE:
self.running = False
elif event.type == pygame.KEYUP: elif event.type == pygame.KEYUP:
if event.key in self.relevant_keys: if event.key in self.relevant_keys:
self.pressed_keys.remove(event.key) self.pressed_keys.remove(event.key)
elif event.type == pygame.QUIT:
self.running = False
elif event.type == VIDEORESIZE:
self.video_size = event.size
self.screen = pygame.display.set_mode(self.video_size)
pygame.display.flip() pygame.display.flip()
self.clock.tick(self.fps) self.clock.tick(self.fps)
@ -84,13 +85,7 @@ class Play:
# process pygame events # process pygame events
for event in pygame.event.get(): for event in pygame.event.get():
if event.type == pygame.QUIT: self._process_common_pygame_events(event)
self.running = False
elif event.type == VIDEORESIZE:
self.video_size = event.size
self.screen = pygame.display.set_mode(self.video_size)
elif event.type == pygame.KEYDOWN and event.key == pygame.K_ESCAPE:
self.running = False
pygame.display.flip() pygame.display.flip()
self.clock.tick(self.fps) self.clock.tick(self.fps)
@ -107,7 +102,7 @@ class Play:
self.video_size = video_size self.video_size = video_size
self.screen = pygame.display.set_mode(self.video_size) self.screen = pygame.display.set_mode(self.video_size)
pygame.font.init() # For later text pygame.font.init()
def _setup_keys(self): def _setup_keys(self):
if self.keys_to_action is None: if self.keys_to_action is None:
@ -124,48 +119,59 @@ class Play:
self.state = (self.state + 1) % 5 self.state = (self.state + 1) % 5
def pause(self, text = ""): def pause(self, text = ""):
self.paused = True
myfont = pygame.font.SysFont('Comic Sans MS', 50) myfont = pygame.font.SysFont('Comic Sans MS', 50)
textsurface = myfont.render(text, False, (0, 0, 0)) textsurface = myfont.render(text, False, (0, 0, 0))
self.screen.blit(textsurface,(0,0)) self.screen.blit(textsurface,(0,0))
# Process pygame events
for event in pygame.event.get(): for event in pygame.event.get():
if event.type == pygame.QUIT: if self._process_common_pygame_events(event):
self.running = False continue
elif event.type == VIDEORESIZE:
self.video_size = event.size
self.screen = pygame.display.set_mode(self.video_size)
elif event.type == pygame.KEYDOWN: elif event.type == pygame.KEYDOWN:
if event.key == pygame.K_SPACE: if event.key == pygame.K_SPACE:
self.pressed_keys.append(event.key) self.pressed_keys.append(event.key)
elif event.key == pygame.K_ESCAPE:
self.running = False
elif event.type == pygame.KEYUP and event.key == pygame.K_SPACE: elif event.type == pygame.KEYUP and event.key == pygame.K_SPACE:
self.pressed_keys.remove(event.key) self.pressed_keys.remove(event.key)
self._increment_state() self._increment_state()
self.paused = False
pygame.display.flip() pygame.display.flip()
self.clock.tick(self.fps) self.clock.tick(self.fps)
def sneaky_train(self): def sneaky_train(self):
self.memory_lock.acquire()
# Backup memory # Backup memory
backup_memory = self.memory backup_memory = self.memory
self.memory = M.ReplayMemory(capacity = 2000) # Another configurable parameter self.memory = ReplayMemory(capacity = self.memory_size)
EPISODES = 30 # Make this configurable
replay_skip = 4 # Make this configurable # Do a standard RL algorithm process for a certain number of episodes
for _ in range(EPISODES): for i in range(self.num_sneaky_episodes):
print("Episode: %d / %d, Reward: " % (i + 1, self.num_sneaky_episodes), end = "")
# Reset all episode releated variables
prev_obs = self.sneaky_env.reset() prev_obs = self.sneaky_env.reset()
done = False done = False
step = 0 step = 0
total_reward = 0
while not done: while not done:
action = self.action_selector.act(prev_obs) action = self.action_selector.act(prev_obs)
obs, reward, done, _ = self.sneaky_env.step(action) obs, reward, done, _ = self.sneaky_env.step(action)
total_reward += reward
self.memory.append(prev_obs, action, reward, obs, done) self.memory.append(prev_obs, action, reward, obs, done)
prev_obs = obs prev_obs = obs
step += 1 step += 1
if step % replay_skip == 0: if step % self.replay_skip == 0:
self.agent.learn() self.agent.learn()
# Finish the previous print with the total reward obtained during the episode
print(total_reward)
# Reset the memory back to the human demonstration / shown computer data
self.memory = backup_memory self.memory = backup_memory
self.memory_lock.release()
# Thoughts:
# It would be cool instead of throwing away all this new data, we keep just a sample of it # It would be cool instead of throwing away all this new data, we keep just a sample of it
# Not sure if i want all of it because then it'll drown out the expert demonstration data # Not sure if i want all of it because then it'll drown out the expert demonstration data
@ -178,55 +184,14 @@ class Play:
Above code works also if env is wrapped, so it's particularly useful in Above code works also if env is wrapped, so it's particularly useful in
verifying that the frame-level preprocessing does not render the game verifying that the frame-level preprocessing does not render the game
unplayable. unplayable.
If you wish to plot real time statistics as you play, you can use
gym.utils.play.PlayPlot. Here's a sample code for plotting the reward
for last 5 second of gameplay.
def callback(obs_t, obs_tp1, rew, done, info):
return [rew,]
env_plotter = EnvPlotter(callback, 30 * 5, ["reward"])
env = gym.make("Pong-v3")
play(env, callback=env_plotter.callback)
Arguments
---------
env: gym.Env
Environment to use for playing.
transpose: bool
If True the output of observation is transposed.
Defaults to true.
fps: int
Maximum number of steps of the environment to execute every second.
Defaults to 30.
zoom: float
Make screen edge this many times bigger
callback: lambda or None
Callback if a callback is provided it will be executed after
every step. It takes the following input:
obs_t: observation before performing action
obs_tp1: observation after performing action
action: action that was executed
rew: reward that was received
done: whether the environment is done or not
info: debug info
keys_to_action: dict: tuple(int) -> int or None
Mapping from keys pressed to action performed.
For example if pressed 'w' and space at the same time is supposed
to trigger action number 2 then key_to_action dict would look like this:
{
# ...
sorted(ord('w'), ord(' ')) -> 2
# ...
}
If None, default key_to_action mapping for that env is used, if provided.
""" """
obs_s = self.env.unwrapped.observation_space obs_s = self.env.unwrapped.observation_space
assert type(obs_s) == gym.spaces.box.Box assert type(obs_s) == Box
assert len(obs_s.shape) == 2 or (len(obs_s.shape) == 3 and obs_s.shape[2] in [1,3]) assert len(obs_s.shape) == 2 or (len(obs_s.shape) == 3 and obs_s.shape[2] in [1,3])
self._setup_keys() self._setup_keys()
self._setup_video() self._setup_video()
self.clock = pygame.time.Clock()
# States # States
HUMAN_PLAY = 0 HUMAN_PLAY = 0
SNEAKY_COMPUTER_PLAY = 1 SNEAKY_COMPUTER_PLAY = 1
@ -234,41 +199,61 @@ class Play:
COMPUTER_PLAY = 3 COMPUTER_PLAY = 3
TRANSITION2 = 4 TRANSITION2 = 4
env_done = True env_done = True
prev_obs = None
obs = None obs = None
reward = 0
i = 0 i = 0
while self.running: while self.running:
# If the environment is done after a turn, reset it so we can keep going
if env_done: if env_done:
obs = self.env.reset() obs = self.env.reset()
env_done = False env_done = False
if self.state is HUMAN_PLAY: if self.state is HUMAN_PLAY:
prev_obs, action, reward, obs, env_done = self._human_play(obs) _, _, _, obs, env_done = self._human_play(obs)
# The computer will train for a few episodes without showing to the user.
# Mainly to speed up the learning process a bit
elif self.state is SNEAKY_COMPUTER_PLAY: elif self.state is SNEAKY_COMPUTER_PLAY:
print("Sneaky Computer Time")
# Display "Training..." text to user
myfont = pygame.font.SysFont('Comic Sans MS', 50) myfont = pygame.font.SysFont('Comic Sans MS', 50)
textsurface = myfont.render("Training....", False, (0, 0, 0)) textsurface = myfont.render("Training....", False, (0, 0, 0))
self.screen.blit(textsurface,(0,0)) self.screen.blit(textsurface,(0,0))
pygame.display.flip()
# Have the agent play a few rounds without showing to the user
self.sneaky_train() self.sneaky_train()
# To take away training text
self._display_arr(obs, self.screen, self.env.unwrapped._get_obs(), video_size=self.video_size)
pygame.display.flip()
# Go to the next step immediately
self._increment_state() self._increment_state()
elif self.state is TRANSITION: elif self.state is TRANSITION:
self.pause("Computers Turn! Press <Space> to Start") self.pause("Computers Turn! Press <Space> to Start")
elif self.state is COMPUTER_PLAY: elif self.state is COMPUTER_PLAY:
prev_obs, action, reward, obs, env_done = self._computer_play(obs) _, _, _, obs, env_done = self._computer_play(obs)
elif self.state is TRANSITION2: elif self.state is TRANSITION2:
self.pause("Your Turn! Press <Space> to Start") self.pause("Your Turn! Press <Space> to Start")
# Increment the timer if it's the human or shown computer's turn
if self.state is COMPUTER_PLAY or self.state is HUMAN_PLAY: if self.state is COMPUTER_PLAY or self.state is HUMAN_PLAY:
self.memory.append(prev_obs, action, reward, obs, env_done)
i += 1 i += 1
# Every 30 seconds... # Perform a quick learning process and increment the state after a certain time period has passed
if i % (self.fps * 30) == 0: if i % (self.fps * self.seconds_play_per_state) == 0:
print("Training...") self.memory_lock.acquire()
print("Number of transitions in buffer: ", len(self.memory))
self.agent.learn() self.agent.learn()
print("PAUSING...") self.memory_lock.release()
self._increment_state() self._increment_state()
i = 0 i = 0
# Stop the pygame environment when done
pygame.quit() pygame.quit()

View file

@ -1,21 +1,31 @@
import play
import rltorch # Import Python Standard Libraries
import rltorch.memory as M from threading import Thread, Lock
import torch from argparse import ArgumentParser
import gym
from collections import namedtuple from collections import namedtuple
from datetime import datetime from datetime import datetime
# Import Pytorch related packages for NNs
from numpy import array as np_array
from numpy import save as np_save
import torch
from torch.optim import Adam
import torch.nn as nn
import torch.nn.functional as F
# Import my custom RL library
import rltorch
from rltorch.memory import PrioritizedReplayMemory
from rltorch.action_selector import EpsilonGreedySelector from rltorch.action_selector import EpsilonGreedySelector
import rltorch.env as E import rltorch.env as E
import rltorch.network as rn import rltorch.network as rn
import torch.nn as nn
import torch.nn.functional as F # Import OpenAI gym and related packages
import pickle from gym import make as makeEnv
import threading from gym import Wrapper as GymWrapper
from time import sleep from gym.wrappers import Monitor as GymMonitor
import argparse import play
import sys
import numpy as np
# #
## Networks ## Networks
@ -73,56 +83,56 @@ class Value(nn.Module):
Transition = namedtuple('Transition', Transition = namedtuple('Transition',
('state', 'action', 'reward', 'next_state', 'done')) ('state', 'action', 'reward', 'next_state', 'done'))
class PlayClass(threading.Thread): class PlayClass(Thread):
def __init__(self, env, action_selector, memory, agent, sneaky_env, fps = 60): def __init__(self, env, action_selector, memory, memory_lock, agent, sneaky_env, config):
super(PlayClass, self).__init__() super(PlayClass, self).__init__()
self.env = env self.play = play.Play(env, action_selector, memory, memory_lock, agent, sneaky_env, config)
self.fps = fps
self.play = play.Play(self.env, action_selector, memory, agent, sneaky_env, fps = fps, zoom = 4)
def run(self): def run(self):
self.play.start() self.play.start()
class Record(gym.Wrapper): class Record(GymWrapper):
def __init__(self, env, memory, args, skipframes = 3): def __init__(self, env, memory, memory_lock, args):
gym.Wrapper.__init__(self, env) GymWrapper.__init__(self, env)
self.memory_lock = threading.Lock() self.memory_lock = memory_lock
self.memory = memory self.memory = memory
self.args = args self.skipframes = args['skip']
self.skipframes = skipframes self.environment_name = args['environment_name']
self.current_i = skipframes self.logdir = args['logdir']
self.current_i = 0
def reset(self): def reset(self):
return self.env.reset() return self.env.reset()
def step(self, action): def step(self, action):
self.memory_lock.acquire()
state = self.env.env._get_obs() state = self.env.env._get_obs()
next_state, reward, done, info = self.env.step(action) next_state, reward, done, info = self.env.step(action)
if self.current_i <= 0: self.current_i += 1
self.memory.append(Transition(state, action, reward, next_state, done)) # Don't add to memory until a certain number of frames is reached
self.current_i = self.skipframes if self.current_i % self.skipframes == 0:
else: self.current_i -= 1 self.memory_lock.acquire()
self.memory.append(state, action, reward, next_state, done)
self.memory_lock.release() self.memory_lock.release()
self.current_i = 0
return next_state, reward, done, info return next_state, reward, done, info
def log_transitions(self): def log_transitions(self):
self.memory_lock.acquire() self.memory_lock.acquire()
if len(self.memory) > 0: if len(self.memory) > 0:
basename = self.args['logdir'] + "/{}.{}".format(self.args['environment_name'], datetime.now().strftime("%Y-%m-%d-%H-%M-%s")) basename = self.logdir + "/{}.{}".format(self.environment_name, datetime.now().strftime("%Y-%m-%d-%H-%M-%s"))
print("Base Filename: ", basename) print("Base Filename: ", basename)
state, action, reward, next_state, done = zip(*self.memory) state, action, reward, next_state, done = zip(*self.memory)
np.save(basename + "-state.npy", np.array(state), allow_pickle = False) np_save(basename + "-state.npy", np_array(state), allow_pickle = False)
np.save(basename + "-action.npy", np.array(action), allow_pickle = False) np_save(basename + "-action.npy", np_array(action), allow_pickle = False)
np.save(basename + "-reward.npy", np.array(reward), allow_pickle = False) np_save(basename + "-reward.npy", np_array(reward), allow_pickle = False)
np.save(basename + "-nextstate.npy", np.array(next_state), allow_pickle = False) np_save(basename + "-nextstate.npy", np_array(next_state), allow_pickle = False)
np.save(basename + "-done.npy", np.array(done), allow_pickle = False) np_save(basename + "-done.npy", np_array(done), allow_pickle = False)
self.memory.clear() self.memory.clear()
self.memory_lock.release() self.memory_lock.release()
## Parsing arguments ## Parsing arguments
parser = argparse.ArgumentParser(description="Play and log the environment") parser = ArgumentParser(description="Play and log the environment")
parser.add_argument("--environment_name", type=str, help="The environment name in OpenAI gym to play.") parser.add_argument("--environment_name", type=str, help="The environment name in OpenAI gym to play.")
parser.add_argument("--logdir", type=str, help="Directory to log video and (state, action, reward, next_state, done) in.") parser.add_argument("--logdir", type=str, help="Directory to log video and (state, action, reward, next_state, done) in.")
parser.add_argument("--skip", type=int, help="Number of frames to skip logging.") parser.add_argument("--skip", type=int, help="Number of frames to skip logging.")
@ -130,14 +140,20 @@ parser.add_argument("--fps", type=int, help="Number of frames per second")
parser.add_argument("--model", type=str, help = "The path location of the PyTorch model") parser.add_argument("--model", type=str, help = "The path location of the PyTorch model")
args = vars(parser.parse_args()) args = vars(parser.parse_args())
## Main configuration for script
config = {} config = {}
config['seed'] = 901 config['seed'] = 901
config['seconds_play_per_state'] = 60
config['zoom'] = 4
config['environment_name'] = 'PongNoFrameskip-v4' config['environment_name'] = 'PongNoFrameskip-v4'
config['learning_rate'] = 1e-4 config['learning_rate'] = 1e-4
config['target_sync_tau'] = 1e-3 config['target_sync_tau'] = 1e-3
config['discount_rate'] = 0.99 config['discount_rate'] = 0.99
config['exploration_rate'] = rltorch.scheduler.ExponentialScheduler(initial_value = 1, end_value = 0.1, iterations = 10**5) config['exploration_rate'] = rltorch.scheduler.ExponentialScheduler(initial_value = 1, end_value = 0.1, iterations = 10**5)
config['batch_size'] = 480 # Number of episodes for the computer to train the agent without the human seeing
config['num_sneaky_episodes'] = 20
config['replay_skip'] = 14
config['batch_size'] = 32 * (config['replay_skip'] + 1)
config['disable_cuda'] = False config['disable_cuda'] = False
config['memory_size'] = 10**4 config['memory_size'] = 10**4
# Prioritized vs Random Sampling # Prioritized vs Random Sampling
@ -151,63 +167,73 @@ config['prioritized_replay_sampling_priority'] = 0.6
config['prioritized_replay_weight_importance'] = rltorch.scheduler.ExponentialScheduler(initial_value = 0.4, end_value = 1, iterations = 10**5) config['prioritized_replay_weight_importance'] = rltorch.scheduler.ExponentialScheduler(initial_value = 0.4, end_value = 1, iterations = 10**5)
# Environment name and log directory is vital so show help message and exit if not provided
if args['environment_name'] is None or args['logdir'] is None: if args['environment_name'] is None or args['logdir'] is None:
parser.print_help() parser.print_help()
sys.exit(1) exit(1)
# Number of frames to skip when recording and fps can have sane defaults
if args['skip'] is None: if args['skip'] is None:
args['skip'] = 3 args['skip'] = 3
if args['fps'] is None: if args['fps'] is None:
args['fps'] = 30 args['fps'] = 30
def wrap_preprocessing(env):
def wrap_preprocessing(env, MaxAndSkipEnv = False):
env = E.NoopResetEnv(
E.EpisodicLifeEnv(env),
noop_max = 30
)
if MaxAndSkipEnv:
env = E.MaxAndSkipEnv(env, skip = 4)
return E.ClippedRewardsWrapper( return E.ClippedRewardsWrapper(
E.FrameStack( E.FrameStack(
E.TorchWrap( E.TorchWrap(
E.ProcessFrame84( E.ProcessFrame84(
E.FireResetEnv( E.FireResetEnv(env)
# E.MaxAndSkipEnv(
E.NoopResetEnv(
E.EpisodicLifeEnv(env)
, noop_max = 30)
# , skip=4)
) )
) )
), , 4)
4)
) )
## Starting the game
memory = [] ## Set up environment to be recorded and preprocessed
env = Record(gym.make(args['environment_name']), memory, args, skipframes = args['skip']) memory = PrioritizedReplayMemory(capacity = config['memory_size'], alpha = config['prioritized_replay_sampling_priority'])
memory_lock = Lock()
env = Record(makeEnv(args['environment_name']), memory, memory_lock, args)
# Bind record_env to current env so that we can reference log_transitions easier later
record_env = env record_env = env
env = gym.wrappers.Monitor(env, args['logdir'], force=True) # Use native gym monitor to get video recording
env = GymMonitor(env, args['logdir'], force=True)
# Preprocess enviornment
env = wrap_preprocessing(env) env = wrap_preprocessing(env)
sneaky_env = wrap_preprocessing(gym.make(args['environment_name'])) # Use a different environment for when the computer trains on the side so that the current game state isn't manipuated
# Also use MaxEnvSkip to speed up processing
sneaky_env = wrap_preprocessing(makeEnv(args['environment_name']), MaxAndSkipEnv = True)
# Set seeds
rltorch.set_seed(config['seed']) rltorch.set_seed(config['seed'])
env.seed(config['seed'])
device = torch.device("cuda:0" if torch.cuda.is_available() and not config['disable_cuda'] else "cpu") device = torch.device("cuda:0" if torch.cuda.is_available() and not config['disable_cuda'] else "cpu")
state_size = env.observation_space.shape[0] state_size = env.observation_space.shape[0]
action_size = env.action_space.n action_size = env.action_space.n
# Set up the networks
net = rn.Network(Value(state_size, action_size), net = rn.Network(Value(state_size, action_size),
torch.optim.Adam, config, device = device) Adam, config, device = device)
target_net = rn.TargetNetwork(net, device = device) target_net = rn.TargetNetwork(net, device = device)
# Relevant components from RLTorch
actor = EpsilonGreedySelector(net, action_size, device = device, epsilon = config['exploration_rate']) actor = EpsilonGreedySelector(net, action_size, device = device, epsilon = config['exploration_rate'])
memory = M.PrioritizedReplayMemory(capacity = config['memory_size'], alpha = config['prioritized_replay_sampling_priority'])
agent = rltorch.agents.DQNAgent(net, memory, config, target_net = target_net) agent = rltorch.agents.DQNAgent(net, memory, config, target_net = target_net)
env.seed(config['seed']) # Pass all this information into the thread that will handle the game play and start
playThread = PlayClass(env, actor, memory, memory_lock, agent, sneaky_env, config)
playThread = PlayClass(env, actor, memory, agent, sneaky_env, fps = args['fps'])
playThread.start() playThread.start()
## Logging portion # While the play thread is running, we'll periodically log transitions we've encountered
while playThread.is_alive(): while playThread.is_alive():
playThread.join(60) playThread.join(60)
print("Logging....", end = " ") print("Logging....", end = " ")