Launch notebook online Binder badge or download it Download badge

Using scikit-learn FedAvg on IRIS dataset

This example illustrate an advanced usage of SubstraFL as it does not use the SubstraFL PyTorch interface, but showcases the general SubstraFL interface that you can use with any ML framework.

This example is based on:

  • Dataset: IRIS, tabular dataset to classify iris type

  • Model type: Logistic regression using Scikit-Learn

  • FL setup: three organizations, two data providers and one algo provider

This example does not use the deployed platform of Substra, it runs in local mode.

To run this example, you need to download and unzip the assets needed to run it in the same directory as used this example:

Please ensure to have all the libraries installed. A requirements.txt file is included in the zip file, where you can run the command pip install -r requirements.txt to install them.

Substra and SubstraFL should already be installed. If not, follow the instructions described here.


We work with three different organizations. Two organizations provide a dataset, and a third one provides the algorithm and registers the machine learning tasks.

This example runs in local mode, simulating a federated learning experiment.

In the following code cell, we define the different organizations needed for our FL experiment.

import numpy as np

from substra import Client

SEED = 42

# Choose the subprocess mode to locally simulate the FL process
clients_list = [Client(client_name=f"org-{i+1}") for i in range(N_CLIENTS)]
clients = {client.organization_info().organization_id: client for client in clients_list}

# Store organization IDs
ORGS_ID = list(clients)
ALGO_ORG_ID = ORGS_ID[0]  # Algo provider is defined as the first organization.
DATA_PROVIDER_ORGS_ID = ORGS_ID[1:]  # Data provider orgs are the last two organizations.
/home/docs/checkouts/ TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See
  from .autonotebook import tqdm as notebook_tqdm

Data and metrics

Data preparation

This section downloads (if needed) the IRIS dataset using the Scikit-Learn dataset module. It extracts the data locally create two folders: one for each organization.

Each organization will have access to half the train data, and to half the test data.

import pathlib
from sklearn_fedavg_assets.dataset.iris_dataset import setup_iris

# Create the temporary directory for generated data
(pathlib.Path.cwd() / "tmp").mkdir(exist_ok=True)
data_path = pathlib.Path.cwd() / "tmp" / "data_iris"

setup_iris(data_path=data_path, n_client=len(DATA_PROVIDER_ORGS_ID))

Dataset registration

from substra.sdk.schemas import DatasetSpec
from substra.sdk.schemas import Permissions
from substra.sdk.schemas import DataSampleSpec

assets_directory = pathlib.Path.cwd() / "sklearn_fedavg_assets"

permissions_dataset = Permissions(public=False, authorized_ids=[ALGO_ORG_ID])

