Using scikit-learn FedAvg on IRIS dataset

This example illustrate an advanced usage of SubstraFL as it does not use the SubstraFL PyTorch interface, but showcases the general SubstraFL interface on which you can use any ML framework.

This example is based on:

  • Dataset: IRIS, tabular dataset to classify iris type

  • Model type: Logistic regression using Scikit-Learn

  • FL setup: three organizations, two data providers and one algo provider

This example does not use the deployed platform of Substra, it runs in local mode.

To run this example, you have two options:

  • Recommended option: use a hosted Jupyter notebook. With this option you don’t have to install anything, just run the notebook. To access the hosted notebook, scroll down at the bottom of this page and click on the Launch Binder button.

  • Run the example locally. To do that you need to download and unzip the assets needed to run it in the same directory as used this example.

    • Please ensure to have all the libraries installed. A requirements.txt file is included in the zip file, where you can run the command pip install -r requirements.txt to install them.

    • Substra and SubstraFL should already be installed. If not follow the instructions described here: Installation.

Setup

We work with three different organizations. Two organizations provide a dataset, and a third one provides the algorithm and register the machine learning tasks.

This example runs in local mode, simulating a federated learning experiment.

In the following code cell, we define the different organizations needed for our FL experiment.

import pathlib
import numpy as np

from substra import Client

SEED = 42
np.random.seed(SEED)

# Choose the subprocess mode to locally simulate the FL process
N_CLIENTS = 3
clients = [Client(backend_type="subprocess") for _ in range(N_CLIENTS)]
clients = {client.organization_info().organization_id: client for client in clients}

# Store organization IDs
ORGS_ID = list(clients.keys())
ALGO_ORG_ID = ORGS_ID[0]  # Algo provider is defined as the first organization.
DATA_PROVIDER_ORGS_ID = ORGS_ID[1:]  # Data providers orgs are the two last organizations.

Data and metrics

Data preparation

This section downloads (if needed) the IRIS dataset using the Scikit-Learn dataset module. It extracts the data locally create two folders: one for each organization.

Each organization will have access to half the train data, and to half the test data.

import pathlib
from sklearn_fedavg_assets.dataset.iris_dataset import setup_iris


# Create the temporary directory for generated data
(pathlib.Path.cwd() / "tmp").mkdir(exist_ok=True)
data_path = pathlib.Path.cwd() / "tmp" / "data_iris"

setup_iris(data_path=data_path, n_client=len(DATA_PROVIDER_ORGS_ID))

Dataset registration

from substra.sdk.schemas import DatasetSpec
from substra.sdk.schemas import Permissions
from substra.sdk.schemas import DataSampleSpec

assets_directory = pathlib.Path.cwd() / "sklearn_fedavg_assets"

permissions_dataset = Permissions(public=False, authorized_ids=[ALGO_ORG_ID])

dataset = DatasetSpec(
    name="Iris",
    type="npy",
    data_opener=assets_directory / "dataset" / "iris_opener.py",
    description=assets_directory / "dataset" / "description.md",
    permissions=permissions_dataset,
    logs_permission=permissions_dataset,
)

dataset_keys = {}
train_datasample_keys = {}
test_datasample_keys = {}

for i, org_id in enumerate(DATA_PROVIDER_ORGS_ID):
    client = clients[org_id]

    # Add the dataset to the client to provide access to the opener in each organization.
    dataset_keys[org_id] = client.add_dataset(dataset)
    assert dataset_keys[org_id], "Missing data manager key"

    client = clients[org_id]

    # Add the training data on each organization.
    data_sample = DataSampleSpec(
        data_manager_keys=[dataset_keys[org_id]],
        path=data_path / f"org_{i+1}" / "train",
    )
    train_datasample_keys[org_id] = client.add_data_sample(
        data_sample,
        local=True,
    )

    # Add the testing data on each organization.
    data_sample = DataSampleSpec(
        data_manager_keys=[dataset_keys[org_id]],
        path=data_path / f"org_{i+1}" / "test",
    )
    test_datasample_keys[org_id] = client.add_data_sample(
        data_sample,
        local=True,
    )

Metrics registration

from sklearn.metrics import accuracy_score
import numpy as np

