Source code for genai.text.tokenization.tokenization_service

from typing import Generator, Optional, Union

from genai._types import ModelLike
from genai._utils.async_executor import execute_async
from genai._utils.general import (
    batch_by_size_constraint,
    cast_list,
    merge_objects,
    to_model_instance,
)
from genai._utils.http_client.httpx_client import AsyncHttpxClient
from genai._utils.limiters.local_limiter import LocalLimiter
from genai._utils.service import (
    BaseService,
    BaseServiceConfig,
    BaseServiceServices,
    CommonExecutionOptions,
    get_service_action_metadata,
    set_service_action_metadata,
)
from genai.schema import (
    TextTokenizationCreateEndpoint,
    TextTokenizationCreateResponse,
    TextTokenizationParameters,
)
from genai.schema._api import (
    _TextTokenizationCreateParametersQuery,
    _TextTokenizationCreateRequest,
)

__all__ = ["CreateExecutionOptions", "BaseConfig", "TokenizationService"]

from genai._utils.limiters.shared_limiter import LoopBoundLimiter


[docs] class CreateExecutionOptions(BaseServiceConfig): """Execution options for tokenization process.""" throw_on_error: CommonExecutionOptions.throw_on_error = True ordered: CommonExecutionOptions.ordered = True concurrency_limit: CommonExecutionOptions.concurrency_limit = None batch_size: CommonExecutionOptions.batch_size = None rate_limit_options: CommonExecutionOptions.rate_limit_options = None callback: CommonExecutionOptions.callback[TextTokenizationCreateResponse] = None
[docs] class BaseConfig(BaseServiceConfig): create_execution_options: CreateExecutionOptions = CreateExecutionOptions()
[docs] class TokenizationService(BaseService[BaseConfig, BaseServiceServices]): Config = BaseConfig
[docs] @set_service_action_metadata(endpoint=TextTokenizationCreateEndpoint) def create( self, *, input: Union[str, list[str]], model_id: Optional[str] = None, prompt_id: Optional[str] = None, parameters: Optional[ModelLike[TextTokenizationParameters]] = None, execution_options: Optional[ModelLike[CreateExecutionOptions]] = None, ) -> Generator[TextTokenizationCreateResponse, None, None]: """ Args: input: The input data for tokenization. It can be a single string or a list of strings. model_id: The ID of the model to use for tokenization. Eiter 'model_id' or 'prompt_id' must be provided. prompt_id: The ID of the prompt to use for tokenization. Eiter 'model_id' or 'prompt_id' must be provided. parameters: The parameters for tokenization, like return options. execution_options: The execution options for tokenization like batch size, callbacks and cetra. Raises: ApiResponseException: In case of a known API error. ApiNetworkException: In case of unhandled network error. ValidationError: In case of provided parameters are invalid. """ metadata = get_service_action_metadata(self.create) prompts = cast_list(input) options = to_model_instance([self.config.create_execution_options, execution_options], CreateExecutionOptions) parameters_validated = to_model_instance(parameters, TextTokenizationParameters) batches = batch_by_size_constraint( prompts, max_size_bytes=self._api_client.config.max_payload_size_bytes, max_chunk_size=options.batch_size or len(prompts), ) self._log_method_execution( "Tokenization Create", prompts=prompts, parameters=parameters_validated, options=options, ) async def handler(inputs_chunk: list[str], http_client: AsyncHttpxClient, *_) -> TextTokenizationCreateResponse: http_response = await http_client.post( url=self._get_endpoint(metadata.endpoint), params=_TextTokenizationCreateParametersQuery().model_dump(), json=_TextTokenizationCreateRequest( input=inputs_chunk, model_id=model_id, prompt_id=prompt_id, parameters=parameters_validated, ).model_dump(), ) response = TextTokenizationCreateResponse(**http_response.json()) if options.callback: options.callback(response) return response yield from execute_async( inputs=list(batches), handler=handler, ordered=options.ordered, limiters=[self._get_local_limiter(options.concurrency_limit)], throw_on_error=options.throw_on_error, http_client=lambda: self._get_async_http_client( rate_limit_options=merge_objects( {"max_rate": 25, "time_period": 1, "disable_rate_limit_no_header": True}, options.rate_limit_options, ) ), )
def _get_local_limiter(self, limit: Optional[int]): return LoopBoundLimiter(lambda: LocalLimiter(limit=limit)) if limit else None