Federated Learning

Federated Learning provides the tools for training a model collaboratively, by coordinating local training runs and fusing the results. Even though data sources are never moved, combined, or shared among parties or the aggregator, all of them contribute to training and improving the quality of the global model.

Tutorial and Samples for IBM watsonx.ai for IBM Cloud

Tutorial and Samples for IBM watsonx.ai software, IBM watsonx.ai Server

Aggregation

The aggregator process, which fuses the parties’ training results, runs as a watsonx.ai training job. For more information on creating and querying a training job, see training. The parameters available to configure a Federated Learning training are described in IBM Cloud API Docs.

Configure and start aggregation

from ibm_watsonx_ai import APIClient

client = APIClient( credentials )

PROJECT_ID = "8ae1a720-83ed-4c57-b719-8dd086bd7ce0"
client.set.default_project( PROJECT_ID )

aggregator_metadata = {
    client.training.ConfigurationMetaNames.NAME: 'Federated Tensorflow MNIST',
    client.training.ConfigurationMetaNames.DESCRIPTION: 'MNIST digit recognition with Federated Learning using Tensorflow',
    client.training.ConfigurationMetaNames.TRAINING_DATA_REFERENCES: [],
    client.training.ConfigurationMetaNames.TRAINING_RESULTS_REFERENCE: {
        'type': 'container',
        'name': 'outputData',
        'connection': {},
        'location': {
           'path': '/projects/' + PROJECT_ID + '/assets/trainings/'
        }

    },
    client.training.ConfigurationMetaNames.FEDERATED_LEARNING: {
        'model': {
            'type': 'tensorflow',
            'spec': {
               'id': untrained_model_id
            },
            'model_file': untrained_model_name
        },
        'fusion_type': 'iter_avg',
        'metrics': 'accuracy',
        'epochs': 3,
        'rounds': 99,
        'remote_training' : {
            'quorum': 1.0,
            'max_timeout': 3600,
            'remote_training_systems': [ { 'id': rts_1_id }, { 'id': rts_2_id} ]
        },
        'hardware_spec': {
            'name': 'S'
        },
        'software_spec': {
            'name': 'runtime-22.1-py3.9'
        }
    }
}


aggregator = client.training.run(aggregator_metadata, asynchronous=True)
aggregator_id = client.training.get_id(aggregator)

Local training

Parties that connect to the aggregator can perform local training.

To perform local training, the parties must be:

  • members of the project or space in which the aggregator is running

  • identified as Remote Training Systems to the Federated Learning aggregator

Configure and start local training

from ibm_watsonx_ai import APIClient

client = APIClient( party_1_credentials )

PROJECT_ID = "8ae1a720-83ed-4c57-b719-8dd086bd7ce0"
client.set.default_project( PROJECT_ID )

