Source code for ibm_watsonx_ai.data_loaders.datasets.documents

#  -----------------------------------------------------------------------------------------
#  (C) Copyright IBM Corp. 2024-2025.
#  https://opensource.org/licenses/BSD-3-Clause
#  -----------------------------------------------------------------------------------------
from __future__ import annotations

__all__ = [
    "DocumentsIterableDataset",
]

import logging
import sys

from ibm_watsonx_ai.helpers import ContainerLocation

logger = logging.getLogger(__name__)

import os
from copy import copy

import pandas as pd
from typing import TYPE_CHECKING, Iterator, Any, Callable

from ibm_watsonx_ai.wml_client_error import WMLClientError
from ibm_watsonx_ai.data_loaders.text_loader import (
    TextLoader,
    _asynch_download,
)
from ibm_watsonx_ai.helpers.remote_document import RemoteDocument
from ibm_watsonx_ai.utils.autoai.enums import DocumentsSamplingTypes, SamplingTypes
from ibm_watsonx_ai.utils.autoai.errors import InvalidSizeLimit

if TYPE_CHECKING:
    from ibm_watsonx_ai.helpers.connections import DataConnection

# Note: try to import torch lib if available, this fallback is done based on
# torch dependency removal request
try:
    from torch.utils.data import IterableDataset

except ImportError:
    IterableDataset: type = object  # type: ignore[no-redef]
# --- end note

DEFAULT_SAMPLE_SIZE_LIMIT = (
    1073741824  # 1GB in Bytes is verified later by _set_sample_size_limit
)
DEFAULT_SAMPLING_TYPE = SamplingTypes.FIRST_VALUES
DEFAULT_DOCUMENTS_SAMPLING_TYPE = DocumentsSamplingTypes.RANDOM


