Skip to content

Commit 8330cde

Browse files
committed
Merge branch 'perf/prepared-write-v2' of github.com:d-v-b/zarr-python into perf/prepared-write-v2-bench
2 parents a18b20f + 5fb28b9 commit 8330cde

1 file changed

Lines changed: 84 additions & 27 deletions

File tree

src/zarr/core/codec_pipeline.py

Lines changed: 84 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -763,6 +763,13 @@ class ChunkLayout:
763763
def is_sharded(self) -> bool:
764764
return False
765765

766+
def needed_coords(self, chunk_selection: SelectorTuple) -> set[tuple[int, ...]] | None:
767+
"""Compute which inner chunk coordinates overlap a selection.
768+
769+
Returns ``None`` for trivial layouts (only one inner chunk).
770+
"""
771+
return None
772+
766773
def unpack_blob(self, blob: Buffer) -> dict[tuple[int, ...], Buffer | None]:
767774
raise NotImplementedError
768775

@@ -771,18 +778,31 @@ def pack_blob(
771778
) -> Buffer | None:
772779
raise NotImplementedError
773780

774-
async def fetch_full_shard(
775-
self, byte_getter: Any
781+
async def fetch(
782+
self,
783+
byte_getter: Any,
784+
needed_coords: set[tuple[int, ...]] | None = None,
776785
) -> dict[tuple[int, ...], Buffer | None] | None:
777-
"""Fetch all inner chunk buffers. IO phase.
786+
"""Fetch inner chunk buffers from the store. IO phase.
778787
779-
For non-sharded, fetches the full blob. For sharded, fetches the
780-
index and then the needed inner chunks via byte-range reads.
788+
Parameters
789+
----------
790+
byte_getter
791+
The store path to read from.
792+
needed_coords
793+
The set of inner chunk coordinates to fetch. ``None`` means all.
794+
795+
Returns
796+
-------
797+
A mapping from inner chunk coordinates to their raw bytes, or
798+
``None`` if the blob/shard does not exist in the store.
781799
"""
782800
raise NotImplementedError
783801

784-
def fetch_full_shard_sync(
785-
self, byte_getter: Any
802+
def fetch_sync(
803+
self,
804+
byte_getter: Any,
805+
needed_coords: set[tuple[int, ...]] | None = None,
786806
) -> dict[tuple[int, ...], Buffer | None] | None:
787807
raise NotImplementedError
788808

@@ -806,8 +826,10 @@ def pack_blob(
806826
key = (0,) * len(self.chunks_per_shard)
807827
return chunk_dict.get(key)
808828

809-
async def fetch_full_shard(
810-
self, byte_getter: Any
829+
async def fetch(
830+
self,
831+
byte_getter: Any,
832+
needed_coords: set[tuple[int, ...]] | None = None,
811833
) -> dict[tuple[int, ...], Buffer | None] | None:
812834
from zarr.core.buffer import default_buffer_prototype
813835

@@ -816,8 +838,10 @@ async def fetch_full_shard(
816838
return None
817839
return self.unpack_blob(blob)
818840

819-
def fetch_full_shard_sync(
820-
self, byte_getter: Any
841+
def fetch_sync(
842+
self,
843+
byte_getter: Any,
844+
needed_coords: set[tuple[int, ...]] | None = None,
821845
) -> dict[tuple[int, ...], Buffer | None] | None:
822846
from zarr.core.buffer import default_buffer_prototype
823847

@@ -843,6 +867,19 @@ class ShardedChunkLayout(ChunkLayout):
843867

844868
chunk_shape: tuple[int, ...]
845869
inner_chunk_shape: tuple[int, ...]
870+
871+
def needed_coords(self, chunk_selection: SelectorTuple) -> set[tuple[int, ...]] | None:
872+
"""Compute which inner chunks overlap the selection."""
873+
from zarr.core.chunk_grids import ChunkGrid as _ChunkGrid
874+
from zarr.core.indexing import get_indexer
875+
876+
indexer = get_indexer(
877+
chunk_selection,
878+
shape=self.chunk_shape,
879+
chunk_grid=_ChunkGrid.from_sizes(self.chunk_shape, self.inner_chunk_shape),
880+
)
881+
return {coords for coords, *_ in indexer}
882+
846883
chunks_per_shard: tuple[int, ...]
847884
inner_transform: ChunkTransform
848885
_index_transform: ChunkTransform
@@ -919,24 +956,36 @@ def pack_blob(
919956

920957
return template.combine(buffers)
921958

922-
async def fetch_full_shard(
923-
self, byte_getter: Any
959+
async def fetch(
960+
self,
961+
byte_getter: Any,
962+
needed_coords: set[tuple[int, ...]] | None = None,
924963
) -> dict[tuple[int, ...], Buffer | None] | None:
925-
"""Fetch shard index + all inner chunks via byte-range reads."""
964+
"""Fetch shard index + inner chunks via byte-range reads.
965+
966+
If ``needed_coords`` is None, fetches all inner chunks.
967+
Otherwise fetches only the specified coordinates.
968+
"""
926969
index = await self._fetch_index(byte_getter)
927970
if index is None:
928971
return None
929-
all_coords = set(np.ndindex(self.chunks_per_shard))
930-
return await self._fetch_chunks(byte_getter, index, all_coords)
972+
coords = (
973+
needed_coords if needed_coords is not None else set(np.ndindex(self.chunks_per_shard))
974+
)
975+
return await self._fetch_chunks(byte_getter, index, coords)
931976

932-
def fetch_full_shard_sync(
933-
self, byte_getter: Any
977+
def fetch_sync(
978+
self,
979+
byte_getter: Any,
980+
needed_coords: set[tuple[int, ...]] | None = None,
934981
) -> dict[tuple[int, ...], Buffer | None] | None:
935982
index = self._fetch_index_sync(byte_getter)
936983
if index is None:
937984
return None
938-
all_coords = set(np.ndindex(self.chunks_per_shard))
939-
return self._fetch_chunks_sync(byte_getter, index, all_coords)
985+
coords = (
986+
needed_coords if needed_coords is not None else set(np.ndindex(self.chunks_per_shard))
987+
)
988+
return self._fetch_chunks_sync(byte_getter, index, coords)
940989

941990
async def _fetch_index(self, byte_getter: Any) -> Any:
942991
from zarr.abc.store import RangeByteRequest, SuffixByteRequest
@@ -1512,14 +1561,16 @@ async def _fetch_and_decode(
15121561
self,
15131562
byte_getter: Any,
15141563
chunk_spec: ArraySpec,
1564+
chunk_selection: SelectorTuple,
15151565
layout: ChunkLayout,
15161566
) -> NDBuffer | None:
1517-
"""IO + compute: fetch all inner chunk buffers, then decode into chunk-shaped array.
1567+
"""IO + compute: fetch inner chunk buffers, then decode into chunk-shaped array.
15181568
1519-
1. IO: ``layout.fetch_full_shard`` fetches the blob or byte-ranges
1569+
1. IO: ``layout.fetch`` fetches only the inner chunks that overlap the selection
15201570
2. Compute: decode each inner chunk and assemble into chunk-shaped output
15211571
"""
1522-
chunk_dict = await layout.fetch_full_shard(byte_getter)
1572+
needed = layout.needed_coords(chunk_selection)
1573+
chunk_dict = await layout.fetch(byte_getter, needed_coords=needed)
15231574
if chunk_dict is None:
15241575
return None
15251576
return self._decode_shard(chunk_dict, chunk_spec, layout)
@@ -1538,7 +1589,10 @@ async def read(
15381589
# Sharded: use selective byte-range reads per shard
15391590
decoded: list[NDBuffer | None] = list(
15401591
await concurrent_map(
1541-
[(bg, cs, self._get_layout(cs)) for bg, cs, *_ in batch],
1592+
[
1593+
(bg, cs, chunk_sel, self._get_layout(cs))
1594+
for bg, cs, chunk_sel, _, _ in batch
1595+
],
15421596
self._fetch_and_decode,
15431597
config.get("async.concurrency"),
15441598
)
@@ -1634,10 +1688,12 @@ def _fetch_and_decode_sync(
16341688
self,
16351689
byte_getter: Any,
16361690
chunk_spec: ArraySpec,
1691+
chunk_selection: SelectorTuple,
16371692
layout: ChunkLayout,
16381693
) -> NDBuffer | None:
1639-
"""Sync IO + compute: fetch all inner chunk buffers, then decode."""
1640-
chunk_dict = layout.fetch_full_shard_sync(byte_getter)
1694+
"""Sync IO + compute: fetch inner chunk buffers, then decode."""
1695+
needed = layout.needed_coords(chunk_selection)
1696+
chunk_dict = layout.fetch_sync(byte_getter, needed_coords=needed)
16411697
if chunk_dict is None:
16421698
return None
16431699
return self._decode_shard(chunk_dict, chunk_spec, layout)
@@ -1657,7 +1713,8 @@ def read_sync(
16571713
if self.layout is not None and self.layout.is_sharded:
16581714
# Sharded: selective byte-range reads per shard
16591715
decoded: list[NDBuffer | None] = [
1660-
self._fetch_and_decode_sync(bg, cs, self._get_layout(cs)) for bg, cs, *_ in batch
1716+
self._fetch_and_decode_sync(bg, cs, chunk_sel, self._get_layout(cs))
1717+
for bg, cs, chunk_sel, _, _ in batch
16611718
]
16621719
else:
16631720
# Non-sharded: fetch full blobs, decode (optionally threaded)

0 commit comments

Comments
 (0)