Using Torch FedAvg on MNIST dataset

This example illustrates the basic usage of SubstraFL and proposes Federated Learning using the Federated Averaging strategy on the MNIST Dataset of handwritten digits using PyTorch. In this example, we work on 28x28 pixel sized grayscale images. This is a classification problem aiming to recognize the number written on each image.

SubstraFL can be used with any machine learning framework (PyTorch, Tensorflow, Scikit-Learn, etc).

However a specific interface has been developed for PyTorch which makes writing PyTorch code simpler than with other frameworks. This example here uses the specific PyTorch interface.

This example does not use a deployed platform of Substra and runs in local mode.

To run this example, you have two options:

  • Recommended option: use a hosted Jupyter notebook. With this option you don’t have to install anything, just run the notebook. To access the hosted notebook, scroll down at the bottom of this page and click on the Launch Binder button.

  • Run the example locally. To do that you need to download and unzip the assets needed to run it in the same directory as used this example.

    • Please ensure to have all the libraries installed. A requirements.txt file is included in the zip file, where you can run the command pip install -r requirements.txt to install them.

    • Substra and SubstraFL should already be installed. If not follow the instructions described here: Installation.

Setup

This example runs with three organizations. Two organizations provide datasets, while a third one provides the algorithm.

In the following code cell, we define the different organizations needed for our FL experiment.

from substra import Client

N_CLIENTS = 3

client_0 = Client(backend_type="subprocess")
client_1 = Client(backend_type="subprocess")
client_2 = Client(backend_type="subprocess")

Every computation will run in subprocess mode, where everything runs locally in Python subprocesses. Other backend_types are:

  • docker mode where computations run locally in docker containers

  • remote mode where computations run remotely (you need to have a deployed platform for that)

To run in remote mode, use the following syntax:

client_remote = Client(url="MY_BACKEND_URL") client_remote.login(username="my-username", password="my-password")

# Create a dictionary to easily access each client from its human-friendly id
clients = {
    client_0.organization_info().organization_id: client_0,
    client_1.organization_info().organization_id: client_1,
    client_2.organization_info().organization_id: client_2,
}

# Store organization IDs
ORGS_ID = list(clients)
ALGO_ORG_ID = ORGS_ID[0]  # Algo provider is defined as the first organization.
DATA_PROVIDER_ORGS_ID = ORGS_ID[1:]  # Data providers orgs are the two last organizations.

Data and metrics

Data preparation

This section downloads (if needed) the MNIST dataset using the torchvision library. It extracts the images from the raw files and locally creates a folder for each organization.

Each organization will have access to half the training data and half the test data (which corresponds to 30,000 images for training and 5,000 for testing each).

import pathlib
from torch_fedavg_assets.dataset.mnist_dataset import setup_mnist


# Create the temporary directory for generated data
(pathlib.Path.cwd() / "tmp").mkdir(exist_ok=True)
data_path = pathlib.Path.cwd() / "tmp" / "data_mnist"

setup_mnist(data_path, len(DATA_PROVIDER_ORGS_ID))

Out:

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to /home/docs/checkouts/readthedocs.org/user_builds/owkin-substra-documentation/checkouts/0.26.1/substrafl_examples/get_started/tmp/data_mnist/MNIST/raw/train-images-idx3-ubyte.gz

  0%|          | 0/9912422 [00:00<?, ?it/s]
 88%|########7 | 8683520/9912422 [00:00<00:00, 86577673.83it/s]
100%|##########| 9912422/9912422 [00:00<00:00, 92031342.68it/s]
Extracting /home/docs/checkouts/readthedocs.org/user_builds/owkin-substra-documentation/checkouts/0.26.1/substrafl_examples/get_started/tmp/data_mnist/MNIST/raw/train-images-idx3-ubyte.gz to /home/docs/checkouts/readthedocs.org/user_builds/owkin-substra-documentation/checkouts/0.26.1/substrafl_examples/get_started/tmp/data_mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to /home/docs/checkouts/readthedocs.org/user_builds/owkin-substra-documentation/checkouts/0.26.1/substrafl_examples/get_started/tmp/data_mnist/MNIST/raw/train-labels-idx1-ubyte.gz

  0%|          | 0/28881 [00:00<?, ?it/s]