[docs] class DocumentsIterableDataset(IterableDataset): """ This dataset is an Iterable stream of documents using an underneath Flight Service. It can download documents asynchronously and serve them to you from a generator. Supported types of documents: - **text/plain** (".txt" file extension) - plain structured text - **docx** (".docx" file extension) - standard Word style file - **pdf** (".pdf" file extension) - standard pdf document - **html** (".html" file extension) - saved html side - **markdown** (".md" file extension) - plain text formatted with markdown :param connections: list of connections to the documents :type connections: list[DataConnection] :param enable_sampling: if set to `True`, will enable sampling, default: True :type enable_sampling: bool :param sample_size_limit: upper limit for documents to be downloaded in bytes, default: 1 GB :type sample_size_limit: int :param sampling_type: a sampling strategy on how to read the data, check the `DocumentsSamplingTypes` enum class for more options :type sampling_type: str :param total_size_limit: upper limit for documents to be downloaded in Bytes, default: 1 GB, if more than one of: `total_size_limit`, `total_ndocs_limit` are set, then data are limited to the lower threshold. :type total_size_limit: int :param total_ndocs_limit: upper limit for documents to be downloaded in a number of rows, if more than one of: `total_size_limit`, `total_nrows_limit` are set, then data are limited to the lower threshold. :type total_ndocs_limit: int, optional :param benchmark_dataset: dataset of benchmarking data with IDs in the `document_ids` column corresponding to the names of documents in the `connections` list :type benchmark_dataset: pd.DataFrame, optional :param error_callback: error callback function, to handle the exceptions from document loading, as arguments are passed document_id and exception :type error_callback: function (str, Exception) -> None, optional :param api_client: initialized APIClient object with set project or space ID. If the DataConnection object in list connections does not have a set API client, then the api_client object is used for reading data. :type api_client: APIClient, optional **Example: default sampling - read up to 1 GB of random documents** .. code-block:: python connections = [DataConnection(data_asset_id='5d99c11a-2060-4ef6-83d5-dc593c6455e2')] iterable_dataset = DocumentsIterableDataset(connections=connections, enable_sampling=True, sampling_type='random', sample_size_limit = 1GB) **Example: read all documents/no subsampling** .. code-block:: python connections = [DataConnection(data_asset_id='5d99c11a-2060-4ef6-83d5-dc593c6455e2')] iterable_dataset = DocumentsIterableDataset(connections=connections, enable_sampling=False) **Example: context based sampling** .. code-block:: python connections = [DataConnection(data_asset_id='5d99c11a-2060-4ef6-83d5-dc593c6455e2')] iterable_dataset = DocumentsIterableDataset(connections=connections, enable_sampling=True, sampling_type='benchmark_driven', sample_size_limit = 1GB, benchmark_dataset=pd.DataFrame( data={ "question": [ "What foundation models are available in watsonx.ai ?" ], "correct_answers": [ [ "The following models are available in watsonx.ai: ..." ] ], "correct_answer_document_ids": ["sample_pdf_file.pdf"], })) """ def __init__( self, *, connections: list[DataConnection], enable_sampling: bool = True, sample_size_limit: int = DEFAULT_SAMPLE_SIZE_LIMIT, sampling_type: str = DEFAULT_DOCUMENTS_SAMPLING_TYPE, total_size_limit: int = DEFAULT_SAMPLE_SIZE_LIMIT, total_ndocs_limit: int | None = None, benchmark_dataset: pd.DataFrame | None = None, error_callback: Callable[[str, Exception], None] = None, **kwargs: Any, ) -> None: from ibm_watsonx_ai.helpers import S3Location, AssetLocation super().__init__() self.enable_sampling = enable_sampling self.sample_size_limit = sample_size_limit self.sampling_type = sampling_type self._set_size_limit(total_size_limit) self.total_ndocs_limit = total_ndocs_limit self.benchmark_dataset = benchmark_dataset self.error_callback = error_callback self._download_strategy = kwargs.get( "_download_strategy", "n_parallel" ) # expected values: "n_parallel", "sequential" api_client = kwargs.get("api_client", kwargs.get("_api_client")) if api_client is not None: for conn in connections: if conn._api_client is None: conn.set_client(api_client) set_of_api_clients = set([conn._api_client for conn in connections]) data_asset_id_name_mapping = {} if any([isinstance(conn.location, AssetLocation) for conn in connections]): for client in set_of_api_clients: for res in client.data_assets.get_details(get_all=True)["resources"]: data_asset_id_name_mapping[res["metadata"]["asset_id"]] = res[ "metadata" ]["resource_key"].split("/")[-1] def get_document_id(conn): if isinstance(conn.location, AssetLocation): if conn.location.id in data_asset_id_name_mapping: return data_asset_id_name_mapping.get(conn.location.id) else: raise WMLClientError( f"The asset with id {conn.location.id} could not be found." ) else: return conn._get_filename() self.remote_documents = [] for connection in connections: if isinstance(connection.location, (S3Location, ContainerLocation)): self.remote_documents.extend( [ RemoteDocument(connection=c, document_id=get_document_id(c)) for c in connection._get_connections_from_bucket() ] ) else: self.remote_documents.append( RemoteDocument( connection=connection, document_id=get_document_id(connection) ) ) if len(set([doc.document_id for doc in self.remote_documents])) < len( self.remote_documents ): raise WMLClientError( "Not unique document file names passed in connections." ) def _set_size_limit(self, size_limit: int) -> None: """If non-default value of total_size_limit was not passed, set Sample Size Limit based on T-Shirt size if code is run on training pod: For memory < 16 (T-Shirts: XS,S) default is 10MB, For memory < 32 & >= 16 (T-Shirts: M) default is 100MB, For memory = 32 (T-Shirt L) default is 0.7GB, For memory > 32 (T-Shirt XL) or runs outside pod default is 1GB. """ self.total_size_limit: int | None from ibm_watsonx_ai.utils.autoai.connection import get_max_sample_size_limit max_tshirt_size_limit = ( get_max_sample_size_limit() if os.getenv("MEM", False) else None ) # limit manual setting of sample size limit on autoai clusters #31527 if self.enable_sampling: if max_tshirt_size_limit: if ( size_limit > max_tshirt_size_limit and size_limit != DEFAULT_SAMPLE_SIZE_LIMIT ): raise InvalidSizeLimit(size_limit, max_tshirt_size_limit) else: self.total_size_limit = min(size_limit, max_tshirt_size_limit) else: self.total_size_limit = size_limit else: if size_limit == DEFAULT_SAMPLE_SIZE_LIMIT: self.total_size_limit = None # do not limit reading if sampling is disabled, we want read all data else: self.total_size_limit = size_limit @staticmethod def _docs_context_sampling( remote_documents: list[RemoteDocument], benchmark_document_ids: list[str], ) -> list[RemoteDocument]: """Randomly sample documents from the benchmark set, then randomly from the rest up to a `size_upper_bound`. :param remote_documents: documents to sample from :type remote_documents: list[RemoteDocument] :param benchmark_document_ids: IDs of documents from the benchmark dataset :type benchmark_document_ids: list[str] :return: list of sampled documents :rtype: list[RemoteDocument] """ sampled_documents = [] benchmark_documents = [ doc for doc in remote_documents if doc.document_id in benchmark_document_ids ] non_benchmark_documents = [ doc for doc in remote_documents if doc.document_id not in benchmark_document_ids ] sampled_documents.extend(benchmark_documents) sampled_documents.extend(non_benchmark_documents) return sampled_documents @staticmethod def _docs_random_sampling( remote_documents: list[RemoteDocument], ) -> list[RemoteDocument]: """Randomly sample documents from `remote_documents` up to `size_upper_bound`. :param remote_documents: documents to sample from :type remote_documents: list[RemoteDocument] :return: list of sampled documents :rtype: list[RemoteDocument] """ from random import shuffle sampling_order = list(range(len(remote_documents))) shuffle(sampling_order) return [remote_documents[i] for i in sampling_order] def __iter__(self) -> Iterator: """Iterate over documents.""" size_limit = ( self.sample_size_limit if self.sample_size_limit is not None and self.enable_sampling else self.total_size_limit ) if self.enable_sampling: if self.sampling_type == DocumentsSamplingTypes.RANDOM: sampled_docs = self._docs_random_sampling(self.remote_documents) elif self.sampling_type == DocumentsSamplingTypes.BENCHMARK_DRIVEN: if self.benchmark_dataset is not None: try: benchmark_documents_ids = list( set( [ y for x in self.benchmark_dataset[ "correct_answer_document_ids" ].values for y in x ] ) ) except TypeError as e: raise WMLClientError( "Unable to collect `correct_answer_document_ids` from the benchmark dataset, " f"due to invalid schema provided. Error: {e}" ) else: raise ValueError( "`benchmark_dataset` is mandatory for sample_type: DocumentsSamplingTypes.BENCHMARK_DRIVEN." ) sampled_docs = self._docs_context_sampling( self.remote_documents, benchmark_documents_ids ) else: raise ValueError( f"Unsupported documents sampling type: {self.sampling_type}" ) else: sampled_docs = copy(self.remote_documents) if self.total_ndocs_limit is not None: if len(sampled_docs) > self.total_ndocs_limit: logger.info( f"Documents sampled with total_ndocs_limit param, " + f"{len(sampled_docs[: self.total_ndocs_limit])} docs chosen " + f"from {len(sampled_docs)} possible." ) sampled_docs = sampled_docs[: self.total_ndocs_limit] match self._download_strategy: case "n_parallel": # downloading documents entirely in parallel import multiprocessing.dummy as mp thread_no = min(5, len(sampled_docs)) q_input = mp.Queue() qs_output = [mp.Queue() for _ in range(len(sampled_docs))] args = [(q_input, qs_output)] * thread_no for i, doc in enumerate(sampled_docs): q_input.put((i, doc)) with mp.Pool(thread_no) as pool: from queue import Empty pool.map_async(_asynch_download, args) res_size = 0 for i in range(len(qs_output)): try: result = qs_output[i].get(timeout=10 * 60) except Empty as e: result = e if isinstance(result, Exception): e = result doc_id = sampled_docs[i].document_id logger.error(f"Failed to download the file: `{doc_id}`") if not isinstance(e, Empty): logger.error(e) if self.error_callback: self.error_callback(doc_id, e) else: doc = result res_size += sys.getsizeof(doc.page_content) if size_limit is not None and res_size > size_limit: return yield doc case _: # "sequential" - simple sequential downloading res_size = 0 for doc in sampled_docs: try: doc.download() loaded_doc = TextLoader(doc).load() res_size += len(loaded_doc.page_content.encode("utf-8")) if size_limit is not None and res_size > size_limit: return yield loaded_doc except Exception as e: if self.error_callback: self.error_callback(doc.document_id, e) else: raise e