from substrafl.remote.register import add_metric
from substrafl.dependency import Dependency

metric_deps = Dependency(pypi_dependencies=["numpy==1.23.1", "scikit-learn==1.1.1"])

permissions_metric = Permissions(public=False, authorized_ids=[ALGO_ORG_ID] + DATA_PROVIDER_ORGS_ID)


def accuracy(datasamples, predictions_path):
    y_true = datasamples["targets"]
    y_pred = np.load(predictions_path)

    return accuracy_score(y_true, y_pred)


metric_key = add_metric(
    client=clients[ALGO_ORG_ID],
    metric_function=accuracy,
    permissions=permissions_metric,
    dependencies=metric_deps,
)

Specify the machine learning components

SubstraFL can be used with any machine learning framework. The framework dependent functions are written in the Algorithm object.

In this section, you will:

  • register a model and its dependencies

  • write your own Sklearn SubstraFL algorithm

  • specify the federated learning strategy

  • specify the organizations where to train and where to aggregate

  • specify the organizations where to test the models

  • actually run the computations

Model definition

The machine learning model used here is a logistic regression. The warm_start argument is essential in this example as it indicates to use the current state of the model as initialization for the future training. By default scikit-learn uses max_iter=100, which means the model trains on up to 100 epochs. When doing FL, we don’t want to train too much locally at every round otherwise the local training will erase what was learned from the other centers. That is why we set max_iter=3.

import os
from sklearn import linear_model

cls = linear_model.LogisticRegression(random_state=SEED, warm_start=True, max_iter=3)

# Optional:
# Scikit-Learn raises warnings in case of non convergence, that we choose to disable here.
# As this example runs with python subprocess, the way to disable it is to use following environment
# variable:
os.environ["PYTHONWARNINGS"] = "ignore:lbfgs failed to converge (status=1):UserWarning"

SubstraFL algo definition

This section is the most important one for this example. We will define here the function that will run locally on each node to train the model.

As SubstraFL does not provide an SklearnFedAvgAlgo, we need to define one using the provided documentation on Base Class.

To define a custom algorithm, we will need to inherit from the base class Algo, and to define two properties and four functions:

  • strategies (property): the list of strategies our algorithm is compatible with.

  • model (property): a property that returns the model from the defined algo.

  • train (function): a function to describe the training process to apply to train our model in a federated way. The train method signature must contains the datasamples and shared_state parameters.

  • predict (function): a function to describe how to compute the predictions from the algo model. The predict method signature must contains the datasamples, shared_state and predictions_path parameters.

  • save (function): specify how to save the important states of our algo.

  • load (function): specify how to load the important states of our algo from a previously saved filed by the save function describe above.

from substrafl import algorithms
from substrafl import remote
from substrafl import schemas as fl_schemas

import joblib
from typing import Optional
import shutil

# The Iris dataset proposes four attributes to predict three different classes.
INPUT_SIZE = 4
OUTPUT_SIZE = 3


