Source code for remote_training_system

#  -----------------------------------------------------------------------------------------
#  (C) Copyright IBM Corp. 2023-2024.
#  https://opensource.org/licenses/BSD-3-Clause
#  -----------------------------------------------------------------------------------------

from __future__ import annotations

from typing import TYPE_CHECKING

from ibm_watsonx_ai._wrappers import requests

from ibm_watsonx_ai.metanames import RemoteTrainingSystemMetaNames
from ibm_watsonx_ai.party_wrapper import Party
from ibm_watsonx_ai.wml_client_error import WMLClientError
from ibm_watsonx_ai.wml_resource import WMLResource

if TYPE_CHECKING:
    from ibm_watsonx_ai import APIClient
    from pandas import DataFrame

_DEFAULT_LIST_LENGTH = 50


[docs] class RemoteTrainingSystem(WMLResource): """The RemoteTrainingSystem class represents a Federated Learning party and provides a list of identities that are permitted to join training as the RemoteTrainingSystem. """ def __init__(self, client: APIClient): WMLResource.__init__(self, __name__, client) self._client = client self.ConfigurationMetaNames = RemoteTrainingSystemMetaNames()
[docs] def store(self, meta_props: dict) -> dict: """Create a remote training system. `space_id` or `project_id` has to be provided. :param meta_props: metadata, to see available meta names use ``client.remote_training_systems.ConfigurationMetaNames.get()`` :type meta_props: dict :return: response json :rtype: dict **Example:** .. code-block:: python metadata = { client.remote_training_systems.ConfigurationMetaNames.NAME: "my-resource", client.remote_training_systems.ConfigurationMetaNames.TAGS: ["tag1", "tag2"], client.remote_training_systems.ConfigurationMetaNames.ORGANIZATION: {"name": "name", "region": "EU"} client.remote_training_systems.ConfigurationMetaNames.ALLOWED_IDENTITIES: [{"id": "43689024", "type": "user"}], client.remote_training_systems.ConfigurationMetaNames.REMOTE_ADMIN: {"id": "43689020", "type": "user"} } client.set.default_space('3fc54cf1-252f-424b-b52d-5cdd9814987f') details = client.remote_training_systems.store(meta_props=metadata) """ self._client._check_if_either_is_set() RemoteTrainingSystem._validate_type(meta_props, "meta_props", dict, True) self._validate_input(meta_props) meta = self.ConfigurationMetaNames._generate_resource_metadata( meta_props, with_validation=True, client=self._client ) if self._client.default_space_id is not None: meta["space_id"] = self._client.default_space_id elif self._client.default_project_id is not None: meta["project_id"] = self._client.default_project_id href = ( self._client.service_instance._href_definitions.remote_training_systems_href() ) creation_response = requests.post( href, params=self._client._params(), headers=self._client._get_headers(), json=meta, ) details = self._handle_response( expected_status_code=201, operationName="store remote training system specification", response=creation_response, ) return details
def _validate_input(self, meta_props: dict) -> None: if "name" not in meta_props: raise WMLClientError( "Its mandatory to provide 'NAME' in meta_props. Example: " "client.remote_training_systems.ConfigurationMetaNames.NAME" ) if "allowed_identities" not in meta_props: raise WMLClientError( "Its mandatory to provide 'ALLOWED_IDENTITIES' in meta_props. Example: " "client.remote_training_systems.ConfigurationMetaNames.ALLOWED_IDENTITIES" ) if "organization" in meta_props and "name" not in meta_props["organization"]: raise WMLClientError( "Its mandatory to provide 'name' for ORGANIZATION meta_prop. Eg: " "client.remote_training_systems.ConfigurationMetaNames.ORGANIZATION: " "{'name': 'org'} " )
[docs] def delete(self, remote_training_systems_id: str) -> str: """Delete the given `remote_training_systems_id` definition. `space_id` or `project_id` has to be provided. :param remote_training_systems_id: identifier of the remote training system :type remote_training_systems_id: str :return: status ("SUCCESS" or "FAILED") :rtype: str **Example:** .. code-block:: python client.remote_training_systems.delete(remote_training_systems_id='6213cf1-252f-424b-b52d-5cdd9814956c') """ self._client._check_if_either_is_set() RemoteTrainingSystem._validate_type( remote_training_systems_id, "remote_training_systems_id", str, True ) href = ( self._client.service_instance._href_definitions.remote_training_system_href( remote_training_systems_id ) ) delete_response = requests.delete( href, params=self._client._params(), headers=self._client._get_headers() ) status = self._handle_response( expected_status_code=204, operationName="delete remote training system definition", response=delete_response, json_response=False, ) if status == "SUCCESS": print("Remote training system deleted") return status
[docs] def get_details( self, remote_training_system_id: str | None = None, limit: int | None = None, asynchronous: bool = False, get_all: bool = False, ) -> dict: """Get metadata of a given remote training system. If `remote_training_system_id` is not metadata is returned for all remote training systems. :param remote_training_system_id: identifier of the remote training system :type remote_training_system_id: str, optional :param limit: limit number of fetched records :type limit: int, optional :param asynchronous: if `True`, it will work as a generator :type asynchronous: bool, optional :param get_all: if `True`, it will get all entries in 'limited' chunks :type get_all: bool, optional :return: remote training system(s) metadata :rtype: dict (if remote_training_systems_id is not None) or {"resources": [dict]} (if remote_training_systems_id is None) **Examples** .. code-block:: python details = client.remote_training_systems.get_details(remote_training_systems_id) details = client.remote_training_systems.get_details() details = client.remote_training_systems.get_details(limit=100) details = client.remote_training_systems.get_details(limit=100, get_all=True) details = [] for entry in client.remote_training_systems.get_details(limit=100, asynchronous=True, get_all=True): details.extend(entry) """ self._client._check_if_either_is_set() RemoteTrainingSystem._validate_type( remote_training_system_id, "remote_training_systems_id", str, False ) RemoteTrainingSystem._validate_type(limit, "limit", int, False) href = ( self._client.service_instance._href_definitions.remote_training_systems_href() ) if remote_training_system_id is None: return self._get_artifact_details( href, remote_training_system_id, limit, "remote_training_systems", _async=asynchronous, _all=get_all, ) else: return self._get_artifact_details( href, remote_training_system_id, limit, "remote_training_systems" )
[docs] def list(self, limit: int | None = None) -> DataFrame: """Lists stored remote training systems in a table format. If limit is set to None, only the first 50 records are shown. :param limit: limit number of fetched records :type limit: int, optional :return: pandas.DataFrame with listed remote training systems :rtype: pandas.DataFrame **Example:** .. code-block:: python client.remote_training_systems.list() """ self._client._check_if_either_is_set() resources = self.get_details()["resources"] values = [ (m["metadata"]["id"], m["metadata"]["name"], m["metadata"]["created_at"]) for m in resources ] return self._list( values, ["ID", "NAME", "CREATED"], limit, _DEFAULT_LIST_LENGTH )
[docs] @staticmethod def get_id(remote_training_system_details: dict) -> str: """Get the ID of a remote training system. :param remote_training_system_details: metadata of the stored remote training system :type remote_training_system_details: dict :return: ID of the stored remote training system :rtype: str **Example:** .. code-block:: python details = client.remote_training_systems.get_details(remote_training_system_id) id = client.remote_training_systems.get_id(details) """ RemoteTrainingSystem._validate_type( remote_training_system_details, "remote_training_system_details", object, True, ) return WMLResource._get_required_element_from_dict( remote_training_system_details, "remote_training_system_details", ["metadata", "id"], )
[docs] def update(self, remote_training_system_id: str, changes: dict) -> dict: """Update the existing metadata of a remote training system. :param remote_training_system_id: identifier of the remote training system :type remote_training_system_id: str :param changes: elements to be changed, where keys are ConfigurationMetaNames :type changes: dict :return: updated remote training system details :rtype: dict **Example:** .. code-block:: python metadata = { client.remote_training_systems.ConfigurationMetaNames.NAME:"updated_remote_training_system" } details = client.remote_training_systems.update(remote_training_system_id, changes=metadata) """ self._client._check_if_either_is_set() self._validate_type( remote_training_system_id, "remote_training_system_id", str, True ) self._validate_type(changes, "changes", dict, True) details = self.get_details(remote_training_system_id) patch_payload = self.ConfigurationMetaNames._generate_patch_payload( details["entity"], changes, with_validation=True ) href = ( self._client.service_instance._href_definitions.remote_training_system_href( remote_training_system_id ) ) response = requests.patch( href, json=patch_payload, params=self._client._params(), headers=self._client._get_headers(), ) updated_details = self._handle_response( 200, "remote training system patch", response ) return updated_details
[docs] def create_revision(self, remote_training_system_id: str) -> dict: """Create a new remote training system revision. :param remote_training_system_id: unique ID of the remote training system :type remote_training_system_id: str :return: details of the remote training system :rtype: dict **Example:** .. code-block:: python client.remote_training_systems.create_revision(remote_training_system_id) """ RemoteTrainingSystem._validate_type( remote_training_system_id, "remote_training_system_id", str, False ) href = ( self._client.service_instance._href_definitions.remote_training_systems_href() ) return self._create_revision_artifact( href, remote_training_system_id, "remote training system" )
[docs] def get_revision_details(self, remote_training_system_id: str, rev_id: str) -> dict: """Get metadata from a specific revision of a stored remote system. :param remote_training_system_id: ID of the remote training system :type remote_training_system_id: str :param rev_id: unique ID of the remote system revision :type rev_id: str :return: metadata of the stored remote system revision :rtype: dict Example: .. code-block:: python details = client.remote_training_systems.get_details(remote_training_system_id, rev_id) """ self._client._check_if_either_is_set() RemoteTrainingSystem._validate_type( remote_training_system_id, "remote_training_system_id", str, True ) RemoteTrainingSystem._validate_type(rev_id, "rev_id", str, True) href = ( self._client.service_instance._href_definitions.remote_training_system_href( remote_training_system_id ) ) return self._get_with_or_without_limit( href, limit=None, op_name="remote_training_system_id", summary=None, pre_defined=None, revision=rev_id, )
[docs] def list_revisions( self, remote_training_system_id: str, limit: int | None = None ) -> DataFrame: """Print all revisions for a given remote_training_system_id in a table format. :param remote_training_system_id: unique ID of the stored remote system :type remote_training_system_id: str :param limit: limit number of fetched records :type limit: int, optional :return: pandas.DataFrame with listed remote training system revisions :rtype: pandas.DataFrame **Example:** .. code-block:: python client.remote_training_systems.list_revisions(remote_training_system_id) """ self._client._check_if_either_is_set() RemoteTrainingSystem._validate_type( remote_training_system_id, "remote_training_system_id", str, True ) href = self._client.service_instance._href_definitions.get_function_href( remote_training_system_id ) resources = self._get_artifact_details( href + "/revisions", None, limit, "remote system revisions" )["resources"] values = [ ( m["metadata"]["id"], m["metadata"]["rev"], m["metadata"]["name"], m["metadata"]["created_at"], ) for m in resources ] return self._list( values, ["ID", "REV", "NAME", "CREATED"], limit, _DEFAULT_LIST_LENGTH )
def _validate_party_input(self, party_metadata: dict) -> None: if "data_handler" not in party_metadata: raise WMLClientError( "Its mandatory to provide 'DATA_HANDLER' in meta_props. Example: " "client.remote_training_systems.ConfigurationMetaNames.DATA_HANDLER" )
[docs] def create_party( self, remote_training_system_id: str, party_metadata: dict ) -> Party: """Create a party object using the specified remote training system ID and the party metadata. :param remote_training_system_id: identifier of the remote training system :type remote_training_system_id: str :param party_metadata: the party configuration :type party_metadata: dict :return: a party object with the specified rts_id and configuration :rtype: Party **Examples** .. code-block:: python party_metadata = { client.remote_training_systems.ConfigurationMetaNames.DATA_HANDLER: { "info": { "npz_file": "./data_party0.npz" }, "name": "MnistTFDataHandler", "path": "./mnist_keras_data_handler.py" }, client.remote_training_systems.ConfigurationMetaNames.LOCAL_TRAINING: { "name": "LocalTrainingHandler", "path": "ibmfl.party.training.local_training_handler" }, client.remote_training_systems.ConfigurationMetaNames.HYPERPARAMS: { "epochs": 3 }, } party = client.remote_training_systems.create_party(remote_training_system_id, party_metadata) .. code-block:: python party_metadata = { client.remote_training_systems.ConfigurationMetaNames.DATA_HANDLER: { "info": { "npz_file": "./data_party0.npz" }, "class": MnistTFDataHandler } } party = client.remote_training_systems.create_party(remote_training_system_id, party_metadata) """ RemoteTrainingSystem._validate_type( remote_training_system_id, "remote_training_system_id", str, True ) RemoteTrainingSystem._validate_type( party_metadata, "party_metadata", dict, True ) self._validate_party_input(party_metadata) host = self._client.credentials.url.split("//")[1] party_config = { "aggregator": {"ip": host}, "connection": { "info": { "id": remote_training_system_id, } }, "data": party_metadata[ self._client.remote_training_systems.ConfigurationMetaNames.DATA_HANDLER ], "protocol_handler": { "name": "PartyProtocolHandler", "path": "ibmfl.party.party_protocol_handler", }, } if "local_training" in party_metadata: party_config["local_training"] = party_metadata[ self._client.remote_training_systems.ConfigurationMetaNames.LOCAL_TRAINING ] if "hyperparams" in party_metadata: party_config["hyperparams"] = party_metadata[ self._client.remote_training_systems.ConfigurationMetaNames.HYPERPARAMS ] if "model" in party_metadata: party_config["model"] = party_metadata[ self._client.remote_training_systems.ConfigurationMetaNames.MODEL ] return Party(client=self._client, config_dict=party_config)