Updated driver file R.py to showcase SMT techinques

Fixed minor bugs concerning lack of falsification rules and interfaces between VSP and SMT
This commit is contained in:
Brandon Rozek 2026-01-27 12:48:33 -05:00
parent f8eca388d4
commit 6d87793803
5 changed files with 120 additions and 65 deletions

75
smt.py
View file

@ -1,18 +1,26 @@
from itertools import product
from typing import Dict, Generator, Optional, Tuple
from typing import Dict, Generator, Optional, Set, Tuple, TYPE_CHECKING
from logic import Logic, Operation, Rule, PropositionalVariable, Term, OpTerm, get_prop_vars_from_rule
from model import Model, ModelValue, ModelFunction
from z3 import (
And, BoolSort, Context, EnumSort, Function, Implies, Or, sat, Solver, z3
)
SMT_LOADED = True
try:
from z3 import (
And, BoolSort, Context, EnumSort, Function, Implies, Or, sat, Solver, z3
)
except ImportError:
SMT_LOADED = False
def smt_is_loaded() -> bool:
global SMT_LOADED
return SMT_LOADED
def term_to_smt(
t: Term,
op_mapping: Dict[Operation, z3.FuncDeclRef],
var_mapping: Dict[PropositionalVariable, z3.DatatypeRef]
) -> z3.DatatypeRef:
op_mapping: Dict[Operation, "z3.FuncDeclRef"],
var_mapping: Dict[PropositionalVariable, "z3.DatatypeRef"]
) -> "z3.DatatypeRef":
"""Convert a logic term to its SMT representation."""
if isinstance(t, PropositionalVariable):
return var_mapping[t]
@ -26,10 +34,10 @@ def term_to_smt(
def logic_rule_to_smt_constraints(
rule: Rule,
IsDesignated: z3.FuncDeclRef,
IsDesignated: "z3.FuncDeclRef",
smt_carrier_set,
op_mapping: Dict[Operation, z3.FuncDeclRef]
) -> Generator[z3.BoolRef, None, None]:
op_mapping: Dict[Operation, "z3.FuncDeclRef"]
) -> Generator["z3.BoolRef", None, None]:
"""
Encode a logic rule as SMT constraints.
@ -63,10 +71,10 @@ def logic_rule_to_smt_constraints(
def logic_falsification_rule_to_smt_constraints(
rule: Rule,
IsDesignated: z3.FuncDeclRef,
IsDesignated: "z3.FuncDeclRef",
smt_carrier_set,
op_mapping: Dict[Operation, z3.FuncDeclRef]
) -> z3.BoolRef:
op_mapping: Dict[Operation, "z3.FuncDeclRef"]
) -> "z3.BoolRef":
"""
Encode a falsification rule as an SMT constraint.
@ -132,7 +140,7 @@ class SMTLogicEncoder:
self.carrier_sort, self.smt_carrier_set = EnumSort("C", element_names, ctx=self.ctx)
# Create operation functions
self.operation_function_map: Dict[Operation, z3.FuncDeclRef] = {}
self.operation_function_map: Dict[Operation, "z3.FuncDeclRef"] = {}
for operation in logic.operations:
self.operation_function_map[operation] = self.create_function(operation.symbol, operation.arity)
@ -143,10 +151,10 @@ class SMTLogicEncoder:
self._add_logic_constraints()
self._add_designation_symmetry_constraints()
def create_predicate(self, name: str, arity: int) -> z3.FuncDeclRef:
def create_predicate(self, name: str, arity: int) -> "z3.FuncDeclRef":
return Function(name, *(self.carrier_sort for _ in range(arity)), BoolSort(ctx=self.ctx))
def create_function(self, name: str, arity: int) -> z3.FuncDeclRef:
def create_function(self, name: str, arity: int) -> "z3.FuncDeclRef":
return Function(name, *(self.carrier_sort for _ in range(arity + 1)))
def _add_logic_constraints(self):
@ -171,9 +179,9 @@ class SMTLogicEncoder:
)
self.solver.add(constraint)
def extract_model(self, smt_model) -> Model:
def extract_model(self, smt_model) -> Tuple[Model, Dict[Operation, ModelFunction]]:
"""
Extract a Model object from an SMT model.
Extract a Model object and interpretation from an SMT model.
"""
carrier_set = {ModelValue(f"{i}") for i in range(self.size)}
@ -185,7 +193,8 @@ class SMTLogicEncoder:
designated_values = {ModelValue(str(x)) for x in smt_designated}
# Extract operation functions
model_functions = set()
model_functions: Set[ModelFunction] = set()
interpretation: Dict[Operation, ModelFunction] = dict()
for (operation, smt_function) in self.operation_function_map.items():
mapping: Dict[Tuple[ModelValue], ModelValue] = {}
for smt_inputs in product(self.smt_carrier_set, repeat=operation.arity):
@ -193,9 +202,12 @@ class SMTLogicEncoder:
smt_output = smt_model.evaluate(smt_function(*smt_inputs))
model_output = ModelValue(str(smt_output))
mapping[model_inputs] = model_output
model_functions.add(ModelFunction(operation.arity, mapping, operation.symbol))
model_function = ModelFunction(operation.arity, mapping, operation.symbol)
model_functions.add(model_function)
interpretation[operation] = model_function
return Model(carrier_set, model_functions, designated_values)
return Model(carrier_set, model_functions, designated_values), interpretation
def _add_designation_symmetry_constraints(self):
@ -218,7 +230,7 @@ class SMTLogicEncoder:
)
)
def create_exclusion_constraint(self, model: Model) -> z3.BoolRef:
def create_exclusion_constraint(self, model: Model) -> "z3.BoolRef":
"""
Create a constraint that excludes the given model from future solutions.
"""
@ -254,7 +266,7 @@ class SMTLogicEncoder:
return Or(constraints)
def find_model(self) -> Optional[Model]:
def find_model(self) -> Optional[Tuple[Model, Dict[Operation, ModelFunction]]]:
"""
Find a single model satisfying the logic constraints.
@ -274,12 +286,12 @@ class SMTLogicEncoder:
pass
def find_model(logic: Logic, size: int) -> Optional[Model]:
def find_model(logic: Logic, size: int) -> Optional[Tuple[Model, Dict[Operation, ModelFunction]]]:
"""Find a single model for the given logic and size."""
encoder = SMTLogicEncoder(logic, size)
return encoder.find_model()
def find_all_models(logic: Logic, size: int) -> Generator[Model, None, None]:
def find_all_models(logic: Logic, size: int) -> Generator[Tuple[Model, Dict[Operation, ModelFunction]], None, None]:
"""
Find all models for the given logic and size.
@ -294,13 +306,14 @@ def find_all_models(logic: Logic, size: int) -> Generator[Model, None, None]:
while True:
# Try to find a model
model = encoder.find_model()
if model is None:
solution = encoder.find_model()
if solution is None:
break
yield model
yield solution
# Add constraint to exclude this model from future solutions
model, _ = solution
exclusion_constraint = encoder.create_exclusion_constraint(model)
encoder.solver.add(exclusion_constraint)
@ -346,13 +359,13 @@ class SMTModelEncoder:
is_designated = model_value in model.designated_values
self.solver.add(self.is_designated(self.model_value_to_smt[model_value]) == is_designated)
def create_predicate(self, name: str, arity: int) -> z3.FuncDeclRef:
def create_predicate(self, name: str, arity: int) -> "z3.FuncDeclRef":
return Function(name, *(self.carrier_sort for _ in range(arity)), BoolSort(ctx=self.ctx))
def create_function(self, name: str, arity: int) -> z3.FuncDeclRef:
def create_function(self, name: str, arity: int) -> "z3.FuncDeclRef":
return Function(name, *(self.carrier_sort for _ in range(arity + 1)))
def add_function_constraints_from_table(self, smt_fn: z3.FuncDeclRef, model_fn: ModelFunction):
def add_function_constraints_from_table(self, smt_fn: "z3.FuncDeclRef", model_fn: ModelFunction):
for inputs, output in model_fn.mapping.items():
smt_inputs = tuple(self.model_value_to_smt[inp] for inp in inputs)
smt_output = self.model_value_to_smt[output]