Substrafl FedAvg on MNIST dataset

This example illustrate the basic usage of Substrafl, and propose a model training by Federated Learning using de Federated Average strategy.

It is based on the MNIST Dataset of handwritten digits.

In this example, we work on the grayscale images of size 28x28 pixels. The problem considered is a classification problem aiming to recognize the number written on each image.

The objective of this example is to launch a federated learning experiment on two organizations, using the FedAvg strategy on a convolutional neural network (CNN) torch model.

This example does not use the deployed platform of Substra and will run in local mode.


  • To run this example locally, please make sure to download and unzip in the same directory as this example the assets needed to run it:

    assets required to run this example

    Please ensure to have all the libraries installed, a requirements.txt file is included in the zip file, where you can run the command: pip install -r requirements.txt to install them.

  • Substra and Substrafl should already be installed, if not follow the instructions described here: Installation

Client and data preparation


import codecs
import os
import pathlib
import sys
import zipfile

import numpy as np
from torchvision.datasets import MNIST

Creating the Substra Client

We work with two different organizations, defined by their IDs. Both organizations provide a dataset. One of them will also provide the algorithm and # will register the machine learning tasks.

Once these variables defined, we can create our Substra Client.

This example runs in local mode, simulating a federated learning experiment.

from substra import Client

# Choose the subprocess mode to locally simulate the FL process
clients = [Client(backend_type="subprocess") for _ in range(N_CLIENTS)]
clients = {client.organization_info().organization_id: client for client in clients}

# Store their IDs
ORGS_ID = list(clients.keys())

# The org id on which your computation tasks are registered

# Create the temporary directory for generated data
(pathlib.Path.cwd() / "tmp").mkdir(exist_ok=True)

data_path = pathlib.Path.cwd() / "tmp" / "data"
assets_directory = pathlib.Path.cwd() / "assets"

Download and extract MNIST dataset

This section downloads (if needed) the MNIST dataset using the torchvision library. It extracts the images from the raw files and locally create two folders: one for each organization.

Each organization will have access to half the train data, and to half the test data (which correspond to 30,000 images for training and 5,000 for testing each).

def get_int(b: bytes) -> int:
    return int(codecs.encode(b, "hex"), 16)

def MNISTraw2numpy(path: str, strict: bool = True) -> np.array:
    # read
    with open(path, "rb") as f:
        data =
    # parse
    magic = get_int(data[0:4])
    nd = magic % 256
    assert 1 <= nd <= 3
    numpy_type = np.uint8
    s = [get_int(data[4 * (i + 1) : 4 * (i + 2)]) for i in range(nd)]

    num_bytes_per_value = np.iinfo(numpy_type).bits // 8
    # The MNIST format uses the big endian byte order. If the system uses little endian byte order by default,
    # we need to reverse the bytes before we can read them with np.frombuffer().
    needs_byte_reversal = sys.byteorder == "little" and num_bytes_per_value > 1
    parsed = np.frombuffer(bytearray(data), dtype=numpy_type, offset=(4 * (nd + 1)))
    if needs_byte_reversal:
        parsed = parsed.flip(0)

    assert parsed.shape[0] == or not strict
    return parsed.reshape(*s)

raw_path = pathlib.Path(data_path) / "MNIST" / "raw"

# Download the dataset
MNIST(data_path, download=True)

# Extract numpy array from raw data
train_images = MNISTraw2numpy(str(raw_path / "train-images-idx3-ubyte"))
train_labels = MNISTraw2numpy(str(raw_path / "train-labels-idx1-ubyte"))
test_images = MNISTraw2numpy(str(raw_path / "t10k-images-idx3-ubyte"))
test_labels = MNISTraw2numpy(str(raw_path / "t10k-labels-idx1-ubyte"))

# Split array into the number of organization
train_images_folds = np.split(train_images, N_CLIENTS)
train_labels_folds = np.split(train_labels, N_CLIENTS)
test_images_folds = np.split(test_images, N_CLIENTS)
test_labels_folds = np.split(test_labels, N_CLIENTS)

