Skip to content
30 changes: 25 additions & 5 deletions cebra/data/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
"""Base classes for datasets and loaders."""

import abc
from typing import Iterator

import literate_dataclasses as dataclasses
import torch
Expand Down Expand Up @@ -242,6 +243,12 @@ class Loader(abc.ABC, cebra.io.HasDevice):
batch_size: int = dataclasses.field(default=None,
doc="""The total batch size.""")

num_negatives: int = dataclasses.field(
default=None,
doc=("The number of negative samples to draw for each reference. "
"If not specified, the batch size is used."),
)

def __post_init__(self):
if self.num_steps is None or self.num_steps <= 0:
raise ValueError(
Expand All @@ -251,28 +258,41 @@ def __post_init__(self):
raise ValueError(
f"Batch size has to be None, or a non-negative value. Got {self.batch_size}."
)
if self.num_negatives is not None and self.num_negatives <= 0:
raise ValueError(
f"Number of negatives has to be None, or a non-negative value. Got {self.num_negatives}."
)

if self.num_negatives is None:
self.num_negatives = self.batch_size

def __len__(self):
"""The number of batches returned when calling as an iterator."""
return self.num_steps

def __iter__(self) -> Batch:
def __iter__(self) -> Iterator[Batch]:
for _ in range(len(self)):
index = self.get_indices(num_samples=self.batch_size)
index = self.get_indices()
yield self.dataset.load_batch(index)

@abc.abstractmethod
def get_indices(self, num_samples: int):
def get_indices(self, *, num_samples: int = None):
"""Sample and return the specified number of indices.

The elements of the returned `BatchIndex` will be used to index the
`dataset` of this data loader.

Args:
num_samples: The size of each of the reference, positive and
negative samples.
num_samples: Deprecated. Use ``batch_size`` on the instance level
instead.

Returns:
batch indices for the reference, positive and negative sample.

Note:
From version 0.7.0 onwards, specifying the ``num_samples``
directly is deprecated and will be removed in version 0.8.0.
Comment thread
stes marked this conversation as resolved.
Please set ``batch_size`` and ``num_negatives`` on the instance
level instead.
"""
raise NotImplementedError()
22 changes: 18 additions & 4 deletions cebra/data/multi_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,10 +155,14 @@ def __post_init__(self):
super().__post_init__()
self.sampler = cebra.distributions.MultisessionSampler(
self.dataset, self.time_offset)
if self.num_negatives is None:
self.num_negatives = self.batch_size

def get_indices(self, num_samples: int) -> List[BatchIndex]:
# NOTE(stes): In the longer run, we need to unify the API here; the num_samples argument
# is not used in the multi-session case, which is different to the single session samples.
def get_indices(self) -> List[BatchIndex]:
ref_idx = self.sampler.sample_prior(self.batch_size)
neg_idx = self.sampler.sample_prior(self.batch_size)
neg_idx = self.sampler.sample_prior(self.num_negatives)
pos_idx, idx, idx_rev = self.sampler.sample_conditional(ref_idx)

ref_idx = torch.from_numpy(ref_idx)
Expand Down Expand Up @@ -192,8 +196,11 @@ class DiscreteMultiSessionDataLoader(MultiSessionLoader):
# Overwrite sampler with the discrete implementation
# Generalize MultisessionSampler to avoid doing this?
def __post_init__(self):
# NOTE(stes): __post_init__ from superclass is intentionally not called.
self.sampler = cebra.distributions.DiscreteMultisessionSampler(
self.dataset)
if self.num_negatives is None:
self.num_negatives = self.batch_size

@property
def index(self):
Expand Down Expand Up @@ -229,7 +236,14 @@ def __post_init__(self):
self.sampler = cebra.distributions.UnifiedSampler(
self.dataset, self.time_offset)

def get_indices(self, num_samples: int) -> BatchIndex:
if self.batch_size is not None and self.batch_size < 2:
raise ValueError("UnifiedLoader does not support batch_size < 2.")

if self.num_negatives is not None and self.num_negatives < 2:
raise ValueError(
"UnifiedLoader does not support num_negatives < 2.")

def get_indices(self) -> BatchIndex:
"""Sample and return the specified number of indices.

The elements of the returned ``BatchIndex`` will be used to index the
Expand All @@ -251,7 +265,7 @@ def get_indices(self, num_samples: int) -> BatchIndex:
Batch indices for the reference, positive and negative samples.
"""
ref_idx = self.sampler.sample_prior(self.batch_size)
neg_idx = self.sampler.sample_prior(self.batch_size)
neg_idx = self.sampler.sample_prior(self.num_negatives)

pos_idx = self.sampler.sample_conditional(ref_idx)

Expand Down
22 changes: 13 additions & 9 deletions cebra/data/multiobjective.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,13 @@
# limitations under the License.
#

from typing import Iterator

import literate_dataclasses as dataclasses

import cebra.data as cebra_data
import cebra.distributions
from cebra.data.datatypes import Batch
from cebra.data.datatypes import BatchIndex
from cebra.distributions.continuous import Prior

Expand Down Expand Up @@ -71,9 +74,9 @@ def __post_init__(self):
def add_config(self, config):
self.labels.append(config['label'])

def get_indices(self, num_samples: int):
def get_indices(self) -> BatchIndex:
if self.sampling_mode_supervised == "ref_shared":
reference_idx = self.prior.sample_prior(num_samples)
reference_idx = self.prior.sample_prior(self.batch_size)
else:
raise ValueError(
f"Sampling mode {self.sampling_mode_supervised} is not implemented."
Expand All @@ -87,9 +90,9 @@ def get_indices(self, num_samples: int):

return batch_index

def __iter__(self):
def __iter__(self) -> Iterator[Batch]:
for _ in range(len(self)):
index = self.get_indices(num_samples=self.batch_size)
index = self.get_indices()
yield self.dataset.load_batch_supervised(index, self.labels)


Expand Down Expand Up @@ -142,13 +145,14 @@ def add_config(self, config):

self.distributions.append(distribution)

def get_indices(self, num_samples: int):
def get_indices(self) -> BatchIndex:
"""Sample and return the specified number of indices."""

if self.sampling_mode_contrastive == "refneg_shared":
ref_and_neg = self.prior.sample_prior(num_samples * 2)
reference_idx = ref_and_neg[:num_samples]
negative_idx = ref_and_neg[num_samples:]
ref_and_neg = self.prior.sample_prior(self.batch_size +
self.num_negatives)
reference_idx = ref_and_neg[:self.batch_size]
negative_idx = ref_and_neg[self.batch_size:]

positives_idx = []
for distribution in self.distributions:
Expand All @@ -169,5 +173,5 @@ def get_indices(self, num_samples: int):

def __iter__(self):
for _ in range(len(self)):
index = self.get_indices(num_samples=self.batch_size)
index = self.get_indices()
yield self.dataset.load_batch_contrastive(index)
78 changes: 36 additions & 42 deletions cebra/data/single_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

import abc
import warnings
from typing import Iterator

import literate_dataclasses as dataclasses
import torch
Expand Down Expand Up @@ -138,7 +139,7 @@ def _init_distribution(self):
f"Invalid choice of prior distribution. Got '{self.prior}', but "
f"only accept 'uniform' or 'empirical' as potential values.")

def get_indices(self, num_samples: int) -> BatchIndex:
def get_indices(self) -> BatchIndex:
"""Samples indices for reference, positive and negative examples.

The reference samples will be sampled from the empirical or uniform prior
Expand All @@ -151,16 +152,13 @@ def get_indices(self, num_samples: int) -> BatchIndex:
The negative samples will be sampled from the same distribution as the
reference examples.

Args:
num_samples: The number of samples (batch size) of the returned
:py:class:`cebra.data.datatypes.BatchIndex`.

Returns:
Indices for reference, positive and negatives samples.
Comment thread
stes marked this conversation as resolved.
"""
reference_idx = self.distribution.sample_prior(num_samples * 2)
negative_idx = reference_idx[num_samples:]
reference_idx = reference_idx[:num_samples]
reference_idx = self.distribution.sample_prior(self.batch_size +
self.num_negatives)
negative_idx = reference_idx[self.batch_size:]
reference_idx = reference_idx[:self.batch_size]
reference = self.index[reference_idx]
positive_idx = self.distribution.sample_conditional(reference)
return BatchIndex(reference=reference_idx,
Expand Down Expand Up @@ -246,7 +244,7 @@ def _init_distribution(self):
else:
raise ValueError(self.conditional)

def get_indices(self, num_samples: int) -> BatchIndex:
def get_indices(self) -> BatchIndex:
"""Samples indices for reference, positive and negative examples.

The reference and negative samples will be sampled uniformly from
Expand All @@ -255,16 +253,13 @@ def get_indices(self, num_samples: int) -> BatchIndex:
The positive samples will be sampled conditional on the reference
samples according to the specified ``conditional`` distribution.

Args:
num_samples: The number of samples (batch size) of the returned
:py:class:`cebra.data.datatypes.BatchIndex`.

Returns:
Indices for reference, positive and negatives samples.
"""
reference_idx = self.distribution.sample_prior(num_samples * 2)
negative_idx = reference_idx[num_samples:]
reference_idx = reference_idx[:num_samples]
reference_idx = self.distribution.sample_prior(self.batch_size +
self.num_negatives)
negative_idx = reference_idx[self.batch_size:]
reference_idx = reference_idx[:self.batch_size]
positive_idx = self.distribution.sample_conditional(reference_idx)
return BatchIndex(reference=reference_idx,
positive=positive_idx,
Expand Down Expand Up @@ -305,7 +300,7 @@ def __post_init__(self):
continuous=self.cindex,
time_delta=self.time_offset)

def get_indices(self, num_samples: int) -> BatchIndex:
def get_indices(self) -> BatchIndex:
"""Samples indices for reference, positive and negative examples.

The reference and negative samples will be sampled uniformly from
Expand All @@ -316,10 +311,6 @@ def get_indices(self, num_samples: int) -> BatchIndex:
:py:class:`ContinuousDataLoader`, or just sampled based on the
conditional variable.

Args:
num_samples: The number of samples (batch size) of the returned
:py:class:`cebra.data.datatypes.BatchIndex`.

Returns:
Indices for reference, positive and negatives samples.
Comment thread
stes marked this conversation as resolved.

Expand All @@ -328,10 +319,13 @@ def get_indices(self, num_samples: int) -> BatchIndex:
class.
- Sample the negatives with matching discrete variable
"""
reference_idx = self.distribution.sample_prior(num_samples)
reference_idx = self.distribution.sample_prior(self.batch_size +
self.num_negatives)
negative_idx = reference_idx[self.batch_size:]
reference_idx = reference_idx[:self.batch_size]
return BatchIndex(
reference=reference_idx,
negative=self.distribution.sample_prior(num_samples),
negative=negative_idx,
positive=self.distribution.sample_conditional(reference_idx),
)

Expand Down Expand Up @@ -421,32 +415,29 @@ def _init_time_distribution(self):
else:
raise ValueError

def get_indices(self, num_samples: int) -> BatchIndex:
def get_indices(self) -> BatchIndex:
"""Samples indices for reference, positive and negative examples.

The reference and negative samples will be sampled uniformly from
all available time steps, and a total of ``2*num_samples`` will be
returned for both.
all available time steps, and a total of ``self.batch_size + self.num_negatives``
will be returned for both.

For the positive samples, ``num_samples`` are sampled according to the
behavior conditional distribution, and another ``num_samples`` are
sampled according to the dime contrastive distribution. The indices
For the positive samples, ``self.batch_size`` samples are sampled according to the
behavior conditional distribution, and another ``self.batch_size`` samples are
sampled according to the time contrastive distribution. The indices
for the positive samples are concatenated across the first dimension.

Args:
num_samples: The number of samples (batch size) of the returned
:py:class:`cebra.data.datatypes.BatchIndex`.

Returns:
Indices for reference, positive and negatives samples.

Todo:
Add the ``empirical`` vs. ``discrete`` sampling modes to this
class.
"""
reference_idx = self.time_distribution.sample_prior(num_samples * 2)
negative_idx = reference_idx[num_samples:]
reference_idx = reference_idx[:num_samples]
reference_idx = self.time_distribution.sample_prior(self.batch_size +
self.num_negatives)
negative_idx = reference_idx[self.batch_size:]
reference_idx = reference_idx[:self.batch_size]
behavior_positive_idx = self.behavior_distribution.sample_conditional(
reference_idx)
time_positive_idx = self.time_distribution.sample_conditional(
Expand All @@ -464,13 +455,18 @@ class FullDataLoader(ContinuousDataLoader):

def __post_init__(self):
super().__post_init__()
self.batch_size = None

if self.batch_size is not None:
Comment thread
stes marked this conversation as resolved.
raise ValueError("Batch size cannot be set for FullDataLoader.")
if self.num_negatives is not None:
raise ValueError(
"Number of negatives cannot be set for FullDataLoader.")

@property
def offset(self):
return self.dataset.offset

def get_indices(self, num_samples=None) -> BatchIndex:
def get_indices(self) -> BatchIndex:
"""Samples indices for reference, positive and negative examples.

The reference indices are all available (valid, according to the
Expand All @@ -490,7 +486,6 @@ def get_indices(self, num_samples=None) -> BatchIndex:
Add the ``empirical`` vs. ``discrete`` sampling modes to this
class.
"""
assert num_samples is None

reference_idx = torch.arange(
self.offset.left,
Expand All @@ -504,7 +499,6 @@ def get_indices(self, num_samples=None) -> BatchIndex:
positive=positive_idx,
negative=negative_idx)

def __iter__(self):
def __iter__(self) -> Iterator[BatchIndex]:
for _ in range(len(self)):
index = self.get_indices(num_samples=self.batch_size)
yield index
yield self.get_indices()
Loading
Loading