diff --git a/src/mdio/ingestion/index_strategies.py b/src/mdio/ingestion/index_strategies.py new file mode 100644 index 00000000..ede35e39 --- /dev/null +++ b/src/mdio/ingestion/index_strategies.py @@ -0,0 +1,418 @@ +"""Composable index strategies for transforming SEG-Y headers into indexable dimensions. + +This module replaces the monolithic :class:`mdio.segy.geometry.GridOverrider` command +dispatch with a small set of single-responsibility :class:`IndexStrategy` objects that can +be composed via :class:`CompositeStrategy`. + +Strategies are selected by :class:`IndexStrategyRegistry` from the typed +:class:`mdio.segy.geometry.GridOverrides` configuration plus optional template hints. The +public contract preserved by :class:`mdio.segy.geometry.GridOverrider` (a thin shim around +this module) keeps end-to-end ingestion behavior identical to v1.1.x. +""" + +from __future__ import annotations + +import logging +from abc import ABC +from abc import abstractmethod +from typing import TYPE_CHECKING + +import numpy as np +from numpy.lib import recfunctions as rfn + +from mdio.core import Dimension +from mdio.ingestion.segy.header_analysis import ShotGunGeometryType +from mdio.ingestion.segy.header_analysis import StreamerShotGeometryType +from mdio.ingestion.segy.header_analysis import analyze_lines_for_guns +from mdio.ingestion.segy.header_analysis import analyze_non_indexed_headers +from mdio.ingestion.segy.header_analysis import analyze_streamer_headers +from mdio.segy.exceptions import GridOverrideKeysError + +if TYPE_CHECKING: + from collections.abc import Iterable + + from numpy.typing import DTypeLike + from segy.arrays import HeaderArray + + from mdio.builder.templates.base import AbstractDatasetTemplate + from mdio.segy.geometry import GridOverrides + +logger = logging.getLogger(__name__) + + +class IndexStrategy(ABC): + """Abstract base for header indexing strategies. + + A strategy transforms a raw header array (e.g. add or rebase fields) and computes + the resulting :class:`Dimension` list. Strategies are composable through + :class:`CompositeStrategy`. The default :meth:`compute_dimensions` builds dimensions + from unique header values; subclasses override only when they need different + semantics (currently just :class:`CompositeStrategy`). + + Subclasses with header preconditions set :attr:`required_keys` so the shim and + :class:`CompositeStrategy` can raise :class:`GridOverrideKeysError` with a clear + "missing fields X, Y, Z" message before numpy fails on a deeper key lookup. + """ + + @property + def required_keys(self) -> frozenset[str]: + """Header field names that must be present before :meth:`transform_headers` runs. + + Empty by default. Override on subclasses whose transform indexes specific fields. + """ + return frozenset() + + def validate_headers(self, headers: HeaderArray) -> None: + """Raise :class:`GridOverrideKeysError` if any required header field is missing. + + Callers (the :class:`mdio.segy.geometry.GridOverrider` shim and + :class:`CompositeStrategy`) invoke this before each transform so failure points + at the user-facing override name rather than at a numpy structured-array key error. + """ + required = self.required_keys + if not required: + return + present = set(headers.dtype.names or ()) + if not required.issubset(present): + raise GridOverrideKeysError(self.name, set(required)) + + @abstractmethod + def transform_headers(self, headers: HeaderArray) -> HeaderArray: + """Return a new header array with this strategy's transformation applied.""" + + def compute_dimensions(self, headers: HeaderArray, dim_names: tuple[str, ...]) -> list[Dimension]: + """Build one :class:`Dimension` per requested name from unique header values. + + Names absent from ``headers.dtype.names`` are silently skipped, matching the v1.1 + ``GridOverrider`` post-processing step. + """ + return [ + Dimension(coords=np.unique(headers[name]), name=name) for name in dim_names if name in headers.dtype.names + ] + + @property + def name(self) -> str: + """Return the strategy's class name; useful for logging and tests.""" + return self.__class__.__name__ + + +class RegularGridStrategy(IndexStrategy): + """Default strategy: headers untouched, dimensions are unique values per name.""" + + def transform_headers(self, headers: HeaderArray) -> HeaderArray: + """Pass headers through unchanged.""" + return headers + + +class DuplicateHandlingStrategy(IndexStrategy): + """Disambiguate duplicate index tuples by appending a per-tuple ``trace`` counter. + + Mirrors the v1.1 ``DuplicateIndex`` command: count occurrences of each unique + combination of dimension fields (excluding coordinate fields and any caller-declared + ``excluded_fields``), then attach the resulting 1-based counter as a new ``trace`` field + on the original headers. + + Args: + coord_fields: Names of header fields that are template coordinates and must be + excluded from the dimension grouping (their values vary independently of the + grid index). + excluded_fields: Additional fields to exclude from grouping. Used by + :class:`NonBinnedStrategy` to keep the explicit ``non_binned_dims`` from + polluting the per-tuple counter. + dtype: NumPy dtype for the appended ``trace`` counter. + """ + + def __init__( + self, + coord_fields: Iterable[str] = (), + excluded_fields: Iterable[str] = (), + dtype: DTypeLike = np.int16, + ) -> None: + self.coord_fields = frozenset(coord_fields) + self.excluded_fields = frozenset(excluded_fields) + self.dtype = dtype + + def _dim_fields(self, headers: HeaderArray) -> list[str]: + """Header field names that participate in the duplicate grouping.""" + return [ + name + for name in headers.dtype.names + if name != "trace" and name not in self.coord_fields and name not in self.excluded_fields + ] + + def transform_headers(self, headers: HeaderArray) -> HeaderArray: + """Append a per-(dim-tuple) ``trace`` counter to ``headers``.""" + dim_fields = self._dim_fields(headers) + dim_headers = headers[dim_fields] if dim_fields else headers + with_trace = analyze_non_indexed_headers(dim_headers, dtype=self.dtype) + + if with_trace is None or "trace" not in with_trace.dtype.names: + return headers + + trace_values = np.array(with_trace["trace"]) + return rfn.append_fields(headers, "trace", trace_values, usemask=False) + + +class NonBinnedStrategy(DuplicateHandlingStrategy): + """Collapse selected non-binned dimensions into a single ``trace`` dimension. + + Inherits the per-tuple ``trace`` counter from :class:`DuplicateHandlingStrategy` and + captures ``chunksize`` so the :class:`mdio.segy.geometry.GridOverrider` shim can size + the new ``trace`` chunk correctly. + + Args: + chunksize: Chunk size to assign to the ``trace`` dimension. The strategy itself + does not apply this value; the shim uses it when rewriting the chunksize tuple. + non_binned_dims: Header fields collapsed into ``trace``. They are excluded from + the duplicate grouping so the counter only varies along the remaining dims. + coord_fields: Template coordinate names to exclude from grouping. + dtype: NumPy dtype for the appended ``trace`` counter. + """ + + def __init__( + self, + chunksize: int, + non_binned_dims: Iterable[str], + coord_fields: Iterable[str] = (), + dtype: DTypeLike = np.int16, + ) -> None: + non_binned_dims = tuple(non_binned_dims) + super().__init__( + coord_fields=coord_fields, + excluded_fields=non_binned_dims, + dtype=dtype, + ) + self.chunksize = chunksize + self.non_binned_dims = non_binned_dims + + +class ChannelWrappingStrategy(IndexStrategy): + """Renumber streamer channels per cable when geometry is Type B. + + Detects whether channel numbering is per-cable (Type A; pass-through) or sequential + across cables (Type B; rebase to 1..N per cable). Mirrors the v1.1 ``AutoChannelWrap`` + command. + """ + + @property + def required_keys(self) -> frozenset[str]: + """Streamer channel detection needs the cable-channel-shot triplet.""" + return frozenset({"shot_point", "cable", "channel"}) + + def transform_headers(self, headers: HeaderArray) -> HeaderArray: + """Rebase ``channel`` per cable for Type B geometry; pass through for Type A.""" + unique_cables, cable_chan_min, cable_chan_max, geom_type = analyze_streamer_headers(headers) + + logger.info("Ingesting dataset as %s", geom_type.name) + for cable, chan_min, chan_max in zip(unique_cables, cable_chan_min, cable_chan_max, strict=True): + logger.info("Cable: %s has min chan: %s and max chan: %s", cable, chan_min, chan_max) + + if geom_type != StreamerShotGeometryType.B: + return headers + + for idx, cable in enumerate(unique_cables): + cable_idxs = np.where(headers["cable"][:] == cable) + headers["channel"][cable_idxs] = headers["channel"][cable_idxs] - cable_chan_min[idx] + 1 + + return headers + + +class ShotWrappingStrategy(IndexStrategy): + """Derive a dense ``shot_index`` field from sparse or interleaved ``shot_point`` values. + + Replaces the v1.1 ``AutoShotWrap`` (streamer) and ``CalculateShotIndex`` (OBN) + commands. The two callers differ only in: + + * ``line_field`` -- ``sail_line`` for streamer, ``shot_line`` for OBN. + * ``always_calculate`` -- streamer skips the transform entirely for Type A geometries + (per-gun shot points are already dense), OBN always emits ``shot_index`` because the + template declares it as a calculated dimension. + + Args: + line_field: Header field used to group shots into independent lines. + always_calculate: When ``True``, emit ``shot_index`` for every geometry type. For + Type A this builds a 0-based ``np.searchsorted`` over sorted unique shot + points per line. + """ + + _STREAMER_LINE_FIELD = "sail_line" + + def __init__(self, line_field: str, always_calculate: bool = False) -> None: + self.line_field = line_field + self.always_calculate = always_calculate + + @property + def required_keys(self) -> frozenset[str]: + """Streamer (``sail_line``) needs cable+channel too; OBN (``shot_line``) does not. + + Mirrors the v1.1 split between ``AutoShotWrap.required_keys`` and + ``CalculateShotIndex.required_keys``. + """ + base = {self.line_field, "gun", "shot_point"} + if self.line_field == self._STREAMER_LINE_FIELD: + base |= {"cable", "channel"} + return frozenset(base) + + def transform_headers(self, headers: HeaderArray) -> HeaderArray: + """Append ``shot_index`` derived from ``shot_point`` per line.""" + unique_lines, unique_guns_per_line, geom_type = analyze_lines_for_guns(headers, line_field=self.line_field) + + logger.info("Ingesting dataset as shot type: %s (line_field=%s)", geom_type.name, self.line_field) + + max_num_guns = 1 + for line_val in unique_lines: + guns = unique_guns_per_line[str(line_val)] + logger.info("%s: %s has guns: %s", self.line_field, line_val, guns) + max_num_guns = max(len(guns), max_num_guns) + + if geom_type == ShotGunGeometryType.A and not self.always_calculate: + return headers + + shot_index = np.empty(len(headers), dtype="uint32") + # `.base` is None for non-view arrays; fall back to the array itself. + base_array = headers.base if headers.base is not None else headers + headers = rfn.append_fields(base_array, "shot_index", shot_index, usemask=False) + + if geom_type == ShotGunGeometryType.B: + for line_val in unique_lines: + line_idxs = np.where(headers[self.line_field][:] == line_val) + headers["shot_index"][line_idxs] = np.floor(headers["shot_point"][line_idxs] / max_num_guns) + headers["shot_index"][line_idxs] -= headers["shot_index"][line_idxs].min() + else: + for line_val in unique_lines: + line_idxs = np.where(headers[self.line_field][:] == line_val) + shot_points = headers["shot_point"][line_idxs] + unique_shots = np.unique(shot_points) + headers["shot_index"][line_idxs] = np.searchsorted(unique_shots, shot_points) + + return headers + + +class ComponentSynthesisStrategy(IndexStrategy): + """Synthesize template-required dimension fields that are absent from the headers. + + Currently used to fill the ``component`` dimension with a constant value of 1 for + OBN templates whose SEG-Y spec does not include a component header. Mirrors the + v1.1 ``GridOverrider._synthesize_obn_component`` behavior. + + Args: + synthesize_dims: Names of dimension fields to synthesize when missing. + """ + + def __init__(self, synthesize_dims: Iterable[str]) -> None: + self.synthesize_dims = tuple(synthesize_dims) + + def transform_headers(self, headers: HeaderArray) -> HeaderArray: + """Append constant-1 fields for any synthesize_dims not already present.""" + for dim in self.synthesize_dims: + if dim in headers.dtype.names: + continue + logger.warning( + "SEG-Y headers do not contain '%s' field required by template; " + "synthesizing dimension with constant value 1 for all traces.", + dim, + ) + comp_array = np.ones(len(headers), dtype=np.uint8) + base_array = headers.base if headers.base is not None else headers + headers = rfn.append_fields(base_array, dim, comp_array, usemask=False) + return headers + + +class CompositeStrategy(IndexStrategy): + """Apply multiple strategies in order; each transform feeds the next. + + Dimension computation is delegated to the final strategy on the assumption it is + aware of all preceding header transformations. + """ + + def __init__(self, strategies: list[IndexStrategy]) -> None: + if not strategies: + msg = "CompositeStrategy requires at least one strategy" + raise ValueError(msg) + self.strategies = strategies + + def transform_headers(self, headers: HeaderArray) -> HeaderArray: + """Validate then run each child strategy's transform in sequence. + + Each step re-validates against the running header array, so a strategy that + produces a field (e.g. :class:`ComponentSynthesisStrategy` adding ``component``) + can satisfy a later strategy's :attr:`required_keys`. + """ + result = headers + for strategy in self.strategies: + logger.debug("Applying strategy: %s", strategy.name) + strategy.validate_headers(result) + result = strategy.transform_headers(result) + return result + + def compute_dimensions(self, headers: HeaderArray, dim_names: tuple[str, ...]) -> list[Dimension]: + """Delegate to the final child strategy.""" + return self.strategies[-1].compute_dimensions(headers, dim_names) + + +class IndexStrategyRegistry: + """Picks the right :class:`IndexStrategy` from grid overrides + template hints.""" + + def create_strategy( + self, + grid_overrides: GridOverrides | None = None, + synthesize_dims: tuple[str, ...] = (), + template: AbstractDatasetTemplate | None = None, + ) -> IndexStrategy: + """Build a strategy (possibly composite) for the given config. + + Strategy ordering, when multiple flags are set, mirrors v1.1 behavior: + + 1. ``ComponentSynthesisStrategy`` (so later strategies can rely on the synthesized + field being present). + 2. ``ChannelWrappingStrategy`` (rebases ``channel`` before any shot calculation). + 3. ``ShotWrappingStrategy`` for ``auto_shot_wrap`` (streamer; ``sail_line``). + 4. ``ShotWrappingStrategy`` for ``calculate_shot_index`` (OBN; ``shot_line``, + ``always_calculate=True``). + 5. ``NonBinnedStrategy`` or ``DuplicateHandlingStrategy`` (mutually exclusive; + ``non_binned`` wins when both are set, matching v1.x semantics). + + Args: + grid_overrides: Typed grid override configuration, or ``None`` for no + user-driven overrides. + synthesize_dims: Dimensions to synthesize if missing (e.g. ``component``). + template: Optional dataset template; used to look up coordinate names so + duplicate-handling counters group on dimension fields only. + + Returns: + A single :class:`IndexStrategy` instance. Returns + :class:`RegularGridStrategy` when no overrides and no synthesis are required. + """ + strategies: list[IndexStrategy] = [] + + if synthesize_dims: + strategies.append(ComponentSynthesisStrategy(synthesize_dims)) + + coord_fields: tuple[str, ...] = template.coordinate_names if template is not None else () + + if grid_overrides: + if grid_overrides.auto_channel_wrap: + strategies.append(ChannelWrappingStrategy()) + + if grid_overrides.auto_shot_wrap: + strategies.append(ShotWrappingStrategy(line_field="sail_line", always_calculate=False)) + + if grid_overrides.calculate_shot_index: + strategies.append(ShotWrappingStrategy(line_field="shot_line", always_calculate=True)) + + if grid_overrides.non_binned: + strategies.append( + NonBinnedStrategy( + chunksize=grid_overrides.chunksize, + non_binned_dims=grid_overrides.non_binned_dims or (), + coord_fields=coord_fields, + ) + ) + elif grid_overrides.has_duplicates: + strategies.append(DuplicateHandlingStrategy(coord_fields=coord_fields)) + + if not strategies: + return RegularGridStrategy() + if len(strategies) == 1: + return strategies[0] + return CompositeStrategy(strategies) diff --git a/src/mdio/segy/geometry.py b/src/mdio/segy/geometry.py index eaa28fb4..c2b80b80 100644 --- a/src/mdio/segy/geometry.py +++ b/src/mdio/segy/geometry.py @@ -1,32 +1,28 @@ -"""SEG-Y geometry handling functions.""" +"""SEG-Y grid override configuration model and legacy executor shim. + +The Pydantic :class:`GridOverrides` model is the supported public API for configuring +grid overrides. The :class:`GridOverrider` class is retained as a thin shim that +delegates to :class:`mdio.ingestion.index_strategies.IndexStrategyRegistry`; it preserves +the v1.1 ``run(...)`` contract for callers that still pass a legacy ``dict``. +""" from __future__ import annotations import logging -from abc import ABC -from abc import abstractmethod from typing import TYPE_CHECKING from typing import Any -import numpy as np -from numpy.lib import recfunctions as rfn from pydantic import BaseModel from pydantic import ConfigDict from pydantic import Field -from mdio.ingestion.segy.header_analysis import ShotGunGeometryType -from mdio.ingestion.segy.header_analysis import StreamerShotGeometryType -from mdio.ingestion.segy.header_analysis import analyze_lines_for_guns -from mdio.ingestion.segy.header_analysis import analyze_non_indexed_headers -from mdio.ingestion.segy.header_analysis import analyze_streamer_headers -from mdio.segy.exceptions import GridOverrideKeysError +from mdio.ingestion.index_strategies import IndexStrategyRegistry from mdio.segy.exceptions import GridOverrideMissingParameterError from mdio.segy.exceptions import GridOverrideUnknownError if TYPE_CHECKING: from collections.abc import Sequence - from numpy.typing import NDArray from segy.arrays import HeaderArray from mdio.builder.templates.base import AbstractDatasetTemplate @@ -90,533 +86,157 @@ def to_legacy_dict(self) -> dict[str, Any]: return self.model_dump(by_alias=True, exclude_defaults=True) -class GridOverrideCommand(ABC): - """Abstract base class for grid override commands.""" - - @property - @abstractmethod - def required_keys(self) -> set: - """Get the set of required keys for the grid override command.""" - - @property - @abstractmethod - def required_parameters(self) -> set: - """Get the set of required parameters for the grid override command.""" - - @abstractmethod - def validate( - self, - index_headers: HeaderArray, - grid_overrides: dict[str, bool | int], - template: AbstractDatasetTemplate | None = None, - ) -> None: - """Validate if this transform should run on the type of data.""" - - @abstractmethod - def transform( - self, - index_headers: HeaderArray, - grid_overrides: dict[str, bool | int], - template: AbstractDatasetTemplate, # noqa: ARG002 - ) -> NDArray: - """Perform the grid transform.""" - - def transform_index_names(self, index_names: Sequence[str]) -> Sequence[str]: - """Perform the transform of index names. - - Optional method: Subclasses may override this method to provide custom behavior. If not - overridden, this default implementation will be used, which is a no-op. - - Args: - index_names: List of index names to be modified. - - Returns: - New tuple of index names after the transform. - """ - return index_names - - def transform_chunksize( - self, - chunksize: Sequence[int], - grid_overrides: dict[str, bool | int], - ) -> Sequence[int]: - """Perform the transform of chunksize. - - Optional method: Subclasses may override this method to provide custom behavior. If not - overridden, this default implementation will be used, which is a no-op. - - Args: - chunksize: List of chunk sizes to be modified. - grid_overrides: Full grid override parameterization. - - Returns: - New tuple of chunk sizes after the transform. - """ - _ = grid_overrides # Unused, required for ABC compatibility - return chunksize - - @property - def name(self) -> str: - """Convenience property to get the name of the command.""" - return self.__class__.__name__ - - def check_required_keys(self, index_headers: HeaderArray) -> None: - """Check if all required keys are present in the index headers.""" - index_names = index_headers.dtype.names - if not self.required_keys.issubset(index_names): - raise GridOverrideKeysError(self.name, self.required_keys) - - def check_required_params(self, grid_overrides: dict[str, str | int]) -> None: - """Check if all required keys are present in the index headers.""" - if self.required_parameters is None: - return - - passed_parameters = set(grid_overrides.keys()) - - if not self.required_parameters.issubset(passed_parameters): - missing_params = self.required_parameters - passed_parameters - raise GridOverrideMissingParameterError(self.name, missing_params) - +def _resolve_synthesize_dims(template: AbstractDatasetTemplate | None) -> tuple[str, ...]: + """Return dimension fields to synthesize when missing for a given template. -class DuplicateIndex(GridOverrideCommand): - """Automatically handle duplicate traces in a new axis - trace with chunksize 1.""" - - required_keys = None - required_parameters = None - - def validate( - self, - index_headers: HeaderArray, - grid_overrides: dict[str, bool | int], - template: AbstractDatasetTemplate | None = None, # noqa: ARG002 - ) -> None: - """Validate if this transform should run on the type of data.""" - if self.required_keys is not None: - self.check_required_keys(index_headers) - self.check_required_params(grid_overrides) - - def transform( - self, - index_headers: HeaderArray, - grid_overrides: dict[str, bool | int], - template: AbstractDatasetTemplate, - ) -> NDArray: - """Perform the grid transform.""" - self.validate(index_headers, grid_overrides) - - # Filter out coordinate fields, keep only dimensions for trace indexing - coord_fields = set(template.coordinate_names) if template else set() - - # For NonBinned: non_binned_dims should be excluded from trace indexing grouping - # because they become coordinates indexed by the trace dimension, not grouping keys. - # The trace index should count all traces per remaining dimension combination. - non_binned_dims = set(grid_overrides.get("non_binned_dims", [])) if grid_overrides else set() - - dim_fields = [ - name - for name in index_headers.dtype.names - if name != "trace" and name not in coord_fields and name not in non_binned_dims - ] - - # Create trace indices on dimension fields only - dim_headers = index_headers[dim_fields] if dim_fields else index_headers - dim_headers_with_trace = analyze_non_indexed_headers(dim_headers) - - # Add trace field back to full headers - if dim_headers_with_trace is not None and "trace" in dim_headers_with_trace.dtype.names: - trace_values = np.array(dim_headers_with_trace["trace"]) - index_headers = rfn.append_fields(index_headers, "trace", trace_values, usemask=False) - - return index_headers - - def transform_index_names(self, index_names: Sequence[str]) -> Sequence[str]: - """Insert dimension "trace" to the sample-1 dimension.""" - new_names = list(index_names) - new_names.append("trace") - return tuple(new_names) - - def transform_chunksize( - self, - chunksize: Sequence[int], - grid_overrides: dict[str, bool | int], - ) -> Sequence[int]: - """Insert chunksize of 1 to the sample-1 dimension.""" - _ = grid_overrides # Unused, required for ABC compatibility - new_chunks = list(chunksize) - new_chunks.insert(-1, 1) - return tuple(new_chunks) - - -class NonBinned(DuplicateIndex): - """Handle non-binned dimensions by converting them to a trace dimension with coordinates. - - This override takes dimensions that are not regularly sampled (non-binned) and converts - them into a single 'trace' dimension. The original non-binned dimensions become coordinates - indexed by the trace dimension. - - Example: - Template with dimensions [shot_point, cable, channel, azimuth, offset, sample] - and non_binned_dims=['azimuth', 'offset'] becomes: - - dimensions: [shot_point, cable, channel, trace, sample] - - coordinates: azimuth and offset with dimensions [shot_point, cable, channel, trace] - - Attributes: - required_keys: No required keys for this override. - required_parameters: Set containing 'chunksize' and 'non_binned_dims'. + Only the OBN receiver gathers template currently synthesizes ``component``; every + other template returns ``()`` so the strategy registry skips synthesis entirely. """ - - required_keys = None - required_parameters = {"chunksize", "non_binned_dims"} - - def validate( - self, - index_headers: HeaderArray, - grid_overrides: dict[str, bool | int], - template: AbstractDatasetTemplate | None = None, # noqa: ARG002 - ) -> None: - """Validate if this transform should run on the type of data.""" - self.check_required_params(grid_overrides) - - # Validate that non_binned_dims is a list - non_binned_dims = grid_overrides.get("non_binned_dims", []) - if not isinstance(non_binned_dims, list): - msg = f"non_binned_dims must be a list, got {type(non_binned_dims)}" - raise ValueError(msg) - - # Validate that all non-binned dimensions exist in headers - missing_dims = set(non_binned_dims) - set(index_headers.dtype.names) - if missing_dims: - msg = f"Non-binned dimensions {missing_dims} not found in index headers" - raise ValueError(msg) - - def transform_chunksize( - self, - chunksize: Sequence[int], - grid_overrides: dict[str, bool | int], - ) -> Sequence[int]: - """Insert chunksize for trace dimension at N-1 position.""" - new_chunks = list(chunksize) - trace_chunksize = grid_overrides["chunksize"] - new_chunks.insert(-1, trace_chunksize) - return tuple(new_chunks) - - -class AutoChannelWrap(GridOverrideCommand): - """Automatically determine Streamer acquisition type.""" - - required_keys = {"shot_point", "cable", "channel"} - required_parameters = None - - def validate( - self, - index_headers: HeaderArray, - grid_overrides: dict[str, bool | int], - template: AbstractDatasetTemplate | None = None, # noqa: ARG002 - ) -> None: - """Validate if this transform should run on the type of data.""" - self.check_required_keys(index_headers) - self.check_required_params(grid_overrides) - - def transform( - self, - index_headers: HeaderArray, - grid_overrides: dict[str, bool | int], - template: AbstractDatasetTemplate, # noqa: ARG002 - ) -> NDArray: - """Perform the grid transform.""" - self.validate(index_headers, grid_overrides) - - result = analyze_streamer_headers(index_headers) - unique_cables, cable_chan_min, cable_chan_max, geom_type = result - logger.info("Ingesting dataset as %s", geom_type.name) - - for cable, chan_min, chan_max in zip(unique_cables, cable_chan_min, cable_chan_max, strict=True): - logger.info("Cable: %s has min chan: %s and max chan: %s", cable, chan_min, chan_max) - - # This might be slow and could be improved with a rewrite to prevent so many lookups - if geom_type == StreamerShotGeometryType.B: - for idx, cable in enumerate(unique_cables): - cable_idxs = np.where(index_headers["cable"][:] == cable) - cc_min = cable_chan_min[idx] - - index_headers["channel"][cable_idxs] = index_headers["channel"][cable_idxs] - cc_min + 1 - - return index_headers - - -class CalculateShotIndex(GridOverrideCommand): - """Calculate dense shot_index from shot_point values for OBN templates. - - Creates a 0-based shot_index dimension from sparse or interleaved shot_point - values, grouping by shot_line and gun. This is required for the OBN template - which uses shot_index as a calculated dimension. - - Required headers: shot_line, gun, shot_point - - Attributes: - required_parameters: Set of required parameters (None for this class). + if template is None: + return () + # Lazy import: builder templates pull in builder schemas that indirectly import this + # module's ``GridOverrides``, so a top-level import would cycle. + from mdio.builder.templates.seismic_3d_obn import Seismic3DObnReceiverGathersTemplate # noqa: PLC0415 + + if isinstance(template, Seismic3DObnReceiverGathersTemplate): + return ("component",) + return () + + +def _validate_template_for_overrides( + config: GridOverrides, + template: AbstractDatasetTemplate | None, +) -> None: + """Reject grid override / template pairings that v1.1 forbade. + + ``auto_shot_wrap`` is streamer-only and ``calculate_shot_index`` is OBN-only; using + either with the wrong template silently produced wrong shot indices in v1.1 unless + the per-command validator caught it. This function restores that guard. + + Args: + config: Typed grid overrides extracted from the user's legacy dict. + template: Template chosen by the caller, or ``None`` if omitted. + + Raises: + TypeError: When ``auto_shot_wrap`` is set without a streamer template, or + ``calculate_shot_index`` is set without an OBN receiver-gathers template. """ - - required_parameters = None - - @property - def required_keys(self) -> set: - """Return required header keys for OBN shot index calculation.""" - return {"shot_line", "gun", "shot_point"} - - def validate( - self, - index_headers: HeaderArray, - grid_overrides: dict[str, bool | int], - template: AbstractDatasetTemplate | None = None, - ) -> None: - """Validate if this transform should run on the type of data.""" - # Import here to avoid circular imports at module load time - from mdio.builder.templates.seismic_3d_obn import Seismic3DObnReceiverGathersTemplate # noqa: PLC0415 - - if template is None: - msg = "CalculateShotIndex requires a template." - raise TypeError(msg) - - if not isinstance(template, Seismic3DObnReceiverGathersTemplate): - msg = ( - f"CalculateShotIndex only supports Seismic3DObnReceiverGathersTemplate, got {type(template).__name__}." - ) - raise TypeError(msg) - - index_names = set(index_headers.dtype.names) - if not self.required_keys.issubset(index_names): - raise GridOverrideKeysError(self.name, self.required_keys) - self.check_required_params(grid_overrides) - - def transform( - self, - index_headers: HeaderArray, - grid_overrides: dict[str, bool | int], - template: AbstractDatasetTemplate, - ) -> NDArray: - """Perform the grid transform to calculate shot_index from shot_point.""" - self.validate(index_headers, grid_overrides, template) - - line_field = "shot_line" - result = analyze_lines_for_guns(index_headers, line_field=line_field) - unique_lines, unique_guns_per_line, geom_type = result - logger.info("Ingesting OBN dataset as shot type: %s", geom_type.name) - - max_num_guns = 1 - for line_val in unique_lines: - guns = unique_guns_per_line[str(line_val)] - logger.info("%s: %s has guns: %s", line_field, line_val, guns) - max_num_guns = max(len(guns), max_num_guns) - - # Always calculate shot_index - the OBN template requires it - shot_index = np.empty(len(index_headers), dtype="uint32") - # Use .base if available (view of another array), otherwise use the array directly - base_array = index_headers.base if index_headers.base is not None else index_headers - index_headers = rfn.append_fields(base_array, "shot_index", shot_index) - - if geom_type == ShotGunGeometryType.B: - # Type B: shot points are interleaved across guns, divide to get dense index - for line_val in unique_lines: - line_idxs = np.where(index_headers[line_field][:] == line_val) - index_headers["shot_index"][line_idxs] = np.floor(index_headers["shot_point"][line_idxs] / max_num_guns) - # Make shot index zero-based PER line - index_headers["shot_index"][line_idxs] -= index_headers["shot_index"][line_idxs].min() - else: - # Type A: shot points are already unique per gun, create 0-based index from unique values - for line_val in unique_lines: - line_idxs = np.where(index_headers[line_field][:] == line_val) - shot_points = index_headers["shot_point"][line_idxs] - # np.unique returns sorted values; searchsorted maps each shot_point to its 0-based index - unique_shots = np.unique(shot_points) - index_headers["shot_index"][line_idxs] = np.searchsorted(unique_shots, shot_points) - - return index_headers - - -class AutoShotWrap(GridOverrideCommand): - """Automatic shot index calculation from interleaved shot points for Streamer templates. - - This grid override handles multi-gun acquisition where shot points may be - interleaved across guns. It calculates a dense shot_index from sparse shot_point values. - - Supported Templates: - - Seismic3DStreamerFieldRecordsTemplate: Uses sail_line, requires cable/channel - - Note: - For OBN templates, use CalculateShotIndex instead. - - Attributes: - required_parameters: Set of required parameters (None for this class). - """ - - required_parameters = None - - @property - def required_keys(self) -> set: - """Return required header keys for streamer shot index calculation.""" - return {"sail_line", "gun", "shot_point", "cable", "channel"} - - def validate( - self, - index_headers: HeaderArray, - grid_overrides: dict[str, bool | int], - template: AbstractDatasetTemplate | None = None, - ) -> None: - """Validate if this transform should run on the type of data.""" - # Import here to avoid circular imports at module load time + if config.auto_shot_wrap: + # Lazy import: see ``_resolve_synthesize_dims`` for the cycle rationale. from mdio.builder.templates.seismic_3d_streamer_field import ( # noqa: PLC0415 Seismic3DStreamerFieldRecordsTemplate, ) - if template is None: - msg = "AutoShotWrap requires a template." - raise TypeError(msg) - if not isinstance(template, Seismic3DStreamerFieldRecordsTemplate): + actual = type(template).__name__ if template is not None else "None" msg = ( - f"AutoShotWrap only supports Seismic3DStreamerFieldRecordsTemplate, " - f"got {type(template).__name__}. For OBN templates, use CalculateShotIndex." + f"auto_shot_wrap only supports Seismic3DStreamerFieldRecordsTemplate, " + f"got {actual}. For OBN templates, use calculate_shot_index." ) raise TypeError(msg) - index_names = set(index_headers.dtype.names) - if not self.required_keys.issubset(index_names): - raise GridOverrideKeysError(self.name, self.required_keys) - self.check_required_params(grid_overrides) + if config.calculate_shot_index: + from mdio.builder.templates.seismic_3d_obn import Seismic3DObnReceiverGathersTemplate # noqa: PLC0415 - def transform( - self, - index_headers: HeaderArray, - grid_overrides: dict[str, bool | int], - template: AbstractDatasetTemplate, - ) -> NDArray: - """Perform the grid transform to calculate shot_index from shot_point.""" - self.validate(index_headers, grid_overrides, template) - - line_field = "sail_line" - result = analyze_lines_for_guns(index_headers, line_field=line_field) - unique_lines, unique_guns_per_line, geom_type = result - logger.info("Ingesting streamer dataset as shot type: %s", geom_type.name) - - max_num_guns = 1 - for line_val in unique_lines: - guns = unique_guns_per_line[str(line_val)] - logger.info("%s: %s has guns: %s", line_field, line_val, guns) - max_num_guns = max(len(guns), max_num_guns) - - # Only calculate shot_index when shot points are interleaved across guns (Type B) - if geom_type == ShotGunGeometryType.B: - shot_index = np.empty(len(index_headers), dtype="uint32") - # Use .base if available (view of another array), otherwise use the array directly - base_array = index_headers.base if index_headers.base is not None else index_headers - index_headers = rfn.append_fields(base_array, "shot_index", shot_index) - - for line_val in unique_lines: - line_idxs = np.where(index_headers[line_field][:] == line_val) - index_headers["shot_index"][line_idxs] = np.floor(index_headers["shot_point"][line_idxs] / max_num_guns) - # Make shot index zero-based PER line - index_headers["shot_index"][line_idxs] -= index_headers["shot_index"][line_idxs].min() - - return index_headers + if not isinstance(template, Seismic3DObnReceiverGathersTemplate): + actual = type(template).__name__ if template is not None else "None" + msg = f"calculate_shot_index only supports Seismic3DObnReceiverGathersTemplate, got {actual}." + raise TypeError(msg) class GridOverrider: - """Executor for grid overrides. + """Legacy facade that adapts the dict-based v1.1 API onto :class:`IndexStrategyRegistry`. - We support a certain type of grid overrides, and they have to be implemented following the - ABC's in this module. - - This class applies the grid overrides if needed. + Existing callers (notably :func:`mdio.segy.utilities.get_grid_plan`) still build a + legacy ``dict`` of grid overrides and call :meth:`run`. This class translates the dict + into a typed :class:`GridOverrides`, dispatches to the appropriate + :class:`IndexStrategy`, and returns the ``(headers, names, chunksize)`` tuple shape + those callers depend on. It will be removed once all callers move to the typed API. """ def __init__(self) -> None: - self.commands = { - "AutoChannelWrap": AutoChannelWrap(), - "AutoShotWrap": AutoShotWrap(), - "CalculateShotIndex": CalculateShotIndex(), - "NonBinned": NonBinned(), - "HasDuplicates": DuplicateIndex(), - } - - self.parameters = self.get_allowed_parameters() - - def get_allowed_parameters(self) -> set: - """Get list of allowed parameters from the allowed commands.""" - parameters = set() - for command in self.commands.values(): - if command.required_parameters is None: - continue - - parameters.update(command.required_parameters) - - # Add optional parameters that are not strictly required but are valid - parameters.add("non_binned_dims") - - return parameters - - def _synthesize_obn_component( - self, - index_headers: HeaderArray, - template: AbstractDatasetTemplate | None, - ) -> HeaderArray: - """Synthesize component field for OBN template when missing from headers. - - OBN data may not have a component field in the SEG-Y headers (e.g., single-component - data). When using Seismic3DObnReceiverGathersTemplate and component is missing, - this method synthesizes it with a constant value of 1. - - Args: - index_headers: The parsed index headers from SEG-Y. - template: The dataset template. - - Returns: - Headers with component field added if applicable. - """ - # Import here to avoid circular imports at module load time - from mdio.builder.templates.seismic_3d_obn import Seismic3DObnReceiverGathersTemplate # noqa: PLC0415 - - if not isinstance(template, Seismic3DObnReceiverGathersTemplate): - return index_headers - - if "component" in index_headers.dtype.names: - return index_headers - - logger.warning( - "SEG-Y headers do not contain 'component' field required by template '%s'. " - "Synthesizing 'component' dimension with constant value 1 for all traces.", - template.name, - ) - synthetic_col = np.full(len(index_headers), 1, dtype=np.int32) - base_array = index_headers.base if index_headers.base is not None else index_headers - return rfn.append_fields(base_array, "component", synthetic_col, usemask=False) + self._registry = IndexStrategyRegistry() def run( self, index_headers: HeaderArray, index_names: Sequence[str], - grid_overrides: dict[str, bool], + grid_overrides: dict[str, Any] | None, chunksize: Sequence[int] | None = None, template: AbstractDatasetTemplate | None = None, - ) -> tuple[HeaderArray, tuple[str], tuple[int]]: - """Run grid overrides and return result.""" - # Synthesize component for OBN template if missing - index_headers = self._synthesize_obn_component(index_headers, template) + ) -> tuple[HeaderArray, tuple[str, ...], tuple[int, ...] | None]: + """Run the configured grid overrides and return updated headers/names/chunks. - for override in grid_overrides: - if override in self.parameters: - continue + Args: + index_headers: Parsed SEG-Y trace headers; structured numpy array. + index_names: Names of the index dimensions before any override is applied. + grid_overrides: Legacy dict of overrides (CamelCase keys). + chunksize: Optional chunk shape that may be expanded by overrides that add a + ``trace`` dimension. + template: Optional dataset template; used to identify coordinate fields and + to drive component synthesis for OBN. - if override not in self.commands: - raise GridOverrideUnknownError(override) + Returns: + Tuple of ``(transformed_headers, new_index_names, new_chunksize)``. The + chunksize tuple is ``None`` when the caller did not pass a chunksize. + + Raises: + GridOverrideUnknownError: When ``grid_overrides`` contains an unknown key. + GridOverrideMissingParameterError: When ``NonBinned`` is enabled without + ``chunksize`` or ``non_binned_dims``. + + Notes: + Header-precondition checks (``GridOverrideKeysError``) are delegated to + :meth:`IndexStrategy.validate_headers`; template-compatibility checks + (``TypeError``) are delegated to :func:`_validate_template_for_overrides`. + """ + grid_overrides = grid_overrides or {} + + field_names = set(GridOverrides.model_fields.keys()) + aliases = {field.alias for field in GridOverrides.model_fields.values() if field.alias} + valid_keys = field_names | aliases + for key in grid_overrides: + if key not in valid_keys: + raise GridOverrideUnknownError(key) + + config = GridOverrides.model_validate(grid_overrides) + + if config.non_binned: + missing: set[str] = set() + if config.chunksize is None: + missing.add("chunksize") + if not config.non_binned_dims: + missing.add("non_binned_dims") + if missing: + command = "NonBinned" + raise GridOverrideMissingParameterError(command, missing) + + _validate_template_for_overrides(config, template) + + synthesize_dims = _resolve_synthesize_dims(template) + strategy = self._registry.create_strategy( + grid_overrides=config, + synthesize_dims=synthesize_dims, + template=template, + ) + logger.debug("Selected grid override strategy: %s", strategy.name) - function = self.commands[override].transform - index_headers = function(index_headers, grid_overrides=grid_overrides, template=template) + strategy.validate_headers(index_headers) + new_headers = strategy.transform_headers(index_headers) - function = self.commands[override].transform_index_names - index_names = function(index_names) + new_names = list(index_names) + new_chunks = list(chunksize) if chunksize is not None else None - function = self.commands[override].transform_chunksize - chunksize = function(chunksize, grid_overrides=grid_overrides) + # Both NonBinned and HasDuplicates add a 'trace' dim at index -1; HasDuplicates + # always uses chunksize 1, NonBinned uses the user-supplied value. + if config.non_binned or config.has_duplicates: + new_names.append("trace") + if new_chunks is not None: + inserted_chunk = config.chunksize if config.non_binned else 1 + new_chunks.insert(-1, inserted_chunk) - return index_headers, index_names, chunksize + return ( + new_headers, + tuple(new_names), + tuple(new_chunks) if new_chunks is not None else None, + ) diff --git a/tests/unit/test_ingestion_index_strategies.py b/tests/unit/test_ingestion_index_strategies.py new file mode 100644 index 00000000..0cac710c --- /dev/null +++ b/tests/unit/test_ingestion_index_strategies.py @@ -0,0 +1,569 @@ +"""Unit tests for the v1.2 ingestion index strategies and the strategy registry. + +These tests exercise individual :class:`mdio.ingestion.index_strategies.IndexStrategy` +subclasses with synthetic structured numpy arrays (mimicking the shape semantics of +:class:`segy.arrays.HeaderArray`) so they remain fast and do not require any real SEG-Y +data. +""" + +from __future__ import annotations + +import numpy as np +import pytest + +from mdio.builder.template_registry import TemplateRegistry +from mdio.ingestion.index_strategies import ChannelWrappingStrategy +from mdio.ingestion.index_strategies import ComponentSynthesisStrategy +from mdio.ingestion.index_strategies import CompositeStrategy +from mdio.ingestion.index_strategies import DuplicateHandlingStrategy +from mdio.ingestion.index_strategies import IndexStrategyRegistry +from mdio.ingestion.index_strategies import NonBinnedStrategy +from mdio.ingestion.index_strategies import RegularGridStrategy +from mdio.ingestion.index_strategies import ShotWrappingStrategy +from mdio.segy.exceptions import GridOverrideKeysError +from mdio.segy.geometry import GridOverrider +from mdio.segy.geometry import GridOverrides + + +def _make_struct(data: dict[str, np.ndarray]) -> np.ndarray: + """Build a 1-D structured array from a name -> 1-D array mapping.""" + names = list(data.keys()) + arrays = [data[name] for name in names] + n = len(arrays[0]) + dtype = np.dtype([(name, arr.dtype) for name, arr in zip(names, arrays, strict=True)]) + out = np.empty(n, dtype=dtype) + for name, arr in zip(names, arrays, strict=True): + out[name] = arr + return out + + +# --------------------------------------------------------------------------- +# IndexStrategyRegistry +# --------------------------------------------------------------------------- + + +class TestIndexStrategyRegistry: + """Selection rules for :class:`IndexStrategyRegistry.create_strategy`.""" + + def test_default_returns_regular_grid(self) -> None: + """No grid overrides and no synthesis -> regular grid.""" + strategy = IndexStrategyRegistry().create_strategy(grid_overrides=None) + assert isinstance(strategy, RegularGridStrategy) + + def test_falsy_overrides_returns_regular_grid(self) -> None: + """A default ``GridOverrides()`` with all flags off must be treated as no-op.""" + strategy = IndexStrategyRegistry().create_strategy(grid_overrides=GridOverrides()) + assert isinstance(strategy, RegularGridStrategy) + + def test_synthesize_dims_only(self) -> None: + """Synthesis-only configuration returns a single ComponentSynthesisStrategy.""" + strategy = IndexStrategyRegistry().create_strategy(synthesize_dims=("component",)) + assert isinstance(strategy, ComponentSynthesisStrategy) + assert strategy.synthesize_dims == ("component",) + + def test_non_binned_only(self) -> None: + """``non_binned`` -> NonBinnedStrategy with chunksize and excluded dims wired.""" + overrides = GridOverrides(non_binned=True, chunksize=64, non_binned_dims=["channel"]) + strategy = IndexStrategyRegistry().create_strategy(grid_overrides=overrides) + assert isinstance(strategy, NonBinnedStrategy) + assert strategy.chunksize == 64 + assert strategy.non_binned_dims == ("channel",) + + def test_has_duplicates_only(self) -> None: + """``has_duplicates`` -> DuplicateHandlingStrategy.""" + strategy = IndexStrategyRegistry().create_strategy(grid_overrides=GridOverrides(has_duplicates=True)) + assert isinstance(strategy, DuplicateHandlingStrategy) + + def test_non_binned_wins_over_has_duplicates(self) -> None: + """Both flags set -> NonBinned wins (matches v1.x semantics).""" + overrides = GridOverrides( + non_binned=True, + chunksize=8, + non_binned_dims=["channel"], + has_duplicates=True, + ) + strategy = IndexStrategyRegistry().create_strategy(grid_overrides=overrides) + assert isinstance(strategy, NonBinnedStrategy) + + def test_composite_with_channel_wrap(self) -> None: + """Channel wrap + non-binned -> CompositeStrategy ordered for safe layering.""" + overrides = GridOverrides( + auto_channel_wrap=True, + non_binned=True, + chunksize=64, + non_binned_dims=["channel"], + ) + strategy = IndexStrategyRegistry().create_strategy(grid_overrides=overrides) + assert isinstance(strategy, CompositeStrategy) + assert [s.name for s in strategy.strategies] == ["ChannelWrappingStrategy", "NonBinnedStrategy"] + + def test_synthesize_dims_runs_first(self) -> None: + """Synthesis must run before any strategy that may depend on the synthesized field.""" + overrides = GridOverrides(calculate_shot_index=True) + strategy = IndexStrategyRegistry().create_strategy( + grid_overrides=overrides, + synthesize_dims=("component",), + ) + assert isinstance(strategy, CompositeStrategy) + assert strategy.strategies[0].name == "ComponentSynthesisStrategy" + + def test_auto_shot_wrap_uses_sail_line(self) -> None: + """``auto_shot_wrap`` is the streamer flag: line_field=sail_line, no always-calc.""" + overrides = GridOverrides(auto_shot_wrap=True) + strategy = IndexStrategyRegistry().create_strategy(grid_overrides=overrides) + assert isinstance(strategy, ShotWrappingStrategy) + assert strategy.line_field == "sail_line" + assert strategy.always_calculate is False + + def test_calculate_shot_index_uses_shot_line(self) -> None: + """``calculate_shot_index`` is the OBN flag: line_field=shot_line, always-calc.""" + overrides = GridOverrides(calculate_shot_index=True) + strategy = IndexStrategyRegistry().create_strategy(grid_overrides=overrides) + assert isinstance(strategy, ShotWrappingStrategy) + assert strategy.line_field == "shot_line" + assert strategy.always_calculate is True + + def test_template_coord_names_propagate_to_duplicate_strategy(self) -> None: + """Template coordinates flow into the duplicate-handling strategy as exclusions.""" + template = TemplateRegistry().get("StreamerShotGathers3D") + strategy = IndexStrategyRegistry().create_strategy( + grid_overrides=GridOverrides(has_duplicates=True), + template=template, + ) + assert isinstance(strategy, DuplicateHandlingStrategy) + assert strategy.coord_fields == frozenset(template.coordinate_names) + + +# --------------------------------------------------------------------------- +# RegularGridStrategy +# --------------------------------------------------------------------------- + + +class TestRegularGridStrategy: + """Pass-through strategy used when no overrides are required.""" + + def test_returns_unique_dims(self) -> None: + """Each dim name maps to its sorted unique values.""" + headers = _make_struct( + { + "inline": np.array([1, 1, 2, 2], dtype=np.int32), + "crossline": np.array([10, 11, 10, 11], dtype=np.int32), + } + ) + dims = RegularGridStrategy().compute_dimensions(headers, ("inline", "crossline")) + assert [d.name for d in dims] == ["inline", "crossline"] + np.testing.assert_array_equal(dims[0].coords, [1, 2]) + np.testing.assert_array_equal(dims[1].coords, [10, 11]) + + def test_unknown_dim_silently_skipped(self) -> None: + """Names absent from the header dtype are dropped, matching v1.1 behavior.""" + headers = _make_struct({"inline": np.array([1, 2], dtype=np.int32)}) + dims = RegularGridStrategy().compute_dimensions(headers, ("inline", "missing")) + assert [d.name for d in dims] == ["inline"] + + def test_default_required_keys_empty_and_validate_is_noop(self) -> None: + """Strategies that don't override ``required_keys`` must not block any headers.""" + headers = _make_struct({"inline": np.array([1], dtype=np.int32)}) + strategy = RegularGridStrategy() + assert strategy.required_keys == frozenset() + strategy.validate_headers(headers) # must not raise + + +# --------------------------------------------------------------------------- +# DuplicateHandlingStrategy +# --------------------------------------------------------------------------- + + +class TestDuplicateHandlingStrategy: + """Counter-based ``trace`` field for duplicated index tuples.""" + + def test_appends_per_tuple_counter(self) -> None: + """Each (inline, crossline) tuple gets a 1-based duplicate counter.""" + headers = _make_struct( + { + "inline": np.array([1, 1, 1, 2], dtype=np.int32), + "crossline": np.array([10, 10, 11, 10], dtype=np.int32), + } + ) + out = DuplicateHandlingStrategy().transform_headers(headers) + assert "trace" in out.dtype.names + # (1,10) appears twice -> {1, 2}; (1,11) once -> {1}; (2,10) once -> {1}. + np.testing.assert_array_equal(np.sort(out["trace"]), [1, 1, 1, 2]) + + def test_coord_fields_excluded_from_grouping(self) -> None: + """Coordinate fields must not influence the duplicate counter.""" + headers = _make_struct( + { + "inline": np.array([1, 1, 2, 2], dtype=np.int32), + "crossline": np.array([10, 10, 11, 11], dtype=np.int32), + # 'cdp_x' varies per row but is a coord, so the counter must ignore it. + "cdp_x": np.array([100.0, 100.5, 200.0, 200.5], dtype=np.float64), + } + ) + strategy = DuplicateHandlingStrategy(coord_fields=("cdp_x",)) + out = strategy.transform_headers(headers) + # Each (inline, crossline) tuple appears twice -> counters cycle 1..2. + np.testing.assert_array_equal(out["trace"], [1, 2, 1, 2]) + + def test_excluded_fields_dropped_from_grouping(self) -> None: + """Caller-declared excluded fields (e.g. non_binned_dims) are not part of the key.""" + headers = _make_struct( + { + "shot_point": np.array([1, 1, 1], dtype=np.int32), + "channel": np.array([1, 2, 3], dtype=np.int32), + } + ) + strategy = DuplicateHandlingStrategy(excluded_fields=("channel",)) + out = strategy.transform_headers(headers) + # Without excluding 'channel', counters would all be 1; excluding it groups by + # shot_point alone so each row in the same shot_point gets a fresh counter. + np.testing.assert_array_equal(out["trace"], [1, 2, 3]) + + +# --------------------------------------------------------------------------- +# NonBinnedStrategy +# --------------------------------------------------------------------------- + + +class TestNonBinnedStrategy: + """``NonBinned`` is the duplicate counter wired with explicit collapse dims.""" + + def test_chunksize_and_dims_recorded(self) -> None: + """Constructor stores both for the ``GridOverrider`` shim to use.""" + strategy = NonBinnedStrategy(chunksize=4, non_binned_dims=("channel",)) + assert strategy.chunksize == 4 + assert strategy.non_binned_dims == ("channel",) + + def test_collapse_dim_excluded_from_counter(self) -> None: + """The non-binned dim must NOT participate in the duplicate counter grouping.""" + headers = _make_struct( + { + "shot_point": np.array([1, 1, 2, 2], dtype=np.int32), + "cable": np.array([1, 2, 1, 2], dtype=np.int32), + "channel": np.array([10, 11, 10, 11], dtype=np.int32), + } + ) + strategy = NonBinnedStrategy(chunksize=4, non_binned_dims=("channel",)) + out = strategy.transform_headers(headers) + # Each (shot_point, cable) tuple appears once -> counters all equal 1. + assert "trace" in out.dtype.names + np.testing.assert_array_equal(out["trace"], [1, 1, 1, 1]) + + def test_coord_fields_also_excluded(self) -> None: + """Both template coords and non_binned_dims are removed from the grouping key.""" + headers = _make_struct( + { + "shot_point": np.array([1, 1, 2, 2], dtype=np.int32), + "channel": np.array([1, 2, 1, 2], dtype=np.int32), + "cdp_x": np.array([10.0, 20.0, 30.0, 40.0], dtype=np.float64), + } + ) + strategy = NonBinnedStrategy( + chunksize=4, + non_binned_dims=("channel",), + coord_fields=("cdp_x",), + ) + out = strategy.transform_headers(headers) + # Grouping is by shot_point only -> each shot_point has 2 rows -> counters {1, 2}. + np.testing.assert_array_equal(out["trace"], [1, 2, 1, 2]) + + +# --------------------------------------------------------------------------- +# ChannelWrappingStrategy +# --------------------------------------------------------------------------- + + +class TestChannelWrappingStrategy: + """Streamer Type-A vs Type-B detection and channel rebasing.""" + + def test_type_a_pass_through(self) -> None: + """Type A (per-cable channel numbering with overlap) -> headers untouched.""" + headers = _make_struct( + { + "cable": np.array([1, 1, 2, 2], dtype=np.int32), + "channel": np.array([1, 2, 1, 2], dtype=np.int32), + } + ) + out = ChannelWrappingStrategy().transform_headers(headers) + np.testing.assert_array_equal(out["channel"], [1, 2, 1, 2]) + + def test_type_b_renumbers_per_cable(self) -> None: + """Type B (sequential numbering across cables) -> rebased to 1..N per cable.""" + headers = _make_struct( + { + "cable": np.array([1, 1, 2, 2], dtype=np.int32), + "channel": np.array([1, 2, 3, 4], dtype=np.int32), + } + ) + out = ChannelWrappingStrategy().transform_headers(headers) + # Cable 1: 1,2 -> 1,2; cable 2: 3,4 -> 1,2. + np.testing.assert_array_equal(out["channel"], [1, 2, 1, 2]) + + def test_required_keys(self) -> None: + """Channel wrap declares the cable-channel-shot triplet as preconditions.""" + assert ChannelWrappingStrategy().required_keys == frozenset({"shot_point", "cable", "channel"}) + + def test_validate_headers_raises_when_field_missing(self) -> None: + """Missing ``cable`` -> :class:`GridOverrideKeysError`, not a deeper numpy crash.""" + headers = _make_struct( + { + "shot_point": np.array([1, 2], dtype=np.int32), + "channel": np.array([1, 2], dtype=np.int32), + } + ) + strategy = ChannelWrappingStrategy() + with pytest.raises(GridOverrideKeysError, match="ChannelWrappingStrategy"): + strategy.validate_headers(headers) + + +# --------------------------------------------------------------------------- +# ShotWrappingStrategy +# --------------------------------------------------------------------------- + + +class TestShotWrappingStrategy: + """Shot-index derivation for both streamer (sail_line) and OBN (shot_line).""" + + def test_type_b_streamer_emits_shot_index(self) -> None: + """Sail line 1 with two interleaved guns -> dense shot_index per line.""" + headers = _make_struct( + { + "sail_line": np.array([1, 1, 1, 1], dtype=np.int32), + "gun": np.array([1, 2, 1, 2], dtype=np.int32), + "shot_point": np.array([1, 2, 3, 4], dtype=np.int32), + } + ) + out = ShotWrappingStrategy(line_field="sail_line").transform_headers(headers) + assert "shot_index" in out.dtype.names + # floor(shot_point / 2) zero-based per line: 0, 1, 1, 2. + np.testing.assert_array_equal(out["shot_index"], [0, 1, 1, 2]) + + def test_type_a_streamer_skipped_without_always_calculate(self) -> None: + """Type A streamer geometry produces no shot_index unless always_calculate=True.""" + headers = _make_struct( + { + "sail_line": np.array([1, 1, 1, 1, 1, 1], dtype=np.int32), + "gun": np.array([1, 1, 1, 2, 2, 2], dtype=np.int32), + "shot_point": np.array([1, 2, 3, 1, 2, 3], dtype=np.int32), + } + ) + out = ShotWrappingStrategy(line_field="sail_line").transform_headers(headers) + assert "shot_index" not in out.dtype.names + + def test_type_a_obn_always_calculates(self) -> None: + """OBN forces shot_index calculation; Type A uses dense per-line searchsorted.""" + headers = _make_struct( + { + "shot_line": np.array([1, 1, 1, 1, 1, 1, 1, 1], dtype=np.int32), + "gun": np.array([1, 1, 1, 1, 2, 2, 2, 2], dtype=np.int32), + "shot_point": np.array([1, 2, 3, 4, 1, 2, 3, 4], dtype=np.int32), + } + ) + out = ShotWrappingStrategy(line_field="shot_line", always_calculate=True).transform_headers(headers) + assert "shot_index" in out.dtype.names + np.testing.assert_array_equal(out["shot_index"], [0, 1, 2, 3, 0, 1, 2, 3]) + + def test_obn_multiline_type_a_processes_all_lines(self) -> None: + """Regression: Type A detection on line 1 must not mask later lines.""" + headers = _make_struct( + { + "shot_line": np.array([1, 1, 2, 2, 3, 3], dtype=np.int32), + "gun": np.array([1, 2, 1, 2, 1, 2], dtype=np.int32), + "shot_point": np.array([1, 2, 1, 2, 1, 2], dtype=np.int32), + } + ) + out = ShotWrappingStrategy(line_field="shot_line", always_calculate=True).transform_headers(headers) + # Each line gets independent dense per-line indices. + np.testing.assert_array_equal(out["shot_index"], [0, 1, 0, 1, 0, 1]) + + def test_required_keys_sail_line(self) -> None: + """Streamer variant requires the streamer-cable-channel headers in addition to shot fields.""" + strategy = ShotWrappingStrategy(line_field="sail_line") + assert strategy.required_keys == frozenset({"sail_line", "gun", "shot_point", "cable", "channel"}) + + def test_required_keys_shot_line(self) -> None: + """OBN variant deliberately omits cable/channel from required keys.""" + strategy = ShotWrappingStrategy(line_field="shot_line", always_calculate=True) + assert strategy.required_keys == frozenset({"shot_line", "gun", "shot_point"}) + + def test_validate_headers_raises_for_obn_when_missing_gun(self) -> None: + """Missing ``gun`` on the OBN path -> :class:`GridOverrideKeysError`.""" + headers = _make_struct( + { + "shot_line": np.array([1, 1], dtype=np.int32), + "shot_point": np.array([1, 2], dtype=np.int32), + } + ) + strategy = ShotWrappingStrategy(line_field="shot_line", always_calculate=True) + with pytest.raises(GridOverrideKeysError, match="ShotWrappingStrategy"): + strategy.validate_headers(headers) + + +# --------------------------------------------------------------------------- +# ComponentSynthesisStrategy +# --------------------------------------------------------------------------- + + +class TestComponentSynthesisStrategy: + """Synthesize template-required dimensions when missing from the SEG-Y headers.""" + + def test_synthesizes_missing_field(self) -> None: + """Missing field is added with constant value 1 for every row.""" + headers = _make_struct({"receiver": np.array([1, 2, 3], dtype=np.int32)}) + out = ComponentSynthesisStrategy(("component",)).transform_headers(headers) + assert "component" in out.dtype.names + np.testing.assert_array_equal(out["component"], [1, 1, 1]) + + def test_existing_field_left_alone(self) -> None: + """If the field is already present, the existing values are preserved.""" + headers = _make_struct( + { + "receiver": np.array([1, 2, 3], dtype=np.int32), + "component": np.array([2, 3, 4], dtype=np.uint8), + } + ) + out = ComponentSynthesisStrategy(("component",)).transform_headers(headers) + np.testing.assert_array_equal(out["component"], [2, 3, 4]) + + +# --------------------------------------------------------------------------- +# CompositeStrategy +# --------------------------------------------------------------------------- + + +class TestCompositeStrategy: + """Strategy chaining with deterministic execution order.""" + + def test_requires_at_least_one_strategy(self) -> None: + """An empty composite is a programming error and must raise.""" + with pytest.raises(ValueError, match="at least one strategy"): + CompositeStrategy([]) + + def test_strategies_run_in_order(self) -> None: + """Synthesis must produce the field that the next strategy then duplicates.""" + headers = _make_struct( + { + "shot_point": np.array([1, 1, 2, 2], dtype=np.int32), + "channel": np.array([1, 2, 1, 2], dtype=np.int32), + } + ) + composite = CompositeStrategy( + [ + ComponentSynthesisStrategy(("component",)), + NonBinnedStrategy(chunksize=4, non_binned_dims=("channel",)), + ] + ) + out = composite.transform_headers(headers) + assert "component" in out.dtype.names + assert "trace" in out.dtype.names + + def test_progressive_validation_raises_for_first_unsatisfied_child(self) -> None: + """Composite must surface a child's required-keys failure as :class:`GridOverrideKeysError`. + + Channel wrap needs ``cable``; ``RegularGridStrategy`` runs first and is a no-op, + so the composite reaches channel wrap with the same incomplete headers and must + raise rather than crash inside numpy. + """ + headers = _make_struct( + { + "shot_point": np.array([1, 2], dtype=np.int32), + "channel": np.array([1, 2], dtype=np.int32), + } + ) + composite = CompositeStrategy([RegularGridStrategy(), ChannelWrappingStrategy()]) + with pytest.raises(GridOverrideKeysError, match="ChannelWrappingStrategy"): + composite.transform_headers(headers) + + +# --------------------------------------------------------------------------- +# GridOverrider template-compatibility checks (shim level) +# --------------------------------------------------------------------------- + + +class TestGridOverriderTemplateValidation: + """Restore v1.1's template-type guards for shot-wrapping overrides. + + ``AutoShotWrap`` was streamer-only and ``CalculateShotIndex`` was OBN-only; pairing + either with the wrong template silently produced incorrect shot indices. The shim + raises :class:`TypeError` early so misconfigurations fail loudly at the API boundary. + """ + + def test_auto_shot_wrap_rejects_obn_template(self) -> None: + """Streamer override + OBN template -> TypeError, no transform run.""" + headers = _make_struct( + { + "sail_line": np.array([1, 1], dtype=np.int32), + "gun": np.array([1, 2], dtype=np.int32), + "shot_point": np.array([1, 2], dtype=np.int32), + "cable": np.array([1, 1], dtype=np.int32), + "channel": np.array([1, 2], dtype=np.int32), + } + ) + template = TemplateRegistry().get("ObnReceiverGathers3D") + with pytest.raises(TypeError, match="auto_shot_wrap.*Seismic3DStreamerFieldRecordsTemplate"): + GridOverrider().run( + headers, + index_names=("sail_line", "gun", "shot_point"), + grid_overrides={"AutoShotWrap": True}, + template=template, + ) + + def test_calculate_shot_index_rejects_streamer_template(self) -> None: + """OBN override + streamer template -> TypeError.""" + headers = _make_struct( + { + "shot_line": np.array([1, 1], dtype=np.int32), + "gun": np.array([1, 2], dtype=np.int32), + "shot_point": np.array([1, 2], dtype=np.int32), + } + ) + template = TemplateRegistry().get("StreamerShotGathers3D") + with pytest.raises(TypeError, match="calculate_shot_index.*Seismic3DObnReceiverGathersTemplate"): + GridOverrider().run( + headers, + index_names=("shot_line", "gun", "shot_point"), + grid_overrides={"CalculateShotIndex": True}, + template=template, + ) + + def test_shot_wrap_without_template_raises(self) -> None: + """Omitting the template is the same misconfiguration as passing the wrong type.""" + headers = _make_struct( + { + "sail_line": np.array([1], dtype=np.int32), + "gun": np.array([1], dtype=np.int32), + "shot_point": np.array([1], dtype=np.int32), + "cable": np.array([1], dtype=np.int32), + "channel": np.array([1], dtype=np.int32), + } + ) + with pytest.raises(TypeError, match="auto_shot_wrap"): + GridOverrider().run( + headers, + index_names=("sail_line", "gun", "shot_point"), + grid_overrides={"AutoShotWrap": True}, + template=None, + ) + + def test_header_keys_missing_raises_via_shim(self) -> None: + """``AutoShotWrap`` + correct template but missing ``cable`` -> :class:`GridOverrideKeysError`. + + Uses the ``StreamerFieldRecords3D`` template (the only template v1.1's + ``AutoShotWrap`` accepts) so the template-type check passes and we exercise the + header-key validator instead. + """ + headers = _make_struct( + { + "sail_line": np.array([1, 1], dtype=np.int32), + "gun": np.array([1, 2], dtype=np.int32), + "shot_point": np.array([1, 2], dtype=np.int32), + } + ) + template = TemplateRegistry().get("StreamerFieldRecords3D") + with pytest.raises(GridOverrideKeysError, match="ShotWrappingStrategy"): + GridOverrider().run( + headers, + index_names=("sail_line", "gun", "shot_point"), + grid_overrides={"AutoShotWrap": True}, + template=template, + )