# Copyright 2023 IBM Corp. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0

from typing import Union, Tuple, List, TypeVar, Iterable

from . import _utils
from .constants import Fact
from lnn.symbolic.logic.leaf_formula import Predicate

import re
import random
import itertools
import numpy as np
from tabulate import tabulate
import matplotlib.pyplot as plt


def split_string_into_groundings(state: str) -> Tuple[str]:

    :param state: The state is given as a string value representing the groundings; i.e "('T','T')"

    P,Q = Predicate("P","Q")
    x,y = Variables("x","y")
    PQ  = Or(P(x), Q(y))
    PQ.state(("T","T")) returns Fact.TRUE

    Groundings, given as strings, as a tuple of strings ; i.e "('T','T')" -> ("T","T")
    pattern = r'[\'"()]'
    grounding_strings = re.sub(pattern, "", state)
    partial_groundings = grounding_strings.split(",")
    partial_groundings = [pg.strip() for pg in partial_groundings]
    return tuple(partial_groundings)

def get_binary_truth_table(formulae: dict) -> dict:

    :param formulae: Formulae is a dictionary with str:formula format ;i.e {"P or Q": <lnn.symbolic.logic.n_ary_neuron >}

    Truth table in a dictionary format

    table = dict()
    groundings = set()
    for f in formulae.values():
    groundings = alphanumeric_sort(groundings)
    table[""] = [g for g in groundings]
    for name, formula in formulae.items():
        table[name] = [formula.state(g).name for g in groundings]
    return table

def alphanumeric_sort(iterable: Iterable):
    def get_int(text):
        return int(text) if text.isdigit() else text

    def alphanum_key(key):
        return [get_int(c) for c in re.split("([0-9]+)", key)]

    return sorted(iterable, key=alphanum_key)

