Source code for ibm_watsonx_ai.gateway.gateway

#  -----------------------------------------------------------------------------------------
#  (C) Copyright IBM Corp. 2025.
#  https://opensource.org/licenses/BSD-3-Clause
#  -----------------------------------------------------------------------------------------
import json
from typing import Any, AsyncIterator, Iterator

import ibm_watsonx_ai._wrappers.requests as requests
from ibm_watsonx_ai import APIClient, Credentials
from ibm_watsonx_ai.gateway.models import Models
from ibm_watsonx_ai.gateway.policies import Policies
from ibm_watsonx_ai.gateway.providers import Providers
from ibm_watsonx_ai.wml_client_error import InvalidMultipleArguments, WMLClientError
from ibm_watsonx_ai.wml_resource import WMLResource


def _streaming_create(api_client: APIClient, url: str, request_json: dict) -> Iterator:
    kw_args: dict = dict(
        method="POST",
        url=url,
        json=request_json,
        headers=api_client._get_headers(),
    )

    if isinstance(api_client.httpx_client, requests.HTTPXClient):
        stream_function = api_client.httpx_client.post_stream  # type: ignore[assignment]
    else:
        stream_function = api_client.httpx_client.stream  # type: ignore[assignment]

    with stream_function(**kw_args) as resp:
        if resp.status_code == 200:
            resp_iter = resp.iter_lines()

            for chunk in resp_iter:
                field_name, _, response = chunk.partition(":")

                if response.strip() == "[DONE]":
                    break

                if field_name == "data" and response:
                    try:
                        parsed_response = json.loads(response)
                    except json.JSONDecodeError:
                        raise Exception(f"Could not parse {response} as json")
                    yield parsed_response
        else:
            resp.read()
            raise WMLClientError(
                f"Request failed with: {resp.text} ({resp.status_code})"
            )


async def _streaming_acreate(
    api_client: APIClient, url: str, request_json: dict
) -> AsyncIterator:
    kw_args: dict = dict(
        method="POST",
        url=url,
        json=request_json,
        headers=api_client._get_headers(),
    )

    if isinstance(api_client.async_httpx_client, requests.HTTPXAsyncClient):
        stream_function = api_client.async_httpx_client.post_stream  # type: ignore[assignment]
    else:
        stream_function = api_client.async_httpx_client.stream  # type: ignore[assignment]

    async with stream_function(**kw_args) as resp:
        if resp.status_code == 200:
            resp_iter = resp.aiter_lines()

            async for chunk in resp_iter:
                field_name, _, response = chunk.partition(":")

                if response.strip() == "[DONE]":
                    break

                if field_name == "data" and response:
                    try:
                        parsed_response = json.loads(response)
                    except json.JSONDecodeError:
                        raise Exception(f"Could not parse {response} as json")
                    yield parsed_response

        else:
            await resp.aread()
            raise WMLClientError(
                f"Request failed with: ({resp.text} {resp.status_code})"
            )


