Source code for ibm_watsonx_ai.foundation_models.classifications.text_classification

#  -----------------------------------------------------------------------------------------
#  (C) Copyright IBM Corp. 2025.
#  https://opensource.org/licenses/BSD-3-Clause
#  -----------------------------------------------------------------------------------------
from __future__ import annotations

from typing import TYPE_CHECKING, Literal, cast

from ibm_watsonx_ai.foundation_models.schema import TextClassificationParameters
from ibm_watsonx_ai.helpers import DataConnection
from ibm_watsonx_ai.messages.messages import Messages
from ibm_watsonx_ai.utils import get_from_json
from ibm_watsonx_ai.wml_client_error import (
    InvalidMultipleArguments,
    InvalidValue,
    WMLClientError,
)
from ibm_watsonx_ai.wml_resource import WMLResource

if TYPE_CHECKING:
    import pandas

    from ibm_watsonx_ai import APIClient, Credentials


[docs] class TextClassification(WMLResource): """Instantiate the text classification service. :param credentials: credentials to the watsonx.ai instance :type credentials: Credentials, optional :param project_id: ID of the project, defaults to None :type project_id: str, optional :param space_id: ID of the space, defaults to None :type space_id: str, optional :param api_client: initialized APIClient object with a set project ID or space ID. If passed, ``credentials`` and ``project_id``/``space_id`` are not required, defaults to None :type api_client: APIClient, optional :raises InvalidMultipleArguments: raised when neither `api_client` nor `credentials` alongside `space_id` or `project_id` are provided .. code-block:: python from ibm_watsonx_ai import Credentials from ibm_watsonx_ai.foundation_models.classifications import TextClassification text_classification = TextClassification( credentials=Credentials( api_key = IAM_API_KEY, url = "https://us-south.ml.cloud.ibm.com"), project_id="*****" ) """ def __init__( self, credentials: Credentials | None = None, project_id: str | None = None, space_id: str | None = None, api_client: APIClient | None = None, ) -> None: if credentials is not None: from ibm_watsonx_ai import APIClient self._client = APIClient(credentials) elif api_client is not None: self._client = api_client else: raise InvalidMultipleArguments( params_names_list=["credentials", "api_client"], reason="None of the arguments were provided.", ) if space_id is not None: self._client.set.default_space(space_id) elif project_id is not None: self._client.set.default_project(project_id) elif not api_client: raise InvalidMultipleArguments( params_names_list=["space_id", "project_id"], reason="None of the arguments were provided.", ) if self._client.ICP_PLATFORM_SPACES and self._client.CPD_version < 5.3: raise WMLClientError( Messages.get_message(">= 5.3", message_id="invalid_cpd_version") ) super().__init__(__name__, self._client)
[docs] def run_job( self, document_reference: DataConnection, parameters: TextClassificationParameters | dict, custom: dict | None = None, ) -> dict: """Start a request to classify text in the document. :param document_reference: reference to the document in the bucket from which text will be classified :type document_reference: DataConnection :param parameters: the parameters for the text classification :type parameters: TextClassificationParameters or dict :param custom: user defined properties for the text classification, defaults to None :type custom: dict, optional :return: text classification response :rtype: dict **Example:** .. code-block:: python from ibm_watsonx_ai.helpers import DataConnection, S3Location from ibm_watsonx_ai.foundation_models.schema import ( TextClassificationParameters, ClassificationMode, OCRMode, ) document_reference = DataConnection( connection_asset_id="<connection_id>", location=S3Location(bucket="<bucket_name>", path="path/to/file"), ) response = text_classification.run_job( document_reference=document_reference, parameters=TextClassificationParameters( ocr_mode=OCRMode.ENABLED, classification_mode=ClassificationMode.EXACT, auto_rotation_correction=True, languages=["en"], semantic_config=TextClassificationSemanticConfig( schemas_merge_strategy=SchemasMergeStrategy.MERGE, schemas=[...], ), ), custom={}, ) """ TextClassification._validate_type( document_reference, "document_reference", DataConnection, True ) TextClassification._validate_type( parameters, "parameters", [TextClassificationParameters, dict], True ) TextClassification._validate_type(custom, "custom", dict, False) if isinstance(parameters, TextClassificationParameters): parameters = parameters.to_dict() payload = { "document_reference": document_reference.to_dict(), "parameters": parameters, } if custom: payload["custom"] = custom if self._client.default_space_id is not None: payload["space_id"] = self._client.default_space_id elif self._client.default_project_id is not None: payload["project_id"] = self._client.default_project_id response = self._client.httpx_client.post( url=self._client._href_definitions.get_text_classifications_href(), json=payload, params=self._client._params(skip_for_create=True), headers=self._client._get_headers(), ) return self._handle_response(201, "run_job", response)
[docs] def list_jobs(self, limit: int | None = None) -> pandas.DataFrame: """List text classification jobs. If limit is None, all jobs will be listed. :param limit: limit number of fetched records, defaults to None :type limit: int, optional :return: text classification jobs information as a pandas DataFrame :rtype: pandas.DataFrame **Example:** .. code-block:: python text_classification.list_jobs() """ import pandas columns_mapping = { "metadata.id": "CLASSIFICATION_JOB_ID", "metadata.created_at": "CREATED", "entity.results.status": "STATUS", } details = self.get_job_details(limit=limit) df_details = ( pandas.json_normalize(details["resources"]) .reindex(columns=columns_mapping.keys()) .rename(columns=columns_mapping) ) return df_details
[docs] def get_results(self, classification_job_id: str) -> dict: """Get the text classification results. :param classification_job_id: ID of text classification job :type classification_job_id: str :return: text classification job results :rtype: dict **Example:** .. code-block:: python results = text_classification.get_results(classification_job_id="<CLASSIFICATION_JOB_ID>") """ self._validate_type(classification_job_id, "classification_job_id", str, True) job_details = self.get_job_details(classification_job_id) return get_from_json(job_details, ["entity", "results"])
[docs] def get_status(self, classification_job_id: str) -> str: """Get the text classification status. :param classification_job_id: ID of text classification job :type classification_job_id: str :return: text classification job status, possible values: [submitted, uploading, running, downloading, downloaded, completed, failed] :rtype: str **Example:** .. code-block:: python status = text_classification.get_status(classification_job_id="<CLASSIFICATION_JOB_ID>") """ self._validate_type(classification_job_id, "classification_job_id", str, True) job_details = self.get_job_details(classification_job_id) return get_from_json(job_details, ["entity", "results", "status"])
[docs] def get_job_details( self, classification_job_id: str | None = None, limit: int | None = None ) -> dict: """Return text classification job details. If `classification_job_id` is None, return the details of all text classification jobs. :param classification_job_id: ID of the text classification job, defaults to None :type classification_job_id: str, optional :param limit: limit number of fetched records, defaults to None :type limit: int, optional :return: details of the text classification job :rtype: dict **Example:** .. code-block:: python text_classification.get_job_details(classification_job_id="<CLASSIFICATION_JOB_ID>") """ self._validate_type(classification_job_id, "classification_job_id", str, False) if classification_job_id is not None: response = self._client.httpx_client.get( url=self._client._href_definitions.get_text_classification_href( classification_job_id ), params=self._client._params(skip_userfs=True), headers=self._client._get_headers(), ) elif limit is None or 1 <= limit <= 200: params = self._client._params(skip_userfs=True) if limit is not None: params["limit"] = limit response = self._client.httpx_client.get( url=self._client._href_definitions.get_text_classifications_href(), params=params, headers=self._client._get_headers(), ) else: raise InvalidValue( value_name="limit", reason=f"The given value {limit} is not in between 1 and 200", ) return self._handle_response(200, "get_job_details", response)
[docs] def delete_job(self, classification_job_id: str) -> Literal["SUCCESS"]: """Delete a text classification job. :param classification_job_id: ID of text classification job :type classification_job_id: str :return: "SUCCESS" if the deletion succeeds :rtype: str **Example:** .. code-block:: python text_classification.delete_job(classification_job_id="<CLASSIFICATION_JOB_ID>") """ self._validate_type(classification_job_id, "classification_job_id", str, True) params = self._client._params(skip_userfs=True) params["hard_delete"] = True response = self._client.httpx_client.delete( url=self._client._href_definitions.get_text_classification_href( classification_job_id ), params=params, headers=self._client._get_headers(), ) return self._handle_response(204, "delete_job", response) # type: ignore[return-value]
[docs] def cancel_job(self, classification_job_id: str) -> Literal["SUCCESS"]: """Cancel a text classification job. :param classification_job_id: ID of text classification job :type classification_job_id: str :return: "SUCCESS" if the cancel succeeds :rtype: str **Example:** .. code-block:: python text_classification.cancel_job(classification_job_id="<CLASSIFICATION_JOB_ID>") """ self._validate_type(classification_job_id, "classification_job_id", str, True) response = self._client.httpx_client.delete( url=self._client._href_definitions.get_text_classification_href( classification_job_id ), params=self._client._params(skip_userfs=True), headers=self._client._get_headers(), ) return self._handle_response(204, "cancel_job", response) # type: ignore[return-value]
[docs] @classmethod def get_job_id(cls, classification_details: dict) -> str: """Get the unique ID of a stored classification request. :param classification_details: metadata of the stored classification :type classification_details: dict :return: unique ID of the stored clasification request :rtype: str **Example:** .. code-block:: python classification_details = text_classification.run_job(...) classification_job_id = text_classification.get_job_id(classification_details) """ cls._validate_type(classification_details, "classification_details", dict, True) return cast( str, cls._get_required_element_from_dict( classification_details, "classification_details", ["metadata", "id"] ), )