diff --git a/src/murfey/workflows/clem/register_preprocessing_results.py b/src/murfey/workflows/clem/register_preprocessing_results.py index b287fbff..7228f0ef 100644 --- a/src/murfey/workflows/clem/register_preprocessing_results.py +++ b/src/murfey/workflows/clem/register_preprocessing_results.py @@ -11,11 +11,12 @@ import logging import traceback from collections.abc import Collection +from functools import cached_property from importlib.metadata import entry_points from pathlib import Path from typing import Literal, Optional -from pydantic import BaseModel +from pydantic import BaseModel, computed_field from sqlmodel import Session, select import murfey.util.db as MurfeyDB @@ -53,16 +54,42 @@ class CLEMPreprocessingResult(BaseModel): resolution: float extent: list[float] # [x0, x1, y0, y1] - -def _is_clem_atlas(result: CLEMPreprocessingResult): - # If an image has a width/height of at least 1.5 mm, it should qualify as an atlas - return ( - max( - result.pixels_x * result.pixel_size, - result.pixels_y * result.pixel_size, + # Valid Pydantic decorator not supported by MyPy + @computed_field # type: ignore + @cached_property + def is_denoised(self) -> bool: + """ + The "_Lng_LVCC" suffix appended to a CLEM dataset's position name indicates + that it's a denoised image set of the same position. These results should + override or supersede the original ones once they're available. + """ + return "_Lng_LVCC" in self.series_name + + # Valid Pydantic decorator not supported by MyPy + @computed_field # type: ignore + @cached_property + def site_name(self) -> str: + """ + Extract just the name of the site by removing the "_Lng_LVCC" suffix from + the series name. + """ + return self.series_name.replace("_Lng_LVCC", "") + + # Valid Pydantic decorator not supported by MyPy + @computed_field # type: ignore + @cached_property + def is_atlas(self) -> bool: + """ + Incoming image sets with a width/height greater/equal to the pre-set threshold, + it should qualify as an atlas. + """ + return ( + max( + self.pixels_x * self.pixel_size, + self.pixels_y * self.pixel_size, + ) + >= processing_params.atlas_threshold ) - >= processing_params.atlas_threshold - ) COLOR_FLAGS_MURFEY = { @@ -91,51 +118,63 @@ def _register_clem_imaging_site( result: CLEMPreprocessingResult, murfey_db: Session, ): + output_file = list(result.output_files.values())[0] if not ( clem_img_site := murfey_db.exec( select(MurfeyDB.ImagingSite) .where(MurfeyDB.ImagingSite.session_id == session_id) - .where(MurfeyDB.ImagingSite.site_name == result.series_name) + .where(MurfeyDB.ImagingSite.site_name == result.site_name) ).one_or_none() ): clem_img_site = MurfeyDB.ImagingSite( - session_id=session_id, site_name=result.series_name + session_id=session_id, + site_name=result.site_name, + image_path=str(output_file.parent / "*tiff"), + data_type="atlas" if result.is_atlas else "grid_square", + # Shape and resolution information + image_pixels_x=result.pixels_x, + image_pixels_y=result.pixels_y, + image_pixel_size=result.pixel_size, + units=result.units, + # Extent of imaged area in real space + x0=result.extent[0], + x1=result.extent[1], + y0=result.extent[2], + y1=result.extent[3], ) - # Add metadata for this series - output_file = list(result.output_files.values())[0] - clem_img_site.image_path = str(output_file.parent / "*tiff") - clem_img_site.data_type = "atlas" if _is_clem_atlas(result) else "grid_square" - clem_img_site.number_of_members = result.number_of_members - for col_name, value in _get_color_flags(result.output_files.keys()).items(): - setattr(clem_img_site, col_name, value) - clem_img_site.collection_mode = _determine_collection_mode( - result.output_files.keys() - ) - clem_img_site.image_pixels_x = result.pixels_x - clem_img_site.image_pixels_y = result.pixels_y - clem_img_site.image_pixel_size = result.pixel_size - clem_img_site.units = result.units - clem_img_site.x0 = result.extent[0] - clem_img_site.x1 = result.extent[1] - clem_img_site.y0 = result.extent[2] - clem_img_site.y1 = result.extent[3] - # Register thumbnails if they are present - if result.thumbnails and result.thumbnail_size: - thumbnail = list(result.thumbnails.values())[0] - clem_img_site.thumbnail_path = str(thumbnail.parent / "*.png") - - thumbnail_height, thumbnail_width = result.thumbnail_size - scaling_factor = min( - thumbnail_height / result.pixels_y, thumbnail_width / result.pixels_x - ) - clem_img_site.thumbnail_pixel_size = result.pixel_size / scaling_factor - clem_img_site.thumbnail_pixels_x = ( - int(round(result.pixels_x * scaling_factor)) or 1 - ) - clem_img_site.thumbnail_pixels_y = ( - int(round(result.pixels_y * scaling_factor)) or 1 + # Iteratively add colour channel information + clem_img_site.number_of_members = result.number_of_members + for col_name, value in _get_color_flags(result.output_files.keys()).items(): + setattr(clem_img_site, col_name, value) + clem_img_site.collection_mode = _determine_collection_mode( + result.output_files.keys() ) + + # Register thumbnails if they are present + if result.thumbnails and result.thumbnail_size: + thumbnail = list(result.thumbnails.values())[0] + clem_img_site.thumbnail_path = str(thumbnail.parent / "*.png") + + thumbnail_height, thumbnail_width = result.thumbnail_size + scaling_factor = min( + thumbnail_height / result.pixels_y, thumbnail_width / result.pixels_x + ) + clem_img_site.thumbnail_pixel_size = result.pixel_size / scaling_factor + clem_img_site.thumbnail_pixels_x = ( + int(round(result.pixels_x * scaling_factor)) or 1 + ) + clem_img_site.thumbnail_pixels_y = ( + int(round(result.pixels_y * scaling_factor)) or 1 + ) + + # Overwrite file paths for existing entry if latest one is denoised + if result.is_denoised: + clem_img_site.image_path = str(output_file.parent / "*tiff") + if result.thumbnails and result.thumbnail_size: + thumbnail = list(result.thumbnails.values())[0] + clem_img_site.thumbnail_path = str(thumbnail.parent / "*.png") + murfey_db.add(clem_img_site) murfey_db.commit() murfey_db.close() @@ -183,12 +222,12 @@ def _register_dcg_and_atlas( visit_number = visit_name.split("-")[-1] # Generate name/tag for data colleciton group based on series name - dcg_name = result.series_name.split("--")[0] - if result.series_name.split("--")[1].isdigit(): - dcg_name += f"--{result.series_name.split('--')[1]}" + dcg_name = result.site_name.split("--")[0] + if result.site_name.split("--")[1].isdigit(): + dcg_name += f"--{result.site_name.split('--')[1]}" # Determine values for atlas - if _is_clem_atlas(result): + if result.is_atlas: output_file = list(result.output_files.values())[0] # Register the thumbnail entries if they are provided if result.thumbnails and result.thumbnail_size is not None: @@ -227,7 +266,7 @@ def _register_dcg_and_atlas( dcg_entry = dcg_search[0] # Update atlas if registering atlas dataset # and data collection group already exists - if _is_clem_atlas(result): + if result.is_atlas: atlas_message = { "session_id": session_id, "dcgid": dcg_entry.id, @@ -287,11 +326,11 @@ def _register_dcg_and_atlas( clem_img_site := murfey_db.exec( select(MurfeyDB.ImagingSite) .where(MurfeyDB.ImagingSite.session_id == session_id) - .where(MurfeyDB.ImagingSite.site_name == result.series_name) + .where(MurfeyDB.ImagingSite.site_name == result.site_name) ).one_or_none() ): clem_img_site = MurfeyDB.ImagingSite( - session_id=session_id, site_name=result.series_name + session_id=session_id, site_name=result.site_name ) clem_img_site.dcg_id = dcg_entry.id @@ -311,9 +350,9 @@ def _register_grid_square( logger.error("Unable to find transport manager") return # Load all entries for the current data collection group - dcg_name = result.series_name.split("--")[0] - if result.series_name.split("--")[1].isdigit(): - dcg_name += f"--{result.series_name.split('--')[1]}" + dcg_name = result.site_name.split("--")[0] + if result.site_name.split("--")[1].isdigit(): + dcg_name += f"--{result.site_name.split('--')[1]}" # Check if an atlas has been registered if not ( diff --git a/tests/workflows/clem/test_register_preprocessing_results.py b/tests/workflows/clem/test_register_preprocessing_results.py index 3a48669f..9f835e7f 100644 --- a/tests/workflows/clem/test_register_preprocessing_results.py +++ b/tests/workflows/clem/test_register_preprocessing_results.py @@ -47,7 +47,7 @@ def generate_preprocessing_messages( # Construct all the datasets to be tested datasets: list[tuple[Path, bool, bool, tuple[int, int], float, list[float]]] = [ ( - grid_dir / "Overview_1" / "Image_1", + grid_dir / "Overview 1" / "Image 1", False, True, (2400, 2400), @@ -59,22 +59,38 @@ def generate_preprocessing_messages( datasets.extend( [ ( - grid_dir / "TileScan_1" / f"Position_{n}", + grid_dir / "TileScan 1" / f"Position {n + 1}", True, False, (2048, 2048), 1.6e-7, [0.003, 0.00332768, 0.003, 0.00332768], ) - for n in range(5) + for n in range(3) + ] + ) + datasets.extend( + [ + ( + grid_dir / "TileScan 1" / f"Position {n + 1}_Lng_LVCC", + True, + False, + (2048, 2048), + 1.6e-7, + [0.003, 0.00332768, 0.003, 0.00332768], + ) + for n in range(3) ] ) messages: list[dict[str, Any]] = [] - for dataset in datasets: + for series_path, is_stack, is_montage, shape, pixel_size, extent in datasets: # Unpack items from list of dataset parameters - series_path = dataset[0] - series_name = str(series_path.relative_to(processed_dir)).replace("/", "--") + series_name = ( + str(series_path.relative_to(processed_dir)) + .replace("/", "--") + .replace(" ", "_") + ) metadata = series_path / "metadata" / f"{series_path.stem}.xml" metadata.parent.mkdir(parents=True, exist_ok=True) metadata.touch(exist_ok=True) @@ -89,11 +105,6 @@ def generate_preprocessing_messages( thumbnail.parent.mkdir(parents=True) thumbnail.touch(exist_ok=True) thumbnail_size = (512, 512) - is_stack = dataset[1] - is_montage = dataset[2] - shape = dataset[3] - pixel_size = dataset[4] - extent = dataset[5] message = { "session_id": session_id, @@ -373,21 +384,23 @@ def test_run_with_db( else: assert mock_align_and_merge_call.call_count == len(preprocessing_messages) * 3 - # Both databases should have entries for data collection group, and grid squares - # ISPyB database should additionally have an atlas entry + # Murfey's DataCollectionGroup should have an entry murfey_dcg_search = murfey_db_session.exec( sm_select(MurfeyDB.DataCollectionGroup).where( MurfeyDB.DataCollectionGroup.session_id == murfey_session.id ) ).all() assert len(murfey_dcg_search) == 1 + + # GridSquare entries should be half the initial number of entries due to overwrites murfey_gs_search = murfey_db_session.exec( sm_select(MurfeyDB.GridSquare).where( MurfeyDB.GridSquare.session_id == murfey_session.id ) ).all() - assert len(murfey_gs_search) == len(preprocessing_messages) - 1 + assert len(murfey_gs_search) == (len(preprocessing_messages) - 1) // 2 + # ISPyB's DataCollectionGroup should have an entry murfey_dcg = murfey_dcg_search[0] ispyb_dcg_search = ( ispyb_db_session.execute( @@ -400,6 +413,7 @@ def test_run_with_db( ) assert len(ispyb_dcg_search) == 1 + # Atlas should have an entry ispyb_dcg = ispyb_dcg_search[0] ispyb_atlas_search = ( ispyb_db_session.execute( @@ -419,12 +433,13 @@ def test_run_with_db( } collection_mode = _determine_collection_mode(colors) + # Atlas color flags and collection mode should be set correctly ispyb_atlas = ispyb_atlas_search[0] - # Check that the Atlas color flags and collection mode are set correctly for flag, value in color_flags.items(): assert getattr(ispyb_atlas, flag) == value assert ispyb_atlas.mode == collection_mode + # ISPyB's GrridSquare should have half the number of intiail entries ispyb_gs_search = ( ispyb_db_session.execute( sa_select(ISPyBDB.GridSquare).where( @@ -434,9 +449,12 @@ def test_run_with_db( .scalars() .all() ) - assert len(ispyb_gs_search) == len(preprocessing_messages) - 1 + assert len(ispyb_gs_search) == (len(preprocessing_messages) - 1) // 2 for gs in ispyb_gs_search: - # Check that the Atlas color flags and collection mode are set correctly + # Check that all entries point to the denoised images ("_Lng_LVCC") + assert gs.gridSquareImage is not None and "_Lng_LVCC" in gs.gridSquareImage + + # Check that the GridSquare color flags and collection mode are set correctly for flag, value in color_flags.items(): assert getattr(gs, flag) == value assert gs.mode == collection_mode