@@ -333,6 +333,12 @@ def __init__(
333333 # object.__setattr__(self, "_get_chunk_spec", lru_cache()(self._get_chunk_spec))
334334 object .__setattr__ (self , "_get_index_chunk_spec" , lru_cache ()(self ._get_index_chunk_spec ))
335335 object .__setattr__ (self , "_get_chunks_per_shard" , lru_cache ()(self ._get_chunks_per_shard ))
336+ object .__setattr__ (
337+ self , "_get_inner_chunk_transform" , lru_cache ()(self ._get_inner_chunk_transform )
338+ )
339+ object .__setattr__ (
340+ self , "_get_index_chunk_transform" , lru_cache ()(self ._get_index_chunk_transform )
341+ )
336342
337343 # todo: typedict return type
338344 def __getstate__ (self ) -> dict [str , Any ]:
@@ -349,6 +355,12 @@ def __setstate__(self, state: dict[str, Any]) -> None:
349355 # object.__setattr__(self, "_get_chunk_spec", lru_cache()(self._get_chunk_spec))
350356 object .__setattr__ (self , "_get_index_chunk_spec" , lru_cache ()(self ._get_index_chunk_spec ))
351357 object .__setattr__ (self , "_get_chunks_per_shard" , lru_cache ()(self ._get_chunks_per_shard ))
358+ object .__setattr__ (
359+ self , "_get_inner_chunk_transform" , lru_cache ()(self ._get_inner_chunk_transform )
360+ )
361+ object .__setattr__ (
362+ self , "_get_index_chunk_transform" , lru_cache ()(self ._get_index_chunk_transform )
363+ )
352364
353365 @classmethod
354366 def from_dict (cls , data : dict [str , JSON ]) -> Self :
@@ -403,6 +415,160 @@ def validate(
403415 f"needs to be divisible by the shard's inner `chunk_shape` (got { self .chunk_shape } )."
404416 )
405417
418+ def _get_inner_chunk_transform (self , shard_spec : ArraySpec ) -> Any :
419+ """Build a ChunkTransform for inner codecs, bound to the inner chunk spec."""
420+ from zarr .core .codec_pipeline import ChunkTransform
421+
422+ chunk_spec = self ._get_chunk_spec (shard_spec )
423+ evolved = tuple (c .evolve_from_array_spec (array_spec = chunk_spec ) for c in self .codecs )
424+ return ChunkTransform (codecs = evolved , array_spec = chunk_spec )
425+
426+ def _get_index_chunk_transform (self , chunks_per_shard : tuple [int , ...]) -> Any :
427+ """Build a ChunkTransform for index codecs."""
428+ from zarr .core .codec_pipeline import ChunkTransform
429+
430+ index_spec = self ._get_index_chunk_spec (chunks_per_shard )
431+ evolved = tuple (c .evolve_from_array_spec (array_spec = index_spec ) for c in self .index_codecs )
432+ return ChunkTransform (codecs = evolved , array_spec = index_spec )
433+
434+ def _decode_shard_index_sync (
435+ self , index_bytes : Buffer , chunks_per_shard : tuple [int , ...]
436+ ) -> _ShardIndex :
437+ """Decode shard index synchronously using ChunkTransform."""
438+ index_transform = self ._get_index_chunk_transform (chunks_per_shard )
439+ index_array = index_transform .decode_chunk (index_bytes )
440+ return _ShardIndex (index_array .as_numpy_array ())
441+
442+ def _encode_shard_index_sync (self , index : _ShardIndex ) -> Buffer :
443+ """Encode shard index synchronously using ChunkTransform."""
444+ index_transform = self ._get_index_chunk_transform (index .chunks_per_shard )
445+ index_nd = get_ndbuffer_class ().from_numpy_array (index .offsets_and_lengths )
446+ result = index_transform .encode_chunk (index_nd )
447+ assert result is not None
448+ return result
449+
450+ def _shard_reader_from_bytes_sync (
451+ self , buf : Buffer , chunks_per_shard : tuple [int , ...]
452+ ) -> _ShardReader :
453+ """Sync version of _ShardReader.from_bytes."""
454+ shard_index_size = self ._shard_index_size (chunks_per_shard )
455+ if self .index_location == ShardingCodecIndexLocation .start :
456+ shard_index_bytes = buf [:shard_index_size ]
457+ else :
458+ shard_index_bytes = buf [- shard_index_size :]
459+ index = self ._decode_shard_index_sync (shard_index_bytes , chunks_per_shard )
460+ reader = _ShardReader ()
461+ reader .buf = buf
462+ reader .index = index
463+ return reader
464+
465+ def _decode_sync (
466+ self ,
467+ shard_bytes : Buffer ,
468+ shard_spec : ArraySpec ,
469+ ) -> NDBuffer :
470+ """Decode a full shard synchronously."""
471+ shard_shape = shard_spec .shape
472+ chunk_shape = self .chunk_shape
473+ chunks_per_shard = self ._get_chunks_per_shard (shard_spec )
474+ chunk_spec = self ._get_chunk_spec (shard_spec )
475+ inner_transform = self ._get_inner_chunk_transform (shard_spec )
476+
477+ indexer = BasicIndexer (
478+ tuple (slice (0 , s ) for s in shard_shape ),
479+ shape = shard_shape ,
480+ chunk_grid = RegularChunkGrid (chunk_shape = chunk_shape ),
481+ )
482+
483+ out = chunk_spec .prototype .nd_buffer .empty (
484+ shape = shard_shape ,
485+ dtype = shard_spec .dtype .to_native_dtype (),
486+ order = shard_spec .order ,
487+ )
488+
489+ shard_dict = self ._shard_reader_from_bytes_sync (shard_bytes , chunks_per_shard )
490+
491+ if shard_dict .index .is_all_empty ():
492+ out .fill (shard_spec .fill_value )
493+ return out
494+
495+ for chunk_coords , chunk_selection , out_selection , _ in indexer :
496+ try :
497+ chunk_bytes = shard_dict [chunk_coords ]
498+ except KeyError :
499+ out [out_selection ] = shard_spec .fill_value
500+ continue
501+ chunk_array = inner_transform .decode_chunk (chunk_bytes )
502+ out [out_selection ] = chunk_array [chunk_selection ]
503+
504+ return out
505+
506+ def _encode_sync (
507+ self ,
508+ shard_array : NDBuffer ,
509+ shard_spec : ArraySpec ,
510+ ) -> Buffer | None :
511+ """Encode a full shard synchronously."""
512+ shard_shape = shard_spec .shape
513+ chunks_per_shard = self ._get_chunks_per_shard (shard_spec )
514+ inner_transform = self ._get_inner_chunk_transform (shard_spec )
515+
516+ indexer = BasicIndexer (
517+ tuple (slice (0 , s ) for s in shard_shape ),
518+ shape = shard_shape ,
519+ chunk_grid = RegularChunkGrid (chunk_shape = self .chunk_shape ),
520+ )
521+
522+ shard_builder : dict [tuple [int , ...], Buffer | None ] = dict .fromkeys (
523+ morton_order_iter (chunks_per_shard )
524+ )
525+
526+ for chunk_coords , chunk_selection , out_selection , _ in indexer :
527+ chunk_array = shard_array [out_selection ]
528+ encoded = inner_transform .encode_chunk (chunk_array )
529+ shard_builder [chunk_coords ] = encoded
530+
531+ return self ._encode_shard_dict_sync (
532+ shard_builder ,
533+ chunks_per_shard = chunks_per_shard ,
534+ buffer_prototype = default_buffer_prototype (),
535+ )
536+
537+ def _encode_shard_dict_sync (
538+ self ,
539+ shard_dict : ShardMapping ,
540+ chunks_per_shard : tuple [int , ...],
541+ buffer_prototype : BufferPrototype ,
542+ ) -> Buffer | None :
543+ """Sync version of _encode_shard_dict."""
544+ index = _ShardIndex .create_empty (chunks_per_shard )
545+ buffers = []
546+ template = buffer_prototype .buffer .create_zero_length ()
547+ chunk_start = 0
548+
549+ for chunk_coords in morton_order_iter (chunks_per_shard ):
550+ value = shard_dict .get (chunk_coords )
551+ if value is None or len (value ) == 0 :
552+ continue
553+ chunk_length = len (value )
554+ buffers .append (value )
555+ index .set_chunk_slice (chunk_coords , slice (chunk_start , chunk_start + chunk_length ))
556+ chunk_start += chunk_length
557+
558+ if len (buffers ) == 0 :
559+ return None
560+
561+ index_bytes = self ._encode_shard_index_sync (index )
562+ if self .index_location == ShardingCodecIndexLocation .start :
563+ empty_chunks_mask = index .offsets_and_lengths [..., 0 ] == MAX_UINT_64
564+ index .offsets_and_lengths [~ empty_chunks_mask , 0 ] += len (index_bytes )
565+ index_bytes = self ._encode_shard_index_sync (index )
566+ buffers .insert (0 , index_bytes )
567+ else :
568+ buffers .append (index_bytes )
569+
570+ return template .combine (buffers )
571+
406572 async def _decode_single (
407573 self ,
408574 shard_bytes : Buffer ,
0 commit comments