Code cleanup

This commit is contained in:
Brandon Rozek 2025-05-03 16:42:15 -04:00
parent fa9e5026ca
commit 01204a9551
4 changed files with 286 additions and 353 deletions

171
model.py
View file

@ -16,9 +16,9 @@ from typing import Dict, List, Optional, Set, Tuple
__all__ = ['ModelValue', 'ModelFunction', 'Model', 'Interpretation']
class ModelValue:
def __init__(self, name):
def __init__(self, name: str, hashed_value: Optional[int] = None):
self.name = name
self.hashed_value = hash(self.name)
self.hashed_value = hashed_value if hashed_value is not None else hash(self.name)
self.__setattr__ = immutable
def __str__(self):
return self.name
@ -27,7 +27,7 @@ class ModelValue:
def __eq__(self, other):
return isinstance(other, ModelValue) and self.name == other.name
def __deepcopy__(self, _):
return ModelValue(self.name)
return ModelValue(self.name, self.hashed_value)
class ModelFunction:
def __init__(self, arity: int, mapping, operation_name = ""):
@ -109,57 +109,75 @@ Interpretation = Dict[Operation, ModelFunction]
class OrderTable:
def __init__(self):
# a : {x | x <= a }
self.ordering: Dict[ModelValue, Set[ModelValue]] = defaultdict(set)
self.le_map: Dict[ModelValue, Set[ModelValue]] = defaultdict(set)
# a : {x | x >= a}
self.ge_map: Dict[ModelValue, Set[ModelValue]] = defaultdict(set)
def add(self, x, y):
"""
Add x <= y
"""
self.ordering[y].add(x)
self.le_map[y].add(x)
self.ge_map[x].add(y)
def is_lt(self, x, y):
return y in self.ordering[x]
return x in self.le_map[y]
def meet(self, x, y) -> Optional[ModelValue]:
X = self.ordering[x]
Y = self.ordering[y]
X = self.le_map[x]
Y = self.le_map[y]
candidates = X.intersection(Y)
for m in candidates:
gt_all_candidates = True
for w in candidates:
if not self.is_lt(w, m):
gt_all_candidates = False
break
# Grab all elements greater than each of the candidates
candidate_ge_maps = (self.ge_map[candidate] for candidate in candidates)
common_ge_values = reduce(set.intersection, candidate_ge_maps)
if gt_all_candidates:
return m
# Intersect with candidates to get the values that satisfy
# the meet properties
result_set = candidates.intersection(common_ge_values)
# Otherwise the meet does not exist
print("Meet does not exist", (x, y), candidates)
return None
# NOTE: The meet may not exist, in which case return None
result = next(iter(result_set), None)
return result
def join(self, x, y) -> Optional[ModelValue]:
# Grab the collection of elements greater than x and y
candidates = set()
for w in self.ordering:
if self.is_lt(x, w) and self.is_lt(y, w):
candidates.add(w)
X = self.ge_map[x]
Y = self.ge_map[y]
for j in candidates:
lt_all_candidates = True
for w in candidates:
if not self.is_lt(j, w):
lt_all_candidates = False
break
candidates = X.intersection(Y)
if lt_all_candidates:
return j
# Grab all elements smaller than each of the candidates
candidate_le_maps = (self.le_map[candidate] for candidate in candidates)
common_le_values = reduce(set.intersection, candidate_le_maps)
# Otherwise the join does not exist
print("Join does not exist", (x, y), candidates)
return None
# Intersect with candidatse to get the values that satisfy
# the join properties
result_set = candidates.intersection(common_le_values)
# NOTE: The join may not exist, in which case return None
result = next(iter(result_set), None)
return result
def top(self) -> Optional[ModelValue]:
ge_maps = (self.ge_map[candidate] for candidate in self.ge_map)
result_set = reduce(set.intersection, ge_maps)
# Either not unique or does not exist
if len(result_set) != 1:
return None
return next(iter(result_set))
def bottom(self) -> Optional[ModelValue]:
le_maps = (self.le_map[candidate] for candidate in self.le_map)
result_set = reduce(set.intersection, le_maps)
# Either not unique or does not exist
if len(result_set) != 1:
return None
return next(iter(result_set))
class Model:
@ -276,86 +294,61 @@ def satisfiable(logic: Logic, model: Model, interpretation: Dict[Operation, Mode
return True
def model_closure(initial_set: Set[ModelValue], mfunctions: Set[ModelFunction], forbidden_element: Optional[ModelValue]) -> Set[ModelValue]:
"""
Given an initial set of model values and a set of model functions,
compute the complete set of model values that are closed
under the operations.
If top or bottom is encountered, then we end the saturation procedure early.
If the forbidden element is encountered, then we end the saturation procedure early.
"""
closure_set: Set[ModelValue] = initial_set
last_new: Set[ModelValue] = initial_set
changed: bool = True
forbidden_found = False
arities = set()
for mfun in mfunctions:
arities.add(mfun.arity)
while changed:
changed = False
new_elements: Set[ModelValue] = set()
old_closure: Set[ModelValue] = closure_set - last_new
# arity -> args
cached_args = defaultdict(list)
args_by_arity = defaultdict(list)
# Pass elements into each model function
for mfun in mfunctions:
# If a previous function shared the same arity,
# we'll use the same set of computed arguments
# to pass into the model functions.
if mfun.arity in cached_args:
for args in cached_args[mfun.arity]:
# Compute the new elements
# given the cached arguments.
element = mfun(*args)
if element not in closure_set:
new_elements.add(element)
# Optimization: Break out of computation
# early when forbidden element is found
if forbidden_element is not None and element == forbidden_element:
forbidden_found = True
break
if forbidden_found:
break
# We don't need to compute the arguments
# thanks to the cache, so move onto the
# next function.
continue
# At this point, we don't have cached arguments, so we need
# to compute this set.
# Each argument must have at least one new element to not repeat
# work. We'll range over the number of new model values within our
# argument.
for num_new in range(1, mfun.arity + 1):
# Motivation: We want to only compute arguments that we have not
# seen before
for arity in arities:
for num_new in range(1, arity + 1):
new_args = combinations_with_replacement(last_new, r=num_new)
old_args = combinations_with_replacement(old_closure, r=mfun.arity - num_new)
# Determine every possible ordering of the concatenated
# new and old model values.
old_args = combinations_with_replacement(old_closure, r=arity - num_new)
# Determine every possible ordering of the concatenated new and
# old model values.
for new_arg, old_arg in product(new_args, old_args):
for args in permutations(new_arg + old_arg):
cached_args[mfun.arity].append(args)
element = mfun(*args)
if element not in closure_set:
new_elements.add(element)
for combined_args in permutations(new_arg + old_arg):
args_by_arity[arity].append(combined_args)
# Optimization: Break out of computation
# early when forbidden element is found
if forbidden_element is not None and element == forbidden_element:
forbidden_found = True
break
if forbidden_found:
break
# Pass each argument into each model function
for mfun in mfunctions:
for args in args_by_arity[mfun.arity]:
# Compute the new elements
# given the cached arguments.
element = mfun(*args)
if element not in closure_set:
new_elements.add(element)
if forbidden_found:
# Optimization: Break out of computation
# early when forbidden element is found
if forbidden_element is not None and element == forbidden_element:
forbidden_found = True
break
if forbidden_found:
break
closure_set.update(new_elements)
changed = len(new_elements) > 0