Source code for mixmo.loaders.batch_repetition_sampler
"""
Sampler definition for multi-input models
"""
import torch
import random
from mixmo.utils import (
config, torchutils, logger, misc)
LOGGER = logger.get_logger(__name__, level="DEBUG")
[docs]class BatchRepetitionSampler(torch.utils.data.sampler.BatchSampler):
"""
Wraps another sampler to yield a mini-batch of repeated indices.
"""
[docs] def __init__(
self,
sampler,
batch_size,
num_members,
config_batch_sampler,
drop_last=False,
):
torch.utils.data.sampler.BatchSampler.__init__(self, sampler, batch_size, drop_last)
self.num_members = num_members
self._batch_repetitions = config_batch_sampler["batch_repetitions"]
self._proba_input_repetition = config_batch_sampler["proba_input_repetition"]
def __iter__(self):
batch = []
for idx in self.sampler:
for _ in range(self._batch_repetitions):
batch.append(idx)
if len(batch) >= self.batch_size:
yield self.output_format(batch)
batch = []
if len(batch) > 0 and not self.drop_last:
yield self.output_format(batch)
[docs] def output_format(self, std_batch):
"""
Transforms standards batches into batches of sample summaries
"""
# Create M shuffled batches, one for each input
batch_size = len(std_batch)
list_shuffled_index = [
torchutils.randperm_static(batch_size, proba_static=self._proba_input_repetition)
for _ in range(self.num_members)
]
shuffled_batch = [
std_batch[list_shuffled_index[0][count]]
for count in range(batch_size)]
# sample batch seed, shared among samples from the given batch
batch_seed = random.randint(0, config.cfg.RANDOM.MAX_RANDOM)
list_index = [
misc.clean_update(
{
"batch_seed": batch_seed,
"index_" + str(0): shuffled_batch[count]
}, {
"index_" + str(num_member):
shuffled_batch[list_shuffled_index[num_member][count]]
for num_member in range(1, self.num_members)
}
)
for count in range(batch_size)
]
return list_index
def __len__(self):
len_sampler = len(self.sampler) * self._batch_repetitions
if self.drop_last:
return len_sampler // self.batch_size
else:
return (len_sampler + self.batch_size - 1) // self.batch_size