Compare commits

...

2 commits

3 changed files with 300 additions and 44 deletions

View file

@ -81,9 +81,11 @@ class Rule:
class Logic:
def __init__(self,
operations: Set[Operation], rules: Set[Rule],
falsifies: Optional[Set[Rule]] = None,
name: Optional[str] = None):
self.operations = operations
self.rules = rules
self.falsifies = falsifies
self.name = str(abs(hash((
frozenset(operations),
frozenset(rules)

View file

@ -266,9 +266,9 @@ def satisfiable(logic: Logic, model: Model, interpretation: Dict[Operation, Mode
pvars = tuple(get_propostional_variables(tuple(logic.rules)))
mappings = all_model_valuations_cached(pvars, tuple(model.carrier_set))
for mapping in mappings:
# Make sure that the model satisfies each of the rules
for rule in logic.rules:
for rule in logic.rules:
# The rule most hold for all valuations
for mapping in mappings:
# The check only applies if the premises are designated
premise_met = True
premise_ts: Set[ModelValue] = set()
@ -291,6 +291,36 @@ def satisfiable(logic: Logic, model: Model, interpretation: Dict[Operation, Mode
if consequent_t not in model.designated_values:
return False
for rule in logic.falsifies:
# We must find one mapping where this does not hold
counterexample_found = False
for mapping in mappings:
# The check only applies if the premises are designated
premise_met = True
premise_ts: Set[ModelValue] = set()
for premise in rule.premises:
premise_t = evaluate_term(premise, mapping, interpretation)
# As soon as one premise is not designated,
# move to the next rule.
if premise_t not in model.designated_values:
premise_met = False
break
# If designated, keep track of the evaluated term
premise_ts.add(premise_t)
if not premise_met:
continue
# With the premises designated, make sure the consequent is not designated
consequent_t = evaluate_term(rule.conclusion, mapping, interpretation)
if consequent_t not in model.designated_values:
counterexample_found = True
break
if not counterexample_found:
return False
return True

306
smt.py
View file

@ -4,38 +4,52 @@ from typing import Dict, Generator, Optional, Tuple
from logic import Logic, Operation, Rule, PropositionalVariable, Term, OpTerm, get_prop_vars_from_rule
from model import Model, ModelValue, ModelFunction
from z3 import EnumSort, Function, BoolSort, z3, And, Implies, Solver, sat, Context
from z3 import (
And, BoolSort, Context, EnumSort, Function, Implies, Or, sat, Solver, z3
)
# TODO: Add an assumption that a partial order exists over the carrier set.
# This adds three restrictions to the logic
# 1) A -> A is always designated
# 2) If A -> B is designated and B -> C is designated then A -> C is designated
# 3) If A -> B is designated and B -> A is designated then A and B share the same truth value
def term_to_smt(t: Term, op_mapping: Dict[Operation, z3.FuncDeclRef], var_mapping: Dict[PropositionalVariable, z3.DatatypeRef]) -> z3.DatatypeRef:
def term_to_smt(
t: Term,
op_mapping: Dict[Operation, z3.FuncDeclRef],
var_mapping: Dict[PropositionalVariable, z3.DatatypeRef]
) -> z3.DatatypeRef:
"""Convert a logic term to its SMT representation."""
if isinstance(t, PropositionalVariable):
return var_mapping[t]
assert isinstance(t, OpTerm)
arguments = [term_to_smt(a, op_mapping, var_mapping) for a in t.arguments]
arguments = [term_to_smt(arg, op_mapping, var_mapping) for arg in t.arguments]
fn = op_mapping[t.operation]
return fn(*arguments)
def logic_rule_to_smt_constraints(
rule: Rule,
IsDesignated: z3.FuncDeclRef,
smt_carrier_set,
op_mapping: Dict[Operation, z3.FuncDeclRef]
) -> Generator[z3.BoolRef, None, None]:
"""
Encode a logic rule as SMT constraints.
def logic_rule_to_smt_constraints(rule: Rule, IsDesignated: z3.FuncDeclRef, smt_carrier_set, op_mapping: Dict[Operation, z3.FuncDeclRef]) -> Generator[z3.BoolRef, None, None]:
For all valuations: if premises are designated, then conclusion is designated.
"""
prop_vars = get_prop_vars_from_rule(rule)
# Requires that the rule holds under all valuations
for smt_vars in product(smt_carrier_set, repeat=len(prop_vars)):
assert len(prop_vars) == len(smt_vars)
var_mapping = {
prop_var : smt_var
prop_var: smt_var
for (prop_var, smt_var) in zip(prop_vars, smt_vars)
}
premises = [IsDesignated(term_to_smt(premise, op_mapping, var_mapping)) == True for premise in rule.premises]
premises = [
IsDesignated(term_to_smt(premise, op_mapping, var_mapping)) == True
for premise in rule.premises
]
conclusion = IsDesignated(term_to_smt(rule.conclusion, op_mapping, var_mapping)) == True
if len(premises) == 0:
@ -47,55 +61,265 @@ def logic_rule_to_smt_constraints(rule: Rule, IsDesignated: z3.FuncDeclRef, smt_
yield Implies(premise, conclusion)
def logic_falsification_rule_to_smt_constraints(
rule: Rule,
IsDesignated: z3.FuncDeclRef,
smt_carrier_set,
op_mapping: Dict[Operation, z3.FuncDeclRef]
) -> z3.BoolRef:
"""
Encode a falsification rule as an SMT constraint.
def find_model(l: Logic, size: int) -> Optional[Model]:
assert size > 0
There exists at least one valuation where premises are designated
but conclusion is not designated.
"""
prop_vars = get_prop_vars_from_rule(rule)
ctx = Context()
solver = Solver(ctx=ctx)
# Collect all possible counter-examples (valuations that falsify the rule)
counter_examples = []
element_names = [f'{i}' for i in range(size)]
Carrier_sort, smt_carrier_set = EnumSort("C", element_names, ctx=ctx)
for smt_vars in product(smt_carrier_set, repeat=len(prop_vars)):
assert len(prop_vars) == len(smt_vars)
operation_function_map: Dict[Operation, z3.FuncDeclRef] = {}
var_mapping = {
prop_var: smt_var
for (prop_var, smt_var) in zip(prop_vars, smt_vars)
}
for operation in l.operations:
operation_function_map[operation] = Function(
operation.symbol,
*(Carrier_sort for _ in range(operation.arity + 1))
)
premises = [
IsDesignated(term_to_smt(premise, op_mapping, var_mapping)) == True
for premise in rule.premises
]
conclusion = IsDesignated(term_to_smt(rule.conclusion, op_mapping, var_mapping)) == False
IsDesignated = Function("D", Carrier_sort, BoolSort(ctx=ctx))
if len(premises) == 0:
counter_examples.append(conclusion)
else:
premise = premises[0]
for p in premises[1:]:
premise = And(premise, p)
for rule in l.rules:
for constraint in logic_rule_to_smt_constraints(rule, IsDesignated, smt_carrier_set, operation_function_map):
solver.add(constraint)
counter_examples.append(And(premise, conclusion))
smt_result = solver.check()
# At least one counter-example must exist (disjunction of all possibilities)
return Or(counter_examples)
if smt_result == sat:
smt_model = solver.model()
carrier_set = {ModelValue(f"{i}") for i in range(size)}
class SMTLogicEncoder:
"""
Encapsulates the SMT encoding of a logic system with a fixed carrier set size.
smt_designated = [x for x in smt_carrier_set if smt_model.evaluate(IsDesignated(x))]
This class handles:
- Creating the SMT sorts and functions
- Encoding logic rules as SMT constraints
- Managing the solver state
- Converting between Model objects and SMT representations
"""
def __init__(self, logic: Logic, size: int):
"""
Initialize the SMT encoding for a logic with given carrier set size.
Args:
logic: The logic system to encode
size: The size of the carrier set
"""
assert size > 0
self.logic = logic
self.size = size
# Create Z3 context and solver
self.ctx = Context()
self.solver = Solver(ctx=self.ctx)
# Create carrier set
element_names = [f'{i}' for i in range(size)]
self.carrier_sort, self.smt_carrier_set = EnumSort("C", element_names, ctx=self.ctx)
# Create operation functions
self.operation_function_map: Dict[Operation, z3.FuncDeclRef] = {}
for operation in logic.operations:
self.operation_function_map[operation] = Function(
operation.symbol,
*(self.carrier_sort for _ in range(operation.arity + 1))
)
# Create designation function
self.is_designated = Function("D", self.carrier_sort, BoolSort(ctx=self.ctx))
# Add logic rules as constraints
self._add_logic_constraints()
self._add_designation_symmetry_constraints()
def _add_logic_constraints(self):
"""Add all logic rules and falsification rules as SMT constraints."""
# Add regular rules
for rule in self.logic.rules:
for constraint in logic_rule_to_smt_constraints(
rule,
self.is_designated,
self.smt_carrier_set,
self.operation_function_map
):
self.solver.add(constraint)
# Add falsification rules
for falsification_rule in self.logic.falsifies:
constraint = logic_falsification_rule_to_smt_constraints(
falsification_rule,
self.is_designated,
self.smt_carrier_set,
self.operation_function_map
)
self.solver.add(constraint)
def extract_model(self, smt_model) -> Model:
"""
Extract a Model object from an SMT model.
Args:
smt_model: The Z3 model to extract from
Returns:
A Model object representing the logic model
"""
carrier_set = {ModelValue(f"{i}") for i in range(self.size)}
# Extract designated values
smt_designated = [
x for x in self.smt_carrier_set
if smt_model.evaluate(self.is_designated(x))
]
designated_values = {ModelValue(str(x)) for x in smt_designated}
# Extract operation functions
model_functions = set()
for (operation, smt_function) in operation_function_map.items():
for (operation, smt_function) in self.operation_function_map.items():
mapping: Dict[Tuple[ModelValue], ModelValue] = {}
for smt_inputs in product(smt_carrier_set, repeat=operation.arity):
model_inputs = tuple((ModelValue(str(i)) for i in smt_inputs))
for smt_inputs in product(self.smt_carrier_set, repeat=operation.arity):
model_inputs = tuple(ModelValue(str(i)) for i in smt_inputs)
smt_output = smt_model.evaluate(smt_function(*smt_inputs))
model_output = ModelValue(str(smt_output))
mapping[model_inputs] = model_output
model_functions.add(ModelFunction(operation.arity, mapping, operation.symbol, ))
solver.reset()
del ctx
model_functions.add(ModelFunction(operation.arity, mapping, operation.symbol))
return Model(carrier_set, model_functions, designated_values)
else:
def _add_designation_symmetry_constraints(self):
"""
Add symmetry breaking constraints to avoid isomorphic models.
Strategy: Enforce a lexicographic ordering on designated values.
If element i is not designated, then no element j < i can be designated.
This ensures designated elements are "packed to the right".
"""
for i in range(1, len(self.smt_carrier_set)):
elem_i = self.smt_carrier_set[i]
elem_j = self.smt_carrier_set[i - 1]
# If i is not designated, then j (which comes before i) cannot be designated
self.solver.add(
Implies(
self.is_designated(elem_i) == False,
self.is_designated(elem_j) == False
)
)
def create_exclusion_constraint(self, model: Model) -> z3.BoolRef:
"""
Create a constraint that excludes the given model from future solutions.
Args:
model: The model to exclude
Returns:
An SMT constraint ensuring at least one aspect differs
"""
constraints = []
# Create mapping from ModelValue to SMT element
model_value_to_smt = {
ModelValue(str(smt_elem)): smt_elem
for smt_elem in self.smt_carrier_set
}
# Exclude operation function mappings
for model_func in model.logical_operations:
operation = Operation(model_func.operation_name, model_func.arity)
smt_func = self.operation_function_map[operation]
for inputs, output in model_func.mapping.items():
smt_inputs = tuple(model_value_to_smt[inp] for inp in inputs)
smt_output = model_value_to_smt[output]
# This input->output mapping should differ
constraints.append(smt_func(*smt_inputs) != smt_output)
# Exclude designated value set
for smt_elem in self.smt_carrier_set:
model_val = ModelValue(str(smt_elem))
is_designated_in_model = model_val in model.designated_values
# Designation should differ
if is_designated_in_model:
constraints.append(self.is_designated(smt_elem) == False)
else:
constraints.append(self.is_designated(smt_elem) == True)
return Or(constraints)
def find_model(self) -> Optional[Model]:
"""
Find a single model satisfying the logic constraints.
Returns:
A Model if one exists, None otherwise
"""
if self.solver.check() == sat:
return self.extract_model(self.solver.model())
return None
def reset(self):
"""Reset the solver state."""
self.solver.reset()
def __del__(self):
"""Cleanup resources."""
try:
self.solver.reset()
del self.ctx
except:
pass
def find_model(logic: Logic, size: int) -> Optional[Model]:
"""Find a single model for the given logic and size."""
encoder = SMTLogicEncoder(logic, size)
return encoder.find_model()
def find_all_models(logic: Logic, size: int) -> Generator[Model, None, None]:
"""
Find all models for the given logic and size.
Args:
logic: The logic system to encode
size: The size of the carrier set
Yields:
Model instances that satisfy the logic
"""
encoder = SMTLogicEncoder(logic, size)
while True:
# Try to find a model
model = encoder.find_model()
if model is None:
break
yield model
# Add constraint to exclude this model from future solutions
exclusion_constraint = encoder.create_exclusion_constraint(model)
encoder.solver.add(exclusion_constraint)