From cd3c7964d1b56e15b6e718e4896556267bafc8a8 Mon Sep 17 00:00:00 2001 From: Galen Lynch Date: Thu, 23 Apr 2026 13:50:41 -0700 Subject: [PATCH] perf: n_workers kwarg for FilterRecording + CommonReferenceRecording MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds opt-in intra-chunk thread-parallelism to two preprocessors: channel-split sosfilt/sosfiltfilt in FilterRecording, time-split median/mean in CommonReferenceRecording. Default n_workers=1 preserves existing behavior. Per-caller-thread inner pools ----------------------------- Each outer thread that calls ``get_traces()`` on a parallel-enabled segment gets its own inner ThreadPoolExecutor, stored in a ``WeakKeyDictionary`` keyed by the calling ``Thread`` object. Rationale: * Avoids the shared-pool queueing pathology that would occur if N outer workers (e.g., TimeSeriesChunkExecutor with n_jobs=N) all submitted into a single shared pool with fewer max_workers than outer callers. Under a shared pool, ``n_workers=2`` with ``n_jobs=24`` thrashed at 3.36 s on the test pipeline; per-caller pools: 1.47 s. * Keying by the Thread object (not thread-id integer) avoids the thread-id-reuse hazard: thread IDs can be reused after a thread dies, which would cause a new thread to silently inherit a dead thread's pool. * WeakKeyDictionary + weakref.finalize ensures automatic shutdown of the inner pool when the calling thread is garbage-collected. The finalizer calls ``pool.shutdown(wait=False)`` to avoid blocking the finalizer thread; in-flight tasks would be cancelled, but the owning thread submits+joins synchronously, so none exist when it exits. When useful ----------- * Direct ``get_traces()`` callers (interactive viewers, streaming consumers, mipmap-zarr tile builders) that don't use ``TimeSeriesChunkExecutor``. * Default SI users who haven't tuned job_kwargs. * RAM-constrained deployments that can't crank ``n_jobs`` to core count: on a 24-core host, ``n_jobs=6, n_workers=2`` gets within 8% of ``n_jobs=24, n_workers=1`` at ~1/4 the RAM. Performance (1M × 384 float32 BP+CMR pipeline, 24-core host, thread engine) --------------------------------------------------------------------------- === Component-level (scipy/numpy only) === sosfiltfilt serial → 8 threads: 7.80 s → 2.67 s (2.92x) np.median serial → 16 threads: 3.51 s → 0.33 s (10.58x) === Per-stage end-to-end (rec.get_traces) === Bandpass (5th-order, 300-6k Hz): 8.59 s → 3.20 s (2.69x) CMR median (global): 4.01 s → 0.81 s (4.95x) === CRE outer × inner Pareto, per-caller pools === outer=24, inner=1 each: 1.54 s (100% of peak) outer=24, inner=8 each: 1.42 s (108% of peak; oversubscribed) outer=12, inner=1 each: 1.59 s (97%, ~1/2 RAM of outer=24) outer=6, inner=2 each: 1.75 s (92%, ~1/4 RAM of outer=24) outer=4, inner=6 each: 1.83 s (87%, ~1/6 RAM with 24 threads) Tests ----- New ``test_parallel_pool_semantics.py`` verifies the per-caller-thread contract: single caller reuses one pool; concurrent callers get distinct pools. Existing bandpass + CMR tests still pass. Independent of the companion FIR phase-shift PR (perf/phase-shift-fir); the two can land in either order. Co-Authored-By: Claude Opus 4.7 (1M context) --- benchmarks/preprocessing/bench_perf.py | 303 ++++++++++++++++++ .../preprocessing/common_reference.py | 63 +++- src/spikeinterface/preprocessing/filter.py | 78 ++++- .../tests/test_common_reference.py | 34 ++ .../preprocessing/tests/test_filter.py | 34 ++ .../tests/test_parallel_pool_semantics.py | 103 ++++++ 6 files changed, 612 insertions(+), 3 deletions(-) create mode 100644 benchmarks/preprocessing/bench_perf.py create mode 100644 src/spikeinterface/preprocessing/tests/test_parallel_pool_semantics.py diff --git a/benchmarks/preprocessing/bench_perf.py b/benchmarks/preprocessing/bench_perf.py new file mode 100644 index 0000000000..b0016a9f73 --- /dev/null +++ b/benchmarks/preprocessing/bench_perf.py @@ -0,0 +1,303 @@ +"""Benchmark script for the parallel bandpass + CMR speedups. + +Runs head-to-head comparisons on synthetic NumpyRecording fixtures so the +numbers are reproducible without external ephys data: + +1. Component-level (hot operation only, no SI plumbing): + - scipy.signal.sosfiltfilt serial vs channel-parallel threads + - np.median(axis=1) serial vs time-parallel threads +2. Per-stage end-to-end (``rec.get_traces()`` path): + - BandpassFilterRecording stock vs n_workers=8 + - CommonReferenceRecording stock vs n_workers=16 +3. CRE (``TimeSeriesChunkExecutor``) × inner (n_workers) interaction at + matched chunk_duration="1s". + +FilterRecordingSegment and CommonReferenceRecordingSegment use +**per-caller-thread inner pools** (WeakKeyDictionary keyed by the calling +Thread object). Each outer thread that calls get_traces() gets its own +inner ThreadPoolExecutor, so n_workers composes cleanly with CRE's outer +parallelism — no shared-pool queueing pathology. See +``tests/test_parallel_pool_semantics.py`` for the contract. + +Measured on a 24-core x86_64 host with 1M x 384 float32 chunks (SI 0.103 +dev, numpy 2.1, scipy 1.14, full get_traces() path end-to-end): + + === Component-level (hot kernel only, no SI plumbing) === + sosfiltfilt serial → 8 threads: 7.80 s → 2.67 s (2.92x) + np.median serial → 16 threads: 3.51 s → 0.33 s (10.58x) + + === Per-stage end-to-end (rec.get_traces) === + Bandpass (5th-order, 300-6k Hz): 8.59 s → 3.20 s (2.69x) + CMR median (global): 4.01 s → 0.81 s (4.95x) + + === CRE outer × inner (chunk=1s, per-caller pools) === + Bandpass: stock n=1 → stock n=8 thread: 7.42 s → 1.40 s (5.3x outer) + n_workers=8 n=1: 3.18 s (2.3x inner) + n_workers=8 n=8 thread: 1.24 s (combined) + CMR: stock n=1 → stock n=8 thread: 3.98 s → 0.61 s (6.5x outer) + n_workers=16 n=1: 1.58 s (2.5x inner) + n_workers=16 n=8 thread: 0.36 s (11.0x combined) + +Bandpass and CMR scale sub-linearly with thread count due to memory +bandwidth saturation; 2.7x / 5x per stage on 8 / 16 threads respectively +is consistent with the DRAM ceiling at these chunk sizes, not a +parallelism bug. Under CRE, the outer-vs-inner combination depends on +whether the inner pool has headroom over n_jobs — per-caller pools make +this deterministic regardless. + +Run with ``python -m benchmarks.preprocessing.bench_perf`` from repo root. +""" + +from __future__ import annotations + +import time + +import numpy as np +import scipy.signal + +from spikeinterface import NumpyRecording +from spikeinterface.preprocessing import ( + BandpassFilterRecording, + CommonReferenceRecording, +) + + +def _make_recording(T: int = 1_048_576, C: int = 384, fs: float = 30_000.0, dtype=np.float32): + """Synthetic NumpyRecording matching typical Neuropixels shard shape.""" + rng = np.random.default_rng(0) + traces = rng.standard_normal((T, C)).astype(dtype) * 100.0 + rec = NumpyRecording([traces], sampling_frequency=fs) + return rec + + +def _time_get_traces(rec, *, n_reps=3, warmup=1): + """Median-of-N timing of rec.get_traces() for the full single segment.""" + for _ in range(warmup): + rec.get_traces() + times = [] + for _ in range(n_reps): + t0 = time.perf_counter() + rec.get_traces() + times.append(time.perf_counter() - t0) + return float(np.median(times)) + + +def _time_callable(fn, *, n_reps=3, warmup=1): + """Best-of-N timing for a bare callable. Used for component-level benches + where we want to isolate the hot operation from surrounding glue.""" + for _ in range(warmup): + fn() + times = [] + for _ in range(n_reps): + t0 = time.perf_counter() + fn() + times.append(time.perf_counter() - t0) + return float(min(times)) + + +def _time_cre(executor, *, n_reps=2, warmup=1): + """Min-of-N timing for a TimeSeriesChunkExecutor invocation.""" + for _ in range(warmup): + executor.run() + times = [] + for _ in range(n_reps): + t0 = time.perf_counter() + executor.run() + times.append(time.perf_counter() - t0) + return float(min(times)) + + +def _cre_init(recording): + return {"recording": recording} + + +def _cre_func(segment_index, start_frame, end_frame, worker_dict): + worker_dict["recording"].get_traces( + start_frame=start_frame, end_frame=end_frame, segment_index=segment_index + ) + + +def bench_sosfiltfilt_component(): + """Component-level bench: just scipy.signal.sosfiltfilt vs channel-parallel. + + Isolates the hot SOS operation from the full BandpassFilter.get_traces + path so you can see the kernel-only speedup (no margin fetch, no dtype + cast, no slice). + """ + from concurrent.futures import ThreadPoolExecutor + + print("--- [component] sosfiltfilt (1M x 384 float32) ---") + T, C = 1_048_576, 384 + rng = np.random.default_rng(0) + x = rng.standard_normal((T, C)).astype(np.float32) * 100.0 + sos = scipy.signal.butter(5, [300.0, 6000.0], btype="bandpass", fs=30_000.0, output="sos") + + pool = ThreadPoolExecutor(max_workers=8) + + def parallel_call(): + block = (C + 8 - 1) // 8 + bounds = [(c0, min(c0 + block, C)) for c0 in range(0, C, block)] + + def _work(c0, c1): + return c0, c1, scipy.signal.sosfiltfilt(sos, x[:, c0:c1], axis=0) + + results = [fut.result() for fut in [pool.submit(_work, c0, c1) for c0, c1 in bounds]] + out = np.empty((T, C), dtype=results[0][2].dtype) + for c0, c1, block_out in results: + out[:, c0:c1] = block_out + return out + + t_stock = _time_callable(lambda: scipy.signal.sosfiltfilt(sos, x, axis=0)) + t_par = _time_callable(parallel_call) + pool.shutdown() + print(f" scipy.sosfiltfilt serial: {t_stock:6.2f} s") + print(f" scipy.sosfiltfilt 8 threads: {t_par:6.2f} s ({t_stock / t_par:4.2f}x)") + print() + + +def bench_median_component(): + """Component-level bench: just np.median(axis=1) vs threaded across time blocks.""" + from concurrent.futures import ThreadPoolExecutor + + print("--- [component] np.median axis=1 (1M x 384 float32) ---") + T, C = 1_048_576, 384 + rng = np.random.default_rng(0) + x = rng.standard_normal((T, C)).astype(np.float32) * 100.0 + + pool = ThreadPoolExecutor(max_workers=16) + + def parallel_call(): + block = (T + 16 - 1) // 16 + bounds = [(t0, min(t0 + block, T)) for t0 in range(0, T, block)] + + def _work(t0, t1): + return t0, t1, np.median(x[t0:t1, :], axis=1) + + results = [fut.result() for fut in [pool.submit(_work, t0, t1) for t0, t1 in bounds]] + out = np.empty(T, dtype=results[0][2].dtype) + for t0, t1, block_out in results: + out[t0:t1] = block_out + return out + + t_stock = _time_callable(lambda: np.median(x, axis=1)) + t_par = _time_callable(parallel_call) + pool.shutdown() + print(f" np.median serial: {t_stock:6.2f} s") + print(f" np.median 16 threads: {t_par:6.2f} s ({t_stock / t_par:4.2f}x)") + print() + + +def bench_bandpass(): + """End-to-end bench: BandpassFilterRecording stock vs n_workers=8.""" + print("=== Bandpass (5th-order Butterworth 300-6000 Hz, 1M x 384 float32) ===") + rec = _make_recording(dtype=np.float32) + stock = BandpassFilterRecording(rec, freq_min=300.0, freq_max=6000.0, margin_ms=40.0) + fast = BandpassFilterRecording(rec, freq_min=300.0, freq_max=6000.0, margin_ms=40.0, n_workers=8) + + t_stock = _time_get_traces(stock) + t_fast = _time_get_traces(fast) + print(f" stock (n_workers=1): {t_stock:6.2f} s") + print(f" parallel (n_workers=8): {t_fast:6.2f} s ({t_stock / t_fast:4.2f}x)") + # Equivalence check + ref = stock.get_traces(start_frame=1000, end_frame=10_000) + out = fast.get_traces(start_frame=1000, end_frame=10_000) + assert np.allclose(out, ref, rtol=1e-5, atol=1e-4), "parallel bandpass output mismatch" + print(" output matches stock within float32 tolerance") + print() + + +def bench_cmr(): + """End-to-end bench: CommonReferenceRecording stock vs n_workers=16.""" + print("=== CMR median (global, 1M x 384 float32) ===") + rec = _make_recording(dtype=np.float32) + stock = CommonReferenceRecording(rec, operator="median", reference="global") + fast = CommonReferenceRecording(rec, operator="median", reference="global", n_workers=16) + + t_stock = _time_get_traces(stock) + t_fast = _time_get_traces(fast) + print(f" stock (n_workers=1): {t_stock:6.2f} s") + print(f" parallel (n_workers=16): {t_fast:6.2f} s ({t_stock / t_fast:4.2f}x)") + ref = stock.get_traces(start_frame=1000, end_frame=10_000) + out = fast.get_traces(start_frame=1000, end_frame=10_000) + np.testing.assert_array_equal(out, ref) + print(" output is bitwise-identical to stock") + print() + + +def bench_bandpass_cre_interaction(): + """Bandpass: outer (TimeSeriesChunkExecutor) × inner (n_workers) parallelism. + + At SI's default ``chunk_duration="1s"``, the intra-chunk ``n_workers`` + kwarg is only useful when outer CRE workers don't already saturate cores. + When combined, the result depends on whether inner-pool ``max_workers`` + exceeds outer ``n_jobs``. + """ + from spikeinterface.core.job_tools import TimeSeriesChunkExecutor + + print("=== Bandpass: outer (CRE) × inner (n_workers), 1M × 384 float32, chunk=1s ===") + rec = _make_recording(dtype=np.float32) + + def make_cre(bp_rec, n_jobs): + return TimeSeriesChunkExecutor( + time_series=bp_rec, func=_cre_func, init_func=_cre_init, init_args=(bp_rec,), + pool_engine="thread", n_jobs=n_jobs, chunk_duration="1s", progress_bar=False, + ) + + t_stock_n1 = _time_cre(make_cre(BandpassFilterRecording(rec), n_jobs=1)) + t_stock_n8 = _time_cre(make_cre(BandpassFilterRecording(rec), n_jobs=8)) + t_fast_n1 = _time_cre(make_cre(BandpassFilterRecording(rec, n_workers=8), n_jobs=1)) + t_fast_n8 = _time_cre(make_cre(BandpassFilterRecording(rec, n_workers=8), n_jobs=8)) + + print(f" {'config':<40} {'time':>8} {'vs baseline':>12}") + print(f" {'stock, CRE n=1 (baseline)':<40} {t_stock_n1:6.2f} s {'1.00×':>12}") + print(f" {'stock, CRE n=8 thread':<40} {t_stock_n8:6.2f} s {t_stock_n1/t_stock_n8:5.2f}× (outer only)") + print(f" {'n_workers=8, CRE n=1':<40} {t_fast_n1:6.2f} s {t_stock_n1/t_fast_n1:5.2f}× (inner only)") + print(f" {'n_workers=8, CRE n=8 thread':<40} {t_fast_n8:6.2f} s {t_stock_n1/t_fast_n8:5.2f}× (both)") + print() + + +def bench_cmr_cre_interaction(): + """CMR: outer (TimeSeriesChunkExecutor) × inner (n_workers) parallelism.""" + from spikeinterface.core.job_tools import TimeSeriesChunkExecutor + + print("=== CMR: outer (CRE) × inner (n_workers), 1M × 384 float32, chunk=1s ===") + rec = _make_recording(dtype=np.float32) + + def make_cre(cmr_rec, n_jobs): + return TimeSeriesChunkExecutor( + time_series=cmr_rec, func=_cre_func, init_func=_cre_init, init_args=(cmr_rec,), + pool_engine="thread", n_jobs=n_jobs, chunk_duration="1s", progress_bar=False, + ) + + t_stock_n1 = _time_cre(make_cre(CommonReferenceRecording(rec), n_jobs=1)) + t_stock_n8 = _time_cre(make_cre(CommonReferenceRecording(rec), n_jobs=8)) + t_fast_n1 = _time_cre(make_cre(CommonReferenceRecording(rec, n_workers=16), n_jobs=1)) + t_fast_n8 = _time_cre(make_cre(CommonReferenceRecording(rec, n_workers=16), n_jobs=8)) + + print(f" {'config':<40} {'time':>8} {'vs baseline':>12}") + print(f" {'stock, CRE n=1 (baseline)':<40} {t_stock_n1:6.2f} s {'1.00×':>12}") + print(f" {'stock, CRE n=8 thread':<40} {t_stock_n8:6.2f} s {t_stock_n1/t_stock_n8:5.2f}× (outer only)") + print(f" {'n_workers=16, CRE n=1':<40} {t_fast_n1:6.2f} s {t_stock_n1/t_fast_n1:5.2f}× (inner only)") + print(f" {'n_workers=16, CRE n=8 thread':<40} {t_fast_n8:6.2f} s {t_stock_n1/t_fast_n8:5.2f}× (both)") + print() + + +def main(): + print("### COMPONENT-LEVEL (hot operation only) ###") + print() + bench_sosfiltfilt_component() + bench_median_component() + + print("### PER-STAGE END-TO-END (rec.get_traces()) ###") + print() + bench_bandpass() + bench_cmr() + + print("### CRE OUTER × INNER (chunk=1s) ###") + print() + bench_bandpass_cre_interaction() + bench_cmr_cre_interaction() + + +if __name__ == "__main__": + main() diff --git a/src/spikeinterface/preprocessing/common_reference.py b/src/spikeinterface/preprocessing/common_reference.py index 5a3a9b0043..555b7f21f5 100644 --- a/src/spikeinterface/preprocessing/common_reference.py +++ b/src/spikeinterface/preprocessing/common_reference.py @@ -1,4 +1,6 @@ +import threading import warnings +import weakref from typing import Literal import numpy as np @@ -88,6 +90,7 @@ def __init__( local_radius: tuple[float, float] = (30.0, 55.0), min_local_neighbors: int = 5, dtype: str | np.dtype | None = None, + n_workers: int = 1, ): num_chans = recording.get_num_channels() local_kernel = None @@ -154,6 +157,7 @@ def __init__( else: ref_channel_indices = None + assert int(n_workers) >= 1, "n_workers must be >= 1" for parent_segment in recording.segments: rec_segment = CommonReferenceRecordingSegment( parent_segment, @@ -163,6 +167,7 @@ def __init__( ref_channel_indices, local_kernel, dtype_, + n_workers=int(n_workers), ) self.add_recording_segment(rec_segment) @@ -175,6 +180,7 @@ def __init__( local_radius=local_radius, min_local_neighbors=min_local_neighbors, dtype=dtype_.str, + n_workers=int(n_workers), ) @@ -188,6 +194,7 @@ def __init__( ref_channel_indices, local_kernel, dtype, + n_workers=1, ): BasePreprocessorSegment.__init__(self, parent_recording_segment) @@ -200,6 +207,59 @@ def __init__( self.dtype = dtype self.operator = operator self.operator_func = np.mean if self.operator == "average" else np.median + self.n_workers = int(n_workers) + # Per-caller-thread lazy pool map. See filter.FilterRecordingSegment + # for full rationale and WeakKeyDictionary mechanics. + self._cmr_pools = weakref.WeakKeyDictionary() + self._cmr_pools_lock = threading.Lock() + + def _get_pool(self): + """Lazy per-caller-thread thread pool for parallel median/mean across time blocks.""" + if self.n_workers <= 1: + return None + thread = threading.current_thread() + pool = self._cmr_pools.get(thread) + if pool is None: + with self._cmr_pools_lock: + pool = self._cmr_pools.get(thread) + if pool is None: + from concurrent.futures import ThreadPoolExecutor + + pool = ThreadPoolExecutor(max_workers=self.n_workers) + self._cmr_pools[thread] = pool + weakref.finalize(thread, pool.shutdown, wait=False) + return pool + + def _parallel_reduce_axis1(self, traces): + """Apply ``operator_func(..., axis=1)`` split across time blocks. + + numpy's partition-based median and BLAS-backed mean release the GIL + during per-row work, so Python-thread parallelism delivers real + speedup (measured ~10× on 16 threads for 1M × 384 median). + """ + if self.n_workers == 1: + return self.operator_func(traces, axis=1) + T = traces.shape[0] + # Minimum block size per worker: below this, per-thread overhead + # outweighs the parallelism gain. + min_block = 8192 + effective = max(1, min(self.n_workers, T // min_block)) + if effective == 1: + return self.operator_func(traces, axis=1) + pool = self._get_pool() + block = (T + effective - 1) // effective + bounds = [(t0, min(t0 + block, T)) for t0 in range(0, T, block)] + + def _work(t0, t1): + return t0, t1, self.operator_func(traces[t0:t1, :], axis=1) + + futures = [pool.submit(_work, t0, t1) for t0, t1 in bounds] + results = [fut.result() for fut in futures] + out_dtype = results[0][2].dtype + out = np.empty(T, dtype=out_dtype) + for t0, t1, block_out in results: + out[t0:t1] = block_out + return out def get_traces(self, start_frame, end_frame, channel_indices): # Let's do the case with group_indices equal None as that is easy @@ -209,7 +269,8 @@ def get_traces(self, start_frame, end_frame, channel_indices): if self.reference == "global": if self.ref_channel_indices is None: - shift = self.operator_func(traces, axis=1, keepdims=True) + # Hot path: parallelizable global median/mean across all channels. + shift = self._parallel_reduce_axis1(traces)[:, np.newaxis] else: shift = self.operator_func(traces[:, self.ref_channel_indices], axis=1, keepdims=True) re_referenced_traces = traces[:, channel_indices] - shift diff --git a/src/spikeinterface/preprocessing/filter.py b/src/spikeinterface/preprocessing/filter.py index b4ceed886e..6356cc8902 100644 --- a/src/spikeinterface/preprocessing/filter.py +++ b/src/spikeinterface/preprocessing/filter.py @@ -1,4 +1,6 @@ +import threading import warnings +import weakref import numpy as np @@ -92,10 +94,12 @@ def __init__( coeff=None, dtype=None, direction="forward-backward", + n_workers=1, ): import scipy.signal assert filter_mode in ("sos", "ba"), "'filter' mode must be 'sos' or 'ba'" + assert int(n_workers) >= 1, "n_workers must be >= 1" fs = recording.get_sampling_frequency() if coeff is None: assert btype in ("bandpass", "highpass"), "'bytpe' must be 'bandpass' or 'highpass'" @@ -140,6 +144,7 @@ def __init__( dtype, add_reflect_padding=add_reflect_padding, direction=direction, + n_workers=int(n_workers), ) ) @@ -155,6 +160,7 @@ def __init__( add_reflect_padding=add_reflect_padding, dtype=dtype.str, direction=direction, + n_workers=int(n_workers), ) @@ -168,6 +174,7 @@ def __init__( dtype, add_reflect_padding=False, direction="forward-backward", + n_workers=1, ): BasePreprocessorSegment.__init__(self, parent_recording_segment) self.coeff = coeff @@ -176,6 +183,73 @@ def __init__( self.margin = margin self.add_reflect_padding = add_reflect_padding self.dtype = dtype + self.n_workers = int(n_workers) + # Per-caller-thread lazy pool map. Each outer thread that calls + # get_traces() on this segment gets its own inner pool, avoiding the + # shared-pool queueing pathology that would occur if multiple outer + # workers (e.g., a TimeSeriesChunkExecutor with n_jobs > 1) all + # dispatched into a single shared pool on the segment. + # + # WeakKeyDictionary + weakref.finalize: entries are keyed by the Thread + # object itself (not by thread-id integer, which can be reused after a + # thread dies). When the calling thread is garbage-collected, its + # inner pool is shut down (non-blocking) and the dict entry drops, so + # long-running processes don't accumulate zombie pools. + self._filter_pools = weakref.WeakKeyDictionary() + self._filter_pools_lock = threading.Lock() + + def _get_pool(self): + """Lazy per-caller-thread thread pool for channel-parallel filtering.""" + if self.n_workers <= 1: + return None + thread = threading.current_thread() + pool = self._filter_pools.get(thread) + if pool is None: + with self._filter_pools_lock: + pool = self._filter_pools.get(thread) + if pool is None: + from concurrent.futures import ThreadPoolExecutor + + pool = ThreadPoolExecutor(max_workers=self.n_workers) + self._filter_pools[thread] = pool + # When the calling thread is GC'd, shut down its pool + # without blocking the finalizer thread. In-flight + # tasks would be cancelled, but the owning thread + # submits + joins synchronously, so no such tasks + # exist when the thread actually exits. + weakref.finalize(thread, pool.shutdown, wait=False) + return pool + + def _apply_sos(self, fn, traces, axis=0): + """Apply a scipy SOS function across channel blocks in parallel. + + Each channel is independent of every other channel, so splitting the + channel axis across threads is a safe parallelization. scipy's C + implementations of ``sosfiltfilt``/``sosfilt`` release the GIL during + per-column work, so Python-thread parallelism delivers real speedup + (measured ~3× on 8 threads for a 1M × 384 float32 chunk). + """ + if self.n_workers == 1: + return fn(self.coeff, traces, axis=axis) + C = traces.shape[1] + if C < 2 * self.n_workers: + return fn(self.coeff, traces, axis=axis) + pool = self._get_pool() + block = (C + self.n_workers - 1) // self.n_workers + bounds = [(c0, min(c0 + block, C)) for c0 in range(0, C, block)] + + def _work(c0, c1): + return c0, c1, fn(self.coeff, traces[:, c0:c1], axis=axis) + + futures = [pool.submit(_work, c0, c1) for c0, c1 in bounds] + results = [fut.result() for fut in futures] + # Allocate the output using the first block's dtype (scipy may promote + # int input to float64). + out_dtype = results[0][2].dtype + out = np.empty((traces.shape[0], C), dtype=out_dtype) + for c0, c1, block_out in results: + out[:, c0:c1] = block_out + return out def get_traces(self, start_frame, end_frame, channel_indices): traces_chunk, left_margin, right_margin = get_chunk_with_margin( @@ -196,7 +270,7 @@ def get_traces(self, start_frame, end_frame, channel_indices): if self.direction == "forward-backward": if self.filter_mode == "sos": - filtered_traces = scipy.signal.sosfiltfilt(self.coeff, traces_chunk, axis=0) + filtered_traces = self._apply_sos(scipy.signal.sosfiltfilt, traces_chunk, axis=0) elif self.filter_mode == "ba": b, a = self.coeff filtered_traces = scipy.signal.filtfilt(b, a, traces_chunk, axis=0) @@ -205,7 +279,7 @@ def get_traces(self, start_frame, end_frame, channel_indices): traces_chunk = np.flip(traces_chunk, axis=0) if self.filter_mode == "sos": - filtered_traces = scipy.signal.sosfilt(self.coeff, traces_chunk, axis=0) + filtered_traces = self._apply_sos(scipy.signal.sosfilt, traces_chunk, axis=0) elif self.filter_mode == "ba": b, a = self.coeff filtered_traces = scipy.signal.lfilter(b, a, traces_chunk, axis=0) diff --git a/src/spikeinterface/preprocessing/tests/test_common_reference.py b/src/spikeinterface/preprocessing/tests/test_common_reference.py index e19cad59ba..f074417e22 100644 --- a/src/spikeinterface/preprocessing/tests/test_common_reference.py +++ b/src/spikeinterface/preprocessing/tests/test_common_reference.py @@ -209,5 +209,39 @@ def test_local_car_vs_cmr_performance(): assert car_time < cmr_time +def test_cmr_parallel_median_matches_stock(): + """CommonReferenceRecording(n_workers=N) must produce bit-identical output.""" + from spikeinterface import NumpyRecording + + rng = np.random.default_rng(0) + T, C = 60_000, 64 + traces = (rng.standard_normal((T, C)) * 100).astype("float32") + rec = NumpyRecording([traces], sampling_frequency=30_000.0) + stock = common_reference(rec, reference="global", operator="median") + fast = common_reference(rec, reference="global", operator="median", n_workers=8) + ref = stock.get_traces(start_frame=5_000, end_frame=T - 5_000) + out = fast.get_traces(start_frame=5_000, end_frame=T - 5_000) + np.testing.assert_array_equal(out, ref) + + +def test_cmr_parallel_average_matches_stock(): + """Same invariant for the mean (CAR) operator; tolerate float rounding.""" + from spikeinterface import NumpyRecording + + rng = np.random.default_rng(0) + T, C = 60_000, 64 + traces = (rng.standard_normal((T, C)) * 100).astype("float32") + rec = NumpyRecording([traces], sampling_frequency=30_000.0) + stock = common_reference(rec, reference="global", operator="average") + fast = common_reference(rec, reference="global", operator="average", n_workers=8) + ref = stock.get_traces(start_frame=5_000, end_frame=T - 5_000) + out = fast.get_traces(start_frame=5_000, end_frame=T - 5_000) + # Mean across different block partitions can differ by 1 ULP due to + # non-associative float summation. + np.testing.assert_allclose(out, ref, rtol=1e-5, atol=1e-4) + + if __name__ == "__main__": test_local_car_vs_cmr_performance() + test_cmr_parallel_median_matches_stock() + test_cmr_parallel_average_matches_stock() diff --git a/src/spikeinterface/preprocessing/tests/test_filter.py b/src/spikeinterface/preprocessing/tests/test_filter.py index e95b456542..c5bbf7961a 100644 --- a/src/spikeinterface/preprocessing/tests/test_filter.py +++ b/src/spikeinterface/preprocessing/tests/test_filter.py @@ -220,5 +220,39 @@ def test_filter_opencl(): # plt.show() +def test_bandpass_parallel_matches_stock(): + """BandpassFilterRecording(n_workers=N) must produce the same output as n_workers=1. + + Locks in the invariant that channel-axis parallelism is a pure perf + optimisation — scipy's sosfiltfilt is channel-independent so splitting + the channel axis across threads cannot change per-channel output. + """ + rng = np.random.default_rng(0) + T, C = 60_000, 64 + traces = (rng.standard_normal((T, C)) * 100).astype("float32") + rec = NumpyRecording([traces], sampling_frequency=30_000.0) + stock = bandpass_filter(rec, freq_min=300.0, freq_max=5000.0, dtype="float32") + fast = bandpass_filter(rec, freq_min=300.0, freq_max=5000.0, dtype="float32", n_workers=8) + ref = stock.get_traces(start_frame=5_000, end_frame=T - 5_000) + out = fast.get_traces(start_frame=5_000, end_frame=T - 5_000) + np.testing.assert_allclose(out, ref, rtol=1e-5, atol=1e-4) + + +def test_filter_parallel_fewer_channels_than_workers(): + """n_workers > C must still produce correct output (falls through to serial).""" + rng = np.random.default_rng(0) + T, C = 10_000, 4 + traces = (rng.standard_normal((T, C)) * 100).astype("float32") + rec = NumpyRecording([traces], sampling_frequency=30_000.0) + fast = bandpass_filter(rec, freq_min=300.0, freq_max=5000.0, dtype="float32", n_workers=16) + # Should not raise; should match stock. + stock = bandpass_filter(rec, freq_min=300.0, freq_max=5000.0, dtype="float32") + ref = stock.get_traces(start_frame=1000, end_frame=T - 1000) + out = fast.get_traces(start_frame=1000, end_frame=T - 1000) + np.testing.assert_allclose(out, ref, rtol=1e-5, atol=1e-4) + + if __name__ == "__main__": test_filter() + test_bandpass_parallel_matches_stock() + test_filter_parallel_fewer_channels_than_workers() diff --git a/src/spikeinterface/preprocessing/tests/test_parallel_pool_semantics.py b/src/spikeinterface/preprocessing/tests/test_parallel_pool_semantics.py new file mode 100644 index 0000000000..16abf00018 --- /dev/null +++ b/src/spikeinterface/preprocessing/tests/test_parallel_pool_semantics.py @@ -0,0 +1,103 @@ +"""Tests for the per-caller-thread pool semantics used by FilterRecording and +CommonReferenceRecording when ``n_workers > 1``. + +Contract: each outer thread that calls ``get_traces()`` on a parallel-enabled +segment gets its own inner ThreadPoolExecutor. Keying by thread avoids the +shared-pool queueing pathology that arises when many outer workers submit +concurrently into a single inner pool with fewer max_workers than outer +callers. See the module-level comments in filter.py and common_reference.py +for the full rationale. +""" + +from __future__ import annotations + +import threading + +import numpy as np +import pytest + +from spikeinterface import NumpyRecording +from spikeinterface.preprocessing import ( + BandpassFilterRecording, + CommonReferenceRecording, +) + + +def _make_recording(T: int = 50_000, C: int = 64, fs: float = 30_000.0): + rng = np.random.default_rng(0) + traces = rng.standard_normal((T, C)).astype(np.float32) * 100.0 + return NumpyRecording([traces], sampling_frequency=fs) + + +@pytest.fixture +def filter_segment(): + rec = _make_recording() + bp = BandpassFilterRecording(rec, freq_min=300.0, freq_max=6000.0, n_workers=4) + return bp, bp._recording_segments[0] + + +@pytest.fixture +def cmr_segment(): + rec = _make_recording() + cmr = CommonReferenceRecording(rec, operator="median", reference="global", n_workers=4) + return cmr, cmr._recording_segments[0] + + +class TestPerCallerThreadPool: + """Verify each calling thread gets its own inner pool.""" + + @pytest.mark.parametrize( + "segment_fixture,pools_attr", + [ + ("filter_segment", "_filter_pools"), + ("cmr_segment", "_cmr_pools"), + ], + ) + def test_single_caller_reuses_pool(self, segment_fixture, pools_attr, request): + """Repeated calls from the same thread reuse the same inner pool.""" + rec, seg = request.getfixturevalue(segment_fixture) + rec.get_traces(start_frame=0, end_frame=50_000) + pool_a = getattr(seg, pools_attr).get(threading.current_thread()) + rec.get_traces() + pool_b = getattr(seg, pools_attr).get(threading.current_thread()) + assert pool_a is not None + assert pool_a is pool_b, "expected the same inner pool to be reused across calls from the same thread" + + @pytest.mark.parametrize( + "segment_fixture,pools_attr", + [ + ("filter_segment", "_filter_pools"), + ("cmr_segment", "_cmr_pools"), + ], + ) + def test_concurrent_callers_get_distinct_pools(self, segment_fixture, pools_attr, request): + """Two outer threads calling get_traces concurrently must receive + different inner pools — not a shared one that would queue their + tasks through a single bottleneck. + """ + rec, seg = request.getfixturevalue(segment_fixture) + + ready = threading.Barrier(2) + captured = {} + + def worker(name): + # Align the two threads so they're definitely live concurrently + # when they touch the pool-map, exercising the double-checked + # locking path. + ready.wait() + rec.get_traces(start_frame=0, end_frame=50_000) + captured[name] = getattr(seg, pools_attr).get(threading.current_thread()) + + t1 = threading.Thread(target=worker, args=("t1",)) + t2 = threading.Thread(target=worker, args=("t2",)) + t1.start() + t2.start() + t1.join() + t2.join() + + assert captured["t1"] is not None + assert captured["t2"] is not None + assert captured["t1"] is not captured["t2"], ( + "expected distinct inner pools for concurrent callers; Model 1 " + "shared-pool semantics would cause queueing pathology" + )