# -----------------------------------------------------------------------------------------
# (C) Copyright IBM Corp. 2024.
# https://opensource.org/licenses/BSD-3-Clause
# -----------------------------------------------------------------------------------------
from __future__ import annotations
from typing import TYPE_CHECKING, TypedDict
from copy import deepcopy
from ibm_watsonx_ai.wml_client_error import InvalidMultipleArguments
from ibm_watsonx_ai.wml_resource import WMLResource
from ibm_watsonx_ai._wrappers import requests
from ibm_watsonx_ai.foundation_models.schema import (
RerankParameters,
BaseSchema,
)
if TYPE_CHECKING:
from ibm_watsonx_ai import APIClient, Credentials
class TextDict(TypedDict):
text: str
[docs]
class Rerank(WMLResource):
"""
Rerank texts based on some queries.
:param model_id: type of model to use
:type model_id: str
:param params: parameters to use during request generation
:type params: dict, RerankParameters, optional
:param credentials: credentials for the Watson Machine Learning instance
:type credentials: Credentials or dict, optional
:param project_id: ID of the Watson Studio project
:type project_id: str, optional
:param space_id: ID of the Watson Studio space
:type space_id: str, optional
:param verify: You can pass one of the following as verify:
* the path to a CA_BUNDLE file
* the path of directory with certificates of trusted CAs
* `True` - default path to truststore will be taken
* `False` - no verification will be made
:type verify: bool or str, optional
:param api_client: initialized APIClient object with a set project ID or space ID. If passed, ``credentials`` and ``project_id``/``space_id`` are not required.
:type api_client: APIClient, optional
**Example:**
.. code-block:: python
from ibm_watsonx_ai import Credentials
from ibm_watsonx_ai.foundation_models import Rerank
generate_params = {
"truncate_input_tokens": 10
}
wx_ranker = Rerank(
model_id="<RERANK MODEL>",
params=generate_params,
credentials=Credentials(
api_key = "***",
url = "https://us-south.ml.cloud.ibm.com"),
project_id=project_id
)
"""
def __init__(
self,
model_id: str,
params: dict | RerankParameters | None = None,
credentials: Credentials | None = None,
project_id: str | None = None,
space_id: str | None = None,
verify: bool | str | None = None,
api_client: APIClient | None = None,
) -> None:
self.model_id = model_id
Rerank._validate_type(model_id, "model_id", str, True)
self.params = params
Rerank._validate_type(params, "params", [dict, RerankParameters], False, True)
if credentials:
from ibm_watsonx_ai import APIClient
self._client = APIClient(credentials, verify=verify)
elif api_client:
self._client = api_client
else:
raise InvalidMultipleArguments(
params_names_list=["credentials", "api_client"],
reason="None of the arguments were provided.",
)
if space_id:
self._client.set.default_space(space_id)
elif project_id:
self._client.set.default_project(project_id)
elif not api_client:
raise InvalidMultipleArguments(
params_names_list=["space_id", "project_id"],
reason="None of the arguments were provided.",
)
WMLResource.__init__(self, __name__, self._client)
[docs]
def generate(
self,
query: str,
inputs: list[str | TextDict],
params: dict | RerankParameters | None = None,
) -> dict:
"""
Calling this method generates the following auditing event.
:param query: The rank query.
:type query: str
:param inputs: The rank input strings.
:type inputs: list[str], list[dict['text', str]]
:param params:
:type params: dict, RerankParameters, optional
**Example:**
.. code-block:: python
query = "As a Youth, I craved excitement while in adulthood I followed Enthusiastic Pursuit."
inputs = [
"In my younger years, I often reveled in the excitement of spontaneous adventures and embraced the thrill of the unknown, whereas in my grownup life, I have come to appreciate the comforting stability of a well-established routine.",
"As a young man, I frequently sought out exhilarating experiences, craving the adrenaline rush of lifes novelties, while as a responsible adult, I have come to understand the profound value of accumulated wisdom and life experience."
]
response = wx_ranker.generate(query=query, inputs=inputs)
# Print all response
print(response)
"""
self._client._check_if_either_is_set()
self._validate_type(query, "query", str, True)
self._validate_type(inputs, "inputs", list, True)
if all(isinstance(el, str) for el in inputs):
inputs_payload = [{"text": el_input} for el_input in inputs]
else:
inputs_payload = inputs # type: ignore
payload: dict = {
"model_id": self.model_id,
"query": query,
"inputs": inputs_payload,
}
if params is not None:
parameters = params
elif self.params is not None:
parameters = deepcopy(self.params)
else:
parameters = None
if isinstance(parameters, BaseSchema):
parameters = parameters.to_dict()
if parameters:
payload["parameters"] = parameters
if self._client.default_project_id:
payload["project_id"] = self._client.default_project_id
elif self._client.default_space_id:
payload["space_id"] = self._client.default_space_id
response = requests.post(
url=self._client.service_instance._href_definitions.get_rerank_href(),
json=payload,
params=self._client._params(skip_for_create=True, skip_userfs=True),
headers=self._client._get_headers(),
)
return self._handle_response(200, "generate_rerank", response)