Skip to content

Commit 59491e1

Browse files
committed
add prepared write logic
1 parent 5a2a884 commit 59491e1

1 file changed

Lines changed: 181 additions & 2 deletions

File tree

src/zarr/abc/codec.py

Lines changed: 181 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22

33
from abc import abstractmethod
44
from collections.abc import Mapping
5-
from typing import TYPE_CHECKING, Generic, Protocol, TypeGuard, TypeVar, runtime_checkable
5+
from dataclasses import dataclass, field
6+
from typing import TYPE_CHECKING, Any, Generic, Protocol, TypeGuard, TypeVar, runtime_checkable
67

78
from typing_extensions import ReadOnly, TypedDict
89

@@ -19,7 +20,7 @@
1920
from zarr.core.array_spec import ArraySpec
2021
from zarr.core.chunk_grids import ChunkGrid
2122
from zarr.core.dtype.wrapper import TBaseDType, TBaseScalar, ZDType
22-
from zarr.core.indexing import SelectorTuple
23+
from zarr.core.indexing import ChunkProjection, SelectorTuple
2324
from zarr.core.metadata import ArrayMetadata
2425

2526
__all__ = [
@@ -32,6 +33,7 @@
3233
"CodecInput",
3334
"CodecOutput",
3435
"CodecPipeline",
36+
"PreparedWrite",
3537
"SupportsSyncCodec",
3638
]
3739

@@ -204,9 +206,186 @@ class ArrayArrayCodec(BaseCodec[NDBuffer, NDBuffer]):
204206
"""Base class for array-to-array codecs."""
205207

206208

209+
@dataclass
210+
class PreparedWrite:
211+
"""Result of ``prepare_write``: existing encoded chunk bytes + selection info."""
212+
213+
chunk_dict: dict[tuple[int, ...], Buffer | None]
214+
inner_codec_chain: Any # CodecChain — typed as Any to avoid circular import
215+
inner_chunk_spec: ArraySpec
216+
indexer: list[ChunkProjection]
217+
value_selection: SelectorTuple | None = None
218+
write_full_shard: bool = True
219+
is_complete_shard: bool = False
220+
shard_data: NDBuffer | None = field(default=None)
221+
222+
207223
class ArrayBytesCodec(BaseCodec[NDBuffer, Buffer]):
208224
"""Base class for array-to-bytes codecs."""
209225

