Using scikit-learn FedAvg on IRIS dataset
This example illustrate an advanced usage of SubstraFL as it does not use the SubstraFL PyTorch interface, but showcases the general SubstraFL interface that you can use with any ML framework.
This example is based on:
Dataset: IRIS, tabular dataset to classify iris type
Model type: Logistic regression using Scikit-Learn
FL setup: three organizations, two data providers and one algo provider
This example does not use the deployed platform of Substra, it runs in local mode.
To run this example, you need to download and unzip the assets needed to run it in the same directory as used this example:
Please ensure to have all the libraries installed. A requirements.txt file is included in the zip file, where you can run the command pip install -r requirements.txt
to install them.
Substra and SubstraFL should already be installed. If not, follow the instructions described here.
Setup
We work with three different organizations. Two organizations provide a dataset, and a third one provides the algorithm and registers the machine learning tasks.
This example runs in local mode, simulating a federated learning experiment.
In the following code cell, we define the different organizations needed for our FL experiment.
[1]:
import numpy as np
from substra import Client
SEED = 42
np.random.seed(SEED)
# Choose the subprocess mode to locally simulate the FL process
N_CLIENTS = 3
clients_list = [Client(client_name=f"org-{i+1}") for i in range(N_CLIENTS)]
clients = {client.organization_info().organization_id: client for client in clients_list}
# Store organization IDs
ORGS_ID = list(clients)
ALGO_ORG_ID = ORGS_ID[0] # Algo provider is defined as the first organization.
DATA_PROVIDER_ORGS_ID = ORGS_ID[1:] # Data provider orgs are the last two organizations.
/home/docs/checkouts/readthedocs.org/user_builds/owkin-substra-documentation/conda/stable/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
from .autonotebook import tqdm as notebook_tqdm
Data and metrics
Data preparation
This section downloads (if needed) the IRIS dataset using the Scikit-Learn dataset module. It extracts the data locally create two folders: one for each organization.
Each organization will have access to half the train data, and to half the test data.
[2]:
import pathlib
from sklearn_fedavg_assets.dataset.iris_dataset import setup_iris
# Create the temporary directory for generated data
(pathlib.Path.cwd() / "tmp").mkdir(exist_ok=True)
data_path = pathlib.Path.cwd() / "tmp" / "data_iris"
setup_iris(data_path=data_path, n_client=len(DATA_PROVIDER_ORGS_ID))
Dataset registration
[3]:
from substra.sdk.schemas import DatasetSpec
from substra.sdk.schemas import Permissions
from substra.sdk.schemas import DataSampleSpec
assets_directory = pathlib.Path.cwd() / "sklearn_fedavg_assets"
permissions_dataset = Permissions(public=False, authorized_ids=[ALGO_ORG_ID])
dataset = DatasetSpec(
name="Iris",
data_opener=assets_directory / "dataset" / "iris_opener.py",
description=assets_directory / "dataset" / "description.md",
permissions=permissions_dataset,
logs_permission=permissions_dataset,
)
dataset_keys = {}
train_datasample_keys = {}
test_datasample_keys = {}
for i, org_id in enumerate(DATA_PROVIDER_ORGS_ID):
client = clients[org_id]
# Add the dataset to the client to provide access to the opener in each organization.
dataset_keys[org_id] = client.add_dataset(dataset)
assert dataset_keys[org_id], "Missing data manager key"
client = clients[org_id]
# Add the training data on each organization.
data_sample = DataSampleSpec(
data_manager_keys=[dataset_keys[org_id]],
path=data_path / f"org_{i+1}" / "train",
)
train_datasample_keys[org_id] = client.add_data_sample(
data_sample,
local=True,
)
# Add the testing data on each organization.
data_sample = DataSampleSpec(
data_manager_keys=[dataset_keys[org_id]],
path=data_path / f"org_{i+1}" / "test",
)
test_datasample_keys[org_id] = client.add_data_sample(
data_sample,
local=True,
)
Metrics registration
[4]:
from sklearn.metrics import accuracy_score
import numpy as np
def accuracy(data_from_opener, predictions):
y_true = data_from_opener["targets"]
return accuracy_score(y_true, predictions)
Specify the machine learning components
SubstraFL can be used with any machine learning framework. The framework dependent functions are written in the Algorithm object.
In this section, you will:
register a model and its dependencies
write your own Sklearn SubstraFL algorithm
specify the federated learning strategy
specify the organizations where to train and where to aggregate
specify the organizations where to test the models
actually run the computations
Model definition
The machine learning model used here is a logistic regression. The warm_start
argument is essential in this example as it indicates to use the current state of the model as initialization for the future training. By default scikit-learn uses max_iter=100
, which means the model trains on up to 100 epochs. When doing federated learning, we don’t want to train too much locally at every round otherwise the local training will erase what was learned from the other centers. That is why we set
max_iter=3
.
[5]:
import os
from sklearn import linear_model
cls = linear_model.LogisticRegression(random_state=SEED, warm_start=True, max_iter=3)
# Optional:
# Scikit-Learn raises warnings in case of non convergence, that we choose to disable here.
# As this example runs with python subprocess, the way to disable it is to use following environment
# variable:
os.environ["PYTHONWARNINGS"] = "ignore:lbfgs failed to converge (status=1):UserWarning"
SubstraFL algo definition
This section is the most important one for this example. We will define here the function that will run locally on each node to train the model.
As SubstraFL does not provide an algorithm comptatible with Sklearn, we need to define one using the provided documentation on substrafl_doc/api/algorithms:Base Class
.
To define a custom algorithm, we will need to inherit from the base class Algo, and to define two properties and four methods:
strategies (property): the list of strategies our algorithm is compatible with.
model (property): a property that returns the model from the defined algo.
train (method): a function to describe the training process to apply to train our model in a federated way. The train method must accept as parameters
data_from_opener
andshared_state
.predict (method): a function to describe how to compute the predictions from the algo model. The predict method must accept as parameters
data_from_opener
andshared_state
.save (method): specify how to save the important states of our algo.
load (method): specify how to load the important states of our algo from a previously saved filed by the
save
function describe above.
[6]:
from substrafl import algorithms
from substrafl import remote
from substrafl.strategies import schemas as fl_schemas
import joblib
from typing import Optional
# The Iris dataset proposes four attributes to predict three different classes.
INPUT_SIZE = 4
OUTPUT_SIZE = 3
class SklearnLogisticRegression(algorithms.Algo):
def __init__(self, model, seed=None):
super().__init__(model=model, seed=seed)
self._model = model
# We need all different instances of the algorithm to have the same
# initialization.
self._model.coef_ = np.ones((OUTPUT_SIZE, INPUT_SIZE))
self._model.intercept_ = np.zeros(3)
self._model.classes_ = np.array([-1])
if seed is not None:
np.random.seed(seed)
@property
def strategies(self):
"""List of compatible strategies"""
return [fl_schemas.StrategyName.FEDERATED_AVERAGING]
@property
def model(self):
return self._model
@remote.remote_data
def train(
self,
data_from_opener,
shared_state: Optional[fl_schemas.FedAvgAveragedState] = None,
) -> fl_schemas.FedAvgSharedState:
"""The train function to be executed on organizations containing
data we want to train our model on. The @remote_data decorator is mandatory
to allow this function to be sent and executed on the right organization.
Args:
data_from_opener: data_from_opener extracted from the organizations data using
the given opener.
shared_state (Optional[fl_schemas.FedAvgAveragedState], optional):
shared_state provided by the aggregator. Defaults to None.
Returns:
fl_schemas.FedAvgSharedState: State to be sent to the aggregator.
"""
if shared_state is not None:
# If we have a shared state, we update the model parameters with
# the average parameters updates.
self._model.coef_ += np.reshape(
shared_state.avg_parameters_update[:-1],
(OUTPUT_SIZE, INPUT_SIZE),
)
self._model.intercept_ += shared_state.avg_parameters_update[-1]
# To be able to compute the delta between the parameters before and after training,
# we need to save them in a temporary variable.
old_coef = self._model.coef_
old_intercept = self._model.intercept_
# Model training.
self._model.fit(data_from_opener["data"], data_from_opener["targets"])
# We compute de delta.
delta_coef = self._model.coef_ - old_coef
delta_bias = self._model.intercept_ - old_intercept
# We reset the model parameters to their state before training in order to remove
# the local updates from it.
self._model.coef_ = old_coef
self._model.intercept_ = old_intercept
# We output the length of the dataset to apply a weighted average between
# the organizations regarding their number of samples, and the local
# parameters updates.
# These updates are sent to the aggregator to compute the average
# parameters updates, that we will receive in the next round in the
# `shared_state`.
return fl_schemas.FedAvgSharedState(
n_samples=len(data_from_opener["targets"]),
parameters_update=[p for p in delta_coef] + [delta_bias],
)
def predict(self, data_from_opener, shared_state):
"""The predict function to be executed by the evaluation function on
data we want to test our model on. The predict method is mandatory and is
an `abstractmethod` of the `Algo` class.
Args:
data_from_opener: data_from_opener extracted from the organizations data using
the given opener.
shared_state: shared_state provided by the aggregator.
"""
predictions = self._model.predict(data_from_opener["data"])
return predictions
def save_local_state(self, path):
joblib.dump(
{
"model": self._model,
"coef": self._model.coef_,
"bias": self._model.intercept_,
},
path,
)
def load_local_state(self, path):
loaded_dict = joblib.load(path)
self._model = loaded_dict["model"]
self._model.coef_ = loaded_dict["coef"]
self._model.intercept_ = loaded_dict["bias"]
return self
Federated Learning strategies
[7]:
from substrafl.strategies import FedAvg
strategy = FedAvg(algo=SklearnLogisticRegression(model=cls, seed=SEED), metric_functions=accuracy)
Where to train where to aggregate
[8]:
from substrafl.nodes import TrainDataNode
from substrafl.nodes import AggregationNode
aggregation_node = AggregationNode(ALGO_ORG_ID)
# Create the Train Data Nodes (or training tasks) and save them in a list
train_data_nodes = [
TrainDataNode(
organization_id=org_id,
data_manager_key=dataset_keys[org_id],
data_sample_keys=[train_datasample_keys[org_id]],
)
for org_id in DATA_PROVIDER_ORGS_ID
]
Where and when to test
[9]:
from substrafl.nodes import TestDataNode
from substrafl.evaluation_strategy import EvaluationStrategy
# Create the Test Data Nodes (or testing tasks) and save them in a list
test_data_nodes = [
TestDataNode(
organization_id=org_id,
data_manager_key=dataset_keys[org_id],
data_sample_keys=[test_datasample_keys[org_id]],
)
for org_id in DATA_PROVIDER_ORGS_ID
]
my_eval_strategy = EvaluationStrategy(test_data_nodes=test_data_nodes, eval_frequency=1)
Running the experiment
[10]:
from substrafl.experiment import execute_experiment
from substrafl.dependency import Dependency
# Number of times to apply the compute plan.
NUM_ROUNDS = 6
dependencies = Dependency(pypi_dependencies=["numpy==1.24.3", "scikit-learn==1.3.1"])
compute_plan = execute_experiment(
client=clients[ALGO_ORG_ID],
strategy=strategy,
train_data_nodes=train_data_nodes,
evaluation_strategy=my_eval_strategy,
aggregation_node=aggregation_node,
num_rounds=NUM_ROUNDS,
experiment_folder=str(pathlib.Path.cwd() / "tmp" / "experiment_summaries"),
dependencies=dependencies,
name="IRIS documentation example",
)
2024-04-08 13:38:12,774 - INFO - Building the compute plan.
Rounds progress: 100%|██████████| 6/6 [00:00<00:00, 1735.93it/s]
2024-04-08 13:38:12,781 - INFO - Registering the functions to Substra.
2024-04-08 13:38:12,816 - INFO - Registering the compute plan to Substra.
2024-04-08 13:38:12,817 - INFO - Experiment summary saved to /home/docs/checkouts/readthedocs.org/user_builds/owkin-substra-documentation/checkouts/stable/docs/source/examples/substrafl/go_further/tmp/experiment_summaries/2024_04_08_13_38_12_5dc4970a-cc7f-495e-9980-427aeec7a558.json
Compute plan progress: 0%| | 0/36 [00:00<?, ?it/s]/home/docs/checkouts/readthedocs.org/user_builds/owkin-substra-documentation/checkouts/stable/docs/src/substra/substra/sdk/backends/local/backend.py:616: UserWarning: `transient=True` is ignored in local mode
warnings.warn("`transient=True` is ignored in local mode", stacklevel=1)
Compute plan progress: 100%|██████████| 36/36 [00:39<00:00, 1.09s/it]
2024-04-08 13:38:52,028 - INFO - The compute plan has been registered to Substra, its key is 5dc4970a-cc7f-495e-9980-427aeec7a558.
Explore the results
[11]:
# The results will be available once the compute plan is completed
clients[ALGO_ORG_ID].wait_compute_plan(compute_plan.key)
[11]:
{
"key": "5dc4970a-cc7f-495e-9980-427aeec7a558",
"tag": "",
"name": "IRIS documentation example",
"owner": "MyOrg1MSP",
"metadata": {
"substrafl_version": "0.45.0",
"substra_version": "0.52.0",
"substratools_version": "0.21.3",
"python_version": "3.10.14"
},
"task_count": 36,
"waiting_builder_slot_count": 0,
"building_count": 0,
"waiting_parent_tasks_count": 0,
"waiting_executor_slot_count": 0,
"executing_count": 0,
"canceled_count": 0,
"failed_count": 0,
"done_count": 36,
"failed_task_key": null,
"status": "PLAN_STATUS_DONE",
"creation_date": "2024-04-08T13:38:12.820215",
"start_date": "2024-04-08T13:38:12.820217",
"end_date": "2024-04-08T13:38:52.023306",
"estimated_end_date": "2024-04-08T13:38:52.023306",
"duration": 39,
"creator": null
}
Listing results
[12]:
import pandas as pd
performances_df = pd.DataFrame(client.get_performances(compute_plan.key).model_dump())
print("\nPerformance Table: \n")
print(performances_df[["worker", "round_idx", "performance"]])
Performance Table:
worker round_idx performance
0 MyOrg2MSP 0 0.000000
1 MyOrg3MSP 0 0.000000
2 MyOrg2MSP 1 0.933333
3 MyOrg3MSP 1 1.000000
4 MyOrg2MSP 2 0.933333
5 MyOrg3MSP 2 1.000000
6 MyOrg2MSP 3 0.933333
7 MyOrg3MSP 3 1.000000
8 MyOrg2MSP 4 0.933333
9 MyOrg3MSP 4 1.000000
10 MyOrg2MSP 5 1.000000
11 MyOrg3MSP 5 1.000000
12 MyOrg2MSP 6 1.000000
13 MyOrg3MSP 6 1.000000
Plot results
[13]:
import matplotlib.pyplot as plt
plt.title("Test dataset results")
plt.xlabel("Rounds")
plt.ylabel("Accuracy")
for org_id in DATA_PROVIDER_ORGS_ID:
df = performances_df[performances_df["worker"] == org_id]
plt.plot(df["round_idx"], df["performance"], label=org_id)
plt.legend(loc="lower right")
plt.show()
Download a model
[14]:
from substrafl.model_loading import download_algo_state
client_to_download_from = DATA_PROVIDER_ORGS_ID[0]
round_idx = None
algo = download_algo_state(
client=clients[client_to_download_from],
compute_plan_key=compute_plan.key,
round_idx=round_idx,
)
cls = algo.model
print("Coefs: ", cls.coef_)
print("Intercepts: ", cls.intercept_)
Coefs: [[ 1.16237637 1.80062789 -0.59844895 0.16076327]
[ 1.02009926 0.51773141 1.04883079 0.61084198]
[ 0.26141703 0.12553336 1.99351081 1.67228741]]
Intercepts: [ 0.21601049 0.1066958 -0.32270629]