Source code for genai.prompt.prompt_service

from typing import Optional, Union

from genai._types import EnumLike, EnumLikeOrEnumLikeList, ModelLike
from genai._utils.general import (
    cast_list,
    cast_list_optional,
    to_enum,
    to_enum_optional,
    to_model_instance,
    to_model_optional,
)
from genai._utils.service import (
    BaseService,
    BaseServiceConfig,
    BaseServiceServices,
    get_service_action_metadata,
    set_service_action_metadata,
)
from genai._utils.validators import assert_is_not_empty_string
from genai.schema import (
    BaseMessage,
    ModerationParameters,
    PromptCreateResponse,
    PromptIdRetrieveResponse,
    PromptIdUpdateResponse,
    PromptListSource,
    PromptRetrieveResponse,
    PromptTemplateData,
    PromptType,
    TextGenerationParameters,
)
from genai.schema._api import (
    PromptListSortBy,
    SortDirection,
    _PromptCreateParametersQuery,
    _PromptCreateRequest,
    _PromptIdDeleteParametersQuery,
    _PromptIdRetrieveParametersQuery,
    _PromptIdUpdateParametersQuery,
    _PromptIdUpdateRequest,
    _PromptRetrieveParametersQuery,
)
from genai.schema._endpoints import (
    PromptCreateEndpoint,
    PromptIdDeleteEndpoint,
    PromptIdRetrieveEndpoint,
    PromptIdUpdateEndpoint,
    PromptRetrieveEndpoint,
)

__all__ = ["PromptService"]


