diff --git a/src/spikeinterface/curation/auto_merge.py b/src/spikeinterface/curation/auto_merge.py index b66b553fd7..66ebf439c5 100644 --- a/src/spikeinterface/curation/auto_merge.py +++ b/src/spikeinterface/curation/auto_merge.py @@ -88,7 +88,7 @@ "censored_period_ms": 0.3, }, "quality_score": {"firing_contamination_balance": 1.5, "refractory_period_ms": 1.0, "censored_period_ms": 0.3}, - "slay_score": {"k1": 0.25, "k2": 1, "slay_threshold": 0.5}, + "slay_score": {"k1": 0.25, "k2": 1, "slay_threshold": 0.5, "censored_period_ms": 0.2}, } @@ -309,7 +309,6 @@ def compute_merge_unit_groups( win_sizes, pair_mask=pair_mask, ) - # print(correlogram_diff) pair_mask = pair_mask & (correlogram_diff < params["corr_diff_thresh"]) outs["correlograms"] = correlograms outs["bins"] = bins @@ -373,12 +372,20 @@ def compute_merge_unit_groups( outs["pairs_decreased_score"] = pairs_decreased_score elif step == "slay_score": - - M_ij = compute_slay_matrix( - sorting_analyzer, params["k1"], params["k2"], templates_diff=outs["templates_diff"], pair_mask=pair_mask + M_ij, sigma_ij, rho_ij, eta_ij = compute_slay_matrix( + sorting_analyzer, + params["k1"], + params["k2"], + params["censored_period_ms"], + templates_diff=outs["templates_diff"], + pair_mask=pair_mask, ) pair_mask = pair_mask & (M_ij > params["slay_threshold"]) + outs["slay_M_ij"] = M_ij + outs["slay_sigma_ij"] = sigma_ij + outs["slay_rho_ij"] = rho_ij + outs["slay_eta_ij"] = eta_ij # FINAL STEP : create the final list from pair_mask boolean matrix ind1, ind2 = np.nonzero(pair_mask) @@ -1552,6 +1559,7 @@ def compute_slay_matrix( sorting_analyzer: SortingAnalyzer, k1: float, k2: float, + censor_period_ms: float, templates_diff: np.ndarray | None, pair_mask: np.ndarray | None = None, ): @@ -1569,6 +1577,9 @@ def compute_slay_matrix( Coefficient determining the importance of the cross-correlation significance k2 : float Coefficient determining the importance of the sliding rp violation + censor_period_ms : float + The censored period to exclude from the refractory period computation to discard + duplicated spikes. templates_diff : np.ndarray | None Pre-computed template similarity difference matrix. If None, it will be retrieved from the sorting_analyzer. pair_mask : None | np.ndarray, default: None @@ -1592,14 +1603,35 @@ def compute_slay_matrix( sigma_ij = 1 - templates_diff else: sigma_ij = sorting_analyzer.get_extension("template_similarity").get_data() - rho_ij, eta_ij = compute_xcorr_and_rp(sorting_analyzer, pair_mask) + rho_ij, eta_ij = compute_xcorr_and_rp(sorting_analyzer, pair_mask, censor_period_ms) M_ij = sigma_ij + k1 * rho_ij - k2 * eta_ij - return M_ij + return M_ij, sigma_ij, rho_ij, eta_ij -def compute_xcorr_and_rp(sorting_analyzer: SortingAnalyzer, pair_mask: np.ndarray): +def _count_coincident_spikes(t1, t2, max_samples): + """ + Count spikes in t1 that have a matching spike in t2 within max_samples, + split by lag direction. + + Returns (n_nonneg, n_neg) where n_nonneg counts pairs where t2 >= t1 + (non-negative lag, mapped to the right center CCG bin) and n_neg counts + pairs where t2 < t1 (negative lag, mapped to the left center CCG bin). + """ + if len(t1) == 0 or len(t2) == 0: + return 0, 0 + indices = np.searchsorted(t2, t1, side="left") + right_valid = indices < len(t2) + right_diffs = np.where(right_valid, t2[np.minimum(indices, len(t2) - 1)] - t1, max_samples + 1) + left_valid = indices > 0 + left_diffs = np.where(left_valid, t1 - t2[np.maximum(indices - 1, 0)], max_samples + 1) + n_nonneg = int(np.sum(right_diffs <= max_samples)) + n_neg = int(np.sum((left_diffs <= max_samples) & (left_diffs > 0))) + return n_nonneg, n_neg + + +def compute_xcorr_and_rp(sorting_analyzer: SortingAnalyzer, pair_mask: np.ndarray, censor_period_ms: float): """ Computes a cross-correlation significance measure and a sliding refractory period violation measure for all units in the `sorting_analyzer`. @@ -1610,14 +1642,36 @@ def compute_xcorr_and_rp(sorting_analyzer: SortingAnalyzer, pair_mask: np.ndarra The sorting analyzer object containing the spike sorting data pair_mask : np.ndarray A bool matrix describing which pairs are possible merges based on previous steps + censor_period_ms : float + The censored period to exclude from the refractory period computation to discard + duplicated spikes. + + Returns + ------- + rho_ij : np.ndarray + The cross-correlation significance measure for each pair of units. + eta_ij : np.ndarray + The sliding refractory period violation measure for each pair of units. """ correlograms_extension = sorting_analyzer.get_extension("correlograms") - ccgs, _ = correlograms_extension.get_data() + ccgs, bin_edges = correlograms_extension.get_data() # convert to seconds for SLAy functions bin_size_ms = correlograms_extension.params["bin_ms"] + # pre-fetch spike trains for duplicate counting (sub-bin resolution) + if censor_period_ms > 0: + sorting = sorting_analyzer.sorting + censor_period_samples = int(censor_period_ms / 1000 * sorting_analyzer.sampling_frequency) + n_segments = sorting_analyzer.get_num_segments() + spike_trains = [ + [sorting.get_unit_spike_train(unit_id=uid, segment_index=seg) for seg in range(n_segments)] + for uid in sorting_analyzer.unit_ids + ] + # lag=0 spike pairs land in the bin starting at 0: xgram[num_half_bins] + center_bin = ccgs.shape[2] // 2 + rho_ij = np.zeros([len(sorting_analyzer.unit_ids), len(sorting_analyzer.unit_ids)]) eta_ij = np.zeros([len(sorting_analyzer.unit_ids), len(sorting_analyzer.unit_ids)]) @@ -1630,10 +1684,39 @@ def compute_xcorr_and_rp(sorting_analyzer: SortingAnalyzer, pair_mask: np.ndarra xgram = ccgs[unit_index_1, unit_index_2, :] + # Merged ACG approximation: sum of individual ACGs and both CCG directions. + # _sliding_RP_viol_pair expects the ACG of the merged unit; the merged ACG + # has a large center bin when duplicates are present, which the LP filter + # attenuates so bin_rate_max reflects the flank rate — making RP violations + # detectable (unlike using the CCG alone where bin_rate_max is dominated by + # the duplicate peak and masks violations). + merged_acg = ( + ccgs[unit_index_1, unit_index_1, :] + + ccgs[unit_index_2, unit_index_2, :] + + ccgs[unit_index_1, unit_index_2, :] + + ccgs[unit_index_2, unit_index_1, :] + ) + + if censor_period_ms > 0: + # count number of spikes within the censor period from the two units + n_right, n_left = 0, 0 + for seg in range(n_segments): + r, l = _count_coincident_spikes( + spike_trains[unit_index_1][seg], spike_trains[unit_index_2][seg], censor_period_samples + ) + n_right += r + n_left += l + # subtract number of duplicates from central bin(s) of the merged ACG: + # n_right pairs land in center_bin (lag ≥ 0), n_left in center_bin-1 (lag < 0); + # each direction is counted in both ccgs[i,j] and ccgs[j,i], hence the factor 2 + merged_acg = merged_acg.copy() + merged_acg[center_bin] = max(0, merged_acg[center_bin] - 2 * n_right) + merged_acg[center_bin - 1] = max(0, merged_acg[center_bin - 1] - 2 * n_left) + rho_ij[unit_index_1, unit_index_2] = _compute_xcorr_pair( xgram, bin_size_s=bin_size_ms / 1000, min_xcorr_rate=0 ) - eta_ij[unit_index_1, unit_index_2] = _sliding_RP_viol_pair(xgram, bin_size_ms=bin_size_ms) + eta_ij[unit_index_1, unit_index_2] = _sliding_RP_viol_pair(merged_acg, bin_size_ms=bin_size_ms) return rho_ij, eta_ij diff --git a/src/spikeinterface/curation/tests/common.py b/src/spikeinterface/curation/tests/common.py index a665b074a6..4b1c1095fb 100644 --- a/src/spikeinterface/curation/tests/common.py +++ b/src/spikeinterface/curation/tests/common.py @@ -35,12 +35,9 @@ def make_sorting_analyzer(sparse=True, num_units=5, durations=[300.0]): sorting = sorting.rename_units(new_unit_ids=unit_ids_as_integers) sorting_analyzer = create_sorting_analyzer( - sorting=sorting, - recording=recording, - format="memory", - sparse=sparse, + sorting=sorting, recording=recording, format="memory", sparse=sparse, n_jobs=-1 ) - sorting_analyzer.compute(extensions) + sorting_analyzer.compute(extensions, n_jobs=-1) return sorting_analyzer @@ -58,9 +55,9 @@ def make_sorting_analyzer_with_splits(sorting_analyzer, num_unit_splitted=1, num ) sorting_analyzer_with_splits = create_sorting_analyzer( - sorting=sorting_with_split, recording=sorting_analyzer.recording, format="memory", sparse=True + sorting=sorting_with_split, recording=sorting_analyzer.recording, format="memory", sparse=True, n_jobs=-1 ) - sorting_analyzer_with_splits.compute(extensions) + sorting_analyzer_with_splits.compute(extensions, n_jobs=-1) return sorting_analyzer_with_splits, num_unit_splitted, other_ids @@ -78,8 +75,8 @@ def sorting_analyzer_for_unitrefine_curation(): recording, sorting_1 = generate_ground_truth_recording(num_channels=4, seed=1, num_units=6) _, sorting_2 = generate_ground_truth_recording(num_channels=4, seed=2, num_units=6) both_sortings = aggregate_units([sorting_1, sorting_2]) - analyzer = create_sorting_analyzer(sorting=both_sortings, recording=recording) - analyzer.compute(["random_spikes", "noise_levels", "templates"]) + analyzer = create_sorting_analyzer(sorting=both_sortings, recording=recording, n_jobs=-1) + analyzer.compute(["random_spikes", "noise_levels", "templates"], n_jobs=-1) return analyzer diff --git a/src/spikeinterface/curation/tests/test_auto_merge.py b/src/spikeinterface/curation/tests/test_auto_merge.py index cab508f4fb..eeafdc9a07 100644 --- a/src/spikeinterface/curation/tests/test_auto_merge.py +++ b/src/spikeinterface/curation/tests/test_auto_merge.py @@ -1,12 +1,14 @@ import pytest +import numpy as np -from spikeinterface.core import create_sorting_analyzer +from spikeinterface.core import create_sorting_analyzer, NumpySorting from spikeinterface.curation import compute_merge_unit_groups, auto_merge_units from spikeinterface.generation import split_sorting_by_times from spikeinterface.curation.tests.common import ( + extensions, make_sorting_analyzer, sorting_analyzer_for_curation, sorting_analyzer_with_splits, @@ -68,6 +70,58 @@ def test_compute_merge_unit_groups_multi_segment(sorting_analyzer_multi_segment_ ) +def test_slay_discard_duplicated_spikes(sorting_analyzer_with_splits): + sorting_analyzer, num_unit_splitted, split_ids = sorting_analyzer_with_splits + + # now for the split units, we add some duplicated spikes + percent_duplicated = 0.7 + split_units = [] + for split in split_ids: + split_units.extend(split_ids[split]) + + # add unsplit spiketrains untouched + new_spiketrains = {} + for unit_id in sorting_analyzer.unit_ids: + if unit_id in split_ids: + continue + new_spiketrains[unit_id] = sorting_analyzer.sorting.get_unit_spike_train(unit_id=unit_id) + # ad duplicated spikes for split units + for unit_id in split_ids: + split_units = split_ids[unit_id] + spiketrains0 = sorting_analyzer.sorting.get_unit_spike_train(unit_id=split_units[0]) + spiketrains1 = sorting_analyzer.sorting.get_unit_spike_train(unit_id=split_units[1]) + num_duplicated = int(percent_duplicated * min(len(spiketrains0), len(spiketrains1))) + duplicated_spikes0 = np.random.choice(spiketrains0, size=num_duplicated, replace=False) + new_spiketrain1 = np.sort(np.concatenate([spiketrains1, duplicated_spikes0])) + + new_spiketrains[split_units[0]] = spiketrains0 + new_spiketrains[split_units[1]] = new_spiketrain1 + + sorting_duplicated = NumpySorting.from_unit_dict( + new_spiketrains, sampling_frequency=sorting_analyzer.sampling_frequency + ) + + sorting_analyzer_duplicated = create_sorting_analyzer( + sorting_duplicated, sorting_analyzer.recording, format="memory" + ) + sorting_analyzer_duplicated.compute(extensions) + + # Without censor period the split should not be found because of duplicates. + merges_no_censor_period, outs_no_censor_period = compute_merge_unit_groups( + sorting_analyzer_duplicated, + preset="slay", + steps_params={"slay_score": {"censored_period_ms": 0.0}}, + extra_outputs=True, + ) + merges_censor_period, outs_censor_period = compute_merge_unit_groups( + sorting_analyzer_duplicated, + preset="slay", + steps_params={"slay_score": {"censored_period_ms": 0.5}}, + extra_outputs=True, + ) + assert np.sum(outs_censor_period["slay_eta_ij"]) < np.sum(outs_no_censor_period["slay_eta_ij"]) + + def test_auto_merge_units(sorting_analyzer_for_curation): recording = sorting_analyzer_for_curation.recording new_sorting, _ = split_sorting_by_times(sorting_analyzer_for_curation) diff --git a/src/spikeinterface/postprocessing/correlograms.py b/src/spikeinterface/postprocessing/correlograms.py index 40ac386ecc..492440a91e 100644 --- a/src/spikeinterface/postprocessing/correlograms.py +++ b/src/spikeinterface/postprocessing/correlograms.py @@ -208,7 +208,7 @@ def _run(self, verbose=False, **job_kwargs): self.data["bins"] = bins def _get_data(self): - return self.data["ccgs"], self.data["bins"] + return self.data["ccgs"].copy(), self.data["bins"].copy() class ComputeAutoCorrelograms(AnalyzerExtension):