diff --git a/logic.py b/logic.py index 44c5c6b..e2024da 100644 --- a/logic.py +++ b/logic.py @@ -81,11 +81,9 @@ class Rule: class Logic: def __init__(self, operations: Set[Operation], rules: Set[Rule], - falsifies: Optional[Set[Rule]] = None, name: Optional[str] = None): self.operations = operations self.rules = rules - self.falsifies = falsifies self.name = str(abs(hash(( frozenset(operations), frozenset(rules) diff --git a/model.py b/model.py index 9c93c1e..429a26e 100644 --- a/model.py +++ b/model.py @@ -266,9 +266,9 @@ def satisfiable(logic: Logic, model: Model, interpretation: Dict[Operation, Mode pvars = tuple(get_propostional_variables(tuple(logic.rules))) mappings = all_model_valuations_cached(pvars, tuple(model.carrier_set)) - for rule in logic.rules: - # The rule most hold for all valuations - for mapping in mappings: + for mapping in mappings: + # Make sure that the model satisfies each of the rules + for rule in logic.rules: # The check only applies if the premises are designated premise_met = True premise_ts: Set[ModelValue] = set() @@ -291,36 +291,6 @@ def satisfiable(logic: Logic, model: Model, interpretation: Dict[Operation, Mode if consequent_t not in model.designated_values: return False - for rule in logic.falsifies: - # We must find one mapping where this does not hold - counterexample_found = False - for mapping in mappings: - # The check only applies if the premises are designated - premise_met = True - premise_ts: Set[ModelValue] = set() - - for premise in rule.premises: - premise_t = evaluate_term(premise, mapping, interpretation) - # As soon as one premise is not designated, - # move to the next rule. - if premise_t not in model.designated_values: - premise_met = False - break - # If designated, keep track of the evaluated term - premise_ts.add(premise_t) - - if not premise_met: - continue - - # With the premises designated, make sure the consequent is not designated - consequent_t = evaluate_term(rule.conclusion, mapping, interpretation) - if consequent_t not in model.designated_values: - counterexample_found = True - break - - if not counterexample_found: - return False - return True diff --git a/smt.py b/smt.py index f6f2c33..5a238f7 100644 --- a/smt.py +++ b/smt.py @@ -4,52 +4,38 @@ from typing import Dict, Generator, Optional, Tuple from logic import Logic, Operation, Rule, PropositionalVariable, Term, OpTerm, get_prop_vars_from_rule from model import Model, ModelValue, ModelFunction -from z3 import ( - And, BoolSort, Context, EnumSort, Function, Implies, Or, sat, Solver, z3 -) +from z3 import EnumSort, Function, BoolSort, z3, And, Implies, Solver, sat, Context -def term_to_smt( - t: Term, - op_mapping: Dict[Operation, z3.FuncDeclRef], - var_mapping: Dict[PropositionalVariable, z3.DatatypeRef] -) -> z3.DatatypeRef: - """Convert a logic term to its SMT representation.""" +# TODO: Add an assumption that a partial order exists over the carrier set. +# This adds three restrictions to the logic +# 1) A -> A is always designated +# 2) If A -> B is designated and B -> C is designated then A -> C is designated +# 3) If A -> B is designated and B -> A is designated then A and B share the same truth value + +def term_to_smt(t: Term, op_mapping: Dict[Operation, z3.FuncDeclRef], var_mapping: Dict[PropositionalVariable, z3.DatatypeRef]) -> z3.DatatypeRef: if isinstance(t, PropositionalVariable): return var_mapping[t] assert isinstance(t, OpTerm) - arguments = [term_to_smt(arg, op_mapping, var_mapping) for arg in t.arguments] + arguments = [term_to_smt(a, op_mapping, var_mapping) for a in t.arguments] fn = op_mapping[t.operation] return fn(*arguments) -def logic_rule_to_smt_constraints( - rule: Rule, - IsDesignated: z3.FuncDeclRef, - smt_carrier_set, - op_mapping: Dict[Operation, z3.FuncDeclRef] -) -> Generator[z3.BoolRef, None, None]: - """ - Encode a logic rule as SMT constraints. - For all valuations: if premises are designated, then conclusion is designated. - """ +def logic_rule_to_smt_constraints(rule: Rule, IsDesignated: z3.FuncDeclRef, smt_carrier_set, op_mapping: Dict[Operation, z3.FuncDeclRef]) -> Generator[z3.BoolRef, None, None]: prop_vars = get_prop_vars_from_rule(rule) - # Requires that the rule holds under all valuations for smt_vars in product(smt_carrier_set, repeat=len(prop_vars)): assert len(prop_vars) == len(smt_vars) var_mapping = { - prop_var: smt_var + prop_var : smt_var for (prop_var, smt_var) in zip(prop_vars, smt_vars) } - premises = [ - IsDesignated(term_to_smt(premise, op_mapping, var_mapping)) == True - for premise in rule.premises - ] + premises = [IsDesignated(term_to_smt(premise, op_mapping, var_mapping)) == True for premise in rule.premises] conclusion = IsDesignated(term_to_smt(rule.conclusion, op_mapping, var_mapping)) == True if len(premises) == 0: @@ -61,265 +47,55 @@ def logic_rule_to_smt_constraints( yield Implies(premise, conclusion) -def logic_falsification_rule_to_smt_constraints( - rule: Rule, - IsDesignated: z3.FuncDeclRef, - smt_carrier_set, - op_mapping: Dict[Operation, z3.FuncDeclRef] -) -> z3.BoolRef: - """ - Encode a falsification rule as an SMT constraint. - There exists at least one valuation where premises are designated - but conclusion is not designated. - """ - prop_vars = get_prop_vars_from_rule(rule) +def find_model(l: Logic, size: int) -> Optional[Model]: + assert size > 0 - # Collect all possible counter-examples (valuations that falsify the rule) - counter_examples = [] + ctx = Context() + solver = Solver(ctx=ctx) - for smt_vars in product(smt_carrier_set, repeat=len(prop_vars)): - assert len(prop_vars) == len(smt_vars) + element_names = [f'{i}' for i in range(size)] + Carrier_sort, smt_carrier_set = EnumSort("C", element_names, ctx=ctx) - var_mapping = { - prop_var: smt_var - for (prop_var, smt_var) in zip(prop_vars, smt_vars) - } + operation_function_map: Dict[Operation, z3.FuncDeclRef] = {} - premises = [ - IsDesignated(term_to_smt(premise, op_mapping, var_mapping)) == True - for premise in rule.premises - ] - conclusion = IsDesignated(term_to_smt(rule.conclusion, op_mapping, var_mapping)) == False + for operation in l.operations: + operation_function_map[operation] = Function( + operation.symbol, + *(Carrier_sort for _ in range(operation.arity + 1)) + ) - if len(premises) == 0: - counter_examples.append(conclusion) - else: - premise = premises[0] - for p in premises[1:]: - premise = And(premise, p) + IsDesignated = Function("D", Carrier_sort, BoolSort(ctx=ctx)) - counter_examples.append(And(premise, conclusion)) + for rule in l.rules: + for constraint in logic_rule_to_smt_constraints(rule, IsDesignated, smt_carrier_set, operation_function_map): + solver.add(constraint) - # At least one counter-example must exist (disjunction of all possibilities) - return Or(counter_examples) + smt_result = solver.check() + if smt_result == sat: + smt_model = solver.model() -class SMTLogicEncoder: - """ - Encapsulates the SMT encoding of a logic system with a fixed carrier set size. + carrier_set = {ModelValue(f"{i}") for i in range(size)} - This class handles: - - Creating the SMT sorts and functions - - Encoding logic rules as SMT constraints - - Managing the solver state - - Converting between Model objects and SMT representations - """ - - def __init__(self, logic: Logic, size: int): - """ - Initialize the SMT encoding for a logic with given carrier set size. - - Args: - logic: The logic system to encode - size: The size of the carrier set - """ - assert size > 0 - - self.logic = logic - self.size = size - - # Create Z3 context and solver - self.ctx = Context() - self.solver = Solver(ctx=self.ctx) - - # Create carrier set - element_names = [f'{i}' for i in range(size)] - self.carrier_sort, self.smt_carrier_set = EnumSort("C", element_names, ctx=self.ctx) - - # Create operation functions - self.operation_function_map: Dict[Operation, z3.FuncDeclRef] = {} - for operation in logic.operations: - self.operation_function_map[operation] = Function( - operation.symbol, - *(self.carrier_sort for _ in range(operation.arity + 1)) - ) - - # Create designation function - self.is_designated = Function("D", self.carrier_sort, BoolSort(ctx=self.ctx)) - - # Add logic rules as constraints - self._add_logic_constraints() - self._add_designation_symmetry_constraints() - - def _add_logic_constraints(self): - """Add all logic rules and falsification rules as SMT constraints.""" - # Add regular rules - for rule in self.logic.rules: - for constraint in logic_rule_to_smt_constraints( - rule, - self.is_designated, - self.smt_carrier_set, - self.operation_function_map - ): - self.solver.add(constraint) - - # Add falsification rules - for falsification_rule in self.logic.falsifies: - constraint = logic_falsification_rule_to_smt_constraints( - falsification_rule, - self.is_designated, - self.smt_carrier_set, - self.operation_function_map - ) - self.solver.add(constraint) - - def extract_model(self, smt_model) -> Model: - """ - Extract a Model object from an SMT model. - - Args: - smt_model: The Z3 model to extract from - - Returns: - A Model object representing the logic model - """ - carrier_set = {ModelValue(f"{i}") for i in range(self.size)} - - # Extract designated values - smt_designated = [ - x for x in self.smt_carrier_set - if smt_model.evaluate(self.is_designated(x)) - ] + smt_designated = [x for x in smt_carrier_set if smt_model.evaluate(IsDesignated(x))] designated_values = {ModelValue(str(x)) for x in smt_designated} - # Extract operation functions model_functions = set() - for (operation, smt_function) in self.operation_function_map.items(): + for (operation, smt_function) in operation_function_map.items(): mapping: Dict[Tuple[ModelValue], ModelValue] = {} - for smt_inputs in product(self.smt_carrier_set, repeat=operation.arity): - model_inputs = tuple(ModelValue(str(i)) for i in smt_inputs) + for smt_inputs in product(smt_carrier_set, repeat=operation.arity): + model_inputs = tuple((ModelValue(str(i)) for i in smt_inputs)) smt_output = smt_model.evaluate(smt_function(*smt_inputs)) model_output = ModelValue(str(smt_output)) mapping[model_inputs] = model_output - model_functions.add(ModelFunction(operation.arity, mapping, operation.symbol)) + model_functions.add(ModelFunction(operation.arity, mapping, operation.symbol, )) + + solver.reset() + del ctx return Model(carrier_set, model_functions, designated_values) - def _add_designation_symmetry_constraints(self): - """ - Add symmetry breaking constraints to avoid isomorphic models. - - Strategy: Enforce a lexicographic ordering on designated values. - If element i is not designated, then no element j < i can be designated. - This ensures designated elements are "packed to the right". - """ - for i in range(1, len(self.smt_carrier_set)): - elem_i = self.smt_carrier_set[i] - elem_j = self.smt_carrier_set[i - 1] - - # If i is not designated, then j (which comes before i) cannot be designated - self.solver.add( - Implies( - self.is_designated(elem_i) == False, - self.is_designated(elem_j) == False - ) - ) - - def create_exclusion_constraint(self, model: Model) -> z3.BoolRef: - """ - Create a constraint that excludes the given model from future solutions. - - Args: - model: The model to exclude - - Returns: - An SMT constraint ensuring at least one aspect differs - """ - constraints = [] - - # Create mapping from ModelValue to SMT element - model_value_to_smt = { - ModelValue(str(smt_elem)): smt_elem - for smt_elem in self.smt_carrier_set - } - - # Exclude operation function mappings - for model_func in model.logical_operations: - operation = Operation(model_func.operation_name, model_func.arity) - smt_func = self.operation_function_map[operation] - - for inputs, output in model_func.mapping.items(): - smt_inputs = tuple(model_value_to_smt[inp] for inp in inputs) - smt_output = model_value_to_smt[output] - - # This input->output mapping should differ - constraints.append(smt_func(*smt_inputs) != smt_output) - - # Exclude designated value set - for smt_elem in self.smt_carrier_set: - model_val = ModelValue(str(smt_elem)) - is_designated_in_model = model_val in model.designated_values - - # Designation should differ - if is_designated_in_model: - constraints.append(self.is_designated(smt_elem) == False) - else: - constraints.append(self.is_designated(smt_elem) == True) - - return Or(constraints) - - def find_model(self) -> Optional[Model]: - """ - Find a single model satisfying the logic constraints. - - Returns: - A Model if one exists, None otherwise - """ - if self.solver.check() == sat: - return self.extract_model(self.solver.model()) + else: return None - - def reset(self): - """Reset the solver state.""" - self.solver.reset() - - def __del__(self): - """Cleanup resources.""" - try: - self.solver.reset() - del self.ctx - except: - pass - - -def find_model(logic: Logic, size: int) -> Optional[Model]: - """Find a single model for the given logic and size.""" - encoder = SMTLogicEncoder(logic, size) - return encoder.find_model() - -def find_all_models(logic: Logic, size: int) -> Generator[Model, None, None]: - """ - Find all models for the given logic and size. - - Args: - logic: The logic system to encode - size: The size of the carrier set - - Yields: - Model instances that satisfy the logic - """ - encoder = SMTLogicEncoder(logic, size) - - while True: - # Try to find a model - model = encoder.find_model() - if model is None: - break - - yield model - - # Add constraint to exclude this model from future solutions - exclusion_constraint = encoder.create_exclusion_constraint(model) - encoder.solver.add(exclusion_constraint) \ No newline at end of file