Source code for ibm_watsonx_ai.foundation_models.embeddings.embeddings

#  -----------------------------------------------------------------------------------------
#  (C) Copyright IBM Corp. 2024.
#  https://opensource.org/licenses/BSD-3-Clause
#  -----------------------------------------------------------------------------------------


from __future__ import annotations
import os
from typing import TypeAlias, TYPE_CHECKING, Any
from concurrent.futures import ThreadPoolExecutor
from functools import reduce, partial
from enum import Enum

import httpx

from ibm_watsonx_ai.wml_client_error import (
    WMLClientError,
    InvalidMultipleArguments,
    ParamOutOfRange,
)
from .base_embeddings import BaseEmbeddings
from ibm_watsonx_ai.wml_resource import WMLResource
from ibm_watsonx_ai._wrappers.requests import (
    get_httpx_client,
    get_httpx_client_transport,
)
import ibm_watsonx_ai._wrappers.requests as requests

if TYPE_CHECKING:
    from ibm_watsonx_ai import APIClient, Credentials

# Type Aliasses
ParamsType: TypeAlias = dict[str, str | dict[str, str]]
PayloadType: TypeAlias = dict[str, str | list[str] | ParamsType]


__all__ = ["Embeddings"]

# Defaults
MAX_INPUTS_LENGTH = 1000

# Increase read and write timeout for embeddings generation
EMBEDDINGS_HTTPX_TIMEOUT = httpx.Timeout(
    read=30 * 60, write=30 * 60, connect=10, pool=30 * 60
)

# Do not change below, required by service
_RETRY_STATUS_CODES = [429, 500, 503, 504, 520]