100%|##########| 28881/28881 [00:00<00:00, 152563846.13it/s]
Extracting /home/docs/checkouts/readthedocs.org/user_builds/owkin-substra-documentation/checkouts/0.26.1/substrafl_examples/get_started/tmp/data_mnist/MNIST/raw/train-labels-idx1-ubyte.gz to /home/docs/checkouts/readthedocs.org/user_builds/owkin-substra-documentation/checkouts/0.26.1/substrafl_examples/get_started/tmp/data_mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to /home/docs/checkouts/readthedocs.org/user_builds/owkin-substra-documentation/checkouts/0.26.1/substrafl_examples/get_started/tmp/data_mnist/MNIST/raw/t10k-images-idx3-ubyte.gz

  0%|          | 0/1648877 [00:00<?, ?it/s]
100%|##########| 1648877/1648877 [00:00<00:00, 29648725.66it/s]
Extracting /home/docs/checkouts/readthedocs.org/user_builds/owkin-substra-documentation/checkouts/0.26.1/substrafl_examples/get_started/tmp/data_mnist/MNIST/raw/t10k-images-idx3-ubyte.gz to /home/docs/checkouts/readthedocs.org/user_builds/owkin-substra-documentation/checkouts/0.26.1/substrafl_examples/get_started/tmp/data_mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to /home/docs/checkouts/readthedocs.org/user_builds/owkin-substra-documentation/checkouts/0.26.1/substrafl_examples/get_started/tmp/data_mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz

  0%|          | 0/4542 [00:00<?, ?it/s]
100%|##########| 4542/4542 [00:00<00:00, 33598816.17it/s]
Extracting /home/docs/checkouts/readthedocs.org/user_builds/owkin-substra-documentation/checkouts/0.26.1/substrafl_examples/get_started/tmp/data_mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz to /home/docs/checkouts/readthedocs.org/user_builds/owkin-substra-documentation/checkouts/0.26.1/substrafl_examples/get_started/tmp/data_mnist/MNIST/raw

Dataset registration

A Dataset is composed of an opener, which is a Python script that can load the data from the files in memory and a description markdown file. The Dataset object itself does not contain the data. The proper asset that contains the data is the datasample asset.

A datasample contains a local path to the data. A datasample can be linked to a dataset in order to add data to a dataset.

Data privacy is a key concept for Federated Learning experiments. That is why we set Permissions for Assets to determine how each organization can access a specific asset. You can read more about permissions in the User Guide.

Note that metadata such as the assets’ creation date and the asset owner are visible to all the organizations of a network.

from substra.sdk.schemas import DatasetSpec
from substra.sdk.schemas import Permissions
from substra.sdk.schemas import DataSampleSpec

assets_directory = pathlib.Path.cwd() / "torch_fedavg_assets"
dataset_keys = {}
train_datasample_keys = {}
test_datasample_keys = {}

for i, org_id in enumerate(DATA_PROVIDER_ORGS_ID):
    client = clients[org_id]

    permissions_dataset = Permissions(public=False, authorized_ids=[ALGO_ORG_ID])

    # DatasetSpec is the specification of a dataset. It makes sure every field
    # is well-defined, and that our dataset is ready to be registered.
    # The real dataset object is created in the add_dataset method.

    dataset = DatasetSpec(
        name="MNIST",
        type="npy",
        data_opener=assets_directory / "dataset" / "mnist_opener.py",
        description=assets_directory / "dataset" / "description.md",
        permissions=permissions_dataset,
        logs_permission=permissions_dataset,
    )
    dataset_keys[org_id] = client.add_dataset(dataset)
    assert dataset_keys[org_id], "Missing dataset key"

    # Add the training data on each organization.
    data_sample = DataSampleSpec(
        data_manager_keys=[dataset_keys[org_id]],
        path=data_path / f"org_{i+1}" / "train",
    )
    train_datasample_keys[org_id] = client.add_data_sample(data_sample)

    # Add the testing data on each organization.
    data_sample = DataSampleSpec(
        data_manager_keys=[dataset_keys[org_id]],
        path=data_path / f"org_{i+1}" / "test",
    )
    test_datasample_keys[org_id] = client.add_data_sample(data_sample)

Metric registration

A metric is a function used to evaluate the performance of your model on one or several datasamples.

To add a metric, you need to define a function that computes and returns a performance from the datasamples (as returned by the opener) and the predictions_path (to be loaded within the function).

When using a Torch SubstraFL algorithm, the predictions are saved in the predict function in numpy format so that you can simply load them using np.load.

from sklearn.metrics import accuracy_score
import numpy as np


def accuracy(datasamples, predictions_path):
    y_true = datasamples["labels"]
    y_pred = np.load(predictions_path)

    return accuracy_score(y_true, np.argmax(y_pred, axis=1))

We also need to specify the third parties dependencies required to compute the metrics. The Dependency object is instantiated in order to install the right libraries in the Python environment of each organization.

