2019-06-15 10:23:18 -04:00
|
|
|
import zmq
|
|
|
|
import numpy
|
|
|
|
|
|
|
|
|
|
|
|
# [TODO] Error handling for if server is down
|
|
|
|
class Environment:
|
2019-06-18 22:28:25 -04:00
|
|
|
def __init__(self, proto, address, port):
|
2019-06-15 10:23:18 -04:00
|
|
|
self.context = zmq.Context()
|
|
|
|
self.socket = self.context.socket(zmq.REQ)
|
2019-06-18 22:28:25 -04:00
|
|
|
if proto != "tcp" and proto != "ipc":
|
|
|
|
raise ValueError("proto must be tcp or ipc")
|
|
|
|
self.socket.connect("%s://%s:%s" % (proto, address, port))
|
2019-06-15 10:23:18 -04:00
|
|
|
self.address = address
|
|
|
|
self.port = port
|
|
|
|
self.observation_space, self.action_space, self.reward_range, self.metadata, self.action_meanings = self.get_initial_metadata()
|
|
|
|
|
|
|
|
##
|
|
|
|
# Helper Functions
|
|
|
|
##
|
|
|
|
def get_environment_name(self):
|
|
|
|
self.socket.send_json({'method':'query', 'items':['environment_name']})
|
|
|
|
return self.socket.recv_pyobj()['environment_name']
|
|
|
|
def get_state(self, preprocess = False):
|
|
|
|
self.socket.send_json({'method':'query', 'items':['state'], 'preprocess':preprocess})
|
|
|
|
return self.socket.recv_pyobj()['state']
|
|
|
|
def get_reward(self):
|
|
|
|
self.socket.send_json({'method':'query', 'items':['reward']})
|
|
|
|
return self.socket.recv_pyobj()['reward']
|
|
|
|
def get_score(self):
|
|
|
|
self.socket.send_json({'method':'query', 'items':['cumulative_reward']})
|
|
|
|
return self.socket.recv_pyobj()['cumulative_reward']
|
|
|
|
def get_done(self):
|
|
|
|
self.socket.send_json({'method':'query', 'items':['done']})
|
|
|
|
return self.socket.recv_pyobj()['done']
|
|
|
|
def get_info(self):
|
|
|
|
self.socket.send_json({'method':'query', 'items':['info']})
|
|
|
|
return self.socket.recv_pyobj()['info']
|
|
|
|
def get_observation_space(self):
|
|
|
|
self.socket.send_json({'method':'query', 'items':['observation_space']})
|
|
|
|
return self.socket.recv_pyobj()['observation_space']
|
|
|
|
def get_action_space(self):
|
|
|
|
self.socket.send_json({'method':'query', 'items':['action_space']})
|
|
|
|
return self.socket.recv_pyobj()['action_space']
|
|
|
|
def get_reward_range(self):
|
|
|
|
self.socket.send_json({'method':'query', 'items':['reward_range']})
|
|
|
|
return self.socket.recv_pyobj()['reward_range']
|
|
|
|
def get_metadata(self):
|
|
|
|
self.socket.send_json({'method':'query', 'items':['metadata']})
|
|
|
|
return self.socket.recv_pyobj()['metadata']
|
|
|
|
def get_action_meanings(self):
|
|
|
|
self.socket.send_json({'method':'query', 'items':['action_meanings']})
|
|
|
|
return self.socket.recv_pyobj()['action_meanings']
|
|
|
|
def get_initial_metadata(self):
|
|
|
|
self.socket.send_json({'method':'query', 'items':['observation_space', 'action_space', 'reward_range', 'metadata', 'action_meanings']})
|
|
|
|
content = self.socket.recv_pyobj()
|
|
|
|
return content['observation_space'], content['action_space'], content['reward_range'], content['metadata'], content['action_meanings']
|
|
|
|
|
|
|
|
##
|
|
|
|
# Common API
|
|
|
|
##
|
|
|
|
def reset(self, preprocess = False):
|
|
|
|
self.socket.send_json({'method':'set', 'type':'reset', 'preprocess':preprocess})
|
|
|
|
return self.socket.recv_pyobj()
|
|
|
|
def step(self, action_id, preprocess = False):
|
|
|
|
self.socket.send_json({'method':'set', 'type':'action', 'id': action_id, 'preprocess':preprocess})
|
|
|
|
content = self.socket.recv_pyobj()
|
|
|
|
return content['state'], content['reward'], content['done'], content['info']
|
|
|
|
|
|
|
|
# env = Environment("127.0.0.1", 5000)
|