Skip to content

Commit 47a407f

Browse files
committed
feat: new codec pipeline that uses sync path
1 parent a072c31 commit 47a407f

5 files changed

Lines changed: 1011 additions & 0 deletions

File tree

src/zarr/abc/codec.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ def decode_chunk(self, chunk_bytes: Buffer) -> NDBuffer: ...
9999
def encode_chunk(self, chunk_array: NDBuffer) -> Buffer | None: ...
100100

101101

102+
@runtime_checkable
102103
class SupportsChunkPacking(Protocol):
103104
"""Protocol for codecs that can pack/unpack inner chunks into a storage blob
104105
and manage the prepare/finalize IO lifecycle.

src/zarr/codecs/sharding.py

Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,12 @@ def __init__(
333333
# object.__setattr__(self, "_get_chunk_spec", lru_cache()(self._get_chunk_spec))
334334
object.__setattr__(self, "_get_index_chunk_spec", lru_cache()(self._get_index_chunk_spec))
335335
object.__setattr__(self, "_get_chunks_per_shard", lru_cache()(self._get_chunks_per_shard))
336+
object.__setattr__(
337+
self, "_get_inner_chunk_transform", lru_cache()(self._get_inner_chunk_transform)
338+
)
339+
object.__setattr__(
340+
self, "_get_index_chunk_transform", lru_cache()(self._get_index_chunk_transform)
341+
)
336342

337343
# todo: typedict return type
338344
def __getstate__(self) -> dict[str, Any]:
@@ -349,6 +355,12 @@ def __setstate__(self, state: dict[str, Any]) -> None:
349355
# object.__setattr__(self, "_get_chunk_spec", lru_cache()(self._get_chunk_spec))
350356
object.__setattr__(self, "_get_index_chunk_spec", lru_cache()(self._get_index_chunk_spec))
351357
object.__setattr__(self, "_get_chunks_per_shard", lru_cache()(self._get_chunks_per_shard))
358+
object.__setattr__(
359+
self, "_get_inner_chunk_transform", lru_cache()(self._get_inner_chunk_transform)
360+
)
361+
object.__setattr__(
362+
self, "_get_index_chunk_transform", lru_cache()(self._get_index_chunk_transform)
363+
)
352364

353365
@classmethod
354366
def from_dict(cls, data: dict[str, JSON]) -> Self:
@@ -403,6 +415,160 @@ def validate(
403415
f"needs to be divisible by the shard's inner `chunk_shape` (got {self.chunk_shape})."
404416
)
405417

418+
def _get_inner_chunk_transform(self, shard_spec: ArraySpec) -> Any:
419+
"""Build a ChunkTransform for inner codecs, bound to the inner chunk spec."""
420+
from zarr.core.codec_pipeline import ChunkTransform
421+
422+
chunk_spec = self._get_chunk_spec(shard_spec)
423+
evolved = tuple(c.evolve_from_array_spec(array_spec=chunk_spec) for c in self.codecs)
424+
return ChunkTransform(codecs=evolved, array_spec=chunk_spec)
425+
426+
def _get_index_chunk_transform(self, chunks_per_shard: tuple[int, ...]) -> Any:
427+
"""Build a ChunkTransform for index codecs."""
428+
from zarr.core.codec_pipeline import ChunkTransform
429+
430+
index_spec = self._get_index_chunk_spec(chunks_per_shard)
431+
evolved = tuple(c.evolve_from_array_spec(array_spec=index_spec) for c in self.index_codecs)
432+
return ChunkTransform(codecs=evolved, array_spec=index_spec)
433+
434+
def _decode_shard_index_sync(
435+
self, index_bytes: Buffer, chunks_per_shard: tuple[int, ...]
436+
) -> _ShardIndex:
437+
"""Decode shard index synchronously using ChunkTransform."""
438+
index_transform = self._get_index_chunk_transform(chunks_per_shard)
439+
index_array = index_transform.decode_chunk(index_bytes)
440+
return _ShardIndex(index_array.as_numpy_array())
441+
442+
def _encode_shard_index_sync(self, index: _ShardIndex) -> Buffer:
443+
"""Encode shard index synchronously using ChunkTransform."""
444+
index_transform = self._get_index_chunk_transform(index.chunks_per_shard)
445+
index_nd = get_ndbuffer_class().from_numpy_array(index.offsets_and_lengths)
446+
result = index_transform.encode_chunk(index_nd)
447+
assert result is not None
448+
return result
449+
450+
def _shard_reader_from_bytes_sync(
451+
self, buf: Buffer, chunks_per_shard: tuple[int, ...]
452+
) -> _ShardReader:
453+
"""Sync version of _ShardReader.from_bytes."""
454+
shard_index_size = self._shard_index_size(chunks_per_shard)
455+
if self.index_location == ShardingCodecIndexLocation.start:
456+
shard_index_bytes = buf[:shard_index_size]
457+
else:
458+
shard_index_bytes = buf[-shard_index_size:]
459+
index = self._decode_shard_index_sync(shard_index_bytes, chunks_per_shard)
460+
reader = _ShardReader()
461+
reader.buf = buf
462+
reader.index = index
463+
return reader
464+
465+
def _decode_sync(
466+
self,
467+
shard_bytes: Buffer,
468+
shard_spec: ArraySpec,
469+
) -> NDBuffer:
470+
"""Decode a full shard synchronously."""
471+
shard_shape = shard_spec.shape
472+
chunk_shape = self.chunk_shape
473+
chunks_per_shard = self._get_chunks_per_shard(shard_spec)
474+
chunk_spec = self._get_chunk_spec(shard_spec)
475+
inner_transform = self._get_inner_chunk_transform(shard_spec)
476+
477+
indexer = BasicIndexer(
478+
tuple(slice(0, s) for s in shard_shape),
479+
shape=shard_shape,
480+
chunk_grid=RegularChunkGrid(chunk_shape=chunk_shape),
481+
)
482+
483+
out = chunk_spec.prototype.nd_buffer.empty(
484+
shape=shard_shape,
485+
dtype=shard_spec.dtype.to_native_dtype(),
486+
order=shard_spec.order,
487+
)
488+
489+
shard_dict = self._shard_reader_from_bytes_sync(shard_bytes, chunks_per_shard)
490+
491+
if shard_dict.index.is_all_empty():
492+
out.fill(shard_spec.fill_value)
493+
return out
494+
495+
for chunk_coords, chunk_selection, out_selection, _ in indexer:
496+
try:
497+
chunk_bytes = shard_dict[chunk_coords]
498+
except KeyError:
499+
out[out_selection] = shard_spec.fill_value
500+
continue
501+
chunk_array = inner_transform.decode_chunk(chunk_bytes)
502+
out[out_selection] = chunk_array[chunk_selection]
503+
504+
return out
505+
506+
def _encode_sync(
507+
self,
508+
shard_array: NDBuffer,
509+
shard_spec: ArraySpec,
510+
) -> Buffer | None:
511+
"""Encode a full shard synchronously."""
512+
shard_shape = shard_spec.shape
513+
chunks_per_shard = self._get_chunks_per_shard(shard_spec)
514+
inner_transform = self._get_inner_chunk_transform(shard_spec)
515+
516+
indexer = BasicIndexer(
517+
tuple(slice(0, s) for s in shard_shape),
518+
shape=shard_shape,
519+
chunk_grid=RegularChunkGrid(chunk_shape=self.chunk_shape),
520+
)
521+
522+
shard_builder: dict[tuple[int, ...], Buffer | None] = dict.fromkeys(
523+
morton_order_iter(chunks_per_shard)
524+
)
525+
526+
for chunk_coords, chunk_selection, out_selection, _ in indexer:
527+
chunk_array = shard_array[out_selection]
528+
encoded = inner_transform.encode_chunk(chunk_array)
529+
shard_builder[chunk_coords] = encoded
530+
531+
return self._encode_shard_dict_sync(
532+
shard_builder,
533+
chunks_per_shard=chunks_per_shard,
534+
buffer_prototype=default_buffer_prototype(),
535+
)
536+
537+
def _encode_shard_dict_sync(
538+
self,
539+
shard_dict: ShardMapping,
540+
chunks_per_shard: tuple[int, ...],
541+
buffer_prototype: BufferPrototype,
542+
) -> Buffer | None:
543+
"""Sync version of _encode_shard_dict."""
544+
index = _ShardIndex.create_empty(chunks_per_shard)
545+
buffers = []
546+
template = buffer_prototype.buffer.create_zero_length()
547+
chunk_start = 0
548+
549+
for chunk_coords in morton_order_iter(chunks_per_shard):
550+
value = shard_dict.get(chunk_coords)
551+
if value is None or len(value) == 0:
552+
continue
553+
chunk_length = len(value)
554+
buffers.append(value)
555+
index.set_chunk_slice(chunk_coords, slice(chunk_start, chunk_start + chunk_length))
556+
chunk_start += chunk_length
557+
558+
if len(buffers) == 0:
559+
return None
560+
561+
index_bytes = self._encode_shard_index_sync(index)
562+
if self.index_location == ShardingCodecIndexLocation.start:
563+
empty_chunks_mask = index.offsets_and_lengths[..., 0] == MAX_UINT_64
564+
index.offsets_and_lengths[~empty_chunks_mask, 0] += len(index_bytes)
565+
index_bytes = self._encode_shard_index_sync(index)
566+
buffers.insert(0, index_bytes)
567+
else:
568+
buffers.append(index_bytes)
569+
570+
return template.combine(buffers)
571+
406572
async def _decode_single(
407573
self,
408574
shard_bytes: Buffer,

0 commit comments

Comments
 (0)