mirror of
				https://github.com/Brandon-Rozek/matmod.git
				synced 2025-11-03 03:11:12 +00:00 
			
		
		
		
	Code cleanup
This commit is contained in:
		
							parent
							
								
									fa9e5026ca
								
							
						
					
					
						commit
						01204a9551
					
				
					 4 changed files with 286 additions and 353 deletions
				
			
		
							
								
								
									
										171
									
								
								model.py
									
										
									
									
									
								
							
							
						
						
									
										171
									
								
								model.py
									
										
									
									
									
								
							| 
						 | 
				
			
			@ -16,9 +16,9 @@ from typing import Dict, List, Optional, Set, Tuple
 | 
			
		|||
__all__ = ['ModelValue', 'ModelFunction', 'Model', 'Interpretation']
 | 
			
		||||
 | 
			
		||||
class ModelValue:
 | 
			
		||||
    def __init__(self, name):
 | 
			
		||||
    def __init__(self, name: str, hashed_value: Optional[int] = None):
 | 
			
		||||
        self.name = name
 | 
			
		||||
        self.hashed_value = hash(self.name)
 | 
			
		||||
        self.hashed_value = hashed_value if hashed_value is not None else hash(self.name)
 | 
			
		||||
        self.__setattr__ = immutable
 | 
			
		||||
    def __str__(self):
 | 
			
		||||
        return self.name
 | 
			
		||||
| 
						 | 
				
			
			@ -27,7 +27,7 @@ class ModelValue:
 | 
			
		|||
    def __eq__(self, other):
 | 
			
		||||
        return isinstance(other, ModelValue) and self.name == other.name
 | 
			
		||||
    def __deepcopy__(self, _):
 | 
			
		||||
        return ModelValue(self.name)
 | 
			
		||||
        return ModelValue(self.name, self.hashed_value)
 | 
			
		||||
 | 
			
		||||
class ModelFunction:
 | 
			
		||||
    def __init__(self, arity: int, mapping, operation_name = ""):
 | 
			
		||||
| 
						 | 
				
			
			@ -109,57 +109,75 @@ Interpretation = Dict[Operation, ModelFunction]
 | 
			
		|||
