Index Generator

NpIndexGenerator

class substrafl.index_generator.np_index_generator.NpIndexGenerator(batch_size: int | None, num_updates: int, shuffle: bool = True, drop_last: bool = False, seed: int = 42)

Bases: BaseIndexGenerator

An index based batch generator. It returns an array of size batch_size indexes. If batch_size is equal to zero, this returns an empty array.

Each batch is generated and returned via the method __next__():

batch_generator = NpIndexGenerator(batch_size=32, num_updates=100)
batch_generator.n_samples = 10

batch_1 = next(batch_generator)
batch_2 = next(batch_generator)
# ...
batch_n = next(batch_generator)

In that case, as the default seed is set, the results are deterministic:

batch_1 = np.array([5, 6, 0])
batch_12 = np.array([8, 4, 0])

This class is stateful and can be saved and loaded with the pickle library:

# Saving
with open(indexer_path, "wb") as f:
    pickle.dump(batch_generator, f)
    f.close()

# Loading
with open(indexer_path, "rb") as f:
    loaded_batch_generator = pickle.load(f)
    f.close()

This index generator can be used to generate the batches for one epoch. For that, num_updates must be equal to:

If drop_last=True:

num_updates = math.floor(num_samples / batch_size)

If drop_last=False:

num_updates = math.ceil(num_samples / batch_size)
Parameters:
  • batch_size (Optional[int]) – The size of each batch. If set to None, the batch_size is the number of samples.

  • num_updates (int) – The number of updates. After num_updates, the generator raises a StopIteration error. To reset it for the next round, use the reset_counter() function.

  • shuffle (bool, Optional) – Set to True to shuffle the indexes before each new epoch. Defaults to True.

  • drop_last (bool, Optional) – Set to True to drop the last incomplete batch, if the dataset size is not divisible by the batch size. If False and the size of dataset is not divisible by the batch size, then the last batch is smaller. Defaults to False.

  • seed (int, Optional) – The seed to set the randomness of the generator and have reproducible results. Defaults to 42.

__iter__()

Required methods for generators, returns self.

__next__()

Generates the next batch.

At the start of each iteration through the whole dataset, if shuffle is True then all the indices are shuffled. If there are less elements left in the dataset than batch_size, then if drop_last is False, a batch containing the remaining elements is returned, else the last batch is dropped and the batch is created from the whole dataset. Each calls updates the counter by one, and each time it goes through an epoch, increases n_epoch_generated by one.

Raises:

StopIteration – when this function has been called num_updates times.

Returns:

The batch indexes as a numpy array.

Return type:

numpy.ndarray

property batch_size: int

Number of samples used per batch.

Returns:

Batch size used by the index generator

Return type:

int

check_num_updates()

Check if the counter is equal to num_updates, which means that num_updates batches have been generated since this instance has been created or the counter has been reset.

Raises:

exceptions.IndexGeneratorUpdateError – if the counter is different from num_updates.

property counter: int

Number of calls made to the iterator since the last counter reset.

Returns:

Number of calls made to the iterator

Return type:

int

property n_epoch_generated: int

Number of epochs generated

Returns:

number of epochs generated

Return type:

int

property n_samples: int | None

Returns the number of samples in the dataset.

Returns:

number of samples in the dataset.

Return type:

Optional[int]

property num_updates: int

Number of batches generated between resets of the counter.

Returns:

number of updates

Return type:

int

reset_counter()

Reset the counter to prepare for the next generation of batches.

Base class

class substrafl.index_generator.base.BaseIndexGenerator(batch_size: int | None, num_updates: int, shuffle: bool = True, drop_last: bool = False, seed: int = 42)

Bases: ABC

Base class for the index generator, must be subclassed.

Parameters:
  • batch_size (Optional[int]) – The size of each batch. If set to None, the batch_size will be the number of samples.

  • num_updates (int) – Number of local updates at each round

  • shuffle (bool, Optional) – Shuffle the indexes or not. Defaults to True.

  • drop_last (bool, Optional) – Drop the last batch if its size is inferior to the batch size. Defaults to False.

  • seed (int, Optional) – Random seed. Defaults to 42.

Raises:

ValueError – if batch_size is negative

property batch_size: int

Number of samples used per batch.

Returns:

Batch size used by the index generator

Return type:

int

check_num_updates()

Check if the counter is equal to num_updates, which means that num_updates batches have been generated since this instance has been created or the counter has been reset.

Raises:

exceptions.IndexGeneratorUpdateError – if the counter is different from num_updates.

property counter: int

Number of calls made to the iterator since the last counter reset.

Returns:

Number of calls made to the iterator

Return type:

int

property n_epoch_generated: int

Number of epochs generated

Returns:

number of epochs generated

Return type:

int

property n_samples: int | None

Returns the number of samples in the dataset.

Returns:

number of samples in the dataset.

Return type:

Optional[int]

property num_updates: int

Number of batches generated between resets of the counter.

Returns:

number of updates

Return type:

int

reset_counter()

Reset the counter to prepare for the next generation of batches.