__all__ = [
"DataConnection",
"S3Connection",
"ConnectionAsset",
"S3Location",
"FSLocation",
"AssetLocation",
"CP4DAssetLocation",
"WMLSAssetLocation",
"WSDAssetLocation",
"CloudAssetLocation",
"DeploymentOutputAssetLocation",
"NFSConnection",
"NFSLocation",
'ConnectionAssetLocation',
"DatabaseLocation",
"ContainerLocation"
]
# -----------------------------------------------------------------------------------------
# (C) Copyright IBM Corp. 2020-2024.
# https://opensource.org/licenses/BSD-3-Clause
# -----------------------------------------------------------------------------------------
import io
import os
import uuid
import copy
import sys
from copy import deepcopy
from typing import Union, Tuple, List, TYPE_CHECKING, Optional
from warnings import warn
from ibm_boto3 import resource
from ibm_botocore.client import ClientError
from pandas import DataFrame
import pandas as pd
import ibm_watson_machine_learning._wrappers.requests as requests
from ibm_watson_machine_learning.utils.autoai.enums import PredictionType, DataConnectionTypes
from ibm_watson_machine_learning.utils.autoai.errors import (
MissingAutoPipelinesParameters, UseWMLClient, MissingCOSStudioConnection, MissingProjectLib,
HoldoutSplitNotSupported, InvalidCOSCredentials, MissingLocalAsset, InvalidIdType, NotWSDEnvironment,
NotExistingCOSResource, InvalidDataAsset, CannotReadSavedRemoteDataBeforeFit, NoAutomatedHoldoutSplit
)
import numpy as np
from ibm_watson_machine_learning.utils.autoai.utils import all_logging_disabled, try_import_autoai_libs, \
try_import_autoai_ts_libs
from ibm_watson_machine_learning.utils.autoai.watson_studio import get_project
from ibm_watson_machine_learning.data_loaders.datasets.experiment import DEFAULT_SAMPLING_TYPE, DEFAULT_SAMPLE_SIZE_LIMIT
from ibm_watson_machine_learning.wml_client_error import MissingValue, ApiRequestFailure, WMLClientError
from ibm_watson_machine_learning.utils.autoai.errors import ContainerTypeNotSupported
from ibm_watson_machine_learning.messages.messages import Messages
from .base_connection import BaseConnection
from .base_data_connection import BaseDataConnection
from .base_location import BaseLocation
if TYPE_CHECKING:
from ibm_watson_machine_learning.workspace import WorkSpace
[docs]
class DataConnection(BaseDataConnection):
"""Data Storage Connection class needed for WML training metadata (input data).
:param connection: connection parameters of specific type
:type connection: NFSConnection or ConnectionAsset, optional
:param location: required location parameters of specific type
:type location: Union[S3Location, FSLocation, AssetLocation]
:param data_join_node_name: name(s) for node(s):
- `None` - data file name will be used as node name
- str - it will became node name
- list[str] - multiple names passed, several nodes will have the same data connection
(used for excel files with multiple sheets)
:type data_join_node_name: None or str or list[str], optional
:param data_asset_id: data asset ID if DataConnection should be pointing out to data asset
:type data_asset_id: str, optional
"""
def __init__(self,
location: Union['S3Location',
'FSLocation',
'AssetLocation',
'CP4DAssetLocation',
'WMLSAssetLocation',
'WSDAssetLocation',
'CloudAssetLocation',
'NFSLocation',
'DeploymentOutputAssetLocation',
'ConnectionAssetLocation',
'DatabaseLocation',
'ContainerLocation'] = None,
connection: Optional[Union['S3Connection', 'NFSConnection', 'ConnectionAsset']] = None,
data_join_node_name: Union[str, List[str]] = None,
data_asset_id: str = None,
connection_asset_id: str = None,
**kwargs):
if data_asset_id is None and location is None:
raise MissingValue('location or data_asset_id', reason="Provide 'location' or 'data_asset_id'.")
elif data_asset_id is not None and location is not None:
raise ValueError("'data_asset_id' and 'location' cannot be specified together.")
elif data_asset_id is not None:
location = AssetLocation(asset_id=data_asset_id)
if kwargs.get('model_location') is not None:
location._model_location = kwargs['model_location']
if kwargs.get('training_status') is not None:
location._training_status = kwargs['training_status']
elif connection_asset_id is not None and isinstance(location, (S3Location, DatabaseLocation, NFSLocation)):
connection = ConnectionAsset(connection_id=connection_asset_id)
elif connection_asset_id is None and connection is None and isinstance(location, (S3Location,
DatabaseLocation,
NFSLocation)):
raise ValueError("'connection_asset_id' and 'connection' cannot be empty together when 'location' is "
"[S3Location, DatabaseLocation, NFSLocation].")
super().__init__()
self.connection = connection
self.location = location
# TODO: remove S3 implementation
if isinstance(connection, S3Connection):
self.type = DataConnectionTypes.S3
elif isinstance(connection, ConnectionAsset):
self.type = DataConnectionTypes.CA
# note: We expect a `file_name` keyword for CA pointing to COS or NFS.
if isinstance(self.location, (S3Location, NFSLocation)):
self.location.file_name = self.location.path
del self.location.path
if isinstance(self.location, NFSLocation):
del self.location.id
# --- end note
elif isinstance(location, FSLocation):
self.type = DataConnectionTypes.FS
elif isinstance(location, ContainerLocation):
self.type = DataConnectionTypes.CN
elif isinstance(location, (AssetLocation, CP4DAssetLocation, WMLSAssetLocation, CloudAssetLocation,
WSDAssetLocation, DeploymentOutputAssetLocation)):
self.type = DataConnectionTypes.DS
self.auto_pipeline_params = {} # note: needed parameters for recreation of autoai holdout split
self._wml_client = None
self.__wml_client = None # only for getter/setter for AssetLocation href
self._run_id = None
self._obm = False
self._obm_cos_path = None
self._test_data = False
self._user_holdout_exists = False
# note: make data connection id as a location path for OBM + KB
if data_join_node_name is None:
# TODO: remove S3 implementation
if self.type == DataConnectionTypes.S3 or (
self.type == DataConnectionTypes.CA and hasattr(location, 'file_name')):
self.id = location.get_location()
else:
self.id = None
else:
self.id = data_join_node_name
# --- end note
# note: client as property and setter for dynamic href creation for AssetLocation
@property
def _wml_client(self):
return self.__wml_client
@_wml_client.setter
def _wml_client(self, var):
self.__wml_client = var
if isinstance(self.location, (AssetLocation, WSDAssetLocation)):
self.location.wml_client = self.__wml_client
if getattr(var, 'project_type', None) == 'local_git_storage':
self.location.userfs = True
[docs]
def set_client(self, wml_client):
"""Set initialized wml client in connection to enable write/read operations with connection to service.
:param wml_client: WML client to connect to service
:type wml_client: APIClient
**Example**
.. code-block:: python
DataConnection.set_client(wml_client)
"""
self._wml_client = wml_client
# --- end note
[docs]
@classmethod
def from_studio(cls, path: str) -> List['DataConnection']:
"""Create DataConnections from the credentials stored (connected) in Watson Studio. Only for COS.
:param path: path in COS bucket to the training dataset
:type path: str
:return: list with DataConnection objects
:rtype: list[DataConnection]
**Example**
.. code-block:: python
data_connections = DataConnection.from_studio(path='iris_dataset.csv')
"""
try:
from project_lib import Project
except ModuleNotFoundError:
raise MissingProjectLib("Missing project_lib package.")
else:
data_connections = []
for name, value in globals().items():
if isinstance(value, Project):
connections = value.get_connections()
if connections:
for connection in connections:
asset_id = connection['asset_id']
connection_details = value.get_connection(asset_id)
if ('url' in connection_details and 'access_key' in connection_details and
'secret_key' in connection_details and 'bucket' in connection_details):
data_connections.append(
cls(connection=ConnectionAsset(connection_id=connection_details['id']),
location=ConnectionAssetLocation(bucket=connection_details['bucket'],
file_name=path))
)
if data_connections:
return data_connections
else:
raise MissingCOSStudioConnection(
"There is no any COS Studio connection. "
"Please create a COS connection from the UI and insert "
"the cell with project API connection (Insert project token)")
def _subdivide_connection(self):
if type(self.id) is str or not self.id:
return [self]
else:
def cpy(new_id):
child = copy.copy(self)
child.id = new_id
return child
return [cpy(id) for id in self.id]
def _to_dict(self) -> dict:
"""Convert DataConnection object to dictionary representation.
:return: DataConnection dictionary representation
:rtype: dict
"""
if self.id and type(self.id) is list:
raise InvalidIdType(list)
_dict = {"type": self.type}
# note: for OBM (id of DataConnection if an OBM node name)
if self.id is not None:
_dict['id'] = self.id
# --- end note
if self.connection is not None:
_dict['connection'] = deepcopy(self.connection.to_dict())
else:
_dict['connection'] = {}
try:
_dict['location'] = deepcopy(self.location.to_dict())
except AttributeError:
_dict['location'] = {}
# note: convert userfs to string - training service requires it as string
if hasattr(self.location, 'userfs'):
_dict['location']['userfs'] = str(getattr(self.location, 'userfs', False)).lower()
# end note
return _dict
def __repr__(self):
return str(self._to_dict())
def __str__(self):
return str(self._to_dict())
@classmethod
def _from_dict(cls, _dict: dict) -> 'DataConnection':
"""Create a DataConnection object from dictionary.
:param _dict: a dictionary data structure with information about data connection reference
:type _dict: dict
:return: DataConnection object
:rtype: DataConnection
"""
# TODO: remove S3 implementation
if _dict['type'] == DataConnectionTypes.S3:
warn(message="S3 DataConnection is deprecated! Please use data_asset_id instead.")
data_connection: 'DataConnection' = cls(
connection=S3Connection(
access_key_id=_dict['connection']['access_key_id'],
secret_access_key=_dict['connection']['secret_access_key'],
endpoint_url=_dict['connection']['endpoint_url']
),
location=S3Location(
bucket=_dict['location']['bucket'],
path=_dict['location']['path']
)
)
elif _dict['type'] == DataConnectionTypes.FS:
data_connection: 'DataConnection' = cls(
location=FSLocation._set_path(path=_dict['location']['path'])
)
elif _dict['type'] == DataConnectionTypes.CA:
if _dict['location'].get('file_name') is not None and _dict['location'].get('bucket'):
data_connection: 'DataConnection' = cls(
connection_asset_id=_dict['connection']['id'],
location=S3Location(
bucket=_dict['location']['bucket'],
path=_dict['location']['file_name']
)
)
elif _dict['location'].get('path') is not None and _dict['location'].get('bucket'):
data_connection: 'DataConnection' = cls(
connection_asset_id=_dict['connection']['id'],
location=S3Location(
bucket=_dict['location']['bucket'],
path=_dict['location']['path']
)
)
elif _dict['location'].get('schema_name') and _dict['location'].get('table_name'):
data_connection: 'DataConnection' = cls(
connection_asset_id=_dict['connection']['id'],
location=DatabaseLocation(schema_name=_dict['location']['schema_name'],
table_name=_dict['location']['table_name'],
catalog_name=_dict['location'].get('catalog_name')
)
)
else:
if 'asset_id' in _dict['connection']:
data_connection: 'DataConnection' = cls(
connection=NFSConnection(asset_id=_dict['connection']['asset_id']),
location=NFSLocation(path=_dict['location']['path'])
)
else:
if _dict['location'].get('file_name') is not None:
data_connection: 'DataConnection' = cls(
connection_asset_id=_dict['connection']['id'],
location=NFSLocation(path=_dict['location']['file_name'])
)
else:
data_connection: 'DataConnection' = cls(
connection_asset_id=_dict['connection']['id'],
location=NFSLocation(path=_dict['location']['path'])
)
elif _dict['type'] == DataConnectionTypes.CN:
data_connection: 'DataConnection' = cls(
location=ContainerLocation(path=_dict['location']['path'])
)
else:
data_connection: 'DataConnection' = cls(
location=AssetLocation._set_path(href=_dict['location']['href'])
)
if _dict.get('id'):
data_connection.id = _dict['id']
if _dict['location'].get('userfs'):
if str(_dict['location'].get('userfs', 'false')).lower() in ['true', '1']:
data_connection.location.userfs = True
else:
data_connection.location.userfs = False
return data_connection
def _recreate_holdout(
self,
data: 'DataFrame',
with_holdout_split: bool = True
) -> Union[Tuple['DataFrame', 'DataFrame'], Tuple['DataFrame', 'DataFrame', 'DataFrame', 'DataFrame']]:
"""This method tries to recreate holdout data."""
if self.auto_pipeline_params.get('prediction_columns') is not None:
# timeseries
try_import_autoai_ts_libs()
from autoai_ts_libs.utils.holdout_utils import make_holdout_split
# Note: When lookback window is auto detected there is need to get the detected value from training details
if self.auto_pipeline_params.get('lookback_window') == -1 or self.auto_pipeline_params.get('lookback_window') is None:
ts_metrics = self._wml_client.training.get_details(self.auto_pipeline_params.get('run_id'), _internal=True)['entity']['status']['metrics']
final_ts_state_name = "after_final_pipelines_generation"
for metric in ts_metrics:
if metric['context']['intermediate_model']['process'] == final_ts_state_name:
self.auto_pipeline_params['lookback_window'] = metric['context']['timeseries']['lookback_window']
break
# Note: imputation is not supported
X_train, X_holdout, y_train, y_holdout, _, _, _, _ = make_holdout_split(
dataset=data,
target_columns=self.auto_pipeline_params.get('prediction_columns'),
learning_type="forecasting",
test_size=self.auto_pipeline_params.get('holdout_size'),
lookback_window=self.auto_pipeline_params.get('lookback_window'),
feature_columns=self.auto_pipeline_params.get('feature_columns'),
timestamp_column=self.auto_pipeline_params.get('timestamp_column_name'),
# n_jobs=None,
# tshirt_size=None,
return_only_holdout=False
)
X_columns = self.auto_pipeline_params.get('feature_columns') if self.auto_pipeline_params.get('feature_columns') else self.auto_pipeline_params['prediction_columns']
X_train = DataFrame(X_train, columns=X_columns)
X_holdout = DataFrame(X_holdout, columns=X_columns)
y_train = DataFrame(y_train, columns=self.auto_pipeline_params['prediction_columns'])
y_holdout = DataFrame(y_holdout, columns=self.auto_pipeline_params['prediction_columns'])
return X_train, X_holdout, y_train, y_holdout
elif self.auto_pipeline_params.get('feature_columns') is not None:
# timeseries anomaly detection
try_import_autoai_ts_libs()
from autoai_ts_libs.utils.holdout_utils import make_holdout_split
from autoai_ts_libs.utils.constants import LEARNING_TYPE_TIMESERIES_ANOMALY_PREDICTION
# Note: imputation is not supported
X_train, X_holdout, y_train, y_holdout, _, _, _, _ = make_holdout_split(
dataset=data,
learning_type=LEARNING_TYPE_TIMESERIES_ANOMALY_PREDICTION,
test_size=self.auto_pipeline_params.get('holdout_size'),
# lookback_window=self.auto_pipeline_params.get('lookback_window'),
feature_columns=self.auto_pipeline_params.get('feature_columns'),
timestamp_column=self.auto_pipeline_params.get('timestamp_column_name'),
# n_jobs=None,
# tshirt_size=None,
return_only_holdout=False
)
X_columns = self.auto_pipeline_params['feature_columns']
y_column = ['anomaly_label']
X_train = DataFrame(X_train, columns=X_columns)
X_holdout = DataFrame(X_holdout, columns=X_columns)
y_train = DataFrame(y_train, columns=y_column)
y_holdout = DataFrame(y_holdout, columns=y_column)
return X_train, X_holdout, y_train, y_holdout
else:
if sys.version_info >= (3, 10):
try_import_autoai_libs(minimum_version='1.14.0')
else:
try_import_autoai_libs(minimum_version='1.12.14')
from autoai_libs.utils.holdout_utils import make_holdout_split, numpy_split_on_target_values
from autoai_libs.utils.sampling_utils import numpy_sample_rows
data.replace([np.inf, -np.inf], np.nan, inplace=True)
data.drop_duplicates(inplace=True)
data.dropna(subset=[self.auto_pipeline_params['prediction_column']], inplace=True)
dfy = data[self.auto_pipeline_params['prediction_column']]
data.drop(columns=[self.auto_pipeline_params['prediction_column']], inplace=True)
y_column = [self.auto_pipeline_params['prediction_column']]
X_columns = data.columns
if self._test_data or not with_holdout_split:
return data, dfy
else:
############################
# REMOVE MISSING ROWS #
from autoai_libs.utils.holdout_utils import numpy_remove_missing_target_rows
# Remove (and save) the rows of X and y for which the target variable has missing values
data, dfy, _, _, _, _ = numpy_remove_missing_target_rows(
y=dfy, X=data
)
# End of REMOVE MISSING ROWS #
###################################
#################
# SAMPLING #
# Get a sample of the rows if requested and applicable
# (check for sampling is performed inside this function)
try:
data, dfy, _ = numpy_sample_rows(
X=data,
y=dfy,
train_sample_rows_test_size=self.auto_pipeline_params['train_sample_rows_test_size'],
learning_type=self.auto_pipeline_params['prediction_type'],
return_sampled_indices=True
)
# Note: we have a silent error here (the old core behaviour)
# sampling is not performed as 'train_sample_rows_test_size' is bigger than data rows count
# TODO: can we throw an error instead?
except ValueError as e:
if 'between' in str(e):
pass
else:
raise e
# End of SAMPLING #
########################
# Perform holdout split
try:
X_train, X_holdout, y_train, y_holdout, _, _ = make_holdout_split(
x=data,
y=dfy,
learning_type=self.auto_pipeline_params['prediction_type'],
fairness_info=self.auto_pipeline_params.get('fairness_info', None),
test_size=self.auto_pipeline_params.get('holdout_size') if self.auto_pipeline_params.get('holdout_size') is not None else 0.1,
return_only_holdout=False,
time_ordered_data=self.auto_pipeline_params.get('time_ordered_data')
)
except (TypeError, KeyError):
if self.auto_pipeline_params.get('time_ordered_data'):
warn("Outdated autoai_libs - time_ordered_data parameter is not supported. Please update autoai_libs to version >=1.16.2")
X_train, X_holdout, y_train, y_holdout, _, _ = make_holdout_split(
x=data,
y=dfy,
learning_type=self.auto_pipeline_params['prediction_type'],
fairness_info=self.auto_pipeline_params.get('fairness_info', None),
test_size=self.auto_pipeline_params.get('holdout_size') if self.auto_pipeline_params.get('holdout_size') is not None else 0.1,
return_only_holdout=False
)
X_train = DataFrame(X_train, columns=X_columns)
X_holdout = DataFrame(X_holdout, columns=X_columns)
y_train = DataFrame(y_train, columns=y_column)
y_holdout = DataFrame(y_holdout, columns=y_column)
return X_train, X_holdout, y_train, y_holdout
[docs]
def read(self,
with_holdout_split: bool = False,
csv_separator: str = ',',
excel_sheet: Union[str, int] = None,
encoding: Optional[str] = 'utf-8',
raw: Optional[bool] = False,
binary: Optional[bool] = False,
read_to_file: Optional[str] = None,
number_of_batch_rows: Optional[int] = None,
sampling_type: Optional[str] = None,
sample_size_limit: Optional[int] = None,
sample_rows_limit: Optional[int] = None,
sample_percentage_limit: Optional[float] = None,
**kwargs) -> Union['DataFrame', Tuple['DataFrame', 'DataFrame'], bytes]:
"""Download dataset stored in remote data storage. Returns batch up to 1GB.
:param with_holdout_split: if `True`, data will be split to train and holdout dataset as it was by AutoAI
:type with_holdout_split: bool, optional
:param csv_separator: separator / delimiter for CSV file
:type csv_separator: str, optional
:param excel_sheet: excel file sheet name to use, only use when xlsx file is an input,
support for number of the sheet is deprecated
:type excel_sheet: str, optional
:param encoding: encoding type of the CSV
:type encoding: str, optional
:param raw: if `False` there wil be applied simple data preprocessing (the same as in the backend),
if `True`, data will be not preprocessed
:type raw: bool, optional
:param binary: indicates to retrieve data in binary mode, the result will be a python binary type variable
:type binary: bool, optional
:param read_to_file: stream read data to file under path specified as value of this parameter,
use this parameter to prevent keeping data in-memory
:type read_to_file: str, optional
:param number_of_batch_rows: number of rows to read in each batch when reading from flight connection
:type number_of_batch_rows: int, optional
:param sampling_type: a sampling strategy how to read the data
:type sampling_type: str, optional
:param sample_size_limit: upper limit for overall data that should be downloaded in bytes, default: 1 GB
:type sample_size_limit: int, optional
:param sample_rows_limit: upper limit for overall data that should be downloaded in number of rows
:type sample_rows_limit: int, optional
:param sample_percentage_limit: upper limit for overall data that should be downloaded
in percent of all dataset, this parameter is ignored, when `sampling_type` parameter is set
to `first_n_records`, must be a float number between 0 and 1
:type sample_percentage_limit: float, optional
.. note::
If more than one of: `sample_size_limit`, `sample_rows_limit`, `sample_percentage_limit` are set,
then downloaded data is limited to the lowest threshold.
:return: one of:
- pandas.DataFrame contains dataset from remote data storage : Xy_train
- Tuple[pandas.DataFrame, pandas.DataFrame, pandas.DataFrame, pandas.DataFrame] : X_train, X_holdout, y_train, y_holdout
- Tuple[pandas.DataFrame, pandas.DataFrame] : X_test, y_test containing training data and holdout data from
remote storage
- bytes object, auto holdout split from backend (only train data provided)
**Examples**
.. code-block:: python
train_data_connections = optimizer.get_data_connections()
data = train_data_connections[0].read() # all train data
# or
X_train, X_holdout, y_train, y_holdout = train_data_connections[0].read(with_holdout_split=True) # train and holdout data
User provided train and test data:
.. code-block:: python
optimizer.fit(training_data_reference=[DataConnection],
training_results_reference=DataConnection,
test_data_reference=DataConnection)
test_data_connection = optimizer.get_test_data_connections()
X_test, y_test = test_data_connection.read() # only holdout data
# and
train_data_connections = optimizer.get_data_connections()
data = train_connections[0].read() # only train data
"""
# enables flight automatically for CP4D 4.0.x, 4.5.x
try:
use_flight = kwargs.get(
'use_flight',
bool((self._wml_client is not None or 'USER_ACCESS_TOKEN' in os.environ or 'RUNTIME_ENV_ACCESS_TOKEN_FILE' in os.environ) and
self._wml_client.CPD_version))
except:
use_flight = False
return_data_as_iterator = kwargs.get('return_data_as_iterator', False)
sampling_type = sampling_type if sampling_type is not None else DEFAULT_SAMPLING_TYPE
enable_sampling = kwargs.get('enable_sampling', True)
total_size_limit = sample_size_limit if sample_size_limit is not None else kwargs.get('total_size_limit', DEFAULT_SAMPLE_SIZE_LIMIT)
total_nrows_limit = sample_rows_limit
total_percentage_limit = sample_percentage_limit if sample_percentage_limit is not None else 1.0
# Deprecation of excel_sheet as number:
if isinstance(excel_sheet, int):
warn(
message="Support for excel sheet as number of the sheet (int) is deprecated! Please set excel sheet with name of the sheet.")
flight_parameters = kwargs.get('flight_parameters', {})
impersonate_header = kwargs.get('impersonate_header', None)
if with_holdout_split and self._user_holdout_exists: # when this connection is training one
raise NoAutomatedHoldoutSplit(reason="Experiment was run based on user defined holdout dataset.")
# note: experiment metadata is used only in autogen notebooks
experiment_metadata = kwargs.get('experiment_metadata')
# note: process subsampling stats flag
_return_subsampling_stats = kwargs.get("_return_subsampling_stats", False)
if experiment_metadata is not None:
self.auto_pipeline_params['train_sample_rows_test_size'] = experiment_metadata.get(
'train_sample_rows_test_size')
self.auto_pipeline_params['prediction_column'] = experiment_metadata.get('prediction_column')
self.auto_pipeline_params['prediction_columns'] = experiment_metadata.get('prediction_columns')
self.auto_pipeline_params['holdout_size'] = experiment_metadata.get('holdout_size')
self.auto_pipeline_params['prediction_type'] = experiment_metadata['prediction_type']
self.auto_pipeline_params['fairness_info'] = experiment_metadata.get('fairness_info')
self.auto_pipeline_params['lookback_window'] = experiment_metadata.get('lookback_window')
self.auto_pipeline_params['timestamp_column_name'] = experiment_metadata.get('timestamp_column_name')
self.auto_pipeline_params['feature_columns'] = experiment_metadata.get('feature_columns')
self.auto_pipeline_params['time_ordered_data'] = experiment_metadata.get('time_ordered_data')
# note: check for cloud
if 'training_result_reference' in experiment_metadata:
if isinstance(experiment_metadata['training_result_reference'].location, (S3Location, AssetLocation)):
run_id = experiment_metadata['training_result_reference'].location._training_status.split('/')[-2]
# WMLS
else:
run_id = experiment_metadata['training_result_reference'].location.path.split('/')[-3]
self.auto_pipeline_params['run_id'] = run_id
if self._test_data:
csv_separator = experiment_metadata.get('test_data_csv_separator', csv_separator)
excel_sheet = experiment_metadata.get('test_data_excel_sheet', excel_sheet)
encoding = experiment_metadata.get('test_data_encoding', encoding)
else:
csv_separator = experiment_metadata.get('csv_separator', csv_separator)
excel_sheet = experiment_metadata.get('excel_sheet', excel_sheet)
encoding = experiment_metadata.get('encoding', encoding)
if self.type == DataConnectionTypes.DS or self.type == DataConnectionTypes.CA:
if self._wml_client is None:
try:
from project_lib import Project
except ModuleNotFoundError:
raise ConnectionError(
"This functionality can be run only on Watson Studio or with wml_client passed to connection. "
"Please initialize WML client using `DataConnection.set_client(wml_client)` function "
"to be able to use this functionality.")
if (with_holdout_split or self._test_data) and not self.auto_pipeline_params.get('prediction_type', False):
raise MissingAutoPipelinesParameters(
self.auto_pipeline_params,
reason=f"To be able to recreate an original holdout split, you need to schedule a training job or "
f"if you are using historical runs, just call historical_optimizer.get_data_connections()")
# note: allow to read data at any time
elif (('csv_separator' not in self.auto_pipeline_params and 'encoding' not in self.auto_pipeline_params)
or csv_separator != ',' or encoding != 'utf-8'):
self.auto_pipeline_params['csv_separator'] = csv_separator
self.auto_pipeline_params['encoding'] = encoding
# --- end note
# note: excel_sheet in params only if it is not None (not specified):
if excel_sheet:
self.auto_pipeline_params['excel_sheet'] = excel_sheet
# --- end note
# note: set default quote character for flight (later applicable only for csv files stored in S3)
self.auto_pipeline_params['quote_character'] = 'double_quote'
# --- end note
data = DataFrame()
headers = None
if self._wml_client is None:
token = self._get_token_from_environment()
if token is not None:
headers = {'Authorization': f'Bearer {token}'}
elif impersonate_header is not None:
headers = self._wml_client._get_headers()
headers['impersonate'] = impersonate_header
if self.type == DataConnectionTypes.S3:
raise ConnectionError(
f"S3 DataConnection is deprecated! Please use data_asset_id instead.")
elif self.type == DataConnectionTypes.DS:
if use_flight and not self._obm:
from ibm_watson_machine_learning.utils.utils import is_lib_installed
is_lib_installed(lib_name='pyarrow', minimum_version='3.0.0', install=True)
from pyarrow.flight import FlightError
_iam_id = None
if headers and headers.get('impersonate'):
_iam_id = headers.get('impersonate', {}).get('iam_id')
self._wml_client._iam_id = _iam_id
try:
if self._check_if_connection_asset_is_s3():
# note: update flight parameters only if `connection_properties` was not set earlier
# (e.x. by wml/autoi)
if not flight_parameters.get('connection_properties'):
flight_parameters = self._update_flight_parameters_with_connection_details(flight_parameters)
data = self._download_data_from_flight_service(data_location=self,
binary=binary,
read_to_file=read_to_file,
flight_parameters=flight_parameters,
headers=headers,
enable_sampling=enable_sampling,
sampling_type=sampling_type,
number_of_batch_rows=number_of_batch_rows,
return_data_as_iterator=return_data_as_iterator,
_return_subsampling_stats=_return_subsampling_stats,
total_size_limit=total_size_limit,
total_nrows_limit=total_nrows_limit,
total_percentage_limit=total_percentage_limit
)
except (ConnectionError, FlightError, ApiRequestFailure) as download_data_error:
# note: try to download normal data asset either directly from cams or from mounted NFS
# to keep backward compatibility
if (
self._wml_client
and (
(
self._is_data_asset_normal()
and self._is_size_acceptable()
)
or self._is_data_asset_nfs()
)
and (
"Found non-unique column index"
not in str(download_data_error)
)
):
import warnings
warnings.warn(str(download_data_error), Warning)
data = self._download_training_data_from_data_asset_storage()
else:
raise download_data_error
# backward compatibility
else:
try:
with all_logging_disabled():
if self._check_if_connection_asset_is_s3():
cos_client = self._init_cos_client()
if self._obm:
data = self._download_obm_data_from_cos(cos_client=cos_client)
else:
data = self._download_data_from_cos(cos_client=cos_client, binary=binary)
else:
data = self._download_training_data_from_data_asset_storage()
except NotImplementedError as e:
raise e
except FileNotFoundError as e:
raise e
except Exception as e:
# do not try Flight if we are on the cloud
if self._wml_client is not None:
if not self._wml_client.ICP:
raise e
elif os.environ.get('USER_ACCESS_TOKEN') is None and os.environ.get('RUNTIME_ENV_ACCESS_TOKEN_FILE') is None:
raise CannotReadSavedRemoteDataBeforeFit()
data = self._download_data_from_flight_service(data_location=self,
binary=binary,
read_to_file=read_to_file,
flight_parameters=flight_parameters,
headers=headers,
enable_sampling=enable_sampling,
sampling_type=sampling_type,
number_of_batch_rows=number_of_batch_rows,
return_data_as_iterator=return_data_as_iterator,
_return_subsampling_stats=_return_subsampling_stats,
total_size_limit=total_size_limit,
total_nrows_limit=total_nrows_limit,
total_percentage_limit=total_percentage_limit)
elif self.type == DataConnectionTypes.FS:
if self._obm:
data = self._download_obm_data_from_file_system()
else:
data = self._download_training_data_from_file_system()
elif self.type == DataConnectionTypes.CA or self.type == DataConnectionTypes.CN:
if getattr(self._wml_client, 'ICP', False) and self.type == DataConnectionTypes.CN:
raise ContainerTypeNotSupported() # block Container type on CPD
if use_flight and not self._obm:
# Workaround for container connection type, we need to fetch COS details from space/project
if self.type == DataConnectionTypes.CN:
# note: update flight parameters only if `connection_properties` was not set earlier
# (e.x. by wml/autoi)
if not flight_parameters.get('connection_properties'):
flight_parameters = self._update_flight_parameters_with_connection_details(flight_parameters)
data = self._download_data_from_flight_service(data_location=self,
binary=binary,
read_to_file=read_to_file,
flight_parameters=flight_parameters,
headers=headers,
enable_sampling=enable_sampling,
sampling_type=sampling_type,
number_of_batch_rows=number_of_batch_rows,
return_data_as_iterator=return_data_as_iterator,
_return_subsampling_stats=_return_subsampling_stats,
total_size_limit=total_size_limit,
total_nrows_limit=total_nrows_limit,
total_percentage_limit=total_percentage_limit)
else: # backward compatibility
try:
with all_logging_disabled():
if self._check_if_connection_asset_is_s3():
cos_client = self._init_cos_client()
try:
if self._obm:
data = self._download_obm_data_from_cos(cos_client=cos_client)
else:
data = self._download_data_from_cos(cos_client=cos_client, binary=binary)
except Exception as cos_access_exception:
raise ConnectionError(
f"Unable to access data object in cloud object storage with credentials supplied. "
f"Error: {cos_access_exception}")
else:
data = self._download_data_from_nfs_connection()
except Exception as e:
# do not try Flight is we are on the cloud
if self._wml_client is not None:
if not self._wml_client.ICP:
raise e
elif os.environ.get('USER_ACCESS_TOKEN') is None and os.environ.get('RUNTIME_ENV_ACCESS_TOKEN_FILE') is None:
raise CannotReadSavedRemoteDataBeforeFit()
data = self._download_data_from_flight_service(data_location=self,
binary=binary,
read_to_file=read_to_file,
flight_parameters=flight_parameters,
headers=headers,
enable_sampling=enable_sampling,
sampling_type=sampling_type,
number_of_batch_rows=number_of_batch_rows,
_return_subsampling_stats=_return_subsampling_stats,
total_size_limit=total_size_limit,
total_nrows_limit=total_nrows_limit,
total_percentage_limit=total_percentage_limit)
if getattr(self._wml_client, '_internal', False):
pass # don't remove additional params if client is used internally
else:
# note: remove additional params and inline credentials added by _check_if_connection_asset_is_s3:
[delattr(self.connection, attr) for attr in
['secret_access_key', 'access_key_id', 'endpoint_url', 'cos_type'] if
hasattr(self.connection, attr)]
# end note
# create data statistics if data were not downloaded with flight:
if not isinstance(data, tuple) and _return_subsampling_stats:
data = (data, {"data_batch_size": sys.getsizeof(data),
"data_batch_nrows": len(data)})
if binary:
return data
if raw or (self.auto_pipeline_params.get('prediction_column') is None
and self.auto_pipeline_params.get('prediction_columns') is None
and self.auto_pipeline_params.get('feature_columns') is None):
return data
else:
if with_holdout_split: # when this connection is training one
if return_data_as_iterator:
raise WMLClientError("The flags `return_data_as_iterator` and `with_holdout_split` cannot be set both in the same time.")
if _return_subsampling_stats:
X_train, X_holdout, y_train, y_holdout = self._recreate_holdout(data=data[0])
return X_train, X_holdout, y_train, y_holdout, data[1]
else:
X_train, X_holdout, y_train, y_holdout = self._recreate_holdout(data=data)
return X_train, X_holdout, y_train, y_holdout
else: # when this data connection is a test / holdout one
if return_data_as_iterator:
return data
if _return_subsampling_stats:
if self.auto_pipeline_params.get('prediction_columns') or \
not self.auto_pipeline_params.get('prediction_column') or \
(self.auto_pipeline_params.get('prediction_column') and self.auto_pipeline_params.get(
'prediction_column') not in data[0].columns):
# timeseries dataset does not have prediction columns. Whole data set is returned:
test_X = data
return test_X
else:
test_X, test_y = self._recreate_holdout(data=data[0], with_holdout_split=False)
test_X[self.auto_pipeline_params.get('prediction_column', 'prediction_column')] = test_y
return test_X, data[1]
else: # when this data connection is a test / holdout one and no subsampling stats are needed
if self.auto_pipeline_params.get('prediction_columns') or \
not self.auto_pipeline_params.get('prediction_column') or \
(self.auto_pipeline_params.get('prediction_column') and self.auto_pipeline_params.get(
'prediction_column') not in data.columns):
# timeseries dataset does not have prediction columns. Whole data set is returned:
test_X = data
else:
test_X, test_y = self._recreate_holdout(data=data, with_holdout_split=False)
test_X[self.auto_pipeline_params.get('prediction_column', 'prediction_column')] = test_y
return test_X # return one dataframe
[docs]
def write(self, data: Union[str, 'DataFrame'], remote_name: str = None, **kwargs) -> None:
"""Upload file to a remote data storage.
:param data: local path to the dataset or pandas.DataFrame with data
:type data: str
:param remote_name: name that dataset should be stored with in remote data storage
:type remote_name: str
"""
# enables flight automatically for CP4D 4.0.x
use_flight = kwargs.get(
'use_flight',
bool((self._wml_client is not None or 'USER_ACCESS_TOKEN' in os.environ or 'RUNTIME_ENV_ACCESS_TOKEN_FILE' in os.environ) and
self._wml_client.CPD_version))
flight_parameters = kwargs.get('flight_parameters', {})
impersonate_header = kwargs.get('impersonate_header', None)
headers = None
if self._wml_client is None:
token = self._get_token_from_environment()
if token is None:
raise ConnectionError("WML client missing. Please initialize WML client and pass it to "
"DataConnection._wml_client property to be able to use this functionality.")
else:
headers = {'Authorization': f'Bearer {token}'}
elif impersonate_header is not None:
headers = self._wml_client._get_headers()
headers['impersonate'] = impersonate_header
# TODO: Remove S3 implementation
if self.type == DataConnectionTypes.S3:
raise ConnectionError("S3 DataConnection is deprecated! Please use data_asset_id instead.")
elif self.type == DataConnectionTypes.CA or self.type == DataConnectionTypes.CN:
if getattr(self._wml_client, 'ICP', False) and self.type == DataConnectionTypes.CN:
raise ContainerTypeNotSupported() # block Container type on CPD
if self._check_if_connection_asset_is_s3():
# do not try Flight if we are on the cloud
if self._wml_client is not None and not self._wml_client.ICP and not use_flight: # CLOUD
if remote_name is None and self._to_dict().get('location', {}).get('path'):
updated_remote_name = data.split('/')[-1]
else:
updated_remote_name = self._get_path_with_remote_name(self._to_dict(), remote_name)
cos_resource_client = self._init_cos_client()
if isinstance(data, str):
with open(data, "rb") as file_data:
cos_resource_client.Object(self.location.bucket, updated_remote_name).upload_fileobj(
Fileobj=file_data)
elif isinstance(data, DataFrame):
# note: we are saving csv in memory as a file and stream it to the COS
buffer = io.StringIO()
data.to_csv(buffer, index=False)
buffer.seek(0)
with buffer as f:
cos_resource_client.Object(self.location.bucket, updated_remote_name).upload_fileobj(
Fileobj=io.BytesIO(bytes(f.read().encode())))
else:
raise TypeError("data should be either of type \"str\" or \"pandas.DataFrame\"")
# CP4D
else:
# Workaround for container connection type, we need to fetch COS details from space/project
if self.type == DataConnectionTypes.CN:
# note: update flight parameters only if `connection_properties` was not set earlier
# (e.x. by wml/autoi)
if not flight_parameters.get('connection_properties'):
flight_parameters = self._update_flight_parameters_with_connection_details(flight_parameters)
if isinstance(data, str):
self._upload_data_via_flight_service(file_path=data,
data_location=self,
remote_name=remote_name,
flight_parameters=flight_parameters,
headers=headers)
elif isinstance(data, DataFrame):
# note: we are saving csv in memory as a file and stream it to the COS
self._upload_data_via_flight_service(data=data,
data_location=self,
remote_name=remote_name,
flight_parameters=flight_parameters,
headers=headers)
else:
raise TypeError("data should be either of type \"str\" or \"pandas.DataFrame\"")
else:
if self._wml_client is not None and not self._wml_client.ICP and not use_flight: # CLOUD
raise ConnectionError("Connections other than COS are not supported on a cloud yet.")
# CP4D
else:
if isinstance(data, str):
self._upload_data_via_flight_service(file_path=data,
data_location=self,
remote_name=remote_name,
flight_parameters=flight_parameters,
headers=headers,
binary=kwargs.get('binary', False))
elif isinstance(data, DataFrame):
# note: we are saving csv in memory as a file and stream it to the COS
self._upload_data_via_flight_service(data=data,
data_location=self,
remote_name=remote_name,
flight_parameters=flight_parameters,
headers=headers)
else:
raise TypeError("data should be either of type \"str\" or \"pandas.DataFrame\"")
if getattr(self._wml_client, '_internal', False):
pass # don't remove additional params if client is used internally
else:
# note: remove additional params and inline credentials added by _check_if_connection_asset_is_s3:
[delattr(self.connection, attr) for attr in ['secret_access_key', 'access_key_id', 'endpoint_url', 'cos_type'] if
hasattr(self.connection, attr)]
# end note
elif self.type == DataConnectionTypes.DS:
if self._wml_client is not None and not self._wml_client.ICP and not use_flight: # CLOUD
raise ConnectionError("Write of data for Data Asset is not supported on Cloud.")
elif self._wml_client is not None:
if isinstance(data, str):
self._upload_data_via_flight_service(file_path=data,
data_location=self,
remote_name=remote_name,
flight_parameters=flight_parameters,
headers=headers)
elif isinstance(data, DataFrame):
# note: we are saving csv in memory as a file and stream it to the COS
self._upload_data_via_flight_service(data=data,
data_location=self,
remote_name=remote_name,
flight_parameters=flight_parameters,
headers=headers)
else:
raise TypeError("data should be either of type \"str\" or \"pandas.DataFrame\"")
else:
self._upload_data_via_flight_service(data=data,
data_location=self,
remote_name=remote_name,
flight_parameters=flight_parameters,
headers=headers)
def _init_cos_client(self) -> 'resource':
""" Initiate COS client for further usage. """
from ibm_botocore.client import Config
try:
if hasattr(self.connection, 'auth_endpoint') and hasattr(self.connection, 'api_key'):
cos_client = resource(
service_name='s3',
ibm_api_key_id=self.connection.api_key,
ibm_auth_endpoint=self.connection.auth_endpoint,
config=Config(signature_version="oauth"),
endpoint_url=self.connection.endpoint_url
)
else:
cos_client = resource(
service_name='s3',
endpoint_url=self.connection.endpoint_url,
aws_access_key_id=self.connection.access_key_id,
aws_secret_access_key=self.connection.secret_access_key
)
return cos_client
except ValueError:
if not self.connection.endpoint_url.startswith('https://'):
raise WMLClientError(Messages.get_message(message_id="invalid_endpoint_url"))
def _validate_cos_resource(self):
cos_client = self._init_cos_client()
try:
files = cos_client.Bucket(self.location.bucket).objects.all()
next(x for x in files if x.key == self.location.path)
except Exception as e:
raise NotExistingCOSResource(self.location.bucket, self.location.path)
def _update_flight_parameters_with_connection_details(self, flight_parameters):
with all_logging_disabled():
self._check_if_connection_asset_is_s3()
connection_properties = {
"bucket": self.location.bucket,
"url": self.connection.endpoint_url
}
if hasattr(self.connection, 'auth_endpoint') and hasattr(self.connection, 'api_key'):
connection_properties["iam_url"] = self.connection.auth_endpoint
connection_properties["api_key"] = self.connection.api_key
connection_properties["resource_instance_id"] = self.connection.resource_instance_id
else:
connection_properties["secret_key"] = self.connection.secret_access_key
connection_properties["access_key"] = self.connection.access_key_id
flight_parameters.update({"connection_properties": connection_properties})
flight_parameters.update({"datasource_type": {"entity": {"name": self._datasource_type}}})
return flight_parameters
# TODO: Remove S3 Implementation for connection
class S3Connection(BaseConnection):
"""Connection class to COS data storage in S3 format.
:param endpoint_url: S3 data storage url (COS)
:type endpoint_url: str
:param access_key_id: access_key_id of the S3 connection (COS)
:type access_key_id: str, optional
:param secret_access_key: secret_access_key of the S3 connection (COS)
:type secret_access_key: str, optional
:param api_key: API key of the S3 connection (COS)
:type api_key: str, optional
:param service_name: service name of the S3 connection (COS)
:type service_name: str, optional
:param auth_endpoint: authentication endpoint url of the S3 connection (COS)
:type auth_endpoint: str, optional
"""
def __init__(self, endpoint_url: str, access_key_id: str = None, secret_access_key: str = None,
api_key: str = None, service_name: str = None, auth_endpoint: str = None,
resource_instance_id: str = None, _internal_use=False) -> None:
if not _internal_use:
warn(message="S3 DataConnection is deprecated! Please use data_asset_id instead.")
if (access_key_id is None or secret_access_key is None) and (api_key is None or auth_endpoint is None):
raise InvalidCOSCredentials(reason='You need to specify (access_key_id and secret_access_key) or'
'(api_key and auth_endpoint)')
if secret_access_key is not None:
self.secret_access_key = secret_access_key
if api_key is not None:
self.api_key = api_key
if service_name is not None:
self.service_name = service_name
if auth_endpoint is not None:
self.auth_endpoint = auth_endpoint
if access_key_id is not None:
self.access_key_id = access_key_id
if endpoint_url is not None:
self.endpoint_url = endpoint_url
if resource_instance_id is not None:
self.resource_instance_id = resource_instance_id
[docs]
class S3Location(BaseLocation):
"""Connection class to COS data storage in S3 format.
:param bucket: COS bucket name
:type bucket: str
:param path: COS data path in the bucket
:type path: str
:param excel_sheet: name of excel sheet if pointed dataset is excel file used for Batched Deployment scoring
:type excel_sheet: str, optional
:param model_location: path to the pipeline model in the COS
:type model_location: str, optional
:param training_status: path to the training status json in COS
:type training_status: str, optional
"""
def __init__(self, bucket: str, path: str, **kwargs) -> None:
self.bucket = bucket
self.path = path
if kwargs.get('model_location') is not None:
self._model_location = kwargs['model_location']
if kwargs.get('training_status') is not None:
self._training_status = kwargs['training_status']
if kwargs.get('excel_sheet') is not None:
self.sheet_name = kwargs['excel_sheet']
self.file_format = "xls"
def _get_file_size(self, cos_resource_client: 'resource') -> 'int':
try:
size = cos_resource_client.Object(self.bucket, self.path).content_length
except ClientError:
size = 0
return size
[docs]
def get_location(self) -> str:
if hasattr(self, "file_name"):
return self.file_name
else:
return self.path
class ContainerLocation(BaseLocation):
"""Connection class to default COS in user Project/Space."""
def __init__(self, path: Optional[str] = None, **kwargs) -> None:
if path is None:
self.path = "default_autoai_out"
else:
self.path = path
self.bucket = None
if kwargs.get('model_location') is not None:
self._model_location = kwargs['model_location']
if kwargs.get('training_status') is not None:
self._training_status = kwargs['training_status']
def to_dict(self) -> dict:
_dict = super().to_dict()
if 'bucket' in _dict and _dict['bucket'] is None:
del _dict['bucket']
return _dict
@classmethod
def _set_path(cls, path: str) -> 'ContainerLocation':
location = cls()
location.path = path
return location
def _get_file_size(self):
pass
class FSLocation(BaseLocation):
"""Connection class to File Storage in CP4D."""
def __init__(self, path: Optional[str] = None) -> None:
if path is None:
self.path = "/{option}/{id}" + f"/assets/auto_ml/auto_ml.{uuid.uuid4()}/wml_data"
else:
self.path = path
@classmethod
def _set_path(cls, path: str) -> 'FSLocation':
location = cls()
location.path = path
return location
def _save_file_as_data_asset(self, workspace: 'WorkSpace') -> 'str':
asset_name = self.path.split('/')[-1]
if self.path:
data_asset_details = workspace.wml_client.data_assets.create(asset_name, self.path)
return workspace.wml_client.data_assets.get_uid(data_asset_details)
else:
raise MissingValue('path', reason="Incorrect initialization of class FSLocation")
def _get_file_size(self, workspace: 'WorkSpace') -> 'int':
# note if path is not file then returned size is 0
try:
# note: try to get file size from remote server
url = workspace.wml_client.service_instance._href_definitions.get_wsd_model_attachment_href() \
+ f"/{self.path.split('/assets/')[-1]}"
path_info_response = requests.head(url, headers=workspace.wml_client._get_headers(),
params=workspace.wml_client._params())
if path_info_response.status_code != 200:
raise ApiRequestFailure(u"Failure during getting path details", path_info_response)
path_info = path_info_response.headers
if 'X-Asset-Files-Type' in path_info and path_info['X-Asset-Files-Type'] == 'file':
size = path_info['X-Asset-Files-Size']
else:
size = 0
# -- end note
except (ApiRequestFailure, AttributeError):
# note try get size of file from local fs
size = os.stat(path=self.path).st_size if os.path.isfile(path=self.path) else 0
# -- end note
return size
class AssetLocation(BaseLocation):
def __init__(self, asset_id: str) -> None:
self._wsd = self._is_wsd()
self.href = None
self._initial_asset_id = asset_id
self.__wml_client = None
if self._wsd:
self._asset_name = None
self._asset_id = None
self._local_asset_path = None
else:
self.id = asset_id
def _get_bucket(self, client) -> str:
"""Try to get bucket from data asset."""
connection_id = self._get_connection_id(client)
conn_details = client.connections.get_details(connection_id)
bucket = conn_details.get('entity', {}).get('properties', {}).get('bucket')
if bucket is None:
asset_details = client.data_assets.get_details(self.id)
connection_path = asset_details['entity'].get('folder_asset', {}).get('connection_path')
if connection_path is None:
attachment_content = self._get_attachment_details(client)
connection_path = attachment_content.get('connection_path')
bucket = connection_path.split('/')[1]
return bucket
def _get_attachment_details(self, client) -> dict:
if self.id is None and self.href:
items = self.href.split('/')
self.id = items[-1].split('?')[0]
asset_details = client.data_assets.get_details(self.id)
if 'attachment_id' in asset_details.get('metadata'):
attachment_id = asset_details['metadata']['attachment_id']
else:
attachment_id = asset_details['attachments'][0]['id']
attachment_url = client.service_instance._href_definitions.get_data_asset_href(self.id)
attachment_url = f"{attachment_url}/attachments/{attachment_id}"
if client.ICP:
attachment = requests.get(attachment_url, headers=client._get_headers(),
params=client._params())
else:
attachment = requests.get(attachment_url, headers=client._get_headers(),
params=client._params())
if attachment.status_code != 200:
raise ApiRequestFailure(u"Failure during getting attachment details", attachment)
return attachment.json()
def _get_connection_id(self, client) -> str:
attachment_content = self._get_attachment_details(client)
return attachment_content.get('connection_id')
@classmethod
def _is_wsd(cls):
if os.environ.get('USER_ACCESS_TOKEN') or os.environ.get('RUNTIME_ENV_ACCESS_TOKEN_FILE'):
return False
try:
from project_lib import Project
try:
with all_logging_disabled():
access = Project.access()
return True
except RuntimeError:
pass
except ModuleNotFoundError:
pass
return False
@classmethod
def _set_path(cls, href: str) -> 'AssetLocation':
items = href.split('/')
_id = items[-1].split('?')[0]
location = cls(_id)
location.href = href
return location
def _get_file_size(self, workspace: 'WorkSpace', *args) -> 'int':
if self._wsd:
return self._wsd_get_file_size()
else:
asset_info_response = requests.get(
workspace.wml_client.service_instance._href_definitions.get_data_asset_href(self.id),
params=workspace.wml_client._params(),
headers=workspace.wml_client._get_headers())
if asset_info_response.status_code != 200:
raise ApiRequestFailure(u"Failure during getting asset details", asset_info_response)
return asset_info_response.json()['metadata'].get('size')
def _wsd_setup_local_asset_details(self) -> None:
if not self._wsd:
raise NotWSDEnvironment()
# note: set local asset file from asset_id
project = get_project()
project_id = project.get_metadata()["metadata"]["guid"]
local_assets = project.get_files()
# note: reuse local asset_id when object is reused more times
if self._asset_id is None:
local_asset_id = self._initial_asset_id
else:
local_asset_id = self._asset_id
# --- end note
if local_asset_id not in str(local_assets):
raise MissingLocalAsset(local_asset_id, reason="Provided asset_id cannot be found on WS Desktop.")
else:
for asset in local_assets:
if asset['asset_id'] == local_asset_id:
asset_name = asset['name']
self._asset_name = asset_name
self._asset_id = local_asset_id
local_asset_path = f"{os.path.abspath('.')}/{project_id}/assets/data_asset/{asset_name}"
self._local_asset_path = local_asset_path
def _wsd_move_asset_to_server(self, workspace: 'WorkSpace') -> None:
if not self._wsd:
raise NotWSDEnvironment()
if not self._local_asset_path or self._asset_name or self._asset_id:
self._wsd_setup_local_asset_details()
remote_asset_details = workspace.wml_client.data_assets.create(self._asset_name, self._local_asset_path)
self.href = remote_asset_details['metadata']['href']
def _wsd_get_file_size(self) -> 'int':
if not self._wsd:
raise NotWSDEnvironment()
if not self._local_asset_path or self._asset_name or self._asset_id:
self._wsd_setup_local_asset_details()
return os.stat(path=self._local_asset_path).st_size if os.path.isfile(path=self._local_asset_path) else 0
@classmethod
def list_wsd_assets(cls):
if not cls._is_wsd():
raise NotWSDEnvironment
project = get_project()
return project.get_files()
def to_dict(self) -> dict:
"""Return a json dictionary representing this model."""
_dict = vars(self).copy()
if _dict.get('id', False) is None and _dict.get('href'):
items = self.href.split('/')
_dict['id'] = items[-1].split('?')[0]
del _dict['_wsd']
del _dict[f"_{self.__class__.__name__}__wml_client"]
if self._wsd:
del _dict['_asset_name']
del _dict['_asset_id']
del _dict['_local_asset_path']
del _dict['_initial_asset_id']
return _dict
@property
def wml_client(self):
return self.__wml_client
@wml_client.setter
def wml_client(self, var):
self.__wml_client = var
if self.__wml_client:
self.href = self.__wml_client.service_instance._href_definitions.get_base_asset_href(self._initial_asset_id)
else:
self.href = f'/v2/assets/{self._initial_asset_id}'
if not self._wsd:
if self.__wml_client:
if self.__wml_client.default_space_id:
self.href = f'{self.href}?space_id={self.__wml_client.default_space_id}'
else:
self.href = f'{self.href}?project_id={self.__wml_client.default_project_id}'
class ConnectionAssetLocation(BaseLocation):
"""Connection class to COS data storage.
:param bucket: COS bucket name
:type bucket: str
:param file_name: COS data path in the bucket
:type file_name: str
:param model_location: path to the pipeline model in the COS
:type model_location: str, optional
:param training_status: path to the training status json in COS
:type training_status: str, optional
"""
def __init__(self, bucket: str, file_name: str, **kwargs) -> None:
self.bucket = bucket
self.file_name = file_name
self.path = file_name
if kwargs.get('model_location') is not None:
self._model_location = kwargs['model_location']
if kwargs.get('training_status') is not None:
self._training_status = kwargs['training_status']
def _get_file_size(self, cos_resource_client: 'resource') -> 'int':
try:
size = cos_resource_client.Object(self.bucket, self.path).content_length
except ClientError:
size = 0
return size
def to_dict(self) -> dict:
"""Return a json dictionary representing this model."""
return vars(self)
class ConnectionAsset(BaseConnection):
"""Connection class for Connection Asset.
:param connection_id: connection asset ID
:type connection_id: str
"""
def __init__(self, connection_id: str):
self.id = connection_id
class NFSConnection(BaseConnection):
"""Connection class to file storage in CP4D of NFS format.
:param asset_id: asset ID from the project on CP4D
:type asset_id: str
"""
def __init__(self, asset_id: str):
self.asset_id = asset_id
self.id = asset_id
class NFSLocation(BaseLocation):
"""Location class to file storage in CP4D of NFS format.
:param path: data path form the project on CP4D
:type path: str
"""
def __init__(self, path: str):
self.path = path
self.id = None
self.file_name = None
def _get_file_size(self, workspace: 'Workspace', *args) -> 'int':
params = workspace.wml_client._params().copy()
params['path'] = self.path
params['detail'] = 'true'
href = workspace.wml_client.connections._href_definitions.get_connection_by_id_href(self.id) + '/assets'
asset_info_response = requests.get(href,
params=params, headers=workspace.wml_client._get_headers(None))
if asset_info_response.status_code != 200:
raise Exception(u"Failure during getting asset details", asset_info_response.json())
return asset_info_response.json()['details']['file_size']
def get_location(self) -> str:
if hasattr(self, "file_name"):
return self.file_name
else:
return self.path
class CP4DAssetLocation(AssetLocation):
"""Connection class to data assets in CP4D.
:param asset_id: asset ID from the project on CP4D
:type asset_id: str
"""
def __init__(self, asset_id: str) -> None:
super().__init__(asset_id)
warning_msg = ("Depreciation Warning: Class CP4DAssetLocation is no longer supported and will be removed."
"Use AssetLocation instead.")
print(warning_msg)
def _get_file_size(self, workspace: 'WorkSpace', *args) -> 'int':
return super()._get_file_size(workspace)
class WMLSAssetLocation(AssetLocation):
"""Connection class to data assets in WML Server.
:param asset_id: asset ID of the file loaded on space in WML Server
:type asset_id: str
"""
def __init__(self, asset_id: str) -> None:
super().__init__(asset_id)
warning_msg = ("Depreciation Warning: Class WMLSAssetLocation is no longer supported and will be removed."
"Use AssetLocation instead.")
print(warning_msg)
def _get_file_size(self, workspace: 'WorkSpace', *args) -> 'int':
return super()._get_file_size(workspace)
[docs]
class CloudAssetLocation(AssetLocation):
"""Connection class to data assets as input data references to batch deployment job on Cloud.
:param asset_id: asset ID of the file loaded on space on Cloud
:type asset_id: str
"""
def __init__(self, asset_id: str) -> None:
super().__init__(asset_id)
self.href = self.href
warning_msg = ("Depreciation Warning: Class CloudAssetLocation is no longer supported and will be removed."
"Use AssetLocation instead.")
print(warning_msg)
def _get_file_size(self, workspace: 'WorkSpace', *args) -> 'int':
return super()._get_file_size(workspace)
class WSDAssetLocation(BaseLocation):
"""Connection class to data assets in WS Desktop.
:param asset_id: asset ID from the project on WS Desktop
:type asset_id: str
"""
def __init__(self, asset_id: str) -> None:
self.href = None
self._asset_name = None
self._asset_id = None
self._local_asset_path = None
self._initial_asset_id = asset_id
self.__wml_client = None
warning_msg = ("Depreciation Warning: Class WSDAssetLocation is no longer supported and will be removed."
"Use AssetLocation instead.")
print(warning_msg)
@classmethod
def list_assets(cls):
project = get_project()
return project.get_files()
def _setup_local_asset_details(self) -> None:
# note: set local asset file from asset_id
project = get_project()
project_id = project.get_metadata()["metadata"]["guid"]
local_assets = project.get_files()
# note: reuse local asset_id when object is reused more times
if self._asset_id is None:
local_asset_id = self.href.split('/')[3].split('?space_id')[0]
else:
local_asset_id = self._asset_id
# --- end note
if local_asset_id not in str(local_assets):
raise MissingLocalAsset(local_asset_id, reason="Provided asset_id cannot be found on WS Desktop.")
else:
for asset in local_assets:
if asset['asset_id'] == local_asset_id:
asset_name = asset['name']
self._asset_name = asset_name
self._asset_id = local_asset_id
local_asset_path = f"{os.path.abspath('.')}/{project_id}/assets/data_asset/{asset_name}"
self._local_asset_path = local_asset_path
def _move_asset_to_server(self, workspace: 'WorkSpace') -> None:
if not self._local_asset_path or self._asset_name or self._asset_id:
self._setup_local_asset_details()
remote_asset_details = workspace.wml_client.data_assets.create(self._asset_name, self._local_asset_path)
self.href = remote_asset_details['metadata']['href']
@classmethod
def _set_path(cls, href: str) -> 'WSDAssetLocation':
location = cls('.')
location.href = href
return location
@property
def wml_client(self):
return self.__wml_client
@wml_client.setter
def wml_client(self, var):
self.__wml_client = var
if self.__wml_client:
self.href = self.__wml_client.service_instance._href_definitions.get_base_asset_href(self._initial_asset_id)
else:
self.href = f'/v2/assets/{self._initial_asset_id}'
def to_dict(self) -> dict:
"""Return a json dictionary representing this model."""
_dict = vars(self).copy()
del _dict['_asset_name']
del _dict['_asset_id']
del _dict['_local_asset_path']
del _dict[f"_{self.__class__.__name__}__wml_client"]
del _dict['_initial_asset_id']
return _dict
def _get_file_size(self) -> 'int':
if not self._local_asset_path or self._asset_name or self._asset_id:
self._setup_local_asset_details()
return os.stat(path=self._local_asset_path).st_size if os.path.isfile(path=self._local_asset_path) else 0
[docs]
class DeploymentOutputAssetLocation(BaseLocation):
"""Connection class to data assets where output of batch deployment will be stored.
:param name: name of .csv file which will be saved as data asset
:type name: str
:param description: description of the data asset
:type description: str, optional
"""
def __init__(self, name: str, description: str = "") -> None:
self.name = name
self.description = description
class DatabaseLocation(BaseLocation):
"""Location class to Database.
:param schema_name: database schema name
:type schema_name: str
:param table_name: database table name
:type table_name: str
catalog_name: database catalog name, required only for Presto data source.
:type catalog_name: str, optional
"""
def __init__(self, schema_name: str, table_name: str, catalog_name: str = None, **kwargs) -> None:
self.schema_name = schema_name
self.table_name = table_name
self.catalog_name = catalog_name
def _get_file_size(self) -> None:
raise NotImplementedError()
def to_dict(self) -> dict:
"""Get a json dictionary representing DatabaseLocation."""
return {key: value for key, value in vars(self).items() if value}