matmod/model.py

316 lines
11 KiB
Python
Raw Normal View History

2024-04-08 23:59:21 -04:00
"""
2024-05-28 14:50:31 -04:00
Matrix model semantics and satisfiability of
a given logic.
2024-04-08 23:59:21 -04:00
"""
2024-04-15 00:08:00 -04:00
from common import set_to_str
2024-04-08 23:59:21 -04:00
from logic import (
2024-05-28 14:50:31 -04:00
get_propostional_variables, Logic,
Operation, PropositionalVariable, Term
2024-04-21 12:15:24 -04:00
)
2024-05-28 14:50:31 -04:00
from collections import defaultdict
2024-05-29 13:50:20 -04:00
from functools import cached_property, lru_cache, reduce
from itertools import chain, combinations_with_replacement, permutations, product
from typing import Dict, List, Optional, Set, Tuple
2024-05-03 13:06:52 -04:00
2024-04-08 23:59:21 -04:00
2024-05-28 16:05:06 -04:00
__all__ = ['ModelValue', 'ModelFunction', 'Model', 'Interpretation']
2024-04-08 23:59:21 -04:00
class ModelValue:
def __init__(self, name):
self.name = name
self.hashed_value = hash(self.name)
def immutable(self, name, value):
raise Exception("Model values are immutable")
self.__setattr__ = immutable
def __str__(self):
return self.name
def __hash__(self):
return self.hashed_value
def __eq__(self, other):
return isinstance(other, ModelValue) and self.name == other.name
2024-05-28 14:50:31 -04:00
def __deepcopy__(self, _):
return ModelValue(self.name)
2024-04-08 23:59:21 -04:00
class ModelFunction:
2024-05-03 13:06:52 -04:00
def __init__(self, arity: int, mapping, operation_name = ""):
2024-04-08 23:59:21 -04:00
self.operation_name = operation_name
2024-05-03 13:06:52 -04:00
self.arity = arity
2024-04-08 23:59:21 -04:00
2024-05-28 14:50:31 -04:00
# Transform the mapping such that the
# key is always a tuple of model values
corrected_mapping: Dict[Tuple[ModelValue], ModelValue] = {}
2024-04-08 23:59:21 -04:00
for k, v in mapping.items():
if isinstance(k, tuple):
2024-05-03 13:06:52 -04:00
assert len(k) == arity
2024-04-08 23:59:21 -04:00
corrected_mapping[k] = v
elif isinstance(k, list):
2024-05-03 13:06:52 -04:00
assert len(k) == arity
2024-04-08 23:59:21 -04:00
corrected_mapping[tuple(k)] = v
else: # Assume it's atomic
2024-05-03 13:06:52 -04:00
assert arity == 1
2024-04-08 23:59:21 -04:00
corrected_mapping[(k,)] = v
2024-04-21 12:15:24 -04:00
2024-04-08 23:59:21 -04:00
self.mapping = corrected_mapping
2024-04-21 12:15:24 -04:00
2024-05-29 13:50:20 -04:00
@cached_property
def domain(self):
result_set: Set[ModelValue] = set()
for args in self.mapping.keys():
for v in args:
result_set.add(v)
return result_set
2024-04-08 23:59:21 -04:00
def __str__(self):
2024-05-29 13:50:20 -04:00
if self.arity == 1:
return unary_function_str(self)
elif self.arity == 2:
return binary_function_str(self)
# Default return dictionary representation
2024-04-08 23:59:21 -04:00
str_dict = dict()
for k, v in self.mapping.items():
inputstr = "(" + ", ".join(str(ki) for ki in k) + ")"
str_dict[inputstr] = str(v)
return self.operation_name + " " + str(str_dict)
2024-04-08 23:59:21 -04:00
def __call__(self, *args):
return self.mapping[args]
2024-04-21 12:15:24 -04:00
2024-05-29 13:50:20 -04:00
def unary_function_str(f: ModelFunction) -> str:
assert isinstance(f, ModelFunction) and f.arity == 1
sorted_domain = sorted(f.domain, key=lambda v : v.name)
header_line = f" {f.operation_name} | " + " ".join((str(v) for v in sorted_domain))
sep_line = "-" + ("-" * len(f.operation_name)) + "-+-" +\
("-" * len(sorted_domain)) +\
("-" * reduce(lambda sum, v : sum + len(v.name), sorted_domain, 0))
data_line = (" " * (len(f.operation_name) + 2)) + "| " + " ".join((str(f.mapping[(v,)]) for v in sorted_domain))
return "\n".join((header_line, sep_line, data_line)) + "\n"
def binary_function_str(f: ModelFunction) -> str:
assert isinstance(f, ModelFunction) and f.arity == 2
sorted_domain = sorted(f.domain, key=lambda v : v.name)
max_col_width = max(chain((len(v.name) for v in sorted_domain), (len(f.operation_name),)))
header_line = f" {f.operation_name} " +\
(" " * (max_col_width - len(f.operation_name))) + "| " +\
" ".join((str(v) for v in sorted_domain))
sep_line = "-" + ("-" * max_col_width) + "-+-" +\
("-" * len(sorted_domain)) +\
("-" * reduce(lambda sum, v : sum + len(v.name), sorted_domain, 0))
data_lines = ""
for row_v in sorted_domain:
data_line = f" {row_v.name} | " + " ".join((str(f.mapping[(row_v, col_v)]) for col_v in sorted_domain))
data_lines += data_line + "\n"
return "\n".join((header_line, sep_line, data_lines))
2024-05-28 16:05:06 -04:00
Interpretation = Dict[Operation, ModelFunction]
2024-04-08 23:59:21 -04:00
class Model:
def __init__(
self,
carrier_set: Set[ModelValue],
logical_operations: Set[ModelFunction],
2024-04-21 12:15:24 -04:00
designated_values: Set[ModelValue],
2024-05-29 13:50:20 -04:00
name: Optional[str] = None
2024-04-08 23:59:21 -04:00
):
assert designated_values <= carrier_set
self.carrier_set = carrier_set
self.logical_operations = logical_operations
self.designated_values = designated_values
2024-05-29 13:50:20 -04:00
self.name = str(abs(hash((
frozenset(carrier_set),
frozenset(logical_operations),
frozenset(designated_values)
))))[:5] if name is None else name
2024-04-21 12:15:24 -04:00
2024-04-08 23:59:21 -04:00
def __str__(self):
2024-05-29 13:50:20 -04:00
result = ("=" * 25) + f"""
Model Name: {self.name}
Carrier Set: {set_to_str(self.carrier_set)}
2024-04-08 23:59:21 -04:00
Designated Values: {set_to_str(self.designated_values)}
"""
for function in self.logical_operations:
result += f"{str(function)}\n"
2024-04-21 12:15:24 -04:00
2024-05-29 13:50:20 -04:00
return result + ("=" * 25) + "\n"
2024-04-08 23:59:21 -04:00
2024-05-28 14:50:31 -04:00
def evaluate_term(
t: Term, f: Dict[PropositionalVariable, ModelValue],
interpretation: Dict[Operation, ModelFunction]) -> ModelValue:
"""
Given a term in a logic, mapping
between terms and model values,
as well as an interpretation
of operations to model functions,
return the evaluated model value.
"""
2024-04-08 23:59:21 -04:00
if isinstance(t, PropositionalVariable):
return f[t]
model_function = interpretation[t.operation]
2024-05-28 14:50:31 -04:00
model_arguments: List[ModelValue] = []
2024-04-08 23:59:21 -04:00
for logic_arg in t.arguments:
model_arg = evaluate_term(logic_arg, f, interpretation)
model_arguments.append(model_arg)
2024-04-21 12:15:24 -04:00
2024-04-08 23:59:21 -04:00
return model_function(*model_arguments)
def all_model_valuations(
2024-04-15 00:08:00 -04:00
pvars: Tuple[PropositionalVariable],
mvalues: Tuple[ModelValue]):
2024-05-28 14:50:31 -04:00
"""
Given propositional variables and model values,
produce every possible mapping between the two.
"""
2024-04-21 12:15:24 -04:00
2024-05-04 16:51:49 -04:00
all_possible_values = product(mvalues, repeat=len(pvars))
2024-04-08 23:59:21 -04:00
for valuation in all_possible_values:
2024-05-28 14:50:31 -04:00
mapping: Dict[PropositionalVariable, ModelValue] = {}
2024-04-08 23:59:21 -04:00
assert len(pvars) == len(valuation)
for pvar, value in zip(pvars, valuation):
mapping[pvar] = value
yield mapping
2024-04-15 00:08:00 -04:00
@lru_cache
def all_model_valuations_cached(
pvars: Tuple[PropositionalVariable],
mvalues: Tuple[ModelValue]):
return list(all_model_valuations(pvars, mvalues))
2024-04-21 17:37:21 -04:00
2024-04-15 00:08:00 -04:00
def satisfiable(logic: Logic, model: Model, interpretation: Dict[Operation, ModelFunction]) -> bool:
2024-05-28 14:50:31 -04:00
"""
Determine whether a model satisfies a logic
given an interpretation.
"""
2024-04-15 00:08:00 -04:00
pvars = tuple(get_propostional_variables(tuple(logic.rules)))
mappings = all_model_valuations_cached(pvars, tuple(model.carrier_set))
2024-04-08 23:59:21 -04:00
for mapping in mappings:
2024-05-28 14:50:31 -04:00
# Make sure that the model satisfies each of the rules
2024-04-08 23:59:21 -04:00
for rule in logic.rules:
2024-05-28 14:50:31 -04:00
# The check only applies if the premises are designated
2024-04-08 23:59:21 -04:00
premise_met = True
2024-05-28 14:50:31 -04:00
premise_ts: Set[ModelValue] = set()
2024-04-08 23:59:21 -04:00
for premise in rule.premises:
2024-04-21 12:15:24 -04:00
premise_t = evaluate_term(premise, mapping, interpretation)
2024-05-28 14:50:31 -04:00
# As soon as one premise is not designated,
# move to the next rule.
2024-04-21 12:15:24 -04:00
if premise_t not in model.designated_values:
2024-04-08 23:59:21 -04:00
premise_met = False
break
2024-05-28 14:50:31 -04:00
# If designated, keep track of the evaluated term
2024-04-21 12:15:24 -04:00
premise_ts.add(premise_t)
2024-04-08 23:59:21 -04:00
if not premise_met:
continue
2024-04-21 12:15:24 -04:00
2024-05-28 14:50:31 -04:00
# With the premises designated, make sure the consequent is designated
2024-04-21 12:15:24 -04:00
consequent_t = evaluate_term(rule.conclusion, mapping, interpretation)
if consequent_t not in model.designated_values:
2024-04-08 23:59:21 -04:00
return False
2024-04-21 12:15:24 -04:00
2024-04-08 23:59:21 -04:00
return True
2024-05-03 13:06:52 -04:00
2024-05-28 14:50:31 -04:00
2024-05-03 13:06:52 -04:00
def model_closure(initial_set: Set[ModelValue], mfunctions: Set[ModelFunction], top: Optional[ModelValue], bottom: Optional[ModelValue]) -> Set[ModelValue]:
2024-05-28 14:50:31 -04:00
"""
Given an initial set of model values and a set of model functions,
compute the complete set of model values that are closed
under the operations.
If top or bottom is encountered, then we end the saturation procedure early.
2024-05-28 14:50:31 -04:00
"""
closure_set: Set[ModelValue] = initial_set
2024-05-28 14:50:31 -04:00
last_new: Set[ModelValue] = initial_set
changed: bool = True
topbottom_found = False
2024-05-03 13:06:52 -04:00
while changed:
changed = False
2024-05-28 14:50:31 -04:00
new_elements: Set[ModelValue] = set()
old_closure: Set[ModelValue] = closure_set - last_new
2024-05-03 13:06:52 -04:00
# arity -> args
cached_args = defaultdict(list)
2024-05-03 13:06:52 -04:00
2024-05-28 14:50:31 -04:00
# Pass elements into each model function
for mfun in mfunctions:
2024-05-03 13:06:52 -04:00
2024-05-28 14:50:31 -04:00
# If a previous function shared the same arity,
# we'll use the same set of computed arguments
# to pass into the model functions.
if mfun.arity in cached_args:
for args in cached_args[mfun.arity]:
2024-05-28 14:50:31 -04:00
# Compute the new elements
# given the cached arguments.
element = mfun(*args)
if element not in closure_set:
new_elements.add(element)
2024-05-28 14:50:31 -04:00
# Optimization: Break out of computation
# early when top or bottom element is foun
if top is not None and element == top:
topbottom_found = True
if bottom is not None and element == bottom:
topbottom_found = True
if topbottom_found:
break
if topbottom_found:
break
2024-05-28 14:50:31 -04:00
# We don't need to compute the arguments
# thanks to the cache, so move onto the
# next function.
continue
2024-05-03 17:04:03 -04:00
2024-05-28 14:50:31 -04:00
# At this point, we don't have cached arguments, so we need
# to compute this set.
# Each argument must have at least one new element to not repeat
# work. We'll range over the number of new model values within our
# argument.
for num_new in range(1, mfun.arity + 1):
new_args = combinations_with_replacement(last_new, r=num_new)
old_args = combinations_with_replacement(old_closure, r=mfun.arity - num_new)
2024-05-28 14:50:31 -04:00
# Determine every possible ordering of the concatenated
# new and old model values.
for new_arg, old_arg in product(new_args, old_args):
for args in permutations(new_arg + old_arg):
cached_args[mfun.arity].append(args)
element = mfun(*args)
if element not in closure_set:
new_elements.add(element)
# Optimization: Break out of computation
# early when top or bottom element is foun
if top is not None and element == top:
topbottom_found = True
if bottom is not None and element == bottom:
topbottom_found = True
if topbottom_found:
break
if topbottom_found:
break
if topbottom_found:
break
closure_set.update(new_elements)
changed = len(new_elements) > 0
last_new = new_elements
if topbottom_found:
break
return closure_set