11from __future__ import annotations
22
33from concurrent .futures import ThreadPoolExecutor
4- from dataclasses import dataclass , field
4+ from dataclasses import dataclass , field , replace
55from itertools import islice , pairwise
66from typing import TYPE_CHECKING , Any
77from warnings import warn
@@ -122,47 +122,78 @@ def __post_init__(self) -> None:
122122 bb_sync .append (bb_codec )
123123 self ._bb_codecs = tuple (bb_sync )
124124
125+ def _spec_for_shape (self , shape : tuple [int , ...]) -> ArraySpec :
126+ """Build an ArraySpec with the given shape, inheriting dtype/fill/config/prototype."""
127+ if shape == self ._ab_spec .shape :
128+ return self ._ab_spec
129+ return replace (self ._ab_spec , shape = shape )
130+
125131 def decode_chunk (
126132 self ,
127133 chunk_bytes : Buffer ,
134+ chunk_shape : tuple [int , ...] | None = None ,
128135 ) -> NDBuffer :
129136 """Decode a single chunk through the full codec chain, synchronously.
130137
131138 Pure compute -- no IO.
139+
140+ Parameters
141+ ----------
142+ chunk_bytes : Buffer
143+ The encoded chunk bytes.
144+ chunk_shape : tuple[int, ...] or None
145+ The shape of this chunk. If None, uses the shape from the
146+ ArraySpec provided at construction. Required for rectilinear
147+ grids where chunks have different shapes.
132148 """
149+ spec = self ._ab_spec if chunk_shape is None else self ._spec_for_shape (chunk_shape )
150+
133151 data : Buffer = chunk_bytes
134152 for bb_codec in reversed (self ._bb_codecs ):
135- data = bb_codec ._decode_sync (data , self . _ab_spec )
153+ data = bb_codec ._decode_sync (data , spec )
136154
137- chunk_array : NDBuffer = self ._ab_codec ._decode_sync (data , self . _ab_spec )
155+ chunk_array : NDBuffer = self ._ab_codec ._decode_sync (data , spec )
138156
139- for aa_codec , spec in reversed (self ._aa_codecs ):
140- chunk_array = aa_codec ._decode_sync (chunk_array , spec )
157+ for aa_codec , aa_spec in reversed (self ._aa_codecs ):
158+ aa_spec_resolved = aa_spec if chunk_shape is None else self ._spec_for_shape (chunk_shape )
159+ chunk_array = aa_codec ._decode_sync (chunk_array , aa_spec_resolved )
141160
142161 return chunk_array
143162
144163 def encode_chunk (
145164 self ,
146165 chunk_array : NDBuffer ,
166+ chunk_shape : tuple [int , ...] | None = None ,
147167 ) -> Buffer | None :
148168 """Encode a single chunk through the full codec chain, synchronously.
149169
150170 Pure compute -- no IO.
171+
172+ Parameters
173+ ----------
174+ chunk_array : NDBuffer
175+ The chunk data to encode.
176+ chunk_shape : tuple[int, ...] or None
177+ The shape of this chunk. If None, uses the shape from the
178+ ArraySpec provided at construction.
151179 """
180+ spec = self ._ab_spec if chunk_shape is None else self ._spec_for_shape (chunk_shape )
181+
152182 aa_data : NDBuffer = chunk_array
153- for aa_codec , spec in self ._aa_codecs :
154- aa_result = aa_codec ._encode_sync (aa_data , spec )
183+ for aa_codec , aa_spec in self ._aa_codecs :
184+ aa_spec_resolved = aa_spec if chunk_shape is None else self ._spec_for_shape (chunk_shape )
185+ aa_result = aa_codec ._encode_sync (aa_data , aa_spec_resolved )
155186 if aa_result is None :
156187 return None
157188 aa_data = aa_result
158189
159- ab_result = self ._ab_codec ._encode_sync (aa_data , self . _ab_spec )
190+ ab_result = self ._ab_codec ._encode_sync (aa_data , spec )
160191 if ab_result is None :
161192 return None
162193
163194 bb_data : Buffer = ab_result
164195 for bb_codec in self ._bb_codecs :
165- bb_result = bb_codec ._encode_sync (bb_data , self . _ab_spec )
196+ bb_result = bb_codec ._encode_sync (bb_data , spec )
166197 if bb_result is None :
167198 return None
168199 bb_data = bb_result
@@ -1104,7 +1135,7 @@ def _transform_read(
11041135 return self ._decode_shard (raw , chunk_spec , self .shard_layout )
11051136
11061137 assert self .chunk_transform is not None
1107- return self .chunk_transform .decode_chunk (raw )
1138+ return self .chunk_transform .decode_chunk (raw , chunk_shape = chunk_spec . shape )
11081139
11091140 def _decode_shard (self , blob : Buffer , shard_spec : ArraySpec , layout : ShardLayout ) -> NDBuffer :
11101141 """Decode a full shard blob into a shard-shaped array. Pure compute.
@@ -1163,14 +1194,18 @@ def _transform_write(
11631194
11641195 assert self .chunk_transform is not None
11651196
1197+ chunk_shape = chunk_spec .shape
1198+
11661199 if existing is not None :
1167- chunk_array : NDBuffer | None = self .chunk_transform .decode_chunk (existing )
1200+ chunk_array : NDBuffer | None = self .chunk_transform .decode_chunk (
1201+ existing , chunk_shape = chunk_shape
1202+ )
11681203 else :
11691204 chunk_array = None
11701205
11711206 if chunk_array is None :
11721207 chunk_array = chunk_spec .prototype .nd_buffer .create (
1173- shape = chunk_spec . shape ,
1208+ shape = chunk_shape ,
11741209 dtype = chunk_spec .dtype .to_native_dtype (),
11751210 fill_value = fill_value_or_default (chunk_spec ),
11761211 )
@@ -1188,7 +1223,7 @@ def _transform_write(
11881223 chunk_value = chunk_value [item ]
11891224 chunk_array [chunk_selection ] = chunk_value
11901225
1191- return self .chunk_transform .encode_chunk (chunk_array )
1226+ return self .chunk_transform .encode_chunk (chunk_array , chunk_shape = chunk_shape )
11921227
11931228 def _transform_write_shard (
11941229 self ,
0 commit comments