dataset = DatasetSpec(
    data_opener=assets_directory / "dataset" / "",
    description=assets_directory / "dataset" / "",

dataset_keys = {}
train_datasample_keys = {}
test_datasample_keys = {}

for i, org_id in enumerate(DATA_PROVIDER_ORGS_ID):
    client = clients[org_id]

    # Add the dataset to the client to provide access to the opener in each organization.
    dataset_keys[org_id] = client.add_dataset(dataset)
    assert dataset_keys[org_id], "Missing data manager key"

    client = clients[org_id]

    # Add the training data on each organization.
    data_sample = DataSampleSpec(
        path=data_path / f"org_{i+1}" / "train",
    train_datasample_keys[org_id] = client.add_data_sample(

    # Add the testing data on each organization.
    data_sample = DataSampleSpec(
        path=data_path / f"org_{i+1}" / "test",
    test_datasample_keys[org_id] = client.add_data_sample(

Metrics registration

from sklearn.metrics import accuracy_score
import numpy as np

def accuracy(data_from_opener, predictions):
    y_true = data_from_opener["targets"]

    return accuracy_score(y_true, predictions)

Specify the machine learning components

SubstraFL can be used with any machine learning framework. The framework dependent functions are written in the Algorithm object.

In this section, you will:

  • register a model and its dependencies

  • write your own Sklearn SubstraFL algorithm

  • specify the federated learning strategy

  • specify the organizations where to train and where to aggregate

  • specify the organizations where to test the models

  • actually run the computations

Model definition

The machine learning model used here is a logistic regression. The warm_start argument is essential in this example as it indicates to use the current state of the model as initialization for the future training. By default scikit-learn uses max_iter=100, which means the model trains on up to 100 epochs. When doing federated learning, we don’t want to train too much locally at every round otherwise the local training will erase what was learned from the other centers. That is why we set max_iter=3.

import os
from sklearn import linear_model

cls = linear_model.LogisticRegression(random_state=SEED, warm_start=True, max_iter=3)

# Optional:
# Scikit-Learn raises warnings in case of non convergence, that we choose to disable here.
# As this example runs with python subprocess, the way to disable it is to use following environment
# variable:
os.environ["PYTHONWARNINGS"] = "ignore:lbfgs failed to converge (status=1):UserWarning"

SubstraFL algo definition

This section is the most important one for this example. We will define here the function that will run locally on each node to train the model.

As SubstraFL does not provide an algorithm comptatible with Sklearn, we need to define one using the provided documentation on substrafl_doc/api/algorithms:Base Class.

To define a custom algorithm, we will need to inherit from the base class Algo, and to define two properties and four methods:

  • strategies (property): the list of strategies our algorithm is compatible with.

  • model (property): a property that returns the model from the defined algo.

  • train (method): a function to describe the training process to apply to train our model in a federated way. The train method must accept as parameters data_from_opener and shared_state.

  • predict (method): a function to describe how to compute the predictions from the algo model. The predict method must accept as parameters data_from_opener and shared_state.

  • save (method): specify how to save the important states of our algo.

  • load (method): specify how to load the important states of our algo from a previously saved filed by the save function describe above.

from substrafl import algorithms
from substrafl import remote
from substrafl.strategies import schemas as fl_schemas

import joblib
from typing import Optional

# The Iris dataset proposes four attributes to predict three different classes.

class SklearnLogisticRegression(algorithms.Algo):
    def __init__(self, model, seed=None):
        super().__init__(model=model, seed=seed)

        self._model = model

        # We need all different instances of the algorithm to have the same
        # initialization.
        self._model.coef_ = np.ones((OUTPUT_SIZE, INPUT_SIZE))
        self._model.intercept_ = np.zeros(3)
        self._model.classes_ = np.array([-1])

        if seed is not None:

    def strategies(self):
        """List of compatible strategies"""
        return [fl_schemas.StrategyName.FEDERATED_AVERAGING]

    def model(self):
        return self._model

    def train(
        shared_state: Optional[fl_schemas.FedAvgAveragedState] = None,
    ) -> fl_schemas.FedAvgSharedState:
        """The train function to be executed on organizations containing
        data we want to train our model on. The @remote_data decorator is mandatory
        to allow this function to be sent and executed on the right organization.

            data_from_opener: data_from_opener extracted from the organizations data using
                the given opener.
            shared_state (Optional[fl_schemas.FedAvgAveragedState], optional):
                shared_state provided by the aggregator. Defaults to None.

            fl_schemas.FedAvgSharedState: State to be sent to the aggregator.

        if shared_state is not None:
            # If we have a shared state, we update the model parameters with
            # the average parameters updates.
            self._model.coef_ += np.reshape(
                (OUTPUT_SIZE, INPUT_SIZE),
            self._model.intercept_ += shared_state.avg_parameters_update[-1]

        # To be able to compute the delta between the parameters before and after training,
        # we need to save them in a temporary variable.
        old_coef = self._model.coef_
        old_intercept = self._model.intercept_

        # Model training.["data"], data_from_opener["targets"])

        # We compute de delta.
        delta_coef = self._model.coef_ - old_coef
        delta_bias = self._model.intercept_ - old_intercept

        # We reset the model parameters to their state before training in order to remove
        # the local updates from it.
        self._model.coef_ = old_coef
        self._model.intercept_ = old_intercept

        # We output the length of the dataset to apply a weighted average between
        # the organizations regarding their number of samples, and the local
        # parameters updates.
        # These updates are sent to the aggregator to compute the average
        # parameters updates, that we will receive in the next round in the
        # `shared_state`.
        return fl_schemas.FedAvgSharedState(
            parameters_update=[p for p in delta_coef] + [delta_bias],

    def predict(self, data_from_opener, shared_state):
        """The predict function to be executed by the evaluation function on
        data we want to test our model on. The predict method is mandatory and is
        an `abstractmethod` of the `Algo` class.

            data_from_opener: data_from_opener extracted from the organizations data using
                the given opener.
            shared_state: shared_state provided by the aggregator.
        predictions = self._model.predict(data_from_opener["data"])

        return predictions

    def save_local_state(self, path):
                "model": self._model,
                "coef": self._model.coef_,
                "bias": self._model.intercept_,

    def load_local_state(self, path):
        loaded_dict = joblib.load(path)
        self._model = loaded_dict["model"]
        self._model.coef_ = loaded_dict["coef"]
        self._model.intercept_ = loaded_dict["bias"]
        return self

Federated Learning strategies

from substrafl.strategies import FedAvg

strategy = FedAvg(algo=SklearnLogisticRegression(model=cls, seed=SEED), metric_functions=accuracy)

Where to train where to aggregate

from substrafl.nodes import TrainDataNode
from substrafl.nodes import AggregationNode

aggregation_node = AggregationNode(ALGO_ORG_ID)

# Create the Train Data Nodes (or training tasks) and save them in a list
train_data_nodes = [
    for org_id in DATA_PROVIDER_ORGS_ID

Where and when to test

from substrafl.nodes import TestDataNode
from substrafl.evaluation_strategy import EvaluationStrategy

# Create the Test Data Nodes (or testing tasks) and save them in a list
test_data_nodes = [
    for org_id in DATA_PROVIDER_ORGS_ID

my_eval_strategy = EvaluationStrategy(test_data_nodes=test_data_nodes, eval_frequency=1)

Running the experiment

from substrafl.experiment import execute_experiment
from substrafl.dependency import Dependency

# Number of times to apply the compute plan.

dependencies = Dependency(pypi_dependencies=["numpy==1.24.3", "scikit-learn==1.3.1"])

compute_plan = execute_experiment(
    experiment_folder=str(pathlib.Path.cwd() / "tmp" / "experiment_summaries"),
    name="IRIS documentation example",
2024-06-14 09:12:33,374 - INFO - Building the compute plan.
Rounds progress: 100%|██████████| 6/6 [00:00<00:00, 1608.14it/s]
2024-06-14 09:12:33,382 - INFO - Registering the functions to Substra.
2024-06-14 09:12:33,417 - INFO - Registering the compute plan to Substra.
2024-06-14 09:12:33,418 - INFO - Experiment summary saved to /home/docs/checkouts/
Compute plan progress:   0%|          | 0/36 [00:00<?, ?it/s]/home/docs/checkouts/ UserWarning: `transient=True` is ignored in local mode
  warnings.warn("`transient=True` is ignored in local mode", stacklevel=1)
Compute plan progress: 100%|██████████| 36/36 [00:39<00:00,  1.10s/it]
2024-06-14 09:13:13,130 - INFO - The compute plan has been registered to Substra, its key is e0e2e590-c19b-4bb7-86ec-0328b622d0e8.

Explore the results

# The results will be available once the compute plan is completed
    "key": "e0e2e590-c19b-4bb7-86ec-0328b622d0e8",
    "tag": "",
    "name": "IRIS documentation example",
    "owner": "MyOrg1MSP",
    "metadata": {
        "substrafl_version": "0.46.0",
        "substra_version": "0.53.0",
        "substratools_version": "0.21.4",
        "python_version": "3.10.14"
    "task_count": 36,
    "waiting_builder_slot_count": 0,
    "building_count": 0,
    "waiting_parent_tasks_count": 0,
    "waiting_executor_slot_count": 0,
    "executing_count": 0,
    "canceled_count": 0,
    "failed_count": 0,
    "done_count": 36,
    "failed_task_key": null,
    "status": "PLAN_STATUS_DONE",
    "creation_date": "2024-06-14T09:12:33.420725",
    "start_date": "2024-06-14T09:12:33.420727",
    "end_date": "2024-06-14T09:13:13.125512",
    "estimated_end_date": "2024-06-14T09:13:13.125512",
    "duration": 39,
    "creator": null

Listing results

import pandas as pd

performances_df = pd.DataFrame(client.get_performances(compute_plan.key).model_dump())
print("\nPerformance Table: \n")
print(performances_df[["worker", "round_idx", "performance"]])

Performance Table:

       worker  round_idx  performance
0   MyOrg2MSP          0     0.000000
1   MyOrg3MSP          0     0.000000
2   MyOrg2MSP          1     0.933333
3   MyOrg3MSP          1     1.000000
4   MyOrg2MSP          2     0.933333
5   MyOrg3MSP          2     1.000000
6   MyOrg2MSP          3     0.933333
7   MyOrg3MSP          3     1.000000
8   MyOrg2MSP          4     0.933333
9   MyOrg3MSP          4     1.000000
10  MyOrg2MSP          5     1.000000
11  MyOrg3MSP          5     1.000000
12  MyOrg2MSP          6     1.000000
13  MyOrg3MSP          6     1.000000

Plot results

import matplotlib.pyplot as plt

plt.title("Test dataset results")

for org_id in DATA_PROVIDER_ORGS_ID:
    df = performances_df[performances_df["worker"] == org_id]
    plt.plot(df["round_idx"], df["performance"], label=org_id)

plt.legend(loc="lower right")

Download a model

from substrafl.model_loading import download_algo_state

client_to_download_from = DATA_PROVIDER_ORGS_ID[0]
round_idx = None

algo = download_algo_state(

cls = algo.model

print("Coefs: ", cls.coef_)
print("Intercepts: ", cls.intercept_)
Coefs:  [[ 1.16237637  1.80062789 -0.59844895  0.16076327]
 [ 1.02009926  0.51773141  1.04883079  0.61084198]
 [ 0.26141703  0.12553336  1.99351081  1.67228741]]
Intercepts:  [ 0.21601049  0.1066958  -0.32270629]
Launch notebook online Binder badge or download it Download badge