# Save splits in different folders to simulate the different organization
for i in range(N_CLIENTS):

    # Save train dataset on each org
    os.makedirs(str(data_path / f"org_{i+1}/train"), exist_ok=True)
    filename = data_path / f"org_{i+1}/train/train_images.npy", train_images_folds[i])
    filename = data_path / f"org_{i+1}/train/train_labels.npy", train_labels_folds[i])

    # Save test dataset on each org
    os.makedirs(str(data_path / f"org_{i+1}/test"), exist_ok=True)
    filename = data_path / f"org_{i+1}/test/test_images.npy", test_images_folds[i])
    filename = data_path / f"org_{i+1}/test/test_labels.npy", test_labels_folds[i])


Downloading to /home/docs/checkouts/

  0%|          | 0/9912422 [00:00<?, ?it/s]
  6%|5         | 590848/9912422 [00:00<00:01, 5858162.67it/s]
 52%|#####2    | 5160960/9912422 [00:00<00:00, 29157928.87it/s]
9913344it [00:00, 35695956.64it/s]
Extracting /home/docs/checkouts/ to /home/docs/checkouts/

Downloading to /home/docs/checkouts/

  0%|          | 0/28881 [00:00<?, ?it/s]
29696it [00:00, 93649662.85it/s]
Extracting /home/docs/checkouts/ to /home/docs/checkouts/

Downloading to /home/docs/checkouts/

  0%|          | 0/1648877 [00:00<?, ?it/s]
 25%|##4       | 411648/1648877 [00:00<00:00, 4081288.49it/s]
1649664it [00:00, 10723805.46it/s]
Extracting /home/docs/checkouts/ to /home/docs/checkouts/

Downloading to /home/docs/checkouts/

  0%|          | 0/4542 [00:00<?, ?it/s]
5120it [00:00, 31956601.90it/s]
Extracting /home/docs/checkouts/ to /home/docs/checkouts/

Registering assets

Substra and Substrafl imports

