diff --git a/generate_model.py b/generate_model.py index 1306adc..5a46efc 100644 --- a/generate_model.py +++ b/generate_model.py @@ -1,6 +1,9 @@ """ Generate all the models for a given logic with a specified number of elements. + +NOTE: This uses a naive brute-force method which +is extremely slow. """ from common import set_to_str from logic import Logic, Operation, Rule, get_operations_from_term diff --git a/logic.py b/logic.py index 7775590..e2024da 100644 --- a/logic.py +++ b/logic.py @@ -100,17 +100,22 @@ def get_prop_var_from_term(t: Term) -> Set[PropositionalVariable]: return result +def get_prop_vars_from_rule(r: Rule) -> Set[PropositionalVariable]: + vars: Set[PropositionalVariable] = set() + + for premise in r.premises: + vars |= get_prop_var_from_term(premise) + + vars |= get_prop_var_from_term(r.conclusion) + + return vars + @lru_cache def get_propostional_variables(rules: Tuple[Rule]) -> Set[PropositionalVariable]: vars: Set[PropositionalVariable] = set() for rule in rules: - # Get all vars in premises - for premise in rule.premises: - vars |= get_prop_var_from_term(premise) - - # Get vars in conclusion - vars |= get_prop_var_from_term(rule.conclusion) + vars |= get_prop_vars_from_rule(rule) return vars diff --git a/smt.py b/smt.py new file mode 100644 index 0000000..5a238f7 --- /dev/null +++ b/smt.py @@ -0,0 +1,101 @@ +from itertools import product +from typing import Dict, Generator, Optional, Tuple + +from logic import Logic, Operation, Rule, PropositionalVariable, Term, OpTerm, get_prop_vars_from_rule +from model import Model, ModelValue, ModelFunction + +from z3 import EnumSort, Function, BoolSort, z3, And, Implies, Solver, sat, Context + +# TODO: Add an assumption that a partial order exists over the carrier set. +# This adds three restrictions to the logic +# 1) A -> A is always designated +# 2) If A -> B is designated and B -> C is designated then A -> C is designated +# 3) If A -> B is designated and B -> A is designated then A and B share the same truth value + +def term_to_smt(t: Term, op_mapping: Dict[Operation, z3.FuncDeclRef], var_mapping: Dict[PropositionalVariable, z3.DatatypeRef]) -> z3.DatatypeRef: + if isinstance(t, PropositionalVariable): + return var_mapping[t] + + assert isinstance(t, OpTerm) + + arguments = [term_to_smt(a, op_mapping, var_mapping) for a in t.arguments] + fn = op_mapping[t.operation] + + return fn(*arguments) + + +def logic_rule_to_smt_constraints(rule: Rule, IsDesignated: z3.FuncDeclRef, smt_carrier_set, op_mapping: Dict[Operation, z3.FuncDeclRef]) -> Generator[z3.BoolRef, None, None]: + prop_vars = get_prop_vars_from_rule(rule) + + for smt_vars in product(smt_carrier_set, repeat=len(prop_vars)): + assert len(prop_vars) == len(smt_vars) + + var_mapping = { + prop_var : smt_var + for (prop_var, smt_var) in zip(prop_vars, smt_vars) + } + + premises = [IsDesignated(term_to_smt(premise, op_mapping, var_mapping)) == True for premise in rule.premises] + conclusion = IsDesignated(term_to_smt(rule.conclusion, op_mapping, var_mapping)) == True + + if len(premises) == 0: + yield conclusion + else: + premise = premises[0] + for p in premises[1:]: + premise = And(premise, p) + + yield Implies(premise, conclusion) + + +def find_model(l: Logic, size: int) -> Optional[Model]: + assert size > 0 + + ctx = Context() + solver = Solver(ctx=ctx) + + element_names = [f'{i}' for i in range(size)] + Carrier_sort, smt_carrier_set = EnumSort("C", element_names, ctx=ctx) + + operation_function_map: Dict[Operation, z3.FuncDeclRef] = {} + + for operation in l.operations: + operation_function_map[operation] = Function( + operation.symbol, + *(Carrier_sort for _ in range(operation.arity + 1)) + ) + + IsDesignated = Function("D", Carrier_sort, BoolSort(ctx=ctx)) + + for rule in l.rules: + for constraint in logic_rule_to_smt_constraints(rule, IsDesignated, smt_carrier_set, operation_function_map): + solver.add(constraint) + + smt_result = solver.check() + + if smt_result == sat: + smt_model = solver.model() + + carrier_set = {ModelValue(f"{i}") for i in range(size)} + + smt_designated = [x for x in smt_carrier_set if smt_model.evaluate(IsDesignated(x))] + designated_values = {ModelValue(str(x)) for x in smt_designated} + + model_functions = set() + for (operation, smt_function) in operation_function_map.items(): + mapping: Dict[Tuple[ModelValue], ModelValue] = {} + for smt_inputs in product(smt_carrier_set, repeat=operation.arity): + model_inputs = tuple((ModelValue(str(i)) for i in smt_inputs)) + smt_output = smt_model.evaluate(smt_function(*smt_inputs)) + model_output = ModelValue(str(smt_output)) + mapping[model_inputs] = model_output + model_functions.add(ModelFunction(operation.arity, mapping, operation.symbol, )) + + solver.reset() + del ctx + + return Model(carrier_set, model_functions, designated_values) + + + else: + return None