2019-06-05 21:56:01 -04:00
|
|
|
import pickle
|
|
|
|
import numpy
|
|
|
|
import requests
|
|
|
|
|
|
|
|
|
|
|
|
# [TODO] Error handling for if server is down
|
|
|
|
class Environment:
|
|
|
|
def __init__(self, address, port, ssl = False):
|
|
|
|
self.address = address
|
|
|
|
self.port = port
|
|
|
|
protocol = "https://" if ssl else "http://"
|
2019-06-06 21:49:12 -04:00
|
|
|
self.server = protocol + address + ":" + str(port)
|
|
|
|
self.observation_space, self.action_space, self.reward_range, self.metadata, self.action_meanings = self.get_initial_metadata()
|
2019-06-05 21:56:01 -04:00
|
|
|
|
|
|
|
##
|
|
|
|
# Helper Functions
|
|
|
|
##
|
|
|
|
def get_environment_name(self):
|
|
|
|
r = requests.get(self.server + "/environment")
|
|
|
|
return r.text
|
2019-06-08 14:52:56 -04:00
|
|
|
def get_state(self, preprocess = False):
|
|
|
|
parameter = "?preprocess" if preprocess else ""
|
|
|
|
r = requests.get(self.server + "/state" + parameter)
|
2019-06-05 21:56:01 -04:00
|
|
|
return pickle.loads(r.content)
|
|
|
|
def get_reward(self):
|
|
|
|
r = requests.get(self.server + "/reward")
|
|
|
|
return float(r.text)
|
|
|
|
def get_score(self):
|
|
|
|
r = requests.get(self.server + "/reward", params = {'all':''})
|
|
|
|
return float(r.text)
|
|
|
|
def get_done(self):
|
|
|
|
r = requests.get(self.server + "/done")
|
|
|
|
return r.text == "True"
|
|
|
|
def get_info(self):
|
|
|
|
r = requests.get(self.server + "/info")
|
|
|
|
return r.json()
|
2019-06-06 21:49:12 -04:00
|
|
|
def get_observation_space(self):
|
|
|
|
r = requests.get(self.server + '/observation_space')
|
|
|
|
return pickle.loads(r.content)
|
|
|
|
def get_action_space(self):
|
|
|
|
r = requests.get(self.server + '/action_space')
|
|
|
|
return pickle.loads(r.content)
|
|
|
|
def get_reward_range(self):
|
|
|
|
r = requests.get(self.server + '/reward_range')
|
|
|
|
return pickle.loads(r.content)
|
|
|
|
def get_metadata(self):
|
|
|
|
r = requests.get(self.server + '/metadata')
|
|
|
|
return pickle.loads(r.content)
|
|
|
|
def get_action_meanings(self):
|
|
|
|
r = requests.get(self.server + '/action_meanings')
|
|
|
|
return pickle.loads(r.content)
|
|
|
|
def get_initial_metadata(self):
|
|
|
|
r = requests.get(self.server + '/gym?observation_space&action_space&reward_range&metadata&action_meanings')
|
|
|
|
content = pickle.loads(r.content)
|
|
|
|
return content['observation_space'], content['action_space'], content['reward_range'], content['metadata'], content['action_meanings']
|
2019-06-05 21:56:01 -04:00
|
|
|
|
|
|
|
##
|
|
|
|
# Common API
|
|
|
|
##
|
2019-06-08 14:52:56 -04:00
|
|
|
def reset(self, preprocess = False):
|
|
|
|
parameter = "?preprocess" if preprocess else ""
|
|
|
|
r = requests.get(self.server + "/reset" + parameter)
|
2019-06-05 21:56:01 -04:00
|
|
|
return pickle.loads(r.content)
|
2019-06-08 14:52:56 -04:00
|
|
|
def step(self, action, preprocess = False):
|
|
|
|
parameter = "?preprocess" if preprocess else ""
|
|
|
|
r = requests.post(self.server + "/action" + parameter, data={'id': action})
|
2019-06-06 21:49:12 -04:00
|
|
|
content = pickle.loads(r.content)
|
|
|
|
return content['state'], content['reward'], content['done'], content['info']
|
2019-06-05 21:56:01 -04:00
|
|
|
|
|
|
|
# env = Environment("127.0.0.1", 5000)
|