Skip to content

Commit c731cf2

Browse files
committed
fix: handle rectilinear chunks
1 parent 9b834a4 commit c731cf2

2 files changed

Lines changed: 54 additions & 14 deletions

File tree

src/zarr/core/array.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,10 +234,15 @@ def create_codec_pipeline(metadata: ArrayMetadata, *, store: Store | None = None
234234
if hasattr(pipeline, "chunk_transform") and pipeline.chunk_transform is None:
235235
from zarr.core.metadata.v3 import RegularChunkGridMetadata
236236

237+
# Use the regular chunk shape if available, otherwise use a
238+
# placeholder shape. The ChunkTransform is shape-agnostic —
239+
# the actual chunk shape is passed per-call at decode/encode time.
237240
if isinstance(metadata.chunk_grid, RegularChunkGridMetadata):
238241
chunk_shape = metadata.chunk_grid.chunk_shape
239242
else:
240-
chunk_shape = metadata.shape # fallback for rectilinear
243+
# Rectilinear: use a 1-element shape per dimension as placeholder.
244+
# Only dtype/fill_value/config matter for codec evolution.
245+
chunk_shape = (1,) * len(metadata.shape)
241246
chunk_spec = ArraySpec(
242247
shape=chunk_shape,
243248
dtype=metadata.data_type,

src/zarr/core/codec_pipeline.py

Lines changed: 48 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
from concurrent.futures import ThreadPoolExecutor
4-
from dataclasses import dataclass, field
4+
from dataclasses import dataclass, field, replace
55
from itertools import islice, pairwise
66
from typing import TYPE_CHECKING, Any
77
from warnings import warn
@@ -122,47 +122,78 @@ def __post_init__(self) -> None:
122122
bb_sync.append(bb_codec)
123123
self._bb_codecs = tuple(bb_sync)
124124

125+
def _spec_for_shape(self, shape: tuple[int, ...]) -> ArraySpec:
126+
"""Build an ArraySpec with the given shape, inheriting dtype/fill/config/prototype."""
127+
if shape == self._ab_spec.shape:
128+
return self._ab_spec
129+
return replace(self._ab_spec, shape=shape)
130+
125131
def decode_chunk(
126132
self,
127133
chunk_bytes: Buffer,
134+
chunk_shape: tuple[int, ...] | None = None,
128135
) -> NDBuffer:
129136
"""Decode a single chunk through the full codec chain, synchronously.
130137
131138
Pure compute -- no IO.
139+
140+
Parameters
141+
----------
142+
chunk_bytes : Buffer
143+
The encoded chunk bytes.
144+
chunk_shape : tuple[int, ...] or None
145+
The shape of this chunk. If None, uses the shape from the
146+
ArraySpec provided at construction. Required for rectilinear
147+
grids where chunks have different shapes.
132148
"""
149+
spec = self._ab_spec if chunk_shape is None else self._spec_for_shape(chunk_shape)
150+
133151
data: Buffer = chunk_bytes
134152
for bb_codec in reversed(self._bb_codecs):
135-
data = bb_codec._decode_sync(data, self._ab_spec)
153+
data = bb_codec._decode_sync(data, spec)
136154

137-
chunk_array: NDBuffer = self._ab_codec._decode_sync(data, self._ab_spec)
155+
chunk_array: NDBuffer = self._ab_codec._decode_sync(data, spec)
138156

139-
for aa_codec, spec in reversed(self._aa_codecs):
140-
chunk_array = aa_codec._decode_sync(chunk_array, spec)
157+
for aa_codec, aa_spec in reversed(self._aa_codecs):
158+
aa_spec_resolved = aa_spec if chunk_shape is None else self._spec_for_shape(chunk_shape)
159+
chunk_array = aa_codec._decode_sync(chunk_array, aa_spec_resolved)
141160

142161
return chunk_array
143162

144163
def encode_chunk(
145164
self,
146165
chunk_array: NDBuffer,
166+
chunk_shape: tuple[int, ...] | None = None,
147167
) -> Buffer | None:
148168
"""Encode a single chunk through the full codec chain, synchronously.
149169
150170
Pure compute -- no IO.
171+
172+
Parameters
173+
----------
174+
chunk_array : NDBuffer
175+
The chunk data to encode.
176+
chunk_shape : tuple[int, ...] or None
177+
The shape of this chunk. If None, uses the shape from the
178+
ArraySpec provided at construction.
151179
"""
180+
spec = self._ab_spec if chunk_shape is None else self._spec_for_shape(chunk_shape)
181+
152182
aa_data: NDBuffer = chunk_array
153-
for aa_codec, spec in self._aa_codecs:
154-
aa_result = aa_codec._encode_sync(aa_data, spec)
183+
for aa_codec, aa_spec in self._aa_codecs:
184+
aa_spec_resolved = aa_spec if chunk_shape is None else self._spec_for_shape(chunk_shape)
185+
aa_result = aa_codec._encode_sync(aa_data, aa_spec_resolved)
155186
if aa_result is None:
156187
return None
157188
aa_data = aa_result
158189

159-
ab_result = self._ab_codec._encode_sync(aa_data, self._ab_spec)
190+
ab_result = self._ab_codec._encode_sync(aa_data, spec)
160191
if ab_result is None:
161192
return None
162193

163194
bb_data: Buffer = ab_result
164195
for bb_codec in self._bb_codecs:
165-
bb_result = bb_codec._encode_sync(bb_data, self._ab_spec)
196+
bb_result = bb_codec._encode_sync(bb_data, spec)
166197
if bb_result is None:
167198
return None
168199
bb_data = bb_result
@@ -1104,7 +1135,7 @@ def _transform_read(
11041135
return self._decode_shard(raw, chunk_spec, self.shard_layout)
11051136

11061137
assert self.chunk_transform is not None
1107-
return self.chunk_transform.decode_chunk(raw)
1138+
return self.chunk_transform.decode_chunk(raw, chunk_shape=chunk_spec.shape)
11081139

11091140
def _decode_shard(self, blob: Buffer, shard_spec: ArraySpec, layout: ShardLayout) -> NDBuffer:
11101141
"""Decode a full shard blob into a shard-shaped array. Pure compute.
@@ -1163,14 +1194,18 @@ def _transform_write(
11631194

11641195
assert self.chunk_transform is not None
11651196

1197+
chunk_shape = chunk_spec.shape
1198+
11661199
if existing is not None:
1167-
chunk_array: NDBuffer | None = self.chunk_transform.decode_chunk(existing)
1200+
chunk_array: NDBuffer | None = self.chunk_transform.decode_chunk(
1201+
existing, chunk_shape=chunk_shape
1202+
)
11681203
else:
11691204
chunk_array = None
11701205

11711206
if chunk_array is None:
11721207
chunk_array = chunk_spec.prototype.nd_buffer.create(
1173-
shape=chunk_spec.shape,
1208+
shape=chunk_shape,
11741209
dtype=chunk_spec.dtype.to_native_dtype(),
11751210
fill_value=fill_value_or_default(chunk_spec),
11761211
)
@@ -1188,7 +1223,7 @@ def _transform_write(
11881223
chunk_value = chunk_value[item]
11891224
chunk_array[chunk_selection] = chunk_value
11901225

1191-
return self.chunk_transform.encode_chunk(chunk_array)
1226+
return self.chunk_transform.encode_chunk(chunk_array, chunk_shape=chunk_shape)
11921227

11931228
def _transform_write_shard(
11941229
self,

0 commit comments

Comments
 (0)