[docs] class Gateway(WMLResource): """Model Gateway class.""" def __init__( self, *, credentials: Credentials | None = None, verify: bool | str | None = None, api_client: APIClient | None = None, ): if credentials: api_client = APIClient(credentials, verify=verify) elif not api_client: raise InvalidMultipleArguments( params_names_list=["credentials", "api_client"], reason="None of the arguments were provided.", ) WMLResource.__init__(self, __name__, api_client) if self._client.ICP_PLATFORM_SPACES: raise WMLClientError("AI Gateway is not supported on CPD.") self.providers = Providers(self._client) self.models = Models(self._client) self.policies = Policies(self._client) # Chat completions class _ChatCompletions(WMLResource): def __init__(self, api_client: APIClient): WMLResource.__init__(self, __name__, api_client) def create( self, model: str, messages: list[dict], **kwargs: Any ) -> dict | Iterator: """Generate chat completions for given model and messages. :param model: name of model for given provider or alias :type model: str :param messages: messages to be processed during call :type messages: list[dict] :returns: model answer :rtype: dict | Iterator """ stream = kwargs.get("stream", False) request_json = {"messages": messages, "model": model} if kwargs: request_json.update(**kwargs) url = self._client._href_definitions.get_gateway_chat_completions_href() if stream: return _streaming_create( api_client=self._client, url=url, request_json=request_json ) else: response = self._client.httpx_client.post( url=url, headers=self._client._get_headers(), json=request_json, ) return self._handle_response( 200, "chat completion creation", response ) async def acreate( self, model: str, messages: list[dict], **kwargs: Any ) -> dict | AsyncIterator: """Generate asynchronously chat completions for given model and messages. :param model: name of model for given provider or alias :type model: str :param messages: messages to be processed during call :type messages: list[dict] :returns: model answer :rtype: dict | AsyncIterator """ stream = kwargs.get("stream", False) request_json = {"messages": messages, "model": model} if kwargs: request_json.update(**kwargs) url = self._client._href_definitions.get_gateway_chat_completions_href() if stream: return _streaming_acreate( api_client=self._client, url=url, request_json=request_json ) else: response = await self._client.async_httpx_client.post( url=url, headers=self._client._get_headers(), json=request_json, ) return self._handle_response( 200, "chat completion creation", response ) class _Chat: def __init__(self, api_client: APIClient): self.completions = _ChatCompletions(api_client) self.chat = _Chat(self._client) # Text completions class _Completions(WMLResource): def __init__(self, api_client: APIClient): WMLResource.__init__(self, __name__, api_client) def create( self, model: str, prompt: str | list[str] | list[int], **kwargs: Any ) -> dict | Iterator: """Generate text completions for given model and prompt. :param model: name of model for given provider or alias :type model: str :param prompt: prompt for processing :type prompt: str or list[str] or list[int] :returns: model answer :rtype: dict | Iterator """ stream = kwargs.get("stream", False) request_json = {"prompt": prompt, "model": model} if kwargs: request_json.update(**kwargs) url = self._client._href_definitions.get_gateway_text_completions_href() if stream: return _streaming_create( api_client=self._client, url=url, request_json=request_json ) else: response = self._client.httpx_client.post( url=url, headers=self._client._get_headers(), json=request_json, ) return self._handle_response( 200, "text completion creation", response ) async def acreate( self, model: str, prompt: str | list[str] | list[int], **kwargs: Any ) -> dict | AsyncIterator: """Generate asynchronous text completions for given model and prompt. :param model: name of model for given provider or alias :type model: str :param prompt: prompt for processing :type prompt: str or list[str] or list[int] :returns: model answer :rtype: dict | AsyncIterator """ stream = kwargs.get("stream", False) request_json = {"prompt": prompt, "model": model} if kwargs: request_json.update(**kwargs) url = self._client._href_definitions.get_gateway_text_completions_href() if stream: return _streaming_acreate( api_client=self._client, url=url, request_json=request_json ) else: response = await self._client.async_httpx_client.post( url=url, headers=self._client._get_headers(), json=request_json, ) return self._handle_response( 200, "text completion creation", response ) self.completions = _Completions(self._client) # Embeddings class _Embeddings(WMLResource): def __init__(self, api_client: APIClient): WMLResource.__init__(self, __name__, api_client) def create( self, model: str, input: str | list[str] | list[int], **kwargs: Any ) -> dict: """Generate embeddings for given model and input. :param model: name of model for given provider or alias :type model: str :param input: prompt for processing :type input: str or list[str] or list[int] :returns: embeddings for given model and input :rtype: dict """ request_json = {"input": input, "model": model} if kwargs: request_json.update(**kwargs) response = self._client.httpx_client.post( self._client._href_definitions.get_gateway_embeddings_href(), headers=self._client._get_headers(), json=request_json, ) return self._handle_response(200, "embedding creation", response) async def acreate( self, model: str, input: str | list[str] | list[int], **kwargs: Any ) -> dict: """Generate asynchronous embeddings for given model and input. :param model: name of model for given provider or alias :type model: str :param input: prompt for processing :type input: str or list[str] or list[int] :returns: embeddings for given model and input :rtype: dict """ request_json = {"input": input, "model": model} if kwargs: request_json.update(**kwargs) response = await self._client.async_httpx_client.post( self._client._href_definitions.get_gateway_embeddings_href(), headers=self._client._get_headers(), json=request_json, ) return self._handle_response(200, "embedding creation", response) self.embeddings = _Embeddings(self._client)
[docs] def set_secrets_manager( self, secrets_manager: str, name: str = "Watsonx AI Model Gateway configuration" ) -> dict: """Configure Model Gateway by, among others, setting Secrets Manager url. :param secrets_manager: Secrets Manager url :type secrets_manager: str :param name: Model Gateway configuration name :type name: str, optional """ response = self._client.httpx_client.post( self._client._href_definitions.get_gateway_tenant_href(), headers=self._client._get_headers(), json={"name": name, "secrets_manager": secrets_manager}, ) return self._handle_response(201, "set secrets manager", response)
[docs] def clear_secrets_manager(self) -> str: """Clear Model Gateway configuration. :return: status ("SUCCESS" if succeeded) :rtype: str """ response = self._client.httpx_client.delete( self._client._href_definitions.get_gateway_tenant_href(), headers=self._client._get_headers(), ) return self._handle_response( 204, "tenant deletion", response, json_response=False )