As for the dataset, we also define Permissions.

from substrafl.dependency import Dependency

metric_deps = Dependency(pypi_dependencies=["numpy==1.23.1", "scikit-learn==1.1.1"])

permissions_metric = Permissions(public=False, authorized_ids=[ALGO_ORG_ID] + DATA_PROVIDER_ORGS_ID)

After defining the metrics, dependencies, and permissions, we use the add_metric function to register the metric. This metric will be used on the test datasamples to evaluate the model performances.

from substrafl.remote.register import add_metric


metric_key = add_metric(
    client=clients[ALGO_ORG_ID],
    metric_function=accuracy,
    permissions=permissions_metric,
    dependencies=metric_deps,
)

Specify the machine learning components

This section uses the PyTorch based SubstraFL API to simplify the definition of machine learning components. However, SubstraFL is compatible with any machine learning framework.

In this section, you will:

  • Register a model and its dependencies

  • Specify the federated learning strategy

  • Specify the training and aggregation nodes

  • Specify the test nodes

  • Actually run the computations

Model definition

We choose to use a classic torch CNN as the model to train. The model structure is defined by the user independently of SubstraFL.

import torch
from torch import nn
import torch.nn.functional as F

seed = 42
torch.manual_seed(seed)


class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=5)
        self.conv2 = nn.Conv2d(32, 32, kernel_size=5)
        self.conv3 = nn.Conv2d(32, 64, kernel_size=5)
        self.fc1 = nn.Linear(3 * 3 * 64, 256)
        self.fc2 = nn.Linear(256, 10)

    def forward(self, x, eval=False):
        x = F.relu(self.conv1(x))
        x = F.relu(F.max_pool2d(self.conv2(x), 2))
        x = F.dropout(x, p=0.5, training=not eval)
        x = F.relu(F.max_pool2d(self.conv3(x), 2))
        x = F.dropout(x, p=0.5, training=not eval)
        x = x.view(-1, 3 * 3 * 64)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, p=0.5, training=not eval)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)


model = CNN()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = torch.nn.CrossEntropyLoss()

Specifying on how much data to train

To specify on how much data to train at each round, we use the index_generator object. We specify the batch size and the number of batches (named num_updates) to consider for each round. See Index Generator for more details.

from substrafl.index_generator import NpIndexGenerator

# Number of model updates between each FL strategy aggregation.
NUM_UPDATES = 100

# Number of samples per update.
BATCH_SIZE = 32

index_generator = NpIndexGenerator(
    batch_size=BATCH_SIZE,
    num_updates=NUM_UPDATES,
)

Torch Dataset definition

This torch Dataset is used to preprocess the data using the __getitem__ function.

This torch Dataset needs to have a specific __init__ signature, that must contain (self, datasamples, is_inference).

The __getitem__ function is expected to return (inputs, outputs) if is_inference is False, else only the inputs. This behavior can be changed by re-writing the _local_train or predict methods.

class TorchDataset(torch.utils.data.Dataset):
    def __init__(self, datasamples, is_inference: bool):
        self.x = datasamples["images"]
        self.y = datasamples["labels"]
        self.is_inference = is_inference

    def __getitem__(self, idx):
        if self.is_inference:
            x = torch.FloatTensor(self.x[idx][None, ...]) / 255
            return x

        else:
            x = torch.FloatTensor(self.x[idx][None, ...]) / 255

            y = torch.tensor(self.y[idx]).type(torch.int64)
            y = F.one_hot(y, 10)
            y = y.type(torch.float32)

            return x, y

    def __len__(self):
        return len(self.x)

SubstraFL algo definition

A SubstraFL Algo gathers all the defined elements that run locally in each organization. This is the only SubstraFL object that is framework specific (here PyTorch specific).

The TorchDataset is passed as a class to the Torch algorithm. Indeed, this TorchDataset will be instantiated directly on the data provider organization.

from substrafl.algorithms.pytorch import TorchFedAvgAlgo


class MyAlgo(TorchFedAvgAlgo):
    def __init__(self):
        super().__init__(
            model=model,
            criterion=criterion,
            optimizer=optimizer,
            index_generator=index_generator,
            dataset=TorchDataset,
            seed=seed,
        )

Federated Learning strategies

A FL strategy specifies how to train a model on distributed data. The most well known strategy is the Federated Averaging strategy: train locally a model on every organization, then aggregate the weight updates from every organization, and then apply locally at each organization the averaged updates.

from substrafl.strategies import FedAvg

strategy = FedAvg(algo=MyAlgo())

