|
2 | 2 |
|
3 | 3 | from abc import abstractmethod |
4 | 4 | from collections.abc import Mapping |
5 | | -from typing import TYPE_CHECKING, Generic, Protocol, TypeGuard, TypeVar, runtime_checkable |
| 5 | +from dataclasses import dataclass, field |
| 6 | +from typing import TYPE_CHECKING, Any, Generic, Protocol, TypeGuard, TypeVar, runtime_checkable |
6 | 7 |
|
7 | 8 | from typing_extensions import ReadOnly, TypedDict |
8 | 9 |
|
|
19 | 20 | from zarr.core.array_spec import ArraySpec |
20 | 21 | from zarr.core.chunk_grids import ChunkGrid |
21 | 22 | from zarr.core.dtype.wrapper import TBaseDType, TBaseScalar, ZDType |
22 | | - from zarr.core.indexing import SelectorTuple |
| 23 | + from zarr.core.indexing import ChunkProjection, SelectorTuple |
23 | 24 | from zarr.core.metadata import ArrayMetadata |
24 | 25 |
|
25 | 26 | __all__ = [ |
|
32 | 33 | "CodecInput", |
33 | 34 | "CodecOutput", |
34 | 35 | "CodecPipeline", |
| 36 | + "PreparedWrite", |
35 | 37 | "SupportsSyncCodec", |
36 | 38 | ] |
37 | 39 |
|
@@ -204,9 +206,186 @@ class ArrayArrayCodec(BaseCodec[NDBuffer, NDBuffer]): |
204 | 206 | """Base class for array-to-array codecs.""" |
205 | 207 |
|
206 | 208 |
|
| 209 | +@dataclass |
| 210 | +class PreparedWrite: |
| 211 | + """Result of ``prepare_write``: existing encoded chunk bytes + selection info.""" |
| 212 | + |
| 213 | + chunk_dict: dict[tuple[int, ...], Buffer | None] |
| 214 | + inner_codec_chain: Any # CodecChain — typed as Any to avoid circular import |
| 215 | + inner_chunk_spec: ArraySpec |
| 216 | + indexer: list[ChunkProjection] |
| 217 | + value_selection: SelectorTuple | None = None |
| 218 | + write_full_shard: bool = True |
| 219 | + is_complete_shard: bool = False |
| 220 | + shard_data: NDBuffer | None = field(default=None) |
| 221 | + |
| 222 | + |
207 | 223 | class ArrayBytesCodec(BaseCodec[NDBuffer, Buffer]): |
208 | 224 | """Base class for array-to-bytes codecs.""" |
209 | 225 |
|
| 226 | + @property |
| 227 | + def inner_codec_chain(self) -> Any: |
| 228 | + """The codec chain for decoding inner chunks after deserialization. |
| 229 | +
|
| 230 | + Returns ``None`` by default — the pipeline should use its own codec chain. |
| 231 | + ``ShardingCodec`` overrides to return its inner codec chain. |
| 232 | + """ |
| 233 | + return None |
| 234 | + |
| 235 | + def deserialize( |
| 236 | + self, raw: Buffer | None, chunk_spec: ArraySpec |
| 237 | + ) -> dict[tuple[int, ...], Buffer | None]: |
| 238 | + """Unpack stored bytes into per-inner-chunk buffers. |
| 239 | +
|
| 240 | + Default: single chunk keyed at ``(0,)``. |
| 241 | + ``ShardingCodec`` overrides to decode the shard index and slice the |
| 242 | + blob into per-chunk buffers. |
| 243 | + """ |
| 244 | + return {(0,): raw} |
| 245 | + |
| 246 | + def serialize( |
| 247 | + self, |
| 248 | + chunk_dict: dict[tuple[int, ...], Buffer | None], |
| 249 | + chunk_spec: ArraySpec, |
| 250 | + ) -> Buffer | None: |
| 251 | + """Pack per-inner-chunk buffers into a storage blob. |
| 252 | +
|
| 253 | + Default: return the single chunk's bytes (or ``None`` if absent). |
| 254 | + ``ShardingCodec`` overrides to concatenate chunks and build an index. |
| 255 | + Returns ``None`` when all chunks are empty (caller should delete the key). |
| 256 | + """ |
| 257 | + return chunk_dict.get((0,)) |
| 258 | + |
| 259 | + # ------------------------------------------------------------------ |
| 260 | + # prepare / finalize — sync |
| 261 | + # ------------------------------------------------------------------ |
| 262 | + |
| 263 | + def prepare_read_sync( |
| 264 | + self, |
| 265 | + byte_getter: Any, |
| 266 | + chunk_spec: ArraySpec, |
| 267 | + chunk_selection: SelectorTuple, |
| 268 | + codec_chain: Any, |
| 269 | + aa_chain: Any, |
| 270 | + ab_pair: Any, |
| 271 | + bb_chain: Any, |
| 272 | + ) -> NDBuffer | None: |
| 273 | + """Sync IO + full decode for the selected region.""" |
| 274 | + raw = byte_getter.get_sync(prototype=chunk_spec.prototype) |
| 275 | + chunk_array: NDBuffer | None = codec_chain.decode_chunk( |
| 276 | + raw, chunk_spec, aa_chain, ab_pair, bb_chain |
| 277 | + ) |
| 278 | + if chunk_array is not None: |
| 279 | + return chunk_array[chunk_selection] |
| 280 | + return None |
| 281 | + |
| 282 | + def prepare_write_sync( |
| 283 | + self, |
| 284 | + byte_setter: Any, |
| 285 | + chunk_spec: ArraySpec, |
| 286 | + chunk_selection: SelectorTuple, |
| 287 | + out_selection: SelectorTuple, |
| 288 | + replace: bool, |
| 289 | + codec_chain: Any, |
| 290 | + ) -> PreparedWrite: |
| 291 | + """Sync IO + deserialize. Returns a :class:`PreparedWrite`.""" |
| 292 | + existing: Buffer | None = None |
| 293 | + if not replace: |
| 294 | + existing = byte_setter.get_sync(prototype=chunk_spec.prototype) |
| 295 | + chunk_dict = self.deserialize(existing, chunk_spec) |
| 296 | + inner_chain = self.inner_codec_chain or codec_chain |
| 297 | + return PreparedWrite( |
| 298 | + chunk_dict=chunk_dict, |
| 299 | + inner_codec_chain=inner_chain, |
| 300 | + inner_chunk_spec=chunk_spec, |
| 301 | + indexer=[ |
| 302 | + ( # type: ignore[list-item] |
| 303 | + (0,), |
| 304 | + chunk_selection, |
| 305 | + out_selection, |
| 306 | + replace, |
| 307 | + ) |
| 308 | + ], |
| 309 | + ) |
| 310 | + |
| 311 | + def finalize_write_sync( |
| 312 | + self, |
| 313 | + prepared: PreparedWrite, |
| 314 | + chunk_spec: ArraySpec, |
| 315 | + byte_setter: Any, |
| 316 | + ) -> None: |
| 317 | + """Serialize the prepared *chunk_dict* and write to store.""" |
| 318 | + blob = self.serialize(prepared.chunk_dict, chunk_spec) |
| 319 | + if blob is None: |
| 320 | + byte_setter.delete_sync() |
| 321 | + else: |
| 322 | + byte_setter.set_sync(blob) |
| 323 | + |
| 324 | + # ------------------------------------------------------------------ |
| 325 | + # prepare / finalize — async |
| 326 | + # ------------------------------------------------------------------ |
| 327 | + |
| 328 | + async def prepare_read( |
| 329 | + self, |
| 330 | + byte_getter: Any, |
| 331 | + chunk_spec: ArraySpec, |
| 332 | + chunk_selection: SelectorTuple, |
| 333 | + codec_chain: Any, |
| 334 | + aa_chain: Any, |
| 335 | + ab_pair: Any, |
| 336 | + bb_chain: Any, |
| 337 | + ) -> NDBuffer | None: |
| 338 | + """Async IO + full decode for the selected region.""" |
| 339 | + raw = await byte_getter.get(prototype=chunk_spec.prototype) |
| 340 | + chunk_array: NDBuffer | None = codec_chain.decode_chunk( |
| 341 | + raw, chunk_spec, aa_chain, ab_pair, bb_chain |
| 342 | + ) |
| 343 | + if chunk_array is not None: |
| 344 | + return chunk_array[chunk_selection] |
| 345 | + return None |
| 346 | + |
| 347 | + async def prepare_write( |
| 348 | + self, |
| 349 | + byte_setter: Any, |
| 350 | + chunk_spec: ArraySpec, |
| 351 | + chunk_selection: SelectorTuple, |
| 352 | + out_selection: SelectorTuple, |
| 353 | + replace: bool, |
| 354 | + codec_chain: Any, |
| 355 | + ) -> PreparedWrite: |
| 356 | + """Async IO + deserialize. Returns a :class:`PreparedWrite`.""" |
| 357 | + existing: Buffer | None = None |
| 358 | + if not replace: |
| 359 | + existing = await byte_setter.get(prototype=chunk_spec.prototype) |
| 360 | + chunk_dict = self.deserialize(existing, chunk_spec) |
| 361 | + inner_chain = self.inner_codec_chain or codec_chain |
| 362 | + return PreparedWrite( |
| 363 | + chunk_dict=chunk_dict, |
| 364 | + inner_codec_chain=inner_chain, |
| 365 | + inner_chunk_spec=chunk_spec, |
| 366 | + indexer=[ |
| 367 | + ( # type: ignore[list-item] |
| 368 | + (0,), |
| 369 | + chunk_selection, |
| 370 | + out_selection, |
| 371 | + replace, |
| 372 | + ) |
| 373 | + ], |
| 374 | + ) |
| 375 | + |
| 376 | + async def finalize_write( |
| 377 | + self, |
| 378 | + prepared: PreparedWrite, |
| 379 | + chunk_spec: ArraySpec, |
| 380 | + byte_setter: Any, |
| 381 | + ) -> None: |
| 382 | + """Async version of :meth:`finalize_write_sync`.""" |
| 383 | + blob = self.serialize(prepared.chunk_dict, chunk_spec) |
| 384 | + if blob is None: |
| 385 | + await byte_setter.delete() |
| 386 | + else: |
| 387 | + await byte_setter.set(blob) |
| 388 | + |
210 | 389 |
|
211 | 390 | class BytesBytesCodec(BaseCodec[Buffer, Buffer]): |
212 | 391 | """Base class for bytes-to-bytes codecs.""" |
|
0 commit comments