Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion AUTHORS.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,10 @@ It is a publicly available tool that has benefited from contributions and sugges
### 2025
- **Steffen Schneider, Rodrigo González Laiz, Anastasiia Filippova, Markus Frey, Mackenzie W. Mathis**
[*Time-series attribution maps with regularized contrastive learning.*](https://openreview.net/forum?id=aGrCXoTB4P)
AISTATS (2025)
AISTATS (2025).

- **Rodrigo González Laiz*, Tobias Schmidt*, Steffen Schneider**:
[*Self-supervised contrastive learning performs non-linear system identification*](https://arxiv.org/abs/2410.14673)
ICLR (2025).
Adds the `cebra.dynamics` module and introduces improvements to `cebra.criterions`. For advanced features on dynamics learning,
the full reference implementation for the paper is available at https://github.com/dynamical-inference/dcl.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ It can jointly use behavioral and neural data in a hypothesis- or discovery-driv
Steffen Schneider, Rodrigo González Laiz, Anastasiia Filipova, Markus Frey, Mackenzie Weygandt Mathis. AISTATS 2025.


- 📄 **Publication April 2025**:
[Self-supervised contrastive learning performs non-linear system identification](https://arxiv.org/abs/2410.14673)
Rodrigo González Laiz*, Tobias Schmidt*, Steffen Schneider. ICLR 2025.

- 📄 **Publication May 2023**:
[Learnable latent embeddings for joint behavioural and neural analysis.](https://doi.org/10.1038/s41586-023-06031-6)
Steffen Schneider*, Jin Hwa Lee* and Mackenzie Weygandt Mathis. Nature 2023.
Expand Down
7 changes: 6 additions & 1 deletion cebra/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
# TODO(stes): More common integrations people care about (e.g. PyTorch lightning)
# could be added here.
from cebra.integrations.sklearn.cebra import CEBRA
from cebra.integrations.sklearn.cebra import DCL
from cebra.integrations.sklearn.decoder import KNNDecoder
from cebra.integrations.sklearn.decoder import L1LinearRegressor

Expand Down Expand Up @@ -67,7 +68,7 @@
import cebra.integrations.sklearn as sklearn

__version__ = "0.6.1"
__all__ = ["CEBRA"]
__all__ = ["CEBRA", "DCL"]
__allow_lazy_imports = False
__lazy_imports = {}

Expand All @@ -91,6 +92,10 @@ def __getattr__(key):
from cebra.integrations.sklearn.cebra import CEBRA

return CEBRA
elif key == "DCL":
from cebra.integrations.sklearn.cebra import DCL

return DCL
elif key == "KNNDecoder":
from cebra.integrations.sklearn.decoder import KNNDecoder # noqa: F811

Expand Down
18 changes: 14 additions & 4 deletions cebra/data/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,8 +239,17 @@ class Loader(abc.ABC, cebra.io.HasDevice):
"""The total number of batches when iterating over the dataloader.""",
)

batch_size: int = dataclasses.field(default=None,
doc="""The total batch size.""")
batch_size: int = dataclasses.field(
default=None,
doc=
"""The total batch size. Number of reference and number of positive samples"""
)

batch_size_negatives: int = dataclasses.field(
default=None,
doc=
"""Number of negative samples for a given batch. If None defaults to batch size"""
)

def __post_init__(self):
if self.num_steps is None or self.num_steps <= 0:
Expand All @@ -258,11 +267,12 @@ def __len__(self):

def __iter__(self) -> Batch:
for _ in range(len(self)):
index = self.get_indices(num_samples=self.batch_size)
index = self.get_indices(num_samples=self.batch_size,
num_negatives=self.batch_size_negatives)
yield self.dataset.load_batch(index)

@abc.abstractmethod
def get_indices(self, num_samples: int):
def get_indices(self, num_samples: int, num_negatives: int = None):
"""Sample and return the specified number of indices.

The elements of the returned `BatchIndex` will be used to index the
Expand Down
8 changes: 6 additions & 2 deletions cebra/data/multi_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,9 @@ def __post_init__(self):
self.sampler = cebra.distributions.MultisessionSampler(
self.dataset, self.time_offset)

def get_indices(self, num_samples: int) -> List[BatchIndex]:
def get_indices(self,
num_samples: int,
num_negatives: int = None) -> List[BatchIndex]:
ref_idx = self.sampler.sample_prior(self.batch_size)
neg_idx = self.sampler.sample_prior(self.batch_size)
pos_idx, idx, idx_rev = self.sampler.sample_conditional(ref_idx)
Expand Down Expand Up @@ -229,7 +231,9 @@ def __post_init__(self):
self.sampler = cebra.distributions.UnifiedSampler(
self.dataset, self.time_offset)

def get_indices(self, num_samples: int) -> BatchIndex:
def get_indices(self,
num_samples: int,
num_negatives: int = None) -> BatchIndex:
"""Sample and return the specified number of indices.

The elements of the returned ``BatchIndex`` will be used to index the
Expand Down
4 changes: 2 additions & 2 deletions cebra/data/multiobjective.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def __post_init__(self):
def add_config(self, config):
self.labels.append(config['label'])

def get_indices(self, num_samples: int):
def get_indices(self, num_samples: int, num_negatives: int = None):
if self.sampling_mode_supervised == "ref_shared":
reference_idx = self.prior.sample_prior(num_samples)
else:
Expand Down Expand Up @@ -142,7 +142,7 @@ def add_config(self, config):

self.distributions.append(distribution)

def get_indices(self, num_samples: int):
def get_indices(self, num_samples: int, num_negatives: int = None):
"""Sample and return the specified number of indices."""

if self.sampling_mode_contrastive == "refneg_shared":
Expand Down
20 changes: 15 additions & 5 deletions cebra/data/single_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,9 @@ def _init_distribution(self):
f"Invalid choice of prior distribution. Got '{self.prior}', but "
f"only accept 'uniform' or 'empirical' as potential values.")

def get_indices(self, num_samples: int) -> BatchIndex:
def get_indices(self,
num_samples: int,
num_negatives: int = None) -> BatchIndex:
"""Samples indices for reference, positive and negative examples.

The reference samples will be sampled from the empirical or uniform prior
Expand Down Expand Up @@ -246,7 +248,9 @@ def _init_distribution(self):
else:
raise ValueError(self.conditional)

def get_indices(self, num_samples: int) -> BatchIndex:
def get_indices(self,
num_samples: int,
num_negatives: int = None) -> BatchIndex:
"""Samples indices for reference, positive and negative examples.

The reference and negative samples will be sampled uniformly from
Expand All @@ -262,7 +266,9 @@ def get_indices(self, num_samples: int) -> BatchIndex:
Returns:
Indices for reference, positive and negatives samples.
"""
reference_idx = self.distribution.sample_prior(num_samples * 2)
num_negatives = num_samples if num_negatives is None else num_negatives
total_samples = num_samples + num_negatives
reference_idx = self.distribution.sample_prior(total_samples)
negative_idx = reference_idx[num_samples:]
reference_idx = reference_idx[:num_samples]
positive_idx = self.distribution.sample_conditional(reference_idx)
Expand Down Expand Up @@ -305,7 +311,9 @@ def __post_init__(self):
continuous=self.cindex,
time_delta=self.time_offset)

def get_indices(self, num_samples: int) -> BatchIndex:
def get_indices(self,
num_samples: int,
num_negatives: int = None) -> BatchIndex:
"""Samples indices for reference, positive and negative examples.

The reference and negative samples will be sampled uniformly from
Expand Down Expand Up @@ -421,7 +429,9 @@ def _init_time_distribution(self):
else:
raise ValueError

def get_indices(self, num_samples: int) -> BatchIndex:
def get_indices(self,
num_samples: int,
num_negatives: int = None) -> BatchIndex:
"""Samples indices for reference, positive and negative examples.

The reference and negative samples will be sampled uniformly from
Expand Down
7 changes: 7 additions & 0 deletions cebra/dynamics/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import torch
from typing import Literal
import cebra.registry

cebra.registry.add_helper_functions(__name__)

from cebra.dynamics.linear import *
46 changes: 46 additions & 0 deletions cebra/dynamics/linear.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from typing import Literal

import torch

from cebra.dynamics import register


@register("identity")
class Identity(torch.nn.Identity):
pass


@register("linear")
class Linear(torch.nn.Linear):

def __init__(self, latent_dim: int, bias: bool = True):
super().__init__(latent_dim, latent_dim, bias)


@register("orthogonal-linear")
class OrthogonalLinear(Linear):
"""
A LinearDynamicsModel that is parametrized to only allow orthogonal dynamics matrices.
"""

def __init__(
self,
latent_dim: int,
bias: bool = True,
orthogonal_map: Literal[
"matrix_exp",
"cayley",
"householder",
] = "matrix_exp",
use_trivialization: bool = True,
):
super().__init__(latent_dim, bias)
self.orthogonal_map = orthogonal_map
self.use_trivialization = use_trivialization

torch.nn.utils.parametrizations.orthogonal(
self,
name="weight",
orthogonal_map=self.orthogonal_map,
use_trivialization=self.use_trivialization,
)
Loading
Loading