From 7d6365e268580b9cf25cfb37e3a53ba558cf6b58 Mon Sep 17 00:00:00 2001 From: unknown Date: Wed, 22 Apr 2026 23:04:32 -0500 Subject: [PATCH 1/3] Finalized example and slight fixes --- docs/api/datasets.rst | 1 + .../datasets/pyhealth.datasets.DSADataset.rst | 17 + ...health.tasks.DSAActivityClassification.rst | 7 + ....tasks.DSABinaryActivityClassification.rst | 7 + examples/dsa_activity_classification.ipynb | 1390 +++++++++++++++++ pyhealth/datasets/configs/dsa.yaml | 15 + pyhealth/datasets/dsa.py | 523 +++++++ pyhealth/tasks/dsa.py | 676 ++++++++ pyproject.toml | 3 + tests/core/test_dsa_dataset.py | 375 +++++ tests/core/test_dsa_tasks.py | 418 +++++ 11 files changed, 3432 insertions(+) create mode 100644 docs/api/datasets/pyhealth.datasets.DSADataset.rst create mode 100644 docs/api/tasks/pyhealth.tasks.DSAActivityClassification.rst create mode 100644 docs/api/tasks/pyhealth.tasks.DSABinaryActivityClassification.rst create mode 100644 examples/dsa_activity_classification.ipynb create mode 100644 pyhealth/datasets/configs/dsa.yaml create mode 100644 pyhealth/datasets/dsa.py create mode 100644 pyhealth/tasks/dsa.py create mode 100644 tests/core/test_dsa_dataset.py create mode 100644 tests/core/test_dsa_tasks.py diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst index 8d9a59d21..121f6f6d6 100644 --- a/docs/api/datasets.rst +++ b/docs/api/datasets.rst @@ -231,6 +231,7 @@ Available Datasets datasets/pyhealth.datasets.MIMICExtractDataset datasets/pyhealth.datasets.OMOPDataset datasets/pyhealth.datasets.DREAMTDataset + datasets/pyhealth.datasets.DSADataset datasets/pyhealth.datasets.SHHSDataset datasets/pyhealth.datasets.SleepEDFDataset datasets/pyhealth.datasets.EHRShotDataset diff --git a/docs/api/datasets/pyhealth.datasets.DSADataset.rst b/docs/api/datasets/pyhealth.datasets.DSADataset.rst new file mode 100644 index 000000000..8b7d78ddd --- /dev/null +++ b/docs/api/datasets/pyhealth.datasets.DSADataset.rst @@ -0,0 +1,17 @@ +pyhealth.datasets.DSADataset +=================================== + +The Daily and Sports Activities (DSA) dataset contains motion sensor data of 19 daily and sports activities each performed by 8 subjects in their own style for 5 minutes. Five Xsens MTx units are used on the torso, arms, and legs. + +Refer to https://archive-beta.ics.uci.edu/dataset/256/daily+and+sports+activities for more information + +.. autoclass:: pyhealth.datasets.DSADataset + :members: + :undoc-members: + :show-inheritance: + + + + + + \ No newline at end of file diff --git a/docs/api/tasks/pyhealth.tasks.DSAActivityClassification.rst b/docs/api/tasks/pyhealth.tasks.DSAActivityClassification.rst new file mode 100644 index 000000000..b082f6a43 --- /dev/null +++ b/docs/api/tasks/pyhealth.tasks.DSAActivityClassification.rst @@ -0,0 +1,7 @@ +pyhealth.tasks.DSAActivityClassification +======================================= + +.. autoclass:: pyhealth.tasks.DSAActivityClassification + :members: + :undoc-members: + :show-inheritance: \ No newline at end of file diff --git a/docs/api/tasks/pyhealth.tasks.DSABinaryActivityClassification.rst b/docs/api/tasks/pyhealth.tasks.DSABinaryActivityClassification.rst new file mode 100644 index 000000000..88b0b5ab7 --- /dev/null +++ b/docs/api/tasks/pyhealth.tasks.DSABinaryActivityClassification.rst @@ -0,0 +1,7 @@ +pyhealth.tasks.DSABinaryActivityClassification +======================================= + +.. autoclass:: pyhealth.tasks.DSABinaryActivityClassification + :members: + :undoc-members: + :show-inheritance: \ No newline at end of file diff --git a/examples/dsa_activity_classification.ipynb b/examples/dsa_activity_classification.ipynb new file mode 100644 index 000000000..ed5853ffa --- /dev/null +++ b/examples/dsa_activity_classification.ipynb @@ -0,0 +1,1390 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "21cc3d1f", + "metadata": {}, + "source": [ + "# Activity Classification Using the UCI Daily and Sports Activities (DSA) Dataset\n", + "\n", + "This notebook demonstrates how to use the DSA dataset within the PyHealth framework to train an LSTM classifier for two experimental setups:\n", + "\n", + "1. **Binary classification** — one-vs-rest, replicating the paper's experimental protocol\n", + "2. **Multiclass classification** — all 19 activities simultaneously, extending the paper\n", + "\n", + "**Dataset:** UCI Daily and Sports Activities \n", + "\n", + "**Paper:** Zhang et al. \"Daily Physical Activity Monitoring — Adaptive Learning from Multi-source Motion Sensor Data\". CHIL 2024 (PMLR 248:39-54) https://raw.githubusercontent.com/mlresearch/v248/main/assets/zhang24a/zhang24a.pdf\n", + "\n", + "**Original Paper Code:** https://github.com/sunlabuiuc/PyHealth/blob/master/README.rst \n", + "\n", + "It may be noted that there are some difference between the implementation in the linked github and the actual paper. This replication attempts to follow the methodology outlined in the paper. Notably, paper suggest usage of the IPD metric to scale epochs, order the domains of transfer learning, and adaptively adjust the lr; however, the code provided does not seem to do these things as described.\n", + "\n", + "**Model:** PyHealth built-in `RNN` (LSTM backbone), matching the architecture used in the paper's code (64 hidden units, dropout=0.2)" + ] + }, + { + "cell_type": "markdown", + "id": "39d6fb06", + "metadata": {}, + "source": [ + "# Paper Replication" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a78eac7f", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Device: cuda\n", + "Config: 3 repeats, random_activity=True\n", + "\n", + "Loading dataset...\n", + "Downloading DSA dataset from https://archive.ics.uci.edu/static/public/256/daily+and+sports+activities.zip ...\n", + "Download complete. Extracting ...\n", + "Extraction complete.\n", + "Dataset structure verified.\n", + "Indexed 9,120 segment files → ./data/DSA\\dsa-metadata-pyhealth.csv\n", + "Initializing DSA dataset from ./data/DSA (dev mode: False)\n", + "No cache_dir provided. Using default cache dir: C:\\Users\\11400\\AppData\\Local\\pyhealth\\pyhealth\\Cache\\64d1ae8a-7617-555e-b299-2d3ab371f7ea\n", + "RANDOM_ACTIVITY=True — will reload per repeat\n" + ] + } + ], + "source": [ + "import math\n", + "import random, time\n", + "import os\n", + "import contextlib\n", + "from typing import Dict, List, Tuple\n", + "from collections import defaultdict\n", + "\n", + "import numpy as np\n", + "import torch\n", + "from torch.optim import Adam\n", + "from torch.utils.data import DataLoader\n", + "\n", + "from pyhealth.datasets.dsa import DSADataset\n", + "from pyhealth.datasets import collate_fn_dict_with_padding\n", + "from pyhealth.models import RNN\n", + "from pyhealth.tasks.dsa import (\n", + " DSAActivityClassification,\n", + " DSABinaryActivityClassification,\n", + " compute_all_ipd_weights,\n", + ")\n", + "\n", + "import warnings\n", + "warnings.simplefilter(\"ignore\", FutureWarning)\n", + "\n", + "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", + "print(f'Device: {device}')\n", + "\n", + "# ── Configuration (paper's protocol) ────────────────────────────────\n", + "ALL_DOMAINS = ['T', 'RA', 'LA', 'RL', 'LL']\n", + "TARGET_DOMAIN = 'LA'\n", + "SOURCE_DOMAINS = [d for d in ALL_DOMAINS if d != TARGET_DOMAIN]\n", + "METRIC = 'dtw_classic'\n", + "KDE_BANDWIDTH = 7.8\n", + "LR = 0.005\n", + "BATCH_SIZE = 16\n", + "SOURCE_EPOCHS_NAIVE = 30\n", + "TARGET_EPOCHS = 30\n", + "EPOCH_SCALE_FACTOR = 7.8\n", + "N_REPEATS = 3 # paper: 15; use 3 for quick debugging\n", + "N_TRAIN_SUBJ = 6\n", + "BASE_SEED = 42\n", + "DATA_ROOT = './data/DSA'\n", + "RANDOM_ACTIVITY = True \n", + "\n", + "print(f'Config: {N_REPEATS} repeats, random_activity={RANDOM_ACTIVITY}')\n", + "\n", + "# ── Load data ───────────────────────────────────────────────────────\n", + "print('\\nLoading dataset...')\n", + "dataset = DSADataset(root=DATA_ROOT, download=False,\n", + " target_domain=TARGET_DOMAIN, scale=True)\n", + "\n", + "# Will be reloaded per repeat if RANDOM_ACTIVITY=True, otherwise load once\n", + "domain_samples = {}\n", + "if not RANDOM_ACTIVITY:\n", + " for domain in ALL_DOMAINS:\n", + " task = DSABinaryActivityClassification(\n", + " positive_activity_id=12, target_domain=domain)\n", + " domain_samples[domain] = dataset.set_task(task)\n", + " print(f'Loaded: {len(domain_samples[TARGET_DOMAIN])} samples in {TARGET_DOMAIN}')\n", + "else:\n", + " print('RANDOM_ACTIVITY=True — will reload per repeat')\n" + ] + }, + { + "cell_type": "markdown", + "id": "2f167e47", + "metadata": {}, + "source": [ + "## Helper Functions\n", + "\n", + "- **`evaluate_rcc`** — computes Ratio of Correct Classifications (the paper's primary metric, equivalent to accuracy). It handles both the binary sigmoid head (`y_prob` shape `(N, 1)`) and the multiclass softmax head (`(N, C)`).\n", + "- **`upsample_positives`** / **`downsample_negatives`** — implement the paper's class-balancing strategy: the training set upsamples the minority positive class to match the negative count; the test set downsamples negatives for balanced evaluation.\n", + "- **`build_aligned_domain_arrays`** — constructs aligned `(N, T, 1)` arrays for IPD computation by joining samples from all domains on their shared `pair_id`. Only training subjects are included to prevent data leakage into the distance computation.\n", + "- **`compute_ipd_weights_for_split`** — runs the full IPD pipeline for one subject split. The `reverse` flag switches between the paper's logic (more epochs for *more* similar domains, i.e. smaller IPD) and the author's code logic (more epochs for *larger* IPD values)." + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "80907810", + "metadata": {}, + "outputs": [], + "source": [ + "def move_to_device(batch, device):\n", + " return {k: v.to(device) if isinstance(v, torch.Tensor) else v\n", + " for k, v in batch.items()}\n", + "\n", + "\n", + "def build_model(ref_samples, device):\n", + " return RNN(dataset=ref_samples, rnn_type='LSTM',\n", + " hidden_dim=64, dropout=0.2).to(device)\n", + "\n", + "\n", + "def make_dataloader(samples, batch_size, shuffle=True):\n", + " return DataLoader(samples, batch_size=batch_size, shuffle=shuffle,\n", + " collate_fn=collate_fn_dict_with_padding)\n", + "\n", + "\n", + "def split_by_subject(samples, train_ids, test_ids):\n", + " tr = [s for s in samples if s['patient_id'] in train_ids]\n", + " te = [s for s in samples if s['patient_id'] in test_ids]\n", + " return tr, te\n", + "\n", + "\n", + "def upsample_positives(samples, rng):\n", + " pos = [s for s in samples if s['label'] == 1]\n", + " neg = [s for s in samples if s['label'] == 0]\n", + " if not pos or not neg: return samples\n", + " n = len(neg)\n", + " up = pos * (n // len(pos)) + rng.sample(pos, n % len(pos))\n", + " return up + neg\n", + "\n", + "\n", + "def downsample_negatives(samples, rng):\n", + " pos = [s for s in samples if s['label'] == 1]\n", + " neg = [s for s in samples if s['label'] == 0]\n", + " if not pos: return samples\n", + " return pos + rng.sample(neg, min(len(pos), len(neg)))\n", + "\n", + "\n", + "def get_all_subject_ids(samples):\n", + " return sorted(set(int(s['patient_id'][1:]) for s in samples))\n", + "\n", + "\n", + "def evaluate_rcc(model, loader, device) -> float:\n", + " \"\"\"RCC = ratio of correct classifications (accuracy).\n", + " \n", + " Handles PyHealth's binary sigmoid head (y_prob shape (N, 1)).\n", + " \"\"\"\n", + " model.eval()\n", + " correct = 0\n", + " total = 0\n", + " with torch.no_grad():\n", + " for batch in loader:\n", + " batch = move_to_device(batch, device)\n", + " out = model(**batch)\n", + " y_prob = out['y_prob']\n", + "\n", + " # Handle both sigmoid (N, 1) and softmax (N, C) heads\n", + " if y_prob.shape[-1] == 1:\n", + " preds = (y_prob.squeeze(-1) >= 0.5).long()\n", + " else:\n", + " preds = y_prob.argmax(dim=1)\n", + "\n", + " # Coerce label to (N,) long\n", + " raw = batch['label']\n", + " if isinstance(raw, torch.Tensor):\n", + " labels = raw.view(-1).long().to(device)\n", + " else:\n", + " labels = torch.tensor(raw, dtype=torch.long, device=device).view(-1)\n", + "\n", + " n = preds.shape[0]\n", + " correct += (preds == labels[:n]).sum().item()\n", + " total += n\n", + " return correct / max(total, 1)\n", + "\n", + "\n", + "def run_epoch(model, loader, optimizer, device):\n", + " model.train()\n", + " for batch in loader:\n", + " batch = move_to_device(batch, device)\n", + " raw = batch.get('label')\n", + " if raw is not None:\n", + " if isinstance(raw, torch.Tensor):\n", + " batch['label'] = raw.view(-1).long().to(device)\n", + " else:\n", + " batch['label'] = torch.tensor(\n", + " raw, dtype=torch.long, device=device).view(-1)\n", + " optimizer.zero_grad()\n", + " out = model(**batch)\n", + " out['loss'].backward()\n", + " optimizer.step()\n", + "\n", + "\n", + "def train_on_domain(model, loader, n_epochs, lr, device, verbose=False, label=''):\n", + " if n_epochs <= 0: return\n", + " opt = Adam(model.parameters(), lr=lr)\n", + " for _ in range(n_epochs):\n", + " run_epoch(model, loader, opt, device)\n", + " if verbose and label:\n", + " print(f' {label}: {n_epochs} ep')\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d2ec9344", + "metadata": {}, + "outputs": [], + "source": [ + "def build_aligned_domain_arrays(\n", + " domain_samples_dict: Dict[str, list],\n", + " train_ids: set,\n", + " channel: int = 0,\n", + ") -> Dict[str, np.ndarray]:\n", + " \"\"\"\n", + " Extract aligned (N, T, 1) arrays for IPD computation.\n", + "\n", + " Each domain produces one array whose i-th row is the *same* physical\n", + " activity instance as every other domain's i-th row (aligned by pair_id).\n", + " Only training subjects are included.\n", + "\n", + " ``channel=0`` takes a single sensor axis, matching the author's code\n", + " which squeezes each view to 1-D before computing distances.\n", + " \"\"\"\n", + " # pair_id → time_series for each domain (training subjects only)\n", + " domain_maps = {\n", + " domain: {\n", + " s['pair_id']: s['time_series']\n", + " for s in samples if s['patient_id'] in train_ids\n", + " }\n", + " for domain, samples in domain_samples_dict.items()\n", + " }\n", + "\n", + " # Intersection of pair_ids across all domains\n", + " common = sorted(\n", + " set.intersection(*[set(m.keys()) for m in domain_maps.values()])\n", + " )\n", + " if not common:\n", + " raise ValueError('No common pair_ids found across domains.')\n", + "\n", + " arrays = {}\n", + " for domain, pm in domain_maps.items():\n", + " ts_list = [pm[pid][channel] for pid in common] # each (T,)\n", + " arr = np.array(ts_list, dtype=np.float32) # (N, T)\n", + " arrays[domain] = arr[:, :, np.newaxis] # (N, T, 1)\n", + "\n", + " return arrays\n", + "\n", + "def compute_ipd_weights_for_split(\n", + " domain_samples_dict: Dict[str, list],\n", + " train_ids: set,\n", + " target_domain: str,\n", + " metric: str = 'dtw_classic',\n", + " bandwidth: float = 7.8,\n", + " reverse: bool = True\n", + ") -> Tuple[Dict[str, float], Dict[str, int]]:\n", + " \"\"\"\n", + " Full IPD pipeline for one subject split.\n", + "\n", + " Returns\n", + " -------\n", + " ipd_weights : {domain: scalar weight} (higher = less similar)\n", + " weighted_epochs: {domain: epoch count for weighted pre-training}\n", + "\n", + " If reverse=True, allocate more epochs to *smaller* IPD values by using\n", + " inverse weights, which is more consistent with the paper's written logic.\n", + " If reverse=False, allocate more epochs to *larger* IPD values, matching\n", + " the author's released code behavior.\n", + " \"\"\"\n", + " print(f' Building aligned domain arrays (metric={metric})...')\n", + " domain_arrays = build_aligned_domain_arrays(\n", + " domain_samples_dict, train_ids, channel=0\n", + " )\n", + " N = next(iter(domain_arrays.values())).shape[0]\n", + " print(f' Aligned {N} paired samples across {len(domain_arrays)} domains')\n", + "\n", + " print(f' Computing pairwise distances (this may take a minute with DTW)...')\n", + " ipd_weights = compute_all_ipd_weights(\n", + " domain_data=domain_arrays,\n", + " target_domain=target_domain,\n", + " metric=metric,\n", + " bandwidth=bandwidth,\n", + " )\n", + "\n", + " # Epoch scaling: int(base * scale_factor * w / w_sum) + 1\n", + " # Code from author uses \"epochs=int(30*7*locals()[weight_name]/weight_all)+1\"\n", + " # This implies that higher weight → more epochs (matches *code*, opposite of paper §4.2)\n", + " if reverse:\n", + " eps = 1e-8\n", + " epoch_weights = {\n", + " d: 1.0 / max(w, eps)\n", + " for d, w in ipd_weights.items()\n", + " }\n", + " else:\n", + " epoch_weights = dict(ipd_weights)\n", + "\n", + " w_sum = sum(epoch_weights.values())\n", + " weighted_epochs = (\n", + " {\n", + " d: int(SOURCE_EPOCHS_NAIVE * EPOCH_SCALE_FACTOR * epoch_weights[d] / w_sum) + 1\n", + " for d in ipd_weights\n", + " }\n", + " if w_sum > 0\n", + " else {d: SOURCE_EPOCHS_NAIVE for d in ipd_weights}\n", + " )\n", + "\n", + " return ipd_weights, weighted_epochs" + ] + }, + { + "cell_type": "markdown", + "id": "01fc0563", + "metadata": {}, + "source": [ + "## Experiment Loop\n", + "\n", + "**`run_one_repeat`** trains and evaluates all three conditions for a single subject split:\n", + "\n", + "| Condition | Description |\n", + "|---|---|\n", + "| No Transfer | Train only on target domain (LA). Baseline — no cross-domain knowledge. |\n", + "| Naive Transfer | Pre-train on each source domain for a fixed epoch count, then fine-tune on target. |\n", + "| Weighted Transfer | Pre-train on sources ordered and weighted by IPD, then fine-tune on target. |\n", + "\n", + "All three conditions use the same freshly initialised model to ensure comparability. The target domain training set is upsampled (positives repeated) and the test set is downsampled (negatives randomly removed) to produce balanced evaluation, following the paper's protocol.\n", + "\n", + "**`run_experiment`** runs `run_one_repeat` for `N_REPEATS` independent random subject splits, accumulating results for mean and standard deviation reporting." + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "86b99264", + "metadata": {}, + "outputs": [], + "source": [ + "paper_weight = True\n", + "\n", + "@contextlib.contextmanager\n", + "def suppress_output():\n", + " with open(os.devnull, \"w\") as fnull:\n", + " with contextlib.redirect_stdout(fnull), contextlib.redirect_stderr(fnull):\n", + " yield\n", + "\n", + "def run_one_repeat(\n", + " domain_samples: Dict[str, list],\n", + " train_ids: set,\n", + " test_ids: set,\n", + " ipd_weights: Dict[str, float],\n", + " weighted_epochs: Dict[str, int],\n", + " source_domains: List[str],\n", + " target_domain: str,\n", + " device,\n", + " rng: random.Random,\n", + ") -> Dict[str, float]:\n", + " loaders = {}\n", + " for domain, samples in domain_samples.items():\n", + " tr, te = split_by_subject(samples, train_ids, test_ids)\n", + " if domain == target_domain:\n", + " tr = upsample_positives(tr, rng)\n", + " te = downsample_negatives(te, rng)\n", + " loaders[domain] = (\n", + " make_dataloader(tr, BATCH_SIZE, shuffle=True),\n", + " make_dataloader(te, BATCH_SIZE, shuffle=False),\n", + " )\n", + "\n", + " tgt_tr_ld, tgt_te_ld = loaders[target_domain]\n", + " ref = domain_samples[target_domain]\n", + " results = {}\n", + "\n", + " # No Transfer\n", + " m = build_model(ref, device)\n", + " train_on_domain(m, tgt_tr_ld, TARGET_EPOCHS, LR, device)\n", + " results['No Transfer'] = evaluate_rcc(m, tgt_te_ld, device)\n", + "\n", + " # Naive Transfer\n", + " m = build_model(ref, device)\n", + " for src in source_domains:\n", + " train_on_domain(m, loaders[src][0], SOURCE_EPOCHS_NAIVE, LR, device)\n", + " train_on_domain(m, tgt_tr_ld, TARGET_EPOCHS, LR, device)\n", + " results['Naive Transfer'] = evaluate_rcc(m, tgt_te_ld, device)\n", + "\n", + " # Weighted Transfer:\n", + " # If paper_weight then sort descending (most similar first) and use epochs ∝ 1/IPD\n", + " m = build_model(ref, device)\n", + " sorted_sources = sorted(source_domains,\n", + " key=lambda d: ipd_weights.get(d, 0.),\n", + " reverse=paper_weight)\n", + " for src in sorted_sources:\n", + " n_ep = weighted_epochs.get(src, SOURCE_EPOCHS_NAIVE)\n", + " train_on_domain(m, loaders[src][0], n_ep, LR, device)\n", + " train_on_domain(m, tgt_tr_ld, TARGET_EPOCHS, LR, device)\n", + " results['Weighted Transfer'] = evaluate_rcc(m, tgt_te_ld, device)\n", + "\n", + " return results\n", + "\n", + "\n", + "def run_experiment(\n", + " n_repeats: int = N_REPEATS,\n", + " base_seed: int = BASE_SEED,\n", + " random_activity: bool = RANDOM_ACTIVITY,\n", + ") -> List[dict]:\n", + " \"\"\"Full experiment using paper's protocol.\"\"\"\n", + " all_subjects = None\n", + " rng = random.Random(base_seed)\n", + " all_activities = list(range(1, 20))\n", + " all_results = []\n", + "\n", + " for rep in range(n_repeats):\n", + " t0 = time.time()\n", + " print(f'\\n{\"=\"*62}')\n", + " print(f' Repeat {rep + 1} / {n_repeats}')\n", + " print(f'{\"=\"*62}')\n", + "\n", + " # Reload if random activity\n", + " cur_domain_samples = {}\n", + " if random_activity:\n", + " act_id = rng.choice(all_activities)\n", + " print(f' Positive activity: {act_id}')\n", + " for domain in ALL_DOMAINS:\n", + " task = DSABinaryActivityClassification(\n", + " positive_activity_id=act_id, target_domain=domain)\n", + " with warnings.catch_warnings():\n", + " warnings.simplefilter(\"ignore\")\n", + " with suppress_output():\n", + " cur_domain_samples[domain] = dataset.set_task(task)\n", + " else:\n", + " cur_domain_samples = domain_samples\n", + "\n", + " if all_subjects is None:\n", + " all_subjects = get_all_subject_ids(cur_domain_samples[TARGET_DOMAIN])\n", + "\n", + " # Subject split\n", + " train_subj = sorted(rng.sample(all_subjects, N_TRAIN_SUBJ))\n", + " test_subj = [s for s in all_subjects if s not in train_subj]\n", + " train_ids = {f'p{s}' for s in train_subj}\n", + " test_ids = {f'p{s}' for s in test_subj}\n", + " print(f' Train: {train_subj} Test: {test_subj}')\n", + "\n", + " # IPD\n", + " print(' Computing IPD...')\n", + " try:\n", + " ipd_weights, w_epochs = compute_ipd_weights_for_split(\n", + " cur_domain_samples, train_ids, TARGET_DOMAIN, METRIC, KDE_BANDWIDTH)\n", + " except Exception as e:\n", + " print(f' IPD failed: {e}')\n", + " ipd_weights = {d: 1. for d in SOURCE_DOMAINS}\n", + " w_epochs = {d: SOURCE_EPOCHS_NAIVE for d in SOURCE_DOMAINS}\n", + "\n", + " desc_order = sorted(SOURCE_DOMAINS, key=lambda d: ipd_weights[d], reverse=True)\n", + " print(f' IPD rank (desc): {\" > \".join(f\"{d}={ipd_weights[d]:.3f}\" for d in desc_order)}')\n", + " print(f' epochs: {w_epochs}')\n", + "\n", + " # Run experiment\n", + " print(' Training...')\n", + " rep_rng = random.Random(rng.random())\n", + " rep_res = run_one_repeat(\n", + " cur_domain_samples, train_ids, test_ids,\n", + " ipd_weights, w_epochs,\n", + " SOURCE_DOMAINS, TARGET_DOMAIN, device, rep_rng)\n", + "\n", + " rep_res.update({\n", + " 'repeat': rep + 1,\n", + " 'train_subjects': train_subj,\n", + " 'test_subjects': test_subj,\n", + " 'ipd_weights': dict(ipd_weights),\n", + " })\n", + " all_results.append(rep_res)\n", + "\n", + " elapsed = time.time() - t0\n", + " print(f' Done in {elapsed:.0f}s')\n", + " for cond in ('No Transfer', 'Naive Transfer', 'Weighted Transfer'):\n", + " print(f' {cond:<22}: {rep_res[cond]:.4f}')\n", + "\n", + " return all_results\n" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "84b95076", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "==============================================================\n", + " Repeat 1 / 3\n", + "==============================================================\n", + " Positive activity: 4\n", + "Setting task dsa_binary_classification for DSA base dataset...\n", + "Task cache paths: task_df=C:\\Users\\11400\\AppData\\Local\\pyhealth\\pyhealth\\Cache\\64d1ae8a-7617-555e-b299-2d3ab371f7ea\\tasks\\dsa_binary_classification_70e62abf-439f-527b-9f04-71b62644f742\\task_df.ld, samples=C:\\Users\\11400\\AppData\\Local\\pyhealth\\pyhealth\\Cache\\64d1ae8a-7617-555e-b299-2d3ab371f7ea\\tasks\\dsa_binary_classification_70e62abf-439f-527b-9f04-71b62644f742\\samples_cdbbc602-34e2-5a41-8643-4c76b08829f6.ld\n", + "Found cached processed samples at C:\\Users\\11400\\AppData\\Local\\pyhealth\\pyhealth\\Cache\\64d1ae8a-7617-555e-b299-2d3ab371f7ea\\tasks\\dsa_binary_classification_70e62abf-439f-527b-9f04-71b62644f742\\samples_cdbbc602-34e2-5a41-8643-4c76b08829f6.ld, skipping processing.\n", + "Setting task dsa_binary_classification for DSA base dataset...\n", + "Task cache paths: task_df=C:\\Users\\11400\\AppData\\Local\\pyhealth\\pyhealth\\Cache\\64d1ae8a-7617-555e-b299-2d3ab371f7ea\\tasks\\dsa_binary_classification_f68bfef6-b72d-51d2-9cd1-2f8f3943b437\\task_df.ld, samples=C:\\Users\\11400\\AppData\\Local\\pyhealth\\pyhealth\\Cache\\64d1ae8a-7617-555e-b299-2d3ab371f7ea\\tasks\\dsa_binary_classification_f68bfef6-b72d-51d2-9cd1-2f8f3943b437\\samples_cdbbc602-34e2-5a41-8643-4c76b08829f6.ld\n", + "Found cached processed samples at C:\\Users\\11400\\AppData\\Local\\pyhealth\\pyhealth\\Cache\\64d1ae8a-7617-555e-b299-2d3ab371f7ea\\tasks\\dsa_binary_classification_f68bfef6-b72d-51d2-9cd1-2f8f3943b437\\samples_cdbbc602-34e2-5a41-8643-4c76b08829f6.ld, skipping processing.\n", + "Setting task dsa_binary_classification for DSA base dataset...\n", + "Task cache paths: task_df=C:\\Users\\11400\\AppData\\Local\\pyhealth\\pyhealth\\Cache\\64d1ae8a-7617-555e-b299-2d3ab371f7ea\\tasks\\dsa_binary_classification_53a0952b-8fb7-595f-a4f2-efc2e8f24696\\task_df.ld, samples=C:\\Users\\11400\\AppData\\Local\\pyhealth\\pyhealth\\Cache\\64d1ae8a-7617-555e-b299-2d3ab371f7ea\\tasks\\dsa_binary_classification_53a0952b-8fb7-595f-a4f2-efc2e8f24696\\samples_cdbbc602-34e2-5a41-8643-4c76b08829f6.ld\n", + "Found cached processed samples at C:\\Users\\11400\\AppData\\Local\\pyhealth\\pyhealth\\Cache\\64d1ae8a-7617-555e-b299-2d3ab371f7ea\\tasks\\dsa_binary_classification_53a0952b-8fb7-595f-a4f2-efc2e8f24696\\samples_cdbbc602-34e2-5a41-8643-4c76b08829f6.ld, skipping processing.\n", + "Setting task dsa_binary_classification for DSA base dataset...\n", + "Task cache paths: task_df=C:\\Users\\11400\\AppData\\Local\\pyhealth\\pyhealth\\Cache\\64d1ae8a-7617-555e-b299-2d3ab371f7ea\\tasks\\dsa_binary_classification_cbcb365e-1f94-5983-a8c5-a521dc22070b\\task_df.ld, samples=C:\\Users\\11400\\AppData\\Local\\pyhealth\\pyhealth\\Cache\\64d1ae8a-7617-555e-b299-2d3ab371f7ea\\tasks\\dsa_binary_classification_cbcb365e-1f94-5983-a8c5-a521dc22070b\\samples_cdbbc602-34e2-5a41-8643-4c76b08829f6.ld\n", + "Found cached processed samples at C:\\Users\\11400\\AppData\\Local\\pyhealth\\pyhealth\\Cache\\64d1ae8a-7617-555e-b299-2d3ab371f7ea\\tasks\\dsa_binary_classification_cbcb365e-1f94-5983-a8c5-a521dc22070b\\samples_cdbbc602-34e2-5a41-8643-4c76b08829f6.ld, skipping processing.\n", + "Setting task dsa_binary_classification for DSA base dataset...\n", + "Task cache paths: task_df=C:\\Users\\11400\\AppData\\Local\\pyhealth\\pyhealth\\Cache\\64d1ae8a-7617-555e-b299-2d3ab371f7ea\\tasks\\dsa_binary_classification_4efe18c3-8889-543d-9274-7b0b52a06658\\task_df.ld, samples=C:\\Users\\11400\\AppData\\Local\\pyhealth\\pyhealth\\Cache\\64d1ae8a-7617-555e-b299-2d3ab371f7ea\\tasks\\dsa_binary_classification_4efe18c3-8889-543d-9274-7b0b52a06658\\samples_cdbbc602-34e2-5a41-8643-4c76b08829f6.ld\n", + "Found cached processed samples at C:\\Users\\11400\\AppData\\Local\\pyhealth\\pyhealth\\Cache\\64d1ae8a-7617-555e-b299-2d3ab371f7ea\\tasks\\dsa_binary_classification_4efe18c3-8889-543d-9274-7b0b52a06658\\samples_cdbbc602-34e2-5a41-8643-4c76b08829f6.ld, skipping processing.\n", + " Train: [1, 2, 3, 5, 6, 8] Test: [4, 7]\n", + " Computing IPD...\n", + " Building aligned domain arrays (metric=dtw_classic)...\n", + " Aligned 1140 paired samples across 5 domains\n", + " Computing pairwise distances (this may take a minute with DTW)...\n", + "IPD weight T → LA: 5.9967 (metric=dtw_classic)\n", + "IPD weight RA → LA: 5.4693 (metric=dtw_classic)\n", + "IPD weight RL → LA: 7.2420 (metric=dtw_classic)\n", + "IPD weight LL → LA: 7.0441 (metric=dtw_classic)\n", + " IPD rank (desc): RL=7.242 > LL=7.044 > T=5.997 > RA=5.469\n", + " Paper epochs: {'T': 55, 'RA': 50, 'RL': 66, 'LL': 65}\n", + " Training...\n", + " Done in 1388s\n", + " No Transfer : 0.6333\n", + " Naive Transfer : 0.6708\n", + " Weighted Transfer : 0.6125\n", + "\n", + "==============================================================\n", + " Repeat 2 / 3\n", + "==============================================================\n", + " Positive activity: 18\n", + "Setting task dsa_binary_classification for DSA base dataset...\n", + "Task cache paths: task_df=C:\\Users\\11400\\AppData\\Local\\pyhealth\\pyhealth\\Cache\\64d1ae8a-7617-555e-b299-2d3ab371f7ea\\tasks\\dsa_binary_classification_2ef503bd-911b-5389-8562-4d9df70b41c3\\task_df.ld, samples=C:\\Users\\11400\\AppData\\Local\\pyhealth\\pyhealth\\Cache\\64d1ae8a-7617-555e-b299-2d3ab371f7ea\\tasks\\dsa_binary_classification_2ef503bd-911b-5389-8562-4d9df70b41c3\\samples_cdbbc602-34e2-5a41-8643-4c76b08829f6.ld\n", + "Applying task transformations on data with 1 workers...\n", + "Found cached event dataframe: C:\\Users\\11400\\AppData\\Local\\pyhealth\\pyhealth\\Cache\\64d1ae8a-7617-555e-b299-2d3ab371f7ea\\global_event_df.parquet\n", + "Detected Jupyter notebook environment, setting num_workers to 1\n", + "Single worker mode, processing sequentially\n", + "Worker 0 started processing 8 patients. (Polars threads: 16)\n", + "Worker 0 finished processing patients.\n", + "Fitting processors on the dataset...\n", + "Label label vocab: {0: 0, 1: 1}\n", + "Processing samples and saving to C:\\Users\\11400\\AppData\\Local\\pyhealth\\pyhealth\\Cache\\64d1ae8a-7617-555e-b299-2d3ab371f7ea\\tasks\\dsa_binary_classification_2ef503bd-911b-5389-8562-4d9df70b41c3\\samples_cdbbc602-34e2-5a41-8643-4c76b08829f6.ld...\n", + "Applying processors on data with 1 workers...\n", + "Detected Jupyter notebook environment, setting num_workers to 1\n", + "Single worker mode, processing sequentially\n", + "Worker 0 started processing 9120 samples. (0 to 9120)\n", + "Worker 0 finished processing samples.\n", + "Cached processed samples to C:\\Users\\11400\\AppData\\Local\\pyhealth\\pyhealth\\Cache\\64d1ae8a-7617-555e-b299-2d3ab371f7ea\\tasks\\dsa_binary_classification_2ef503bd-911b-5389-8562-4d9df70b41c3\\samples_cdbbc602-34e2-5a41-8643-4c76b08829f6.ld\n", + "Setting task dsa_binary_classification for DSA base dataset...\n", + "Task cache paths: task_df=C:\\Users\\11400\\AppData\\Local\\pyhealth\\pyhealth\\Cache\\64d1ae8a-7617-555e-b299-2d3ab371f7ea\\tasks\\dsa_binary_classification_140e1bb5-fdd0-5794-83c3-caefc4f49106\\task_df.ld, samples=C:\\Users\\11400\\AppData\\Local\\pyhealth\\pyhealth\\Cache\\64d1ae8a-7617-555e-b299-2d3ab371f7ea\\tasks\\dsa_binary_classification_140e1bb5-fdd0-5794-83c3-caefc4f49106\\samples_cdbbc602-34e2-5a41-8643-4c76b08829f6.ld\n", + "Applying task transformations on data with 1 workers...\n", + "Detected Jupyter notebook environment, setting num_workers to 1\n", + "Single worker mode, processing sequentially\n", + "Worker 0 started processing 8 patients. (Polars threads: 16)\n", + "Worker 0 finished processing patients.\n", + "Fitting processors on the dataset...\n", + "Label label vocab: {0: 0, 1: 1}\n", + "Processing samples and saving to C:\\Users\\11400\\AppData\\Local\\pyhealth\\pyhealth\\Cache\\64d1ae8a-7617-555e-b299-2d3ab371f7ea\\tasks\\dsa_binary_classification_140e1bb5-fdd0-5794-83c3-caefc4f49106\\samples_cdbbc602-34e2-5a41-8643-4c76b08829f6.ld...\n", + "Applying processors on data with 1 workers...\n", + "Detected Jupyter notebook environment, setting num_workers to 1\n", + "Single worker mode, processing sequentially\n", + "Worker 0 started processing 9120 samples. (0 to 9120)\n", + "Worker 0 finished processing samples.\n", + "Cached processed samples to C:\\Users\\11400\\AppData\\Local\\pyhealth\\pyhealth\\Cache\\64d1ae8a-7617-555e-b299-2d3ab371f7ea\\tasks\\dsa_binary_classification_140e1bb5-fdd0-5794-83c3-caefc4f49106\\samples_cdbbc602-34e2-5a41-8643-4c76b08829f6.ld\n", + "Setting task dsa_binary_classification for DSA base dataset...\n", + "Task cache paths: task_df=C:\\Users\\11400\\AppData\\Local\\pyhealth\\pyhealth\\Cache\\64d1ae8a-7617-555e-b299-2d3ab371f7ea\\tasks\\dsa_binary_classification_962dcdcd-a420-5cac-951d-fb60ebc5e641\\task_df.ld, samples=C:\\Users\\11400\\AppData\\Local\\pyhealth\\pyhealth\\Cache\\64d1ae8a-7617-555e-b299-2d3ab371f7ea\\tasks\\dsa_binary_classification_962dcdcd-a420-5cac-951d-fb60ebc5e641\\samples_cdbbc602-34e2-5a41-8643-4c76b08829f6.ld\n", + "Applying task transformations on data with 1 workers...\n", + "Detected Jupyter notebook environment, setting num_workers to 1\n", + "Single worker mode, processing sequentially\n", + "Worker 0 started processing 8 patients. (Polars threads: 16)\n", + "Worker 0 finished processing patients.\n", + "Fitting processors on the dataset...\n", + "Label label vocab: {0: 0, 1: 1}\n", + "Processing samples and saving to C:\\Users\\11400\\AppData\\Local\\pyhealth\\pyhealth\\Cache\\64d1ae8a-7617-555e-b299-2d3ab371f7ea\\tasks\\dsa_binary_classification_962dcdcd-a420-5cac-951d-fb60ebc5e641\\samples_cdbbc602-34e2-5a41-8643-4c76b08829f6.ld...\n", + "Applying processors on data with 1 workers...\n", + "Detected Jupyter notebook environment, setting num_workers to 1\n", + "Single worker mode, processing sequentially\n", + "Worker 0 started processing 9120 samples. (0 to 9120)\n", + "Worker 0 finished processing samples.\n", + "Cached processed samples to C:\\Users\\11400\\AppData\\Local\\pyhealth\\pyhealth\\Cache\\64d1ae8a-7617-555e-b299-2d3ab371f7ea\\tasks\\dsa_binary_classification_962dcdcd-a420-5cac-951d-fb60ebc5e641\\samples_cdbbc602-34e2-5a41-8643-4c76b08829f6.ld\n", + "Setting task dsa_binary_classification for DSA base dataset...\n", + "Task cache paths: task_df=C:\\Users\\11400\\AppData\\Local\\pyhealth\\pyhealth\\Cache\\64d1ae8a-7617-555e-b299-2d3ab371f7ea\\tasks\\dsa_binary_classification_2c76fed8-fcb8-5d0b-96da-b9b487a5f3bf\\task_df.ld, samples=C:\\Users\\11400\\AppData\\Local\\pyhealth\\pyhealth\\Cache\\64d1ae8a-7617-555e-b299-2d3ab371f7ea\\tasks\\dsa_binary_classification_2c76fed8-fcb8-5d0b-96da-b9b487a5f3bf\\samples_cdbbc602-34e2-5a41-8643-4c76b08829f6.ld\n", + "Applying task transformations on data with 1 workers...\n", + "Detected Jupyter notebook environment, setting num_workers to 1\n", + "Single worker mode, processing sequentially\n", + "Worker 0 started processing 8 patients. (Polars threads: 16)\n", + "Worker 0 finished processing patients.\n", + "Fitting processors on the dataset...\n", + "Label label vocab: {0: 0, 1: 1}\n", + "Processing samples and saving to C:\\Users\\11400\\AppData\\Local\\pyhealth\\pyhealth\\Cache\\64d1ae8a-7617-555e-b299-2d3ab371f7ea\\tasks\\dsa_binary_classification_2c76fed8-fcb8-5d0b-96da-b9b487a5f3bf\\samples_cdbbc602-34e2-5a41-8643-4c76b08829f6.ld...\n", + "Applying processors on data with 1 workers...\n", + "Detected Jupyter notebook environment, setting num_workers to 1\n", + "Single worker mode, processing sequentially\n", + "Worker 0 started processing 9120 samples. (0 to 9120)\n", + "Worker 0 finished processing samples.\n", + "Cached processed samples to C:\\Users\\11400\\AppData\\Local\\pyhealth\\pyhealth\\Cache\\64d1ae8a-7617-555e-b299-2d3ab371f7ea\\tasks\\dsa_binary_classification_2c76fed8-fcb8-5d0b-96da-b9b487a5f3bf\\samples_cdbbc602-34e2-5a41-8643-4c76b08829f6.ld\n", + "Setting task dsa_binary_classification for DSA base dataset...\n", + "Task cache paths: task_df=C:\\Users\\11400\\AppData\\Local\\pyhealth\\pyhealth\\Cache\\64d1ae8a-7617-555e-b299-2d3ab371f7ea\\tasks\\dsa_binary_classification_5536b554-1486-53e8-aa75-8326d2cafa63\\task_df.ld, samples=C:\\Users\\11400\\AppData\\Local\\pyhealth\\pyhealth\\Cache\\64d1ae8a-7617-555e-b299-2d3ab371f7ea\\tasks\\dsa_binary_classification_5536b554-1486-53e8-aa75-8326d2cafa63\\samples_cdbbc602-34e2-5a41-8643-4c76b08829f6.ld\n", + "Applying task transformations on data with 1 workers...\n", + "Detected Jupyter notebook environment, setting num_workers to 1\n", + "Single worker mode, processing sequentially\n", + "Worker 0 started processing 8 patients. (Polars threads: 16)\n", + "Worker 0 finished processing patients.\n", + "Fitting processors on the dataset...\n", + "Label label vocab: {0: 0, 1: 1}\n", + "Processing samples and saving to C:\\Users\\11400\\AppData\\Local\\pyhealth\\pyhealth\\Cache\\64d1ae8a-7617-555e-b299-2d3ab371f7ea\\tasks\\dsa_binary_classification_5536b554-1486-53e8-aa75-8326d2cafa63\\samples_cdbbc602-34e2-5a41-8643-4c76b08829f6.ld...\n", + "Applying processors on data with 1 workers...\n", + "Detected Jupyter notebook environment, setting num_workers to 1\n", + "Single worker mode, processing sequentially\n", + "Worker 0 started processing 9120 samples. (0 to 9120)\n", + "Worker 0 finished processing samples.\n", + "Cached processed samples to C:\\Users\\11400\\AppData\\Local\\pyhealth\\pyhealth\\Cache\\64d1ae8a-7617-555e-b299-2d3ab371f7ea\\tasks\\dsa_binary_classification_5536b554-1486-53e8-aa75-8326d2cafa63\\samples_cdbbc602-34e2-5a41-8643-4c76b08829f6.ld\n", + " Train: [1, 2, 4, 5, 6, 7] Test: [3, 8]\n", + " Computing IPD...\n", + " Building aligned domain arrays (metric=dtw_classic)...\n", + " Aligned 1140 paired samples across 5 domains\n", + " Computing pairwise distances (this may take a minute with DTW)...\n", + "IPD weight T → LA: 6.2897 (metric=dtw_classic)\n", + "IPD weight RA → LA: 6.3211 (metric=dtw_classic)\n", + "IPD weight RL → LA: 6.9094 (metric=dtw_classic)\n", + "IPD weight LL → LA: 7.2497 (metric=dtw_classic)\n", + " IPD rank (desc): LL=7.250 > RL=6.909 > RA=6.321 > T=6.290\n", + " Paper epochs: {'T': 55, 'RA': 56, 'RL': 61, 'LL': 64}\n", + " Training...\n", + " Done in 1417s\n", + " No Transfer : 0.7708\n", + " Naive Transfer : 0.7875\n", + " Weighted Transfer : 0.8125\n", + "\n", + "==============================================================\n", + " Repeat 3 / 3\n", + "==============================================================\n", + " Positive activity: 17\n", + "Setting task dsa_binary_classification for DSA base dataset...\n", + "Task cache paths: task_df=C:\\Users\\11400\\AppData\\Local\\pyhealth\\pyhealth\\Cache\\64d1ae8a-7617-555e-b299-2d3ab371f7ea\\tasks\\dsa_binary_classification_d95cc56a-3b46-5822-96a8-f49948edb5df\\task_df.ld, samples=C:\\Users\\11400\\AppData\\Local\\pyhealth\\pyhealth\\Cache\\64d1ae8a-7617-555e-b299-2d3ab371f7ea\\tasks\\dsa_binary_classification_d95cc56a-3b46-5822-96a8-f49948edb5df\\samples_cdbbc602-34e2-5a41-8643-4c76b08829f6.ld\n", + "Applying task transformations on data with 1 workers...\n", + "Detected Jupyter notebook environment, setting num_workers to 1\n", + "Single worker mode, processing sequentially\n", + "Worker 0 started processing 8 patients. (Polars threads: 16)\n", + "Worker 0 finished processing patients.\n", + "Fitting processors on the dataset...\n", + "Label label vocab: {0: 0, 1: 1}\n", + "Processing samples and saving to C:\\Users\\11400\\AppData\\Local\\pyhealth\\pyhealth\\Cache\\64d1ae8a-7617-555e-b299-2d3ab371f7ea\\tasks\\dsa_binary_classification_d95cc56a-3b46-5822-96a8-f49948edb5df\\samples_cdbbc602-34e2-5a41-8643-4c76b08829f6.ld...\n", + "Applying processors on data with 1 workers...\n", + "Detected Jupyter notebook environment, setting num_workers to 1\n", + "Single worker mode, processing sequentially\n", + "Worker 0 started processing 9120 samples. (0 to 9120)\n", + "Worker 0 finished processing samples.\n", + "Cached processed samples to C:\\Users\\11400\\AppData\\Local\\pyhealth\\pyhealth\\Cache\\64d1ae8a-7617-555e-b299-2d3ab371f7ea\\tasks\\dsa_binary_classification_d95cc56a-3b46-5822-96a8-f49948edb5df\\samples_cdbbc602-34e2-5a41-8643-4c76b08829f6.ld\n", + "Setting task dsa_binary_classification for DSA base dataset...\n", + "Task cache paths: task_df=C:\\Users\\11400\\AppData\\Local\\pyhealth\\pyhealth\\Cache\\64d1ae8a-7617-555e-b299-2d3ab371f7ea\\tasks\\dsa_binary_classification_22214b4b-7cfc-56b3-93b1-c14e945ee58f\\task_df.ld, samples=C:\\Users\\11400\\AppData\\Local\\pyhealth\\pyhealth\\Cache\\64d1ae8a-7617-555e-b299-2d3ab371f7ea\\tasks\\dsa_binary_classification_22214b4b-7cfc-56b3-93b1-c14e945ee58f\\samples_cdbbc602-34e2-5a41-8643-4c76b08829f6.ld\n", + "Applying task transformations on data with 1 workers...\n", + "Detected Jupyter notebook environment, setting num_workers to 1\n", + "Single worker mode, processing sequentially\n", + "Worker 0 started processing 8 patients. (Polars threads: 16)\n", + "Worker 0 finished processing patients.\n", + "Fitting processors on the dataset...\n", + "Label label vocab: {0: 0, 1: 1}\n", + "Processing samples and saving to C:\\Users\\11400\\AppData\\Local\\pyhealth\\pyhealth\\Cache\\64d1ae8a-7617-555e-b299-2d3ab371f7ea\\tasks\\dsa_binary_classification_22214b4b-7cfc-56b3-93b1-c14e945ee58f\\samples_cdbbc602-34e2-5a41-8643-4c76b08829f6.ld...\n", + "Applying processors on data with 1 workers...\n", + "Detected Jupyter notebook environment, setting num_workers to 1\n", + "Single worker mode, processing sequentially\n", + "Worker 0 started processing 9120 samples. (0 to 9120)\n", + "Worker 0 finished processing samples.\n", + "Cached processed samples to C:\\Users\\11400\\AppData\\Local\\pyhealth\\pyhealth\\Cache\\64d1ae8a-7617-555e-b299-2d3ab371f7ea\\tasks\\dsa_binary_classification_22214b4b-7cfc-56b3-93b1-c14e945ee58f\\samples_cdbbc602-34e2-5a41-8643-4c76b08829f6.ld\n", + "Setting task dsa_binary_classification for DSA base dataset...\n", + "Task cache paths: task_df=C:\\Users\\11400\\AppData\\Local\\pyhealth\\pyhealth\\Cache\\64d1ae8a-7617-555e-b299-2d3ab371f7ea\\tasks\\dsa_binary_classification_9c9c046a-0d9f-5f0b-b625-bf574b0a8620\\task_df.ld, samples=C:\\Users\\11400\\AppData\\Local\\pyhealth\\pyhealth\\Cache\\64d1ae8a-7617-555e-b299-2d3ab371f7ea\\tasks\\dsa_binary_classification_9c9c046a-0d9f-5f0b-b625-bf574b0a8620\\samples_cdbbc602-34e2-5a41-8643-4c76b08829f6.ld\n", + "Applying task transformations on data with 1 workers...\n", + "Detected Jupyter notebook environment, setting num_workers to 1\n", + "Single worker mode, processing sequentially\n", + "Worker 0 started processing 8 patients. (Polars threads: 16)\n", + "Worker 0 finished processing patients.\n", + "Fitting processors on the dataset...\n", + "Label label vocab: {0: 0, 1: 1}\n", + "Processing samples and saving to C:\\Users\\11400\\AppData\\Local\\pyhealth\\pyhealth\\Cache\\64d1ae8a-7617-555e-b299-2d3ab371f7ea\\tasks\\dsa_binary_classification_9c9c046a-0d9f-5f0b-b625-bf574b0a8620\\samples_cdbbc602-34e2-5a41-8643-4c76b08829f6.ld...\n", + "Applying processors on data with 1 workers...\n", + "Detected Jupyter notebook environment, setting num_workers to 1\n", + "Single worker mode, processing sequentially\n", + "Worker 0 started processing 9120 samples. (0 to 9120)\n", + "Worker 0 finished processing samples.\n", + "Cached processed samples to C:\\Users\\11400\\AppData\\Local\\pyhealth\\pyhealth\\Cache\\64d1ae8a-7617-555e-b299-2d3ab371f7ea\\tasks\\dsa_binary_classification_9c9c046a-0d9f-5f0b-b625-bf574b0a8620\\samples_cdbbc602-34e2-5a41-8643-4c76b08829f6.ld\n", + "Setting task dsa_binary_classification for DSA base dataset...\n", + "Task cache paths: task_df=C:\\Users\\11400\\AppData\\Local\\pyhealth\\pyhealth\\Cache\\64d1ae8a-7617-555e-b299-2d3ab371f7ea\\tasks\\dsa_binary_classification_2a4bb875-e605-50c5-843e-07c42c73f18d\\task_df.ld, samples=C:\\Users\\11400\\AppData\\Local\\pyhealth\\pyhealth\\Cache\\64d1ae8a-7617-555e-b299-2d3ab371f7ea\\tasks\\dsa_binary_classification_2a4bb875-e605-50c5-843e-07c42c73f18d\\samples_cdbbc602-34e2-5a41-8643-4c76b08829f6.ld\n", + "Applying task transformations on data with 1 workers...\n", + "Detected Jupyter notebook environment, setting num_workers to 1\n", + "Single worker mode, processing sequentially\n", + "Worker 0 started processing 8 patients. (Polars threads: 16)\n", + "Worker 0 finished processing patients.\n", + "Fitting processors on the dataset...\n", + "Label label vocab: {0: 0, 1: 1}\n", + "Processing samples and saving to C:\\Users\\11400\\AppData\\Local\\pyhealth\\pyhealth\\Cache\\64d1ae8a-7617-555e-b299-2d3ab371f7ea\\tasks\\dsa_binary_classification_2a4bb875-e605-50c5-843e-07c42c73f18d\\samples_cdbbc602-34e2-5a41-8643-4c76b08829f6.ld...\n", + "Applying processors on data with 1 workers...\n", + "Detected Jupyter notebook environment, setting num_workers to 1\n", + "Single worker mode, processing sequentially\n", + "Worker 0 started processing 9120 samples. (0 to 9120)\n", + "Worker 0 finished processing samples.\n", + "Cached processed samples to C:\\Users\\11400\\AppData\\Local\\pyhealth\\pyhealth\\Cache\\64d1ae8a-7617-555e-b299-2d3ab371f7ea\\tasks\\dsa_binary_classification_2a4bb875-e605-50c5-843e-07c42c73f18d\\samples_cdbbc602-34e2-5a41-8643-4c76b08829f6.ld\n", + "Setting task dsa_binary_classification for DSA base dataset...\n", + "Task cache paths: task_df=C:\\Users\\11400\\AppData\\Local\\pyhealth\\pyhealth\\Cache\\64d1ae8a-7617-555e-b299-2d3ab371f7ea\\tasks\\dsa_binary_classification_e0a051b2-b2fe-5cd7-83db-a48159d45146\\task_df.ld, samples=C:\\Users\\11400\\AppData\\Local\\pyhealth\\pyhealth\\Cache\\64d1ae8a-7617-555e-b299-2d3ab371f7ea\\tasks\\dsa_binary_classification_e0a051b2-b2fe-5cd7-83db-a48159d45146\\samples_cdbbc602-34e2-5a41-8643-4c76b08829f6.ld\n", + "Applying task transformations on data with 1 workers...\n", + "Detected Jupyter notebook environment, setting num_workers to 1\n", + "Single worker mode, processing sequentially\n", + "Worker 0 started processing 8 patients. (Polars threads: 16)\n", + "Worker 0 finished processing patients.\n", + "Fitting processors on the dataset...\n", + "Label label vocab: {0: 0, 1: 1}\n", + "Processing samples and saving to C:\\Users\\11400\\AppData\\Local\\pyhealth\\pyhealth\\Cache\\64d1ae8a-7617-555e-b299-2d3ab371f7ea\\tasks\\dsa_binary_classification_e0a051b2-b2fe-5cd7-83db-a48159d45146\\samples_cdbbc602-34e2-5a41-8643-4c76b08829f6.ld...\n", + "Applying processors on data with 1 workers...\n", + "Detected Jupyter notebook environment, setting num_workers to 1\n", + "Single worker mode, processing sequentially\n", + "Worker 0 started processing 9120 samples. (0 to 9120)\n", + "Worker 0 finished processing samples.\n", + "Cached processed samples to C:\\Users\\11400\\AppData\\Local\\pyhealth\\pyhealth\\Cache\\64d1ae8a-7617-555e-b299-2d3ab371f7ea\\tasks\\dsa_binary_classification_e0a051b2-b2fe-5cd7-83db-a48159d45146\\samples_cdbbc602-34e2-5a41-8643-4c76b08829f6.ld\n", + " Train: [1, 2, 4, 5, 7, 8] Test: [3, 6]\n", + " Computing IPD...\n", + " Building aligned domain arrays (metric=dtw_classic)...\n", + " Aligned 1140 paired samples across 5 domains\n", + " Computing pairwise distances (this may take a minute with DTW)...\n", + "IPD weight T → LA: 5.9967 (metric=dtw_classic)\n", + "IPD weight RA → LA: 5.4693 (metric=dtw_classic)\n", + "IPD weight RL → LA: 7.2420 (metric=dtw_classic)\n", + "IPD weight LL → LA: 7.0441 (metric=dtw_classic)\n", + " IPD rank (desc): RL=7.242 > LL=7.044 > T=5.997 > RA=5.469\n", + " Paper epochs: {'T': 55, 'RA': 50, 'RL': 66, 'LL': 65}\n", + " Training...\n", + " Done in 1388s\n", + " No Transfer : 0.7333\n", + " Naive Transfer : 0.6792\n", + " Weighted Transfer : 0.6833\n" + ] + } + ], + "source": [ + "results = run_experiment()" + ] + }, + { + "cell_type": "markdown", + "id": "49ec3270", + "metadata": {}, + "source": [ + "## Results\n", + "\n", + "The table reports mean RCC and standard deviation across repeats. The paper (Table 1) reports 0.9722 ± 0.0104 for DTW-Paired with LSTM — our replication shows lower values for two reasons:\n", + "\n", + "1. We use only 3 repeats rather than 15, so variance is higher.\n", + "2. The paper's reported results may reflect a fixed positive activity rather than a randomly selected one per repeat.\n", + "\n", + "The **Weighted Transfer gain** row shows how much the IPD-weighted pre-training adds over the no-transfer baseline on average. A negative gain indicates that, for this small sample, the transfer procedure does not reliably help — consistent with the paper's RSS appendix results where transfer does not always improve over the no-transfer baseline." + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "0d323de6", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Paper's Implementation — RCC by condition\n", + "\n", + "Repeat No Transfer Naive Transfer Weighted Transfer\n", + "--------------------------------------------------------------\n", + "Rep 1 0.6333 0.6708 0.6125\n", + "Rep 2 0.7708 0.7875 0.8125\n", + "Rep 3 0.7333 0.6792 0.6833\n", + "--------------------------------------------------------------\n", + "Mean 0.7125 0.7125 0.7028\n", + "Std 0.0580 0.0531 0.0828\n", + "\n", + "Weighted Transfer gain over No Transfer:\n", + " -0.0097 (std: 0.0382)\n" + ] + } + ], + "source": [ + "CONDITIONS = ['No Transfer', 'Naive Transfer', 'Weighted Transfer']\n", + "\n", + "print('\\nPaper\\'s Implementation — RCC by condition\\n')\n", + "col_w = 18\n", + "header = f'{\"Repeat\":<8}' + ''.join(f'{c:>{col_w}}' for c in CONDITIONS)\n", + "print(header)\n", + "print('-' * (8 + col_w * len(CONDITIONS)))\n", + "\n", + "for res in results:\n", + " rep = res['repeat']\n", + " row = f'Rep {rep:<4}'\n", + " for cond in CONDITIONS:\n", + " row += f'{res[cond]:>{col_w}.4f}'\n", + " print(row)\n", + "\n", + "print('-' * (8 + col_w * len(CONDITIONS)))\n", + "row = f'{\"Mean\":<8}'\n", + "for cond in CONDITIONS:\n", + " vals = [r[cond] for r in results]\n", + " row += f'{np.mean(vals):>{col_w}.4f}'\n", + "print(row)\n", + "row = f'{\"Std\":<8}'\n", + "for cond in CONDITIONS:\n", + " vals = [r[cond] for r in results]\n", + " row += f'{np.std(vals):>{col_w}.4f}'\n", + "print(row)\n", + "\n", + "print(f'\\nWeighted Transfer gain over No Transfer:')\n", + "wt_vals = [r['Weighted Transfer'] for r in results]\n", + "nt_vals = [r['No Transfer'] for r in results]\n", + "gain = np.mean(wt_vals) - np.mean(nt_vals)\n", + "print(f' {gain:+.4f} (std: {np.std([w - n for w, n in zip(wt_vals, nt_vals)]):.4f})')\n" + ] + }, + { + "cell_type": "markdown", + "id": "6b7c9d74", + "metadata": {}, + "source": [ + "In my limited sample size I found that the weighted transfer did not perform better than no knowledge transfer. It should be noted that this example used only 3 repetitions (instead of the 15 of the paper) and the results are very high variance." + ] + }, + { + "cell_type": "markdown", + "id": "b5293c25", + "metadata": {}, + "source": [ + "# Ablations" + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "id": "5dcf49fe", + "metadata": {}, + "outputs": [], + "source": [ + "def evaluate_f1_binary(model, loader, device) -> float:\n", + " from sklearn.metrics import f1_score\n", + " model.eval()\n", + " all_preds, all_labels = [], []\n", + " with torch.no_grad():\n", + " for batch in loader:\n", + " batch = move_to_device(batch, device)\n", + " out = model(**batch)\n", + " y_prob = out['y_prob']\n", + "\n", + " if y_prob.shape[-1] == 1:\n", + " preds = (y_prob.squeeze(-1) >= 0.5).long()\n", + " else:\n", + " preds = y_prob.argmax(dim=1)\n", + "\n", + " raw = batch['label']\n", + " if isinstance(raw, torch.Tensor):\n", + " labels = raw.view(-1).long()\n", + " else:\n", + " labels = torch.tensor(raw, dtype=torch.long)\n", + "\n", + " all_preds.extend(preds.cpu().tolist())\n", + " all_labels.extend(labels.tolist())\n", + "\n", + " return f1_score(all_labels, all_preds, average='macro', zero_division=0)\n", + "\n", + "\n", + "def majority_baseline_rcc(samples: list) -> float:\n", + " \"\"\"RCC achieved by always predicting the most common class.\"\"\"\n", + " from collections import Counter\n", + " counts = Counter(s['label'] for s in samples)\n", + " return counts.most_common(1)[0][1] / len(samples)\n", + "\n", + "\n", + "def train_one_pass(model, loader, optimizer, device):\n", + " model.train()\n", + " for batch in loader:\n", + " batch = move_to_device(batch, device)\n", + " optimizer.zero_grad()\n", + " out = model(**batch)\n", + " out['loss'].backward()\n", + " optimizer.step()\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 42, + "id": "b728d693", + "metadata": {}, + "outputs": [], + "source": [ + "def run_one_repeat_classification(\n", + " domain_samples: Dict[str, list],\n", + " train_ids: set,\n", + " test_ids: set,\n", + " ipd_weights: Dict[str, float],\n", + " weighted_epochs: Dict[str, int],\n", + " source_domains: List[str],\n", + " target_domain: str,\n", + " balance_train: bool,\n", + " balance_test: bool,\n", + " device,\n", + " rng: random.Random,\n", + ") -> Dict[str, Dict[str, float]]:\n", + " \"\"\"\n", + " Run all three conditions for one subject split under a specific\n", + " class-balance configuration.\n", + "\n", + " Returns\n", + " -------\n", + " dict condition → {'rcc': float, 'f1': float}\n", + " \"\"\"\n", + " loaders = {}\n", + " test_sets = {} # keep raw test samples for majority baseline\n", + "\n", + " for domain, samples in domain_samples.items():\n", + " tr, te = split_by_subject(samples, train_ids, test_ids)\n", + "\n", + " if domain == target_domain:\n", + " if balance_train:\n", + " tr = upsample_positives(tr, rng)\n", + " if balance_test:\n", + " te = downsample_negatives(te, rng)\n", + " test_sets[domain] = te # save for baseline calc\n", + "\n", + " loaders[domain] = (\n", + " make_dataloader(tr, BATCH_SIZE, shuffle=True),\n", + " make_dataloader(te, BATCH_SIZE, shuffle=False),\n", + " )\n", + "\n", + " tgt_tr_ld, tgt_te_ld = loaders[target_domain]\n", + " ref = domain_samples[target_domain]\n", + " results = {}\n", + "\n", + " for cond_name, train_fn in [\n", + " ('No Transfer', lambda m: None),\n", + " ('Naive Transfer', lambda m: [train_on_domain(m, loaders[s][0],\n", + " SOURCE_EPOCHS_NAIVE, LR, device)\n", + " for s in source_domains]),\n", + " ('Weighted Transfer', lambda m: [train_on_domain(m, loaders[s][0],\n", + " weighted_epochs.get(s, SOURCE_EPOCHS_NAIVE),\n", + " LR, device)\n", + " for s in sorted(source_domains,\n", + " key=lambda d: ipd_weights.get(d, 0.),\n", + " reverse=True)]),\n", + " ]:\n", + " m = build_model(ref, device)\n", + " train_fn(m) # pre-train on source domains\n", + " train_on_domain(m, tgt_tr_ld, TARGET_EPOCHS, LR, device)\n", + " results[cond_name] = {\n", + " 'rcc': evaluate_rcc(m, tgt_te_ld, device),\n", + " 'f1': evaluate_f1_binary(m, tgt_te_ld, device),\n", + " }\n", + "\n", + " results['_majority_rcc'] = majority_baseline_rcc(test_sets[target_domain])\n", + " return results\n" + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "id": "8d424f0f", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loading 19-class samples for all domains...\n", + "Setting task dsa_activity_classification for DSA base dataset...\n", + "Task cache paths: task_df=C:\\Users\\11400\\AppData\\Local\\pyhealth\\pyhealth\\Cache\\64d1ae8a-7617-555e-b299-2d3ab371f7ea\\tasks\\dsa_activity_classification_df8f03e2-23d6-50b2-bd42-742d6eb0ec29\\task_df.ld, samples=C:\\Users\\11400\\AppData\\Local\\pyhealth\\pyhealth\\Cache\\64d1ae8a-7617-555e-b299-2d3ab371f7ea\\tasks\\dsa_activity_classification_df8f03e2-23d6-50b2-bd42-742d6eb0ec29\\samples_cdbbc602-34e2-5a41-8643-4c76b08829f6.ld\n", + "Found cached processed samples at C:\\Users\\11400\\AppData\\Local\\pyhealth\\pyhealth\\Cache\\64d1ae8a-7617-555e-b299-2d3ab371f7ea\\tasks\\dsa_activity_classification_df8f03e2-23d6-50b2-bd42-742d6eb0ec29\\samples_cdbbc602-34e2-5a41-8643-4c76b08829f6.ld, skipping processing.\n", + " T: 9120 samples, 9120 classes\n", + "Setting task dsa_activity_classification for DSA base dataset...\n", + "Task cache paths: task_df=C:\\Users\\11400\\AppData\\Local\\pyhealth\\pyhealth\\Cache\\64d1ae8a-7617-555e-b299-2d3ab371f7ea\\tasks\\dsa_activity_classification_a06c548b-4566-5f53-b62a-3b5be005572b\\task_df.ld, samples=C:\\Users\\11400\\AppData\\Local\\pyhealth\\pyhealth\\Cache\\64d1ae8a-7617-555e-b299-2d3ab371f7ea\\tasks\\dsa_activity_classification_a06c548b-4566-5f53-b62a-3b5be005572b\\samples_cdbbc602-34e2-5a41-8643-4c76b08829f6.ld\n", + "Found cached processed samples at C:\\Users\\11400\\AppData\\Local\\pyhealth\\pyhealth\\Cache\\64d1ae8a-7617-555e-b299-2d3ab371f7ea\\tasks\\dsa_activity_classification_a06c548b-4566-5f53-b62a-3b5be005572b\\samples_cdbbc602-34e2-5a41-8643-4c76b08829f6.ld, skipping processing.\n", + " RA: 9120 samples, 9120 classes\n", + "Setting task dsa_activity_classification for DSA base dataset...\n", + "Task cache paths: task_df=C:\\Users\\11400\\AppData\\Local\\pyhealth\\pyhealth\\Cache\\64d1ae8a-7617-555e-b299-2d3ab371f7ea\\tasks\\dsa_activity_classification_e933a306-59f3-5bc4-91ba-a9124ecf23eb\\task_df.ld, samples=C:\\Users\\11400\\AppData\\Local\\pyhealth\\pyhealth\\Cache\\64d1ae8a-7617-555e-b299-2d3ab371f7ea\\tasks\\dsa_activity_classification_e933a306-59f3-5bc4-91ba-a9124ecf23eb\\samples_cdbbc602-34e2-5a41-8643-4c76b08829f6.ld\n", + "Found cached processed samples at C:\\Users\\11400\\AppData\\Local\\pyhealth\\pyhealth\\Cache\\64d1ae8a-7617-555e-b299-2d3ab371f7ea\\tasks\\dsa_activity_classification_e933a306-59f3-5bc4-91ba-a9124ecf23eb\\samples_cdbbc602-34e2-5a41-8643-4c76b08829f6.ld, skipping processing.\n", + " LA: 9120 samples, 9120 classes\n", + "Setting task dsa_activity_classification for DSA base dataset...\n", + "Task cache paths: task_df=C:\\Users\\11400\\AppData\\Local\\pyhealth\\pyhealth\\Cache\\64d1ae8a-7617-555e-b299-2d3ab371f7ea\\tasks\\dsa_activity_classification_33ae4850-f0d1-5779-bb0f-d4ef5307c841\\task_df.ld, samples=C:\\Users\\11400\\AppData\\Local\\pyhealth\\pyhealth\\Cache\\64d1ae8a-7617-555e-b299-2d3ab371f7ea\\tasks\\dsa_activity_classification_33ae4850-f0d1-5779-bb0f-d4ef5307c841\\samples_cdbbc602-34e2-5a41-8643-4c76b08829f6.ld\n", + "Found cached processed samples at C:\\Users\\11400\\AppData\\Local\\pyhealth\\pyhealth\\Cache\\64d1ae8a-7617-555e-b299-2d3ab371f7ea\\tasks\\dsa_activity_classification_33ae4850-f0d1-5779-bb0f-d4ef5307c841\\samples_cdbbc602-34e2-5a41-8643-4c76b08829f6.ld, skipping processing.\n", + " RL: 9120 samples, 9120 classes\n", + "Setting task dsa_activity_classification for DSA base dataset...\n", + "Task cache paths: task_df=C:\\Users\\11400\\AppData\\Local\\pyhealth\\pyhealth\\Cache\\64d1ae8a-7617-555e-b299-2d3ab371f7ea\\tasks\\dsa_activity_classification_fa4748a0-5895-5581-8db3-1cf4ad0d1dc4\\task_df.ld, samples=C:\\Users\\11400\\AppData\\Local\\pyhealth\\pyhealth\\Cache\\64d1ae8a-7617-555e-b299-2d3ab371f7ea\\tasks\\dsa_activity_classification_fa4748a0-5895-5581-8db3-1cf4ad0d1dc4\\samples_cdbbc602-34e2-5a41-8643-4c76b08829f6.ld\n", + "Found cached processed samples at C:\\Users\\11400\\AppData\\Local\\pyhealth\\pyhealth\\Cache\\64d1ae8a-7617-555e-b299-2d3ab371f7ea\\tasks\\dsa_activity_classification_fa4748a0-5895-5581-8db3-1cf4ad0d1dc4\\samples_cdbbc602-34e2-5a41-8643-4c76b08829f6.ld, skipping processing.\n", + " LL: 9120 samples, 9120 classes\n" + ] + } + ], + "source": [ + "print('Loading 19-class samples for all domains...')\n", + "mc_domain_samples: Dict[str, list] = {}\n", + "for domain in ALL_DOMAINS:\n", + " task = DSAActivityClassification(target_domain=domain)\n", + " mc_domain_samples[domain] = dataset.set_task(task)\n", + " print(f' {domain}: {len(mc_domain_samples[domain])} samples, '\n", + " f'{len(set(s[\"label\"] for s in mc_domain_samples[domain]))} classes')\n" + ] + }, + { + "cell_type": "markdown", + "id": "ec43b78d", + "metadata": {}, + "source": [ + "The four binary setups all frame the problem as one-vs-rest detection of a single activity, but differ in how they handle the 1:18 class imbalance:\n", + "\n", + "Balanced (paper) — upsamples positives in training, downsamples negatives in testing. This is the paper's exact protocol and produces a fair 50/50 test set.\n", + "\n", + "Unbalanced train — keeps the natural imbalance in training but still downsamples the test set. Tests whether the balanced test benefit survives an imbalanced training regime.\n", + "\n", + "Natural test — natural imbalance in both train and test. The closest to real-world deployment conditions.\n", + "\n", + "Fully unbalanced — natural imbalance everywhere with no rebalancing at any stage." + ] + }, + { + "cell_type": "code", + "execution_count": 52, + "id": "0ff3c45a", + "metadata": {}, + "outputs": [], + "source": [ + "def build_aligned_domain_arrays(domain_samples_dict, train_ids, channel=0):\n", + " domain_maps = {\n", + " dom: {s['pair_id']: s['time_series']\n", + " for s in samps if s['patient_id'] in train_ids}\n", + " for dom, samps in domain_samples_dict.items()\n", + " }\n", + " common = sorted(set.intersection(*[set(m.keys()) for m in domain_maps.values()]))\n", + " arrays = {}\n", + " for dom, pm in domain_maps.items():\n", + " arr = np.array([pm[pid][channel] for pid in common], dtype=np.float32)\n", + " arrays[dom] = arr[:, :, np.newaxis]\n", + " return arrays\n", + "\n", + "\n", + "def get_all_subject_ids(samples):\n", + " return sorted(set(int(s['patient_id'][1:]) for s in samples))\n", + "\n", + "\n", + "# ── Classification ablation setups ───────────────────────────────────\n", + "# Each entry: (label, domain_samples_dict, balance_train, balance_test)\n", + "#\n", + "# 'Binary + natural test' deliberately omits test downsampling.\n", + "# RCC there is not comparable to other setups — majority-class baseline\n", + "# gives ~0.947 (18/19 samples are negative). Watch the F1 column instead.\n", + "\n", + "CLASSIFICATION_SETUPS = [\n", + " ('Binary + balanced (paper)', mc_domain_samples, True, True),\n", + " ('Binary + unbalanced train', mc_domain_samples, False, True),\n", + " ('Binary + natural test', mc_domain_samples, True, False),\n", + " ('Binary + fully unbalanced', mc_domain_samples, False, False),\n", + " # Multiclass handled separately below — different labels\n", + "]\n", + "\n", + "\n", + "def run_classification_ablation(\n", + " n_repeats: int = N_REPEATS,\n", + " base_seed: int = BASE_SEED,\n", + ") -> Dict[str, list]:\n", + " \"\"\"\n", + " Run every classification setup for n_repeats independent subject splits.\n", + "\n", + " The subject splits and IPD weights are computed ONCE per repeat and\n", + " then reused across all setups so results are directly comparable.\n", + "\n", + " Returns\n", + " -------\n", + " dict setup_label → list of per-repeat result dicts\n", + " \"\"\"\n", + " all_subjects = get_all_subject_ids(mc_domain_samples[TARGET_DOMAIN])\n", + " rng = random.Random(base_seed)\n", + "\n", + " # One results list per setup\n", + " results = {label: [] for label, *_ in CLASSIFICATION_SETUPS}\n", + " results['Multiclass'] = []\n", + "\n", + " for rep in range(n_repeats):\n", + " print(f'\\n{\"=\"*60}')\n", + " print(f' Ablation A — Repeat {rep + 1} / {n_repeats}')\n", + " print(f'{\"=\"*60}')\n", + "\n", + " # ── Subject split (shared across all setups this repeat) ─────\n", + " train_subj = sorted(rng.sample(all_subjects, N_TRAIN_SUBJ))\n", + " test_subj = [s for s in all_subjects if s not in train_subj]\n", + " train_ids = {f'p{s}' for s in train_subj}\n", + " test_ids = {f'p{s}' for s in test_subj}\n", + " print(f' Train: {train_subj} Test: {test_subj}')\n", + "\n", + " # ── IPD (shared; computed from binary samples, training subjs) ─\n", + " print(' Computing IPD weights...')\n", + " try:\n", + " arrays = build_aligned_domain_arrays(mc_domain_samples, train_ids)\n", + " ipd_weights = compute_all_ipd_weights(\n", + " domain_data=arrays, target_domain=TARGET_DOMAIN,\n", + " metric=METRIC, bandwidth=KDE_BANDWIDTH)\n", + " w_sum = sum(ipd_weights.values())\n", + " # divided the epochs by 3 to reduce computation time\n", + " w_epochs = {\n", + " d: math.ceil((int(SOURCE_EPOCHS_NAIVE * EPOCH_SCALE_FACTOR * w / w_sum) + 1) / 3)\n", + " for d, w in ipd_weights.items()\n", + " }\n", + " except Exception as e:\n", + " print(f' IPD failed ({e}); using uniform weights.')\n", + " ipd_weights = {d: 1. for d in SOURCE_DOMAINS}\n", + " w_epochs = {d: SOURCE_EPOCHS_NAIVE for d in SOURCE_DOMAINS}\n", + "\n", + " # ── Binary setups ────────────────────────────────────────────\n", + " for label, samp_dict, bal_tr, bal_te in CLASSIFICATION_SETUPS:\n", + " print(f' Running: {label}')\n", + " rep_rng = random.Random(rng.random()) # deterministic but independent\n", + " res = run_one_repeat_classification(\n", + " domain_samples = samp_dict,\n", + " train_ids = train_ids,\n", + " test_ids = test_ids,\n", + " ipd_weights = ipd_weights,\n", + " weighted_epochs = w_epochs,\n", + " source_domains = SOURCE_DOMAINS,\n", + " target_domain = TARGET_DOMAIN,\n", + " balance_train = bal_tr,\n", + " balance_test = bal_te,\n", + " device = device,\n", + " rng = rep_rng,\n", + " )\n", + " results[label].append(res)\n", + " for cond in ('No Transfer', 'Naive Transfer', 'Weighted Transfer'):\n", + " r = res[cond]\n", + " print(f' {cond:<22}: RCC={r[\"rcc\"]:.4f} F1={r[\"f1\"]:.4f}')\n", + "\n", + " # ── Multiclass ───────────────────────────────────────────────\n", + " print(' Running: Multiclass (19-class)')\n", + " mc_loaders = {}\n", + " for domain, samps in mc_domain_samples.items():\n", + " tr, te = split_by_subject(samps, train_ids, test_ids)\n", + " mc_loaders[domain] = (\n", + " make_dataloader(tr, BATCH_SIZE, shuffle=True),\n", + " make_dataloader(te, BATCH_SIZE, shuffle=False),\n", + " )\n", + " mc_tgt_tr, mc_tgt_te = mc_loaders[TARGET_DOMAIN]\n", + " mc_ref = mc_domain_samples[TARGET_DOMAIN]\n", + " mc_res = {}\n", + "\n", + " for cond_name, src_list in [\n", + " ('No Transfer', []),\n", + " ('Naive Transfer', SOURCE_DOMAINS),\n", + " ('Weighted Transfer', sorted(SOURCE_DOMAINS,\n", + " key=lambda d: ipd_weights.get(d, 0.),\n", + " reverse=True)),\n", + " ]:\n", + " m = build_model(mc_ref, device)\n", + " for src in src_list:\n", + " n_ep = w_epochs.get(src, SOURCE_EPOCHS_NAIVE)\n", + " train_on_domain(m, mc_loaders[src][0], n_ep, LR, device)\n", + " train_on_domain(m, mc_tgt_tr, TARGET_EPOCHS, LR, device)\n", + " mc_res[cond_name] = {\n", + " 'rcc': evaluate_rcc(m, mc_tgt_te, device),\n", + " 'f1': float('nan'), # macro F1 for 19 classes not shown\n", + " }\n", + " print(f' {cond_name:<22}: RCC={mc_res[cond_name][\"rcc\"]:.4f}')\n", + "\n", + " results['Multiclass'].append(mc_res)\n", + "\n", + " return results\n" + ] + }, + { + "cell_type": "code", + "execution_count": 54, + "id": "90f6a6d3", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "============================================================\n", + " Ablation A — Repeat 1 / 1\n", + "============================================================\n", + " Train: [1, 2, 3, 6, 7, 8] Test: [4, 5]\n", + " Computing IPD weights...\n", + "IPD weight T → LA: 5.9967 (metric=dtw_classic)\n", + "IPD weight RA → LA: 5.4693 (metric=dtw_classic)\n", + "IPD weight RL → LA: 7.2420 (metric=dtw_classic)\n", + "IPD weight LL → LA: 7.0441 (metric=dtw_classic)\n", + " Running: Binary + balanced (paper)\n", + " No Transfer : RCC=0.5042 F1=0.5041\n", + " Naive Transfer : RCC=0.5750 F1=0.5731\n", + " Weighted Transfer : RCC=0.6000 F1=0.5982\n", + " Running: Binary + unbalanced train\n", + " No Transfer : RCC=0.2542 F1=0.0555\n", + " Naive Transfer : RCC=0.0917 F1=0.0222\n", + " Weighted Transfer : RCC=0.1792 F1=0.0342\n", + " Running: Binary + natural test\n", + " No Transfer : RCC=0.0513 F1=0.0104\n", + " Naive Transfer : RCC=0.1057 F1=0.0735\n", + " Weighted Transfer : RCC=0.0930 F1=0.0611\n", + " Running: Binary + fully unbalanced\n", + " No Transfer : RCC=0.4377 F1=0.4300\n", + " Naive Transfer : RCC=0.4105 F1=0.3854\n", + " Weighted Transfer : RCC=0.4237 F1=0.4084\n", + " Running: Multiclass (19-class)\n", + " No Transfer : RCC=0.4272\n", + " Naive Transfer : RCC=0.4206\n", + " Weighted Transfer : RCC=0.4241\n" + ] + } + ], + "source": [ + "classification_results = run_classification_ablation(n_repeats=1)" + ] + }, + { + "cell_type": "markdown", + "id": "d6ca88a0", + "metadata": {}, + "source": [ + "## Results" + ] + }, + { + "cell_type": "code", + "execution_count": 57, + "id": "9b6f39b5", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "RCC (Ratio of Correct Classifications)\n", + "\n", + "Setup No Transfer Naive Transfer Weighted Transfer\n", + "--------------------------------------------------------------------------------------------------------\n", + "Binary + balanced (paper) 0.5042 0.5750 0.6000\n", + "Binary + unbalanced train 0.2542 0.0917 0.1792\n", + "Binary + natural test 0.0513 0.1057 0.0930\n", + "Binary + fully unbalanced 0.4377 0.4105 0.4237\n", + "Multiclass 0.4272 0.4206 0.4241\n", + "\n", + "\n", + "Macro F1 (binary setups only; not distorted by class imbalance)\n", + "\n", + "Setup No Transfer Naive Transfer Weighted Transfer\n", + "--------------------------------------------------------------------------------------------------------\n", + "Binary + balanced (paper) 0.5041 0.5731 0.5982\n", + "Binary + unbalanced train 0.0555 0.0222 0.0342\n", + "Binary + natural test 0.0104 0.0735 0.0611\n", + "Binary + fully unbalanced 0.4300 0.3854 0.4084\n", + "\n", + "\n", + "Majority-class baseline RCC (always predict most common class)\n", + "\n", + " Binary + balanced (paper) : 0.0042\n", + " Binary + unbalanced train : 0.0042\n", + " Binary + natural test : 0.0004\n", + " Binary + fully unbalanced : 0.0004\n", + "\n", + "\n", + "Interpretation guide:\n", + " Weighted-Transfer gain over No-Transfer in each setup:\n", + " Binary + balanced (paper) : Δ = +0.0958\n", + " Binary + unbalanced train : Δ = -0.0750\n", + " Binary + natural test : Δ = +0.0417\n", + " Binary + fully unbalanced : Δ = -0.0140\n", + " Multiclass : Δ = -0.0031\n" + ] + } + ], + "source": [ + "CONDITIONS = ['No Transfer', 'Naive Transfer', 'Weighted Transfer']\n", + "ALL_SETUPS = [label for label, *_ in CLASSIFICATION_SETUPS] + ['Multiclass']\n", + "\n", + "# ── RCC table ────────────────────────────────────────────────────────\n", + "print('RCC (Ratio of Correct Classifications)\\n')\n", + "col_w = 24\n", + "header = f'{\"Setup\":<32}' + ''.join(f'{c:>{col_w}}' for c in CONDITIONS)\n", + "print(header)\n", + "print('-' * (32 + col_w * len(CONDITIONS)))\n", + "\n", + "for setup in ALL_SETUPS:\n", + " reps = classification_results[setup]\n", + " row = f'{setup:<32}'\n", + " for cond in CONDITIONS:\n", + " vals = [r[cond]['rcc'] for r in reps if cond in r]\n", + " row += f'{np.mean(vals):.4f}'.rjust(col_w)\n", + " print(row)\n", + "\n", + "# ── F1 table (binary setups only) ────────────────────────────────────\n", + "print('\\n\\nMacro F1 (binary setups only; not distorted by class imbalance)\\n')\n", + "header_f1 = f'{\"Setup\":<32}' + ''.join(f'{c:>{col_w}}' for c in CONDITIONS)\n", + "print(header_f1)\n", + "print('-' * (32 + col_w * len(CONDITIONS)))\n", + "\n", + "for label, *_ in CLASSIFICATION_SETUPS:\n", + " reps = classification_results[label]\n", + " row = f'{label:<32}'\n", + " for cond in CONDITIONS:\n", + " vals = [r[cond]['f1'] for r in reps if cond in r]\n", + " row += f'{np.mean(vals):.4f}'.rjust(col_w)\n", + " print(row)\n", + "\n", + "# ── Majority-class baseline (for natural-test rows) ──────────────────\n", + "print('\\n\\nMajority-class baseline RCC (always predict most common class)\\n')\n", + "for label, *_ in CLASSIFICATION_SETUPS:\n", + " reps = classification_results[label]\n", + " baselines = [r['_majority_rcc'] for r in reps]\n", + " print(f' {label:<40}: {np.mean(baselines):.4f}')\n", + "\n", + "# ── Summary interpretation guide ─────────────────────────────────────\n", + "print('\\n\\nInterpretation guide:')\n", + "print(' Weighted-Transfer gain over No-Transfer in each setup:')\n", + "for setup in ALL_SETUPS:\n", + " reps = classification_results[setup]\n", + " nt_vals = [r['No Transfer']['rcc'] for r in reps if 'No Transfer' in r]\n", + " wt_vals = [r['Weighted Transfer']['rcc'] for r in reps if 'Weighted Transfer' in r]\n", + " gain = np.mean(wt_vals) - np.mean(nt_vals)\n", + " print(f' {setup:<40}: Δ = {gain:+.4f}')\n" + ] + }, + { + "cell_type": "markdown", + "id": "171a0063", + "metadata": {}, + "source": [ + "Note that the sample size is rather low and the variance on these results are high.\n", + "\n", + "In my sample result I found that the Macro F1 results reveal that without the upsampling/downsampling balancing that the paper employs, the the results suffer greatly generally. When considering the paper's novel methodology of weighted knowledge transfer, it did give a meaningful boost to performance in the sample of +0.0958 over no transfer, but this gain did not hold across other data balancing schemes or in the multiclass result." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.10" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/pyhealth/datasets/configs/dsa.yaml b/pyhealth/datasets/configs/dsa.yaml new file mode 100644 index 000000000..8f9cef026 --- /dev/null +++ b/pyhealth/datasets/configs/dsa.yaml @@ -0,0 +1,15 @@ +# Author: Edward Guan (edwardg2@illinois.edu) +version: "1.0" +tables: + dsa: + file_path: "dsa-metadata-pyhealth.csv" + patient_id: "patient_id" + timestamp: null + attributes: + - label + - activity_id + - activity_name + - filepath + - segment_id + - pair_id + - visit_id \ No newline at end of file diff --git a/pyhealth/datasets/dsa.py b/pyhealth/datasets/dsa.py new file mode 100644 index 000000000..836d9b8ac --- /dev/null +++ b/pyhealth/datasets/dsa.py @@ -0,0 +1,523 @@ +""" +PyHealth dataset for the UCI Daily and Sports Activities (DSA) dataset. + +Dataset link: + https://archive.ics.uci.edu/dataset/256/daily+and+sports+activities + +Dataset paper: (please cite if you use this dataset) + Kerem Altun, Billur Barshan, and Orkun Tunçel. "Comparative Study on + Classifying Human Activities with Miniature Inertial and Magnetic Sensors." + Pattern Recognition 43(10): 3605-3620, 2010. + +Dataset paper link: + https://doi.org/10.1016/j.patcog.2010.04.019 + +Author: + Edward Guan (edwardg2@illinois.edu) +""" + +import logging +import os +import random +import zipfile +from pathlib import Path +from typing import List, Optional +import urllib.request + +import numpy as np +import pandas as pd + +from pyhealth.datasets import BaseDataset + +logger = logging.getLogger(__name__) + + +class DSADataset(BaseDataset): + """Dataset class for the UCI Daily and Sports Activities (DSA) dataset. + + The dataset contains motion sensor data of 19 daily and sports activities, + each performed by 8 subjects for 5 minutes. Five Xsens MTx sensor units + are placed on the torso, right arm, left arm, right leg, and left leg. + Each unit records 9-channel data (x/y/z accelerometer, gyroscope, and + magnetometer) at 25 Hz, segmented into 5-second (125-timestep) windows. + + The dataset is structured to support multi-source transfer learning: all + five sensor domains are recorded simultaneously, so every segment across + domains is paired by the same subject performing the same activity at the + same moment. This pairwise structure is preserved in the ``pair_id`` field + of each indexed row. + + Attributes: + root (str): Root directory of the raw data. + dataset_name (str): Name of the dataset. + config_path (str): Path to the configuration file. + target_domain (str): Sensor placement used as the classification target. + activities (List[str]): Ordered list of activity names (1-indexed). + domains (List[str]): Ordered list of sensor domain keys. + domain_full_names (dict): Mapping from domain key to descriptive name. + + Example:: + + >>> dataset = DSADataset(root="./data/DSA") + >>> print(len(dataset)) # number of indexed rows + >>> task_ds = dataset.set_task() + """ + + # ------------------------------------------------------------------ + # Class-level constants + # ------------------------------------------------------------------ + + activities: List[str] = [ + "sitting", + "standing", + "lying_back", + "lying_right", + "ascending_stairs", + "descending_stairs", + "elevator_standing", + "elevator_moving", + "walking_parking_lot", + "walking_treadmill_flat", + "walking_treadmill_inclined", + "running", + "stepper", + "cross_trainer", + "cycling_horizontal", + "cycling_vertical", + "rowing", + "jumping", + "basketball", + ] + + domains: List[str] = ["T", "RA", "LA", "RL", "LL"] + + domain_full_names: dict = { + "T": "Torso", + "RA": "Right Arm", + "LA": "Left Arm", + "RL": "Right Leg", + "LL": "Left Leg", + } + + # Column slices within each 45-column row (0-indexed, end exclusive) + _domain_cols: dict = { + "T": (0, 9), + "RA": (9, 18), + "LA": (18, 27), + "RL": (27, 36), + "LL": (36, 45), + } + + _N_ACTIVITIES = 19 + _N_SUBJECTS = 8 + _N_SEGMENTS = 60 + _N_CHANNELS = 9 # per sensor unit + _N_TIMESTEPS = 125 # 5 sec at 25 Hz + + # ------------------------------------------------------------------ + # Constructor + # ------------------------------------------------------------------ + + def __init__( + self, + root: str = ".", + config_path: Optional[str] = str(Path(__file__).parent / "configs" / "dsa.yaml"), + download: bool = False, + target_domain: str = "LA", + scale: bool = True, + **kwargs, + ) -> None: + """Initialises the DSA dataset. + + Args: + root (str): Root directory of the raw data. Must contain folders + ``a01`` through ``a19`` after download or manual extraction. + Defaults to the working directory. + config_path (Optional[str]): Path to a PyHealth YAML configuration + file. Defaults to "../configs/dsa.yaml" + download (bool): If ``True``, download and extract the dataset from + the UCI ML Repository into ``root``. Defaults to ``False``. + target_domain (str): Sensor domain treated as the target for + classification. Must be one of ``["T", "RA", "LA", "RL", "LL"]``. + Defaults to ``"LA"`` (Left Arm, simulating a wrist wearable). + scale (bool): If ``True``, apply per-channel min-max scaling to + ``[-1, 1]`` when loading time series arrays. Defaults to ``True``. + + Raises: + ValueError: If ``target_domain`` is not a valid domain key. + FileNotFoundError: If ``root`` does not exist or lacks ``a01``. + FileNotFoundError: If any expected segment file is missing. + + Example:: + + >>> dataset = DSADataset(root="./data/DSA", target_domain="LA") + """ + if target_domain not in self.domains: + raise ValueError( + f"target_domain must be one of {self.domains}, " + f"got '{target_domain}'." + ) + + self.target_domain = target_domain + self.scale = scale + self._metadata_path = os.path.join(root, "dsa-metadata-pyhealth.csv") + + if download: + self._download(root) + + self._verify_data(root) + self._index_data(root) + + super().__init__( + root=root, + tables=["dsa"], + dataset_name="DSA", + config_path=config_path, + **kwargs, + ) + + # ------------------------------------------------------------------ + # Default task + # ------------------------------------------------------------------ + + @property + def default_task(self): + """Returns the default task for this dataset. + + Returns: + DSAActivityClassification: The default classification task using the target domain time series. + + Example:: + + >>> dataset = DSADataset(root="./data/DSA") + >>> task = dataset.default_task + """ + # Import here to avoid circular imports + from pyhealth.tasks.dsa import DSAActivityClassification + return DSAActivityClassification() + + # ------------------------------------------------------------------ + # Download + # ------------------------------------------------------------------ + + def _download(self, root: str) -> None: + """Downloads and extracts the DSA dataset from the UCI ML Repository. + + Downloads the zip archive (~163 MB), extracts it into ``root``, and + removes the archive afterwards. + + Args: + root (str): Destination directory for the extracted dataset. + + Raises: + FileNotFoundError: If extraction produces no ``a01`` folder. + """ + os.makedirs(root, exist_ok=True) + url = ( + "https://archive.ics.uci.edu/static/public/256/" + "daily+and+sports+activities.zip" + ) + zip_path = os.path.join(root, "dsa.zip") + + logger.info(f"Downloading DSA dataset from {url} ...") + urllib.request.urlretrieve(url, zip_path) + logger.info("Download complete. Extracting ...") + + with zipfile.ZipFile(zip_path, "r") as zf: + # Validate paths before extraction (safety check) + for member in zf.namelist(): + member_path = os.path.realpath(os.path.join(root, member)) + if not member_path.startswith(os.path.realpath(root)): + raise ValueError( + f"Unsafe path detected in zip: '{member}'" + ) + zf.extractall(root) + + os.remove(zip_path) + logger.info("Extraction complete.") + + # The zip may contain a top-level subfolder — move contents up if so + extracted_dirs = [ + d for d in os.listdir(root) + if os.path.isdir(os.path.join(root, d)) and d.startswith("data") + ] + if extracted_dirs and not os.path.isdir(os.path.join(root, "a01")): + inner = os.path.join(root, extracted_dirs[0]) + for item in os.listdir(inner): + os.rename(os.path.join(inner, item), os.path.join(root, item)) + os.rmdir(inner) + + # ------------------------------------------------------------------ + # Verification + # ------------------------------------------------------------------ + + def _verify_data(self, root: str) -> None: + """Verifies the presence and structure of the dataset directory. + + Checks that ``root`` exists and contains the expected activity folders + ``a01`` through ``a19``, each with 8 subject subdirectories containing + 60 segment files. + + Args: + root (str): Root directory of the raw data. + + Raises: + FileNotFoundError: If ``root`` does not exist. + FileNotFoundError: If the expected folder ``a01`` is missing. + """ + if not os.path.exists(root): + msg = ( + f"Dataset root '{root}' does not exist. " + "Pass download=True to download it automatically." + ) + logger.error(msg) + raise FileNotFoundError(msg) + + expected_dir = os.path.join(root, "a01") + if not os.path.isdir(expected_dir): + msg = ( + f"Expected activity folder '{expected_dir}' not found. " + "Ensure 'root' points to the directory containing a01..a19." + ) + logger.error(msg) + raise FileNotFoundError(msg) + + logger.info("Dataset structure verified.") + + # ------------------------------------------------------------------ + # Indexing + # ------------------------------------------------------------------ + + def _index_data(self, root: str) -> pd.DataFrame: + """Parses the dataset directory structure into a metadata index. + + Walks all activity, subject, and segment folders to build a flat + DataFrame where each row represents one segment file. No time series + data is loaded at this stage — only file paths and identifiers. + + The ``pair_id`` column encodes the pairwise synchronisation structure: + all five domains share the same ``pair_id`` for a given + (activity, segment) combination, regardless of subject. This field is + required for Inter-domain Pairwise Distance (IPD) computation. + + Args: + root (str): Root directory of the raw data. + + Returns: + pd.DataFrame: Metadata index saved to ``dsa-metadata-pyhealth.csv``. + + Raises: + FileNotFoundError: If any expected segment file is missing. + """ + rows = [] + + for activity_idx in range(1, self._N_ACTIVITIES + 1): + activity_folder = f"a{activity_idx:02d}" + activity_name = self.activities[activity_idx - 1] + + for subject_id in range(1, self._N_SUBJECTS + 1): + subject_folder = f"p{subject_id}" + subject_dir = os.path.join(root, activity_folder, subject_folder) + + if not os.path.isdir(subject_dir): + raise FileNotFoundError( + f"Expected subject directory not found: {subject_dir}" + ) + + for segment_id in range(1, self._N_SEGMENTS + 1): + filename = f"s{segment_id:02d}.txt" + filepath = os.path.join(subject_dir, filename) + + if not os.path.isfile(filepath): + raise FileNotFoundError( + f"Expected segment file not found: {filepath}" + ) + + rows.append({ + # --- PyHealth standard fields --- + "patient_id": f"p{subject_id}", + "visit_id": f"a{activity_idx:02d}_p{subject_id}_s{segment_id:02d}", + + # --- DSA-specific fields --- + "activity_id": activity_idx, + "activity_name": activity_name, + # 0-indexed label for model output layers + "label": activity_idx - 1, + "segment_id": segment_id, + "filepath": filepath, + + # pair_id links the same (activity, segment) across all + # domains and subjects — the foundation of IPD computation + "pair_id": f"a{activity_idx:02d}_s{segment_id:02d}", + }) + + df = pd.DataFrame(rows) + df.to_csv(self._metadata_path, index=False) + logger.info( + f"Indexed {len(df):,} segment files → {self._metadata_path}" + ) + return df + + # ------------------------------------------------------------------ + # Time series loading helpers + # ------------------------------------------------------------------ + + @staticmethod + def _load_segment(filepath: str) -> np.ndarray: + """Load one segment file into a (125, 45) float32 array. + + Args: + filepath (str): Path to a ``s{segment}.txt`` file. + + Returns: + np.ndarray: Shape ``(125, 45)``, dtype ``float32``. + """ + return np.loadtxt(filepath, delimiter=",", dtype=np.float32) + + def _slice_domain(self, raw: np.ndarray, domain: str) -> np.ndarray: + """Extract one domain's channels from a raw segment array. + + Args: + raw (np.ndarray): Shape ``(125, 45)`` full segment array. + domain (str): Domain key, one of ``["T", "RA", "LA", "RL", "LL"]``. + + Returns: + np.ndarray: Shape ``(9, 125)`` — channels × timesteps. + """ + start, end = self._domain_cols[domain] + # raw[:, start:end] is (125, 9); transpose to (9, 125) + return raw[:, start:end].T + + @staticmethod + def _minmax_scale(ts: np.ndarray) -> np.ndarray: + """Scale each channel of a ``(K, T)`` array independently to ``[-1, 1]``. + + Channels with zero range (flat signal) are left as zeros. + + Args: + ts (np.ndarray): Shape ``(K, T)``. + + Returns: + np.ndarray: Shape ``(K, T)``, values in ``[-1, 1]``. + """ + scaled = np.zeros_like(ts) + for k in range(ts.shape[0]): + mn, mx = ts[k].min(), ts[k].max() + if mx > mn: + scaled[k] = 2.0 * (ts[k] - mn) / (mx - mn) - 1.0 + return scaled + + def load_time_series( + self, + filepath: str, + domain: Optional[str] = None, + ) -> dict: + """Load and preprocess all domain time series from one segment file. + + This is the primary method for retrieving time series data. It is + called by task functions when building model-ready samples. + + Args: + filepath (str): Path to the segment ``.txt`` file. + domain (Optional[str]): If provided, return only this domain's + array. If ``None``, return all five domains. + + Returns: + dict: Mapping ``{domain_key: np.ndarray (9, 125)}``. + If ``domain`` is specified, the dict has one entry. + All arrays are scaled to ``[-1, 1]`` if ``self.scale=True``. + """ + raw = self._load_segment(filepath) + domains_to_load = [domain] if domain else self.domains + + result = {} + for d in domains_to_load: + ts = self._slice_domain(raw, d) + if self.scale: + ts = self._minmax_scale(ts) + result[d] = ts + + return result + + # ------------------------------------------------------------------ + # Subject-level split utilities + # ------------------------------------------------------------------ + + def get_subject_split( + self, + train_subjects: List[int], + test_subjects: List[int], + ) -> tuple: + """Return metadata DataFrames split by subject. + + Args: + train_subjects (List[int]): Subject IDs (1–8) for training. + test_subjects (List[int]): Subject IDs (1–8) for testing. + + Returns: + tuple: ``(train_df, test_df)`` as pandas DataFrames. + + Raises: + ValueError: If train and test subject sets overlap. + + Example:: + + >>> dataset = DSADataset(root="./data/DSA") + >>> train_df, test_df = dataset.get_subject_split( + ... train_subjects=[1,2,3,4,5,6], + ... test_subjects=[7,8], + ... ) + """ + if set(train_subjects) & set(test_subjects): + raise ValueError("train_subjects and test_subjects must not overlap.") + + df = pd.read_csv(self._metadata_path) + train_ids = {f"p{s}" for s in train_subjects} + test_ids = {f"p{s}" for s in test_subjects} + + train_df = df[df["patient_id"].isin(train_ids)].reset_index(drop=True) + test_df = df[df["patient_id"].isin(test_ids)].reset_index(drop=True) + + return train_df, test_df + + def random_subject_splits( + self, + n_repeats: int = 15, + n_train: int = 6, + random_seed: int = 0, + ): + """Generator yielding repeated random train/test subject splits. + + Replicates the paper's evaluation protocol: randomly choose ``n_train`` + of 8 subjects for training, reserve the rest for testing, and repeat + ``n_repeats`` times. Report mean ± std of the metric across repeats. + + Args: + n_repeats (int): Number of random repetitions. Paper uses 15. + n_train (int): Number of training subjects. Paper uses 6. + random_seed (int): Base random seed; repeat ``i`` uses + ``random_seed + i`` for reproducibility. + + Yields: + tuple: ``(repeat_idx, train_subjects, test_subjects, + train_df, test_df)`` + + Example:: + + >>> dataset = DSADataset(root="./data/DSA") + >>> results = [] + >>> for i, train_sub, test_sub, train_df, test_df in \\ + ... dataset.random_subject_splits(n_repeats=15): + ... rcc = run_experiment(train_df, test_df) + ... results.append(rcc) + >>> print(f"{np.mean(results):.4f} ± {np.std(results):.4f}") + """ + all_subjects = list(range(1, self._N_SUBJECTS + 1)) + for i in range(n_repeats): + rng = random.Random(random_seed + i) + shuffled = all_subjects.copy() + rng.shuffle(shuffled) + train_subjects = shuffled[:n_train] + test_subjects = shuffled[n_train:] + train_df, test_df = self.get_subject_split(train_subjects, test_subjects) + yield i, train_subjects, test_subjects, train_df, test_df \ No newline at end of file diff --git a/pyhealth/tasks/dsa.py b/pyhealth/tasks/dsa.py new file mode 100644 index 000000000..a64a2ca72 --- /dev/null +++ b/pyhealth/tasks/dsa.py @@ -0,0 +1,676 @@ +"""PyHealth tasks for the UCI Daily and Sports Activities (DSA) dataset. + +Implements two classification tasks and the Inter-domain Pairwise Distance +(IPD) computation from: + + Zhang et al. "Daily Physical Activity Monitoring — Adaptive Learning + from Multi-source Motion Sensor Data." CHIL 2024 (PMLR 248:39-54). + https://proceedings.mlr.press/v248/zhang24a.html + +.. note:: Paper vs. Code Discrepancies + + The authors' released implementation (github.com/Oceanjinghai/ + HealthTimeSerial) diverges from the paper in several ways. This module + replicates the **code's actual behavior** and documents each divergence: + + 1. **IPD computation**: The paper describes a K-dimensional KDE with a + K×K bandwidth matrix and MCMC sampling (Algorithm 1, Eq. 1). The + code fits a *scalar* 1-D KDE (bandwidth=7.8, hardcoded) on flattened + pairwise distances and draws exactly 10 deterministic samples + (``random_state=0``). + + 2. **Transfer mechanism**: The paper describes per-epoch learning-rate + decay ``λ^{j+1} = λ^j · (1 − α_q)`` (Eq. 5). The code instead + scales the *number of training epochs* per source domain: + ``epochs = int(30 * 7 * weight / weight_all) + 1``. The learning + rate stays fixed at 0.005 throughout. + + 3. **Fine-tuning**: The paper describes k-fold cross-validation with an + adaptive LR and R=5 degeneration stopping (Appendix A, Eq. 7-9). + The code does a single fixed ``model.fit`` for 40 epochs with no + adaptive stopping. + + 4. **Learning rate**: The paper states ``λ₀ = 5×10⁻⁴`` (Appendix B). + The code uses ``lr=0.005`` — ten times larger. + + 5. **Domain ordering**: The paper sorts source domains in descending + IPD order (Algorithm 2, line 10). The code iterates domains in + their natural index order without sorting. + + 6. **Models**: The paper references PyTorch LSTM and sktime + Encoder/ResNet/TapNet. The code uses Keras LSTM (64 units, + dropout=0.2) and a single-block Keras FCN. TapNet is absent. + +Author: + Edward Guan (edwardg2@illinois.edu) +""" + +import logging +from typing import Dict, List, Optional, Tuple + +import numpy as np +from sklearn.neighbors import KernelDensity + +from pyhealth.tasks import BaseTask + +logger = logging.getLogger(__name__) + +# ===================================================================== +# Constants — matching the author's code, not the paper +# ===================================================================== + +# Column slices within each 45-column segment row (0-indexed, end exclusive). +# Mirrors DSADataset._domain_cols so tasks are self-contained. +_DOMAIN_COLS: Dict[str, Tuple[int, int]] = { + "T": (0, 9), + "RA": (9, 18), + "LA": (18, 27), + "RL": (27, 36), + "LL": (36, 45), +} + +# Distance metrics supported by compute_pairwise_distances +SUPPORTED_METRICS: List[str] = [ + "boss", + "dtw_classic", + "dtw_sakoechiba", + "dtw_itakura", + "dtw_multiscale", + "dtw_fast", + "euclidean", +] + +# KDE hyperparameters (hardcoded in author's code, not stated in paper) +DEFAULT_KDE_BANDWIDTH: float = 7.8 +DEFAULT_KDE_N_SAMPLES: int = 10 +DEFAULT_KDE_RANDOM_STATE: int = 0 + +# Training hyperparameters (from author's code, NOT from paper Appendix B) +# +# Code vs. Paper (Appendix B) +# lr=0.005 λ₀ = 5×10⁻⁴ +# 30 epochs/src J = 50 +# 40 epochs/tgt J_target = 100 +# batch=16 (not stated) +# no k-fold k = 10 +# no degen stop R = 5 +CODE_LEARNING_RATE: float = 0.005 +CODE_SOURCE_EPOCHS: int = 30 +CODE_TARGET_EPOCHS_NO_TRANSFER: int = 40 +CODE_TARGET_EPOCHS_NAIVE: int = 30 +CODE_TARGET_EPOCHS_WEIGHTED: int = 40 +CODE_EPOCH_SCALE_FACTOR: int = 7 # the '7' in int(30 * 7 * w / w_all) + 1 +CODE_BATCH_SIZE: int = 16 + + +# ===================================================================== +# Private helpers (replicate DSADataset internals so tasks are +# self-contained and do not need a dataset reference at call time) +# ===================================================================== + +def _load_segment(filepath: str) -> np.ndarray: + """Load one segment file into a (125, 45) float32 array.""" + arr = np.loadtxt(filepath, delimiter=",", dtype=np.float32) + if arr.shape != (125, 45): + raise ValueError( + f"Expected shape (125, 45), got {arr.shape}: {filepath}" + ) + return arr + + +def _slice_domain(raw: np.ndarray, domain: str) -> np.ndarray: + """Extract one domain's channels from a raw (125, 45) array. + + Returns shape (9, 125) — channels × timesteps. + """ + start, end = _DOMAIN_COLS[domain] + return raw[:, start:end].T + + +def _minmax_scale(ts: np.ndarray) -> np.ndarray: + """Scale each channel of a (K, T) array independently to [-1, 1]. + + Channels with zero range (flat signal) are left as zeros. + """ + scaled = np.zeros_like(ts) + for k in range(ts.shape[0]): + mn, mx = ts[k].min(), ts[k].max() + if mx > mn: + scaled[k] = 2.0 * (ts[k] - mn) / (mx - mn) - 1.0 + return scaled + + +def _load_domain_ts(filepath: str, domain: str, scale: bool) -> np.ndarray: + """Load, slice, and optionally scale one domain from a segment file.""" + raw = _load_segment(filepath) + ts = _slice_domain(raw, domain) + if scale: + ts = _minmax_scale(ts) + return ts + + +# ===================================================================== +# Task classes +# ===================================================================== + +class DSAActivityClassification(BaseTask): + """Full 19-class activity classification task for the DSA dataset. + + This is the standard multiclass formulation used in most activity + recognition literature. The paper does not evaluate this setup — + it uses binary classification only — so this task extends the + paper's scope. + + Attributes: + task_name (str): ``"dsa_activity_classification"`` + input_schema (Dict[str, str]): ``{"time_series": "tensor"}`` + output_schema (Dict[str, str]): ``{"label": "multiclass"}`` + target_domain (str): Sensor placement to load (e.g. ``"LA"``). + scale (bool): Whether to apply per-channel min-max scaling. + + Example:: + + >>> from pyhealth.datasets import DSADataset + >>> dataset = DSADataset(root="./data/DSA") + >>> task = DSAActivityClassification(target_domain="LA") + >>> samples = dataset.set_task(task) + >>> samples[0].keys() + dict_keys(['patient_id', 'visit_id', 'time_series', 'label', + 'activity_name', 'pair_id']) + """ + + task_name: str = "dsa_activity_classification" + input_schema: Dict[str, str] = {"time_series": "tensor"} + output_schema: Dict[str, str] = {"label": "multiclass"} + + def __init__( + self, + target_domain: str = "LA", + scale: bool = True, + ) -> None: + """Initialise the 19-class activity classification task. + + Args: + target_domain (str): Sensor domain to load. Must be one of + ``["T", "RA", "LA", "RL", "LL"]``. Defaults to ``"LA"`` + (Left Arm, simulating a wrist wearable). + scale (bool): If ``True``, apply per-channel min-max scaling + to ``[-1, 1]``. Defaults to ``True``. + + Raises: + ValueError: If ``target_domain`` is not a valid domain key. + """ + if target_domain not in _DOMAIN_COLS: + raise ValueError( + f"target_domain must be one of {list(_DOMAIN_COLS)}, " + f"got '{target_domain}'." + ) + self.target_domain = target_domain + self.scale = scale + super().__init__() + + def __call__(self, patient) -> List[Dict]: + """Extract multiclass samples from one patient. + + Each segment file in the patient's record becomes one sample. + Time series data is loaded from disk, sliced to ``target_domain``, + and optionally scaled. + + Args: + patient: A PyHealth ``Patient`` object from ``DSADataset``. + + Returns: + List of sample dicts, each containing: + + - ``patient_id`` (str): Subject identifier. + - ``visit_id`` (str): Segment identifier. + - ``time_series`` (np.ndarray): Shape ``(9, 125)``. + - ``label`` (int): Activity index in ``[0, 18]``. + - ``activity_name`` (str): Human-readable activity string. + - ``pair_id`` (str): Pairwise synchronisation key for IPD. + """ + samples = [] + for event in patient.get_events(event_type="dsa"): + ts = _load_domain_ts(event.filepath, self.target_domain, self.scale) + samples.append( + { + "patient_id": patient.patient_id, + "visit_id": event.visit_id, + "time_series": ts, + "label": int(event.label), + "activity_name": event.activity_name, + "pair_id": event.pair_id, + } + ) + return samples + + +class DSABinaryActivityClassification(BaseTask): + """Binary one-vs-rest activity classification for the DSA dataset. + + Replicates the paper's experimental setup: one activity is the + positive class (label=1), all others are negative (label=0). + Positive samples are upsampled during training and negative samples + are downsampled during evaluation to maintain class balance. + + Attributes: + task_name (str): ``"dsa_binary_classification"`` + input_schema (Dict[str, str]): ``{"time_series": "tensor"}`` + output_schema (Dict[str, str]): ``{"label": "binary"}`` + positive_activity_id (int): 1-indexed activity ID treated as + the positive class. + target_domain (str): Sensor placement to load. + scale (bool): Whether to apply per-channel min-max scaling. + + Example:: + + >>> task = DSABinaryClassification( + ... positive_activity_id=12, + ... target_domain="LA", + ... ) + >>> samples = dataset.set_task(task) + >>> samples[0].keys() + dict_keys(['patient_id', 'visit_id', 'time_series', 'label', + 'activity_id', 'activity_name', 'pair_id']) + """ + + task_name: str = "dsa_binary_classification" + input_schema: Dict[str, str] = {"time_series": "tensor"} + output_schema: Dict[str, str] = {"label": "binary"} + + def __init__( + self, + positive_activity_id: int, + target_domain: str = "LA", + scale: bool = True, + ) -> None: + """Initialise the binary activity classification task. + + Args: + positive_activity_id (int): 1-indexed activity ID treated as + the positive class (1–19). All other activities are + negative. In the paper's protocol this is chosen randomly + for each repetition. + target_domain (str): Sensor domain to load. Must be one of + ``["T", "RA", "LA", "RL", "LL"]``. Defaults to ``"LA"``. + scale (bool): If ``True``, apply per-channel min-max scaling + to ``[-1, 1]``. Defaults to ``True``. + + Raises: + ValueError: If ``target_domain`` is not a valid domain key. + ValueError: If ``positive_activity_id`` is not in 1–19. + """ + if target_domain not in _DOMAIN_COLS: + raise ValueError( + f"target_domain must be one of {list(_DOMAIN_COLS)}, " + f"got '{target_domain}'." + ) + if not (1 <= positive_activity_id <= 19): + raise ValueError( + f"positive_activity_id must be in [1, 19], " + f"got {positive_activity_id}." + ) + self.positive_activity_id = positive_activity_id + self.target_domain = target_domain + self.scale = scale + super().__init__() + + def __call__(self, patient) -> List[Dict]: + """Extract binary samples from one patient. + + Each segment becomes one sample with ``label=1`` if the + segment's activity matches ``positive_activity_id``, else + ``label=0``. + + Args: + patient: A PyHealth ``Patient`` object from ``DSADataset``. + + Returns: + List of sample dicts, each containing: + + - ``patient_id`` (str): Subject identifier. + - ``visit_id`` (str): Segment identifier. + - ``time_series`` (np.ndarray): Shape ``(9, 125)``. + - ``label`` (int): ``1`` if positive class, else ``0``. + - ``activity_id`` (int): 1-indexed activity ID. + - ``activity_name`` (str): Human-readable activity string. + - ``pair_id`` (str): Pairwise synchronisation key for IPD. + """ + samples = [] + for event in patient.get_events(event_type="dsa"): + ts = _load_domain_ts(event.filepath, self.target_domain, self.scale) + is_positive = int(event.activity_id) == self.positive_activity_id + samples.append( + { + "patient_id": patient.patient_id, + "visit_id": event.visit_id, + "time_series": ts, + "label": 1 if is_positive else 0, + "activity_id": int(event.activity_id), + "activity_name": event.activity_name, + "pair_id": event.pair_id, + } + ) + return samples + + +# ===================================================================== +# Inter-domain Pairwise Distance (IPD) +# ===================================================================== + +def compute_pairwise_distances( + source_ts: np.ndarray, + target_ts: np.ndarray, + metric: str = "dtw_classic", +) -> np.ndarray: + """Compute scalar pairwise distances between paired time series. + + Replicates ``cal_similarity`` from the author's ``metrics.py``. + Each pair of univariate time series produces one scalar distance. + + .. note:: Paper vs. Code + + The paper (Algorithm 1) computes per-channel distances and + assembles them into K-dimensional vectors. The code squeezes + each sample to 1-D and computes a single scalar, collapsing + the multivariate structure entirely. + + Args: + source_ts (np.ndarray): Source domain array of shape ``(N, T, 1)`` + or ``(N, T)`` where N is the number of paired samples. + target_ts (np.ndarray): Target domain array with the same shape. + metric (str): Distance metric name. One of ``SUPPORTED_METRICS``. + + Returns: + np.ndarray: Shape ``(N,)`` containing one distance per pair. + + Raises: + ValueError: If ``metric`` is not in ``SUPPORTED_METRICS``. + """ + from pyts.metrics import boss + from pyts.metrics import dtw as _dtw + + n_samples = source_ts.shape[0] + distances = np.zeros(n_samples) + + for i in range(n_samples): + s = np.squeeze(source_ts[i]) + t = np.squeeze(target_ts[i]) + + if metric == "boss": + distances[i] = boss(s, t) + elif metric == "dtw_classic": + distances[i] = _dtw(s, t) + elif metric == "dtw_sakoechiba": + distances[i] = _dtw(s, t, method="sakoechiba", options={"window_size": 0.5}) + elif metric == "dtw_itakura": + distances[i] = _dtw(s, t, method="itakura", options={"max_slope": 1.5}) + elif metric == "dtw_multiscale": + distances[i] = _dtw(s, t, method="multiscale", options={"resolution": 2}) + elif metric == "dtw_fast": + distances[i] = _dtw(s, t, method="fast", options={"radius": 1}) + elif metric == "euclidean": + distances[i] = float(np.linalg.norm(s - t)) + else: + raise ValueError( + f"Unknown metric '{metric}'. Supported: {SUPPORTED_METRICS}" + ) + + return distances + + +def compute_ipd_weight( + distances: np.ndarray, + bandwidth: float = DEFAULT_KDE_BANDWIDTH, + n_samples: int = DEFAULT_KDE_N_SAMPLES, + random_state: int = DEFAULT_KDE_RANDOM_STATE, +) -> float: + """Compute a single IPD weight from pairwise distances via KDE. + + Replicates the author's code: fit a 1-D Gaussian KDE on the scalar + distances, draw ``n_samples`` deterministic samples, and return the + mean as the domain weight. + + .. note:: Paper vs. Code + + The paper (Section 4.1) describes fitting a multivariate Gaussian + KDE on K-dimensional difference vectors, then sampling via MCMC + and computing a matrix norm. The code fits a **scalar** KDE on + flattened distances, draws exactly 10 samples with a fixed seed, + and returns the scalar mean. The bandwidth 7.8 is hardcoded with + no justification or scale-dependence on the distance metric. + + Args: + distances (np.ndarray): Shape ``(N,)`` from + ``compute_pairwise_distances``. + bandwidth (float): KDE bandwidth. Author's code hardcodes ``7.8``. + n_samples (int): Number of points to sample from the fitted KDE. + random_state (int): Seed for deterministic KDE sampling. + + Returns: + float: Scalar weight (higher = less similar to target domain). + """ + kde = KernelDensity(kernel="gaussian", bandwidth=bandwidth).fit( + distances.flatten().reshape(-1, 1) + ) + weight = float( + np.mean(kde.sample(n_samples, random_state=random_state), axis=0)[0] + ) + return weight + + +def compute_all_ipd_weights( + domain_data: Dict[str, np.ndarray], + target_domain: str, + metric: str = "dtw_classic", + bandwidth: float = DEFAULT_KDE_BANDWIDTH, +) -> Dict[str, float]: + """Compute IPD weights for all source domains relative to a target. + + Top-level function that runs the full IPD pipeline for one choice + of target domain. + + Args: + domain_data (Dict[str, np.ndarray]): Mapping of domain key to + array of shape ``(N, T, 1)`` or ``(N, T)``. All domains must + have the same number of paired samples N. + target_domain (str): Key of the target domain in ``domain_data``. + metric (str): Distance metric for pairwise computation. + bandwidth (float): KDE bandwidth. + + Returns: + Dict[str, float]: Mapping of source domain key to its scalar IPD + weight. The target domain is excluded from the output. + + Example:: + + >>> weights = compute_all_ipd_weights( + ... domain_data={"T": ts_t, "RA": ts_ra, "LA": ts_la, + ... "RL": ts_rl, "LL": ts_ll}, + ... target_domain="LA", + ... metric="dtw_classic", + ... ) + >>> for domain, w in sorted(weights.items(), key=lambda x: x[1]): + ... print(f"{domain}: {w:.4f}") + """ + target_ts = domain_data[target_domain] + weights: Dict[str, float] = {} + + for domain, source_ts in domain_data.items(): + if domain == target_domain: + continue + distances = compute_pairwise_distances(source_ts, target_ts, metric) + weights[domain] = compute_ipd_weight(distances, bandwidth=bandwidth) + logger.info( + "IPD weight %s → %s: %.4f (metric=%s)", + domain, target_domain, weights[domain], metric, + ) + + return weights + + +# ===================================================================== +# Epoch scaling — replicating the author's transfer mechanism +# ===================================================================== + +def compute_weighted_epochs( + weights: Dict[str, float], + base_epochs: int = CODE_SOURCE_EPOCHS, + scale_factor: int = CODE_EPOCH_SCALE_FACTOR, +) -> Dict[str, int]: + """Compute per-domain epoch counts from IPD weights. + + Replicates the author's epoch-scaling formula: + ``epochs = int(base * scale_factor * weight / weight_sum) + 1`` + + .. note:: Paper vs. Code + + The paper (Eq. 5) describes learning-rate decay + ``λ^{j+1} = λ^j · (1 − α_q)`` where ``α_q = g_q / Σg_l``. + The code does NOT implement LR decay. Instead, it scales the + number of epochs per source domain proportionally to the IPD + weight, keeping the learning rate fixed at 0.005. + + Args: + weights (Dict[str, float]): Mapping of source domain key to IPD + weight, from ``compute_all_ipd_weights``. + base_epochs (int): Base epoch count before scaling (code: 30). + scale_factor (int): Multiplier in the formula (code: 7, + unexplained in the paper). + + Returns: + Dict[str, int]: Mapping of source domain key to epoch count. + """ + weight_sum = sum(weights.values()) + if weight_sum == 0: + return {d: base_epochs for d in weights} + return { + domain: int(base_epochs * scale_factor * w / weight_sum) + 1 + for domain, w in weights.items() + } + + +# ===================================================================== +# Experiment configuration +# ===================================================================== + +class ExperimentConfig: + """Configuration for one run of the DPAM replication. + + Bundles all hyperparameters needed to run the three experimental + conditions (No Transfer, Naive Transfer, Weighted Transfer) for a + given target domain and distance metric. + + Code-default values replicate the author's implementation. + Paper-stated values are provided in comments for reference. + + Args: + target_domain (str): Sensor domain to classify from. + metric (str): Distance metric for IPD computation. + learning_rate (float): Optimizer LR (code: 0.005, paper: 5e-4). + batch_size (int): Training batch size (code: 16). + source_epochs (int): Epochs per source domain (code: 30, paper J=50). + target_epochs_no_transfer (int): Epochs for No Transfer (code: 40). + target_epochs_naive (int): Epochs for Naive Transfer (code: 30). + target_epochs_weighted (int): Epochs for Weighted Transfer (code: 40). + epoch_scale_factor (int): Multiplier in epoch formula (code: 7). + kde_bandwidth (float): KDE bandwidth (code: 7.8). + kde_n_samples (int): KDE sample count (code: 10). + positive_activity_id (Optional[int]): Activity ID for binary + setup. ``None`` uses 19-class classification. + n_repeats (int): Number of random subject-split repetitions. + n_train_subjects (int): Number of subjects in training set. + + Example:: + + >>> config = ExperimentConfig( + ... target_domain="LA", + ... metric="dtw_classic", + ... positive_activity_id=12, + ... ) + """ + + def __init__( + self, + target_domain: str = "LA", + metric: str = "dtw_classic", + learning_rate: float = CODE_LEARNING_RATE, + batch_size: int = CODE_BATCH_SIZE, + source_epochs: int = CODE_SOURCE_EPOCHS, + target_epochs_no_transfer: int = CODE_TARGET_EPOCHS_NO_TRANSFER, + target_epochs_naive: int = CODE_TARGET_EPOCHS_NAIVE, + target_epochs_weighted: int = CODE_TARGET_EPOCHS_WEIGHTED, + epoch_scale_factor: int = CODE_EPOCH_SCALE_FACTOR, + kde_bandwidth: float = DEFAULT_KDE_BANDWIDTH, + kde_n_samples: int = DEFAULT_KDE_N_SAMPLES, + positive_activity_id: Optional[int] = None, + n_repeats: int = 15, + n_train_subjects: int = 6, + ) -> None: + self.target_domain = target_domain + self.metric = metric + self.learning_rate = learning_rate + self.batch_size = batch_size + self.source_epochs = source_epochs + self.target_epochs_no_transfer = target_epochs_no_transfer + self.target_epochs_naive = target_epochs_naive + self.target_epochs_weighted = target_epochs_weighted + self.epoch_scale_factor = epoch_scale_factor + self.kde_bandwidth = kde_bandwidth + self.kde_n_samples = kde_n_samples + self.positive_activity_id = positive_activity_id + self.n_repeats = n_repeats + self.n_train_subjects = n_train_subjects + + def __repr__(self) -> str: + return ( + f"ExperimentConfig(target_domain={self.target_domain!r}, " + f"metric={self.metric!r}, " + f"positive_activity_id={self.positive_activity_id})" + ) + + +class ExperimentResult: + """Results from one repetition of the experiment. + + Args: + repeat_idx (int): Index of this repetition (0-based). + train_subjects (List[int]): Subject IDs used for training. + test_subjects (List[int]): Subject IDs used for testing. + metric (str): Distance metric used for IPD. + ipd_weights (Dict[str, float]): Per-domain IPD weights. + weighted_epochs (Dict[str, int]): Per-domain epoch counts. + accuracy_no_transfer (float): RCC for No Transfer baseline. + accuracy_naive_transfer (float): RCC for Naive Transfer. + accuracy_weighted_transfer (float): RCC for Weighted Transfer. + """ + + def __init__( + self, + repeat_idx: int = 0, + train_subjects: Optional[List[int]] = None, + test_subjects: Optional[List[int]] = None, + metric: str = "dtw_classic", + ipd_weights: Optional[Dict[str, float]] = None, + weighted_epochs: Optional[Dict[str, int]] = None, + accuracy_no_transfer: float = 0.0, + accuracy_naive_transfer: float = 0.0, + accuracy_weighted_transfer: float = 0.0, + ) -> None: + self.repeat_idx = repeat_idx + self.train_subjects = train_subjects or [] + self.test_subjects = test_subjects or [] + self.metric = metric + self.ipd_weights = ipd_weights or {} + self.weighted_epochs = weighted_epochs or {} + self.accuracy_no_transfer = accuracy_no_transfer + self.accuracy_naive_transfer = accuracy_naive_transfer + self.accuracy_weighted_transfer = accuracy_weighted_transfer + + def __repr__(self) -> str: + return ( + f"ExperimentResult(repeat={self.repeat_idx}, " + f"no_transfer={self.accuracy_no_transfer:.4f}, " + f"naive={self.accuracy_naive_transfer:.4f}, " + f"weighted={self.accuracy_weighted_transfer:.4f})" + ) diff --git a/pyproject.toml b/pyproject.toml index 98f88d47b..67fc81144 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -68,6 +68,9 @@ nlp = [ "rouge_score~=0.1.2", "nltk~=3.9.1", ] +dsa = [ + "pyts>=0.12.0", +] [project.urls] Homepage = "https://github.com/sunlabuiuc/PyHealth" diff --git a/tests/core/test_dsa_dataset.py b/tests/core/test_dsa_dataset.py new file mode 100644 index 000000000..21bfed33f --- /dev/null +++ b/tests/core/test_dsa_dataset.py @@ -0,0 +1,375 @@ +"""Unit tests for DSADataset using fully synthetic data.""" + +import tempfile +import unittest +from pathlib import Path +from unittest.mock import patch + +import numpy as np + +from pyhealth.datasets.dsa import DSADataset +from pyhealth.tasks.dsa import DSAActivityClassification + +# ===================================================================== +# Synthetic data constants — kept small for fast test execution +# ===================================================================== + +_N_ACTIVITIES = 2 +_N_SUBJECTS = 3 +_N_SEGMENTS = 2 +_N_TIMESTEPS = 125 +_N_COLS = 45 # 5 domains × 9 channels + + +def _make_segment_file(path: Path, seed: int = 0) -> None: + """Write a synthetic 125×45 segment file to ``path``.""" + rng = np.random.default_rng(seed) + data = rng.uniform(-10.0, 10.0, size=(_N_TIMESTEPS, _N_COLS)).astype( + np.float32 + ) + np.savetxt(path, data, delimiter=",", fmt="%.6f") + + +def _build_synthetic_dataset(root: Path) -> None: + """Create the full activity/subject/segment folder structure.""" + for a in range(1, _N_ACTIVITIES + 1): + for p in range(1, _N_SUBJECTS + 1): + subject_dir = root / f"a{a:02d}" / f"p{p}" + subject_dir.mkdir(parents=True, exist_ok=True) + for s in range(1, _N_SEGMENTS + 1): + seed = a * 1000 + p * 100 + s + _make_segment_file(subject_dir / f"s{s:02d}.txt", seed=seed) + + +class TestDSADataset(unittest.TestCase): + """Tests for DSADataset: loading, indexing, patient/event parsing.""" + + @classmethod + def setUpClass(cls): + cls.tmp = tempfile.TemporaryDirectory() + cls.root = Path(cls.tmp.name) + _build_synthetic_dataset(cls.root) + cls.cache_dir = tempfile.TemporaryDirectory() + + with patch.multiple( + DSADataset, + _N_ACTIVITIES=_N_ACTIVITIES, + _N_SUBJECTS=_N_SUBJECTS, + _N_SEGMENTS=_N_SEGMENTS, + ): + cls.dataset = DSADataset( + root=str(cls.root), + cache_dir=cls.cache_dir.name, + target_domain="LA", + scale=True, + ) + + @classmethod + def tearDownClass(cls): + cls.tmp.cleanup() + cls.cache_dir.cleanup() + + # ------------------------------------------------------------------ + # Metadata index + # ------------------------------------------------------------------ + + def test_metadata_csv_created(self): + """Index CSV must be written alongside the raw data.""" + csv_path = self.root / "dsa-metadata-pyhealth.csv" + self.assertTrue(csv_path.exists()) + + def test_metadata_row_count(self): + """Row count must equal N_activities × N_subjects × N_segments.""" + import pandas as pd + + df = pd.read_csv(self.root / "dsa-metadata-pyhealth.csv") + expected = _N_ACTIVITIES * _N_SUBJECTS * _N_SEGMENTS + self.assertEqual(len(df), expected) + + def test_metadata_required_columns(self): + """All columns required by dsa.yaml must be present.""" + import pandas as pd + + df = pd.read_csv(self.root / "dsa-metadata-pyhealth.csv") + required = { + "patient_id", + "visit_id", + "activity_id", + "activity_name", + "label", + "segment_id", + "filepath", + "pair_id", + } + self.assertTrue(required.issubset(df.columns)) + + def test_pair_id_format(self): + """pair_id must be 'a{act:02d}_s{seg:02d}', shared across subjects.""" + import pandas as pd + + df = pd.read_csv(self.root / "dsa-metadata-pyhealth.csv") + for _, row in df.iterrows(): + expected_pair = ( + f"a{int(row['activity_id']):02d}" + f"_s{int(row['segment_id']):02d}" + ) + self.assertEqual(row["pair_id"], expected_pair) + + def test_pair_id_shared_across_subjects(self): + """All subjects must share the same pair_id for the same (act, seg).""" + import pandas as pd + + df = pd.read_csv(self.root / "dsa-metadata-pyhealth.csv") + groups = df.groupby("pair_id")["patient_id"].nunique() + for pair_id, n_subjects in groups.items(): + self.assertEqual( + n_subjects, + _N_SUBJECTS, + f"pair_id '{pair_id}' should have {_N_SUBJECTS} subjects, " + f"got {n_subjects}", + ) + + def test_label_zero_indexed(self): + """Label must be activity_id minus 1 (0-indexed).""" + import pandas as pd + + df = pd.read_csv(self.root / "dsa-metadata-pyhealth.csv") + for _, row in df.iterrows(): + self.assertEqual(int(row["label"]), int(row["activity_id"]) - 1) + + def test_all_filepaths_exist(self): + """Every filepath stored in the index must point to a real file.""" + import pandas as pd + + df = pd.read_csv(self.root / "dsa-metadata-pyhealth.csv") + for fp in df["filepath"]: + self.assertTrue( + Path(fp).exists(), f"Missing segment file: {fp}" + ) + + # ------------------------------------------------------------------ + # Patient structure + # ------------------------------------------------------------------ + + def test_num_patients(self): + """One Patient object must exist per subject.""" + self.assertEqual( + len(self.dataset.unique_patient_ids), _N_SUBJECTS + ) + + def test_patient_id_format(self): + """Patient IDs must follow the 'p{n}' convention.""" + for pid in self.dataset.unique_patient_ids: + self.assertTrue( + pid.startswith("p"), + f"Patient ID '{pid}' does not start with 'p'", + ) + + def test_events_per_patient(self): + """Each patient must have N_activities × N_segments events.""" + expected = _N_ACTIVITIES * _N_SEGMENTS + for pid in self.dataset.unique_patient_ids: + events = self.dataset.get_patient(pid).get_events() + self.assertEqual( + len(events), + expected, + f"Patient '{pid}' has {len(events)} events, expected {expected}", + ) + + def test_event_attributes_present(self): + """Every event must expose all declared attribute columns.""" + required = { + "label", + "activity_id", + "activity_name", + "filepath", + "segment_id", + "pair_id", + "visit_id", + } + patient = self.dataset.get_patient( + self.dataset.unique_patient_ids[0] + ) + for event in patient.get_events(): + for attr in required: + self.assertIn( + attr, + event.attr_dict, + f"Attribute '{attr}' missing from event", + ) + + def test_event_activity_names_valid(self): + """Activity names must appear in DSADataset.activities.""" + patient = self.dataset.get_patient( + self.dataset.unique_patient_ids[0] + ) + for event in patient.get_events(): + self.assertIn(event.activity_name, DSADataset.activities) + + # ------------------------------------------------------------------ + # Time series loading + # ------------------------------------------------------------------ + + def test_load_time_series_shape(self): + """Single-domain load must return (9, 125) array.""" + import pandas as pd + + df = pd.read_csv(self.root / "dsa-metadata-pyhealth.csv") + fp = df.iloc[0]["filepath"] + ts = self.dataset.load_time_series(fp, domain="LA") + self.assertIn("LA", ts) + self.assertEqual(ts["LA"].shape, (9, _N_TIMESTEPS)) + + def test_load_time_series_all_domains(self): + """No-domain load must return all five domain arrays.""" + import pandas as pd + + df = pd.read_csv(self.root / "dsa-metadata-pyhealth.csv") + fp = df.iloc[0]["filepath"] + ts = self.dataset.load_time_series(fp) + self.assertEqual(set(ts.keys()), set(DSADataset.domains)) + for domain, arr in ts.items(): + self.assertEqual( + arr.shape, + (9, _N_TIMESTEPS), + f"Domain '{domain}' has shape {arr.shape}", + ) + + def test_minmax_scale_range(self): + """Scaled time series values must lie in [-1, 1].""" + import pandas as pd + + df = pd.read_csv(self.root / "dsa-metadata-pyhealth.csv") + fp = df.iloc[0]["filepath"] + ts = self.dataset.load_time_series(fp, domain="T") + arr = ts["T"] + self.assertGreaterEqual(arr.min(), -1.0 - 1e-6) + self.assertLessEqual(arr.max(), 1.0 + 1e-6) + + def test_scale_false_preserves_raw_values(self): + """With scale=False, raw values outside [-1, 1] must be present.""" + with patch.multiple( + DSADataset, + _N_ACTIVITIES=_N_ACTIVITIES, + _N_SUBJECTS=_N_SUBJECTS, + _N_SEGMENTS=_N_SEGMENTS, + ): + dataset_unscaled = DSADataset( + root=str(self.root), + cache_dir=self.cache_dir.name, + target_domain="LA", + scale=False, + ) + + import pandas as pd + + df = pd.read_csv(self.root / "dsa-metadata-pyhealth.csv") + fp = df.iloc[0]["filepath"] + ts = dataset_unscaled.load_time_series(fp, domain="T") + arr = ts["T"] + has_outside = (arr.min() < -1.0) or (arr.max() > 1.0) + self.assertTrue(has_outside, "Unscaled values should exceed [-1, 1]") + + # ------------------------------------------------------------------ + # Subject split utilities + # ------------------------------------------------------------------ + + def test_get_subject_split_counts(self): + """Train/test split must contain the correct number of rows.""" + train_df, test_df = self.dataset.get_subject_split( + train_subjects=[1, 2], test_subjects=[3] + ) + rows_per_subject = _N_ACTIVITIES * _N_SEGMENTS + self.assertEqual(len(train_df), 2 * rows_per_subject) + self.assertEqual(len(test_df), 1 * rows_per_subject) + + def test_get_subject_split_no_overlap(self): + """No patient_id should appear in both train and test splits.""" + train_df, test_df = self.dataset.get_subject_split( + train_subjects=[1, 2], test_subjects=[3] + ) + train_ids = set(train_df["patient_id"]) + test_ids = set(test_df["patient_id"]) + self.assertTrue(train_ids.isdisjoint(test_ids)) + + def test_get_subject_split_overlap_raises(self): + """Overlapping train/test subjects must raise ValueError.""" + with self.assertRaises(ValueError): + self.dataset.get_subject_split( + train_subjects=[1, 2], test_subjects=[2, 3] + ) + + def test_random_subject_splits_count(self): + """Generator must yield exactly n_repeats tuples.""" + results = list( + self.dataset.random_subject_splits( + n_repeats=3, n_train=2, random_seed=0 + ) + ) + self.assertEqual(len(results), 3) + + def test_random_subject_splits_reproducible(self): + """Same seed must produce identical splits.""" + splits_a = [ + (tr, te) + for _, tr, te, _, _ in self.dataset.random_subject_splits( + n_repeats=3, n_train=2, random_seed=42 + ) + ] + splits_b = [ + (tr, te) + for _, tr, te, _, _ in self.dataset.random_subject_splits( + n_repeats=3, n_train=2, random_seed=42 + ) + ] + self.assertEqual(splits_a, splits_b) + + def test_random_subject_splits_no_overlap(self): + """Every random split must have disjoint train/test sets.""" + for _, tr, te, _, _ in self.dataset.random_subject_splits( + n_repeats=5, n_train=2, random_seed=0 + ): + self.assertTrue(set(tr).isdisjoint(set(te))) + + # ------------------------------------------------------------------ + # Default task + # ------------------------------------------------------------------ + + def test_default_task_type(self): + """default_task must return a DSAActivityClassification instance.""" + self.assertIsInstance( + self.dataset.default_task, DSAActivityClassification + ) + + # ------------------------------------------------------------------ + # Validation + # ------------------------------------------------------------------ + + def test_invalid_target_domain_raises(self): + """Unknown target_domain must raise ValueError at construction.""" + with self.assertRaises(ValueError): + with patch.multiple( + DSADataset, + _N_ACTIVITIES=_N_ACTIVITIES, + _N_SUBJECTS=_N_SUBJECTS, + _N_SEGMENTS=_N_SEGMENTS, + ): + DSADataset( + root=str(self.root), + target_domain="WRIST", + ) + + def test_missing_root_raises(self): + """Non-existent root directory must raise FileNotFoundError.""" + with self.assertRaises(FileNotFoundError): + DSADataset(root="/nonexistent/path/dsa") + + def test_missing_a01_raises(self): + """Root without 'a01' subfolder must raise FileNotFoundError.""" + with tempfile.TemporaryDirectory() as bad_root: + with self.assertRaises(FileNotFoundError): + DSADataset(root=bad_root) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/core/test_dsa_tasks.py b/tests/core/test_dsa_tasks.py new file mode 100644 index 000000000..4275bba07 --- /dev/null +++ b/tests/core/test_dsa_tasks.py @@ -0,0 +1,418 @@ +"""Unit tests for DSA tasks and IPD computation using synthetic data.""" + +import tempfile +import unittest +from pathlib import Path +from unittest.mock import patch + +import numpy as np + +from pyhealth.datasets.dsa import DSADataset +from pyhealth.tasks.dsa import ( + DSAActivityClassification, + DSABinaryActivityClassification, +) +from pyhealth.tasks.dsa import ( + compute_all_ipd_weights, + compute_ipd_weight, + compute_pairwise_distances, + compute_weighted_epochs, + ExperimentConfig, + ExperimentResult, +) + +_N_ACTIVITIES = 2 +_N_SUBJECTS = 3 +_N_SEGMENTS = 2 +_N_TIMESTEPS = 125 +_N_COLS = 45 # 5 domains × 9 channels + + +def _make_segment_file(path: Path, seed: int = 0) -> None: + """Write a synthetic 125×45 segment file to ``path``.""" + rng = np.random.default_rng(seed) + data = rng.uniform(-10.0, 10.0, size=(_N_TIMESTEPS, _N_COLS)).astype( + np.float32 + ) + np.savetxt(path, data, delimiter=",", fmt="%.6f") + + +def _build_synthetic_dataset(root: Path) -> None: + """Create the full activity/subject/segment folder structure.""" + for a in range(1, _N_ACTIVITIES + 1): + for p in range(1, _N_SUBJECTS + 1): + subject_dir = root / f"a{a:02d}" / f"p{p}" + subject_dir.mkdir(parents=True, exist_ok=True) + for s in range(1, _N_SEGMENTS + 1): + seed = a * 1000 + p * 100 + s + _make_segment_file(subject_dir / f"s{s:02d}.txt", seed=seed) + + +class TestDSAActivityClassification(unittest.TestCase): + """Tests for the activity classification task.""" + + @classmethod + def setUpClass(cls): + cls.tmp = tempfile.TemporaryDirectory() + cls.root = Path(cls.tmp.name) + _build_synthetic_dataset(cls.root) + cls.cache_dir = tempfile.TemporaryDirectory() + + with patch.multiple( + DSADataset, + _N_ACTIVITIES=_N_ACTIVITIES, + _N_SUBJECTS=_N_SUBJECTS, + _N_SEGMENTS=_N_SEGMENTS, + ): + cls.dataset = DSADataset( + root=str(cls.root), + cache_dir=cls.cache_dir.name, + target_domain="LA", + scale=True, + ) + cls.task = DSAActivityClassification() + cls.samples = cls.dataset.set_task(cls.task) + cls.sample_list = list(cls.samples) + + @classmethod + def tearDownClass(cls): + try: + del cls.samples + except Exception: + pass + + import gc + gc.collect() + + def test_sample_count(self): + """Total samples must equal N_patients × N_activities × N_segments.""" + expected = _N_SUBJECTS * _N_ACTIVITIES * _N_SEGMENTS + self.assertEqual(len(self.samples), expected) + + def test_sample_required_keys(self): + """Every sample must contain required keys.""" + required = { + "patient_id", + "visit_id", + "time_series", + "label", + "activity_name", + "pair_id", + } + for sample in self.sample_list: + self.assertTrue(required.issubset(sample.keys())) + + def test_time_series_shape(self): + """time_series must have shape (9, 125).""" + for sample in self.sample_list: + self.assertEqual( + sample["time_series"].shape, (9, _N_TIMESTEPS) + ) + + def test_labels_range(self): + """Labels must be integers in [0, N_activities - 1].""" + labels = {int(sample["label"]) for sample in self.sample_list} + self.assertTrue(labels.issubset(set(range(_N_ACTIVITIES)))) + + def test_activity_names_valid(self): + """Activity names must come from DSADataset.activities.""" + for sample in self.sample_list: + self.assertIn(sample["activity_name"], DSADataset.activities) + + def test_time_series_scaled(self): + """Scaled time series must have all values in [-1, 1].""" + for sample in self.sample_list: + arr = sample["time_series"] + self.assertGreaterEqual(arr.min(), -1.0 - 1e-6) + self.assertLessEqual(arr.max(), 1.0 + 1e-6) + + def test_pair_ids_present(self): + """pair_id must be present and non-empty for every sample.""" + for sample in self.sample_list: + self.assertIsInstance(sample["pair_id"], str) + self.assertGreater(len(sample["pair_id"]), 0) + + +class TestDSABinaryClassification(unittest.TestCase): + """Tests for the binary one-vs-rest classification task.""" + + @classmethod + def setUpClass(cls): + cls.tmp = tempfile.TemporaryDirectory() + cls.root = Path(cls.tmp.name) + _build_synthetic_dataset(cls.root) + cls.cache_dir = tempfile.TemporaryDirectory() + + with patch.multiple( + DSADataset, + _N_ACTIVITIES=_N_ACTIVITIES, + _N_SUBJECTS=_N_SUBJECTS, + _N_SEGMENTS=_N_SEGMENTS, + ): + cls.dataset = DSADataset( + root=str(cls.root), + cache_dir=cls.cache_dir.name, + target_domain="RA", + scale=True, + ) + cls.task = DSABinaryActivityClassification(positive_activity_id=1) + cls.samples = cls.dataset.set_task(cls.task) + cls.sample_list = list(cls.samples) + + def test_sample_count(self): + """Binary task must produce the same number of samples as multiclass.""" + expected = _N_SUBJECTS * _N_ACTIVITIES * _N_SEGMENTS + self.assertEqual(len(self.samples), expected) + + def test_labels_binary(self): + """All labels must be 0 or 1.""" + for sample in self.sample_list: + self.assertIn(sample["label"], (0, 1)) + + def test_positive_label_count(self): + """Exactly N_subjects × N_segments samples must be positive.""" + n_positive = sum(s["label"] for s in self.samples) + expected = _N_SUBJECTS * _N_SEGMENTS + self.assertEqual(n_positive, expected) + + def test_positive_label_corresponds_to_activity(self): + """Positive samples must belong to the designated positive activity.""" + for sample in self.sample_list: + if sample["label"] == 1: + self.assertEqual(sample["activity_id"], 1) + else: + self.assertNotEqual(sample["activity_id"], 1) + + def test_missing_positive_activity_raises(self): + """Calling task without positive_activity_id set must raise ValueError.""" + with self.assertRaises(TypeError): + DSABinaryActivityClassification() + with patch.multiple( + DSADataset, + _N_ACTIVITIES=_N_ACTIVITIES, + _N_SUBJECTS=_N_SUBJECTS, + _N_SEGMENTS=_N_SEGMENTS, + ): + with self.assertRaises(TypeError): + DSABinaryActivityClassification() + + # def test_different_positive_class_gives_different_positives(self): + # """Changing positive_activity_id must change which samples are positive.""" + # task_a2 = DSABinaryActivityClassification(positive_activity_id=2) + # with patch.multiple( + # DSADataset, + # _N_ACTIVITIES=_N_ACTIVITIES, + # _N_SUBJECTS=_N_SUBJECTS, + # _N_SEGMENTS=_N_SEGMENTS, + # ): + # samples_a2 = self.dataset.set_task(task_a2) + + # positives_a1 = {s["visit_id"] for s in self.sample_list if s["label"] == 1} + # positives_a2 = {s["visit_id"] for s in samples_a2 if s["label"] == 1} + # self.assertTrue(positives_a1.isdisjoint(positives_a2)) + # del samples_a2 + + def test_activity_id_in_sample(self): + """activity_id field must be present in every binary sample.""" + for sample in self.sample_list: + self.assertIn("activity_id", sample) + + +class TestIPDComputation(unittest.TestCase): + """Tests for IPD pipeline: pairwise distances, KDE weights, epoch scaling.""" + + @classmethod + def setUpClass(cls): + rng = np.random.default_rng(42) + n = 20 + cls.source_ts = rng.uniform(-1, 1, (n, 125)).astype(np.float32) + cls.target_ts = rng.uniform(-1, 1, (n, 125)).astype(np.float32) + cls.n = n + cls.euclidean_dists = compute_pairwise_distances( + cls.source_ts, cls.target_ts, metric="euclidean" + ) + cls.self_dists = compute_pairwise_distances( + cls.source_ts, cls.source_ts, metric="euclidean" + ) + + def test_euclidean_output_shape(self): + """Output must have one scalar per pair.""" + dists = compute_pairwise_distances( + self.source_ts, self.target_ts, metric="euclidean" + ) + self.assertEqual(self.euclidean_dists.shape, (self.n,)) + + def test_euclidean_non_negative(self): + """Euclidean distances must be non-negative.""" + self.assertTrue(np.all(self.euclidean_dists >= 0)) + + def test_self_distance_is_zero(self): + """Euclidean distance from a series to itself must be zero.""" + np.testing.assert_allclose(self.self_dists, 0.0, atol=1e-5) + + def test_invalid_metric_raises(self): + """Unknown metric string must raise ValueError.""" + with self.assertRaises(ValueError): + compute_pairwise_distances( + self.source_ts, self.target_ts, metric="manhattan_city" + ) + + def test_ipd_weight_is_scalar(self): + """IPD weight must be a Python float.""" + weight = compute_ipd_weight(self.euclidean_dists) + self.assertIsInstance(weight, float) + + def test_identical_series_low_weight(self): + """Identical source/target should produce a lower distance than random.""" + near_zero = np.zeros((self.n, 125), dtype=np.float32) + dists_identical = compute_pairwise_distances( + near_zero, near_zero, metric="euclidean" + ) + dists_random = self.euclidean_dists + weight_identical = compute_ipd_weight(dists_identical) + weight_random = compute_ipd_weight(dists_random) + self.assertLess(weight_identical, weight_random) + + def test_ipd_weight_deterministic(self): + """Same input and random_state must give the same weight.""" + w1 = compute_ipd_weight(self.euclidean_dists, random_state=0) + w2 = compute_ipd_weight(self.euclidean_dists, random_state=0) + self.assertAlmostEqual(w1, w2, places=10) + + def test_ipd_weight_changes_with_bandwidth(self): + """Changing bandwidth must produce a different weight.""" + w_narrow = compute_ipd_weight(self.euclidean_dists) + w_wide = compute_ipd_weight(self.euclidean_dists, bandwidth=50.0) + self.assertNotAlmostEqual(w_narrow, w_wide, places=3) + + def test_all_ipd_weights_excludes_target(self): + """Target domain must not appear in the returned weights dict.""" + rng = np.random.default_rng(0) + domain_data = { + d: rng.uniform(-1, 1, (10, 125)).astype(np.float32) + for d in ["T", "RA", "LA", "RL", "LL"] + } + weights = compute_all_ipd_weights( + domain_data, target_domain="LA", metric="euclidean" + ) + self.assertNotIn("LA", weights) + + def test_all_ipd_weights_contains_source_domains(self): + """All non-target domains must appear in the returned weights.""" + rng = np.random.default_rng(1) + domain_data = { + d: rng.uniform(-1, 1, (10, 125)).astype(np.float32) + for d in ["T", "RA", "LA", "RL", "LL"] + } + weights = compute_all_ipd_weights( + domain_data, target_domain="LA", metric="euclidean" + ) + expected_sources = {"T", "RA", "RL", "LL"} + self.assertEqual(set(weights.keys()), expected_sources) + + def test_all_ipd_weights_are_positive(self): + """All weights must be positive for non-zero distance inputs.""" + rng = np.random.default_rng(2) + domain_data = { + d: rng.uniform(0.5, 1.0, (10, 125)).astype(np.float32) + for d in ["T", "RA", "LA"] + } + weights = compute_all_ipd_weights( + domain_data, target_domain="LA", metric="euclidean" + ) + for domain, w in weights.items(): + self.assertGreater(w, 0, f"Weight for '{domain}' must be positive") + + def test_weighted_epochs_all_domains_present(self): + """Every input domain must appear in the output.""" + weights = {"T": 10.0, "RA": 5.0, "RL": 3.0, "LL": 2.0} + epochs = compute_weighted_epochs(weights) + self.assertEqual(set(epochs.keys()), set(weights.keys())) + + def test_weighted_epochs_minimum_one(self): + """Every domain must get at least 1 epoch (from +1 in formula).""" + weights = {"T": 0.0001, "RA": 0.0001} + epochs = compute_weighted_epochs(weights) + for d, e in epochs.items(): + self.assertGreaterEqual(e, 1, f"Domain '{d}' got 0 epochs") + + def test_weighted_epochs_proportional(self): + """Higher-weight domain should receive more epochs.""" + weights = {"high": 100.0, "low": 1.0} + epochs = compute_weighted_epochs(weights) + self.assertGreater(epochs["high"], epochs["low"]) + + def test_weighted_epochs_zero_sum_fallback(self): + """All-zero weights must not cause a division error.""" + weights = {"T": 0.0, "RA": 0.0} + try: + epochs = compute_weighted_epochs(weights) + for e in epochs.values(): + self.assertIsInstance(e, int) + except ZeroDivisionError: + self.fail("compute_weighted_epochs raised ZeroDivisionError") + + def test_epoch_scale_factor_applied(self): + """Custom scale_factor must affect the epoch counts.""" + weights = {"T": 1.0} + e1 = compute_weighted_epochs(weights, scale_factor=7) + e2 = compute_weighted_epochs(weights, scale_factor=14) + self.assertGreater(e2["T"], e1["T"]) + + +class TestExperimentConfig(unittest.TestCase): + """Tests for ExperimentConfig and ExperimentResult dataclasses.""" + + def test_default_values(self): + """Default config must reflect author's code values, not paper.""" + config = ExperimentConfig() + self.assertEqual(config.learning_rate, 0.005) + self.assertEqual(config.source_epochs, 30) + self.assertEqual(config.target_epochs_weighted, 40) + self.assertEqual(config.kde_bandwidth, 7.8) + self.assertEqual(config.kde_n_samples, 10) + self.assertEqual(config.n_repeats, 15) + self.assertEqual(config.n_train_subjects, 6) + + def test_custom_values(self): + """Custom config must override defaults correctly.""" + config = ExperimentConfig( + target_domain="RA", + metric="dtw_classic", + positive_activity_id=12, + ) + self.assertEqual(config.target_domain, "RA") + self.assertEqual(config.metric, "dtw_classic") + self.assertEqual(config.positive_activity_id, 12) + + def test_none_positive_activity_triggers_multiclass(self): + """positive_activity_id=None must signal the 19-class setup.""" + config = ExperimentConfig() + self.assertIsNone(config.positive_activity_id) + + def test_experiment_result_defaults(self): + """ExperimentResult must initialise with zero accuracies.""" + result = ExperimentResult() + self.assertEqual(result.accuracy_no_transfer, 0.0) + self.assertEqual(result.accuracy_naive_transfer, 0.0) + self.assertEqual(result.accuracy_weighted_transfer, 0.0) + self.assertIsInstance(result.ipd_weights, dict) + self.assertIsInstance(result.weighted_epochs, dict) + + def test_experiment_result_assignment(self): + """Assigned accuracies must be retrievable.""" + result = ExperimentResult( + repeat_idx=3, + train_subjects=[1, 2, 3], + test_subjects=[4, 5], + metric="euclidean", + accuracy_no_transfer=0.82, + accuracy_naive_transfer=0.88, + accuracy_weighted_transfer=0.91, + ) + self.assertEqual(result.repeat_idx, 3) + self.assertAlmostEqual(result.accuracy_weighted_transfer, 0.91) + + +if __name__ == "__main__": + unittest.main() From 5b9d590227291edb2a731ea80efa95ea413fbb9e Mon Sep 17 00:00:00 2001 From: unknown Date: Thu, 23 Apr 2026 00:37:30 -0500 Subject: [PATCH 2/3] typo fix --- examples/dsa_activity_classification.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/dsa_activity_classification.ipynb b/examples/dsa_activity_classification.ipynb index ed5853ffa..f0fc90f42 100644 --- a/examples/dsa_activity_classification.ipynb +++ b/examples/dsa_activity_classification.ipynb @@ -105,7 +105,7 @@ "\n", "# ── Load data ───────────────────────────────────────────────────────\n", "print('\\nLoading dataset...')\n", - "dataset = DSADataset(root=DATA_ROOT, download=False,\n", + "dataset = DSADataset(root=DATA_ROOT, download=True,\n", " target_domain=TARGET_DOMAIN, scale=True)\n", "\n", "# Will be reloaded per repeat if RANDOM_ACTIVITY=True, otherwise load once\n", From 5b0002c202f65a38608ef7b2195b041498d60c65 Mon Sep 17 00:00:00 2001 From: unknown Date: Thu, 23 Apr 2026 01:07:49 -0500 Subject: [PATCH 3/3] typo fix 2 --- examples/dsa_activity_classification.ipynb | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/dsa_activity_classification.ipynb b/examples/dsa_activity_classification.ipynb index f0fc90f42..b2f0afbb0 100644 --- a/examples/dsa_activity_classification.ipynb +++ b/examples/dsa_activity_classification.ipynb @@ -135,7 +135,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": null, "id": "80907810", "metadata": {}, "outputs": [], @@ -221,7 +221,7 @@ " raw = batch.get('label')\n", " if raw is not None:\n", " if isinstance(raw, torch.Tensor):\n", - " batch['label'] = raw.view(-1).long().to(device)\n", + " batch['label'] = raw.view(-1, 1).float().to(device)\n", " else:\n", " batch['label'] = torch.tensor(\n", " raw, dtype=torch.long, device=device).view(-1)\n",