Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions deeplc/_architecture.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,3 +423,69 @@ def forward(
output = self.final_network(concatenated)

return output


class BatchedHeads(nn.Module):
"""Parallel output heads sharing a hidden projection.

Each head maps the shared trunk output to a scalar via a two-step
computation: a batched linear projection followed by a per-head dot
product with a learned weight vector.

Parameters
----------
input_size
Size of the input feature vector (output of shared trunk).
n_heads
Number of parallel output heads.
hidden
Hidden dimension per head (default: 32).
"""

def __init__(self, input_size: int, n_heads: int, hidden: int = 32):
super().__init__()
self.layer1 = nn.Linear(input_size, n_heads * hidden)
self.w2 = nn.Parameter(torch.zeros(n_heads, hidden))
self.b2 = nn.Parameter(torch.zeros(n_heads))
nn.init.normal_(self.w2, std=0.05)

def forward(self, x: torch.Tensor) -> torch.Tensor:
h = self.layer1(x) # (batch, n_heads * hidden)
n_heads = self.b2.shape[0]
h = torch.relu(h.view(h.shape[0], n_heads, h.shape[1] // n_heads))
return (h * self.w2.unsqueeze(0)).sum(dim=-1) + self.b2 # (batch, n_heads)


class MultitaskDeepLCModel(nn.Module):
"""Multi-task DeepLC backbone predicting RT across multiple LC systems.

Shares the same four input branches as :class:`DeepLCModel` but replaces
the single-output final network with a shared trunk feeding into
:class:`BatchedHeads`, producing one RT value per LC system.

This class is primarily used for loading pre-trained checkpoints via
``torch.load``. The child modules (``branch_a``, ``branch_b``,
``branch_c``, ``branch_d``, ``shared_trunk``, ``heads``) are restored
from the checkpoint state dict and do not need to be constructed here.
"""

def forward(
self,
x_atom: torch.Tensor,
x_atom_sum: torch.Tensor,
x_global: torch.Tensor,
x_one_hot: torch.Tensor,
) -> torch.Tensor:
x_atom = x_atom.transpose(1, 2)
x_atom_sum = x_atom_sum.transpose(1, 2)
x_one_hot = x_one_hot.transpose(1, 2)
concatenated = torch.cat(
[
self.branch_a(x_atom), # type: ignore[attr-defined]
self.branch_b(x_atom_sum), # type: ignore[attr-defined]
self.branch_c(x_global), # type: ignore[attr-defined]
self.branch_d(x_one_hot), # type: ignore[attr-defined]
],
dim=1,
)
return self.heads(self.shared_trunk(concatenated)) # type: ignore[attr-defined]
32 changes: 30 additions & 2 deletions deeplc/_model_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import copy
import logging
import sys
import types
from collections.abc import Callable
from os import PathLike
from pathlib import Path
Expand All @@ -17,12 +19,28 @@
)
from torch.utils.data import DataLoader, Dataset, Subset

from deeplc._architecture import DeepLCModel
from deeplc._architecture import BatchedHeads, DeepLCModel, MultitaskDeepLCModel
from deeplc.data import DeepLCDataset

logger = logging.getLogger(__name__)


def _patch_legacy_multitask_module() -> None:
"""Register a backwards-compatibility shim for multitask_model.pt.

The bundled multitask checkpoint was saved when MultitaskDeepLCModel and
BatchedHeads lived in a top-level module called ``multitask_model``. That
module no longer exists; the classes now live in ``deeplc._architecture``.
Registering a shim in ``sys.modules`` before ``torch.load`` lets pickle
resolve the old import paths without re-saving the checkpoint.
"""
if "multitask_model" not in sys.modules:
shim = types.ModuleType("multitask_model")
shim.MultitaskDeepLCModel = MultitaskDeepLCModel # type: ignore[attr-defined]
shim.BatchedHeads = BatchedHeads # type: ignore[attr-defined]
sys.modules["multitask_model"] = shim


def load_model(
model: torch.nn.Module | PathLike | str | None = None,
device: str | None = None,
Expand All @@ -33,6 +51,7 @@ def load_model(

# Load model from file if a path is provided
if isinstance(model, (str, PathLike, Path)):
_patch_legacy_multitask_module()
loaded_model = torch.load(model, weights_only=False, map_location=selected_device)
elif isinstance(model, torch.nn.Module):
loaded_model = model
Expand Down Expand Up @@ -61,6 +80,8 @@ def train(
epochs: int = 25,
batch_size: int = 512,
patience: int = 10,
freeze_epochs: int = 0,
unfreeze_lr_scale: float = 0.1,
show_progress: bool = True,
) -> torch.nn.Module:
"""
Expand Down Expand Up @@ -119,6 +140,9 @@ def train(
"Validation data loader is empty. Adjust validation data or validation_split."
)

has_freeze = hasattr(model, "freeze_backbone") and hasattr(model, "unfreeze_backbone")
if has_freeze and freeze_epochs > 0:
model.freeze_backbone()
optimizer = _get_optimizer(model, learning_rate)
loss_fn = torch.nn.L1Loss()

Expand All @@ -129,7 +153,11 @@ def train(
with _create_progress(disable=not show_progress) as progress:
epoch_task = progress.add_task("Epochs", total=epochs, status="")

for _epoch in range(epochs):
for epoch in range(epochs):
if has_freeze and freeze_epochs > 0 and epoch == freeze_epochs:
model.unfreeze_backbone()
optimizer = _get_optimizer(model, learning_rate * unfreeze_lr_scale)

avg_loss = _train_epoch(model, train_loader, optimizer, loss_fn, device)
avg_val_loss = _validate_epoch(model, val_loader, loss_fn, device)

Expand Down
92 changes: 89 additions & 3 deletions deeplc/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,52 @@
SplineTransformerCalibration,
)
from deeplc.data import DeepLCDataset, split_datasets
from deeplc.multitask import MultitaskAdapter

LOGGER = logging.getLogger(__name__)

DEEPLC_DIR = Path(__file__).resolve().parent
DEFAULT_MODEL_NAME = "full_hc_PXD005573_pub_1fd8363d9af9dcad3be7553c39396960.pt"
DEFAULT_MODEL = DEEPLC_DIR / "package_data" / "models" / DEFAULT_MODEL_NAME
DEFAULT_MODEL_FALLBACK = DEEPLC_DIR / "package_data" / "models" / DEFAULT_MODEL_NAME
DEFAULT_MULTITASK_MODEL_PACKAGED = DEEPLC_DIR / "package_data" / "models" / "multitask_model.pt"
DEFAULT_MULTITASK_MODEL_WORKSPACE = (
DEEPLC_DIR.parent.parent / "multitask_output_200ep" / "multitask_model.pt"
)
DEFAULT_MODEL = (
DEFAULT_MULTITASK_MODEL_PACKAGED
if DEFAULT_MULTITASK_MODEL_PACKAGED.exists()
else (
DEFAULT_MULTITASK_MODEL_WORKSPACE
if DEFAULT_MULTITASK_MODEL_WORKSPACE.exists()
else DEFAULT_MODEL_FALLBACK
)
)


def _best_correlating_head(predictions: np.ndarray, targets: np.ndarray) -> int:
"""Return the head index with highest valid Pearson correlation to targets."""
best_idx = 0
best_corr = float("-inf")

for idx in range(predictions.shape[1]):
pred_col = predictions[:, idx]
mask = np.isfinite(pred_col) & np.isfinite(targets)
if mask.sum() < 3:
continue
pred_masked = pred_col[mask]
target_masked = targets[mask]
if np.std(pred_masked) < 1e-8 or np.std(target_masked) < 1e-8:
continue
corr = np.corrcoef(pred_masked, target_masked)[0, 1]
if np.isfinite(corr) and corr > best_corr:
best_corr = corr
best_idx = idx

return best_idx


def _is_multitask_output(predictions: np.ndarray) -> bool:
return predictions.ndim == 2 and predictions.shape[1] > 1


def predict(
Expand Down Expand Up @@ -110,6 +150,12 @@
# Fit calibration
LOGGER.debug("Fitting calibration...")
target_rt_cal = np.array(psm_list_reference["retention_time"], dtype=np.float32)

if _is_multitask_output(source_rt_cal):
selected_head_idx = _best_correlating_head(source_rt_cal, target_rt_cal)
source_rt_cal = source_rt_cal[:, selected_head_idx]
setattr(calibration, "selected_head_idx", int(selected_head_idx))

Check failure on line 157 in deeplc/core.py

View workflow job for this annotation

GitHub Actions / lint-python-package

ruff (B010)

deeplc/core.py:157:9: B010 Do not call `setattr` with a constant attribute value. It is not any safer than normal property access. help: Replace `setattr` with assignment

calibration.fit(target=target_rt_cal, source=source_rt_cal)

return calibration
Expand Down Expand Up @@ -168,6 +214,21 @@
else:
LOGGER.info("Calibration is already fitted, skipping fitting step.")

if _is_multitask_output(predicted_rt):
selected_head_idx = getattr(calibration, "selected_head_idx", None)
if selected_head_idx is None:
ref_pred_rt = predict(
psm_list=psm_list_reference,
model=model,
predict_kwargs=predict_kwargs,
)
if _is_multitask_output(ref_pred_rt):
ref_targets = np.array(psm_list_reference["retention_time"], dtype=np.float32)
selected_head_idx = _best_correlating_head(ref_pred_rt, ref_targets)
else:
selected_head_idx = 0
predicted_rt = predicted_rt[:, int(selected_head_idx)]

# Apply calibration to predictions
calibrated_rt = calibration.transform(predicted_rt)

Expand Down Expand Up @@ -271,11 +332,36 @@
training_dataset, validation_dataset = split_datasets(
training_data, validation_data=validation_data, validation_split=validation_split
)
train_kwargs_local = dict(train_kwargs or {})
model_for_training: torch.nn.Module | PathLike | str | None = model or DEFAULT_MODEL

loaded_model = _model_ops.load_model(
model_for_training,
device=train_kwargs_local.get("device"),
)

sample_features, _ = training_dataset[0]
sample_features = [feature.unsqueeze(0).to(next(loaded_model.parameters()).device) for feature in sample_features]

Check failure on line 344 in deeplc/core.py

View workflow job for this annotation

GitHub Actions / lint-python-package

ruff (E501)

deeplc/core.py:344:100: E501 Line too long (118 > 99)
with torch.no_grad():
sample_output = loaded_model(*sample_features)

if sample_output.ndim == 2 and sample_output.shape[1] > 1:
adapter_hidden_size = int(train_kwargs_local.pop("adapter_hidden_size", 256))
freeze_epochs = int(train_kwargs_local.pop("freeze_epochs", 5))
model_for_training = MultitaskAdapter(
multitask_model=loaded_model,
n_heads=sample_output.shape[1],
hidden_size=adapter_hidden_size,
)
train_kwargs_local["freeze_epochs"] = freeze_epochs
else:
model_for_training = loaded_model

finetuned_model = _model_ops.train(
model=model or DEFAULT_MODEL,
model=model_for_training,
train_dataset=training_dataset,
validation_dataset=validation_dataset,
**(train_kwargs or {}),
**train_kwargs_local,
)
return finetuned_model

Expand Down
37 changes: 37 additions & 0 deletions deeplc/multitask.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
"""Utilities for adapting multitask RT models to single-task outputs."""

from __future__ import annotations

import copy

import torch

Check failure on line 7 in deeplc/multitask.py

View workflow job for this annotation

GitHub Actions / lint-python-package

ruff (F401)

deeplc/multitask.py:7:8: F401 `torch` imported but unused help: Remove unused import: `torch`
import torch.nn as nn


class MultitaskAdapter(nn.Module):
"""Wrap a multitask backbone and map its head vector to one RT output."""

def __init__(self, multitask_model: nn.Module, n_heads: int, hidden_size: int = 256):
super().__init__()
self.backbone = copy.deepcopy(multitask_model)
self.adapter = nn.Sequential(
nn.Linear(n_heads, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, max(1, hidden_size // 2)),
nn.ReLU(),
nn.Linear(max(1, hidden_size // 2), 1),
)

def forward(self, x_atom, x_atom_sum, x_global, x_one_hot):
multitask_output = self.backbone(x_atom, x_atom_sum, x_global, x_one_hot)
if multitask_output.ndim == 1:
multitask_output = multitask_output.unsqueeze(-1)
return self.adapter(multitask_output)

def freeze_backbone(self):
for param in self.backbone.parameters():
param.requires_grad = False

def unfreeze_backbone(self):
for param in self.backbone.parameters():
param.requires_grad = True
Binary file added deeplc/package_data/models/multitask_model.pt
Binary file not shown.
25 changes: 25 additions & 0 deletions tests/test_model_ops.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
from __future__ import annotations

import sys

import pytest
import torch
from torch.utils.data import Dataset

from deeplc import _model_ops
from deeplc._architecture import BatchedHeads, MultitaskDeepLCModel
from deeplc.core import DEFAULT_MULTITASK_MODEL_PACKAGED
from deeplc.data import split_datasets


Expand Down Expand Up @@ -58,3 +62,24 @@ def test_train_rejects_empty_validation_loader():
batch_size=2,
show_progress=False,
)


def test_load_multitask_model_without_prior_shim():
"""multitask_model.pt must load even when the legacy module is not pre-registered."""
# Remove any previously registered shim so the test is self-contained.
sys.modules.pop("multitask_model", None)

model = _model_ops.load_model(DEFAULT_MULTITASK_MODEL_PACKAGED, device="cpu")

assert isinstance(model, MultitaskDeepLCModel)

x_atom = torch.zeros(2, 60, 6)
x_sum = torch.zeros(2, 30, 6)
x_global = torch.zeros(2, 55)
x_hc = torch.zeros(2, 60, 20)
with torch.no_grad():
out = model(x_atom, x_sum, x_global, x_hc)

assert out.ndim == 2
assert out.shape[0] == 2
assert out.shape[1] > 1 # multiple heads
Loading