This commit is contained in:
Brandon Rozek 2026-01-12 17:07:44 -05:00
parent 25bd83f032
commit 95e482a265

37
smt.py
View file

@ -108,12 +108,6 @@ def logic_falsification_rule_to_smt_constraints(
class SMTLogicEncoder: class SMTLogicEncoder:
""" """
Encapsulates the SMT encoding of a logic system with a fixed carrier set size. Encapsulates the SMT encoding of a logic system with a fixed carrier set size.
This class handles:
- Creating the SMT sorts and functions
- Encoding logic rules as SMT constraints
- Managing the solver state
- Converting between Model objects and SMT representations
""" """
def __init__(self, logic: Logic, size: int): def __init__(self, logic: Logic, size: int):
@ -140,18 +134,21 @@ class SMTLogicEncoder:
# Create operation functions # Create operation functions
self.operation_function_map: Dict[Operation, z3.FuncDeclRef] = {} self.operation_function_map: Dict[Operation, z3.FuncDeclRef] = {}
for operation in logic.operations: for operation in logic.operations:
self.operation_function_map[operation] = Function( self.operation_function_map[operation] = self.create_function(operation.symbol, operation.arity)
operation.symbol,
*(self.carrier_sort for _ in range(operation.arity + 1))
)
# Create designation function # Create designation function
self.is_designated = Function("D", self.carrier_sort, BoolSort(ctx=self.ctx)) self.is_designated = self.create_predicate("D", 1)
# Add logic rules as constraints # Add logic rules as constraints
self._add_logic_constraints() self._add_logic_constraints()
self._add_designation_symmetry_constraints() self._add_designation_symmetry_constraints()
def create_predicate(self, name: str, arity: int) -> z3.FuncDeclRef:
return Function(name, *(self.carrier_sort for _ in range(arity)), BoolSort(ctx=self.ctx))
def create_function(self, name: str, arity: int) -> z3.FuncDeclRef:
return Function(name, *(self.carrier_sort for _ in range(arity + 1)))
def _add_logic_constraints(self): def _add_logic_constraints(self):
"""Add all logic rules and falsification rules as SMT constraints.""" """Add all logic rules and falsification rules as SMT constraints."""
# Add regular rules # Add regular rules
@ -177,12 +174,6 @@ class SMTLogicEncoder:
def extract_model(self, smt_model) -> Model: def extract_model(self, smt_model) -> Model:
""" """
Extract a Model object from an SMT model. Extract a Model object from an SMT model.
Args:
smt_model: The Z3 model to extract from
Returns:
A Model object representing the logic model
""" """
carrier_set = {ModelValue(f"{i}") for i in range(self.size)} carrier_set = {ModelValue(f"{i}") for i in range(self.size)}
@ -281,10 +272,6 @@ class SMTLogicEncoder:
return self.extract_model(self.solver.model()) return self.extract_model(self.solver.model())
return None return None
def reset(self):
"""Reset the solver state."""
self.solver.reset()
def __del__(self): def __del__(self):
"""Cleanup resources.""" """Cleanup resources."""
try: try:
@ -377,3 +364,11 @@ class SMTModelEncoder:
smt_inputs = tuple(self.model_value_to_smt[inp] for inp in inputs) smt_inputs = tuple(self.model_value_to_smt[inp] for inp in inputs)
smt_output = self.model_value_to_smt[output] smt_output = self.model_value_to_smt[output]
self.solver.add(smt_fn(*smt_inputs) == smt_output) self.solver.add(smt_fn(*smt_inputs) == smt_output)
def __del__(self):
"""Cleanup resources."""
try:
self.solver.reset()
del self.ctx
except:
pass