Where to train where to aggregate

We specify on which data we want to train our model, using the TrainDataNode object. Here we train on the two datasets that we have registered earlier.

The AggregationNode specifies the organization on which the aggregation operation will be computed.

from substrafl.nodes import TrainDataNode
from substrafl.nodes import AggregationNode


aggregation_node = AggregationNode(ALGO_ORG_ID)

# Create the Train Data Nodes (or training tasks) and save them in a list
train_data_nodes = [
    TrainDataNode(
        organization_id=org_id,
        data_manager_key=dataset_keys[org_id],
        data_sample_keys=[train_datasample_keys[org_id]],
    )
    for org_id in DATA_PROVIDER_ORGS_ID
]

Where and when to test

With the same logic as the train nodes, we create TestDataNode to specify on which data we want to test our model.

The Evaluation Strategy defines where and at which frequency we evaluate the model, using the given metric(s) that you registered in a previous section.

from substrafl.nodes import TestDataNode
from substrafl.evaluation_strategy import EvaluationStrategy

# Create the Test Data Nodes (or testing tasks) and save them in a list
test_data_nodes = [
    TestDataNode(
        organization_id=org_id,
        data_manager_key=dataset_keys[org_id],
        test_data_sample_keys=[test_datasample_keys[org_id]],
        metric_keys=[metric_key],
    )
    for org_id in DATA_PROVIDER_ORGS_ID
]


# Test at the end of every round
my_eval_strategy = EvaluationStrategy(test_data_nodes=test_data_nodes, eval_frequency=1)

Running the experiment

We now have all the necessary objects to launch our experiment. Please see a summary below of all the objects we created so far:

  • A Client to add or retrieve the assets of our experiment, using their keys to identify them.

  • An Torch algorithm to define the training parameters (optimizer, train function, predict function, etc…).

  • A Federated Strategy, to specify how to train the model on distributed data.

  • Train data nodes to indicate on which data to train.

  • An Evaluation Strategy, to define where and at which frequency we evaluate the model.

  • An AggregationNode, to specify the organization on which the aggregation operation will be computed.

  • The number of rounds, a round being defined by a local training step followed by an aggregation operation.

  • An experiment folder to save a summary of the operation made.

  • The Dependency to define the libraries on which the experiment needs to run.

from substrafl.experiment import execute_experiment

# A round is defined by a local training step followed by an aggregation operation
NUM_ROUNDS = 3

# The Dependency object is instantiated in order to install the right libraries in
# the Python environment of each organization.
algo_deps = Dependency(pypi_dependencies=["numpy==1.23.1", "torch==1.11.0"])

compute_plan = execute_experiment(
    client=clients[ALGO_ORG_ID],
    strategy=strategy,
    train_data_nodes=train_data_nodes,
    evaluation_strategy=my_eval_strategy,
    aggregation_node=aggregation_node,
    num_rounds=NUM_ROUNDS,
    experiment_folder=str(pathlib.Path.cwd() / "tmp" / "experiment_summaries"),
    dependencies=algo_deps,
)

Out:

Compute plan progress:   0%|          | 0/29 [00:00<?, ?it/s]/home/docs/checkouts/readthedocs.org/user_builds/owkin-substra-documentation/checkouts/0.26.1/docs/src/substra/substra/sdk/backends/local/backend.py:580: UserWarning: `transient=True` is ignored in local mode
  warnings.warn("`transient=True` is ignored in local mode")

