vspursuer/smt.py

101 lines
3.7 KiB
Python

from itertools import product
from typing import Dict, Generator, Optional, Tuple
from logic import Logic, Operation, Rule, PropositionalVariable, Term, OpTerm, get_prop_vars_from_rule
from model import Model, ModelValue, ModelFunction
from z3 import EnumSort, Function, BoolSort, z3, And, Implies, Solver, sat, Context
# TODO: Add an assumption that a partial order exists over the carrier set.
# This adds three restrictions to the logic
# 1) A -> A is always designated
# 2) If A -> B is designated and B -> C is designated then A -> C is designated
# 3) If A -> B is designated and B -> A is designated then A and B share the same truth value
def term_to_smt(t: Term, op_mapping: Dict[Operation, z3.FuncDeclRef], var_mapping: Dict[PropositionalVariable, z3.DatatypeRef]) -> z3.DatatypeRef:
if isinstance(t, PropositionalVariable):
return var_mapping[t]
assert isinstance(t, OpTerm)
arguments = [term_to_smt(a, op_mapping, var_mapping) for a in t.arguments]
fn = op_mapping[t.operation]
return fn(*arguments)
def logic_rule_to_smt_constraints(rule: Rule, IsDesignated: z3.FuncDeclRef, smt_carrier_set, op_mapping: Dict[Operation, z3.FuncDeclRef]) -> Generator[z3.BoolRef, None, None]:
prop_vars = get_prop_vars_from_rule(rule)
for smt_vars in product(smt_carrier_set, repeat=len(prop_vars)):
assert len(prop_vars) == len(smt_vars)
var_mapping = {
prop_var : smt_var
for (prop_var, smt_var) in zip(prop_vars, smt_vars)
}
premises = [IsDesignated(term_to_smt(premise, op_mapping, var_mapping)) == True for premise in rule.premises]
conclusion = IsDesignated(term_to_smt(rule.conclusion, op_mapping, var_mapping)) == True
if len(premises) == 0:
yield conclusion
else:
premise = premises[0]
for p in premises[1:]:
premise = And(premise, p)
yield Implies(premise, conclusion)
def find_model(l: Logic, size: int) -> Optional[Model]:
assert size > 0
ctx = Context()
solver = Solver(ctx=ctx)
element_names = [f'{i}' for i in range(size)]
Carrier_sort, smt_carrier_set = EnumSort("C", element_names, ctx=ctx)
operation_function_map: Dict[Operation, z3.FuncDeclRef] = {}
for operation in l.operations:
operation_function_map[operation] = Function(
operation.symbol,
*(Carrier_sort for _ in range(operation.arity + 1))
)
IsDesignated = Function("D", Carrier_sort, BoolSort(ctx=ctx))
for rule in l.rules:
for constraint in logic_rule_to_smt_constraints(rule, IsDesignated, smt_carrier_set, operation_function_map):
solver.add(constraint)
smt_result = solver.check()
if smt_result == sat:
smt_model = solver.model()
carrier_set = {ModelValue(f"{i}") for i in range(size)}
smt_designated = [x for x in smt_carrier_set if smt_model.evaluate(IsDesignated(x))]
designated_values = {ModelValue(str(x)) for x in smt_designated}
model_functions = set()
for (operation, smt_function) in operation_function_map.items():
mapping: Dict[Tuple[ModelValue], ModelValue] = {}
for smt_inputs in product(smt_carrier_set, repeat=operation.arity):
model_inputs = tuple((ModelValue(str(i)) for i in smt_inputs))
smt_output = smt_model.evaluate(smt_function(*smt_inputs))
model_output = ModelValue(str(smt_output))
mapping[model_inputs] = model_output
model_functions.add(ModelFunction(operation.arity, mapping, operation.symbol, ))
solver.reset()
del ctx
return Model(carrier_set, model_functions, designated_values)
else:
return None