From 338fd76221c98536478b6aca06f9ae6239bced69 Mon Sep 17 00:00:00 2001 From: Steffen Schneider Date: Mon, 22 Jun 2026 23:23:03 +0200 Subject: [PATCH 1/6] Add dynamics contrastive learning Co-authored-by: Tobias Schmidt Co-authored-by: Rodrigo Gonzalez Laiz --- AUTHORS.md | 8 +- README.md | 4 + cebra/__init__.py | 7 +- cebra/data/base.py | 18 ++- cebra/data/multi_session.py | 8 +- cebra/data/multiobjective.py | 4 +- cebra/data/single_session.py | 20 ++- cebra/dynamics/__init__.py | 7 + cebra/dynamics/linear.py | 46 +++++++ cebra/integrations/sklearn/cebra.py | 185 ++++++++++++++++++++++++-- cebra/models/criterions.py | 46 ++++++- cebra/solver/base.py | 6 +- cebra/solver/single_session.py | 82 +++++++++++- tests/examples/dcl_example.py | 77 +++++++++++ tests/examples/dcl_sklearn_example.py | 52 ++++++++ tests/test_criterions.py | 67 ++++++++++ tests/test_dynamics.py | 41 ++++++ 17 files changed, 640 insertions(+), 38 deletions(-) create mode 100644 cebra/dynamics/__init__.py create mode 100644 cebra/dynamics/linear.py create mode 100644 tests/examples/dcl_example.py create mode 100644 tests/examples/dcl_sklearn_example.py create mode 100644 tests/test_dynamics.py diff --git a/AUTHORS.md b/AUTHORS.md index 17db8887..c29d0868 100644 --- a/AUTHORS.md +++ b/AUTHORS.md @@ -23,4 +23,10 @@ It is a publicly available tool that has benefited from contributions and sugges ### 2025 - **Steffen Schneider, Rodrigo González Laiz, Anastasiia Filippova, Markus Frey, Mackenzie W. Mathis** [*Time-series attribution maps with regularized contrastive learning.*](https://openreview.net/forum?id=aGrCXoTB4P) - AISTATS (2025) + AISTATS (2025). + +- **Rodrigo González Laiz*, Tobias Schmidt*, Steffen Schneider**: + [*Self-supervised contrastive learning performs non-linear system identification*](https://arxiv.org/abs/2410.14673) + ICLR (2025). + Adds the `cebra.dynamics` module and introduces improvements to `cebra.criterions`. For advanced features on dynamics learning, + the full reference implementation for the paper is available at https://github.com/dynamical-inference/dcl. diff --git a/README.md b/README.md index d75fd99b..0859de54 100644 --- a/README.md +++ b/README.md @@ -42,6 +42,10 @@ It can jointly use behavioral and neural data in a hypothesis- or discovery-driv Steffen Schneider, Rodrigo González Laiz, Anastasiia Filipova, Markus Frey, Mackenzie Weygandt Mathis. AISTATS 2025. +- 📄 **Publication April 2025**: + [Self-supervised contrastive learning performs non-linear system identification](https://arxiv.org/abs/2410.14673) + Rodrigo González Laiz*, Tobias Schmidt*, Steffen Schneider. ICLR 2025. + - 📄 **Publication May 2023**: [Learnable latent embeddings for joint behavioural and neural analysis.](https://doi.org/10.1038/s41586-023-06031-6) Steffen Schneider*, Jin Hwa Lee* and Mackenzie Weygandt Mathis. Nature 2023. diff --git a/cebra/__init__.py b/cebra/__init__.py index 0dc6c652..6e0388af 100644 --- a/cebra/__init__.py +++ b/cebra/__init__.py @@ -29,6 +29,7 @@ # TODO(stes): More common integrations people care about (e.g. PyTorch lightning) # could be added here. from cebra.integrations.sklearn.cebra import CEBRA + from cebra.integrations.sklearn.cebra import DCL from cebra.integrations.sklearn.decoder import KNNDecoder from cebra.integrations.sklearn.decoder import L1LinearRegressor @@ -67,7 +68,7 @@ import cebra.integrations.sklearn as sklearn __version__ = "0.6.1" -__all__ = ["CEBRA"] +__all__ = ["CEBRA", "DCL"] __allow_lazy_imports = False __lazy_imports = {} @@ -91,6 +92,10 @@ def __getattr__(key): from cebra.integrations.sklearn.cebra import CEBRA return CEBRA + elif key == "DCL": + from cebra.integrations.sklearn.cebra import DCL + + return DCL elif key == "KNNDecoder": from cebra.integrations.sklearn.decoder import KNNDecoder # noqa: F811 diff --git a/cebra/data/base.py b/cebra/data/base.py index f5491e51..e8c36276 100644 --- a/cebra/data/base.py +++ b/cebra/data/base.py @@ -239,8 +239,17 @@ class Loader(abc.ABC, cebra.io.HasDevice): """The total number of batches when iterating over the dataloader.""", ) - batch_size: int = dataclasses.field(default=None, - doc="""The total batch size.""") + batch_size: int = dataclasses.field( + default=None, + doc= + """The total batch size. Number of reference and number of positive samples""" + ) + + batch_size_negatives: int = dataclasses.field( + default=None, + doc= + """Number of negative samples for a given batch. If None defaults to batch size""" + ) def __post_init__(self): if self.num_steps is None or self.num_steps <= 0: @@ -258,11 +267,12 @@ def __len__(self): def __iter__(self) -> Batch: for _ in range(len(self)): - index = self.get_indices(num_samples=self.batch_size) + index = self.get_indices(num_samples=self.batch_size, + num_negatives=self.batch_size_negatives) yield self.dataset.load_batch(index) @abc.abstractmethod - def get_indices(self, num_samples: int): + def get_indices(self, num_samples: int, num_negatives: int = None): """Sample and return the specified number of indices. The elements of the returned `BatchIndex` will be used to index the diff --git a/cebra/data/multi_session.py b/cebra/data/multi_session.py index f33ad6ec..bf5cc2a1 100644 --- a/cebra/data/multi_session.py +++ b/cebra/data/multi_session.py @@ -156,7 +156,9 @@ def __post_init__(self): self.sampler = cebra.distributions.MultisessionSampler( self.dataset, self.time_offset) - def get_indices(self, num_samples: int) -> List[BatchIndex]: + def get_indices(self, + num_samples: int, + num_negatives: int = None) -> List[BatchIndex]: ref_idx = self.sampler.sample_prior(self.batch_size) neg_idx = self.sampler.sample_prior(self.batch_size) pos_idx, idx, idx_rev = self.sampler.sample_conditional(ref_idx) @@ -229,7 +231,9 @@ def __post_init__(self): self.sampler = cebra.distributions.UnifiedSampler( self.dataset, self.time_offset) - def get_indices(self, num_samples: int) -> BatchIndex: + def get_indices(self, + num_samples: int, + num_negatives: int = None) -> BatchIndex: """Sample and return the specified number of indices. The elements of the returned ``BatchIndex`` will be used to index the diff --git a/cebra/data/multiobjective.py b/cebra/data/multiobjective.py index f700d1c4..dbbb9b86 100644 --- a/cebra/data/multiobjective.py +++ b/cebra/data/multiobjective.py @@ -71,7 +71,7 @@ def __post_init__(self): def add_config(self, config): self.labels.append(config['label']) - def get_indices(self, num_samples: int): + def get_indices(self, num_samples: int, num_negatives: int = None): if self.sampling_mode_supervised == "ref_shared": reference_idx = self.prior.sample_prior(num_samples) else: @@ -142,7 +142,7 @@ def add_config(self, config): self.distributions.append(distribution) - def get_indices(self, num_samples: int): + def get_indices(self, num_samples: int, num_negatives: int = None): """Sample and return the specified number of indices.""" if self.sampling_mode_contrastive == "refneg_shared": diff --git a/cebra/data/single_session.py b/cebra/data/single_session.py index 7e4ad2fd..bc307fd5 100644 --- a/cebra/data/single_session.py +++ b/cebra/data/single_session.py @@ -138,7 +138,9 @@ def _init_distribution(self): f"Invalid choice of prior distribution. Got '{self.prior}', but " f"only accept 'uniform' or 'empirical' as potential values.") - def get_indices(self, num_samples: int) -> BatchIndex: + def get_indices(self, + num_samples: int, + num_negatives: int = None) -> BatchIndex: """Samples indices for reference, positive and negative examples. The reference samples will be sampled from the empirical or uniform prior @@ -246,7 +248,9 @@ def _init_distribution(self): else: raise ValueError(self.conditional) - def get_indices(self, num_samples: int) -> BatchIndex: + def get_indices(self, + num_samples: int, + num_negatives: int = None) -> BatchIndex: """Samples indices for reference, positive and negative examples. The reference and negative samples will be sampled uniformly from @@ -262,7 +266,9 @@ def get_indices(self, num_samples: int) -> BatchIndex: Returns: Indices for reference, positive and negatives samples. """ - reference_idx = self.distribution.sample_prior(num_samples * 2) + num_negatives = num_samples if num_negatives is None else num_negatives + total_samples = num_samples + num_negatives + reference_idx = self.distribution.sample_prior(total_samples) negative_idx = reference_idx[num_samples:] reference_idx = reference_idx[:num_samples] positive_idx = self.distribution.sample_conditional(reference_idx) @@ -305,7 +311,9 @@ def __post_init__(self): continuous=self.cindex, time_delta=self.time_offset) - def get_indices(self, num_samples: int) -> BatchIndex: + def get_indices(self, + num_samples: int, + num_negatives: int = None) -> BatchIndex: """Samples indices for reference, positive and negative examples. The reference and negative samples will be sampled uniformly from @@ -421,7 +429,9 @@ def _init_time_distribution(self): else: raise ValueError - def get_indices(self, num_samples: int) -> BatchIndex: + def get_indices(self, + num_samples: int, + num_negatives: int = None) -> BatchIndex: """Samples indices for reference, positive and negative examples. The reference and negative samples will be sampled uniformly from diff --git a/cebra/dynamics/__init__.py b/cebra/dynamics/__init__.py new file mode 100644 index 00000000..f272b92f --- /dev/null +++ b/cebra/dynamics/__init__.py @@ -0,0 +1,7 @@ +import torch +from typing import Literal +import cebra.registry + +cebra.registry.add_helper_functions(__name__) + +from cebra.dynamics.linear import * diff --git a/cebra/dynamics/linear.py b/cebra/dynamics/linear.py new file mode 100644 index 00000000..7a8d4ee8 --- /dev/null +++ b/cebra/dynamics/linear.py @@ -0,0 +1,46 @@ +from typing import Literal + +import torch + +from cebra.dynamics import register + + +@register("identity") +class Identity(torch.nn.Identity): + pass + + +@register("linear") +class Linear(torch.nn.Linear): + + def __init__(self, latent_dim: int, bias: bool = True): + super().__init__(latent_dim, latent_dim, bias) + + +@register("orthogonal-linear") +class OrthogonalLinear(Linear): + """ + A LinearDynamicsModel that is parametrized to only allow orthogonal dynamics matrices. + """ + + def __init__( + self, + latent_dim: int, + bias: bool = True, + orthogonal_map: Literal[ + "matrix_exp", + "cayley", + "householder", + ] = "matrix_exp", + use_trivialization: bool = True, + ): + super().__init__(latent_dim, bias) + self.orthogonal_map = orthogonal_map + self.use_trivialization = use_trivialization + + torch.nn.utils.parametrizations.orthogonal( + self, + name="weight", + orthogonal_map=self.orthogonal_map, + use_trivialization=self.use_trivialization, + ) diff --git a/cebra/integrations/sklearn/cebra.py b/cebra/integrations/sklearn/cebra.py index 00645523..dade1307 100644 --- a/cebra/integrations/sklearn/cebra.py +++ b/cebra/integrations/sklearn/cebra.py @@ -40,6 +40,7 @@ from torch import nn import cebra.data +import cebra.dynamics import cebra.integrations.sklearn import cebra.integrations.sklearn.dataset as cebra_sklearn_dataset import cebra.integrations.sklearn.utils as sklearn_utils @@ -357,6 +358,7 @@ def _load_cebra_with_sklearn_backend(cebra_info: Dict) -> "CEBRA": args, state, state_dict = cebra_info['args'], cebra_info[ 'state'], cebra_info['state_dict'] + cebra_ = cebra.CEBRA(**args) for key, value in state.items(): @@ -390,24 +392,45 @@ def _load_cebra_with_sklearn_backend(cebra_info: Dict) -> "CEBRA": criterion = cebra_._prepare_criterion() criterion.to(state['device_']) - optimizer = torch.optim.Adam( - itertools.chain(model.parameters(), criterion.parameters()), - lr=args['learning_rate'], - **dict(args['optimizer_kwargs']), - ) + # Create the dynamics model when the estimator was trained in DCL mode. + dynamics_model = cebra_._prepare_dynamics_model() + if dynamics_model is not None: + dynamics_model.to(state['device_']) + + if dynamics_model is not None: + optimizer = torch.optim.Adam( + itertools.chain(model.parameters(), dynamics_model.parameters(), + criterion.parameters()), + lr=args['learning_rate'], + **dict(args['optimizer_kwargs']), + ) + else: + optimizer = torch.optim.Adam( + itertools.chain(model.parameters(), criterion.parameters()), + lr=args['learning_rate'], + **dict(args['optimizer_kwargs']), + ) + + solver_kwargs = { + "model": model, + "criterion": criterion, + "optimizer": optimizer, + "tqdm_on": args['verbose'], + } + if dynamics_model is not None: + solver_kwargs["dynamics_model"] = dynamics_model solver = cebra.solver.init( state['solver_name_'], - model=model, - criterion=criterion, - optimizer=optimizer, - tqdm_on=args['verbose'], + **solver_kwargs, ) solver.load_state_dict(state_dict) solver.to(state['device_']) cebra_.model_ = model cebra_.solver_ = solver + if dynamics_model is not None: + cebra_.dynamics_model_ = dynamics_model return cebra_ @@ -504,6 +527,15 @@ class CEBRA(TransformerMixin, BaseEstimator): hybrid (bool): If ``True``, the model will be trained using both the time-contrastive and the selected behavior-constrastive loss functions. |Default:| ``False``. + full_denominator (bool): + If ``True``, the InfoNCE loss will use the full denominator formulation, which includes + the positive sample in the denominator of the softmax. This can improve numerical stability + and training dynamics in some cases. |Default:| ``False``. + dynamics_model_architecture (str): + If set, train with an auxiliary dynamics model applied to the reference samples, as in + Dynamics-aware Contrastive Learning (DCL). The value selects the dynamics model + architecture registered with :py:mod:`cebra.dynamics` (e.g. ``"linear"``). When ``None``, + no dynamics model is used and the estimator behaves as standard CEBRA. |Default:| ``None``. optimizer_kwargs (tuple): Additional optimization parameters. These have the form ``((key, value), (key, value))`` and are passed to the PyTorch optimizer specified through the ``optimizer`` argument. Refer to the @@ -574,6 +606,7 @@ def __init__( max_iterations: int = 10000, max_adapt_iterations: int = 500, batch_size: int = None, + batch_size_negatives: int = None, learning_rate: float = 3e-4, optimizer: str = "adam", output_dimension: int = 8, @@ -581,6 +614,8 @@ def __init__( num_hidden_units: int = 32, pad_before_transform: bool = True, hybrid: bool = False, + full_denominator: bool = False, + dynamics_model_architecture: Optional[str] = None, optimizer_kwargs: Tuple[Tuple[str, object], ...] = ( ("betas", (0.9, 0.999)), ("eps", 1e-08), @@ -740,6 +775,7 @@ def _prepare_loader(self, dataset: cebra.data.Dataset, max_iterations: int, shared_kwargs=dict( dataset=dataset, batch_size=self.batch_size, + batch_size_negatives=self.batch_size_negatives, num_steps=max_iterations, ), extra_kwargs=dict( @@ -761,19 +797,25 @@ def _prepare_criterion(self): return cebra.models.LearnableCosineInfoNCE( temperature=self.temperature, min_temperature=self.min_temperature, + full_denominator=self.full_denominator, ) elif self.distance == "euclidean": return cebra.models.LearnableEuclideanInfoNCE( temperature=self.temperature, min_temperature=self.min_temperature, + full_denominator=self.full_denominator, ) elif self.temperature_mode == "constant": if self.distance == "cosine": return cebra.models.FixedCosineInfoNCE( - temperature=self.temperature,) + temperature=self.temperature, + full_denominator=self.full_denominator, + ) elif self.distance == "euclidean": return cebra.models.FixedEuclideanInfoNCE( - temperature=self.temperature,) + temperature=self.temperature, + full_denominator=self.full_denominator, + ) raise ValueError(f"Unknown similarity measure '{self.distance}' for " f"criterion '{self.criterion}'.") @@ -890,6 +932,20 @@ def _check_labels_types(self, y: tuple, session_id: Optional[int] = None): "Labels invalid: must have the same type of features as the ones used for fitting," f"expects {label_types_idx[0]}, got {y[i].dtype}.") + def _prepare_dynamics_model(self) -> Optional[torch.nn.Module]: + """Create the dynamics model when training in DCL mode. + + Returns: + A dynamics model selected via :py:attr:`dynamics_model_architecture`, or ``None`` + when no dynamics model is requested (i.e. standard CEBRA training). + """ + if self.dynamics_model_architecture is None: + return None + return cebra.dynamics.init( + self.dynamics_model_architecture, + latent_dim=self.output_dimension, + ) + def _prepare_fit( self, X: Union[npt.NDArray, torch.Tensor], @@ -926,10 +982,22 @@ def _prepare_fit( self._configure_for_all(dataset, model, is_multisession) + dynamics_model = self._prepare_dynamics_model() + criterion = self._prepare_criterion() criterion.to(self.device_) + + trainable_parameters = [model.parameters(), criterion.parameters()] + solver_kwargs = {} + if dynamics_model is not None: + dynamics_model.to(self.device_) + trainable_parameters.insert(1, dynamics_model.parameters()) + solver_name = solver_name + "-dcl" + solver_kwargs["dynamics_model"] = dynamics_model + self.dynamics_model_ = dynamics_model + optimizer = torch.optim.Adam( - itertools.chain(model.parameters(), criterion.parameters()), + itertools.chain(*trainable_parameters), lr=self.learning_rate, **dict(self.optimizer_kwargs), ) @@ -940,6 +1008,7 @@ def _prepare_fit( criterion=criterion, optimizer=optimizer, tqdm_on=self.verbose, + **solver_kwargs, ) solver.to(self.device_) self.solver_name_ = solver_name @@ -1015,11 +1084,22 @@ def _adapt_model( self._configure_for_all(dataset, adapt_model, is_multisession) + # Reuse the existing dynamics model (it does not depend on input dimension). + dynamics_model = getattr(self, "dynamics_model_", None) + criterion = self._prepare_criterion() criterion.to(self.device_) + trainable_parameters = list(adapt_model.parameters()) + list( + criterion.parameters()) + solver_kwargs = {} + if dynamics_model is not None: + trainable_parameters += list(dynamics_model.parameters()) + solver_name = solver_name + "-dcl" + solver_kwargs["dynamics_model"] = dynamics_model + optimizer = torch.optim.Adam( - list(adapt_model.parameters()) + list(criterion.parameters()), + trainable_parameters, lr=self.learning_rate, **dict(self.optimizer_kwargs), ) @@ -1030,6 +1110,7 @@ def _adapt_model( criterion=criterion, optimizer=optimizer, tqdm_on=self.verbose, + **solver_kwargs, ) solver.to(self.device_) @@ -1585,4 +1666,82 @@ def to(self, device: Union[str, torch.device]): self.device = device self.solver_.model.to(device) + if hasattr(self, "dynamics_model_"): + self.dynamics_model_.to(device) + return self + + +class DCL(CEBRA): + """CEBRA preset for Dynamics-aware Contrastive Learning (DCL). + + This is a thin convenience wrapper around :py:class:`CEBRA` that only changes the + default hyperparameters to sensible values for DCL: it enables a ``"linear"`` + dynamics model on the reference samples, uses the full-denominator InfoNCE + formulation (positives are included in the softmax denominator), and samples more + negatives than positives per batch. All behaviour is implemented in :py:class:`CEBRA`; + instantiating ``DCL(...)`` is equivalent to calling + :py:class:`CEBRA` with these defaults, and every argument can still be overridden. + """ + + def __init__( + self, + model_architecture: str = "offset1-model", + device: str = "cuda_if_available", + criterion: str = "infonce", + distance: str = "cosine", + conditional: str = None, + temperature: float = 1.0, + temperature_mode: Literal["constant", "auto"] = "constant", + min_temperature: Optional[float] = 0.1, + time_offsets: int = 1, + delta: float = None, + max_iterations: int = 10000, + max_adapt_iterations: int = 500, + batch_size: int = 4096, + batch_size_negatives: int = 20000, + learning_rate: float = 3e-4, + optimizer: str = "adam", + output_dimension: int = 8, + verbose: bool = False, + num_hidden_units: int = 32, + pad_before_transform: bool = True, + hybrid: bool = False, + full_denominator: bool = True, + dynamics_model_architecture: Optional[str] = "linear", + optimizer_kwargs: Tuple[Tuple[str, object], ...] = ( + ("betas", (0.9, 0.999)), + ("eps", 1e-08), + ("weight_decay", 0), + ("amsgrad", False), + ), + masking_kwargs: Tuple[Tuple[str, Union[float, List[float], + Tuple[float, ...]]], ...] = None, + ): + super().__init__( + model_architecture=model_architecture, + device=device, + criterion=criterion, + distance=distance, + conditional=conditional, + temperature=temperature, + temperature_mode=temperature_mode, + min_temperature=min_temperature, + time_offsets=time_offsets, + delta=delta, + max_iterations=max_iterations, + max_adapt_iterations=max_adapt_iterations, + batch_size=batch_size, + batch_size_negatives=batch_size_negatives, + learning_rate=learning_rate, + optimizer=optimizer, + output_dimension=output_dimension, + verbose=verbose, + num_hidden_units=num_hidden_units, + pad_before_transform=pad_before_transform, + hybrid=hybrid, + full_denominator=full_denominator, + dynamics_model_architecture=dynamics_model_architecture, + optimizer_kwargs=optimizer_kwargs, + masking_kwargs=masking_kwargs, + ) diff --git a/cebra/models/criterions.py b/cebra/models/criterions.py index f78e298b..24f95994 100644 --- a/cebra/models/criterions.py +++ b/cebra/models/criterions.py @@ -113,6 +113,32 @@ def infonce( return align + uniform, align_corrected, uniform_corrected +@torch.jit.script +def infonce_full_denominator(pos_dist, neg_dist): + + with torch.no_grad(): + c, _ = neg_dist.max(dim=1, keepdim=True) + c = c.detach() + + pos_dist = pos_dist - c.squeeze(1) + neg_dist = neg_dist - c + + numerator = (-pos_dist).mean() + denominator = torch.logsumexp( + torch.concatenate([ + pos_dist.unsqueeze(1), + neg_dist, + ], dim=1), + dim=1, + ).mean() + + c_mean = c.mean() + numerator = numerator - c_mean + denominator = denominator + c_mean + + return numerator + denominator, numerator, denominator + + class ContrastiveLoss(nn.Module): """Base class for contrastive losses. @@ -149,6 +175,13 @@ class BaseInfoNCE(ContrastiveLoss): """ + def __init__(self, full_denominator: bool = False): + super().__init__() + if full_denominator: + self.infonce = infonce_full_denominator + else: + self.infonce = infonce + def _distance(self, ref: torch.Tensor, pos: torch.Tensor, neg: torch.Tensor) -> Tuple[torch.Tensor]: """The similarity measure. @@ -178,7 +211,7 @@ def forward(self, ref, pos, :py:class:`BaseInfoNCE`. """ pos_dist, neg_dist = self._distance(ref, pos, neg) - return infonce(pos_dist, neg_dist) + return self.infonce(pos_dist, neg_dist) class FixedInfoNCE(BaseInfoNCE): @@ -189,8 +222,10 @@ class FixedInfoNCE(BaseInfoNCE): The softmax temperature """ - def __init__(self, temperature: float = 1.0): - super().__init__() + def __init__(self, + temperature: float = 1.0, + full_denominator: bool = False): + super().__init__(full_denominator) self.temperature = temperature @@ -207,8 +242,9 @@ class LearnableInfoNCE(BaseInfoNCE): def __init__(self, temperature: float = 1.0, - min_temperature: Optional[float] = None): - super().__init__() + min_temperature: Optional[float] = None, + full_denominator: bool = False): + super().__init__(full_denominator) if min_temperature is None: self.max_inverse_temperature = math.inf else: diff --git a/cebra/solver/base.py b/cebra/solver/base.py index c04c3398..2766cd79 100644 --- a/cebra/solver/base.py +++ b/cebra/solver/base.py @@ -910,9 +910,9 @@ def transform(self, The output embedding. """ self._check_is_fitted() - model, offset = self._select_model( - inputs, session_id, use_reference_model=use_reference_model) - + model = self._get_model(session_id=session_id, + use_reference_model=use_reference_model) + offset = model.get_offset() if len(offset) < 2 and pad_before_transform: pad_before_transform = False diff --git a/cebra/solver/single_session.py b/cebra/solver/single_session.py index 61b880e4..872f9f8e 100644 --- a/cebra/solver/single_session.py +++ b/cebra/solver/single_session.py @@ -196,9 +196,9 @@ def _get_model(self, self._check_is_session_id_valid(session_id=session_id) self._check_is_fitted() if use_reference_model: - model = self.reference_model[session_id] + model = self.reference_model else: - model = self.model[session_id] + model = self.model return model def _inference(self, batch: cebra.data.Batch) -> cebra.data.Batch: @@ -224,6 +224,21 @@ def _inference(self, batch: cebra.data.Batch) -> cebra.data.Batch: return cebra.data.Batch(ref, pos, neg) +@register("single-session-dcl") +@dataclasses.dataclass +class SingleSessionDCLSolver(SingleSessionAuxVariableSolver): + dynamics_model: torch.nn.Module = None + + def __post_init__(self): + if self.reference_model is not None: + raise ValueError("Reference model must be None for DCL solver") + self.reference_model = torch.nn.Sequential( + self.model, + self.dynamics_model, + ) + super().__post_init__() + + @register("single-session-hybrid") @dataclasses.dataclass class SingleSessionHybridSolver(abc_.MultiobjectiveSolver, SingleSessionSolver): @@ -349,3 +364,66 @@ def _inference(self, batch: cebra.data.Batch) -> cebra.data.Batch: outputs[batch.positive - self.offset.left], outputs[batch.negative - self.offset.left], ) + + +@register("single-session-full-dcl") +@dataclasses.dataclass +class BatchSingleSessionDCLSolver(BatchSingleSessionSolver, + SingleSessionAuxVariableSolver): + """Optimize a DCL model with batch gradient descent. + + This solver combines the batch gradient descent approach of + :py:class:`BatchSingleSessionSolver` with the dynamics model functionality + of :py:class:`SingleSessionDCLSolver`. It uses the full dataset for training + and applies the dynamics model to reference samples. + + Usage of this solver requires a sufficient amount of GPU memory. Using this solver + is equivalent to using a single session DCL solver with batch size set to dataset size, + but requires less computation. + """ + + dynamics_model: torch.nn.Module = None + + def __post_init__(self): + # Set up the reference model as Sequential(model, dynamics_model) for DCL + if self.reference_model is not None: + raise ValueError("Reference model must be None for DCL solver") + if self.dynamics_model is None: + raise ValueError("Dynamics model must be provided for DCL solver") + self.reference_model = torch.nn.Sequential( + self.model, + self.dynamics_model, + ) + # Call parent __post_init__ methods + super().__post_init__() + + def _inference(self, batch: cebra.data.Batch) -> cebra.data.Batch: + """Given a batch of input examples, computes the feature representation/embedding. + + For DCL, reference samples are processed through model + dynamics_model, + while positive and negative samples are processed through model only. + + Args: + batch: The input data, not necessarily aligned across the batch + dimension. This means that ``batch.index`` specifies the map + between reference/positive samples, if not equal ``None``. + + Returns: + Processed batch of data. While the input data might not be aligned + across the sample dimensions, the output data should be aligned and + ``batch.index`` should be set to ``None``. + """ + # Get embeddings for all samples using the regular model + outputs = self.get_embedding(self.neural) + idc = batch.positive - self.offset.left >= len(outputs) + batch.positive[idc] = batch.reference[idc] + + # For reference samples, apply dynamics model + ref_embeddings = outputs[batch.reference - self.offset.left] + ref_embeddings = self.dynamics_model(ref_embeddings) + + # For positive and negative samples, use regular embeddings + pos_embeddings = outputs[batch.positive - self.offset.left] + neg_embeddings = outputs[batch.negative - self.offset.left] + + return cebra.data.Batch(ref_embeddings, pos_embeddings, neg_embeddings) diff --git a/tests/examples/dcl_example.py b/tests/examples/dcl_example.py new file mode 100644 index 00000000..5ee7fb64 --- /dev/null +++ b/tests/examples/dcl_example.py @@ -0,0 +1,77 @@ +#!/usr/bin/env python3 +""" +Example script demonstrating how to use CEBRA for time contrastive learning. +""" + +import sklearn.decomposition +import torch + +import cebra +import cebra.datasets + +if __name__ == "__main__": + + device = "mps" + + input_data = cebra.datasets.init("rat-hippocampus-single-achilles") + + latent_dim = 5 + + neural_model = cebra.models.init( + name="offset10-model-mse", + num_neurons=input_data.input_dimension, + num_units=32, + num_output=latent_dim, + ).to(device) + + dynamics_model = cebra.dynamics.init( + name="linear", + latent_dim=latent_dim, + bias=True, + ).to(device) + + input_data.configure_for(neural_model) + + crit = cebra.models.criterions.FixedEuclideanInfoNCE( + temperature=1, + full_denominator=True, + ).to(device) + + opt = torch.optim.Adam(list(neural_model.parameters()) + + list(dynamics_model.parameters()) + + list(crit.parameters()), + lr=0.001, + weight_decay=0) + + solver = cebra.solver.init( + name="single-session-dcl", + model=neural_model, + criterion=crit, + optimizer=opt, + tqdm_on=True, + dynamics_model=dynamics_model, + ).to(device) + + loader = cebra.data.single_session.ContinuousDataLoader( + dataset=input_data, + num_steps=500, + batch_size=2048, + batch_size_negatives=10000, + conditional="time", + time_offset=10, + ).to(device) + + solver.fit(loader=loader) + + x_train_emb = solver.transform(input_data.neural) + + ica = sklearn.decomposition.FastICA(n_components=2) + x_train_emb = ica.fit_transform(x_train_emb.cpu()) + + ax = cebra.plot_embedding( + x_train_emb, + embedding_labels=input_data.continuous_index[:, 0].cpu(), + markersize=10, + cmap="rainbow") + + ax.figure.savefig("dcl-example.png") diff --git a/tests/examples/dcl_sklearn_example.py b/tests/examples/dcl_sklearn_example.py new file mode 100644 index 00000000..3777b1bf --- /dev/null +++ b/tests/examples/dcl_sklearn_example.py @@ -0,0 +1,52 @@ +#!/usr/bin/env python3 +""" +Example script demonstrating how to use CEBRA DCL with the sklearn API. +This reproduces the functionality of dcl_example.py using the sklearn wrapper. +""" + +import sklearn.decomposition + +import cebra +import cebra.datasets + +if __name__ == "__main__": + device = "mps" + + input_data = cebra.datasets.init("rat-hippocampus-single-achilles") + + latent_dim = 5 + + # Create DCL model using sklearn API + dcl_model = cebra.CEBRA( + model_architecture="offset10-model-mse", + output_dimension=latent_dim, + num_hidden_units=32, + batch_size=4096, + learning_rate=0.001, + max_iterations=500, + time_offsets=10, + conditional="time", + distance="euclidean", + temperature=1.0, + device=device, + verbose=True, + dynamics_model_architecture="linear", + full_denominator=True, # Match the PyTorch API example + ) + + # Fit the model + dcl_model.fit(input_data.neural.cpu().numpy()) + + # Transform to get embeddings + x_train_emb = dcl_model.transform(input_data.neural.cpu().numpy()) + + ica = sklearn.decomposition.FastICA(n_components=2) + x_train_emb = ica.fit_transform(x_train_emb) + + ax = cebra.plot_embedding( + x_train_emb, + embedding_labels=input_data.continuous_index[:, 0].cpu(), + markersize=10, + cmap="rainbow") + + ax.figure.savefig("dcl-example-sklearn.png") diff --git a/tests/test_criterions.py b/tests/test_criterions.py index 0d6f8ff2..cda01b86 100644 --- a/tests/test_criterions.py +++ b/tests/test_criterions.py @@ -350,3 +350,70 @@ def test_infonce_gradients(seed, case): assert torch.allclose(grad[0], torch.zeros_like(grad[0])) assert grad[1] is not None assert torch.allclose(grad_ref[1], grad[1]) + + +def test_infonce_full_denominator(): + rng = torch.Generator().manual_seed(42) + + pos_dist = torch.randn(100, generator=rng) + neg_dist = torch.randn(100, 100, generator=rng) + + loss, align, uniform = cebra_criterions.infonce_full_denominator( + pos_dist, neg_dist) + + assert loss.dim() == 0 + assert align.dim() == 0 + assert uniform.dim() == 0 + + assert torch.allclose(loss, align + uniform) + + assert not torch.isnan(loss) + assert not torch.isinf(loss) + assert not torch.isnan(align) + assert not torch.isnan(uniform) + + def simple_infonce_full_denominator(pos_dist, neg_dist): + numerator = (-pos_dist).mean() + + all_distances = torch.concatenate([ + pos_dist.unsqueeze(1), + neg_dist, + ], + dim=1) + denominator = torch.logsumexp(all_distances, dim=1).mean() + + return numerator + denominator, numerator, denominator + + simple_loss, simple_align, simple_uniform = simple_infonce_full_denominator( + pos_dist, neg_dist) + + assert torch.allclose(loss, simple_loss, rtol=1e-3, atol=1e-3) + assert torch.allclose(align, simple_align, rtol=1e-3, atol=1e-3) + assert torch.allclose(uniform, simple_uniform, rtol=1e-3, atol=1e-3) + + +@pytest.mark.parametrize("full_denominator", [True, False]) +@pytest.mark.parametrize("criterion", [ + cebra_criterions.FixedCosineInfoNCE, + cebra_criterions.FixedEuclideanInfoNCE, + cebra_criterions.LearnableCosineInfoNCE, + cebra_criterions.LearnableEuclideanInfoNCE, +]) +def test_infonce_full_denominator_api(criterion, full_denominator): + + ref = torch.randn(10, 5) + pos = torch.randn(10, 5) + neg = torch.randn(10, 5) + + crit = criterion(temperature=1.0, full_denominator=full_denominator) + pos_dist, neg_dist = crit._distance(ref, pos, neg) + + if full_denominator: + expected_loss = cebra_criterions.infonce_full_denominator( + pos_dist, neg_dist)[0] + else: + expected_loss = cebra_criterions.infonce(pos_dist, neg_dist)[0] + + actual_loss = crit(ref, pos, neg)[0] + + assert torch.allclose(expected_loss, actual_loss, rtol=1e-5) diff --git a/tests/test_dynamics.py b/tests/test_dynamics.py new file mode 100644 index 00000000..58b5323a --- /dev/null +++ b/tests/test_dynamics.py @@ -0,0 +1,41 @@ +import torch + +import cebra.dynamics.linear as linear_dynamics + + +def test_linear_dynamics(): + latent_dim = 5 + model = linear_dynamics.Linear(latent_dim=latent_dim, bias=True) + + assert model.weight.shape == (latent_dim, latent_dim) + assert model.bias.shape == (latent_dim,) + + x = torch.randn(10, latent_dim) + out = model(x) + assert out.shape == (10, latent_dim) + + +def test_orthogonal_linear_dynamics(): + latent_dim = 5 + model = linear_dynamics.OrthogonalLinear(latent_dim=latent_dim, bias=True) + + assert model.weight.shape == (latent_dim, latent_dim) + assert model.bias.shape == (latent_dim,) + + x = torch.randn(10, latent_dim) + out = model(x) + assert out.shape == (10, latent_dim) + + W = model.weight.detach() + WWT = W @ W.T + identity = torch.eye(latent_dim) + assert torch.allclose(WWT, identity, atol=1e-6) + + +def test_identity_dynamics(): + model = linear_dynamics.Identity() + + x = torch.randn(10, 5) + out = model(x) + assert torch.allclose(x, out) + assert x.shape == out.shape From 335936ea211a8f5f5d1e442e5c7be7522263d014 Mon Sep 17 00:00:00 2001 From: Steffen Schneider Date: Tue, 23 Jun 2026 00:07:50 +0200 Subject: [PATCH 2/6] Apply fix for non-writeable array --- cebra/integrations/sklearn/cebra.py | 1 + 1 file changed, 1 insertion(+) diff --git a/cebra/integrations/sklearn/cebra.py b/cebra/integrations/sklearn/cebra.py index dade1307..bb422ece 100644 --- a/cebra/integrations/sklearn/cebra.py +++ b/cebra/integrations/sklearn/cebra.py @@ -888,6 +888,7 @@ def _configure_for_all( def _select_model(self, X: Union[npt.NDArray, torch.Tensor], session_id: int): if isinstance(X, np.ndarray): + X = cebra_sklearn_dataset._ensure_writable(X) X = torch.from_numpy(X) return self.solver_._select_model(X, session_id=session_id) From 178ac00e01f59ef96beef4fbf7b13b2e1b94dd73 Mon Sep 17 00:00:00 2001 From: Steffen Schneider Date: Tue, 23 Jun 2026 00:10:26 +0200 Subject: [PATCH 3/6] temporarily limit numpy<2.5 --- setup.cfg | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index 0a1af183..c3eeb409 100644 --- a/setup.cfg +++ b/setup.cfg @@ -30,7 +30,8 @@ python_requires = >=3.10 install_requires = joblib numpy<2.0;platform_system=="Windows" - numpy;platform_system!="Windows" and python_version>="3.10" + # temporary 2.5 pin until update issues are resolved + numpy<2.5;platform_system!="Windows" and python_version>="3.10" literate-dataclasses scikit-learn scipy From e9133484698e2077f626218876550c76ca49da50 Mon Sep 17 00:00:00 2001 From: Steffen Schneider Date: Tue, 23 Jun 2026 00:21:08 +0200 Subject: [PATCH 4/6] fix pickle error --- cebra/models/criterions.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/cebra/models/criterions.py b/cebra/models/criterions.py index 24f95994..6ed1c3b1 100644 --- a/cebra/models/criterions.py +++ b/cebra/models/criterions.py @@ -177,10 +177,12 @@ class BaseInfoNCE(ContrastiveLoss): def __init__(self, full_denominator: bool = False): super().__init__() - if full_denominator: - self.infonce = infonce_full_denominator - else: - self.infonce = infonce + # NOTE(stes): Store a boolean flag rather than a reference to the + # ``torch.jit.script`` function. Assigning the ScriptFunction as an + # instance attribute places it in the module ``__dict__``, which then + # cannot be pickled by ``torch.save`` ("ScriptFunction cannot be + # pickled"). See https://github.com/AdaptiveMotorControlLab/CEBRA/pull/301 + self.full_denominator = full_denominator def _distance(self, ref: torch.Tensor, pos: torch.Tensor, neg: torch.Tensor) -> Tuple[torch.Tensor]: @@ -211,7 +213,9 @@ def forward(self, ref, pos, :py:class:`BaseInfoNCE`. """ pos_dist, neg_dist = self._distance(ref, pos, neg) - return self.infonce(pos_dist, neg_dist) + if self.full_denominator: + return infonce_full_denominator(pos_dist, neg_dist) + return infonce(pos_dist, neg_dist) class FixedInfoNCE(BaseInfoNCE): From 371566f80524c93a9e41f7cff8a3766359ec0329 Mon Sep 17 00:00:00 2001 From: Steffen Schneider Date: Tue, 23 Jun 2026 00:29:52 +0200 Subject: [PATCH 5/6] pin docs requirements --- docs/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/requirements.txt b/docs/requirements.txt index f9fc1663..87dbdc7f 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -19,4 +19,4 @@ sphinxcontrib-serializinghtml==2.0.0 literate_dataclasses # For IPython.sphinxext.ipython_console_highlighting extension ipython -numpy +numpy<2.5 # temporary fix, remove once objects.inv are updated From bdb8fa0414db7703a83ba41f2cb661b31faef6a3 Mon Sep 17 00:00:00 2001 From: Steffen Schneider Date: Tue, 23 Jun 2026 00:41:54 +0200 Subject: [PATCH 6/6] add minimal docs page for dynamics --- docs/requirements.txt | 4 +++- docs/source/api.rst | 1 + docs/source/api/pytorch/dynamics.rst | 26 ++++++++++++++++++++++++++ 3 files changed, 30 insertions(+), 1 deletion(-) create mode 100644 docs/source/api/pytorch/dynamics.rst diff --git a/docs/requirements.txt b/docs/requirements.txt index 87dbdc7f..afff44b1 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -19,4 +19,6 @@ sphinxcontrib-serializinghtml==2.0.0 literate_dataclasses # For IPython.sphinxext.ipython_console_highlighting extension ipython -numpy<2.5 # temporary fix, remove once objects.inv are updated +# TODO(stes): temporary fix, remove once objects.inv are updated +# see https://github.com/AdaptiveMotorControlLab/CEBRA/issues/303 +numpy<2.5 diff --git a/docs/source/api.rst b/docs/source/api.rst index 846602f1..28b65729 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -37,6 +37,7 @@ these components in other contexts and research code bases. api/pytorch/datasets api/pytorch/distributions api/pytorch/models + api/pytorch/dynamics api/pytorch/helpers api/pytorch/multiobjective api/pytorch/regularized diff --git a/docs/source/api/pytorch/dynamics.rst b/docs/source/api/pytorch/dynamics.rst new file mode 100644 index 00000000..d3cd6fd2 --- /dev/null +++ b/docs/source/api/pytorch/dynamics.rst @@ -0,0 +1,26 @@ +Dynamics Models +--------------- + +Dynamics models used for Dynamics-aware Contrastive Learning (DCL). They are +applied to the reference samples during training and selected by name via +:py:attr:`cebra.CEBRA.dynamics_model_architecture`. + +.. automodule:: cebra.dynamics + :members: + :show-inheritance: + +Registration and initialization +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: init + +.. autofunction:: get_options + +.. autofunction:: register + +Models +~~~~~~ + +.. automodule:: cebra.dynamics.linear + :members: + :show-inheritance: \ No newline at end of file