Encapsulated State better and added a preprocessing function
This commit is contained in:
parent
147d682c1c
commit
18ad080026
2 changed files with 81 additions and 58 deletions
15
gymclient.py
15
gymclient.py
|
@ -18,8 +18,9 @@ class Environment:
|
||||||
def get_environment_name(self):
|
def get_environment_name(self):
|
||||||
r = requests.get(self.server + "/environment")
|
r = requests.get(self.server + "/environment")
|
||||||
return r.text
|
return r.text
|
||||||
def get_state(self):
|
def get_state(self, preprocess = False):
|
||||||
r = requests.get(self.server + "/state")
|
parameter = "?preprocess" if preprocess else ""
|
||||||
|
r = requests.get(self.server + "/state" + parameter)
|
||||||
return pickle.loads(r.content)
|
return pickle.loads(r.content)
|
||||||
def get_reward(self):
|
def get_reward(self):
|
||||||
r = requests.get(self.server + "/reward")
|
r = requests.get(self.server + "/reward")
|
||||||
|
@ -56,11 +57,13 @@ class Environment:
|
||||||
##
|
##
|
||||||
# Common API
|
# Common API
|
||||||
##
|
##
|
||||||
def reset(self):
|
def reset(self, preprocess = False):
|
||||||
r = requests.get(self.server + "/reset")
|
parameter = "?preprocess" if preprocess else ""
|
||||||
|
r = requests.get(self.server + "/reset" + parameter)
|
||||||
return pickle.loads(r.content)
|
return pickle.loads(r.content)
|
||||||
def step(self, action):
|
def step(self, action, preprocess = False):
|
||||||
r = requests.post(self.server + "/action", data={'id': action})
|
parameter = "?preprocess" if preprocess else ""
|
||||||
|
r = requests.post(self.server + "/action" + parameter, data={'id': action})
|
||||||
content = pickle.loads(r.content)
|
content = pickle.loads(r.content)
|
||||||
return content['state'], content['reward'], content['done'], content['info']
|
return content['state'], content['reward'], content['done'], content['info']
|
||||||
|
|
||||||
|
|
124
gymserver.py
124
gymserver.py
|
@ -13,27 +13,54 @@ log.setLevel(logging.ERROR)
|
||||||
##
|
##
|
||||||
# OpenAI Gym State
|
# OpenAI Gym State
|
||||||
##
|
##
|
||||||
# environment_name = "Acrobot-v1"
|
class Environment:
|
||||||
environment_name = "Pong-v0"
|
def __init__(self, environment_name):
|
||||||
env = gym.make(environment_name)
|
self.sim = gym.make(environment_name)
|
||||||
|
self.environment_name = environment_name
|
||||||
|
self.reset()
|
||||||
|
def step(self, action):
|
||||||
|
# [TODO] Check to see if 'action' is valid
|
||||||
|
self.state, self.reward, self.done, self.info = self.sim.step(action)
|
||||||
|
self.score += self.reward
|
||||||
|
def reset(self):
|
||||||
|
self.state = self.sim.reset()
|
||||||
|
self.reward = 0
|
||||||
|
self.score = 0
|
||||||
|
self.done = False
|
||||||
|
self.info = {}
|
||||||
|
def preprocess(self, state):
|
||||||
|
raise NotImplementedError
|
||||||
|
def get_state(self, preprocess = False, pickleobj = False):
|
||||||
|
state = self.state
|
||||||
|
if preprocess:
|
||||||
|
state = self.preprocess(state)
|
||||||
|
if pickleobj:
|
||||||
|
state = pickle.dumps(state)
|
||||||
|
return state
|
||||||
|
|
||||||
# Observations to release to agent
|
|
||||||
state = env.reset()
|
|
||||||
reward = 0
|
|
||||||
score = 0
|
|
||||||
done = False
|
|
||||||
info = {}
|
|
||||||
|
|
||||||
##
|
##
|
||||||
# Helper Functions
|
# Pong Specific Environment Information
|
||||||
##
|
##
|
||||||
# [TODO] Evaluate whether pickling is the right option here
|
import cv2
|
||||||
def pickle_state():
|
class PongEnv(Environment):
|
||||||
global state
|
def __init__(self):
|
||||||
return pickle.dumps(state)
|
super(PongEnv, self).__init__("PongNoFrameskip-v4")
|
||||||
|
def preprocess(self, state):
|
||||||
|
# Grayscale
|
||||||
|
frame = cv2.cvtColor(state, cv2.COLOR_RGB2GRAY)
|
||||||
|
# Crop irrelevant parts
|
||||||
|
frame = frame[34:194, 15:145] # Crops to shape (160, 130)
|
||||||
|
# Downsample
|
||||||
|
frame = cv2.resize(frame, (80, 80), interpolation=cv2.INTER_AREA)
|
||||||
|
# Normalize
|
||||||
|
frame = frame / 255
|
||||||
|
return frame
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
env = PongEnv()
|
||||||
|
|
||||||
##
|
##
|
||||||
# Flask Environment
|
# Flask Environment
|
||||||
##
|
##
|
||||||
|
@ -41,100 +68,93 @@ app = Flask(__name__)
|
||||||
|
|
||||||
@app.route('/environment', methods=['GET'])
|
@app.route('/environment', methods=['GET'])
|
||||||
def get_env():
|
def get_env():
|
||||||
global env, environment_name
|
global env
|
||||||
if request.args.get('shape') is not None:
|
if request.args.get('shape') is not None:
|
||||||
shape = {}
|
shape = {}
|
||||||
shape['observation'] = env.observation_space.shape
|
shape['observation'] = env.sim.observation_space.shape
|
||||||
shape['action'] = env.action_space.n
|
shape['action'] = env.sim.action_space.n
|
||||||
return json.dumps(shape)
|
return json.dumps(shape)
|
||||||
return environment_name
|
return env.environment_name
|
||||||
|
|
||||||
@app.route('/gym', methods=['GET'])
|
@app.route('/gym', methods=['GET'])
|
||||||
def get_extra_data():
|
def get_extra_data():
|
||||||
global env
|
global env
|
||||||
data = {}
|
data = {}
|
||||||
if request.args.get('action_space') is not None:
|
if request.args.get('action_space') is not None:
|
||||||
data['action_space'] = env.action_space
|
data['action_space'] = env.sim.action_space
|
||||||
if request.args.get('observation_space') is not None:
|
if request.args.get('observation_space') is not None:
|
||||||
data['observation_space'] = env.observation_space
|
data['observation_space'] = env.sim.observation_space
|
||||||
if request.args.get('reward_range') is not None:
|
if request.args.get('reward_range') is not None:
|
||||||
data['reward_range'] = env.reward_range
|
data['reward_range'] = env.sim.reward_range
|
||||||
if request.args.get('metadata') is not None:
|
if request.args.get('metadata') is not None:
|
||||||
data['metadata'] = env.metadata
|
data['metadata'] = env.sim.metadata
|
||||||
if request.args.get('action_meanings') is not None:
|
if request.args.get('action_meanings') is not None:
|
||||||
data['action_meanings'] = env.unwrapped.get_action_meanings()
|
data['action_meanings'] = env.sim.unwrapped.get_action_meanings()
|
||||||
return pickle.dumps(data)
|
return pickle.dumps(data)
|
||||||
|
|
||||||
@app.route('/action_space', methods=['GET'])
|
@app.route('/action_space', methods=['GET'])
|
||||||
def get_action_space():
|
def get_action_space():
|
||||||
global env
|
global env
|
||||||
return pickle.dumps(env.action_space)
|
return pickle.dumps(env.sim.action_space)
|
||||||
|
|
||||||
@app.route('/observation_space', methods=['GET'])
|
@app.route('/observation_space', methods=['GET'])
|
||||||
def get_observation_space():
|
def get_observation_space():
|
||||||
global env
|
global env
|
||||||
return pickle.dumps(env.observation_space)
|
return pickle.dumps(env.sim.observation_space)
|
||||||
|
|
||||||
@app.route('/reward_range', methods=['GET'])
|
@app.route('/reward_range', methods=['GET'])
|
||||||
def get_reward_range():
|
def get_reward_range():
|
||||||
global env
|
global env
|
||||||
return pickle.dumps(env.reward_range)
|
return pickle.dumps(env.sim.reward_range)
|
||||||
|
|
||||||
@app.route('/metadata', methods=['GET'])
|
@app.route('/metadata', methods=['GET'])
|
||||||
def get_metadata():
|
def get_metadata():
|
||||||
global env
|
global env
|
||||||
return pickle.dumps(env.metadata)
|
return pickle.dumps(env.sim.metadata)
|
||||||
|
|
||||||
@app.route('/action_meanings', methods=['GET'])
|
@app.route('/action_meanings', methods=['GET'])
|
||||||
def get_action_meanings():
|
def get_action_meanings():
|
||||||
global env
|
global env
|
||||||
return pickle.dumps(env.unwrapped.get_action_meanings())
|
return pickle.dumps(env.sim.unwrapped.get_action_meanings())
|
||||||
|
|
||||||
@app.route('/state', methods=['GET'])
|
@app.route('/state', methods=['GET'])
|
||||||
def get_state():
|
def get_state():
|
||||||
return pickle_state()
|
return env.get_state(pickleobj = True, preprocess = request.args.get('preprocess') is not None)
|
||||||
|
|
||||||
@app.route('/reward', methods=['GET'])
|
@app.route('/reward', methods=['GET'])
|
||||||
def get_reward():
|
def get_reward():
|
||||||
global score, reward
|
global env
|
||||||
if request.args.get('all') is not None:
|
if request.args.get('all') is not None:
|
||||||
return str(score)
|
return str(env.score)
|
||||||
else:
|
else:
|
||||||
return str(reward)
|
return str(env.reward)
|
||||||
|
|
||||||
@app.route('/done', methods=['GET'])
|
@app.route('/done', methods=['GET'])
|
||||||
def is_done():
|
def is_done():
|
||||||
global done
|
global env
|
||||||
return str(done)
|
return str(env.done)
|
||||||
|
|
||||||
@app.route('/info', methods=['GET'])
|
@app.route('/info', methods=['GET'])
|
||||||
def get_info():
|
def get_info():
|
||||||
global info
|
global env
|
||||||
return json.dumps(info)
|
return json.dumps(env.info)
|
||||||
|
|
||||||
@app.route('/action', methods=['POST'])
|
@app.route('/action', methods=['POST'])
|
||||||
def perform_action():
|
def perform_action():
|
||||||
global state, reward, done, info, score
|
global env
|
||||||
action = int(request.form['id'])
|
action = int(request.form['id'])
|
||||||
|
env.step(action)
|
||||||
# [TODO] Check to see if 'action' is valid
|
|
||||||
state, reward, done, info = env.step(action)
|
|
||||||
score += reward
|
|
||||||
|
|
||||||
content = {}
|
content = {}
|
||||||
content['state'] = state
|
content['state'] = env.get_state(preprocess = request.args.get('preprocess') is not None)
|
||||||
content['reward'] = reward
|
content['reward'] = env.reward
|
||||||
content['done'] = done
|
content['done'] = env.done
|
||||||
content['info'] = info
|
content['info'] = env.info
|
||||||
return pickle.dumps(content)
|
return pickle.dumps(content)
|
||||||
|
|
||||||
@app.route('/reset')
|
@app.route('/reset')
|
||||||
def reset_env():
|
def reset_env():
|
||||||
global env, state, reward, done, info, score
|
global env
|
||||||
state = env.reset()
|
env.reset()
|
||||||
reward = 0
|
return env.get_state(pickleobj = True, preprocess = request.args.get('preprocess') is not None)
|
||||||
done = False
|
|
||||||
info = {}
|
|
||||||
score = 0
|
|
||||||
return pickle_state()
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue