mirror of
https://github.com/Brandon-Rozek/matmod.git
synced 2025-12-19 05:10:25 +00:00
Code cleanup
This commit is contained in:
parent
fa9e5026ca
commit
01204a9551
4 changed files with 286 additions and 353 deletions
171
model.py
171
model.py
|
|
@ -16,9 +16,9 @@ from typing import Dict, List, Optional, Set, Tuple
|
|||
__all__ = ['ModelValue', 'ModelFunction', 'Model', 'Interpretation']
|
||||
|
||||
class ModelValue:
|
||||
def __init__(self, name):
|
||||
def __init__(self, name: str, hashed_value: Optional[int] = None):
|
||||
self.name = name
|
||||
self.hashed_value = hash(self.name)
|
||||
self.hashed_value = hashed_value if hashed_value is not None else hash(self.name)
|
||||
self.__setattr__ = immutable
|
||||
def __str__(self):
|
||||
return self.name
|
||||
|
|
@ -27,7 +27,7 @@ class ModelValue:
|
|||
def __eq__(self, other):
|
||||
return isinstance(other, ModelValue) and self.name == other.name
|
||||
def __deepcopy__(self, _):
|
||||
return ModelValue(self.name)
|
||||
return ModelValue(self.name, self.hashed_value)
|
||||
|
||||
class ModelFunction:
|
||||
def __init__(self, arity: int, mapping, operation_name = ""):
|
||||
|
|
@ -109,57 +109,75 @@ Interpretation = Dict[Operation, ModelFunction]
|
|||
class OrderTable:
|
||||
def __init__(self):
|
||||
# a : {x | x <= a }
|
||||
self.ordering: Dict[ModelValue, Set[ModelValue]] = defaultdict(set)
|
||||
self.le_map: Dict[ModelValue, Set[ModelValue]] = defaultdict(set)
|
||||
# a : {x | x >= a}
|
||||
self.ge_map: Dict[ModelValue, Set[ModelValue]] = defaultdict(set)
|
||||
|
||||
def add(self, x, y):
|
||||
"""
|
||||
Add x <= y
|
||||
"""
|
||||
self.ordering[y].add(x)
|
||||
self.le_map[y].add(x)
|
||||
self.ge_map[x].add(y)
|
||||
|
||||
def is_lt(self, x, y):
|
||||
return y in self.ordering[x]
|
||||
return x in self.le_map[y]
|
||||
|
||||
def meet(self, x, y) -> Optional[ModelValue]:
|
||||
X = self.ordering[x]
|
||||
Y = self.ordering[y]
|
||||
X = self.le_map[x]
|
||||
Y = self.le_map[y]
|
||||
|
||||
candidates = X.intersection(Y)
|
||||
|
||||
for m in candidates:
|
||||
gt_all_candidates = True
|
||||
for w in candidates:
|
||||
if not self.is_lt(w, m):
|
||||
gt_all_candidates = False
|
||||
break
|
||||
# Grab all elements greater than each of the candidates
|
||||
candidate_ge_maps = (self.ge_map[candidate] for candidate in candidates)
|
||||
common_ge_values = reduce(set.intersection, candidate_ge_maps)
|
||||
|
||||
if gt_all_candidates:
|
||||
return m
|
||||
# Intersect with candidates to get the values that satisfy
|
||||
# the meet properties
|
||||
result_set = candidates.intersection(common_ge_values)
|
||||
|
||||
# Otherwise the meet does not exist
|
||||
print("Meet does not exist", (x, y), candidates)
|
||||
return None
|
||||
# NOTE: The meet may not exist, in which case return None
|
||||
result = next(iter(result_set), None)
|
||||
return result
|
||||
|
||||
def join(self, x, y) -> Optional[ModelValue]:
|
||||
# Grab the collection of elements greater than x and y
|
||||
candidates = set()
|
||||
for w in self.ordering:
|
||||
if self.is_lt(x, w) and self.is_lt(y, w):
|
||||
candidates.add(w)
|
||||
X = self.ge_map[x]
|
||||
Y = self.ge_map[y]
|
||||
|
||||
for j in candidates:
|
||||
lt_all_candidates = True
|
||||
for w in candidates:
|
||||
if not self.is_lt(j, w):
|
||||
lt_all_candidates = False
|
||||
break
|
||||
candidates = X.intersection(Y)
|
||||
|
||||
if lt_all_candidates:
|
||||
return j
|
||||
# Grab all elements smaller than each of the candidates
|
||||
candidate_le_maps = (self.le_map[candidate] for candidate in candidates)
|
||||
common_le_values = reduce(set.intersection, candidate_le_maps)
|
||||
|
||||
# Otherwise the join does not exist
|
||||
print("Join does not exist", (x, y), candidates)
|
||||
return None
|
||||
# Intersect with candidatse to get the values that satisfy
|
||||
# the join properties
|
||||
result_set = candidates.intersection(common_le_values)
|
||||
|
||||
# NOTE: The join may not exist, in which case return None
|
||||
result = next(iter(result_set), None)
|
||||
return result
|
||||
|
||||
def top(self) -> Optional[ModelValue]:
|
||||
ge_maps = (self.ge_map[candidate] for candidate in self.ge_map)
|
||||
result_set = reduce(set.intersection, ge_maps)
|
||||
|
||||
# Either not unique or does not exist
|
||||
if len(result_set) != 1:
|
||||
return None
|
||||
|
||||
return next(iter(result_set))
|
||||
|
||||
def bottom(self) -> Optional[ModelValue]:
|
||||
le_maps = (self.le_map[candidate] for candidate in self.le_map)
|
||||
result_set = reduce(set.intersection, le_maps)
|
||||
|
||||
# Either not unique or does not exist
|
||||
if len(result_set) != 1:
|
||||
return None
|
||||
|
||||
return next(iter(result_set))
|
||||
|
||||
|
||||
class Model:
|
||||
|
|
@ -276,86 +294,61 @@ def satisfiable(logic: Logic, model: Model, interpretation: Dict[Operation, Mode
|
|||
return True
|
||||
|
||||
|
||||
|
||||
def model_closure(initial_set: Set[ModelValue], mfunctions: Set[ModelFunction], forbidden_element: Optional[ModelValue]) -> Set[ModelValue]:
|
||||
"""
|
||||
Given an initial set of model values and a set of model functions,
|
||||
compute the complete set of model values that are closed
|
||||
under the operations.
|
||||
|
||||
If top or bottom is encountered, then we end the saturation procedure early.
|
||||
If the forbidden element is encountered, then we end the saturation procedure early.
|
||||
"""
|
||||
closure_set: Set[ModelValue] = initial_set
|
||||
last_new: Set[ModelValue] = initial_set
|
||||
changed: bool = True
|
||||
forbidden_found = False
|
||||
|
||||
arities = set()
|
||||
for mfun in mfunctions:
|
||||
arities.add(mfun.arity)
|
||||
|
||||
while changed:
|
||||
changed = False
|
||||
new_elements: Set[ModelValue] = set()
|
||||
old_closure: Set[ModelValue] = closure_set - last_new
|
||||
|
||||
# arity -> args
|
||||
cached_args = defaultdict(list)
|
||||
args_by_arity = defaultdict(list)
|
||||
|
||||
# Pass elements into each model function
|
||||
for mfun in mfunctions:
|
||||
|
||||
# If a previous function shared the same arity,
|
||||
# we'll use the same set of computed arguments
|
||||
# to pass into the model functions.
|
||||
if mfun.arity in cached_args:
|
||||
for args in cached_args[mfun.arity]:
|
||||
# Compute the new elements
|
||||
# given the cached arguments.
|
||||
element = mfun(*args)
|
||||
if element not in closure_set:
|
||||
new_elements.add(element)
|
||||
|
||||
# Optimization: Break out of computation
|
||||
# early when forbidden element is found
|
||||
if forbidden_element is not None and element == forbidden_element:
|
||||
forbidden_found = True
|
||||
break
|
||||
|
||||
if forbidden_found:
|
||||
break
|
||||
|
||||
# We don't need to compute the arguments
|
||||
# thanks to the cache, so move onto the
|
||||
# next function.
|
||||
continue
|
||||
|
||||
# At this point, we don't have cached arguments, so we need
|
||||
# to compute this set.
|
||||
|
||||
# Each argument must have at least one new element to not repeat
|
||||
# work. We'll range over the number of new model values within our
|
||||
# argument.
|
||||
for num_new in range(1, mfun.arity + 1):
|
||||
# Motivation: We want to only compute arguments that we have not
|
||||
# seen before
|
||||
for arity in arities:
|
||||
for num_new in range(1, arity + 1):
|
||||
new_args = combinations_with_replacement(last_new, r=num_new)
|
||||
old_args = combinations_with_replacement(old_closure, r=mfun.arity - num_new)
|
||||
# Determine every possible ordering of the concatenated
|
||||
# new and old model values.
|
||||
old_args = combinations_with_replacement(old_closure, r=arity - num_new)
|
||||
# Determine every possible ordering of the concatenated new and
|
||||
# old model values.
|
||||
for new_arg, old_arg in product(new_args, old_args):
|
||||
for args in permutations(new_arg + old_arg):
|
||||
cached_args[mfun.arity].append(args)
|
||||
element = mfun(*args)
|
||||
if element not in closure_set:
|
||||
new_elements.add(element)
|
||||
for combined_args in permutations(new_arg + old_arg):
|
||||
args_by_arity[arity].append(combined_args)
|
||||
|
||||
# Optimization: Break out of computation
|
||||
# early when forbidden element is found
|
||||
if forbidden_element is not None and element == forbidden_element:
|
||||
forbidden_found = True
|
||||
break
|
||||
|
||||
if forbidden_found:
|
||||
break
|
||||
# Pass each argument into each model function
|
||||
for mfun in mfunctions:
|
||||
for args in args_by_arity[mfun.arity]:
|
||||
# Compute the new elements
|
||||
# given the cached arguments.
|
||||
element = mfun(*args)
|
||||
if element not in closure_set:
|
||||
new_elements.add(element)
|
||||
|
||||
if forbidden_found:
|
||||
# Optimization: Break out of computation
|
||||
# early when forbidden element is found
|
||||
if forbidden_element is not None and element == forbidden_element:
|
||||
forbidden_found = True
|
||||
break
|
||||
|
||||
if forbidden_found:
|
||||
break
|
||||
|
||||
closure_set.update(new_elements)
|
||||
changed = len(new_elements) > 0
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue