Source code for ibm_watsonx_ai.helpers.connections.connections

__all__ = [
    "DataConnection",
    "S3Connection",
    "ConnectionAsset",
    "S3Location",
    "FSLocation",
    "AssetLocation",
    "CloudAssetLocation",
    "DeploymentOutputAssetLocation",
    "NFSConnection",
    "NFSLocation",
    "ConnectionAssetLocation",
    "DatabaseLocation",
    "ContainerLocation",
    "GithubLocation",
]

#  -----------------------------------------------------------------------------------------
#  (C) Copyright IBM Corp. 2023-2025.
#  https://opensource.org/licenses/BSD-3-Clause
#  -----------------------------------------------------------------------------------------

import copy
import io
import os
import sys
import uuid
from copy import deepcopy
from typing import Union, Tuple, List, TYPE_CHECKING, Optional, Any
from warnings import warn
import re

import numpy as np
from ibm_boto3 import resource
from ibm_botocore.client import ClientError
from pandas import DataFrame

import ibm_watsonx_ai._wrappers.requests as requests
from ibm_watsonx_ai.data_loaders.datasets.experiment import (
    DEFAULT_SAMPLING_TYPE,
    DEFAULT_SAMPLE_SIZE_LIMIT,
)
from ibm_watsonx_ai.utils.autoai.enums import DataConnectionTypes
from ibm_watsonx_ai.utils.autoai.errors import (
    ContainerTypeNotSupported,
    ConnectionAssetNotSupported,
    InvalidLocationInDataConnection,
)
from ibm_watsonx_ai.utils.autoai.errors import (
    MissingAutoPipelinesParameters,
    MissingCOSStudioConnection,
    MissingProjectLib,
    InvalidCOSCredentials,
    InvalidIdType,
    NotExistingCOSResource,
    CannotReadSavedRemoteDataBeforeFit,
    NoAutomatedHoldoutSplit,
    DirectoryHasNoFilename,
    CannotGetFilename,
)
from ibm_watsonx_ai.utils.autoai.utils import (
    all_logging_disabled,
    try_import_autoai_libs,
    try_import_autoai_ts_libs,
)
from ibm_watsonx_ai.wml_client_error import (
    MissingValue,
    ApiRequestFailure,
    WMLClientError,
)
from .base_connection import BaseConnection
from .base_data_connection import BaseDataConnection
from .base_location import BaseLocation

if TYPE_CHECKING:
    from ibm_watsonx_ai.workspace import WorkSpace


[docs] class DataConnection(BaseDataConnection): """You need a Data Storage Connection class for Service training metadata (input data). :param connection: connection parameters of a specific type :type connection: NFSConnection or ConnectionAsset, optional :param location: required location parameters of a specific type :type location: Union[S3Location, FSLocation, AssetLocation] :param data_asset_id: data asset ID, if the DataConnection should point to a data asset :type data_asset_id: str, optional """ def __init__( self, location: Union[ "S3Location", "FSLocation", "AssetLocation", "CloudAssetLocation", "NFSLocation", "DeploymentOutputAssetLocation", "ConnectionAssetLocation", "DatabaseLocation", "ContainerLocation", "GithubLocation", ] = None, connection: Optional[ Union["S3Connection", "NFSConnection", "ConnectionAsset"] ] = None, data_asset_id: str = None, connection_asset_id: str = None, **kwargs, ): if data_asset_id is None and location is None: if connection_asset_id is not None: connection = ConnectionAsset(connection_id=connection_asset_id) else: raise MissingValue( "location or data_asset_id", reason="Provide 'location' or 'data_asset_id'.", ) elif data_asset_id is not None and location is not None: raise ValueError( "'data_asset_id' and 'location' cannot be specified together." ) elif data_asset_id is not None: location = AssetLocation(asset_id=data_asset_id) if kwargs.get("model_location") is not None: location._model_location = kwargs["model_location"] if kwargs.get("training_status") is not None: location._training_status = kwargs["training_status"] elif connection_asset_id is not None and isinstance( location, (S3Location, DatabaseLocation, NFSLocation) ): if not isinstance(connection_asset_id, str): raise InvalidIdType(type(connection_asset_id)) connection = ConnectionAsset(connection_id=connection_asset_id) elif ( connection_asset_id is None and connection is None and isinstance(location, (S3Location, DatabaseLocation, NFSLocation)) ): raise ValueError( "'connection_asset_id' and 'connection' cannot be empty together when 'location' is " "[S3Location, DatabaseLocation, NFSLocation]." ) super().__init__() self.connection = connection self.location = location # TODO: remove S3 implementation if isinstance(connection, S3Connection): self.type = DataConnectionTypes.S3 elif isinstance(connection, ConnectionAsset): self.type = DataConnectionTypes.CA # note: We expect a `file_name` keyword for CA pointing to COS or NFS. if isinstance(self.location, (S3Location, NFSLocation)): self.location.file_name = self.location.path del self.location.path if isinstance(self.location, NFSLocation): del self.location.id # --- end note elif isinstance(location, FSLocation): self.type = DataConnectionTypes.FS elif isinstance(location, ContainerLocation): self.type = DataConnectionTypes.CN elif isinstance( location, (AssetLocation, CloudAssetLocation, DeploymentOutputAssetLocation) ): self.type = DataConnectionTypes.DS elif isinstance(location, GithubLocation): self.type = DataConnectionTypes.GH self.auto_pipeline_params = ( {} ) # note: needed parameters for recreation of autoai holdout split self._api_client = None self.__api_client = None # only for getter/setter for AssetLocation href self._run_id = None self._test_data = False self._user_holdout_exists = False # note: client as property and setter for dynamic href creation for AssetLocation @property def _wml_client(self): # note: backward compatibility warn( ( "`_wml_client` is deprecated and will be removed in future. " "Instead, please use `_api_client`." ), category=DeprecationWarning, ) # --- end note return self.__api_client @_wml_client.setter def _wml_client(self, var): # note: backward compatibility warn( ( "`_wml_client` is deprecated and will be removed in future. " "Instead, please use `_api_client`." ), category=DeprecationWarning, ) # --- end note self._api_client = var @property def _api_client(self): return self.__api_client @_api_client.setter def _api_client(self, var): self.__api_client = var if isinstance(self.location, (AssetLocation)): self.location.api_client = self.__api_client if getattr(var, "project_type", None) == "local_git_storage": self.location.userfs = True
[docs] def set_client(self, api_client=None, **kwargs): """To enable write/read operations with a connection to a service, set an initialized service client in the connection. :param api_client: API client to connect to a service :type api_client: APIClient **Example:** .. code-block:: python DataConnection.set_client(api_client=api_client) """ # note: backward compatibility if (wml_client := kwargs.get("wml_client")) is None and api_client is None: raise WMLClientError("No `api_client` provided") elif wml_client is not None: if api_client is None: api_client = wml_client warn( ( "`wml_client` is deprecated and will be removed in future. " "Instead, please use `api_client`." ), category=DeprecationWarning, ) # --- end note self._api_client = api_client
# --- end note
[docs] @classmethod def from_studio(cls, path: str) -> List["DataConnection"]: """Create DataConnections from the credentials stored (connected) in Watson Studio. Only for COS. :param path: path in the COS bucket to the training dataset :type path: str :return: list with DataConnection objects :rtype: list[DataConnection] **Example:** .. code-block:: python data_connections = DataConnection.from_studio(path='iris_dataset.csv') """ try: from project_lib import Project except ModuleNotFoundError: raise MissingProjectLib("Missing project_lib package.") else: data_connections = [] for name, value in globals().items(): if isinstance(value, Project): connections = value.get_connections() if connections: for connection in connections: asset_id = connection["asset_id"] connection_details = value.get_connection(asset_id) if ( "url" in connection_details and "access_key" in connection_details and "secret_key" in connection_details and "bucket" in connection_details ): data_connections.append( cls( connection=ConnectionAsset( connection_id=connection_details["id"] ), location=ConnectionAssetLocation( bucket=connection_details["bucket"], file_name=path, ), ) ) if data_connections: return data_connections else: raise MissingCOSStudioConnection( "There is no any COS Studio connection. " "Please create a COS connection from the UI and insert " "the cell with project API connection (Insert project token)" )
def _get_file_paths_from_bucket(self) -> list[str]: """Returns file paths (keys) of objects that are stored at a bucket location. Returns all file paths under the ``self.location.get_location()`` prefix. First checks if the prefix exists as a bucket "directory" - if not, treats prefix as the file name. :return: list of file names (keys) of objects in a bucket location :rtype: list[str] """ prefix = self.location.get_location().strip("/") from ibm_watsonx_ai.data_loaders.datasets.tabular import TabularIterableDataset from .flight_service import FlightConnection try: flight_parameters = ( TabularIterableDataset._update_params_with_connection_properties( connection=self.to_dict(), flight_parameters={}, api_client=self._api_client, ) ) conn = FlightConnection( headers=self._api_client._get_headers(), sampling_type=None, label=None, learning_type=None, connection_id=( self.connection.id if hasattr(self.connection, "id") else None ), flight_parameters=flight_parameters, params={}, project_id=self._api_client.default_project_id, space_id=self._api_client.default_space_id, _api_client=self._api_client, ) paths = [ r["path"].replace(f"/{self.location.bucket}/", "", 1) for r in conn.discovery(f"/{self.location.bucket}/{prefix}")["assets"] if r["type"] == "file" ] except Exception as e: warn(f"Flight discovery didn't work, error: {e}") return self._get_file_paths_from_bucket_fallback() else: if paths: return paths else: raise InvalidLocationInDataConnection(self.location.get_location()) def _get_file_paths_from_bucket_fallback(self) -> list[str]: """Returns file paths (keys) of objects that are stored at a bucket location. Returns all file paths under the ``self.location.get_location()`` prefix. First checks if the prefix exists as a bucket "directory" - if not, treats prefix as the file name. :return: list of file names (keys) of objects in a bucket location :rtype: list[str] """ def get_keys_with_prefix(bucket_objects, prefix): regex = re.compile(prefix + "[^/]+$") return [obj["Key"] for obj in bucket_objects if re.match(regex, obj["Key"])] self._check_if_connection_asset_is_s3() cos_resource_client = self._init_cos_client() prefix = self.location.get_location().strip("/") + "/" bucket_objects = [] marker = None while True: params = {"Bucket": self.location.bucket, "Prefix": prefix} if marker is not None: params["Marker"] = marker res = cos_resource_client.meta.client.list_objects(**params) if "Contents" not in res: raise InvalidLocationInDataConnection(self.location.get_location()) marker = res.get("NextMarker") bucket_objects.extend(res["Contents"]) if marker is None: break return get_keys_with_prefix(bucket_objects, prefix) def _get_connections_from_bucket(self) -> list["DataConnection"]: """Return connections for every object in a bucket location. :raises WMLClientError: If location is not ``S3Location`` :return: list of connections to objects in a bucket location :rtype: list[DataConnection] """ if not isinstance(self.location, (S3Location, ContainerLocation)): raise WMLClientError( error_msg=f"Can't create separate connections from this DataConnection.", reason="This DataConnection's location is not pointing to a S3 bucket.", ) file_extension = os.path.splitext(self.location.get_location())[1] if file_extension: return [self] if isinstance(self.location, ContainerLocation): self._check_if_connection_asset_is_s3() file_paths = self._get_file_paths_from_bucket() new_data_connections = [] for file_path in file_paths: if isinstance(self.location, ContainerLocation): new_data_conn = DataConnection( location=ContainerLocation(path=file_path), ) else: new_data_conn = DataConnection( connection=self.connection, location=S3Location(bucket=self.location.bucket, path=file_path), ) if self._api_client: new_data_conn.set_client(self._api_client) new_data_connections.append(new_data_conn) return new_data_connections def _subdivide_connection(self): if type(self.id) is str or not self.id: return [self] else: def cpy(new_id): child = copy.copy(self) child.id = new_id return child return [cpy(id) for id in self.id] def _to_dict(self) -> dict: """Convert a DataConnection object to a dictionary representation. :return: DataConnection dictionary representation :rtype: dict """ if self.id and type(self.id) is list: raise InvalidIdType(list) _dict = {"type": self.type} # note: id of DataConnection if self.id is not None: _dict["id"] = self.id # --- end note if self.connection is not None: _dict["connection"] = deepcopy(self.connection.to_dict()) try: _dict["location"] = deepcopy(self.location.to_dict()) except AttributeError: _dict["location"] = {} # note: convert userfs to string - training service requires it as string if hasattr(self.location, "userfs"): _dict["location"]["userfs"] = str( getattr(self.location, "userfs", False) ).lower() # end note return _dict
[docs] def to_dict(self) -> dict: """Convert a DataConnection object to a dictionary representation. :return: DataConnection dictionary representation :rtype: dict """ return self._to_dict()
def __repr__(self): return str(self._to_dict()) def __str__(self): return str(self._to_dict()) @classmethod def _from_dict(cls, _dict: dict) -> "DataConnection": """Create a DataConnection object from a dictionary. :param _dict: dictionary data structure with information about the data connection reference :type _dict: dict :return: DataConnection object :rtype: DataConnection """ if _dict["type"] == DataConnectionTypes.FS: data_connection: "DataConnection" = cls( location=FSLocation._set_path(path=_dict["location"]["path"]) ) elif _dict["type"] == DataConnectionTypes.CA: if _dict["location"].get("file_name") is not None and _dict["location"].get( "bucket" ): data_connection: "DataConnection" = cls( connection_asset_id=_dict["connection"]["id"], location=S3Location( bucket=_dict["location"]["bucket"], path=_dict["location"]["file_name"], ), ) elif _dict["location"].get("path") is not None and _dict["location"].get( "bucket" ): data_connection: "DataConnection" = cls( connection_asset_id=_dict["connection"]["id"], location=S3Location( bucket=_dict["location"]["bucket"], path=_dict["location"]["path"], ), ) elif _dict["location"].get("schema_name") and _dict["location"].get( "table_name" ): data_connection: "DataConnection" = cls( connection_asset_id=_dict["connection"]["id"], location=DatabaseLocation( schema_name=_dict["location"]["schema_name"], table_name=_dict["location"]["table_name"], catalog_name=_dict["location"].get("catalog_name"), ), ) else: if "asset_id" in _dict["connection"]: data_connection: "DataConnection" = cls( connection=NFSConnection( asset_id=_dict["connection"]["asset_id"] ), location=NFSLocation(path=_dict["location"]["path"]), ) else: if _dict["location"].get("file_name") is not None: data_connection: "DataConnection" = cls( connection_asset_id=_dict["connection"]["id"], location=NFSLocation(path=_dict["location"]["file_name"]), ) elif _dict["location"].get("path") is not None: data_connection: DataConnection = cls( connection_asset_id=_dict["connection"]["id"], location=NFSLocation(path=_dict["location"]["path"]), ) else: data_connection: DataConnection = cls( connection_asset_id=_dict["connection"]["id"] ) elif _dict["type"] == DataConnectionTypes.CN: data_connection: "DataConnection" = cls( location=ContainerLocation(path=_dict["location"]["path"]) ) else: data_connection: "DataConnection" = cls( location=AssetLocation._set_path(href=_dict["location"]["href"]) ) if _dict.get("id"): data_connection.id = _dict["id"] if _dict["location"].get("userfs"): if str(_dict["location"].get("userfs", "false")).lower() in ["true", "1"]: data_connection.location.userfs = True else: data_connection.location.userfs = False return data_connection
[docs] @classmethod def from_dict(cls, connection_data: dict) -> "DataConnection": """Create a DataConnection object from a dictionary. :param connection_data: dictionary data structure with information about the data connection reference :type connection_data: dict :return: DataConnection object :rtype: DataConnection """ return DataConnection._from_dict(connection_data)
def _recreate_holdout( self, data: "DataFrame", with_holdout_split: bool = True ) -> Union[ Tuple["DataFrame", "DataFrame"], Tuple["DataFrame", "DataFrame", "DataFrame", "DataFrame"], ]: """This method tries to recreate holdout data.""" if self.auto_pipeline_params.get("prediction_columns") is not None: # timeseries try_import_autoai_ts_libs() from autoai_ts_libs.utils.holdout_utils import make_holdout_split # Note: When lookback window is auto detected there is need to get the detected value from training details if ( self.auto_pipeline_params.get("lookback_window") == -1 or self.auto_pipeline_params.get("lookback_window") is None ): ts_metrics = self._api_client.training.get_details( self.auto_pipeline_params.get("run_id"), _internal=True )["entity"]["status"]["metrics"] final_ts_state_name = "after_final_pipelines_generation" for metric in ts_metrics: if ( metric["context"]["intermediate_model"]["process"] == final_ts_state_name ): self.auto_pipeline_params["lookback_window"] = metric[ "context" ]["timeseries"]["lookback_window"] break # Note: imputation is not supported X_train, X_holdout, y_train, y_holdout, _, _, _, _ = make_holdout_split( dataset=data, target_columns=self.auto_pipeline_params.get("prediction_columns"), learning_type="forecasting", test_size=self.auto_pipeline_params.get("holdout_size"), lookback_window=self.auto_pipeline_params.get("lookback_window"), feature_columns=self.auto_pipeline_params.get("feature_columns"), timestamp_column=self.auto_pipeline_params.get("timestamp_column_name"), # n_jobs=None, # tshirt_size=None, return_only_holdout=False, ) X_columns = ( self.auto_pipeline_params.get("feature_columns") if self.auto_pipeline_params.get("feature_columns") else self.auto_pipeline_params["prediction_columns"] ) X_train = DataFrame(X_train, columns=X_columns) X_holdout = DataFrame(X_holdout, columns=X_columns) y_train = DataFrame( y_train, columns=self.auto_pipeline_params["prediction_columns"] ) y_holdout = DataFrame( y_holdout, columns=self.auto_pipeline_params["prediction_columns"] ) return X_train, X_holdout, y_train, y_holdout elif self.auto_pipeline_params.get("feature_columns") is not None: # timeseries anomaly detection try_import_autoai_ts_libs() from autoai_ts_libs.utils.holdout_utils import make_holdout_split from autoai_ts_libs.utils.constants import ( LEARNING_TYPE_TIMESERIES_ANOMALY_PREDICTION, ) # Note: imputation is not supported X_train, X_holdout, y_train, y_holdout, _, _, _, _ = make_holdout_split( dataset=data, learning_type=LEARNING_TYPE_TIMESERIES_ANOMALY_PREDICTION, test_size=self.auto_pipeline_params.get("holdout_size"), # lookback_window=self.auto_pipeline_params.get('lookback_window'), feature_columns=self.auto_pipeline_params.get("feature_columns"), timestamp_column=self.auto_pipeline_params.get("timestamp_column_name"), # n_jobs=None, # tshirt_size=None, return_only_holdout=False, ) X_columns = self.auto_pipeline_params["feature_columns"] y_column = ["anomaly_label"] X_train = DataFrame(X_train, columns=X_columns) X_holdout = DataFrame(X_holdout, columns=X_columns) y_train = DataFrame(y_train, columns=y_column) y_holdout = DataFrame(y_holdout, columns=y_column) return X_train, X_holdout, y_train, y_holdout else: if sys.version_info >= (3, 10): try_import_autoai_libs(minimum_version="1.14.0") else: try_import_autoai_libs(minimum_version="1.12.14") from autoai_libs.utils.holdout_utils import ( make_holdout_split, ) from autoai_libs.utils.sampling_utils import numpy_sample_rows data.replace([np.inf, -np.inf], np.nan, inplace=True) data.drop_duplicates(inplace=True) data.dropna( subset=[self.auto_pipeline_params["prediction_column"]], inplace=True ) dfy = data[self.auto_pipeline_params["prediction_column"]] data.drop( columns=[self.auto_pipeline_params["prediction_column"]], inplace=True ) y_column = [self.auto_pipeline_params["prediction_column"]] X_columns = data.columns if self._test_data or not with_holdout_split: return data, dfy else: ############################ # REMOVE MISSING ROWS # from autoai_libs.utils.holdout_utils import ( numpy_remove_missing_target_rows, ) # Remove (and save) the rows of X and y for which the target variable has missing values data, dfy, _, _, _, _ = numpy_remove_missing_target_rows(y=dfy, X=data) # End of REMOVE MISSING ROWS # ################################### ################# # SAMPLING # # Get a sample of the rows if requested and applicable # (check for sampling is performed inside this function) try: data, dfy, _ = numpy_sample_rows( X=data, y=dfy, train_sample_rows_test_size=self.auto_pipeline_params[ "train_sample_rows_test_size" ], learning_type=self.auto_pipeline_params["prediction_type"], return_sampled_indices=True, ) # Note: we have a silent error here (the old core behaviour) # sampling is not performed as 'train_sample_rows_test_size' is bigger than data rows count # TODO: can we throw an error instead? except ValueError as e: if "between" in str(e): pass else: raise e # End of SAMPLING # ######################## # Perform holdout split try: X_train, X_holdout, y_train, y_holdout, _, _ = make_holdout_split( x=data, y=dfy, learning_type=self.auto_pipeline_params["prediction_type"], fairness_info=self.auto_pipeline_params.get( "fairness_info", None ), test_size=( self.auto_pipeline_params.get("holdout_size") if self.auto_pipeline_params.get("holdout_size") is not None else 0.1 ), return_only_holdout=False, time_ordered_data=self.auto_pipeline_params.get( "time_ordered_data" ), ) except (TypeError, KeyError): if self.auto_pipeline_params.get("time_ordered_data"): warn( "Outdated autoai_libs - time_ordered_data parameter is not supported. Please update autoai_libs to version >=1.16.2" ) X_train, X_holdout, y_train, y_holdout, _, _ = make_holdout_split( x=data, y=dfy, learning_type=self.auto_pipeline_params["prediction_type"], fairness_info=self.auto_pipeline_params.get( "fairness_info", None ), test_size=( self.auto_pipeline_params.get("holdout_size") if self.auto_pipeline_params.get("holdout_size") is not None else 0.1 ), return_only_holdout=False, ) X_train = DataFrame(X_train, columns=X_columns) X_holdout = DataFrame(X_holdout, columns=X_columns) y_train = DataFrame(y_train, columns=y_column) y_holdout = DataFrame(y_holdout, columns=y_column) return X_train, X_holdout, y_train, y_holdout
[docs] def read( self, with_holdout_split: bool = False, csv_separator: str = ",", excel_sheet: str | int | None = None, encoding: str = "utf-8", raw: bool = False, binary: bool = False, read_to_file: str | None = None, number_of_batch_rows: int | None = None, sampling_type: str | None = None, sample_size_limit: int | None = None, sample_rows_limit: int | None = None, sample_percentage_limit: float | None = None, **kwargs: Any, ) -> "DataFrame" | Tuple["DataFrame", "DataFrame"] | bytes: """Download a dataset that is stored in a remote data storage. Returns batch up to 1 GB. :param with_holdout_split: if `True`, data will be split to train and holdout dataset as it was by AutoAI :type with_holdout_split: bool, optional :param csv_separator: separator/delimiter for the CSV file :type csv_separator: str, optional :param excel_sheet: excel file sheet name to use, use only when the xlsx file is an input, support for the number of the sheet is deprecated :type excel_sheet: str, optional :param encoding: encoding type of the CSV file :type encoding: str, optional :param raw: if `False`, simple data is preprocessed (the same as in the backend), if `True`, data is not preprocessed :type raw: bool, optional :param binary: indicates to retrieve data in binary mode, the result will be a python binary type variable :type binary: bool, optional :param read_to_file: stream read data to a file under the path specified as the value of this parameter, use this parameter to prevent keeping data in-memory :type read_to_file: str, optional :param number_of_batch_rows: number of rows to read in each batch when reading from the flight connection :type number_of_batch_rows: int, optional :param sampling_type: a sampling strategy on how to read the data :type sampling_type: str, optional :param sample_size_limit: upper limit for the overall data to be downloaded in bytes, default: 1 GB :type sample_size_limit: int, optional :param sample_rows_limit: upper limit for the overall data to be downloaded in a number of rows :type sample_rows_limit: int, optional :param sample_percentage_limit: upper limit for the overall data to be downloaded in the percent of all dataset, this parameter is ignored, when `sampling_type` parameter is set to `first_n_records`, must be a float number between 0 and 1 :type sample_percentage_limit: float, optional .. note:: If more than one of: `sample_size_limit`, `sample_rows_limit`, `sample_percentage_limit` are set, then downloaded data is limited to the lowest threshold. :return: one of the following: - pandas.DataFrame that contains dataset from remote data storage : Xy_train - Tuple[pandas.DataFrame, pandas.DataFrame, pandas.DataFrame, pandas.DataFrame] : X_train, X_holdout, y_train, y_holdout - Tuple[pandas.DataFrame, pandas.DataFrame] : X_test, y_test that contains training data and holdout data from remote storage - bytes object, auto holdout split from backend (only train data provided) **Examples** .. code-block:: python train_data_connections = optimizer.get_data_connections() data = train_data_connections[0].read() # all train data # or X_train, X_holdout, y_train, y_holdout = train_data_connections[0].read(with_holdout_split=True) # train and holdout data Your train and test data: .. code-block:: python optimizer.fit(training_data_reference=[DataConnection], training_results_reference=DataConnection, test_data_reference=DataConnection) test_data_connection = optimizer.get_test_data_connections() X_test, y_test = test_data_connection.read() # only holdout data # and train_data_connections = optimizer.get_data_connections() data = train_connections[0].read() # only train data """ # enables flight automatically for CP4D 4.0.x, 4.5.x try: use_flight = kwargs.get( "use_flight", bool( self._api_client is not None or "USER_ACCESS_TOKEN" in os.environ or "RUNTIME_ENV_ACCESS_TOKEN_FILE" in os.environ ), ) except: use_flight = False return_data_as_iterator = kwargs.get("return_data_as_iterator", False) sampling_type = ( sampling_type if sampling_type is not None else DEFAULT_SAMPLING_TYPE ) enable_sampling = kwargs.get("enable_sampling", True) total_size_limit = ( sample_size_limit if sample_size_limit is not None else kwargs.get("total_size_limit", DEFAULT_SAMPLE_SIZE_LIMIT) ) total_nrows_limit = sample_rows_limit total_percentage_limit = ( sample_percentage_limit if sample_percentage_limit is not None else 1.0 ) # Deprecation of excel_sheet as number: if isinstance(excel_sheet, int): warn( message="Support for excel sheet as number of the sheet (int) is deprecated! Please set excel sheet with name of the sheet." ) flight_parameters = kwargs.get("flight_parameters", {}) impersonate_header = kwargs.get("impersonate_header", None) if ( with_holdout_split and self._user_holdout_exists ): # when this connection is training one raise NoAutomatedHoldoutSplit( reason="Experiment was run based on user defined holdout dataset." ) # note: experiment metadata is used only in autogen notebooks experiment_metadata = kwargs.get("experiment_metadata") # note: process subsampling stats flag _return_subsampling_stats = kwargs.get("_return_subsampling_stats", False) if experiment_metadata is not None: self.auto_pipeline_params["train_sample_rows_test_size"] = ( experiment_metadata.get("train_sample_rows_test_size") ) self.auto_pipeline_params["prediction_column"] = experiment_metadata.get( "prediction_column" ) self.auto_pipeline_params["prediction_columns"] = experiment_metadata.get( "prediction_columns" ) self.auto_pipeline_params["holdout_size"] = experiment_metadata.get( "holdout_size" ) self.auto_pipeline_params["prediction_type"] = experiment_metadata[ "prediction_type" ] self.auto_pipeline_params["fairness_info"] = experiment_metadata.get( "fairness_info" ) self.auto_pipeline_params["lookback_window"] = experiment_metadata.get( "lookback_window" ) self.auto_pipeline_params["timestamp_column_name"] = ( experiment_metadata.get("timestamp_column_name") ) self.auto_pipeline_params["feature_columns"] = experiment_metadata.get( "feature_columns" ) self.auto_pipeline_params["time_ordered_data"] = experiment_metadata.get( "time_ordered_data" ) # note: check for cloud if "training_result_reference" in experiment_metadata: if isinstance( experiment_metadata["training_result_reference"].location, (S3Location, AssetLocation), ): run_id = experiment_metadata[ "training_result_reference" ].location._training_status.split("/")[-2] # WMLS else: run_id = experiment_metadata[ "training_result_reference" ].location.path.split("/")[-3] self.auto_pipeline_params["run_id"] = run_id if self._test_data: csv_separator = experiment_metadata.get( "test_data_csv_separator", csv_separator ) excel_sheet = experiment_metadata.get( "test_data_excel_sheet", excel_sheet ) encoding = experiment_metadata.get("test_data_encoding", encoding) else: csv_separator = experiment_metadata.get("csv_separator", csv_separator) excel_sheet = experiment_metadata.get("excel_sheet", excel_sheet) encoding = experiment_metadata.get("encoding", encoding) if self.type == DataConnectionTypes.DS or self.type == DataConnectionTypes.CA: if self._api_client is None: try: from project_lib import Project except ModuleNotFoundError: raise ConnectionError( "This functionality can be run only on Watson Studio or with api_client passed to connection. " "Please initialize API client using `DataConnection.set_client(api_client=api_client)` function " "to be able to use this functionality." ) if ( with_holdout_split or self._test_data ) and not self.auto_pipeline_params.get("prediction_type", False): raise MissingAutoPipelinesParameters( self.auto_pipeline_params, reason=f"To be able to recreate an original holdout split, you need to schedule a training job or " f"if you are using historical runs, just call historical_optimizer.get_data_connections()", ) # note: allow to read data at any time elif ( ( "csv_separator" not in self.auto_pipeline_params and "encoding" not in self.auto_pipeline_params ) or csv_separator != "," or encoding != "utf-8" ): self.auto_pipeline_params["csv_separator"] = csv_separator self.auto_pipeline_params["encoding"] = encoding # --- end note # note: excel_sheet in params only if it is not None (not specified): if excel_sheet: self.auto_pipeline_params["excel_sheet"] = excel_sheet # --- end note # note: set default quote character for flight (later applicable only for csv files stored in S3) self.auto_pipeline_params["quote_character"] = "double_quote" # --- end note data = DataFrame() headers = None if self._api_client is None: token = self._get_token_from_environment() if token is not None: headers = {"Authorization": f"Bearer {token}"} elif impersonate_header is not None: headers = self._api_client._get_headers() headers["impersonate"] = impersonate_header if self.type == DataConnectionTypes.S3: raise ConnectionError( f"S3 DataConnection is not supported! Please use data_asset_id instead." ) elif self.type == DataConnectionTypes.DS: if use_flight: from ibm_watsonx_ai.utils.utils import is_lib_installed is_lib_installed( lib_name="pyarrow", minimum_version="3.0.0", install=True ) from pyarrow.flight import FlightError _iam_id = None if headers and headers.get("impersonate"): _iam_id = headers.get("impersonate", {}).get("iam_id") self._api_client._iam_id = _iam_id try: if self._check_if_connection_asset_is_s3(): # note: update flight parameters only if `connection_properties` was not set earlier # (e.x. by wml/autoi) if not flight_parameters.get("connection_properties"): flight_parameters = ( self._update_flight_parameters_with_connection_details( flight_parameters ) ) data = self._download_data_from_flight_service( data_location=self, binary=binary, read_to_file=read_to_file, flight_parameters=flight_parameters, headers=headers, enable_sampling=enable_sampling, sampling_type=sampling_type, number_of_batch_rows=number_of_batch_rows, return_data_as_iterator=return_data_as_iterator, _return_subsampling_stats=_return_subsampling_stats, total_size_limit=total_size_limit, total_nrows_limit=total_nrows_limit, total_percentage_limit=total_percentage_limit, ) except ( ConnectionError, FlightError, ApiRequestFailure, ) as download_data_error: # note: try to download normal data asset either directly from cams or from mounted NFS # to keep backward compatibility if ( self._api_client and ( ( self._is_data_asset_normal() and self._is_size_acceptable() ) or self._is_data_asset_nfs() ) and ( "Found non-unique column index" not in str(download_data_error) ) ): import warnings warnings.warn(str(download_data_error), Warning) try: data = self._download_training_data_from_data_asset_storage( binary=binary, is_flight_fallback=True ) except: raise download_data_error else: raise download_data_error # backward compatibility else: try: with all_logging_disabled(): if self._check_if_connection_asset_is_s3(): cos_client = self._init_cos_client() data = self._download_data_from_cos( cos_client=cos_client, binary=binary ) else: data = self._download_training_data_from_data_asset_storage( binary=binary ) except NotImplementedError as e: raise e except FileNotFoundError as e: raise e except Exception as e: # do not try Flight if we are on the cloud if self._api_client is not None: if not self._api_client.ICP_PLATFORM_SPACES: raise e elif ( os.environ.get("USER_ACCESS_TOKEN") is None and os.environ.get("RUNTIME_ENV_ACCESS_TOKEN_FILE") is None ): raise CannotReadSavedRemoteDataBeforeFit() data = self._download_data_from_flight_service( data_location=self, binary=binary, read_to_file=read_to_file, flight_parameters=flight_parameters, headers=headers, enable_sampling=enable_sampling, sampling_type=sampling_type, number_of_batch_rows=number_of_batch_rows, return_data_as_iterator=return_data_as_iterator, _return_subsampling_stats=_return_subsampling_stats, total_size_limit=total_size_limit, total_nrows_limit=total_nrows_limit, total_percentage_limit=total_percentage_limit, ) elif self.type == DataConnectionTypes.FS: data = self._download_training_data_from_file_system(binary=binary) elif self.type == DataConnectionTypes.CA or self.type == DataConnectionTypes.CN: if ( getattr(self._api_client, "ICP_PLATFORM_SPACES", False) and self.type == DataConnectionTypes.CN ): raise ContainerTypeNotSupported() # block Container type on CPD if self.type == DataConnectionTypes.CA and self.location is None: raise ConnectionAssetNotSupported() if use_flight: # Workaround for container connection type, we need to fetch COS details from space/project if self.type == DataConnectionTypes.CN: # note: update flight parameters only if `connection_properties` was not set earlier # (e.x. by wml/autoi) if not flight_parameters.get("connection_properties"): flight_parameters = ( self._update_flight_parameters_with_connection_details( flight_parameters ) ) data = self._download_data_from_flight_service( data_location=self, binary=binary, read_to_file=read_to_file, flight_parameters=flight_parameters, headers=headers, enable_sampling=enable_sampling, sampling_type=sampling_type, number_of_batch_rows=number_of_batch_rows, return_data_as_iterator=return_data_as_iterator, _return_subsampling_stats=_return_subsampling_stats, total_size_limit=total_size_limit, total_nrows_limit=total_nrows_limit, total_percentage_limit=total_percentage_limit, ) else: # backward compatibility if isinstance(self.location, DatabaseLocation): raise ConnectionError( "Reading data from 'DatabaseLocation' is supported only with Flight Service. Please set `use_flight=True` parameter." ) try: with all_logging_disabled(): if self._check_if_connection_asset_is_s3(): cos_client = self._init_cos_client() try: data = self._download_data_from_cos( cos_client=cos_client, binary=binary ) except Exception as cos_access_exception: raise ConnectionError( f"Unable to access data object in cloud object storage with credentials supplied. " f"Error: {cos_access_exception}" ) else: data = self._download_data_from_nfs_connection( binary=binary ) except Exception as e: # do not try Flight is we are on the cloud if self._api_client is not None: if not self._api_client.ICP_PLATFORM_SPACES: raise e elif ( os.environ.get("USER_ACCESS_TOKEN") is None and os.environ.get("RUNTIME_ENV_ACCESS_TOKEN_FILE") is None ): raise CannotReadSavedRemoteDataBeforeFit() data = self._download_data_from_flight_service( data_location=self, binary=binary, read_to_file=read_to_file, flight_parameters=flight_parameters, headers=headers, enable_sampling=enable_sampling, sampling_type=sampling_type, number_of_batch_rows=number_of_batch_rows, _return_subsampling_stats=_return_subsampling_stats, total_size_limit=total_size_limit, total_nrows_limit=total_nrows_limit, total_percentage_limit=total_percentage_limit, ) if getattr(self._api_client, "_internal", False): pass # don't remove additional params if client is used internally else: # note: remove additional params and inline credentials added by _check_if_connection_asset_is_s3: [ delattr(self.connection, attr) for attr in [ "secret_access_key", "access_key_id", "endpoint_url", "is_s3", ] if hasattr(self.connection, attr) ] # end note # create data statistics if data were not downloaded with flight: if not isinstance(data, tuple) and _return_subsampling_stats: data = ( data, {"data_batch_size": sys.getsizeof(data), "data_batch_nrows": len(data)}, ) if binary: return data if raw or ( self.auto_pipeline_params.get("prediction_column") is None and self.auto_pipeline_params.get("prediction_columns") is None and self.auto_pipeline_params.get("feature_columns") is None ): return data else: if with_holdout_split: # when this connection is training one if return_data_as_iterator: raise WMLClientError( "The flags `return_data_as_iterator` and `with_holdout_split` cannot be set both in the same time." ) if _return_subsampling_stats: X_train, X_holdout, y_train, y_holdout = self._recreate_holdout( data=data[0] ) return X_train, X_holdout, y_train, y_holdout, data[1] else: X_train, X_holdout, y_train, y_holdout = self._recreate_holdout( data=data ) return X_train, X_holdout, y_train, y_holdout else: # when this data connection is a test / holdout one if return_data_as_iterator: return data if _return_subsampling_stats: if ( self.auto_pipeline_params.get("prediction_columns") or not self.auto_pipeline_params.get("prediction_column") or ( self.auto_pipeline_params.get("prediction_column") and self.auto_pipeline_params.get("prediction_column") not in data[0].columns ) ): # timeseries dataset does not have prediction columns. Whole data set is returned: test_X = data return test_X else: test_X, test_y = self._recreate_holdout( data=data[0], with_holdout_split=False ) test_X[ self.auto_pipeline_params.get( "prediction_column", "prediction_column" ) ] = test_y return test_X, data[1] else: # when this data connection is a test / holdout one and no subsampling stats are needed if ( self.auto_pipeline_params.get("prediction_columns") or not self.auto_pipeline_params.get("prediction_column") or ( self.auto_pipeline_params.get("prediction_column") and self.auto_pipeline_params.get("prediction_column") not in data.columns ) ): # timeseries dataset does not have prediction columns. Whole data set is returned: test_X = data else: test_X, test_y = self._recreate_holdout( data=data, with_holdout_split=False ) test_X[ self.auto_pipeline_params.get( "prediction_column", "prediction_column" ) ] = test_y return test_X # return one dataframe
[docs] def write( self, data: Union[str, "DataFrame"], remote_name: str = None, **kwargs ) -> None: """Upload a file to a remote data storage. :param data: local path to the dataset or pandas.DataFrame with data :type data: str :param remote_name: name of dataset to be stored in the remote data storage :type remote_name: str """ # enables flight automatically for CP4D 4.0.x use_flight = kwargs.get( "use_flight", bool( self._api_client is not None or "USER_ACCESS_TOKEN" in os.environ or "RUNTIME_ENV_ACCESS_TOKEN_FILE" in os.environ ), ) flight_parameters = kwargs.get("flight_parameters", {}) impersonate_header = kwargs.get("impersonate_header", None) headers = None if self._api_client is None: token = self._get_token_from_environment() if token is None: raise ConnectionError( "API client missing. Please initialize API client and pass it to " "DataConnection._api_client property to be able to use this functionality." ) else: headers = {"Authorization": f"Bearer {token}"} elif impersonate_header is not None: headers = self._api_client._get_headers() headers["impersonate"] = impersonate_header # TODO: Remove S3 implementation if self.type == DataConnectionTypes.S3: raise ConnectionError( "S3 DataConnection is not supported. Please use data_asset_id instead." ) elif self.type == DataConnectionTypes.CA or self.type == DataConnectionTypes.CN: if ( getattr(self._api_client, "ICP_PLATFORM_SPACES", False) and self.type == DataConnectionTypes.CN ): raise ContainerTypeNotSupported() # block Container type on CPD if self.type == DataConnectionTypes.CA and self.location is None: raise ConnectionAssetNotSupported() if self._check_if_connection_asset_is_s3(): # do not try Flight if we are on the cloud if ( self._api_client is not None and not self._api_client.ICP_PLATFORM_SPACES and not use_flight ): # CLOUD if remote_name is None and ( self._to_dict().get("location", {}).get("path") or self._to_dict().get("location", {}).get("file_name") ): updated_remote_name = data.split("/")[-1] else: updated_remote_name = self._get_path_with_remote_name( self._to_dict(), remote_name ) cos_resource_client = self._init_cos_client() if isinstance(data, str): with open(data, "rb") as file_data: cos_resource_client.Object( self.location.bucket, updated_remote_name ).upload_fileobj(Fileobj=file_data) elif isinstance(data, DataFrame): # note: we are saving csv in memory as a file and stream it to the COS buffer = io.StringIO() data.to_csv(buffer, index=False) buffer.seek(0) with buffer as f: cos_resource_client.Object( self.location.bucket, updated_remote_name ).upload_fileobj( Fileobj=io.BytesIO(bytes(f.read().encode())) ) else: raise TypeError( 'data should be either of type "str" or "pandas.DataFrame"' ) # CP4D else: # Workaround for container connection type, we need to fetch COS details from space/project if self.type == DataConnectionTypes.CN: # note: update flight parameters only if `connection_properties` was not set earlier # (e.x. by wml/autoi) if not flight_parameters.get("connection_properties"): flight_parameters = ( self._update_flight_parameters_with_connection_details( flight_parameters ) ) if isinstance(data, str): self._upload_data_via_flight_service( file_path=data, data_location=self, remote_name=remote_name, flight_parameters=flight_parameters, headers=headers, ) elif isinstance(data, DataFrame): # note: we are saving csv in memory as a file and stream it to the COS self._upload_data_via_flight_service( data=data, data_location=self, remote_name=remote_name, flight_parameters=flight_parameters, headers=headers, ) else: raise TypeError( 'data should be either of type "str" or "pandas.DataFrame"' ) else: if ( self._api_client is not None and not self._api_client.ICP_PLATFORM_SPACES and not use_flight ): # CLOUD raise ConnectionError( "Connections other than COS are not supported on a cloud yet." ) # CP4D else: if isinstance(data, str): self._upload_data_via_flight_service( file_path=data, data_location=self, remote_name=remote_name, flight_parameters=flight_parameters, headers=headers, binary=kwargs.get("binary", False), ) elif isinstance(data, DataFrame): # note: we are saving csv in memory as a file and stream it to the COS self._upload_data_via_flight_service( data=data, data_location=self, remote_name=remote_name, flight_parameters=flight_parameters, headers=headers, ) else: raise TypeError( 'data should be either of type "str" or "pandas.DataFrame"' ) if getattr(self._api_client, "_internal", False): pass # don't remove additional params if client is used internally else: # note: remove additional params and inline credentials added by _check_if_connection_asset_is_s3: [ delattr(self.connection, attr) for attr in [ "secret_access_key", "access_key_id", "endpoint_url", "is_s3", ] if hasattr(self.connection, attr) ] # end note elif self.type == DataConnectionTypes.DS: if ( self._api_client is not None and not self._api_client.ICP_PLATFORM_SPACES and not use_flight ): # CLOUD raise ConnectionError( "Write of data for Data Asset is not supported on Cloud." ) elif self._api_client is not None: if isinstance(data, str): self._upload_data_via_flight_service( file_path=data, data_location=self, remote_name=remote_name, flight_parameters=flight_parameters, headers=headers, ) elif isinstance(data, DataFrame): # note: we are saving csv in memory as a file and stream it to the COS self._upload_data_via_flight_service( data=data, data_location=self, remote_name=remote_name, flight_parameters=flight_parameters, headers=headers, ) else: raise TypeError( 'data should be either of type "str" or "pandas.DataFrame"' ) else: self._upload_data_via_flight_service( data=data, data_location=self, remote_name=remote_name, flight_parameters=flight_parameters, headers=headers, ) elif self.type == DataConnectionTypes.FS: if isinstance(data, str): with open(data, "rb") as file_data: self._upload_data_to_file_system( location=self.location.path, data=file_data, remote_name=remote_name, ) elif isinstance(data, DataFrame): buffer = io.BytesIO() data.to_csv(buffer, index=False) buffer.seek(0) self._upload_data_to_file_system( location=self.location.path, data=io.BufferedReader(buffer), remote_name=remote_name, ) else: raise TypeError( 'data should be either of type "str" or "pandas.DataFrame"' )
def _init_cos_client(self) -> "resource": """Initiate COS client for further usage.""" from ibm_botocore.client import Config # Make sure endpoint_url startswith 'https://' prefix if not self.connection.endpoint_url.startswith("https://"): self.connection.endpoint_url = "https://" + self.connection.endpoint_url try: if hasattr(self.connection, "auth_endpoint") and hasattr( self.connection, "api_key" ): cos_client = resource( service_name="s3", ibm_api_key_id=self.connection.api_key, ibm_auth_endpoint=self.connection.auth_endpoint, config=Config(signature_version="oauth"), endpoint_url=self.connection.endpoint_url, ) else: cos_client = resource( service_name="s3", endpoint_url=self.connection.endpoint_url, aws_access_key_id=self.connection.access_key_id, aws_secret_access_key=self.connection.secret_access_key, ) except ValueError as e: raise WMLClientError( "Error occurred during COS client initialisation %s".format(e) ) return cos_client def _validate_cos_resource(self): cos_client = self._init_cos_client() try: files = cos_client.Bucket(self.location.bucket).objects.all() next(x for x in files if x.key == self.location.path) except Exception as e: raise NotExistingCOSResource(self.location.bucket, self.location.path) def _update_flight_parameters_with_connection_details(self, flight_parameters): with all_logging_disabled(): self._check_if_connection_asset_is_s3() connection_properties = { "bucket": self.location.bucket, "url": self.connection.endpoint_url, } if hasattr(self.connection, "auth_endpoint") and hasattr( self.connection, "api_key" ): connection_properties["iam_url"] = self.connection.auth_endpoint connection_properties["api_key"] = self.connection.api_key connection_properties["resource_instance_id"] = ( self.connection.resource_instance_id ) else: connection_properties["secret_key"] = self.connection.secret_access_key connection_properties["access_key"] = self.connection.access_key_id flight_parameters.update({"connection_properties": connection_properties}) flight_parameters.update( {"datasource_type": {"entity": {"name": self._datasource_type}}} ) return flight_parameters
[docs] def download(self, filename: str) -> None: """Download a dataset stored in a remote data storage and save to a file. :param filename: path to the file where data will be downloaded :type filename: str **Examples** .. code-block:: python document_reference = DataConnection( connection_asset_id="<connection_id>", location=S3Location(bucket="<bucket_name>", path="path/to/file"), ) document_reference.download(filename='results.json') """ with open(filename, "wb") as file: file.write(self.read(binary=True))
def _get_filename(self): """Get file name of the file in data connection, if applicable. :return: file name :rtype: str **Examples** .. code-block:: python document_reference = DataConnection( connection_asset_id="<connection_id>", location=S3Location(bucket="<bucket_name>", path="path/to/file"), ) filename = document_reference._get_filename() """ if isinstance(self.location, AssetLocation): if self._api_client is None: raise ConnectionError( "API client missing. Please initialize API client and pass it to " "DataConnection._api_client property to be able to use this functionality." ) asset_details = self._api_client.data_assets.get_details(self.location.id) return asset_details["metadata"]["resource_key"].split("/")[-1] elif hasattr(self.location, "file_name"): filename = self.location.file_name.split("/")[-1] if "." not in filename or filename == ".": raise DirectoryHasNoFilename() return filename elif hasattr(self.location, "path"): filename = self.location.path.split("/")[-1] if "." not in filename or filename == ".": raise DirectoryHasNoFilename() return filename else: raise CannotGetFilename()
# TODO: Remove S3 Implementation for connection class S3Connection(BaseConnection): """Connection class to a COS data storage in S3 format. :param endpoint_url: URL of the S3 data storage (COS) :type endpoint_url: str :param access_key_id: access key ID of the S3 connection (COS) :type access_key_id: str, optional :param secret_access_key: secret access key of the S3 connection (COS) :type secret_access_key: str, optional :param api_key: API key of the S3 connection (COS) :type api_key: str, optional :param service_name: service name of the S3 connection (COS) :type service_name: str, optional :param auth_endpoint: authentication endpoint URL of the S3 connection (COS) :type auth_endpoint: str, optional """ def __init__( self, endpoint_url: str, access_key_id: str = None, secret_access_key: str = None, api_key: str = None, service_name: str = None, auth_endpoint: str = None, resource_instance_id: str = None, _internal_use=False, ) -> None: if not _internal_use: warn( message="S3 DataConnection is not supported. Please use data_asset_id instead." ) if (access_key_id is None or secret_access_key is None) and ( api_key is None or auth_endpoint is None ): raise InvalidCOSCredentials( reason="You need to specify (access_key_id and secret_access_key) or" "(api_key and auth_endpoint)" ) if secret_access_key is not None: self.secret_access_key = secret_access_key if api_key is not None: self.api_key = api_key if service_name is not None: self.service_name = service_name if auth_endpoint is not None: self.auth_endpoint = auth_endpoint if access_key_id is not None: self.access_key_id = access_key_id if endpoint_url is not None: self.endpoint_url = endpoint_url if resource_instance_id is not None: self.resource_instance_id = resource_instance_id
[docs] class S3Location(BaseLocation): """Connection class to a COS data storage in S3 format. :param bucket: COS bucket name :type bucket: str :param path: COS data path in the bucket :type path: str :param excel_sheet: name of the excel sheet, if the chosen dataset uses an excel file for Batched Deployment scoring :type excel_sheet: str, optional :param model_location: path to the pipeline model in the COS :type model_location: str, optional :param training_status: path to the training status JSON in the COS :type training_status: str, optional """ def __init__(self, bucket: str, path: str, **kwargs) -> None: self.bucket = bucket self.path = path if kwargs.get("model_location") is not None: self._model_location = kwargs["model_location"] if kwargs.get("training_status") is not None: self._training_status = kwargs["training_status"] if kwargs.get("excel_sheet") is not None: self.sheet_name = kwargs["excel_sheet"] self.file_format = "xls" def _get_file_size(self, cos_resource_client: "resource") -> int: try: size = cos_resource_client.Object( self.bucket, getattr(self, "path", getattr(self, "file_name")) ).content_length except ClientError: size = 0 return size
[docs] def get_location(self) -> str: if hasattr(self, "file_name"): return self.file_name else: return self.path
def _get_file_extension(self) -> str: """ Returns the file extension of the file located at the specified location. If no file extension is specified in self.path / self.file_name then empty string "" is returned. """ return os.path.splitext(self.get_location())[-1]
class ContainerLocation(BaseLocation): """Connection class to default COS in user Project/Space.""" def __init__(self, path: Optional[str] = None, **kwargs) -> None: if path is None: self.path = "default_autoai_out" else: self.path = path self.bucket = None if kwargs.get("model_location") is not None: self._model_location = kwargs["model_location"] if kwargs.get("training_status") is not None: self._training_status = kwargs["training_status"] def to_dict(self) -> dict: _dict = super().to_dict() if "bucket" in _dict and _dict["bucket"] is None: del _dict["bucket"] return _dict @classmethod def _set_path(cls, path: str) -> "ContainerLocation": location = cls() location.path = path return location def _get_file_size(self): pass def _get_file_extension(self) -> str: """ Returns the file extension of the file located at the specified location. If no file extension is specified in self.path then empty string "" is returned. """ return os.path.splitext(self.path)[-1] def get_location(self) -> str: if hasattr(self, "file_name"): return self.file_name else: return self.path class FSLocation(BaseLocation): """Connection class to File Storage in CP4D.""" def __init__(self, path: Optional[str] = None) -> None: if path is None: self.path = ( "/{option}/{id}" + f"/assets/auto_ml/auto_ml.{uuid.uuid4()}/wml_data" ) else: self.path = path @classmethod def _set_path(cls, path: str) -> "FSLocation": location = cls() location.path = path return location def _save_file_as_data_asset(self, workspace: "WorkSpace") -> "str": asset_name = self.path.split("/")[-1] if self.path: data_asset_details = workspace.api_client.data_assets.create( asset_name, self.path ) return workspace.api_client.data_assets.get_id(data_asset_details) else: raise MissingValue( "path", reason="Incorrect initialization of class FSLocation" ) def _get_file_size(self, workspace: "WorkSpace") -> "int": # note if path is not file then returned size is 0 try: # note: try to get file size from remote server url = ( workspace.api_client.service_instance._href_definitions.get_wsd_model_attachment_href() + f"/{self.path.split('/assets/')[-1]}" ) path_info_response = requests.head( url, headers=workspace.api_client._get_headers(), params=workspace.api_client._params(), ) if path_info_response.status_code != 200: raise ApiRequestFailure( "Failure during getting path details", path_info_response ) path_info = path_info_response.headers if ( "X-Asset-Files-Type" in path_info and path_info["X-Asset-Files-Type"] == "file" ): size = path_info["X-Asset-Files-Size"] else: size = 0 # -- end note except (ApiRequestFailure, AttributeError): # note try get size of file from local fs size = ( os.stat(path=self.path).st_size if os.path.isfile(path=self.path) else 0 ) # -- end note return size def _get_file_extension(self) -> str: """ Returns the file extension of the file located at the specified location. If no file extension is specified in self.path then empty string "" is returned. """ return os.path.splitext(self.path)[-1] class AssetLocation(BaseLocation): def __init__(self, asset_id: str) -> None: self.href = None self._initial_asset_id = asset_id self.__api_client = None self.id = asset_id def _get_bucket(self, client) -> str: """Try to get bucket from data asset.""" connection_id = self._get_connection_id(client) conn_details = client.connections.get_details(connection_id) bucket = conn_details.get("entity", {}).get("properties", {}).get("bucket") if bucket is None: asset_details = client.data_assets.get_details(self.id) connection_path = ( asset_details["entity"].get("folder_asset", {}).get("connection_path") ) if connection_path is None: attachment_content = self._get_attachment_details(client) connection_path = attachment_content.get("connection_path") bucket = connection_path.split("/")[1] return bucket def _get_attachment_details(self, client) -> dict: if self.id is None and self.href: items = self.href.split("/") self.id = items[-1].split("?")[0] asset_details = client.data_assets.get_details(self.id) if "attachment_id" in asset_details.get("metadata"): attachment_id = asset_details["metadata"]["attachment_id"] else: attachment_id = asset_details["attachments"][0]["id"] attachment_url = client.service_instance._href_definitions.get_data_asset_href( self.id ) attachment_url = f"{attachment_url}/attachments/{attachment_id}" if client.ICP_PLATFORM_SPACES: attachment = requests.get( attachment_url, headers=client._get_headers(), params=client._params() ) else: attachment = requests.get( attachment_url, headers=client._get_headers(), params=client._params() ) if attachment.status_code != 200: raise ApiRequestFailure( "Failure during getting attachment details", attachment ) return attachment.json() def _get_connection_id(self, client) -> str: attachment_content = self._get_attachment_details(client) return attachment_content.get("connection_id") @classmethod def _set_path(cls, href: str) -> "AssetLocation": items = href.split("/") _id = items[-1].split("?")[0] location = cls(_id) location.href = href return location def _get_file_size(self, workspace: "WorkSpace", *args) -> "int": asset_info_response = requests.get( workspace.api_client.service_instance._href_definitions.get_data_asset_href( self.id ), params=workspace.api_client._params(), headers=workspace.api_client._get_headers(), ) if asset_info_response.status_code != 200: raise ApiRequestFailure( "Failure during getting asset details", asset_info_response ) return asset_info_response.json()["metadata"].get("size") def to_dict(self) -> dict: """Return a json dictionary representing this model.""" _dict = vars(self).copy() if _dict.get("id", False) is None and _dict.get("href"): items = self.href.split("/") _dict["id"] = items[-1].split("?")[0] del _dict[f"_{self.__class__.__name__}__api_client"] del _dict["_initial_asset_id"] return _dict @property def wml_client(self): # note: backward compatibility warn( ( "`wml_client` is deprecated and will be removed in future. " "Instead, please use `api_client`." ), category=DeprecationWarning, ) # --- end note return self.__api_client @wml_client.setter def wml_client(self, var): # note: backward compatibility warn( ( "`wml_client` is deprecated and will be removed in future. " "Instead, please use `api_client`." ), category=DeprecationWarning, ) # --- end note self.api_client = var @property def api_client(self): return self.__api_client @api_client.setter def api_client(self, var): self.__api_client = var if self.__api_client: self.href = self.__api_client.service_instance._href_definitions.get_base_asset_href( self._initial_asset_id ) else: self.href = f"/v2/assets/{self._initial_asset_id}" if self.__api_client: if self.__api_client.default_space_id: self.href = f"{self.href}?space_id={self.__api_client.default_space_id}" else: self.href = ( f"{self.href}?project_id={self.__api_client.default_project_id}" ) def _get_file_extension(self) -> str: """ Returns the file extension of the file located at the specified location. """ if self.api_client: attachment_details = self._get_attachment_details(self.api_client) return os.path.splitext(attachment_details.get("name", ""))[-1] else: raise NotImplementedError class ConnectionAssetLocation(BaseLocation): """Connection class to a COS data storage. :param bucket: COS bucket name :type bucket: str :param file_name: COS data path in the bucket :type file_name: str :param model_location: path to the pipeline model in the COS :type model_location: str, optional :param training_status: path to the training status JSON in COS :type training_status: str, optional """ def __init__(self, bucket: str, file_name: str, **kwargs) -> None: self.bucket = bucket self.file_name = file_name self.path = file_name if kwargs.get("model_location") is not None: self._model_location = kwargs["model_location"] if kwargs.get("training_status") is not None: self._training_status = kwargs["training_status"] def _get_file_size(self, cos_resource_client: "resource") -> "int": try: size = cos_resource_client.Object(self.bucket, self.path).content_length except ClientError: size = 0 return size def to_dict(self) -> dict: """Return a json dictionary representing this model.""" return vars(self) def _get_file_extension(self) -> str: """ Returns the file extension of the file located at the specified location. """ return os.path.splitext(self.file_name)[-1] class GithubLocation(BaseLocation): """Connection class to a Github. :param secret_manager_url: url of Secrets Manager service where the Github PAT and url are stored. :type secret_manager_url: str :param secret_id: ID of the secret with Github PAT and url in the Secrets Manager :type secret_id: str :param path: path within github repo to the file :type path: str """ def __init__(self, secret_manager_url: str, secret_id: str, path: str) -> None: self.secret_manager_url = secret_manager_url self.secret_id = secret_id self.path = path def to_dict(self) -> dict: """Return a json dictionary representing this model.""" return vars(self) class ConnectionAsset(BaseConnection): """Connection class for a Connection Asset. :param connection_id: ID of the connection asset :type connection_id: str """ def __init__(self, connection_id: str): self.id = connection_id class NFSConnection(BaseConnection): """Connection class to file storage in Cloud Pak for Data of NFS format. :param asset_id: asset ID of the Cloud Pak for Data project :type asset_id: str """ def __init__(self, asset_id: str): self.asset_id = asset_id self.id = asset_id class NFSLocation(BaseLocation): """Location class to file storage in Cloud Pak for Data of NFS format. :param path: data path to the Cloud Pak for Data project :type path: str """ def __init__(self, path: str): self.path = path self.id = None self.file_name = None def _get_file_size(self, workspace: "Workspace", *args) -> "int": params = workspace.api_client._params().copy() params["path"] = self.path params["detail"] = "true" href = ( workspace.api_client.connections._href_definitions.get_connection_by_id_href( self.id ) + "/assets" ) asset_info_response = requests.get( href, params=params, headers=workspace.api_client._get_headers(None) ) if asset_info_response.status_code != 200: raise Exception( "Failure during getting asset details", asset_info_response.json() ) return asset_info_response.json()["details"]["file_size"] def get_location(self) -> str: if hasattr(self, "file_name"): return self.file_name else: return self.path def _get_file_extension(self) -> str: """ Returns the file extension of the file located at the specified location. """ return os.path.splitext(self.get_location())[-1]
[docs] class CloudAssetLocation(AssetLocation): """Connection class to data assets as input data references to a batch deployment job on Cloud. :param asset_id: asset ID of the file loaded on space on Cloud :type asset_id: str """ def __init__(self, asset_id: str) -> None: super().__init__(asset_id) self.href = self.href warning_msg = ( "Depreciation Warning: Class CloudAssetLocation is no longer supported and will be removed." "Use AssetLocation instead." ) print(warning_msg) def _get_file_size(self, workspace: "WorkSpace", *args) -> "int": return super()._get_file_size(workspace)
[docs] class DeploymentOutputAssetLocation(BaseLocation): """Connection class to data assets where output of batch deployment will be stored. :param name: name of CSV file to be saved as a data asset :type name: str :param description: description of the data asset :type description: str, optional """ def __init__(self, name: str, description: str = "") -> None: self.name = name self.description = description def _get_file_extension(self) -> str: """ Returns the file extension of the file located at the specified location. """ return os.path.splitext(self.name)[-1]
class DatabaseLocation(BaseLocation): """Location class to Database. :param schema_name: name of database schema :type schema_name: str :param table_name: name of database table :type table_name: str catalog_name: name of database catalog, required only for Presto data source :type catalog_name: str, optional """ def __init__( self, schema_name: str, table_name: str, catalog_name: str = None, **kwargs ) -> None: self.schema_name = schema_name self.table_name = table_name self.catalog_name = catalog_name def _get_file_size(self) -> None: raise NotImplementedError() def to_dict(self) -> dict: """Get a json dictionary representing DatabaseLocation.""" return {key: value for key, value in vars(self).items() if value}