"""
Matrix model semantics and satisfiability of
a given logic.
"""
from common import set_to_str
from logic import (
    get_propostional_variables, Logic,
    Operation, PropositionalVariable, Term
)
from collections import defaultdict
from functools import cached_property, lru_cache, reduce
from itertools import chain, combinations_with_replacement, permutations, product
from typing import Dict, List, Optional, Set, Tuple


__all__ = ['ModelValue', 'ModelFunction', 'Model', 'Interpretation']


class ModelValue:
    def __init__(self, name):
        self.name = name
        self.hashed_value = hash(self.name)
        def immutable(self, name, value):
            raise Exception("Model values are immutable")
        self.__setattr__ = immutable
    def __str__(self):
        return self.name
    def __hash__(self):
        return self.hashed_value
    def __eq__(self, other):
        return isinstance(other, ModelValue) and self.name == other.name
    def __deepcopy__(self, _):
        return ModelValue(self.name)

class ModelFunction:
    def __init__(self, arity: int, mapping, operation_name = ""):
        self.operation_name = operation_name
        self.arity = arity

        # Transform the mapping such that the
        # key is always a tuple of model values
        corrected_mapping: Dict[Tuple[ModelValue], ModelValue] = {}
        for k, v in mapping.items():
            if isinstance(k, tuple):
                assert len(k) == arity
                corrected_mapping[k] = v
            elif isinstance(k, list):
                assert len(k) == arity
                corrected_mapping[tuple(k)] = v
            else: # Assume it's atomic
                assert arity == 1
                corrected_mapping[(k,)] = v

        self.mapping = corrected_mapping

    @cached_property
    def domain(self):
        result_set: Set[ModelValue] = set()
        for args in self.mapping.keys():
            for v in args:
                result_set.add(v)
        return result_set

    def __str__(self):
        if self.arity == 1:
            return unary_function_str(self)
        elif self.arity == 2:
            return binary_function_str(self)

        # Default return dictionary representation
        str_dict = dict()
        for k, v in self.mapping.items():
            inputstr = "(" + ", ".join(str(ki) for ki in k) + ")"
            str_dict[inputstr] = str(v)
        return self.operation_name + " " + str(str_dict)

    def __call__(self, *args):
        return self.mapping[args]


def unary_function_str(f: ModelFunction) -> str:
    assert isinstance(f, ModelFunction) and f.arity == 1
    sorted_domain = sorted(f.domain, key=lambda v : v.name)
    header_line = f" {f.operation_name} | " + " ".join((str(v) for v in sorted_domain))
    sep_line = "-" + ("-" * len(f.operation_name)) + "-+-" +\
         ("-" * len(sorted_domain)) +\
         ("-" * reduce(lambda sum, v : sum + len(v.name), sorted_domain, 0))
    data_line = (" " * (len(f.operation_name) + 2)) + "| " + " ".join((str(f.mapping[(v,)]) for v in sorted_domain))
    return "\n".join((header_line, sep_line, data_line)) + "\n"

def binary_function_str(f: ModelFunction) -> str:
    assert isinstance(f, ModelFunction) and f.arity == 2
    sorted_domain = sorted(f.domain, key=lambda v : v.name)
    max_col_width = max(chain((len(v.name) for v in sorted_domain), (len(f.operation_name),)))
    header_line = f" {f.operation_name} " +\
         (" " * (max_col_width - len(f.operation_name))) + "| " +\
         " ".join((str(v) for v in sorted_domain))
    sep_line = "-" + ("-" * max_col_width) + "-+-" +\
         ("-" * len(sorted_domain)) +\
         ("-" * reduce(lambda sum, v : sum + len(v.name), sorted_domain, 0))
    data_lines = ""
    for row_v in sorted_domain:
        data_line = f" {row_v.name} | " + " ".join((str(f.mapping[(row_v, col_v)]) for col_v in sorted_domain))
        data_lines += data_line + "\n"
    return "\n".join((header_line, sep_line, data_lines))

Interpretation = Dict[Operation, ModelFunction]

class Model:
    def __init__(
            self,
            carrier_set: Set[ModelValue],
            logical_operations: Set[ModelFunction],
            designated_values: Set[ModelValue],
            name: Optional[str] = None
    ):
        assert designated_values <= carrier_set
        self.carrier_set = carrier_set
        self.logical_operations = logical_operations
        self.designated_values = designated_values
        self.name = str(abs(hash((
            frozenset(carrier_set),
            frozenset(logical_operations),
            frozenset(designated_values)
        ))))[:5] if name is None else name

    def __str__(self):
        result = ("=" * 25) + f"""
Model Name: {self.name}
Carrier Set: {set_to_str(self.carrier_set)}
Designated Values: {set_to_str(self.designated_values)}
"""
        for function in self.logical_operations:
            result += f"{str(function)}\n"

        return result + ("=" * 25) + "\n"


