From ca9a61d7e66941fdc31570c1a54c8b0ce34a2d07 Mon Sep 17 00:00:00 2001 From: Steffen Schneider Date: Sat, 31 Jan 2026 18:51:44 +0000 Subject: [PATCH 01/13] Update download urls from figshare to s3 bucket --- cebra/data/assets.py | 126 ++++++++++++++++++++++++++++++ cebra/data/base.py | 22 ++++-- cebra/datasets/hippocampus.py | 36 ++++++--- cebra/datasets/monkey_reaching.py | 76 ++++++++++++------ cebra/datasets/synthetic_data.py | 60 +++++++++----- tests/test_datasets.py | 109 ++++++++++++++++++++++++++ 6 files changed, 366 insertions(+), 63 deletions(-) diff --git a/cebra/data/assets.py b/cebra/data/assets.py index adea8413..683f3b6c 100644 --- a/cebra/data/assets.py +++ b/cebra/data/assets.py @@ -20,6 +20,7 @@ # limitations under the License. # +import gzip import hashlib import re import warnings @@ -140,3 +141,128 @@ def calculate_checksum(file_path: str) -> str: for chunk in iter(lambda: file.read(4096), b""): checksum.update(chunk) return checksum.hexdigest() + + +def download_and_extract_gzipped_file(url: str, + expected_checksum: str, + gzipped_checksum: str, + location: str, + file_name: str, + retry_count: int = 0) -> Optional[str]: + """Download a gzipped file from the given URL, verify checksums, and extract. + + This function downloads a gzipped file, verifies the checksum of the gzipped + file, extracts it, and then verifies the checksum of the extracted file. + + Args: + url: The URL to download the gzipped file from. + expected_checksum: The expected MD5 checksum of the extracted file. + gzipped_checksum: The expected MD5 checksum of the gzipped file. + location: The directory where the file will be saved. + file_name: The name of the final extracted file (without .gz extension). + retry_count: The number of retry attempts (default: 0). + + Returns: + The path of the extracted file if successful, None otherwise. + + Raises: + RuntimeError: If the maximum retry count is exceeded. + requests.HTTPError: If the download fails. + """ + + # Check if the final extracted file already exists with correct checksum + location_path = Path(location) + final_file_path = location_path / file_name + + if final_file_path.exists(): + existing_checksum = calculate_checksum(final_file_path) + if existing_checksum == expected_checksum: + return final_file_path + + if retry_count >= _MAX_RETRY_COUNT: + raise RuntimeError( + f"Exceeded maximum retry count ({_MAX_RETRY_COUNT}). " + f"Unable to download the file from {url}") + + # Create the directory and any necessary parent directories + location_path.mkdir(parents=True, exist_ok=True) + + # Download the gzipped file + gz_file_path = location_path / f"{file_name}.gz" + + response = requests.get(url, stream=True) + + # Check if the request was successful + if response.status_code != 200: + raise requests.HTTPError( + f"Error occurred while downloading the file. Response code: {response.status_code}" + ) + + total_size = int(response.headers.get("Content-Length", 0)) + checksum = hashlib.md5() # create checksum for gzipped file + + # Download the gzipped file + with open(gz_file_path, "wb") as file: + with tqdm.tqdm(total=total_size, + unit="B", + unit_scale=True, + desc="Downloading") as progress_bar: + for data in response.iter_content(chunk_size=1024): + file.write(data) + checksum.update(data) + progress_bar.update(len(data)) + + downloaded_gz_checksum = checksum.hexdigest() + + # Verify gzipped file checksum + if downloaded_gz_checksum != gzipped_checksum: + warnings.warn( + f"Gzipped file checksum verification failed. Deleting '{gz_file_path}'." + ) + gz_file_path.unlink() + warnings.warn("Gzipped file deleted. Retrying download...") + return download_and_extract_gzipped_file(url, expected_checksum, + gzipped_checksum, location, + file_name, retry_count + 1) + + print("Gzipped file checksum verified. Extracting...") + + # Extract the gzipped file + try: + with gzip.open(gz_file_path, 'rb') as f_in: + with open(final_file_path, 'wb') as f_out: + # Extract with progress (estimate based on typical compression ratio) + extracted_size = 0 + while True: + chunk = f_in.read(8192) + if not chunk: + break + f_out.write(chunk) + extracted_size += len(chunk) + except Exception as e: + warnings.warn(f"Extraction failed: {e}. Deleting files and retrying...") + if gz_file_path.exists(): + gz_file_path.unlink() + if final_file_path.exists(): + final_file_path.unlink() + return download_and_extract_gzipped_file(url, expected_checksum, + gzipped_checksum, location, + file_name, retry_count + 1) + + # Verify extracted file checksum + extracted_checksum = calculate_checksum(final_file_path) + if extracted_checksum != expected_checksum: + warnings.warn( + "Extracted file checksum verification failed. Deleting files.") + gz_file_path.unlink() + final_file_path.unlink() + warnings.warn("Files deleted. Retrying download...") + return download_and_extract_gzipped_file(url, expected_checksum, + gzipped_checksum, location, + file_name, retry_count + 1) + + # Clean up the gzipped file after successful extraction + gz_file_path.unlink() + + print(f"Extraction complete. Dataset saved in '{final_file_path}'") + return final_file_path diff --git a/cebra/data/base.py b/cebra/data/base.py index 51199cec..acdcff53 100644 --- a/cebra/data/base.py +++ b/cebra/data/base.py @@ -55,6 +55,7 @@ def __init__(self, download=False, data_url=None, data_checksum=None, + gzipped_checksum=None, location=None, file_name=None): @@ -64,6 +65,7 @@ def __init__(self, self.download = download self.data_url = data_url self.data_checksum = data_checksum + self.gzipped_checksum = gzipped_checksum self.location = location self.file_name = file_name @@ -78,11 +80,21 @@ def __init__(self, "Missing data checksum. Please provide the checksum to verify the data integrity." ) - cebra_data_assets.download_file_with_progress_bar( - url=self.data_url, - expected_checksum=self.data_checksum, - location=self.location, - file_name=self.file_name) + # Use gzipped download if gzipped_checksum is provided + if self.gzipped_checksum is not None: + cebra_data_assets.download_and_extract_gzipped_file( + url=self.data_url, + expected_checksum=self.data_checksum, + gzipped_checksum=self.gzipped_checksum, + location=self.location, + file_name=self.file_name) + else: + # Fall back to legacy download for backward compatibility + cebra_data_assets.download_file_with_progress_bar( + url=self.data_url, + expected_checksum=self.data_checksum, + location=self.location, + file_name=self.file_name) @property @abc.abstractmethod diff --git a/cebra/datasets/hippocampus.py b/cebra/datasets/hippocampus.py index 05c47acb..a8ce12d1 100644 --- a/cebra/datasets/hippocampus.py +++ b/cebra/datasets/hippocampus.py @@ -50,27 +50,35 @@ rat_dataset_urls = { "achilles": { "url": - "https://figshare.com/ndownloader/files/40849463?private_link=9f91576cbbcc8b0d8828", + "https://cebra.fra1.digitaloceanspaces.com/data/rat_hippocampus/achilles.jl.gz", "checksum": - "c52f9b55cbc23c66d57f3842214058b8" + "c52f9b55cbc23c66d57f3842214058b8", + "gzipped_checksum": + "5d7b243e07b24c387e5412cd5ff46f0b" }, "buddy": { "url": - "https://figshare.com/ndownloader/files/40849460?private_link=9f91576cbbcc8b0d8828", + "https://cebra.fra1.digitaloceanspaces.com/data/rat_hippocampus/buddy.jl.gz", "checksum": - "36341322907708c466871bf04bc133c2" + "36341322907708c466871bf04bc133c2", + "gzipped_checksum": + "339290585be2188f48a176f05aaf5df6" }, "cicero": { "url": - "https://figshare.com/ndownloader/files/40849457?private_link=9f91576cbbcc8b0d8828", + "https://cebra.fra1.digitaloceanspaces.com/data/rat_hippocampus/cicero.jl.gz", "checksum": - "a83b02dbdc884fdd7e53df362499d42f" + "a83b02dbdc884fdd7e53df362499d42f", + "gzipped_checksum": + "f262a87d2e59f164cb404cd410015f3a" }, "gatsby": { "url": - "https://figshare.com/ndownloader/files/40849454?private_link=9f91576cbbcc8b0d8828", + "https://cebra.fra1.digitaloceanspaces.com/data/rat_hippocampus/gatsby.jl.gz", "checksum": - "2b889da48178b3155011c12555342813" + "2b889da48178b3155011c12555342813", + "gzipped_checksum": + "564e431c19e55db2286a9d64c86a94c4" } } @@ -95,11 +103,13 @@ def __init__(self, name="achilles", root=_DEFAULT_DATADIR, download=True): location = pathlib.Path(root) / "rat_hippocampus" file_path = location / f"{name}.jl" - super().__init__(download=download, - data_url=rat_dataset_urls[name]["url"], - data_checksum=rat_dataset_urls[name]["checksum"], - location=location, - file_name=f"{name}.jl") + super().__init__( + download=download, + data_url=rat_dataset_urls[name]["url"], + data_checksum=rat_dataset_urls[name]["checksum"], + gzipped_checksum=rat_dataset_urls[name].get("gzipped_checksum"), + location=location, + file_name=f"{name}.jl") data = joblib.load(file_path) self.neural = torch.from_numpy(data["spikes"]).float() diff --git a/cebra/datasets/monkey_reaching.py b/cebra/datasets/monkey_reaching.py index 05071b12..22479455 100644 --- a/cebra/datasets/monkey_reaching.py +++ b/cebra/datasets/monkey_reaching.py @@ -160,75 +160,99 @@ def _get_info(trial_info, data): monkey_reaching_urls = { "all_all.jl": { "url": - "https://figshare.com/ndownloader/files/41668764?private_link=6fa4ee74a8f465ec7914", + "https://cebra.fra1.digitaloceanspaces.com/data/monkey_reaching_preload_smth_40/all_all.jl.gz", "checksum": - "dea556301fa4fafa86e28cf8621cab5a" + "dea556301fa4fafa86e28cf8621cab5a", + "gzipped_checksum": + "399abc6e9ef0b23a0d6d057c6f508939" }, "all_train.jl": { "url": - "https://figshare.com/ndownloader/files/41668752?private_link=6fa4ee74a8f465ec7914", + "https://cebra.fra1.digitaloceanspaces.com/data/monkey_reaching_preload_smth_40/all_train.jl.gz", "checksum": - "e280e4cd86969e6fd8bfd3a8f402b2fe" + "e280e4cd86969e6fd8bfd3a8f402b2fe", + "gzipped_checksum": + "eb52c8641fe83ae2a278b372ddec5f69" }, "all_test.jl": { "url": - "https://figshare.com/ndownloader/files/41668761?private_link=6fa4ee74a8f465ec7914", + "https://cebra.fra1.digitaloceanspaces.com/data/monkey_reaching_preload_smth_40/all_test.jl.gz", "checksum": - "25d3ff2c15014db8b8bf2543482ae881" + "25d3ff2c15014db8b8bf2543482ae881", + "gzipped_checksum": + "7688245cf15e0b92503af943ce9f66aa" }, "all_valid.jl": { "url": - "https://figshare.com/ndownloader/files/41668755?private_link=6fa4ee74a8f465ec7914", + "https://cebra.fra1.digitaloceanspaces.com/data/monkey_reaching_preload_smth_40/all_valid.jl.gz", "checksum": - "8cd25169d31f83ae01b03f7b1b939723" + "8cd25169d31f83ae01b03f7b1b939723", + "gzipped_checksum": + "b169fc008b4d092fe2a1b7e006cd17a7" }, "active_all.jl": { "url": - "https://figshare.com/ndownloader/files/41668776?private_link=6fa4ee74a8f465ec7914", + "https://cebra.fra1.digitaloceanspaces.com/data/monkey_reaching_preload_smth_40/active_all.jl.gz", "checksum": - "c626acea5062122f5a68ef18d3e45e51" + "c626acea5062122f5a68ef18d3e45e51", + "gzipped_checksum": + "b7b86e2ae00bb71341de8fc352dae097" }, "active_train.jl": { "url": - "https://figshare.com/ndownloader/files/41668770?private_link=6fa4ee74a8f465ec7914", + "https://cebra.fra1.digitaloceanspaces.com/data/monkey_reaching_preload_smth_40/active_train.jl.gz", "checksum": - "72a48056691078eee22c36c1992b1d37" + "72a48056691078eee22c36c1992b1d37", + "gzipped_checksum": + "56687c633efcbff6c56bbcfa35597565" }, "active_test.jl": { "url": - "https://figshare.com/ndownloader/files/41668773?private_link=6fa4ee74a8f465ec7914", + "https://cebra.fra1.digitaloceanspaces.com/data/monkey_reaching_preload_smth_40/active_test.jl.gz", "checksum": - "35b7e060008a8722c536584c4748f2ea" + "35b7e060008a8722c536584c4748f2ea", + "gzipped_checksum": + "2057ef1846908a69486a61895d1198e8" }, "active_valid.jl": { "url": - "https://figshare.com/ndownloader/files/41668767?private_link=6fa4ee74a8f465ec7914", + "https://cebra.fra1.digitaloceanspaces.com/data/monkey_reaching_preload_smth_40/active_valid.jl.gz", "checksum": - "dd58eb1e589361b4132f34b22af56b79" + "dd58eb1e589361b4132f34b22af56b79", + "gzipped_checksum": + "60b8e418f234877351fe36f1efc169ad" }, "passive_all.jl": { "url": - "https://figshare.com/ndownloader/files/41668758?private_link=6fa4ee74a8f465ec7914", + "https://cebra.fra1.digitaloceanspaces.com/data/monkey_reaching_preload_smth_40/passive_all.jl.gz", "checksum": - "bbb1bc9d8eec583a46f6673470fc98ad" + "bbb1bc9d8eec583a46f6673470fc98ad", + "gzipped_checksum": + "afb257efa0cac3ccd69ec80478d63691" }, "passive_train.jl": { "url": - "https://figshare.com/ndownloader/files/41668743?private_link=6fa4ee74a8f465ec7914", + "https://cebra.fra1.digitaloceanspaces.com/data/monkey_reaching_preload_smth_40/passive_train.jl.gz", "checksum": - "f22e05a69f70e18ba823a0a89162a45c" + "f22e05a69f70e18ba823a0a89162a45c", + "gzipped_checksum": + "24d98d7d41a52591f838c41fe83dc2c6" }, "passive_test.jl": { "url": - "https://figshare.com/ndownloader/files/41668746?private_link=6fa4ee74a8f465ec7914", + "https://cebra.fra1.digitaloceanspaces.com/data/monkey_reaching_preload_smth_40/passive_test.jl.gz", "checksum": - "42453ae3e4fd27d82d297f78c13cd6b7" + "42453ae3e4fd27d82d297f78c13cd6b7", + "gzipped_checksum": + "f1ff4e9b7c4a0d7fa9dcd271893f57ab" }, "passive_valid.jl": { "url": - "https://figshare.com/ndownloader/files/41668749?private_link=6fa4ee74a8f465ec7914", + "https://cebra.fra1.digitaloceanspaces.com/data/monkey_reaching_preload_smth_40/passive_valid.jl.gz", "checksum": - "2dcc10c27631b95a075eaa2d2297bb4a" + "2dcc10c27631b95a075eaa2d2297bb4a", + "gzipped_checksum": + "311fcb6a3e86022f12d78828f7bd29d5" } } @@ -270,6 +294,8 @@ def __init__(self, data_url=monkey_reaching_urls[f"{self.load_session}_all.jl"]["url"], data_checksum=monkey_reaching_urls[f"{self.load_session}_all.jl"] ["checksum"], + gzipped_checksum=monkey_reaching_urls[f"{self.load_session}_all.jl"] + .get("gzipped_checksum"), location=self.path, file_name=f"{self.load_session}_all.jl", ) @@ -297,6 +323,8 @@ def split(self, split): ["url"], data_checksum=monkey_reaching_urls[ f"{self.load_session}_{split}.jl"]["checksum"], + gzipped_checksum=monkey_reaching_urls[ + f"{self.load_session}_{split}.jl"].get("gzipped_checksum"), location=self.path, file_name=f"{self.load_session}_{split}.jl", ) diff --git a/cebra/datasets/synthetic_data.py b/cebra/datasets/synthetic_data.py index 9288a93d..eab8d6cf 100644 --- a/cebra/datasets/synthetic_data.py +++ b/cebra/datasets/synthetic_data.py @@ -33,51 +33,67 @@ synthetic_data_urls = { "continuous_label_refractory_poisson": { "url": - "https://figshare.com/ndownloader/files/41668815?private_link=7439c5302e99db36eebb", + "https://cebra.fra1.digitaloceanspaces.com/data/synthetic/continuous_label_refractory_poisson.jl.gz", "checksum": - "fcd92bd283c528d5294093190f55ceba" + "fcd92bd283c528d5294093190f55ceba", + "gzipped_checksum": + "3641eed973b9cae972493c70b364e981" }, "continuous_label_t": { "url": - "https://figshare.com/ndownloader/files/41668818?private_link=7439c5302e99db36eebb", + "https://cebra.fra1.digitaloceanspaces.com/data/synthetic/continuous_label_t.jl.gz", "checksum": - "a6e76f274da571568fd2a4bf4cf48b66" + "a6e76f274da571568fd2a4bf4cf48b66", + "gzipped_checksum": + "1dc8805e8f0836c7c99e864100a65bff" }, "continuous_label_uniform": { "url": - "https://figshare.com/ndownloader/files/41668821?private_link=7439c5302e99db36eebb", + "https://cebra.fra1.digitaloceanspaces.com/data/synthetic/continuous_label_uniform.jl.gz", "checksum": - "e67400e77ac009e8c9bc958aa5151973" + "e67400e77ac009e8c9bc958aa5151973", + "gzipped_checksum": + "71d33bc56b89bc227da0990bf16e584b" }, "continuous_label_laplace": { "url": - "https://figshare.com/ndownloader/files/41668824?private_link=7439c5302e99db36eebb", + "https://cebra.fra1.digitaloceanspaces.com/data/synthetic/continuous_label_laplace.jl.gz", "checksum": - "41d7ce4ce8901ae7a5136605ac3f5ffb" + "41d7ce4ce8901ae7a5136605ac3f5ffb", + "gzipped_checksum": + "1563e4958031392d2b2e30cc4cd79b3f" }, "continuous_label_poisson": { "url": - "https://figshare.com/ndownloader/files/41668827?private_link=7439c5302e99db36eebb", + "https://cebra.fra1.digitaloceanspaces.com/data/synthetic/continuous_label_poisson.jl.gz", "checksum": - "a789828f9cca5f3faf36d62ebc4cc8a1" + "a789828f9cca5f3faf36d62ebc4cc8a1", + "gzipped_checksum": + "7691304ee061e0bf1e9bb5f2bb6b20e7" }, "continuous_label_gaussian": { "url": - "https://figshare.com/ndownloader/files/41668830?private_link=7439c5302e99db36eebb", + "https://cebra.fra1.digitaloceanspaces.com/data/synthetic/continuous_label_gaussian.jl.gz", "checksum": - "18d66a2020923e2cd67d2264d20890aa" + "18d66a2020923e2cd67d2264d20890aa", + "gzipped_checksum": + "0cb97a2c1eaa526e57d2248a333ea8e0" }, "continuous_poisson_gaussian_noise": { "url": - "https://figshare.com/ndownloader/files/41668833?private_link=7439c5302e99db36eebb", + "https://cebra.fra1.digitaloceanspaces.com/data/synthetic/continuous_poisson_gaussian_noise.jl.gz", "checksum": - "1a51461820c24a5bcaddaff3991f0ebe" + "1a51461820c24a5bcaddaff3991f0ebe", + "gzipped_checksum": + "5aa6b6eadf2b733562864d5b67bc6b8d" }, "sim_100d_poisson_cont_label": { "url": - "https://figshare.com/ndownloader/files/41668836?private_link=7439c5302e99db36eebb", + "https://cebra.fra1.digitaloceanspaces.com/data/synthetic/sim_100d_poisson_cont_label.npz.gz", "checksum": - "306b9c646e7b76a52cfd828612d700cb" + "306b9c646e7b76a52cfd828612d700cb", + "gzipped_checksum": + "768299435a167dedd57e29b1a6d5af63" } } @@ -98,11 +114,13 @@ def __init__(self, name, root=_DEFAULT_DATADIR, download=True): location = os.path.join(root, "synthetic") file_path = os.path.join(location, f"{name}.jl") - super().__init__(download=download, - data_url=synthetic_data_urls[name]["url"], - data_checksum=synthetic_data_urls[name]["checksum"], - location=location, - file_name=f"{name}.jl") + super().__init__( + download=download, + data_url=synthetic_data_urls[name]["url"], + data_checksum=synthetic_data_urls[name]["checksum"], + gzipped_checksum=synthetic_data_urls[name].get("gzipped_checksum"), + location=location, + file_name=f"{name}.jl") data = joblib.load(file_path) self.data = data #NOTE: making it backwards compatible with synth notebook. diff --git a/tests/test_datasets.py b/tests/test_datasets.py index e8e03ff0..36aa77f6 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -19,6 +19,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import hashlib import os import pathlib import tempfile @@ -384,6 +385,114 @@ def test_download_file_wrong_content_disposition(filename, url, file_name=filename) +def test_download_and_extract_gzipped_file(): + """Test downloading and extracting gzipped files with dual checksum verification.""" + import gzip + + with tempfile.TemporaryDirectory() as temp_dir: + # Create a test file + test_content = b"Test dataset content for gzipped download" + test_filename = "test_dataset.jl" + test_gz_filename = f"{test_filename}.gz" + + # Calculate checksums + unzipped_checksum = cebra_data_assets.calculate_checksum.__wrapped__(test_content) \ + if hasattr(cebra_data_assets.calculate_checksum, '__wrapped__') \ + else hashlib.md5(test_content).hexdigest() + + # Create gzipped content + gzipped_content = gzip.compress(test_content) + gzipped_checksum = hashlib.md5(gzipped_content).hexdigest() + + # Mock the HTTP response + with patch("requests.get") as mock_get: + mock_response = mock_get.return_value + mock_response.status_code = 200 + mock_response.headers = { + "Content-Length": str(len(gzipped_content)) + } + mock_response.iter_content = lambda chunk_size: [gzipped_content] + + # Test successful download and extraction + result = cebra_data_assets.download_and_extract_gzipped_file( + url="http://example.com/test.jl.gz", + expected_checksum=unzipped_checksum, + gzipped_checksum=gzipped_checksum, + location=temp_dir, + file_name=test_filename) + + # Verify the file was extracted + assert result is not None + final_path = os.path.join(temp_dir, test_filename) + assert os.path.exists(final_path) + + # Verify the content is correct + with open(final_path, 'rb') as f: + extracted_content = f.read() + assert extracted_content == test_content + + # Verify the .gz file was cleaned up + gz_path = os.path.join(temp_dir, test_gz_filename) + assert not os.path.exists(gz_path) + + +def test_download_and_extract_gzipped_file_wrong_gzipped_checksum(): + """Test that wrong gzipped checksum raises error after retries.""" + import gzip + + with tempfile.TemporaryDirectory() as temp_dir: + test_content = b"Test content" + gzipped_content = gzip.compress(test_content) + wrong_gz_checksum = "0" * 32 # Wrong checksum + + with patch("requests.get") as mock_get: + mock_response = mock_get.return_value + mock_response.status_code = 200 + mock_response.headers = { + "Content-Length": str(len(gzipped_content)) + } + mock_response.iter_content = lambda chunk_size: [gzipped_content] + + with pytest.raises(RuntimeError, + match="Exceeded maximum retry count"): + cebra_data_assets.download_and_extract_gzipped_file( + url="http://example.com/test.jl.gz", + expected_checksum=hashlib.md5(test_content).hexdigest(), + gzipped_checksum=wrong_gz_checksum, + location=temp_dir, + file_name="test.jl", + retry_count=2) + + +def test_download_and_extract_gzipped_file_wrong_unzipped_checksum(): + """Test that wrong unzipped checksum raises error after retries.""" + import gzip + + with tempfile.TemporaryDirectory() as temp_dir: + test_content = b"Test content" + gzipped_content = gzip.compress(test_content) + gzipped_checksum = hashlib.md5(gzipped_content).hexdigest() + wrong_unzipped_checksum = "0" * 32 # Wrong checksum + + with patch("requests.get") as mock_get: + mock_response = mock_get.return_value + mock_response.status_code = 200 + mock_response.headers = { + "Content-Length": str(len(gzipped_content)) + } + mock_response.iter_content = lambda chunk_size: [gzipped_content] + + with pytest.raises(RuntimeError, + match="Exceeded maximum retry count"): + cebra_data_assets.download_and_extract_gzipped_file( + url="http://example.com/test.jl.gz", + expected_checksum=wrong_unzipped_checksum, + gzipped_checksum=gzipped_checksum, + location=temp_dir, + file_name="test.jl", + retry_count=2) + + @pytest.mark.parametrize("neural, continuous, discrete", [ (np.random.randn(100, 30), np.random.randn( 100, 2), np.random.randint(0, 5, (100,))), From 79a25e6d4ef3849592d2bf2235d5548ceb2f732e Mon Sep 17 00:00:00 2001 From: Steffen Schneider Date: Sat, 31 Jan 2026 19:02:48 +0000 Subject: [PATCH 02/13] minimize the diff --- cebra/datasets/hippocampus.py | 24 +++++------ cebra/datasets/monkey_reaching.py | 72 +++++++++++++++---------------- cebra/datasets/synthetic_data.py | 48 ++++++++++----------- 3 files changed, 72 insertions(+), 72 deletions(-) diff --git a/cebra/datasets/hippocampus.py b/cebra/datasets/hippocampus.py index a8ce12d1..aa794d45 100644 --- a/cebra/datasets/hippocampus.py +++ b/cebra/datasets/hippocampus.py @@ -51,34 +51,34 @@ "achilles": { "url": "https://cebra.fra1.digitaloceanspaces.com/data/rat_hippocampus/achilles.jl.gz", - "checksum": - "c52f9b55cbc23c66d57f3842214058b8", "gzipped_checksum": - "5d7b243e07b24c387e5412cd5ff46f0b" + "5d7b243e07b24c387e5412cd5ff46f0b", + "checksum": + "c52f9b55cbc23c66d57f3842214058b8" }, "buddy": { "url": "https://cebra.fra1.digitaloceanspaces.com/data/rat_hippocampus/buddy.jl.gz", - "checksum": - "36341322907708c466871bf04bc133c2", "gzipped_checksum": - "339290585be2188f48a176f05aaf5df6" + "339290585be2188f48a176f05aaf5df6", + "checksum": + "36341322907708c466871bf04bc133c2" }, "cicero": { "url": "https://cebra.fra1.digitaloceanspaces.com/data/rat_hippocampus/cicero.jl.gz", - "checksum": - "a83b02dbdc884fdd7e53df362499d42f", "gzipped_checksum": - "f262a87d2e59f164cb404cd410015f3a" + "f262a87d2e59f164cb404cd410015f3a", + "checksum": + "a83b02dbdc884fdd7e53df362499d42f" }, "gatsby": { "url": "https://cebra.fra1.digitaloceanspaces.com/data/rat_hippocampus/gatsby.jl.gz", - "checksum": - "2b889da48178b3155011c12555342813", "gzipped_checksum": - "564e431c19e55db2286a9d64c86a94c4" + "564e431c19e55db2286a9d64c86a94c4", + "checksum": + "2b889da48178b3155011c12555342813" } } diff --git a/cebra/datasets/monkey_reaching.py b/cebra/datasets/monkey_reaching.py index 22479455..080e83ae 100644 --- a/cebra/datasets/monkey_reaching.py +++ b/cebra/datasets/monkey_reaching.py @@ -161,98 +161,98 @@ def _get_info(trial_info, data): "all_all.jl": { "url": "https://cebra.fra1.digitaloceanspaces.com/data/monkey_reaching_preload_smth_40/all_all.jl.gz", - "checksum": - "dea556301fa4fafa86e28cf8621cab5a", "gzipped_checksum": - "399abc6e9ef0b23a0d6d057c6f508939" + "399abc6e9ef0b23a0d6d057c6f508939", + "checksum": + "dea556301fa4fafa86e28cf8621cab5a" }, "all_train.jl": { "url": "https://cebra.fra1.digitaloceanspaces.com/data/monkey_reaching_preload_smth_40/all_train.jl.gz", - "checksum": - "e280e4cd86969e6fd8bfd3a8f402b2fe", "gzipped_checksum": - "eb52c8641fe83ae2a278b372ddec5f69" + "eb52c8641fe83ae2a278b372ddec5f69", + "checksum": + "e280e4cd86969e6fd8bfd3a8f402b2fe" }, "all_test.jl": { "url": "https://cebra.fra1.digitaloceanspaces.com/data/monkey_reaching_preload_smth_40/all_test.jl.gz", - "checksum": - "25d3ff2c15014db8b8bf2543482ae881", "gzipped_checksum": - "7688245cf15e0b92503af943ce9f66aa" + "7688245cf15e0b92503af943ce9f66aa", + "checksum": + "25d3ff2c15014db8b8bf2543482ae881" }, "all_valid.jl": { "url": "https://cebra.fra1.digitaloceanspaces.com/data/monkey_reaching_preload_smth_40/all_valid.jl.gz", - "checksum": - "8cd25169d31f83ae01b03f7b1b939723", "gzipped_checksum": - "b169fc008b4d092fe2a1b7e006cd17a7" + "b169fc008b4d092fe2a1b7e006cd17a7", + "checksum": + "8cd25169d31f83ae01b03f7b1b939723" }, "active_all.jl": { "url": "https://cebra.fra1.digitaloceanspaces.com/data/monkey_reaching_preload_smth_40/active_all.jl.gz", - "checksum": - "c626acea5062122f5a68ef18d3e45e51", "gzipped_checksum": - "b7b86e2ae00bb71341de8fc352dae097" + "b7b86e2ae00bb71341de8fc352dae097", + "checksum": + "c626acea5062122f5a68ef18d3e45e51" }, "active_train.jl": { "url": "https://cebra.fra1.digitaloceanspaces.com/data/monkey_reaching_preload_smth_40/active_train.jl.gz", - "checksum": - "72a48056691078eee22c36c1992b1d37", "gzipped_checksum": - "56687c633efcbff6c56bbcfa35597565" + "56687c633efcbff6c56bbcfa35597565", + "checksum": + "72a48056691078eee22c36c1992b1d37" }, "active_test.jl": { "url": "https://cebra.fra1.digitaloceanspaces.com/data/monkey_reaching_preload_smth_40/active_test.jl.gz", - "checksum": - "35b7e060008a8722c536584c4748f2ea", "gzipped_checksum": - "2057ef1846908a69486a61895d1198e8" + "2057ef1846908a69486a61895d1198e8", + "checksum": + "35b7e060008a8722c536584c4748f2ea" }, "active_valid.jl": { "url": "https://cebra.fra1.digitaloceanspaces.com/data/monkey_reaching_preload_smth_40/active_valid.jl.gz", - "checksum": - "dd58eb1e589361b4132f34b22af56b79", "gzipped_checksum": - "60b8e418f234877351fe36f1efc169ad" + "60b8e418f234877351fe36f1efc169ad", + "checksum": + "dd58eb1e589361b4132f34b22af56b79" }, "passive_all.jl": { "url": "https://cebra.fra1.digitaloceanspaces.com/data/monkey_reaching_preload_smth_40/passive_all.jl.gz", - "checksum": - "bbb1bc9d8eec583a46f6673470fc98ad", "gzipped_checksum": - "afb257efa0cac3ccd69ec80478d63691" + "afb257efa0cac3ccd69ec80478d63691", + "checksum": + "bbb1bc9d8eec583a46f6673470fc98ad" }, "passive_train.jl": { "url": "https://cebra.fra1.digitaloceanspaces.com/data/monkey_reaching_preload_smth_40/passive_train.jl.gz", - "checksum": - "f22e05a69f70e18ba823a0a89162a45c", "gzipped_checksum": - "24d98d7d41a52591f838c41fe83dc2c6" + "24d98d7d41a52591f838c41fe83dc2c6", + "checksum": + "f22e05a69f70e18ba823a0a89162a45c" }, "passive_test.jl": { "url": "https://cebra.fra1.digitaloceanspaces.com/data/monkey_reaching_preload_smth_40/passive_test.jl.gz", - "checksum": - "42453ae3e4fd27d82d297f78c13cd6b7", "gzipped_checksum": - "f1ff4e9b7c4a0d7fa9dcd271893f57ab" + "f1ff4e9b7c4a0d7fa9dcd271893f57ab", + "checksum": + "42453ae3e4fd27d82d297f78c13cd6b7" }, "passive_valid.jl": { "url": "https://cebra.fra1.digitaloceanspaces.com/data/monkey_reaching_preload_smth_40/passive_valid.jl.gz", - "checksum": - "2dcc10c27631b95a075eaa2d2297bb4a", "gzipped_checksum": - "311fcb6a3e86022f12d78828f7bd29d5" + "311fcb6a3e86022f12d78828f7bd29d5", + "checksum": + "2dcc10c27631b95a075eaa2d2297bb4a" } } diff --git a/cebra/datasets/synthetic_data.py b/cebra/datasets/synthetic_data.py index eab8d6cf..dc65ff0a 100644 --- a/cebra/datasets/synthetic_data.py +++ b/cebra/datasets/synthetic_data.py @@ -34,66 +34,66 @@ "continuous_label_refractory_poisson": { "url": "https://cebra.fra1.digitaloceanspaces.com/data/synthetic/continuous_label_refractory_poisson.jl.gz", - "checksum": - "fcd92bd283c528d5294093190f55ceba", "gzipped_checksum": - "3641eed973b9cae972493c70b364e981" + "3641eed973b9cae972493c70b364e981", + "checksum": + "fcd92bd283c528d5294093190f55ceba" }, "continuous_label_t": { "url": "https://cebra.fra1.digitaloceanspaces.com/data/synthetic/continuous_label_t.jl.gz", - "checksum": - "a6e76f274da571568fd2a4bf4cf48b66", "gzipped_checksum": - "1dc8805e8f0836c7c99e864100a65bff" + "1dc8805e8f0836c7c99e864100a65bff", + "checksum": + "a6e76f274da571568fd2a4bf4cf48b66" }, "continuous_label_uniform": { "url": "https://cebra.fra1.digitaloceanspaces.com/data/synthetic/continuous_label_uniform.jl.gz", - "checksum": - "e67400e77ac009e8c9bc958aa5151973", "gzipped_checksum": - "71d33bc56b89bc227da0990bf16e584b" + "71d33bc56b89bc227da0990bf16e584b", + "checksum": + "e67400e77ac009e8c9bc958aa5151973" }, "continuous_label_laplace": { "url": "https://cebra.fra1.digitaloceanspaces.com/data/synthetic/continuous_label_laplace.jl.gz", - "checksum": - "41d7ce4ce8901ae7a5136605ac3f5ffb", "gzipped_checksum": - "1563e4958031392d2b2e30cc4cd79b3f" + "1563e4958031392d2b2e30cc4cd79b3f", + "checksum": + "41d7ce4ce8901ae7a5136605ac3f5ffb" }, "continuous_label_poisson": { "url": "https://cebra.fra1.digitaloceanspaces.com/data/synthetic/continuous_label_poisson.jl.gz", - "checksum": - "a789828f9cca5f3faf36d62ebc4cc8a1", "gzipped_checksum": - "7691304ee061e0bf1e9bb5f2bb6b20e7" + "7691304ee061e0bf1e9bb5f2bb6b20e7", + "checksum": + "a789828f9cca5f3faf36d62ebc4cc8a1" }, "continuous_label_gaussian": { "url": "https://cebra.fra1.digitaloceanspaces.com/data/synthetic/continuous_label_gaussian.jl.gz", - "checksum": - "18d66a2020923e2cd67d2264d20890aa", "gzipped_checksum": - "0cb97a2c1eaa526e57d2248a333ea8e0" + "0cb97a2c1eaa526e57d2248a333ea8e0", + "checksum": + "18d66a2020923e2cd67d2264d20890aa" }, "continuous_poisson_gaussian_noise": { "url": "https://cebra.fra1.digitaloceanspaces.com/data/synthetic/continuous_poisson_gaussian_noise.jl.gz", - "checksum": - "1a51461820c24a5bcaddaff3991f0ebe", "gzipped_checksum": - "5aa6b6eadf2b733562864d5b67bc6b8d" + "5aa6b6eadf2b733562864d5b67bc6b8d", + "checksum": + "1a51461820c24a5bcaddaff3991f0ebe" }, "sim_100d_poisson_cont_label": { "url": "https://cebra.fra1.digitaloceanspaces.com/data/synthetic/sim_100d_poisson_cont_label.npz.gz", - "checksum": - "306b9c646e7b76a52cfd828612d700cb", "gzipped_checksum": - "768299435a167dedd57e29b1a6d5af63" + "768299435a167dedd57e29b1a6d5af63", + "checksum": + "306b9c646e7b76a52cfd828612d700cb" } } From 20ab07c44b6e3e3d376bc6ab2922ead31eb5dacc Mon Sep 17 00:00:00 2001 From: Steffen Schneider Date: Sat, 31 Jan 2026 19:07:20 +0000 Subject: [PATCH 03/13] unify the dataset download funcs --- cebra/data/assets.py | 250 ++++++++++++++++------------------------- cebra/data/base.py | 21 +--- tests/test_datasets.py | 18 +-- 3 files changed, 113 insertions(+), 176 deletions(-) diff --git a/cebra/data/assets.py b/cebra/data/assets.py index 683f3b6c..490b798a 100644 --- a/cebra/data/assets.py +++ b/cebra/data/assets.py @@ -33,22 +33,31 @@ _MAX_RETRY_COUNT = 2 -def download_file_with_progress_bar(url: str, - expected_checksum: str, - location: str, - file_name: str, - retry_count: int = 0) -> Optional[str]: +def download_file_with_progress_bar( + url: str, + expected_checksum: str, + location: str, + file_name: str, + retry_count: int = 0, + gzipped_checksum: str = None) -> Optional[str]: """Download a file from the given URL. During download, progress is reported using a progress bar. The downloaded file's checksum is compared to the provided ``expected_checksum``. + If ``gzipped_checksum`` is provided, the file is expected to be gzipped. + The function will verify the gzipped file's checksum, extract it, and then + verify the extracted file's checksum. + Args: url: The URL to download the file from. - expected_checksum: The expected checksum value of the downloaded file. + expected_checksum: The expected checksum value of the downloaded file + (or extracted file if gzipped_checksum is provided). location: The directory where the file will be saved. file_name: The name of the file. retry_count: The number of retry attempts (default: 0). + gzipped_checksum: Optional MD5 checksum of the gzipped file. If provided, + the file will be extracted after download. Returns: The path of the downloaded file if the download is successful, None otherwise. @@ -79,30 +88,34 @@ def download_file_with_progress_bar(url: str, f"Error occurred while downloading the file. Response code: {response.status_code}" ) - # Check if the response headers contain the 'Content-Disposition' header - if 'Content-Disposition' not in response.headers: - raise ValueError( - "Unable to determine the filename. 'Content-Disposition' header not found." - ) - - # Extract the filename from the 'Content-Disposition' header - filename_match = re.search(r'filename="(.+)"', - response.headers.get("Content-Disposition")) - if not filename_match: - raise ValueError( - "Unable to determine the filename from the 'Content-Disposition' header." - ) + # For gzipped files, download to a .gz file first + if gzipped_checksum: + download_path = location_path / f"{file_name}.gz" + else: + # Check if the response headers contain the 'Content-Disposition' header + if 'Content-Disposition' not in response.headers: + raise ValueError( + "Unable to determine the filename. 'Content-Disposition' header not found." + ) + + # Extract the filename from the 'Content-Disposition' header + filename_match = re.search(r'filename="(.+)"', + response.headers.get("Content-Disposition")) + if not filename_match: + raise ValueError( + "Unable to determine the filename from the 'Content-Disposition' header." + ) + + filename = filename_match.group(1) + download_path = location_path / filename # Create the directory and any necessary parent directories location_path.mkdir(parents=True, exist_ok=True) - filename = filename_match.group(1) - file_path = location_path / filename - total_size = int(response.headers.get("Content-Length", 0)) checksum = hashlib.md5() # create checksum - with open(file_path, "wb") as file: + with open(download_path, "wb") as file: with tqdm.tqdm(total=total_size, unit="B", unit_scale=True) as progress_bar: for data in response.iter_content(chunk_size=1024): @@ -112,18 +125,76 @@ def download_file_with_progress_bar(url: str, progress_bar.update(len(data)) downloaded_checksum = checksum.hexdigest() # Get the checksum value + + # If gzipped, verify gzipped checksum, extract, and verify extracted checksum + if gzipped_checksum: + if downloaded_checksum != gzipped_checksum: + warnings.warn( + f"Gzipped file checksum verification failed. Deleting '{download_path}'." + ) + download_path.unlink() + warnings.warn("Gzipped file deleted. Retrying download...") + return download_file_with_progress_bar(url, expected_checksum, + location, file_name, + retry_count + 1, + gzipped_checksum) + + print("Gzipped file checksum verified. Extracting...") + + # Extract the gzipped file + try: + with gzip.open(download_path, 'rb') as f_in: + with open(file_path, 'wb') as f_out: + while True: + chunk = f_in.read(8192) + if not chunk: + break + f_out.write(chunk) + except Exception as e: + warnings.warn( + f"Extraction failed: {e}. Deleting files and retrying...") + if download_path.exists(): + download_path.unlink() + if file_path.exists(): + file_path.unlink() + return download_file_with_progress_bar(url, expected_checksum, + location, file_name, + retry_count + 1, + gzipped_checksum) + + # Verify extracted file checksum + extracted_checksum = calculate_checksum(file_path) + if extracted_checksum != expected_checksum: + warnings.warn( + "Extracted file checksum verification failed. Deleting files.") + download_path.unlink() + file_path.unlink() + warnings.warn("Files deleted. Retrying download...") + return download_file_with_progress_bar(url, expected_checksum, + location, file_name, + retry_count + 1, + gzipped_checksum) + + # Clean up the gzipped file after successful extraction + download_path.unlink() + print(f"Extraction complete. Dataset saved in '{file_path}'") + return url + + # For non-gzipped files, verify checksum if downloaded_checksum != expected_checksum: - warnings.warn(f"Checksum verification failed. Deleting '{file_path}'.") - file_path.unlink() + warnings.warn( + f"Checksum verification failed. Deleting '{download_path}'.") + download_path.unlink() warnings.warn("File deleted. Retrying download...") # Retry download using a for loop for _ in range(retry_count + 1, _MAX_RETRY_COUNT + 1): return download_file_with_progress_bar(url, expected_checksum, location, file_name, - retry_count + 1) + retry_count + 1, + gzipped_checksum) else: - print(f"Download complete. Dataset saved in '{file_path}'") + print(f"Download complete. Dataset saved in '{download_path}'") return url @@ -141,128 +212,3 @@ def calculate_checksum(file_path: str) -> str: for chunk in iter(lambda: file.read(4096), b""): checksum.update(chunk) return checksum.hexdigest() - - -def download_and_extract_gzipped_file(url: str, - expected_checksum: str, - gzipped_checksum: str, - location: str, - file_name: str, - retry_count: int = 0) -> Optional[str]: - """Download a gzipped file from the given URL, verify checksums, and extract. - - This function downloads a gzipped file, verifies the checksum of the gzipped - file, extracts it, and then verifies the checksum of the extracted file. - - Args: - url: The URL to download the gzipped file from. - expected_checksum: The expected MD5 checksum of the extracted file. - gzipped_checksum: The expected MD5 checksum of the gzipped file. - location: The directory where the file will be saved. - file_name: The name of the final extracted file (without .gz extension). - retry_count: The number of retry attempts (default: 0). - - Returns: - The path of the extracted file if successful, None otherwise. - - Raises: - RuntimeError: If the maximum retry count is exceeded. - requests.HTTPError: If the download fails. - """ - - # Check if the final extracted file already exists with correct checksum - location_path = Path(location) - final_file_path = location_path / file_name - - if final_file_path.exists(): - existing_checksum = calculate_checksum(final_file_path) - if existing_checksum == expected_checksum: - return final_file_path - - if retry_count >= _MAX_RETRY_COUNT: - raise RuntimeError( - f"Exceeded maximum retry count ({_MAX_RETRY_COUNT}). " - f"Unable to download the file from {url}") - - # Create the directory and any necessary parent directories - location_path.mkdir(parents=True, exist_ok=True) - - # Download the gzipped file - gz_file_path = location_path / f"{file_name}.gz" - - response = requests.get(url, stream=True) - - # Check if the request was successful - if response.status_code != 200: - raise requests.HTTPError( - f"Error occurred while downloading the file. Response code: {response.status_code}" - ) - - total_size = int(response.headers.get("Content-Length", 0)) - checksum = hashlib.md5() # create checksum for gzipped file - - # Download the gzipped file - with open(gz_file_path, "wb") as file: - with tqdm.tqdm(total=total_size, - unit="B", - unit_scale=True, - desc="Downloading") as progress_bar: - for data in response.iter_content(chunk_size=1024): - file.write(data) - checksum.update(data) - progress_bar.update(len(data)) - - downloaded_gz_checksum = checksum.hexdigest() - - # Verify gzipped file checksum - if downloaded_gz_checksum != gzipped_checksum: - warnings.warn( - f"Gzipped file checksum verification failed. Deleting '{gz_file_path}'." - ) - gz_file_path.unlink() - warnings.warn("Gzipped file deleted. Retrying download...") - return download_and_extract_gzipped_file(url, expected_checksum, - gzipped_checksum, location, - file_name, retry_count + 1) - - print("Gzipped file checksum verified. Extracting...") - - # Extract the gzipped file - try: - with gzip.open(gz_file_path, 'rb') as f_in: - with open(final_file_path, 'wb') as f_out: - # Extract with progress (estimate based on typical compression ratio) - extracted_size = 0 - while True: - chunk = f_in.read(8192) - if not chunk: - break - f_out.write(chunk) - extracted_size += len(chunk) - except Exception as e: - warnings.warn(f"Extraction failed: {e}. Deleting files and retrying...") - if gz_file_path.exists(): - gz_file_path.unlink() - if final_file_path.exists(): - final_file_path.unlink() - return download_and_extract_gzipped_file(url, expected_checksum, - gzipped_checksum, location, - file_name, retry_count + 1) - - # Verify extracted file checksum - extracted_checksum = calculate_checksum(final_file_path) - if extracted_checksum != expected_checksum: - warnings.warn( - "Extracted file checksum verification failed. Deleting files.") - gz_file_path.unlink() - final_file_path.unlink() - warnings.warn("Files deleted. Retrying download...") - return download_and_extract_gzipped_file(url, expected_checksum, - gzipped_checksum, location, - file_name, retry_count + 1) - - # Clean up the gzipped file after successful extraction - gz_file_path.unlink() - - print(f"Extraction complete. Dataset saved in '{final_file_path}'") - return final_file_path diff --git a/cebra/data/base.py b/cebra/data/base.py index acdcff53..f5491e51 100644 --- a/cebra/data/base.py +++ b/cebra/data/base.py @@ -80,21 +80,12 @@ def __init__(self, "Missing data checksum. Please provide the checksum to verify the data integrity." ) - # Use gzipped download if gzipped_checksum is provided - if self.gzipped_checksum is not None: - cebra_data_assets.download_and_extract_gzipped_file( - url=self.data_url, - expected_checksum=self.data_checksum, - gzipped_checksum=self.gzipped_checksum, - location=self.location, - file_name=self.file_name) - else: - # Fall back to legacy download for backward compatibility - cebra_data_assets.download_file_with_progress_bar( - url=self.data_url, - expected_checksum=self.data_checksum, - location=self.location, - file_name=self.file_name) + cebra_data_assets.download_file_with_progress_bar( + url=self.data_url, + expected_checksum=self.data_checksum, + location=self.location, + file_name=self.file_name, + gzipped_checksum=self.gzipped_checksum) @property @abc.abstractmethod diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 36aa77f6..88af686c 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -414,12 +414,12 @@ def test_download_and_extract_gzipped_file(): mock_response.iter_content = lambda chunk_size: [gzipped_content] # Test successful download and extraction - result = cebra_data_assets.download_and_extract_gzipped_file( + result = cebra_data_assets.download_file_with_progress_bar( url="http://example.com/test.jl.gz", expected_checksum=unzipped_checksum, - gzipped_checksum=gzipped_checksum, location=temp_dir, - file_name=test_filename) + file_name=test_filename, + gzipped_checksum=gzipped_checksum) # Verify the file was extracted assert result is not None @@ -455,13 +455,13 @@ def test_download_and_extract_gzipped_file_wrong_gzipped_checksum(): with pytest.raises(RuntimeError, match="Exceeded maximum retry count"): - cebra_data_assets.download_and_extract_gzipped_file( + cebra_data_assets.download_file_with_progress_bar( url="http://example.com/test.jl.gz", expected_checksum=hashlib.md5(test_content).hexdigest(), - gzipped_checksum=wrong_gz_checksum, location=temp_dir, file_name="test.jl", - retry_count=2) + retry_count=2, + gzipped_checksum=wrong_gz_checksum) def test_download_and_extract_gzipped_file_wrong_unzipped_checksum(): @@ -484,13 +484,13 @@ def test_download_and_extract_gzipped_file_wrong_unzipped_checksum(): with pytest.raises(RuntimeError, match="Exceeded maximum retry count"): - cebra_data_assets.download_and_extract_gzipped_file( + cebra_data_assets.download_file_with_progress_bar( url="http://example.com/test.jl.gz", expected_checksum=wrong_unzipped_checksum, - gzipped_checksum=gzipped_checksum, location=temp_dir, file_name="test.jl", - retry_count=2) + retry_count=2, + gzipped_checksum=gzipped_checksum) @pytest.mark.parametrize("neural, continuous, discrete", [ From 214bf6a6d2f7c08f5228a74878d31c06b59b5de0 Mon Sep 17 00:00:00 2001 From: Steffen Schneider Date: Sun, 1 Feb 2026 11:24:00 +0000 Subject: [PATCH 04/13] fix accedentally skipped save/load test --- tests/test_sklearn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_sklearn.py b/tests/test_sklearn.py index 7dfbac0f..145c466c 100644 --- a/tests/test_sklearn.py +++ b/tests/test_sklearn.py @@ -1027,7 +1027,7 @@ def get_offset(self) -> cebra.data.datatypes.Offset: @pytest.mark.parametrize("model_architecture", ["offset1-model", "parametrized-model-5"]) @pytest.mark.parametrize("device", ["cpu"] + - ["cuda"] if torch.cuda.is_available() else []) + (["cuda"] if torch.cuda.is_available() else [])) def test_save_and_load(action, backend_save, backend_load, model_architecture, device): original_model = cebra_sklearn_cebra.CEBRA( From a893974feb0c3306350030000229f967dec4ca22 Mon Sep 17 00:00:00 2001 From: Steffen Schneider Date: Sun, 1 Feb 2026 11:33:52 +0000 Subject: [PATCH 05/13] include a numpy legacy test --- .github/workflows/build.yml | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index a0337e31..f9513351 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -29,18 +29,28 @@ jobs: # https://pytorch.org/get-started/previous-versions/ torch-version: ["2.6.0", "2.10.0"] sklearn-version: ["latest"] + numpy-version: ["latest"] + include: # windows test with standard config - os: windows-latest torch-version: 2.6.0 python-version: "3.12" sklearn-version: "latest" + numpy-version: "latest" # legacy sklearn (several API differences) - os: ubuntu-latest torch-version: 2.6.0 python-version: "3.12" sklearn-version: "legacy" + numpy-version: "latest" + + - os: ubuntu-latest + torch-version: 2.6.0 + python-version: "3.12" + sklearn-version: "latest" + numpy-version: "legacy" # TODO(stes): latest torch and python # requires a PyTables release compatible with @@ -55,6 +65,7 @@ jobs: torch-version: 2.4.0 python-version: "3.10" sklearn-version: "legacy" + numpy-version: "latest" runs-on: ${{ matrix.os }} @@ -88,6 +99,11 @@ jobs: run: | pip install scikit-learn==1.4.2 '.[dev,datasets,integrations]' + - name: Check numpy legacy version + if: matrix.numpy-version == 'legacy' + run: | + pip install "numpy<2" '.[dev,datasets,integrations]' + - name: Run the formatter run: | make format From 3e98794f90c608188a571c1bf9cd9228d25147a2 Mon Sep 17 00:00:00 2001 From: Steffen Schneider Date: Sun, 1 Feb 2026 12:45:12 +0000 Subject: [PATCH 06/13] fix windows compatibility for tempfile --- tests/test_sklearn.py | 74 +++++++++++++++++++++++++++++++++---------- 1 file changed, 58 insertions(+), 16 deletions(-) diff --git a/tests/test_sklearn.py b/tests/test_sklearn.py index 145c466c..7bee251e 100644 --- a/tests/test_sklearn.py +++ b/tests/test_sklearn.py @@ -19,7 +19,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import contextlib import itertools +import os import tempfile import warnings @@ -47,6 +49,34 @@ _DEVICES = ("cpu",) +@contextlib.contextmanager +def _windows_compatible_tempfile(mode="w+b", delete=True, **kwargs): + """Context manager for creating temporary files compatible with Windows. + + On Windows, files opened with delete=True cannot be accessed by other + processes or reopened. This context manager creates a temporary file + with delete=False, yields its path, and ensures cleanup in a finally block. + + Args: + mode: File mode (default: "w+b") + **kwargs: Additional arguments passed to NamedTemporaryFile + + Yields: + str: Path to the temporary file + """ + if not delete: + raise ValueError("'delete' must be True") + + with tempfile.NamedTemporaryFile(mode=mode, delete=False, **kwargs) as f: + tempname = f.name + + try: + yield tempname + finally: + if os.path.exists(tempname): + os.remove(tempname) + + def test_imports(): import cebra @@ -1037,24 +1067,23 @@ def test_save_and_load(action, backend_save, backend_load, model_architecture, device=device) original_model = action(original_model) - with tempfile.NamedTemporaryFile(mode="w+b", delete=True) as savefile: + with _windows_compatible_tempfile(mode="w+b") as tempname: if not check_if_fit(original_model): with pytest.raises(ValueError): - original_model.save(savefile.name, backend=backend_save) + original_model.save(tempname, backend=backend_save) else: if "parametrized" in original_model.model_architecture and backend_save == "torch": with pytest.raises(AttributeError): - original_model.save(savefile.name, backend=backend_save) + original_model.save(tempname, backend=backend_save) else: - original_model.save(savefile.name, backend=backend_save) + original_model.save(tempname, backend=backend_save) if (backend_load != "auto") and (backend_save != backend_load): with pytest.raises(RuntimeError): - cebra_sklearn_cebra.CEBRA.load(savefile.name, - backend_load) + cebra_sklearn_cebra.CEBRA.load(tempname, backend_load) else: loaded_model = cebra_sklearn_cebra.CEBRA.load( - savefile.name, backend_load) + tempname, backend_load) _assert_equal(original_model, loaded_model) action(loaded_model) @@ -1130,9 +1159,9 @@ def test_move_cpu_to_cuda_device(device): device_str = f'cuda:{device_model.index}' assert device_str == new_device - with tempfile.NamedTemporaryFile(mode="w+b", delete=True) as savefile: - cebra_model.save(savefile.name) - loaded_model = cebra_sklearn_cebra.CEBRA.load(savefile.name) + with _windows_compatible_tempfile(mode="w+b") as tempname: + cebra_model.save(tempname) + loaded_model = cebra_sklearn_cebra.CEBRA.load(tempname) assert cebra_model.device == loaded_model.device assert next(cebra_model.solver_.model.parameters()).device == next( @@ -1159,9 +1188,9 @@ def test_move_cpu_to_mps_device(device): device_model = next(cebra_model.solver_.model.parameters()).device assert device_model.type == new_device - with tempfile.NamedTemporaryFile(mode="w+b", delete=True) as savefile: - cebra_model.save(savefile.name) - loaded_model = cebra_sklearn_cebra.CEBRA.load(savefile.name) + with _windows_compatible_tempfile(mode="w+b") as tempname: + cebra_model.save(tempname) + loaded_model = cebra_sklearn_cebra.CEBRA.load(tempname) assert cebra_model.device == loaded_model.device assert next(cebra_model.solver_.model.parameters()).device == next( @@ -1198,9 +1227,9 @@ def test_move_mps_to_cuda_device(device): device_str = f'cuda:{device_model.index}' assert device_str == new_device - with tempfile.NamedTemporaryFile(mode="w+b", delete=True) as savefile: - cebra_model.save(savefile.name) - loaded_model = cebra_sklearn_cebra.CEBRA.load(savefile.name) + with _windows_compatible_tempfile(mode="w+b") as tempname: + cebra_model.save(tempname) + loaded_model = cebra_sklearn_cebra.CEBRA.load(tempname) assert cebra_model.device == loaded_model.device assert next(cebra_model.solver_.model.parameters()).device == next( @@ -1561,3 +1590,16 @@ def test_non_writable_array(): embedding = cebra_model.transform(X) assert isinstance(embedding, np.ndarray) assert embedding.shape[0] == X.shape[0] + + +def test_read_write(): + X = np.random.randn(100, 10) + y = np.random.randn(100, 2) + cebra_model = cebra.CEBRA(max_iterations=2, batch_size=32, device="cpu") + cebra_model.fit(X, y) + cebra_model.transform(X) + + with _windows_compatible_tempfile(mode="w+b", delete=False) as tempname: + cebra_model.save(tempname) + loaded_model = cebra.CEBRA.load(tempname) + _assert_equal(cebra_model, loaded_model) From d71effae6ddbcef044799589236358026f34e0e8 Mon Sep 17 00:00:00 2001 From: Steffen Schneider Date: Sun, 1 Feb 2026 12:59:30 +0000 Subject: [PATCH 07/13] fix added test --- tests/test_sklearn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_sklearn.py b/tests/test_sklearn.py index 7bee251e..7d13383f 100644 --- a/tests/test_sklearn.py +++ b/tests/test_sklearn.py @@ -1599,7 +1599,7 @@ def test_read_write(): cebra_model.fit(X, y) cebra_model.transform(X) - with _windows_compatible_tempfile(mode="w+b", delete=False) as tempname: + with _windows_compatible_tempfile(mode="w+b", delete=True) as tempname: cebra_model.save(tempname) loaded_model = cebra.CEBRA.load(tempname) _assert_equal(cebra_model, loaded_model) From 7bc5ca353faba8d08a081f27f9004a61d132faf3 Mon Sep 17 00:00:00 2001 From: Steffen Schneider Date: Sun, 1 Feb 2026 14:08:40 +0000 Subject: [PATCH 08/13] Fix legacy loading logic --- cebra/integrations/sklearn/cebra.py | 99 +++++++++++++++++++++-------- cebra/registry.py | 18 +++++- tests/test_sklearn.py | 43 ++++++++----- 3 files changed, 118 insertions(+), 42 deletions(-) diff --git a/cebra/integrations/sklearn/cebra.py b/cebra/integrations/sklearn/cebra.py index 474145ee..b9f9c795 100644 --- a/cebra/integrations/sklearn/cebra.py +++ b/cebra/integrations/sklearn/cebra.py @@ -23,6 +23,8 @@ import importlib.metadata import itertools +import pickle +import warnings from typing import (Callable, Dict, Iterable, List, Literal, Optional, Tuple, Union) @@ -1336,6 +1338,26 @@ def _get_state(self): } return state + def _get_state_dict(self): + backend = "sklearn" + return { + 'args': self.get_params(), + 'state': self._get_state(), + 'state_dict': self.solver_.state_dict(), + 'metadata': { + 'backend': + backend, + 'cebra_version': + cebra.__version__, + 'torch_version': + torch.__version__, + 'numpy_version': + np.__version__, + 'sklearn_version': + importlib.metadata.distribution("scikit-learn").version + } + } + def save(self, filename: str, backend: Literal["torch", "sklearn"] = "sklearn"): @@ -1384,28 +1406,16 @@ def save(self, """ if sklearn_utils.check_fitted(self): if backend == "torch": + warnings.warn( + "Saving with backend='torch' is deprecated and will be removed in a future version. " + "Please use backend='sklearn' instead.", + DeprecationWarning, + stacklevel=2, + ) checkpoint = torch.save(self, filename) elif backend == "sklearn": - checkpoint = torch.save( - { - 'args': self.get_params(), - 'state': self._get_state(), - 'state_dict': self.solver_.state_dict(), - 'metadata': { - 'backend': - backend, - 'cebra_version': - cebra.__version__, - 'torch_version': - torch.__version__, - 'numpy_version': - np.__version__, - 'sklearn_version': - importlib.metadata.distribution("scikit-learn" - ).version - } - }, filename) + checkpoint = torch.save(self._get_state_dict(), filename) else: raise NotImplementedError(f"Unsupported backend: {backend}") else: @@ -1457,15 +1467,52 @@ def load(cls, >>> tmp_file.unlink() """ supported_backends = ["auto", "sklearn", "torch"] + if backend not in supported_backends: raise NotImplementedError( f"Unsupported backend: '{backend}'. Supported backends are: {', '.join(supported_backends)}" ) - checkpoint = _safe_torch_load(filename, weights_only, **kwargs) + if backend not in ["auto", "sklearn"]: + warnings.warn( + "From CEBRA version 0.6.1 onwards, the 'backend' parameter in cebra.CEBRA.load is deprecated and will be ignored; " + "the sklearn backend is now always used. Models saved with the torch backend can still be loaded.", + category=DeprecationWarning, + stacklevel=2, + ) - if backend == "auto": - backend = "sklearn" if isinstance(checkpoint, dict) else "torch" + backend = "sklearn" + + # NOTE(stes): For maximum backwards compatibility, we allow to load legacy checkpoints. From 0.7.0 onwards, + # the user will have to explicitly pass weights_only=False to load these checkpoints, following the changes + # introduced in torch 2.6.0. + try: + checkpoint = _safe_torch_load(filename, weights_only=True, **kwargs) + except pickle.UnpicklingError as e: + if weights_only is not False: + if packaging.version.parse( + cebra.__version__) < packaging.version.parse("0.7"): + warnings.warn( + "Failed to unpickle checkpoint with weights_only=True. " + "Falling back to loading with weights_only=False. " + "This is unsafe and should only be done if you trust the source of the model file. " + "In the future, loading these checkpoints will only work if weights_only=False is explicitly passed.", + category=UserWarning, + stacklevel=2, + ) + else: + raise ValueError( + "Failed to unpickle checkpoint with weights_only=True. " + "This may be due to an incompatible model file format. " + "To attempt loading this checkpoint, please pass weights_only=False to CEBRA.load. " + "Example: CEBRA.load(filename, weights_only=False)." + ) from e + + checkpoint = _safe_torch_load(filename, + weights_only=False, + **kwargs) + checkpoint = _check_type_checkpoint(checkpoint) + checkpoint = checkpoint._get_state_dict() if isinstance(checkpoint, dict) and backend == "torch": raise RuntimeError( @@ -1476,10 +1523,10 @@ def load(cls, "Cannot use 'sklearn' backend a non dictionary-based checkpoint. " "Please try a different backend.") - if backend == "sklearn": - cebra_ = _load_cebra_with_sklearn_backend(checkpoint) - else: - cebra_ = _check_type_checkpoint(checkpoint) + if backend != "sklearn": + raise ValueError(f"Unsupported backend: {backend}") + + cebra_ = _load_cebra_with_sklearn_backend(checkpoint) n_features = cebra_.n_features_ cebra_.solver_.n_features = ([ diff --git a/cebra/registry.py b/cebra/registry.py index 994fbd5c..1bbc5093 100644 --- a/cebra/registry.py +++ b/cebra/registry.py @@ -46,6 +46,7 @@ from __future__ import annotations import fnmatch +import functools import itertools import sys import textwrap @@ -214,14 +215,29 @@ def _zip_dict(d): yield dict(zip(keys, combination)) def _create_class(cls, **default_kwargs): + class_name = pattern.format(**default_kwargs) - @register(pattern.format(**default_kwargs), base=pattern) + @register(class_name, base=pattern) class _ParametrizedClass(cls): def __init__(self, *args, **kwargs): default_kwargs.update(kwargs) super().__init__(*args, **default_kwargs) + # Make the class pickleable by copying metadata from the base class + # and registering it in the module namespace + functools.update_wrapper(_ParametrizedClass, cls, updated=[]) + + # Set a unique qualname so pickle finds this class, not the base class + unique_name = f"{cls.__qualname__}_{class_name.replace('-', '_')}" + _ParametrizedClass.__qualname__ = unique_name + _ParametrizedClass.__name__ = unique_name + + # Register in module namespace so pickle can find it via getattr + parent_module = sys.modules.get(cls.__module__) + if parent_module is not None: + setattr(parent_module, unique_name, _ParametrizedClass) + def _parametrize(cls): for _default_kwargs in kwargs: _create_class(cls, **_default_kwargs) diff --git a/tests/test_sklearn.py b/tests/test_sklearn.py index 7d13383f..999bc7f3 100644 --- a/tests/test_sklearn.py +++ b/tests/test_sklearn.py @@ -1053,7 +1053,7 @@ def get_offset(self) -> cebra.data.datatypes.Offset: @pytest.mark.parametrize("action", _iterate_actions()) @pytest.mark.parametrize("backend_save", ["torch", "sklearn"]) -@pytest.mark.parametrize("backend_load", ["auto", "torch", "sklearn"]) +@pytest.mark.parametrize("backend_load", ["sklearn", "auto", "torch"]) @pytest.mark.parametrize("model_architecture", ["offset1-model", "parametrized-model-5"]) @pytest.mark.parametrize("device", ["cpu"] + @@ -1072,20 +1072,14 @@ def test_save_and_load(action, backend_save, backend_load, model_architecture, with pytest.raises(ValueError): original_model.save(tempname, backend=backend_save) else: - if "parametrized" in original_model.model_architecture and backend_save == "torch": - with pytest.raises(AttributeError): - original_model.save(tempname, backend=backend_save) - else: - original_model.save(tempname, backend=backend_save) + original_model.save(tempname, backend=backend_save) + + weights_only = None - if (backend_load != "auto") and (backend_save != backend_load): - with pytest.raises(RuntimeError): - cebra_sklearn_cebra.CEBRA.load(tempname, backend_load) - else: - loaded_model = cebra_sklearn_cebra.CEBRA.load( - tempname, backend_load) - _assert_equal(original_model, loaded_model) - action(loaded_model) + loaded_model = cebra_sklearn_cebra.CEBRA.load( + tempname, backend_load, weights_only=weights_only) + _assert_equal(original_model, loaded_model) + action(loaded_model) def get_ordered_cuda_devices(): @@ -1489,7 +1483,7 @@ def test_new_transform(model_architecture, device): X, session_id=0) assert np.allclose(embedding1, embedding2, rtol=1e-5, - atol=1e-8), "Arrays are not close enough" + atol=1e-8), " are not close enough" embedding1 = cebra_model.transform(torch.Tensor(X), session_id=0) embedding2 = _utils_deprecated.cebra_transform_deprecated(cebra_model, @@ -1603,3 +1597,22 @@ def test_read_write(): cebra_model.save(tempname) loaded_model = cebra.CEBRA.load(tempname) _assert_equal(cebra_model, loaded_model) + + +def test_repro_pickle_error(): + """The torch backend for save/loading fails with python 3.14. + + See https://github.com/AdaptiveMotorControlLab/CEBRA/pull/292. + + This test is a minimal repro of the error. + """ + + model = cebra_sklearn_cebra.CEBRA(model_architecture='parametrized-model-5', + max_iterations=5, + batch_size=100, + device='cpu') + + model.fit(np.random.randn(1000, 10)) + + with _windows_compatible_tempfile(mode="w+b", delete=True) as tempname: + model.save(tempname, backend="torch") From 52023c9965952c9a4a8ca79f35573c68cec266db Mon Sep 17 00:00:00 2001 From: Steffen Schneider Date: Sun, 1 Feb 2026 14:20:46 +0000 Subject: [PATCH 09/13] minimize diff in tests --- tests/test_sklearn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_sklearn.py b/tests/test_sklearn.py index 999bc7f3..de3cec42 100644 --- a/tests/test_sklearn.py +++ b/tests/test_sklearn.py @@ -1053,7 +1053,7 @@ def get_offset(self) -> cebra.data.datatypes.Offset: @pytest.mark.parametrize("action", _iterate_actions()) @pytest.mark.parametrize("backend_save", ["torch", "sklearn"]) -@pytest.mark.parametrize("backend_load", ["sklearn", "auto", "torch"]) +@pytest.mark.parametrize("backend_load", ["auto", "torch", "sklearn"]) @pytest.mark.parametrize("model_architecture", ["offset1-model", "parametrized-model-5"]) @pytest.mark.parametrize("device", ["cpu"] + From abdba0047219cf20cc1fc989f79b03f475d4b6df Mon Sep 17 00:00:00 2001 From: Steffen Schneider Date: Sun, 1 Feb 2026 14:29:07 +0000 Subject: [PATCH 10/13] Fix _assert_equal check --- tests/test_sklearn.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/test_sklearn.py b/tests/test_sklearn.py index de3cec42..831ad49d 100644 --- a/tests/test_sklearn.py +++ b/tests/test_sklearn.py @@ -1010,7 +1010,11 @@ def _assert_equal(original_model, loaded_model): if check_if_fit(loaded_model): _assert_same_state_dict(original_model.state_dict_, loaded_model.state_dict_) - X = np.random.normal(0, 1, (100, 1)) + + n_features = loaded_model.n_features_ + if isinstance(n_features, list): + n_features = n_features[0] + X = np.random.normal(0, 1, (100, n_features)) if loaded_model.num_sessions is not None: assert np.allclose(loaded_model.transform(X, session_id=0), From 595ffdf0ec5d130588e0f595fbb4c84e672cbd2a Mon Sep 17 00:00:00 2001 From: Steffen Schneider Date: Sun, 1 Feb 2026 14:31:23 +0000 Subject: [PATCH 11/13] Bump version to 0.6.1 --- Dockerfile | 2 +- Makefile | 2 +- PKGBUILD | 2 +- cebra/__init__.py | 2 +- reinstall.sh | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/Dockerfile b/Dockerfile index 5f092805..ba57657f 100644 --- a/Dockerfile +++ b/Dockerfile @@ -40,7 +40,7 @@ RUN make dist FROM cebra-base # install the cebra wheel -ENV WHEEL=cebra-0.6.0-py3-none-any.whl +ENV WHEEL=cebra-0.6.1-py3-none-any.whl WORKDIR /build COPY --from=wheel /build/dist/${WHEEL} . RUN pip install --no-cache-dir ${WHEEL}'[dev,integrations,datasets]' diff --git a/Makefile b/Makefile index 9989d2dd..78321283 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -CEBRA_VERSION := 0.6.0 +CEBRA_VERSION := 0.6.1 dist: python3 -m pip install virtualenv diff --git a/PKGBUILD b/PKGBUILD index 401569ba..ac588deb 100644 --- a/PKGBUILD +++ b/PKGBUILD @@ -1,7 +1,7 @@ # Maintainer: Steffen Schneider pkgname=python-cebra _pkgname=cebra -pkgver=0.6.0 +pkgver=0.6.1 pkgrel=1 pkgdesc="Consistent Embeddings of high-dimensional Recordings using Auxiliary variables" url="https://cebra.ai" diff --git a/cebra/__init__.py b/cebra/__init__.py index ff36d354..0dc6c652 100644 --- a/cebra/__init__.py +++ b/cebra/__init__.py @@ -66,7 +66,7 @@ import cebra.integrations.sklearn as sklearn -__version__ = "0.6.0" +__version__ = "0.6.1" __all__ = ["CEBRA"] __allow_lazy_imports = False __lazy_imports = {} diff --git a/reinstall.sh b/reinstall.sh index d191b8f7..78034d7b 100755 --- a/reinstall.sh +++ b/reinstall.sh @@ -15,7 +15,7 @@ pip uninstall -y cebra # Get version info after uninstalling --- this will automatically get the # most recent version based on the source code in the current directory. # $(tools/get_cebra_version.sh) -VERSION=0.6.0 +VERSION=0.6.1 echo "Upgrading to CEBRA v${VERSION}" # Upgrade the build system (PEP517/518 compatible) From 2aad99a9593c31c7c7a905cf00f8e8ac6f02f5d8 Mon Sep 17 00:00:00 2001 From: Steffen Schneider Date: Sun, 1 Feb 2026 14:50:33 +0000 Subject: [PATCH 12/13] fix loading logic for legacy torch --- cebra/integrations/sklearn/cebra.py | 45 ++++++++++++++--------------- 1 file changed, 21 insertions(+), 24 deletions(-) diff --git a/cebra/integrations/sklearn/cebra.py b/cebra/integrations/sklearn/cebra.py index b9f9c795..3de6d4fd 100644 --- a/cebra/integrations/sklearn/cebra.py +++ b/cebra/integrations/sklearn/cebra.py @@ -64,20 +64,22 @@ def check_version(estimator): sklearn.__version__) < packaging.version.parse("1.6.dev") -def _safe_torch_load(filename, weights_only, **kwargs): - if weights_only is None: - if packaging.version.parse( - torch.__version__) >= packaging.version.parse("2.6.0"): - weights_only = True - else: - weights_only = False +def _safe_torch_load(filename, weights_only=False, **kwargs): + checkpoint = None + legacy_mode = packaging.version.parse( + torch.__version__) < packaging.version.parse("2.6.0") - if not weights_only: + if legacy_mode: checkpoint = torch.load(filename, weights_only=False, **kwargs) else: - # NOTE(stes): This is only supported for torch 2.6+ with torch.serialization.safe_globals(CEBRA_LOAD_SAFE_GLOBALS): - checkpoint = torch.load(filename, weights_only=True, **kwargs) + checkpoint = torch.load(filename, + weights_only=weights_only, + **kwargs) + + if not isinstance(checkpoint, dict): + _check_type_checkpoint(checkpoint) + checkpoint = checkpoint._get_state_dict() return checkpoint @@ -317,8 +319,9 @@ def _require_arg(key): def _check_type_checkpoint(checkpoint): if not isinstance(checkpoint, cebra.CEBRA): - raise RuntimeError("Model loaded from file is not compatible with " - "the current CEBRA version.") + raise RuntimeError( + "Model loaded from file is not compatible with " + f"the current CEBRA version. Got: {type(checkpoint)}") if not sklearn_utils.check_fitted(checkpoint): raise ValueError( "CEBRA model is not fitted. Loading it is not supported.") @@ -1487,7 +1490,7 @@ def load(cls, # the user will have to explicitly pass weights_only=False to load these checkpoints, following the changes # introduced in torch 2.6.0. try: - checkpoint = _safe_torch_load(filename, weights_only=True, **kwargs) + checkpoint = _safe_torch_load(filename, **kwargs) except pickle.UnpicklingError as e: if weights_only is not False: if packaging.version.parse( @@ -1511,21 +1514,15 @@ def load(cls, checkpoint = _safe_torch_load(filename, weights_only=False, **kwargs) - checkpoint = _check_type_checkpoint(checkpoint) - checkpoint = checkpoint._get_state_dict() - - if isinstance(checkpoint, dict) and backend == "torch": - raise RuntimeError( - "Cannot use 'torch' backend with a dictionary-based checkpoint. " - "Please try a different backend.") - if not isinstance(checkpoint, dict) and backend == "sklearn": - raise RuntimeError( - "Cannot use 'sklearn' backend a non dictionary-based checkpoint. " - "Please try a different backend.") if backend != "sklearn": raise ValueError(f"Unsupported backend: {backend}") + if not isinstance(checkpoint, dict): + raise RuntimeError( + "Cannot use 'sklearn' backend a non dictionary-based checkpoint. " + f"Please try a different backend. Got: {type(checkpoint)}") + cebra_ = _load_cebra_with_sklearn_backend(checkpoint) n_features = cebra_.n_features_ From bd27653db9885f6dd8103937720dea037fd687d3 Mon Sep 17 00:00:00 2001 From: Steffen Schneider Date: Sun, 1 Feb 2026 15:06:45 +0000 Subject: [PATCH 13/13] allowlist float32d --- cebra/integrations/sklearn/cebra.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/cebra/integrations/sklearn/cebra.py b/cebra/integrations/sklearn/cebra.py index 3de6d4fd..00645523 100644 --- a/cebra/integrations/sklearn/cebra.py +++ b/cebra/integrations/sklearn/cebra.py @@ -52,8 +52,13 @@ # windows (https://github.com/AdaptiveMotorControlLab/CEBRA/pull/281#issuecomment-3764185072) # on build (windows-latest, torch 2.6.0, python 3.12, latest sklearn) CEBRA_LOAD_SAFE_GLOBALS = [ - cebra.data.Offset, torch.torch_version.TorchVersion, np.dtype, - np.dtypes.Int32DType, np.dtypes.Float64DType, np.dtypes.Int64DType + cebra.data.Offset, + torch.torch_version.TorchVersion, + np.dtype, + np.dtypes.Int32DType, + np.dtypes.Int64DType, + np.dtypes.Float32DType, + np.dtypes.Float64DType, ]