GymRecord/gymrecord

95 lines
3.1 KiB
Python

#!/usr/bin/env python
import play
import gym
from collections import namedtuple
from datetime import datetime
import pickle
import threading
from time import sleep
import argparse
import sys
import numpy as np
Transition = namedtuple('Transition',
('state', 'action', 'reward', 'next_state', 'done'))
class PlayClass(threading.Thread):
def __init__(self, env, fps = 60):
super(PlayClass, self).__init__()
self.env = env
self.fps = fps
def run(self):
play.play(self.env, fps = self.fps, zoom = 4)
class Record(gym.Wrapper):
def __init__(self, env, memory, args, skipframes = 3):
gym.Wrapper.__init__(self, env)
self.memory_lock = threading.Lock()
self.memory = memory
self.args = args
self.skipframes = skipframes
self.current_i = skipframes
def reset(self):
return self.env.reset()
def step(self, action):
self.memory_lock.acquire()
state = self.env.env._get_obs()
next_state, reward, done, info = self.env.step(action)
if self.current_i <= 0:
self.memory.append(Transition(state, action, reward, next_state, done))
self.current_i = self.skipframes
else: self.current_i -= 1
self.memory_lock.release()
return next_state, reward, done, info
def log_transitions(self):
self.memory_lock.acquire()
if len(self.memory) > 0:
basename = self.args['logdir'] + "/{}.{}".format(self.args['environment_name'], datetime.now().strftime("%Y-%m-%d-%H-%M-%s"))
print("Base Filename: ", basename)
state, action, reward, next_state, done = zip(*self.memory)
np.save(basename + "-state.npy", np.array(state), allow_pickle = False)
np.save(basename + "-action.npy", np.array(action), allow_pickle = False)
np.save(basename + "-reward.npy", np.array(reward), allow_pickle = False)
np.save(basename + "-nextstate.npy", np.array(next_state), allow_pickle = False)
np.save(basename + "-done.npy", np.array(done), allow_pickle = False)
self.memory.clear()
self.memory_lock.release()
## Parsing arguments
parser = argparse.ArgumentParser(description="Play and log the environment")
parser.add_argument("--environment_name", type=str, help="The environment name in OpenAI gym to play.")
parser.add_argument("--logdir", type=str, help="Directory to log video and (state, action, reward, next_state, done) in.")
parser.add_argument("--skip", type=int, help="Number of frames to skip logging.")
parser.add_argument("--fps", type=int, help="Number of frames per second")
args = vars(parser.parse_args())
if args['environment_name'] is None or args['logdir'] is None:
parser.print_help()
sys.exit(1)
if args['skip'] is None:
args['skip'] = 3
if args['fps'] is None:
args['fps'] = 30
## Starting the game
memory = []
env = Record(gym.make(args['environment_name']), memory, args, skipframes = args['skip'])
env = gym.wrappers.Monitor(env, args['logdir'], force=True)
playThread = PlayClass(env, args['fps'])
playThread.start()
## Logging portion
while playThread.is_alive():
playThread.join(60)
print("Logging....", end = " ")
env.log_transitions()
# Save what's remaining after process died
env.log_transitions()