Added threaded implementation

This commit is contained in:
Brandon Rozek 2024-01-07 11:20:19 -05:00
parent 4a3e71fd66
commit 81b7b16c3f
No known key found for this signature in database
GPG key ID: 26E457DA82C9F480

View file

@ -11,20 +11,11 @@ For authentication, we rely on challenge
tokens and the unix permission system as tokens and the unix permission system as
both server and client run on the same both server and client run on the same
machine. machine.
Remaining TODO ...
TODO: Handle a user trying to connect multiple
times at the same time.
This might be handled automatically if only one
user can play at a time...
TODO: Handle timeout properly
""" """
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from threading import Thread
from typing import Union from typing import Union
import binascii import binascii
import json import json
@ -33,6 +24,8 @@ import pwd
import sys import sys
import socket import socket
__all__ = ['run_simple_server', 'run_simple_client']
MESSAGE_BUFFER_LEN = 1024 MESSAGE_BUFFER_LEN = 1024
TOKEN_LENGTH = 50 TOKEN_LENGTH = 50
TIMEOUT = 5 * 60 # 5 minutes TIMEOUT = 5 * 60 # 5 minutes
@ -59,14 +52,31 @@ def run_simple_server(address, fn, force_auth=True):
print("Started server at", address) print("Started server at", address)
try: try:
while True: while True:
with client_connection(sock) as connection: connection, _ = sock.accept()
connection.settimeout(TIMEOUT)
t = Thread(target=thread_connection, args=[connection, force_auth, fn])
t.daemon = True # TODO: Implement graceful cleanup instead
t.start()
except KeyboardInterrupt:
print("Stopping server...")
def thread_connection(connection, force_auth, fn):
try:
user = None user = None
if force_auth: if force_auth:
user = authenticate(connection) user = authenticate(connection)
receive_message(connection, StartMessage) receive_message(connection, StartMessage)
fn(connection, user) fn(connection, user)
except KeyboardInterrupt: except (
print("Stopping server...") ProtocolException,
BrokenPipeError,
TimeoutError,
ConnectionResetError) as e:
# Ignore as client can reconnect
pass
finally: # clean up the connection
if connection is not None:
connection.close()
@contextmanager @contextmanager
def start_server(address, allow_other=True): def start_server(address, allow_other=True):
@ -83,7 +93,6 @@ def start_server(address, allow_other=True):
# Create a unix domain socket # Create a unix domain socket
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
sock.settimeout(TIMEOUT)
sock.bind(address) sock.bind(address)
sock.listen() sock.listen()
@ -98,21 +107,6 @@ def start_server(address, allow_other=True):
# Delete game.sock when finished # Delete game.sock when finished
os.unlink(address) os.unlink(address)
@contextmanager
def client_connection(sock):
connection, _ = sock.accept()
try:
yield connection
except (
ProtocolException,
BrokenPipeError,
TimeoutError,
ConnectionResetError) as e:
# Ignore as client can reconnect
pass
finally: # clean up the connection
connection.close()
def generate_challenge(user): def generate_challenge(user):
SERVER_FOLDER = Path(__file__).parent.absolute() SERVER_FOLDER = Path(__file__).parent.absolute()
Path(f"{SERVER_FOLDER}/challenges").mkdir(mode=33279, exist_ok=True) Path(f"{SERVER_FOLDER}/challenges").mkdir(mode=33279, exist_ok=True)
@ -249,10 +243,13 @@ def send_message(connection, message):
def receive_message(connection, cls=None): def receive_message(connection, cls=None):
message = connection.recv(MESSAGE_BUFFER_LEN).decode() message = connection.recv(MESSAGE_BUFFER_LEN).decode()
if len(message) == 0:
raise ProtocolException("Sender closed the connection")
try: try:
message = json.loads(message) message = json.loads(message)
except Exception: except Exception:
print("Received:", message, flush=True)
close_with_error(connection, "Invalid Message Received") close_with_error(connection, "Invalid Message Received")
if cls is not None: if cls is not None:
@ -262,7 +259,6 @@ def receive_message(connection, cls=None):
if "type" in message and message['type'] == "error": if "type" in message and message['type'] == "error":
raise ProtocolException(message.get("message")) raise ProtocolException(message.get("message"))
else: else:
print("Received:", message, flush=True)
close_with_error(connection, f"Expected message of type {cls}") close_with_error(connection, f"Expected message of type {cls}")
return message return message