[docs] class Embeddings(BaseEmbeddings, WMLResource): """Instantiate the embeddings service. :param model_id: the type of model to use :type model_id: str, optional :param params: parameters to use during generate requests, use ``ibm_watsonx_ai.metanames.EmbedTextParamsMetaNames().show()`` to view the list of MetaNames :type params: dict, optional :param credentials: credentials for the Watson Machine Learning instance :type credentials: dict, optional :param project_id: ID of the Watson Studio project :type project_id: str, optional :param space_id: ID of the Watson Studio space :type space_id: str, optional :param api_client: initialized APIClient object with a set project ID or space ID. If passed, ``credentials`` and ``project_id``/``space_id`` are not required. :type api_client: APIClient, optional :param verify: You can pass one of following as verify: * the path to a CA_BUNDLE file * the path of a directory with certificates of trusted CAs * `True` - default path to truststore will be taken * `False` - no verification will be made :type verify: bool or str, optional :param persistent_connection: defines whether to keep a persistent connection when evaluating the `generate`, 'embed_query', and 'embed_documents` methods with one prompt or batch of prompts that meet the length limit. For more details, see `Generate embeddings <https://cloud.ibm.com/apidocs/watsonx-ai#text-embeddings>`_. To close the connection, run `embeddings.close_persistent_connection()`, defaults to True. Added in 1.1.2. :type persistent_connection: bool, optional :param batch_size: Number of elements to be embedded sending in one call, defaults to 1000 :type batch_size: int, optional .. note:: When the ``credentials`` parameter is passed, one of these parameters is required: [``project_id``, ``space_id``]. .. hint:: You can copy the project_id from the Project's Manage tab (Project -> Manage -> General -> Details). **Example:** .. code-block:: python from ibm_watsonx_ai import Credentials from ibm_watsonx_ai.foundation_models import Embeddings from ibm_watsonx_ai.metanames import EmbedTextParamsMetaNames as EmbedParams from ibm_watsonx_ai.foundation_models.utils.enums import EmbeddingTypes embed_params = { EmbedParams.TRUNCATE_INPUT_TOKENS: 3, EmbedParams.RETURN_OPTIONS: { 'input_text': True } } embedding = Embeddings( model_id=EmbeddingTypes.IBM_SLATE_30M_ENG, params=embed_params, credentials=Credentials( api_key = "***", url = "https://us-south.ml.cloud.ibm.com"), project_id="*****" ) """ def __init__( self, *, model_id: str, params: ParamsType | None = None, credentials: Credentials | dict[str, str] | None = None, project_id: str | None = None, space_id: str | None = None, api_client: APIClient | None = None, verify: bool | str | None = None, persistent_connection: bool = True, batch_size: int = MAX_INPUTS_LENGTH, ) -> None: if isinstance(model_id, Enum): self.model_id = model_id.value else: self.model_id = model_id self.params = params Embeddings._validate_type(params, "params", dict, False) Embeddings._validate_type(batch_size, "batch_size", int, False) if batch_size > MAX_INPUTS_LENGTH or batch_size < 1: raise ParamOutOfRange( param_name="batch_size", value=batch_size, min=1, max=MAX_INPUTS_LENGTH ) else: self.batch_size = batch_size if credentials: from ibm_watsonx_ai import APIClient self._client = APIClient(credentials, verify=verify) elif api_client: self._client = api_client else: raise InvalidMultipleArguments( params_names_list=["credentials", "api_client"], reason="None of the arguments were provided.", ) if space_id: self._client.set.default_space(space_id) elif project_id: self._client.set.default_project(project_id) elif not api_client: raise InvalidMultipleArguments( params_names_list=["space_id", "project_id"], reason="None of the arguments were provided.", ) if not self._client.CLOUD_PLATFORM_SPACES and self._client.CPD_version < 5.0: raise WMLClientError(error_msg="Operation is unsupported for this release.") self._persistent_connection = persistent_connection WMLResource.__init__(self, __name__, self._client) if self._persistent_connection: self._http_client = get_httpx_client( transport=get_httpx_client_transport( _retry_status_codes=_RETRY_STATUS_CODES, verify=self._client.credentials.verify, limits=requests.HTTPX_DEFAULT_LIMIT, ), timeout=EMBEDDINGS_HTTPX_TIMEOUT, ) else: self._http_client = requests
[docs] def generate( self, inputs: list[str], params: ParamsType | None = None, concurrency_limit: int = 5, ) -> dict: """Generate embeddings vectors for the given input with the given parameters. Returns a REST API response. :param inputs: list of texts for which embedding vectors will be generated :type inputs: list[str] :param params: MetaProps for the embedding generation, use ``ibm_watsonx_ai.metanames.EmbedTextParamsMetaNames().show()`` to view the list of MetaNames, defaults to None :type params: ParamsType | None, optional :param concurrency_limit: number of requests to be sent in parallel, max is 10, defaults to 5 :type concurrency_limit: int, optional :return: scoring results containing generated embeddings vectors :rtype: dict """ self._validate_type(inputs, "inputs", list, True) generate_url = ( self._client.service_instance._href_definitions.get_fm_embeddings_href() ) if concurrency_limit > 10 or concurrency_limit < 1: raise ParamOutOfRange( param_name="concurrency_limit", value=concurrency_limit, min=1, max=10 ) try: # For batch of prompts use keep-alive connection even if persistent_connection=False http_client = ( get_httpx_client( transport=get_httpx_client_transport( _retry_status_codes=_RETRY_STATUS_CODES, verify=self._client.credentials.verify, limits=requests.HTTPX_DEFAULT_LIMIT, ), ) if not self._persistent_connection else self._http_client ) if len(inputs) > self.batch_size: generated_responses: list = [] inputs_splitted = [ inputs[i : i + self.batch_size] for i in range(0, len(inputs), self.batch_size) ] _generate_partial = partial( self._generate, generate_url, params=params, _http_client=http_client, ) if (inputs_length := len(inputs_splitted)) <= concurrency_limit: with ThreadPoolExecutor(max_workers=inputs_length) as executor: generated_responses = list( executor.map(_generate_partial, inputs_splitted) ) else: with ThreadPoolExecutor(max_workers=concurrency_limit) as executor: generated_responses = list( executor.map(_generate_partial, inputs_splitted) ) def reduce_response(left: dict, right: dict) -> dict: import copy left_copy = copy.deepcopy(left) left_copy["results"].extend(right["results"]) left_copy["input_token_count"] += right["input_token_count"] return left_copy return reduce( reduce_response, generated_responses[1:], generated_responses[0] ) else: results = self._generate(generate_url, inputs, params) return results except Exception: http_client = None # type: ignore[assignment] raise finally: if not self._persistent_connection and isinstance( http_client, httpx.Client ): http_client.close()
[docs] def embed_documents( self, texts: list[str], params: ParamsType | None = None, concurrency_limit: int = 5, ) -> list[list[float]]: """Returns list of embedding vectors for provided texts. :param texts: list of texts for which embedding vectors will be generated :type texts: list[str] :param params: MetaProps for the embedding generation, use ``ibm_watsonx_ai.metanames.EmbedTextParamsMetaNames().show()`` to view the list of MetaNames, defaults to None :type params: ParamsType | None, optional :param concurrency_limit: number of requests to be sent in parallel, max is 10, defaults to 5 :type concurrency_limit: int, optional :return: list of embedding vectors :rtype: list[list[float]] **Example:** .. code-block:: python q = [ "What is a Generative AI?", "Generative AI refers to a type of artificial intelligence that can original content." ] embedding_vectors = embedding.embed_documents(texts=q) print(embedding_vectors) """ return [ vector.get("embedding") for vector in self.generate( inputs=texts, params=params, concurrency_limit=concurrency_limit ).get("results", [{}]) ]
[docs] def embed_query(self, text: str, params: ParamsType | None = None) -> list[float]: """Returns an embedding vector for a provided text. :param text: text for which embedding vector will be generated :type text: str :param params: MetaProps for the embedding generation, use ``ibm_watsonx_ai.metanames.EmbedTextParamsMetaNames().show()`` to view the list of MetaNames, defaults to None :type params: ParamsType | None, optional :return: embedding vector :rtype: list[float] **Example:** .. code-block:: python q = "What is a Generative AI?" embedding_vector = embedding.embed_query(text=q) print(embedding_vector) """ return ( self.generate(inputs=[text], params=params) .get("results", [{}])[0] .get("embedding") )
def _prepare_payload( self, inputs: list[str], params: ParamsType | None = None ) -> PayloadType: """Prepare payload based in provided inputs and params.""" payload: PayloadType = {"model_id": self.model_id, "inputs": inputs} if params is not None: payload["parameters"] = params elif self.params is not None: payload["parameters"] = self.params if self._client.default_project_id: payload["project_id"] = self._client.default_project_id elif self._client.default_space_id: payload["space_id"] = self._client.default_space_id return payload def _generate( self, generate_url: str, inputs: list[str], params: ParamsType | None = None, _http_client: requests.HTTPXClient | None = None, ) -> dict: """Send request with post and return service response.""" payload = self._prepare_payload(inputs, params) http_client = _http_client or self._http_client post_params: dict[str, Any] = dict( url=generate_url, json=payload, params=self._client._params(skip_for_create=True, skip_userfs=True), headers=self._client._get_headers(), ) if not isinstance(http_client, httpx.Client): post_params["_retry_status_codes"] = _RETRY_STATUS_CODES response_scoring = http_client.post(**post_params) return self._handle_response( 200, "generate", response_scoring, _field_to_hide="embedding", )
[docs] def to_dict(self) -> dict: data = super().to_dict() embeddings_args = { "model_id": self.model_id, "params": self.params, "credentials": self._client.credentials.to_dict(), "project_id": self._client.default_project_id, "space_id": self._client.default_space_id, "verify": os.environ.get("WML_CLIENT_VERIFY_REQUESTS"), } if self.batch_size != MAX_INPUTS_LENGTH: embeddings_args |= {"batch_size": self.batch_size} data.update(embeddings_args) return data
[docs] def close_persistent_connection(self) -> None: """Only applicable if persistent_connection was set to True in Embeddings initialization.""" if self._persistent_connection is not None and isinstance( self._http_client, httpx.Client ): self._http_client.close() self._http_client = get_httpx_client( transport=get_httpx_client_transport( _retry_status_codes=_RETRY_STATUS_CODES, verify=self._client.credentials.verify, limits=requests.HTTPX_DEFAULT_LIMIT, ) )