diff --git a/gymclient.py b/gymclient.py index 0c4d8fb..71feb1a 100644 --- a/gymclient.py +++ b/gymclient.py @@ -9,7 +9,8 @@ class Environment: self.address = address self.port = port protocol = "https://" if ssl else "http://" - self.server = protocol + address + ":" + str(port) + self.server = protocol + address + ":" + str(port) + self.observation_space, self.action_space, self.reward_range, self.metadata, self.action_meanings = self.get_initial_metadata() ## # Helper Functions @@ -32,6 +33,25 @@ class Environment: def get_info(self): r = requests.get(self.server + "/info") return r.json() + def get_observation_space(self): + r = requests.get(self.server + '/observation_space') + return pickle.loads(r.content) + def get_action_space(self): + r = requests.get(self.server + '/action_space') + return pickle.loads(r.content) + def get_reward_range(self): + r = requests.get(self.server + '/reward_range') + return pickle.loads(r.content) + def get_metadata(self): + r = requests.get(self.server + '/metadata') + return pickle.loads(r.content) + def get_action_meanings(self): + r = requests.get(self.server + '/action_meanings') + return pickle.loads(r.content) + def get_initial_metadata(self): + r = requests.get(self.server + '/gym?observation_space&action_space&reward_range&metadata&action_meanings') + content = pickle.loads(r.content) + return content['observation_space'], content['action_space'], content['reward_range'], content['metadata'], content['action_meanings'] ## # Common API @@ -41,7 +61,7 @@ class Environment: return pickle.loads(r.content) def step(self, action): r = requests.post(self.server + "/action", data={'id': action}) - content = r.json() - return self.get_state(), float(content['reward']), content['done'] == "True", content['info'] + content = pickle.loads(r.content) + return content['state'], content['reward'], content['done'], content['info'] # env = Environment("127.0.0.1", 5000) \ No newline at end of file diff --git a/gymserver.py b/gymserver.py index 7a74370..cc88f7f 100644 --- a/gymserver.py +++ b/gymserver.py @@ -5,11 +5,14 @@ from flask import request import pickle import json +# Make it so that it doesn't log every HTTP request +import logging +log = logging.getLogger('werkzeug') +log.setLevel(logging.ERROR) ## # OpenAI Gym State ## -# environment_name = sys.argv[1] # environment_name = "Acrobot-v1" environment_name = "Pong-v0" env = gym.make(environment_name) @@ -46,6 +49,47 @@ def get_env(): return json.dumps(shape) return environment_name +@app.route('/gym', methods=['GET']) +def get_extra_data(): + global env + data = {} + if request.args.get('action_space') is not None: + data['action_space'] = env.action_space + if request.args.get('observation_space') is not None: + data['observation_space'] = env.observation_space + if request.args.get('reward_range') is not None: + data['reward_range'] = env.reward_range + if request.args.get('metadata') is not None: + data['metadata'] = env.metadata + if request.args.get('action_meanings') is not None: + data['action_meanings'] = env.unwrapped.get_action_meanings() + return pickle.dumps(data) + +@app.route('/action_space', methods=['GET']) +def get_action_space(): + global env + return pickle.dumps(env.action_space) + +@app.route('/observation_space', methods=['GET']) +def get_observation_space(): + global env + return pickle.dumps(env.observation_space) + +@app.route('/reward_range', methods=['GET']) +def get_reward_range(): + global env + return pickle.dumps(env.reward_range) + +@app.route('/metadata', methods=['GET']) +def get_metadata(): + global env + return pickle.dumps(env.metadata) + +@app.route('/action_meanings', methods=['GET']) +def get_action_meanings(): + global env + return pickle.dumps(env.unwrapped.get_action_meanings()) + @app.route('/state', methods=['GET']) def get_state(): return pickle_state() @@ -78,10 +122,11 @@ def perform_action(): score += reward content = {} + content['state'] = state content['reward'] = reward content['done'] = done content['info'] = info - return json.dumps(content) + return pickle.dumps(content) @app.route('/reset') def reset_env():