from substra.sdk.schemas import (
from substrafl.nodes import TestDataNode, TrainDataNode


As data can not be seen once it is registered on the platform, we set Permissions for each Assets define their access rights to the different data.

The metadata are visible by all the users of a Channel.

permissions = Permissions(public=False, authorized_ids=ORGS_ID)

Registering dataset

A Dataset is composed of an opener, which is a Python script with the instruction of how to load the data from the files in memory, and a description markdown file.

dataset = DatasetSpec(
    data_opener=assets_directory / "dataset" / "",
    description=assets_directory / "dataset" / "",

Adding Metrics

A metric corresponds to an algorithm used to compute the score of predictions on a datasample. Concretely, a metric corresponds to an archive (tar or zip file), automatically build from:

  • a Python scripts that implement the metric computation

  • a Dockerfile to specify the required dependencies of the Python scripts

inputs_metrics = [
        identifier="opener", kind=AssetKind.data_manager, optional=False, multiple=False
        identifier="predictions", kind=AssetKind.model, optional=False, multiple=False

outputs_metrics = [
    AlgoOutputSpec(identifier="performance", kind=AssetKind.performance, multiple=False)

objective = AlgoSpec(
    description=assets_directory / "metric" / "",
    file=assets_directory / "metric" / "",

    assets_directory / "metric" / "",
    assets_directory / "metric" / "Dockerfile",

archive_path = objective.file
with zipfile.ZipFile(archive_path, "w") as z:
    for filepath in METRICS_DOCKERFILE_FILES:

metric_key = clients[ALGO_ORG_ID].add_algo(objective)

Train and test data nodes

The Dataset object itself does not contain the data. The proper asset to access them is the datasample asset.

A datasample contains a local path to the data, and the key identifying the Dataset it is based on, in order to have access to the proper file.

Now that all our Assets are well defined, we can create TrainDataNode and TestDataNode to gathered the Dataset and the datasamples on the specified nodes.

train_data_nodes = list()
test_data_nodes = list()

for ind, org_id in enumerate(ORGS_ID):
    client = clients[org_id]

    # Add the dataset to the client to provide access to the opener in each organization.
    dataset_key = client.add_dataset(dataset)
    assert dataset_key, "Missing data manager key"

    # Add the training data on each organization.
    data_sample = DataSampleSpec(
        path=data_path / f"org_{ind+1}" / "train",
    train_datasample_key = client.add_data_sample(

    # Create the Train Data Node (or training task) and save it in a list
    train_data_node = TrainDataNode(

    # Add the testing data on each organization.
    data_sample = DataSampleSpec(
        path=data_path / f"org_{ind+1}" / "test",
    test_datasample_key = client.add_data_sample(

    # Create the Test Data Node (or testing task) and save it in a list
    test_data_node = TestDataNode(

Machine Learning specification

Torch imports

import torch
from torch import nn
import torch.nn.functional as F

CNN definition

We choose to use a classic torch CNN as the model to train. The model structure is defined by the user independently of Substrafl.

seed = 42

class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=5)
        self.conv2 = nn.Conv2d(32, 32, kernel_size=5)
        self.conv3 = nn.Conv2d(32, 64, kernel_size=5)
        self.fc1 = nn.Linear(3 * 3 * 64, 256)
        self.fc2 = nn.Linear(256, 10)

    def forward(self, x, eval=False):
        x = F.relu(self.conv1(x))
        x = F.relu(F.max_pool2d(self.conv2(x), 2))
        x = F.dropout(x, p=0.5, training=not eval)
        x = F.relu(F.max_pool2d(self.conv3(x), 2))
        x = F.dropout(x, p=0.5, training=not eval)
        x = x.view(-1, 3 * 3 * 64)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, p=0.5, training=not eval)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

model = CNN()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = torch.nn.CrossEntropyLoss()

Substrafl imports

from typing import Any

from substrafl.algorithms.pytorch import TorchFedAvgAlgo
from substrafl.dependency import Dependency
from substrafl.strategies import FedAvg
from substrafl.nodes import AggregationNode
from substrafl.evaluation_strategy import EvaluationStrategy
from substrafl.index_generator import NpIndexGenerator
from substrafl.experiment import execute_experiment

Substrafl algo definition

To instantiate a Substrafl Torch Algorithms, you need to define a torch Dataset with a specific __init__ signature, that must contain (self, x, y, is_inference). This torch Dataset is useful to preprocess your data on the __getitem__ function. The __getitem__ function is expected to return x and y if is_inference is False, else x. This behavior can be changed by re-writing the _local_train or predict methods.

This dataset is passed as a class to the Torch Algorithms. Indeed, this torch Dataset will be instantiated within the algorithm, using the opener functions as x and y parameters.

The index generator will be used a the batch sampler of the dataset, in order to save the state of the seen samples during the training, as Federated Algorithms have a notion of num_updates, which forced the batch sampler of the dataset to be stateful.

# Number of model update between each FL strategy aggregation.

# Number of samples per update.

index_generator = NpIndexGenerator(

class TorchDataset(
    def __init__(self, datasamples, is_inference: bool):
        self.x = torch.FloatTensor(datasamples["images"][:, None, ...])
        self.y = F.one_hot(
            torch.from_numpy(datasamples["labels"]).type(torch.int64), 10
        self.is_inference = is_inference

    def __getitem__(self, idx):
        if not self.is_inference:
            return self.x[idx] / 255, self.y[idx]
            return self.x[idx] / 255

    def __len__(self):
        return len(self.x)

class MyAlgo(TorchFedAvgAlgo):
    def __init__(self):

Algo dependencies

The dependencies needed for the Torch Algorithms are specified by a Dependency object, in order to install the right library in the Python environment of each organization.

algo_deps = Dependency(pypi_dependencies=["numpy==1.23.1", "torch==1.11.0"])

Federated Learning strategies

For this example, we choose to use the Federated averaging Strategy (Strategies), based on the FedAvg paper by McMahan et al., 2017.

strategy = FedAvg()

Running the experiment

We now have all the necessary objects to launch our experiment. Below a summary of all the objects we created so far:

  • A Client to orchestrate all the assets of our project, using their keys to identify them

  • An Torch Algorithms, to define the training parameters (optimizer, train function, predict function, etc…)

  • A Strategies, to specify the federated learning aggregation operation

  • TrainDataNode, to indicate where we can process training task, on which data and using which opener

  • An Evaluation Strategy, to define where and at which frequency we evaluate the model

  • An AggregationNode, to specify the node on which the aggregation operation will be computed

  • The number of round, a round being defined by a local training step followed by an aggregation operation

  • An experiment folder to save a summary of the operation made

  • The Dependency to define the libraries the experiment needs to run.

aggregation_node = AggregationNode(ALGO_ORG_ID)

my_eval_strategy = EvaluationStrategy(test_data_nodes=test_data_nodes, rounds=1)

# Number of time to apply the compute plan.

compute_plan = execute_experiment(
    experiment_folder=str(pathlib.Path.cwd() / "tmp" / "experiment_summaries"),


Compute plan progress:   0%|          | 0/23 [00:00<?, ?it/s]/home/docs/checkouts/ UserWarning: `transient=True` is ignored in local mode
  warnings.warn("`transient=True` is ignored in local mode")

Compute plan progress:   4%|4         | 1/23 [00:04<01:45,  4.80s/it]
Compute plan progress:   9%|8         | 2/23 [00:09<01:39,  4.76s/it]
Compute plan progress:  13%|#3        | 3/23 [00:09<00:55,  2.79s/it]
Compute plan progress:  17%|#7        | 4/23 [00:14<01:06,  3.53s/it]
Compute plan progress:  22%|##1       | 5/23 [00:19<01:11,  3.96s/it]
Compute plan progress:  26%|##6       | 6/23 [00:19<00:47,  2.77s/it]
Compute plan progress:  30%|###       | 7/23 [00:23<00:46,  2.94s/it]
Compute plan progress:  35%|###4      | 8/23 [00:26<00:46,  3.08s/it]
Compute plan progress:  39%|###9      | 9/23 [00:31<00:50,  3.61s/it]
Compute plan progress:  43%|####3     | 10/23 [00:35<00:50,  3.92s/it]
Compute plan progress:  48%|####7     | 11/23 [00:36<00:34,  2.91s/it]
Compute plan progress:  52%|#####2    | 12/23 [00:37<00:24,  2.22s/it]
Compute plan progress:  57%|#####6    | 13/23 [00:37<00:16,  1.68s/it]
Compute plan progress:  61%|######    | 14/23 [00:41<00:19,  2.22s/it]
Compute plan progress:  65%|######5   | 15/23 [00:44<00:20,  2.58s/it]
Compute plan progress:  70%|######9   | 16/23 [00:48<00:21,  3.13s/it]
Compute plan progress:  74%|#######3  | 17/23 [00:53<00:21,  3.54s/it]
Compute plan progress:  78%|#######8  | 18/23 [00:53<00:13,  2.67s/it]
Compute plan progress:  83%|########2 | 19/23 [00:54<00:08,  2.05s/it]
Compute plan progress:  87%|########6 | 20/23 [00:57<00:06,  2.32s/it]
Compute plan progress:  91%|#########1| 21/23 [01:00<00:05,  2.63s/it]
Compute plan progress:  96%|#########5| 22/23 [01:01<00:02,  2.03s/it]
Compute plan progress: 100%|##########| 23/23 [01:02<00:00,  1.61s/it]
Compute plan progress: 100%|##########| 23/23 [01:02<00:00,  2.70s/it]

Listing results

import pandas as pd

performances_df = pd.DataFrame(client.get_performances(compute_plan.key).dict())
print("\nPerformance Table: \n")
print(performances_df[["worker", "round_idx", "performance"]])


Performance Table:

      worker round_idx  performance
0  MyOrg2MSP         1       0.7458
1  MyOrg3MSP         1       0.8300
2  MyOrg2MSP         2       0.8912
3  MyOrg3MSP         2       0.9428
4  MyOrg2MSP         3       0.9144
5  MyOrg3MSP         3       0.9516

Plot results

import matplotlib.pyplot as plt

plt.title("Test dataset results")

for id in ORGS_ID:
    df = performances_df.query(f"worker == '{id}'")
    plt.plot(df["round_idx"], df["performance"], label=id)

plt.legend(loc="lower right")
Test dataset results

Total running time of the script: ( 1 minutes 4.551 seconds)

Gallery generated by Sphinx-Gallery