# The party needs, at mimimum, to specify how the data are loaded for training.  The data
# handler class and any input to the class is provided.  In this case, the info block
# contains a key to locate the training data from the current working directory.
party_metadata = {
                    client.remote_training_systems.ConfigurationMetaNames.DATA_HANDLER: {
                       "class": MNISTDataHandler,
                       "info": {
                          "npz_file": "./training_data.npz"
                       }
                 }
# The party object is created
party = client.remote_training_systems.create_party(remote_training_system_id = "d516d42c-6c59-41f2-b7ca-c63d11ea79a1", party_metadata)
# Send training logging to standard output
party.monitor_logs()
# Start training.  Training will run in the Python process that is executing this code.
# The supplied aggregator_id refers to the watsonx.ai training job that will perform aggregation.
party.run(aggregator_id = "564fb126-9bfd-409b-beb3-5d401e4c50ec", asynchronous = False)
class remote_training_system.RemoteTrainingSystem(client)[source]

The RemoteTrainingSystem class represents a Federated Learning party and provides a list of identities that are permitted to join training as the RemoteTrainingSystem.

create_party(remote_training_system_id, party_metadata)[source]

Create a party object using the specified remote training system ID and the party metadata.

Parameters:
  • remote_training_system_id (str) – identifier of the remote training system

  • party_metadata (dict) – the party configuration

Returns:

a party object with the specified rts_id and configuration

Return type:

Party

Examples

party_metadata = {
    client.remote_training_systems.ConfigurationMetaNames.DATA_HANDLER: {
        "info": {
            "npz_file": "./data_party0.npz"
        },
        "name": "MnistTFDataHandler",
        "path": "./mnist_keras_data_handler.py"
    },
    client.remote_training_systems.ConfigurationMetaNames.LOCAL_TRAINING: {
        "name": "LocalTrainingHandler",
        "path": "ibmfl.party.training.local_training_handler"
    },
    client.remote_training_systems.ConfigurationMetaNames.HYPERPARAMS: {
        "epochs": 3
    },
}
party = client.remote_training_systems.create_party(remote_training_system_id, party_metadata)
party_metadata = {
    client.remote_training_systems.ConfigurationMetaNames.DATA_HANDLER: {
        "info": {
            "npz_file": "./data_party0.npz"
        },
        "class": MnistTFDataHandler
    }
}
party = client.remote_training_systems.create_party(remote_training_system_id, party_metadata)
create_revision(remote_training_system_id)[source]

Create a new remote training system revision.

Parameters:

remote_training_system_id (str) – unique ID of the remote training system

Returns:

details of the remote training system

Return type:

dict

Example:

client.remote_training_systems.create_revision(remote_training_system_id)
delete(remote_training_systems_id)[source]

Delete the given remote_training_systems_id definition. space_id or project_id has to be provided.

Parameters:

remote_training_systems_id (str) – identifier of the remote training system

Returns:

status (“SUCCESS” or “FAILED”)

Return type:

str

Example:

client.remote_training_systems.delete(remote_training_systems_id='6213cf1-252f-424b-b52d-5cdd9814956c')
get_details(remote_training_system_id=None, limit=None, asynchronous=False, get_all=False)[source]
Get metadata of a given remote training system. If remote_training_system_id is not

metadata is returned for all remote training systems.

Parameters:
  • remote_training_system_id (str, optional) – identifier of the remote training system

  • limit (int, optional) – limit number of fetched records

  • asynchronous (bool, optional) – if True, it will work as a generator

  • get_all (bool, optional) – if True, it will get all entries in ‘limited’ chunks

Returns:

remote training system(s) metadata

Return type:

dict (if remote_training_systems_id is not None) or {“resources”: [dict]} (if remote_training_systems_id is None)

Examples

details = client.remote_training_systems.get_details(remote_training_systems_id)
details = client.remote_training_systems.get_details()
details = client.remote_training_systems.get_details(limit=100)
details = client.remote_training_systems.get_details(limit=100, get_all=True)
details = []
for entry in client.remote_training_systems.get_details(limit=100, asynchronous=True, get_all=True):
    details.extend(entry)
static get_id(remote_training_system_details)[source]

Get the ID of a remote training system.

Parameters:

remote_training_system_details (dict) – metadata of the stored remote training system

Returns:

ID of the stored remote training system

Return type:

str

Example:

details = client.remote_training_systems.get_details(remote_training_system_id)
id = client.remote_training_systems.get_id(details)
get_revision_details(remote_training_system_id, rev_id)[source]

Get metadata from a specific revision of a stored remote system.

Parameters:
  • remote_training_system_id (str) – ID of the remote training system

  • rev_id (str) – unique ID of the remote system revision

Returns:

metadata of the stored remote system revision

Return type:

dict

Example:

details = client.remote_training_systems.get_details(remote_training_system_id, rev_id)
list(limit=None)[source]

Lists stored remote training systems in a table format. If limit is set to None, only the first 50 records are shown.

Parameters:

limit (int, optional) – limit number of fetched records

Returns:

pandas.DataFrame with listed remote training systems

Return type:

pandas.DataFrame

Example:

client.remote_training_systems.list()
list_revisions(remote_training_system_id, limit=None)[source]

Print all revisions for a given remote_training_system_id in a table format.

Parameters:
  • remote_training_system_id (str) – unique ID of the stored remote system

  • limit (int, optional) – limit number of fetched records

Returns:

pandas.DataFrame with listed remote training system revisions

Return type:

pandas.DataFrame

Example:

client.remote_training_systems.list_revisions(remote_training_system_id)
store(meta_props)[source]

Create a remote training system. space_id or project_id has to be provided.

Parameters:

meta_props (dict) – metadata, to see available meta names use client.remote_training_systems.ConfigurationMetaNames.get()

Returns:

response json

Return type:

dict

Example:

metadata = {
    client.remote_training_systems.ConfigurationMetaNames.NAME: "my-resource",
    client.remote_training_systems.ConfigurationMetaNames.TAGS: ["tag1", "tag2"],
    client.remote_training_systems.ConfigurationMetaNames.ORGANIZATION: {"name": "name", "region": "EU"}
    client.remote_training_systems.ConfigurationMetaNames.ALLOWED_IDENTITIES: [{"id": "43689024", "type": "user"}],
    client.remote_training_systems.ConfigurationMetaNames.REMOTE_ADMIN: {"id": "43689020", "type": "user"}
}
client.set.default_space('3fc54cf1-252f-424b-b52d-5cdd9814987f')
details = client.remote_training_systems.store(meta_props=metadata)
update(remote_training_system_id, changes)[source]

Update the existing metadata of a remote training system.

Parameters:
  • remote_training_system_id (str) – identifier of the remote training system

  • changes (dict) – elements to be changed, where keys are ConfigurationMetaNames

Returns:

updated remote training system details

Return type:

dict

Example:

metadata = {
    client.remote_training_systems.ConfigurationMetaNames.NAME:"updated_remote_training_system"
}
details = client.remote_training_systems.update(remote_training_system_id, changes=metadata)
class party_wrapper.Party(client=None, **kwargs)[source]

The Party class embodies a Federated Learning party with methods to run, cancel, and query local training. Refer to the client.remote_training_system.create_party() API for more information about creating an instance of the Party class.

cancel()[source]

Stop the local connection to the training on the party side.

Example:

party.cancel()
get_round()[source]

Get the current round number.

Returns:

the current round number

Return type:

int

Example:

party.get_round()
is_running()[source]

Check if the training job is running.

Returns:

if the job is running

Return type:

bool

Example:

party.is_running()
monitor_logs(log_level='INFO')[source]

Enable logging of the training job to standard output. Call this method before calling the run() method.

Parameters:

log_level (str, optional) – log level specified by user

Example:

party.monitor_logs()
monitor_metrics(metrics_file='-')[source]

Enable output of training metrics.

Parameters:

metrics_file (str, optional) – a filename specified by user to which the metrics should be written

Note

This method outputs the metrics to stdout if a filename is not specified

Example:

party.monitor_metrics()
run(aggregator_id=None, experiment_id=None, asynchronous=True, verify=True, timeout=600)[source]

Connect to a Federated Learning aggregator and run local training. Exactly one of aggregator_id and experiment_id must be supplied.

Parameters:
  • aggregator_id (str, optional) –

    aggregator identifier

    • If aggregator_id is supplied, the party will connect to the given aggregator

  • experiment_id (str, optional) –

    experiment identifier

    • If experiment_id is supplied, the party will connect to the most recently created aggregator

      for the experiment.

  • asynchronous (bool, optional) –

    • True - party starts to run the job in the background and progress can be checked later

    • False - method will wait until training is complete and then print the job status

  • verify (bool, optional) – verify certificate

  • timeout (int, or None for no timeout) –

    timeout in seconds

    • If the aggregator is not ready within a provided number of seconds, there is a timeout.

Examples

party.run( aggregator_id = "69500105-9fd2-4326-ad27-7231aeb37ac8", asynchronous = True, verify = True )
party.run( experiment_id = "2466fa06-1110-4169-a166-01959adec995", asynchronous = False )