mirror of
https://github.com/Brandon-Rozek/matmod.git
synced 2026-01-30 07:33:38 +00:00
Updated driver file R.py to showcase SMT techinques
Fixed minor bugs concerning lack of falsification rules and interfaces between VSP and SMT
This commit is contained in:
parent
f8eca388d4
commit
6d87793803
5 changed files with 120 additions and 65 deletions
75
smt.py
75
smt.py
|
|
@ -1,18 +1,26 @@
|
|||
from itertools import product
|
||||
from typing import Dict, Generator, Optional, Tuple
|
||||
from typing import Dict, Generator, Optional, Set, Tuple, TYPE_CHECKING
|
||||
|
||||
from logic import Logic, Operation, Rule, PropositionalVariable, Term, OpTerm, get_prop_vars_from_rule
|
||||
from model import Model, ModelValue, ModelFunction
|
||||
|
||||
from z3 import (
|
||||
And, BoolSort, Context, EnumSort, Function, Implies, Or, sat, Solver, z3
|
||||
)
|
||||
SMT_LOADED = True
|
||||
try:
|
||||
from z3 import (
|
||||
And, BoolSort, Context, EnumSort, Function, Implies, Or, sat, Solver, z3
|
||||
)
|
||||
except ImportError:
|
||||
SMT_LOADED = False
|
||||
|
||||
def smt_is_loaded() -> bool:
|
||||
global SMT_LOADED
|
||||
return SMT_LOADED
|
||||
|
||||
def term_to_smt(
|
||||
t: Term,
|
||||
op_mapping: Dict[Operation, z3.FuncDeclRef],
|
||||
var_mapping: Dict[PropositionalVariable, z3.DatatypeRef]
|
||||
) -> z3.DatatypeRef:
|
||||
op_mapping: Dict[Operation, "z3.FuncDeclRef"],
|
||||
var_mapping: Dict[PropositionalVariable, "z3.DatatypeRef"]
|
||||
) -> "z3.DatatypeRef":
|
||||
"""Convert a logic term to its SMT representation."""
|
||||
if isinstance(t, PropositionalVariable):
|
||||
return var_mapping[t]
|
||||
|
|
@ -26,10 +34,10 @@ def term_to_smt(
|
|||
|
||||
def logic_rule_to_smt_constraints(
|
||||
rule: Rule,
|
||||
IsDesignated: z3.FuncDeclRef,
|
||||
IsDesignated: "z3.FuncDeclRef",
|
||||
smt_carrier_set,
|
||||
op_mapping: Dict[Operation, z3.FuncDeclRef]
|
||||
) -> Generator[z3.BoolRef, None, None]:
|
||||
op_mapping: Dict[Operation, "z3.FuncDeclRef"]
|
||||
) -> Generator["z3.BoolRef", None, None]:
|
||||
"""
|
||||
Encode a logic rule as SMT constraints.
|
||||
|
||||
|
|
@ -63,10 +71,10 @@ def logic_rule_to_smt_constraints(
|
|||
|
||||
def logic_falsification_rule_to_smt_constraints(
|
||||
rule: Rule,
|
||||
IsDesignated: z3.FuncDeclRef,
|
||||
IsDesignated: "z3.FuncDeclRef",
|
||||
smt_carrier_set,
|
||||
op_mapping: Dict[Operation, z3.FuncDeclRef]
|
||||
) -> z3.BoolRef:
|
||||
op_mapping: Dict[Operation, "z3.FuncDeclRef"]
|
||||
) -> "z3.BoolRef":
|
||||
"""
|
||||
Encode a falsification rule as an SMT constraint.
|
||||
|
||||
|
|
@ -132,7 +140,7 @@ class SMTLogicEncoder:
|
|||
self.carrier_sort, self.smt_carrier_set = EnumSort("C", element_names, ctx=self.ctx)
|
||||
|
||||
# Create operation functions
|
||||
self.operation_function_map: Dict[Operation, z3.FuncDeclRef] = {}
|
||||
self.operation_function_map: Dict[Operation, "z3.FuncDeclRef"] = {}
|
||||
for operation in logic.operations:
|
||||
self.operation_function_map[operation] = self.create_function(operation.symbol, operation.arity)
|
||||
|
||||
|
|
@ -143,10 +151,10 @@ class SMTLogicEncoder:
|
|||
self._add_logic_constraints()
|
||||
self._add_designation_symmetry_constraints()
|
||||
|
||||
def create_predicate(self, name: str, arity: int) -> z3.FuncDeclRef:
|
||||
def create_predicate(self, name: str, arity: int) -> "z3.FuncDeclRef":
|
||||
return Function(name, *(self.carrier_sort for _ in range(arity)), BoolSort(ctx=self.ctx))
|
||||
|
||||
def create_function(self, name: str, arity: int) -> z3.FuncDeclRef:
|
||||
def create_function(self, name: str, arity: int) -> "z3.FuncDeclRef":
|
||||
return Function(name, *(self.carrier_sort for _ in range(arity + 1)))
|
||||
|
||||
def _add_logic_constraints(self):
|
||||
|
|
@ -171,9 +179,9 @@ class SMTLogicEncoder:
|
|||
)
|
||||
self.solver.add(constraint)
|
||||
|
||||
def extract_model(self, smt_model) -> Model:
|
||||
def extract_model(self, smt_model) -> Tuple[Model, Dict[Operation, ModelFunction]]:
|
||||
"""
|
||||
Extract a Model object from an SMT model.
|
||||
Extract a Model object and interpretation from an SMT model.
|
||||
"""
|
||||
carrier_set = {ModelValue(f"{i}") for i in range(self.size)}
|
||||
|
||||
|
|
@ -185,7 +193,8 @@ class SMTLogicEncoder:
|
|||
designated_values = {ModelValue(str(x)) for x in smt_designated}
|
||||
|
||||
# Extract operation functions
|
||||
model_functions = set()
|
||||
model_functions: Set[ModelFunction] = set()
|
||||
interpretation: Dict[Operation, ModelFunction] = dict()
|
||||
for (operation, smt_function) in self.operation_function_map.items():
|
||||
mapping: Dict[Tuple[ModelValue], ModelValue] = {}
|
||||
for smt_inputs in product(self.smt_carrier_set, repeat=operation.arity):
|
||||
|
|
@ -193,9 +202,12 @@ class SMTLogicEncoder:
|
|||
smt_output = smt_model.evaluate(smt_function(*smt_inputs))
|
||||
model_output = ModelValue(str(smt_output))
|
||||
mapping[model_inputs] = model_output
|
||||
model_functions.add(ModelFunction(operation.arity, mapping, operation.symbol))
|
||||
model_function = ModelFunction(operation.arity, mapping, operation.symbol)
|
||||
model_functions.add(model_function)
|
||||
interpretation[operation] = model_function
|
||||
|
||||
return Model(carrier_set, model_functions, designated_values)
|
||||
|
||||
return Model(carrier_set, model_functions, designated_values), interpretation
|
||||
|
||||
|
||||
def _add_designation_symmetry_constraints(self):
|
||||
|
|
@ -218,7 +230,7 @@ class SMTLogicEncoder:
|
|||
)
|
||||
)
|
||||
|
||||
def create_exclusion_constraint(self, model: Model) -> z3.BoolRef:
|
||||
def create_exclusion_constraint(self, model: Model) -> "z3.BoolRef":
|
||||
"""
|
||||
Create a constraint that excludes the given model from future solutions.
|
||||
"""
|
||||
|
|
@ -254,7 +266,7 @@ class SMTLogicEncoder:
|
|||
|
||||
return Or(constraints)
|
||||
|
||||
def find_model(self) -> Optional[Model]:
|
||||
def find_model(self) -> Optional[Tuple[Model, Dict[Operation, ModelFunction]]]:
|
||||
"""
|
||||
Find a single model satisfying the logic constraints.
|
||||
|
||||
|
|
@ -274,12 +286,12 @@ class SMTLogicEncoder:
|
|||
pass
|
||||
|
||||
|
||||
def find_model(logic: Logic, size: int) -> Optional[Model]:
|
||||
def find_model(logic: Logic, size: int) -> Optional[Tuple[Model, Dict[Operation, ModelFunction]]]:
|
||||
"""Find a single model for the given logic and size."""
|
||||
encoder = SMTLogicEncoder(logic, size)
|
||||
return encoder.find_model()
|
||||
|
||||
def find_all_models(logic: Logic, size: int) -> Generator[Model, None, None]:
|
||||
def find_all_models(logic: Logic, size: int) -> Generator[Tuple[Model, Dict[Operation, ModelFunction]], None, None]:
|
||||
"""
|
||||
Find all models for the given logic and size.
|
||||
|
||||
|
|
@ -294,13 +306,14 @@ def find_all_models(logic: Logic, size: int) -> Generator[Model, None, None]:
|
|||
|
||||
while True:
|
||||
# Try to find a model
|
||||
model = encoder.find_model()
|
||||
if model is None:
|
||||
solution = encoder.find_model()
|
||||
if solution is None:
|
||||
break
|
||||
|
||||
yield model
|
||||
yield solution
|
||||
|
||||
# Add constraint to exclude this model from future solutions
|
||||
model, _ = solution
|
||||
exclusion_constraint = encoder.create_exclusion_constraint(model)
|
||||
encoder.solver.add(exclusion_constraint)
|
||||
|
||||
|
|
@ -346,13 +359,13 @@ class SMTModelEncoder:
|
|||
is_designated = model_value in model.designated_values
|
||||
self.solver.add(self.is_designated(self.model_value_to_smt[model_value]) == is_designated)
|
||||
|
||||
def create_predicate(self, name: str, arity: int) -> z3.FuncDeclRef:
|
||||
def create_predicate(self, name: str, arity: int) -> "z3.FuncDeclRef":
|
||||
return Function(name, *(self.carrier_sort for _ in range(arity)), BoolSort(ctx=self.ctx))
|
||||
|
||||
def create_function(self, name: str, arity: int) -> z3.FuncDeclRef:
|
||||
def create_function(self, name: str, arity: int) -> "z3.FuncDeclRef":
|
||||
return Function(name, *(self.carrier_sort for _ in range(arity + 1)))
|
||||
|
||||
def add_function_constraints_from_table(self, smt_fn: z3.FuncDeclRef, model_fn: ModelFunction):
|
||||
def add_function_constraints_from_table(self, smt_fn: "z3.FuncDeclRef", model_fn: ModelFunction):
|
||||
for inputs, output in model_fn.mapping.items():
|
||||
smt_inputs = tuple(self.model_value_to_smt[inp] for inp in inputs)
|
||||
smt_output = self.model_value_to_smt[output]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue