From ae8658fda23e244213b63971c3570ac11ce0e54c Mon Sep 17 00:00:00 2001 From: Brandon Rozek Date: Sun, 21 Apr 2024 12:15:24 -0400 Subject: [PATCH] Introduced ordering at model level... --- generate_model.py | 24 +++++++++++++---- model.py | 65 ++++++++++++++++++++++++++++++++++------------- 2 files changed, 67 insertions(+), 22 deletions(-) diff --git a/generate_model.py b/generate_model.py index e7aa93e..bae9e1a 100644 --- a/generate_model.py +++ b/generate_model.py @@ -22,7 +22,7 @@ def possible_functions(operation, carrier_set): new_function = dict() for input, output in zip(inputs, outputs): new_function[input] = output - + yield ModelFunction(new_function, operation.symbol) @@ -46,7 +46,6 @@ def only_rules_with(rules: Set[Rule], operation: Operation) -> Set[Rule]: return result_rules - def possible_interpretations( logic: Logic, carrier_set: Set[ModelValue], designated_values: Set[ModelValue]): @@ -88,10 +87,25 @@ def possible_interpretations( yield interpretation def generate_model(logic: Logic, number_elements: int, num_solutions: int = -1, print_model=False): + assert number_elements > 0 carrier_set = { ModelValue("a" + str(i)) for i in range(number_elements) } + ordering = set() + + # a(0) is less than all other elements + a0 = ModelValue("a0") + for v in carrier_set: + if v != a0: + ordering.add(a0 < v) + + # Every other element is less than a(n - 1) + an = ModelValue(f"a{number_elements-1}") + for v in carrier_set: + if an != v: + ordering.add(v < an) + possible_designated_values = possible_designations(carrier_set) satisfied_models = [] @@ -102,7 +116,7 @@ def generate_model(logic: Logic, number_elements: int, num_solutions: int = -1, for interpretation in possible_interps: is_valid = True - model = Model(carrier_set, set(interpretation.values()), designated_values) + model = Model(carrier_set, set(interpretation.values()), designated_values, ordering) # Iteratively test possible interpretations # by adding one axiom at a time for rule in logic.rules: @@ -110,12 +124,12 @@ def generate_model(logic: Logic, number_elements: int, num_solutions: int = -1, if not satisfiable(small_logic, model, interpretation): is_valid = False break - + if is_valid: satisfied_models.append(model) if print_model: print(model, flush=True) - + if num_solutions >= 0 and len(satisfied_models) >= num_solutions: return satisfied_models diff --git a/model.py b/model.py index e5a325b..513fdf5 100644 --- a/model.py +++ b/model.py @@ -5,8 +5,8 @@ from common import set_to_str from logic import ( PropositionalVariable, get_propostional_variables, Logic, Term, Operation -) -from typing import Set, List, Dict, Tuple +) +from typing import Set, List, Dict, Tuple, Optional from itertools import product from functools import lru_cache @@ -27,6 +27,9 @@ class ModelValue: return self.hashed_value def __eq__(self, other): return isinstance(other, ModelValue) and self.name == other.name + def __lt__(self, other): + assert isinstance(other, ModelValue) + return ModelOrderConstraint(self, other) class ModelFunction: @@ -42,9 +45,9 @@ class ModelFunction: corrected_mapping[tuple(k)] = v else: # Assume it's atomic corrected_mapping[(k,)] = v - + self.mapping = corrected_mapping - + def __str__(self): str_dict = dict() for k, v in self.mapping.items(): @@ -54,29 +57,44 @@ class ModelFunction: def __call__(self, *args): return self.mapping[args] - + # def __eq__(self, other): # return isinstance(other, ModelFunction) and self.name == other.name and self.arity == other.arity +class ModelOrderConstraint: + # a < b + def __init__(self, a: ModelValue, b: ModelValue): + self.a = a + self.b = b + def __hash__(self): + return hash(self.a) * hash(self.b) + def __eq__(self, other): + return isinstance(other, ModelOrderConstraint) and \ + self.a == other.a and self.b == other.b + class Model: def __init__( self, carrier_set: Set[ModelValue], logical_operations: Set[ModelFunction], - designated_values: Set[ModelValue] + designated_values: Set[ModelValue], + ordering: Optional[Set[ModelOrderConstraint]] = None ): assert designated_values <= carrier_set self.carrier_set = carrier_set self.logical_operations = logical_operations self.designated_values = designated_values - + self.ordering = ordering if ordering is not None else set() + # TODO: Make sure ordering is "valid" + # That is: transitive, etc. + def __str__(self): result = f"""Carrier Set: {set_to_str(self.carrier_set)} Designated Values: {set_to_str(self.designated_values)} """ for function in self.logical_operations: result += f"{str(function)}\n" - + return result @@ -89,13 +107,13 @@ def evaluate_term(t: Term, f: Dict[PropositionalVariable, ModelValue], interpret for logic_arg in t.arguments: model_arg = evaluate_term(logic_arg, f, interpretation) model_arguments.append(model_arg) - + return model_function(*model_arguments) def all_model_valuations( pvars: Tuple[PropositionalVariable], mvalues: Tuple[ModelValue]): - + possible_valuations = [mvalues for _ in pvars] all_possible_values = product(*possible_valuations) @@ -116,20 +134,33 @@ def satisfiable(logic: Logic, model: Model, interpretation: Dict[Operation, Mode pvars = tuple(get_propostional_variables(tuple(logic.rules))) mappings = all_model_valuations_cached(pvars, tuple(model.carrier_set)) + """ + TODO: Make sure that ordering for conjunction and disjunction + at the model function level. + """ + for mapping in mappings: for rule in logic.rules: premise_met = True + premise_ts = set() for premise in rule.premises: - t = evaluate_term(premise, mapping, interpretation) - if t not in model.designated_values: + premise_t = evaluate_term(premise, mapping, interpretation) + if premise_t not in model.designated_values: premise_met = False break - + premise_ts.add(premise_t) + if not premise_met: continue - - t = evaluate_term(rule.conclusion, mapping, interpretation) - if t not in model.designated_values: + + consequent_t = evaluate_term(rule.conclusion, mapping, interpretation) + + if consequent_t not in model.designated_values: return False - + + # Make sure ordering constraint is met + for premise_t in premise_ts: + if consequent_t < premise_t in model.ordering: + return False + return True