matmod/model.py

263 lines
8.7 KiB
Python
Raw Normal View History

2024-04-08 23:59:21 -04:00
"""
Defining what it means to be a model
"""
2024-04-15 00:08:00 -04:00
from common import set_to_str
2024-04-08 23:59:21 -04:00
from logic import (
PropositionalVariable, get_propostional_variables, Logic, Term,
2024-05-03 13:06:52 -04:00
Operation, Conjunction, Disjunction, Implication
2024-04-21 12:15:24 -04:00
)
2024-05-03 13:06:52 -04:00
from typing import Set, Dict, Tuple, Optional
2024-04-15 00:08:00 -04:00
from functools import lru_cache
2024-05-03 13:06:52 -04:00
from itertools import combinations, chain, product
from copy import deepcopy
2024-04-08 23:59:21 -04:00
__all__ = ['ModelValue', 'ModelFunction', 'Model']
class ModelValue:
def __init__(self, name):
self.name = name
self.hashed_value = hash(self.name)
def immutable(self, name, value):
raise Exception("Model values are immutable")
self.__setattr__ = immutable
def __str__(self):
return self.name
def __hash__(self):
return self.hashed_value
def __eq__(self, other):
return isinstance(other, ModelValue) and self.name == other.name
2024-04-21 12:15:24 -04:00
def __lt__(self, other):
assert isinstance(other, ModelValue)
return ModelOrderConstraint(self, other)
2024-04-08 23:59:21 -04:00
class ModelFunction:
2024-05-03 13:06:52 -04:00
def __init__(self, arity: int, mapping, operation_name = ""):
2024-04-08 23:59:21 -04:00
self.operation_name = operation_name
2024-05-03 13:06:52 -04:00
self.arity = arity
2024-04-08 23:59:21 -04:00
# Correct input to always be a tuple
corrected_mapping = dict()
for k, v in mapping.items():
if isinstance(k, tuple):
2024-05-03 13:06:52 -04:00
assert len(k) == arity
2024-04-08 23:59:21 -04:00
corrected_mapping[k] = v
elif isinstance(k, list):
2024-05-03 13:06:52 -04:00
assert len(k) == arity
2024-04-08 23:59:21 -04:00
corrected_mapping[tuple(k)] = v
else: # Assume it's atomic
2024-05-03 13:06:52 -04:00
assert arity == 1
2024-04-08 23:59:21 -04:00
corrected_mapping[(k,)] = v
2024-04-21 12:15:24 -04:00
2024-04-08 23:59:21 -04:00
self.mapping = corrected_mapping
2024-04-21 12:15:24 -04:00
2024-04-08 23:59:21 -04:00
def __str__(self):
str_dict = dict()
for k, v in self.mapping.items():
inputstr = "(" + ", ".join(str(ki) for ki in k) + ")"
str_dict[inputstr] = str(v)
return str(str_dict)
def __call__(self, *args):
return self.mapping[args]
2024-04-21 12:15:24 -04:00
2024-04-08 23:59:21 -04:00
# def __eq__(self, other):
# return isinstance(other, ModelFunction) and self.name == other.name and self.arity == other.arity
2024-04-21 12:15:24 -04:00
class ModelOrderConstraint:
# a < b
def __init__(self, a: ModelValue, b: ModelValue):
self.a = a
self.b = b
def __hash__(self):
return hash(self.a) * hash(self.b)
def __eq__(self, other):
return isinstance(other, ModelOrderConstraint) and \
self.a == other.a and self.b == other.b
2024-04-08 23:59:21 -04:00
class Model:
def __init__(
self,
carrier_set: Set[ModelValue],
logical_operations: Set[ModelFunction],
2024-04-21 12:15:24 -04:00
designated_values: Set[ModelValue],
ordering: Optional[Set[ModelOrderConstraint]] = None
2024-04-08 23:59:21 -04:00
):
assert designated_values <= carrier_set
self.carrier_set = carrier_set
self.logical_operations = logical_operations
self.designated_values = designated_values
2024-04-21 12:15:24 -04:00
self.ordering = ordering if ordering is not None else set()
# TODO: Make sure ordering is "valid"
# That is: transitive, etc.
2024-04-08 23:59:21 -04:00
def __str__(self):
result = f"""Carrier Set: {set_to_str(self.carrier_set)}
Designated Values: {set_to_str(self.designated_values)}
"""
for function in self.logical_operations:
result += f"{str(function)}\n"
2024-04-21 12:15:24 -04:00
2024-04-08 23:59:21 -04:00
return result
2024-04-15 00:08:00 -04:00
def evaluate_term(t: Term, f: Dict[PropositionalVariable, ModelValue], interpretation: Dict[Operation, ModelFunction]) -> ModelValue:
2024-04-08 23:59:21 -04:00
if isinstance(t, PropositionalVariable):
return f[t]
model_function = interpretation[t.operation]
model_arguments = []
for logic_arg in t.arguments:
model_arg = evaluate_term(logic_arg, f, interpretation)
model_arguments.append(model_arg)
2024-04-21 12:15:24 -04:00
2024-04-08 23:59:21 -04:00
return model_function(*model_arguments)
def all_model_valuations(
2024-04-15 00:08:00 -04:00
pvars: Tuple[PropositionalVariable],
mvalues: Tuple[ModelValue]):
2024-04-21 12:15:24 -04:00
2024-04-08 23:59:21 -04:00
possible_valuations = [mvalues for _ in pvars]
all_possible_values = product(*possible_valuations)
for valuation in all_possible_values:
2024-04-15 00:08:00 -04:00
mapping: Dict[PropositionalVariable, ModelValue] = dict()
2024-04-08 23:59:21 -04:00
assert len(pvars) == len(valuation)
for pvar, value in zip(pvars, valuation):
mapping[pvar] = value
yield mapping
2024-04-15 00:08:00 -04:00
@lru_cache
def all_model_valuations_cached(
pvars: Tuple[PropositionalVariable],
mvalues: Tuple[ModelValue]):
return list(all_model_valuations(pvars, mvalues))
2024-04-21 17:37:21 -04:00
def rule_ordering_satisfied(model: Model, interpretation: Dict[Operation, ModelFunction]) -> bool:
"""
Currently testing whether this function helps with runtime...
"""
if Conjunction in interpretation:
possible_inputs = ((a, b) for (a, b) in product(model.carrier_set, model.carrier_set))
for a, b in possible_inputs:
output = interpretation[Conjunction](a, b)
if a < b in model.ordering and output != a:
print("RETURNING FALSE")
return False
if b < a in model.ordering and output != b:
print("RETURNING FALSE")
return False
if Disjunction in interpretation:
possible_inputs = ((a, b) for (a, b) in product(model.carrier_set, model.carrier_set))
for a, b in possible_inputs:
output = interpretation[Disjunction](a, b)
if a < b in model.ordering and output != b:
print("RETURNING FALSE")
return False
if b < a in model.ordering and output != a:
print("RETURNING FALSE")
return False
return True
2024-04-15 00:08:00 -04:00
def satisfiable(logic: Logic, model: Model, interpretation: Dict[Operation, ModelFunction]) -> bool:
pvars = tuple(get_propostional_variables(tuple(logic.rules)))
mappings = all_model_valuations_cached(pvars, tuple(model.carrier_set))
2024-04-08 23:59:21 -04:00
2024-04-21 17:37:21 -04:00
# NOTE: Does not look like rule ordering is helping for finding
# models of R...
if not rule_ordering_satisfied(model, interpretation):
return False
2024-04-21 12:15:24 -04:00
2024-04-08 23:59:21 -04:00
for mapping in mappings:
for rule in logic.rules:
premise_met = True
2024-04-21 12:15:24 -04:00
premise_ts = set()
2024-04-08 23:59:21 -04:00
for premise in rule.premises:
2024-04-21 12:15:24 -04:00
premise_t = evaluate_term(premise, mapping, interpretation)
if premise_t not in model.designated_values:
2024-04-08 23:59:21 -04:00
premise_met = False
break
2024-04-21 12:15:24 -04:00
premise_ts.add(premise_t)
2024-04-08 23:59:21 -04:00
if not premise_met:
continue
2024-04-21 12:15:24 -04:00
consequent_t = evaluate_term(rule.conclusion, mapping, interpretation)
if consequent_t not in model.designated_values:
2024-04-08 23:59:21 -04:00
return False
2024-04-21 12:15:24 -04:00
2024-04-08 23:59:21 -04:00
return True
2024-05-03 13:06:52 -04:00
def model_closure(initial_set: Set[ModelValue], mfunctions: Set[ModelFunction]):
last_set: Set[ModelValue] = set()
current_set: Set[ModelValue] = initial_set
while last_set != current_set:
last_set = deepcopy(current_set)
for mfun in mfunctions:
# Get output for every possible input configuration
# from last_set
for args in product(*(last_set for _ in range(mfun.arity))):
current_set.add(mfun(*args))
return current_set
def violates_vsp(model: Model, interpretation: Dict[Operation, ModelFunction]) -> bool:
"""
Tells you whether a model violates the
variable sharing property.
If it returns false, it is still possible that
the variable sharing property is violated
just that we didn't check for the appopriate
subalgebras.
"""
impfunction = interpretation[Implication]
# Compute I the set of tuples (x, y) where
# x -> y does not take a designiated value
I: Set[Tuple[ModelValue, ModelValue]] = set()
for (x, y) in product(model.carrier_set, model.carrier_set):
if impfunction(x, y) not in model.designated_values:
I.add((x, y))
# Construct the powerset without the empty set
s = list(I)
I_power = chain.from_iterable(combinations(s, r) for r in range(1, len(s) + 1))
# ((x1, y1)), ((x1, y1), (x2, y2)), ...
for xys in I_power:
# Compute the closure of all operations
# with just the xs
xs = {xy[0] for xy in xys}
carrier_set_left: Set[ModelValue] = model_closure(xs, model.logical_operations)
# Compute the closure of all operations
# with just the ys
ys = {xy[1] for xy in xys}
carrier_set_right: Set[ModelValue] = model_closure(ys, model.logical_operations)
# If the carrier set intersects, then we violate VSP
if len(carrier_set_left & carrier_set_right) > 0:
print("FAIL: Carrier sets intersect")
return True
for (x2, y2) in product(carrier_set_left, carrier_set_right):
if impfunction(x2, y2) in model.designated_values:
print(f"({x2}, {y2}) take on a designated value")
return True
return False