#!/usr/bin/env python3
# -----------------------------------------------------------------------------------------
# (C) Copyright IBM Corp. 2023-2025.
# https://opensource.org/licenses/BSD-3-Clause
# -----------------------------------------------------------------------------------------
from __future__ import annotations
import importlib.util
import json
import logging
import os
import platform
import sys
from pathlib import Path
from typing import TYPE_CHECKING, Any
import ibm_watsonx_ai._wrappers.requests as requests
import requests as req
from ibm_watsonx_ai.utils.utils import get_module_version
from ibm_watsonx_ai.wml_client_error import ApiRequestFailure
from ibm_watsonx_ai.wml_resource import WMLResource
if TYPE_CHECKING:
from ibm_watsonx_ai import APIClient
from ibmfl.party.party import Party as IBMFL_Party
logger = logging.getLogger(__name__)
CRYPTO_LIBRARY = "pyhelayers"
def is_crypto_supported() -> bool:
# pyhelayers is used for client side encryption
# Test is pyhelayers is installed, and if so we assume crypto is supported
pyhelayers_spec = importlib.util.find_spec(CRYPTO_LIBRARY)
return pyhelayers_spec is not None
def import_diff(module_file_path):
if "runtime-23-1" in module_file_path:
return
pathlist = Path(module_file_path).rglob("*.py")
for path in pathlist:
# because path is object not string
path_in_str = str(path)
if not path_in_str.endswith("__init__.py"):
module_name = (
"ibmfl" + path_in_str.split("ibmfl")[2].replace(os.sep, ".")
)[:-3]
module_spec = importlib.util.spec_from_file_location(
module_name, path_in_str
)
loader = importlib.util.LazyLoader(module_spec.loader) # type: ignore
module_spec.loader = loader # type: ignore[union-attr]
module = importlib.util.module_from_spec(module_spec) # type: ignore[arg-type]
sys.modules[module_name] = module
module_spec.loader.exec_module(module) # type: ignore
def check_python_framework_version(client_reqs: dict[str, Any]) -> None:
client_system = platform.system()
client_processor = platform.processor()
if "fl_extras" in client_reqs:
install_msg = (
" You can install this and other required packages by running pip install --upgrade 'ibm-watsonx-ai["
+ client_reqs["fl_extras"]
+ "]'"
)
else:
install_msg = " See documentation for more information."
py_version = platform.python_version()
logger.info("Detected Client Python Version: {}".format(py_version))
if not py_version.startswith(client_reqs["py_version"]):
raise Exception(
"The selected software spec requires python=={}.".format(
client_reqs["py_version"]
)
)
if client_system == "Darwin" and client_processor == "arm":
try:
tensorflowmacos_version = get_module_version(lib_name="tensorflow-macos")
logger.info(
"Detected tensorflow-macos Version: {}".format(tensorflowmacos_version)
)
if not tensorflowmacos_version.startswith(
client_reqs["tensorflow_version"]
):
raise Exception("Incompatible tensorflow-macos version found")
except Exception as ex:
logger.warning("{}, this may cause unexpected errors.".format(ex))
logger.warning(
"The selected software spec requires tensorflow-macos=={}.".format(
client_reqs["tensorflow_version"]
)
)
else:
try:
tensorflow_version = get_module_version(lib_name="tensorflow")
logger.info("Detected tensorflow Version: {}".format(tensorflow_version))
if not tensorflow_version.startswith(client_reqs["tensorflow_version"]):
raise Exception("Incompatible tensorflow version found")
except Exception as ex:
logger.warning("{}, this may cause unexpected errors.".format(ex))
logger.warning(
"The selected software spec requires tensorflow=={}.".format(
client_reqs["tensorflow_version"]
)
+ install_msg
)
try:
torch_version = get_module_version(lib_name="torch")
logger.info("Detected torch Version: {}".format(torch_version))
if not torch_version.startswith(client_reqs["torch_version"]):
raise Exception("Incompatible torch version found")
except Exception as ex:
logger.warning("{}, this may cause unexpected errors.".format(ex))
logger.warning(
"The selected software spec requires torch=={}.".format(
client_reqs["torch_version"]
)
+ install_msg
)
try:
scikitlearn_version = get_module_version(lib_name="scikit-learn")
logger.info("Detected scikit-learn Version: {}".format(scikitlearn_version))
if not scikitlearn_version.startswith(client_reqs["scikitlearn_version"]):
raise Exception("Incompatible scikit-learn version found")
except Exception as ex:
logger.warning("{}, this may cause unexpected errors.".format(ex))
logger.warning(
"The selected software spec requires scikit-learn=={}.".format(
client_reqs["scikitlearn_version"]
)
+ install_msg
)
def choose_software_version(software_spec: str) -> str:
logger.info("Aggregator Software Spec: {}".format(software_spec))
if (
software_spec == "runtime-23.1-py3.10"
or software_spec == "336b29df-e0e1-5e7d-b6a5-f6ab722625b2"
):
client_reqs = {
"py_version": "3.10",
"tensorflow_version": "2.12",
"torch_version": "2.0",
"scikitlearn_version": "1.1",
"fl_extras": "fl-rt23.1-py3.10",
}
check_python_framework_version(client_reqs)
return "runtime-23-1"
elif (
software_spec == "runtime-24.1-py3.11"
or software_spec == "45f12dfe-aa78-5b8d-9f38-0ee223c47309"
):
client_reqs = {
"py_version": "3.11",
"tensorflow_version": "2.14",
"torch_version": "2.1",
"scikitlearn_version": "1.3",
"fl_extras": "fl-rt24.1-py3.11",
}
check_python_framework_version(client_reqs)
return "runtime-24-1"
else:
client_reqs = {
"py_version": "3.10",
"tensorflow_version": "2.12",
"torch_version": "2.0",
"scikitlearn_version": "1.1",
"fl_extras": "fl-rt23.1-py3.10",
}
check_python_framework_version(client_reqs)
return "runtime-23-1"
fl_path = os.path.abspath(".")
if fl_path not in sys.path:
sys.path.append(fl_path)
[docs]
class Party(WMLResource):
"""The Party class embodies a Federated Learning party with methods to run, cancel, and query local training.
Refer to the ``client.remote_training_system.create_party()`` API for more information about creating an
instance of the Party class.
"""
base_platform = "runtime-23-1"
default_software_spec = "runtime-23.1-py3.10"
SUPPORTED_PLATFORMS_MAP = {
base_platform: "/runtime-23-1/ibmfl", # default
"runtime-24-1": "/runtime-24-1/ibmfl",
}
def __init__(self, client: APIClient | None = None, **kwargs: Any) -> None:
libs_module = sys.modules["ibm_watsonx_ai.libs"]
libs_location_list = libs_module.__path__
# base location string, default to cloud location
ibmfl_base_module_location = libs_location_list[0] + "/ibmfl"
# process location
ibmfl_module_location = (
ibmfl_base_module_location
+ self.SUPPORTED_PLATFORMS_MAP.get(self.base_platform)
)
# check if using old connector script which is removed
if not client:
raise Exception(
"This version of the party connector script is outdated. "
"Please download the party connector script from your current Federated Learning experiment. "
"For more details, please refer to the documentation."
)
self.module_location = ibmfl_module_location
self.args = kwargs
self.Party: IBMFL_Party | None = None
self.connection = None
self.log_level = None
self.metrics_output: str | None = None
if "ibmfl" in sys.modules:
del sys.modules["ibmfl"]
if "ibmfl.party" in sys.modules:
del sys.modules["ibmfl.party"]
if "ibmfl.party.party" in sys.modules:
del sys.modules["ibmfl.party.party"]
# install the general lib
module_name = "ibmfl"
module_spec = importlib.util.spec_from_file_location(
module_name,
ibmfl_base_module_location
+ "/"
+ self.base_platform
+ "/ibmfl/__init__.py",
)
module = importlib.util.module_from_spec(module_spec) # type: ignore[arg-type]
sys.modules[module_name] = module
module_spec.loader.exec_module(module) # type: ignore
WMLResource.__init__(self, __name__, client)
self._client = client
self.auth_token = "Bearer " + self._client.token
self.project_id = self._client.default_project_id
self.log_level = kwargs.get("log_level", "ERROR")
def start(self) -> None:
self.Party.start() # type: ignore
[docs]
def run(
self,
aggregator_id: str | None = None,
experiment_id: str | None = None,
asynchronous: bool = True,
verify: bool = True,
timeout: int = 60 * 10,
) -> None:
"""Connect to a Federated Learning aggregator and run local training.
Exactly one of `aggregator_id` and `experiment_id` must be supplied.
:param aggregator_id: aggregator identifier
* If aggregator_id is supplied, the party will connect to the given aggregator
:type aggregator_id: str, optional
:param experiment_id: experiment identifier
* If experiment_id is supplied, the party will connect to the most recently created aggregator
for the experiment.
:type experiment_id: str, optional
:param asynchronous:
* `True` - party starts to run the job in the background and progress can be checked later
* `False` - method will wait until training is complete and then print the job status
:type asynchronous: bool, optional
:param verify: verify certificate
:type verify: bool, optional
:param timeout: timeout in seconds
* If the aggregator is not ready within a provided number of seconds, there is a timeout.
:type timeout: int, or None for no timeout
**Examples**
.. code-block:: python
party.run( aggregator_id = "69500105-9fd2-4326-ad27-7231aeb37ac8", asynchronous = True, verify = True )
party.run( experiment_id = "2466fa06-1110-4169-a166-01959adec995", asynchronous = False )
"""
import time
from datetime import datetime
from ibmfl.exceptions import FLException
timeout_time = None if timeout is None else timeout + time.time()
if (experiment_id is None and aggregator_id is None) or (
experiment_id is not None and aggregator_id is not None
):
raise FLException(
"Exactly one of aggregator_id and experiment_id must be supplied"
)
if experiment_id is not None:
while True:
try:
details = self._client.training.get_details(
get_all=True,
training_definition_id=experiment_id,
_internal=True,
)["resources"]
details = [
d
for d in details
if d["entity"]["status"]["state"]
in ["accepting_parties", "pending", "running"]
]
if not details:
if timeout_time and timeout_time < time.time():
raise FLException(
"Cannot find an aggregator for experiment %s",
experiment_id,
)
else:
logger.info(
"Cannot find an aggregator for experiment %s. Retrying.",
experiment_id,
)
time.sleep(30)
continue
else:
aggregator_id = max(
[
(
t["metadata"]["id"],
datetime.strptime(
t["metadata"]["created_at"],
"%Y-%m-%dT%H:%M:%S.%fZ",
),
)
for t in details
],
key=lambda d: d[1],
)[0]
logger.info("Using aggregator id %s", aggregator_id)
break
except Exception as ex:
logger.exception(ex)
raise FLException(str(ex))
# If crypto is enabled, test if this client supports it
# WML training service uses the fusion_type field to determine if encryption is required.
try:
aggregator_details = self._client.training.get_details(
training_id=aggregator_id, _internal=True
)
except ApiRequestFailure as exc:
logger.error("Failed get training details")
raise FLException("Failed get training details") from exc
# Response might not be for a federated learning training
try:
fusion_type = aggregator_details["entity"]["federated_learning"][
"fusion_type"
]
except (KeyError, TypeError) as exc:
logger.error("Failed to read fusion type from training details")
raise FLException(
"Failed to read fusion type from training details"
) from exc
if fusion_type == "crypto_iter_avg":
logger.info("This training requires encryption")
if not is_crypto_supported():
logger.error(
"Encryption is required, but the '%s' module required for encryption was not found",
CRYPTO_LIBRARY,
)
raise FLException(f"'{CRYPTO_LIBRARY}' module not found")
# Import the changed files in the desired lib version
config_dict = self.args.get("config_dict", {})
metrics_config = {
"name": "WMLMetricsRecorder",
"path": "ibmfl.party.metrics.metrics_recorder",
"output_file": self.metrics_output,
"output_type": "json",
"compute_pre_train_eval": False,
"compute_post_train_eval": False,
}
if "metrics_recorder" not in config_dict:
config_dict["metrics_recorder"] = metrics_config
wml_services_url = config_dict.get("aggregator").get("ip").split("/")[0]
agg_info = wml_services_url + "/ml/v4/trainings/" + aggregator_id
config_dict["aggregator"]["ip"] = agg_info
self.args["config_dict"] = config_dict
# Verify ssl context
if verify:
try:
req.get(
"https://" + wml_services_url + "/wml_services/training/heartbeat",
verify=verify,
)
except requests.exceptions.SSLError as ex:
logger.error(str(ex))
raise FLException(
"No valid certificate detected. Please replace the default certificate with your own "
"TLS certificate, or set verify to False at your own risk. For more details, please see "
"https://www.ibm.com/docs/en/cloud-paks/cp-data/4.0?topic=client-using-custom-tls-certificate-connect-platform"
)
# Check for aggregator state and start job
try:
training_status = self._client.training.get_status(aggregator_id)
state = training_status["state"]
ready = False
if state == "pending":
while state == "pending" and (
not timeout_time or timeout_time > time.time()
):
logger.info("Waiting for aggregator accepting parties state..")
time.sleep(10)
training_status = self._client.training.get_status(aggregator_id)
state = training_status["state"]
if state != "accepting_parties":
raise FLException(
"The current state of training %s is %s, so the party is not able to start a job."
% (aggregator_id, state)
)
ready = True
elif state == "running" or state == "accepting_parties":
ready = True
else:
raise FLException(
"The current state of training %s is %s, so the party is not able to start a job."
% (aggregator_id, state)
)
if ready:
details = self._client.training.get_details(
aggregator_id, _internal=True
)
fl_entity = details["entity"]["federated_learning"]
if "software_spec" in fl_entity:
software_spec = (
fl_entity["software_spec"]["name"]
if "name" in fl_entity["software_spec"]
else fl_entity["software_spec"]["id"]
)
else:
software_spec = self.default_software_spec
platform_env = choose_software_version(software_spec)
logger.info("Loading {} environment..".format(platform_env))
if "system" in details and "warnings" in details["system"]:
for warning in details["system"]["warnings"]:
logger.info("Warning: {}".format(warning["message"]))
self.module_location = "/".join(
self.module_location.split("/")[0:-2]
) + self.SUPPORTED_PLATFORMS_MAP.get(platform_env)
import_diff(self.module_location)
from ibmfl.party.party import Party
self.Party = Party(
**self.args,
token=self.auth_token,
self_signed_cert=not verify,
log_level=self.log_level,
)
self.connection = self.Party.connection # type: ignore[attr-defined]
self.start()
else:
raise FLException(
"The current state of training %s is %s, so the party is not able to start a job."
% (aggregator_id, state)
)
# wait for the job to finish if synchrounous
if not asynchronous:
while (
"completed" != state
and "failed" != state
and "canceled" != state
and self.is_running()
):
training_status = self._client.training.get_status(aggregator_id)
state = training_status["state"]
time.sleep(10)
logger.info("The training finishes with %s status" % state)
except FLException as ex:
raise FLException(str(ex)) from ex
except Exception as ex:
logger.info("The party failed to start training")
logger.exception(ex)
[docs]
def monitor_logs(self, log_level: str = "INFO") -> None:
"""Enable logging of the training job to standard output.
Call this method before calling the ``run()`` method.
:param log_level: log level specified by user
:type log_level: str, optional
**Example:**
.. code-block:: python
party.monitor_logs()
"""
self.log_level = log_level
# Configure logging locally as well
from ibmfl.util.config import configure_logging_from_file
configure_logging_from_file(log_level=log_level)
logger.setLevel(log_level)
[docs]
def monitor_metrics(self, metrics_file: str = "-") -> None:
"""Enable output of training metrics.
:param metrics_file: a filename specified by user to which the metrics should be written
:type metrics_file: str, optional
.. note::
This method outputs the metrics to stdout if a filename is not specified
**Example:**
.. code-block:: python
party.monitor_metrics()
"""
self.metrics_output = metrics_file
[docs]
def is_running(self) -> bool:
"""Check if the training job is running.
:return: if the job is running
:rtype: bool
**Example:**
.. code-block:: python
party.is_running()
"""
return not self.connection.stopped # type: ignore[attr-defined]
[docs]
def get_round(self) -> int:
"""Get the current round number.
:return: the current round number
:rtype: int
**Example:**
.. code-block:: python
party.get_round()
"""
return self.Party.proto_handler.metrics_recorder.get_round_no() # type: ignore
[docs]
def cancel(self) -> None:
"""Stop the local connection to the training on the party side.
**Example:**
.. code-block:: python
party.cancel()
"""
self.Party.stop_connection() # type: ignore