diff --git a/examples/example_a2c.py b/examples/example_a2c.py index 4574207..6f06de7 100644 --- a/examples/example_a2c.py +++ b/examples/example_a2c.py @@ -217,7 +217,7 @@ class FrameStack(gym.Wrapper): NUMBER_ENVIRONMENTS = 32 pool = ThreadPool(NUMBER_ENVIRONMENTS) -envs = [Environment("127.0.0.1", i) for i in range(5000, 5000 + NUMBER_ENVIRONMENTS)] +envs = [Environment("ipc", "/tmp/zerogym", i) for i in range(5000, 5000 + NUMBER_ENVIRONMENTS)] envs = [FrameStack(FireResetEnv(env), 4) for env in envs] # env.seed(SEED) state_size = [1, 4, 80, 70] diff --git a/gymclient.py b/gymclient.py index bdac868..fe15579 100644 --- a/gymclient.py +++ b/gymclient.py @@ -4,10 +4,12 @@ import numpy # [TODO] Error handling for if server is down class Environment: - def __init__(self, address, port): + def __init__(self, proto, address, port): self.context = zmq.Context() self.socket = self.context.socket(zmq.REQ) - self.socket.connect("tcp://%s:%s" % (address, port)) + if proto != "tcp" and proto != "ipc": + raise ValueError("proto must be tcp or ipc") + self.socket.connect("%s://%s:%s" % (proto, address, port)) self.address = address self.port = port self.observation_space, self.action_space, self.reward_range, self.metadata, self.action_meanings = self.get_initial_metadata() diff --git a/gymserver.py b/gymserver.py index c689f2f..55c6a9a 100644 --- a/gymserver.py +++ b/gymserver.py @@ -129,15 +129,25 @@ def respond(msg, env, socket): # # [TODO] Make sure port is an int -if len(sys.argv) != 2: - print("Usage: gymserver_zero.py ", file=sys.stderr) +if len(sys.argv) != 3: + print("Usage: gymserver_zero.py ", file=sys.stderr) + print("Where proto equals tcp or ipc", file=sys.stderr) sys.exit(1) env = PongEnv() -port = int(sys.argv[1]) +port = int(sys.argv[2]) context = zmq.Context() socket = context.socket(zmq.REP) -socket.bind("tcp://*:%s" % port) + +proto = sys.argv[1] +if proto == "tcp": + socket.bind("tcp://*:%s" % port) +elif proto == "ipc": + socket.bind("ipc:///tmp/zerogym:%s" % port) +else: + print("proto not recognized", file=sys.stderr) + sys.exit(1) + while True: msg = socket.recv_json() respond(msg, env, socket) diff --git a/start_servers.sh b/start_servers.sh index 5a80db8..1422157 100755 --- a/start_servers.sh +++ b/start_servers.sh @@ -2,5 +2,5 @@ export FLASK_APP=gymserver.py for i in {0..31} do - python gymserver.py $((5000 + i)) & + python gymserver.py $1 $((5000 + i)) & done