class SklearnFedAvgAlgo(algorithms.Algo):
    def __init__(self, model, seed=None):
        super().__init__(model=model, seed=seed)

        self._model = model

        if seed is not None:
            np.random.seed(seed)

    @property
    def strategies(self):
        """List of compatible strategies"""
        return [fl_schemas.StrategyName.FEDERATED_AVERAGING]

    @property
    def model(self):
        return self._model

    @remote.remote_data
    def train(
        self,
        datasamples,
        shared_state: Optional[fl_schemas.FedAvgAveragedState] = None,
    ) -> fl_schemas.FedAvgSharedState:
        """The train function to be executed on organizations containing
        data we want to train our model on. The @remote_data decorator is mandatory
        to allow this function to be sent and executed on the right organization.

        Args:
            datasamples (_type_): datasamples extracted from the organizations data using
                the given opener.
            shared_state (Optional[fl_schemas.FedAvgAveragedState], optional):
                shared_state provided by the aggregator. Defaults to None.

        Returns:
            fl_schemas.FedAvgSharedState: State to be sent to the aggregator.
        """

        if shared_state is None:
            # If shared state is None, we are at the init state of the algorithm.
            # We need all different instances of the algorithm to have the same
            # initialization.
            self._model.coef_ = np.ones((OUTPUT_SIZE, INPUT_SIZE))
            self._model.intercept_ = np.zeros(3)
        else:
            # If we have a shared state, we update the model parameters with
            # the average parameters updates.
            self._model.coef_ += np.reshape(
                shared_state.avg_parameters_update[:-1],
                (OUTPUT_SIZE, INPUT_SIZE),
            )
            self._model.intercept_ += shared_state.avg_parameters_update[-1]

        # To be able to compute the delta between the parameters before and after training,
        # we need to save them in a temporary variable.
        old_coef = self._model.coef_
        old_intercept = self._model.intercept_

        # Model training.
        self._model.fit(datasamples["data"], datasamples["targets"])

        # We compute de delta.
        delta_coef = self._model.coef_ - old_coef
        delta_bias = self._model.intercept_ - old_intercept

        # We reset the model parameters to their state before training in order to remove
        # the local updates from it.
        self._model.coef_ = old_coef
        self._model.intercept_ = old_intercept

        # We output the len of the dataset to apply a pondered average between
        # the organizations regarding their number of samples, and the local
        # parameters updates.
        # These updates are sent to the aggregator to compute the average
        # parameters updates, that we will receive in the next round in the
        # `shared_state`.
        return fl_schemas.FedAvgSharedState(
            n_samples=len(datasamples["targets"]),
            parameters_update=[p for p in delta_coef] + [delta_bias],
        )

    @remote.remote_data
    def predict(self, datasamples, shared_state, predictions_path):
        """The predict function to be executed on organizations containing
        data we want to test our model on. The @remote_data decorator is mandatory
        to allow this function to be sent and executed on the right organization.

        Args:
            datasamples (_type_): datasamples extracted from the organizations data using
                the given opener.
            shared_state (_type_): shared_state provided by the aggregator.
            predictions_path (_type_, optional): Path where to save the predictions.
                This path is provided by Substra and the metric will automatically
                get access to this path to load the predictions.
        """
        predictions = self._model.predict(datasamples["data"])

        if predictions_path is not None:
            np.save(predictions_path, predictions)

            # np.save() automatically adds a ".npy" to the end of the file.
            # We rename the file produced by removing the ".npy" suffix, to make sure that
            # predictions_path is the actual file name.
            shutil.move(str(predictions_path) + ".npy", predictions_path)

    def save(self, path):
        joblib.dump(
            {
                "model": self._model,
                "coef": self._model.coef_,
                "bias": self._model.intercept_,
            },
            path,
        )

    def load(self, path):
        loaded_dict = joblib.load(path)
        self._model = loaded_dict["model"]
        self._model.coef_ = loaded_dict["coef"]
        self._model.intercept_ = loaded_dict["bias"]
        return self

Federated Learning strategies

from substrafl.strategies import FedAvg

strategy = FedAvg()

Where to train where to aggregate

from substrafl.nodes import TrainDataNode
from substrafl.nodes import AggregationNode


aggregation_node = AggregationNode(ALGO_ORG_ID)

train_data_nodes = list()

for org_id in DATA_PROVIDER_ORGS_ID:

    # Create the Train Data Node (or training task) and save it in a list
    train_data_node = TrainDataNode(
        organization_id=org_id,
        data_manager_key=dataset_keys[org_id],
        data_sample_keys=[train_datasample_keys[org_id]],
    )
    train_data_nodes.append(train_data_node)

Where and when to test

from substrafl.nodes import TestDataNode
from substrafl.evaluation_strategy import EvaluationStrategy


test_data_nodes = list()

for org_id in DATA_PROVIDER_ORGS_ID:

    # Create the Test Data Node (or testing task) and save it in a list
    test_data_node = TestDataNode(
        organization_id=org_id,
        data_manager_key=dataset_keys[org_id],
        test_data_sample_keys=[test_datasample_keys[org_id]],
        metric_keys=[metric_key],
    )
    test_data_nodes.append(test_data_node)

my_eval_strategy = EvaluationStrategy(test_data_nodes=test_data_nodes, rounds=1)

Running the experiment

from substrafl.experiment import execute_experiment

