Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pyiceberg/table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -874,8 +874,8 @@ def upsert(
format_version=self.table_metadata.format_version,
)

# get list of rows that exist so we don't have to load the entire target table
matched_predicate = upsert_util.create_match_filter(df, join_cols)
# Use a conservative file-pruning predicate for the initial scan; exact matching happens below.
matched_predicate = upsert_util.create_file_match_filter(df, join_cols)

# We must use Transaction.table_metadata for the scan. This includes all uncommitted - but relevant - changes.

Expand Down
30 changes: 30 additions & 0 deletions pyiceberg/table/upsert_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,41 @@
AlwaysFalse,
BooleanExpression,
EqualTo,
GreaterThanOrEqual,
In,
IsNull,
LessThanOrEqual,
Or,
)


def create_file_match_filter(df: pyarrow_table, join_cols: list[str]) -> BooleanExpression:
"""Build a conservative predicate for upsert file pruning.

The returned predicate may match extra files, but must not exclude files that
could contain a matching row. Exact row matching still happens downstream.
"""
if len(df) == 0:
return AlwaysFalse()

per_col: list[BooleanExpression] = []
for col in join_cols:
col_arr = df.column(col)
bounds = pc.min_max(col_arr).as_py()
col_min, col_max = bounds["min"], bounds["max"]

if col_min is None:
per_col.append(IsNull(col))
continue

pred: BooleanExpression = GreaterThanOrEqual(col, col_min) & LessThanOrEqual(col, col_max)
if pc.any(pc.is_null(col_arr)).as_py():
pred = pred | IsNull(col)
per_col.append(pred)

return functools.reduce(operator.and_, per_col)


def create_match_filter(df: pyarrow_table, join_cols: list[str]) -> BooleanExpression:
unique_keys = df.select(join_cols).group_by(join_cols).aggregate([])

Expand Down
149 changes: 147 additions & 2 deletions tests/table/test_upsert.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import subprocess
import sys
import textwrap
from pathlib import PosixPath

import pyarrow as pa
Expand All @@ -23,13 +26,23 @@

from pyiceberg.catalog import Catalog
from pyiceberg.exceptions import NoSuchTableError
from pyiceberg.expressions import AlwaysTrue, And, EqualTo, Reference
from pyiceberg.expressions import (
AlwaysFalse,
AlwaysTrue,
And,
EqualTo,
GreaterThanOrEqual,
IsNull,
LessThanOrEqual,
Or,
Reference,
)
from pyiceberg.expressions.literals import LongLiteral
from pyiceberg.io.pyarrow import schema_to_pyarrow
from pyiceberg.schema import Schema
from pyiceberg.table import Table, UpsertResult
from pyiceberg.table.snapshots import Operation
from pyiceberg.table.upsert_util import create_match_filter
from pyiceberg.table.upsert_util import create_file_match_filter, create_match_filter
from pyiceberg.types import IntegerType, NestedField, StringType, StructType
from tests.catalog.test_base import InMemoryCatalog

Expand Down Expand Up @@ -443,6 +456,138 @@ def test_create_match_filter_single_condition() -> None:
)


def test_create_file_match_filter_empty_source_prunes_everything() -> None:
table = pa.table({"order_id": pa.array([], type=pa.int64()), "order_line_id": pa.array([], type=pa.int64())})

assert create_file_match_filter(table, ["order_id", "order_line_id"]) == AlwaysFalse()


def test_create_file_match_filter_multi_column_bounds() -> None:
table = pa.table({"order_id": [1, 2, 3], "order_line_id": [100, 200, 300]})

expr = create_file_match_filter(table, ["order_id", "order_line_id"])

leaves: list[object] = []

def collect(node: object) -> None:
if isinstance(node, And):
collect(node.left)
collect(node.right)
else:
leaves.append(node)

collect(expr)
bounds_by_col: dict[str, dict[type, object]] = {}
for leaf in leaves:
bounds_by_col.setdefault(leaf.term.name, {})[type(leaf)] = leaf.literal.value # type: ignore[attr-defined]

