Source code for ibm_watsonx_ai.foundation_models.embeddings.base_embeddings

#  -----------------------------------------------------------------------------------------
#  (C) Copyright IBM Corp. 2024-2025.
#  https://opensource.org/licenses/BSD-3-Clause
#  -----------------------------------------------------------------------------------------

from __future__ import annotations
import warnings
import copy
import importlib
from typing import Any
from abc import ABC, abstractmethod

from ibm_watsonx_ai.wml_client_error import UnexpectedKeyWordArgument


[docs] class BaseEmbeddings(ABC): """LangChain-like embedding function interface."""
[docs] @abstractmethod def embed_documents(self, texts: list[str]) -> list[list[float]]: """Embed search docs.""" raise NotImplementedError()
[docs] @abstractmethod def embed_query(self, text: str) -> list[float]: """Embed query text.""" raise NotImplementedError()
[docs] def to_dict(self) -> dict: """Serialize Embeddings. :return: serializes this Embeddings so that it can be reconstructed by ``from_dict`` class method. :rtype: dict """ return {"__class__": self.__class__.__name__, "__module__": self.__module__}
[docs] @classmethod def from_dict(cls, data: dict, **kwargs: Any) -> BaseEmbeddings | None: """Deserialize ``BaseEmbeddings`` into a concrete one using arguments. :return: concrete Embeddings or None if data is incorrect :rtype: BaseEmbeddings | None """ supported_kwargs = ["api_client"] if unsupported_kwargs := set(kwargs.keys()) - set(supported_kwargs): for kword in unsupported_kwargs: raise UnexpectedKeyWordArgument( kword, reason=f"{kword} is not supported as a keyword argument. Supported kwargs: {supported_kwargs}", ) api_client = kwargs.get("api_client") data = copy.deepcopy(data) if isinstance(data, dict): class_type = data.pop("__class__", None) module_name = data.pop("__module__", None) if module_name: module = importlib.import_module(module_name) if class_type: try: cls = getattr(module, class_type) except AttributeError: raise AttributeError( f"Module: {module} has no attribute {class_type}" ) if cls: if ( module_name == "ibm_watsonx_ai.foundation_models.embeddings.embeddings" and api_client is not None ): data.pop("credentials", None) data["api_client"] = api_client with warnings.catch_warnings(record=True): warnings.simplefilter("ignore", category=DeprecationWarning) return cls(**data) return None