Federated Learning¶
Federated Learning provides the tools for training a model collaboratively, by coordinating local training runs and fusing the results. Even though data sources are never moved, combined, or shared among parties or the aggregator, all of them contribute to training and improving the quality of the global model.
Tutorial and Samples for IBM watsonx.ai for IBM Cloud
Tutorial and Samples for IBM watsonx.ai software, IBM watsonx.ai Server
Aggregation¶
The aggregator process, which fuses the parties’ training results, runs as a watsonx.ai training job. For more information on creating and querying a training job, see training. The parameters available to configure a Federated Learning training are described in IBM Cloud API Docs.
Configure and start aggregation¶
from ibm_watsonx_ai import APIClient
client = APIClient( credentials )
PROJECT_ID = "8ae1a720-83ed-4c57-b719-8dd086bd7ce0"
client.set.default_project( PROJECT_ID )
aggregator_metadata = {
client.training.ConfigurationMetaNames.NAME: 'Federated Tensorflow MNIST',
client.training.ConfigurationMetaNames.DESCRIPTION: 'MNIST digit recognition with Federated Learning using Tensorflow',
client.training.ConfigurationMetaNames.TRAINING_DATA_REFERENCES: [],
client.training.ConfigurationMetaNames.TRAINING_RESULTS_REFERENCE: {
'type': 'container',
'name': 'outputData',
'connection': {},
'location': {
'path': '/projects/' + PROJECT_ID + '/assets/trainings/'
}
},
client.training.ConfigurationMetaNames.FEDERATED_LEARNING: {
'model': {
'type': 'tensorflow',
'spec': {
'id': untrained_model_id
},
'model_file': untrained_model_name
},
'fusion_type': 'iter_avg',
'metrics': 'accuracy',
'epochs': 3,
'rounds': 99,
'remote_training' : {
'quorum': 1.0,
'max_timeout': 3600,
'remote_training_systems': [ { 'id': rts_1_id }, { 'id': rts_2_id} ]
},
'hardware_spec': {
'name': 'S'
},
'software_spec': {
'name': 'runtime-22.1-py3.9'
}
}
}
aggregator = client.training.run(aggregator_metadata, asynchronous=True)
aggregator_id = client.training.get_id(aggregator)
Local training¶
Parties that connect to the aggregator can perform local training.
To perform local training, the parties must be:
members of the project or space in which the aggregator is running
identified as Remote Training Systems to the Federated Learning aggregator
Configure and start local training¶
from ibm_watsonx_ai import APIClient
client = APIClient( party_1_credentials )
PROJECT_ID = "8ae1a720-83ed-4c57-b719-8dd086bd7ce0"
client.set.default_project( PROJECT_ID )
# The party needs, at mimimum, to specify how the data are loaded for training. The data
# handler class and any input to the class is provided. In this case, the info block
# contains a key to locate the training data from the current working directory.
party_metadata = {
client.remote_training_systems.ConfigurationMetaNames.DATA_HANDLER: {
"class": MNISTDataHandler,
"info": {
"npz_file": "./training_data.npz"
}
}
# The party object is created
party = client.remote_training_systems.create_party(remote_training_system_id = "d516d42c-6c59-41f2-b7ca-c63d11ea79a1", party_metadata)
# Send training logging to standard output
party.monitor_logs()
# Start training. Training will run in the Python process that is executing this code.
# The supplied aggregator_id refers to the watsonx.ai training job that will perform aggregation.
party.run(aggregator_id = "564fb126-9bfd-409b-beb3-5d401e4c50ec", asynchronous = False)
- class remote_training_system.RemoteTrainingSystem(client)[source]¶
The RemoteTrainingSystem class represents a Federated Learning party and provides a list of identities that are permitted to join training as the RemoteTrainingSystem.
- create_party(remote_training_system_id, party_metadata)[source]¶
Create a party object using the specified remote training system ID and the party metadata.
- Parameters:
remote_training_system_id (str) – identifier of the remote training system
party_metadata (dict) – the party configuration
- Returns:
a party object with the specified rts_id and configuration
- Return type:
Examples
party_metadata = { client.remote_training_systems.ConfigurationMetaNames.DATA_HANDLER: { "info": { "npz_file": "./data_party0.npz" }, "name": "MnistTFDataHandler", "path": "./mnist_keras_data_handler.py" }, client.remote_training_systems.ConfigurationMetaNames.LOCAL_TRAINING: { "name": "LocalTrainingHandler", "path": "ibmfl.party.training.local_training_handler" }, client.remote_training_systems.ConfigurationMetaNames.HYPERPARAMS: { "epochs": 3 }, } party = client.remote_training_systems.create_party(remote_training_system_id, party_metadata)
party_metadata = { client.remote_training_systems.ConfigurationMetaNames.DATA_HANDLER: { "info": { "npz_file": "./data_party0.npz" }, "class": MnistTFDataHandler } } party = client.remote_training_systems.create_party(remote_training_system_id, party_metadata)
- create_revision(remote_training_system_id)[source]¶
Create a new remote training system revision.
- Parameters:
remote_training_system_id (str) – unique ID of the remote training system
- Returns:
details of the remote training system
- Return type:
dict
Example:
client.remote_training_systems.create_revision(remote_training_system_id)
- delete(remote_training_systems_id)[source]¶
Delete the given remote_training_systems_id definition. space_id or project_id has to be provided.
- Parameters:
remote_training_systems_id (str) – identifier of the remote training system
- Returns:
status (“SUCCESS” or “FAILED”)
- Return type:
str
Example:
client.remote_training_systems.delete(remote_training_systems_id='6213cf1-252f-424b-b52d-5cdd9814956c')
- get_details(remote_training_system_id=None, limit=None, asynchronous=False, get_all=False)[source]¶
- Get metadata of a given remote training system. If remote_training_system_id is not
metadata is returned for all remote training systems.
- Parameters:
remote_training_system_id (str, optional) – identifier of the remote training system
limit (int, optional) – limit number of fetched records
asynchronous (bool, optional) – if True, it will work as a generator
get_all (bool, optional) – if True, it will get all entries in ‘limited’ chunks
- Returns:
remote training system(s) metadata
- Return type:
dict (if remote_training_systems_id is not None) or {“resources”: [dict]} (if remote_training_systems_id is None)
Examples
details = client.remote_training_systems.get_details(remote_training_systems_id) details = client.remote_training_systems.get_details() details = client.remote_training_systems.get_details(limit=100) details = client.remote_training_systems.get_details(limit=100, get_all=True) details = [] for entry in client.remote_training_systems.get_details(limit=100, asynchronous=True, get_all=True): details.extend(entry)
- static get_id(remote_training_system_details)[source]¶
Get the ID of a remote training system.
- Parameters:
remote_training_system_details (dict) – metadata of the stored remote training system
- Returns:
ID of the stored remote training system
- Return type:
str
Example:
details = client.remote_training_systems.get_details(remote_training_system_id) id = client.remote_training_systems.get_id(details)
- get_revision_details(remote_training_system_id, rev_id)[source]¶
Get metadata from a specific revision of a stored remote system.
- Parameters:
remote_training_system_id (str) – ID of the remote training system
rev_id (str) – unique ID of the remote system revision
- Returns:
metadata of the stored remote system revision
- Return type:
dict
Example:
details = client.remote_training_systems.get_details(remote_training_system_id, rev_id)
- list(limit=None)[source]¶
Lists stored remote training systems in a table format. If limit is set to None, only the first 50 records are shown.
- Parameters:
limit (int, optional) – limit number of fetched records
- Returns:
pandas.DataFrame with listed remote training systems
- Return type:
pandas.DataFrame
Example:
client.remote_training_systems.list()
- list_revisions(remote_training_system_id, limit=None)[source]¶
Print all revisions for a given remote_training_system_id in a table format.
- Parameters:
remote_training_system_id (str) – unique ID of the stored remote system
limit (int, optional) – limit number of fetched records
- Returns:
pandas.DataFrame with listed remote training system revisions
- Return type:
pandas.DataFrame
Example:
client.remote_training_systems.list_revisions(remote_training_system_id)
- store(meta_props)[source]¶
Create a remote training system. space_id or project_id has to be provided.
- Parameters:
meta_props (dict) – metadata, to see available meta names use
client.remote_training_systems.ConfigurationMetaNames.get()
- Returns:
response json
- Return type:
dict
Example:
metadata = { client.remote_training_systems.ConfigurationMetaNames.NAME: "my-resource", client.remote_training_systems.ConfigurationMetaNames.TAGS: ["tag1", "tag2"], client.remote_training_systems.ConfigurationMetaNames.ORGANIZATION: {"name": "name", "region": "EU"} client.remote_training_systems.ConfigurationMetaNames.ALLOWED_IDENTITIES: [{"id": "43689024", "type": "user"}], client.remote_training_systems.ConfigurationMetaNames.REMOTE_ADMIN: {"id": "43689020", "type": "user"} } client.set.default_space('3fc54cf1-252f-424b-b52d-5cdd9814987f') details = client.remote_training_systems.store(meta_props=metadata)
- update(remote_training_system_id, changes)[source]¶
Update the existing metadata of a remote training system.
- Parameters:
remote_training_system_id (str) – identifier of the remote training system
changes (dict) – elements to be changed, where keys are ConfigurationMetaNames
- Returns:
updated remote training system details
- Return type:
dict
Example:
metadata = { client.remote_training_systems.ConfigurationMetaNames.NAME:"updated_remote_training_system" } details = client.remote_training_systems.update(remote_training_system_id, changes=metadata)
- class party_wrapper.Party(client=None, **kwargs)[source]¶
The Party class embodies a Federated Learning party with methods to run, cancel, and query local training. Refer to the
client.remote_training_system.create_party()
API for more information about creating an instance of the Party class.- cancel()[source]¶
Stop the local connection to the training on the party side.
Example:
party.cancel()
- get_round()[source]¶
Get the current round number.
- Returns:
the current round number
- Return type:
int
Example:
party.get_round()
- is_running()[source]¶
Check if the training job is running.
- Returns:
if the job is running
- Return type:
bool
Example:
party.is_running()
- monitor_logs(log_level='INFO')[source]¶
Enable logging of the training job to standard output. Call this method before calling the
run()
method.- Parameters:
log_level (str, optional) – log level specified by user
Example:
party.monitor_logs()
- monitor_metrics(metrics_file='-')[source]¶
Enable output of training metrics.
- Parameters:
metrics_file (str, optional) – a filename specified by user to which the metrics should be written
Note
This method outputs the metrics to stdout if a filename is not specified
Example:
party.monitor_metrics()
- run(aggregator_id=None, experiment_id=None, asynchronous=True, verify=True, timeout=600)[source]¶
Connect to a Federated Learning aggregator and run local training. Exactly one of aggregator_id and experiment_id must be supplied.
- Parameters:
aggregator_id (str, optional) –
aggregator identifier
If aggregator_id is supplied, the party will connect to the given aggregator
experiment_id (str, optional) –
experiment identifier
- If experiment_id is supplied, the party will connect to the most recently created aggregator
for the experiment.
asynchronous (bool, optional) –
True - party starts to run the job in the background and progress can be checked later
False - method will wait until training is complete and then print the job status
verify (bool, optional) – verify certificate
timeout (int, or None for no timeout) –
timeout in seconds
If the aggregator is not ready within a provided number of seconds, there is a timeout.
Examples
party.run( aggregator_id = "69500105-9fd2-4326-ad27-7231aeb37ac8", asynchronous = True, verify = True ) party.run( experiment_id = "2466fa06-1110-4169-a166-01959adec995", asynchronous = False )