# -----------------------------------------------------------------------------------------
# (C) Copyright IBM Corp. 2020-2024.
# https://opensource.org/licenses/BSD-3-Clause
# -----------------------------------------------------------------------------------------
from __future__ import print_function
import ibm_watson_machine_learning._wrappers.requests as requests
from ibm_watson_machine_learning.metanames import RemoteTrainingSystemMetaNames
from ibm_watson_machine_learning.party_wrapper import Party
from ibm_watson_machine_learning.wml_client_error import WMLClientError, UnexpectedType
from ibm_watson_machine_learning.wml_resource import WMLResource
_DEFAULT_LIST_LENGTH = 50
[docs]
class RemoteTrainingSystem(WMLResource):
"""The RemoteTrainingSystem class represents a Federated Learning party and provides a list of identities
that are permitted to join training as the RemoteTrainingSystem.
"""
def __init__(self, client):
WMLResource.__init__(self, __name__, client)
self._client = client
self.ConfigurationMetaNames = RemoteTrainingSystemMetaNames()
[docs]
def store(self, meta_props):
"""Create a remote training system. Either `space_id` or `project_id` has to be provided.
:param meta_props: metadata, to see available meta names use
``client.remote_training_systems.ConfigurationMetaNames.get()``
:type meta_props: dict
:return: response json
:rtype: dict
**Example**
.. code-block:: python
metadata = {
client.remote_training_systems.ConfigurationMetaNames.NAME: "my-resource",
client.remote_training_systems.ConfigurationMetaNames.TAGS: ["tag1", "tag2"],
client.remote_training_systems.ConfigurationMetaNames.ORGANIZATION: {"name": "name", "region": "EU"}
client.remote_training_systems.ConfigurationMetaNames.ALLOWED_IDENTITIES: [{"id": "43689024", "type": "user"}],
client.remote_training_systems.ConfigurationMetaNames.REMOTE_ADMIN: {"id": "43689020", "type": "user"}
}
client.set.default_space('3fc54cf1-252f-424b-b52d-5cdd9814987f')
details = client.remote_training_systems.store(meta_props=metadata)
"""
WMLResource._chk_and_block_create_update_for_python36(self)
if not self._client.CLOUD_PLATFORM_SPACES and not self._client.ICP_PLATFORM_SPACES:
raise WMLClientError("Not supported in this release")
self._client._check_if_either_is_set()
RemoteTrainingSystem._validate_type(meta_props, u'meta_props', dict, True)
self._validate_input(meta_props)
meta = self.ConfigurationMetaNames._generate_resource_metadata(
meta_props,
with_validation=True,
client=self._client)
if self._client.default_space_id is not None:
meta['space_id'] = self._client.default_space_id
elif self._client.default_project_id is not None:
meta['project_id'] = self._client.default_project_id
href = self._client.service_instance._href_definitions.remote_training_systems_href()
creation_response = requests.post(href,
params=self._client._params(),
headers=self._client._get_headers(),
json=meta)
details = self._handle_response(expected_status_code=201,
operationName=u'store remote training system specification',
response=creation_response)
return details
def _validate_input(self, meta_props):
if 'name' not in meta_props:
raise WMLClientError("Its mandatory to provide 'NAME' in meta_props. Example: "
"client.remote_training_systems.ConfigurationMetaNames.NAME")
if 'allowed_identities' not in meta_props:
raise WMLClientError("Its mandatory to provide 'ALLOWED_IDENTITIES' in meta_props. Example: "
"client.remote_training_systems.ConfigurationMetaNames.ALLOWED_IDENTITIES")
if 'organization' in meta_props and 'name' not in meta_props[u'organization']:
raise WMLClientError("Its mandatory to provide 'name' for ORGANIZATION meta_prop. Eg: "
"client.remote_training_systems.ConfigurationMetaNames.ORGANIZATION: "
"{'name': 'org'} ")
[docs]
def delete(self, remote_training_systems_id):
"""Deletes the given `remote_training_systems_id` definition. `space_id` or `project_id` has to be provided.
:param remote_training_systems_id: remote training system identifier
:type remote_training_systems_id: str
:return: status ("SUCCESS" or "FAILED")
:rtype: str
**Example**
.. code-block:: python
client.remote_training_systems.delete(remote_training_systems_id='6213cf1-252f-424b-b52d-5cdd9814956c')
"""
if not self._client.CLOUD_PLATFORM_SPACES and not self._client.ICP_PLATFORM_SPACES:
raise WMLClientError("Not supported in this release")
self._client._check_if_either_is_set()
RemoteTrainingSystem._validate_type(remote_training_systems_id, u'remote_training_systems_id', str, True)
href = self._client.service_instance._href_definitions.remote_training_system_href(remote_training_systems_id)
delete_response = requests.delete(href,
params=self._client._params(),
headers=self._client._get_headers())
details = self._handle_response(expected_status_code=204,
operationName=u'delete remote training system definition',
response=delete_response)
if "SUCCESS" == details:
print("Remote training system deleted")
[docs]
def get_details(self, remote_training_system_id=None, limit=None, asynchronous=False, get_all=False):
"""Get metadata of the given remote training system. If `remote_training_system_id` is not specified,
metadata is returned for all remote training systems.
:param remote_training_system_id: remote training system identifier
:type remote_training_system_id: str, optional
:param limit: limit number of fetched records
:type limit: int, optional
:param asynchronous: if `True`, it will work as a generator
:type asynchronous: bool, optional
:param get_all: if `True`, it will get all entries in 'limited' chunks
:type get_all: bool, optional
:return: remote training system(s) metadata
:rtype: dict (if remote_training_systems_id is not None) or {"resources": [dict]} (if remote_training_systems_id is None)
**Examples**
.. code-block:: python
details = client.remote_training_systems.get_details(remote_training_systems_id)
details = client.remote_training_systems.get_details()
details = client.remote_training_systems.get_details(limit=100)
details = client.remote_training_systems.get_details(limit=100, get_all=True)
details = []
for entry in client.remote_training_systems.get_details(limit=100, asynchronous=True, get_all=True):
details.extend(entry)
"""
if not self._client.CLOUD_PLATFORM_SPACES and not self._client.ICP_PLATFORM_SPACES:
raise WMLClientError("Not supported in this release")
self._client._check_if_either_is_set()
RemoteTrainingSystem._validate_type(remote_training_system_id, u'remote_training_systems_id', str, False)
RemoteTrainingSystem._validate_type(limit, u'limit', int, False)
href = self._client.service_instance._href_definitions.remote_training_systems_href()
if remote_training_system_id is None:
return self._get_artifact_details(href, remote_training_system_id, limit, 'remote_training_systems',
_async=asynchronous, _all=get_all)
else:
return self._get_artifact_details(href, remote_training_system_id, limit, 'remote_training_systems')
[docs]
def list(self, limit=None):
"""Print stored remote training systems in a table format.
If limit is set to None, only the first 50 records are shown.
:param limit: limit number of fetched records
:type limit: int
**Example**
.. code-block:: python
client.remote_training_systems.list()
"""
if not self._client.CLOUD_PLATFORM_SPACES and not self._client.ICP_PLATFORM_SPACES:
raise WMLClientError("Not supported in this release")
self._client._check_if_either_is_set()
resources = self.get_details()[u'resources']
values = [(m[u'metadata'][u'id'],
m[u'metadata'][u'name'],
m[u'metadata'][u'created_at']) for m in resources]
self._list(values, [u'ID', u'NAME', u'CREATED'], limit, _DEFAULT_LIST_LENGTH)
[docs]
@staticmethod
def get_id(remote_training_system_details):
"""Get ID of remote training system.
:param remote_training_system_details: metadata of the stored remote training system
:type remote_training_system_details: dict
:return: ID of stored remote training system
:rtype: str
**Example**
.. code-block:: python
details = client.remote_training_systems.get_details(remote_training_system_id)
id = client.remote_training_systems.get_id(details)
"""
RemoteTrainingSystem._validate_type(remote_training_system_details, u'remote_training_system_details', object, True)
return WMLResource._get_required_element_from_dict(remote_training_system_details,
u'remote_training_system_details',
[u'metadata', u'id'])
[docs]
def update(self, remote_training_system_id, changes):
"""Updates existing remote training system metadata.
:param remote_training_system_id: remote training system identifier
:type remote_training_system_id: str
:param changes: elements which should be changed, where keys are ConfigurationMetaNames
:type changes: dict
:return: updated remote training system details
:rtype: dict
**Example**
.. code-block:: python
metadata = {
client.remote_training_systems.ConfigurationMetaNames.NAME:"updated_remote_training_system"
}
details = client.remote_training_systems.update(remote_training_system_id, changes=metadata)
"""
WMLResource._chk_and_block_create_update_for_python36(self)
if not self._client.CLOUD_PLATFORM_SPACES and not self._client.ICP_PLATFORM_SPACES:
raise WMLClientError("Not supported in this release")
self._client._check_if_either_is_set()
self._validate_type(remote_training_system_id, u'remote_training_system_id', str, True)
self._validate_type(changes, u'changes', dict, True)
details = self.get_details(remote_training_system_id)
patch_payload = self.ConfigurationMetaNames._generate_patch_payload(details['entity'], changes,
with_validation=True)
href = self._client.service_instance._href_definitions.remote_training_system_href(remote_training_system_id)
response = requests.patch(href,
json=patch_payload,
params = self._client._params(),
headers=self._client._get_headers())
updated_details = self._handle_response(200, u'remote training system patch', response)
return updated_details
[docs]
def create_revision(self, remote_training_system_id):
"""Create a new remote training system revision.
:param remote_training_system_id: Unique remote training system ID
:type remote_training_system_id: str
:return: remote training system details
:rtype: dict
**Example**
.. code-block:: python
client.remote_training_systems.create_revision(remote_training_system_id)
"""
if not self._client.CLOUD_PLATFORM_SPACES and not self._client.ICP_PLATFORM_SPACES:
raise WMLClientError("Not supported in this release")
RemoteTrainingSystem._validate_type(remote_training_system_id, u'remote_training_system_id', str, False)
href = self._client.service_instance._href_definitions.remote_training_systems_href()
return self._create_revision_artifact(href, remote_training_system_id, 'remote training system')
[docs]
def get_revision_details(self, remote_training_system_id, rev_id):
"""Get metadata from the specific revision of a stored remote system.
:param remote_training_system_id: UID of remote training system
:type remote_training_system_id: str
:param rev_id: Unique id of the remote system revision
:type rev_id: str
:return: stored remote system revision metadata
:rtype: dict
Example:
.. code-block:: python
details = client.remote_training_systems.get_details(remote_training_system_id, rev_id)
"""
if not self._client.CLOUD_PLATFORM_SPACES and not self._client.ICP_PLATFORM_SPACES:
raise WMLClientError("Not supported in this release")
self._client._check_if_either_is_set()
RemoteTrainingSystem._validate_type(remote_training_system_id, u'remote_training_system_id', str, True)
if not RemoteTrainingSystem._validate_type(rev_id, u'rev_id', [str, int], True):
raise UnexpectedType('rev_id', [str, int], type(rev_id))
elif isinstance(rev_id, int):
from warnings import warn
warn("Parameter rev_id as int is depreacted, should be of type str.")
href = self._client.service_instance._href_definitions.remote_training_system_href(remote_training_system_id)
return self._get_with_or_without_limit(href, limit=None, op_name="remote_training_system_id",
summary=None, pre_defined=None, revision=rev_id)
[docs]
def list_revisions(self, remote_training_system_id, limit=None):
"""Print all revisions for the given remote_training_system_id in a table format.
:param remote_training_system_id: Unique id of stored remote system
:type remote_training_system_id: str
:param limit: limit number of fetched records
:type limit: int, optional
**Example**
.. code-block:: python
client.remote_training_systems.list_revisions(remote_training_system_id)
"""
if not self._client.CLOUD_PLATFORM_SPACES and not self._client.ICP_PLATFORM_SPACES:
raise WMLClientError("Not supported in this release")
self._client._check_if_either_is_set()
RemoteTrainingSystem._validate_type(remote_training_system_id, u'remote_training_system_id', str, True)
href = self._client.service_instance._href_definitions.get_function_href(remote_training_system_id)
resources = self._get_artifact_details(href + '/revisions',
None,
limit,
'remote system revisions')[u'resources']
values = [(m[u'metadata'][u'id'],
m[u'metadata'][u'rev'],
m[u'metadata'][u'name'],
m[u'metadata'][u'created_at']) for m in resources]
self._list(values, [u'ID', u'rev', u'NAME', u'CREATED'], limit, _DEFAULT_LIST_LENGTH)
def _validate_party_input(self, party_metadata):
if "data_handler" not in party_metadata:
raise WMLClientError("Its mandatory to provide 'DATA_HANDLER' in meta_props. Example: "
"client.remote_training_systems.ConfigurationMetaNames.DATA_HANDLER")
[docs]
def create_party(self, remote_training_system_id, party_metadata):
"""Create a party object using the specified remote training system id and the party metadata.
:param remote_training_system_id: remote training system identifier
:type remote_training_system_id: str
:param party_metadata: the party configuration
:type party_metadata: dict
:return: a party object with the specified rts_id and configuration
:rtype: Party
**Examples**
.. code-block:: python
party_metadata = {
wml_client.remote_training_systems.ConfigurationMetaNames.DATA_HANDLER: {
"info": {
"npz_file": "./data_party0.npz"
},
"name": "MnistTFDataHandler",
"path": "./mnist_keras_data_handler.py"
},
wml_client.remote_training_systems.ConfigurationMetaNames.LOCAL_TRAINING: {
"name": "LocalTrainingHandler",
"path": "ibmfl.party.training.local_training_handler"
},
wml_client.remote_training_systems.ConfigurationMetaNames.HYPERPARAMS: {
"epochs": 3
},
}
party = client.remote_training_systems.create_party(remote_training_system_id, party_metadata)
.. code-block:: python
party_metadata = {
wml_client.remote_training_systems.ConfigurationMetaNames.DATA_HANDLER: {
"info": {
"npz_file": "./data_party0.npz"
},
"class": MnistTFDataHandler
}
}
party = client.remote_training_systems.create_party(remote_training_system_id, party_metadata)
"""
RemoteTrainingSystem._validate_type(remote_training_system_id, u'remote_training_system_id', str, True)
RemoteTrainingSystem._validate_type(party_metadata, u'party_metadata', dict, True)
self._validate_party_input(party_metadata)
host = self._client.wml_credentials['url'].split('//')[1]
party_config = {
"aggregator": {
"ip": host
},
"connection": {
"info": {
"id": remote_training_system_id,
}
},
"data": party_metadata[self._client.remote_training_systems.ConfigurationMetaNames.DATA_HANDLER],
"protocol_handler": {
"name": "PartyProtocolHandler",
"path": "ibmfl.party.party_protocol_handler"
}
}
if "local_training" in party_metadata:
party_config["local_training"] = party_metadata[self._client.remote_training_systems.ConfigurationMetaNames.LOCAL_TRAINING]
if "hyperparams" in party_metadata:
party_config["hyperparams"] = party_metadata[self._client.remote_training_systems.ConfigurationMetaNames.HYPERPARAMS]
if "model" in party_metadata:
party_config["model"] = party_metadata[self._client.remote_training_systems.ConfigurationMetaNames.MODEL]
return Party(client=self._client, config_dict=party_config)