Initial Commit for ZeroMQ implentation
This commit is contained in:
commit
52f4d4cbf3
3 changed files with 218 additions and 0 deletions
67
gymclient.py
Normal file
67
gymclient.py
Normal file
|
@ -0,0 +1,67 @@
|
|||
import zmq
|
||||
import numpy
|
||||
|
||||
|
||||
# [TODO] Error handling for if server is down
|
||||
class Environment:
|
||||
def __init__(self, address, port):
|
||||
self.context = zmq.Context()
|
||||
self.socket = self.context.socket(zmq.REQ)
|
||||
self.socket.connect("tcp://%s:%s" % (address, port))
|
||||
self.address = address
|
||||
self.port = port
|
||||
self.observation_space, self.action_space, self.reward_range, self.metadata, self.action_meanings = self.get_initial_metadata()
|
||||
|
||||
##
|
||||
# Helper Functions
|
||||
##
|
||||
def get_environment_name(self):
|
||||
self.socket.send_json({'method':'query', 'items':['environment_name']})
|
||||
return self.socket.recv_pyobj()['environment_name']
|
||||
def get_state(self, preprocess = False):
|
||||
self.socket.send_json({'method':'query', 'items':['state'], 'preprocess':preprocess})
|
||||
return self.socket.recv_pyobj()['state']
|
||||
def get_reward(self):
|
||||
self.socket.send_json({'method':'query', 'items':['reward']})
|
||||
return self.socket.recv_pyobj()['reward']
|
||||
def get_score(self):
|
||||
self.socket.send_json({'method':'query', 'items':['cumulative_reward']})
|
||||
return self.socket.recv_pyobj()['cumulative_reward']
|
||||
def get_done(self):
|
||||
self.socket.send_json({'method':'query', 'items':['done']})
|
||||
return self.socket.recv_pyobj()['done']
|
||||
def get_info(self):
|
||||
self.socket.send_json({'method':'query', 'items':['info']})
|
||||
return self.socket.recv_pyobj()['info']
|
||||
def get_observation_space(self):
|
||||
self.socket.send_json({'method':'query', 'items':['observation_space']})
|
||||
return self.socket.recv_pyobj()['observation_space']
|
||||
def get_action_space(self):
|
||||
self.socket.send_json({'method':'query', 'items':['action_space']})
|
||||
return self.socket.recv_pyobj()['action_space']
|
||||
def get_reward_range(self):
|
||||
self.socket.send_json({'method':'query', 'items':['reward_range']})
|
||||
return self.socket.recv_pyobj()['reward_range']
|
||||
def get_metadata(self):
|
||||
self.socket.send_json({'method':'query', 'items':['metadata']})
|
||||
return self.socket.recv_pyobj()['metadata']
|
||||
def get_action_meanings(self):
|
||||
self.socket.send_json({'method':'query', 'items':['action_meanings']})
|
||||
return self.socket.recv_pyobj()['action_meanings']
|
||||
def get_initial_metadata(self):
|
||||
self.socket.send_json({'method':'query', 'items':['observation_space', 'action_space', 'reward_range', 'metadata', 'action_meanings']})
|
||||
content = self.socket.recv_pyobj()
|
||||
return content['observation_space'], content['action_space'], content['reward_range'], content['metadata'], content['action_meanings']
|
||||
|
||||
##
|
||||
# Common API
|
||||
##
|
||||
def reset(self, preprocess = False):
|
||||
self.socket.send_json({'method':'set', 'type':'reset', 'preprocess':preprocess})
|
||||
return self.socket.recv_pyobj()
|
||||
def step(self, action_id, preprocess = False):
|
||||
self.socket.send_json({'method':'set', 'type':'action', 'id': action_id, 'preprocess':preprocess})
|
||||
content = self.socket.recv_pyobj()
|
||||
return content['state'], content['reward'], content['done'], content['info']
|
||||
|
||||
# env = Environment("127.0.0.1", 5000)
|
145
gymserver.py
Normal file
145
gymserver.py
Normal file
|
@ -0,0 +1,145 @@
|
|||
import sys
|
||||
import gym
|
||||
import zmq
|
||||
import json
|
||||
|
||||
##
|
||||
# OpenAI Gym State
|
||||
##
|
||||
class Environment:
|
||||
def __init__(self, environment_name):
|
||||
self.sim = gym.make(environment_name)
|
||||
self.environment_name = environment_name
|
||||
self.reset()
|
||||
def step(self, action):
|
||||
# [TODO] Check to see if 'action' is valid
|
||||
self.state, self.reward, self.done, self.info = self.sim.step(action)
|
||||
self.score += self.reward
|
||||
def reset(self):
|
||||
self.state = self.sim.reset()
|
||||
self.reward = 0
|
||||
self.score = 0
|
||||
self.done = False
|
||||
self.info = {}
|
||||
def preprocess(self, state):
|
||||
raise NotImplementedError
|
||||
def get_state(self, preprocess = False):
|
||||
state = self.state
|
||||
if preprocess:
|
||||
state = self.preprocess(state)
|
||||
return state
|
||||
|
||||
|
||||
##
|
||||
# Pong Specific Environment Information
|
||||
##
|
||||
import cv2
|
||||
class PongEnv(Environment):
|
||||
def __init__(self):
|
||||
super(PongEnv, self).__init__("PongNoFrameskip-v4")
|
||||
def preprocess(self, state):
|
||||
# Grayscale
|
||||
frame = cv2.cvtColor(state, cv2.COLOR_RGB2GRAY)
|
||||
# Crop irrelevant parts
|
||||
frame = frame[34:194, 15:145] # Crops to shape (160, 130)
|
||||
# Downsample
|
||||
frame = cv2.resize(frame, (80, 80), interpolation=cv2.INTER_AREA)
|
||||
# Normalize
|
||||
frame = frame / 255
|
||||
return frame
|
||||
|
||||
#
|
||||
## Response Methods
|
||||
#
|
||||
def perform_action(msg, env, socket):
|
||||
if 'id' not in msg: # [TODO] Include an integer check
|
||||
socket.send_string("ERROR: 'id' not in message when type is set to 'action'.")
|
||||
|
||||
action = int(msg['id'])
|
||||
env.step(action)
|
||||
|
||||
p = msg['preprocess'] is not None and msg['preprocess']
|
||||
content = {}
|
||||
content['state'] = env.get_state(preprocess = p)
|
||||
content['reward'] = env.reward
|
||||
content['done'] = env.done
|
||||
content['info'] = env.info
|
||||
socket.send_pyobj(content)
|
||||
|
||||
def reset_env(msg, env, socket):
|
||||
env.reset()
|
||||
# Preprocess if part of message is set and equal to True
|
||||
p = msg['preprocess'] is not None and msg['preprocess']
|
||||
socket.send_pyobj(env.get_state(preprocess = p))
|
||||
|
||||
def respond_query(msg, env, socket):
|
||||
if 'items' not in msg:
|
||||
socket.send_string("ERROR: items not found in query message")
|
||||
# [TODO] Have a check here to make sure msg['items'] is a list
|
||||
response = {}
|
||||
for item in msg['items']:
|
||||
if item == "environment_name":
|
||||
response[item] = env.environment_name
|
||||
elif item == "action_space":
|
||||
response[item] = env.sim.action_space
|
||||
elif item == "observation_space":
|
||||
response[item] = env.sim.observation_space
|
||||
elif item == "reward_range":
|
||||
response[item] = env.sim.reward_range
|
||||
elif item == "metadata":
|
||||
response[item] = env.sim.metadata
|
||||
elif item == "action_meanings":
|
||||
response[item] = env.sim.unwrapped.get_action_meanings()
|
||||
elif item == "state":
|
||||
p = msg['preprocess'] is not None and msg['preprocess']
|
||||
response[item] = env.get_state(preprocess = p)
|
||||
elif item == "cumulative_reward":
|
||||
response[item] = env.score
|
||||
elif item == "reward":
|
||||
response[item] = env.reward
|
||||
elif item == "done":
|
||||
response[item] = env.done
|
||||
elif item == "info":
|
||||
response[item] = env.info
|
||||
socket.send_pyobj(response)
|
||||
|
||||
def respond_set(msg, env, socket):
|
||||
if 'type' not in msg:
|
||||
socket.send_string("ERROR: type not found in JSON message.")
|
||||
if msg['type'] == 'reset':
|
||||
reset_env(msg, env, socket)
|
||||
elif msg['type'] == 'action':
|
||||
perform_action(msg, env, socket)
|
||||
pass
|
||||
|
||||
def respond(msg, env, socket):
|
||||
if 'method' not in msg:
|
||||
socket.send_string("ERROR: method not found in JSON message.")
|
||||
# From here establish what type of message it is....
|
||||
# Maybe create a switch statement of some sort that sends it off to each of the specialized functions already created below
|
||||
if msg['method'] == 'query':
|
||||
respond_query(msg, env, socket)
|
||||
elif msg['method'] == "set":
|
||||
respond_set(msg, env, socket)
|
||||
else:
|
||||
socket.send_string("ERROR: method is not 'query' or 'set'.")
|
||||
|
||||
#
|
||||
## Main Routine
|
||||
#
|
||||
|
||||
# [TODO] Make sure port is an int
|
||||
if len(sys.argv) != 2:
|
||||
print("Usage: gymserver_zero.py <port>", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
env = PongEnv()
|
||||
port = int(sys.argv[1])
|
||||
context = zmq.Context()
|
||||
socket = context.socket(zmq.REP)
|
||||
socket.bind("tcp://*:%s" % port)
|
||||
while True:
|
||||
msg = socket.recv_json()
|
||||
respond(msg, env, socket)
|
||||
|
||||
|
6
start_servers.sh
Executable file
6
start_servers.sh
Executable file
|
@ -0,0 +1,6 @@
|
|||
#!/bin/bash
|
||||
export FLASK_APP=gymserver.py
|
||||
for i in {0..31}
|
||||
do
|
||||
python gymserver.py $((5000 + i)) &
|
||||
done
|
Loading…
Reference in a new issue