diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index f23b524271..eb0642b79d 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -910,9 +910,27 @@ def __init__(self, sampling_frequency=None, t_start=None, time_vector=None): self.sampling_frequency = sampling_frequency self.t_start = t_start self.time_vector = time_vector + self._num_channels = None BaseSegment.__init__(self) + @property + def num_channels(self): + # Return an explicit value if a subclass set one (via the `num_channels` kwarg + # at construction or by assigning `self._num_channels = N`). Otherwise derive from + # the container recording through the weakref established in `add_segment`. + if self._num_channels is not None: + return self._num_channels + if self._parent_extractor is None: + return None + container_recording = self._parent_extractor() + if container_recording is None: + return None + return container_recording.get_num_channels() + + def get_num_channels(self): + return self.num_channels + def get_times(self) -> np.ndarray: if self.time_vector is not None: self.time_vector = np.asarray(self.time_vector) diff --git a/src/spikeinterface/core/binaryrecordingextractor.py b/src/spikeinterface/core/binaryrecordingextractor.py index b3eaa099ed..81511538ef 100644 --- a/src/spikeinterface/core/binaryrecordingextractor.py +++ b/src/spikeinterface/core/binaryrecordingextractor.py @@ -173,7 +173,7 @@ def __del__(self): class BinaryRecordingSegment(BaseRecordingSegment): def __init__(self, file_path, sampling_frequency, t_start, num_channels, dtype, time_axis, file_offset): BaseRecordingSegment.__init__(self, sampling_frequency=sampling_frequency, t_start=t_start) - self.num_channels = num_channels + self._num_channels = num_channels self.dtype = np.dtype(dtype) self.file_offset = file_offset self.time_axis = time_axis diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 1c9ece728f..fcb571ab56 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -1359,7 +1359,7 @@ def __init__( BaseRecordingSegment.__init__(self, sampling_frequency=sampling_frequency) self.num_samples = num_samples - self.num_channels = num_channels + self._num_channels = num_channels self.noise_block_size = noise_block_size self.noise_levels = noise_levels self.cov_matrix = cov_matrix @@ -2075,9 +2075,9 @@ def get_traces( channel_indices: list | None = None, ) -> np.ndarray: if channel_indices is None: - n_channels = self.templates.shape[2] + n_channels = self.num_channels elif isinstance(channel_indices, slice): - stop = channel_indices.stop if channel_indices.stop is not None else self.templates.shape[2] + stop = channel_indices.stop if channel_indices.stop is not None else self.num_channels start = channel_indices.start if channel_indices.start is not None else 0 step = channel_indices.step if channel_indices.step is not None else 1 n_channels = math.ceil((stop - start) / step) diff --git a/src/spikeinterface/generation/drift_tools.py b/src/spikeinterface/generation/drift_tools.py index 1800138dae..d5096ebb41 100644 --- a/src/spikeinterface/generation/drift_tools.py +++ b/src/spikeinterface/generation/drift_tools.py @@ -542,9 +542,9 @@ def get_traces( end_frame = self.num_samples if end_frame is None else end_frame if channel_indices is None: - n_channels = self.drifting_templates.num_channels + n_channels = self.num_channels elif isinstance(channel_indices, slice): - stop = channel_indices.stop if channel_indices.stop is not None else self.drifting_templates.num_channels + stop = channel_indices.stop if channel_indices.stop is not None else self.num_channels start = channel_indices.start if channel_indices.start is not None else 0 step = channel_indices.step if channel_indices.step is not None else 1 n_channels = math.ceil((stop - start) / step) diff --git a/src/spikeinterface/preprocessing/average_across_direction.py b/src/spikeinterface/preprocessing/average_across_direction.py index 8d1c4475cd..a6045e85db 100644 --- a/src/spikeinterface/preprocessing/average_across_direction.py +++ b/src/spikeinterface/preprocessing/average_across_direction.py @@ -103,7 +103,7 @@ def __init__( ): BasePreprocessorSegment.__init__(self, parent_recording_segment) self.parent_recording_segment = parent_recording_segment - self.num_channels = num_channels + self._num_channels = num_channels self.same_along_dim_chans = same_along_dim_chans self.n_chans_each_pos = n_chans_each_pos self._dtype = dtype diff --git a/src/spikeinterface/preprocessing/zero_channel_pad.py b/src/spikeinterface/preprocessing/zero_channel_pad.py index 45d4809cd8..ee5f80bb3f 100644 --- a/src/spikeinterface/preprocessing/zero_channel_pad.py +++ b/src/spikeinterface/preprocessing/zero_channel_pad.py @@ -35,7 +35,6 @@ def __init__(self, recording: BaseRecording, padding_start: int = 0, padding_end for segment in recording.segments: recording_segment = TracePaddedRecordingSegment( segment, - recording.get_num_channels(), self.dtype, self.padding_start, self.padding_end, @@ -55,7 +54,6 @@ class TracePaddedRecordingSegment(BasePreprocessorSegment): def __init__( self, recording_segment: BaseRecordingSegment, - num_channels, dtype, padding_left, padding_end, @@ -64,7 +62,6 @@ def __init__( self.padding_start = padding_left self.padding_end = padding_end self.fill_value = fill_value - self.num_channels = num_channels self.num_samples_in_original_segment = recording_segment.get_num_samples() self.dtype = dtype @@ -165,7 +162,7 @@ def __init__(self, recording: BaseRecording, num_channels: int, channel_mapping: self.parent_recording = recording self.num_channels = num_channels for segment in recording.segments: - recording_segment = ZeroChannelPaddedRecordingSegment(segment, self.num_channels, self.channel_mapping) + recording_segment = ZeroChannelPaddedRecordingSegment(segment, self.channel_mapping) self.add_recording_segment(recording_segment) # only copy relevant metadata and properties @@ -182,10 +179,9 @@ def __init__(self, recording: BaseRecording, num_channels: int, channel_mapping: class ZeroChannelPaddedRecordingSegment(BasePreprocessorSegment): - def __init__(self, recording_segment: BaseRecordingSegment, num_channels: int, channel_mapping: list): + def __init__(self, recording_segment: BaseRecordingSegment, channel_mapping: list): BasePreprocessorSegment.__init__(self, recording_segment) self.parent_recording_segment = recording_segment - self.num_channels = num_channels self.channel_mapping = channel_mapping def get_traces(self, start_frame, end_frame, channel_indices):