wordguess/pubnix.py

295 lines
7.6 KiB
Python
Raw Normal View History

2024-01-05 22:30:58 -05:00
"""
2024-01-05 23:42:44 -05:00
Pubnix Server/Client Library
2024-01-05 22:30:58 -05:00
Author: Brandon Rozek
2024-01-05 23:42:44 -05:00
The pubnix library contains server
and client code needed to communicate
over unix domain sockets on a single
machine.
For authentication, we rely on challenge
tokens and the unix permission system as
both server and client run on the same
machine.
Remaining TODO ...
2024-01-05 22:30:58 -05:00
TODO: Handle a user trying to connect multiple
times at the same time.
This might be handled automatically if only one
user can play at a time...
2024-01-05 23:42:44 -05:00
TODO: Handle timeout properly
"""
from contextlib import contextmanager
from dataclasses import dataclass
from pathlib import Path
from typing import Union
import binascii
2024-01-05 22:30:58 -05:00
import json
import os
2024-01-05 23:42:44 -05:00
import pwd
2024-01-05 22:30:58 -05:00
import sys
import socket
MESSAGE_BUFFER_LEN = 1024
TOKEN_LENGTH = 50
TIMEOUT = 5 * 60 # 5 minutes
2024-01-05 23:42:44 -05:00
###
# Server
###
2024-01-05 22:30:58 -05:00
2024-01-05 23:42:44 -05:00
def run_simple_server(address, fn, force_auth=True):
"""
This function can act as the main entrypoint
for the server. It takes a function that interacts
with a connected user (potentially authenticated)
Example
=======
if __name__ == "__main__":
run_simple_server(
"/home/project/.pubnix.sock",
lambda connection, user: connection.sendall(f"Hello {user}".encode())
)
"""
with start_server(address) as sock:
print("Started server at", address)
try:
while True:
with client_connection(sock) as connection:
user = None
if force_auth:
user = authenticate(connection)
receive_message(connection, StartMessage)
fn(connection, user)
except KeyboardInterrupt:
print("Stopping server...")
2024-01-05 22:30:58 -05:00
@contextmanager
2024-01-05 23:42:44 -05:00
def start_server(address):
"""
Opens up a unix domain socket at the specified address
and listens for connections.
"""
if os.path.exists(address):
print(f"{address} exists -- server already running")
2024-01-05 22:30:58 -05:00
sys.exit(1)
2024-01-05 23:42:44 -05:00
# Create a unix domain socket
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
sock.settimeout(TIMEOUT)
sock.bind(address)
sock.listen()
2024-01-05 22:30:58 -05:00
try:
2024-01-05 23:42:44 -05:00
yield sock
2024-01-05 22:30:58 -05:00
finally:
2024-01-05 23:42:44 -05:00
# Delete game.sock when finished
os.unlink(address)
2024-01-05 22:30:58 -05:00
@contextmanager
def client_connection(sock):
connection, _ = sock.accept()
try:
yield connection
except (
ProtocolException,
BrokenPipeError,
TimeoutError,
ConnectionResetError) as e:
# Ignore as client can reconnect
pass
finally: # clean up the connection
connection.close()
def generate_challenge(user):
return ChallengeMessage(
username=user,
token=generate_token(TOKEN_LENGTH),
location=f"/home/{user}/.pubnix_challenge"
)
def authenticate(connection):
# First message should be an authentication message
message = receive_message(connection, AuthenticateMessage)
user = message.username
# Send challenge message
challenge = generate_challenge(user)
send_message(connection, challenge)
# Second message should be validation message
message = receive_message(connection, ValidationMessage)
# Check that challenge file exists
if not os.path.exists(challenge.location):
close_with_error(connection, "Challange file doesn't exist")
# Check if user owns the file
if find_owner(challenge.location) != user:
close_with_error(connection, "Challange file not owned by user")
# Make sure we can read the file
if not os.access(challenge.location, os.R_OK):
close_with_error(connection, "Challange file cannot be read by server")
# Check contents of challenge file
with open(challenge.location, "r") as file:
contents = file.read()
if contents != challenge.token:
close_with_error(connection, "Token within challange file is incorrect")
# Send authentication successful message
send_message(connection, AuthSuccessMessage())
return user
2024-01-05 23:42:44 -05:00
def generate_token(length):
# From https://stackoverflow.com/a/41354711
return binascii.hexlify(os.urandom(length // 2)).decode()
def find_owner(path: Union[str, Path]) -> str:
return Path(path).owner()
###
# Client
###
def run_simple_client(address, fn, force_auth=True):
"""
This function can act as the main entrypoint
for the client. It takes a function that interacts
with the server. If force_auth is enabled, then it
first authenticates as the effect user running the
program.
Example
=======
if __name__ == "__main__":
run_simple_client(
"/home/project/.pubnix.sock",
lambda connection, user: connection.sendall(f"Hello server, I'm {user}".encode())
)
"""
with start_client(address) as client:
if force_auth:
user = login(client)
send_message(client, StartMessage())
fn(client, user)
@contextmanager
def start_client(address):
# Create the Unix socket client
client = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
# Connect to the server
try:
client.connect(address)
except FileNotFoundError:
print("Server is not running at location", address)
sys.exit(1)
try:
yield client
finally:
client.close()
2024-01-05 22:30:58 -05:00
def login(connection):
# Send authentication message
user = pwd.getpwuid(os.geteuid()).pw_name
message = AuthenticateMessage(username=user)
send_message(connection, message)
# Receive challenge message
challenge = receive_message(connection, ChallengeMessage)
# Write to challenge file
with open(challenge.location, "w") as file:
file.write(challenge.token)
# Tell server to check the challenge file
send_message(connection, ValidationMessage())
2024-01-05 23:42:44 -05:00
# On success, delete challenge file
2024-01-05 22:30:58 -05:00
try:
message = receive_message(connection, AuthSuccessMessage)
finally:
os.unlink(challenge.location)
return user
2024-01-05 23:42:44 -05:00
##
# Messages
##
2024-01-05 22:30:58 -05:00
class MessageEncoder(json.JSONEncoder):
def default(self, o):
return o.__dict__
def send_message(connection, message):
contents = json.dumps(message, cls=MessageEncoder).encode()
connection.sendall(contents)
def receive_message(connection, cls=None):
message = connection.recv(MESSAGE_BUFFER_LEN).decode()
try:
message = json.loads(message)
except Exception:
print("Received:", message, flush=True)
close_with_error(connection, "Invalid Message Received")
if cls is not None:
try:
message = cls(**message)
except (TypeError, AssertionError):
close_with_error(connection, "Expected message of type")
return message
class ProtocolException(Exception):
pass
def close_with_error(connection, content: str):
message = dict(type="error", message=content)
connection.sendall(json.dumps(message).encode())
raise ProtocolException()
@dataclass
class ChallengeMessage:
username: str
token: str
location: str
action: str = "challenge"
def __post_init__(self):
assert self.action == "challenge"
@dataclass
class AuthenticateMessage:
username: str
action: str = "authenticate"
def __post_init__(self):
assert self.action == "authenticate"
assert len(self.username) > 0
@dataclass
class ValidationMessage:
action: str = "validate"
def __post_init__(self):
assert self.action == "validate"
@dataclass
class AuthSuccessMessage:
type: str = "authentication_success"
def __post_init__(self):
assert self.type == "authentication_success"
@dataclass
class StartMessage:
action: str = "start"
def __post_init__(self):
assert self.action == "start"