Code cleanup and documentation

This commit is contained in:
Brandon Rozek 2024-05-28 14:50:31 -04:00
parent 81a2d17965
commit 6b4d5828c8
No known key found for this signature in database
GPG key ID: 26E457DA82C9F480
3 changed files with 118 additions and 108 deletions

149
model.py
View file

@ -1,21 +1,20 @@
"""
Defining what it means to be a model
Matrix model semantics and satisfiability of
a given logic.
"""
from common import set_to_str
from logic import (
PropositionalVariable, get_propostional_variables, Logic, Term,
Operation, Conjunction, Disjunction, Implication
get_propostional_variables, Logic,
Operation, PropositionalVariable, Term
)
from typing import Set, Dict, Tuple, Optional
from collections import defaultdict
from functools import lru_cache
from itertools import combinations, chain, product, permutations
from copy import deepcopy
from itertools import combinations_with_replacement, permutations, product
from typing import Dict, List, Set, Tuple
__all__ = ['ModelValue', 'ModelFunction', 'Model']
class ModelValue:
def __init__(self, name):
self.name = name
@ -29,10 +28,7 @@ class ModelValue:
return self.hashed_value
def __eq__(self, other):
return isinstance(other, ModelValue) and self.name == other.name
def __lt__(self, other):
assert isinstance(other, ModelValue)
return ModelOrderConstraint(self, other)
def __deepcopy__(self, memo):
def __deepcopy__(self, _):
return ModelValue(self.name)
@ -41,8 +37,9 @@ class ModelFunction:
self.operation_name = operation_name
self.arity = arity
# Correct input to always be a tuple
corrected_mapping = dict()
# Transform the mapping such that the
# key is always a tuple of model values
corrected_mapping: Dict[Tuple[ModelValue], ModelValue] = {}
for k, v in mapping.items():
if isinstance(k, tuple):
assert len(k) == arity
@ -66,35 +63,17 @@ class ModelFunction:
def __call__(self, *args):
return self.mapping[args]
# def __eq__(self, other):
# return isinstance(other, ModelFunction) and self.name == other.name and self.arity == other.arity
class ModelOrderConstraint:
# a < b
def __init__(self, a: ModelValue, b: ModelValue):
self.a = a
self.b = b
def __hash__(self):
return hash(self.a) * hash(self.b)
def __eq__(self, other):
return isinstance(other, ModelOrderConstraint) and \
self.a == other.a and self.b == other.b
class Model:
def __init__(
self,
carrier_set: Set[ModelValue],
logical_operations: Set[ModelFunction],
designated_values: Set[ModelValue],
ordering: Optional[Set[ModelOrderConstraint]] = None
):
assert designated_values <= carrier_set
self.carrier_set = carrier_set
self.logical_operations = logical_operations
self.designated_values = designated_values
self.ordering = ordering if ordering is not None else set()
# TODO: Make sure ordering is "valid"
# That is: transitive, etc.
def __str__(self):
result = f"""Carrier Set: {set_to_str(self.carrier_set)}
@ -106,12 +85,22 @@ Designated Values: {set_to_str(self.designated_values)}
return result
def evaluate_term(t: Term, f: Dict[PropositionalVariable, ModelValue], interpretation: Dict[Operation, ModelFunction]) -> ModelValue:
def evaluate_term(
t: Term, f: Dict[PropositionalVariable, ModelValue],
interpretation: Dict[Operation, ModelFunction]) -> ModelValue:
"""
Given a term in a logic, mapping
between terms and model values,
as well as an interpretation
of operations to model functions,
return the evaluated model value.
"""
if isinstance(t, PropositionalVariable):
return f[t]
model_function = interpretation[t.operation]
model_arguments = []
model_arguments: List[ModelValue] = []
for logic_arg in t.arguments:
model_arg = evaluate_term(logic_arg, f, interpretation)
model_arguments.append(model_arg)
@ -121,11 +110,15 @@ def evaluate_term(t: Term, f: Dict[PropositionalVariable, ModelValue], interpret
def all_model_valuations(
pvars: Tuple[PropositionalVariable],
mvalues: Tuple[ModelValue]):
"""
Given propositional variables and model values,
produce every possible mapping between the two.
"""
all_possible_values = product(mvalues, repeat=len(pvars))
for valuation in all_possible_values:
mapping: Dict[PropositionalVariable, ModelValue] = dict()
mapping: Dict[PropositionalVariable, ModelValue] = {}
assert len(pvars) == len(valuation)
for pvar, value in zip(pvars, valuation):
mapping[pvar] = value
@ -137,98 +130,92 @@ def all_model_valuations_cached(
mvalues: Tuple[ModelValue]):
return list(all_model_valuations(pvars, mvalues))
def rule_ordering_satisfied(model: Model, interpretation: Dict[Operation, ModelFunction]) -> bool:
"""
Currently testing whether this function helps with runtime...
"""
if Conjunction in interpretation:
possible_inputs = ((a, b) for (a, b) in product(model.carrier_set, model.carrier_set))
for a, b in possible_inputs:
output = interpretation[Conjunction](a, b)
if a < b in model.ordering and output != a:
print("RETURNING FALSE")
return False
if b < a in model.ordering and output != b:
print("RETURNING FALSE")
return False
if Disjunction in interpretation:
possible_inputs = ((a, b) for (a, b) in product(model.carrier_set, model.carrier_set))
for a, b in possible_inputs:
output = interpretation[Disjunction](a, b)
if a < b in model.ordering and output != b:
print("RETURNING FALSE")
return False
if b < a in model.ordering and output != a:
print("RETURNING FALSE")
return False
return True
def satisfiable(logic: Logic, model: Model, interpretation: Dict[Operation, ModelFunction]) -> bool:
"""
Determine whether a model satisfies a logic
given an interpretation.
"""
pvars = tuple(get_propostional_variables(tuple(logic.rules)))
mappings = all_model_valuations_cached(pvars, tuple(model.carrier_set))
# NOTE: Does not look like rule ordering is helping for finding
# models of R...
if not rule_ordering_satisfied(model, interpretation):
return False
for mapping in mappings:
# Make sure that the model satisfies each of the rules
for rule in logic.rules:
# The check only applies if the premises are designated
premise_met = True
premise_ts = set()
premise_ts: Set[ModelValue] = set()
for premise in rule.premises:
premise_t = evaluate_term(premise, mapping, interpretation)
# As soon as one premise is not designated,
# move to the next rule.
if premise_t not in model.designated_values:
premise_met = False
break
# If designated, keep track of the evaluated term
premise_ts.add(premise_t)
if not premise_met:
continue
# With the premises designated, make sure the consequent is designated
consequent_t = evaluate_term(rule.conclusion, mapping, interpretation)
if consequent_t not in model.designated_values:
return False
return True
from itertools import combinations_with_replacement
from collections import defaultdict
def model_closure(initial_set: Set[ModelValue], mfunctions: Set[ModelFunction]):
"""
Given an initial set of model values and a set of model functions,
compute the complete set of model values that are closed
under the operations.
"""
closure_set: Set[ModelValue] = initial_set
last_new = initial_set
changed = True
last_new: Set[ModelValue] = initial_set
changed: bool = True
while changed:
changed = False
new_elements = set()
old_closure = closure_set - last_new
new_elements: Set[ModelValue] = set()
old_closure: Set[ModelValue] = closure_set - last_new
# arity -> args
cached_args = defaultdict(list)
# Pass elements into each model function
for mfun in mfunctions:
# Use cached args if this arity was looked at before
# If a previous function shared the same arity,
# we'll use the same set of computed arguments
# to pass into the model functions.
if mfun.arity in cached_args:
for args in cached_args[mfun.arity]:
# Compute the new elements
# given the cached arguments.
element = mfun(*args)
if element not in closure_set:
new_elements.add(element)
# Move onto next function
# We don't need to compute the arguments
# thanks to the cache, so move onto the
# next function.
continue
# Iterate over how many new elements would be within the arguments
# NOTE: To not repeat work, there must be at least one new element
# At this point, we don't have cached arguments, so we need
# to compute this set.
# Each argument must have at least one new element to not repeat
# work. We'll range over the number of new model values within our
# argument.
for num_new in range(1, mfun.arity + 1):
new_args = combinations_with_replacement(last_new, r=num_new)
old_args = combinations_with_replacement(old_closure, r=mfun.arity - num_new)
# Determine every possible ordering of the concatenated
# new and old model values.
for new_arg, old_arg in product(new_args, old_args):
for args in permutations(new_arg + old_arg):
cached_args[mfun.arity].append(args)