Federated Learning#
Federated Learning provides the tools for training a model collaboratively, by coordinating local training runs and fusing the results. Even though data sources are never moved, combined, or shared among parties or the aggregator, all of them contribute to training and improving the quality of the global model.
Tutorial and Samples for IBM Cloud
Tutorial and Samples for IBM Cloud Pak for Data, IBM Watson Machine Learning Server
Aggregation#
The aggregator process, which fuses the parties’ training results, runs as a Watson Machine Learning training job.
For more information on creating and querying a training job, see the API documentation for the client.training
class.
The parameters available to configure a Federated Learning training are described in the IBM Cloud API Docs.
Configure and start aggregation#
from ibm_watson_machine_learning import APIClient
client = APIClient( credentials )
PROJECT_ID = "8ae1a720-83ed-4c57-b719-8dd086bd7ce0"
client.set.default_project( PROJECT_ID )
aggregator_metadata = {
client.training.ConfigurationMetaNames.NAME: 'Federated Tensorflow MNIST',
client.training.ConfigurationMetaNames.DESCRIPTION: 'MNIST digit recognition with Federated Learning using Tensorflow',
client.training.ConfigurationMetaNames.TRAINING_DATA_REFERENCES: [],
client.training.ConfigurationMetaNames.TRAINING_RESULTS_REFERENCE: {
'type': 'container',
'name': 'outputData',
'connection': {},
'location': {
'path': '/projects/' + PROJECT_ID + '/assets/trainings/'
}
},
client.training.ConfigurationMetaNames.FEDERATED_LEARNING: {
'model': {
'type': 'tensorflow',
'spec': {
'id': untrained_model_id
},
'model_file': untrained_model_name
},
'fusion_type': 'iter_avg',
'metrics': 'accuracy',
'epochs': 3,
'rounds': 99,
'remote_training' : {
'quorum': 1.0,
'max_timeout': 3600,
'remote_training_systems': [ { 'id': rts_1_id }, { 'id': rts_2_id} ]
},
'hardware_spec': {
'name': 'S'
},
'software_spec': {
'name': 'runtime-22.1-py3.9'
}
}
}
aggregator = client.training.run(aggregator_metadata, asynchronous=True)
aggregator_id = client.training.get_id(aggregator)
Local training#
Training is performed locally by parties which connect to the aggregator. The parties must be members of the project or space in which the aggregator is running and are identified to the Federated Learning aggregator as Remote Training Systems.
Configure and start local training#
from ibm_watson_machine_learning import APIClient
client = APIClient( party_1_credentials )
PROJECT_ID = "8ae1a720-83ed-4c57-b719-8dd086bd7ce0"
client.set.default_project( PROJECT_ID )
# The party needs, at mimimum, to specify how the data are loaded for training. The data
# handler class and any input to the class is provided. In this case, the info block
# contains a key to locate the training data from the current working directory.
party_metadata = {
client.remote_training_systems.ConfigurationMetaNames.DATA_HANDLER: {
"class": MNISTDataHandler,
"info": {
"npz_file": "./training_data.npz"
}
}
# The party object is created
party = client.remote_training_systems.create_party(remote_training_system_id = "d516d42c-6c59-41f2-b7ca-c63d11ea79a1", party_metadata)
# Send training logging to standard output
party.monitor_logs()
# Start training. Training will run in the Python process that is executing this code.
# The supplied aggregator_id refers to the Watson Machine Learning training job that will perform aggregation.
party.run(aggregator_id = "564fb126-9bfd-409b-beb3-5d401e4c50ec", asynchronous = False)
- class remote_training_system.RemoteTrainingSystem(client)[source]#
The RemoteTrainingSystem class represents a Federated Learning party and provides a list of identities that are permitted to join training as the RemoteTrainingSystem.
- create_party(remote_training_system_id, party_metadata)[source]#
Create a party object using the specified remote training system id and the party metadata.
- Parameters:
remote_training_system_id (str) – remote training system identifier
party_metadata (dict) – the party configuration
- Returns:
a party object with the specified rts_id and configuration
- Return type:
Examples
party_metadata = { wml_client.remote_training_systems.ConfigurationMetaNames.DATA_HANDLER: { "info": { "npz_file": "./data_party0.npz" }, "name": "MnistTFDataHandler", "path": "./mnist_keras_data_handler.py" }, wml_client.remote_training_systems.ConfigurationMetaNames.LOCAL_TRAINING: { "name": "LocalTrainingHandler", "path": "ibmfl.party.training.local_training_handler" }, wml_client.remote_training_systems.ConfigurationMetaNames.HYPERPARAMS: { "epochs": 3 }, } party = client.remote_training_systems.create_party(remote_training_system_id, party_metadata)
party_metadata = { wml_client.remote_training_systems.ConfigurationMetaNames.DATA_HANDLER: { "info": { "npz_file": "./data_party0.npz" }, "class": MnistTFDataHandler } } party = client.remote_training_systems.create_party(remote_training_system_id, party_metadata)
- create_revision(remote_training_system_id)[source]#
Create a new remote training system revision.
- Parameters:
remote_training_system_id (str) – Unique remote training system ID
- Returns:
remote training system details
- Return type:
dict
Example
client.remote_training_systems.create_revision(remote_training_system_id)
- delete(remote_training_systems_id)[source]#
Deletes the given remote_training_systems_id definition. space_id or project_id has to be provided.
- Parameters:
remote_training_systems_id (str) – remote training system identifier
- Returns:
status (“SUCCESS” or “FAILED”)
- Return type:
str
Example
client.remote_training_systems.delete(remote_training_systems_id='6213cf1-252f-424b-b52d-5cdd9814956c')
- get_details(remote_training_system_id=None, limit=None, asynchronous=False, get_all=False)[source]#
- Get metadata of the given remote training system. If remote_training_system_id is not specified,
metadata is returned for all remote training systems.
- Parameters:
remote_training_system_id (str, optional) – remote training system identifier
limit (int, optional) – limit number of fetched records
asynchronous (bool, optional) – if True, it will work as a generator
get_all (bool, optional) – if True, it will get all entries in ‘limited’ chunks
- Returns:
remote training system(s) metadata
- Return type:
dict (if remote_training_systems_id is not None) or {“resources”: [dict]} (if remote_training_systems_id is None)
Examples
details = client.remote_training_systems.get_details(remote_training_systems_id) details = client.remote_training_systems.get_details() details = client.remote_training_systems.get_details(limit=100) details = client.remote_training_systems.get_details(limit=100, get_all=True) details = [] for entry in client.remote_training_systems.get_details(limit=100, asynchronous=True, get_all=True): details.extend(entry)
- static get_id(remote_training_system_details)[source]#
Get ID of remote training system.
- Parameters:
remote_training_system_details (dict) – metadata of the stored remote training system
- Returns:
ID of stored remote training system
- Return type:
str
Example
details = client.remote_training_systems.get_details(remote_training_system_id) id = client.remote_training_systems.get_id(details)
- get_revision_details(remote_training_system_id, rev_id)[source]#
Get metadata from the specific revision of a stored remote system.
- Parameters:
remote_training_system_id (str) – UID of remote training system
rev_id (str) – Unique id of the remote system revision
- Returns:
stored remote system revision metadata
- Return type:
dict
Example:
details = client.remote_training_systems.get_details(remote_training_system_id, rev_id)
- list(limit=None)[source]#
Print stored remote training systems in a table format. If limit is set to None, only the first 50 records are shown.
- Parameters:
limit (int) – limit number of fetched records
Example
client.remote_training_systems.list()
- list_revisions(remote_training_system_id, limit=None)[source]#
Print all revisions for the given remote_training_system_id in a table format.
- Parameters:
remote_training_system_id (str) – Unique id of stored remote system
limit (int, optional) – limit number of fetched records
Example
client.remote_training_systems.list_revisions(remote_training_system_id)
- store(meta_props)[source]#
Create a remote training system. Either space_id or project_id has to be provided.
- Parameters:
meta_props (dict) – metadata, to see available meta names use
client.remote_training_systems.ConfigurationMetaNames.get()
- Returns:
response json
- Return type:
dict
Example
metadata = { client.remote_training_systems.ConfigurationMetaNames.NAME: "my-resource", client.remote_training_systems.ConfigurationMetaNames.TAGS: ["tag1", "tag2"], client.remote_training_systems.ConfigurationMetaNames.ORGANIZATION: {"name": "name", "region": "EU"} client.remote_training_systems.ConfigurationMetaNames.ALLOWED_IDENTITIES: [{"id": "43689024", "type": "user"}], client.remote_training_systems.ConfigurationMetaNames.REMOTE_ADMIN: {"id": "43689020", "type": "user"} } client.set.default_space('3fc54cf1-252f-424b-b52d-5cdd9814987f') details = client.remote_training_systems.store(meta_props=metadata)
- update(remote_training_system_id, changes)[source]#
Updates existing remote training system metadata.
- Parameters:
remote_training_system_id (str) – remote training system identifier
changes (dict) – elements which should be changed, where keys are ConfigurationMetaNames
- Returns:
updated remote training system details
- Return type:
dict
Example
metadata = { client.remote_training_systems.ConfigurationMetaNames.NAME:"updated_remote_training_system" } details = client.remote_training_systems.update(remote_training_system_id, changes=metadata)
- class party_wrapper.Party(client=None, **kwargs)[source]#
The Party class embodies a Federated Learning party, with methods to run, cancel, and query local training. Refer to the
client.remote_training_system.create_party()
API for more information about creating an instance of the Party class.- cancel()[source]#
Stop the local connection to the training on the party side.
Example
party.cancel()
- get_round()[source]#
Get the current round number.
- Returns:
the current round number
- Return type:
int
Example
party.get_round()
- is_running()[source]#
Check if the training job is running.
- Returns:
if the job is running
- Return type:
bool
Example
party.is_running()
- monitor_logs(log_level='INFO')[source]#
Enable logging of the training job to standard output. This method should be called before calling the
run()
method.- Parameters:
log_level (str, optional) – log level specified by user
Example
party.monitor_logs()
- monitor_metrics(metrics_file='-')[source]#
Enable output of training metrics.
- Parameters:
metrics_file (str, optional) – a filename specified by user to which the metrics should be written
Note
This method outputs the metrics to stdout if a filename is not specified
Example
party.monitor_metrics()
- run(aggregator_id=None, experiment_id=None, asynchronous=True, verify=True, timeout=600)[source]#
Connect to a Federated Learning aggregator and run local training. Exactly one of aggregator_id and experiment_id must be supplied.
- Parameters:
aggregator_id (str, optional) –
aggregator identifier
If aggregator_id is supplied, the party will connect to the given aggregator.
experiment_id (str, optional) –
experiment identifier
- If experiment_id is supplied, the party will connect to the most recently created aggregator
for the experiment.
asynchronous (bool, optional) –
True - party starts to run the job in the background and progress can be checked later
False - method will wait until training is complete and then print the job status
verify (bool, optional) – verify certificate
timeout (int, or None for no timeout) –
timeout in seconds
If the aggregator is not ready within timeout seconds from now, exit.
Examples
party.run( aggregator_id = "69500105-9fd2-4326-ad27-7231aeb37ac8", asynchronous = True, verify = True ) party.run( experiment_id = "2466fa06-1110-4169-a166-01959adec995", asynchronous = False )