diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst index 8d9a59d21..a45bdc1a9 100644 --- a/docs/api/datasets.rst +++ b/docs/api/datasets.rst @@ -246,3 +246,4 @@ Available Datasets datasets/pyhealth.datasets.TCGAPRADDataset datasets/pyhealth.datasets.splitter datasets/pyhealth.datasets.utils + datasets/pyhealth.datasets.Wav2SleepDataset diff --git a/docs/api/datasets/pyhealth.datasets.Wav2SleepDataset.rst b/docs/api/datasets/pyhealth.datasets.Wav2SleepDataset.rst new file mode 100644 index 000000000..64cb77c1f --- /dev/null +++ b/docs/api/datasets/pyhealth.datasets.Wav2SleepDataset.rst @@ -0,0 +1,21 @@ +pyhealth.datasets.Wav2SleepDataset +=================================== + +Overview +-------- + +A unified dataset of polysomnography (PSG) recordings spanning 7 datasets +hosted on sleepdata.org: SHHS, MESA, WSC, CHAT, CFS, CCSHS, and MROS. +Used in wav2sleep: A Unified Multi-Modal Approach to Sleep Stage +Classification from Physiological Signals +(https://arxiv.org/abs/2411.04644). Requires a Data Use Agreement via +sleepdata.org. + +API Reference +------------- + +.. autoclass:: pyhealth.datasets.Wav2SleepDataset + :members: + :undoc-members: + :show-inheritance: + \ No newline at end of file diff --git a/docs/api/models.rst b/docs/api/models.rst index 7c3ac7c4b..a7889153a 100644 --- a/docs/api/models.rst +++ b/docs/api/models.rst @@ -206,3 +206,4 @@ API Reference models/pyhealth.models.BIOT models/pyhealth.models.unified_multimodal_embedding_docs models/pyhealth.models.califorest + models/pyhealth.models.Wav2Sleep diff --git a/docs/api/models/pyhealth.models.Wav2Sleep.rst b/docs/api/models/pyhealth.models.Wav2Sleep.rst new file mode 100644 index 000000000..bd6b0f399 --- /dev/null +++ b/docs/api/models/pyhealth.models.Wav2Sleep.rst @@ -0,0 +1,45 @@ +pyhealth.models.Wav2Sleep +========================== + +Overview +-------- + +The complete Wav2Sleep model for sleep stage classification from +polysomnography biosignals. Consists of per-modality Signal Encoders, +an Epoch Mixer for cross-modal fusion, and a Sequence Mixer for temporal +modeling. A trained model can be applied to any subset of the modalities +seen during training. + +API Reference +------------- + +.. autoclass:: pyhealth.models.ResidualBlock + :members: + :undoc-members: + :show-inheritance: + +.. autoclass:: pyhealth.models.SignalEncoder + :members: + :undoc-members: + :show-inheritance: + +.. autoclass:: pyhealth.models.EpochMixer + :members: + :undoc-members: + :show-inheritance: + +.. autoclass:: pyhealth.models.SequenceMixer + :members: + :undoc-members: + :show-inheritance: + +.. autoclass:: pyhealth.models.DilatedConvBlock + :members: + :undoc-members: + :show-inheritance: + +.. autoclass:: pyhealth.models.Wav2Sleep + :members: + :undoc-members: + :show-inheritance: + \ No newline at end of file diff --git a/docs/api/tasks.rst b/docs/api/tasks.rst index 23a4e06e5..d46bd2d62 100644 --- a/docs/api/tasks.rst +++ b/docs/api/tasks.rst @@ -230,3 +230,4 @@ Available Tasks Mutation Pathogenicity (COSMIC) Cancer Survival Prediction (TCGA) Cancer Mutation Burden (TCGA) + Wav2Sleep diff --git a/docs/api/tasks/pyhealth.tasks.Wav2SleepStaging.rst b/docs/api/tasks/pyhealth.tasks.Wav2SleepStaging.rst new file mode 100644 index 000000000..ad5080dc1 --- /dev/null +++ b/docs/api/tasks/pyhealth.tasks.Wav2SleepStaging.rst @@ -0,0 +1,20 @@ +pyhealth.tasks.Wav2SleepStaging +================================ + +Overview +-------- + +Multi-class sleep stage classification from heterogeneous biosignal data. +Prepares biosignal and annotation data from multiple polysomnography +datasets for use with the Wav2Sleep model. Classifies each 30-second +epoch into one of four stages: Wake, Light Sleep (N1+N2), Deep Sleep (N3), +or REM. + +API Reference +------------- + +.. autoclass:: pyhealth.tasks.Wav2SleepStaging + :members: + :undoc-members: + :show-inheritance: + \ No newline at end of file diff --git a/examples/wav2sleep_sleep_staging_wav2sleep.py b/examples/wav2sleep_sleep_staging_wav2sleep.py new file mode 100644 index 000000000..fb511effa --- /dev/null +++ b/examples/wav2sleep_sleep_staging_wav2sleep.py @@ -0,0 +1,139 @@ +from pyhealth.datasets import Wav2SleepDataset +from pyhealth.datasets import get_dataloader, split_by_sample +from pyhealth.tasks import Wav2SleepStaging +from pyhealth.models import Wav2Sleep +from pyhealth.trainer import Trainer + +import os +import tempfile +import numpy as np +import mne +import xml.etree.ElementTree as ET + + +def create_mock_data(tmp_dir: str, n_patients: int = 8) -> str: + """Generate synthetic PSG data matching the Wav2SleepDataset structure. + + Writes fake EDF + XML annotation pairs so the example runs without + requiring access to restricted datasets (SHHS, MESA, etc.). + Uses the SHHS dataset structure: + shhs/polysomnography/edfs/shhs1/{patient_id}.edf + shhs/polysomnography/annotations-events-profusion/shhs1/{patient_id}-profusion.xml + """ + sfreq = 128 + duration_s = 30 * 10 # 10 epochs of 30s + n_samples = int(sfreq * duration_s) + ch_names = ["ECG", "THOR RES", "ABDO RES"] + ch_types = ["ecg", "bio", "bio"] + info = mne.create_info(ch_names=ch_names, sfreq=sfreq, ch_types=ch_types) + + edf_dir = os.path.join(tmp_dir, "shhs", "polysomnography", "edfs", "shhs1") + label_dir = os.path.join( + tmp_dir, "shhs", "polysomnography", "annotations-events-profusion", "shhs1" + ) + os.makedirs(edf_dir, exist_ok=True) + os.makedirs(label_dir, exist_ok=True) + + for i in range(n_patients): + patient_id = f"shhs1-{i:06d}" + + # Fake EDF + data = np.random.randn(len(ch_names), n_samples).astype(np.float32) + raw = mne.io.RawArray(data, info, verbose=False) + edf_path = os.path.join(edf_dir, f"{patient_id}.edf") + mne.export.export_raw(edf_path, raw, fmt="edf", verbose=False) + + # Fake XML annotation (10 epochs, random stages 0-5, no -1 unscored) + root_el = ET.Element("CMPStudyConfig") + staging_el = ET.SubElement(root_el, "SleepStages") + for stage in np.random.choice([0, 1, 2, 3, 5], size=10): + el = ET.SubElement(staging_el, "SleepStage") + el.text = str(stage) + xml_path = os.path.join(label_dir, f"{patient_id}-profusion.xml") + ET.ElementTree(root_el).write(xml_path) + + return tmp_dir + + +def example(root: str = None): + """Demonstrate the full wav2sleep pipeline integrated with PyHealth. + + Runs Dataset → Task → Model → Trainer using either real PSG data from + sleepdata.org or synthetic mock data if no root is provided. + + Args: + root: Path to a directory containing PSG data in the Wav2SleepDataset + structure. If None, synthetic mock data is generated automatically. + To obtain real data, complete a Data Use Agreement at sleepdata.org. + """ + + if root is None: + tmp_dir = tempfile.mkdtemp() + root = create_mock_data(tmp_dir) + print(f"No --root provided. Running on synthetic mock data at {tmp_dir}") + + dataset = Wav2SleepDataset(root) + task = Wav2SleepStaging() + + samples = dataset.set_task(task) + train_dataset, val_dataset, test_dataset = split_by_sample( + samples, [0.5, 0.25, 0.25] + ) + train_loader = get_dataloader(train_dataset, batch_size=2, shuffle=True) + val_loader = get_dataloader(val_dataset, batch_size=2, shuffle=False) + test_loader = get_dataloader(test_dataset, batch_size=2, shuffle=False) + + wav2sleep = Wav2Sleep(samples) + + trainer = Trainer(model=wav2sleep) + trainer.train(train_dataloader=train_loader, val_dataloader=val_loader, epochs=1) + trainer.evaluate(test_loader) + + +def demo_ablation(root: str = None): + """Demonstrate the wav2sleep stochastic signal masking ablation. + + Replicates the ablation study from Section 4.3 of the wav2sleep paper, + which investigates model robustness when signals are randomly dropped + during training. This is controlled via custom per-signal masking + probabilities passed to Wav2Sleep. + + Args: + root: Path to a directory containing PSG data in the Wav2SleepDataset + structure. If None, synthetic mock data is generated automatically. + To obtain real data, complete a Data Use Agreement at sleepdata.org. + """ + + if root is None: + tmp_dir = tempfile.mkdtemp() + root = create_mock_data(tmp_dir) + print(f"No --root provided. Running on synthetic mock data at {tmp_dir}") + + dataset = Wav2SleepDataset(root) + task = Wav2SleepStaging() + + samples = dataset.set_task(task) + train_dataset, val_dataset, test_dataset = split_by_sample( + samples, [0.5, 0.25, 0.25] + ) + train_loader = get_dataloader(train_dataset, batch_size=2, shuffle=True) + val_loader = get_dataloader(val_dataset, batch_size=2, shuffle=False) + test_loader = get_dataloader(test_dataset, batch_size=2, shuffle=False) + + # ABLATION Create a custom set of masking probabilities + mask_probabilities = {"ECG": 0.5, "PPG": 0.5, "THX": 0.5, "ABD": 0.5} + + wav2sleep = Wav2Sleep(samples, stochastic_mask_probabilities=mask_probabilities) + + trainer = Trainer(model=wav2sleep) + trainer.train(train_dataloader=train_loader, val_dataloader=val_loader, epochs=1) + trainer.evaluate(test_loader) + + +if __name__ == "__main__": + # if you have Data Access from sleepdata.org, you can download and use the datasets + root = "../../full_sample_PSG" + # otherwise, rely on mock data + root = root if os.path.isdir(root) else None + example(root) + demo_ablation(root) diff --git a/pyhealth/datasets/__init__.py b/pyhealth/datasets/__init__.py index 50b1b3887..7e0a1ff11 100644 --- a/pyhealth/datasets/__init__.py +++ b/pyhealth/datasets/__init__.py @@ -90,4 +90,5 @@ def __init__(self, *args, **kwargs): load_processors, save_processors, ) +from .wav2sleep import Wav2SleepDataset from .collate import collate_temporal diff --git a/pyhealth/datasets/configs/wav2sleep.yaml b/pyhealth/datasets/configs/wav2sleep.yaml new file mode 100644 index 000000000..4a45700a7 --- /dev/null +++ b/pyhealth/datasets/configs/wav2sleep.yaml @@ -0,0 +1,10 @@ +version: "1.0" +tables: + wav2sleep: + file_path: "wav2sleep-metadata.csv" + patient_id: "patient_id" + timestamp: null + attributes: + - source_dataset + - edf_path + - label_path \ No newline at end of file diff --git a/pyhealth/datasets/wav2sleep.py b/pyhealth/datasets/wav2sleep.py new file mode 100644 index 000000000..80aefde11 --- /dev/null +++ b/pyhealth/datasets/wav2sleep.py @@ -0,0 +1,203 @@ +""" +Author(s): Bronze Frazer +NetID(s): bfrazer2 +Paper: wav2sleep: A Unified Multi-Modal Approach to Sleep Stage Classification from Physiological Signals +Link: https://arxiv.org/abs/2411.04644 +Desc: PyHealth Dataset for the collection of 7 datasets used to train wav2sleep +""" + +import logging +import os +import pandas as pd +from pathlib import Path +from typing import Optional + +from pyhealth.datasets import BaseDataset + +logger = logging.getLogger(__name__) + + +class Wav2SleepDataset(BaseDataset): + """Unified dataset of PSG recordings (EDF and annotation files) + + Spans 7 datasets hosted on sleepdata.org that are used in wav2sleep: A Unified Multi-Modal Approach to Sleep Stage Classification from Physiological Signals (https://arxiv.org/abs/2411.04644): + SHHS, MESA, WSC, CHAT, CFS, CCSHS, and MROS. + + Signal availability varies by source dataset: + - ECG, THX, ABD: available in all datasets + - PPG: available in MESA, CHAT, CFS, CCSHS only + + Note: + A Data Use Agreement must be completed via sleepdata.org + + Here are the steps required to download the raw data: + 1. Fill out a Data Use Agreement on sleepdata.org + 2. Receive a Data Access Token (sleepdata.org/token) + 3. Use the nsrr gem tool (https://github.com/nsrr/nsrr-gem) + + Once you have your token... + Create and enter the directory you want to use as the root: + mkdir PSG_root + cd PSG_root + + Then download each dataset using NSRR Ruby Gem (https://github.com/nsrr/nsrr-gem) + using the following command structures: + For SHHS, MESA, CHAT, CFS, CCSHS, and MROS: + nsrr download {dataset}/polysomnography/edfs --fast + nsrr download {dataset}/polysomnography/annotations-events-profusion --fast + For WSC: + nsrr download wsc/polysomnography --fast + + The resulting structure will be: + + PSG_root + ├── ccshs + │ └── polysomnography + │ ├── annotations-events-profusion + │ └── edfs + ├── cfs + │ └── polysomnography + │ ├── annotations-events-profusion + │ └── edfs + ├── chat + │ └── polysomnography + │ ├── annotations-events-profusion + │ │ ├── baseline + │ │ └── followup + │ | └── nonrandomized + │ └── edfs + │ ├── baseline + │ └── followup + │ └── nonrandomized + ├── mesa + │ └── polysomnography + │ ├── annotations-events-profusion + │ └── edfs + ├── mros + │ └── polysomnography + │ ├── annotations-events-profusion + │ │ ├── visit1 + │ │ └── visit2 + │ └── edfs + │ ├── visit1 + │ └── visit2 + ├── shhs + │ └── polysomnography + │ ├── annotations-events-profusion + │ │ ├── shhs1 + │ │ └── shhs2 + │ └── edfs + │ ├── shhs1 + │ └── shhs2 + └── wsc + └── polysomnography + + Args: + root: root directory containing one subdirectory per source dataset + config_path: optional path to YAML config, defaults to wav2sleep.yaml + + Examples: + >>> dataset = Wav2SleepDataset(root = "path/to/root") + >>> dataset.stats() + """ + + def __init__(self, root: str, config_path: Optional[str] = None) -> None: + # Validate root + if not os.path.isdir(root): + raise FileNotFoundError(f"Root directory not found: {root}") + + if config_path is None: + logger.info("No config path provided, using default config") + config_path = Path(__file__).parent / "configs" / "wav2sleep.yaml" + + # Prepare metadata file if it does not already exist + metadata_file = os.path.join(root, "wav2sleep-metadata.csv") + if not os.path.exists(metadata_file): + logger.info("Preparing Wav2Sleep metadata...") + self.prepare_metadata(root) + + super().__init__( + root=root, + tables=["wav2sleep"], + dataset_name="wav2sleep", + config_path=config_path, + ) + + def prepare_metadata(self, root: str) -> None: + """Prepares a metadata CSV file that outlines the locations of + EDF and label files across all datasets found in root. + + Args: + root: root directory containing one subdirectory per source dataset + """ + rows = [] + + for dataset_dir in Path(root).iterdir(): + if not dataset_dir.is_dir(): + continue + + source_dataset = dataset_dir.name + logger.info(f"Processing {source_dataset}...") + + for edf_dir, label_dir in self.get_edf_and_label_dirs(dataset_dir): + for edf_file in edf_dir.glob("*.edf"): + patient_id = edf_file.stem + + label_file_extension = ( + ".stg.txt" if source_dataset == "wsc" else "-profusion.xml" + ) + + label_file = label_dir / f"{patient_id}{label_file_extension}" + if not label_file.exists(): + logger.warning( + f"Label file not found for \ + {patient_id} in {source_dataset}, skipping" + ) + continue + + rows.append( + { + "patient_id": patient_id, + "source_dataset": source_dataset, + "edf_path": str(edf_file), + "label_path": str(label_file), + } + ) + + output_path = Path(root) / "wav2sleep-metadata.csv" + pd.DataFrame(rows).to_csv(output_path, index=False) + logger.info(f"Metadata saved to {output_path}") + + def get_edf_and_label_dirs(self, dataset_dir: Path) -> list[tuple[Path, Path]]: + """Retrieves the EDF and label directories for a given dataset. + + Handles datasets that have an extra subdirectory layer (e.g. SHHS). + + Args: + dataset_dir: path to the dataset directory (e.g. root/shhs) + + Returns: + list[tuple[Path, Path]]: A list of (edf_dir, label_dir) pairs, + one per subdirectory if subdirectories exist, otherwise a single pair. + """ + + if dataset_dir.name == "wsc": + edf_dir = dataset_dir / "polysomnography" + label_dir = ( + dataset_dir / "polysomnography" + ) # the annotations for WSC are not in a separate directory + else: + edf_dir = dataset_dir / "polysomnography" / "edfs" + label_dir = dataset_dir / "polysomnography" / "annotations-events-profusion" + + subdirs = [d for d in edf_dir.iterdir() if d.is_dir()] + if subdirs: + return [(edf_dir / d.name, label_dir / d.name) for d in sorted(subdirs)] + + return [(edf_dir, label_dir)] + + +if __name__ == "__main__": + # ../../../full_sample_PSG/ has one edf and annotation per dataset (currently not per subdir) + dataset = Wav2SleepDataset(root="../../../full_sample_PSG/") + dataset.stats() diff --git a/pyhealth/models/__init__.py b/pyhealth/models/__init__.py index 4c168d3e3..f5e1871b3 100644 --- a/pyhealth/models/__init__.py +++ b/pyhealth/models/__init__.py @@ -45,4 +45,5 @@ from .sdoh import SdohClassifier from .medlink import MedLink from .unified_embedding import UnifiedMultimodalEmbeddingModel, SinusoidalTimeEmbedding -from .califorest import CaliForest \ No newline at end of file +from .califorest import CaliForestfrom +from .wav2sleep import Wav2Sleep \ No newline at end of file diff --git a/pyhealth/models/wav2sleep.py b/pyhealth/models/wav2sleep.py new file mode 100644 index 000000000..5270e62f4 --- /dev/null +++ b/pyhealth/models/wav2sleep.py @@ -0,0 +1,528 @@ +""" +Author(s): Bronze Frazer +NetID(s): bfrazer2 +Paper: wav2sleep: A Unified Multi-Modal Approach to Sleep Stage Classification from Physiological Signals +Link: https://arxiv.org/abs/2411.04644 +Desc: PyHealth Model implementation of wav2sleep for sleep stage classification +""" + +from typing import Dict, List, Optional + +import torch +from torch import Tensor, nn +from torch.functional import F + +from pyhealth.datasets import SampleDataset +from pyhealth.models.base_model import BaseModel + +# Global hyperparameters (as used in the paper) +FEATURE_DIM = 128 +ACTIVATION_FUNCTION = nn.GELU() +DROPOUT_RATE = 0.1 + + +class ResidualBlock(nn.Module): + """Residual Convolution Block to encode a signal + + Args: + c_in: Number of input channels + c_out: Number of output channels + pool_size: downsampling factor + """ + + def __init__(self, c_in: int, c_out: int, pool_size: int = 2) -> None: + super().__init__() + + def create_conv_block( + input_dim: int, output_dim: int, kernel_size: int = 3 + ) -> nn.Sequential: + """Create a Convolution Block + + Args: + input_dim: Dimension of the input + output_dim: Dimension of the output + kernel_size: Size of the convolutional kernel + + Returns: + nn.Sequential: A convolutional block with instance normalization + """ + pad = kernel_size // 2 + return nn.Sequential( + nn.Conv1d(input_dim, output_dim, kernel_size=kernel_size, padding=pad), + nn.InstanceNorm1d(output_dim), + ACTIVATION_FUNCTION, + nn.Dropout(DROPOUT_RATE), + ) + + self.conv1 = create_conv_block(c_in, c_out) + self.conv2 = create_conv_block(c_out, c_out) + self.conv3 = create_conv_block(c_out, c_out) + self.pool = nn.MaxPool1d(pool_size) + self.skip = nn.Conv1d(c_in, c_out, 1) if c_in != c_out else nn.Identity() + self.activation = ACTIVATION_FUNCTION + self.dropout = nn.Dropout(DROPOUT_RATE) + + def forward(self, x: Tensor) -> Tensor: + """Forward pass of the residual block + + Args: + x: Tensor input; shape = (batch_size, c_in, length) + + Returns: + Tensor: Output tensor; shape = (batch_size, c_out, length//2) + """ + residual = self.skip(x) + out = self.conv1(x) + out = self.conv2(out) + out = self.conv3(out) + out += residual + out = self.pool(out) + out = self.activation(out) + out = self.dropout(out) + return out + + +class SignalEncoder(nn.Module): + """Architecture for a Signal Encoder as described in Section 3.1 + + Turns a raw input signal into a per-epoch feature vector sequence. + + A Signal Encoder consists of a stack of residual layers. + Each layer contains three convolutional layers followed by a + max pooling layer to downsample the signal by a factor of 2. + Residual layers are followed by a reshape operation and a + time-distributed dense layer to produce the sequence of feature vectors. + + Args: + signal_sample_rate: The original sample rate (Hz) used when measuring the signal + + Raises: + ValueError: If signal_sample_rate is not 1024 or 256. + """ + + def __init__(self, signal_sample_rate: int) -> None: + super().__init__() + + self.T = 1200 # total epochs after preprocessing + self.signal_sample_rate = signal_sample_rate + + if signal_sample_rate == 1024: + channels = [1, 16, 16, 32, 32, 64, 64, 128, 128] + elif signal_sample_rate == 256: + channels = [1, 16, 32, 64, 64, 128, 128] + else: + raise ValueError( + f"{signal_sample_rate} is not a valid resample rate. " + "Channel progression cannot be assigned" + ) + + channel_progression = list(zip(channels, channels[1:])) + + blocks = [ + layer + for input_dim, output_dim in channel_progression + for layer in ( + ResidualBlock(c_in=input_dim, c_out=output_dim), + nn.InstanceNorm1d(output_dim, affine=True), + ) + ] + + self.encoder = nn.Sequential(*blocks) + self.epoch_dim = ( + channels[-1] * 4 + ) # Flattened dimension for time-distributed dense layer + self.dense = nn.Linear(self.epoch_dim, FEATURE_DIM) + self.activation = ACTIVATION_FUNCTION + + def forward(self, x: Tensor) -> Tensor: + """Encode biosignal to a sequence of features + + Args: + x: A raw signal Tensor; shape = (batch_size, 1, signal_measurements) + + Returns: + Tensor: A sequence of per-epoch feature vectors; shape = (batch_size, T, feature_dim) + """ + batch_size = x.shape[0] + + # Split into epochs — treat each epoch independently (time-distributed) + x = x.view( + batch_size * self.T, 1, self.signal_sample_rate + ) # (batch_size*T, 1, k) + + z = self.encoder(x) # (batch_size*T, feature_dim, 4) + # Flatten spatial dim for the dense layer + z = z.view(batch_size * self.T, -1) # (batch_size*T, 512) + # Time-distributed dense: same weights applied to every epoch + z = self.activation(self.dense(z)) + # Reassemble the time axis + z = z.view(batch_size, self.T, FEATURE_DIM) + return z + + +class EpochMixer(nn.Module): + """Architecture for the Epoch Mixer as described in Section 3.2 + + Provides a unified representation of sleep epochs. + Uses a transformer encoder with a learnable CLS vector + that fuses information among a set of modalities. + + Args: + num_transformer_layers: The number of transformer layers to use + hidden_dimension: The hidden dimension of a transformer layer + num_attention_heads: The number of attention heads to use for a transformer layer + modalities: The list of selected modalities (order of modalities matters) + stochastic_mask_probabilities: Probability that a modality will be masked during training + """ + + def __init__( + self, + num_transformer_layers: int = 2, + hidden_dimension: int = 512, + num_attention_heads: int = 8, + modalities: Optional[List[str]] = None, + stochastic_mask_probabilities: Optional[Dict[str, float]] = None, + ) -> None: + super().__init__() + if modalities is None: + modalities = ["ECG", "PPG", "THX", "ABD"] + if stochastic_mask_probabilities is None: + stochastic_mask_probabilities = { + "ECG": 0.5, + "PPG": 0.1, + "THX": 0.7, + "ABD": 0.7, + } + encoder_layer = nn.TransformerEncoderLayer( + d_model=FEATURE_DIM, + nhead=num_attention_heads, + dim_feedforward=hidden_dimension, + dropout=DROPOUT_RATE, + activation=ACTIVATION_FUNCTION, + batch_first=True, + ) + self.transformer = nn.TransformerEncoder(encoder_layer, num_transformer_layers) + self.cls_token = nn.Parameter(torch.zeros(1, 1, FEATURE_DIM)) + self.stochastic_mask_probabilities = stochastic_mask_probabilities + self.modalities = modalities + + def _build_attention_mask( + self, batch: int, T: int, availability_mask: Tensor + ) -> Tensor: + """Builds an attention mask for the transformer + + Builds a stochastic mask first, then joins it with `availability_mask`. + Ensures that the CLS token remains unmasked, and that + at least one modality is available for the transformer. + + Args: + batch: The batch size (number of sequences of the input batch) + T: The total number of sleep epochs + availability_mask: Mask indicating which modalities are available + + Returns: + Tensor: The complete mask to pass to the transformer + """ + + device = self.cls_token.device + + probs = torch.tensor( + [self.stochastic_mask_probabilities[m] for m in self.modalities] + ) + probs = probs.unsqueeze(0).expand(batch, -1) + stochastic_mask = torch.bernoulli(probs).bool() + + # Combine availability and stochastics masks + complete_mask = availability_mask | stochastic_mask + + # Guarantee at least one modality is visible per recording + all_masked = complete_mask.all(dim=1) + if all_masked.any(): # if all modalities get masked + # unmask the first available modality (index 0) + complete_mask[all_masked, 0] = False + + # Expand across T epochs and fold into batch dimension + complete_mask = ( + complete_mask.unsqueeze(1).expand(-1, T, -1).reshape(batch * T, -1) + ) + + # Prepend False for CLS — never masked + cls_mask = torch.zeros(batch * T, 1, dtype=torch.bool, device=device) + + mask = torch.cat([cls_mask, complete_mask], dim=1) + return mask + + def forward(self, x: Tensor, availability_mask: Tensor) -> Tensor: + """Fuse modalities into one unified representation of sleep epoch sequences + + Args: + x: Stacked modality encodings per epoch; + shape = (batch_size, T, num_modalities, feature_dim) + availability_mask: Mask indicating which modalities are available + + Returns: + Tensor: A unified feature sequence; shape = (batch_size, T, feature_dim) + """ + batch_size = x.shape[0] + T = x.shape[1] + + x = x.reshape(batch_size * T, -1, FEATURE_DIM) + cls_tokens = self.cls_token.expand(batch_size * T, 1, FEATURE_DIM) + x = torch.cat( + [cls_tokens, x], dim=1 + ) # (batch*T, num_modalities + 1, feature_dim) + + # Create a per-recording mask + mask = ( + self._build_attention_mask(batch_size, T, availability_mask) + if self.training + else None + ) # (batch*T, num_modalities + 1, feature_dim) + + out = self.transformer(x, src_key_padding_mask=mask) + + # Slice CLS position — the unified summary for each epoch + z = out[:, 0, :] # (batch*T, feature_dim) + z = z.reshape(batch_size, T, FEATURE_DIM) + + return z + + +class SequenceMixer(nn.Module): + """Architecture for the Sequence Mixer as described in Section 3.3 + + Captures temporal dependencies among encoded sequences using a stack + of dilated convolutional blocks with increasing dilation factors. + + Args: + dilated_blocks: Number of dilated blocks to use + kernel_size: kernel_size for the dilated blocks + """ + + def __init__(self, dilated_blocks: int = 2, kernel_size: int = 7) -> None: + super().__init__() + dilations = [1, 2, 4, 8, 16, 32] + + blocks = [ + DilatedConvBlock(d, kernel_size) + for _ in range(dilated_blocks) + for d in dilations + ] + + self.dilated_cnns = nn.Sequential(*blocks) + + def forward(self, x: Tensor) -> Tensor: + """Processes a sequence of feature vectors through the mixer + + Applies 1D dilated convolutions over the time dimension to capture + long‑range temporal dependencies, while preserving the sequence + length and feature dimension. + + Args: + x: The unified sequence of feature vectors; shape = (batch_size, T, feature_dim) + + Returns: + Tensor: The transformed sequence; shape = (batch_size, T, feature_dim) + """ + + out = x.transpose(1, 2) + out = self.dilated_cnns(out) + out = out.transpose(1, 2) + return out + + +class DilatedConvBlock(nn.Module): + """Dilated Convolution Block + + Applies a 1D dilated convolution followed by layer normalization, + an activation function, and dropout, with a residual connection + that adds the input back to the output. Used to capture long‑range + temporal dependencies while preserving the sequence length. + + Args: + dilation: Dilation factor for convolution + kernel_size: Size of the convolutional kernel + """ + + def __init__(self, dilation: int, kernel_size: int = 7) -> None: + super().__init__() + padding = (kernel_size - 1) * dilation // 2 + self.conv = nn.Conv1d( + FEATURE_DIM, FEATURE_DIM, kernel_size, dilation=dilation, padding=padding + ) + self.norm = nn.LayerNorm(FEATURE_DIM) + self.activation = ACTIVATION_FUNCTION + self.dropout = nn.Dropout(DROPOUT_RATE) + + def forward(self, x: Tensor) -> Tensor: + """Forward pass of the dilated convolution block + + Args: + x: Input sequence of feature vectors; shape = (batch_size, feature_dim, T) + + Returns: + Tensor: Output sequence of feature vectors; shape = (batch_size, feature_dim, T) + """ + out = self.conv(x) + out = out.transpose(1, 2) + out = self.norm(out) + out = out.transpose(1, 2) + out = self.activation(out) + out = self.dropout(out) + out += x + return out + + +class Wav2Sleep(BaseModel): + """The wav2sleep model + + Classifies sleep stage sequences from sets of time-series biosignals. + A trained model can applied to any subset of the signal modalities seen during training. + + The model consists of + Signal Encoders for each modality + An Epoch Mixer to fuse cross-modal information for each sleep epoch + A Sequence Mixer to mix temporal information + + Args: + dataset: The dataset used to train the model + modalities: The list of modalities to train the model with + + Example: + >>> from pyhealth.datasets import Wav2SleepDataset + >>> from pyhealth.tasks import Wav2SleepStaging + >>> wav2sleep_dataset = Wav2SleepDataset(root = "path/to/root") + >>> task = Wav2SleepStaging() + >>> samples = wav2sleep_dataset.set_task(task) + >>> wav2sleep_model = Wav2Sleep(samples) # train with all modalities (default) + >>> train_loader = get_dataloader(samples, batch_size=2, shuffle=False) + >>> data_batch = next(iter(train_loader)) + >>> output = wav2sleep_model(**data_batch) + >>> print(output) + """ + + def __init__( + self, + dataset: SampleDataset, + modalities: Optional[List[str]] = None, + stochastic_mask_probabilities: Optional[Dict[str, float]] = None, + ) -> None: + if modalities is None: + modalities = ["ECG", "PPG", "THX", "ABD"] + + # Validate stochastic_mask_probabilities: all modalities present, values are numbers in [0, 1] + if stochastic_mask_probabilities is not None: + for modality in modalities: + if modality not in stochastic_mask_probabilities: + raise ValueError( + f"Missing mask probability for modality '{modality}'" + ) + probability = stochastic_mask_probabilities[modality] + if probability is None: + raise ValueError("The mask probability must not be None") + if (probability < 0) or (probability > 1): + raise ValueError( + f"The mask probability must be in [0, 1], got {probability}" + ) + + super(Wav2Sleep, self).__init__(dataset=dataset) + + # signal_type : resample_rate + self.all_modalities = {"ECG": 1024, "PPG": 1024, "THX": 256, "ABD": 256} + + self.selected_modalities = { + k: self.all_modalities[k] for k in modalities if k in self.all_modalities + } + + # Initialize Signal Encoders for each modality + self.signal_encoders = nn.ModuleDict( + { + signal_type: SignalEncoder(signal_sample_rate=resample_rate) + for signal_type, resample_rate in self.selected_modalities.items() + } + ) + # Initialize Epoch Mixer to learn attention between the modalities at each epoch + self.epoch_mixer = EpochMixer( + modalities=modalities, + stochastic_mask_probabilities=stochastic_mask_probabilities, + ) + # Initialize Sequence Mixer to mix temporal information and output predicted sleep stages + self.sequence_mixer = SequenceMixer() + + # 4 classes total (Wake, Light Sleep, Deep Sleep, REM) + self.num_classes = 4 + self.classifier = nn.Linear( + in_features=FEATURE_DIM, out_features=self.num_classes + ) + + def forward(self, **kwargs) -> Dict[str, Tensor]: + """Forward pass for the Wav2Sleep model + + Transforms a set of raw polysomnography biosignals into sleep-stage predictions. + + Args: + **kwargs: Batch dictionary containing: + - signals (Dict[str, Tensor]): + Input modality tensors; shape = (batch, signal_length) + - availability_mask (Tensor): + Indicates unavailable modalities; shape = (batch, num_modalities) + - stages (Tensor, optional): Sleep stage labels; shape = (batch, T) + + Returns: + Dict[str, Tensor]: + y_prob: logits; shape = (batch, T, num_classes) + y_hat: sleep stage predictions; shape = (batch, T) + loss: cross-entropy loss (if ground truth labels `stages` was provided) + """ + + signals = {m: kwargs[m] for m in self.selected_modalities.keys()} + availability_mask = kwargs["availability_mask"] + stages = kwargs.get("stages", None) + + encoded_signals = torch.stack( + [ + self.signal_encoders[m](signals[m]) + for m in self.selected_modalities.keys() + ], + dim=2, + ) # → (batch, T, num_modalities, feature_dim) + + selected_indices = [ + list(self.all_modalities.keys()).index(m) + for m in self.selected_modalities.keys() + ] + filtered_mask = availability_mask[:, selected_indices].bool() + + mixed = self.epoch_mixer(encoded_signals, filtered_mask) + + logits = self.sequence_mixer(mixed) + + y_T = self.classifier(logits) + + y_prob = torch.softmax(y_T, dim=-1) + output = {"y_prob": y_prob, "y_hat": y_prob.argmax(dim=-1)} + if stages is not None: + loss = F.cross_entropy( + y_T.view(-1, self.num_classes), stages.view(-1).long(), ignore_index=-1 + ) + output["loss"] = loss + return output + + +if __name__ == "__main__": + batch_size = 2 + T = 1200 + + from pyhealth.datasets import Wav2SleepDataset, get_dataloader + from pyhealth.tasks import Wav2SleepStaging + + wav2sleep_dataset = Wav2SleepDataset(root="../../../full_sample_PSG/") + task = Wav2SleepStaging() + samples = wav2sleep_dataset.set_task(task) + wav2sleep_model = Wav2Sleep(samples) + + train_loader = get_dataloader(samples, batch_size=2, shuffle=False) + data_batch = next(iter(train_loader)) + + output = wav2sleep_model(**data_batch) + print(output) diff --git a/pyhealth/tasks/__init__.py b/pyhealth/tasks/__init__.py index a32618f9c..11896094e 100644 --- a/pyhealth/tasks/__init__.py +++ b/pyhealth/tasks/__init__.py @@ -67,3 +67,4 @@ VariantClassificationClinVar, ) from .patient_linkage_mimic3 import PatientLinkageMIMIC3Task +from .wav2sleep_staging import Wav2SleepStaging \ No newline at end of file diff --git a/pyhealth/tasks/wav2sleep_staging.py b/pyhealth/tasks/wav2sleep_staging.py new file mode 100644 index 000000000..2face22ff --- /dev/null +++ b/pyhealth/tasks/wav2sleep_staging.py @@ -0,0 +1,294 @@ +""" +Author(s): Bronze Frazer +NetID(s): bfrazer2 +Paper: wav2sleep: A Unified Multi-Modal Approach to Sleep Stage Classification from Physiological Signals +Link: https://arxiv.org/abs/2411.04644 +Desc: PyHealth Task for sleep stage classification using wav2sleep +""" + +import logging +from typing import Any, Dict + +from pyhealth.data import Patient +from pyhealth.tasks import BaseTask + +import mne +import numpy as np +import pandas as pd +import scipy +import xml.etree.ElementTree as ET + +logger = logging.getLogger(__name__) + + +class Wav2SleepStaging(BaseTask): + """Multi-class sleep stage classification from heterogenous biosignal data + + This task prepares biosignal and annotation data from multiple + polysomnography datasets. + + Attributes: + task_name: The name of the task. Set to "Wav2SleepStaging" + input_schema: Schema for the task input + output_schema: Schema for the task output + XML_STAGE_MAP: A label map for annotaion-events-profusion files to expected labels + WSC_STAGE_MAP: A label map for WSC dataset annotations to expected labels + CHANNEL_MAPS: A map to extract the expected signal names as recorded in the EDF for a dataset + + Examples: + >>> from pyhealth.datasets import Wav2SleepDataset + >>> from pyhealth.tasks import Wav2SleepStaging + >>> wav2sleep_dataset = Wav2SleepDataset(root = "path/to/root") + >>> task = Wav2SleepStaging() + >>> samples = wav2sleep_dataset.set_task(task) + """ + + task_name: str = "Wav2SleepStaging" + input_schema: Dict[str, str] = { + "ECG": "tensor", + "PPG": "tensor", + "THX": "tensor", + "ABD": "tensor", + "availability_mask": "tensor", + } + output_schema: Dict[str, str] = {"stages": "tensor"} + + # 0=Wake, 1=Light(N1+N2), 2=Deep(N3), 3=REM, -1=Unscored + XML_STAGE_MAP: Dict[int, int] = {0: 0, 1: 1, 2: 1, 3: 2, 4: 2, 5: 3, 9: -1} + WSC_STAGE_MAP: Dict[int, int] = {0: 0, 1: 1, 2: 1, 3: 2, 5: 3, 7: -1} + + # Mappings for the different names of biosignal across the datasets + CHANNEL_MAPS: Dict[str, Dict[str, Any]] = { + "shhs": {"ECG": "ECG", "PPG": None, "THX": "THOR RES", "ABD": "ABDO RES"}, + "mesa": {"ECG": "EKG", "PPG": "Pleth", "THX": "Thor", "ABD": "Abdo"}, + "cfs": { + "ECG": "ECG1", + "PPG": "PlethWV", + "THX": "THOR EFFORT", + "ABD": "ABDO EFFORT", + }, + "chat": {"ECG": "ECG1", "PPG": "PlethNellcor", "THX": "Chest", "ABD": "ABD"}, + "mros": {"ECG": "ECG L", "PPG": None, "THX": "Thoracic", "ABD": "Abdominal"}, + "ccshs": { + "ECG": "ECG1", + "PPG": "PlethWV", + "THX": "THOR EFFORT", + "ABD": "ABDO EFFORT", + }, + "wsc": {"ECG": "ECG", "PPG": None, "THX": "thorax", "ABD": "abdomen"}, + } + + def __init__(self) -> None: + super().__init__() + + def __call__(self, patient: Patient) -> list[Dict[str, Any]]: + """Process patient polysomnography biosignals for sleep stage prediction + + Args: + patient: A patient object containing biosignals (ECG, PPG, THX, ABD), + a mask to indicate which signals are available, + and sleep stage annotations + + Returns: + list[Dict[str, Any]]: A list of samples for a patient, where each sample + is a dict containing the patient's id, their preprocessed biosignals (ECG, PPG, THX, ABD), + an availability mask for the signals, and ground-truth sleep stages + """ + + samples = [] + for event in patient.get_events(): + signals, availability_mask = self.load_signals( + event.edf_path, event.source_dataset + ) + stages = self.load_stages(event.label_path, event.source_dataset) + sample = { + "patient_id": patient.patient_id, + **signals, + "availability_mask": availability_mask, + "stages": stages, + } + samples.append(sample) + + return samples + + def load_signals( + self, edf_path: str, dataset: str + ) -> tuple[Dict[str, np.ndarray], np.ndarray]: + """Extract raw biosignals (EDF, PPG, THX, ABD) from and EDF file + + Args: + edf_path: Path to the EDF file + dataset: name of the source dataset (e.g. SHHS, CFS, ...) + + Returns: + tuple[Dict[str, np.ndarray], np.ndarray]: + A dictionary of raw biosignals, and a boolean array indicating + which signals were not found in a recording + """ + signals = {"ECG": None, "PPG": None, "THX": None, "ABD": None} + + channel_map = self.CHANNEL_MAPS[dataset] + src_names = [src for src in channel_map.values() if src is not None] + + raw = mne.io.read_raw_edf( + edf_path, include=src_names, preload=True, verbose=False + ) + + src_to_canonical = { + src: sig for sig, src in channel_map.items() if src is not None + } + + data = raw.get_data() + + signals.update( + { + src_to_canonical[ch]: data[i] + for i, ch in enumerate(raw.ch_names) + if ch in src_to_canonical + } + ) + + original_sample_rate = raw.info["sfreq"] + availability_mask = [] + for signal_type, signal in signals.items(): + mask_value = signal is None + availability_mask.append(mask_value) + preprocessed_signal = self.preprocess_signal( + signal_type, signal, original_sample_rate + ) + signals[signal_type] = preprocessed_signal + + return signals, availability_mask + + def load_stages(self, annotation_path: str, dataset: str) -> np.ndarray: + """Parse ground-truth sleep stage labels from an annotation file + + The categories are: + 0 = Wake + 1 = Light(N1+N2) + 2 = Deep(N3) + 3 = REM + -1 = Unscored + + Args: + annotation_path: path to a sleep stage annotation file + dataset: name of the source dataset (e.g. SHHS, CFS, ...) + + Returns: + np.ndarray: An integer array of ground truth sleep stage cateogies + """ + + def load_stages_xml(xml_path: str) -> np.ndarray: + """Parse the XML annotation format""" + root = ET.parse(xml_path).getroot() + stages = [self.XML_STAGE_MAP[int(e.text)] for e in root.iter("SleepStage")] + return np.array(stages, dtype=np.int8) + + def load_stages_wsc(stg_path: str) -> np.ndarray: + """Parse the TXT annotation format (only used for the WSC dataset)""" + df = pd.read_csv(stg_path, sep="\t") + return ( + df["User-Defined Stage"].map(self.WSC_STAGE_MAP).to_numpy(dtype=np.int8) + ) + + STAGE_LOADERS = { + "shhs": load_stages_xml, + "mesa": load_stages_xml, + "cfs": load_stages_xml, + "chat": load_stages_xml, + "mros": load_stages_xml, + "ccshs": load_stages_xml, + "wsc": load_stages_wsc, + } + + stages = STAGE_LOADERS[dataset](annotation_path) + fixed_epoch_stages = self._pad_or_truncate( + stages, is_label=True, target_length=1200 + ) # T=1200 epochs + return fixed_epoch_stages + + def preprocess_signal( + self, signal_type: str, signal: np.ndarray, original_sample_rate: float + ) -> np.ndarray: + """Pre-processing for raw biosignal data + + The steps are outlined in section 4.1. of wav2sleep + Pad or truncate to 10 hours (T = 1200) + Resample each biosignal to target frequency + (~34Hz for ECG & PPG, ~8Hz for THX & ABD) + Apply unit normalisation + + Args: + signal_type: name of the biosignal to be preprocessed + signal: the raw biosignal + original_sample_rate: the sample rate (in Hertz) of the EDF recording + + Returns: + np.ndarray: A preprocessed biosignal + """ + T = 1200 # target number of epochs + seconds_per_epoch = 30 # epochs are 30 seconds each + + # k = total number of measurements to retain per epoch + if signal_type in ["ECG", "PPG"]: + k = 1024 + elif signal_type in ["ABD", "THX"]: + k = 256 + + target_raw_samples = T * seconds_per_epoch * original_sample_rate + target_output_samples = T * k + + # return a zeros array of target_output_samples if a signal is None + # (signal_type is not present in the dataset) + if signal is None: + return np.zeros(target_output_samples, dtype=np.float32) + + # Step 1: Pad or truncate to exactly 10 hours + signal = self._pad_or_truncate( + signal, is_label=False, target_length=int(target_raw_samples) + ) + # Step 2: Resample to target frequency + signal = scipy.signal.resample(signal, target_output_samples) + # Step 3: Unit normalization + signal = (signal - signal.mean()) / signal.std() + + return signal + + def _pad_or_truncate( + self, array: np.ndarray, is_label: bool, target_length: int + ) -> np.ndarray: + """Pad or truncate an array to a desired size + + Args: + array: an array to be padded/truncated + is_label: A flag indicating if `array` should be treated as a list of labels + target_length: desired length of the array + + Returns: + np.ndarray: A padded or truncated array + """ + if len(array) >= target_length: + return array[:target_length] + pad_length = target_length - len(array) + # zero-pad signal arrays, but pad label arrays with -1 (unscored) + constant_values = -1 if is_label else 0 + return np.pad( + array, (0, pad_length), mode="constant", constant_values=constant_values + ) + + +if __name__ == "__main__": + from pyhealth.datasets import Wav2SleepDataset + from pyhealth.tasks import Wav2SleepStaging + + wav2sleep_dataset = Wav2SleepDataset(root="../../../full_sample_PSG/") + task = Wav2SleepStaging() + sample_dataset = wav2sleep_dataset.set_task(task) + + # test the DataLoader and discover shape of inputs to model + # data loader + from pyhealth.datasets import get_dataloader + + train_loader = get_dataloader(sample_dataset, batch_size=7, shuffle=False) + data_batch = next(iter(train_loader)) + print(data_batch) diff --git a/tests/core/test_wav2sleep.py b/tests/core/test_wav2sleep.py new file mode 100644 index 000000000..ab0071883 --- /dev/null +++ b/tests/core/test_wav2sleep.py @@ -0,0 +1,708 @@ +"""Tests for Wav2SleepDataset and Wav2SleepStaging.""" + +import shutil +import tempfile +import unittest +import xml.etree.ElementTree as ET +from pathlib import Path +from unittest.mock import MagicMock, patch + +import numpy as np +import pandas as pd +import torch +from torch import nn + +from pyhealth.models.base_model import BaseModel +from pyhealth.models.wav2sleep import ( + DilatedConvBlock, + EpochMixer, + FEATURE_DIM, + ResidualBlock, + SequenceMixer, + SignalEncoder, + Wav2Sleep, +) +from pyhealth.tasks.wav2sleep_staging import Wav2SleepStaging + + +# ───────────────────────────────────────────── +# Helpers +# ───────────────────────────────────────────── + + +def make_xml_annotation(path: Path, stages: list[int]) -> None: + """Write a minimal profusion-style XML annotation file.""" + root = ET.Element("CMPStudyConfig") + stage_container = ET.SubElement(root, "SleepStages") + for s in stages: + el = ET.SubElement(stage_container, "SleepStage") + el.text = str(s) + ET.ElementTree(root).write(str(path)) + + +def make_wsc_annotation(path: Path, stages: list[int]) -> None: + """Write a minimal WSC-style TSV annotation file.""" + df = pd.DataFrame({"User-Defined Stage": stages}) + df.to_csv(str(path), sep="\t", index=False) + + +def make_mock_raw(ch_names: list[str], sfreq: float = 256.0, n_samples: int = 7680): + """Return a mock MNE Raw object with synthetic data.""" + mock_raw = MagicMock() + mock_raw.ch_names = ch_names + mock_raw.info = {"sfreq": sfreq} + mock_raw.get_data.return_value = np.random.randn(len(ch_names), n_samples).astype( + np.float32 + ) + return mock_raw + + +def build_tmp_root(root: Path) -> None: + """ + Builds a minimal synthetic PSG root directory for dataset tests. + + shhs/ + polysomnography/ + edfs/shhs1/ + patient-001.edf (empty placeholder) + annotations-events-profusion/shhs1/ + patient-001-profusion.xml + + mesa/ + polysomnography/ + edfs/ + patient-002.edf + annotations-events-profusion/ + patient-002-profusion.xml + + wsc/ + polysomnography/ + patient-003.edf + patient-003.stg.txt + """ + # SHHS (has subdirectory layer) + shhs_edf = root / "shhs" / "polysomnography" / "edfs" / "shhs1" + shhs_lbl = ( + root / "shhs" / "polysomnography" / "annotations-events-profusion" / "shhs1" + ) + shhs_edf.mkdir(parents=True) + shhs_lbl.mkdir(parents=True) + (shhs_edf / "patient-001.edf").touch() + make_xml_annotation(shhs_lbl / "patient-001-profusion.xml", [0, 1, 2, 5, 9]) + + # MESA (flat layout, has PPG) + mesa_edf = root / "mesa" / "polysomnography" / "edfs" + mesa_lbl = root / "mesa" / "polysomnography" / "annotations-events-profusion" + mesa_edf.mkdir(parents=True) + mesa_lbl.mkdir(parents=True) + (mesa_edf / "patient-002.edf").touch() + make_xml_annotation(mesa_lbl / "patient-002-profusion.xml", [0, 0, 5, 3]) + + # WSC (flat layout, TSV annotations) + wsc_psg = root / "wsc" / "polysomnography" + wsc_psg.mkdir(parents=True) + (wsc_psg / "patient-003.edf").touch() + make_wsc_annotation(wsc_psg / "patient-003.stg.txt", [0, 1, 2, 5]) + + +# ───────────────────────────────────────────── +# Dataset Tests +# ───────────────────────────────────────────── + + +class TestWav2SleepDatasetMetadata(unittest.TestCase): + def setUp(self): + self.tmp_dir = tempfile.mkdtemp() + self.tmp_root = Path(self.tmp_dir) / "psg_root" + self.tmp_root.mkdir() + build_tmp_root(self.tmp_root) + + def tearDown(self): + shutil.rmtree(self.tmp_dir) + + def _make_dataset_instance(self): + """Return a Wav2SleepDataset with __init__ bypassed.""" + from pyhealth.datasets.wav2sleep import Wav2SleepDataset + + with patch.object(Wav2SleepDataset, "__init__", lambda self, **kw: None): + ds = Wav2SleepDataset.__new__(Wav2SleepDataset) + return ds + + def test_prepare_metadata_creates_csv(self): + """prepare_metadata should produce a CSV at root/wav2sleep-metadata.csv.""" + ds = self._make_dataset_instance() + ds.prepare_metadata(str(self.tmp_root)) + self.assertTrue((self.tmp_root / "wav2sleep-metadata.csv").exists()) + + def test_metadata_columns(self): + """CSV should have patient_id, source_dataset, edf_path, label_path.""" + ds = self._make_dataset_instance() + ds.prepare_metadata(str(self.tmp_root)) + df = pd.read_csv(self.tmp_root / "wav2sleep-metadata.csv") + self.assertTrue( + {"patient_id", "source_dataset", "edf_path", "label_path"}.issubset( + df.columns + ) + ) + + def test_metadata_patient_count(self): + """Should find exactly 3 patients across the 3 synthetic datasets.""" + ds = self._make_dataset_instance() + ds.prepare_metadata(str(self.tmp_root)) + df = pd.read_csv(self.tmp_root / "wav2sleep-metadata.csv") + self.assertEqual(len(df), 3) + + def test_metadata_skips_missing_label(self): + """If an EDF has no matching annotation, it should be skipped.""" + (self.tmp_root / "mesa" / "polysomnography" / "edfs" / "orphan.edf").touch() + ds = self._make_dataset_instance() + ds.prepare_metadata(str(self.tmp_root)) + df = pd.read_csv(self.tmp_root / "wav2sleep-metadata.csv") + self.assertNotIn("orphan", df["patient_id"].values) + + def test_get_edf_and_label_dirs_flat(self): + """Flat datasets (mesa) should return a single (edf_dir, label_dir) pair.""" + ds = self._make_dataset_instance() + pairs = ds.get_edf_and_label_dirs(self.tmp_root / "mesa") + self.assertEqual(len(pairs), 1) + + def test_get_edf_and_label_dirs_subdirs(self): + """Datasets with subdirectory layers (shhs) should return one pair per subdir.""" + ds = self._make_dataset_instance() + pairs = ds.get_edf_and_label_dirs(self.tmp_root / "shhs") + self.assertEqual(len(pairs), 1) + self.assertEqual(pairs[0][0].name, "shhs1") + + def test_invalid_root_raises(self): + """Passing a nonexistent root should raise FileNotFoundError.""" + from pyhealth.datasets.wav2sleep import Wav2SleepDataset + + with self.assertRaises(FileNotFoundError): + Wav2SleepDataset(root=str(self.tmp_root / "does_not_exist")) + + +# ───────────────────────────────────────────── +# Task Tests — annotation parsing +# ───────────────────────────────────────────── + + +class TestLoadStages(unittest.TestCase): + def setUp(self): + self.task = Wav2SleepStaging() + self.tmp_dir = tempfile.mkdtemp() + self.tmp_path = Path(self.tmp_dir) + + def tearDown(self): + shutil.rmtree(self.tmp_dir) + + def test_xml_wake_maps_correctly(self): + """XML stage 0 → category 0 (Wake).""" + p = self.tmp_path / "p.xml" + make_xml_annotation(p, [0]) + result = self.task.load_stages(str(p), "shhs") + self.assertEqual(result[0], 0) + + def test_xml_n1_n2_map_to_light(self): + """XML stages 1 and 2 → category 1 (Light).""" + p = self.tmp_path / "p.xml" + make_xml_annotation(p, [1, 2]) + result = self.task.load_stages(str(p), "shhs") + self.assertEqual(result[0], 1) + self.assertEqual(result[1], 1) + + def test_xml_n3_maps_to_deep(self): + """XML stages 3 and 4 → category 2 (Deep).""" + p = self.tmp_path / "p.xml" + make_xml_annotation(p, [3, 4]) + result = self.task.load_stages(str(p), "shhs") + self.assertEqual(result[0], 2) + self.assertEqual(result[1], 2) + + def test_xml_rem_maps_correctly(self): + """XML stage 5 → category 3 (REM).""" + p = self.tmp_path / "p.xml" + make_xml_annotation(p, [5]) + result = self.task.load_stages(str(p), "shhs") + self.assertEqual(result[0], 3) + + def test_xml_unscored_maps_to_minus_one(self): + """XML stage 9 → category -1 (Unscored).""" + p = self.tmp_path / "p.xml" + make_xml_annotation(p, [9]) + result = self.task.load_stages(str(p), "shhs") + self.assertEqual(result[0], -1) + + def test_wsc_stage_mapping(self): + """WSC TSV stages map correctly via WSC_STAGE_MAP.""" + p = self.tmp_path / "p.stg.txt" + make_wsc_annotation(p, [0, 1, 5, 7]) + result = self.task.load_stages(str(p), "wsc") + self.assertEqual(list(result[:4]), [0, 1, 3, -1]) + + def test_output_padded_to_1200(self): + """Short recordings should be padded to T=1200 epochs.""" + p = self.tmp_path / "p.xml" + make_xml_annotation(p, [0, 5]) + result = self.task.load_stages(str(p), "shhs") + self.assertEqual(len(result), 1200) + self.assertEqual(result[1199], -1) + + def test_output_truncated_to_1200(self): + """Recordings longer than 1200 epochs should be truncated.""" + p = self.tmp_path / "p.xml" + make_xml_annotation(p, [0] * 1500) + result = self.task.load_stages(str(p), "shhs") + self.assertEqual(len(result), 1200) + + +# ───────────────────────────────────────────── +# Task Tests — signal preprocessing +# ───────────────────────────────────────────── + + +class TestPadOrTruncate(unittest.TestCase): + def setUp(self): + self.task = Wav2SleepStaging() + + def test_truncates_signal(self): + arr = np.ones(200) + result = self.task._pad_or_truncate(arr, is_label=False, target_length=100) + self.assertEqual(len(result), 100) + + def test_pads_signal_with_zeros(self): + arr = np.ones(50) + result = self.task._pad_or_truncate(arr, is_label=False, target_length=100) + self.assertEqual(len(result), 100) + self.assertEqual(result[99], 0.0) + + def test_pads_labels_with_minus_one(self): + arr = np.array([0, 1, 2], dtype=np.int8) + result = self.task._pad_or_truncate(arr, is_label=True, target_length=10) + self.assertEqual(result[9], -1) + + def test_exact_length_unchanged(self): + arr = np.arange(100) + result = self.task._pad_or_truncate(arr, is_label=False, target_length=100) + self.assertEqual(len(result), 100) + np.testing.assert_array_equal(result, arr) + + +class TestPreprocessSignal(unittest.TestCase): + def setUp(self): + self.task = Wav2SleepStaging() + + def test_none_signal_returns_zeros(self): + """Missing signals (None) should return a zero-filled array.""" + result = self.task.preprocess_signal("ECG", None, 256.0) + self.assertTrue(np.all(result == 0)) + self.assertEqual(result.dtype, np.float32) + + def test_output_is_unit_normalized(self): + """Output should have mean ≈ 0 and std ≈ 1.""" + signal = np.random.randn(7680).astype(np.float32) + result = self.task.preprocess_signal("ECG", signal, 256.0) + self.assertAlmostEqual(result.mean(), 0.0, delta=0.05) + self.assertAlmostEqual(result.std(), 1.0, delta=0.05) + + def test_ecg_output_length(self): + """ECG output should have T * k = 1200 * 1024 samples.""" + signal = np.random.randn(7680).astype(np.float32) + result = self.task.preprocess_signal("ECG", signal, 256.0) + self.assertEqual(len(result), 1200 * 1024) + + def test_thx_output_length(self): + """THX output should have T * k = 1200 * 256 samples.""" + signal = np.random.randn(7680).astype(np.float32) + result = self.task.preprocess_signal("THX", signal, 256.0) + self.assertEqual(len(result), 1200 * 256) + + +# ───────────────────────────────────────────── +# Task Tests — load_signals (mocked MNE) +# ───────────────────────────────────────────── + + +class TestLoadSignals(unittest.TestCase): + def setUp(self): + self.task = Wav2SleepStaging() + + def test_shhs_loads_no_ppg(self): + """SHHS has no PPG — availability_mask should mark PPG as missing.""" + mock_raw = make_mock_raw(["ECG", "THOR RES", "ABDO RES"]) + with patch("mne.io.read_raw_edf", return_value=mock_raw): + signals, mask = self.task.load_signals("fake.edf", "shhs") + ppg_idx = list(signals.keys()).index("PPG") + self.assertTrue(mask[ppg_idx]) + + def test_mesa_loads_ppg(self): + """MESA has PPG — availability_mask should mark PPG as present.""" + mock_raw = make_mock_raw(["EKG", "Pleth", "Thor", "Abdo"]) + with patch("mne.io.read_raw_edf", return_value=mock_raw): + signals, mask = self.task.load_signals("fake.edf", "mesa") + ppg_idx = list(signals.keys()).index("PPG") + self.assertFalse(mask[ppg_idx]) + + def test_signal_keys(self): + """Returned signals dict must contain ECG, PPG, THX, ABD.""" + mock_raw = make_mock_raw(["ECG", "THOR RES", "ABDO RES"]) + with patch("mne.io.read_raw_edf", return_value=mock_raw): + signals, _ = self.task.load_signals("fake.edf", "shhs") + self.assertEqual(set(signals.keys()), {"ECG", "PPG", "THX", "ABD"}) + + def test_availability_mask_length(self): + """Availability mask should have one entry per signal (4 total).""" + mock_raw = make_mock_raw(["ECG", "THOR RES", "ABDO RES"]) + with patch("mne.io.read_raw_edf", return_value=mock_raw): + _, mask = self.task.load_signals("fake.edf", "shhs") + self.assertEqual(len(mask), 4) + + +class TestPatientEvents(unittest.TestCase): + """Tests that Wav2SleepStaging.__call__ correctly parses patient events.""" + + def setUp(self): + self.task = Wav2SleepStaging() + + def _make_event(self, edf_path: str, label_path: str, source_dataset: str): + """Build a mock event with the fields __call__ accesses.""" + event = MagicMock() + event.edf_path = edf_path + event.label_path = label_path + event.source_dataset = source_dataset + return event + + def _make_patient(self, events: list) -> MagicMock: + """Build a mock Patient whose get_events() returns the given events.""" + patient = MagicMock() + patient.patient_id = "test-patient" + patient.get_events.return_value = events + return patient + + def test_call_returns_one_sample_per_event(self): + """__call__ should produce exactly one sample per event on the patient.""" + fake_signals = { + "ECG": np.zeros(1200 * 1024, dtype=np.float32), + "PPG": np.zeros(1200 * 1024, dtype=np.float32), + "THX": np.zeros(1200 * 256, dtype=np.float32), + "ABD": np.zeros(1200 * 256, dtype=np.float32), + } + fake_mask = [False, True, False, False] + fake_stages = np.zeros(1200, dtype=np.int8) + + event = self._make_event("fake.edf", "fake.xml", "shhs") + patient = self._make_patient([event]) + + with ( + patch.object( + self.task, "load_signals", return_value=(fake_signals, fake_mask) + ), + patch.object(self.task, "load_stages", return_value=fake_stages), + ): + samples = self.task(patient) + + self.assertEqual(len(samples), 1) + + def test_sample_contains_required_keys(self): + """Each sample should contain patient_id, all signals, availability_mask, stages.""" + fake_signals = { + "ECG": np.zeros(1200 * 1024, dtype=np.float32), + "PPG": np.zeros(1200 * 1024, dtype=np.float32), + "THX": np.zeros(1200 * 256, dtype=np.float32), + "ABD": np.zeros(1200 * 256, dtype=np.float32), + } + fake_mask = [False, False, False, False] + fake_stages = np.zeros(1200, dtype=np.int8) + + event = self._make_event("fake.edf", "fake.xml", "mesa") + patient = self._make_patient([event]) + + with ( + patch.object( + self.task, "load_signals", return_value=(fake_signals, fake_mask) + ), + patch.object(self.task, "load_stages", return_value=fake_stages), + ): + samples = self.task(patient) + + expected_keys = { + "patient_id", + "ECG", + "PPG", + "THX", + "ABD", + "availability_mask", + "stages", + } + self.assertEqual(set(samples[0].keys()), expected_keys) + + def test_sample_patient_id_matches(self): + """patient_id in each sample should match the patient object.""" + fake_signals = { + k: np.zeros(10, dtype=np.float32) for k in ["ECG", "PPG", "THX", "ABD"] + } + fake_mask = [False, False, False, False] + fake_stages = np.zeros(1200, dtype=np.int8) + + event = self._make_event("fake.edf", "fake.xml", "shhs") + patient = self._make_patient([event]) + + with ( + patch.object( + self.task, "load_signals", return_value=(fake_signals, fake_mask) + ), + patch.object(self.task, "load_stages", return_value=fake_stages), + ): + samples = self.task(patient) + + self.assertEqual(samples[0]["patient_id"], "test-patient") + + def test_event_fields_forwarded_correctly(self): + """load_signals and load_stages should be called with the event's path and dataset.""" + fake_signals = { + k: np.zeros(10, dtype=np.float32) for k in ["ECG", "PPG", "THX", "ABD"] + } + fake_mask = [False, False, False, False] + fake_stages = np.zeros(1200, dtype=np.int8) + + event = self._make_event("path/to/recording.edf", "path/to/labels.xml", "mesa") + patient = self._make_patient([event]) + + with ( + patch.object( + self.task, "load_signals", return_value=(fake_signals, fake_mask) + ) as mock_signals, + patch.object( + self.task, "load_stages", return_value=fake_stages + ) as mock_stages, + ): + self.task(patient) + + mock_signals.assert_called_once_with("path/to/recording.edf", "mesa") + mock_stages.assert_called_once_with("path/to/labels.xml", "mesa") + + def test_multiple_events_produce_multiple_samples(self): + """A patient with multiple events should produce one sample per event.""" + fake_signals = { + k: np.zeros(10, dtype=np.float32) for k in ["ECG", "PPG", "THX", "ABD"] + } + fake_mask = [False, False, False, False] + fake_stages = np.zeros(1200, dtype=np.int8) + + events = [ + self._make_event(f"recording_{i}.edf", f"labels_{i}.xml", "shhs") + for i in range(3) + ] + patient = self._make_patient(events) + + with ( + patch.object( + self.task, "load_signals", return_value=(fake_signals, fake_mask) + ), + patch.object(self.task, "load_stages", return_value=fake_stages), + ): + samples = self.task(patient) + + self.assertEqual(len(samples), 3) + + +# ───────────────────────────────────────────── +# Model Tests — sub-components +# ───────────────────────────────────────────── + + +class TestResidualBlock(unittest.TestCase): + def test_output_shape(self): + """Output should be (batch, c_out, length // pool_size).""" + block = ResidualBlock(c_in=1, c_out=16, pool_size=2) + x = torch.randn(2, 1, 64) + out = block(x) + self.assertEqual(out.shape, (2, 16, 32)) + + def test_same_channel_skip_is_identity(self): + """c_in == c_out should use an Identity skip connection.""" + block = ResidualBlock(c_in=16, c_out=16) + self.assertIsInstance(block.skip, nn.Identity) + + def test_different_channel_skip_is_conv(self): + """c_in != c_out should use a Conv1d skip connection.""" + block = ResidualBlock(c_in=1, c_out=16) + self.assertIsInstance(block.skip, nn.Conv1d) + + +class TestSignalEncoder(unittest.TestCase): + def test_invalid_sample_rate_raises(self): + """Sample rates other than 1024 or 256 should raise ValueError.""" + with self.assertRaises(ValueError): + SignalEncoder(signal_sample_rate=512) + + def test_ecg_output_feature_dim(self): + """ECG encoder dense layer should project to FEATURE_DIM.""" + encoder = SignalEncoder(signal_sample_rate=1024) + self.assertEqual(encoder.dense.out_features, FEATURE_DIM) + + def test_thx_output_feature_dim(self): + """THX encoder dense layer should project to FEATURE_DIM.""" + encoder = SignalEncoder(signal_sample_rate=256) + self.assertEqual(encoder.dense.out_features, FEATURE_DIM) + + +class TestDilatedConvBlock(unittest.TestCase): + def test_output_shape_preserved(self): + """Dilated conv block should preserve sequence length and feature dim.""" + block = DilatedConvBlock(dilation=2, kernel_size=7) + x = torch.randn(2, FEATURE_DIM, 50) + out = block(x) + self.assertEqual(out.shape, (2, FEATURE_DIM, 50)) + + +class TestEpochMixer(unittest.TestCase): + def setUp(self): + self.T = 4 + self.batch = 2 + self.mixer = EpochMixer() + + def test_output_shape(self): + """Output should collapse modality dim, preserving (batch, T, feature_dim).""" + x = torch.randn(self.batch, self.T, 4, FEATURE_DIM) + mask = torch.zeros(self.batch, 4, dtype=torch.bool) + self.mixer.eval() + with torch.no_grad(): + out = self.mixer(x, mask) + self.assertEqual(out.shape, (self.batch, self.T, FEATURE_DIM)) + + def test_attention_mask_cls_always_unmasked(self): + """CLS token (index 0) should never be masked.""" + mask = torch.ones(self.batch, 4, dtype=torch.bool) + attn_mask = self.mixer._build_attention_mask(self.batch, self.T, mask) + self.assertFalse(attn_mask[:, 0].any()) + + def test_attention_mask_at_least_one_modality_visible(self): + """Even if all modalities are masked, at least one should be forced visible.""" + mask = torch.ones(self.batch, 4, dtype=torch.bool) + attn_mask = self.mixer._build_attention_mask(self.batch, self.T, mask) + modality_cols = attn_mask[:, 1:] + self.assertFalse(modality_cols.all()) + + +class TestSequenceMixer(unittest.TestCase): + def test_output_shape(self): + """Sequence mixer should preserve (batch, T, feature_dim).""" + mixer = SequenceMixer() + x = torch.randn(2, 10, FEATURE_DIM) + out = mixer(x) + self.assertEqual(out.shape, (2, 10, FEATURE_DIM)) + + +# ───────────────────────────────────────────── +# Model Tests — Wav2Sleep +# ───────────────────────────────────────────── + + +class _FakeEncoder(nn.Module): + """Minimal nn.Module stand-in for SignalEncoder in forward pass tests.""" + + def __init__(self, output: torch.Tensor) -> None: + super().__init__() + self._output = output + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self._output + + +class TestWav2Sleep(unittest.TestCase): + """Tests for the full Wav2Sleep model. + + Signal encoders are mocked to return pre-computed (batch, T, feature_dim) + tensors, keeping tests fast while still exercising the epoch mixer, + sequence mixer, and classifier. + """ + + T = 4 + BATCH = 2 + + def _make_model(self, modalities=None): + with patch.object( + BaseModel, "__init__", lambda self, **kw: nn.Module.__init__(self) + ): + return Wav2Sleep(dataset=MagicMock(), modalities=modalities) + + def _mock_encoders(self, model): + """Replace signal encoders with lightweight fakes that return synthetic tensors.""" + for m in model.selected_modalities: + model.signal_encoders[m] = _FakeEncoder( + torch.randn(self.BATCH, self.T, FEATURE_DIM) + ) + + def _make_batch(self, model, with_labels=False): + """Build a minimal input batch (real signal content is irrelevant — encoders are mocked).""" + batch = {m: torch.zeros(self.BATCH, 1) for m in model.selected_modalities} + batch["availability_mask"] = torch.zeros(self.BATCH, 4, dtype=torch.bool) + if with_labels: + batch["stages"] = torch.randint(0, 4, (self.BATCH, self.T)) + return batch + + def test_default_modalities(self): + """Default model should include all four modalities.""" + model = self._make_model() + self.assertEqual( + set(model.selected_modalities.keys()), {"ECG", "PPG", "THX", "ABD"} + ) + + def test_subset_modalities(self): + """Model should only include the modalities it was initialised with.""" + model = self._make_model(modalities=["ECG", "THX"]) + self.assertEqual(set(model.selected_modalities.keys()), {"ECG", "THX"}) + + def test_forward_output_keys_without_labels(self): + """Without ground-truth stages, output should have y_prob and y_hat only.""" + model = self._make_model() + self._mock_encoders(model) + model.eval() + with torch.no_grad(): + output = model(**self._make_batch(model)) + self.assertIn("y_prob", output) + self.assertIn("y_hat", output) + self.assertNotIn("loss", output) + + def test_forward_output_shapes(self): + """y_prob should be (batch, T, 4) and y_hat should be (batch, T).""" + model = self._make_model() + self._mock_encoders(model) + model.eval() + with torch.no_grad(): + output = model(**self._make_batch(model)) + self.assertEqual(output["y_prob"].shape, (self.BATCH, self.T, 4)) + self.assertEqual(output["y_hat"].shape, (self.BATCH, self.T)) + + def test_forward_with_labels_returns_loss(self): + """Providing ground-truth stages should add a scalar loss to the output.""" + model = self._make_model() + self._mock_encoders(model) + model.eval() + with torch.no_grad(): + output = model(**self._make_batch(model, with_labels=True)) + self.assertIn("loss", output) + self.assertIsInstance(output["loss"].item(), float) + + def test_y_prob_sums_to_one(self): + """Softmax output should sum to 1 across the class dimension.""" + model = self._make_model() + self._mock_encoders(model) + model.eval() + with torch.no_grad(): + output = model(**self._make_batch(model)) + sums = output["y_prob"].sum(dim=-1) + self.assertTrue(torch.allclose(sums, torch.ones_like(sums), atol=1e-5)) + + def test_gradient_computation(self): + """Loss should be differentiable — classifier weights should receive gradients.""" + model = self._make_model() + self._mock_encoders(model) + model.train() + output = model(**self._make_batch(model, with_labels=True)) + output["loss"].backward() + self.assertIsNotNone(model.classifier.weight.grad) + + +if __name__ == "__main__": + unittest.main() diff --git a/wav2sleep.yaml b/wav2sleep.yaml new file mode 100644 index 000000000..0967ef424 --- /dev/null +++ b/wav2sleep.yaml @@ -0,0 +1 @@ +{}