Source code for lnn.model

##
# Copyright 2023 IBM Corp. All Rights Reserved.
#
# SPDX-License-Identifier: Apache-2.0
##

# flake8: noqa: E501

import itertools as itls
from collections.abc import Iterable
from typing import Union, Dict, Tuple, List

from . import viz
from . import _exceptions, _utils
from .constants import Fact, World, Direction, Loss
from .symbolic.logic import Proposition, Predicate, Formula

import torch
import random
import logging
import datetime
import networkx as nx
from tqdm import tqdm
from torch import nn
import matplotlib.pyplot as plt

_utils.logger_setup(flush=True)


[docs]class Model(nn.Module): r"""Creates a container for logical reasoning and neural learning. Models define a theory or a collection of formulae, with additional reasoning and learning functionality that can be applied to each formula in the model. In contrast to standard FOL where the existence of a formula symbol assumes a `True` truth value, the data associated with LNN formulae can take on any classical truth (Fact) or belief bounds (a range real-values). Models are also dynamic, instantiated as empty containers which are later populated with knowledge rules and data. This additionally allows LNNs to operate in dynamic environments whereby the knowledge acquired may grow as new information becomes available. Parameters ------------ name : str, optional Name of contextual model, defaults to "Model" Attributes ---------- graph : nx.DiGraph Directed graph that connects nodes, pointing from operator to operand nodes. nodes : dict Each formula is keyed by a formula_number, with the value as the formula object. query : Formula A formula node that is set as the current query - allows the model to be used in QA/theorem proving whereby inference is governed towards solving the query. Examples -------- ```python # define the predicates x, y = Variables("x", "y") Smokes, Asthma, Cough = Predicates("Smokes", "Asthma", "Cough") Friends = Predicate("Friends", arity=2) # define the connectives/quantifiers Smokers_have_friends = And(Smokes(x), Friends(x, y)) Asthmatic_smokers_cough = ( Exists(x, Implies(And(Smokes(x), Asthma(x)), Cough(x)))) Smokers_befriend_smokers = ( Forall(x, y, Implies(Smokers_have_friends(x, y), Smokes(y)))) # add root formulae to model model = Model() model.add_knowledge( Asthmatic_smokers_cough, Smokers_befriend_smokers) # add data to the model model.add_data({ Smokes: { "Person_1": Fact.TRUE, "Person_2": Fact.UNKNOWN, "Person_3": Fact.UNKNOWN}, Friends: { ("Person_1", "Person_2"): Fact.TRUE, ("Person_2", "Person_3"): Fact.UNKNOWN}}) # reason over the model model.infer() # verify the model outputs model.print() ``` """ def __init__( self, knowledge: Union[Formula, Iterable[Formula]] = None, data: Dict = None, name: str = "Model", ): super(Model, self).__init__() self.graph = nx.DiGraph() self.nodes = dict() self.node_names = dict() self.node_structures = dict() self.num_formulae = 0 self.name = name self.query = None self._converge = None if knowledge: if isinstance(knowledge, Iterable): self.add_knowledge(*knowledge) else: self.add_knowledge(knowledge) if data: self.add_data(data) logging.info(f" {name} {datetime.datetime.now()} ".join(["*" * 22] * 2)) def __getitem__( self, formula: Union[Formula, int] ) -> Union[Formula, List[Formula]]: r"""Returns a formula object from the model. If the formula is in the model, return the formula - for backward compatibility if multiple formula exists in the model with the same structure, return a list of all the relevant nodes """ if isinstance(formula, int): return self.nodes[formula] if formula.formula_number is not None and formula.formula_number in self.nodes: return self.nodes[formula.formula_number] if formula.structure in self.node_structures: result = self.node_structures[formula.structure] return ( result if len(self.node_structures[formula.structure]) > 1 else result[0] ) def __contains__(self, formula: Formula): if formula.formula_number and formula.formula_number in self.nodes: return True return formula.structure in self.node_structures
[docs] def set_query(self, formula: Formula, world=World.OPEN, converge=False): r"""Inserts a query node into the model and maintains a handle on the node. Parameters ---------- formula : Formula Name of contextual model world : World Default behavior of the formula. If unspecified, assumes open world assumption. Notes ----- The query formula will be added to the model and will not be removed, even if a new query is defined using this function. """ self.add_knowledge(formula, world=world) self.query = formula self._converge = converge
[docs] def infer_query(self, *args, **kwds) -> Tuple[Tuple[int, int], torch.Tensor]: r"""Reasons only over the stored query. Is the same as calling [model.infer](#lnn.Model.infer) but setting the source node as [model.query](#lnn.Model.set_query).""" if self.query: return self.infer(*args, **kwds, source=self.query)
def add_formulae(self, *args, **kwds): raise NameError(f"`add_formulae` is deprecated, use `add_knowledge` instead")
[docs] def add_knowledge(self, *formulae: Formula, world: World = None): r"""Extend the model to include additional formulae. Only root level formulae explicitly need to be added to the model. Unless otherwise specified, each root formula follows the open world assumption. Examples -------- ```python P, Q = Predicates("P1", "Q") model.add_knowledge(P, Q) ``` creates the predicate and inserts into the model or ```python model = Model() P1 = Predicate("P1") P2 = Predicate("P2", 2) P3 = Predicate("P3", 3) model.add_knowledge( And(P1(x), P2(x, y)), Implies(P2(x, y), P3(x, y, z)) ) ``` inserts the formulae roots into the model and appropriately includes all subformulae also into the scope of the model. Any formulae that directly require inquiry should first be created in the user scope and thereafter inserted into the model for reference after reasoning/learning e.g. ```python model = Model() P1 = Predicate("P1") P2 = Predicate("P2", 2) P3 = Predicate("P3", 3) my_and = And(P1(x), P2(x, y)) model.add_knowledge( my_and, Implies(P2(x, y), P3(x, y, z)) ) ... model.infer() ... my_and.state() ``` """ self._add_knowledge(*formulae, world=world)
def add_propositions(self, *names: str, **kwds): ret = [] for name in names: P = Proposition(name, **kwds) self.add_knowledge(P) ret.append(P) return ret[0] if len(ret) == 1 else ret def add_predicates(self, arity: int, *names: str, **kwds): ret = [] for name in names: P = Predicate(name, arity=arity, **kwds) self.add_knowledge(P) ret.append(P) return ret[0] if len(ret) == 1 else ret def replace_graph_edge( self, old_edge: (Formula, Formula), new_edge: (Formula, Formula) ): self.graph.remove_edge(*old_edge) self.graph.add_edge(*new_edge) def _add_knowledge(self, *formulae: Formula, world: World = None): for idx, f in enumerate(formulae): _exceptions.AssertFormula(f) self.graph.add_node(f) self.graph.add_edges_from(f.edge_list) self.num_formulae = f.set_formula_number(self.num_formulae) + 1 for node in self.graph.nodes: if node.structure in self.node_structures: if node not in self.node_structures[node.structure]: self.node_structures[node.structure].append(node) else: self.node_structures.update({node.structure: [node]}) if node.name in self.node_names: if node not in self.node_names[node.name]: self.node_names[node.name].append(node) else: self.node_names.update({node.name: [node]}) self.nodes[node.formula_number] = node if world: for f in formulae: f.reset_world(world) def add_facts(self, *args, **kwds): raise NameError(f"`add_facts` is deprecated, use `add_data` instead")
[docs] def add_data( self, data: Dict[ Formula, Union[ Union[bool, Fact, float, Tuple[float, float]], Dict[ Union[str, Tuple[str, ...]], Union[bool, Fact, float, Tuple[float, float]], ], ], ], ): r"""Add data to select formulae in the model, in the form of classical facts or belief bounds. Data given is a Fact or belief bounds assumes a propositional formula. Data given in a dict assumes a first-order logic formula, keyed by the grounding and a value given as a Fact or belief bounds. Parameters ---------- data : a dict of Fact, belief bounds or dict The dict is keyed by the formula for which data is to be added, with the truths as the value. For propositional formulae, truths are given as either Facts or belief bounds. These beliefs can be given as a bool, float or a float-range, i.e. a tuple of 2 floats. For first-order logic formula, inputs truths are given as a dict. This is further keyed by the grounding (a str for unary formlae or tuple of strings of larger arities), with values also as Facts or bounds on beliefs. Examples -------- ```python # propositional P = Proposition("Person") model.add_data({ P: Fact.TRUE }) ``` ```python # first-order logic Person = Predicate("Person") BD = Predicate("Birthdate", 2) model.add_data({ Person: { "Barack Obama": Fact.TRUE, "Bo": (.1, .4) }, BD: { ("Barack Obama", "04 August 1961"): Fact.TRUE, ("Bo", "09 October 2008"): (.6, .75) } }) ``` Warning ------- Assumes that the formulae have already been inserted into the model, see [add_knowledge](https://ibm.github.io/LNN/lnn/LNN.html#lnn.Model.add_knowledge) for more details. """ for formula, fact in data.items(): if not isinstance(formula, Formula): raise TypeError( "formula expected of type Formula, received " f"{formula.__class__.__name__}" ) _exceptions.AssertFormulaInModel(self, formula) if formula.propositional: _exceptions.AssertBounds(fact) else: _exceptions.AssertFOLFacts(fact) formula.add_data(fact)
[docs] def add_labels( self, labels: Dict[ Formula, Union[ Union[Fact, Tuple[float, float]], Dict[Union[str, Tuple[str, ...]], Union[Fact, Tuple[float, float]]], ], ], ): r"""Add labels to select formulae in the model, in the form of classical facts or belief bounds. Labels given is a Fact or belief bounds assumes a propositional formula. Labels given in a dict assumes a first-order logic formula, keyed by the grounding and a value given as a Fact or belief bounds. Parameters ---------- labels : a dict of Fact, belief bounds or dict The dict is keyed by the formula for which data is to be added, with the truths as the value. For propositional formulae, truths are given as either Facts or belief bounds (a tuple of 2 floats). For first-order logic formula, inputs truths are given as a dict. This is further keyed by the grounding (a str for unary formlae or tuple of strings of larger arities), with values also as Facts or bounds on beliefs. Examples -------- ```python # propositional P = Proposition("Person") model.add_labels({ P: Fact.TRUE }) ``` ```python # first-order logic Person = Predicate("Person") BD = Predicate("Birthdate", 2) model.add_labels({ Person: { "Barack Obama": Fact.TRUE, "Bo": (.1, .4) }, BD: { ("Barack Obama", "04 August 1961"): Fact.TRUE, ("Bo", "09 October 2008"): (.6, .75) } }) ``` Warning ------- Assumes that the formulae have already been inserted into the model, see [add_knowledge](https://ibm.github.io/LNN/lnn/LNN.html#lnn.Model.add_knowledge) for more details. """ for formula, label in labels.items(): _exceptions.AssertFormulaInModel(self, formula) if formula.propositional: _exceptions.AssertBounds(label) else: _exceptions.AssertFOLFacts(label) formula.add_labels(label)
def _traverse_execute( self, func: str, direction: Direction = Direction.UPWARD, source: Formula = None, **kwds, ): r"""Traverse over the subgraph and execute an operation per node starting from source. Traverses through graph from `source` in the given `direction` and execute `func` at each node """ _exceptions.AssertValidDirection(direction) nodes = None if direction is Direction.UPWARD: nodes = list(nx.dfs_postorder_nodes(self.graph, source)) elif direction is Direction.DOWNWARD: nodes = list(reversed(list(nx.dfs_postorder_nodes(self.graph, source)))) coalesce = torch.tensor(0.0) for node in nodes: val = getattr(node, func)(**kwds) if hasattr(node, func) else None coalesce = coalesce + val if val is not None else coalesce if coalesce and func in [d.value.lower() for d in Direction]: logging.info(f"{direction.value} INFERENCE RESULT:{coalesce}") return coalesce
[docs] def infer( self, direction: Direction = None, source: Formula = None, max_steps: int = None, **kwds, ) -> Tuple[Tuple[int, int], torch.Tensor]: r"""Reasons over all possible inferences until convergence Parameters ---------- direction : {Direction.UPWARD, Direction.DOWNWARD}, optional Can be specified as either UPWARD or DOWNWARD inference, a single pass of that direction will be applied. If unspecified, defaults to the LNN naive inference strategy of doing inference until convergence. source : node, optional Specifies starting node for [depth-first search traversal](https://networkx.org/documentation/stable/reference/algorithms/generated/networkx.algorithms.traversal.depth_first_search.dfs_postorder_nodes.html#networkx.algorithms.traversal.depth_first_search.dfs_postorder_nodes). Specifying a node here will compute reasoning (until convergence) on the subgraph, with the specified source is the root of the subgraph. max_steps: int, optional Limits the inference to a specified number of passes of the naive traversal strategy. If unspecified, the steps will not be limited, i.e. inference will take place until convergence. Returns ------- (steps, facts_inferred) : Tuple[tuple of 2 ints, torch.Tensor] """ return self._infer( direction=direction, source=source, max_steps=max_steps, **kwds, )
def _infer( self, direction: Direction = None, source: Formula = None, max_steps: int = None, **kwds, ) -> Tuple[Tuple[int, int], torch.Tensor]: r"""Implementation of model inference.""" direction = ( [Direction.UPWARD, Direction.DOWNWARD] if not direction else [direction] ) converged = False additional_axioms, steps, facts_inferred = 0, 0, 0 while not converged: if self.query and self.query.is_classically_resolved and not self._converge: logging.info("=" * 22) logging.info( f"QUERY PROVED AS {self.query.world_state(True)} for " f"'{self.query.name}'" ) break logging.info("-" * 22) logging.info(f"REASONING STEP:{steps}") bounds_diff = 0.0 for d in direction: bounds_diff += self._traverse_execute( d.value.lower(), d, source, **kwds ) converged_bounds = ( True if direction in ([[Direction.UPWARD], [Direction.DOWNWARD]]) else bounds_diff <= 1e-7 ) if converged_bounds: converged = True logging.info("NO UPDATES AVAILABLE, TRYING A NEW AXIOM") facts_inferred += bounds_diff steps += 1 if max_steps and steps >= max_steps: break logging.info("=" * 22) logging.info( f"INFERENCE CONVERGED WITH {facts_inferred} BOUNDS " f"UPDATES IN {steps} REASONING STEPS " ) logging.info("*" * 78) return steps, facts_inferred
[docs] def upward(self, **kwds): r"""Performs upward inference for each node in the model from leaf to root.""" return self.infer(Direction.UPWARD, **kwds)
[docs] def downward(self, **kwds): r"""Performs downward inference for each node in the model from root to leaf.""" return self.infer(Direction.DOWNWARD, **kwds)
[docs] def train(self, losses: Union[Loss, List[Loss], Dict[List[Loss], float]], **kwds): r"""Train the model. Reasons across the model until convergence using the standard inference strategy - equivalent to running a NN in the forward direction. At the end of each reasoning pass losses are calculated according to a predefined or custom loss and model parameters are updated. An epoch constitutes all computation until parameters take a step. Parameters ------------ losses: Loss, list or dict of losses Predefined losses expected from the fixed Loss constants. If given in dict form, coefficients of each loss can be specified as a float value. The value can alternatively specify additional parameters for each loss calculation using a dict. optimizer : pytorch optimizer, optional Custom optimizers should be instantiated with the model parameters using `model.parameters()`. If unspecified, defaults to [Adam](https://pytorch.org/docs/stable/generated/torch.optim.Adam.html#torch.optim.Adam). learning_rate : float, optional If unspecified, defaults to 5e-2. epochs : float, optional Number of training epochs. If unspecified, trains for 3e2 epochs. pbar : bool, optional Prints out a tqdm training progress bar. If unspecified, does not print out. Returns ------- (epochs, total_loss) : Tuple[int, Tuple[List, Tensor]] A tuple of variables are returned. The `epochs` is number of epochs trained before stopped/converged + 1. The `total_loss` returns a tuple of 2 values: first is the `running_loss` as a list for the sum of loss at the end of each epoch; then the `loss_history`, which is a Tensor of individual loss components as specified by the `losses` argument. Examples -------- ```python # construct the model from formulae model = Model() p1, p2 = Predicates("P1", "P2") x = Variable("x") AB = And(p1(x), p2(x)) model.add_knowledge(AB) # add data to the model model.add_data({ p1: { "0": Fact.TRUE, "1": Fact.TRUE, '2': Fact.FALSE, '3': Fact.FALSE }, p2: { '0': Fact.TRUE, '1': Fact.FALSE, '2': Fact.TRUE, '3': Fact.FALSE, } }) # add supervisory targets model.add_labels({ AB: { '0': Fact.TRUE, '1': Fact.FALSE, '2': Fact.TRUE, '3': Fact.FALSE, } }) # train the model and output results model.train(losses=Loss.SUPERVISED) model.print(params=True) ``` """ optimizer = kwds.get( "optimizer", torch.optim.Adam( kwds.get("parameters", self.parameters()), lr=kwds.get("learning_rate", 5e-2), ), ) running_loss, loss_history, inference_history = [], [], [] for epoch in tqdm( range(int(kwds.get("epochs", 3e2))), desc="training epoch", disable=not kwds.get("pbar", False), ): optimizer.zero_grad() if epoch > 0: logging.info(" PARAMETER STEP ".join(["#" * 31] * 2)) self.reset_bounds() self.increment_param_history(kwds.get("parameter_history")) _, facts_inferred = self.infer(**kwds) loss_fn = self.loss_fn(losses) loss = sum(loss_fn) if not loss.grad_fn: break if loss and len(loss_fn) > 1: logging.info(f"TOTAL LOSS: {loss}") loss.backward() optimizer.step() self._project_params() running_loss.append(loss.item()) loss_history.append([L.clone().detach().tolist() for L in loss_fn]) inference_history.append(facts_inferred.item()) if loss <= 1e-7 and kwds.get("stop_at_convergence", True): break self.reset_bounds() self.infer(**kwds) self.increment_param_history(kwds.get("parameter_history")) return (running_loss, loss_history), inference_history
[docs] def parameters(self): result = list( itls.chain.from_iterable([n.parameters() for n in self.nodes.values()]) ) return result
def parameters_grouped_by_neuron(self): result = list() for n in self.nodes.values(): param_group = dict() param_group["params"] = list() param_group["param_names"] = list() for name, param in n.named_parameters(): param_group["params"].append(param) param_group["param_names"].append(name) param_group["neuron_type"] = n.__class__.__name__ result.append(param_group) return result
[docs] def named_parameters(self): result = dict() for n in self.nodes.values(): result.update( {f"{n}.{name}": param for name, param in n.named_parameters()} ) return result
def loss_fn(self, losses): if losses is None: raise Exception( "no loss function given, " f"expected losses from the following {[l.name for l in Loss]}" ) elif isinstance(losses, Loss): losses = [losses] elif isinstance(losses, list): losses = {c: None for c in losses} result = list() for loss in losses: _exceptions.AssertLossType(loss) if loss == Loss.CUSTOM: if not isinstance(losses[loss], dict): raise TypeError( "custom losses expected as a dict with keys as " "name of the loss and values as function " "definitions" ) for loss_fn in losses[loss].values(): coalesce = torch.tensor(0.0) for node in list(nx.dfs_postorder_nodes(self.graph)): coalesce = coalesce + loss_fn(node) result.append(coalesce) else: kwds = ( losses[loss] if (isinstance(losses[loss], dict)) else ({"coeff": losses[loss]}) ) result.append( self._traverse_execute(f"_{loss.value.lower()}_loss", **kwds) ) if result[-1]: logging.info(f"{loss.value.upper()} LOSS {result[-1]}") return result def print( self, source: Formula = None, header_len: int = 50, roundoff: int = 5, params: bool = False, grads: bool = False, numbering: bool = False, ): n = header_len + 25 print("\n" + "*" * n + f'\n{"":<{n / 2 - 5}}LNN {self.name}\n') self._traverse_execute( "print", Direction.DOWNWARD, source=source, header_len=header_len, roundoff=roundoff, params=params, grads=grads, numbering=numbering, ) print("*" * n) def plot_graph( self, formula_number: bool = False, edge_variables: bool = False, **kwds ): options = { "with_labels": False, "arrows": False, "edge_color": "#d0e2ff", "node_color": "#ffffff", "node_size": 16, "font_size": 9, } options.update(kwds) pos = viz.get_pos(self) nx.draw(self.graph, pos, **options) nx.draw_networkx_labels( self.graph, pos, dict( [ (node, node.formula_number) if formula_number else (node, node.connective_str) if hasattr(node, "connective_str") else (node, node.name) for node in self.graph ] ), ) if edge_variables: labels = { edge: _utils.list_to_str( edge[0].operand_map[edge[0].operands.index(edge[1])] ) for edge in self.graph.edges if isinstance(edge[1], Predicate) } nx.draw_networkx_edge_labels( self.graph, pos, labels, ) plt.show() def flush(self): self._traverse_execute("flush") def reset_bounds(self): self._traverse_execute("reset_bounds") def _project_params(self): self._traverse_execute("project_params") def increment_param_history(self, parameter_history): if parameter_history: self._traverse_execute( "increment_param_history", parameter_history=parameter_history ) def has_contradiction(self): return ( True if any([node.is_contradiction() for node in self.nodes.values()]) else False ) @property def shape(self): groundings = sum( [ 1 if node.propositional else len(node.groundings) for node in self.nodes.values() ] ) return [len(self.nodes), groundings]