# Number of time to apply the compute plan.
NUM_ROUNDS = 6

algo_deps = Dependency(pypi_dependencies=["numpy==1.23.1", "torch==1.11.0"])

compute_plan = execute_experiment(
    client=clients[ALGO_ORG_ID],
    algo=SklearnFedAvgAlgo(model=cls, seed=SEED),
    strategy=strategy,
    train_data_nodes=train_data_nodes,
    evaluation_strategy=my_eval_strategy,
    aggregation_node=aggregation_node,
    num_rounds=NUM_ROUNDS,
    experiment_folder=str(pathlib.Path.cwd() / "tmp" / "experiment_summaries"),
    dependencies=algo_deps,
)

Out:

Compute plan progress:   0%|          | 0/48 [00:00<?, ?it/s]/home/docs/checkouts/readthedocs.org/user_builds/owkin-substra-documentation/checkouts/0.25.0/docs/src/substra/substra/sdk/backends/local/backend.py:578: UserWarning: `transient=True` is ignored in local mode
  warnings.warn("`transient=True` is ignored in local mode")

Compute plan progress:   2%|2         | 1/48 [00:00<00:40,  1.17it/s]
Compute plan progress:   4%|4         | 2/48 [00:01<00:39,  1.16it/s]
Compute plan progress:   6%|6         | 3/48 [00:02<00:38,  1.16it/s]
Compute plan progress:   8%|8         | 4/48 [00:03<00:30,  1.45it/s]
Compute plan progress:  10%|#         | 5/48 [00:03<00:32,  1.33it/s]
Compute plan progress:  12%|#2        | 6/48 [00:04<00:33,  1.27it/s]
Compute plan progress:  15%|#4        | 7/48 [00:05<00:33,  1.24it/s]
Compute plan progress:  17%|#6        | 8/48 [00:06<00:32,  1.23it/s]
Compute plan progress:  19%|#8        | 9/48 [00:07<00:32,  1.21it/s]
Compute plan progress:  21%|##        | 10/48 [00:08<00:31,  1.20it/s]
Compute plan progress:  23%|##2       | 11/48 [00:08<00:26,  1.41it/s]
Compute plan progress:  25%|##5       | 12/48 [00:09<00:27,  1.33it/s]
Compute plan progress:  27%|##7       | 13/48 [00:10<00:27,  1.28it/s]
Compute plan progress:  29%|##9       | 14/48 [00:11<00:27,  1.24it/s]
Compute plan progress:  31%|###1      | 15/48 [00:11<00:27,  1.22it/s]
Compute plan progress:  33%|###3      | 16/48 [00:12<00:26,  1.21it/s]
Compute plan progress:  35%|###5      | 17/48 [00:13<00:26,  1.18it/s]
Compute plan progress:  38%|###7      | 18/48 [00:14<00:21,  1.39it/s]
Compute plan progress:  40%|###9      | 19/48 [00:14<00:22,  1.30it/s]
Compute plan progress:  42%|####1     | 20/48 [00:15<00:22,  1.25it/s]
Compute plan progress:  44%|####3     | 21/48 [00:16<00:22,  1.22it/s]
Compute plan progress:  46%|####5     | 22/48 [00:17<00:21,  1.21it/s]
Compute plan progress:  48%|####7     | 23/48 [00:18<00:20,  1.21it/s]
Compute plan progress:  50%|#####     | 24/48 [00:19<00:20,  1.19it/s]
Compute plan progress:  52%|#####2    | 25/48 [00:19<00:16,  1.41it/s]
Compute plan progress:  54%|#####4    | 26/48 [00:20<00:16,  1.33it/s]
Compute plan progress:  56%|#####6    | 27/48 [00:21<00:16,  1.27it/s]
Compute plan progress:  58%|#####8    | 28/48 [00:22<00:16,  1.24it/s]
Compute plan progress:  60%|######    | 29/48 [00:23<00:15,  1.23it/s]
Compute plan progress:  62%|######2   | 30/48 [00:23<00:14,  1.22it/s]
Compute plan progress:  65%|######4   | 31/48 [00:24<00:14,  1.20it/s]
Compute plan progress:  67%|######6   | 32/48 [00:25<00:11,  1.42it/s]
Compute plan progress:  69%|######8   | 33/48 [00:26<00:11,  1.34it/s]
Compute plan progress:  71%|#######   | 34/48 [00:26<00:10,  1.28it/s]
Compute plan progress:  73%|#######2  | 35/48 [00:27<00:10,  1.24it/s]
Compute plan progress:  75%|#######5  | 36/48 [00:28<00:09,  1.23it/s]
Compute plan progress:  77%|#######7  | 37/48 [00:29<00:09,  1.22it/s]
Compute plan progress:  79%|#######9  | 38/48 [00:30<00:08,  1.21it/s]
Compute plan progress:  81%|########1 | 39/48 [00:30<00:06,  1.43it/s]
Compute plan progress:  83%|########3 | 40/48 [00:31<00:05,  1.34it/s]
Compute plan progress:  85%|########5 | 41/48 [00:32<00:05,  1.28it/s]
Compute plan progress:  88%|########7 | 42/48 [00:33<00:04,  1.25it/s]
Compute plan progress:  90%|########9 | 43/48 [00:34<00:04,  1.24it/s]
Compute plan progress:  92%|#########1| 44/48 [00:34<00:03,  1.23it/s]
Compute plan progress:  94%|#########3| 45/48 [00:35<00:02,  1.22it/s]
Compute plan progress:  96%|#########5| 46/48 [00:36<00:01,  1.20it/s]
Compute plan progress:  98%|#########7| 47/48 [00:37<00:00,  1.20it/s]
Compute plan progress: 100%|##########| 48/48 [00:38<00:00,  1.21it/s]
Compute plan progress: 100%|##########| 48/48 [00:38<00:00,  1.26it/s]

