Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 94 additions & 55 deletions src/murfey/workflows/clem/register_preprocessing_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@
import logging
import traceback
from collections.abc import Collection
from functools import cached_property
from importlib.metadata import entry_points
from pathlib import Path
from typing import Literal, Optional

from pydantic import BaseModel
from pydantic import BaseModel, computed_field
from sqlmodel import Session, select

import murfey.util.db as MurfeyDB
Expand Down Expand Up @@ -53,16 +54,42 @@ class CLEMPreprocessingResult(BaseModel):
resolution: float
extent: list[float] # [x0, x1, y0, y1]


def _is_clem_atlas(result: CLEMPreprocessingResult):
# If an image has a width/height of at least 1.5 mm, it should qualify as an atlas
return (
max(
result.pixels_x * result.pixel_size,
result.pixels_y * result.pixel_size,
# Valid Pydantic decorator not supported by MyPy
@computed_field # type: ignore
@cached_property
def is_denoised(self) -> bool:
"""
The "_Lng_LVCC" suffix appended to a CLEM dataset's position name indicates
that it's a denoised image set of the same position. These results should
override or supersede the original ones once they're available.
"""
return "_Lng_LVCC" in self.series_name

# Valid Pydantic decorator not supported by MyPy
@computed_field # type: ignore
@cached_property
def site_name(self) -> str:
"""
Extract just the name of the site by removing the "_Lng_LVCC" suffix from
the series name.
"""
return self.series_name.replace("_Lng_LVCC", "")

# Valid Pydantic decorator not supported by MyPy
@computed_field # type: ignore
@cached_property
def is_atlas(self) -> bool:
"""
Incoming image sets with a width/height greater/equal to the pre-set threshold,
it should qualify as an atlas.
"""
return (
max(
self.pixels_x * self.pixel_size,
self.pixels_y * self.pixel_size,
)
>= processing_params.atlas_threshold
)
>= processing_params.atlas_threshold
)


COLOR_FLAGS_MURFEY = {
Expand Down Expand Up @@ -91,51 +118,63 @@ def _register_clem_imaging_site(
result: CLEMPreprocessingResult,
murfey_db: Session,
):
output_file = list(result.output_files.values())[0]
if not (
clem_img_site := murfey_db.exec(
select(MurfeyDB.ImagingSite)
.where(MurfeyDB.ImagingSite.session_id == session_id)
.where(MurfeyDB.ImagingSite.site_name == result.series_name)
.where(MurfeyDB.ImagingSite.site_name == result.site_name)
).one_or_none()
):
clem_img_site = MurfeyDB.ImagingSite(
session_id=session_id, site_name=result.series_name
session_id=session_id,
site_name=result.site_name,
image_path=str(output_file.parent / "*tiff"),
data_type="atlas" if result.is_atlas else "grid_square",
# Shape and resolution information
image_pixels_x=result.pixels_x,
image_pixels_y=result.pixels_y,
image_pixel_size=result.pixel_size,
units=result.units,
# Extent of imaged area in real space
x0=result.extent[0],
x1=result.extent[1],
y0=result.extent[2],
y1=result.extent[3],
)

# Add metadata for this series
output_file = list(result.output_files.values())[0]
clem_img_site.image_path = str(output_file.parent / "*tiff")
clem_img_site.data_type = "atlas" if _is_clem_atlas(result) else "grid_square"
clem_img_site.number_of_members = result.number_of_members
for col_name, value in _get_color_flags(result.output_files.keys()).items():
setattr(clem_img_site, col_name, value)
clem_img_site.collection_mode = _determine_collection_mode(
result.output_files.keys()
)
clem_img_site.image_pixels_x = result.pixels_x
clem_img_site.image_pixels_y = result.pixels_y
clem_img_site.image_pixel_size = result.pixel_size
clem_img_site.units = result.units
clem_img_site.x0 = result.extent[0]
clem_img_site.x1 = result.extent[1]
clem_img_site.y0 = result.extent[2]
clem_img_site.y1 = result.extent[3]
# Register thumbnails if they are present
if result.thumbnails and result.thumbnail_size:
thumbnail = list(result.thumbnails.values())[0]
clem_img_site.thumbnail_path = str(thumbnail.parent / "*.png")

thumbnail_height, thumbnail_width = result.thumbnail_size
scaling_factor = min(
thumbnail_height / result.pixels_y, thumbnail_width / result.pixels_x
)
clem_img_site.thumbnail_pixel_size = result.pixel_size / scaling_factor
clem_img_site.thumbnail_pixels_x = (
int(round(result.pixels_x * scaling_factor)) or 1
)
clem_img_site.thumbnail_pixels_y = (
int(round(result.pixels_y * scaling_factor)) or 1
# Iteratively add colour channel information
clem_img_site.number_of_members = result.number_of_members
for col_name, value in _get_color_flags(result.output_files.keys()).items():
setattr(clem_img_site, col_name, value)
clem_img_site.collection_mode = _determine_collection_mode(
result.output_files.keys()
)

# Register thumbnails if they are present
if result.thumbnails and result.thumbnail_size:
thumbnail = list(result.thumbnails.values())[0]
clem_img_site.thumbnail_path = str(thumbnail.parent / "*.png")

thumbnail_height, thumbnail_width = result.thumbnail_size
scaling_factor = min(
thumbnail_height / result.pixels_y, thumbnail_width / result.pixels_x
)
clem_img_site.thumbnail_pixel_size = result.pixel_size / scaling_factor
clem_img_site.thumbnail_pixels_x = (
int(round(result.pixels_x * scaling_factor)) or 1
)
clem_img_site.thumbnail_pixels_y = (
int(round(result.pixels_y * scaling_factor)) or 1
)

# Overwrite file paths for existing entry if latest one is denoised
if result.is_denoised:
clem_img_site.image_path = str(output_file.parent / "*tiff")
if result.thumbnails and result.thumbnail_size:
thumbnail = list(result.thumbnails.values())[0]
clem_img_site.thumbnail_path = str(thumbnail.parent / "*.png")

murfey_db.add(clem_img_site)
murfey_db.commit()
murfey_db.close()
Expand Down Expand Up @@ -183,12 +222,12 @@ def _register_dcg_and_atlas(
visit_number = visit_name.split("-")[-1]

# Generate name/tag for data colleciton group based on series name
dcg_name = result.series_name.split("--")[0]
if result.series_name.split("--")[1].isdigit():
dcg_name += f"--{result.series_name.split('--')[1]}"
dcg_name = result.site_name.split("--")[0]
if result.site_name.split("--")[1].isdigit():
dcg_name += f"--{result.site_name.split('--')[1]}"

# Determine values for atlas
if _is_clem_atlas(result):
if result.is_atlas:
output_file = list(result.output_files.values())[0]
# Register the thumbnail entries if they are provided
if result.thumbnails and result.thumbnail_size is not None:
Expand Down Expand Up @@ -227,7 +266,7 @@ def _register_dcg_and_atlas(
dcg_entry = dcg_search[0]
# Update atlas if registering atlas dataset
# and data collection group already exists
if _is_clem_atlas(result):
if result.is_atlas:
atlas_message = {
"session_id": session_id,
"dcgid": dcg_entry.id,
Expand Down Expand Up @@ -287,11 +326,11 @@ def _register_dcg_and_atlas(
clem_img_site := murfey_db.exec(
select(MurfeyDB.ImagingSite)
.where(MurfeyDB.ImagingSite.session_id == session_id)
.where(MurfeyDB.ImagingSite.site_name == result.series_name)
.where(MurfeyDB.ImagingSite.site_name == result.site_name)
).one_or_none()
):
clem_img_site = MurfeyDB.ImagingSite(
session_id=session_id, site_name=result.series_name
session_id=session_id, site_name=result.site_name
)

clem_img_site.dcg_id = dcg_entry.id
Expand All @@ -311,9 +350,9 @@ def _register_grid_square(
logger.error("Unable to find transport manager")
return
# Load all entries for the current data collection group
dcg_name = result.series_name.split("--")[0]
if result.series_name.split("--")[1].isdigit():
dcg_name += f"--{result.series_name.split('--')[1]}"
dcg_name = result.site_name.split("--")[0]
if result.site_name.split("--")[1].isdigit():
dcg_name += f"--{result.site_name.split('--')[1]}"

# Check if an atlas has been registered
if not (
Expand Down
52 changes: 35 additions & 17 deletions tests/workflows/clem/test_register_preprocessing_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def generate_preprocessing_messages(
# Construct all the datasets to be tested
datasets: list[tuple[Path, bool, bool, tuple[int, int], float, list[float]]] = [
(
grid_dir / "Overview_1" / "Image_1",
grid_dir / "Overview 1" / "Image 1",
False,
True,
(2400, 2400),
Expand All @@ -59,22 +59,38 @@ def generate_preprocessing_messages(
datasets.extend(
[
(
grid_dir / "TileScan_1" / f"Position_{n}",
grid_dir / "TileScan 1" / f"Position {n + 1}",
True,
False,
(2048, 2048),
1.6e-7,
[0.003, 0.00332768, 0.003, 0.00332768],
)
for n in range(5)
for n in range(3)
]
)
datasets.extend(
[
(
grid_dir / "TileScan 1" / f"Position {n + 1}_Lng_LVCC",
True,
False,
(2048, 2048),
1.6e-7,
[0.003, 0.00332768, 0.003, 0.00332768],
)
for n in range(3)
]
)

messages: list[dict[str, Any]] = []
for dataset in datasets:
for series_path, is_stack, is_montage, shape, pixel_size, extent in datasets:
# Unpack items from list of dataset parameters
series_path = dataset[0]
series_name = str(series_path.relative_to(processed_dir)).replace("/", "--")
series_name = (
str(series_path.relative_to(processed_dir))
.replace("/", "--")
.replace(" ", "_")
)
metadata = series_path / "metadata" / f"{series_path.stem}.xml"
metadata.parent.mkdir(parents=True, exist_ok=True)
metadata.touch(exist_ok=True)
Expand All @@ -89,11 +105,6 @@ def generate_preprocessing_messages(
thumbnail.parent.mkdir(parents=True)
thumbnail.touch(exist_ok=True)
thumbnail_size = (512, 512)
is_stack = dataset[1]
is_montage = dataset[2]
shape = dataset[3]
pixel_size = dataset[4]
extent = dataset[5]

message = {
"session_id": session_id,
Expand Down Expand Up @@ -373,21 +384,23 @@ def test_run_with_db(
else:
assert mock_align_and_merge_call.call_count == len(preprocessing_messages) * 3

# Both databases should have entries for data collection group, and grid squares
# ISPyB database should additionally have an atlas entry
# Murfey's DataCollectionGroup should have an entry
murfey_dcg_search = murfey_db_session.exec(
sm_select(MurfeyDB.DataCollectionGroup).where(
MurfeyDB.DataCollectionGroup.session_id == murfey_session.id
)
).all()
assert len(murfey_dcg_search) == 1

# GridSquare entries should be half the initial number of entries due to overwrites
murfey_gs_search = murfey_db_session.exec(
sm_select(MurfeyDB.GridSquare).where(
MurfeyDB.GridSquare.session_id == murfey_session.id
)
).all()
assert len(murfey_gs_search) == len(preprocessing_messages) - 1
assert len(murfey_gs_search) == (len(preprocessing_messages) - 1) // 2

# ISPyB's DataCollectionGroup should have an entry
murfey_dcg = murfey_dcg_search[0]
ispyb_dcg_search = (
ispyb_db_session.execute(
Expand All @@ -400,6 +413,7 @@ def test_run_with_db(
)
assert len(ispyb_dcg_search) == 1

# Atlas should have an entry
ispyb_dcg = ispyb_dcg_search[0]
ispyb_atlas_search = (
ispyb_db_session.execute(
Expand All @@ -419,12 +433,13 @@ def test_run_with_db(
}
collection_mode = _determine_collection_mode(colors)

# Atlas color flags and collection mode should be set correctly
ispyb_atlas = ispyb_atlas_search[0]
# Check that the Atlas color flags and collection mode are set correctly
for flag, value in color_flags.items():
assert getattr(ispyb_atlas, flag) == value
assert ispyb_atlas.mode == collection_mode

# ISPyB's GrridSquare should have half the number of intiail entries
ispyb_gs_search = (
ispyb_db_session.execute(
sa_select(ISPyBDB.GridSquare).where(
Expand All @@ -434,9 +449,12 @@ def test_run_with_db(
.scalars()
.all()
)
assert len(ispyb_gs_search) == len(preprocessing_messages) - 1
assert len(ispyb_gs_search) == (len(preprocessing_messages) - 1) // 2
for gs in ispyb_gs_search:
# Check that the Atlas color flags and collection mode are set correctly
# Check that all entries point to the denoised images ("_Lng_LVCC")
assert gs.gridSquareImage is not None and "_Lng_LVCC" in gs.gridSquareImage

# Check that the GridSquare color flags and collection mode are set correctly
for flag, value in color_flags.items():
assert getattr(gs, flag) == value
assert gs.mode == collection_mode
Expand Down
Loading