def get_ternary_table(formulae: dict, unique_groundings: List[str] = None) -> dict:
    keys = list(list(formulae.values())[0].state().keys())
    formula = list(formulae.keys())[0]
    keys = [split_string_into_groundings(key) for key in keys]
    operator = list(formulae.values())[0]
    if unique_groundings is None:
        # unique_groundings = sorted(list(set(itertools.chain.from_iterable(keys))))
        unique_groundings = ["F", "U", "T"]
    grounding_map = dict(zip(range(len(unique_groundings)), unique_groundings))
    number_of_unique_groundings = len(unique_groundings)
    grounding_grid = np.zeros(
        (number_of_unique_groundings, number_of_unique_groundings), dtype="int,int"
    truth_grid = np.array(
        [[""] * number_of_unique_groundings] * number_of_unique_groundings, dtype="<U10"
    for i in range(grounding_grid.shape[0]):
        for j in range(grounding_grid.shape[1]):
            row, column = grounding_map[i], grounding_map[j]
            truth_grid[i, j] = operator.state((row, column)).name
    row_title = np.atleast_2d(unique_groundings).T
    column_data = np.atleast_2d([formula] + unique_groundings)
    truth_grid = np.hstack((row_title, truth_grid))
    truth_grid = np.vstack((column_data, truth_grid))
    return truth_grid

def get_n_ary_truth_table(formulae: dict, unique_groundings: List[str] = None) -> dict:

    :param formulae: Formulae is a dictionary with str:formula format ;i.e {"P or Q": <lnn.symbolic.logic.n_ary_neuron >}
    :param unique_groundings: Unique groundings that fill either the rows or columns i.e ["F", "U", "T"]

    Truth table in a dictionary format

    keys = list(list(formulae.values())[0].state().keys())
    formula = list(formulae.keys())[0]
    keys = [split_string_into_groundings(key) for key in keys]
    if unique_groundings is None:
        unique_groundings = sorted(list(set(itertools.chain.from_iterable(keys))))
    truth = {f"{formula}": unique_groundings}
    for unique_grounding in unique_groundings:
        table_columns = list(
                lambda partial_grounding: partial_grounding[0].startswith(
        for _, formula in formulae.items():
            truth[unique_grounding] = [
                formula.state(table_column).name for table_column in table_columns
    return truth

[docs]def pretty_truth_table(formulae: dict, unique_groundings: List[str] = None) -> None: r""" Parameters ---------- :param formulae: Formulae is a dictionary with str:formula format ;i.e {"P or Q": <lnn.symbolic.logic.n_ary_neuron >} :param arity: The number of variables specified :param unique_groundings: unique_groundings: Unique groundings that fill either the rows or columns i.e ["F", "U", "T"] Returns ------- A pretty truth table """ keys = list(list(formulae.values())[0].state().keys()) if len(keys[0]) <= 2: table = get_binary_truth_table(formulae) print(tabulate(table, headers="keys", tablefmt="fancy_grid")) else: table = get_ternary_table(formulae, unique_groundings) print(tabulate(table, tablefmt="fancy_grid"))
def generate_truth_table(P: Predicate, Q: Predicate, states=None) -> None: r""" Parameters ---------- P: Predicate P Q: Predicate Q states: Data that you would like to each of the predicates Returns ------- Nothing, causes side effect on P and Q """ if states is None: states = [Fact.FALSE, Fact.UNKNOWN, Fact.TRUE] idx = [f"{i}" for i in range(len(states))] data = dict(zip(idx, states)) P.add_data(data) Q.add_data(data) def truth_table(n: int, states=None) -> List[Tuple[Fact, ...]]: if states is None: states = [FALSE, TRUE] return list(itertools.product(states, repeat=n)) def truth_table_dict(*args: str, states=None): if states is None: states = [FALSE, TRUE] for instance in itertools.product(states, repeat=len(args)): yield dict(zip(args, instance)) def fact_to_bool(*fact: Fact) -> Union[Fact, bool, Tuple[bool, ...]]: if len(fact) > 1: return tuple(map(fact_to_bool, fact)) if fact[0] is TRUE: return True elif fact[0] is FALSE: return False else: return fact[0] def bool_to_fact(*truth: bool) -> Union[Fact, Tuple[Fact, ...]]: if len(truth) > 1: return tuple(map(bool_to_fact, truth)) return TRUE if truth[0] else FALSE
[docs]def predicate_truth_table(*args: str, arity: int, model, states=None): """ predicate_truth_table("p", "q", "r", model=model) randomises the truth table into a predicate by str(int) rows Returns ------- model : Model """ if states is None: states = [FALSE, TRUE] from lnn import Predicate # noqa: F401 n = len(args) TT = np.array(truth_table(n, states)) _range = list(range(len(TT))) for idx, arg in enumerate(args): model[arg] = Predicate(arg, arity=arity) random.shuffle(_range) for i in _range: grounding = f"{i}" if arity == 1 else (f"{i}",) * arity truth = TT[i, idx].item() model[arg].add_data({grounding: truth}) return model
def plot_loss(total_loss, losses): loss, cummulative_loss = total_loss fig, axs = plt.subplots(1, 2) fig.suptitle("Model Loss") axs[0].plot(np.array(loss)) for ax in axs.flat: ax.set(xlabel="Epochs", ylabel="Loss") axs[0].legend(["Total Loss"]) axs[1].plot(np.array(cummulative_loss)) axs[1].legend([loss.value.capitalize() for loss in losses]) Model = TypeVar("Model") def plot_params(self: Model): legend = [] for node in self.nodes.values(): if hasattr(node, "parameter_history"): for param, data in node.parameter_history.items(): if isinstance(data[0], list): operands = list(node.operands) legend_idxs = [f"_{operands[i]}" for i in list(range(len(data[0])))] else: legend_idxs = [""] [ legend.append(f"{} {_utils.param_symbols[param]}{i}") for i in legend_idxs ] plt.plot(data) plt.xlabel("Epochs") plt.legend(legend) plt.title(f"{} Parameters") def return1(args: Union[List, Tuple]): if len(args) == 1: return args[0] return args