def evaluate_term(
        t: Term, f: Dict[PropositionalVariable, ModelValue],
        interpretation: Dict[Operation, ModelFunction]) -> ModelValue:
    """
    Given a term in a logic, mapping
    between terms and model values,
    as well as an interpretation
    of operations to model functions,
    return the evaluated model value.
    """

    if isinstance(t, PropositionalVariable):
        return f[t]

    model_function = interpretation[t.operation]
    model_arguments: List[ModelValue] = []
    for logic_arg in t.arguments:
        model_arg = evaluate_term(logic_arg, f, interpretation)
        model_arguments.append(model_arg)

    return model_function(*model_arguments)

def all_model_valuations(
        pvars: Tuple[PropositionalVariable],
        mvalues: Tuple[ModelValue]):
    """
    Given propositional variables and model values,
    produce every possible mapping between the two.
    """

    all_possible_values = product(mvalues, repeat=len(pvars))

    for valuation in all_possible_values:
        mapping: Dict[PropositionalVariable, ModelValue] = {}
        assert len(pvars) == len(valuation)
        for pvar, value in zip(pvars, valuation):
            mapping[pvar] = value
        yield mapping

@lru_cache
def all_model_valuations_cached(
        pvars: Tuple[PropositionalVariable],
        mvalues: Tuple[ModelValue]):
    return list(all_model_valuations(pvars, mvalues))


def satisfiable(logic: Logic, model: Model, interpretation: Dict[Operation, ModelFunction]) -> bool:
    """
    Determine whether a model satisfies a logic
    given an interpretation.
    """
    pvars = tuple(get_propostional_variables(tuple(logic.rules)))
    mappings = all_model_valuations_cached(pvars, tuple(model.carrier_set))

    for mapping in mappings:
        # Make sure that the model satisfies each of the rules
        for rule in logic.rules:
            # The check only applies if the premises are designated
            premise_met = True
            premise_ts: Set[ModelValue] = set()

            for premise in rule.premises:
                premise_t = evaluate_term(premise, mapping, interpretation)
                # As soon as one premise is not designated,
                # move to the next rule.
                if premise_t not in model.designated_values:
                    premise_met = False
                    break
                # If designated, keep track of the evaluated term
                premise_ts.add(premise_t)

            if not premise_met:
                continue

            # With the premises designated, make sure the consequent is designated
            consequent_t = evaluate_term(rule.conclusion, mapping, interpretation)
            if consequent_t not in model.designated_values:
                return False

    return True



def model_closure(initial_set: Set[ModelValue], mfunctions: Set[ModelFunction]):
    """
    Given an initial set of model values and a set of model functions,
    compute the complete set of model values that are closed
    under the operations.
    """
    closure_set: Set[ModelValue] = initial_set
    last_new: Set[ModelValue] = initial_set
    changed: bool = True

    while changed:
        changed = False
        new_elements: Set[ModelValue] = set()
        old_closure: Set[ModelValue] = closure_set - last_new

        # arity -> args
        cached_args = defaultdict(list)

        # Pass elements into each model function
        for mfun in mfunctions:

            # If a previous function shared the same arity,
            # we'll use the same set of computed arguments
            # to pass into the model functions.
            if mfun.arity in cached_args:
                for args in cached_args[mfun.arity]:
                    # Compute the new elements
                    # given the cached arguments.
                    element = mfun(*args)
                    if element not in closure_set:
                        new_elements.add(element)

                # We don't need to compute the arguments
                # thanks to the cache, so move onto the
                # next function.
                continue

            # At this point, we don't have cached arguments, so we need
            # to compute this set.

            # Each argument must have at least one new element to not repeat
            # work. We'll range over the number of new model values within our
            # argument.
            for num_new in range(1, mfun.arity + 1):
                new_args = combinations_with_replacement(last_new, r=num_new)
                old_args = combinations_with_replacement(old_closure, r=mfun.arity - num_new)
                # Determine every possible ordering of the concatenated
                # new and old model values.
                for new_arg, old_arg in product(new_args, old_args):
                    for args in permutations(new_arg + old_arg):
                        cached_args[mfun.arity].append(args)
                        element = mfun(*args)
                        if element not in closure_set:
                            new_elements.add(element)

        closure_set.update(new_elements)
        changed = len(new_elements) > 0
        last_new = new_elements

    return closure_set