From 52f4d4cbf3e36bf38e61acb020aa265c79dcf77f Mon Sep 17 00:00:00 2001 From: Brandon Rozek Date: Sat, 15 Jun 2019 10:23:18 -0400 Subject: [PATCH] Initial Commit for ZeroMQ implentation --- gymclient.py | 67 ++++++++++++++++++++++ gymserver.py | 145 +++++++++++++++++++++++++++++++++++++++++++++++ start_servers.sh | 6 ++ 3 files changed, 218 insertions(+) create mode 100644 gymclient.py create mode 100644 gymserver.py create mode 100755 start_servers.sh diff --git a/gymclient.py b/gymclient.py new file mode 100644 index 0000000..bdac868 --- /dev/null +++ b/gymclient.py @@ -0,0 +1,67 @@ +import zmq +import numpy + + +# [TODO] Error handling for if server is down +class Environment: + def __init__(self, address, port): + self.context = zmq.Context() + self.socket = self.context.socket(zmq.REQ) + self.socket.connect("tcp://%s:%s" % (address, port)) + self.address = address + self.port = port + self.observation_space, self.action_space, self.reward_range, self.metadata, self.action_meanings = self.get_initial_metadata() + + ## + # Helper Functions + ## + def get_environment_name(self): + self.socket.send_json({'method':'query', 'items':['environment_name']}) + return self.socket.recv_pyobj()['environment_name'] + def get_state(self, preprocess = False): + self.socket.send_json({'method':'query', 'items':['state'], 'preprocess':preprocess}) + return self.socket.recv_pyobj()['state'] + def get_reward(self): + self.socket.send_json({'method':'query', 'items':['reward']}) + return self.socket.recv_pyobj()['reward'] + def get_score(self): + self.socket.send_json({'method':'query', 'items':['cumulative_reward']}) + return self.socket.recv_pyobj()['cumulative_reward'] + def get_done(self): + self.socket.send_json({'method':'query', 'items':['done']}) + return self.socket.recv_pyobj()['done'] + def get_info(self): + self.socket.send_json({'method':'query', 'items':['info']}) + return self.socket.recv_pyobj()['info'] + def get_observation_space(self): + self.socket.send_json({'method':'query', 'items':['observation_space']}) + return self.socket.recv_pyobj()['observation_space'] + def get_action_space(self): + self.socket.send_json({'method':'query', 'items':['action_space']}) + return self.socket.recv_pyobj()['action_space'] + def get_reward_range(self): + self.socket.send_json({'method':'query', 'items':['reward_range']}) + return self.socket.recv_pyobj()['reward_range'] + def get_metadata(self): + self.socket.send_json({'method':'query', 'items':['metadata']}) + return self.socket.recv_pyobj()['metadata'] + def get_action_meanings(self): + self.socket.send_json({'method':'query', 'items':['action_meanings']}) + return self.socket.recv_pyobj()['action_meanings'] + def get_initial_metadata(self): + self.socket.send_json({'method':'query', 'items':['observation_space', 'action_space', 'reward_range', 'metadata', 'action_meanings']}) + content = self.socket.recv_pyobj() + return content['observation_space'], content['action_space'], content['reward_range'], content['metadata'], content['action_meanings'] + + ## + # Common API + ## + def reset(self, preprocess = False): + self.socket.send_json({'method':'set', 'type':'reset', 'preprocess':preprocess}) + return self.socket.recv_pyobj() + def step(self, action_id, preprocess = False): + self.socket.send_json({'method':'set', 'type':'action', 'id': action_id, 'preprocess':preprocess}) + content = self.socket.recv_pyobj() + return content['state'], content['reward'], content['done'], content['info'] + +# env = Environment("127.0.0.1", 5000) \ No newline at end of file diff --git a/gymserver.py b/gymserver.py new file mode 100644 index 0000000..c689f2f --- /dev/null +++ b/gymserver.py @@ -0,0 +1,145 @@ +import sys +import gym +import zmq +import json + +## +# OpenAI Gym State +## +class Environment: + def __init__(self, environment_name): + self.sim = gym.make(environment_name) + self.environment_name = environment_name + self.reset() + def step(self, action): + # [TODO] Check to see if 'action' is valid + self.state, self.reward, self.done, self.info = self.sim.step(action) + self.score += self.reward + def reset(self): + self.state = self.sim.reset() + self.reward = 0 + self.score = 0 + self.done = False + self.info = {} + def preprocess(self, state): + raise NotImplementedError + def get_state(self, preprocess = False): + state = self.state + if preprocess: + state = self.preprocess(state) + return state + + +## +# Pong Specific Environment Information +## +import cv2 +class PongEnv(Environment): + def __init__(self): + super(PongEnv, self).__init__("PongNoFrameskip-v4") + def preprocess(self, state): + # Grayscale + frame = cv2.cvtColor(state, cv2.COLOR_RGB2GRAY) + # Crop irrelevant parts + frame = frame[34:194, 15:145] # Crops to shape (160, 130) + # Downsample + frame = cv2.resize(frame, (80, 80), interpolation=cv2.INTER_AREA) + # Normalize + frame = frame / 255 + return frame + +# +## Response Methods +# +def perform_action(msg, env, socket): + if 'id' not in msg: # [TODO] Include an integer check + socket.send_string("ERROR: 'id' not in message when type is set to 'action'.") + + action = int(msg['id']) + env.step(action) + + p = msg['preprocess'] is not None and msg['preprocess'] + content = {} + content['state'] = env.get_state(preprocess = p) + content['reward'] = env.reward + content['done'] = env.done + content['info'] = env.info + socket.send_pyobj(content) + +def reset_env(msg, env, socket): + env.reset() + # Preprocess if part of message is set and equal to True + p = msg['preprocess'] is not None and msg['preprocess'] + socket.send_pyobj(env.get_state(preprocess = p)) + +def respond_query(msg, env, socket): + if 'items' not in msg: + socket.send_string("ERROR: items not found in query message") + # [TODO] Have a check here to make sure msg['items'] is a list + response = {} + for item in msg['items']: + if item == "environment_name": + response[item] = env.environment_name + elif item == "action_space": + response[item] = env.sim.action_space + elif item == "observation_space": + response[item] = env.sim.observation_space + elif item == "reward_range": + response[item] = env.sim.reward_range + elif item == "metadata": + response[item] = env.sim.metadata + elif item == "action_meanings": + response[item] = env.sim.unwrapped.get_action_meanings() + elif item == "state": + p = msg['preprocess'] is not None and msg['preprocess'] + response[item] = env.get_state(preprocess = p) + elif item == "cumulative_reward": + response[item] = env.score + elif item == "reward": + response[item] = env.reward + elif item == "done": + response[item] = env.done + elif item == "info": + response[item] = env.info + socket.send_pyobj(response) + +def respond_set(msg, env, socket): + if 'type' not in msg: + socket.send_string("ERROR: type not found in JSON message.") + if msg['type'] == 'reset': + reset_env(msg, env, socket) + elif msg['type'] == 'action': + perform_action(msg, env, socket) + pass + +def respond(msg, env, socket): + if 'method' not in msg: + socket.send_string("ERROR: method not found in JSON message.") + # From here establish what type of message it is.... + # Maybe create a switch statement of some sort that sends it off to each of the specialized functions already created below + if msg['method'] == 'query': + respond_query(msg, env, socket) + elif msg['method'] == "set": + respond_set(msg, env, socket) + else: + socket.send_string("ERROR: method is not 'query' or 'set'.") + +# +## Main Routine +# + +# [TODO] Make sure port is an int +if len(sys.argv) != 2: + print("Usage: gymserver_zero.py ", file=sys.stderr) + sys.exit(1) + +env = PongEnv() +port = int(sys.argv[1]) +context = zmq.Context() +socket = context.socket(zmq.REP) +socket.bind("tcp://*:%s" % port) +while True: + msg = socket.recv_json() + respond(msg, env, socket) + + diff --git a/start_servers.sh b/start_servers.sh new file mode 100755 index 0000000..5a80db8 --- /dev/null +++ b/start_servers.sh @@ -0,0 +1,6 @@ +#!/bin/bash +export FLASK_APP=gymserver.py +for i in {0..31} +do + python gymserver.py $((5000 + i)) & +done