diff --git a/deeplc/_architecture.py b/deeplc/_architecture.py index a21a1cc..f2495a6 100644 --- a/deeplc/_architecture.py +++ b/deeplc/_architecture.py @@ -423,3 +423,69 @@ def forward( output = self.final_network(concatenated) return output + + +class BatchedHeads(nn.Module): + """Parallel output heads sharing a hidden projection. + + Each head maps the shared trunk output to a scalar via a two-step + computation: a batched linear projection followed by a per-head dot + product with a learned weight vector. + + Parameters + ---------- + input_size + Size of the input feature vector (output of shared trunk). + n_heads + Number of parallel output heads. + hidden + Hidden dimension per head (default: 32). + """ + + def __init__(self, input_size: int, n_heads: int, hidden: int = 32): + super().__init__() + self.layer1 = nn.Linear(input_size, n_heads * hidden) + self.w2 = nn.Parameter(torch.zeros(n_heads, hidden)) + self.b2 = nn.Parameter(torch.zeros(n_heads)) + nn.init.normal_(self.w2, std=0.05) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + h = self.layer1(x) # (batch, n_heads * hidden) + n_heads = self.b2.shape[0] + h = torch.relu(h.view(h.shape[0], n_heads, h.shape[1] // n_heads)) + return (h * self.w2.unsqueeze(0)).sum(dim=-1) + self.b2 # (batch, n_heads) + + +class MultitaskDeepLCModel(nn.Module): + """Multi-task DeepLC backbone predicting RT across multiple LC systems. + + Shares the same four input branches as :class:`DeepLCModel` but replaces + the single-output final network with a shared trunk feeding into + :class:`BatchedHeads`, producing one RT value per LC system. + + This class is primarily used for loading pre-trained checkpoints via + ``torch.load``. The child modules (``branch_a``, ``branch_b``, + ``branch_c``, ``branch_d``, ``shared_trunk``, ``heads``) are restored + from the checkpoint state dict and do not need to be constructed here. + """ + + def forward( + self, + x_atom: torch.Tensor, + x_atom_sum: torch.Tensor, + x_global: torch.Tensor, + x_one_hot: torch.Tensor, + ) -> torch.Tensor: + x_atom = x_atom.transpose(1, 2) + x_atom_sum = x_atom_sum.transpose(1, 2) + x_one_hot = x_one_hot.transpose(1, 2) + concatenated = torch.cat( + [ + self.branch_a(x_atom), # type: ignore[attr-defined] + self.branch_b(x_atom_sum), # type: ignore[attr-defined] + self.branch_c(x_global), # type: ignore[attr-defined] + self.branch_d(x_one_hot), # type: ignore[attr-defined] + ], + dim=1, + ) + return self.heads(self.shared_trunk(concatenated)) # type: ignore[attr-defined] diff --git a/deeplc/_model_ops.py b/deeplc/_model_ops.py index b3245cc..209a4c8 100644 --- a/deeplc/_model_ops.py +++ b/deeplc/_model_ops.py @@ -2,6 +2,8 @@ import copy import logging +import sys +import types from collections.abc import Callable from os import PathLike from pathlib import Path @@ -17,12 +19,28 @@ ) from torch.utils.data import DataLoader, Dataset, Subset -from deeplc._architecture import DeepLCModel +from deeplc._architecture import BatchedHeads, DeepLCModel, MultitaskDeepLCModel from deeplc.data import DeepLCDataset logger = logging.getLogger(__name__) +def _patch_legacy_multitask_module() -> None: + """Register a backwards-compatibility shim for multitask_model.pt. + + The bundled multitask checkpoint was saved when MultitaskDeepLCModel and + BatchedHeads lived in a top-level module called ``multitask_model``. That + module no longer exists; the classes now live in ``deeplc._architecture``. + Registering a shim in ``sys.modules`` before ``torch.load`` lets pickle + resolve the old import paths without re-saving the checkpoint. + """ + if "multitask_model" not in sys.modules: + shim = types.ModuleType("multitask_model") + shim.MultitaskDeepLCModel = MultitaskDeepLCModel # type: ignore[attr-defined] + shim.BatchedHeads = BatchedHeads # type: ignore[attr-defined] + sys.modules["multitask_model"] = shim + + def load_model( model: torch.nn.Module | PathLike | str | None = None, device: str | None = None, @@ -33,6 +51,7 @@ def load_model( # Load model from file if a path is provided if isinstance(model, (str, PathLike, Path)): + _patch_legacy_multitask_module() loaded_model = torch.load(model, weights_only=False, map_location=selected_device) elif isinstance(model, torch.nn.Module): loaded_model = model @@ -61,6 +80,8 @@ def train( epochs: int = 25, batch_size: int = 512, patience: int = 10, + freeze_epochs: int = 0, + unfreeze_lr_scale: float = 0.1, show_progress: bool = True, ) -> torch.nn.Module: """ @@ -119,6 +140,9 @@ def train( "Validation data loader is empty. Adjust validation data or validation_split." ) + has_freeze = hasattr(model, "freeze_backbone") and hasattr(model, "unfreeze_backbone") + if has_freeze and freeze_epochs > 0: + model.freeze_backbone() optimizer = _get_optimizer(model, learning_rate) loss_fn = torch.nn.L1Loss() @@ -129,7 +153,11 @@ def train( with _create_progress(disable=not show_progress) as progress: epoch_task = progress.add_task("Epochs", total=epochs, status="") - for _epoch in range(epochs): + for epoch in range(epochs): + if has_freeze and freeze_epochs > 0 and epoch == freeze_epochs: + model.unfreeze_backbone() + optimizer = _get_optimizer(model, learning_rate * unfreeze_lr_scale) + avg_loss = _train_epoch(model, train_loader, optimizer, loss_fn, device) avg_val_loss = _validate_epoch(model, val_loader, loss_fn, device) diff --git a/deeplc/core.py b/deeplc/core.py index ce4555f..485f34c 100644 --- a/deeplc/core.py +++ b/deeplc/core.py @@ -16,12 +16,52 @@ SplineTransformerCalibration, ) from deeplc.data import DeepLCDataset, split_datasets +from deeplc.multitask import MultitaskAdapter LOGGER = logging.getLogger(__name__) DEEPLC_DIR = Path(__file__).resolve().parent DEFAULT_MODEL_NAME = "full_hc_PXD005573_pub_1fd8363d9af9dcad3be7553c39396960.pt" -DEFAULT_MODEL = DEEPLC_DIR / "package_data" / "models" / DEFAULT_MODEL_NAME +DEFAULT_MODEL_FALLBACK = DEEPLC_DIR / "package_data" / "models" / DEFAULT_MODEL_NAME +DEFAULT_MULTITASK_MODEL_PACKAGED = DEEPLC_DIR / "package_data" / "models" / "multitask_model.pt" +DEFAULT_MULTITASK_MODEL_WORKSPACE = ( + DEEPLC_DIR.parent.parent / "multitask_output_200ep" / "multitask_model.pt" +) +DEFAULT_MODEL = ( + DEFAULT_MULTITASK_MODEL_PACKAGED + if DEFAULT_MULTITASK_MODEL_PACKAGED.exists() + else ( + DEFAULT_MULTITASK_MODEL_WORKSPACE + if DEFAULT_MULTITASK_MODEL_WORKSPACE.exists() + else DEFAULT_MODEL_FALLBACK + ) +) + + +def _best_correlating_head(predictions: np.ndarray, targets: np.ndarray) -> int: + """Return the head index with highest valid Pearson correlation to targets.""" + best_idx = 0 + best_corr = float("-inf") + + for idx in range(predictions.shape[1]): + pred_col = predictions[:, idx] + mask = np.isfinite(pred_col) & np.isfinite(targets) + if mask.sum() < 3: + continue + pred_masked = pred_col[mask] + target_masked = targets[mask] + if np.std(pred_masked) < 1e-8 or np.std(target_masked) < 1e-8: + continue + corr = np.corrcoef(pred_masked, target_masked)[0, 1] + if np.isfinite(corr) and corr > best_corr: + best_corr = corr + best_idx = idx + + return best_idx + + +def _is_multitask_output(predictions: np.ndarray) -> bool: + return predictions.ndim == 2 and predictions.shape[1] > 1 def predict( @@ -110,6 +150,12 @@ def calibrate( # Fit calibration LOGGER.debug("Fitting calibration...") target_rt_cal = np.array(psm_list_reference["retention_time"], dtype=np.float32) + + if _is_multitask_output(source_rt_cal): + selected_head_idx = _best_correlating_head(source_rt_cal, target_rt_cal) + source_rt_cal = source_rt_cal[:, selected_head_idx] + setattr(calibration, "selected_head_idx", int(selected_head_idx)) + calibration.fit(target=target_rt_cal, source=source_rt_cal) return calibration @@ -168,6 +214,21 @@ def predict_and_calibrate( else: LOGGER.info("Calibration is already fitted, skipping fitting step.") + if _is_multitask_output(predicted_rt): + selected_head_idx = getattr(calibration, "selected_head_idx", None) + if selected_head_idx is None: + ref_pred_rt = predict( + psm_list=psm_list_reference, + model=model, + predict_kwargs=predict_kwargs, + ) + if _is_multitask_output(ref_pred_rt): + ref_targets = np.array(psm_list_reference["retention_time"], dtype=np.float32) + selected_head_idx = _best_correlating_head(ref_pred_rt, ref_targets) + else: + selected_head_idx = 0 + predicted_rt = predicted_rt[:, int(selected_head_idx)] + # Apply calibration to predictions calibrated_rt = calibration.transform(predicted_rt) @@ -271,11 +332,36 @@ def finetune( training_dataset, validation_dataset = split_datasets( training_data, validation_data=validation_data, validation_split=validation_split ) + train_kwargs_local = dict(train_kwargs or {}) + model_for_training: torch.nn.Module | PathLike | str | None = model or DEFAULT_MODEL + + loaded_model = _model_ops.load_model( + model_for_training, + device=train_kwargs_local.get("device"), + ) + + sample_features, _ = training_dataset[0] + sample_features = [feature.unsqueeze(0).to(next(loaded_model.parameters()).device) for feature in sample_features] + with torch.no_grad(): + sample_output = loaded_model(*sample_features) + + if sample_output.ndim == 2 and sample_output.shape[1] > 1: + adapter_hidden_size = int(train_kwargs_local.pop("adapter_hidden_size", 256)) + freeze_epochs = int(train_kwargs_local.pop("freeze_epochs", 5)) + model_for_training = MultitaskAdapter( + multitask_model=loaded_model, + n_heads=sample_output.shape[1], + hidden_size=adapter_hidden_size, + ) + train_kwargs_local["freeze_epochs"] = freeze_epochs + else: + model_for_training = loaded_model + finetuned_model = _model_ops.train( - model=model or DEFAULT_MODEL, + model=model_for_training, train_dataset=training_dataset, validation_dataset=validation_dataset, - **(train_kwargs or {}), + **train_kwargs_local, ) return finetuned_model diff --git a/deeplc/multitask.py b/deeplc/multitask.py new file mode 100644 index 0000000..f588daa --- /dev/null +++ b/deeplc/multitask.py @@ -0,0 +1,37 @@ +"""Utilities for adapting multitask RT models to single-task outputs.""" + +from __future__ import annotations + +import copy + +import torch +import torch.nn as nn + + +class MultitaskAdapter(nn.Module): + """Wrap a multitask backbone and map its head vector to one RT output.""" + + def __init__(self, multitask_model: nn.Module, n_heads: int, hidden_size: int = 256): + super().__init__() + self.backbone = copy.deepcopy(multitask_model) + self.adapter = nn.Sequential( + nn.Linear(n_heads, hidden_size), + nn.ReLU(), + nn.Linear(hidden_size, max(1, hidden_size // 2)), + nn.ReLU(), + nn.Linear(max(1, hidden_size // 2), 1), + ) + + def forward(self, x_atom, x_atom_sum, x_global, x_one_hot): + multitask_output = self.backbone(x_atom, x_atom_sum, x_global, x_one_hot) + if multitask_output.ndim == 1: + multitask_output = multitask_output.unsqueeze(-1) + return self.adapter(multitask_output) + + def freeze_backbone(self): + for param in self.backbone.parameters(): + param.requires_grad = False + + def unfreeze_backbone(self): + for param in self.backbone.parameters(): + param.requires_grad = True diff --git a/deeplc/package_data/models/multitask_model.pt b/deeplc/package_data/models/multitask_model.pt new file mode 100644 index 0000000..7c3d18e Binary files /dev/null and b/deeplc/package_data/models/multitask_model.pt differ diff --git a/tests/test_model_ops.py b/tests/test_model_ops.py index 38306a0..4959b51 100644 --- a/tests/test_model_ops.py +++ b/tests/test_model_ops.py @@ -1,10 +1,14 @@ from __future__ import annotations +import sys + import pytest import torch from torch.utils.data import Dataset from deeplc import _model_ops +from deeplc._architecture import BatchedHeads, MultitaskDeepLCModel +from deeplc.core import DEFAULT_MULTITASK_MODEL_PACKAGED from deeplc.data import split_datasets @@ -58,3 +62,24 @@ def test_train_rejects_empty_validation_loader(): batch_size=2, show_progress=False, ) + + +def test_load_multitask_model_without_prior_shim(): + """multitask_model.pt must load even when the legacy module is not pre-registered.""" + # Remove any previously registered shim so the test is self-contained. + sys.modules.pop("multitask_model", None) + + model = _model_ops.load_model(DEFAULT_MULTITASK_MODEL_PACKAGED, device="cpu") + + assert isinstance(model, MultitaskDeepLCModel) + + x_atom = torch.zeros(2, 60, 6) + x_sum = torch.zeros(2, 30, 6) + x_global = torch.zeros(2, 55) + x_hc = torch.zeros(2, 60, 20) + with torch.no_grad(): + out = model(x_atom, x_sum, x_global, x_hc) + + assert out.ndim == 2 + assert out.shape[0] == 2 + assert out.shape[1] > 1 # multiple heads