Using Torch FedAvg on MNIST dataset

This example illustrates the basic usage of SubstraFL, and proposes Federated Learning model training using the Federated Average strategy on the MNIST Dataset of handwritten digits using PyTorch. In this example, we work on 28x28 pixel sized grayscale images. The problem considered is a classification problem aiming to recognize the number written on each image..

SubstraFL can be used with any machine learning framework (PyTorch, Tensorflow, Scikit-Learn, etc). However a specific interface has been developed for PyTorch which makes writing PyTorch code simpler than with other frameworks. This example here uses the specific PyTorch interface.

This example does not use a deployed platform of Substra and run in local mode.


  • To run this example locally, please make sure to download and unzip the assets needed to run it in the same directory as used this example:

    assets required to run this example

    Please ensure to have all the libraries installed. A requirements.txt file is included in the zip file, where you can run the command: pip install -r requirements.txt to install them.

  • Substra and SubstraFL should already be installed. If not follow the instructions described here: Installation


We work with three different organizations. Two organizations provide a dataset, and a third one provides the algorithm and register the machine learning tasks.

This example runs in local mode, simulating a federated learning experiment.

In the following code cell, we define the different organizations needed for our FL experiment.

from substra import Client


# Every computation will run in `subprocess` mode, where everything run locally in Python
# subprocesses.
# Ohers backend_types are:
# "docker" mode where computations run locally in docker containers
# "remote" mode where computations run remotely (you need to have deployed platform for that)
client_0 = Client(backend_type="subprocess")
client_1 = Client(backend_type="subprocess")
client_2 = Client(backend_type="subprocess")
# To run in remote mode you have to also use the function `Client.login(username, password)`

