Source code for genai.text.generation.generation_service

from typing import Generator, Optional, Union

import httpx
from httpx import AsyncClient, HTTPStatusError
from pydantic import BaseModel

from genai._types import ModelLike
from genai._utils.api_client import ApiClient
from genai._utils.async_executor import execute_async
from genai._utils.general import cast_list, to_model_instance, to_model_optional
from genai._utils.service import (
    BaseService,
    BaseServiceConfig,
    BaseServiceServices,
    CommonExecutionOptions,
    get_service_action_metadata,
    set_service_action_metadata,
)
from genai.schema import (
    ModerationParameters,
    PromptTemplateData,
    TextGenerationComparisonCreateEndpoint,
    TextGenerationComparisonCreateRequestRequest,
    TextGenerationComparisonCreateResponse,
    TextGenerationComparisonParameters,
    TextGenerationCreateEndpoint,
    TextGenerationCreateResponse,
    TextGenerationParameters,
    TextGenerationStreamCreateEndpoint,
    TextGenerationStreamCreateResponse,
)
from genai.schema._api import (
    _TextGenerationComparisonCreateParametersQuery,
    _TextGenerationComparisonCreateRequest,
    _TextGenerationCreateParametersQuery,
    _TextGenerationCreateRequest,
    _TextGenerationStreamCreateParametersQuery,
    _TextGenerationStreamCreateRequest,
)
from genai.text.generation._generation_utils import generation_stream_handler
from genai.text.generation.limits.limit_service import LimitService as _LimitService

__all__ = ["GenerationService", "BaseConfig", "BaseServices", "CreateExecutionOptions"]


from genai._utils.http_client.retry_transport import BaseRetryTransport
from genai._utils.limiters.base_limiter import BaseLimiter
from genai._utils.limiters.external_limiter import ConcurrencyResponse, ExternalLimiter
from genai._utils.limiters.local_limiter import LocalLimiter
from genai._utils.limiters.shared_limiter import LoopBoundLimiter


