Encapsulated State better and added a preprocessing function

This commit is contained in:
Brandon Rozek 2019-06-08 14:52:56 -04:00
parent 147d682c1c
commit 18ad080026
2 changed files with 81 additions and 58 deletions

View file

@ -18,8 +18,9 @@ class Environment:
def get_environment_name(self): def get_environment_name(self):
r = requests.get(self.server + "/environment") r = requests.get(self.server + "/environment")
return r.text return r.text
def get_state(self): def get_state(self, preprocess = False):
r = requests.get(self.server + "/state") parameter = "?preprocess" if preprocess else ""
r = requests.get(self.server + "/state" + parameter)
return pickle.loads(r.content) return pickle.loads(r.content)
def get_reward(self): def get_reward(self):
r = requests.get(self.server + "/reward") r = requests.get(self.server + "/reward")
@ -56,11 +57,13 @@ class Environment:
## ##
# Common API # Common API
## ##
def reset(self): def reset(self, preprocess = False):
r = requests.get(self.server + "/reset") parameter = "?preprocess" if preprocess else ""
r = requests.get(self.server + "/reset" + parameter)
return pickle.loads(r.content) return pickle.loads(r.content)
def step(self, action): def step(self, action, preprocess = False):
r = requests.post(self.server + "/action", data={'id': action}) parameter = "?preprocess" if preprocess else ""
r = requests.post(self.server + "/action" + parameter, data={'id': action})
content = pickle.loads(r.content) content = pickle.loads(r.content)
return content['state'], content['reward'], content['done'], content['info'] return content['state'], content['reward'], content['done'], content['info']

View file