assert bounds_by_col == {
"order_id": {GreaterThanOrEqual: 1, LessThanOrEqual: 3},
"order_line_id": {GreaterThanOrEqual: 100, LessThanOrEqual: 300},
}


@pytest.mark.parametrize(
("values", "expected"),
[
([1, 5], And(GreaterThanOrEqual("order_id", 1), LessThanOrEqual("order_id", 5))),
([None, None], IsNull("order_id")),
([1, None, 5], Or(And(GreaterThanOrEqual("order_id", 1), LessThanOrEqual("order_id", 5)), IsNull("order_id"))),
],
)
def test_create_file_match_filter_null_shape(values: list[int | None], expected: object) -> None:
table = pa.table({"order_id": pa.array(values, type=pa.int64())})

assert create_file_match_filter(table, ["order_id"]) == expected


def test_upsert_multi_col_file_match_filter_culls_false_positives(catalog: Catalog) -> None:
identifier = "default.test_upsert_multi_col_file_match_filter_culls_false_positives"
_drop_table(catalog, identifier)

schema = pa.schema([("order_id", pa.int32()), ("order_line_id", pa.int32()), ("payload", pa.string())])
table = catalog.create_table(identifier, schema)
table.append(
pa.Table.from_pylist(
[
{"order_id": 1, "order_line_id": 200, "payload": "keep-1"},
{"order_id": 2, "order_line_id": 100, "payload": "keep-2"},
{"order_id": 1, "order_line_id": 100, "payload": "old"},
],
schema=schema,
)
)

source = pa.Table.from_pylist(
[
{"order_id": 1, "order_line_id": 100, "payload": "new"},
{"order_id": 2, "order_line_id": 200, "payload": "insert"},
],
schema=schema,
)

result = table.upsert(source, join_cols=["order_id", "order_line_id"])

assert result.rows_updated == 1
assert result.rows_inserted == 1
rows_by_key = {(row["order_id"], row["order_line_id"]): row["payload"] for row in table.scan().to_arrow().to_pylist()}
assert rows_by_key == {
(1, 200): "keep-1",
(2, 100): "keep-2",
(1, 100): "new",
(2, 200): "insert",
}


def test_upsert_large_composite_key_initial_scan_does_not_recurse(tmp_path: PosixPath) -> None:
"""Regression: initial scan planning must not build the exact composite-key tree.

Running this in a subprocess keeps the test runner alive on runtimes where the
old recursive visitor shape can overflow the C stack.
"""
script = textwrap.dedent(
f"""
import sys

sys.setrecursionlimit(10**7)

import pyarrow as pa

from tests.catalog.test_base import InMemoryCatalog

n = 30_000
catalog = InMemoryCatalog("test", warehouse={str(tmp_path)!r})
catalog.create_namespace("default")
schema = pa.schema([
("order_id", pa.int64()),
("order_line_id", pa.int64()),
("payload", pa.string()),
])
table = catalog.create_table("default.regression", schema)
table.append(
pa.Table.from_pylist(
[{{"order_id": 0, "order_line_id": 100_000, "payload": "old"}}],
schema=schema,
)
)
source = pa.table({{
"order_id": pa.array(range(n), type=pa.int64()),
"order_line_id": pa.array(range(100_000, 100_000 + n), type=pa.int64()),
"payload": pa.array(["old"] * n, type=pa.string()),
}})

table.upsert(
source,
join_cols=["order_id", "order_line_id"],
when_not_matched_insert_all=False,
)
"""
)
result = subprocess.run([sys.executable, "-c", script], capture_output=True, text=True, check=False)

assert result.returncode == 0, f"stdout:\n{result.stdout}\nstderr:\n{result.stderr}"


def test_upsert_with_duplicate_rows_in_table(catalog: Catalog) -> None:
identifier = "default.test_upsert_with_duplicate_rows_in_table"

Expand Down