[docs] class BaseServices(BaseServiceServices): LimitService: type[_LimitService] = _LimitService
[docs] class CreateExecutionOptions(BaseModel): throw_on_error: CommonExecutionOptions.throw_on_error = True ordered: CommonExecutionOptions.ordered = True concurrency_limit: CommonExecutionOptions.concurrency_limit = None callback: CommonExecutionOptions.callback[ Union[TextGenerationStreamCreateResponse, TextGenerationCreateResponse] ] = None
[docs] class BaseConfig(BaseServiceConfig): create_execution_options: CreateExecutionOptions = CreateExecutionOptions()
[docs] class GenerationService(BaseService[BaseConfig, BaseServices]): Config = BaseConfig Services = BaseServices
[docs] def __init__( self, *, api_client: ApiClient, services: Optional[BaseServices] = None, config: Optional[Union[BaseConfig, dict]] = None, ): super().__init__(api_client=api_client, config=config) if not services: services = BaseServices() self._concurrency_limiter = self._get_concurrency_limiter() self.limit = services.LimitService(api_client=api_client)
def _get_concurrency_limiter(self) -> LoopBoundLimiter: async def handler(): response = await self.limit.aretrieve() return ConcurrencyResponse( limit=response.result.concurrency.limit, remaining=response.result.concurrency.remaining, ) return LoopBoundLimiter(lambda: ExternalLimiter(handler=handler))
[docs] @set_service_action_metadata(endpoint=TextGenerationCreateEndpoint) def create( self, *, model_id: Optional[str] = None, prompt_id: Optional[str] = None, input: Optional[str] = None, inputs: Optional[Union[list[str], str]] = None, parameters: Optional[ModelLike[TextGenerationParameters]] = None, moderations: Optional[ModelLike[ModerationParameters]] = None, data: Optional[ModelLike[PromptTemplateData]] = None, execution_options: Optional[ModelLike[CreateExecutionOptions]] = None, ) -> Generator[TextGenerationCreateResponse, None, None]: """ Args: model_id: The ID of the model. prompt_id: The ID of the prompt which should be used. input: Prompt to process. It is recommended not to leave any trailing spaces. inputs: Prompt/prompts to process. It is recommended not to leave any trailing spaces. parameters: Parameters for text generation. moderations: Parameters for moderation. data: An optional data object for underlying prompt. execution_options: An optional configuration how SDK should work (error handling, limits, callbacks, ...) Yields: TextGenerationCreateResponse object (server response without modification). Raises: ApiResponseException: In case of a known API error. ApiNetworkException: In case of unhandled network error. ValidationError: In case of provided parameters are invalid. Note: To limit number of concurrent requests or change execution procedure, see 'execute_options' parameter. """ if inputs is not None and input is not None: raise ValueError("Either specify 'inputs' or 'input' parameter!") prompts: Optional[list[str]] = ( cast_list(inputs) if inputs is not None else cast_list(input) if input is not None else None ) if not prompts and not prompt_id: raise ValueError("At least one of the following parameters input/inputs/prompt_id must be specified!") metadata = get_service_action_metadata(self.create) parameters_formatted = to_model_optional(parameters, TextGenerationParameters) moderations_formatted = to_model_optional(moderations, ModerationParameters, copy=True) template_formatted = to_model_optional(data, PromptTemplateData) execution_options_formatted = to_model_instance( [self.config.create_execution_options, execution_options], CreateExecutionOptions ) assert execution_options_formatted self._log_method_execution( "Generate Create", prompts=prompts, prompt_id=prompt_id, parameters=parameters_formatted, moderations=moderations_formatted, data=template_formatted, execution_options=execution_options_formatted, ) if prompt_id is not None: with self._get_http_client() as client: http_response = client.post( url=self._get_endpoint(metadata.endpoint), params=_TextGenerationCreateParametersQuery().model_dump(), json=_TextGenerationCreateRequest( input=None, model_id=model_id, moderations=moderations_formatted, parameters=parameters_formatted, prompt_id=prompt_id, data=template_formatted, ).model_dump(), ) yield TextGenerationCreateResponse(**http_response.json()) return async def handler( batch_input: str, http_client: AsyncClient, limiter: BaseLimiter ) -> TextGenerationCreateResponse: self._log_method_execution("Generate Create - processing input", input=batch_input) async def handle_retry(ex: Exception): if isinstance(ex, HTTPStatusError) and ex.response.status_code == httpx.codes.TOO_MANY_REQUESTS: await limiter.report_error() async def handle_success(*args): await limiter.report_success() http_response = await http_client.post( url=self._get_endpoint(metadata.endpoint), extensions={ BaseRetryTransport.Callback.Retry: handle_retry, BaseRetryTransport.Callback.Success: handle_success, }, params=_TextGenerationCreateParametersQuery().model_dump(), json=_TextGenerationCreateRequest( input=batch_input, model_id=model_id, moderations=moderations_formatted, parameters=parameters_formatted, prompt_id=prompt_id, data=template_formatted, ).model_dump(), ) response = TextGenerationCreateResponse(**http_response.json()) if execution_options_formatted.callback: execution_options_formatted.callback(response) return response yield from execute_async( inputs=prompts, handler=handler, limiters=[ self._concurrency_limiter, self._get_local_limiter(execution_options_formatted.concurrency_limit), ], http_client=self._get_async_http_client, ordered=execution_options_formatted.ordered, throw_on_error=execution_options_formatted.throw_on_error, )
[docs] @set_service_action_metadata(endpoint=TextGenerationStreamCreateEndpoint) def create_stream( self, *, input: Optional[str] = None, model_id: Optional[str] = None, prompt_id: Optional[str] = None, parameters: Optional[ModelLike[TextGenerationParameters]] = None, moderations: Optional[ModelLike[ModerationParameters]] = None, data: Optional[ModelLike[PromptTemplateData]] = None, ) -> Generator[TextGenerationStreamCreateResponse, None, None]: """ Yields: TextGenerationStreamCreateResponse (raw server response object) Raises: ApiResponseException: In case of a known API error. ApiNetworkException: In case of unhandled network error. ValidationError: In case of provided parameters are invalid. """ metadata = get_service_action_metadata(self.create_stream) parameters_formatted = to_model_optional(parameters, TextGenerationParameters) moderations_formatted = to_model_optional(moderations, ModerationParameters, copy=True) template_formatted = to_model_optional(data, PromptTemplateData) self._log_method_execution( "Generate Create Stream", input=input, parameters=parameters_formatted, moderations=moderations_formatted, template=template_formatted, ) with self._get_http_client() as client: yield from generation_stream_handler( ResponseModel=TextGenerationStreamCreateResponse, logger=self._logger, generator=client.post_stream( url=self._get_endpoint(metadata.endpoint), params=_TextGenerationStreamCreateParametersQuery().model_dump(), json=_TextGenerationStreamCreateRequest( input=input, parameters=parameters_formatted, model_id=model_id, prompt_id=prompt_id, moderations=moderations_formatted, data=template_formatted, ).model_dump(), ), )
[docs] @set_service_action_metadata(endpoint=TextGenerationComparisonCreateEndpoint) def compare( self, *, request: TextGenerationComparisonCreateRequestRequest, compare_parameters: Optional[ModelLike[TextGenerationComparisonParameters]] = None, name: Optional[str] = None, ) -> TextGenerationComparisonCreateResponse: """ Raises: ApiResponseException: In case of a known API error. ApiNetworkException: In case of unhandled network error. ValidationError: In case of provided parameters are invalid. """ metadata = get_service_action_metadata(self.compare) request_formatted = to_model_instance(request, TextGenerationComparisonCreateRequestRequest, copy=True) compare_parameters_formatted = to_model_instance(compare_parameters, TextGenerationComparisonParameters) self._log_method_execution( "Text Generation Compare", input=input, requests=request_formatted, parameters=compare_parameters_formatted, ) with self._get_http_client() as client: http_response = client.post( url=self._get_endpoint(metadata.endpoint), params=_TextGenerationComparisonCreateParametersQuery().model_dump(), json=_TextGenerationComparisonCreateRequest( name=name, compare_parameters=compare_parameters_formatted, request=request_formatted, ).model_dump(), ) return TextGenerationComparisonCreateResponse(**http_response.json())
def _get_local_limiter(self, limit: Optional[int]): return LoopBoundLimiter(lambda: LocalLimiter(limit=limit)) if limit else None