clients = {
    client_0.organization_info().organization_id: client_0,
    client_1.organization_info().organization_id: client_1,
    client_2.organization_info().organization_id: client_2,

# Store organization IDs
ORGS_ID = list(clients.keys())
ALGO_ORG_ID = ORGS_ID[0]  # Algo provider is defined as the first organization.
DATA_PROVIDER_ORGS_ID = ORGS_ID[1:]  # Data providers orgs are the two last organizations.

Data and metrics

Data preparation

This section downloads (if needed) the MNIST dataset using the torchvision library. It extracts the images from the raw files and locally creates two folders: one for each organization.

Each organization will have access to half the train data, and to half the test data (which corresponds to 30,000 images for training and 5,000 for testing each).

import pathlib
from torch_fedavg_assets.dataset.mnist_dataset import setup_mnist

# Create the temporary directory for generated data
(pathlib.Path.cwd() / "tmp").mkdir(exist_ok=True)
data_path = pathlib.Path.cwd() / "tmp" / "data_mnist"

setup_mnist(data_path, len(DATA_PROVIDER_ORGS_ID))


Downloading to /home/docs/checkouts/

  0%|          | 0/9912422 [00:00<?, ?it/s]
  4%|3         | 348160/9912422 [00:00<00:02, 3444540.61it/s]
  9%|9         | 920576/9912422 [00:00<00:01, 4741008.12it/s]
 16%|#5        | 1575936/9912422 [00:00<00:01, 5364566.85it/s]
 24%|##3       | 2329600/9912422 [00:00<00:01, 5924140.94it/s]
 30%|###       | 2984960/9912422 [00:00<00:01, 5971986.43it/s]
 37%|###6      | 3640320/9912422 [00:00<00:01, 6003973.39it/s]
 43%|####3     | 4295680/9912422 [00:00<00:00, 6012638.21it/s]
 79%|#######9  | 7875584/9912422 [00:00<00:00, 14997831.27it/s]
9913344it [00:00, 11260887.84it/s]
Extracting /home/docs/checkouts/ to /home/docs/checkouts/

Downloading to /home/docs/checkouts/

  0%|          | 0/28881 [00:00<?, ?it/s]
29696it [00:00, 12536894.98it/s]
Extracting /home/docs/checkouts/ to /home/docs/checkouts/

Downloading to /home/docs/checkouts/

  0%|          | 0/1648877 [00:00<?, ?it/s]
 25%|##4       | 409600/1648877 [00:00<00:00, 4091175.61it/s]
 62%|######2   | 1022976/1648877 [00:00<00:00, 5226210.89it/s]
1649664it [00:00, 5862142.50it/s]
Extracting /home/docs/checkouts/ to /home/docs/checkouts/

Downloading to /home/docs/checkouts/

  0%|          | 0/4542 [00:00<?, ?it/s]
5120it [00:00, 31814572.56it/s]
Extracting /home/docs/checkouts/ to /home/docs/checkouts/

Dataset registration

A Dataset is composed of an opener, which is a Python script that can load the data from the files in memory and a description markdown file. The Dataset object itself does not contain the data. The proper asset that contains the data is the datasample asset.

A datasample contains a local path to the data. A datasample can be linked to a dataset in order to add data to a dataset.

Data privacy is a key concept for Federated Learning experiments. That is why we set Permissions for each Assets to define which organization can use them.

Note that metadata, for instance: assets’ creation date, assets owner, are visible by all the organizations of a network.

from substra.sdk.schemas import DatasetSpec
from substra.sdk.schemas import Permissions
from substra.sdk.schemas import DataSampleSpec

assets_directory = pathlib.Path.cwd() / "torch_fedavg_assets"
dataset_keys = {}
train_datasample_keys = {}
test_datasample_keys = {}

for i, org_id in enumerate(DATA_PROVIDER_ORGS_ID):

    client = clients[org_id]

    permissions_dataset = Permissions(public=False, authorized_ids=[ALGO_ORG_ID])

    # DatasetSpec is the specification of a dataset. It makes sure every field
    # is well defined, and that our dataset is ready to be registered.
    # The real dataset object is created in the add_dataset method.

    dataset = DatasetSpec(
        data_opener=assets_directory / "dataset" / "",
        description=assets_directory / "dataset" / "",
    dataset_keys[org_id] = client.add_dataset(dataset)
    assert dataset_keys[org_id], "Missing dataset key"

    # Add the training data on each organization.
    data_sample = DataSampleSpec(
        path=data_path / f"org_{i+1}" / "train",
    train_datasample_keys[org_id] = client.add_data_sample(data_sample)

    # Add the testing data on each organization.
    data_sample = DataSampleSpec(
        path=data_path / f"org_{i+1}" / "test",
    test_datasample_keys[org_id] = client.add_data_sample(data_sample)

Metric registration

A metric is a function used to compute the score of predictions on one or several datasamples.

To add a metric, you need to define a function that computes and return a performance from the datasamples (as returned by the opener) and the predictions_path (to be loaded within the function).

When using a Torch SubstraFL algorithm, the predictions are saved in the predict function under the numpy format so that you can simply load them using np.load.

After defining the metrics dependencies and permissions, we use the add_metric function to register the metric. This metric will be used on the test datasamples to evaluate the model performances.

from sklearn.metrics import accuracy_score
import numpy as np

from substrafl.dependency import Dependency
from substrafl.remote.register import add_metric

permissions_metric = Permissions(public=False, authorized_ids=[ALGO_ORG_ID] + DATA_PROVIDER_ORGS_ID)

# The Dependency object is instantiated in order to install the right libraries in
# the Python environment of each organization.
metric_deps = Dependency(pypi_dependencies=["numpy==1.23.1", "scikit-learn==1.1.1"])

def accuracy(datasamples, predictions_path):
    y_true = datasamples["labels"]
    y_pred = np.load(predictions_path)

    return accuracy_score(y_true, np.argmax(y_pred, axis=1))

metric_key = add_metric(

Specify the machine learning components

This section uses the PyTorch based SubstraFL API to simplify the machine learning components definition. However, SubstraFL is compatible with any machine learning framework.

In this section, you will:

  • register a model and its dependencies

  • specify the federated learning strategy

  • specify the organizations where to train and where to aggregate

  • specify the organizations where to test the models

  • actually run the computations

Model definition

We choose to use a classic torch CNN as the model to train. The model structure is defined by the user independently of SubstraFL.

import torch
from torch import nn
import torch.nn.functional as F

seed = 42

class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=5)
        self.conv2 = nn.Conv2d(32, 32, kernel_size=5)
        self.conv3 = nn.Conv2d(32, 64, kernel_size=5)
        self.fc1 = nn.Linear(3 * 3 * 64, 256)
        self.fc2 = nn.Linear(256, 10)

    def forward(self, x, eval=False):
        x = F.relu(self.conv1(x))
        x = F.relu(F.max_pool2d(self.conv2(x), 2))
        x = F.dropout(x, p=0.5, training=not eval)
        x = F.relu(F.max_pool2d(self.conv3(x), 2))
        x = F.dropout(x, p=0.5, training=not eval)
        x = x.view(-1, 3 * 3 * 64)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, p=0.5, training=not eval)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

model = CNN()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = torch.nn.CrossEntropyLoss()

Specifying on how much data to train

To specify on how much data to train at each round, we use the index_generator object. We specify the batch size and the number of batches to consider for each round (called num_updates). See Index Generator for more details.

from substrafl.index_generator import NpIndexGenerator

# Number of model update between each FL strategy aggregation.

# Number of samples per update.

index_generator = NpIndexGenerator(

Torch Dataset definition

This torch Dataset is used to preprocess the data using the __getitem__ function.

This torch Dataset needs to have a specific __init__ signature, that must contain (self, datasamples, is_inference).

The __getitem__ function is expected to return (inputs, outputs) if is_inference is False, else only the inputs. This behavior can be changed by re-writing the _local_train or predict methods.

class TorchDataset(
    def __init__(self, datasamples, is_inference: bool):
        self.x = datasamples["images"]
        self.y = datasamples["labels"]
        self.is_inference = is_inference

    def __getitem__(self, idx):

        if self.is_inference:
            x = torch.FloatTensor(self.x[idx][None, ...]) / 255
            return x

            x = torch.FloatTensor(self.x[idx][None, ...]) / 255

            y = torch.tensor(self.y[idx]).type(torch.int64)
            y = F.one_hot(y, 10)
            y = y.type(torch.float32)

            return x, y

    def __len__(self):
        return len(self.x)

SubstraFL algo definition

A SubstraFL Algo gathers all the elements that we defined that run locally in each organization. This is the only SubstraFL object that is framework specific (here PyTorch specific).

The TorchDataset is passed as a class to the Torch algorithm. Indeed, this TorchDataset will be instantiated directly on the data provider organization.

from substrafl.algorithms.pytorch import TorchFedAvgAlgo

class MyAlgo(TorchFedAvgAlgo):
    def __init__(self):

Federated Learning strategies

A FL strategy specifies how to train a model on distributed data. The most well known strategy is the Federated Averaging strategy: train locally a model on every organization, then aggregate the weight updates from every organization, and then apply locally at each organization the averaged updates.

from substrafl.strategies import FedAvg

strategy = FedAvg()

Where to train where to aggregate

We specify on which data we want to train our model, using the TrainDataNode objets. Here we train on the two datasets that we have registered earlier.

The AggregationNode specifies the organization on which the aggregation operation will be computed.

from substrafl.nodes import TrainDataNode
from substrafl.nodes import AggregationNode

aggregation_node = AggregationNode(ALGO_ORG_ID)

train_data_nodes = list()

for org_id in DATA_PROVIDER_ORGS_ID:

    # Create the Train Data Node (or training task) and save it in a list
    train_data_node = TrainDataNode(

Where and when to test

With the same logic as the train nodes, we create TestDataNode to specify on which data we want to test our model.

The Evaluation Strategy defines where and at which frequency we evaluate the model, using the given metric(s) that you registered in a previous section.

from substrafl.nodes import TestDataNode
from substrafl.evaluation_strategy import EvaluationStrategy

test_data_nodes = list()

for org_id in DATA_PROVIDER_ORGS_ID:

    # Create the Test Data Node (or testing task) and save it in a list
    test_data_node = TestDataNode(

# Test at the end of every round
my_eval_strategy = EvaluationStrategy(test_data_nodes=test_data_nodes, rounds=1)

Running the experiment

We now have all the necessary objects to launch our experiment. Please see a summary below of all the objects we created so far:

  • A Client to add or retrieve the assets of our experiment, using their keys to identify them.

  • An Torch algorithm to define the training parameters (optimizer, train function, predict function, etc…).

  • A Federated Strategy, to specify how to train the model on distributed data.

  • Train data nodes to indicate on which data to train.

  • An Evaluation Strategy, to define where and at which frequency we evaluate the model.

  • An AggregationNode, to specify the organization on which the aggregation operation will be computed.

  • The number of rounds, a round being defined by a local training step followed by an aggregation operation.

  • An experiment folder to save a summary of the operation made.

  • The Dependency to define the libraries on which the experiment needs to run.

from substrafl.experiment import execute_experiment

# A round is defined by a local training step followed by an aggregation operation

# The Dependency object is instantiated in order to install the right libraries in
# the Python environment of each organization.
algo_deps = Dependency(pypi_dependencies=["numpy==1.23.1", "torch==1.11.0"])

compute_plan = execute_experiment(
    experiment_folder=str(pathlib.Path.cwd() / "tmp" / "experiment_summaries"),


Compute plan progress:   0%|          | 0/27 [00:00<?, ?it/s]/home/docs/checkouts/ UserWarning: `transient=True` is ignored in local mode
  warnings.warn("`transient=True` is ignored in local mode")

Compute plan progress:   4%|3         | 1/27 [00:04<02:05,  4.84s/it]
Compute plan progress:   7%|7         | 2/27 [00:09<02:00,  4.84s/it]
Compute plan progress:  11%|#1        | 3/27 [00:13<01:40,  4.20s/it]
Compute plan progress:  15%|#4        | 4/27 [00:13<01:02,  2.71s/it]
Compute plan progress:  19%|#8        | 5/27 [00:17<01:05,  2.98s/it]
Compute plan progress:  22%|##2       | 6/27 [00:21<01:14,  3.57s/it]
Compute plan progress:  26%|##5       | 7/27 [00:26<01:18,  3.94s/it]
Compute plan progress:  30%|##9       | 8/27 [00:27<00:56,  2.95s/it]
Compute plan progress:  33%|###3      | 9/27 [00:28<00:41,  2.29s/it]
Compute plan progress:  37%|###7      | 10/27 [00:31<00:44,  2.63s/it]
Compute plan progress:  41%|####      | 11/27 [00:31<00:31,  1.96s/it]
Compute plan progress:  44%|####4     | 12/27 [00:35<00:35,  2.38s/it]
Compute plan progress:  48%|####8     | 13/27 [00:39<00:42,  3.01s/it]
Compute plan progress:  52%|#####1    | 14/27 [00:44<00:44,  3.45s/it]
Compute plan progress:  56%|#####5    | 15/27 [00:45<00:31,  2.67s/it]
Compute plan progress:  59%|#####9    | 16/27 [00:45<00:23,  2.12s/it]
Compute plan progress:  63%|######2   | 17/27 [00:49<00:25,  2.50s/it]
Compute plan progress:  67%|######6   | 18/27 [00:49<00:16,  1.88s/it]
Compute plan progress:  70%|#######   | 19/27 [00:52<00:17,  2.21s/it]
Compute plan progress:  74%|#######4  | 20/27 [00:57<00:20,  2.86s/it]
Compute plan progress:  78%|#######7  | 21/27 [01:01<00:19,  3.33s/it]
Compute plan progress:  81%|########1 | 22/27 [01:02<00:12,  2.59s/it]
Compute plan progress:  85%|########5 | 23/27 [01:03<00:08,  2.06s/it]
Compute plan progress:  89%|########8 | 24/27 [01:06<00:07,  2.34s/it]
Compute plan progress:  93%|#########2| 25/27 [01:09<00:05,  2.69s/it]
Compute plan progress:  96%|#########6| 26/27 [01:10<00:02,  2.14s/it]
Compute plan progress: 100%|##########| 27/27 [01:11<00:00,  1.74s/it]
Compute plan progress: 100%|##########| 27/27 [01:11<00:00,  2.64s/it]

Explore the results

List results

import pandas as pd

performances_df = pd.DataFrame(client.get_performances(compute_plan.key).dict())
print("\nPerformance Table: \n")
print(performances_df[["worker", "round_idx", "performance"]])


Performance Table:

      worker round_idx  performance
0  MyOrg3MSP         0       0.0990
1  MyOrg4MSP         0       0.0988
2  MyOrg3MSP         1       0.7458
3  MyOrg4MSP         1       0.8300
4  MyOrg3MSP         2       0.8912
5  MyOrg4MSP         2       0.9428
6  MyOrg3MSP         3       0.9144
7  MyOrg4MSP         3       0.9516

Plot results

import matplotlib.pyplot as plt

plt.title("Test dataset results")

    df = performances_df.query(f"worker == '{id}'")
    plt.plot(df["round_idx"], df["performance"], label=id)

plt.legend(loc="lower right")
Test dataset results

Download a model

After the experiment, you might be interested in downloading your trained model. To do so, you will need the source code in order to reload your code architecture in memory. You have the option to choose the client and the round you are interested in downloading.

If round_idx is set to None, the last round will be selected by default.

from substrafl.model_loading import download_algo_files
from substrafl.model_loading import load_algo

client_to_dowload_from = DATA_PROVIDER_ORGS_ID[0]
round_idx = None

algo_files_folder = str(pathlib.Path.cwd() / "tmp" / "algo_files")


model = load_algo(input_folder=algo_files_folder).model



  (conv1): Conv2d(1, 32, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(32, 32, kernel_size=(5, 5), stride=(1, 1))
  (conv3): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=576, out_features=256, bias=True)
  (fc2): Linear(in_features=256, out_features=10, bias=True)

Total running time of the script: ( 1 minutes 15.037 seconds)

Gallery generated by Sphinx-Gallery