### Copyright 2023 IBM Corp. All Rights Reserved.## SPDX-License-Identifier: Apache-2.0### flake8: noqa: E501importloggingfromtypingimportUnion,Tuple,Set,Dictimporttorchfrom.connective_formulaimport_ConnectiveFormulafrom.formulaimportFormulafrom.node_activationimport_NodeActivationfrom..import_gmfrom...import_utilsfrom...constantsimportFact_utils.logger_setup()class_NAryOperator(_ConnectiveFormula):r"""N-ary connective operator"""def__init__(self,*formula,**kwds):super().__init__(*formula,arity=len(formula),**kwds)
[docs]classCongruent(_NAryOperator):r"""Symbolic Congruency This is used to define nodes that are symbolically equivalent to one another (despite the possibility of neural differences) """def__init__(self,*formulae:Formula,**kwds):self.connective_str="≅"super().__init__(*formulae,**kwds)kwds.setdefault("propositional",self.propositional)self.neuron=_NodeActivation()(**kwds.get("activation",{}),**kwds)def__contains__(self,item):returnTrueifiteminself.congruent_nodeselseFalse
[docs]defadd_data(self,facts:Union[Fact,Tuple,Set,Dict]):"""Should not be called by the user"""raiseAttributeError("Should not be called directly by the user, instead use ""`congruent_node.upward()` to evaluate the facts from the operands")
defupward(self,groundings:Set[Union[str,Tuple[str,...]]]=None,**kwds)->float:r"""Upward inference from the operands to the operator. Parameters ---------- groundings : str or tuple of str restrict upward inference to a specific grounding or row in the truth table Returns ------- tightened_bounds : float The amount of bounds tightening or new information that is leaned by the inference step. """upward_bounds=_gm.upward_bounds(self,self.operands,groundings)ifupward_boundsisNone:# contradiction arrestingreturninput_bounds,groundings=upward_boundsgrounding_rows=(Noneifself.propositionalelse(self.grounding_table.values()ifgroundingsisNoneelse[self.grounding_table.get(g)forgingroundings]))input_bounds=torch.stack([input_bounds[...,0,:].max(-1)[0],input_bounds[...,1,:].max(-1)[0],],dim=-1,)result=self.neuron.aggregate_bounds(grounding_rows,input_bounds)ifresult:logging.info("↑ BOUNDS UPDATED "f"TIGHTENED:{result} "f"FOR:'{self.name}' "f"FORMULA:{self.formula_number} ")returnresult
[docs]defdownward(self,index:int=None,groundings:Set[Union[str,Tuple[str,...]]]=None,**kwds,)->Union[torch.Tensor,None]:r"""Downward inference from the operator to the operands. Parameters ---------- index : int, optional restricts downward inference to an operand at the specified index. If unspecified, all operands are updated. groundings : str or tuple of str, optional restrict upward inference to a specific grounding or row in the truth table Returns ------- tightened_bounds : float The amount of bounds tightening or new information that is leaned by the inference step. """downward_bounds=_gm.downward_bounds(self,self.operands,groundings)ifdownward_boundsisNone:# contradiction arrestingreturnparent,_,groundings=downward_boundsop_indices=(enumerate(self.operands)ifindexisNoneelse([(index,self.operands[index])]))result=0forop_index,opinop_indices:ifop.propositional:op_grounding_rows=Noneelse:ifgroundingsisNone:op_grounding_rows=op.grounding_table.values()else:op_grounding_rows=[None]*len(groundings)forg_i,ginenumerate(groundings):op_g=[str(g.partial_grounding[slot])forslotinself.operand_map[op_index]]op_g=tuple(op_g)op_grounding_rows[g_i]=op.grounding_table.get(op_g)op_aggregate=op.neuron.aggregate_bounds(op_grounding_rows,parent)ifop_aggregate:logging.info("↓ BOUNDS UPDATED "f"TIGHTENED:{op_aggregate} "f"FOR:'{op.name}' "f"FROM:'{self.name}' "f"FORMULA:{op.formula_number} "f"PARENT:{self.formula_number} ")result=result+op_aggregatereturnresult