From 66c8872118dda9c565c9ac7df35748d7548d401e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9lia=20Benquet?= <32598028+CeliaBenquet@users.noreply.github.com> Date: Fri, 22 Aug 2025 13:34:29 +0200 Subject: [PATCH 1/6] Parametrize offset models --- cebra/models/model.py | 183 ++++++++---------------------------------- docs/source/usage.rst | 2 +- 2 files changed, 34 insertions(+), 151 deletions(-) diff --git a/cebra/models/model.py b/cebra/models/model.py index 77423532..ef80d57b 100644 --- a/cebra/models/model.py +++ b/cebra/models/model.py @@ -260,59 +260,62 @@ def num_trainable_parameters(self) -> int: param.numel() for param in self.parameters() if param.requires_grad) -@register("offset10-model") -class Offset10Model(_OffsetModel, ConvolutionalModelMixin): - """CEBRA model with a 10 sample receptive field.""" +@parametrize("offset{n_offset}-model", + n_offset=(5, 10, 15, 18, 20, 31, 36, 40, 50)) +class OffsetNModel(_OffsetModel, ConvolutionalModelMixin): + """CEBRA model with a `n_offset` sample receptive field. - def __init__(self, num_neurons, num_units, num_output, normalize=True): + n_offset: The size of the receptive field. + """ + + def __init__(self, + num_neurons, + num_units, + num_output, + n_offset, + normalize=True): if num_units < 1: raise ValueError( f"Hidden dimension needs to be at least 1, but got {num_units}." ) + + self.n_offset = n_offset + + def _compute_num_layers(n_offset): + """Compute the number of layers to add on top of the first and last conv layers.""" + return (n_offset - 4) // 2 + self.n_offset % 2 + + last_layer_kernel = 3 if (self.n_offset % 2) == 0 else 2 super().__init__( nn.Conv1d(num_neurons, num_units, 2), nn.GELU(), - *self._make_layers(num_units, num_layers=3), - nn.Conv1d(num_units, num_output, 3), + *self._make_layers(num_units, + num_layers=_compute_num_layers(self.n_offset)), + nn.Conv1d(num_units, num_output, last_layer_kernel), num_input=num_neurons, num_output=num_output, normalize=normalize, ) def get_offset(self) -> cebra.data.datatypes.Offset: - """See :py:meth:`~.Model.get_offset`""" - return cebra.data.Offset(5, 5) + """See `:py:meth:Model.get_offset`""" + return cebra.data.Offset(self.n_offset // 2, + self.n_offset // 2 + self.n_offset % 2) @register("offset10-model-mse") -class Offset10ModelMSE(Offset10Model): +class Offset10ModelMSE(OffsetNModel): """Symmetric model with 10 sample receptive field, without normalization. Suitable for use with InfoNCE metrics for Euclidean space. """ def __init__(self, num_neurons, num_units, num_output, normalize=False): - super().__init__(num_neurons, num_units, num_output, normalize) - - -@register("offset5-model") -class Offset5Model(_OffsetModel, ConvolutionalModelMixin): - """CEBRA model with a 5 sample receptive field and output normalization.""" - - def __init__(self, num_neurons, num_units, num_output, normalize=True): - super().__init__( - nn.Conv1d(num_neurons, num_units, 2), - nn.GELU(), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - nn.Conv1d(num_units, num_output, 2), - num_input=num_neurons, - num_output=num_output, - normalize=normalize, - ) - - def get_offset(self) -> cebra.data.datatypes.Offset: - """See :py:meth:`~.Model.get_offset`""" - return cebra.data.Offset(2, 3) + super().__init__(num_neurons, + num_units, + num_output, + n_offset=10, + normalize=normalize) @register("offset1-model-mse") @@ -666,30 +669,6 @@ def get_offset(self) -> cebra.data.datatypes.Offset: return cebra.data.Offset(0, 1) -@register("offset36-model") -class Offset36(_OffsetModel, ConvolutionalModelMixin): - """CEBRA model with a 10 sample receptive field.""" - - def __init__(self, num_neurons, num_units, num_output, normalize=True): - if num_units < 1: - raise ValueError( - f"Hidden dimension needs to be at least 1, but got {num_units}." - ) - super().__init__( - nn.Conv1d(num_neurons, num_units, 2), - nn.GELU(), - *self._make_layers(num_units, num_layers=16), - nn.Conv1d(num_units, num_output, 3), - num_input=num_neurons, - num_output=num_output, - normalize=normalize, - ) - - def get_offset(self) -> cebra.data.datatypes.Offset: - """See `:py:meth:Model.get_offset`""" - return cebra.data.Offset(18, 18) - - @_register_conditionally("offset36-model-dropout") class Offset36Dropout(_OffsetModel, ConvolutionalModelMixin): """CEBRA model with a 10 sample receptive field. @@ -767,102 +746,6 @@ def get_offset(self) -> cebra.data.datatypes.Offset: return cebra.data.Offset(18, 18) -@register("offset40-model") -class Offset40(_OffsetModel, ConvolutionalModelMixin): - """CEBRA model with a 40 samples receptive field.""" - - def __init__(self, num_neurons, num_units, num_output, normalize=True): - if num_units < 1: - raise ValueError( - f"Hidden dimension needs to be at least 1, but got {num_units}." - ) - super().__init__( - nn.Conv1d(num_neurons, num_units, 2), - nn.GELU(), - *self._make_layers(num_units, 18), - nn.Conv1d(num_units, num_output, 3), - num_input=num_neurons, - num_output=num_output, - normalize=normalize, - ) - - def get_offset(self) -> cebra.data.datatypes.Offset: - """See `:py:meth:Model.get_offset`""" - return cebra.data.Offset(20, 20) - - -@register("offset50-model") -class Offset50(_OffsetModel, ConvolutionalModelMixin): - """CEBRA model with a sample receptive field.""" - - def __init__(self, num_neurons, num_units, num_output, normalize=True): - if num_units < 1: - raise ValueError( - f"Hidden dimension needs to be at least 1, but got {num_units}." - ) - super().__init__( - nn.Conv1d(num_neurons, num_units, 2), - nn.GELU(), - *self._make_layers(num_units, 23), - nn.Conv1d(num_units, num_output, 3), - num_input=num_neurons, - num_output=num_output, - normalize=normalize, - ) - - def get_offset(self) -> cebra.data.datatypes.Offset: - """See `:py:meth:Model.get_offset`""" - return cebra.data.Offset(25, 25) - - -@register("offset15-model") -class Offset15Model(_OffsetModel, ConvolutionalModelMixin): - """CEBRA model with a 15 sample receptive field.""" - - def __init__(self, num_neurons, num_units, num_output, normalize=True): - if num_units < 1: - raise ValueError( - f"Hidden dimension needs to be at least 1, but got {num_units}." - ) - super().__init__( - nn.Conv1d(num_neurons, num_units, 2), - nn.GELU(), - *self._make_layers(num_units, num_layers=6), - nn.Conv1d(num_units, num_output, 2), - num_input=num_neurons, - num_output=num_output, - normalize=normalize, - ) - - def get_offset(self) -> cebra.data.datatypes.Offset: - """See `:py:meth:Model.get_offset`""" - return cebra.data.Offset(7, 8) - - -@register("offset20-model") -class Offset20Model(_OffsetModel, ConvolutionalModelMixin): - """CEBRA model with a 15 sample receptive field.""" - - def __init__(self, num_neurons, num_units, num_output, normalize=True): - if num_units < 1: - raise ValueError( - f"Hidden dimension needs to be at least 1, but got {num_units}." - ) - super().__init__( - nn.Conv1d(num_neurons, num_units, 2), - nn.GELU(), - *self._make_layers(num_units, num_layers=8), - nn.Conv1d(num_units, num_output, 3), - num_input=num_neurons, - num_output=num_output, - normalize=normalize, - ) - - def get_offset(self) -> cebra.data.datatypes.Offset: - """See `:py:meth:Model.get_offset`""" - return cebra.data.Offset(10, 10) - - @register("offset10-model-mse-tanh") class Offset10Model(_OffsetModel, ConvolutionalModelMixin): """CEBRA model with a 10 sample receptive field.""" diff --git a/docs/source/usage.rst b/docs/source/usage.rst index 82e45a0b..38bbd633 100644 --- a/docs/source/usage.rst +++ b/docs/source/usage.rst @@ -175,7 +175,7 @@ We provide a set of pre-defined models. You can access (and search) a list of av .. testoutput:: - ['offset10-model', 'offset10-model-mse', 'offset5-model', 'offset1-model-mse'] + ['offset5-model', 'offset10-model', 'offset15-model', 'offset18-model'] Then, you can choose the one that fits best with your needs and provide it to the CEBRA model as the :py:attr:`~.CEBRA.model_architecture` parameter. From 7212cc4c29ccbd5726058129f4b47e5cb1af9001 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9lia=20Benquet?= <32598028+CeliaBenquet@users.noreply.github.com> Date: Thu, 7 May 2026 16:02:57 +0200 Subject: [PATCH 2/6] Implement copilot suggestions --- cebra/models/model.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/cebra/models/model.py b/cebra/models/model.py index ef80d57b..118748d8 100644 --- a/cebra/models/model.py +++ b/cebra/models/model.py @@ -281,16 +281,15 @@ def __init__(self, self.n_offset = n_offset - def _compute_num_layers(n_offset): + def _compute_num_layers(): """Compute the number of layers to add on top of the first and last conv layers.""" - return (n_offset - 4) // 2 + self.n_offset % 2 + return (self.n_offset - 4) // 2 + self.n_offset % 2 last_layer_kernel = 3 if (self.n_offset % 2) == 0 else 2 super().__init__( nn.Conv1d(num_neurons, num_units, 2), nn.GELU(), - *self._make_layers(num_units, - num_layers=_compute_num_layers(self.n_offset)), + *self._make_layers(num_units, num_layers=_compute_num_layers()), nn.Conv1d(num_units, num_output, last_layer_kernel), num_input=num_neurons, num_output=num_output, @@ -298,7 +297,7 @@ def _compute_num_layers(n_offset): ) def get_offset(self) -> cebra.data.datatypes.Offset: - """See `:py:meth:Model.get_offset`""" + """See :py:meth:`~.Model.get_offset`""" return cebra.data.Offset(self.n_offset // 2, self.n_offset // 2 + self.n_offset % 2) @@ -700,7 +699,7 @@ def __init__(self, ) def get_offset(self) -> cebra.data.datatypes.Offset: - """See `:py:meth:Model.get_offset`""" + """See :py:meth:`~.Model.get_offset`""" return cebra.data.Offset(18, 18) @@ -742,12 +741,12 @@ def __init__(self, ) def get_offset(self) -> cebra.data.datatypes.Offset: - """See `:py:meth:Model.get_offset`""" + """See :py:meth:`~.Model.get_offset`""" return cebra.data.Offset(18, 18) @register("offset10-model-mse-tanh") -class Offset10Model(_OffsetModel, ConvolutionalModelMixin): +class Offset10ModelMSETanh(_OffsetModel, ConvolutionalModelMixin): """CEBRA model with a 10 sample receptive field.""" def __init__(self, num_neurons, num_units, num_output, normalize=False): From 2d48ef8de483f64133fc9c16e9da33ea4ccd83bb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9lia=20Benquet?= <32598028+CeliaBenquet@users.noreply.github.com> Date: Thu, 7 May 2026 16:04:50 +0200 Subject: [PATCH 3/6] Restructure deprecation functions --- tests/_reference_implementations/__init__.py | 32 ++++++++ .../deprecated_transforms.py} | 10 +-- tests/test_integration_xcebra.py | 4 +- tests/test_sklearn.py | 76 +++++++++---------- 4 files changed, 73 insertions(+), 49 deletions(-) create mode 100644 tests/_reference_implementations/__init__.py rename tests/{_utils_deprecated.py => _reference_implementations/deprecated_transforms.py} (93%) diff --git a/tests/_reference_implementations/__init__.py b/tests/_reference_implementations/__init__.py new file mode 100644 index 00000000..15e6755b --- /dev/null +++ b/tests/_reference_implementations/__init__.py @@ -0,0 +1,32 @@ +"""Reference implementations for testing consistency and backward compatibility. + +This package contains reference implementations of previously deprecated or +parametrized model components, used for testing consistency and backward compatibility +in the test suite. +""" + +from .deprecated_transforms import ( + cebra_transform_deprecated, + multiobjective_transform_deprecated, +) +from .reference_offset_models import ( + Offset5ModelReference, + Offset10ModelReference, + Offset15ModelReference, + Offset20ModelReference, + Offset36Reference, + Offset40Reference, + Offset50Reference, +) + +__all__ = [ + "cebra_transform_deprecated", + "multiobjective_transform_deprecated", + "Offset5ModelReference", + "Offset10ModelReference", + "Offset15ModelReference", + "Offset20ModelReference", + "Offset36Reference", + "Offset40Reference", + "Offset50Reference", +] diff --git a/tests/_utils_deprecated.py b/tests/_reference_implementations/deprecated_transforms.py similarity index 93% rename from tests/_utils_deprecated.py rename to tests/_reference_implementations/deprecated_transforms.py index a5a2aaaf..01be92dd 100644 --- a/tests/_utils_deprecated.py +++ b/tests/_reference_implementations/deprecated_transforms.py @@ -11,8 +11,8 @@ import cebra.models -#NOTE: Deprecated: transform is now handled in the solver but the original -# method is kept here for testing. +#NOTE(celia): Deprecated: transform is now handled in the solver but the original +# method is kept here for testing consistency. def cebra_transform_deprecated(cebra_model, X: Union[npt.NDArray, torch.Tensor], session_id: Optional[int] = None) -> npt.NDArray: @@ -72,9 +72,9 @@ def cebra_transform_deprecated(cebra_model, return output -# NOTE: Deprecated: batched transform can now be performed (more memory efficient) +# NOTE(celia): Deprecated: batched transform can now be performed (more memory efficient) # using the transform method of the model, and handling padding is implemented -# directly in the base Solver. This method is kept for testing purposes. +# directly in the base Solver. This method is kept for testing consistency. @torch.no_grad() def multiobjective_transform_deprecated(solver: "cebra.solvers.Solver", inputs: torch.Tensor) -> torch.Tensor: @@ -90,7 +90,7 @@ def multiobjective_transform_deprecated(solver: "cebra.solvers.Solver", warnings.warn( "The method is deprecated " - "but kept for testing puroposes." + "but kept for testing purposes." "We recommend using `transform` instead.", DeprecationWarning, stacklevel=2) diff --git a/tests/test_integration_xcebra.py b/tests/test_integration_xcebra.py index 760e26ef..e213be61 100644 --- a/tests/test_integration_xcebra.py +++ b/tests/test_integration_xcebra.py @@ -1,6 +1,6 @@ import pickle -import _utils_deprecated +import _reference_implementations import numpy as np import pytest import torch @@ -174,7 +174,7 @@ def test_synthetic_data_training(synthetic_data, device): atol=1e-4) # Test and compare the previous transform (transform_deprecated) - deprecated_transform_embedding = _utils_deprecated.multiobjective_transform_deprecated( + deprecated_transform_embedding = _reference_implementations.multiobjective_transform_deprecated( solver, data.neural.to(device)) assert np.allclose(embedding, deprecated_transform_embedding, diff --git a/tests/test_sklearn.py b/tests/test_sklearn.py index 831ad49d..bdecdeef 100644 --- a/tests/test_sklearn.py +++ b/tests/test_sklearn.py @@ -25,8 +25,8 @@ import tempfile import warnings +import _reference_implementations import _util -import _utils_deprecated import numpy as np import packaging.version import pytest @@ -1412,7 +1412,8 @@ def test_new_transform(model_architecture, device): # time contrastive cebra_model.fit(X) embedding1 = cebra_model.transform(X) - embedding2 = _utils_deprecated.cebra_transform_deprecated(cebra_model, X) + embedding2 = _reference_implementations.cebra_transform_deprecated( + cebra_model, X) assert np.allclose(embedding1, embedding2, rtol=1e-5, atol=1e-8), "Arrays are not close enough" @@ -1421,20 +1422,20 @@ def test_new_transform(model_architecture, device): assert cebra_model.num_sessions is None embedding1 = cebra_model.transform(X) - embedding2 = _utils_deprecated.cebra_transform_deprecated(cebra_model, X) + embedding2 = _reference_implementations.cebra_transform_deprecated( + cebra_model, X) assert np.allclose(embedding1, embedding2, rtol=1e-5, atol=1e-8), "Arrays are not close enough" embedding1 = cebra_model.transform(torch.Tensor(X)) - embedding2 = _utils_deprecated.cebra_transform_deprecated( + embedding2 = _reference_implementations.cebra_transform_deprecated( cebra_model, torch.Tensor(X)) assert np.allclose(embedding1, embedding2, rtol=1e-5, atol=1e-8), "Arrays are not close enough" embedding1 = cebra_model.transform(torch.Tensor(X), session_id=0) - embedding2 = _utils_deprecated.cebra_transform_deprecated(cebra_model, - torch.Tensor(X), - session_id=0) + embedding2 = _reference_implementations.cebra_transform_deprecated( + cebra_model, torch.Tensor(X), session_id=0) assert np.allclose(embedding1, embedding2, rtol=1e-5, atol=1e-8), "Arrays are not close enough" @@ -1444,14 +1445,16 @@ def test_new_transform(model_architecture, device): # discrete behavior contrastive cebra_model.fit(X, y_d) embedding1 = cebra_model.transform(X) - embedding2 = _utils_deprecated.cebra_transform_deprecated(cebra_model, X) + embedding2 = _reference_implementations.cebra_transform_deprecated( + cebra_model, X) assert np.allclose(embedding1, embedding2, rtol=1e-5, atol=1e-8), "Arrays are not close enough" # mixed cebra_model.fit(X, y_c1, y_c2, y_d) embedding1 = cebra_model.transform(X) - embedding2 = _utils_deprecated.cebra_transform_deprecated(cebra_model, X) + embedding2 = _reference_implementations.cebra_transform_deprecated( + cebra_model, X) assert np.allclose(embedding1, embedding2, rtol=1e-5, atol=1e-8), "Arrays are not close enough" @@ -1459,23 +1462,20 @@ def test_new_transform(model_architecture, device): cebra_model.fit([X, X_s2], [y_d, y_d_s2]) embedding1 = cebra_model.transform(X, session_id=0) - embedding2 = _utils_deprecated.cebra_transform_deprecated(cebra_model, - X, - session_id=0) + embedding2 = _reference_implementations.cebra_transform_deprecated( + cebra_model, X, session_id=0) assert np.allclose(embedding1, embedding2, rtol=1e-5, atol=1e-8), "Arrays are not close enough" embedding1 = cebra_model.transform(torch.Tensor(X), session_id=0) - embedding2 = _utils_deprecated.cebra_transform_deprecated(cebra_model, - torch.Tensor(X), - session_id=0) + embedding2 = _reference_implementations.cebra_transform_deprecated( + cebra_model, torch.Tensor(X), session_id=0) assert np.allclose(embedding1, embedding2, rtol=1e-5, atol=1e-8), "Arrays are not close enough" embedding1 = cebra_model.transform(X_s2, session_id=1) - embedding2 = _utils_deprecated.cebra_transform_deprecated(cebra_model, - X_s2, - session_id=1) + embedding2 = _reference_implementations.cebra_transform_deprecated( + cebra_model, X_s2, session_id=1) assert np.allclose(embedding1, embedding2, rtol=1e-5, atol=1e-8), "Arrays are not close enough" @@ -1483,16 +1483,14 @@ def test_new_transform(model_architecture, device): cebra_model.fit([X, X_s2], [y_c1, y_c1_s2]) embedding1 = cebra_model.transform(X, session_id=0) - embedding2 = _utils_deprecated.cebra_transform_deprecated(cebra_model, - X, - session_id=0) + embedding2 = _reference_implementations.cebra_transform_deprecated( + cebra_model, X, session_id=0) assert np.allclose(embedding1, embedding2, rtol=1e-5, atol=1e-8), " are not close enough" embedding1 = cebra_model.transform(torch.Tensor(X), session_id=0) - embedding2 = _utils_deprecated.cebra_transform_deprecated(cebra_model, - torch.Tensor(X), - session_id=0) + embedding2 = _reference_implementations.cebra_transform_deprecated( + cebra_model, torch.Tensor(X), session_id=0) assert np.allclose(embedding1, embedding2, rtol=1e-5, atol=1e-8), "Arrays are not close enough" @@ -1511,23 +1509,20 @@ def test_new_transform(model_architecture, device): cebra_model.fit([X, X_s2, X], [y_d, y_d_s2, y_d]) embedding1 = cebra_model.transform(X, session_id=0) - embedding2 = _utils_deprecated.cebra_transform_deprecated(cebra_model, - X, - session_id=0) + embedding2 = _reference_implementations.cebra_transform_deprecated( + cebra_model, X, session_id=0) assert np.allclose(embedding1, embedding2, rtol=1e-5, atol=1e-8), "Arrays are not close enough" embedding1 = cebra_model.transform(X_s2, session_id=1) - embedding2 = _utils_deprecated.cebra_transform_deprecated(cebra_model, - X_s2, - session_id=1) + embedding2 = _reference_implementations.cebra_transform_deprecated( + cebra_model, X_s2, session_id=1) assert np.allclose(embedding1, embedding2, rtol=1e-5, atol=1e-8), "Arrays are not close enough" embedding1 = cebra_model.transform(X, session_id=2) - embedding2 = _utils_deprecated.cebra_transform_deprecated(cebra_model, - X, - session_id=2) + embedding2 = _reference_implementations.cebra_transform_deprecated( + cebra_model, X, session_id=2) assert np.allclose(embedding1, embedding2, rtol=1e-5, atol=1e-8), "Arrays are not close enough" @@ -1535,23 +1530,20 @@ def test_new_transform(model_architecture, device): cebra_model.fit([X, X_s2, X], [y_c1, y_c1_s2, y_c1]) embedding1 = cebra_model.transform(X, session_id=0) - embedding2 = _utils_deprecated.cebra_transform_deprecated(cebra_model, - X, - session_id=0) + embedding2 = _reference_implementations.cebra_transform_deprecated( + cebra_model, X, session_id=0) assert np.allclose(embedding1, embedding2, rtol=1e-5, atol=1e-8), "Arrays are not close enough" embedding1 = cebra_model.transform(X_s2, session_id=1) - embedding2 = _utils_deprecated.cebra_transform_deprecated(cebra_model, - X_s2, - session_id=1) + embedding2 = _reference_implementations.cebra_transform_deprecated( + cebra_model, X_s2, session_id=1) assert np.allclose(embedding1, embedding2, rtol=1e-5, atol=1e-8), "Arrays are not close enough" embedding1 = cebra_model.transform(X, session_id=2) - embedding2 = _utils_deprecated.cebra_transform_deprecated(cebra_model, - X, - session_id=2) + embedding2 = _reference_implementations.cebra_transform_deprecated( + cebra_model, X, session_id=2) assert np.allclose(embedding1, embedding2, rtol=1e-5, atol=1e-8), "Arrays are not close enough" From dbae49a6baf1b115741af2084a423b355098d048 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9lia=20Benquet?= <32598028+CeliaBenquet@users.noreply.github.com> Date: Thu, 7 May 2026 16:05:10 +0200 Subject: [PATCH 4/6] Add deprecation tests on offsets models --- .../reference_offset_models.py | 163 ++++++++++++++ tests/test_models.py | 201 ++++++++++++++++++ 2 files changed, 364 insertions(+) create mode 100644 tests/_reference_implementations/reference_offset_models.py diff --git a/tests/_reference_implementations/reference_offset_models.py b/tests/_reference_implementations/reference_offset_models.py new file mode 100644 index 00000000..f687f9fb --- /dev/null +++ b/tests/_reference_implementations/reference_offset_models.py @@ -0,0 +1,163 @@ +"""Reference implementations of previously hardcoded offset models before parametrization. + +These models are used to verify that the parametrized versions produce identical outputs. +""" + +import torch +from torch import nn + +import cebra.data +import cebra.data.datatypes +import cebra.models.layers as cebra_layers +from cebra.models.model import _OffsetModel +from cebra.models.model import ConvolutionalModelMixin + + +class Offset10ModelReference(_OffsetModel, ConvolutionalModelMixin): + """Reference: CEBRA model with a 10 sample receptive field (offset10-model).""" + + def __init__(self, num_neurons, num_units, num_output, normalize=True): + if num_units < 1: + raise ValueError( + f"Hidden dimension needs to be at least 1, but got {num_units}." + ) + super().__init__( + nn.Conv1d(num_neurons, num_units, 2), + nn.GELU(), + *self._make_layers(num_units, num_layers=3), + nn.Conv1d(num_units, num_output, 3), + num_input=num_neurons, + num_output=num_output, + normalize=normalize, + ) + + def get_offset(self) -> cebra.data.datatypes.Offset: + return cebra.data.Offset(5, 5) + + +class Offset5ModelReference(_OffsetModel, ConvolutionalModelMixin): + """Reference: CEBRA model with a 5 sample receptive field (offset5-model).""" + + def __init__(self, num_neurons, num_units, num_output, normalize=True): + super().__init__( + nn.Conv1d(num_neurons, num_units, 2), + nn.GELU(), + cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), + nn.Conv1d(num_units, num_output, 2), + num_input=num_neurons, + num_output=num_output, + normalize=normalize, + ) + + def get_offset(self) -> cebra.data.datatypes.Offset: + return cebra.data.Offset(2, 3) + + +class Offset15ModelReference(_OffsetModel, ConvolutionalModelMixin): + """Reference: CEBRA model with a 15 sample receptive field (offset15-model).""" + + def __init__(self, num_neurons, num_units, num_output, normalize=True): + if num_units < 1: + raise ValueError( + f"Hidden dimension needs to be at least 1, but got {num_units}." + ) + super().__init__( + nn.Conv1d(num_neurons, num_units, 2), + nn.GELU(), + *self._make_layers(num_units, num_layers=6), + nn.Conv1d(num_units, num_output, 2), + num_input=num_neurons, + num_output=num_output, + normalize=normalize, + ) + + def get_offset(self) -> cebra.data.datatypes.Offset: + return cebra.data.Offset(7, 8) + + +class Offset20ModelReference(_OffsetModel, ConvolutionalModelMixin): + """Reference: CEBRA model with a 20 sample receptive field (offset20-model).""" + + def __init__(self, num_neurons, num_units, num_output, normalize=True): + if num_units < 1: + raise ValueError( + f"Hidden dimension needs to be at least 1, but got {num_units}." + ) + super().__init__( + nn.Conv1d(num_neurons, num_units, 2), + nn.GELU(), + *self._make_layers(num_units, num_layers=8), + nn.Conv1d(num_units, num_output, 3), + num_input=num_neurons, + num_output=num_output, + normalize=normalize, + ) + + def get_offset(self) -> cebra.data.datatypes.Offset: + return cebra.data.Offset(10, 10) + + +class Offset36Reference(_OffsetModel, ConvolutionalModelMixin): + """Reference: CEBRA model with a 36 sample receptive field (offset36-model).""" + + def __init__(self, num_neurons, num_units, num_output, normalize=True): + if num_units < 1: + raise ValueError( + f"Hidden dimension needs to be at least 1, but got {num_units}." + ) + super().__init__( + nn.Conv1d(num_neurons, num_units, 2), + nn.GELU(), + *self._make_layers(num_units, num_layers=16), + nn.Conv1d(num_units, num_output, 3), + num_input=num_neurons, + num_output=num_output, + normalize=normalize, + ) + + def get_offset(self) -> cebra.data.datatypes.Offset: + return cebra.data.Offset(18, 18) + + +class Offset40Reference(_OffsetModel, ConvolutionalModelMixin): + """Reference: CEBRA model with a 40 sample receptive field (offset40-model).""" + + def __init__(self, num_neurons, num_units, num_output, normalize=True): + if num_units < 1: + raise ValueError( + f"Hidden dimension needs to be at least 1, but got {num_units}." + ) + super().__init__( + nn.Conv1d(num_neurons, num_units, 2), + nn.GELU(), + *self._make_layers(num_units, 18), + nn.Conv1d(num_units, num_output, 3), + num_input=num_neurons, + num_output=num_output, + normalize=normalize, + ) + + def get_offset(self) -> cebra.data.datatypes.Offset: + return cebra.data.Offset(20, 20) + + +class Offset50Reference(_OffsetModel, ConvolutionalModelMixin): + """Reference: CEBRA model with a 50 sample receptive field (offset50-model).""" + + def __init__(self, num_neurons, num_units, num_output, normalize=True): + if num_units < 1: + raise ValueError( + f"Hidden dimension needs to be at least 1, but got {num_units}." + ) + super().__init__( + nn.Conv1d(num_neurons, num_units, 2), + nn.GELU(), + *self._make_layers(num_units, 23), + nn.Conv1d(num_units, num_output, 3), + num_input=num_neurons, + num_output=num_output, + normalize=normalize, + ) + + def get_offset(self) -> cebra.data.datatypes.Offset: + return cebra.data.Offset(25, 25) diff --git a/tests/test_models.py b/tests/test_models.py index 658cc467..78779107 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -165,3 +165,204 @@ def test_version_check_dropout_available(): assert len(cebra.models.get_options("*dropout*")) == 0 else: assert len(cebra.models.get_options("*dropout*")) > 0 + + +# Tests for parametrized offset models backward compatibility +from _reference_implementations import Offset5ModelReference +from _reference_implementations import Offset10ModelReference +from _reference_implementations import Offset15ModelReference +from _reference_implementations import Offset20ModelReference +from _reference_implementations import Offset36Reference +from _reference_implementations import Offset40Reference +from _reference_implementations import Offset50Reference + + +@pytest.mark.parametrize("offset_n,reference_class", [ + (5, Offset5ModelReference), + (10, Offset10ModelReference), + (15, Offset15ModelReference), + (20, Offset20ModelReference), + (36, Offset36Reference), + (40, Offset40Reference), + (50, Offset50Reference), +]) +def test_parametrized_offset_models_match_reference(offset_n, reference_class): + """Test that parametrized offset models produce identical output to reference hardcoded models.""" + + num_neurons = 5 + num_units = 8 + num_output = 3 + normalize = True + + # Create reference model + ref_model = reference_class(num_neurons, + num_units, + num_output, + normalize=normalize) + + # Create parametrized model using OffsetNModel + param_model = cebra.models.init(f"offset{offset_n}-model", + num_neurons=num_neurons, + num_units=num_units, + num_output=num_output) + + # Test 1: Check offsets match + ref_offset = ref_model.get_offset() + param_offset = param_model.get_offset() + assert ref_offset.left == param_offset.left, \ + f"Offset left mismatch for offset{offset_n}: {ref_offset.left} != {param_offset.left}" + assert ref_offset.right == param_offset.right, \ + f"Offset right mismatch for offset{offset_n}: {ref_offset.right} != {param_offset.right}" + + # Test 2: Check model architecture - same number of parameters + ref_params = sum(p.numel() for p in ref_model.parameters()) + param_params = sum(p.numel() for p in param_model.parameters()) + assert ref_params == param_params, \ + f"Parameter count mismatch for offset{offset_n}: {ref_params} != {param_params}" + + # Test 3: Check output shape consistency + batch_size = 2 + input_length = 100 + offset_len = len(ref_offset) + + test_input = torch.randn(batch_size, num_neurons, offset_len) + + with torch.no_grad(): + ref_output = ref_model.net(test_input) + param_output = param_model.net(test_input) + + assert ref_output.shape == param_output.shape, \ + f"Output shape mismatch for offset{offset_n}: {ref_output.shape} != {param_output.shape}" + + # Test 4: For convolutional models, test on full length input + if isinstance(param_model, cebra.models.ConvolutionalModelMixin): + test_input_full = torch.randn(batch_size, num_neurons, input_length) + + with torch.no_grad(): + ref_output_full = ref_model.net(test_input_full) + param_output_full = param_model.net(test_input_full) + + expected_length = input_length - len(ref_offset) + 1 + assert ref_output_full.shape == (batch_size, num_output, expected_length), \ + f"Reference model output shape unexpected for offset{offset_n}" + assert param_output_full.shape == (batch_size, num_output, expected_length), \ + f"Parametrized model output shape unexpected for offset{offset_n}" + + +@pytest.mark.parametrize("offset_n", [5, 10, 15, 18, 20, 31, 36, 40, 50]) +def test_parametrized_offset_models_exist(offset_n): + """Test that all parametrized offset models can be instantiated.""" + model = cebra.models.init(f"offset{offset_n}-model", + num_neurons=5, + num_units=4, + num_output=3) + assert isinstance(model, cebra.models.Model) + assert isinstance(model, cebra.models.HasFeatureEncoder) + assert isinstance(model, cebra.models.ConvolutionalModelMixin) + + +@pytest.mark.parametrize("offset_n,reference_class", [ + (5, Offset5ModelReference), + (10, Offset10ModelReference), + (15, Offset15ModelReference), + (20, Offset20ModelReference), + (36, Offset36Reference), + (40, Offset40Reference), + (50, Offset50Reference), +]) +def test_parametrized_offset_models_forward_pass_identical( + offset_n, reference_class): + """Test that parametrized and reference models produce identical forward pass outputs. + + This test verifies that when both models are initialized with the same seed and weights, + they produce identical outputs. + """ + + num_neurons = 5 + num_units = 8 + num_output = 3 + normalize = True + batch_size = 2 + + # Set seed for reproducibility + torch.manual_seed(42) + + # Create reference model and get its state dict + ref_model = reference_class(num_neurons, + num_units, + num_output, + normalize=normalize) + ref_state_dict = {k: v.clone() for k, v in ref_model.state_dict().items()} + + # Create parametrized model + param_model = cebra.models.init(f"offset{offset_n}-model", + num_neurons=num_neurons, + num_units=num_units, + num_output=num_output) + + # Load the same weights into parametrized model + param_model.load_state_dict(ref_state_dict) + + # Test with multiple input sizes + offset = ref_model.get_offset() + offset_len = len(offset) + + for input_length in [offset_len, offset_len * 2, 100]: + test_input = torch.randn(batch_size, num_neurons, input_length) + + with torch.no_grad(): + ref_output = ref_model.net(test_input) + param_output = param_model.net(test_input) + + # Check that outputs are identical + assert torch.allclose(ref_output, param_output, rtol=1e-5, atol=1e-7), \ + f"Output mismatch for offset{offset_n} with input_length={input_length}" + + # Check that outputs have same device and dtype + assert ref_output.device == param_output.device, \ + f"Device mismatch for offset{offset_n}" + assert ref_output.dtype == param_output.dtype, \ + f"Dtype mismatch for offset{offset_n}" + + +@pytest.mark.parametrize("offset_n", [5, 10, 15, 18, 20, 31, 36, 40, 50]) +def test_parametrized_offset_models_layer_structure(offset_n): + """Test that parametrized models have the correct layer structure.""" + num_neurons = 4 + num_units = 8 + num_output = 3 + + model = cebra.models.init(f"offset{offset_n}-model", + num_neurons=num_neurons, + num_units=num_units, + num_output=num_output) + + # Model should have Conv1d -> GELU -> Skip layers -> Conv1d structure + # Extract the actual network layers + layers = list(model.net.children()) + + # First layer should be Conv1d + assert isinstance(layers[0], nn.Conv1d), \ + f"First layer of offset{offset_n} model should be Conv1d" + assert layers[0].in_channels == num_neurons + assert layers[0].out_channels == num_units + assert layers[0].kernel_size == (2,) + + # Last meaningful layer (before Norm and Squeeze) should be Conv1d + # Find the second-to-last Conv1d layer + conv_layers = [l for l in layers if isinstance(l, nn.Conv1d)] + assert len(conv_layers) >= 2, \ + f"offset{offset_n} model should have at least 2 Conv1d layers" + + last_conv = conv_layers[-1] + assert last_conv.out_channels == num_output + + # Check that offset is computed correctly + offset = model.get_offset() + expected_left = offset_n // 2 + expected_right = offset_n // 2 + offset_n % 2 + + assert offset.left == expected_left, \ + f"Offset left for offset{offset_n} should be {expected_left}, got {offset.left}" + assert offset.right == expected_right, \ + f"Offset right for offset{offset_n} should be {expected_right}, got {offset.right}" From 484c75f874bf470e578bffd32b1ec51f39701e5e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9lia=20Benquet?= <32598028+CeliaBenquet@users.noreply.github.com> Date: Thu, 7 May 2026 16:10:48 +0200 Subject: [PATCH 5/6] Fix codespell --- cebra/distributions/index.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cebra/distributions/index.py b/cebra/distributions/index.py index 724e86e4..020e674a 100644 --- a/cebra/distributions/index.py +++ b/cebra/distributions/index.py @@ -215,7 +215,7 @@ def search(self, continuous, discrete=None): Samples from the continuous index discrete: Optionally matching samples from the discrete index, - used to pre-select matching indices. + used to preselect matching indices. """ if continuous.shape[1] != self.continuous.shape[1]: raise ValueError(f"Shape of continuous index does not match along " From 4081fe796d97a169c7f4811108b0d8a5ed269eed Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9lia=20Benquet?= <32598028+CeliaBenquet@users.noreply.github.com> Date: Thu, 7 May 2026 16:32:35 +0200 Subject: [PATCH 6/6] Tentative fix to parametrized class TypeError --- cebra/registry.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/cebra/registry.py b/cebra/registry.py index 1bbc5093..03bffa65 100644 --- a/cebra/registry.py +++ b/cebra/registry.py @@ -192,7 +192,7 @@ def _register(cls): def parametrize(pattern: str, *, - kwargs: List[Dict[str, Any]] = [], + kwargs: List[Dict[str, Any]] = None, **all_kwargs): """Decorator to add parametrizations of a new class to the registry. @@ -221,8 +221,8 @@ def _create_class(cls, **default_kwargs): class _ParametrizedClass(cls): def __init__(self, *args, **kwargs): - default_kwargs.update(kwargs) - super().__init__(*args, **default_kwargs) + merged_kwargs = {**default_kwargs, **kwargs} + super().__init__(*args, **merged_kwargs) # Make the class pickleable by copying metadata from the base class # and registering it in the module namespace @@ -239,7 +239,7 @@ def __init__(self, *args, **kwargs): setattr(parent_module, unique_name, _ParametrizedClass) def _parametrize(cls): - for _default_kwargs in kwargs: + for _default_kwargs in (kwargs or []): _create_class(cls, **_default_kwargs) if len(all_kwargs) > 0: for _default_kwargs in _product_dict(all_kwargs):