Explore the results

Listing results

import pandas as pd

performances_df = pd.DataFrame(client.get_performances(compute_plan.key).dict())
print("\nPerformance Table: \n")
print(performances_df[["worker", "round_idx", "performance"]])

Out:

Performance Table:

       worker round_idx  performance
0   MyOrg6MSP         0     0.333333
1   MyOrg7MSP         0     0.133333
2   MyOrg6MSP         1     0.933333
3   MyOrg7MSP         1     1.000000
4   MyOrg6MSP         2     0.933333
5   MyOrg7MSP         2     1.000000
6   MyOrg6MSP         3     0.933333
7   MyOrg7MSP         3     1.000000
8   MyOrg6MSP         4     0.933333
9   MyOrg7MSP         4     1.000000
10  MyOrg6MSP         5     1.000000
11  MyOrg7MSP         5     1.000000
12  MyOrg6MSP         6     1.000000
13  MyOrg7MSP         6     1.000000

Plot results

import matplotlib.pyplot as plt

plt.title("Test dataset results")
plt.xlabel("Rounds")
plt.ylabel("Accuracy")

for id in DATA_PROVIDER_ORGS_ID:
    df = performances_df.query(f"worker == '{id}'")
    plt.plot(df["round_idx"], df["performance"], label=id)

plt.legend(loc="lower right")
plt.show()
Test dataset results

Download a model

from substrafl.model_loading import download_algo_files
from substrafl.model_loading import load_algo

client_to_dowload_from = DATA_PROVIDER_ORGS_ID[0]
round_idx = None

algo_files_folder = str(pathlib.Path.cwd() / "tmp" / "algo_files")

download_algo_files(
    client=clients[client_to_dowload_from],
    compute_plan_key=compute_plan.key,
    round_idx=round_idx,
    dest_folder=algo_files_folder,
)

cls = load_algo(input_folder=algo_files_folder).model

print("Coefs: ", cls.coef_)
print("Intercepts: ", cls.intercept_)

Out:

Coefs:  [[ 1.16237637  1.80062789 -0.59844895  0.16076327]
 [ 1.02009926  0.51773141  1.04883079  0.61084198]
 [ 0.26141703  0.12553336  1.99351081  1.67228741]]
Intercepts:  [ 0.21601049  0.1066958  -0.32270629]

Total running time of the script: ( 0 minutes 38.440 seconds)

Gallery generated by Sphinx-Gallery