This commit is contained in:
Brandon Rozek 2026-01-12 17:07:44 -05:00
parent 25bd83f032
commit 95e482a265

37
smt.py
View file

@ -108,12 +108,6 @@ def logic_falsification_rule_to_smt_constraints(
class SMTLogicEncoder:
"""
Encapsulates the SMT encoding of a logic system with a fixed carrier set size.
This class handles:
- Creating the SMT sorts and functions
- Encoding logic rules as SMT constraints
- Managing the solver state
- Converting between Model objects and SMT representations
"""
def __init__(self, logic: Logic, size: int):
@ -140,18 +134,21 @@ class SMTLogicEncoder:
# Create operation functions
self.operation_function_map: Dict[Operation, z3.FuncDeclRef] = {}
for operation in logic.operations:
self.operation_function_map[operation] = Function(
operation.symbol,
*(self.carrier_sort for _ in range(operation.arity + 1))
)
self.operation_function_map[operation] = self.create_function(operation.symbol, operation.arity)
# Create designation function
self.is_designated = Function("D", self.carrier_sort, BoolSort(ctx=self.ctx))
self.is_designated = self.create_predicate("D", 1)
# Add logic rules as constraints
self._add_logic_constraints()
self._add_designation_symmetry_constraints()
def create_predicate(self, name: str, arity: int) -> z3.FuncDeclRef:
return Function(name, *(self.carrier_sort for _ in range(arity)), BoolSort(ctx=self.ctx))
def create_function(self, name: str, arity: int) -> z3.FuncDeclRef:
return Function(name, *(self.carrier_sort for _ in range(arity + 1)))
def _add_logic_constraints(self):
"""Add all logic rules and falsification rules as SMT constraints."""
# Add regular rules
@ -177,12 +174,6 @@ class SMTLogicEncoder:
def extract_model(self, smt_model) -> Model:
"""
Extract a Model object from an SMT model.
Args:
smt_model: The Z3 model to extract from
Returns:
A Model object representing the logic model
"""
carrier_set = {ModelValue(f"{i}") for i in range(self.size)}
@ -281,10 +272,6 @@ class SMTLogicEncoder:
return self.extract_model(self.solver.model())
return None
def reset(self):
"""Reset the solver state."""
self.solver.reset()
def __del__(self):
"""Cleanup resources."""
try:
@ -377,3 +364,11 @@ class SMTModelEncoder:
smt_inputs = tuple(self.model_value_to_smt[inp] for inp in inputs)
smt_output = self.model_value_to_smt[output]
self.solver.add(smt_fn(*smt_inputs) == smt_output)
def __del__(self):
"""Cleanup resources."""
try:
self.solver.reset()
del self.ctx
except:
pass