From b8df39eb6f609e71060e131200ed744da7e8b2f9 Mon Sep 17 00:00:00 2001 From: mtuann Date: Wed, 11 Mar 2026 19:25:08 -0500 Subject: [PATCH 1/2] add RetinaUNet model skeleton with tests and example ablation --- docs/api/models.rst | 1 + .../api/models/pyhealth.models.RetinaUNet.rst | 9 + examples/lidc_nodule_detection_retina_unet.py | 126 +++++++++ pyhealth/models/__init__.py | 1 + pyhealth/models/retina_unet.py | 246 ++++++++++++++++++ tests/core/test_retina_unet.py | 95 +++++++ 6 files changed, 478 insertions(+) create mode 100644 docs/api/models/pyhealth.models.RetinaUNet.rst create mode 100644 examples/lidc_nodule_detection_retina_unet.py create mode 100644 pyhealth/models/retina_unet.py create mode 100644 tests/core/test_retina_unet.py diff --git a/docs/api/models.rst b/docs/api/models.rst index 7368dec94..fdab07021 100644 --- a/docs/api/models.rst +++ b/docs/api/models.rst @@ -177,6 +177,7 @@ API Reference models/pyhealth.models.GNN models/pyhealth.models.Transformer models/pyhealth.models.TransformersModel + models/pyhealth.models.RetinaUNet models/pyhealth.models.RETAIN models/pyhealth.models.GAMENet models/pyhealth.models.GraphCare diff --git a/docs/api/models/pyhealth.models.RetinaUNet.rst b/docs/api/models/pyhealth.models.RetinaUNet.rst new file mode 100644 index 000000000..e2f8c71a6 --- /dev/null +++ b/docs/api/models/pyhealth.models.RetinaUNet.rst @@ -0,0 +1,9 @@ +pyhealth.models.RetinaUNet +=================================== + +Retina U-Net style image model with an auxiliary segmentation branch. + +.. autoclass:: pyhealth.models.RetinaUNet + :members: + :undoc-members: + :show-inheritance: diff --git a/examples/lidc_nodule_detection_retina_unet.py b/examples/lidc_nodule_detection_retina_unet.py new file mode 100644 index 000000000..7ca272240 --- /dev/null +++ b/examples/lidc_nodule_detection_retina_unet.py @@ -0,0 +1,126 @@ +"""Synthetic ablation example for RetinaUNet. + +This script follows the course naming convention: +`examples/{dataset}_{task_name}_{model}.py` + +The dataset is a tiny synthetic stand-in to validate: +- model wiring +- training/evaluation loop +- ablation workflow + +For full LIDC experiments, replace `build_synthetic_dataset` with a real loader. +""" + +from __future__ import annotations + +import argparse +import random + +import torch + +from pyhealth.datasets import create_sample_dataset, get_dataloader, split_by_visit +from pyhealth.models import RetinaUNet +from pyhealth.trainer import Trainer + + +def set_seed(seed: int) -> None: + random.seed(seed) + torch.manual_seed(seed) + + +def make_circle_image(size: int, radius: int, cx: int, cy: int) -> list[list[float]]: + image = [] + for y in range(size): + row = [] + for x in range(size): + dist = ((x - cx) ** 2 + (y - cy) ** 2) ** 0.5 + row.append(1.0 if dist <= radius else 0.0) + image.append(row) + return image + + +def build_synthetic_dataset(num_samples: int, image_size: int): + samples = [] + for idx in range(num_samples): + has_nodule = idx % 2 + radius = 4 if has_nodule else 2 + cx = 10 + (idx % 8) + cy = 12 + (idx % 10) + image = make_circle_image(image_size, radius, cx, cy) + samples.append( + { + "patient_id": f"p-{idx}", + "visit_id": f"v-{idx}", + "image": image, + "label": has_nodule, + } + ) + dataset = create_sample_dataset( + samples=samples, + input_schema={"image": "tensor"}, + output_schema={"label": "binary"}, + dataset_name="lidc_synthetic", + ) + return dataset + + +def run_ablation( + dataset, + batch_size: int, + epochs: int, + device: str, + base_channels: int, + seg_loss_weight: float, +) -> dict: + train_data, val_data, test_data = split_by_visit(dataset, [0.6, 0.2, 0.2], seed=7) + train_loader = get_dataloader(train_data, batch_size=batch_size, shuffle=True) + val_loader = get_dataloader(val_data, batch_size=batch_size, shuffle=False) + test_loader = get_dataloader(test_data, batch_size=batch_size, shuffle=False) + + model = RetinaUNet( + dataset=dataset, + in_channels=1, + base_channels=base_channels, + seg_loss_weight=seg_loss_weight, + ) + trainer = Trainer(model=model, device=device) + trainer.train(train_dataloader=train_loader, val_dataloader=val_loader, epochs=epochs) + return trainer.evaluate(test_loader) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--epochs", type=int, default=1) + parser.add_argument("--batch-size", type=int, default=8) + parser.add_argument("--num-samples", type=int, default=80) + parser.add_argument("--image-size", type=int, default=32) + parser.add_argument("--seed", type=int, default=7) + parser.add_argument("--device", type=str, default="cpu") + args = parser.parse_args() + + set_seed(args.seed) + dataset = build_synthetic_dataset(args.num_samples, args.image_size) + + experiments = [ + {"name": "cls_only", "base_channels": 16, "seg_loss_weight": 0.0}, + {"name": "cls_plus_seg", "base_channels": 16, "seg_loss_weight": 0.1}, + ] + + print("Running RetinaUNet ablation on synthetic LIDC-style data") + for exp in experiments: + result = run_ablation( + dataset=dataset, + batch_size=args.batch_size, + epochs=args.epochs, + device=args.device, + base_channels=exp["base_channels"], + seg_loss_weight=exp["seg_loss_weight"], + ) + print("-" * 80) + print(f"Experiment: {exp['name']}") + print(f"base_channels={exp['base_channels']}, seg_loss_weight={exp['seg_loss_weight']}") + print(f"Metrics: {result}") + + +if __name__ == "__main__": + main() diff --git a/pyhealth/models/__init__.py b/pyhealth/models/__init__.py index 5233b1726..08ed96129 100644 --- a/pyhealth/models/__init__.py +++ b/pyhealth/models/__init__.py @@ -21,6 +21,7 @@ from .molerec import MoleRec, MoleRecLayer from .retain import MultimodalRETAIN, RETAIN, RETAINLayer from .rnn import MultimodalRNN, RNN, RNNLayer +from .retina_unet import RetinaUNet from .safedrug import SafeDrug, SafeDrugLayer from .sparcnet import DenseBlock, DenseLayer, SparcNet, TransitionLayer from .stagenet import StageNet, StageNetLayer diff --git a/pyhealth/models/retina_unet.py b/pyhealth/models/retina_unet.py new file mode 100644 index 000000000..5b913b69f --- /dev/null +++ b/pyhealth/models/retina_unet.py @@ -0,0 +1,246 @@ +from __future__ import annotations + +from typing import Dict + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from pyhealth.datasets import SampleDataset +from pyhealth.models.base_model import BaseModel + + +class ConvBlock(nn.Module): + """A two-layer convolutional block used by RetinaUNet.""" + + def __init__(self, in_channels: int, out_channels: int): + super().__init__() + self.block = nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), + nn.BatchNorm2d(out_channels), + nn.ReLU(inplace=True), + nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), + nn.BatchNorm2d(out_channels), + nn.ReLU(inplace=True), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.block(x) + + +class RetinaUNet(BaseModel): + """Retina U-Net style model with an auxiliary segmentation branch. + + This implementation is intentionally lightweight for reproducibility + experiments in PyHealth: + + - An encoder backbone extracts image features. + - A classification head predicts image-level labels. + - A U-Net-like decoder predicts an auxiliary segmentation map. + + The final training loss is: + + ``loss = cls_loss + seg_loss_weight * seg_loss`` + + where ``seg_loss`` is computed from either a provided ``seg_target`` or a + pseudo-mask created from image intensity. + + Args: + dataset: SampleDataset used to infer feature/label keys and output size. + in_channels: Expected number of input channels. Default is 1. + base_channels: Width of the first encoder stage. Default is 32. + seg_loss_weight: Weight for auxiliary segmentation loss. Default is 0.1. + dropout: Dropout used in the classification head. Default is 0.1. + """ + + def __init__( + self, + dataset: SampleDataset, + in_channels: int = 1, + base_channels: int = 32, + seg_loss_weight: float = 0.1, + dropout: float = 0.1, + ): + super().__init__(dataset=dataset) + if len(self.feature_keys) != 1: + raise ValueError("RetinaUNet supports exactly one image-like feature key.") + if len(self.label_keys) != 1: + raise ValueError("RetinaUNet supports exactly one label key.") + if in_channels <= 0: + raise ValueError("in_channels must be positive.") + if base_channels <= 0: + raise ValueError("base_channels must be positive.") + if seg_loss_weight < 0: + raise ValueError("seg_loss_weight must be non-negative.") + + self.feature_key = self.feature_keys[0] + self.label_key = self.label_keys[0] + self.in_channels = in_channels + self.base_channels = base_channels + self.seg_loss_weight = seg_loss_weight + self.dropout = dropout + + c1 = base_channels + c2 = base_channels * 2 + c3 = base_channels * 4 + c4 = base_channels * 8 + + self.pool = nn.MaxPool2d(kernel_size=2, stride=2) + self.enc1 = ConvBlock(in_channels, c1) + self.enc2 = ConvBlock(c1, c2) + self.enc3 = ConvBlock(c2, c3) + self.bottleneck = ConvBlock(c3, c4) + + self.up3 = nn.ConvTranspose2d(c4, c3, kernel_size=2, stride=2) + self.dec3 = ConvBlock(c3 * 2, c3) + self.up2 = nn.ConvTranspose2d(c3, c2, kernel_size=2, stride=2) + self.dec2 = ConvBlock(c2 * 2, c2) + self.up1 = nn.ConvTranspose2d(c2, c1, kernel_size=2, stride=2) + self.dec1 = ConvBlock(c1 * 2, c1) + self.seg_head = nn.Conv2d(c1, 1, kernel_size=1) + + output_size = self.get_output_size() + self.classifier = nn.Sequential( + nn.AdaptiveAvgPool2d((1, 1)), + nn.Flatten(), + nn.Dropout(p=dropout), + nn.Linear(c4, output_size), + ) + + @staticmethod + def _to_nchw(x: torch.Tensor) -> torch.Tensor: + """Convert input tensor to NCHW format.""" + if x.dim() == 2: + x = x.unsqueeze(0).unsqueeze(0) + elif x.dim() == 3: + # Treat as NHW by default. + x = x.unsqueeze(1) + elif x.dim() != 4: + raise ValueError(f"Expected 2D/3D/4D tensor, got shape {tuple(x.shape)}.") + if x.dim() == 4 and x.shape[1] not in {1, 3} and x.shape[-1] in {1, 3}: + # NHWC -> NCHW + x = x.permute(0, 3, 1, 2).contiguous() + return x + + def _align_channels(self, x: torch.Tensor) -> torch.Tensor: + """Match input channel count to model configuration.""" + if x.shape[1] == self.in_channels: + return x + if self.in_channels == 1: + return x.mean(dim=1, keepdim=True) + if x.shape[1] == 1 and self.in_channels == 3: + return x.repeat(1, 3, 1, 1) + if x.shape[1] > self.in_channels: + return x[:, : self.in_channels] + repeats = (self.in_channels + x.shape[1] - 1) // x.shape[1] + x = x.repeat(1, repeats, 1, 1) + return x[:, : self.in_channels] + + @staticmethod + def _build_pseudo_mask(x: torch.Tensor) -> torch.Tensor: + """Create a pseudo segmentation target from image intensity.""" + intensity = x.mean(dim=1, keepdim=True) + threshold = intensity.mean(dim=(2, 3), keepdim=True) + return (intensity > threshold).float() + + @staticmethod + def _resize_like(x: torch.Tensor, ref: torch.Tensor) -> torch.Tensor: + if x.shape[-2:] != ref.shape[-2:]: + x = F.interpolate(x, size=ref.shape[-2:], mode="bilinear", align_corners=False) + return x + + def _encode_decode( + self, x: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Run encoder-decoder and return (class_logits, seg_logits, embedding).""" + e1 = self.enc1(x) + e2 = self.enc2(self.pool(e1)) + e3 = self.enc3(self.pool(e2)) + bottleneck = self.bottleneck(self.pool(e3)) + + class_logits = self.classifier(bottleneck) + embed = F.adaptive_avg_pool2d(bottleneck, output_size=(1, 1)).flatten(1) + + d3 = self.up3(bottleneck) + d3 = self._resize_like(d3, e3) + d3 = self.dec3(torch.cat([d3, e3], dim=1)) + + d2 = self.up2(d3) + d2 = self._resize_like(d2, e2) + d2 = self.dec2(torch.cat([d2, e2], dim=1)) + + d1 = self.up1(d2) + d1 = self._resize_like(d1, e1) + d1 = self.dec1(torch.cat([d1, e1], dim=1)) + + seg_logits = self.seg_head(d1) + seg_logits = self._resize_like(seg_logits, x) + return class_logits, seg_logits, embed + + def forward(self, **kwargs) -> Dict[str, torch.Tensor]: + """Forward pass. + + Required inputs: + - ``feature_key`` inferred from dataset schema + - ``label_key`` during training + + Optional inputs: + - ``seg_target``: tensor of shape (B, 1, H, W) or broadcastable form + - ``embed``: if True, only returns pooled encoder embedding + """ + x = kwargs[self.feature_key] + if not isinstance(x, torch.Tensor): + x = torch.as_tensor(x) + x = x.to(self.device, dtype=torch.float32) + x = self._to_nchw(x) + x = self._align_channels(x) + + class_logits, seg_logits, embed = self._encode_decode(x) + + if kwargs.get("embed", False): + return {"embed": embed} + + results: Dict[str, torch.Tensor] = { + "logit": class_logits, + "y_prob": self.prepare_y_prob(class_logits), + "seg_logit": seg_logits, + } + + if self.label_key not in kwargs: + return results + + y_true = kwargs[self.label_key].to(self.device) + cls_loss = self.get_loss_function()(class_logits, y_true) + + seg_target = kwargs.get("seg_target") + if seg_target is None: + seg_target = self._build_pseudo_mask(x) + else: + if not isinstance(seg_target, torch.Tensor): + seg_target = torch.as_tensor(seg_target) + seg_target = seg_target.to(self.device, dtype=torch.float32) + if seg_target.dim() == 3: + seg_target = seg_target.unsqueeze(1) + if seg_target.dim() == 2: + seg_target = seg_target.unsqueeze(0).unsqueeze(0) + if seg_target.shape[1] != 1: + seg_target = seg_target.mean(dim=1, keepdim=True) + if seg_target.shape[-2:] != seg_logits.shape[-2:]: + seg_target = F.interpolate( + seg_target, + size=seg_logits.shape[-2:], + mode="nearest", + ) + + seg_loss = F.binary_cross_entropy_with_logits(seg_logits, seg_target) + loss = cls_loss + self.seg_loss_weight * seg_loss + + results.update( + { + "loss": loss, + "cls_loss": cls_loss, + "seg_loss": seg_loss, + "y_true": y_true, + } + ) + return results diff --git a/tests/core/test_retina_unet.py b/tests/core/test_retina_unet.py new file mode 100644 index 000000000..664949c12 --- /dev/null +++ b/tests/core/test_retina_unet.py @@ -0,0 +1,95 @@ +import unittest + +import torch + +from pyhealth.datasets import create_sample_dataset, get_dataloader +from pyhealth.models import RetinaUNet + + +class TestRetinaUNet(unittest.TestCase): + """Unit tests for RetinaUNet with synthetic image tensors.""" + + def setUp(self): + samples = [] + for idx in range(4): + image = [ + [float((r + c + idx) % 5) / 5.0 for c in range(32)] + for r in range(32) + ] + samples.append( + { + "patient_id": f"patient-{idx}", + "visit_id": f"visit-{idx}", + "image": image, + "label": idx % 2, + } + ) + + self.dataset = create_sample_dataset( + samples=samples, + input_schema={"image": "tensor"}, + output_schema={"label": "binary"}, + dataset_name="retina_unet_toy", + ) + self.model = RetinaUNet(dataset=self.dataset, in_channels=1, base_channels=16) + + def test_initialization(self): + self.assertEqual(self.model.feature_key, "image") + self.assertEqual(self.model.label_key, "label") + self.assertEqual(self.model.in_channels, 1) + self.assertEqual(self.model.base_channels, 16) + + def test_forward_train(self): + loader = get_dataloader(self.dataset, batch_size=2, shuffle=False) + batch = next(iter(loader)) + + with torch.no_grad(): + output = self.model(**batch) + + self.assertIn("loss", output) + self.assertIn("y_prob", output) + self.assertIn("y_true", output) + self.assertIn("logit", output) + self.assertIn("seg_logit", output) + self.assertEqual(output["y_prob"].shape[0], 2) + self.assertEqual(output["seg_logit"].shape[0], 2) + self.assertEqual(output["seg_logit"].shape[1], 1) + self.assertEqual(output["loss"].dim(), 0) + + def test_backward(self): + loader = get_dataloader(self.dataset, batch_size=2, shuffle=True) + batch = next(iter(loader)) + output = self.model(**batch) + output["loss"].backward() + + has_grad = any( + parameter.requires_grad and parameter.grad is not None + for parameter in self.model.parameters() + ) + self.assertTrue(has_grad) + + def test_embed_mode(self): + loader = get_dataloader(self.dataset, batch_size=2, shuffle=False) + batch = next(iter(loader)) + batch["embed"] = True + + with torch.no_grad(): + output = self.model(**batch) + + self.assertIn("embed", output) + self.assertEqual(output["embed"].shape[0], 2) + + def test_custom_seg_target(self): + loader = get_dataloader(self.dataset, batch_size=2, shuffle=False) + batch = next(iter(loader)) + + seg_target = torch.randint(0, 2, size=(2, 1, 32, 32)).float() + batch["seg_target"] = seg_target + output = self.model(**batch) + + self.assertIn("seg_loss", output) + self.assertTrue(torch.isfinite(output["seg_loss"])) + + +if __name__ == "__main__": + unittest.main() From f4a0f3297ac1203bc3b2cb96820d626b0cce4cd9 Mon Sep 17 00:00:00 2001 From: mtuann Date: Thu, 2 Apr 2026 03:25:03 -0500 Subject: [PATCH 2/2] Refine RetinaUNet example and align PR validation --- examples/lidc_nodule_detection_retina_unet.py | 351 +++++-- pyhealth/models/retina_unet.py | 877 +++++++++++++++--- tests/core/test_retina_unet.py | 127 +-- 3 files changed, 1079 insertions(+), 276 deletions(-) diff --git a/examples/lidc_nodule_detection_retina_unet.py b/examples/lidc_nodule_detection_retina_unet.py index 7ca272240..77f342fc0 100644 --- a/examples/lidc_nodule_detection_retina_unet.py +++ b/examples/lidc_nodule_detection_retina_unet.py @@ -1,26 +1,104 @@ """Synthetic ablation example for RetinaUNet. -This script follows the course naming convention: -`examples/{dataset}_{task_name}_{model}.py` +Contributor: Tuan Nguyen +NetID: tuanmn2 +Paper: Retina U-Net: Embarrassingly Simple Exploitation of Segmentation + Supervision for Medical Object Detection +Paper link: https://proceedings.mlr.press/v116/jaeger20a/jaeger20a.pdf +Description: Lightweight synthetic ablation example for the RetinaUNet model + in PyHealth. -The dataset is a tiny synthetic stand-in to validate: -- model wiring -- training/evaluation loop -- ablation workflow +This example is intentionally lightweight so it can be used in a PyHealth pull +request without depending on any real dataset. It demonstrates the intended +task contract for RetinaUNet: -For full LIDC experiments, replace `build_synthetic_dataset` with a real loader. +- image tensor input +- bounding boxes + box labels for detection +- segmentation mask as auxiliary supervision + +It also doubles as a minimal ablation example by comparing multiple model +configurations on the same synthetic dataset. The comparison is intentionally +small and fast; its purpose is to show how hyperparameter variations can be +tested, not to reproduce the original Retina U-Net paper benchmark. """ from __future__ import annotations import argparse import random +from dataclasses import dataclass +from typing import Dict, List import torch +from torch.utils.data import DataLoader, Dataset -from pyhealth.datasets import create_sample_dataset, get_dataloader, split_by_visit from pyhealth.models import RetinaUNet -from pyhealth.trainer import Trainer + + +class _DummyOutputProcessor: + def size(self): + return 1 + + +class _RetinaConfigDataset: + def __init__(self): + self.input_schema = {"image": "tensor"} + self.output_schema = {"label": "binary"} + self.output_processors = {"label": _DummyOutputProcessor()} + + +class SyntheticLIDCDataset(Dataset): + """Small synthetic dataset with positive and negative slices.""" + + def __init__(self, num_samples: int, image_size: int): + self.samples: List[Dict[str, torch.Tensor]] = [] + for idx in range(num_samples): + image = torch.zeros(1, image_size, image_size) + seg_target = torch.zeros(image_size, image_size, dtype=torch.long) + + if idx % 2 == 0: + x1 = 8 + (idx % 6) + y1 = 10 + (idx % 5) + x2 = x1 + 10 + y2 = y1 + 8 + image[:, y1:y2, x1:x2] = 1.0 + class_label = 1 if (idx % 4 == 0) else 2 + seg_target[y1:y2, x1:x2] = class_label + boxes = torch.tensor([[x1, y1, x2, y2]], dtype=torch.float32) + labels = torch.tensor([class_label], dtype=torch.long) + else: + boxes = torch.zeros((0, 4), dtype=torch.float32) + labels = torch.zeros((0,), dtype=torch.long) + + self.samples.append( + { + "image": image, + "boxes": boxes, + "labels": labels, + "seg_target": seg_target, + } + ) + + def __len__(self) -> int: + return len(self.samples) + + def __getitem__(self, index: int) -> Dict[str, torch.Tensor]: + return self.samples[index] + + +def collate_detection_batch( + batch: List[Dict[str, torch.Tensor]], +) -> Dict[str, torch.Tensor | list]: + images = torch.stack([sample["image"] for sample in batch], dim=0) + seg_target = torch.stack([sample["seg_target"] for sample in batch], dim=0) + boxes = [sample["boxes"] for sample in batch] + labels = [sample["labels"] for sample in batch] + return { + "image": images, + "boxes": boxes, + "labels": labels, + "seg_target": seg_target, + } def set_seed(seed: int) -> None: @@ -28,98 +106,209 @@ def set_seed(seed: int) -> None: torch.manual_seed(seed) -def make_circle_image(size: int, radius: int, cx: int, cy: int) -> list[list[float]]: - image = [] - for y in range(size): - row = [] - for x in range(size): - dist = ((x - cx) ** 2 + (y - cy) ** 2) ** 0.5 - row.append(1.0 if dist <= radius else 0.0) - image.append(row) - return image - - -def build_synthetic_dataset(num_samples: int, image_size: int): - samples = [] - for idx in range(num_samples): - has_nodule = idx % 2 - radius = 4 if has_nodule else 2 - cx = 10 + (idx % 8) - cy = 12 + (idx % 10) - image = make_circle_image(image_size, radius, cx, cy) - samples.append( - { - "patient_id": f"p-{idx}", - "visit_id": f"v-{idx}", - "image": image, - "label": has_nodule, - } +@dataclass(frozen=True) +class AblationConfig: + """Small configuration record for the synthetic comparison.""" + + name: str + base_channels: int + learning_rate: float + + +def build_ablation_configs() -> List[AblationConfig]: + """Returns a small set of fast, PR-friendly ablation configurations.""" + return [ + AblationConfig(name="small_width", base_channels=8, learning_rate=1e-3), + AblationConfig(name="default_width", base_channels=16, learning_rate=1e-3), + AblationConfig(name="lower_lr", base_channels=16, learning_rate=5e-4), + ] + + +def run_epoch( + model: RetinaUNet, + loader: DataLoader, + optimizer: torch.optim.Optimizer, + device: torch.device, +) -> Dict[str, float]: + """Runs one epoch and returns average training losses.""" + model.train() + totals = { + "loss": 0.0, + "cls_loss": 0.0, + "bbox_loss": 0.0, + "seg_loss": 0.0, + } + num_steps = 0 + + for step, batch in enumerate(loader, start=1): + batch["image"] = batch["image"].to(device) + batch["seg_target"] = batch["seg_target"].to(device) + batch["boxes"] = [box.to(device) for box in batch["boxes"]] + batch["labels"] = [label.to(device) for label in batch["labels"]] + + optimizer.zero_grad(set_to_none=True) + output = model(**batch) + output["loss"].backward() + optimizer.step() + + num_steps += 1 + totals["loss"] += output["loss"].item() + totals["cls_loss"] += output["cls_loss"].item() + totals["bbox_loss"] += output["bbox_loss"].item() + totals["seg_loss"] += output["seg_loss"].item() + + print( + "[train] " + f"step={step} " + f"loss={output['loss'].item():.5f} " + f"cls={output['cls_loss'].item():.5f} " + f"bbox={output['bbox_loss'].item():.5f} " + f"seg={output['seg_loss'].item():.5f}" ) - dataset = create_sample_dataset( - samples=samples, - input_schema={"image": "tensor"}, - output_schema={"label": "binary"}, - dataset_name="lidc_synthetic", + + return {key: value / max(num_steps, 1) for key, value in totals.items()} + + +@torch.no_grad() +def run_eval( + model: RetinaUNet, + loader: DataLoader, + device: torch.device, +) -> Dict[str, float]: + """Runs a tiny evaluation pass and summarizes prediction volume.""" + model.eval() + batch = next(iter(loader)) + batch["image"] = batch["image"].to(device) + output = model(image=batch["image"]) + + num_detections = [ + int(detection["boxes"].shape[0]) for detection in output["detections"] + ] + avg_detections = sum(num_detections) / max(len(num_detections), 1) + + print(f"[eval] batch_detections={len(output['detections'])}") + print( + f"[eval] first_boxes_shape={tuple(output['detections'][0]['boxes'].shape)} " + f"first_scores_shape={tuple(output['detections'][0]['scores'].shape)}" ) - return dataset + + return { + "avg_detections_per_sample": avg_detections, + "max_detections_in_batch": float(max(num_detections, default=0)), + } def run_ablation( - dataset, - batch_size: int, - epochs: int, - device: str, - base_channels: int, - seg_loss_weight: float, -) -> dict: - train_data, val_data, test_data = split_by_visit(dataset, [0.6, 0.2, 0.2], seed=7) - train_loader = get_dataloader(train_data, batch_size=batch_size, shuffle=True) - val_loader = get_dataloader(val_data, batch_size=batch_size, shuffle=False) - test_loader = get_dataloader(test_data, batch_size=batch_size, shuffle=False) + config: AblationConfig, + dataset: SyntheticLIDCDataset, + args: argparse.Namespace, + device: torch.device, +) -> Dict[str, float | int | str]: + """Trains and evaluates one ablation configuration.""" + set_seed(args.seed) + loader = DataLoader( + dataset, + batch_size=args.batch_size, + shuffle=False, + collate_fn=collate_detection_batch, + ) + + config_dataset = _RetinaConfigDataset() model = RetinaUNet( - dataset=dataset, + dataset=config_dataset, in_channels=1, - base_channels=base_channels, - seg_loss_weight=seg_loss_weight, + num_classes=2, + base_channels=config.base_channels, + ).to(device) + optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate) + + print("\n" + "=" * 80) + print( + f"Running config={config.name} " + f"base_channels={config.base_channels} " + f"lr={config.learning_rate}" + ) + + train_metrics: Dict[str, float] = {} + for epoch in range(args.epochs): + print(f"Epoch {epoch + 1}") + train_metrics = run_epoch( + model=model, + loader=loader, + optimizer=optimizer, + device=device, + ) + + eval_metrics = run_eval(model=model, loader=loader, device=device) + + return { + "config": config.name, + "base_channels": config.base_channels, + "learning_rate": config.learning_rate, + "train_loss": train_metrics["loss"], + "cls_loss": train_metrics["cls_loss"], + "bbox_loss": train_metrics["bbox_loss"], + "seg_loss": train_metrics["seg_loss"], + "avg_detections": eval_metrics["avg_detections_per_sample"], + } + + +def print_summary_table(results: List[Dict[str, float | int | str]]) -> None: + """Prints a compact comparison table for the ablation study.""" + print("\n" + "=" * 80) + print("Synthetic ablation summary") + print( + "config base_channels lr " + "train_loss cls_loss bbox_loss seg_loss avg_dets" + ) + for result in results: + print( + f"{result['config']:<16} " + f"{int(result['base_channels']):>13} " + f"{float(result['learning_rate']):<10.5f} " + f"{float(result['train_loss']):>10.5f} " + f"{float(result['cls_loss']):>10.5f} " + f"{float(result['bbox_loss']):>11.5f} " + f"{float(result['seg_loss']):>10.5f} " + f"{float(result['avg_detections']):>9.2f}" + ) + + best_result = min(results, key=lambda item: float(item["train_loss"])) + print("\nObservation:") + print( + "- Lower train loss is better in this toy setup. " + f"The best synthetic config here is {best_result['config']}." + ) + print( + "- This comparison is only a lightweight PR-facing ablation and should " + "not be interpreted as a paper benchmark." ) - trainer = Trainer(model=model, device=device) - trainer.train(train_dataloader=train_loader, val_dataloader=val_loader, epochs=epochs) - return trainer.evaluate(test_loader) def main(): parser = argparse.ArgumentParser() parser.add_argument("--epochs", type=int, default=1) - parser.add_argument("--batch-size", type=int, default=8) - parser.add_argument("--num-samples", type=int, default=80) - parser.add_argument("--image-size", type=int, default=32) + parser.add_argument("--batch-size", type=int, default=4) + parser.add_argument("--num-samples", type=int, default=20) + parser.add_argument("--image-size", type=int, default=64) parser.add_argument("--seed", type=int, default=7) parser.add_argument("--device", type=str, default="cpu") args = parser.parse_args() - set_seed(args.seed) - dataset = build_synthetic_dataset(args.num_samples, args.image_size) + device = torch.device(args.device) + dataset = SyntheticLIDCDataset( + num_samples=args.num_samples, + image_size=args.image_size, + ) + configs = build_ablation_configs() - experiments = [ - {"name": "cls_only", "base_channels": 16, "seg_loss_weight": 0.0}, - {"name": "cls_plus_seg", "base_channels": 16, "seg_loss_weight": 0.1}, + print("Running synthetic RetinaUNet ablation example") + results = [ + run_ablation(config=config, dataset=dataset, args=args, device=device) + for config in configs ] - - print("Running RetinaUNet ablation on synthetic LIDC-style data") - for exp in experiments: - result = run_ablation( - dataset=dataset, - batch_size=args.batch_size, - epochs=args.epochs, - device=args.device, - base_channels=exp["base_channels"], - seg_loss_weight=exp["seg_loss_weight"], - ) - print("-" * 80) - print(f"Experiment: {exp['name']}") - print(f"base_channels={exp['base_channels']}, seg_loss_weight={exp['seg_loss_weight']}") - print(f"Metrics: {result}") + print_summary_table(results) if __name__ == "__main__": diff --git a/pyhealth/models/retina_unet.py b/pyhealth/models/retina_unet.py index 5b913b69f..a3d748cb1 100644 --- a/pyhealth/models/retina_unet.py +++ b/pyhealth/models/retina_unet.py @@ -1,6 +1,18 @@ +"""RetinaUNet model for PyHealth. + +Contributor: Tuan Nguyen +NetID: tuanmn2 +Paper: Retina U-Net: Embarrassingly Simple Exploitation of Segmentation + Supervision for Medical Object Detection +Paper link: https://proceedings.mlr.press/v116/jaeger20a/jaeger20a.pdf +Description: Retina U-Net style medical object detection model with an + auxiliary segmentation branch for PyHealth. +""" + from __future__ import annotations -from typing import Dict +import math +from typing import Any, Dict, List, Sequence import torch import torch.nn as nn @@ -10,16 +22,173 @@ from pyhealth.models.base_model import BaseModel -class ConvBlock(nn.Module): - """A two-layer convolutional block used by RetinaUNet.""" +def _box_area(boxes: torch.Tensor) -> torch.Tensor: + return (boxes[:, 2] - boxes[:, 0]).clamp(min=0) * ( + boxes[:, 3] - boxes[:, 1] + ).clamp(min=0) + + +def _box_iou(boxes1: torch.Tensor, boxes2: torch.Tensor) -> torch.Tensor: + if boxes1.numel() == 0 or boxes2.numel() == 0: + return boxes1.new_zeros((boxes1.shape[0], boxes2.shape[0])) + + area1 = _box_area(boxes1) + area2 = _box_area(boxes2) + + top_left = torch.maximum(boxes1[:, None, :2], boxes2[:, :2]) + bottom_right = torch.minimum(boxes1[:, None, 2:], boxes2[:, 2:]) + wh = (bottom_right - top_left).clamp(min=0) + inter = wh[..., 0] * wh[..., 1] + union = area1[:, None] + area2 - inter + return inter / union.clamp(min=1e-6) + + +def _nms(boxes: torch.Tensor, scores: torch.Tensor, iou_threshold: float) -> torch.Tensor: + if boxes.numel() == 0: + return boxes.new_zeros((0,), dtype=torch.long) + + order = scores.argsort(descending=True) + keep: List[torch.Tensor] = [] + + while order.numel() > 0: + current = order[0] + keep.append(current) + if order.numel() == 1: + break + remaining = order[1:] + ious = _box_iou(boxes[current].unsqueeze(0), boxes[remaining]).squeeze(0) + order = remaining[ious <= iou_threshold] + + return torch.stack(keep) + + +def _weighted_box_clustering( + boxes: torch.Tensor, + scores: torch.Tensor, + iou_threshold: float, + expected_num_predictions: float = 1.0, + min_score: float = 0.01, +) -> tuple[torch.Tensor, torch.Tensor]: + if boxes.numel() == 0: + return boxes.new_zeros((0,)), boxes.new_zeros((0, 4)) + + order = scores.argsort(descending=True) + areas = _box_area(boxes).clamp(min=1e-6) + keep_scores: List[torch.Tensor] = [] + keep_boxes: List[torch.Tensor] = [] + + while order.numel() > 0: + current = order[0] + cluster_ious = _box_iou(boxes[current].unsqueeze(0), boxes[order]).squeeze(0) + matches = cluster_ious > iou_threshold + cluster_indices = order[matches] + + cluster_boxes = boxes[cluster_indices] + cluster_scores = scores[cluster_indices] + cluster_areas = areas[cluster_indices] + cluster_overlap = cluster_ious[matches].clamp(min=1e-6) + + score_weights = cluster_overlap * cluster_areas + weighted_scores = cluster_scores * score_weights + expected = max(float(expected_num_predictions), 1.0) + missing = max(0.0, expected - float(cluster_indices.numel())) + mean_weight = score_weights.mean() + denom = score_weights.sum() + missing * mean_weight + if float(denom) <= 0.0: + avg_score = cluster_scores.mean() + else: + avg_score = weighted_scores.sum() / denom + + coord_denom = weighted_scores.sum().clamp(min=1e-6) + avg_box = (cluster_boxes * weighted_scores[:, None]).sum(dim=0) / coord_denom + + if float(avg_score) > min_score: + keep_scores.append(avg_score) + keep_boxes.append(avg_box) + + order = order[~matches] + + if not keep_scores: + return boxes.new_zeros((0,)), boxes.new_zeros((0, 4)) + return torch.stack(keep_scores), torch.stack(keep_boxes) + + +def _encode_boxes(anchors: torch.Tensor, boxes: torch.Tensor) -> torch.Tensor: + anchor_wh = (anchors[:, 2:] - anchors[:, :2]).clamp(min=1e-6) + anchor_ctr = anchors[:, :2] + 0.5 * anchor_wh + + box_wh = (boxes[:, 2:] - boxes[:, :2]).clamp(min=1e-6) + box_ctr = boxes[:, :2] + 0.5 * box_wh + + delta_ctr = (box_ctr - anchor_ctr) / anchor_wh + delta_wh = torch.log(box_wh / anchor_wh) + return torch.cat([delta_ctr, delta_wh], dim=1) + + +def _decode_boxes(anchors: torch.Tensor, deltas: torch.Tensor) -> torch.Tensor: + anchor_wh = (anchors[:, 2:] - anchors[:, :2]).clamp(min=1e-6) + anchor_ctr = anchors[:, :2] + 0.5 * anchor_wh + pred_ctr = deltas[:, :2] * anchor_wh + anchor_ctr + pred_wh = deltas[:, 2:].exp() * anchor_wh + + top_left = pred_ctr - 0.5 * pred_wh + bottom_right = pred_ctr + 0.5 * pred_wh + return torch.cat([top_left, bottom_right], dim=1) + + +def _clip_boxes(boxes: torch.Tensor, height: int, width: int) -> torch.Tensor: + boxes = boxes.clone() + boxes[:, 0] = boxes[:, 0].clamp(min=0, max=width) + boxes[:, 1] = boxes[:, 1].clamp(min=0, max=height) + boxes[:, 2] = boxes[:, 2].clamp(min=0, max=width) + boxes[:, 3] = boxes[:, 3].clamp(min=0, max=height) + return boxes + + +def _multiclass_dice_loss( + logits: torch.Tensor, + target: torch.Tensor, + num_classes: int, +) -> torch.Tensor: + probs = F.softmax(logits, dim=1) + one_hot = F.one_hot( + target.long().clamp(min=0, max=num_classes - 1), + num_classes=num_classes, + ).permute(0, 3, 1, 2).to(dtype=probs.dtype) + + class_losses: List[torch.Tensor] = [] + for class_index in range(1, num_classes): + class_probs = probs[:, class_index] + class_target = one_hot[:, class_index] + intersection = (class_probs * class_target).sum(dim=(-2, -1)) + denom = class_probs.sum(dim=(-2, -1)) + class_target.sum(dim=(-2, -1)) + dice = (2 * intersection + 1e-6) / (denom + 1e-6) + class_losses.append(1 - dice.mean()) + + if not class_losses: + return logits.new_zeros(()) + return torch.stack(class_losses).mean() + + +def _multiclass_cross_entropy_loss( + logits: torch.Tensor, + target: torch.Tensor, +) -> torch.Tensor: + log_probs = F.log_softmax(logits, dim=1) + target = target.long().clamp(min=0, max=logits.shape[1] - 1) + gathered = torch.gather(log_probs, dim=1, index=target.unsqueeze(1)).squeeze(1) + return -gathered.mean() + + +class ConvBlock(nn.Module): def __init__(self, in_channels: int, out_channels: int): super().__init__() self.block = nn.Sequential( - nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), + nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), - nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), + nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), ) @@ -28,219 +197,649 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return self.block(x) -class RetinaUNet(BaseModel): - """Retina U-Net style model with an auxiliary segmentation branch. +class DownBlock(nn.Module): + def __init__(self, in_channels: int, out_channels: int): + super().__init__() + self.pool = nn.MaxPool2d(kernel_size=2, stride=2) + self.block = ConvBlock(in_channels, out_channels) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.block(self.pool(x)) + + +class RetinaHead(nn.Module): + def __init__(self, in_channels: int, hidden_channels: int, out_channels: int): + super().__init__() + layers: List[nn.Module] = [] + current = in_channels + for _ in range(4): + layers.extend( + [ + nn.Conv2d(current, hidden_channels, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + ] + ) + current = hidden_channels + layers.append(nn.Conv2d(hidden_channels, out_channels, kernel_size=3, padding=1)) + self.head = nn.Sequential(*layers) - This implementation is intentionally lightweight for reproducibility - experiments in PyHealth: + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.head(x) - - An encoder backbone extracts image features. - - A classification head predicts image-level labels. - - A U-Net-like decoder predicts an auxiliary segmentation map. - The final training loss is: +class RetinaUNet(BaseModel): + """Detection-oriented Retina U-Net with auxiliary segmentation supervision. - ``loss = cls_loss + seg_loss_weight * seg_loss`` + This follows the core idea of the original Retina U-Net paper and the + MIC-DKFZ reference implementation: - where ``seg_loss`` is computed from either a provided ``seg_target`` or a - pseudo-mask created from image intensity. + - multi-scale detection heads on coarse pyramid levels + - a U-FPN decoder reaching full image resolution + - segmentation loss used as auxiliary supervision - Args: - dataset: SampleDataset used to infer feature/label keys and output size. - in_channels: Expected number of input channels. Default is 1. - base_channels: Width of the first encoder stage. Default is 32. - seg_loss_weight: Weight for auxiliary segmentation loss. Default is 0.1. - dropout: Dropout used in the classification head. Default is 0.1. + The current implementation is 2D-only because the LIDC preprocessing in + this project produces 2.5D slice stacks for 2D detectors. """ def __init__( self, dataset: SampleDataset, in_channels: int = 1, + num_classes: int = 1, base_channels: int = 32, - seg_loss_weight: float = 0.1, - dropout: float = 0.1, + anchor_sizes: Sequence[float] = (8.0, 16.0, 32.0, 64.0), + anchor_scales: Sequence[float] = (1.0, 2 ** (1 / 3), 2 ** (2 / 3)), + aspect_ratios: Sequence[float] = (0.5, 1.0, 2.0), + positive_iou_threshold: float = 0.5, + negative_iou_threshold: float = 0.1, + negative_to_positive_ratio: float = 1.0, + seg_loss_weight: float = 1.0, + bbox_loss_weight: float = 1.0, + cls_loss_weight: float = 1.0, + score_threshold: float = 0.1, + nms_threshold: float = 1e-5, + postprocess_method: str = "wbc", + max_detections: int = 100, ): super().__init__(dataset=dataset) if len(self.feature_keys) != 1: raise ValueError("RetinaUNet supports exactly one image-like feature key.") - if len(self.label_keys) != 1: - raise ValueError("RetinaUNet supports exactly one label key.") if in_channels <= 0: raise ValueError("in_channels must be positive.") + if num_classes <= 0: + raise ValueError("num_classes must be positive.") if base_channels <= 0: raise ValueError("base_channels must be positive.") - if seg_loss_weight < 0: - raise ValueError("seg_loss_weight must be non-negative.") + if positive_iou_threshold <= negative_iou_threshold: + raise ValueError("positive_iou_threshold must be greater than negative_iou_threshold.") + if postprocess_method not in {"nms", "wbc"}: + raise ValueError("postprocess_method must be one of {'nms', 'wbc'}.") self.feature_key = self.feature_keys[0] - self.label_key = self.label_keys[0] + self.label_key = self.label_keys[0] if self.label_keys else "label" self.in_channels = in_channels + self.num_classes = num_classes + self.num_head_classes = num_classes + 1 self.base_channels = base_channels + self.anchor_sizes = tuple(anchor_sizes) + self.anchor_scales = tuple(anchor_scales) + self.aspect_ratios = tuple(aspect_ratios) + self.positive_iou_threshold = positive_iou_threshold + self.negative_iou_threshold = negative_iou_threshold + self.negative_to_positive_ratio = negative_to_positive_ratio self.seg_loss_weight = seg_loss_weight - self.dropout = dropout + self.bbox_loss_weight = bbox_loss_weight + self.cls_loss_weight = cls_loss_weight + self.score_threshold = score_threshold + self.nms_threshold = nms_threshold + self.postprocess_method = postprocess_method + self.max_detections = max_detections + self.num_anchors = len(self.anchor_scales) * len(self.aspect_ratios) + self.pyramid_strides = (4, 8, 16, 32) c1 = base_channels c2 = base_channels * 2 c3 = base_channels * 4 c4 = base_channels * 8 - - self.pool = nn.MaxPool2d(kernel_size=2, stride=2) - self.enc1 = ConvBlock(in_channels, c1) - self.enc2 = ConvBlock(c1, c2) - self.enc3 = ConvBlock(c2, c3) - self.bottleneck = ConvBlock(c3, c4) - - self.up3 = nn.ConvTranspose2d(c4, c3, kernel_size=2, stride=2) - self.dec3 = ConvBlock(c3 * 2, c3) - self.up2 = nn.ConvTranspose2d(c3, c2, kernel_size=2, stride=2) - self.dec2 = ConvBlock(c2 * 2, c2) - self.up1 = nn.ConvTranspose2d(c2, c1, kernel_size=2, stride=2) - self.dec1 = ConvBlock(c1 * 2, c1) - self.seg_head = nn.Conv2d(c1, 1, kernel_size=1) - - output_size = self.get_output_size() - self.classifier = nn.Sequential( - nn.AdaptiveAvgPool2d((1, 1)), - nn.Flatten(), - nn.Dropout(p=dropout), - nn.Linear(c4, output_size), + c5 = base_channels * 16 + c6 = base_channels * 16 + + self.stem = ConvBlock(in_channels, c1) + self.enc1 = DownBlock(c1, c2) + self.enc2 = DownBlock(c2, c3) + self.enc3 = DownBlock(c3, c4) + self.enc4 = DownBlock(c4, c5) + self.bottleneck = DownBlock(c5, c6) + + self.lat5 = nn.Conv2d(c6, c5, kernel_size=1) + self.lat4 = nn.Conv2d(c5, c5, kernel_size=1) + self.lat3 = nn.Conv2d(c4, c5, kernel_size=1) + self.lat2 = nn.Conv2d(c3, c5, kernel_size=1) + self.lat1 = nn.Conv2d(c2, c5, kernel_size=1) + self.lat0 = nn.Conv2d(c1, c5, kernel_size=1) + + self.out5 = nn.Conv2d(c5, c5, kernel_size=3, padding=1) + self.out4 = nn.Conv2d(c5, c5, kernel_size=3, padding=1) + self.out3 = nn.Conv2d(c5, c5, kernel_size=3, padding=1) + self.out2 = nn.Conv2d(c5, c5, kernel_size=3, padding=1) + self.out0 = nn.Conv2d(c5, c5, kernel_size=3, padding=1) + + self.cls_head = RetinaHead( + in_channels=c5, + hidden_channels=c5, + out_channels=self.num_anchors * self.num_head_classes, + ) + self.box_head = RetinaHead( + in_channels=c5, + hidden_channels=c5, + out_channels=self.num_anchors * 4, ) + self.seg_head = nn.Conv2d(c5, self.num_head_classes, kernel_size=1) @staticmethod def _to_nchw(x: torch.Tensor) -> torch.Tensor: - """Convert input tensor to NCHW format.""" if x.dim() == 2: x = x.unsqueeze(0).unsqueeze(0) elif x.dim() == 3: - # Treat as NHW by default. x = x.unsqueeze(1) elif x.dim() != 4: raise ValueError(f"Expected 2D/3D/4D tensor, got shape {tuple(x.shape)}.") if x.dim() == 4 and x.shape[1] not in {1, 3} and x.shape[-1] in {1, 3}: - # NHWC -> NCHW x = x.permute(0, 3, 1, 2).contiguous() return x def _align_channels(self, x: torch.Tensor) -> torch.Tensor: - """Match input channel count to model configuration.""" if x.shape[1] == self.in_channels: return x if self.in_channels == 1: return x.mean(dim=1, keepdim=True) - if x.shape[1] == 1 and self.in_channels == 3: - return x.repeat(1, 3, 1, 1) + if x.shape[1] == 1 and self.in_channels > 1: + return x.repeat(1, self.in_channels, 1, 1) if x.shape[1] > self.in_channels: return x[:, : self.in_channels] repeats = (self.in_channels + x.shape[1] - 1) // x.shape[1] x = x.repeat(1, repeats, 1, 1) return x[:, : self.in_channels] - @staticmethod - def _build_pseudo_mask(x: torch.Tensor) -> torch.Tensor: - """Create a pseudo segmentation target from image intensity.""" - intensity = x.mean(dim=1, keepdim=True) - threshold = intensity.mean(dim=(2, 3), keepdim=True) - return (intensity > threshold).float() - @staticmethod def _resize_like(x: torch.Tensor, ref: torch.Tensor) -> torch.Tensor: if x.shape[-2:] != ref.shape[-2:]: x = F.interpolate(x, size=ref.shape[-2:], mode="bilinear", align_corners=False) return x - def _encode_decode( + def _build_pyramid( self, x: torch.Tensor - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Run encoder-decoder and return (class_logits, seg_logits, embedding).""" - e1 = self.enc1(x) - e2 = self.enc2(self.pool(e1)) - e3 = self.enc3(self.pool(e2)) - bottleneck = self.bottleneck(self.pool(e3)) - - class_logits = self.classifier(bottleneck) - embed = F.adaptive_avg_pool2d(bottleneck, output_size=(1, 1)).flatten(1) - - d3 = self.up3(bottleneck) - d3 = self._resize_like(d3, e3) - d3 = self.dec3(torch.cat([d3, e3], dim=1)) - - d2 = self.up2(d3) - d2 = self._resize_like(d2, e2) - d2 = self.dec2(torch.cat([d2, e2], dim=1)) - - d1 = self.up1(d2) - d1 = self._resize_like(d1, e1) - d1 = self.dec1(torch.cat([d1, e1], dim=1)) - - seg_logits = self.seg_head(d1) - seg_logits = self._resize_like(seg_logits, x) - return class_logits, seg_logits, embed - - def forward(self, **kwargs) -> Dict[str, torch.Tensor]: - """Forward pass. - - Required inputs: - - ``feature_key`` inferred from dataset schema - - ``label_key`` during training - - Optional inputs: - - ``seg_target``: tensor of shape (B, 1, H, W) or broadcastable form - - ``embed``: if True, only returns pooled encoder embedding - """ - x = kwargs[self.feature_key] - if not isinstance(x, torch.Tensor): - x = torch.as_tensor(x) - x = x.to(self.device, dtype=torch.float32) - x = self._to_nchw(x) - x = self._align_channels(x) + ) -> tuple[list[torch.Tensor], torch.Tensor, torch.Tensor]: + c0 = self.stem(x) + c1 = self.enc1(c0) + c2 = self.enc2(c1) + c3 = self.enc3(c2) + c4 = self.enc4(c3) + c5 = self.bottleneck(c4) + + p5_pre = self.lat5(c5) + p4_pre = self.lat4(c4) + F.interpolate( + p5_pre, size=c4.shape[-2:], mode="bilinear", align_corners=False + ) + p3_pre = self.lat3(c3) + F.interpolate( + p4_pre, size=c3.shape[-2:], mode="bilinear", align_corners=False + ) + p2_pre = self.lat2(c2) + F.interpolate( + p3_pre, size=c2.shape[-2:], mode="bilinear", align_corners=False + ) + p1_pre = self.lat1(c1) + F.interpolate( + p2_pre, size=c1.shape[-2:], mode="bilinear", align_corners=False + ) + p0_pre = self.lat0(c0) + F.interpolate( + p1_pre, size=c0.shape[-2:], mode="bilinear", align_corners=False + ) + + p5 = self.out5(p5_pre) + p4 = self.out4(p4_pre) + p3 = self.out3(p3_pre) + p2 = self.out2(p2_pre) + p0 = self.out0(p0_pre) + seg_logit = self.seg_head(p0) + embed = F.adaptive_avg_pool2d(c5, output_size=(1, 1)).flatten(1) + return [p2, p3, p4, p5], seg_logit, embed + + def _reshape_cls_output(self, x: torch.Tensor) -> torch.Tensor: + b, _, h, w = x.shape + x = x.view(b, self.num_anchors, self.num_head_classes, h, w) + x = x.permute(0, 3, 4, 1, 2).contiguous() + return x.view(b, -1, self.num_head_classes) + + def _reshape_box_output(self, x: torch.Tensor) -> torch.Tensor: + b, _, h, w = x.shape + x = x.view(b, self.num_anchors, 4, h, w) + x = x.permute(0, 3, 4, 1, 2).contiguous() + return x.view(b, -1, 4) + + def _generate_level_anchors( + self, + feature_shape: tuple[int, int], + stride: int, + base_size: float, + device: torch.device, + dtype: torch.dtype, + ) -> torch.Tensor: + height, width = feature_shape + shifts_y = (torch.arange(height, device=device, dtype=dtype) + 0.5) * stride + shifts_x = (torch.arange(width, device=device, dtype=dtype) + 0.5) * stride + grid_y, grid_x = torch.meshgrid(shifts_y, shifts_x, indexing="ij") + centers = torch.stack([grid_x, grid_y], dim=-1).reshape(-1, 2) + + anchor_shapes = [] + for scale in self.anchor_scales: + scaled = base_size * scale + for ratio in self.aspect_ratios: + width_val = scaled * math.sqrt(1.0 / ratio) + height_val = scaled * math.sqrt(ratio) + anchor_shapes.append([width_val, height_val]) + + wh = torch.tensor(anchor_shapes, device=device, dtype=dtype) + centers = centers[:, None, :].expand(-1, wh.shape[0], -1) + wh = wh[None, :, :].expand(centers.shape[0], -1, -1) + + top_left = centers - 0.5 * wh + bottom_right = centers + 0.5 * wh + return torch.cat([top_left, bottom_right], dim=-1).reshape(-1, 4) + + def _generate_anchors( + self, features: Sequence[torch.Tensor], image_shape: tuple[int, int] + ) -> torch.Tensor: + _ = image_shape + anchors = [ + self._generate_level_anchors( + feature_shape=(feature.shape[-2], feature.shape[-1]), + stride=stride, + base_size=base_size, + device=feature.device, + dtype=feature.dtype, + ) + for feature, stride, base_size in zip(features, self.pyramid_strides, self.anchor_sizes) + ] + return torch.cat(anchors, dim=0) + + def _normalize_box_targets( + self, + boxes: Sequence[torch.Tensor] | torch.Tensor | None, + labels: Sequence[torch.Tensor] | torch.Tensor | None, + batch_size: int, + device: torch.device, + ) -> tuple[list[torch.Tensor], list[torch.Tensor]]: + if boxes is None: + empty_boxes = [torch.zeros((0, 4), device=device) for _ in range(batch_size)] + empty_labels = [torch.zeros((0,), dtype=torch.long, device=device) for _ in range(batch_size)] + return empty_boxes, empty_labels + + if isinstance(boxes, torch.Tensor): + box_list = [boxes[i].to(device=device, dtype=torch.float32) for i in range(boxes.shape[0])] + else: + box_list = [box.to(device=device, dtype=torch.float32) for box in boxes] + + if labels is None: + label_list = [ + torch.ones((box.shape[0],), dtype=torch.long, device=device) for box in box_list + ] + elif isinstance(labels, torch.Tensor): + label_list = [labels[i].to(device=device, dtype=torch.long) for i in range(labels.shape[0])] + else: + label_list = [label.to(device=device, dtype=torch.long) for label in labels] - class_logits, seg_logits, embed = self._encode_decode(x) + if len(box_list) != batch_size or len(label_list) != batch_size: + raise ValueError("boxes and labels must provide one entry per batch element.") + return box_list, label_list - if kwargs.get("embed", False): - return {"embed": embed} + def _match_anchors( + self, anchors: torch.Tensor, gt_boxes: torch.Tensor, gt_labels: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + cls_targets = torch.full( + (anchors.shape[0],), -1, dtype=torch.long, device=anchors.device + ) + box_targets = torch.zeros((anchors.shape[0], 4), dtype=anchors.dtype, device=anchors.device) - results: Dict[str, torch.Tensor] = { - "logit": class_logits, - "y_prob": self.prepare_y_prob(class_logits), - "seg_logit": seg_logits, - } + if gt_boxes.numel() == 0: + cls_targets.fill_(0) + return cls_targets, box_targets - if self.label_key not in kwargs: - return results + ious = _box_iou(anchors, gt_boxes) + max_iou, matched_gt = ious.max(dim=1) + + cls_targets[max_iou < self.negative_iou_threshold] = 0 + positive = max_iou >= self.positive_iou_threshold + cls_targets[positive] = gt_labels[matched_gt[positive]] + box_targets[positive] = _encode_boxes(anchors[positive], gt_boxes[matched_gt[positive]]) + + best_anchor_per_gt = ious.argmax(dim=0) + gt_indices = torch.arange(gt_boxes.shape[0], device=anchors.device) + cls_targets[best_anchor_per_gt] = gt_labels[gt_indices] + box_targets[best_anchor_per_gt] = _encode_boxes( + anchors[best_anchor_per_gt], gt_boxes[gt_indices] + ) + return cls_targets, box_targets + + def _compute_cls_loss( + self, cls_logits: torch.Tensor, cls_targets: torch.Tensor + ) -> tuple[torch.Tensor, int, int]: + pos_mask = cls_targets > 0 + neg_mask = cls_targets == 0 + + zero = cls_logits.new_zeros(()) + pos_loss = zero + neg_loss = zero - y_true = kwargs[self.label_key].to(self.device) - cls_loss = self.get_loss_function()(class_logits, y_true) + pos_count = int(pos_mask.sum().item()) + neg_count = int(neg_mask.sum().item()) + + if pos_count > 0: + pos_loss = F.cross_entropy(cls_logits[pos_mask], cls_targets[pos_mask]) + + if neg_count > 0: + neg_losses = F.cross_entropy( + cls_logits[neg_mask], + cls_targets[neg_mask], + reduction="none", + ) + keep_neg = max(1, int(max(pos_count, 1) * self.negative_to_positive_ratio)) + keep_neg = min(keep_neg, neg_losses.shape[0]) + neg_loss = neg_losses.topk(keep_neg).values.mean() + + return (pos_loss + neg_loss) / 2, pos_count, neg_count + + def _compute_training_losses( + self, + cls_logits: torch.Tensor, + box_deltas: torch.Tensor, + anchors: torch.Tensor, + boxes: Sequence[torch.Tensor], + labels: Sequence[torch.Tensor], + seg_logit: torch.Tensor, + seg_target: torch.Tensor | None, + ) -> Dict[str, torch.Tensor]: + cls_loss = cls_logits.new_zeros(()) + bbox_loss = cls_logits.new_zeros(()) + total_pos = 0 + total_neg = 0 + + for image_ix in range(cls_logits.shape[0]): + cls_targets, box_targets = self._match_anchors( + anchors=anchors, + gt_boxes=boxes[image_ix], + gt_labels=labels[image_ix], + ) + image_cls_loss, pos_count, neg_count = self._compute_cls_loss( + cls_logits=cls_logits[image_ix], cls_targets=cls_targets + ) + cls_loss = cls_loss + image_cls_loss + total_pos += pos_count + total_neg += neg_count + + pos_mask = cls_targets > 0 + if pos_mask.any(): + bbox_loss = bbox_loss + F.smooth_l1_loss( + box_deltas[image_ix][pos_mask], + box_targets[pos_mask], + ) + + cls_loss = cls_loss / max(cls_logits.shape[0], 1) + bbox_loss = bbox_loss / max(cls_logits.shape[0], 1) - seg_target = kwargs.get("seg_target") if seg_target is None: - seg_target = self._build_pseudo_mask(x) + seg_loss = cls_logits.new_zeros(()) + seg_ce = cls_logits.new_zeros(()) + seg_dice = cls_logits.new_zeros(()) else: if not isinstance(seg_target, torch.Tensor): seg_target = torch.as_tensor(seg_target) - seg_target = seg_target.to(self.device, dtype=torch.float32) - if seg_target.dim() == 3: - seg_target = seg_target.unsqueeze(1) + seg_target = seg_target.to(self.device, dtype=torch.long) + if seg_target.dim() == 4 and seg_target.shape[1] == 1: + seg_target = seg_target[:, 0] + elif seg_target.dim() == 4 and seg_target.shape[1] > 1: + seg_target = seg_target.argmax(dim=1) if seg_target.dim() == 2: - seg_target = seg_target.unsqueeze(0).unsqueeze(0) - if seg_target.shape[1] != 1: - seg_target = seg_target.mean(dim=1, keepdim=True) - if seg_target.shape[-2:] != seg_logits.shape[-2:]: + seg_target = seg_target.unsqueeze(0) + if seg_target.shape[-2:] != seg_logit.shape[-2:]: seg_target = F.interpolate( - seg_target, - size=seg_logits.shape[-2:], + seg_target.unsqueeze(1).to(dtype=torch.float32), + size=seg_logit.shape[-2:], mode="nearest", + ).squeeze(1).to(dtype=torch.long) + seg_ce = _multiclass_cross_entropy_loss(seg_logit, seg_target) + seg_dice = _multiclass_dice_loss( + seg_logit, + seg_target, + num_classes=self.num_head_classes, + ) + seg_loss = 0.5 * (seg_ce + seg_dice) + + total_loss = ( + self.cls_loss_weight * cls_loss + + self.bbox_loss_weight * bbox_loss + + self.seg_loss_weight * seg_loss + ) + return { + "loss": total_loss, + "cls_loss": cls_loss, + "bbox_loss": bbox_loss, + "seg_loss": seg_loss, + "seg_ce_loss": seg_ce, + "seg_bce_loss": seg_ce, + "seg_dice_loss": seg_dice, + "positive_anchors": cls_logits.new_tensor(float(total_pos)), + "negative_anchors": cls_logits.new_tensor(float(total_neg)), + } + + def _decode_detections( + self, + cls_logits: torch.Tensor, + box_deltas: torch.Tensor, + anchors: torch.Tensor, + image_shape: tuple[int, int], + ) -> list[dict[str, torch.Tensor]]: + probs = F.softmax(cls_logits, dim=-1) + height, width = image_shape + detections: list[dict[str, torch.Tensor]] = [] + + for image_ix in range(cls_logits.shape[0]): + class_scores = probs[image_ix][:, 1:] + if class_scores.numel() == 0: + detections.append( + { + "boxes": anchors.new_zeros((0, 4)), + "scores": anchors.new_zeros((0,)), + "labels": torch.zeros((0,), dtype=torch.long, device=anchors.device), + } + ) + continue + + scores, labels = class_scores.max(dim=1) + labels = labels + 1 + keep = scores > self.score_threshold + if keep.sum() == 0: + detections.append( + { + "boxes": anchors.new_zeros((0, 4)), + "scores": anchors.new_zeros((0,)), + "labels": torch.zeros((0,), dtype=torch.long, device=anchors.device), + } + ) + continue + + decoded = _decode_boxes(anchors[keep], box_deltas[image_ix][keep]) + decoded = _clip_boxes(decoded, height=height, width=width) + kept_scores = scores[keep] + kept_labels = labels[keep] + + detections.append( + self.merge_detections( + detections=[ + { + "boxes": decoded, + "scores": kept_scores, + "labels": kept_labels, + } + ], + expected_num_predictions=1.0, ) + ) - seg_loss = F.binary_cross_entropy_with_logits(seg_logits, seg_target) - loss = cls_loss + self.seg_loss_weight * seg_loss + return detections - results.update( - { - "loss": loss, - "cls_loss": cls_loss, - "seg_loss": seg_loss, - "y_true": y_true, + def merge_detections( + self, + detections: Sequence[dict[str, torch.Tensor]], + *, + expected_num_predictions: float = 1.0, + ) -> dict[str, torch.Tensor]: + if not detections: + device = self.device if isinstance(self.device, torch.device) else torch.device(self.device) + return { + "boxes": torch.zeros((0, 4), dtype=torch.float32, device=device), + "scores": torch.zeros((0,), dtype=torch.float32, device=device), + "labels": torch.zeros((0,), dtype=torch.long, device=device), + } + + boxes_chunks = [item["boxes"] for item in detections if item["boxes"].numel() > 0] + if not boxes_chunks: + reference = detections[0]["boxes"] + return { + "boxes": reference.new_zeros((0, 4)), + "scores": reference.new_zeros((0,)), + "labels": torch.zeros((0,), dtype=torch.long, device=reference.device), } + + boxes = torch.cat(boxes_chunks, dim=0) + scores = torch.cat([item["scores"] for item in detections if item["boxes"].numel() > 0], dim=0) + labels = torch.cat([item["labels"] for item in detections if item["boxes"].numel() > 0], dim=0) + + final_boxes: List[torch.Tensor] = [] + final_scores: List[torch.Tensor] = [] + final_labels: List[torch.Tensor] = [] + for class_id in torch.unique(labels): + class_mask = labels == class_id + class_boxes = boxes[class_mask] + class_scores = scores[class_mask] + if self.postprocess_method == "wbc": + merged_scores, merged_boxes = _weighted_box_clustering( + class_boxes, + class_scores, + iou_threshold=self.nms_threshold, + expected_num_predictions=expected_num_predictions, + ) + class_labels = torch.full( + (merged_scores.shape[0],), + int(class_id.item()), + dtype=torch.long, + device=labels.device, + ) + final_boxes.append(merged_boxes) + final_scores.append(merged_scores) + final_labels.append(class_labels) + else: + class_keep = _nms( + class_boxes, + class_scores, + iou_threshold=self.nms_threshold, + ) + final_boxes.append(class_boxes[class_keep]) + final_scores.append(class_scores[class_keep]) + final_labels.append(labels[class_mask][class_keep]) + + if final_boxes: + boxes_out = torch.cat(final_boxes, dim=0) + scores_out = torch.cat(final_scores, dim=0) + labels_out = torch.cat(final_labels, dim=0) + keep = scores_out > self.score_threshold + boxes_out = boxes_out[keep] + scores_out = scores_out[keep] + labels_out = labels_out[keep] + if scores_out.numel() > 0: + order = scores_out.argsort(descending=True)[: self.max_detections] + boxes_out = boxes_out[order] + scores_out = scores_out[order] + labels_out = labels_out[order] + else: + boxes_out = boxes.new_zeros((0, 4)) + scores_out = scores.new_zeros((0,)) + labels_out = torch.zeros((0,), dtype=torch.long, device=labels.device) + else: + boxes_out = boxes.new_zeros((0, 4)) + scores_out = scores.new_zeros((0,)) + labels_out = torch.zeros((0,), dtype=torch.long, device=labels.device) + + return { + "boxes": boxes_out, + "scores": scores_out, + "labels": labels_out, + } + + def forward(self, **kwargs) -> Dict[str, Any]: + return_detections = bool(kwargs.pop("return_detections", True)) + return_seg_logit = bool(kwargs.pop("return_seg_logit", True)) + return_raw_outputs = bool(kwargs.pop("return_raw_outputs", True)) + x = kwargs[self.feature_key] + if not isinstance(x, torch.Tensor): + x = torch.as_tensor(x) + x = x.to(self.device, dtype=torch.float32) + x = self._to_nchw(x) + x = self._align_channels(x) + + pyramid_features, seg_logit, embed = self._build_pyramid(x) + cls_logits = torch.cat( + [self._reshape_cls_output(self.cls_head(feature)) for feature in pyramid_features], + dim=1, + ) + box_deltas = torch.cat( + [self._reshape_box_output(self.box_head(feature)) for feature in pyramid_features], + dim=1, + ) + anchors = self._generate_anchors( + features=pyramid_features, + image_shape=(x.shape[-2], x.shape[-1]), + ) + + if kwargs.get("embed", False): + return {"embed": embed} + + results: Dict[str, Any] = {} + if return_raw_outputs: + results.update( + { + "logit": cls_logits, + "y_prob": F.softmax(cls_logits, dim=-1), + "cls_logits": cls_logits, + "bbox_deltas": box_deltas, + "anchors": anchors, + } + ) + if return_seg_logit: + results["seg_logit"] = seg_logit + if return_detections: + results["detections"] = self._decode_detections( + cls_logits=cls_logits, + box_deltas=box_deltas, + anchors=anchors, + image_shape=(x.shape[-2], x.shape[-1]), + ) + + if "boxes" not in kwargs and "labels" not in kwargs and "seg_target" not in kwargs: + return results + + boxes, labels = self._normalize_box_targets( + boxes=kwargs.get("boxes"), + labels=kwargs.get("labels"), + batch_size=x.shape[0], + device=self.device, + ) + losses = self._compute_training_losses( + cls_logits=cls_logits, + box_deltas=box_deltas, + anchors=anchors, + boxes=boxes, + labels=labels, + seg_logit=seg_logit, + seg_target=kwargs.get("seg_target"), ) + results.update(losses) return results diff --git a/tests/core/test_retina_unet.py b/tests/core/test_retina_unet.py index 664949c12..f41518af6 100644 --- a/tests/core/test_retina_unet.py +++ b/tests/core/test_retina_unet.py @@ -1,65 +1,87 @@ +"""Unit tests for RetinaUNet. + +Contributor: Tuan Nguyen +NetID: tuanmn2 +Paper: Retina U-Net: Embarrassingly Simple Exploitation of Segmentation + Supervision for Medical Object Detection +Paper link: https://proceedings.mlr.press/v116/jaeger20a/jaeger20a.pdf +Description: Synthetic unit tests covering initialization, forward pass, + backward pass, embedding mode, and inference mode for RetinaUNet. +""" + import unittest import torch -from pyhealth.datasets import create_sample_dataset, get_dataloader from pyhealth.models import RetinaUNet -class TestRetinaUNet(unittest.TestCase): - """Unit tests for RetinaUNet with synthetic image tensors.""" +class _DummyOutputProcessor: + def size(self): + return 1 + + +class _DummyDataset: + def __init__(self): + self.input_schema = {"image": "tensor"} + self.output_schema = {"label": "binary"} + self.output_processors = {"label": _DummyOutputProcessor()} + +class TestRetinaUNet(unittest.TestCase): def setUp(self): - samples = [] - for idx in range(4): - image = [ - [float((r + c + idx) % 5) / 5.0 for c in range(32)] - for r in range(32) - ] - samples.append( - { - "patient_id": f"patient-{idx}", - "visit_id": f"visit-{idx}", - "image": image, - "label": idx % 2, - } - ) - - self.dataset = create_sample_dataset( - samples=samples, - input_schema={"image": "tensor"}, - output_schema={"label": "binary"}, - dataset_name="retina_unet_toy", + self.dataset = _DummyDataset() + self.model = RetinaUNet( + dataset=self.dataset, + in_channels=3, + num_classes=2, + base_channels=8, ) - self.model = RetinaUNet(dataset=self.dataset, in_channels=1, base_channels=16) + self.images = torch.randn(2, 3, 64, 64) + self.boxes = [ + torch.tensor([[10.0, 10.0, 24.0, 24.0]], dtype=torch.float32), + torch.tensor([[30.0, 28.0, 48.0, 44.0]], dtype=torch.float32), + ] + self.labels = [ + torch.tensor([1], dtype=torch.long), + torch.tensor([2], dtype=torch.long), + ] + self.seg_target = torch.zeros(2, 64, 64, dtype=torch.long) + self.seg_target[0, 10:24, 10:24] = 1 + self.seg_target[1, 28:44, 30:48] = 2 def test_initialization(self): self.assertEqual(self.model.feature_key, "image") self.assertEqual(self.model.label_key, "label") - self.assertEqual(self.model.in_channels, 1) - self.assertEqual(self.model.base_channels, 16) + self.assertEqual(self.model.in_channels, 3) + self.assertEqual(self.model.num_classes, 2) def test_forward_train(self): - loader = get_dataloader(self.dataset, batch_size=2, shuffle=False) - batch = next(iter(loader)) - - with torch.no_grad(): - output = self.model(**batch) + output = self.model( + image=self.images, + boxes=self.boxes, + labels=self.labels, + seg_target=self.seg_target, + ) self.assertIn("loss", output) - self.assertIn("y_prob", output) - self.assertIn("y_true", output) - self.assertIn("logit", output) + self.assertIn("cls_loss", output) + self.assertIn("bbox_loss", output) + self.assertIn("seg_loss", output) + self.assertIn("detections", output) self.assertIn("seg_logit", output) - self.assertEqual(output["y_prob"].shape[0], 2) - self.assertEqual(output["seg_logit"].shape[0], 2) - self.assertEqual(output["seg_logit"].shape[1], 1) + self.assertEqual(output["cls_logits"].shape[0], 2) + self.assertEqual(output["seg_logit"].shape, (2, 3, 64, 64)) + self.assertEqual(len(output["detections"]), 2) self.assertEqual(output["loss"].dim(), 0) def test_backward(self): - loader = get_dataloader(self.dataset, batch_size=2, shuffle=True) - batch = next(iter(loader)) - output = self.model(**batch) + output = self.model( + image=self.images, + boxes=self.boxes, + labels=self.labels, + seg_target=self.seg_target, + ) output["loss"].backward() has_grad = any( @@ -69,26 +91,19 @@ def test_backward(self): self.assertTrue(has_grad) def test_embed_mode(self): - loader = get_dataloader(self.dataset, batch_size=2, shuffle=False) - batch = next(iter(loader)) - batch["embed"] = True - - with torch.no_grad(): - output = self.model(**batch) - + output = self.model(image=self.images, embed=True) self.assertIn("embed", output) self.assertEqual(output["embed"].shape[0], 2) - def test_custom_seg_target(self): - loader = get_dataloader(self.dataset, batch_size=2, shuffle=False) - batch = next(iter(loader)) - - seg_target = torch.randint(0, 2, size=(2, 1, 32, 32)).float() - batch["seg_target"] = seg_target - output = self.model(**batch) + def test_forward_inference(self): + with torch.no_grad(): + output = self.model(image=self.images) - self.assertIn("seg_loss", output) - self.assertTrue(torch.isfinite(output["seg_loss"])) + self.assertIn("detections", output) + self.assertEqual(len(output["detections"]), 2) + self.assertIn("boxes", output["detections"][0]) + self.assertIn("scores", output["detections"][0]) + self.assertIn("labels", output["detections"][0]) if __name__ == "__main__":