#!/usr/bin/env python import play import gym from collections import namedtuple from datetime import datetime import pickle import threading from time import sleep import argparse import sys import numpy as np Transition = namedtuple('Transition', ('state', 'action', 'reward', 'next_state', 'done')) class PlayClass(threading.Thread): def __init__(self, env, fps = 60): super(PlayClass, self).__init__() self.env = env self.fps = fps def run(self): play.play(self.env, fps = self.fps, zoom = 4) class Record(gym.Wrapper): def __init__(self, env, memory, args, skipframes = 3): gym.Wrapper.__init__(self, env) self.memory_lock = threading.Lock() self.memory = memory self.args = args self.skipframes = skipframes self.current_i = skipframes def reset(self): return self.env.reset() def step(self, action): self.memory_lock.acquire() state = self.env.env._get_obs() next_state, reward, done, info = self.env.step(action) if self.current_i <= 0: self.memory.append(Transition(state, action, reward, next_state, done)) self.current_i = self.skipframes else: self.current_i -= 1 self.memory_lock.release() return next_state, reward, done, info def log_transitions(self): self.memory_lock.acquire() if len(self.memory) > 0: basename = self.args['logdir'] + "/{}.{}".format(self.args['environment_name'], datetime.now().strftime("%Y-%m-%d-%H-%M-%s")) print("Base Filename: ", basename) state, action, reward, next_state, done = zip(*self.memory) np.save(basename + "-state.npy", np.array(state), allow_pickle = False) np.save(basename + "-action.npy", np.array(action), allow_pickle = False) np.save(basename + "-reward.npy", np.array(reward), allow_pickle = False) np.save(basename + "-nextstate.npy", np.array(next_state), allow_pickle = False) np.save(basename + "-done.npy", np.array(done), allow_pickle = False) self.memory.clear() self.memory_lock.release() ## Parsing arguments parser = argparse.ArgumentParser(description="Play and log the environment") parser.add_argument("--environment_name", type=str, help="The environment name in OpenAI gym to play.") parser.add_argument("--logdir", type=str, help="Directory to log video and (state, action, reward, next_state, done) in.") parser.add_argument("--skip", type=int, help="Number of frames to skip logging.") parser.add_argument("--fps", type=int, help="Number of frames per second") args = vars(parser.parse_args()) if args['environment_name'] is None or args['logdir'] is None: parser.print_help() sys.exit(1) if args['skip'] is None: args['skip'] = 3 if args['fps'] is None: args['fps'] = 30 ## Starting the game memory = [] env = Record(gym.make(args['environment_name']), memory, args, skipframes = args['skip']) env = gym.wrappers.Monitor(env, args['logdir'], force=True) playThread = PlayClass(env, args['fps']) playThread.start() ## Logging portion while playThread.is_alive(): playThread.join(60) print("Logging....", end = " ") env.log_transitions() # Save what's remaining after process died env.log_transitions()