diff --git a/docs/api/models.rst b/docs/api/models.rst index 7c3ac7c4b..fccb6127b 100644 --- a/docs/api/models.rst +++ b/docs/api/models.rst @@ -206,3 +206,4 @@ API Reference models/pyhealth.models.BIOT models/pyhealth.models.unified_multimodal_embedding_docs models/pyhealth.models.califorest + models/pyhealth.models.CADRE diff --git a/docs/api/models/pyhealth.models.cadre.rst b/docs/api/models/pyhealth.models.cadre.rst new file mode 100644 index 000000000..d32c51169 --- /dev/null +++ b/docs/api/models/pyhealth.models.cadre.rst @@ -0,0 +1,13 @@ +pyhealth.models.CADRE +===================== + +CADRE model for multilabel drug-response prediction. + +This implementation is inspired by the CADRE paper and original reference +repository. It provides a PyHealth-compatible model with optional contextual +attention over gene embeddings and drug-specific decoding. + +.. autoclass:: pyhealth.models.CADRE + :members: + :undoc-members: + :show-inheritance: \ No newline at end of file diff --git a/examples/benchmark_cadre_vs_mlp.py b/examples/benchmark_cadre_vs_mlp.py new file mode 100644 index 000000000..4b6bb58bf --- /dev/null +++ b/examples/benchmark_cadre_vs_mlp.py @@ -0,0 +1,264 @@ +import random +import time +from typing import Dict, List + +import numpy as np +import torch +from pyhealth.datasets import create_sample_dataset, get_dataloader +from pyhealth.models import CADRE, MLP +from sklearn.metrics import f1_score + + +SEED = 42 +random.seed(SEED) +np.random.seed(SEED) +torch.manual_seed(SEED) + + +def build_dataset( + n_samples: int = 60, + seq_len: int = 8, + num_genes: int = 50, + num_labels: int = 2, +): + """Create a small synthetic multilabel dataset. + + The labels are intentionally correlated with simple gene-pattern rules so + models have something learnable. + """ + samples: List[Dict] = [] + + for i in range(n_samples): + genes = np.random.randint(1, num_genes + 1, size=seq_len).tolist() + + # Simple synthetic label rules: + # label 0 active if more even genes than odd + even_count = sum(g % 2 == 0 for g in genes) + label_tokens = [] + if even_count >= seq_len // 2: + label_tokens.append(0) + + # label 1 active if mean gene id is relatively high + if np.mean(genes) > (num_genes / 2): + label_tokens.append(1) + + # Guarantee at least one active label + if not label_tokens: + label_tokens = [0] + + samples.append( + { + "patient_id": f"patient-{i}", + "visit_id": f"visit-{i}", + "gene_idx": genes, + "label": label_tokens, + } + ) + + input_schema = { + "gene_idx": "sequence", + } + output_schema = { + "label": "multilabel", + } + + dataset = create_sample_dataset( + samples=samples, + input_schema=input_schema, + output_schema=output_schema, + dataset_name="cadre_benchmark", + ) + return dataset + + +def split_dataset(dataset, train_ratio: float = 0.8): + """Deterministic split for reproducibility.""" + n = len(dataset) + n_train = int(n * train_ratio) + indices = list(range(n)) + train_ds = torch.utils.data.Subset(dataset, indices[:n_train]) + test_ds = torch.utils.data.Subset(dataset, indices[n_train:]) + return train_ds, test_ds + + +def evaluate_metrics(model, dataloader, device): + """Compute average loss and micro F1.""" + model.eval() + + losses = [] + y_true_all = [] + y_pred_all = [] + + with torch.no_grad(): + for batch in dataloader: + batch = { + k: (v.to(device) if isinstance(v, torch.Tensor) else v) + for k, v in batch.items() + } + + ret = model(**batch) + + losses.append(ret["loss"].item()) + + y_true = ret["y_true"].cpu().numpy() + y_prob = ret["y_prob"].cpu().numpy() + + y_pred = (y_prob >= 0.5).astype(int) + + y_true_all.append(y_true) + y_pred_all.append(y_pred) + + y_true_all = np.vstack(y_true_all) + y_pred_all = np.vstack(y_pred_all) + + f1 = f1_score( + y_true_all, + y_pred_all, + average="micro", + zero_division=0, + ) + + return float(np.mean(losses)), float(f1) + + +def train_model(model, train_loader, test_loader, device, epochs: int = 5, lr: float = 1e-3): + """Simple training loop for quick benchmarking.""" + model = model.to(device) + optimizer = torch.optim.Adam(model.parameters(), lr=lr) + + train_losses = [] + test_losses = [] + + start = time.time() + for epoch in range(1, epochs + 1): + model.train() + batch_losses = [] + + for batch in train_loader: + batch = { + k: (v.to(device) if isinstance(v, torch.Tensor) else v) + for k, v in batch.items() + } + + optimizer.zero_grad() + ret = model(**batch) + loss = ret["loss"] + loss.backward() + optimizer.step() + + batch_losses.append(loss.item()) + + train_loss = float(np.mean(batch_losses)) + test_loss, test_f1 = evaluate_metrics(model, test_loader, device) + + train_losses.append(train_loss) + test_losses.append(test_loss) + + print( + f" epoch={epoch} train_loss={train_loss:.4f} " + f"test_loss={test_loss:.4f} " + f"test_f1={test_f1:.4f}" + ) + + elapsed = time.time() - start + + return { + "final_train_loss": train_losses[-1], + "final_test_loss": test_losses[-1], + "best_test_loss": min(test_losses), + "final_test_f1": test_f1, + "runtime_sec": elapsed, + } + + +def benchmark_cadre(use_attention: bool, train_loader, test_loader, dataset, device): + """Benchmark one CADRE configuration.""" + model = CADRE( + dataset=dataset, + feature_key="gene_idx", + label_key="label", + num_genes=dataset.input_processors["gene_idx"].size(), + num_drugs=2, + embedding_dim=16, + hidden_dim=16, + attention_size=8, + attention_head=2, + dropout=0.1, + use_attention=use_attention, + use_cntx_attn=use_attention, + ) + + label = "CADRE(attention=True)" if use_attention else "CADRE(attention=False)" + print(f"\n=== {label} ===") + result = train_model(model, train_loader, test_loader, device, epochs=5, lr=1e-3) + return label, result + + +def benchmark_mlp(train_loader, test_loader, dataset, device): + """Benchmark MLP baseline.""" + model = MLP( + dataset=dataset, + embedding_dim=16, + hidden_dim=16, + n_layers=2, + ) + + print("\n=== MLP ===") + result = train_model(model, train_loader, test_loader, device, epochs=5, lr=1e-3) + return "MLP", result + + +def main(): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print(f"Running on device: {device}") + + dataset = build_dataset() + + train_loader = get_dataloader(dataset, batch_size=8, shuffle=True) + test_loader = get_dataloader(dataset, batch_size=8, shuffle=False) + + results = {} + + name, result = benchmark_cadre( + use_attention=True, + train_loader=train_loader, + test_loader=test_loader, + dataset=dataset, + device=device, + ) + results[name] = result + + name, result = benchmark_cadre( + use_attention=False, + train_loader=train_loader, + test_loader=test_loader, + dataset=dataset, + device=device, + ) + results[name] = result + + name, result = benchmark_mlp( + train_loader=train_loader, + test_loader=test_loader, + dataset=dataset, + device=device, + ) + results[name] = result + + print("\n=== Benchmark Summary ===") + for model_name, metrics in results.items(): + print( + f"{model_name:24s} " + f"final_test_loss={metrics['final_test_loss']:.4f} " + f"final_test_f1={metrics['final_test_f1']:.4f} " + f"runtime_sec={metrics['runtime_sec']:.2f}" + ) + + print("\nInterpretation:") + print("- Lower test loss is better in this quick benchmark.") + print("- This is a synthetic-data sanity benchmark, not a real-world claim.") + print("- It shows that CADRE can be benchmarked against an existing PyHealth baseline.") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/examples/sample_multilabel_cadre.py b/examples/sample_multilabel_cadre.py new file mode 100644 index 000000000..7678fffd4 --- /dev/null +++ b/examples/sample_multilabel_cadre.py @@ -0,0 +1,124 @@ +import random +import numpy as np +import torch + +from pyhealth.datasets import create_sample_dataset, get_dataloader +from pyhealth.models import CADRE + + +SEED = 42 +random.seed(SEED) +np.random.seed(SEED) +torch.manual_seed(SEED) + + +def build_dataset(): + """Create a tiny synthetic multilabel dataset for CADRE.""" + samples = [ + { + "patient_id": "patient-0", + "visit_id": "visit-0", + "gene_idx": [1, 2, 3, 4, 0, 0], + "label": [0], + }, + { + "patient_id": "patient-1", + "visit_id": "visit-1", + "gene_idx": [5, 6, 7, 8, 9, 0], + "label": [1], + }, + { + "patient_id": "patient-2", + "visit_id": "visit-2", + "gene_idx": [2, 4, 6, 8, 0, 0], + "label": [0], + }, + { + "patient_id": "patient-3", + "visit_id": "visit-3", + "gene_idx": [1, 3, 5, 7, 9, 0], + "label": [1], + }, + ] + + input_schema = { + "gene_idx": "sequence", + } + output_schema = { + "label": "multilabel", + } + + dataset = create_sample_dataset( + samples=samples, + input_schema=input_schema, + output_schema=output_schema, + dataset_name="cadre_demo", + ) + return dataset + + +def run_config(name: str, use_attention: bool): + """Run one CADRE configuration and print summary stats.""" + dataset = build_dataset() + loader = get_dataloader(dataset, batch_size=2, shuffle=False) + batch = next(iter(loader)) + + model = CADRE( + dataset=dataset, + feature_key="gene_idx", + label_key="label", + num_genes=20, + num_drugs=2, + embedding_dim=16, + hidden_dim=16, + attention_size=8, + attention_head=2, + dropout=0.1, + use_attention=use_attention, + use_cntx_attn=use_attention, + ) + + model.eval() + with torch.no_grad(): + ret = model(**batch) + + print(f"\n=== {name} ===") + print(f"use_attention={use_attention}") + print(f"loss={ret['loss'].item():.6f}") + print(f"logit shape={tuple(ret['logit'].shape)}") + print(f"y_prob shape={tuple(ret['y_prob'].shape)}") + print(f"mean y_prob={ret['y_prob'].mean().item():.6f}") + + return { + "name": name, + "use_attention": use_attention, + "loss": ret["loss"].item(), + "mean_y_prob": ret["y_prob"].mean().item(), + } + + +def main(): + """Simple ablation: attention on vs off.""" + print("Running CADRE ablation on synthetic data...") + print("Ablation: use_attention=True vs use_attention=False") + + result_attn = run_config("CADRE with attention", use_attention=True) + result_no_attn = run_config("CADRE without attention", use_attention=False) + + print("\n=== Summary ===") + print( + f"with attention: loss={result_attn['loss']:.6f}, " + f"mean_y_prob={result_attn['mean_y_prob']:.6f}" + ) + print( + f"without attention: loss={result_no_attn['loss']:.6f}, " + f"mean_y_prob={result_no_attn['mean_y_prob']:.6f}" + ) + print( + "\nThis example demonstrates a runnable CADRE ablation using " + "synthetic multilabel data in PyHealth." + ) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/pyhealth/models/__init__.py b/pyhealth/models/__init__.py index 4c168d3e3..5bd681a16 100644 --- a/pyhealth/models/__init__.py +++ b/pyhealth/models/__init__.py @@ -45,4 +45,5 @@ from .sdoh import SdohClassifier from .medlink import MedLink from .unified_embedding import UnifiedMultimodalEmbeddingModel, SinusoidalTimeEmbedding -from .califorest import CaliForest \ No newline at end of file +from .califorest import CaliForest +from .cadre import CADRE diff --git a/pyhealth/models/cadre.py b/pyhealth/models/cadre.py new file mode 100644 index 000000000..e8f15f246 --- /dev/null +++ b/pyhealth/models/cadre.py @@ -0,0 +1,325 @@ +from __future__ import annotations + +from typing import Dict, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from pyhealth.datasets import SampleDataset +from pyhealth.models import BaseModel + + +class CADRE(BaseModel): + """CADRE model for multilabel drug-response prediction. + + This is a clean PyHealth-style implementation inspired by the original + CADRE architecture. It preserves the core ideas: + + 1. Embed input gene indices. + 2. Optionally compute drug-contextual attention over genes. + 3. Build a drug-specific hidden representation for each sample. + 4. Decode to per-drug logits using learned drug embeddings. + + Expected input + -------------- + This model expects one feature key containing integer gene indices, + padded with 0 if needed. The tensor shape should be: + + gene_idx: [batch_size, num_selected_genes] + + It expects one multilabel target key containing per-drug binary labels: + + label: [batch_size, num_drugs] + + Optionally, a mask tensor can be provided to ignore missing labels: + + label_mask: [batch_size, num_drugs] + + Notes + ----- + - This implementation is intentionally cleaner than the original training + code and is designed to fit PyHealth's BaseModel API. + - The first version focuses on expression/indexed gene input. + - If contextual attention is enabled, the model uses drug IDs as context + tokens, analogous to the original CADRE pathway/drug-context mechanism. + + Args: + dataset: PyHealth sample dataset. + feature_key: Name of the feature containing integer gene indices. + label_key: Name of the multilabel drug-response target. + mask_key: Optional mask key for missing labels. + num_genes: Number of gene IDs excluding padding. Padding index is 0. + num_drugs: Number of drug outputs. + embedding_dim: Dimension of gene embeddings. + hidden_dim: Hidden dimension used for drug decoding. + attention_size: Intermediate size for attention scoring. + attention_head: Number of attention heads. + dropout: Dropout probability. + use_attention: Whether to use attention over gene embeddings. + use_cntx_attn: Whether to use drug-contextual attention. + init_gene_emb: Optional pretrained gene embedding tensor of shape + [num_genes + 1, embedding_dim]. + use_relu: Whether to apply ReLU before dropout on the encoded + drug-specific representation. + """ + + def __init__( + self, + dataset: SampleDataset, + feature_key: str, + label_key: str, + num_genes: int, + num_drugs: int, + mask_key: Optional[str] = None, + embedding_dim: int = 200, + hidden_dim: int = 200, + attention_size: int = 128, + attention_head: int = 8, + dropout: float = 0.6, + use_attention: bool = True, + use_cntx_attn: bool = True, + init_gene_emb: Optional[torch.Tensor] = None, + use_relu: bool = False, + ): + super().__init__(dataset) + + if label_key not in self.label_keys: + raise ValueError( + f"label_key='{label_key}' not found in dataset output schema. " + f"Available label keys: {self.label_keys}" + ) + if feature_key not in self.feature_keys: + raise ValueError( + f"feature_key='{feature_key}' not found in dataset input schema. " + f"Available feature keys: {self.feature_keys}" + ) + + self.feature_key = feature_key + self.label_key = label_key + self.mask_key = mask_key + + self.num_genes = num_genes + self.num_drugs = num_drugs + self.embedding_dim = embedding_dim + self.hidden_dim = hidden_dim + self.attention_size = attention_size + self.attention_head = attention_head + self.dropout_rate = dropout + self.use_attention = use_attention + self.use_cntx_attn = use_cntx_attn + self.use_relu = use_relu + + # Force multilabel behavior for per-drug binary outputs. + self.mode = "multilabel" + + if init_gene_emb is not None: + expected_shape = (num_genes + 1, embedding_dim) + if tuple(init_gene_emb.shape) != expected_shape: + raise ValueError( + f"init_gene_emb must have shape {expected_shape}, " + f"got {tuple(init_gene_emb.shape)}" + ) + self.gene_embedding = nn.Embedding.from_pretrained( + init_gene_emb.float(), + freeze=True, + padding_idx=0, + ) + else: + self.gene_embedding = nn.Embedding( + num_embeddings=num_genes + 1, + embedding_dim=embedding_dim, + padding_idx=0, + ) + + self.dropout = nn.Dropout(p=dropout) + + if use_attention: + self.attn_proj = nn.Linear(embedding_dim, attention_size, bias=True) + self.attn_beta = nn.Linear(attention_size, attention_head, bias=True) + + if use_cntx_attn: + self.drug_context_embedding = nn.Embedding( + num_embeddings=num_drugs, + embedding_dim=attention_size, + ) + + # Optional projection to hidden_dim if embedding_dim != hidden_dim. + if embedding_dim != hidden_dim: + self.hidden_proj = nn.Linear(embedding_dim, hidden_dim) + else: + self.hidden_proj = nn.Identity() + + # Drug decoder, analogous to original DrugDecoder. + self.drug_embedding = nn.Embedding( + num_embeddings=num_drugs, + embedding_dim=hidden_dim, + ) + self.drug_bias = nn.Parameter(torch.zeros(num_drugs)) + + # Loss: masked multilabel BCE with logits. + self.loss_fn = nn.BCEWithLogitsLoss(reduction="none") + self.relu = nn.ReLU() + + def _encode( + self, + gene_idx: torch.Tensor, + ) -> torch.Tensor: + """Encodes gene indices into drug-specific hidden representations. + + Args: + gene_idx: Tensor of shape [batch_size, num_selected_genes]. + + Returns: + Tensor of shape [batch_size, num_drugs, hidden_dim]. + """ + if gene_idx.dtype != torch.long: + gene_idx = gene_idx.long() + + # E_t: [B, G, E] + gene_emb = self.gene_embedding(gene_idx) + + if self.use_attention: + # Expand genes over drugs: + # [B, 1, G, E] -> [B, D, G, E] + gene_emb_exp = gene_emb.unsqueeze(1).repeat(1, self.num_drugs, 1, 1) + + # Base attention projection: + # [B, D, G, A] + attn_input = self.attn_proj(gene_emb_exp) + + if self.use_cntx_attn: + # Drug-context embeddings: [D, A] -> [1, D, 1, A] + drug_ids = torch.arange( + self.num_drugs, device=gene_idx.device, dtype=torch.long + ) + drug_ctx = self.drug_context_embedding(drug_ids) + drug_ctx = drug_ctx.unsqueeze(0).unsqueeze(2) + attn_input = attn_input + drug_ctx + + attn_hidden = torch.tanh(attn_input) + + # [B, D, G, H] + attn_scores = self.attn_beta(attn_hidden) + + # Softmax across genes, then collapse heads by summation. + # [B, D, G, H] -> [B, D, G, 1] + attn_weights = F.softmax(attn_scores, dim=2).sum(dim=3, keepdim=True) + + # Weighted sum across genes: + # [B, D, 1, G] x [B, D, G, E] -> [B, D, 1, E] -> [B, D, E] + drug_specific = torch.matmul( + attn_weights.permute(0, 1, 3, 2), gene_emb_exp + ).squeeze(2) + else: + # Mean pooling over genes, then repeat across drugs. + # [B, G, E] -> [B, E] -> [B, D, E] + pooled = gene_emb.mean(dim=1) + drug_specific = pooled.unsqueeze(1).repeat(1, self.num_drugs, 1) + + drug_specific = self.hidden_proj(drug_specific) + + if self.use_relu: + drug_specific = self.relu(drug_specific) + + drug_specific = self.dropout(drug_specific) + return drug_specific + + def _decode( + self, + drug_hidden: torch.Tensor, + ) -> torch.Tensor: + """Decodes drug-specific hidden states into per-drug logits. + + Args: + drug_hidden: Tensor of shape [batch_size, num_drugs, hidden_dim]. + + Returns: + Logits tensor of shape [batch_size, num_drugs]. + """ + batch_size = drug_hidden.shape[0] + + # [D, H] -> [B, D, H] + drug_ids = torch.arange( + self.num_drugs, device=drug_hidden.device, dtype=torch.long + ) + drug_emb = self.drug_embedding(drug_ids).unsqueeze(0).repeat(batch_size, 1, 1) + + # Dot product across hidden dimension. + logits = (drug_hidden * drug_emb).sum(dim=-1) + self.drug_bias + return logits + + def _compute_loss( + self, + logits: torch.Tensor, + y_true: torch.Tensor, + y_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Computes masked multilabel BCE loss.""" + y_true = y_true.float() + loss = self.loss_fn(logits, y_true) + + if y_mask is not None: + y_mask = y_mask.float() + denom = y_mask.sum().clamp_min(1.0) + return (loss * y_mask).sum() / denom + + return loss.mean() + + def forward( + self, + **kwargs: torch.Tensor | tuple[torch.Tensor, ...], + ) -> Dict[str, torch.Tensor]: + """Forward pass of CADRE. + + Required kwargs: + feature_key: gene index tensor [B, G] + + Optional kwargs: + label_key: multilabel target tensor [B, D] + mask_key: mask tensor [B, D] + + Returns: + Dictionary containing: + - logit: [B, D] + - y_prob: [B, D] + - loss: scalar tensor, if labels are provided + - y_true: [B, D], if labels are provided + """ + if self.feature_key not in kwargs: + raise KeyError( + f"Missing required feature key '{self.feature_key}' in forward input." + ) + + gene_idx = kwargs[self.feature_key] + if isinstance(gene_idx, tuple): + gene_idx = gene_idx[0] + assert isinstance(gene_idx, torch.Tensor) + + drug_hidden = self._encode(gene_idx) + logits = self._decode(drug_hidden) + y_prob = self.prepare_y_prob(logits) + + results: Dict[str, torch.Tensor] = { + "logit": logits, + "y_prob": y_prob, + } + + if self.label_key in kwargs: + y_true = kwargs[self.label_key] + if isinstance(y_true, tuple): + y_true = y_true[0] + assert isinstance(y_true, torch.Tensor) + + y_mask: Optional[torch.Tensor] = None + if self.mask_key is not None and self.mask_key in kwargs: + y_mask = kwargs[self.mask_key] + if isinstance(y_mask, tuple): + y_mask = y_mask[0] + assert isinstance(y_mask, torch.Tensor) + + loss = self._compute_loss(logits, y_true, y_mask) + results["loss"] = loss + results["y_true"] = y_true.float() + + return results diff --git a/tests/core/test_cadre.py b/tests/core/test_cadre.py new file mode 100644 index 000000000..ae6921331 --- /dev/null +++ b/tests/core/test_cadre.py @@ -0,0 +1,110 @@ +import unittest +import torch + +from pyhealth.datasets import create_sample_dataset, get_dataloader +from pyhealth.models import CADRE + + +class TestCADRE(unittest.TestCase): + """Test cases for the CADRE model.""" + + def setUp(self): + self.num_drugs = 2 + self.num_genes = 20 + + # gene_idx values must be integers; 0 reserved for padding + self.samples = [ + { + "patient_id": "patient-0", + "gene_idx": [1, 2, 3, 4, 0, 0], + "label": [0,1], + }, + { + "patient_id": "patient-1", + "gene_idx": [5, 6, 7, 8, 9, 0], + "label": [0,1], + }, + ] + + self.input_schema = { + "gene_idx": "sequence", + } + + self.output_schema = { + "label": "multilabel", + } + + self.dataset = create_sample_dataset( + samples=self.samples, + input_schema=self.input_schema, + output_schema=self.output_schema, + dataset_name="cadre_test", + ) + + self.model = CADRE( + dataset=self.dataset, + feature_key="gene_idx", + label_key="label", + num_genes=self.num_genes, + num_drugs=self.num_drugs, + embedding_dim=16, + hidden_dim=16, + attention_size=8, + attention_head=2, + dropout=0.1, + ) + + def test_model_initialization(self): + self.assertIsInstance(self.model, CADRE) + self.assertEqual(self.model.num_genes, self.num_genes) + self.assertEqual(self.model.num_drugs, self.num_drugs) + + def test_forward_input_format(self): + loader = get_dataloader(self.dataset, batch_size=2, shuffle=False) + batch = next(iter(loader)) + + self.assertIsInstance(batch["gene_idx"], torch.Tensor) + self.assertIsInstance(batch["label"], torch.Tensor) + + def test_model_forward(self): + loader = get_dataloader(self.dataset, batch_size=2, shuffle=False) + batch = next(iter(loader)) + + with torch.no_grad(): + ret = self.model(**batch) + + self.assertIn("loss", ret) + self.assertIn("y_prob", ret) + self.assertIn("y_true", ret) + self.assertIn("logit", ret) + + self.assertEqual(ret["logit"].shape, (2, self.num_drugs)) + self.assertEqual(ret["y_prob"].shape, (2, self.num_drugs)) + self.assertEqual(ret["y_true"].shape, (2, self.num_drugs)) + self.assertEqual(ret["loss"].dim(), 0) + + def test_model_backward(self): + loader = get_dataloader(self.dataset, batch_size=2, shuffle=False) + batch = next(iter(loader)) + + ret = self.model(**batch) + ret["loss"].backward() + + has_grad = any( + p.requires_grad and p.grad is not None + for p in self.model.parameters() + ) + self.assertTrue(has_grad) + + def test_loss_is_finite(self): + loader = get_dataloader(self.dataset, batch_size=2, shuffle=False) + batch = next(iter(loader)) + + with torch.no_grad(): + ret = self.model(**batch) + + self.assertTrue(torch.isfinite(ret["loss"]).all()) + + +if __name__ == "__main__": + unittest.main()