GymRecord/play_env.py

94 lines
3.1 KiB
Python
Raw Normal View History

2019-09-03 07:16:26 -04:00
import play
import gym
from collections import namedtuple
from datetime import datetime
import pickle
import threading
from time import sleep
import argparse
import sys
import numpy as np
Transition = namedtuple('Transition',
('state', 'action', 'reward', 'next_state', 'done'))
class PlayClass(threading.Thread):
def __init__(self, env, fps = 60):
super(PlayClass, self).__init__()
self.env = env
self.fps = fps
def run(self):
play.play(self.env, fps = self.fps, zoom = 4)
class Record(gym.Wrapper):
def __init__(self, env, memory, args, skipframes = 3):
gym.Wrapper.__init__(self, env)
self.memory_lock = threading.Lock()
self.memory = memory
self.args = args
self.skipframes = skipframes
self.current_i = skipframes
def reset(self):
return self.env.reset()
def step(self, action):
self.memory_lock.acquire()
state = self.env.env._get_obs()
next_state, reward, done, info = self.env.step(action)
if self.current_i <= 0:
self.memory.append(Transition(state, action, reward, next_state, done))
self.current_i = self.skipframes
else: self.current_i -= 1
self.memory_lock.release()
return next_state, reward, done, info
def log_transitions(self):
self.memory_lock.acquire()
if len(self.memory) > 0:
basename = self.args['logdir'] + "/{}.{}".format(self.args['environment_name'], datetime.now().strftime("%Y-%m-%d-%H-%M-%s"))
print("Base Filename: ", basename)
state, action, reward, next_state, done = zip(*self.memory)
np.save(basename + "-state.npy", np.array(state), allow_pickle = False)
np.save(basename + "-action.npy", np.array(action), allow_pickle = False)
np.save(basename + "-reward.npy", np.array(reward), allow_pickle = False)
np.save(basename + "-nextstate.npy", np.array(next_state), allow_pickle = False)
np.save(basename + "-done.npy", np.array(done), allow_pickle = False)
self.memory.clear()
self.memory_lock.release()
## Parsing arguments
parser = argparse.ArgumentParser(description="Play and log the environment")
parser.add_argument("--environment_name", type=str, help="The environment name in OpenAI gym to play.")
parser.add_argument("--logdir", type=str, help="Directory to log video and (state, action, reward, next_state, done) in.")
parser.add_argument("--skip", type=int, help="Number of frames to skip logging.")
parser.add_argument("--fps", type=int, help="Number of frames per second")
args = vars(parser.parse_args())
if args['environment_name'] is None or args['logdir'] is None:
parser.print_help()
sys.exit(1)
if args['skip'] is None:
args['skip'] = 3
if args['fps'] is None:
args['fps'] = 30
## Starting the game
memory = []
env = Record(gym.make(args['environment_name']), memory, args, skipframes = args['skip'])
env = gym.wrappers.Monitor(env, args['logdir'], force=True)
playThread = PlayClass(env, args['fps'])
playThread.start()
## Logging portion
while playThread.is_alive():
playThread.join(60)
print("Logging....", end = " ")
env.log_transitions()
# Save what's remaining after process died
env.log_transitions()