from enum import Enum
from typing import Optional, Union
from genai._types import EnumLike, ModelLike
from genai._utils.general import enum_like_to_string, to_enum, to_enum_optional, to_model_optional
from genai._utils.service import (
BaseService,
BaseServiceConfig,
BaseServiceServices,
get_service_action_metadata,
set_service_action_metadata,
)
from genai._utils.validators import assert_is_not_empty_string
from genai.schema import (
TuneAssetType,
TuneCreateEndpoint,
TuneCreateResponse,
TuneFromFileCreateEndpoint,
TuneIdContentTypeRetrieveEndpoint,
TuneIdDeleteEndpoint,
TuneIdRetrieveEndpoint,
TuneIdRetrieveResponse,
TuneParameters,
TuneRetrieveEndpoint,
TuneRetrieveResponse,
TuneStatus,
TuningTypeRetrieveEndpoint,
TuningTypeRetrieveResponse,
)
from genai.schema._api import (
TuneFromFileCreateResponse,
_TuneCreateParametersQuery,
_TuneCreateRequest,
_TuneFromFileCreateParametersQuery,
_TuneFromFileCreateRequest,
_TuneIdContentTypeRetrieveParametersQuery,
_TuneIdDeleteParametersQuery,
_TuneIdRetrieveParametersQuery,
_TuneRetrieveParametersQuery,
_TuningTypeRetrieveParametersQuery,
)
__all__ = ["TuneService"]
[docs]
class TuneService(BaseService[BaseServiceConfig, BaseServiceServices]):
[docs]
@set_service_action_metadata(endpoint=TuneCreateEndpoint)
def create(
self,
*,
model_id: str,
name: str,
task_id: str,
training_file_ids: list[str],
tuning_type: Union[str, Enum],
validation_file_ids: Optional[list[str]] = None,
parameters: Optional[ModelLike[TuneParameters]] = None,
) -> TuneCreateResponse:
"""
Raises:
ApiResponseException: In case of a known API error.
ApiNetworkException: In case of unhandled network error.
ValidationError: In case of provided parameters are invalid.
"""
with self._get_http_client() as client:
metadata = get_service_action_metadata(self.create)
request_body = _TuneCreateRequest(
model_id=model_id,
name=name,
parameters=to_model_optional(parameters, TuneParameters),
task_id=task_id,
training_file_ids=training_file_ids,
# TODO: remove casting in the next major release
tuning_type=enum_like_to_string(tuning_type),
validation_file_ids=validation_file_ids,
).model_dump()
self._log_method_execution("Tune Create", **request_body)
response = client.post(
url=self._get_endpoint(metadata.endpoint),
params=_TuneCreateParametersQuery().model_dump(),
json=request_body,
)
return TuneCreateResponse(**response.json())
[docs]
@set_service_action_metadata(endpoint=TuneFromFileCreateEndpoint)
def create_from_file(self, *, name: str, file_id: str) -> TuneFromFileCreateResponse:
"""
Raises:
ApiResponseException: In case of a known API error.
ApiNetworkException: In case of unhandled network error.
ValidationError: In case of provided parameters are invalid.
"""
with self._get_http_client() as client:
metadata = get_service_action_metadata(self.create_from_file)
request_body = _TuneFromFileCreateRequest(name=name, file_id=file_id).model_dump()
self._log_method_execution("Tune Create From File Create", **request_body)
response = client.post(
url=self._get_endpoint(metadata.endpoint),
params=_TuneFromFileCreateParametersQuery().model_dump(),
json=request_body,
)
return TuneFromFileCreateResponse(**response.json())
[docs]
@set_service_action_metadata(endpoint=TuneIdContentTypeRetrieveEndpoint)
def read(self, *, id: str, type: EnumLike[TuneAssetType]) -> bytes:
"""
Download tune assets.
Raises:
ValueError: if the tune status is not 'COMPLETED'.
ApiResponseException: In case of a known API error.
ApiNetworkException: In case of unhandled network error.
ValidationError: In case of provided parameters are invalid.
"""
assert_is_not_empty_string(id)
self._log_method_execution("Tune Read", id=id)
tune = self.retrieve(id).result
if not tune.status or tune.status != TuneStatus.COMPLETED:
raise ValueError(
f"Tune status: '{tune.status if tune.status else 'unknown'}'. "
f"Tune can not be downloaded if status is not '{TuneStatus.COMPLETED.value}'."
)
metadata = get_service_action_metadata(self.read)
with self._get_http_client() as client:
response = client.get(
url=self._get_endpoint(metadata.endpoint, id=id, type=to_enum(TuneAssetType, type)),
params=_TuneIdContentTypeRetrieveParametersQuery().model_dump(),
)
return response.content
[docs]
@set_service_action_metadata(endpoint=TuneIdRetrieveEndpoint)
def retrieve(
self,
id: str,
) -> TuneIdRetrieveResponse:
"""
Raises:
ApiResponseException: In case of a known API error.
ApiNetworkException: In case of unhandled network error.
"""
metadata = get_service_action_metadata(self.retrieve)
assert_is_not_empty_string(id)
self._log_method_execution("Tune Retrieve", id=id)
with self._get_http_client() as client:
response = client.get(
url=self._get_endpoint(metadata.endpoint, id=id),
params=_TuneIdRetrieveParametersQuery().model_dump(),
)
return TuneIdRetrieveResponse(**response.json())
[docs]
@set_service_action_metadata(endpoint=TuneRetrieveEndpoint)
def list(
self,
*,
limit: Optional[int] = None,
offset: Optional[int] = None,
status: Optional[TuneStatus] = None,
search: Optional[str] = None,
) -> TuneRetrieveResponse:
"""
Raises:
ApiResponseException: In case of a known API error.
ApiNetworkException: In case of unhandled network error.
"""
self._log_method_execution("Tune List")
with self._get_http_client() as client:
metadata = get_service_action_metadata(self.list)
response = client.get(
url=self._get_endpoint(metadata.endpoint),
params=_TuneRetrieveParametersQuery(
limit=limit, offset=offset, status=to_enum_optional(status, TuneStatus), search=search
).model_dump(),
)
return TuneRetrieveResponse(**response.json())
[docs]
@set_service_action_metadata(endpoint=TuningTypeRetrieveEndpoint)
def types(self) -> TuningTypeRetrieveResponse:
"""
Raises:
ApiResponseException: In case of a known API error.
ApiNetworkException: In case of unhandled network error.
"""
with self._get_http_client() as client:
metadata = get_service_action_metadata(self.types)
response = client.get(
url=self._get_endpoint(metadata.endpoint),
params=_TuningTypeRetrieveParametersQuery().model_dump(),
)
self._log_method_execution("Tune Types")
return TuningTypeRetrieveResponse(**response.json())
[docs]
@set_service_action_metadata(endpoint=TuneIdDeleteEndpoint)
def delete(
self,
id: str,
) -> None:
"""
Raises:
ApiResponseException: In case of a known API error.
ApiNetworkException: In case of unhandled network error.
"""
assert_is_not_empty_string(id)
self._log_method_execution("Tune Delete")
with self._get_http_client() as client:
metadata = get_service_action_metadata(self.delete)
client.delete(
url=self._get_endpoint(metadata.endpoint, id=id), params=_TuneIdDeleteParametersQuery().model_dump()
)