Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/spikeinterface/core/basesorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -908,7 +908,8 @@ def _compute_and_cache_spike_vector(self) -> None:
spikes.append(spikes_in_seg)

spikes = np.concatenate(spikes)
spikes = spikes[np.lexsort((spikes["unit_index"], spikes["sample_index"], spikes["segment_index"]))]
# the spikes are not lexsorted here because the previous loop ensure that the spike vector is constructucted alway the same way.
# spikes = spikes[np.lexsort((spikes["unit_index"], spikes["sample_index"], spikes["segment_index"]))]

self._cached_spike_vector = spikes
self._cached_spike_vector_segment_slices = segment_slices
Expand Down
10 changes: 7 additions & 3 deletions src/spikeinterface/core/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,6 @@ def generate_sorting(
spikes_in_seg["sample_index"] = samples
spikes_in_seg["unit_index"] = labels
spikes_in_seg["segment_index"] = segment_index
spikes.append(spikes_in_seg)

if add_spikes_on_borders:
spikes_on_borders = np.zeros(2 * num_spikes_per_border, dtype=minimum_spike_dtype)
Expand All @@ -182,10 +181,15 @@ def generate_sorting(
spikes_on_borders["sample_index"][num_spikes_per_border:] = rng.integers(
num_samples - border_size_samples, num_samples, num_spikes_per_border
)
spikes.append(spikes_on_borders)
spikes_in_seg = np.concatenate([spikes_in_seg, spikes_on_borders])
order = np.argsort(spikes_in_seg["sample_index"])
spikes_in_seg = spikes_in_seg[order]

spikes.append(spikes_in_seg)

spikes = np.concatenate(spikes)
spikes = spikes[np.lexsort((spikes["unit_index"], spikes["sample_index"], spikes["segment_index"]))]
# the spikes do not need a full lexsort because synthesize_poisson_spike_vector() garanty spikes to sorted by frame already
# spikes = spikes[np.lexsort((spikes["unit_index"], spikes["sample_index"], spikes["segment_index"]))]

sorting = NumpySorting(spikes, sampling_frequency, unit_ids)

Expand Down
11 changes: 10 additions & 1 deletion src/spikeinterface/core/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,14 +104,23 @@ def check_recordings_equal(


def check_sortings_equal(
SX1: BaseSorting, SX2: BaseSorting, check_annotations: bool = False, check_properties: bool = False
SX1: BaseSorting,
SX2: BaseSorting,
check_annotations: bool = False,
check_properties: bool = False,
check_exact_lexsort: bool = True,
) -> None:
assert SX1.get_num_segments() == SX2.get_num_segments()

max_spike_index = SX1.to_spike_vector()["sample_index"].max()

s1 = SX1.to_spike_vector()
s2 = SX2.to_spike_vector()
if not check_exact_lexsort:
# 2 sorting can be equal even if the internal lexsort is not the same.
# spiketrains still wiwll be the same per units
s1 = s1[np.lexsort((s1["unit_index"], s1["sample_index"], s1["segment_index"]))]
s2 = s2[np.lexsort((s2["unit_index"], s2["sample_index"], s2["segment_index"]))]
assert_array_equal(s1, s2)

for start_frame, end_frame in [
Expand Down
4 changes: 2 additions & 2 deletions src/spikeinterface/core/tests/test_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,8 +655,8 @@ def test_synthesize_random_firings_length():
# test_generate_recording()
# test_generate_single_fake_waveform()
# test_transformsorting()
test_generate_unit_locations()
# test_generate_unit_locations()
# test_generate_templates()
# test_inject_templates()
# test_generate_ground_truth_recording()
# test_generate_sorting_with_spikes_on_borders()
test_generate_sorting_with_spikes_on_borders()
15 changes: 10 additions & 5 deletions src/spikeinterface/core/tests/test_sorting_folder.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,17 @@ def test_NpzFolderSorting(create_cache_folder):
NpzFolderSorting.write_sorting(sorting, folder)

sorting_loaded = NpzFolderSorting(folder)
check_sortings_equal(sorting_loaded, sorting)
# the NpzFolderSorting is a by unit storage and te lexsort is not maintain always so check_exact_lexsort=False
check_sortings_equal(sorting_loaded, sorting, check_exact_lexsort=False)
assert np.array_equal(sorting_loaded.unit_ids, sorting.unit_ids)
assert np.array_equal(
sorting_loaded.to_spike_vector(),
sorting.to_spike_vector(),
)

# Note changing the class do not necessarily maintain the internal internal.
# but the vectors should be the same after lexsort
s1 = sorting_loaded.to_spike_vector()
s2 = sorting.to_spike_vector()
s1 = s1[np.lexsort((s1["unit_index"], s1["sample_index"], s1["segment_index"]))]
s2 = s2[np.lexsort((s2["unit_index"], s2["sample_index"], s2["segment_index"]))]
assert np.array_equal(s1, s2)


if __name__ == "__main__":
Expand Down
30 changes: 17 additions & 13 deletions src/spikeinterface/core/tests/test_sortinganalyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,17 +210,20 @@ def test_create_by_dict():
Interally, this aggregates the dicts of recordings and sortings. This test checks that the
unit structure is maintained from the dicts to the analyzer. Then checks that the function
fails if the dict keys are different for the recordings and the sortings.

Note, in this tests sparse is False because units are randomlly assign to differents of the
recording and they can have no channels
"""

rec, sort = generate_ground_truth_recording(num_channels=6)
rec, sort = generate_ground_truth_recording(num_channels=6, seed=2205)

rec.set_property(key="group", values=[1, 2, 1, 1, 2, 2])
sort.set_property(key="group", values=[2, 2, 2, 1, 2, 2, 2, 1, 2, 1])

unit_ids = sort.unit_ids
split_sort = sort.split_by("group")
split_rec = rec.split_by("group")
analyzer = create_sorting_analyzer(split_sort, split_rec)
analyzer = create_sorting_analyzer(split_sort, split_rec, sparse=False)
analyzer_unit_ids = analyzer.unit_ids

assert set(analyzer.unit_ids) == set(sort.unit_ids)
Expand All @@ -236,15 +239,15 @@ def test_create_by_dict():
}

with pytest.raises(ValueError):
analyzer = create_sorting_analyzer(split_sort_bad_keys, rec.split_by("group"))
analyzer = create_sorting_analyzer(split_sort_bad_keys, rec.split_by("group"), sparse=False)

# make a dict of sortings, in a different order than the recording. This should
# still work
split_sort_different_order = {
2: sort.select_units(unit_ids=unit_ids[sort.get_property("group") == 2]),
1: sort.select_units(unit_ids=unit_ids[sort.get_property("group") == 1]),
}
combined_analyzer = create_sorting_analyzer(split_sort_different_order, rec.split_by("group"))
combined_analyzer = create_sorting_analyzer(split_sort_different_order, rec.split_by("group"), sparse=False)
assert np.all(sort.get_unit_spike_train(unit_id="5") == combined_analyzer.sorting.get_unit_spike_train(unit_id="5"))


Expand Down Expand Up @@ -715,12 +718,13 @@ def test_runtime_dependencies(dataset):


if __name__ == "__main__":
tmp_path = Path("test_SortingAnalyzer")
dataset = get_dataset()
test_SortingAnalyzer_memory(tmp_path, dataset)
test_SortingAnalyzer_binary_folder(tmp_path, dataset)
test_SortingAnalyzer_zarr(tmp_path, dataset)
test_SortingAnalyzer_tmp_recording(dataset)
test_extension()
test_extension_params()
test_runtime_dependencies()
# tmp_path = Path("test_SortingAnalyzer")
# dataset = get_dataset()
# test_SortingAnalyzer_memory(tmp_path, dataset)
# test_SortingAnalyzer_binary_folder(tmp_path, dataset)
# test_SortingAnalyzer_zarr(tmp_path, dataset)
# test_SortingAnalyzer_tmp_recording(dataset)
# test_extension()
# test_extension_params()
# test_runtime_dependencies()
test_create_by_dict()
5 changes: 4 additions & 1 deletion src/spikeinterface/core/zarrextractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,10 @@ def __init__(self, folder_path: Path | str, storage_options: dict | None = None,
spikes["unit_index"] = spikes_group["unit_index"][:]
for i, (start, end) in enumerate(segment_slices_list):
spikes["segment_index"][start:end] = i
spikes = spikes[np.lexsort((spikes["unit_index"], spikes["sample_index"], spikes["segment_index"]))]
# we do not need to lexsort at init (very high cost) because there already sorted by frame before to be saved.
# During version 0.104.X this was fully lexsorted in the file but this is not annoying.
# spikes = spikes[np.lexsort((spikes["unit_index"], spikes["sample_index"], spikes["segment_index"]))]

self._cached_spike_vector = spikes

for segment_index in range(num_segments):
Expand Down
Loading