Initial Commit for ZeroMQ implentation

This commit is contained in:
Brandon Rozek 2019-06-15 10:23:18 -04:00
commit 52f4d4cbf3
3 changed files with 218 additions and 0 deletions

67
gymclient.py Normal file
View file

@ -0,0 +1,67 @@
import zmq
import numpy
# [TODO] Error handling for if server is down
class Environment:
def __init__(self, address, port):
self.context = zmq.Context()
self.socket = self.context.socket(zmq.REQ)
self.socket.connect("tcp://%s:%s" % (address, port))
self.address = address
self.port = port
self.observation_space, self.action_space, self.reward_range, self.metadata, self.action_meanings = self.get_initial_metadata()
##
# Helper Functions
##
def get_environment_name(self):
self.socket.send_json({'method':'query', 'items':['environment_name']})
return self.socket.recv_pyobj()['environment_name']
def get_state(self, preprocess = False):
self.socket.send_json({'method':'query', 'items':['state'], 'preprocess':preprocess})
return self.socket.recv_pyobj()['state']
def get_reward(self):
self.socket.send_json({'method':'query', 'items':['reward']})
return self.socket.recv_pyobj()['reward']
def get_score(self):
self.socket.send_json({'method':'query', 'items':['cumulative_reward']})
return self.socket.recv_pyobj()['cumulative_reward']
def get_done(self):
self.socket.send_json({'method':'query', 'items':['done']})
return self.socket.recv_pyobj()['done']
def get_info(self):
self.socket.send_json({'method':'query', 'items':['info']})
return self.socket.recv_pyobj()['info']
def get_observation_space(self):
self.socket.send_json({'method':'query', 'items':['observation_space']})
return self.socket.recv_pyobj()['observation_space']
def get_action_space(self):
self.socket.send_json({'method':'query', 'items':['action_space']})
return self.socket.recv_pyobj()['action_space']
def get_reward_range(self):
self.socket.send_json({'method':'query', 'items':['reward_range']})
return self.socket.recv_pyobj()['reward_range']
def get_metadata(self):
self.socket.send_json({'method':'query', 'items':['metadata']})
return self.socket.recv_pyobj()['metadata']
def get_action_meanings(self):
self.socket.send_json({'method':'query', 'items':['action_meanings']})
return self.socket.recv_pyobj()['action_meanings']
def get_initial_metadata(self):
self.socket.send_json({'method':'query', 'items':['observation_space', 'action_space', 'reward_range', 'metadata', 'action_meanings']})
content = self.socket.recv_pyobj()
return content['observation_space'], content['action_space'], content['reward_range'], content['metadata'], content['action_meanings']
##
# Common API
##
def reset(self, preprocess = False):
self.socket.send_json({'method':'set', 'type':'reset', 'preprocess':preprocess})
return self.socket.recv_pyobj()
def step(self, action_id, preprocess = False):
self.socket.send_json({'method':'set', 'type':'action', 'id': action_id, 'preprocess':preprocess})
content = self.socket.recv_pyobj()
return content['state'], content['reward'], content['done'], content['info']
# env = Environment("127.0.0.1", 5000)

145
gymserver.py Normal file
View file

@ -0,0 +1,145 @@
import sys
import gym
import zmq
import json
##
# OpenAI Gym State
##
class Environment:
def __init__(self, environment_name):
self.sim = gym.make(environment_name)
self.environment_name = environment_name
self.reset()
def step(self, action):
# [TODO] Check to see if 'action' is valid
self.state, self.reward, self.done, self.info = self.sim.step(action)
self.score += self.reward
def reset(self):
self.state = self.sim.reset()
self.reward = 0
self.score = 0
self.done = False
self.info = {}
def preprocess(self, state):
raise NotImplementedError
def get_state(self, preprocess = False):
state = self.state
if preprocess:
state = self.preprocess(state)
return state
##
# Pong Specific Environment Information
##
import cv2
class PongEnv(Environment):
def __init__(self):
super(PongEnv, self).__init__("PongNoFrameskip-v4")
def preprocess(self, state):
# Grayscale
frame = cv2.cvtColor(state, cv2.COLOR_RGB2GRAY)
# Crop irrelevant parts
frame = frame[34:194, 15:145] # Crops to shape (160, 130)
# Downsample
frame = cv2.resize(frame, (80, 80), interpolation=cv2.INTER_AREA)
# Normalize
frame = frame / 255
return frame
#
## Response Methods
#
def perform_action(msg, env, socket):
if 'id' not in msg: # [TODO] Include an integer check
socket.send_string("ERROR: 'id' not in message when type is set to 'action'.")
action = int(msg['id'])
env.step(action)
p = msg['preprocess'] is not None and msg['preprocess']
content = {}
content['state'] = env.get_state(preprocess = p)
content['reward'] = env.reward
content['done'] = env.done
content['info'] = env.info
socket.send_pyobj(content)
def reset_env(msg, env, socket):
env.reset()
# Preprocess if part of message is set and equal to True
p = msg['preprocess'] is not None and msg['preprocess']
socket.send_pyobj(env.get_state(preprocess = p))
def respond_query(msg, env, socket):
if 'items' not in msg:
socket.send_string("ERROR: items not found in query message")
# [TODO] Have a check here to make sure msg['items'] is a list
response = {}
for item in msg['items']:
if item == "environment_name":
response[item] = env.environment_name
elif item == "action_space":
response[item] = env.sim.action_space
elif item == "observation_space":
response[item] = env.sim.observation_space
elif item == "reward_range":
response[item] = env.sim.reward_range
elif item == "metadata":
response[item] = env.sim.metadata
elif item == "action_meanings":
response[item] = env.sim.unwrapped.get_action_meanings()
elif item == "state":
p = msg['preprocess'] is not None and msg['preprocess']
response[item] = env.get_state(preprocess = p)
elif item == "cumulative_reward":
response[item] = env.score
elif item == "reward":
response[item] = env.reward
elif item == "done":
response[item] = env.done
elif item == "info":
response[item] = env.info
socket.send_pyobj(response)
def respond_set(msg, env, socket):
if 'type' not in msg:
socket.send_string("ERROR: type not found in JSON message.")
if msg['type'] == 'reset':
reset_env(msg, env, socket)
elif msg['type'] == 'action':
perform_action(msg, env, socket)
pass
def respond(msg, env, socket):
if 'method' not in msg:
socket.send_string("ERROR: method not found in JSON message.")
# From here establish what type of message it is....
# Maybe create a switch statement of some sort that sends it off to each of the specialized functions already created below
if msg['method'] == 'query':
respond_query(msg, env, socket)
elif msg['method'] == "set":
respond_set(msg, env, socket)
else:
socket.send_string("ERROR: method is not 'query' or 'set'.")
#
## Main Routine
#
# [TODO] Make sure port is an int
if len(sys.argv) != 2:
print("Usage: gymserver_zero.py <port>", file=sys.stderr)
sys.exit(1)
env = PongEnv()
port = int(sys.argv[1])
context = zmq.Context()
socket = context.socket(zmq.REP)
socket.bind("tcp://*:%s" % port)
while True:
msg = socket.recv_json()
respond(msg, env, socket)

6
start_servers.sh Executable file
View file

@ -0,0 +1,6 @@
#!/bin/bash
export FLASK_APP=gymserver.py
for i in {0..31}
do
python gymserver.py $((5000 + i)) &
done