class OrderTable:
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
        # a : {x | x <= a }
 | 
			
		||||
        self.ordering: Dict[ModelValue, Set[ModelValue]] = defaultdict(set)
 | 
			
		||||
        self.le_map: Dict[ModelValue, Set[ModelValue]] = defaultdict(set)
 | 
			
		||||
        # a : {x | x >= a}
 | 
			
		||||
        self.ge_map: Dict[ModelValue, Set[ModelValue]] = defaultdict(set)
 | 
			
		||||
 | 
			
		||||
    def add(self, x, y):
 | 
			
		||||
        """
 | 
			
		||||
        Add x <= y
 | 
			
		||||
        """
 | 
			
		||||
        self.ordering[y].add(x)
 | 
			
		||||
        self.le_map[y].add(x)
 | 
			
		||||
        self.ge_map[x].add(y)
 | 
			
		||||
 | 
			
		||||
    def is_lt(self, x, y):
 | 
			
		||||
        return y in self.ordering[x]
 | 
			
		||||
        return x in self.le_map[y]
 | 
			
		||||
 | 
			
		||||
    def meet(self, x, y) -> Optional[ModelValue]:
 | 
			
		||||
        X = self.ordering[x]
 | 
			
		||||
        Y = self.ordering[y]
 | 
			
		||||
        X = self.le_map[x]
 | 
			
		||||
        Y = self.le_map[y]
 | 
			
		||||
 | 
			
		||||
        candidates = X.intersection(Y)
 | 
			
		||||
 | 
			
		||||
        for m in candidates:
 | 
			
		||||
            gt_all_candidates = True
 | 
			
		||||
            for w in candidates:
 | 
			
		||||
                if not self.is_lt(w, m):
 | 
			
		||||
                    gt_all_candidates = False
 | 
			
		||||
                    break
 | 
			
		||||
        # Grab all elements greater than each of the candidates
 | 
			
		||||
        candidate_ge_maps = (self.ge_map[candidate] for candidate in candidates)
 | 
			
		||||
        common_ge_values = reduce(set.intersection, candidate_ge_maps)
 | 
			
		||||
 | 
			
		||||
            if gt_all_candidates:
 | 
			
		||||
                return m
 | 
			
		||||
        # Intersect with candidates to get the values that satisfy
 | 
			
		||||
        # the meet properties
 | 
			
		||||
        result_set = candidates.intersection(common_ge_values)
 | 
			
		||||
 | 
			
		||||
        # Otherwise the meet does not exist
 | 
			
		||||
        print("Meet does not exist", (x, y), candidates)
 | 
			
		||||
        return None
 | 
			
		||||
        # NOTE: The meet may not exist, in which case return None
 | 
			
		||||
        result = next(iter(result_set), None)
 | 
			
		||||
        return result
 | 
			
		||||
 | 
			
		||||
    def join(self, x, y) -> Optional[ModelValue]:
 | 
			
		||||
        # Grab the collection of elements greater than x and y
 | 
			
		||||
        candidates = set()
 | 
			
		||||
        for w in self.ordering:
 | 
			
		||||
            if self.is_lt(x, w) and self.is_lt(y, w):
 | 
			
		||||
                candidates.add(w)
 | 
			
		||||
        X = self.ge_map[x]
 | 
			
		||||
        Y = self.ge_map[y]
 | 
			
		||||
 | 
			
		||||
        for j in candidates:
 | 
			
		||||
            lt_all_candidates = True
 | 
			
		||||
            for w in candidates:
 | 
			
		||||
                if not self.is_lt(j, w):
 | 
			
		||||
                    lt_all_candidates = False
 | 
			
		||||
                    break
 | 
			
		||||
        candidates = X.intersection(Y)
 | 
			
		||||
 | 
			
		||||
            if lt_all_candidates:
 | 
			
		||||
                return j
 | 
			
		||||
        # Grab all elements smaller than each of the candidates
 | 
			
		||||
        candidate_le_maps = (self.le_map[candidate] for candidate in candidates)
 | 
			
		||||
        common_le_values = reduce(set.intersection, candidate_le_maps)
 | 
			
		||||
 | 
			
		||||
        # Otherwise the join does not exist
 | 
			
		||||
        print("Join does not exist", (x, y), candidates)
 | 
			
		||||
        return None
 | 
			
		||||
        # Intersect with candidatse to get the values that satisfy
 | 
			
		||||
        # the join properties
 | 
			
		||||
        result_set = candidates.intersection(common_le_values)
 | 
			
		||||
 | 
			
		||||
        # NOTE: The join may not exist, in which case return None
 | 
			
		||||
        result = next(iter(result_set), None)
 | 
			
		||||
        return result
 | 
			
		||||
 | 
			
		||||
    def top(self) -> Optional[ModelValue]:
 | 
			
		||||
        ge_maps = (self.ge_map[candidate] for candidate in self.ge_map)
 | 
			
		||||
        result_set = reduce(set.intersection, ge_maps)
 | 
			
		||||
 | 
			
		||||
        # Either not unique or does not exist
 | 
			
		||||
        if len(result_set) != 1:
 | 
			
		||||
            return None
 | 
			
		||||
 | 
			
		||||
        return next(iter(result_set))
 | 
			
		||||
 | 
			
		||||
    def bottom(self) -> Optional[ModelValue]:
 | 
			
		||||
        le_maps = (self.le_map[candidate] for candidate in self.le_map)
 | 
			
		||||
        result_set = reduce(set.intersection, le_maps)
 | 
			
		||||
 | 
			
		||||
        # Either not unique or does not exist
 | 
			
		||||
        if len(result_set) != 1:
 | 
			
		||||
            return None
 | 
			
		||||
 | 
			
		||||
        return next(iter(result_set))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Model:
 | 
			
		||||
