Updated driver file R.py to showcase SMT techinques

Fixed minor bugs concerning lack of falsification rules and interfaces between VSP and SMT
This commit is contained in:
Brandon Rozek 2026-01-27 12:48:33 -05:00
parent f8eca388d4
commit 6d87793803
5 changed files with 120 additions and 65 deletions

81
R.py
View file

@ -12,7 +12,8 @@ from logic import (
) )
from model import Model, ModelFunction, ModelValue, satisfiable from model import Model, ModelFunction, ModelValue, satisfiable
from generate_model import generate_model from generate_model import generate_model
# from vsp import has_vsp from vsp import has_vsp
from smt import smt_is_loaded
# =================================================== # ===================================================
@ -56,12 +57,17 @@ disjunction_rules = {
Rule({Conjunction(x, Disjunction(y, z)),}, Disjunction(Conjunction(x, y), Conjunction(x, z))) Rule({Conjunction(x, Disjunction(y, z)),}, Disjunction(Conjunction(x, y), Conjunction(x, z)))
} }
falsification_rules = {
# At least one value is non-designated
Rule(set(), x)
}
logic_rules = implication_rules | negation_rules | conjunction_rules | disjunction_rules logic_rules = implication_rules | negation_rules | conjunction_rules | disjunction_rules
operations = {Negation, Conjunction, Disjunction, Implication} operations = {Negation, Conjunction, Disjunction, Implication}
R_logic = Logic(operations, logic_rules, "R") R_logic = Logic(operations, logic_rules, falsification_rules, "R")
# =============================== # ===============================
@ -69,36 +75,36 @@ R_logic = Logic(operations, logic_rules, "R")
Example 2-Element Model of R Example 2-Element Model of R
""" """
a0 = ModelValue("a0") a0 = ModelValue("0")
a1 = ModelValue("a1") a1 = ModelValue("1")
carrier_set = {a0, a1} carrier_set = {a0, a1}
mnegation = ModelFunction(1, { mnegation = ModelFunction(1, {
a0: a1, a0: a1,
a1: a0 a1: a0
}) }, "¬")
mimplication = ModelFunction(2, { mimplication = ModelFunction(2, {
(a0, a0): a1, (a0, a0): a1,
(a0, a1): a1, (a0, a1): a1,
(a1, a0): a0, (a1, a0): a0,
(a1, a1): a1 (a1, a1): a1
}) }, "")
mconjunction = ModelFunction(2, { mconjunction = ModelFunction(2, {
(a0, a0): a0, (a0, a0): a0,
(a0, a1): a0, (a0, a1): a0,
(a1, a0): a0, (a1, a0): a0,
(a1, a1): a1 (a1, a1): a1
}) }, "")
mdisjunction = ModelFunction(2, { mdisjunction = ModelFunction(2, {
(a0, a0): a0, (a0, a0): a0,
(a0, a1): a1, (a0, a1): a1,
(a1, a0): a1, (a1, a0): a1,
(a1, a1): a1 (a1, a1): a1
}) }, "")
designated_values = {a1} designated_values = {a1}
@ -117,11 +123,18 @@ interpretation = {
print(R_model_2) print(R_model_2)
print(f"Does {R_model_2.name} satisfy the logic R?", satisfiable(R_logic, R_model_2, interpretation))
if smt_is_loaded():
print(has_vsp(R_model_2, mimplication, True, True))
else:
print("Z3 not setup, skipping VSP check...")
# ================================= # =================================
""" """
Generate models of R of a specified size Generate models of R of a specified size using the slow approach
""" """
print("*" * 30) print("*" * 30)
@ -130,14 +143,20 @@ model_size = 2
print("Generating models of Logic", R_logic.name, "of size", model_size) print("Generating models of Logic", R_logic.name, "of size", model_size)
solutions = generate_model(R_logic, model_size, print_model=False) solutions = generate_model(R_logic, model_size, print_model=False)
print(f"Found {len(solutions)} satisfiable models") if smt_is_loaded():
num_satisfies_vsp = 0
for model, interpretation in solutions:
negation_defined = Negation in interpretation
conj_disj_defined = Conjunction in interpretation and Disjunction in interpretation
if has_vsp(model, interpretation[Implication], negation_defined, conj_disj_defined).has_vsp:
num_satisfies_vsp += 1
print(f"Found {len(solutions)} satisfiable models of size {model_size}, {num_satisfies_vsp} of which satisfy VSP")
# for model, interpretation in solutions:
# print(has_vsp(model, interpretation))
print("*" * 30) print("*" * 30)
###### # =================================
""" """
Showing the smallest model for R that has the Showing the smallest model for R that has the
@ -146,12 +165,12 @@ variable sharing property.
This model has 6 elements. This model has 6 elements.
""" """
a0 = ModelValue("a0") a0 = ModelValue("0")
a1 = ModelValue("a1") a1 = ModelValue("1")
a2 = ModelValue("a2") a2 = ModelValue("2")
a3 = ModelValue("a3") a3 = ModelValue("3")
a4 = ModelValue("a4") a4 = ModelValue("4")
a5 = ModelValue("a5") a5 = ModelValue("5")
carrier_set = { a0, a1, a2, a3, a4, a5 } carrier_set = { a0, a1, a2, a3, a4, a5 }
designated_values = {a1, a2, a3, a4, a5 } designated_values = {a1, a2, a3, a4, a5 }
@ -312,4 +331,26 @@ interpretation = {
print(R_model_6) print(R_model_6)
print(f"Model {R_model_6.name} satisfies logic {R_logic.name}?", satisfiable(R_logic, R_model_6, interpretation)) print(f"Model {R_model_6.name} satisfies logic {R_logic.name}?", satisfiable(R_logic, R_model_6, interpretation))
# print(has_vsp(R_model_6, interpretation)) if smt_is_loaded():
print(has_vsp(R_model_6, mimplication, True, True))
else:
print("Z3 not loaded, skipping VSP check...")
"""
Generate models of R of a specified size using the SMT approach
"""
from vsp import logic_has_vsp
size = 7
print(f"Searching for a model of size {size} which witness VSP...")
if smt_is_loaded():
solution = logic_has_vsp(R_logic, size)
if solution is None:
print(f"No models found of size {size} which witness VSP")
else:
model, vsp_result = solution
print(vsp_result)
print(model)
else:
print("Z3 not setup, skipping...")

View file

@ -67,7 +67,7 @@ def only_rules_with(rules: Set[Rule], operation: Operation) -> List[Rule]:
def possible_interpretations( def possible_interpretations(
logic: Logic, carrier_set: Set[ModelValue], logic: Logic, carrier_set: Set[ModelValue],
designated_values: Set[ModelValue]): designated_values: Set[ModelValue], debug: bool):
""" """
Consider every possible interpretation of operations Consider every possible interpretation of operations
within the specified logic given the carrier set of within the specified logic given the carrier set of
@ -100,7 +100,7 @@ def possible_interpretations(
passed_functions = candidate_functions passed_functions = candidate_functions
if len(passed_functions) == 0: if len(passed_functions) == 0:
raise Exception("No interpretation satisfies the axioms for the operation " + str(operation)) raise Exception("No interpretation satisfies the axioms for the operation " + str(operation))
else: elif debug:
print( print(
f"Operation {operation.symbol} has {len(passed_functions)} candidate functions" f"Operation {operation.symbol} has {len(passed_functions)} candidate functions"
) )
@ -120,7 +120,7 @@ def possible_interpretations(
def generate_model( def generate_model(
logic: Logic, number_elements: int, num_solutions: int = -1, logic: Logic, number_elements: int, num_solutions: int = -1,
print_model=False) -> List[Tuple[Model, Interpretation]]: print_model=False, debug=False) -> List[Tuple[Model, Interpretation]]:
""" """
Generate the specified number of models that Generate the specified number of models that
satisfy a logic of a certain size. satisfy a logic of a certain size.
@ -136,9 +136,10 @@ def generate_model(
for designated_values in possible_designated_values: for designated_values in possible_designated_values:
designated_values = set(designated_values) designated_values = set(designated_values)
print("Considering models for designated values", set_to_str(designated_values)) if debug:
print("Considering models for designated values", set_to_str(designated_values))
possible_interps = possible_interpretations(logic, carrier_set, designated_values) possible_interps = possible_interpretations(logic, carrier_set, designated_values, debug)
for interpretation in possible_interps: for interpretation in possible_interps:
is_valid = True is_valid = True
model = Model(carrier_set, set(interpretation.values()), designated_values) model = Model(carrier_set, set(interpretation.values()), designated_values)

View file

@ -85,7 +85,7 @@ class Logic:
name: Optional[str] = None): name: Optional[str] = None):
self.operations = operations self.operations = operations
self.rules = rules self.rules = rules
self.falsifies = falsifies self.falsifies = falsifies if falsifies is not None else set()
self.name = str(abs(hash(( self.name = str(abs(hash((
frozenset(operations), frozenset(operations),
frozenset(rules) frozenset(rules)

75
smt.py
View file

@ -1,18 +1,26 @@
from itertools import product from itertools import product
from typing import Dict, Generator, Optional, Tuple from typing import Dict, Generator, Optional, Set, Tuple, TYPE_CHECKING
from logic import Logic, Operation, Rule, PropositionalVariable, Term, OpTerm, get_prop_vars_from_rule from logic import Logic, Operation, Rule, PropositionalVariable, Term, OpTerm, get_prop_vars_from_rule
from model import Model, ModelValue, ModelFunction from model import Model, ModelValue, ModelFunction
from z3 import ( SMT_LOADED = True
And, BoolSort, Context, EnumSort, Function, Implies, Or, sat, Solver, z3 try:
) from z3 import (
And, BoolSort, Context, EnumSort, Function, Implies, Or, sat, Solver, z3
)
except ImportError:
SMT_LOADED = False
def smt_is_loaded() -> bool:
global SMT_LOADED
return SMT_LOADED
def term_to_smt( def term_to_smt(
t: Term, t: Term,
op_mapping: Dict[Operation, z3.FuncDeclRef], op_mapping: Dict[Operation, "z3.FuncDeclRef"],
var_mapping: Dict[PropositionalVariable, z3.DatatypeRef] var_mapping: Dict[PropositionalVariable, "z3.DatatypeRef"]
) -> z3.DatatypeRef: ) -> "z3.DatatypeRef":
"""Convert a logic term to its SMT representation.""" """Convert a logic term to its SMT representation."""
if isinstance(t, PropositionalVariable): if isinstance(t, PropositionalVariable):
return var_mapping[t] return var_mapping[t]
@ -26,10 +34,10 @@ def term_to_smt(
def logic_rule_to_smt_constraints( def logic_rule_to_smt_constraints(
rule: Rule, rule: Rule,
IsDesignated: z3.FuncDeclRef, IsDesignated: "z3.FuncDeclRef",
smt_carrier_set, smt_carrier_set,
op_mapping: Dict[Operation, z3.FuncDeclRef] op_mapping: Dict[Operation, "z3.FuncDeclRef"]
) -> Generator[z3.BoolRef, None, None]: ) -> Generator["z3.BoolRef", None, None]:
""" """
Encode a logic rule as SMT constraints. Encode a logic rule as SMT constraints.
@ -63,10 +71,10 @@ def logic_rule_to_smt_constraints(
def logic_falsification_rule_to_smt_constraints( def logic_falsification_rule_to_smt_constraints(
rule: Rule, rule: Rule,
IsDesignated: z3.FuncDeclRef, IsDesignated: "z3.FuncDeclRef",
smt_carrier_set, smt_carrier_set,
op_mapping: Dict[Operation, z3.FuncDeclRef] op_mapping: Dict[Operation, "z3.FuncDeclRef"]
) -> z3.BoolRef: ) -> "z3.BoolRef":
""" """
Encode a falsification rule as an SMT constraint. Encode a falsification rule as an SMT constraint.
@ -132,7 +140,7 @@ class SMTLogicEncoder:
self.carrier_sort, self.smt_carrier_set = EnumSort("C", element_names, ctx=self.ctx) self.carrier_sort, self.smt_carrier_set = EnumSort("C", element_names, ctx=self.ctx)
# Create operation functions # Create operation functions
self.operation_function_map: Dict[Operation, z3.FuncDeclRef] = {} self.operation_function_map: Dict[Operation, "z3.FuncDeclRef"] = {}
for operation in logic.operations: for operation in logic.operations:
self.operation_function_map[operation] = self.create_function(operation.symbol, operation.arity) self.operation_function_map[operation] = self.create_function(operation.symbol, operation.arity)
@ -143,10 +151,10 @@ class SMTLogicEncoder:
self._add_logic_constraints() self._add_logic_constraints()
self._add_designation_symmetry_constraints() self._add_designation_symmetry_constraints()
def create_predicate(self, name: str, arity: int) -> z3.FuncDeclRef: def create_predicate(self, name: str, arity: int) -> "z3.FuncDeclRef":
return Function(name, *(self.carrier_sort for _ in range(arity)), BoolSort(ctx=self.ctx)) return Function(name, *(self.carrier_sort for _ in range(arity)), BoolSort(ctx=self.ctx))
def create_function(self, name: str, arity: int) -> z3.FuncDeclRef: def create_function(self, name: str, arity: int) -> "z3.FuncDeclRef":
return Function(name, *(self.carrier_sort for _ in range(arity + 1))) return Function(name, *(self.carrier_sort for _ in range(arity + 1)))
def _add_logic_constraints(self): def _add_logic_constraints(self):
@ -171,9 +179,9 @@ class SMTLogicEncoder:
) )
self.solver.add(constraint) self.solver.add(constraint)
def extract_model(self, smt_model) -> Model: def extract_model(self, smt_model) -> Tuple[Model, Dict[Operation, ModelFunction]]:
""" """
Extract a Model object from an SMT model. Extract a Model object and interpretation from an SMT model.
""" """
carrier_set = {ModelValue(f"{i}") for i in range(self.size)} carrier_set = {ModelValue(f"{i}") for i in range(self.size)}
@ -185,7 +193,8 @@ class SMTLogicEncoder:
designated_values = {ModelValue(str(x)) for x in smt_designated} designated_values = {ModelValue(str(x)) for x in smt_designated}
# Extract operation functions # Extract operation functions
model_functions = set() model_functions: Set[ModelFunction] = set()
interpretation: Dict[Operation, ModelFunction] = dict()
for (operation, smt_function) in self.operation_function_map.items(): for (operation, smt_function) in self.operation_function_map.items():
mapping: Dict[Tuple[ModelValue], ModelValue] = {} mapping: Dict[Tuple[ModelValue], ModelValue] = {}
for smt_inputs in product(self.smt_carrier_set, repeat=operation.arity): for smt_inputs in product(self.smt_carrier_set, repeat=operation.arity):
@ -193,9 +202,12 @@ class SMTLogicEncoder:
smt_output = smt_model.evaluate(smt_function(*smt_inputs)) smt_output = smt_model.evaluate(smt_function(*smt_inputs))
model_output = ModelValue(str(smt_output)) model_output = ModelValue(str(smt_output))
mapping[model_inputs] = model_output mapping[model_inputs] = model_output
model_functions.add(ModelFunction(operation.arity, mapping, operation.symbol)) model_function = ModelFunction(operation.arity, mapping, operation.symbol)
model_functions.add(model_function)
interpretation[operation] = model_function
return Model(carrier_set, model_functions, designated_values)
return Model(carrier_set, model_functions, designated_values), interpretation
def _add_designation_symmetry_constraints(self): def _add_designation_symmetry_constraints(self):
@ -218,7 +230,7 @@ class SMTLogicEncoder:
) )
) )
def create_exclusion_constraint(self, model: Model) -> z3.BoolRef: def create_exclusion_constraint(self, model: Model) -> "z3.BoolRef":
""" """
Create a constraint that excludes the given model from future solutions. Create a constraint that excludes the given model from future solutions.
""" """
@ -254,7 +266,7 @@ class SMTLogicEncoder:
return Or(constraints) return Or(constraints)
def find_model(self) -> Optional[Model]: def find_model(self) -> Optional[Tuple[Model, Dict[Operation, ModelFunction]]]:
""" """
Find a single model satisfying the logic constraints. Find a single model satisfying the logic constraints.
@ -274,12 +286,12 @@ class SMTLogicEncoder:
pass pass
def find_model(logic: Logic, size: int) -> Optional[Model]: def find_model(logic: Logic, size: int) -> Optional[Tuple[Model, Dict[Operation, ModelFunction]]]:
"""Find a single model for the given logic and size.""" """Find a single model for the given logic and size."""
encoder = SMTLogicEncoder(logic, size) encoder = SMTLogicEncoder(logic, size)
return encoder.find_model() return encoder.find_model()
def find_all_models(logic: Logic, size: int) -> Generator[Model, None, None]: def find_all_models(logic: Logic, size: int) -> Generator[Tuple[Model, Dict[Operation, ModelFunction]], None, None]:
""" """
Find all models for the given logic and size. Find all models for the given logic and size.
@ -294,13 +306,14 @@ def find_all_models(logic: Logic, size: int) -> Generator[Model, None, None]:
while True: while True:
# Try to find a model # Try to find a model
model = encoder.find_model() solution = encoder.find_model()
if model is None: if solution is None:
break break
yield model yield solution
# Add constraint to exclude this model from future solutions # Add constraint to exclude this model from future solutions
model, _ = solution
exclusion_constraint = encoder.create_exclusion_constraint(model) exclusion_constraint = encoder.create_exclusion_constraint(model)
encoder.solver.add(exclusion_constraint) encoder.solver.add(exclusion_constraint)
@ -346,13 +359,13 @@ class SMTModelEncoder:
is_designated = model_value in model.designated_values is_designated = model_value in model.designated_values
self.solver.add(self.is_designated(self.model_value_to_smt[model_value]) == is_designated) self.solver.add(self.is_designated(self.model_value_to_smt[model_value]) == is_designated)
def create_predicate(self, name: str, arity: int) -> z3.FuncDeclRef: def create_predicate(self, name: str, arity: int) -> "z3.FuncDeclRef":
return Function(name, *(self.carrier_sort for _ in range(arity)), BoolSort(ctx=self.ctx)) return Function(name, *(self.carrier_sort for _ in range(arity)), BoolSort(ctx=self.ctx))
def create_function(self, name: str, arity: int) -> z3.FuncDeclRef: def create_function(self, name: str, arity: int) -> "z3.FuncDeclRef":
return Function(name, *(self.carrier_sort for _ in range(arity + 1))) return Function(name, *(self.carrier_sort for _ in range(arity + 1)))
def add_function_constraints_from_table(self, smt_fn: z3.FuncDeclRef, model_fn: ModelFunction): def add_function_constraints_from_table(self, smt_fn: "z3.FuncDeclRef", model_fn: ModelFunction):
for inputs, output in model_fn.mapping.items(): for inputs, output in model_fn.mapping.items():
smt_inputs = tuple(self.model_value_to_smt[inp] for inp in inputs) smt_inputs = tuple(self.model_value_to_smt[inp] for inp in inputs)
smt_output = self.model_value_to_smt[output] smt_output = self.model_value_to_smt[output]

16
vsp.py
View file

@ -10,12 +10,12 @@ from model import (
Model, model_closure, ModelFunction, ModelValue Model, model_closure, ModelFunction, ModelValue
) )
SMT_LOADED = True from smt import SMTModelEncoder, SMTLogicEncoder, smt_is_loaded
try: try:
from z3 import And, Or, Implies, sat from z3 import And, Or, Implies, sat
from smt import SMTModelEncoder, SMTLogicEncoder
except ImportError: except ImportError:
SMT_LOADED = False pass
class VSP_Result: class VSP_Result:
def __init__( def __init__(
@ -139,7 +139,7 @@ def has_vsp_smt(model: Model, impfn: ModelFunction) -> VSP_Result:
Checks whether a given model satisfies the variable Checks whether a given model satisfies the variable
sharing property via SMT sharing property via SMT
""" """
if not SMT_LOADED: if not smt_is_loaded():
raise Exception("Z3 is not property installed, cannot check via SMT") raise Exception("Z3 is not property installed, cannot check via SMT")
encoder = SMTModelEncoder(model) encoder = SMTModelEncoder(model)
@ -198,7 +198,7 @@ def has_vsp(model: Model, impfunction: ModelFunction,
if model.is_magical: if model.is_magical:
return has_vsp_magical(model, impfunction, negation_defined, conjunction_disjunction_defined) return has_vsp_magical(model, impfunction, negation_defined, conjunction_disjunction_defined)
return has_vsp_smt(model) return has_vsp_smt(model, impfunction)
def logic_has_vsp(logic: Logic, size: int) -> Optional[Tuple[Model, VSP_Result]]: def logic_has_vsp(logic: Logic, size: int) -> Optional[Tuple[Model, VSP_Result]]:
@ -254,15 +254,15 @@ def logic_has_vsp(logic: Logic, size: int) -> Optional[Tuple[Model, VSP_Result]]
) )
) )
model = encoder.find_model() solution = encoder.find_model()
# We failed to find a VSP witness # We failed to find a VSP witness
if model is None: if solution is None:
return None return None
# Otherwise, a matrix model and correspoding # Otherwise, a matrix model and correspoding
# subalgebras exist. # subalgebras exist.
model, _ = solution
smt_model = encoder.solver.model() smt_model = encoder.solver.model()
K1_smt = [x for x in encoder.smt_carrier_set if smt_model.evaluate(IsInK1(x))] K1_smt = [x for x in encoder.smt_carrier_set if smt_model.evaluate(IsInK1(x))]
K1 = {ModelValue(str(x)) for x in K1_smt} K1 = {ModelValue(str(x)) for x in K1_smt}