diff --git a/src/spikeinterface/core/basesorting.py b/src/spikeinterface/core/basesorting.py index fc8373cdfb..b17cca6506 100644 --- a/src/spikeinterface/core/basesorting.py +++ b/src/spikeinterface/core/basesorting.py @@ -908,7 +908,8 @@ def _compute_and_cache_spike_vector(self) -> None: spikes.append(spikes_in_seg) spikes = np.concatenate(spikes) - spikes = spikes[np.lexsort((spikes["unit_index"], spikes["sample_index"], spikes["segment_index"]))] + # the spikes are not lexsorted here because the previous loop ensure that the spike vector is constructucted alway the same way. + # spikes = spikes[np.lexsort((spikes["unit_index"], spikes["sample_index"], spikes["segment_index"]))] self._cached_spike_vector = spikes self._cached_spike_vector_segment_slices = segment_slices diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 4fa68ebec0..ba50df7ed7 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -168,7 +168,6 @@ def generate_sorting( spikes_in_seg["sample_index"] = samples spikes_in_seg["unit_index"] = labels spikes_in_seg["segment_index"] = segment_index - spikes.append(spikes_in_seg) if add_spikes_on_borders: spikes_on_borders = np.zeros(2 * num_spikes_per_border, dtype=minimum_spike_dtype) @@ -182,10 +181,15 @@ def generate_sorting( spikes_on_borders["sample_index"][num_spikes_per_border:] = rng.integers( num_samples - border_size_samples, num_samples, num_spikes_per_border ) - spikes.append(spikes_on_borders) + spikes_in_seg = np.concatenate([spikes_in_seg, spikes_on_borders]) + order = np.argsort(spikes_in_seg["sample_index"]) + spikes_in_seg = spikes_in_seg[order] + + spikes.append(spikes_in_seg) spikes = np.concatenate(spikes) - spikes = spikes[np.lexsort((spikes["unit_index"], spikes["sample_index"], spikes["segment_index"]))] + # the spikes do not need a full lexsort because synthesize_poisson_spike_vector() garanty spikes to sorted by frame already + # spikes = spikes[np.lexsort((spikes["unit_index"], spikes["sample_index"], spikes["segment_index"]))] sorting = NumpySorting(spikes, sampling_frequency, unit_ids) diff --git a/src/spikeinterface/core/testing.py b/src/spikeinterface/core/testing.py index 2b5a7c5157..67516cbf68 100644 --- a/src/spikeinterface/core/testing.py +++ b/src/spikeinterface/core/testing.py @@ -104,7 +104,11 @@ def check_recordings_equal( def check_sortings_equal( - SX1: BaseSorting, SX2: BaseSorting, check_annotations: bool = False, check_properties: bool = False + SX1: BaseSorting, + SX2: BaseSorting, + check_annotations: bool = False, + check_properties: bool = False, + check_exact_lexsort: bool = True, ) -> None: assert SX1.get_num_segments() == SX2.get_num_segments() @@ -112,6 +116,11 @@ def check_sortings_equal( s1 = SX1.to_spike_vector() s2 = SX2.to_spike_vector() + if not check_exact_lexsort: + # 2 sorting can be equal even if the internal lexsort is not the same. + # spiketrains still wiwll be the same per units + s1 = s1[np.lexsort((s1["unit_index"], s1["sample_index"], s1["segment_index"]))] + s2 = s2[np.lexsort((s2["unit_index"], s2["sample_index"], s2["segment_index"]))] assert_array_equal(s1, s2) for start_frame, end_frame in [ diff --git a/src/spikeinterface/core/tests/test_generate.py b/src/spikeinterface/core/tests/test_generate.py index c5a0b83f87..c5a29c62d3 100644 --- a/src/spikeinterface/core/tests/test_generate.py +++ b/src/spikeinterface/core/tests/test_generate.py @@ -655,8 +655,8 @@ def test_synthesize_random_firings_length(): # test_generate_recording() # test_generate_single_fake_waveform() # test_transformsorting() - test_generate_unit_locations() + # test_generate_unit_locations() # test_generate_templates() # test_inject_templates() # test_generate_ground_truth_recording() - # test_generate_sorting_with_spikes_on_borders() + test_generate_sorting_with_spikes_on_borders() diff --git a/src/spikeinterface/core/tests/test_sorting_folder.py b/src/spikeinterface/core/tests/test_sorting_folder.py index a285ae8d29..f32b6d7f81 100644 --- a/src/spikeinterface/core/tests/test_sorting_folder.py +++ b/src/spikeinterface/core/tests/test_sorting_folder.py @@ -38,12 +38,17 @@ def test_NpzFolderSorting(create_cache_folder): NpzFolderSorting.write_sorting(sorting, folder) sorting_loaded = NpzFolderSorting(folder) - check_sortings_equal(sorting_loaded, sorting) + # the NpzFolderSorting is a by unit storage and te lexsort is not maintain always so check_exact_lexsort=False + check_sortings_equal(sorting_loaded, sorting, check_exact_lexsort=False) assert np.array_equal(sorting_loaded.unit_ids, sorting.unit_ids) - assert np.array_equal( - sorting_loaded.to_spike_vector(), - sorting.to_spike_vector(), - ) + + # Note changing the class do not necessarily maintain the internal internal. + # but the vectors should be the same after lexsort + s1 = sorting_loaded.to_spike_vector() + s2 = sorting.to_spike_vector() + s1 = s1[np.lexsort((s1["unit_index"], s1["sample_index"], s1["segment_index"]))] + s2 = s2[np.lexsort((s2["unit_index"], s2["sample_index"], s2["segment_index"]))] + assert np.array_equal(s1, s2) if __name__ == "__main__": diff --git a/src/spikeinterface/core/tests/test_sortinganalyzer.py b/src/spikeinterface/core/tests/test_sortinganalyzer.py index a9bd71b5c0..14e62384fa 100644 --- a/src/spikeinterface/core/tests/test_sortinganalyzer.py +++ b/src/spikeinterface/core/tests/test_sortinganalyzer.py @@ -210,9 +210,12 @@ def test_create_by_dict(): Interally, this aggregates the dicts of recordings and sortings. This test checks that the unit structure is maintained from the dicts to the analyzer. Then checks that the function fails if the dict keys are different for the recordings and the sortings. + + Note, in this tests sparse is False because units are randomlly assign to differents of the + recording and they can have no channels """ - rec, sort = generate_ground_truth_recording(num_channels=6) + rec, sort = generate_ground_truth_recording(num_channels=6, seed=2205) rec.set_property(key="group", values=[1, 2, 1, 1, 2, 2]) sort.set_property(key="group", values=[2, 2, 2, 1, 2, 2, 2, 1, 2, 1]) @@ -220,7 +223,7 @@ def test_create_by_dict(): unit_ids = sort.unit_ids split_sort = sort.split_by("group") split_rec = rec.split_by("group") - analyzer = create_sorting_analyzer(split_sort, split_rec) + analyzer = create_sorting_analyzer(split_sort, split_rec, sparse=False) analyzer_unit_ids = analyzer.unit_ids assert set(analyzer.unit_ids) == set(sort.unit_ids) @@ -236,7 +239,7 @@ def test_create_by_dict(): } with pytest.raises(ValueError): - analyzer = create_sorting_analyzer(split_sort_bad_keys, rec.split_by("group")) + analyzer = create_sorting_analyzer(split_sort_bad_keys, rec.split_by("group"), sparse=False) # make a dict of sortings, in a different order than the recording. This should # still work @@ -244,7 +247,7 @@ def test_create_by_dict(): 2: sort.select_units(unit_ids=unit_ids[sort.get_property("group") == 2]), 1: sort.select_units(unit_ids=unit_ids[sort.get_property("group") == 1]), } - combined_analyzer = create_sorting_analyzer(split_sort_different_order, rec.split_by("group")) + combined_analyzer = create_sorting_analyzer(split_sort_different_order, rec.split_by("group"), sparse=False) assert np.all(sort.get_unit_spike_train(unit_id="5") == combined_analyzer.sorting.get_unit_spike_train(unit_id="5")) @@ -715,12 +718,13 @@ def test_runtime_dependencies(dataset): if __name__ == "__main__": - tmp_path = Path("test_SortingAnalyzer") - dataset = get_dataset() - test_SortingAnalyzer_memory(tmp_path, dataset) - test_SortingAnalyzer_binary_folder(tmp_path, dataset) - test_SortingAnalyzer_zarr(tmp_path, dataset) - test_SortingAnalyzer_tmp_recording(dataset) - test_extension() - test_extension_params() - test_runtime_dependencies() + # tmp_path = Path("test_SortingAnalyzer") + # dataset = get_dataset() + # test_SortingAnalyzer_memory(tmp_path, dataset) + # test_SortingAnalyzer_binary_folder(tmp_path, dataset) + # test_SortingAnalyzer_zarr(tmp_path, dataset) + # test_SortingAnalyzer_tmp_recording(dataset) + # test_extension() + # test_extension_params() + # test_runtime_dependencies() + test_create_by_dict() diff --git a/src/spikeinterface/core/zarrextractors.py b/src/spikeinterface/core/zarrextractors.py index bbc797c693..c57af6d0d3 100644 --- a/src/spikeinterface/core/zarrextractors.py +++ b/src/spikeinterface/core/zarrextractors.py @@ -294,7 +294,10 @@ def __init__(self, folder_path: Path | str, storage_options: dict | None = None, spikes["unit_index"] = spikes_group["unit_index"][:] for i, (start, end) in enumerate(segment_slices_list): spikes["segment_index"][start:end] = i - spikes = spikes[np.lexsort((spikes["unit_index"], spikes["sample_index"], spikes["segment_index"]))] + # we do not need to lexsort at init (very high cost) because there already sorted by frame before to be saved. + # During version 0.104.X this was fully lexsorted in the file but this is not annoying. + # spikes = spikes[np.lexsort((spikes["unit_index"], spikes["sample_index"], spikes["segment_index"]))] + self._cached_spike_vector = spikes for segment_index in range(num_segments):