From 55d1a46b8ee50c64c6ce204ab620dcc8d53b475a Mon Sep 17 00:00:00 2001 From: BrianMichell Date: Wed, 27 May 2026 16:27:55 +0000 Subject: [PATCH 1/2] Add composable index strategies for SEG-Y header transformation This commit introduces a new module for index strategies that replaces the monolithic GridOverrider with a set of single-responsibility IndexStrategy classes. The new structure allows for better composition and selection of strategies via the IndexStrategyRegistry, while maintaining compatibility with the existing ingestion behavior. Additionally, unit tests for the new strategies have been added to ensure functionality and correctness. --- src/mdio/ingestion/index_strategies.py | 420 ++++++++++++ src/mdio/segy/geometry.py | 638 ++++-------------- tests/unit/test_ingestion_index_strategies.py | 569 ++++++++++++++++ 3 files changed, 1118 insertions(+), 509 deletions(-) create mode 100644 src/mdio/ingestion/index_strategies.py create mode 100644 tests/unit/test_ingestion_index_strategies.py diff --git a/src/mdio/ingestion/index_strategies.py b/src/mdio/ingestion/index_strategies.py new file mode 100644 index 00000000..bec58b50 --- /dev/null +++ b/src/mdio/ingestion/index_strategies.py @@ -0,0 +1,420 @@ +"""Composable index strategies for transforming SEG-Y headers into indexable dimensions. + +This module replaces the monolithic :class:`mdio.segy.geometry.GridOverrider` command +dispatch with a small set of single-responsibility :class:`IndexStrategy` objects that can +be composed via :class:`CompositeStrategy`. + +Strategies are selected by :class:`IndexStrategyRegistry` from the typed +:class:`mdio.segy.geometry.GridOverrides` configuration plus optional template hints. The +public contract preserved by :class:`mdio.segy.geometry.GridOverrider` (a thin shim around +this module) keeps end-to-end ingestion behavior identical to v1.1.x. +""" + +from __future__ import annotations + +import logging +from abc import ABC +from abc import abstractmethod +from typing import TYPE_CHECKING + +import numpy as np +from numpy.lib import recfunctions as rfn + +from mdio.core import Dimension +from mdio.ingestion.segy.header_analysis import ShotGunGeometryType +from mdio.ingestion.segy.header_analysis import StreamerShotGeometryType +from mdio.ingestion.segy.header_analysis import analyze_lines_for_guns +from mdio.ingestion.segy.header_analysis import analyze_non_indexed_headers +from mdio.ingestion.segy.header_analysis import analyze_streamer_headers +from mdio.segy.exceptions import GridOverrideKeysError + +if TYPE_CHECKING: + from collections.abc import Iterable + + from numpy.typing import DTypeLike + from segy.arrays import HeaderArray + + from mdio.builder.templates.base import AbstractDatasetTemplate + from mdio.segy.geometry import GridOverrides + +logger = logging.getLogger(__name__) + + +class IndexStrategy(ABC): + """Abstract base for header indexing strategies. + + A strategy transforms a raw header array (e.g. add or rebase fields) and computes + the resulting :class:`Dimension` list. Strategies are composable through + :class:`CompositeStrategy`. The default :meth:`compute_dimensions` builds dimensions + from unique header values; subclasses override only when they need different + semantics (currently just :class:`CompositeStrategy`). + + Subclasses with header preconditions set :attr:`required_keys` so the shim and + :class:`CompositeStrategy` can raise :class:`GridOverrideKeysError` with a clear + "missing fields X, Y, Z" message before numpy fails on a deeper key lookup. + """ + + @property + def required_keys(self) -> frozenset[str]: + """Header field names that must be present before :meth:`transform_headers` runs. + + Empty by default. Override on subclasses whose transform indexes specific fields. + """ + return frozenset() + + def validate_headers(self, headers: HeaderArray) -> None: + """Raise :class:`GridOverrideKeysError` if any required header field is missing. + + Callers (the :class:`mdio.segy.geometry.GridOverrider` shim and + :class:`CompositeStrategy`) invoke this before each transform so failure points + at the user-facing override name rather than at a numpy structured-array key error. + """ + required = self.required_keys + if not required: + return + present = set(headers.dtype.names or ()) + if not required.issubset(present): + raise GridOverrideKeysError(self.name, set(required)) + + @abstractmethod + def transform_headers(self, headers: HeaderArray) -> HeaderArray: + """Return a new header array with this strategy's transformation applied.""" + + def compute_dimensions(self, headers: HeaderArray, dim_names: tuple[str, ...]) -> list[Dimension]: + """Build one :class:`Dimension` per requested name from unique header values. + + Names absent from ``headers.dtype.names`` are silently skipped, matching the v1.1 + ``GridOverrider`` post-processing step. + """ + return [ + Dimension(coords=np.unique(headers[name]), name=name) + for name in dim_names + if name in headers.dtype.names + ] + + @property + def name(self) -> str: + """Return the strategy's class name; useful for logging and tests.""" + return self.__class__.__name__ + + +class RegularGridStrategy(IndexStrategy): + """Default strategy: headers untouched, dimensions are unique values per name.""" + + def transform_headers(self, headers: HeaderArray) -> HeaderArray: + """Pass headers through unchanged.""" + return headers + + +class DuplicateHandlingStrategy(IndexStrategy): + """Disambiguate duplicate index tuples by appending a per-tuple ``trace`` counter. + + Mirrors the v1.1 ``DuplicateIndex`` command: count occurrences of each unique + combination of dimension fields (excluding coordinate fields and any caller-declared + ``excluded_fields``), then attach the resulting 1-based counter as a new ``trace`` field + on the original headers. + + Args: + coord_fields: Names of header fields that are template coordinates and must be + excluded from the dimension grouping (their values vary independently of the + grid index). + excluded_fields: Additional fields to exclude from grouping. Used by + :class:`NonBinnedStrategy` to keep the explicit ``non_binned_dims`` from + polluting the per-tuple counter. + dtype: NumPy dtype for the appended ``trace`` counter. + """ + + def __init__( + self, + coord_fields: Iterable[str] = (), + excluded_fields: Iterable[str] = (), + dtype: DTypeLike = np.int16, + ) -> None: + self.coord_fields = frozenset(coord_fields) + self.excluded_fields = frozenset(excluded_fields) + self.dtype = dtype + + def _dim_fields(self, headers: HeaderArray) -> list[str]: + """Header field names that participate in the duplicate grouping.""" + return [ + name + for name in headers.dtype.names + if name != "trace" and name not in self.coord_fields and name not in self.excluded_fields + ] + + def transform_headers(self, headers: HeaderArray) -> HeaderArray: + """Append a per-(dim-tuple) ``trace`` counter to ``headers``.""" + dim_fields = self._dim_fields(headers) + dim_headers = headers[dim_fields] if dim_fields else headers + with_trace = analyze_non_indexed_headers(dim_headers, dtype=self.dtype) + + if with_trace is None or "trace" not in with_trace.dtype.names: + return headers + + trace_values = np.array(with_trace["trace"]) + return rfn.append_fields(headers, "trace", trace_values, usemask=False) + + +class NonBinnedStrategy(DuplicateHandlingStrategy): + """Collapse selected non-binned dimensions into a single ``trace`` dimension. + + Inherits the per-tuple ``trace`` counter from :class:`DuplicateHandlingStrategy` and + captures ``chunksize`` so the :class:`mdio.segy.geometry.GridOverrider` shim can size + the new ``trace`` chunk correctly. + + Args: + chunksize: Chunk size to assign to the ``trace`` dimension. The strategy itself + does not apply this value; the shim uses it when rewriting the chunksize tuple. + non_binned_dims: Header fields collapsed into ``trace``. They are excluded from + the duplicate grouping so the counter only varies along the remaining dims. + coord_fields: Template coordinate names to exclude from grouping. + dtype: NumPy dtype for the appended ``trace`` counter. + """ + + def __init__( + self, + chunksize: int, + non_binned_dims: Iterable[str], + coord_fields: Iterable[str] = (), + dtype: DTypeLike = np.int16, + ) -> None: + non_binned_dims = tuple(non_binned_dims) + super().__init__( + coord_fields=coord_fields, + excluded_fields=non_binned_dims, + dtype=dtype, + ) + self.chunksize = chunksize + self.non_binned_dims = non_binned_dims + + +class ChannelWrappingStrategy(IndexStrategy): + """Renumber streamer channels per cable when geometry is Type B. + + Detects whether channel numbering is per-cable (Type A; pass-through) or sequential + across cables (Type B; rebase to 1..N per cable). Mirrors the v1.1 ``AutoChannelWrap`` + command. + """ + + @property + def required_keys(self) -> frozenset[str]: + """Streamer channel detection needs the cable-channel-shot triplet.""" + return frozenset({"shot_point", "cable", "channel"}) + + def transform_headers(self, headers: HeaderArray) -> HeaderArray: + """Rebase ``channel`` per cable for Type B geometry; pass through for Type A.""" + unique_cables, cable_chan_min, cable_chan_max, geom_type = analyze_streamer_headers(headers) + + logger.info("Ingesting dataset as %s", geom_type.name) + for cable, chan_min, chan_max in zip(unique_cables, cable_chan_min, cable_chan_max, strict=True): + logger.info("Cable: %s has min chan: %s and max chan: %s", cable, chan_min, chan_max) + + if geom_type != StreamerShotGeometryType.B: + return headers + + for idx, cable in enumerate(unique_cables): + cable_idxs = np.where(headers["cable"][:] == cable) + headers["channel"][cable_idxs] = headers["channel"][cable_idxs] - cable_chan_min[idx] + 1 + + return headers + + +class ShotWrappingStrategy(IndexStrategy): + """Derive a dense ``shot_index`` field from sparse or interleaved ``shot_point`` values. + + Replaces the v1.1 ``AutoShotWrap`` (streamer) and ``CalculateShotIndex`` (OBN) + commands. The two callers differ only in: + + * ``line_field`` -- ``sail_line`` for streamer, ``shot_line`` for OBN. + * ``always_calculate`` -- streamer skips the transform entirely for Type A geometries + (per-gun shot points are already dense), OBN always emits ``shot_index`` because the + template declares it as a calculated dimension. + + Args: + line_field: Header field used to group shots into independent lines. + always_calculate: When ``True``, emit ``shot_index`` for every geometry type. For + Type A this builds a 0-based ``np.searchsorted`` over sorted unique shot + points per line. + """ + + _STREAMER_LINE_FIELD = "sail_line" + + def __init__(self, line_field: str, always_calculate: bool = False) -> None: + self.line_field = line_field + self.always_calculate = always_calculate + + @property + def required_keys(self) -> frozenset[str]: + """Streamer (``sail_line``) needs cable+channel too; OBN (``shot_line``) does not. + + Mirrors the v1.1 split between ``AutoShotWrap.required_keys`` and + ``CalculateShotIndex.required_keys``. + """ + base = {self.line_field, "gun", "shot_point"} + if self.line_field == self._STREAMER_LINE_FIELD: + base |= {"cable", "channel"} + return frozenset(base) + + def transform_headers(self, headers: HeaderArray) -> HeaderArray: + """Append ``shot_index`` derived from ``shot_point`` per line.""" + unique_lines, unique_guns_per_line, geom_type = analyze_lines_for_guns(headers, line_field=self.line_field) + + logger.info("Ingesting dataset as shot type: %s (line_field=%s)", geom_type.name, self.line_field) + + max_num_guns = 1 + for line_val in unique_lines: + guns = unique_guns_per_line[str(line_val)] + logger.info("%s: %s has guns: %s", self.line_field, line_val, guns) + max_num_guns = max(len(guns), max_num_guns) + + if geom_type == ShotGunGeometryType.A and not self.always_calculate: + return headers + + shot_index = np.empty(len(headers), dtype="uint32") + # `.base` is None for non-view arrays; fall back to the array itself. + base_array = headers.base if headers.base is not None else headers + headers = rfn.append_fields(base_array, "shot_index", shot_index, usemask=False) + + if geom_type == ShotGunGeometryType.B: + for line_val in unique_lines: + line_idxs = np.where(headers[self.line_field][:] == line_val) + headers["shot_index"][line_idxs] = np.floor(headers["shot_point"][line_idxs] / max_num_guns) + headers["shot_index"][line_idxs] -= headers["shot_index"][line_idxs].min() + else: + for line_val in unique_lines: + line_idxs = np.where(headers[self.line_field][:] == line_val) + shot_points = headers["shot_point"][line_idxs] + unique_shots = np.unique(shot_points) + headers["shot_index"][line_idxs] = np.searchsorted(unique_shots, shot_points) + + return headers + + +class ComponentSynthesisStrategy(IndexStrategy): + """Synthesize template-required dimension fields that are absent from the headers. + + Currently used to fill the ``component`` dimension with a constant value of 1 for + OBN templates whose SEG-Y spec does not include a component header. Mirrors the + v1.1 ``GridOverrider._synthesize_obn_component`` behavior. + + Args: + synthesize_dims: Names of dimension fields to synthesize when missing. + """ + + def __init__(self, synthesize_dims: Iterable[str]) -> None: + self.synthesize_dims = tuple(synthesize_dims) + + def transform_headers(self, headers: HeaderArray) -> HeaderArray: + """Append constant-1 fields for any synthesize_dims not already present.""" + for dim in self.synthesize_dims: + if dim in headers.dtype.names: + continue + logger.warning( + "SEG-Y headers do not contain '%s' field required by template; " + "synthesizing dimension with constant value 1 for all traces.", + dim, + ) + comp_array = np.ones(len(headers), dtype=np.uint8) + base_array = headers.base if headers.base is not None else headers + headers = rfn.append_fields(base_array, dim, comp_array, usemask=False) + return headers + + +class CompositeStrategy(IndexStrategy): + """Apply multiple strategies in order; each transform feeds the next. + + Dimension computation is delegated to the final strategy on the assumption it is + aware of all preceding header transformations. + """ + + def __init__(self, strategies: list[IndexStrategy]) -> None: + if not strategies: + msg = "CompositeStrategy requires at least one strategy" + raise ValueError(msg) + self.strategies = strategies + + def transform_headers(self, headers: HeaderArray) -> HeaderArray: + """Validate then run each child strategy's transform in sequence. + + Each step re-validates against the running header array, so a strategy that + produces a field (e.g. :class:`ComponentSynthesisStrategy` adding ``component``) + can satisfy a later strategy's :attr:`required_keys`. + """ + result = headers + for strategy in self.strategies: + logger.debug("Applying strategy: %s", strategy.name) + strategy.validate_headers(result) + result = strategy.transform_headers(result) + return result + + def compute_dimensions(self, headers: HeaderArray, dim_names: tuple[str, ...]) -> list[Dimension]: + """Delegate to the final child strategy.""" + return self.strategies[-1].compute_dimensions(headers, dim_names) + + +class IndexStrategyRegistry: + """Picks the right :class:`IndexStrategy` from grid overrides + template hints.""" + + def create_strategy( + self, + grid_overrides: GridOverrides | None = None, + synthesize_dims: tuple[str, ...] = (), + template: AbstractDatasetTemplate | None = None, + ) -> IndexStrategy: + """Build a strategy (possibly composite) for the given config. + + Strategy ordering, when multiple flags are set, mirrors v1.1 behavior: + + 1. ``ComponentSynthesisStrategy`` (so later strategies can rely on the synthesized + field being present). + 2. ``ChannelWrappingStrategy`` (rebases ``channel`` before any shot calculation). + 3. ``ShotWrappingStrategy`` for ``auto_shot_wrap`` (streamer; ``sail_line``). + 4. ``ShotWrappingStrategy`` for ``calculate_shot_index`` (OBN; ``shot_line``, + ``always_calculate=True``). + 5. ``NonBinnedStrategy`` or ``DuplicateHandlingStrategy`` (mutually exclusive; + ``non_binned`` wins when both are set, matching v1.x semantics). + + Args: + grid_overrides: Typed grid override configuration, or ``None`` for no + user-driven overrides. + synthesize_dims: Dimensions to synthesize if missing (e.g. ``component``). + template: Optional dataset template; used to look up coordinate names so + duplicate-handling counters group on dimension fields only. + + Returns: + A single :class:`IndexStrategy` instance. Returns + :class:`RegularGridStrategy` when no overrides and no synthesis are required. + """ + strategies: list[IndexStrategy] = [] + + if synthesize_dims: + strategies.append(ComponentSynthesisStrategy(synthesize_dims)) + + coord_fields: tuple[str, ...] = template.coordinate_names if template is not None else () + + if grid_overrides: + if grid_overrides.auto_channel_wrap: + strategies.append(ChannelWrappingStrategy()) + + if grid_overrides.auto_shot_wrap: + strategies.append(ShotWrappingStrategy(line_field="sail_line", always_calculate=False)) + + if grid_overrides.calculate_shot_index: + strategies.append(ShotWrappingStrategy(line_field="shot_line", always_calculate=True)) + + if grid_overrides.non_binned: + strategies.append( + NonBinnedStrategy( + chunksize=grid_overrides.chunksize, + non_binned_dims=grid_overrides.non_binned_dims or (), + coord_fields=coord_fields, + ) + ) + elif grid_overrides.has_duplicates: + strategies.append(DuplicateHandlingStrategy(coord_fields=coord_fields)) + + if not strategies: + return RegularGridStrategy() + if len(strategies) == 1: + return strategies[0] + return CompositeStrategy(strategies) diff --git a/src/mdio/segy/geometry.py b/src/mdio/segy/geometry.py index eaa28fb4..c2b80b80 100644 --- a/src/mdio/segy/geometry.py +++ b/src/mdio/segy/geometry.py @@ -1,32 +1,28 @@ -"""SEG-Y geometry handling functions.""" +"""SEG-Y grid override configuration model and legacy executor shim. + +The Pydantic :class:`GridOverrides` model is the supported public API for configuring +grid overrides. The :class:`GridOverrider` class is retained as a thin shim that +delegates to :class:`mdio.ingestion.index_strategies.IndexStrategyRegistry`; it preserves +the v1.1 ``run(...)`` contract for callers that still pass a legacy ``dict``. +""" from __future__ import annotations import logging -from abc import ABC -from abc import abstractmethod from typing import TYPE_CHECKING from typing import Any -import numpy as np -from numpy.lib import recfunctions as rfn from pydantic import BaseModel from pydantic import ConfigDict from pydantic import Field -from mdio.ingestion.segy.header_analysis import ShotGunGeometryType -from mdio.ingestion.segy.header_analysis import StreamerShotGeometryType -from mdio.ingestion.segy.header_analysis import analyze_lines_for_guns -from mdio.ingestion.segy.header_analysis import analyze_non_indexed_headers -from mdio.ingestion.segy.header_analysis import analyze_streamer_headers -from mdio.segy.exceptions import GridOverrideKeysError +from mdio.ingestion.index_strategies import IndexStrategyRegistry from mdio.segy.exceptions import GridOverrideMissingParameterError from mdio.segy.exceptions import GridOverrideUnknownError if TYPE_CHECKING: from collections.abc import Sequence - from numpy.typing import NDArray from segy.arrays import HeaderArray from mdio.builder.templates.base import AbstractDatasetTemplate @@ -90,533 +86,157 @@ def to_legacy_dict(self) -> dict[str, Any]: return self.model_dump(by_alias=True, exclude_defaults=True) -class GridOverrideCommand(ABC): - """Abstract base class for grid override commands.""" - - @property - @abstractmethod - def required_keys(self) -> set: - """Get the set of required keys for the grid override command.""" - - @property - @abstractmethod - def required_parameters(self) -> set: - """Get the set of required parameters for the grid override command.""" - - @abstractmethod - def validate( - self, - index_headers: HeaderArray, - grid_overrides: dict[str, bool | int], - template: AbstractDatasetTemplate | None = None, - ) -> None: - """Validate if this transform should run on the type of data.""" - - @abstractmethod - def transform( - self, - index_headers: HeaderArray, - grid_overrides: dict[str, bool | int], - template: AbstractDatasetTemplate, # noqa: ARG002 - ) -> NDArray: - """Perform the grid transform.""" - - def transform_index_names(self, index_names: Sequence[str]) -> Sequence[str]: - """Perform the transform of index names. - - Optional method: Subclasses may override this method to provide custom behavior. If not - overridden, this default implementation will be used, which is a no-op. - - Args: - index_names: List of index names to be modified. - - Returns: - New tuple of index names after the transform. - """ - return index_names - - def transform_chunksize( - self, - chunksize: Sequence[int], - grid_overrides: dict[str, bool | int], - ) -> Sequence[int]: - """Perform the transform of chunksize. - - Optional method: Subclasses may override this method to provide custom behavior. If not - overridden, this default implementation will be used, which is a no-op. - - Args: - chunksize: List of chunk sizes to be modified. - grid_overrides: Full grid override parameterization. - - Returns: - New tuple of chunk sizes after the transform. - """ - _ = grid_overrides # Unused, required for ABC compatibility - return chunksize - - @property - def name(self) -> str: - """Convenience property to get the name of the command.""" - return self.__class__.__name__ - - def check_required_keys(self, index_headers: HeaderArray) -> None: - """Check if all required keys are present in the index headers.""" - index_names = index_headers.dtype.names - if not self.required_keys.issubset(index_names): - raise GridOverrideKeysError(self.name, self.required_keys) - - def check_required_params(self, grid_overrides: dict[str, str | int]) -> None: - """Check if all required keys are present in the index headers.""" - if self.required_parameters is None: - return - - passed_parameters = set(grid_overrides.keys()) - - if not self.required_parameters.issubset(passed_parameters): - missing_params = self.required_parameters - passed_parameters - raise GridOverrideMissingParameterError(self.name, missing_params) - +def _resolve_synthesize_dims(template: AbstractDatasetTemplate | None) -> tuple[str, ...]: + """Return dimension fields to synthesize when missing for a given template. -class DuplicateIndex(GridOverrideCommand): - """Automatically handle duplicate traces in a new axis - trace with chunksize 1.""" - - required_keys = None - required_parameters = None - - def validate( - self, - index_headers: HeaderArray, - grid_overrides: dict[str, bool | int], - template: AbstractDatasetTemplate | None = None, # noqa: ARG002 - ) -> None: - """Validate if this transform should run on the type of data.""" - if self.required_keys is not None: - self.check_required_keys(index_headers) - self.check_required_params(grid_overrides) - - def transform( - self, - index_headers: HeaderArray, - grid_overrides: dict[str, bool | int], - template: AbstractDatasetTemplate, - ) -> NDArray: - """Perform the grid transform.""" - self.validate(index_headers, grid_overrides) - - # Filter out coordinate fields, keep only dimensions for trace indexing - coord_fields = set(template.coordinate_names) if template else set() - - # For NonBinned: non_binned_dims should be excluded from trace indexing grouping - # because they become coordinates indexed by the trace dimension, not grouping keys. - # The trace index should count all traces per remaining dimension combination. - non_binned_dims = set(grid_overrides.get("non_binned_dims", [])) if grid_overrides else set() - - dim_fields = [ - name - for name in index_headers.dtype.names - if name != "trace" and name not in coord_fields and name not in non_binned_dims - ] - - # Create trace indices on dimension fields only - dim_headers = index_headers[dim_fields] if dim_fields else index_headers - dim_headers_with_trace = analyze_non_indexed_headers(dim_headers) - - # Add trace field back to full headers - if dim_headers_with_trace is not None and "trace" in dim_headers_with_trace.dtype.names: - trace_values = np.array(dim_headers_with_trace["trace"]) - index_headers = rfn.append_fields(index_headers, "trace", trace_values, usemask=False) - - return index_headers - - def transform_index_names(self, index_names: Sequence[str]) -> Sequence[str]: - """Insert dimension "trace" to the sample-1 dimension.""" - new_names = list(index_names) - new_names.append("trace") - return tuple(new_names) - - def transform_chunksize( - self, - chunksize: Sequence[int], - grid_overrides: dict[str, bool | int], - ) -> Sequence[int]: - """Insert chunksize of 1 to the sample-1 dimension.""" - _ = grid_overrides # Unused, required for ABC compatibility - new_chunks = list(chunksize) - new_chunks.insert(-1, 1) - return tuple(new_chunks) - - -class NonBinned(DuplicateIndex): - """Handle non-binned dimensions by converting them to a trace dimension with coordinates. - - This override takes dimensions that are not regularly sampled (non-binned) and converts - them into a single 'trace' dimension. The original non-binned dimensions become coordinates - indexed by the trace dimension. - - Example: - Template with dimensions [shot_point, cable, channel, azimuth, offset, sample] - and non_binned_dims=['azimuth', 'offset'] becomes: - - dimensions: [shot_point, cable, channel, trace, sample] - - coordinates: azimuth and offset with dimensions [shot_point, cable, channel, trace] - - Attributes: - required_keys: No required keys for this override. - required_parameters: Set containing 'chunksize' and 'non_binned_dims'. + Only the OBN receiver gathers template currently synthesizes ``component``; every + other template returns ``()`` so the strategy registry skips synthesis entirely. """ - - required_keys = None - required_parameters = {"chunksize", "non_binned_dims"} - - def validate( - self, - index_headers: HeaderArray, - grid_overrides: dict[str, bool | int], - template: AbstractDatasetTemplate | None = None, # noqa: ARG002 - ) -> None: - """Validate if this transform should run on the type of data.""" - self.check_required_params(grid_overrides) - - # Validate that non_binned_dims is a list - non_binned_dims = grid_overrides.get("non_binned_dims", []) - if not isinstance(non_binned_dims, list): - msg = f"non_binned_dims must be a list, got {type(non_binned_dims)}" - raise ValueError(msg) - - # Validate that all non-binned dimensions exist in headers - missing_dims = set(non_binned_dims) - set(index_headers.dtype.names) - if missing_dims: - msg = f"Non-binned dimensions {missing_dims} not found in index headers" - raise ValueError(msg) - - def transform_chunksize( - self, - chunksize: Sequence[int], - grid_overrides: dict[str, bool | int], - ) -> Sequence[int]: - """Insert chunksize for trace dimension at N-1 position.""" - new_chunks = list(chunksize) - trace_chunksize = grid_overrides["chunksize"] - new_chunks.insert(-1, trace_chunksize) - return tuple(new_chunks) - - -class AutoChannelWrap(GridOverrideCommand): - """Automatically determine Streamer acquisition type.""" - - required_keys = {"shot_point", "cable", "channel"} - required_parameters = None - - def validate( - self, - index_headers: HeaderArray, - grid_overrides: dict[str, bool | int], - template: AbstractDatasetTemplate | None = None, # noqa: ARG002 - ) -> None: - """Validate if this transform should run on the type of data.""" - self.check_required_keys(index_headers) - self.check_required_params(grid_overrides) - - def transform( - self, - index_headers: HeaderArray, - grid_overrides: dict[str, bool | int], - template: AbstractDatasetTemplate, # noqa: ARG002 - ) -> NDArray: - """Perform the grid transform.""" - self.validate(index_headers, grid_overrides) - - result = analyze_streamer_headers(index_headers) - unique_cables, cable_chan_min, cable_chan_max, geom_type = result - logger.info("Ingesting dataset as %s", geom_type.name) - - for cable, chan_min, chan_max in zip(unique_cables, cable_chan_min, cable_chan_max, strict=True): - logger.info("Cable: %s has min chan: %s and max chan: %s", cable, chan_min, chan_max) - - # This might be slow and could be improved with a rewrite to prevent so many lookups - if geom_type == StreamerShotGeometryType.B: - for idx, cable in enumerate(unique_cables): - cable_idxs = np.where(index_headers["cable"][:] == cable) - cc_min = cable_chan_min[idx] - - index_headers["channel"][cable_idxs] = index_headers["channel"][cable_idxs] - cc_min + 1 - - return index_headers - - -class CalculateShotIndex(GridOverrideCommand): - """Calculate dense shot_index from shot_point values for OBN templates. - - Creates a 0-based shot_index dimension from sparse or interleaved shot_point - values, grouping by shot_line and gun. This is required for the OBN template - which uses shot_index as a calculated dimension. - - Required headers: shot_line, gun, shot_point - - Attributes: - required_parameters: Set of required parameters (None for this class). + if template is None: + return () + # Lazy import: builder templates pull in builder schemas that indirectly import this + # module's ``GridOverrides``, so a top-level import would cycle. + from mdio.builder.templates.seismic_3d_obn import Seismic3DObnReceiverGathersTemplate # noqa: PLC0415 + + if isinstance(template, Seismic3DObnReceiverGathersTemplate): + return ("component",) + return () + + +def _validate_template_for_overrides( + config: GridOverrides, + template: AbstractDatasetTemplate | None, +) -> None: + """Reject grid override / template pairings that v1.1 forbade. + + ``auto_shot_wrap`` is streamer-only and ``calculate_shot_index`` is OBN-only; using + either with the wrong template silently produced wrong shot indices in v1.1 unless + the per-command validator caught it. This function restores that guard. + + Args: + config: Typed grid overrides extracted from the user's legacy dict. + template: Template chosen by the caller, or ``None`` if omitted. + + Raises: + TypeError: When ``auto_shot_wrap`` is set without a streamer template, or + ``calculate_shot_index`` is set without an OBN receiver-gathers template. """ - - required_parameters = None - - @property - def required_keys(self) -> set: - """Return required header keys for OBN shot index calculation.""" - return {"shot_line", "gun", "shot_point"} - - def validate( - self, - index_headers: HeaderArray, - grid_overrides: dict[str, bool | int], - template: AbstractDatasetTemplate | None = None, - ) -> None: - """Validate if this transform should run on the type of data.""" - # Import here to avoid circular imports at module load time - from mdio.builder.templates.seismic_3d_obn import Seismic3DObnReceiverGathersTemplate # noqa: PLC0415 - - if template is None: - msg = "CalculateShotIndex requires a template." - raise TypeError(msg) - - if not isinstance(template, Seismic3DObnReceiverGathersTemplate): - msg = ( - f"CalculateShotIndex only supports Seismic3DObnReceiverGathersTemplate, got {type(template).__name__}." - ) - raise TypeError(msg) - - index_names = set(index_headers.dtype.names) - if not self.required_keys.issubset(index_names): - raise GridOverrideKeysError(self.name, self.required_keys) - self.check_required_params(grid_overrides) - - def transform( - self, - index_headers: HeaderArray, - grid_overrides: dict[str, bool | int], - template: AbstractDatasetTemplate, - ) -> NDArray: - """Perform the grid transform to calculate shot_index from shot_point.""" - self.validate(index_headers, grid_overrides, template) - - line_field = "shot_line" - result = analyze_lines_for_guns(index_headers, line_field=line_field) - unique_lines, unique_guns_per_line, geom_type = result - logger.info("Ingesting OBN dataset as shot type: %s", geom_type.name) - - max_num_guns = 1 - for line_val in unique_lines: - guns = unique_guns_per_line[str(line_val)] - logger.info("%s: %s has guns: %s", line_field, line_val, guns) - max_num_guns = max(len(guns), max_num_guns) - - # Always calculate shot_index - the OBN template requires it - shot_index = np.empty(len(index_headers), dtype="uint32") - # Use .base if available (view of another array), otherwise use the array directly - base_array = index_headers.base if index_headers.base is not None else index_headers - index_headers = rfn.append_fields(base_array, "shot_index", shot_index) - - if geom_type == ShotGunGeometryType.B: - # Type B: shot points are interleaved across guns, divide to get dense index - for line_val in unique_lines: - line_idxs = np.where(index_headers[line_field][:] == line_val) - index_headers["shot_index"][line_idxs] = np.floor(index_headers["shot_point"][line_idxs] / max_num_guns) - # Make shot index zero-based PER line - index_headers["shot_index"][line_idxs] -= index_headers["shot_index"][line_idxs].min() - else: - # Type A: shot points are already unique per gun, create 0-based index from unique values - for line_val in unique_lines: - line_idxs = np.where(index_headers[line_field][:] == line_val) - shot_points = index_headers["shot_point"][line_idxs] - # np.unique returns sorted values; searchsorted maps each shot_point to its 0-based index - unique_shots = np.unique(shot_points) - index_headers["shot_index"][line_idxs] = np.searchsorted(unique_shots, shot_points) - - return index_headers - - -class AutoShotWrap(GridOverrideCommand): - """Automatic shot index calculation from interleaved shot points for Streamer templates. - - This grid override handles multi-gun acquisition where shot points may be - interleaved across guns. It calculates a dense shot_index from sparse shot_point values. - - Supported Templates: - - Seismic3DStreamerFieldRecordsTemplate: Uses sail_line, requires cable/channel - - Note: - For OBN templates, use CalculateShotIndex instead. - - Attributes: - required_parameters: Set of required parameters (None for this class). - """ - - required_parameters = None - - @property - def required_keys(self) -> set: - """Return required header keys for streamer shot index calculation.""" - return {"sail_line", "gun", "shot_point", "cable", "channel"} - - def validate( - self, - index_headers: HeaderArray, - grid_overrides: dict[str, bool | int], - template: AbstractDatasetTemplate | None = None, - ) -> None: - """Validate if this transform should run on the type of data.""" - # Import here to avoid circular imports at module load time + if config.auto_shot_wrap: + # Lazy import: see ``_resolve_synthesize_dims`` for the cycle rationale. from mdio.builder.templates.seismic_3d_streamer_field import ( # noqa: PLC0415 Seismic3DStreamerFieldRecordsTemplate, ) - if template is None: - msg = "AutoShotWrap requires a template." - raise TypeError(msg) - if not isinstance(template, Seismic3DStreamerFieldRecordsTemplate): + actual = type(template).__name__ if template is not None else "None" msg = ( - f"AutoShotWrap only supports Seismic3DStreamerFieldRecordsTemplate, " - f"got {type(template).__name__}. For OBN templates, use CalculateShotIndex." + f"auto_shot_wrap only supports Seismic3DStreamerFieldRecordsTemplate, " + f"got {actual}. For OBN templates, use calculate_shot_index." ) raise TypeError(msg) - index_names = set(index_headers.dtype.names) - if not self.required_keys.issubset(index_names): - raise GridOverrideKeysError(self.name, self.required_keys) - self.check_required_params(grid_overrides) + if config.calculate_shot_index: + from mdio.builder.templates.seismic_3d_obn import Seismic3DObnReceiverGathersTemplate # noqa: PLC0415 - def transform( - self, - index_headers: HeaderArray, - grid_overrides: dict[str, bool | int], - template: AbstractDatasetTemplate, - ) -> NDArray: - """Perform the grid transform to calculate shot_index from shot_point.""" - self.validate(index_headers, grid_overrides, template) - - line_field = "sail_line" - result = analyze_lines_for_guns(index_headers, line_field=line_field) - unique_lines, unique_guns_per_line, geom_type = result - logger.info("Ingesting streamer dataset as shot type: %s", geom_type.name) - - max_num_guns = 1 - for line_val in unique_lines: - guns = unique_guns_per_line[str(line_val)] - logger.info("%s: %s has guns: %s", line_field, line_val, guns) - max_num_guns = max(len(guns), max_num_guns) - - # Only calculate shot_index when shot points are interleaved across guns (Type B) - if geom_type == ShotGunGeometryType.B: - shot_index = np.empty(len(index_headers), dtype="uint32") - # Use .base if available (view of another array), otherwise use the array directly - base_array = index_headers.base if index_headers.base is not None else index_headers - index_headers = rfn.append_fields(base_array, "shot_index", shot_index) - - for line_val in unique_lines: - line_idxs = np.where(index_headers[line_field][:] == line_val) - index_headers["shot_index"][line_idxs] = np.floor(index_headers["shot_point"][line_idxs] / max_num_guns) - # Make shot index zero-based PER line - index_headers["shot_index"][line_idxs] -= index_headers["shot_index"][line_idxs].min() - - return index_headers + if not isinstance(template, Seismic3DObnReceiverGathersTemplate): + actual = type(template).__name__ if template is not None else "None" + msg = f"calculate_shot_index only supports Seismic3DObnReceiverGathersTemplate, got {actual}." + raise TypeError(msg) class GridOverrider: - """Executor for grid overrides. + """Legacy facade that adapts the dict-based v1.1 API onto :class:`IndexStrategyRegistry`. - We support a certain type of grid overrides, and they have to be implemented following the - ABC's in this module. - - This class applies the grid overrides if needed. + Existing callers (notably :func:`mdio.segy.utilities.get_grid_plan`) still build a + legacy ``dict`` of grid overrides and call :meth:`run`. This class translates the dict + into a typed :class:`GridOverrides`, dispatches to the appropriate + :class:`IndexStrategy`, and returns the ``(headers, names, chunksize)`` tuple shape + those callers depend on. It will be removed once all callers move to the typed API. """ def __init__(self) -> None: - self.commands = { - "AutoChannelWrap": AutoChannelWrap(), - "AutoShotWrap": AutoShotWrap(), - "CalculateShotIndex": CalculateShotIndex(), - "NonBinned": NonBinned(), - "HasDuplicates": DuplicateIndex(), - } - - self.parameters = self.get_allowed_parameters() - - def get_allowed_parameters(self) -> set: - """Get list of allowed parameters from the allowed commands.""" - parameters = set() - for command in self.commands.values(): - if command.required_parameters is None: - continue - - parameters.update(command.required_parameters) - - # Add optional parameters that are not strictly required but are valid - parameters.add("non_binned_dims") - - return parameters - - def _synthesize_obn_component( - self, - index_headers: HeaderArray, - template: AbstractDatasetTemplate | None, - ) -> HeaderArray: - """Synthesize component field for OBN template when missing from headers. - - OBN data may not have a component field in the SEG-Y headers (e.g., single-component - data). When using Seismic3DObnReceiverGathersTemplate and component is missing, - this method synthesizes it with a constant value of 1. - - Args: - index_headers: The parsed index headers from SEG-Y. - template: The dataset template. - - Returns: - Headers with component field added if applicable. - """ - # Import here to avoid circular imports at module load time - from mdio.builder.templates.seismic_3d_obn import Seismic3DObnReceiverGathersTemplate # noqa: PLC0415 - - if not isinstance(template, Seismic3DObnReceiverGathersTemplate): - return index_headers - - if "component" in index_headers.dtype.names: - return index_headers - - logger.warning( - "SEG-Y headers do not contain 'component' field required by template '%s'. " - "Synthesizing 'component' dimension with constant value 1 for all traces.", - template.name, - ) - synthetic_col = np.full(len(index_headers), 1, dtype=np.int32) - base_array = index_headers.base if index_headers.base is not None else index_headers - return rfn.append_fields(base_array, "component", synthetic_col, usemask=False) + self._registry = IndexStrategyRegistry() def run( self, index_headers: HeaderArray, index_names: Sequence[str], - grid_overrides: dict[str, bool], + grid_overrides: dict[str, Any] | None, chunksize: Sequence[int] | None = None, template: AbstractDatasetTemplate | None = None, - ) -> tuple[HeaderArray, tuple[str], tuple[int]]: - """Run grid overrides and return result.""" - # Synthesize component for OBN template if missing - index_headers = self._synthesize_obn_component(index_headers, template) + ) -> tuple[HeaderArray, tuple[str, ...], tuple[int, ...] | None]: + """Run the configured grid overrides and return updated headers/names/chunks. - for override in grid_overrides: - if override in self.parameters: - continue + Args: + index_headers: Parsed SEG-Y trace headers; structured numpy array. + index_names: Names of the index dimensions before any override is applied. + grid_overrides: Legacy dict of overrides (CamelCase keys). + chunksize: Optional chunk shape that may be expanded by overrides that add a + ``trace`` dimension. + template: Optional dataset template; used to identify coordinate fields and + to drive component synthesis for OBN. - if override not in self.commands: - raise GridOverrideUnknownError(override) + Returns: + Tuple of ``(transformed_headers, new_index_names, new_chunksize)``. The + chunksize tuple is ``None`` when the caller did not pass a chunksize. + + Raises: + GridOverrideUnknownError: When ``grid_overrides`` contains an unknown key. + GridOverrideMissingParameterError: When ``NonBinned`` is enabled without + ``chunksize`` or ``non_binned_dims``. + + Notes: + Header-precondition checks (``GridOverrideKeysError``) are delegated to + :meth:`IndexStrategy.validate_headers`; template-compatibility checks + (``TypeError``) are delegated to :func:`_validate_template_for_overrides`. + """ + grid_overrides = grid_overrides or {} + + field_names = set(GridOverrides.model_fields.keys()) + aliases = {field.alias for field in GridOverrides.model_fields.values() if field.alias} + valid_keys = field_names | aliases + for key in grid_overrides: + if key not in valid_keys: + raise GridOverrideUnknownError(key) + + config = GridOverrides.model_validate(grid_overrides) + + if config.non_binned: + missing: set[str] = set() + if config.chunksize is None: + missing.add("chunksize") + if not config.non_binned_dims: + missing.add("non_binned_dims") + if missing: + command = "NonBinned" + raise GridOverrideMissingParameterError(command, missing) + + _validate_template_for_overrides(config, template) + + synthesize_dims = _resolve_synthesize_dims(template) + strategy = self._registry.create_strategy( + grid_overrides=config, + synthesize_dims=synthesize_dims, + template=template, + ) + logger.debug("Selected grid override strategy: %s", strategy.name) - function = self.commands[override].transform - index_headers = function(index_headers, grid_overrides=grid_overrides, template=template) + strategy.validate_headers(index_headers) + new_headers = strategy.transform_headers(index_headers) - function = self.commands[override].transform_index_names - index_names = function(index_names) + new_names = list(index_names) + new_chunks = list(chunksize) if chunksize is not None else None - function = self.commands[override].transform_chunksize - chunksize = function(chunksize, grid_overrides=grid_overrides) + # Both NonBinned and HasDuplicates add a 'trace' dim at index -1; HasDuplicates + # always uses chunksize 1, NonBinned uses the user-supplied value. + if config.non_binned or config.has_duplicates: + new_names.append("trace") + if new_chunks is not None: + inserted_chunk = config.chunksize if config.non_binned else 1 + new_chunks.insert(-1, inserted_chunk) - return index_headers, index_names, chunksize + return ( + new_headers, + tuple(new_names), + tuple(new_chunks) if new_chunks is not None else None, + ) diff --git a/tests/unit/test_ingestion_index_strategies.py b/tests/unit/test_ingestion_index_strategies.py new file mode 100644 index 00000000..0cac710c --- /dev/null +++ b/tests/unit/test_ingestion_index_strategies.py @@ -0,0 +1,569 @@ +"""Unit tests for the v1.2 ingestion index strategies and the strategy registry. + +These tests exercise individual :class:`mdio.ingestion.index_strategies.IndexStrategy` +subclasses with synthetic structured numpy arrays (mimicking the shape semantics of +:class:`segy.arrays.HeaderArray`) so they remain fast and do not require any real SEG-Y +data. +""" + +from __future__ import annotations + +import numpy as np +import pytest + +from mdio.builder.template_registry import TemplateRegistry +from mdio.ingestion.index_strategies import ChannelWrappingStrategy +from mdio.ingestion.index_strategies import ComponentSynthesisStrategy +from mdio.ingestion.index_strategies import CompositeStrategy +from mdio.ingestion.index_strategies import DuplicateHandlingStrategy +from mdio.ingestion.index_strategies import IndexStrategyRegistry +from mdio.ingestion.index_strategies import NonBinnedStrategy +from mdio.ingestion.index_strategies import RegularGridStrategy +from mdio.ingestion.index_strategies import ShotWrappingStrategy +from mdio.segy.exceptions import GridOverrideKeysError +from mdio.segy.geometry import GridOverrider +from mdio.segy.geometry import GridOverrides + + +def _make_struct(data: dict[str, np.ndarray]) -> np.ndarray: + """Build a 1-D structured array from a name -> 1-D array mapping.""" + names = list(data.keys()) + arrays = [data[name] for name in names] + n = len(arrays[0]) + dtype = np.dtype([(name, arr.dtype) for name, arr in zip(names, arrays, strict=True)]) + out = np.empty(n, dtype=dtype) + for name, arr in zip(names, arrays, strict=True): + out[name] = arr + return out + + +# --------------------------------------------------------------------------- +# IndexStrategyRegistry +# --------------------------------------------------------------------------- + + +class TestIndexStrategyRegistry: + """Selection rules for :class:`IndexStrategyRegistry.create_strategy`.""" + + def test_default_returns_regular_grid(self) -> None: + """No grid overrides and no synthesis -> regular grid.""" + strategy = IndexStrategyRegistry().create_strategy(grid_overrides=None) + assert isinstance(strategy, RegularGridStrategy) + + def test_falsy_overrides_returns_regular_grid(self) -> None: + """A default ``GridOverrides()`` with all flags off must be treated as no-op.""" + strategy = IndexStrategyRegistry().create_strategy(grid_overrides=GridOverrides()) + assert isinstance(strategy, RegularGridStrategy) + + def test_synthesize_dims_only(self) -> None: + """Synthesis-only configuration returns a single ComponentSynthesisStrategy.""" + strategy = IndexStrategyRegistry().create_strategy(synthesize_dims=("component",)) + assert isinstance(strategy, ComponentSynthesisStrategy) + assert strategy.synthesize_dims == ("component",) + + def test_non_binned_only(self) -> None: + """``non_binned`` -> NonBinnedStrategy with chunksize and excluded dims wired.""" + overrides = GridOverrides(non_binned=True, chunksize=64, non_binned_dims=["channel"]) + strategy = IndexStrategyRegistry().create_strategy(grid_overrides=overrides) + assert isinstance(strategy, NonBinnedStrategy) + assert strategy.chunksize == 64 + assert strategy.non_binned_dims == ("channel",) + + def test_has_duplicates_only(self) -> None: + """``has_duplicates`` -> DuplicateHandlingStrategy.""" + strategy = IndexStrategyRegistry().create_strategy(grid_overrides=GridOverrides(has_duplicates=True)) + assert isinstance(strategy, DuplicateHandlingStrategy) + + def test_non_binned_wins_over_has_duplicates(self) -> None: + """Both flags set -> NonBinned wins (matches v1.x semantics).""" + overrides = GridOverrides( + non_binned=True, + chunksize=8, + non_binned_dims=["channel"], + has_duplicates=True, + ) + strategy = IndexStrategyRegistry().create_strategy(grid_overrides=overrides) + assert isinstance(strategy, NonBinnedStrategy) + + def test_composite_with_channel_wrap(self) -> None: + """Channel wrap + non-binned -> CompositeStrategy ordered for safe layering.""" + overrides = GridOverrides( + auto_channel_wrap=True, + non_binned=True, + chunksize=64, + non_binned_dims=["channel"], + ) + strategy = IndexStrategyRegistry().create_strategy(grid_overrides=overrides) + assert isinstance(strategy, CompositeStrategy) + assert [s.name for s in strategy.strategies] == ["ChannelWrappingStrategy", "NonBinnedStrategy"] + + def test_synthesize_dims_runs_first(self) -> None: + """Synthesis must run before any strategy that may depend on the synthesized field.""" + overrides = GridOverrides(calculate_shot_index=True) + strategy = IndexStrategyRegistry().create_strategy( + grid_overrides=overrides, + synthesize_dims=("component",), + ) + assert isinstance(strategy, CompositeStrategy) + assert strategy.strategies[0].name == "ComponentSynthesisStrategy" + + def test_auto_shot_wrap_uses_sail_line(self) -> None: + """``auto_shot_wrap`` is the streamer flag: line_field=sail_line, no always-calc.""" + overrides = GridOverrides(auto_shot_wrap=True) + strategy = IndexStrategyRegistry().create_strategy(grid_overrides=overrides) + assert isinstance(strategy, ShotWrappingStrategy) + assert strategy.line_field == "sail_line" + assert strategy.always_calculate is False + + def test_calculate_shot_index_uses_shot_line(self) -> None: + """``calculate_shot_index`` is the OBN flag: line_field=shot_line, always-calc.""" + overrides = GridOverrides(calculate_shot_index=True) + strategy = IndexStrategyRegistry().create_strategy(grid_overrides=overrides) + assert isinstance(strategy, ShotWrappingStrategy) + assert strategy.line_field == "shot_line" + assert strategy.always_calculate is True + + def test_template_coord_names_propagate_to_duplicate_strategy(self) -> None: + """Template coordinates flow into the duplicate-handling strategy as exclusions.""" + template = TemplateRegistry().get("StreamerShotGathers3D") + strategy = IndexStrategyRegistry().create_strategy( + grid_overrides=GridOverrides(has_duplicates=True), + template=template, + ) + assert isinstance(strategy, DuplicateHandlingStrategy) + assert strategy.coord_fields == frozenset(template.coordinate_names) + + +# --------------------------------------------------------------------------- +# RegularGridStrategy +# --------------------------------------------------------------------------- + + +class TestRegularGridStrategy: + """Pass-through strategy used when no overrides are required.""" + + def test_returns_unique_dims(self) -> None: + """Each dim name maps to its sorted unique values.""" + headers = _make_struct( + { + "inline": np.array([1, 1, 2, 2], dtype=np.int32), + "crossline": np.array([10, 11, 10, 11], dtype=np.int32), + } + ) + dims = RegularGridStrategy().compute_dimensions(headers, ("inline", "crossline")) + assert [d.name for d in dims] == ["inline", "crossline"] + np.testing.assert_array_equal(dims[0].coords, [1, 2]) + np.testing.assert_array_equal(dims[1].coords, [10, 11]) + + def test_unknown_dim_silently_skipped(self) -> None: + """Names absent from the header dtype are dropped, matching v1.1 behavior.""" + headers = _make_struct({"inline": np.array([1, 2], dtype=np.int32)}) + dims = RegularGridStrategy().compute_dimensions(headers, ("inline", "missing")) + assert [d.name for d in dims] == ["inline"] + + def test_default_required_keys_empty_and_validate_is_noop(self) -> None: + """Strategies that don't override ``required_keys`` must not block any headers.""" + headers = _make_struct({"inline": np.array([1], dtype=np.int32)}) + strategy = RegularGridStrategy() + assert strategy.required_keys == frozenset() + strategy.validate_headers(headers) # must not raise + + +# --------------------------------------------------------------------------- +# DuplicateHandlingStrategy +# --------------------------------------------------------------------------- + + +class TestDuplicateHandlingStrategy: + """Counter-based ``trace`` field for duplicated index tuples.""" + + def test_appends_per_tuple_counter(self) -> None: + """Each (inline, crossline) tuple gets a 1-based duplicate counter.""" + headers = _make_struct( + { + "inline": np.array([1, 1, 1, 2], dtype=np.int32), + "crossline": np.array([10, 10, 11, 10], dtype=np.int32), + } + ) + out = DuplicateHandlingStrategy().transform_headers(headers) + assert "trace" in out.dtype.names + # (1,10) appears twice -> {1, 2}; (1,11) once -> {1}; (2,10) once -> {1}. + np.testing.assert_array_equal(np.sort(out["trace"]), [1, 1, 1, 2]) + + def test_coord_fields_excluded_from_grouping(self) -> None: + """Coordinate fields must not influence the duplicate counter.""" + headers = _make_struct( + { + "inline": np.array([1, 1, 2, 2], dtype=np.int32), + "crossline": np.array([10, 10, 11, 11], dtype=np.int32), + # 'cdp_x' varies per row but is a coord, so the counter must ignore it. + "cdp_x": np.array([100.0, 100.5, 200.0, 200.5], dtype=np.float64), + } + ) + strategy = DuplicateHandlingStrategy(coord_fields=("cdp_x",)) + out = strategy.transform_headers(headers) + # Each (inline, crossline) tuple appears twice -> counters cycle 1..2. + np.testing.assert_array_equal(out["trace"], [1, 2, 1, 2]) + + def test_excluded_fields_dropped_from_grouping(self) -> None: + """Caller-declared excluded fields (e.g. non_binned_dims) are not part of the key.""" + headers = _make_struct( + { + "shot_point": np.array([1, 1, 1], dtype=np.int32), + "channel": np.array([1, 2, 3], dtype=np.int32), + } + ) + strategy = DuplicateHandlingStrategy(excluded_fields=("channel",)) + out = strategy.transform_headers(headers) + # Without excluding 'channel', counters would all be 1; excluding it groups by + # shot_point alone so each row in the same shot_point gets a fresh counter. + np.testing.assert_array_equal(out["trace"], [1, 2, 3]) + + +# --------------------------------------------------------------------------- +# NonBinnedStrategy +# --------------------------------------------------------------------------- + + +class TestNonBinnedStrategy: + """``NonBinned`` is the duplicate counter wired with explicit collapse dims.""" + + def test_chunksize_and_dims_recorded(self) -> None: + """Constructor stores both for the ``GridOverrider`` shim to use.""" + strategy = NonBinnedStrategy(chunksize=4, non_binned_dims=("channel",)) + assert strategy.chunksize == 4 + assert strategy.non_binned_dims == ("channel",) + + def test_collapse_dim_excluded_from_counter(self) -> None: + """The non-binned dim must NOT participate in the duplicate counter grouping.""" + headers = _make_struct( + { + "shot_point": np.array([1, 1, 2, 2], dtype=np.int32), + "cable": np.array([1, 2, 1, 2], dtype=np.int32), + "channel": np.array([10, 11, 10, 11], dtype=np.int32), + } + ) + strategy = NonBinnedStrategy(chunksize=4, non_binned_dims=("channel",)) + out = strategy.transform_headers(headers) + # Each (shot_point, cable) tuple appears once -> counters all equal 1. + assert "trace" in out.dtype.names + np.testing.assert_array_equal(out["trace"], [1, 1, 1, 1]) + + def test_coord_fields_also_excluded(self) -> None: + """Both template coords and non_binned_dims are removed from the grouping key.""" + headers = _make_struct( + { + "shot_point": np.array([1, 1, 2, 2], dtype=np.int32), + "channel": np.array([1, 2, 1, 2], dtype=np.int32), + "cdp_x": np.array([10.0, 20.0, 30.0, 40.0], dtype=np.float64), + } + ) + strategy = NonBinnedStrategy( + chunksize=4, + non_binned_dims=("channel",), + coord_fields=("cdp_x",), + ) + out = strategy.transform_headers(headers) + # Grouping is by shot_point only -> each shot_point has 2 rows -> counters {1, 2}. + np.testing.assert_array_equal(out["trace"], [1, 2, 1, 2]) + + +# --------------------------------------------------------------------------- +# ChannelWrappingStrategy +# --------------------------------------------------------------------------- + + +class TestChannelWrappingStrategy: + """Streamer Type-A vs Type-B detection and channel rebasing.""" + + def test_type_a_pass_through(self) -> None: + """Type A (per-cable channel numbering with overlap) -> headers untouched.""" + headers = _make_struct( + { + "cable": np.array([1, 1, 2, 2], dtype=np.int32), + "channel": np.array([1, 2, 1, 2], dtype=np.int32), + } + ) + out = ChannelWrappingStrategy().transform_headers(headers) + np.testing.assert_array_equal(out["channel"], [1, 2, 1, 2]) + + def test_type_b_renumbers_per_cable(self) -> None: + """Type B (sequential numbering across cables) -> rebased to 1..N per cable.""" + headers = _make_struct( + { + "cable": np.array([1, 1, 2, 2], dtype=np.int32), + "channel": np.array([1, 2, 3, 4], dtype=np.int32), + } + ) + out = ChannelWrappingStrategy().transform_headers(headers) + # Cable 1: 1,2 -> 1,2; cable 2: 3,4 -> 1,2. + np.testing.assert_array_equal(out["channel"], [1, 2, 1, 2]) + + def test_required_keys(self) -> None: + """Channel wrap declares the cable-channel-shot triplet as preconditions.""" + assert ChannelWrappingStrategy().required_keys == frozenset({"shot_point", "cable", "channel"}) + + def test_validate_headers_raises_when_field_missing(self) -> None: + """Missing ``cable`` -> :class:`GridOverrideKeysError`, not a deeper numpy crash.""" + headers = _make_struct( + { + "shot_point": np.array([1, 2], dtype=np.int32), + "channel": np.array([1, 2], dtype=np.int32), + } + ) + strategy = ChannelWrappingStrategy() + with pytest.raises(GridOverrideKeysError, match="ChannelWrappingStrategy"): + strategy.validate_headers(headers) + + +# --------------------------------------------------------------------------- +# ShotWrappingStrategy +# --------------------------------------------------------------------------- + + +class TestShotWrappingStrategy: + """Shot-index derivation for both streamer (sail_line) and OBN (shot_line).""" + + def test_type_b_streamer_emits_shot_index(self) -> None: + """Sail line 1 with two interleaved guns -> dense shot_index per line.""" + headers = _make_struct( + { + "sail_line": np.array([1, 1, 1, 1], dtype=np.int32), + "gun": np.array([1, 2, 1, 2], dtype=np.int32), + "shot_point": np.array([1, 2, 3, 4], dtype=np.int32), + } + ) + out = ShotWrappingStrategy(line_field="sail_line").transform_headers(headers) + assert "shot_index" in out.dtype.names + # floor(shot_point / 2) zero-based per line: 0, 1, 1, 2. + np.testing.assert_array_equal(out["shot_index"], [0, 1, 1, 2]) + + def test_type_a_streamer_skipped_without_always_calculate(self) -> None: + """Type A streamer geometry produces no shot_index unless always_calculate=True.""" + headers = _make_struct( + { + "sail_line": np.array([1, 1, 1, 1, 1, 1], dtype=np.int32), + "gun": np.array([1, 1, 1, 2, 2, 2], dtype=np.int32), + "shot_point": np.array([1, 2, 3, 1, 2, 3], dtype=np.int32), + } + ) + out = ShotWrappingStrategy(line_field="sail_line").transform_headers(headers) + assert "shot_index" not in out.dtype.names + + def test_type_a_obn_always_calculates(self) -> None: + """OBN forces shot_index calculation; Type A uses dense per-line searchsorted.""" + headers = _make_struct( + { + "shot_line": np.array([1, 1, 1, 1, 1, 1, 1, 1], dtype=np.int32), + "gun": np.array([1, 1, 1, 1, 2, 2, 2, 2], dtype=np.int32), + "shot_point": np.array([1, 2, 3, 4, 1, 2, 3, 4], dtype=np.int32), + } + ) + out = ShotWrappingStrategy(line_field="shot_line", always_calculate=True).transform_headers(headers) + assert "shot_index" in out.dtype.names + np.testing.assert_array_equal(out["shot_index"], [0, 1, 2, 3, 0, 1, 2, 3]) + + def test_obn_multiline_type_a_processes_all_lines(self) -> None: + """Regression: Type A detection on line 1 must not mask later lines.""" + headers = _make_struct( + { + "shot_line": np.array([1, 1, 2, 2, 3, 3], dtype=np.int32), + "gun": np.array([1, 2, 1, 2, 1, 2], dtype=np.int32), + "shot_point": np.array([1, 2, 1, 2, 1, 2], dtype=np.int32), + } + ) + out = ShotWrappingStrategy(line_field="shot_line", always_calculate=True).transform_headers(headers) + # Each line gets independent dense per-line indices. + np.testing.assert_array_equal(out["shot_index"], [0, 1, 0, 1, 0, 1]) + + def test_required_keys_sail_line(self) -> None: + """Streamer variant requires the streamer-cable-channel headers in addition to shot fields.""" + strategy = ShotWrappingStrategy(line_field="sail_line") + assert strategy.required_keys == frozenset({"sail_line", "gun", "shot_point", "cable", "channel"}) + + def test_required_keys_shot_line(self) -> None: + """OBN variant deliberately omits cable/channel from required keys.""" + strategy = ShotWrappingStrategy(line_field="shot_line", always_calculate=True) + assert strategy.required_keys == frozenset({"shot_line", "gun", "shot_point"}) + + def test_validate_headers_raises_for_obn_when_missing_gun(self) -> None: + """Missing ``gun`` on the OBN path -> :class:`GridOverrideKeysError`.""" + headers = _make_struct( + { + "shot_line": np.array([1, 1], dtype=np.int32), + "shot_point": np.array([1, 2], dtype=np.int32), + } + ) + strategy = ShotWrappingStrategy(line_field="shot_line", always_calculate=True) + with pytest.raises(GridOverrideKeysError, match="ShotWrappingStrategy"): + strategy.validate_headers(headers) + + +# --------------------------------------------------------------------------- +# ComponentSynthesisStrategy +# --------------------------------------------------------------------------- + + +class TestComponentSynthesisStrategy: + """Synthesize template-required dimensions when missing from the SEG-Y headers.""" + + def test_synthesizes_missing_field(self) -> None: + """Missing field is added with constant value 1 for every row.""" + headers = _make_struct({"receiver": np.array([1, 2, 3], dtype=np.int32)}) + out = ComponentSynthesisStrategy(("component",)).transform_headers(headers) + assert "component" in out.dtype.names + np.testing.assert_array_equal(out["component"], [1, 1, 1]) + + def test_existing_field_left_alone(self) -> None: + """If the field is already present, the existing values are preserved.""" + headers = _make_struct( + { + "receiver": np.array([1, 2, 3], dtype=np.int32), + "component": np.array([2, 3, 4], dtype=np.uint8), + } + ) + out = ComponentSynthesisStrategy(("component",)).transform_headers(headers) + np.testing.assert_array_equal(out["component"], [2, 3, 4]) + + +# --------------------------------------------------------------------------- +# CompositeStrategy +# --------------------------------------------------------------------------- + + +class TestCompositeStrategy: + """Strategy chaining with deterministic execution order.""" + + def test_requires_at_least_one_strategy(self) -> None: + """An empty composite is a programming error and must raise.""" + with pytest.raises(ValueError, match="at least one strategy"): + CompositeStrategy([]) + + def test_strategies_run_in_order(self) -> None: + """Synthesis must produce the field that the next strategy then duplicates.""" + headers = _make_struct( + { + "shot_point": np.array([1, 1, 2, 2], dtype=np.int32), + "channel": np.array([1, 2, 1, 2], dtype=np.int32), + } + ) + composite = CompositeStrategy( + [ + ComponentSynthesisStrategy(("component",)), + NonBinnedStrategy(chunksize=4, non_binned_dims=("channel",)), + ] + ) + out = composite.transform_headers(headers) + assert "component" in out.dtype.names + assert "trace" in out.dtype.names + + def test_progressive_validation_raises_for_first_unsatisfied_child(self) -> None: + """Composite must surface a child's required-keys failure as :class:`GridOverrideKeysError`. + + Channel wrap needs ``cable``; ``RegularGridStrategy`` runs first and is a no-op, + so the composite reaches channel wrap with the same incomplete headers and must + raise rather than crash inside numpy. + """ + headers = _make_struct( + { + "shot_point": np.array([1, 2], dtype=np.int32), + "channel": np.array([1, 2], dtype=np.int32), + } + ) + composite = CompositeStrategy([RegularGridStrategy(), ChannelWrappingStrategy()]) + with pytest.raises(GridOverrideKeysError, match="ChannelWrappingStrategy"): + composite.transform_headers(headers) + + +# --------------------------------------------------------------------------- +# GridOverrider template-compatibility checks (shim level) +# --------------------------------------------------------------------------- + + +class TestGridOverriderTemplateValidation: + """Restore v1.1's template-type guards for shot-wrapping overrides. + + ``AutoShotWrap`` was streamer-only and ``CalculateShotIndex`` was OBN-only; pairing + either with the wrong template silently produced incorrect shot indices. The shim + raises :class:`TypeError` early so misconfigurations fail loudly at the API boundary. + """ + + def test_auto_shot_wrap_rejects_obn_template(self) -> None: + """Streamer override + OBN template -> TypeError, no transform run.""" + headers = _make_struct( + { + "sail_line": np.array([1, 1], dtype=np.int32), + "gun": np.array([1, 2], dtype=np.int32), + "shot_point": np.array([1, 2], dtype=np.int32), + "cable": np.array([1, 1], dtype=np.int32), + "channel": np.array([1, 2], dtype=np.int32), + } + ) + template = TemplateRegistry().get("ObnReceiverGathers3D") + with pytest.raises(TypeError, match="auto_shot_wrap.*Seismic3DStreamerFieldRecordsTemplate"): + GridOverrider().run( + headers, + index_names=("sail_line", "gun", "shot_point"), + grid_overrides={"AutoShotWrap": True}, + template=template, + ) + + def test_calculate_shot_index_rejects_streamer_template(self) -> None: + """OBN override + streamer template -> TypeError.""" + headers = _make_struct( + { + "shot_line": np.array([1, 1], dtype=np.int32), + "gun": np.array([1, 2], dtype=np.int32), + "shot_point": np.array([1, 2], dtype=np.int32), + } + ) + template = TemplateRegistry().get("StreamerShotGathers3D") + with pytest.raises(TypeError, match="calculate_shot_index.*Seismic3DObnReceiverGathersTemplate"): + GridOverrider().run( + headers, + index_names=("shot_line", "gun", "shot_point"), + grid_overrides={"CalculateShotIndex": True}, + template=template, + ) + + def test_shot_wrap_without_template_raises(self) -> None: + """Omitting the template is the same misconfiguration as passing the wrong type.""" + headers = _make_struct( + { + "sail_line": np.array([1], dtype=np.int32), + "gun": np.array([1], dtype=np.int32), + "shot_point": np.array([1], dtype=np.int32), + "cable": np.array([1], dtype=np.int32), + "channel": np.array([1], dtype=np.int32), + } + ) + with pytest.raises(TypeError, match="auto_shot_wrap"): + GridOverrider().run( + headers, + index_names=("sail_line", "gun", "shot_point"), + grid_overrides={"AutoShotWrap": True}, + template=None, + ) + + def test_header_keys_missing_raises_via_shim(self) -> None: + """``AutoShotWrap`` + correct template but missing ``cable`` -> :class:`GridOverrideKeysError`. + + Uses the ``StreamerFieldRecords3D`` template (the only template v1.1's + ``AutoShotWrap`` accepts) so the template-type check passes and we exercise the + header-key validator instead. + """ + headers = _make_struct( + { + "sail_line": np.array([1, 1], dtype=np.int32), + "gun": np.array([1, 2], dtype=np.int32), + "shot_point": np.array([1, 2], dtype=np.int32), + } + ) + template = TemplateRegistry().get("StreamerFieldRecords3D") + with pytest.raises(GridOverrideKeysError, match="ShotWrappingStrategy"): + GridOverrider().run( + headers, + index_names=("sail_line", "gun", "shot_point"), + grid_overrides={"AutoShotWrap": True}, + template=template, + ) From ba2d7114e5a69e2a04b84140dc26f560a8b44f06 Mon Sep 17 00:00:00 2001 From: BrianMichell Date: Wed, 27 May 2026 16:28:31 +0000 Subject: [PATCH 2/2] pre-commit --- src/mdio/ingestion/index_strategies.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/mdio/ingestion/index_strategies.py b/src/mdio/ingestion/index_strategies.py index bec58b50..ede35e39 100644 --- a/src/mdio/ingestion/index_strategies.py +++ b/src/mdio/ingestion/index_strategies.py @@ -87,9 +87,7 @@ def compute_dimensions(self, headers: HeaderArray, dim_names: tuple[str, ...]) - ``GridOverrider`` post-processing step. """ return [ - Dimension(coords=np.unique(headers[name]), name=name) - for name in dim_names - if name in headers.dtype.names + Dimension(coords=np.unique(headers[name]), name=name) for name in dim_names if name in headers.dtype.names ] @property