Compute plan progress:   3%|3         | 1/29 [00:01<00:40,  1.46s/it]
Compute plan progress:   7%|6         | 2/29 [00:02<00:39,  1.46s/it]
Compute plan progress:  10%|#         | 3/29 [00:07<01:19,  3.05s/it]
Compute plan progress:  14%|#3        | 4/29 [00:11<01:26,  3.44s/it]
Compute plan progress:  17%|#7        | 5/29 [00:17<01:38,  4.10s/it]
Compute plan progress:  21%|##        | 6/29 [00:21<01:33,  4.08s/it]
Compute plan progress:  24%|##4       | 7/29 [00:22<01:11,  3.23s/it]
Compute plan progress:  28%|##7       | 8/29 [00:23<00:51,  2.47s/it]
Compute plan progress:  31%|###1      | 9/29 [00:24<00:39,  1.97s/it]
Compute plan progress:  34%|###4      | 10/29 [00:29<00:56,  2.99s/it]
Compute plan progress:  38%|###7      | 11/29 [00:34<01:06,  3.69s/it]
Compute plan progress:  41%|####1     | 12/29 [00:39<01:05,  3.83s/it]
Compute plan progress:  45%|####4     | 13/29 [00:40<00:50,  3.13s/it]
Compute plan progress:  48%|####8     | 14/29 [00:44<00:50,  3.38s/it]
Compute plan progress:  52%|#####1    | 15/29 [00:49<00:55,  3.97s/it]
Compute plan progress:  55%|#####5    | 16/29 [00:55<00:56,  4.33s/it]
Compute plan progress:  59%|#####8    | 17/29 [00:55<00:39,  3.29s/it]
Compute plan progress:  62%|######2   | 18/29 [00:56<00:28,  2.56s/it]
Compute plan progress:  66%|######5   | 19/29 [01:00<00:29,  2.99s/it]
Compute plan progress:  69%|######8   | 20/29 [01:02<00:22,  2.53s/it]
Compute plan progress:  72%|#######2  | 21/29 [01:06<00:23,  2.98s/it]
Compute plan progress:  76%|#######5  | 22/29 [01:11<00:25,  3.67s/it]
Compute plan progress:  79%|#######9  | 23/29 [01:16<00:24,  4.13s/it]
Compute plan progress:  83%|########2 | 24/29 [01:17<00:15,  3.15s/it]
Compute plan progress:  86%|########6 | 25/29 [01:18<00:09,  2.46s/it]
Compute plan progress:  90%|########9 | 26/29 [01:22<00:08,  2.94s/it]
Compute plan progress:  93%|#########3| 27/29 [01:26<00:06,  3.27s/it]
Compute plan progress:  97%|#########6| 28/29 [01:27<00:02,  2.55s/it]
Compute plan progress: 100%|##########| 29/29 [01:28<00:00,  2.04s/it]
Compute plan progress: 100%|##########| 29/29 [01:28<00:00,  3.04s/it]

The compute plan created is composed of 27 tasks:

  • For each local training step, we create 3 tasks per organisation: training + prediction + evaluation -> 3 tasks.

  • We are training on 2 data organizations; for each round, we have 3 * 2 local taks + 1 aggregation task -> 7 tasks.

  • We are training for 3 rounds: 3 * 7 -> 21 tasks.

  • After the last aggregation step, there are three more tasks: applying the last updates from the aggregator + prediction + evaluation, on both organizations: 21 + 2 * 3 -> 27 tasks

Explore the results

List results

import pandas as pd

performances_df = pd.DataFrame(client.get_performances(compute_plan.key).dict())
print("\nPerformance Table: \n")
print(performances_df[["worker", "round_idx", "performance"]])

Out:

Performance Table:

      worker round_idx  performance
0  MyOrg6MSP         0       0.1032
1  MyOrg7MSP         0       0.1130
2  MyOrg6MSP         1       0.7426
3  MyOrg7MSP         1       0.8272
4  MyOrg6MSP         2       0.8896
5  MyOrg7MSP         2       0.9410
6  MyOrg6MSP         3       0.9184
7  MyOrg7MSP         3       0.9512

Plot results

import matplotlib.pyplot as plt

plt.title("Test dataset results")
plt.xlabel("Rounds")
plt.ylabel("Accuracy")

for org_id in DATA_PROVIDER_ORGS_ID:
    df = performances_df[performances_df["worker"] == org_id]
    plt.plot(df["round_idx"], df["performance"], label=org_id)

plt.legend(loc="lower right")
plt.show()
Test dataset results

Download a model

After the experiment, you might be interested in downloading your trained model. To do so, you will need the source code in order to reload your code architecture in memory. You have the option to choose the client and the round you are interested in downloading.

If round_idx is set to None, the last round will be selected by default.

from substrafl.model_loading import download_algo_files
from substrafl.model_loading import load_algo

client_to_dowload_from = DATA_PROVIDER_ORGS_ID[0]
round_idx = None

algo_files_folder = str(pathlib.Path.cwd() / "tmp" / "algo_files")

download_algo_files(
    client=clients[client_to_dowload_from],
    compute_plan_key=compute_plan.key,
    round_idx=round_idx,
    dest_folder=algo_files_folder,
)

model = load_algo(input_folder=algo_files_folder).model

print(model)

Out:

CNN(
  (conv1): Conv2d(1, 32, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(32, 32, kernel_size=(5, 5), stride=(1, 1))
  (conv3): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=576, out_features=256, bias=True)
  (fc2): Linear(in_features=256, out_features=10, bias=True)
)

Total running time of the script: ( 1 minutes 30.914 seconds)

Gallery generated by Sphinx-Gallery