From 06cca7d32ecd3c69d36fbb3a50f2f4b7ebe4ddf3 Mon Sep 17 00:00:00 2001 From: Brandon Rozek Date: Tue, 27 Jan 2026 15:22:47 -0500 Subject: [PATCH] Refactored out valuation generation code in SMT --- model.py | 4 ++-- smt.py | 64 ++++++++++++++++++++++++++++++++++++-------------------- 2 files changed, 43 insertions(+), 25 deletions(-) diff --git a/model.py b/model.py index 855b320..05a1a1d 100644 --- a/model.py +++ b/model.py @@ -13,7 +13,7 @@ from itertools import ( chain, combinations_with_replacement, permutations, product ) -from typing import Dict, List, Optional, Set, Tuple +from typing import Any, Dict, Generator, List, Optional, Set, Tuple __all__ = ['ModelValue', 'ModelFunction', 'Model', 'Interpretation'] @@ -255,7 +255,7 @@ def evaluate_term( def all_model_valuations( pvars: Tuple[PropositionalVariable], - mvalues: Tuple[ModelValue]): + mvalues: Tuple[ModelValue]) -> Generator[Dict[PropositionalVariable, ModelValue], Any, None]: """ Given propositional variables and model values, produce every possible mapping between the two. diff --git a/smt.py b/smt.py index 4394713..9dbee99 100644 --- a/smt.py +++ b/smt.py @@ -1,5 +1,6 @@ +from functools import lru_cache from itertools import product -from typing import Dict, Generator, Optional, Set, Tuple, TYPE_CHECKING +from typing import Dict, Generator, Optional, Set, Tuple from logic import Logic, Operation, Rule, PropositionalVariable, Term, OpTerm, get_prop_vars_from_rule from model import Model, ModelValue, ModelFunction @@ -27,11 +28,32 @@ def term_to_smt( assert isinstance(t, OpTerm) + # Recursively convert all arguments to SMT arguments = [term_to_smt(arg, op_mapping, var_mapping) for arg in t.arguments] fn = op_mapping[t.operation] return fn(*arguments) +def all_smt_valuations(pvars: Tuple[PropositionalVariable], smtvalues): + """ + Generator which maps all the propositional variable to + smt variables representing the carrier set. + + Exhaust the generator to get all such mappings. + """ + all_possible_values = product(smtvalues, repeat=len(pvars)) + for valuation in all_possible_values: + mapping = dict() + assert len(pvars) == len(valuation) + for pvar, value in zip(pvars, valuation): + mapping[pvar] = value + yield mapping + + +@lru_cache +def all_smt_valuations_cached(pvars: Tuple[PropositionalVariable], smtvalues): + return list(all_smt_valuations(pvars, smtvalues)) + def logic_rule_to_smt_constraints( rule: Rule, IsDesignated: "z3.FuncDeclRef", @@ -43,32 +65,30 @@ def logic_rule_to_smt_constraints( For all valuations: if premises are designated, then conclusion is designated. """ - prop_vars = get_prop_vars_from_rule(rule) - - # Requires that the rule holds under all valuations - for smt_vars in product(smt_carrier_set, repeat=len(prop_vars)): - assert len(prop_vars) == len(smt_vars) - - var_mapping = { - prop_var: smt_var - for (prop_var, smt_var) in zip(prop_vars, smt_vars) - } + prop_vars = tuple(get_prop_vars_from_rule(rule)) + valuations = all_smt_valuations_cached(prop_vars, tuple(smt_carrier_set)) + for valuation in valuations: premises = [ - IsDesignated(term_to_smt(premise, op_mapping, var_mapping)) == True + IsDesignated(term_to_smt(premise, op_mapping, valuation)) == True for premise in rule.premises ] - conclusion = IsDesignated(term_to_smt(rule.conclusion, op_mapping, var_mapping)) == True + conclusion = IsDesignated(term_to_smt(rule.conclusion, op_mapping, valuation)) == True if len(premises) == 0: + # If there are no premises, then the conclusion must always be designated yield conclusion else: + # Otherwise, combine all the premises with and + # and have that if the premises are designated + # then the conclusion is designated premise = premises[0] for p in premises[1:]: premise = And(premise, p) yield Implies(premise, conclusion) + def logic_falsification_rule_to_smt_constraints( rule: Rule, IsDesignated: "z3.FuncDeclRef", @@ -81,24 +101,22 @@ def logic_falsification_rule_to_smt_constraints( There exists at least one valuation where premises are designated but conclusion is not designated. """ - prop_vars = get_prop_vars_from_rule(rule) + prop_vars = tuple(get_prop_vars_from_rule(rule)) + valuations = all_smt_valuations_cached(prop_vars, tuple(smt_carrier_set)) # Collect all possible counter-examples (valuations that falsify the rule) counter_examples = [] - for smt_vars in product(smt_carrier_set, repeat=len(prop_vars)): - assert len(prop_vars) == len(smt_vars) - - var_mapping = { - prop_var: smt_var - for (prop_var, smt_var) in zip(prop_vars, smt_vars) - } + for valuation in valuations: + # The rule is falsified when all of our premises + # are designated but our conclusion is not designated premises = [ - IsDesignated(term_to_smt(premise, op_mapping, var_mapping)) == True + IsDesignated(term_to_smt(premise, op_mapping, valuation)) == True for premise in rule.premises ] - conclusion = IsDesignated(term_to_smt(rule.conclusion, op_mapping, var_mapping)) == False + + conclusion = IsDesignated(term_to_smt(rule.conclusion, op_mapping, valuation)) == False if len(premises) == 0: counter_examples.append(conclusion)