[docs] class PromptService(BaseService[BaseServiceConfig, BaseServiceServices]):
[docs] @set_service_action_metadata(endpoint=PromptCreateEndpoint) def create( self, *, name: str, model_id: str, prompt_id: Optional[str] = None, messages: Optional[list[ModelLike[BaseMessage]]] = None, task_id: Optional[str] = None, description: Optional[str] = None, moderations: Optional[ModelLike[ModerationParameters]] = None, data: Optional[ModelLike[PromptTemplateData]] = None, type: Optional[EnumLike[PromptType]] = None, input: Optional[str] = None, output: Optional[str] = None, parameters: Optional[ModelLike[TextGenerationParameters]] = None, folder_id: Optional[str] = None, ) -> PromptCreateResponse: """ Raises: ApiResponseException: In case of a known API error. ApiNetworkException: In case of unhandled network error. ValidationError: In case of provided parameters are invalid. """ request_body = _PromptCreateRequest( name=name, model_id=model_id, prompt_id=prompt_id, messages=[to_model_instance(msg, BaseMessage) for msg in messages] if messages else None, task_id=task_id, description=description, moderations=to_model_optional(moderations, ModerationParameters), data=to_model_optional(data, PromptTemplateData), input=input, output=output, parameters=to_model_optional(parameters, TextGenerationParameters), type=to_enum_optional(type, PromptType), folder_id=folder_id, ).model_dump() self._log_method_execution("Prompt Create", **request_body) with self._get_http_client() as client: metadata = get_service_action_metadata(self.create) response = client.post( url=self._get_endpoint(metadata.endpoint), params=_PromptCreateParametersQuery().model_dump(), json=request_body, ) return PromptCreateResponse(**response.json())
[docs] @set_service_action_metadata(endpoint=PromptIdRetrieveEndpoint) def retrieve( self, id: str, ) -> PromptIdRetrieveResponse: """ Raises: ApiResponseException: In case of a known API error. ApiNetworkException: In case of unhandled network error. ValidationError: In case of provided parameters are invalid. """ assert_is_not_empty_string(id) self._log_method_execution("Prompts Retrieve", id=id) with self._get_http_client() as client: metadata = get_service_action_metadata(self.retrieve) response = client.get( url=self._get_endpoint(metadata.endpoint, id=id), params=_PromptIdRetrieveParametersQuery().model_dump(), ) return PromptIdRetrieveResponse(**response.json())
[docs] @set_service_action_metadata(endpoint=PromptIdUpdateEndpoint) def update( self, id: str, *, name: str, model_id: str, folder_id: Optional[str] = None, industry_id: Optional[str] = None, language_id: Optional[str] = None, description: Optional[str] = None, input: Optional[str] = None, output: Optional[str] = None, task_id: Optional[str] = None, type: Optional[EnumLike[PromptType]] = None, messages: Optional[list[ModelLike[BaseMessage]]] = None, moderations: Optional[ModelLike[ModerationParameters]] = None, parameters: Optional[ModelLike[TextGenerationParameters]] = None, data: Optional[ModelLike[PromptTemplateData]] = None, ) -> PromptIdUpdateResponse: """ Raises: ApiResponseException: In case of a known API error. ApiNetworkException: In case of unhandled network error. ValidationError: In case of provided parameters are invalid. """ assert_is_not_empty_string(id) request_body = _PromptIdUpdateRequest( name=name, model_id=model_id, folder_id=folder_id, industry_id=industry_id, language_id=language_id, messages=[to_model_instance(msg, BaseMessage) for msg in messages] if messages else None, task_id=task_id, description=description, moderations=to_model_optional(moderations, ModerationParameters), data=to_model_optional(data, PromptTemplateData), input=input, output=output, parameters=to_model_optional(parameters, TextGenerationParameters), type=to_enum_optional(type, PromptType), ).model_dump() self._log_method_execution("Prompt Update", **request_body) with self._get_http_client() as client: metadata = get_service_action_metadata(self.update) response = client.post( url=self._get_endpoint(metadata.endpoint, id=id), params=_PromptIdUpdateParametersQuery().model_dump(), json=request_body, ) return PromptIdUpdateResponse(**response.json())
[docs] @set_service_action_metadata(endpoint=PromptRetrieveEndpoint) def list( self, *, limit: Optional[int] = None, offset: Optional[int] = None, sort_by: Optional[EnumLike[PromptListSortBy]] = None, direction: Optional[EnumLike[SortDirection]] = None, search: Optional[str] = None, task_id: Optional[Union[str, list[str]]] = None, model_id: Optional[Union[str, list[str]]] = None, source: Optional[EnumLikeOrEnumLikeList[PromptListSource]] = None, model_family_id: Optional[float] = None, industry_id: Optional[Union[str, list[str]]] = None, prompt_language_id: Optional[Union[str, list[str]]] = None, model_type_id: Optional[Union[str, list[str]]] = None, avg_time_min: Optional[int] = None, avg_time_max: Optional[int] = None, context_window_min: Optional[int] = None, context_window_max: Optional[int] = None, folder_id: Optional[str] = None, ) -> PromptRetrieveResponse: """ Raises: ApiResponseException: In case of a known API error. ApiNetworkException: In case of unhandled network error. ValidationError: In case of provided parameters are invalid. """ request_parameters = _PromptRetrieveParametersQuery( limit=limit, offset=offset, sort_by=to_enum_optional(sort_by, PromptListSortBy), direction=to_enum_optional(direction, SortDirection), search=search, task_id=cast_list_optional(task_id), model_id=cast_list_optional(model_id), source=[to_enum(PromptListSource, s) for s in cast_list(source)] if source else None, model_family_id=model_family_id, industry_id=cast_list_optional(industry_id), prompt_language_id=cast_list_optional(prompt_language_id), model_type_id=cast_list_optional(model_type_id), avg_time_min=avg_time_min, avg_time_max=avg_time_max, context_window_min=context_window_min, context_window_max=context_window_max, folder_id=folder_id, ).model_dump() self._log_method_execution("Prompt List", **request_parameters) with self._get_http_client() as client: metadata = get_service_action_metadata(self.list) response = client.get(url=self._get_endpoint(metadata.endpoint), params=request_parameters) return PromptRetrieveResponse(**response.json())
[docs] @set_service_action_metadata(endpoint=PromptIdDeleteEndpoint) def delete( self, id: str, ) -> None: """ Raises: ApiResponseException: In case of a known API error. ApiNetworkException: In case of unhandled network error. ValidationError: In case of provided parameters are invalid. """ assert_is_not_empty_string(id) self._log_method_execution("Prompt Delete", id=id) with self._get_http_client() as client: metadata = get_service_action_metadata(self.delete) client.delete( url=self._get_endpoint(metadata.endpoint, id=id), params=_PromptIdDeleteParametersQuery().model_dump(), )