@ -13,27 +13,54 @@ log.setLevel(logging.ERROR)
## ##
# OpenAI Gym State # OpenAI Gym State
## ##
# environment_name = "Acrobot-v1" class Environment:
environment_name = "Pong-v0" def __init__(self, environment_name):
env = gym.make(environment_name) self.sim = gym.make(environment_name)
self.environment_name = environment_name
self.reset()
def step(self, action):
# [TODO] Check to see if 'action' is valid
self.state, self.reward, self.done, self.info = self.sim.step(action)
self.score += self.reward
def reset(self):
self.state = self.sim.reset()
self.reward = 0
self.score = 0
self.done = False
self.info = {}
def preprocess(self, state):
raise NotImplementedError
def get_state(self, preprocess = False, pickleobj = False):
state = self.state
if preprocess:
state = self.preprocess(state)
if pickleobj:
state = pickle.dumps(state)
return state
# Observations to release to agent
state = env.reset()
reward = 0
score = 0
done = False
info = {}
## ##
# Helper Functions # Pong Specific Environment Information
## ##
# [TODO] Evaluate whether pickling is the right option here import cv2
def pickle_state(): class PongEnv(Environment):
global state def __init__(self):
return pickle.dumps(state) super(PongEnv, self).__init__("PongNoFrameskip-v4")
def preprocess(self, state):
# Grayscale
frame = cv2.cvtColor(state, cv2.COLOR_RGB2GRAY)
# Crop irrelevant parts
frame = frame[34:194, 15:145] # Crops to shape (160, 130)
# Downsample
frame = cv2.resize(frame, (80, 80), interpolation=cv2.INTER_AREA)
# Normalize
frame = frame / 255
return frame
env = PongEnv()
## ##
# Flask Environment # Flask Environment
## ##
@ -41,100 +68,93 @@ app = Flask(__name__)
@app.route('/environment', methods=['GET']) @app.route('/environment', methods=['GET'])
def get_env(): def get_env():
global env, environment_name global env
if request.args.get('shape') is not None: if request.args.get('shape') is not None:
shape = {} shape = {}
shape['observation'] = env.observation_space.shape shape['observation'] = env.sim.observation_space.shape
shape['action'] = env.action_space.n shape['action'] = env.sim.action_space.n
return json.dumps(shape) return json.dumps(shape)
return environment_name return env.environment_name
@app.route('/gym', methods=['GET']) @app.route('/gym', methods=['GET'])
def get_extra_data(): def get_extra_data():
global env global env
data = {} data = {}
if request.args.get('action_space') is not None: if request.args.get('action_space') is not None:
data['action_space'] = env.action_space data['action_space'] = env.sim.action_space
if request.args.get('observation_space') is not None: if request.args.get('observation_space') is not None:
data['observation_space'] = env.observation_space data['observation_space'] = env.sim.observation_space
if request.args.get('reward_range') is not None: if request.args.get('reward_range') is not None:
data['reward_range'] = env.reward_range data['reward_range'] = env.sim.reward_range
if request.args.get('metadata') is not None: if request.args.get('metadata') is not None:
data['metadata'] = env.metadata data['metadata'] = env.sim.metadata
if request.args.get('action_meanings') is not None: if request.args.get('action_meanings') is not None:
data['action_meanings'] = env.unwrapped.get_action_meanings() data['action_meanings'] = env.sim.unwrapped.get_action_meanings()
return pickle.dumps(data) return pickle.dumps(data)
@app.route('/action_space', methods=['GET']) @app.route('/action_space', methods=['GET'])
def get_action_space(): def get_action_space():
global env global env
return pickle.dumps(env.action_space) return pickle.dumps(env.sim.action_space)
@app.route('/observation_space', methods=['GET']) @app.route('/observation_space', methods=['GET'])
def get_observation_space(): def get_observation_space():
global env global env
return pickle.dumps(env.observation_space) return pickle.dumps(env.sim.observation_space)
@app.route('/reward_range', methods=['GET']) @app.route('/reward_range', methods=['GET'])
def get_reward_range(): def get_reward_range():
global env global env
return pickle.dumps(env.reward_range) return pickle.dumps(env.sim.reward_range)
@app.route('/metadata', methods=['GET']) @app.route('/metadata', methods=['GET'])
def get_metadata(): def get_metadata():
global env global env
return pickle.dumps(env.metadata) return pickle.dumps(env.sim.metadata)
@app.route('/action_meanings', methods=['GET']) @app.route('/action_meanings', methods=['GET'])
def get_action_meanings(): def get_action_meanings():
global env global env
return pickle.dumps(env.unwrapped.get_action_meanings()) return pickle.dumps(env.sim.unwrapped.get_action_meanings())
@app.route('/state', methods=['GET']) @app.route('/state', methods=['GET'])
def get_state(): def get_state():
return pickle_state() return env.get_state(pickleobj = True, preprocess = request.args.get('preprocess') is not None)
@app.route('/reward', methods=['GET']) @app.route('/reward', methods=['GET'])
def get_reward(): def get_reward():
global score, reward global env
if request.args.get('all') is not None: if request.args.get('all') is not None:
return str(score) return str(env.score)
else: else:
return str(reward) return str(env.reward)
@app.route('/done', methods=['GET']) @app.route('/done', methods=['GET'])
def is_done(): def is_done():
global done global env
return str(done) return str(env.done)
@app.route('/info', methods=['GET']) @app.route('/info', methods=['GET'])
def get_info(): def get_info():
global info global env
return json.dumps(info) return json.dumps(env.info)
@app.route('/action', methods=['POST']) @app.route('/action', methods=['POST'])
def perform_action(): def perform_action():
global state, reward, done, info, score global env
action = int(request.form['id']) action = int(request.form['id'])
env.step(action)
# [TODO] Check to see if 'action' is valid
state, reward, done, info = env.step(action)
score += reward
content = {} content = {}
content['state'] = state content['state'] = env.get_state(preprocess = request.args.get('preprocess') is not None)
content['reward'] = reward content['reward'] = env.reward
content['done'] = done content['done'] = env.done
content['info'] = info content['info'] = env.info
return pickle.dumps(content) return pickle.dumps(content)
@app.route('/reset') @app.route('/reset')
def reset_env(): def reset_env():
global env, state, reward, done, info, score global env
state = env.reset() env.reset()
reward = 0 return env.get_state(pickleobj = True, preprocess = request.args.get('preprocess') is not None)
done = False
info = {}
score = 0
return pickle_state()