diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 2e10e0e..f5308e4 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -33,7 +33,7 @@ jobs: env: PYTEST_DISABLE_PLUGIN_AUTOLOAD: "1" run: | - pytest tests/unit/ tests/test_security.py \ + pytest tests/unit/ tests/test_security.py tests/test_voiceprint_db.py tests/test_job_service.py \ -p pytest_cov \ -v --tb=short --no-header \ --cov=app --cov-report=xml:coverage.xml \ diff --git a/.github/workflows/claude-code-review.yml b/.github/workflows/claude-code-review.yml index 3071329..24b6f66 100644 --- a/.github/workflows/claude-code-review.yml +++ b/.github/workflows/claude-code-review.yml @@ -8,6 +8,7 @@ on: permissions: contents: read + id-token: write pull-requests: write issues: read @@ -26,20 +27,30 @@ jobs: env: ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }} ANTHROPIC_BASE_URL: ${{ secrets.ANTHROPIC_BASE_URL }} + GH_TOKEN_VALUE: ${{ secrets.GH_TOKEN }} + CLAUDE_MODEL: claude-sonnet-4-6 steps: - - name: Skip when ANTHROPIC_API_KEY is not configured - if: ${{ env.ANTHROPIC_API_KEY == '' }} - run: echo "ANTHROPIC_API_KEY is not configured; skipping Claude Code review." + - name: Skip when Claude secrets are not configured + if: ${{ env.ANTHROPIC_API_KEY == '' || env.ANTHROPIC_BASE_URL == '' || env.GH_TOKEN_VALUE == '' }} + run: echo "Claude Code review secrets are not configured; skipping Claude Code review." + - name: Checkout repository + if: ${{ env.ANTHROPIC_API_KEY != '' && env.ANTHROPIC_BASE_URL != '' && env.GH_TOKEN_VALUE != '' }} + uses: actions/checkout@v4 + with: + fetch-depth: 1 + persist-credentials: false - name: Run Claude Code review - if: ${{ env.ANTHROPIC_API_KEY != '' }} + if: ${{ env.ANTHROPIC_API_KEY != '' && env.ANTHROPIC_BASE_URL != '' && env.GH_TOKEN_VALUE != '' }} uses: anthropics/claude-code-action@v1 with: anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }} + github_token: ${{ secrets.GH_TOKEN }} prompt: | Review this pull request using REVIEW.md as the review-only guide. Focus on actionable VoScript risks: privacy/security leaks, model lifecycle races, GPU/CPU fallback behavior, HTTP API compatibility, regression-test coverage, and synchronized English/Chinese documentation. Avoid formatting-only comments. - claude_args: "--max-turns 5" + claude_args: | + --model ${{ env.CLAUDE_MODEL }} env: ANTHROPIC_BASE_URL: ${{ secrets.ANTHROPIC_BASE_URL }} diff --git a/codecov.yml b/codecov.yml index 08de640..6bd9a7e 100644 --- a/codecov.yml +++ b/codecov.yml @@ -2,10 +2,12 @@ coverage: status: project: default: - target: 70% + target: 90% + threshold: 0% patch: default: - target: 60% + target: 80% + threshold: 0% ignore: - "tests/**" diff --git a/tests/unit/test_api_route_coverage.py b/tests/unit/test_api_route_coverage.py new file mode 100644 index 0000000..8fa1439 --- /dev/null +++ b/tests/unit/test_api_route_coverage.py @@ -0,0 +1,298 @@ +"""Endpoint edge coverage for transcription and voiceprint routers.""" + +from __future__ import annotations + +import json +from pathlib import Path + +import numpy as np + + +def _seed_result(transcriptions_dir: Path, tr_id: str, *, filename: str = "audio.wav"): + tr_dir = transcriptions_dir / tr_id + tr_dir.mkdir(parents=True, exist_ok=True) + payload = { + "id": tr_id, + "filename": filename, + "created_at": "2026-04-25T00:00:00+00:00", + "segments": [ + { + "id": 1, + "start": None, + "end": float("nan"), + "speaker_label": "SPEAKER_00", + "speaker_name": "Maple\nInjected", + "speaker_id": "spk_old", + "text": "hello", + }, + { + "id": 2, + "start": 61.25, + "end": 62.0, + "speaker_label": "SPEAKER_01", + "speaker_name": "Guest", + "speaker_id": None, + "text": "world", + }, + ], + "unique_speakers": ["Maple\nInjected", "Guest"], + "speaker_map": {}, + } + result_path = tr_dir / "result.json" + result_path.write_text(json.dumps(payload), encoding="utf-8") + return result_path + + +def test_transcription_job_status_fallback_paths(app_client): + import api.routers.transcriptions as router + + router.jobs["tr_memory_done"] = { + "status": "completed", + "filename": "done.wav", + "result": {"id": "tr_memory_done"}, + } + router.jobs["tr_memory_failed"] = { + "status": "failed", + "filename": "failed.wav", + "error": "boom", + } + + done = app_client.get("/api/jobs/tr_memory_done") + assert done.status_code == 200 + assert done.json()["result"] == {"id": "tr_memory_done"} + + failed = app_client.get("/api/jobs/tr_memory_failed") + assert failed.status_code == 200 + assert failed.json()["error"] == "boom" + + completed_dir = router.TRANSCRIPTIONS_DIR / "tr_disk_done" + completed_dir.mkdir(parents=True) + (completed_dir / "status.json").write_text( + json.dumps({"status": "completed", "filename": "disk.wav"}), + encoding="utf-8", + ) + (completed_dir / "result.json").write_text("{bad-json", encoding="utf-8") + disk_done = app_client.get("/api/jobs/tr_disk_done") + assert disk_done.status_code == 200 + assert disk_done.json()["result"] is None + + queued_dir = router.TRANSCRIPTIONS_DIR / "tr_disk_queued" + queued_dir.mkdir(parents=True) + (queued_dir / "status.json").write_text( + json.dumps({"status": "queued", "filename": "queued.wav"}), + encoding="utf-8", + ) + disk_queued = app_client.get("/api/jobs/tr_disk_queued") + assert disk_queued.status_code == 200 + assert disk_queued.json()["status"] == "failed" + + failed_dir = router.TRANSCRIPTIONS_DIR / "tr_disk_failed" + failed_dir.mkdir(parents=True) + (failed_dir / "status.json").write_text( + json.dumps({"status": "failed", "filename": "failed.wav", "error": "bad"}), + encoding="utf-8", + ) + disk_failed = app_client.get("/api/jobs/tr_disk_failed") + assert disk_failed.status_code == 200 + assert disk_failed.json()["error"] == "bad" + + +def test_transcription_list_audio_export_and_reassign_paths(app_client): + import api.routers.transcriptions as router + + tr_id = "tr_route_edges" + _seed_result(router.TRANSCRIPTIONS_DIR, tr_id, filename="route_audio.wav") + bad_dir = router.TRANSCRIPTIONS_DIR / "tr_bad_listing" + bad_dir.mkdir(parents=True) + (bad_dir / "result.json").write_text("{bad-json", encoding="utf-8") + + listing = app_client.get("/api/transcriptions") + assert listing.status_code == 200 + assert any( + row["id"] == tr_id and row["segment_count"] == 2 for row in listing.json() + ) + + missing_audio = app_client.get(f"/api/transcriptions/{tr_id}/audio") + assert missing_audio.status_code == 404 + + (router.UPLOADS_DIR / "route_audio.wav").write_bytes(b"audio") + audio = app_client.get(f"/api/transcriptions/{tr_id}/audio") + assert audio.status_code == 200 + assert audio.content == b"audio" + + srt = app_client.get(f"/api/export/{tr_id}?format=srt") + assert srt.status_code == 200 + assert "00:00:00,000 --> 00:00:00,000" in srt.text + assert "[Maple Injected] hello" in srt.text + + txt = app_client.get(f"/api/export/{tr_id}?format=txt") + assert txt.status_code == 200 + assert "[01:01] Guest: world" in txt.text + + exported_json = app_client.get(f"/api/export/{tr_id}?format=json") + assert exported_json.status_code == 200 + assert exported_json.json()["id"] == tr_id + + unsupported = app_client.get(f"/api/export/{tr_id}?format=vtt") + assert unsupported.status_code == 400 + + invalid_id = app_client.put( + f"/api/transcriptions/{tr_id}/segments/1/speaker", + data={"speaker_name": "Maple", "speaker_id": "not-safe"}, + ) + assert invalid_id.status_code == 422 + + class FakeDB: + def __init__(self, found): + self.found = found + + def get_speaker(self, speaker_id): + return {"id": speaker_id} if self.found else None + + app_client.app.state.db = FakeDB(found=False) + missing_voiceprint = app_client.put( + f"/api/transcriptions/{tr_id}/segments/1/speaker", + data={"speaker_name": "Maple", "speaker_id": "spk_missing"}, + ) + assert missing_voiceprint.status_code == 404 + + app_client.app.state.db = FakeDB(found=True) + updated = app_client.put( + f"/api/transcriptions/{tr_id}/segments/1/speaker", + data={"speaker_name": "Maple", "speaker_id": "spk_known"}, + ) + assert updated.status_code == 200 + + cleared = app_client.put( + f"/api/transcriptions/{tr_id}/segments/2/speaker", + data={"speaker_name": "Maple"}, + ) + assert cleared.status_code == 200 + + result = json.loads((router.TRANSCRIPTIONS_DIR / tr_id / "result.json").read_text()) + assert result["segments"][0]["speaker_id"] == "spk_known" + assert result["segments"][1]["speaker_id"] is None + assert result["unique_speakers"] == ["Maple"] + + missing_segment = app_client.put( + f"/api/transcriptions/{tr_id}/segments/99/speaker", + data={"speaker_name": "Nobody"}, + ) + assert missing_segment.status_code == 404 + + +def test_voiceprint_management_routes(app_client): + import api.routers.voiceprints as router + + class FakeDB: + def __init__(self): + self.speakers = {} + self.updated = [] + self.deleted = [] + self.renamed = [] + self.cohort_path = None + self.last_cohort_skipped = 3 + + def add_speaker(self, name, embedding): + assert embedding.shape == (3,) + self.speakers["spk_new"] = {"id": "spk_new", "name": name} + return "spk_new" + + def update_speaker(self, speaker_id, embedding, name=None): + self.updated.append((speaker_id, name, tuple(embedding.tolist()))) + + def list_speakers(self): + return list(self.speakers.values()) + + def get_speaker(self, speaker_id): + return self.speakers.get(speaker_id) + + def delete_speaker(self, speaker_id): + if speaker_id not in self.speakers: + raise ValueError("missing speaker") + self.deleted.append(speaker_id) + self.speakers.pop(speaker_id) + + def rename_speaker(self, speaker_id, name): + if speaker_id not in self.speakers: + raise ValueError("missing speaker") + self.renamed.append((speaker_id, name)) + self.speakers[speaker_id]["name"] = name + + def build_cohort_from_transcriptions(self, transcriptions_dir): + assert transcriptions_dir == str(router.TRANSCRIPTIONS_DIR) + return 7 + + fake_db = FakeDB() + app_client.app.state.db = fake_db + + missing = app_client.post( + "/api/voiceprints/enroll", + data={ + "tr_id": "tr_voiceprint", + "speaker_label": "SPEAKER_00", + "speaker_name": "Maple", + }, + ) + assert missing.status_code == 404 + + tr_dir = router.safe_tr_dir("tr_voiceprint") + tr_dir.mkdir(parents=True, exist_ok=True) + np.save(tr_dir / "emb_SPEAKER_00.npy", np.array([1.0, 2.0, 3.0], dtype=np.float32)) + + created = app_client.post( + "/api/voiceprints/enroll", + data={ + "tr_id": "tr_voiceprint", + "speaker_label": "SPEAKER_00", + "speaker_name": "Maple", + }, + ) + assert created.status_code == 200 + assert created.json() == {"action": "created", "speaker_id": "spk_new"} + + updated = app_client.post( + "/api/voiceprints/enroll", + data={ + "tr_id": "tr_voiceprint", + "speaker_label": "SPEAKER_00", + "speaker_name": "Maple Updated", + "speaker_id": "spk_new", + }, + ) + assert updated.status_code == 200 + assert updated.json() == {"action": "updated", "speaker_id": "spk_new"} + assert fake_db.updated == [("spk_new", "Maple Updated", (1.0, 2.0, 3.0))] + + listing = app_client.get("/api/voiceprints") + assert listing.status_code == 200 + assert listing.json() == [{"id": "spk_new", "name": "Maple"}] + + found = app_client.get("/api/voiceprints/spk_new") + assert found.status_code == 200 + assert found.json()["name"] == "Maple" + + missing_get = app_client.get("/api/voiceprints/spk_missing") + assert missing_get.status_code == 404 + + renamed = app_client.put("/api/voiceprints/spk_new/name", data={"name": "Renamed"}) + assert renamed.status_code == 200 + assert fake_db.renamed == [("spk_new", "Renamed")] + + missing_rename = app_client.put( + "/api/voiceprints/spk_missing/name", data={"name": "Missing"} + ) + assert missing_rename.status_code == 404 + + cohort = app_client.post("/api/voiceprints/rebuild-cohort") + assert cohort.status_code == 200 + assert cohort.json()["cohort_size"] == 7 + assert cohort.json()["skipped"] == 3 + assert cohort.json()["saved_to"].endswith("asnorm_cohort.npy") + + deleted = app_client.delete("/api/voiceprints/spk_new") + assert deleted.status_code == 200 + assert fake_db.deleted == ["spk_new"] + + missing_delete = app_client.delete("/api/voiceprints/spk_missing") + assert missing_delete.status_code == 404 diff --git a/tests/unit/test_audio_layers.py b/tests/unit/test_audio_layers.py index 234fd73..9bb2fe8 100644 --- a/tests/unit/test_audio_layers.py +++ b/tests/unit/test_audio_layers.py @@ -3,21 +3,27 @@ from __future__ import annotations import json +import subprocess import sys from contextlib import contextmanager from inspect import signature from pathlib import Path from types import ModuleType +import numpy as np +import pytest import infra.audio.hash_index as hash_index_module import infra.audio as audio_infra import providers import providers.enhance.default as enhance_default +import providers.normalize.default as normalize_default +import providers.voiceprint_match.default as voiceprint_match_default from infra.audio import JsonAudioArtifactIndex from pipeline.contracts import ( AudioEnhancementRequest, AudioNormalizationRequest, UploadPersistenceRequest, + VoiceprintMatchRequest, ) from api.routers.transcriptions import transcribe @@ -103,6 +109,116 @@ def test_denoise_api_snr_threshold_overrides_env_default(monkeypatch, tmp_path): assert result.output_path == wav_path +def test_unknown_denoise_model_is_a_noop(tmp_path, caplog): + wav_path = tmp_path / "sample.wav" + wav_path.write_bytes(b"stub") + + with caplog.at_level("WARNING", logger=enhance_default.logger.name): + result = enhance_default.ConditionalDenoiseEnhancer().enhance( + AudioEnhancementRequest(wav_path=wav_path, model="unsupported") + ) + + assert result.applied is False + assert result.output_path == wav_path + assert result.model == "unsupported" + assert "Unknown DENOISE_MODEL='unsupported'" in caplog.text + + +def test_estimate_snr_uses_energy_heuristic(monkeypatch, tmp_path): + class FakeTensor: + def __init__(self, values): + self.values = np.asarray(values, dtype=np.float32) + + @property + def shape(self): + return self.values.shape + + def __len__(self): + return len(self.values) + + def __getitem__(self, item): + return FakeTensor(self.values[item]) + + def mean(self, dim=None, keepdim=False): + return FakeTensor(np.mean(self.values, axis=dim, keepdims=keepdim)) + + def squeeze(self, dim): + return FakeTensor(np.squeeze(self.values, axis=dim)) + + def reshape(self, *shape): + return FakeTensor(self.values.reshape(*shape)) + + def pow(self, power): + return FakeTensor(np.power(self.values, power)) + + def sqrt(self): + return FakeTensor(np.sqrt(self.values)) + + def sort(self): + return FakeTensor(np.sort(self.values)), None + + def item(self): + return float(np.asarray(self.values).item()) + + quiet = np.full(60, 0.1, dtype=np.float32) + speech = np.ones(240, dtype=np.float32) + mono = np.concatenate([quiet, speech]) + stereo = np.stack([mono, mono]) + torchaudio_module = ModuleType("torchaudio") + torchaudio_module.load = lambda path: (FakeTensor(stereo), 1000) + monkeypatch.setitem(sys.modules, "torchaudio", torchaudio_module) + + assert enhance_default._estimate_snr(tmp_path / "sample.wav") == pytest.approx(20.0) + + +def test_estimate_snr_returns_inf_for_too_short_or_silent_noise(monkeypatch, tmp_path): + class FakeTensor: + def __init__(self, values): + self.values = np.asarray(values, dtype=np.float32) + + @property + def shape(self): + return self.values.shape + + def __len__(self): + return len(self.values) + + def __getitem__(self, item): + return FakeTensor(self.values[item]) + + def squeeze(self, dim): + return FakeTensor(np.squeeze(self.values, axis=dim)) + + def reshape(self, *shape): + return FakeTensor(self.values.reshape(*shape)) + + def pow(self, power): + return FakeTensor(np.power(self.values, power)) + + def mean(self, dim=None, keepdim=False): + return FakeTensor(np.mean(self.values, axis=dim, keepdims=keepdim)) + + def sqrt(self): + return FakeTensor(np.sqrt(self.values)) + + def sort(self): + return FakeTensor(np.sort(self.values)), None + + def item(self): + return float(np.asarray(self.values).item()) + + torchaudio_module = ModuleType("torchaudio") + samples = [ + FakeTensor(np.zeros((1, 60))), + FakeTensor(np.concatenate([np.zeros(60), np.ones(240)]).reshape(1, 300)), + ] + torchaudio_module.load = lambda path: (samples.pop(0), 1000) + monkeypatch.setitem(sys.modules, "torchaudio", torchaudio_module) + + assert enhance_default._estimate_snr(tmp_path / "short.wav") == float("inf") + assert enhance_default._estimate_snr(tmp_path / "silent-noise.wav") == float("inf") + + def test_deepfilternet_lazy_load_logs_elapsed_time(monkeypatch, caplog): monkeypatch.setattr(enhance_default, "_df_model", None) monkeypatch.setattr(enhance_default, "_df_state", None) @@ -266,3 +382,121 @@ def test_hash_index_infra_requires_completed_result(monkeypatch, tmp_path): store.register("hash-b", "tr_ready") assert store.lookup("hash-b") == "tr_ready" + + +def test_ffmpeg_normalizer_reuses_existing_target_format(tmp_path): + wav_path = tmp_path / "already.wav" + wav_path.write_bytes(b"wav") + + result = normalize_default.FFmpegInputNormalizer().normalize( + AudioNormalizationRequest(input_path=wav_path) + ) + + assert result.reused_source is True + assert result.source_path == wav_path + assert result.normalized_path == wav_path + + +def test_ffmpeg_normalizer_invokes_ffmpeg_for_non_wav(monkeypatch, tmp_path): + source = tmp_path / "meeting.ogg" + source.write_bytes(b"ogg") + calls = [] + + def fake_run(args, *, check, timeout): + calls.append((args, check, timeout)) + + monkeypatch.setattr(normalize_default.subprocess, "run", fake_run) + + result = normalize_default.FFmpegInputNormalizer().normalize( + AudioNormalizationRequest(input_path=source) + ) + + args, check, timeout = calls[0] + assert check is True + assert timeout == normalize_default.FFMPEG_TIMEOUT_SEC + assert args[:6] == ["ffmpeg", "-y", "-v", "error", "-i", str(source)] + assert args[-2:] == ["--", str(tmp_path / "meeting.wav")] + assert result.reused_source is False + assert result.normalized_path == tmp_path / "meeting.wav" + + +def test_ffmpeg_normalizer_timeout_cleans_partial(monkeypatch, tmp_path): + source = tmp_path / "meeting.mp3" + source.write_bytes(b"mp3") + partial = tmp_path / "meeting.wav" + partial.write_bytes(b"partial") + + def fake_run(*args, **kwargs): + raise subprocess.TimeoutExpired(cmd="ffmpeg", timeout=1) + + monkeypatch.setattr(normalize_default.subprocess, "run", fake_run) + + with pytest.raises(Exception) as excinfo: + normalize_default.FFmpegInputNormalizer().normalize( + AudioNormalizationRequest(input_path=source) + ) + + assert getattr(excinfo.value, "status_code", None) == 504 + assert not partial.exists() + + +def test_voiceprint_match_provider_reports_no_embeddings(): + result = voiceprint_match_default.DefaultVoiceprintMatchProvider().match( + VoiceprintMatchRequest( + speaker_embeddings={}, + voiceprint_db=object(), + threshold=0.72, + ) + ) + + assert result.applied is False + assert result.speaker_map == {} + assert result.threshold == 0.72 + assert result.reason == "no_embeddings" + + +def test_voiceprint_match_provider_reports_missing_db(): + result = voiceprint_match_default.DefaultVoiceprintMatchProvider().match( + VoiceprintMatchRequest( + speaker_embeddings={"SPEAKER_00": [1.0, 0.0]}, + voiceprint_db=None, + threshold=None, + ) + ) + + assert result.applied is False + assert result.speaker_map == {} + assert result.reason == "voiceprint_db_unavailable" + + +def test_voiceprint_match_provider_uses_identify_threshold_when_supplied(): + class FakeDB: + def __init__(self): + self.calls = [] + + def identify(self, embedding, threshold=None): + self.calls.append((embedding, threshold)) + return "spk_1", "Maple", 0.87654 + + fake_db = FakeDB() + + result = voiceprint_match_default.DefaultVoiceprintMatchProvider().match( + VoiceprintMatchRequest( + speaker_embeddings={"SPEAKER_00": [0.1, 0.9]}, + voiceprint_db=fake_db, + threshold=0.75, + ) + ) + + assert fake_db.calls == [([0.1, 0.9], 0.75)] + assert result.applied is True + assert result.reason == "matched" + assert result.threshold == 0.75 + assert result.speaker_map == { + "SPEAKER_00": { + "matched_id": "spk_1", + "matched_name": "Maple", + "similarity": 0.8765, + "embedding_key": "SPEAKER_00", + } + } diff --git a/tests/unit/test_voiceprint_db.py b/tests/unit/test_voiceprint_db.py index fce3633..a65b3eb 100644 --- a/tests/unit/test_voiceprint_db.py +++ b/tests/unit/test_voiceprint_db.py @@ -16,10 +16,12 @@ import threading import time import wave +import base64 from concurrent.futures import ThreadPoolExecutor from pathlib import Path import numpy as np +import pytest _APP_DIR = Path(__file__).resolve().parents[2] / "app" @@ -223,6 +225,204 @@ def test_manual_rebuild_can_replace_existing_cohort_with_available_sources(tmp_p assert np.load(cohort_path, allow_pickle=False).shape == (1, 256) +def test_legacy_npy_voiceprint_store_migrates_to_sqlite(tmp_path): + """A pre-SQLite voiceprint store must migrate avg and sample embeddings once.""" + db_dir = tmp_path / "voiceprints" + db_dir.mkdir(parents=True) + speaker_id = "spk_legacy" + avg = _unit_vec(7000) + samples = np.stack([avg, _unit_vec(7001)]).astype(np.float32) + + (db_dir / "index.json").write_text( + json.dumps( + { + "speakers": { + speaker_id: { + "name": "Legacy", + "sample_count": 2, + "created_at": "2026-04-01T00:00:00", + "updated_at": "2026-04-02T00:00:00", + } + } + } + ), + encoding="utf-8", + ) + np.save(db_dir / f"{speaker_id}_avg.npy", avg) + np.save(db_dir / f"{speaker_id}_samples.npy", samples) + + db, _mod = _fresh_db(db_dir) + + assert db.list_speakers() == [ + { + "id": speaker_id, + "name": "Legacy", + "sample_count": 2, + "sample_spread": None, + "created_at": "2026-04-01T00:00:00", + "updated_at": "2026-04-02T00:00:00", + } + ] + migrated_samples = db._conn.execute( + "SELECT COUNT(*) FROM speaker_samples WHERE speaker_id = ?", + (speaker_id,), + ).fetchone()[0] + assert migrated_samples == 2 + assert (db_dir / "index.json.migrated.bak").exists() + + +def test_legacy_voiceprint_migration_skips_unreadable_index(tmp_path): + db_dir = tmp_path / "voiceprints" + db_dir.mkdir(parents=True) + (db_dir / "index.json").write_text("{not-json", encoding="utf-8") + + db, _mod = _fresh_db(db_dir) + + assert db.list_speakers() == [] + assert (db_dir / "index.json").exists() + + +def test_legacy_voiceprint_migration_ignores_missing_avg_and_existing_db(tmp_path): + db_dir = tmp_path / "voiceprints" + db_dir.mkdir(parents=True) + (db_dir / "index.json").write_text( + json.dumps({"speakers": {"spk_missing": {"name": "Missing"}}}), + encoding="utf-8", + ) + + db, _mod = _fresh_db(db_dir) + assert db.list_speakers() == [] + assert (db_dir / "index.json.migrated.bak").exists() + + sid = db.add_speaker("Existing", _unit_vec(7100)) + (db_dir / "index.json").write_text( + json.dumps({"speakers": {"spk_other": {"name": "Other"}}}), + encoding="utf-8", + ) + + db._storage._maybe_migrate_legacy() + + assert [speaker["id"] for speaker in db.list_speakers()] == [sid] + assert (db_dir / "index.json").exists() + + +def test_repository_crud_and_private_scan_edges(tmp_path): + db, _mod = _fresh_db(tmp_path / "voiceprints") + sid = db.add_speaker("Maple", _unit_vec(7200)) + + db.rename_speaker(sid, "Maple Renamed") + assert db.get_speaker(sid)["name"] == "Maple Renamed" + assert db.get_speaker("spk_missing") is None + + repo = db._repository + assert repo._find_best_match(_unit_vec(7200))[0] == sid + assert repo._python_cosine_scan(np.zeros(256, dtype=np.float32)) == [] + + with pytest.raises(ValueError, match="No samples"): + repo._recompute_avg_and_spread("spk_no_samples") + + db.delete_speaker(sid) + assert db.get_speaker(sid) is None + with pytest.raises(ValueError, match="not found"): + db.delete_speaker(sid) + + +def test_recompute_spread_handles_zero_average(tmp_path): + db, _mod = _fresh_db(tmp_path / "voiceprints") + zero = np.zeros(256, dtype=np.float32) + sid = db.add_speaker("Zero", zero) + + db.update_speaker(sid, zero) + + row = db.get_speaker(sid) + assert row["sample_count"] == 2 + assert row["sample_spread"] is None + + +def test_cohort_helpers_handle_paths_invalid_files_and_collectors(tmp_path): + _fresh_voiceprint_module() + cohort_mod = importlib.import_module("voiceprints.cohort") + + class DummyDB: + _asnorm = None + _cohort_generation = 3 + _cohort_built_gen = 0 + _lock = threading.RLock() + + manager = cohort_mod.VoiceprintCohortManager( + DummyDB(), + cohort_path=tmp_path / "configured.npy", + embedding_dim=3, + ) + + assert manager.cohort_path == tmp_path / "configured.npy" + assert manager.cohort_size == 0 + assert manager.resolve_path(save_path=tmp_path / "explicit.npy") == ( + tmp_path / "explicit.npy" + ) + assert manager.resolve_path(transcriptions_dir=tmp_path) == ( + tmp_path / "configured.npy" + ) + no_default = cohort_mod.VoiceprintCohortManager(DummyDB(), None, embedding_dim=3) + assert no_default.resolve_path(transcriptions_dir=tmp_path) == ( + tmp_path / "asnorm_cohort.npy" + ) + assert no_default.resolve_path() is None + + invalid_ndim = tmp_path / "invalid.npy" + np.save(invalid_ndim, np.array([1.0, 2.0, 3.0], dtype=np.float32)) + with pytest.raises(ValueError, match="Cohort must be 2D"): + manager.load(str(invalid_ndim)) + assert manager._persisted_cohort_size(invalid_ndim) == 0 + + corrupt = tmp_path / "corrupt.npy" + corrupt.write_bytes(b"not-numpy") + assert manager._persisted_cohort_size(corrupt) == 0 + assert manager._persisted_cohort_size(None) == 0 + + assert manager._should_keep_existing_cohort(source_size=1, current_size=0) is False + assert manager._should_keep_existing_cohort(source_size=1, current_size=2) is True + assert ( + manager._should_keep_existing_cohort( + source_size=1, + current_size=cohort_mod.ASNORM_MIN_COHORT_SIZE, + ) + is True + ) + + collected = [] + encoded = base64.b64encode(np.array([1, 2, 3], dtype=np.float32).tobytes()).decode() + added = manager._collect_json_embeddings( + payload={ + "speaker_embeddings": { + "list": [1, 2, 3], + "encoded": encoded, + "wrong_shape": [1, 2], + "ignored": {"bad": True}, + } + }, + expected_shape=(3,), + collected=collected, + ) + assert added == 2 + assert len(collected) == 2 + + result_path = tmp_path / "tr_collect" / "result.json" + result_path.parent.mkdir() + np.save(result_path.parent / "emb_good.npy", np.array([4, 5, 6], dtype=np.float32)) + np.save(result_path.parent / "emb_wrong.npy", np.array([1, 2], dtype=np.float32)) + (result_path.parent / "emb_bad.npy").write_bytes(b"bad") + + skipped = manager._collect_npy_embeddings( + result_path=result_path, + expected_shape=(3,), + collected=collected, + ) + + assert skipped == 1 + assert len(collected) == 3 + + def test_lifespan_loads_saved_cohort_without_rebuild(tmp_path, monkeypatch): """Startup must load an existing cohort file instead of rebuilding it again.""" transcriptions_dir = tmp_path / "transcriptions"