Federated Learning#

Federated Learning provides the tools for training a model collaboratively, by coordinating local training runs and fusing the results. Even though data sources are never moved, combined, or shared among parties or the aggregator, all of them contribute to training and improving the quality of the global model.

Tutorial and Samples for IBM Cloud

Tutorial and Samples for IBM Cloud Pak for Data, IBM Watson Machine Learning Server

Aggregation#

The aggregator process, which fuses the parties’ training results, runs as a Watson Machine Learning training job. For more information on creating and querying a training job, see the API documentation for the client.training class. The parameters available to configure a Federated Learning training are described in the IBM Cloud API Docs.

Configure and start aggregation#

from ibm_watson_machine_learning import APIClient

client = APIClient( credentials )

PROJECT_ID = "8ae1a720-83ed-4c57-b719-8dd086bd7ce0"
client.set.default_project( PROJECT_ID )

aggregator_metadata = {
    client.training.ConfigurationMetaNames.NAME: 'Federated Tensorflow MNIST',
    client.training.ConfigurationMetaNames.DESCRIPTION: 'MNIST digit recognition with Federated Learning using Tensorflow',
    client.training.ConfigurationMetaNames.TRAINING_DATA_REFERENCES: [],
    client.training.ConfigurationMetaNames.TRAINING_RESULTS_REFERENCE: {
        'type': 'container',
        'name': 'outputData',
        'connection': {},
        'location': {
           'path': '/projects/' + PROJECT_ID + '/assets/trainings/'
        }

    },
    client.training.ConfigurationMetaNames.FEDERATED_LEARNING: {
        'model': {
            'type': 'tensorflow',
            'spec': {
               'id': untrained_model_id
            },
            'model_file': untrained_model_name
        },
        'fusion_type': 'iter_avg',
        'metrics': 'accuracy',
        'epochs': 3,
        'rounds': 99,
        'remote_training' : {
            'quorum': 1.0,
            'max_timeout': 3600,
            'remote_training_systems': [ { 'id': rts_1_id }, { 'id': rts_2_id} ]
        },
        'hardware_spec': {
            'name': 'S'
        },
        'software_spec': {
            'name': 'runtime-22.1-py3.9'
        }
    }
}


aggregator = client.training.run(aggregator_metadata, asynchronous=True)
aggregator_id = client.training.get_id(aggregator)

Local training#

Training is performed locally by parties which connect to the aggregator. The parties must be members of the project or space in which the aggregator is running and are identified to the Federated Learning aggregator as Remote Training Systems.

Configure and start local training#

from ibm_watson_machine_learning import APIClient

client = APIClient( party_1_credentials )

PROJECT_ID = "8ae1a720-83ed-4c57-b719-8dd086bd7ce0"
client.set.default_project( PROJECT_ID )

# The party needs, at mimimum, to specify how the data are loaded for training.  The data
# handler class and any input to the class is provided.  In this case, the info block
# contains a key to locate the training data from the current working directory.
party_metadata = {
                    client.remote_training_systems.ConfigurationMetaNames.DATA_HANDLER: {
                       "class": MNISTDataHandler,
                       "info": {
                          "npz_file": "./training_data.npz"
                       }
                 }
# The party object is created
party = client.remote_training_systems.create_party(remote_training_system_id = "d516d42c-6c59-41f2-b7ca-c63d11ea79a1", party_metadata)
# Send training logging to standard output
party.monitor_logs()
# Start training.  Training will run in the Python process that is executing this code.
# The supplied aggregator_id refers to the Watson Machine Learning training job that will perform aggregation.
party.run(aggregator_id = "564fb126-9bfd-409b-beb3-5d401e4c50ec", asynchronous = False)
class remote_training_system.RemoteTrainingSystem(client)[source]#

The RemoteTrainingSystem class represents a Federated Learning party and provides a list of identities that are permitted to join training as the RemoteTrainingSystem.

create_party(remote_training_system_id, party_metadata)[source]#

Create a party object using the specified remote training system id and the party metadata.

Parameters:
  • remote_training_system_id (str) – remote training system identifier

  • party_metadata (dict) – the party configuration

Returns:

a party object with the specified rts_id and configuration

Return type:

Party

Examples

party_metadata = {
    wml_client.remote_training_systems.ConfigurationMetaNames.DATA_HANDLER: {
        "info": {
            "npz_file": "./data_party0.npz"
        },
        "name": "MnistTFDataHandler",
        "path": "./mnist_keras_data_handler.py"
    },
    wml_client.remote_training_systems.ConfigurationMetaNames.LOCAL_TRAINING: {
        "name": "LocalTrainingHandler",
        "path": "ibmfl.party.training.local_training_handler"
    },
    wml_client.remote_training_systems.ConfigurationMetaNames.HYPERPARAMS: {
        "epochs": 3
    },
}
party = client.remote_training_systems.create_party(remote_training_system_id, party_metadata)
party_metadata = {
    wml_client.remote_training_systems.ConfigurationMetaNames.DATA_HANDLER: {
        "info": {
            "npz_file": "./data_party0.npz"
        },
        "class": MnistTFDataHandler
    }
}
party = client.remote_training_systems.create_party(remote_training_system_id, party_metadata)
create_revision(remote_training_system_id)[source]#

Create a new remote training system revision.

Parameters:

remote_training_system_id (str) – Unique remote training system ID

Returns:

remote training system details

Return type:

dict

Example

client.remote_training_systems.create_revision(remote_training_system_id)
delete(remote_training_systems_id)[source]#

Deletes the given remote_training_systems_id definition. space_id or project_id has to be provided.

Parameters:

remote_training_systems_id (str) – remote training system identifier

Returns:

status (“SUCCESS” or “FAILED”)

Return type:

str

Example

client.remote_training_systems.delete(remote_training_systems_id='6213cf1-252f-424b-b52d-5cdd9814956c')
get_details(remote_training_system_id=None, limit=None, asynchronous=False, get_all=False)[source]#
Get metadata of the given remote training system. If remote_training_system_id is not specified,

metadata is returned for all remote training systems.

Parameters:
  • remote_training_system_id (str, optional) – remote training system identifier

  • limit (int, optional) – limit number of fetched records

  • asynchronous (bool, optional) – if True, it will work as a generator

  • get_all (bool, optional) – if True, it will get all entries in ‘limited’ chunks

Returns:

remote training system(s) metadata

Return type:

dict (if remote_training_systems_id is not None) or {“resources”: [dict]} (if remote_training_systems_id is None)

Examples

details = client.remote_training_systems.get_details(remote_training_systems_id)
details = client.remote_training_systems.get_details()
details = client.remote_training_systems.get_details(limit=100)
details = client.remote_training_systems.get_details(limit=100, get_all=True)
details = []
for entry in client.remote_training_systems.get_details(limit=100, asynchronous=True, get_all=True):
    details.extend(entry)
static get_id(remote_training_system_details)[source]#

Get ID of remote training system.

Parameters:

remote_training_system_details (dict) – metadata of the stored remote training system

Returns:

ID of stored remote training system

Return type:

str

Example

details = client.remote_training_systems.get_details(remote_training_system_id)
id = client.remote_training_systems.get_id(details)
get_revision_details(remote_training_system_id, rev_id)[source]#

Get metadata from the specific revision of a stored remote system.

Parameters:
  • remote_training_system_id (str) – UID of remote training system

  • rev_id (str) – Unique id of the remote system revision

Returns:

stored remote system revision metadata

Return type:

dict

Example:

details = client.remote_training_systems.get_details(remote_training_system_id, rev_id)
list(limit=None)[source]#

Print stored remote training systems in a table format. If limit is set to None, only the first 50 records are shown.

Parameters:

limit (int) – limit number of fetched records

Example

client.remote_training_systems.list()
list_revisions(remote_training_system_id, limit=None)[source]#

Print all revisions for the given remote_training_system_id in a table format.

Parameters:
  • remote_training_system_id (str) – Unique id of stored remote system

  • limit (int, optional) – limit number of fetched records

Example

client.remote_training_systems.list_revisions(remote_training_system_id)
store(meta_props)[source]#

Create a remote training system. Either space_id or project_id has to be provided.

Parameters:

meta_props (dict) – metadata, to see available meta names use client.remote_training_systems.ConfigurationMetaNames.get()

Returns:

response json

Return type:

dict

Example

metadata = {
    client.remote_training_systems.ConfigurationMetaNames.NAME: "my-resource",
    client.remote_training_systems.ConfigurationMetaNames.TAGS: ["tag1", "tag2"],
    client.remote_training_systems.ConfigurationMetaNames.ORGANIZATION: {"name": "name", "region": "EU"}
    client.remote_training_systems.ConfigurationMetaNames.ALLOWED_IDENTITIES: [{"id": "43689024", "type": "user"}],
    client.remote_training_systems.ConfigurationMetaNames.REMOTE_ADMIN: {"id": "43689020", "type": "user"}
}
client.set.default_space('3fc54cf1-252f-424b-b52d-5cdd9814987f')
details = client.remote_training_systems.store(meta_props=metadata)
update(remote_training_system_id, changes)[source]#

Updates existing remote training system metadata.

Parameters:
  • remote_training_system_id (str) – remote training system identifier

  • changes (dict) – elements which should be changed, where keys are ConfigurationMetaNames

Returns:

updated remote training system details

Return type:

dict

Example

metadata = {
    client.remote_training_systems.ConfigurationMetaNames.NAME:"updated_remote_training_system"
}
details = client.remote_training_systems.update(remote_training_system_id, changes=metadata)
class party_wrapper.Party(client=None, **kwargs)[source]#

The Party class embodies a Federated Learning party, with methods to run, cancel, and query local training. Refer to the client.remote_training_system.create_party() API for more information about creating an instance of the Party class.

cancel()[source]#

Stop the local connection to the training on the party side.

Example

party.cancel()
get_round()[source]#

Get the current round number.

Returns:

the current round number

Return type:

int

Example

party.get_round()
is_running()[source]#

Check if the training job is running.

Returns:

if the job is running

Return type:

bool

Example

party.is_running()
monitor_logs(log_level='INFO')[source]#

Enable logging of the training job to standard output. This method should be called before calling the run() method.

Parameters:

log_level (str, optional) – log level specified by user

Example

party.monitor_logs()
monitor_metrics(metrics_file='-')[source]#

Enable output of training metrics.

Parameters:

metrics_file (str, optional) – a filename specified by user to which the metrics should be written

Note

This method outputs the metrics to stdout if a filename is not specified

Example

party.monitor_metrics()
run(aggregator_id=None, experiment_id=None, asynchronous=True, verify=True, timeout=600)[source]#

Connect to a Federated Learning aggregator and run local training. Exactly one of aggregator_id and experiment_id must be supplied.

Parameters:
  • aggregator_id (str, optional) –

    aggregator identifier

    • If aggregator_id is supplied, the party will connect to the given aggregator.

  • experiment_id (str, optional) –

    experiment identifier

    • If experiment_id is supplied, the party will connect to the most recently created aggregator

      for the experiment.

  • asynchronous (bool, optional) –

    • True - party starts to run the job in the background and progress can be checked later

    • False - method will wait until training is complete and then print the job status

  • verify (bool, optional) – verify certificate

  • timeout (int, or None for no timeout) –

    timeout in seconds

    • If the aggregator is not ready within timeout seconds from now, exit.

Examples

party.run( aggregator_id = "69500105-9fd2-4326-ad27-7231aeb37ac8", asynchronous = True, verify = True )
party.run( experiment_id = "2466fa06-1110-4169-a166-01959adec995", asynchronous = False )