Source code for kif_lib.model.term.template

# Copyright (C) 2024 IBM Corp.
# SPDX-License-Identifier: Apache-2.0

from __future__ import annotations

from ... import itertools
from ...typing import Any, cast, ClassVar, Iterator, Location, override
from .term import ClosedTerm, OpenTerm, Term, Theta
from .variable import Variable


[docs] class Template(OpenTerm): """Abstract base class for templates.""" #: Object class associated with this template class. object_class: ClassVar[type[ClosedTerm]] def _preprocess_args(self, args: tuple[Any, ...]) -> tuple[Any, ...]: return self._normalize_args(super()._preprocess_args(args)) def _normalize_args(self, args: tuple[Any, ...]) -> tuple[Any, ...]: ### # A template is *normal* iff for every two variables ``x``, ``y`` in # the template, ``x.name == y.name`` implies ``type(x) == type(y)``. # # We normalize a template by normalizing its arguments, i.e., by # replacing homonymous occurrences of a same variable by the # occurrence with the most specific type. For example, if # ItemVariable('v') and EntityVariable('v') occur in the arguments # of the template, we replace all occurrences of the latter by # ItemVariable('v'). # # Note that normalization is only possible if the homonymous # variables are inter-coercible. This method will throw an error if # that is not the case. ### vars = frozenset(itertools.chain(*map( OpenTerm.get_variables, filter(self.is_open, args)))) most_specific: dict[str, Variable] = {} for var in vars: if var.name not in most_specific: most_specific[var.name] = var else: cur = most_specific[var.name] assert var != cur # as there are no repetitions in vars most_specific[var.name] = cur.check(var, type(self)) theta: dict[Variable, Variable] = {} for var in vars: if var.name in most_specific and most_specific[var.name] != var: theta[var] = most_specific[var.name] if not theta: return args else: return tuple(map( lambda x: x._instantiate(theta, False) if self.is_open(x) else x, args)) @override def _iterate_variables(self) -> Iterator[Variable]: return self._traverse(lambda x: isinstance(x, Variable), self.is_open)
[docs] @override def instantiate(self, theta: Theta, coerce: bool = True) -> Term: return cast(Term, super().instantiate(theta, coerce))
@override def _instantiate( self, theta: Theta, coerce: bool, function: Location | None = None, name: str | None = None, position: int | None = None ) -> Term: return self.__class__(*map( lambda arg: arg._instantiate( theta, coerce, function, name, position) if isinstance(arg, OpenTerm) else arg, self.args))