| 
						 | 
				
			
			@ -276,86 +294,61 @@ def satisfiable(logic: Logic, model: Model, interpretation: Dict[Operation, Mode
 | 
			
		|||
    return True
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def model_closure(initial_set: Set[ModelValue], mfunctions: Set[ModelFunction], forbidden_element: Optional[ModelValue]) -> Set[ModelValue]:
 | 
			
		||||
    """
 | 
			
		||||
    Given an initial set of model values and a set of model functions,
 | 
			
		||||
    compute the complete set of model values that are closed
 | 
			
		||||
    under the operations.
 | 
			
		||||
 | 
			
		||||
    If top or bottom is encountered, then we end the saturation procedure early.
 | 
			
		||||
    If the forbidden element is encountered, then we end the saturation procedure early.
 | 
			
		||||
    """
 | 
			
		||||
    closure_set: Set[ModelValue] = initial_set
 | 
			
		||||
    last_new: Set[ModelValue] = initial_set
 | 
			
		||||
    changed: bool = True
 | 
			
		||||
    forbidden_found = False
 | 
			
		||||
 | 
			
		||||
    arities = set()
 | 
			
		||||
    for mfun in mfunctions:
 | 
			
		||||
        arities.add(mfun.arity)
 | 
			
		||||
 | 
			
		||||
    while changed:
 | 
			
		||||
        changed = False
 | 
			
		||||
        new_elements: Set[ModelValue] = set()
 | 
			
		||||
        old_closure: Set[ModelValue] = closure_set - last_new
 | 
			
		||||
 | 
			
		||||
        # arity -> args
 | 
			
		||||
        cached_args = defaultdict(list)
 | 
			
		||||
        args_by_arity = defaultdict(list)
 | 
			
		||||
 | 
			
		||||
        # Pass elements into each model function
 | 
			
		||||
        for mfun in mfunctions:
 | 
			
		||||
 | 
			
		||||
            # If a previous function shared the same arity,
 | 
			
		||||
            # we'll use the same set of computed arguments
 | 
			
		||||
            # to pass into the model functions.
 | 
			
		||||
            if mfun.arity in cached_args:
 | 
			
		||||
                for args in cached_args[mfun.arity]:
 | 
			
		||||
                    # Compute the new elements
 | 
			
		||||
                    # given the cached arguments.
 | 
			
		||||
                    element = mfun(*args)
 | 
			
		||||
                    if element not in closure_set:
 | 
			
		||||
                        new_elements.add(element)
 | 
			
		||||
 | 
			
		||||
                    # Optimization: Break out of computation
 | 
			
		||||
                    # early when forbidden element is found
 | 
			
		||||
                    if forbidden_element is not None and element == forbidden_element:
 | 
			
		||||
                        forbidden_found = True
 | 
			
		||||
                        break
 | 
			
		||||
 | 
			
		||||
                if forbidden_found:
 | 
			
		||||
                    break
 | 
			
		||||
 | 
			
		||||
                # We don't need to compute the arguments
 | 
			
		||||
                # thanks to the cache, so move onto the
 | 
			
		||||
                # next function.
 | 
			
		||||
                continue
 | 
			
		||||
 | 
			
		||||
            # At this point, we don't have cached arguments, so we need
 | 
			
		||||
            # to compute this set.
 | 
			
		||||
 | 
			
		||||
            # Each argument must have at least one new element to not repeat
 | 
			
		||||
            # work. We'll range over the number of new model values within our
 | 
			
		||||
            # argument.
 | 
			
		||||
            for num_new in range(1, mfun.arity + 1):
 | 
			
		||||
        # Motivation: We want to only compute arguments that we have not
 | 
			
		||||
        # seen before
 | 
			
		||||
        for arity in arities:
 | 
			
		||||
            for num_new in range(1, arity + 1):
 | 
			
		||||
                new_args = combinations_with_replacement(last_new, r=num_new)
 | 
			
		||||
                old_args = combinations_with_replacement(old_closure, r=mfun.arity - num_new)
 | 
			
		||||
                # Determine every possible ordering of the concatenated
 | 
			
		||||
                # new and old model values.
 | 
			
		||||
                old_args = combinations_with_replacement(old_closure, r=arity - num_new)
 | 
			
		||||
                # Determine every possible ordering of the concatenated new and
 | 
			
		||||
                # old model values.
 | 
			
		||||
                for new_arg, old_arg in product(new_args, old_args):
 | 
			
		||||
                    for args in permutations(new_arg + old_arg):
 | 
			
		||||
                        cached_args[mfun.arity].append(args)
 | 
			
		||||
                        element = mfun(*args)
 | 
			
		||||
                        if element not in closure_set:
 | 
			
		||||
                            new_elements.add(element)
 | 
			
		||||
                    for combined_args in permutations(new_arg + old_arg):
 | 
			
		||||
                        args_by_arity[arity].append(combined_args)
 | 
			
		||||
 | 
			
		||||
                        # Optimization: Break out of computation
 | 
			
		||||
                        # early when forbidden element is found
 | 
			
		||||
                        if forbidden_element is not None and element == forbidden_element:
 | 
			
		||||
                            forbidden_found = True
 | 
			
		||||
                            break
 | 
			
		||||
 | 
			
		||||
                    if forbidden_found:
 | 
			
		||||
                        break
 | 
			
		||||
        # Pass each argument into each model function
 | 
			
		||||
        for mfun in mfunctions:
 | 
			
		||||
            for args in args_by_arity[mfun.arity]:
 | 
			
		||||
                # Compute the new elements
 | 
			
		||||
                # given the cached arguments.
 | 
			
		||||
                element = mfun(*args)
 | 
			
		||||
                if element not in closure_set:
 | 
			
		||||
                    new_elements.add(element)
 | 
			
		||||
 | 
			
		||||
                if forbidden_found:
 | 
			
		||||
                # Optimization: Break out of computation
 | 
			
		||||
                # early when forbidden element is found
 | 
			
		||||
                if forbidden_element is not None and element == forbidden_element:
 | 
			
		||||
                    forbidden_found = True
 | 
			
		||||
                    break
 | 
			
		||||
 | 
			
		||||
            if forbidden_found:
 | 
			
		||||
                break
 | 
			
		||||
 | 
			
		||||
        closure_set.update(new_elements)
 | 
			
		||||
        changed = len(new_elements) > 0
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue