Index Generator¶
NpIndexGenerator¶
- class substrafl.index_generator.np_index_generator.NpIndexGenerator(batch_size: Optional[int], num_updates: int, shuffle: bool = True, drop_last: bool = False, seed: int = 42)¶
Bases:
substrafl.index_generator.base.BaseIndexGenerator
An index based batch generator. It returns an array of size
batch_size
indexes. Ifbatch_size
is equal to zero, this returns an empty array.Each batch is generated and returned via the method
__next__()
:batch_generator = NpIndexGenerator(batch_size=32, num_updates=100) batch_generator.n_samples = 10 batch_1 = next(batch_generator) batch_2 = next(batch_generator) # ... batch_n = next(batch_generator)
In that case, as the default seed is set, the results are deterministic:
batch_1 = np.array([5, 6, 0]) batch_12 = np.array([8, 4, 0])
This class is stateful and can be saved and loaded with the pickle library:
# Saving with open(indexer_path, "wb") as f: pickle.dump(batch_generator, f) f.close() # Loading with open(indexer_path, "rb") as f: loaded_batch_generator = pickle.load(f) f.close()
This index generator can be used to generate the batches for one epoch. For that,
num_updates
must be equal to:If
drop_last=True
:num_updates = math.floor(num_samples / batch_size)
If
drop_last=False
:num_updates = math.ceil(num_samples / batch_size)
- Parameters
batch_size (Optional[int]) – The size of each batch. If set to None, the batch_size is the number of samples.
num_updates (int) – The number of updates. After num_updates, the generator raises a StopIteration error. To reset it for the next round, use the
reset_counter()
function.shuffle (bool, Optional) – Set to True to shuffle the indexes before each new epoch. Defaults to True.
drop_last (bool, Optional) – Set to True to drop the last incomplete batch, if the dataset size is not divisible by the batch size. If False and the size of dataset is not divisible by the batch size, then the last batch is smaller. Defaults to False.
seed (int, Optional) – The seed to set the randomness of the generator and have reproducible results. Defaults to 42.
- __iter__()¶
Required methods for generators, returns
self
.
- __next__()¶
Generates the next batch.
At the start of each iteration through the whole dataset, if
shuffle
is True then all the indices are shuffled. If there are less elements left in the dataset thanbatch_size
, then ifdrop_last
is False, a batch containing the remaining elements is returned, else the last batch is dropped and the batch is created from the whole dataset. Each calls updates thecounter
by one, and each time it goes through an epoch, increasesn_epoch_generated
by one.- Raises
StopIteration – when this function has been called
num_updates
times.- Returns
The batch indexes as a numpy array.
- Return type
- property batch_size: int¶
Number of samples used per batch.
- Returns
Batch size used by the index generator
- Return type
- check_num_updates()¶
Check if the counter is equal to
num_updates
, which means thatnum_updates
batches have been generated since this instance has been created or the counter has been reset.- Raises
exceptions.IndexGeneratorUpdateError – if the counter is different from
num_updates
.
- property counter: int¶
Number of calls made to the iterator since the last counter reset.
- Returns
Number of calls made to the iterator
- Return type
- property n_epoch_generated: int¶
Number of epochs generated
- Returns
number of epochs generated
- Return type
- property num_updates: int¶
Number of batches generated between resets of the counter.
- Returns
number of updates
- Return type
- reset_counter()¶
Reset the counter to prepare for the next generation of batches.
Base class¶
- class substrafl.index_generator.base.BaseIndexGenerator(batch_size: Optional[int], num_updates: int, shuffle: bool = True, drop_last: bool = False, seed: int = 42)¶
Bases:
abc.ABC
Base class for the index generator, must be subclassed.
- Parameters
batch_size (Optional[int]) – The size of each batch. If set to None, the batch_size will be the number of samples.
num_updates (int) – Number of local updates at each round
shuffle (bool, Optional) – Shuffle the indexes or not. Defaults to True.
drop_last (bool, Optional) – Drop the last batch if its size is inferior to the batch size. Defaults to False.
seed (int, Optional) – Random seed. Defaults to 42.
- Raises
ValueError – if batch_size is negative
- property batch_size: int¶
Number of samples used per batch.
- Returns
Batch size used by the index generator
- Return type
- check_num_updates()¶
Check if the counter is equal to
num_updates
, which means thatnum_updates
batches have been generated since this instance has been created or the counter has been reset.- Raises
exceptions.IndexGeneratorUpdateError – if the counter is different from
num_updates
.
- property counter: int¶
Number of calls made to the iterator since the last counter reset.
- Returns
Number of calls made to the iterator
- Return type
- property n_epoch_generated: int¶
Number of epochs generated
- Returns
number of epochs generated
- Return type
- property num_updates: int¶
Number of batches generated between resets of the counter.
- Returns
number of updates
- Return type
- reset_counter()¶
Reset the counter to prepare for the next generation of batches.