diff --git a/gymclient.py b/gymclient.py index 71feb1a..f75277b 100644 --- a/gymclient.py +++ b/gymclient.py @@ -18,8 +18,9 @@ class Environment: def get_environment_name(self): r = requests.get(self.server + "/environment") return r.text - def get_state(self): - r = requests.get(self.server + "/state") + def get_state(self, preprocess = False): + parameter = "?preprocess" if preprocess else "" + r = requests.get(self.server + "/state" + parameter) return pickle.loads(r.content) def get_reward(self): r = requests.get(self.server + "/reward") @@ -56,11 +57,13 @@ class Environment: ## # Common API ## - def reset(self): - r = requests.get(self.server + "/reset") + def reset(self, preprocess = False): + parameter = "?preprocess" if preprocess else "" + r = requests.get(self.server + "/reset" + parameter) return pickle.loads(r.content) - def step(self, action): - r = requests.post(self.server + "/action", data={'id': action}) + def step(self, action, preprocess = False): + parameter = "?preprocess" if preprocess else "" + r = requests.post(self.server + "/action" + parameter, data={'id': action}) content = pickle.loads(r.content) return content['state'], content['reward'], content['done'], content['info'] diff --git a/gymserver.py b/gymserver.py index cc88f7f..ab96873 100644 --- a/gymserver.py +++ b/gymserver.py @@ -13,27 +13,54 @@ log.setLevel(logging.ERROR) ## # OpenAI Gym State ## -# environment_name = "Acrobot-v1" -environment_name = "Pong-v0" -env = gym.make(environment_name) +class Environment: + def __init__(self, environment_name): + self.sim = gym.make(environment_name) + self.environment_name = environment_name + self.reset() + def step(self, action): + # [TODO] Check to see if 'action' is valid + self.state, self.reward, self.done, self.info = self.sim.step(action) + self.score += self.reward + def reset(self): + self.state = self.sim.reset() + self.reward = 0 + self.score = 0 + self.done = False + self.info = {} + def preprocess(self, state): + raise NotImplementedError + def get_state(self, preprocess = False, pickleobj = False): + state = self.state + if preprocess: + state = self.preprocess(state) + if pickleobj: + state = pickle.dumps(state) + return state -# Observations to release to agent -state = env.reset() -reward = 0 -score = 0 -done = False -info = {} ## -# Helper Functions +# Pong Specific Environment Information ## -# [TODO] Evaluate whether pickling is the right option here -def pickle_state(): - global state - return pickle.dumps(state) +import cv2 +class PongEnv(Environment): + def __init__(self): + super(PongEnv, self).__init__("PongNoFrameskip-v4") + def preprocess(self, state): + # Grayscale + frame = cv2.cvtColor(state, cv2.COLOR_RGB2GRAY) + # Crop irrelevant parts + frame = frame[34:194, 15:145] # Crops to shape (160, 130) + # Downsample + frame = cv2.resize(frame, (80, 80), interpolation=cv2.INTER_AREA) + # Normalize + frame = frame / 255 + return frame +env = PongEnv() + ## # Flask Environment ## @@ -41,100 +68,93 @@ app = Flask(__name__) @app.route('/environment', methods=['GET']) def get_env(): - global env, environment_name + global env if request.args.get('shape') is not None: shape = {} - shape['observation'] = env.observation_space.shape - shape['action'] = env.action_space.n + shape['observation'] = env.sim.observation_space.shape + shape['action'] = env.sim.action_space.n return json.dumps(shape) - return environment_name + return env.environment_name @app.route('/gym', methods=['GET']) def get_extra_data(): global env data = {} if request.args.get('action_space') is not None: - data['action_space'] = env.action_space + data['action_space'] = env.sim.action_space if request.args.get('observation_space') is not None: - data['observation_space'] = env.observation_space + data['observation_space'] = env.sim.observation_space if request.args.get('reward_range') is not None: - data['reward_range'] = env.reward_range + data['reward_range'] = env.sim.reward_range if request.args.get('metadata') is not None: - data['metadata'] = env.metadata + data['metadata'] = env.sim.metadata if request.args.get('action_meanings') is not None: - data['action_meanings'] = env.unwrapped.get_action_meanings() + data['action_meanings'] = env.sim.unwrapped.get_action_meanings() return pickle.dumps(data) @app.route('/action_space', methods=['GET']) def get_action_space(): global env - return pickle.dumps(env.action_space) + return pickle.dumps(env.sim.action_space) @app.route('/observation_space', methods=['GET']) def get_observation_space(): global env - return pickle.dumps(env.observation_space) + return pickle.dumps(env.sim.observation_space) @app.route('/reward_range', methods=['GET']) def get_reward_range(): global env - return pickle.dumps(env.reward_range) + return pickle.dumps(env.sim.reward_range) @app.route('/metadata', methods=['GET']) def get_metadata(): global env - return pickle.dumps(env.metadata) + return pickle.dumps(env.sim.metadata) @app.route('/action_meanings', methods=['GET']) def get_action_meanings(): global env - return pickle.dumps(env.unwrapped.get_action_meanings()) + return pickle.dumps(env.sim.unwrapped.get_action_meanings()) @app.route('/state', methods=['GET']) def get_state(): - return pickle_state() + return env.get_state(pickleobj = True, preprocess = request.args.get('preprocess') is not None) @app.route('/reward', methods=['GET']) def get_reward(): - global score, reward + global env if request.args.get('all') is not None: - return str(score) + return str(env.score) else: - return str(reward) + return str(env.reward) @app.route('/done', methods=['GET']) def is_done(): - global done - return str(done) + global env + return str(env.done) @app.route('/info', methods=['GET']) def get_info(): - global info - return json.dumps(info) + global env + return json.dumps(env.info) @app.route('/action', methods=['POST']) def perform_action(): - global state, reward, done, info, score + global env action = int(request.form['id']) - - # [TODO] Check to see if 'action' is valid - state, reward, done, info = env.step(action) - score += reward + env.step(action) content = {} - content['state'] = state - content['reward'] = reward - content['done'] = done - content['info'] = info + content['state'] = env.get_state(preprocess = request.args.get('preprocess') is not None) + content['reward'] = env.reward + content['done'] = env.done + content['info'] = env.info return pickle.dumps(content) @app.route('/reset') def reset_env(): - global env, state, reward, done, info, score - state = env.reset() - reward = 0 - done = False - info = {} - score = 0 - return pickle_state() + global env + env.reset() + return env.get_state(pickleobj = True, preprocess = request.args.get('preprocess') is not None)