226+
@property
227+
def inner_codec_chain(self) -> Any:
228+
"""The codec chain for decoding inner chunks after deserialization.
229+
230+
Returns ``None`` by default — the pipeline should use its own codec chain.
231+
``ShardingCodec`` overrides to return its inner codec chain.
232+
"""
233+
return None
234+
235+
def deserialize(
236+
self, raw: Buffer | None, chunk_spec: ArraySpec
237+
) -> dict[tuple[int, ...], Buffer | None]:
238+
"""Unpack stored bytes into per-inner-chunk buffers.
239+
240+
Default: single chunk keyed at ``(0,)``.
241+
``ShardingCodec`` overrides to decode the shard index and slice the
242+
blob into per-chunk buffers.
243+
"""
244+
return {(0,): raw}
245+
246+
def serialize(
247+
self,
248+
chunk_dict: dict[tuple[int, ...], Buffer | None],
249+
chunk_spec: ArraySpec,
250+
) -> Buffer | None:
251+
"""Pack per-inner-chunk buffers into a storage blob.
252+
253+
Default: return the single chunk's bytes (or ``None`` if absent).
254+
``ShardingCodec`` overrides to concatenate chunks and build an index.
255+
Returns ``None`` when all chunks are empty (caller should delete the key).
256+
"""
257+
return chunk_dict.get((0,))
258+
259+
# ------------------------------------------------------------------
260+
# prepare / finalize — sync
261+
# ------------------------------------------------------------------
262+
263+
def prepare_read_sync(
264+
self,
265+
byte_getter: Any,
266+
chunk_spec: ArraySpec,
267+
chunk_selection: SelectorTuple,
268+
codec_chain: Any,
269+
aa_chain: Any,
270+
ab_pair: Any,
271+
bb_chain: Any,
272+
) -> NDBuffer | None:
273+
"""Sync IO + full decode for the selected region."""
274+
raw = byte_getter.get_sync(prototype=chunk_spec.prototype)
275+
chunk_array: NDBuffer | None = codec_chain.decode_chunk(
276+
raw, chunk_spec, aa_chain, ab_pair, bb_chain
277+
)
278+
if chunk_array is not None:
279+
return chunk_array[chunk_selection]
280+
return None
281+
282+
def prepare_write_sync(
283+
self,
284+
byte_setter: Any,
285+
chunk_spec: ArraySpec,
286+
chunk_selection: SelectorTuple,
287+
out_selection: SelectorTuple,
288+
replace: bool,
289+
codec_chain: Any,
290+
) -> PreparedWrite:
291+
"""Sync IO + deserialize. Returns a :class:`PreparedWrite`."""
292+
existing: Buffer | None = None
293+
if not replace:
294+
existing = byte_setter.get_sync(prototype=chunk_spec.prototype)
295+
chunk_dict = self.deserialize(existing, chunk_spec)
296+
inner_chain = self.inner_codec_chain or codec_chain
297+
return PreparedWrite(
298+
chunk_dict=chunk_dict,
299+
inner_codec_chain=inner_chain,
300+
inner_chunk_spec=chunk_spec,
301+
indexer=[
302+
( # type: ignore[list-item]
303+
(0,),
304+
chunk_selection,
305+
out_selection,
306+
replace,
307+
)
308+
],
309+
)
310+
311+
def finalize_write_sync(
312+
self,
313+
prepared: PreparedWrite,
314+
chunk_spec: ArraySpec,
315+
byte_setter: Any,
316+
) -> None:
317+
"""Serialize the prepared *chunk_dict* and write to store."""
318+
blob = self.serialize(prepared.chunk_dict, chunk_spec)
319+
if blob is None:
320+
byte_setter.delete_sync()
321+
else:
322+
byte_setter.set_sync(blob)
323+
324+
# ------------------------------------------------------------------
325+
# prepare / finalize — async
326+
# ------------------------------------------------------------------
327+
328+
async def prepare_read(
329+
self,
330+
byte_getter: Any,
331+
chunk_spec: ArraySpec,
332+
chunk_selection: SelectorTuple,
333+
codec_chain: Any,
334+
aa_chain: Any,
335+
ab_pair: Any,
336+
bb_chain: Any,
337+
) -> NDBuffer | None:
338+
"""Async IO + full decode for the selected region."""
339+
raw = await byte_getter.get(prototype=chunk_spec.prototype)
340+
chunk_array: NDBuffer | None = codec_chain.decode_chunk(
341+
raw, chunk_spec, aa_chain, ab_pair, bb_chain
342+
)
343+
if chunk_array is not None:
344+
return chunk_array[chunk_selection]
345+
return None
346+
347+
async def prepare_write(
348+
self,
349+
byte_setter: Any,
350+
chunk_spec: ArraySpec,
351+
chunk_selection: SelectorTuple,
352+
out_selection: SelectorTuple,
353+
replace: bool,
354+
codec_chain: Any,
355+
) -> PreparedWrite:
356+
"""Async IO + deserialize. Returns a :class:`PreparedWrite`."""
357+
existing: Buffer | None = None
358+
if not replace:
359+
existing = await byte_setter.get(prototype=chunk_spec.prototype)
360+
chunk_dict = self.deserialize(existing, chunk_spec)
361+
inner_chain = self.inner_codec_chain or codec_chain
362+
return PreparedWrite(
363+
chunk_dict=chunk_dict,
364+
inner_codec_chain=inner_chain,
365+
inner_chunk_spec=chunk_spec,
366+
indexer=[
367+
( # type: ignore[list-item]
368+
(0,),
369+
chunk_selection,
370+
out_selection,
371+
replace,
372+
)
373+
],
374+
)
375+
376+
async def finalize_write(
377+
self,
378+
prepared: PreparedWrite,
379+
chunk_spec: ArraySpec,
380+
byte_setter: Any,
381+
) -> None:
382+
"""Async version of :meth:`finalize_write_sync`."""
383+
blob = self.serialize(prepared.chunk_dict, chunk_spec)
384+
if blob is None:
385+
await byte_setter.delete()
386+
else:
387+
await byte_setter.set(blob)
388+
210389

211390
class BytesBytesCodec(BaseCodec[Buffer, Buffer]):
212391
"""Base class for bytes-to-bytes codecs."""

0 commit comments

Comments
 (0)