Refactored out valuation generation code in SMT

This commit is contained in:
Brandon Rozek 2026-01-27 15:22:47 -05:00
parent 85ef364a57
commit 06cca7d32e
2 changed files with 43 additions and 25 deletions

View file

@ -13,7 +13,7 @@ from itertools import (
chain, combinations_with_replacement, chain, combinations_with_replacement,
permutations, product permutations, product
) )
from typing import Dict, List, Optional, Set, Tuple from typing import Any, Dict, Generator, List, Optional, Set, Tuple
__all__ = ['ModelValue', 'ModelFunction', 'Model', 'Interpretation'] __all__ = ['ModelValue', 'ModelFunction', 'Model', 'Interpretation']
@ -255,7 +255,7 @@ def evaluate_term(
def all_model_valuations( def all_model_valuations(
pvars: Tuple[PropositionalVariable], pvars: Tuple[PropositionalVariable],
mvalues: Tuple[ModelValue]): mvalues: Tuple[ModelValue]) -> Generator[Dict[PropositionalVariable, ModelValue], Any, None]:
""" """
Given propositional variables and model values, Given propositional variables and model values,
produce every possible mapping between the two. produce every possible mapping between the two.

64
smt.py
View file

@ -1,5 +1,6 @@
from functools import lru_cache
from itertools import product from itertools import product
from typing import Dict, Generator, Optional, Set, Tuple, TYPE_CHECKING from typing import Dict, Generator, Optional, Set, Tuple
from logic import Logic, Operation, Rule, PropositionalVariable, Term, OpTerm, get_prop_vars_from_rule from logic import Logic, Operation, Rule, PropositionalVariable, Term, OpTerm, get_prop_vars_from_rule
from model import Model, ModelValue, ModelFunction from model import Model, ModelValue, ModelFunction
@ -27,11 +28,32 @@ def term_to_smt(
assert isinstance(t, OpTerm) assert isinstance(t, OpTerm)
# Recursively convert all arguments to SMT
arguments = [term_to_smt(arg, op_mapping, var_mapping) for arg in t.arguments] arguments = [term_to_smt(arg, op_mapping, var_mapping) for arg in t.arguments]
fn = op_mapping[t.operation] fn = op_mapping[t.operation]
return fn(*arguments) return fn(*arguments)
def all_smt_valuations(pvars: Tuple[PropositionalVariable], smtvalues):
"""
Generator which maps all the propositional variable to
smt variables representing the carrier set.
Exhaust the generator to get all such mappings.
"""
all_possible_values = product(smtvalues, repeat=len(pvars))
for valuation in all_possible_values:
mapping = dict()
assert len(pvars) == len(valuation)
for pvar, value in zip(pvars, valuation):
mapping[pvar] = value
yield mapping
@lru_cache
def all_smt_valuations_cached(pvars: Tuple[PropositionalVariable], smtvalues):
return list(all_smt_valuations(pvars, smtvalues))
def logic_rule_to_smt_constraints( def logic_rule_to_smt_constraints(
rule: Rule, rule: Rule,
IsDesignated: "z3.FuncDeclRef", IsDesignated: "z3.FuncDeclRef",
@ -43,32 +65,30 @@ def logic_rule_to_smt_constraints(
For all valuations: if premises are designated, then conclusion is designated. For all valuations: if premises are designated, then conclusion is designated.
""" """
prop_vars = get_prop_vars_from_rule(rule) prop_vars = tuple(get_prop_vars_from_rule(rule))
valuations = all_smt_valuations_cached(prop_vars, tuple(smt_carrier_set))
# Requires that the rule holds under all valuations
for smt_vars in product(smt_carrier_set, repeat=len(prop_vars)):
assert len(prop_vars) == len(smt_vars)
var_mapping = {
prop_var: smt_var
for (prop_var, smt_var) in zip(prop_vars, smt_vars)
}
for valuation in valuations:
premises = [ premises = [
IsDesignated(term_to_smt(premise, op_mapping, var_mapping)) == True IsDesignated(term_to_smt(premise, op_mapping, valuation)) == True
for premise in rule.premises for premise in rule.premises
] ]
conclusion = IsDesignated(term_to_smt(rule.conclusion, op_mapping, var_mapping)) == True conclusion = IsDesignated(term_to_smt(rule.conclusion, op_mapping, valuation)) == True
if len(premises) == 0: if len(premises) == 0:
# If there are no premises, then the conclusion must always be designated
yield conclusion yield conclusion
else: else:
# Otherwise, combine all the premises with and
# and have that if the premises are designated
# then the conclusion is designated
premise = premises[0] premise = premises[0]
for p in premises[1:]: for p in premises[1:]:
premise = And(premise, p) premise = And(premise, p)
yield Implies(premise, conclusion) yield Implies(premise, conclusion)
def logic_falsification_rule_to_smt_constraints( def logic_falsification_rule_to_smt_constraints(
rule: Rule, rule: Rule,
IsDesignated: "z3.FuncDeclRef", IsDesignated: "z3.FuncDeclRef",
@ -81,24 +101,22 @@ def logic_falsification_rule_to_smt_constraints(
There exists at least one valuation where premises are designated There exists at least one valuation where premises are designated
but conclusion is not designated. but conclusion is not designated.
""" """
prop_vars = get_prop_vars_from_rule(rule) prop_vars = tuple(get_prop_vars_from_rule(rule))
valuations = all_smt_valuations_cached(prop_vars, tuple(smt_carrier_set))
# Collect all possible counter-examples (valuations that falsify the rule) # Collect all possible counter-examples (valuations that falsify the rule)
counter_examples = [] counter_examples = []
for smt_vars in product(smt_carrier_set, repeat=len(prop_vars)): for valuation in valuations:
assert len(prop_vars) == len(smt_vars) # The rule is falsified when all of our premises
# are designated but our conclusion is not designated
var_mapping = {
prop_var: smt_var
for (prop_var, smt_var) in zip(prop_vars, smt_vars)
}
premises = [ premises = [
IsDesignated(term_to_smt(premise, op_mapping, var_mapping)) == True IsDesignated(term_to_smt(premise, op_mapping, valuation)) == True
for premise in rule.premises for premise in rule.premises
] ]
conclusion = IsDesignated(term_to_smt(rule.conclusion, op_mapping, var_mapping)) == False
conclusion = IsDesignated(term_to_smt(rule.conclusion, op_mapping, valuation)) == False
if len(premises) == 0: if len(premises) == 0:
counter_examples.append(conclusion) counter_examples.append(conclusion)