Federated Analytics on the diabetes dataset

This example demonstrates how to use the flexibility of the SubstraFL library and the base class ComputePlanBuilder to do Federated Analytics. It reproduces the diabetes example of the Substra SDK example section using SubstraFL. If you are new to SubstraFL, we recommend to start by the MNIST Example to learn how to use the library in the simplest configuration first.

We use the Diabetes dataset available from the Scikit-Learn dataset module. This dataset contains medical information such as Age, Sex or Blood pressure. The goal of this example is to compute some analytics such as Age mean, Blood pressure standard deviation or Sex percentage.

We simulate having two different data organizations, and a third organization which wants to compute aggregated analytics without having access to the raw data. The example here runs everything locally; however there is only one parameter to change to run it on a real network.


This example is provided as an illustrative example only. In real life, you should be careful not to accidentally leak private information when doing Federated Analytics. For example if a column contains very similar values, sharing its mean and its standard deviation is functionally equivalent to sharing the content of the column. It is strongly recommended to consider what are the potential security risks in your use case, and to act accordingly. It is possible to use other privacy-preserving techniques, such as Differential Privacy, in addition to Substra. Because the focus of this example is Substra capabilities and for the sake of simplicity, such safeguards are not implemented here.

To run this example, you have two options:

  • Recommended option: use a hosted Jupyter notebook. With this option you don’t have to install anything, just run the notebook. To access the hosted notebook, scroll down at the bottom of this page and click on the Launch Binder button.

  • Run the example locally. To do that you need to download and unzip the assets needed to run it in the same directory as used this example.

    Please ensure to have all the libraries installed. A requirements.txt file is included in the zip file, where you can run the command pip install -r requirements.txt to install them.

Instantiating the Substra clients

We work with three different organizations. Two organizations provide data, and a third one performs Federated Analytics to compute aggregated statistics without having access to the raw datasets.

This example runs in local mode, simulating a federated learning experiment.

In the following code cell, we define the different organizations needed for our FL experiment.

from substra import Client

# Choose the subprocess mode to locally simulate the FL process
client_0 = Client(client_name="org-1")
client_1 = Client(client_name="org-2")
client_2 = Client(client_name="org-3")

# Create a dictionary to easily access each client from its human-friendly id
clients = {
    client_0.organization_info().organization_id: client_0,
    client_1.organization_info().organization_id: client_1,
    client_2.organization_info().organization_id: client_2,
# Store organization IDs
ORGS_ID = list(clients)

# The provider of the functions for computing analytics is defined as the first organization.
# Data providers orgs are the two last organizations.

Prepare the data

The function setup_diabetes downloads if needed the diabetes dataset, and split it in two to simulate a federated setup. Each data organization has access to a chunk of the dataset.

import pathlib

from diabetes_substrafl_assets.dataset.diabetes_substrafl_dataset import setup_diabetes

data_path = pathlib.Path.cwd() / "tmp" / "data_diabetes"


Registering data samples and dataset

Every asset will be created in respect to predefined specifications previously imported from substra.sdk.schemas. To register assets, Schemas are first instantiated and the specs are then registered, which generate the real assets.

Permissions are defined when registering assets. In a nutshell:

  • Data cannot be seen once it’s registered on the platform.

  • Metadata are visible by all the users of a network.

  • Permissions allow you to execute a function on a certain dataset.

Next, we need to define the asset directory. You should have already downloaded the assets folder as stated above.

A dataset represents the data in Substra. It contains some metadata and an opener, a script used to load the data from files into memory. You can find more details about datasets in the API reference.

from substra.sdk.schemas import DataSampleSpec
from substra.sdk.schemas import DatasetSpec
from substra.sdk.schemas import Permissions

assets_directory = pathlib.Path.cwd() / "diabetes_substrafl_assets"
assert assets_directory.is_dir(), """Did not find the asset directory,
a directory called 'assets' is expected in the same location as this file"""

permissions_dataset = Permissions(public=False, authorized_ids=[ANALYTICS_PROVIDER_ORG_ID])

dataset = DatasetSpec(
    name=f"Diabetes dataset",
    data_opener=assets_directory / "dataset" / "diabetes_substrafl_opener.py",
    description=data_path / "description.md",

# We register the dataset for each organization
dataset_keys = {client_id: clients[client_id].add_dataset(dataset) for client_id in DATA_PROVIDER_ORGS_ID}

for client_id, key in dataset_keys.items():
    print(f"Dataset key for {client_id}: {key}")


Dataset key for MyOrg12MSP: 62146226-a39a-4c5e-8c87-a784c7d8484d
Dataset key for MyOrg13MSP: 9e2afe57-f896-493e-b12a-03e59c2503fd

The dataset object itself is an empty shell. Data samples are needed in order to add actual data. A data sample contains subfolders containing a single data file like a CSV and the key identifying the dataset it is linked to.

datasample_keys = {
    org_id: clients[org_id].add_data_sample(
            path=data_path / f"org_{i + 1}",
    for i, org_id in enumerate(DATA_PROVIDER_ORGS_ID)

The flexibility of the ComputePlanBuilder class

This example aims at explaining how to use the Compute Plan Builder class, and how to use the full power of the flexibility it provides.

Before starting, we need to have in mind that a federated computation can be represented as a graph of tasks. Some of these tasks need data to be executed (training tasks) and others are here to aggregate local results (aggregation tasks).

Substra does not store an explicit definition of this graph; instead, it gives the user full flexibility to define the compute plan (or computation graph) they need, by linking a task to its parents.

To create this graph of computations, SubstraFL provides the Node abstraction. A Node assigns to an organization (aka a Client) tasks of a given type. The type of the Node depends on the type of tasks we want to run on this organization (training or aggregation tasks).

An organization (aka Client) without data can host an Aggregation node. We will use the Aggregation node object to compute the aggregated analytics.

An organization (aka a Client) containing the data samples can host a Train data node. Each node will only have access data from the organization hosting it. These data samples must be instantiated with the right permissions to be processed by the given Client.

from substrafl.nodes import TrainDataNode
from substrafl.nodes import AggregationNode

aggregation_node = AggregationNode(ANALYTICS_PROVIDER_ORG_ID)

train_data_nodes = [
    for org_id in DATA_PROVIDER_ORGS_ID

The Compute Plan Builder is an abstract class that asks the user to implement only three methods:

  • build_compute_plan(...)

  • load_local_state(...)

  • save_local_state(...)

The build_compute_plan method is essential to create the graph of the compute plan that will be executed on Substra. Using the different Nodes we created, we will update their states by applying user defined methods.

These methods are passed as argument to the Node using its update_state method.

import numpy as np
import pandas as pd
import json
from collections import defaultdict
import pandas as pd
from typing import List, Dict

from substrafl import ComputePlanBuilder
from substrafl.remote import remote_data, remote

class Analytics(ComputePlanBuilder):
    def __init__(self):
        self.first_order_aggregated_state = {}
        self.second_order_aggregated_state = {}

    def local_first_order_computation(self, datasamples: pd.DataFrame, shared_state=None):
        """Compute from the data samples, expected to be a pandas dataframe,
        the means and counts of each column of the data frame.
        These datasamples are the output of the ``get_data`` function defined
        in the ``diabetes_substrafl_opener.py`` file are available in the asset
        folder downloaded at the beginning of the example.

        The signature of a function decorated by @remote_data must contain
        the datasamples and the shared_state arguments.

            datasamples (pd.DataFrame): Pandas dataframe provided by the opener.
            shared_state (None, optional): Unused here as this function only
                use local information already present in the datasamples.
                Defaults to None.

            dict: dictionary containing the local information on means, counts
                and number of sample. This dict will be used as a state to be
                shared to an AggregationNode in order to compute the aggregation
                of the different analytics.
        df = datasamples
        states = {
            "n_samples": len(df),
            "means": df.select_dtypes(include=np.number).sum().to_dict(),
            "counts": {
                name: series.value_counts().to_dict() for name, series in df.select_dtypes(include="category").items()
        return states

    def local_second_order_computation(self, datasamples: pd.DataFrame, shared_state: Dict):
        """This function will use the output of the ``aggregation`` function to compute
        locally the standard deviation of the different columns.

            datasamples (pd.DataFrame): Pandas dataframe provided by the opener.
            shared_state (Dict): Output of a first order analytics computation,
                that must contain the means.

            Dict: dictionary containing the local information on standard deviation
                and number of sample. This dict will be used as a state to be shared
                to an AggregationNode in order to compute the aggregation of the
                different analytics.
        df = datasamples
        means = pd.Series(shared_state["means"])
        states = {
            "n_samples": len(df),
            "std": np.power(df.select_dtypes(include=np.number) - means, 2).sum(),
        return states

    def aggregation(self, shared_states: List[Dict]):
        """Aggregation function that receive a list on locally computed analytics in order to
        aggregate them.
        The aggregation will be a weighted average using "n_samples" as weight coefficient.

            shared_states (List[Dict]): list of dictionaries containing a field "n_samples",
            and the analytics to aggregate in separated fields.

            Dict: dictionary containing the aggregated analytics.
        total_len = 0
        for state in shared_states:
            total_len += state["n_samples"]

        aggregated_values = defaultdict(lambda: defaultdict(float))
        for state in shared_states:
            for analytics_name, col_dict in state.items():
                if analytics_name == "n_samples":
                    # already aggregated in total_len
                for col_name, v in col_dict.items():
                    if isinstance(v, dict):
                        # this column is categorical and v is a dict over
                        # the different modalities
                        if not aggregated_values[analytics_name][col_name]:
                            aggregated_values[analytics_name][col_name] = defaultdict(float)
                        for modality, vv in v.items():
                            aggregated_values[analytics_name][col_name][modality] += vv / total_len
                        # this is a numerical column and v is numerical
                        aggregated_values[analytics_name][col_name] += v / total_len

        # transform default_dict to regular dict
        aggregated_values = json.loads(json.dumps(aggregated_values))

        return aggregated_values

    def build_compute_plan(
        train_data_nodes: List[TrainDataNode],
        aggregation_node: AggregationNode,
        """Method to build and link the different computations to execute with each other.
        We will use the ``update_state``method of the nodes given as input to choose which
        method to apply.
        For our example, we will only use TrainDataNodes and AggregationNodes.

            train_data_nodes (List[TrainDataNode]): Nodes linked to the data samples on which
                to compute analytics.
            aggregation_node (AggregationNode): Node on which to compute the aggregation
                of the analytics extracted from the train_data_nodes.
            num_rounds Optional[int]: Num rounds to be used to iterate on recurrent part of
                the compute plan. Defaults to None.
            evaluation_strategy Optional[substrafl.EvaluationStrategy]: Object storing the
                TestDataNode. Unused in this example. Defaults to None.
            clean_models (bool): Clean the intermediary models of this round on the
                Substra platform. Default to False.
        first_order_shared_states = []
        local_states = {}

        for node in train_data_nodes:
            # Call local_first_order_computation on each train data node
            next_local_state, next_shared_state = node.update_states(
                    _algo_name=f"Computing first order means with {self.__class__.__name__}",

            # All local analytics are stored in the first_order_shared_states,
            # given as input the the aggregation method.
            local_states[node.organization_id] = next_local_state

        # Call the aggregation method on the first_order_shared_states
        self.first_order_aggregated_state = aggregation_node.update_states(
                _algo_name="Aggregating first order",
            authorized_ids=set([train_data_node.organization_id for train_data_node in train_data_nodes]),

        second_order_shared_states = []

        for node in train_data_nodes:
            # Call local_second_order_computation on each train data node
            _, next_shared_state = node.update_states(
                    _algo_name=f"Computing second order analytics with {self.__class__.__name__}",

            # All local analytics are stored in the second_order_shared_states,
            # given as input the the aggregation method.

        # Call the aggregation method on the second_order_shared_states
        self.second_order_aggregated_state = aggregation_node.update_states(
                _algo_name="Aggregating second order",
            authorized_ids=set([train_data_node.organization_id for train_data_node in train_data_nodes]),

    def save_local_state(self, path: pathlib.Path):
        """This function will save the important local state to retrieve after each new
        call to a train or test task.

            path (pathlib.Path): Path where to save the local_state. Provided internally by
        state_to_save = {
            "first_order": self.first_order_aggregated_state,
            "second_order": self.second_order_aggregated_state,
        with open(path, "w") as f:
            json.dump(state_to_save, f)

    def load_local_state(self, path: pathlib.Path):
        """Mirror function to load the local_state from a file saved using

            path (pathlib.Path): Path where to load the local_state. Provided internally by

            ComputePlanBuilder: return self with the updated local state.
        with open(path, "r") as f:
            state_to_load = json.load(f)

        self.first_order_aggregated_state = state_to_load["first_order"]
        self.second_order_aggregated_state = state_to_load["second_order"]

        return self

Now that we saw the implementation of the custom Analytics class, we can add details to some of the previously introduced concepts.

The update_state method outputs the new state of the node, that can be passed as an argument to a following one. This succession of next_state passed to a new node.update_state is how Substra build the graph of the compute plan.

The load_local_state and save_local_state are two methods used at each new iteration on a Node, in order to retrieve the previous local state that have not been shared with the other Nodes.

For instance, after updating a Train data node using its update_state method, we will have access to its next local state, that we will pass as argument to the next update_state we will apply on this Train data node.

To summarize, a Compute Plan Builder is composed of several decorated user defined functions, that can need some data (decorated with @remote_data) or not (decorated with @remote).

See Decorator for more information on these decorators.

These user defined functions will be used to create the graph of the compute plan through the build_compute_plan method and the update_state method of the different Nodes.

The local state obtained after updating a Train data node needs the methods save_local_state and load_local_state to retrieve the state where the Node was at the end of the last update.

Running the experiment

As a last step before launching our experiment, we need to specify the third parties dependencies required to run it. The Dependency object is instantiated in order to install the right libraries in the Python environment of each organization.

We now have all the necessary objects to launch our experiment. Please see a summary below of all the objects we created so far:

  • A Client to add or retrieve the assets of our experiment, using their keys to identify them.

  • An Torch algorithm to define the training parameters (optimizer, train function, predict function, etc…).

  • A Federated Strategy, to specify what compute plan we want to execute.

  • Train data nodes to indicate on which data to train.

  • An Evaluation Strategy, to define where and at which frequency we evaluate the model. Here this does not apply to our experiment. We set it to None.

  • An AggregationNode, to specify the organization on which the aggregation operation will be computed.

  • An experiment folder to save a summary of the operation made.

  • The Dependency to define the libraries on which the experiment needs to run.

from substrafl.dependency import Dependency
from substrafl.experiment import execute_experiment

dependencies = Dependency(pypi_dependencies=["numpy==1.23.1", "pandas==1.5.3"])

compute_plan = execute_experiment(
    experiment_folder=str(pathlib.Path.cwd() / "tmp" / "experiment_summaries"),
    name="Federated Analytics with SubstraFL documentation example",


Compute plan progress:   0%|          | 0/6 [00:00<?, ?it/s]
Compute plan progress:  17%|#6        | 1/6 [00:00<00:03,  1.34it/s]
Compute plan progress:  33%|###3      | 2/6 [00:01<00:02,  1.34it/s]
Compute plan progress:  50%|#####     | 3/6 [00:02<00:02,  1.34it/s]
Compute plan progress:  67%|######6   | 4/6 [00:02<00:01,  1.33it/s]
Compute plan progress:  83%|########3 | 5/6 [00:03<00:00,  1.33it/s]
Compute plan progress: 100%|##########| 6/6 [00:04<00:00,  1.33it/s]
Compute plan progress: 100%|##########| 6/6 [00:04<00:00,  1.33it/s]


The output of a task can be downloaded using some utils function provided by SubstraFL, such as download_algo_state, download_train_shared_state or download_aggregate_shared_state.

These functions download from a given Client and a given compute_plan_key the output of a given round_idx or rank_idx.

from substrafl.model_loading import download_aggregate_shared_state

# The aggregated analytics are computed in the ANALYTICS_PROVIDER_ORG_ID client.
client_to_dowload_from = clients[ANALYTICS_PROVIDER_ORG_ID]

# The results will be available once the compute plan is completed

first_rank_analytics = download_aggregate_shared_state(

second_rank_analytics = download_aggregate_shared_state(

    f"""Age mean: {first_rank_analytics['means']['age']:.2f} years
Sex percentage:
    Male: {100*first_rank_analytics['counts']['sex']['M']:.2f}%
    Female: {100*first_rank_analytics['counts']['sex']['F']:.2f}%
Blood pressure std: {second_rank_analytics["std"]["bp"]:.2f} mm Hg


Age mean: 48.52 years
Sex percentage:
    Male: 53.17%
    Female: 46.83%
Blood pressure std: 190.87 mm Hg

Total running time of the script: ( 0 minutes 4.562 seconds)

Gallery generated by Sphinx-Gallery