Skip to content

Commit aa3b194

Browse files
timsaucerclaude
andauthored
Add missing registration methods (#1474)
* Add missing SessionContext read/register methods for Arrow IPC and batches Add read_arrow, read_empty, register_arrow, and register_batch methods to SessionContext, exposing upstream DataFusion v53 functionality. The write_* methods and read_batch/read_batches are already covered by DataFrame.write_* and SessionContext.from_arrow respectively. Closes #1458. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * Remove redundant read_empty Rust binding, make Python read_empty an alias for empty_table Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * Add pathlib.Path and empty batch tests for Arrow IPC and register_batch Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * Make test_read_empty more robust with length and num_rows checks Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * Add examples to docstrings for new register/read methods Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * Empty table actually returns record batch of length one but there are no columns * Add optional argument examples to register_arrow and read_arrow docstrings Demonstrate schema= and file_extension= keyword arguments in the docstring examples for register_arrow and read_arrow, following project guidelines for optional parameter documentation. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * Simplify read_empty docstring to use alias pattern Follow the same See Also alias convention used in functions.py since read_empty is a simple alias for empty_table. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * Remove shared ctx from doctest namespace, use inline SessionContext Avoid shared SessionContext state across doctests by having each docstring example create its own ctx instance, matching the pattern used throughout the rest of the codebase. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * Remove redundant import pyarrow as pa from docstrings The pa alias is already provided by the doctest namespace in conftest.py, so inline imports are unnecessary. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 46f9ab8 commit aa3b194

File tree

5 files changed

+302
-4
lines changed

5 files changed

+302
-4
lines changed

conftest.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
import datafusion as dfn
2121
import numpy as np
22+
import pyarrow as pa
2223
import pytest
2324
from datafusion import col, lit
2425
from datafusion import functions as F
@@ -29,6 +30,7 @@ def _doctest_namespace(doctest_namespace: dict) -> None:
2930
"""Add common imports to the doctest namespace."""
3031
doctest_namespace["dfn"] = dfn
3132
doctest_namespace["np"] = np
33+
doctest_namespace["pa"] = pa
3234
doctest_namespace["col"] = col
3335
doctest_namespace["lit"] = lit
3436
doctest_namespace["F"] = F

crates/core/src/context.rs

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ use datafusion::execution::context::{
4141
};
4242
use datafusion::execution::disk_manager::DiskManagerMode;
4343
use datafusion::execution::memory_pool::{FairSpillPool, GreedyMemoryPool, UnboundedMemoryPool};
44-
use datafusion::execution::options::ReadOptions;
44+
use datafusion::execution::options::{ArrowReadOptions, ReadOptions};
4545
use datafusion::execution::runtime_env::RuntimeEnvBuilder;
4646
use datafusion::execution::session_state::SessionStateBuilder;
4747
use datafusion::prelude::{
@@ -974,6 +974,39 @@ impl PySessionContext {
974974
Ok(())
975975
}
976976

977+
#[pyo3(signature = (name, path, schema=None, file_extension=".arrow", table_partition_cols=vec![]))]
978+
pub fn register_arrow(
979+
&self,
980+
name: &str,
981+
path: &str,
982+
schema: Option<PyArrowType<Schema>>,
983+
file_extension: &str,
984+
table_partition_cols: Vec<(String, PyArrowType<DataType>)>,
985+
py: Python,
986+
) -> PyDataFusionResult<()> {
987+
let mut options = ArrowReadOptions::default().table_partition_cols(
988+
table_partition_cols
989+
.into_iter()
990+
.map(|(name, ty)| (name, ty.0))
991+
.collect::<Vec<(String, DataType)>>(),
992+
);
993+
options.file_extension = file_extension;
994+
options.schema = schema.as_ref().map(|x| &x.0);
995+
996+
let result = self.ctx.register_arrow(name, path, options);
997+
wait_for_future(py, result)??;
998+
Ok(())
999+
}
1000+
1001+
pub fn register_batch(
1002+
&self,
1003+
name: &str,
1004+
batch: PyArrowType<RecordBatch>,
1005+
) -> PyDataFusionResult<()> {
1006+
self.ctx.register_batch(name, batch.0)?;
1007+
Ok(())
1008+
}
1009+
9771010
// Registers a PyArrow.Dataset
9781011
pub fn register_dataset(
9791012
&self,
@@ -1214,6 +1247,29 @@ impl PySessionContext {
12141247
Ok(PyDataFrame::new(df))
12151248
}
12161249

1250+
#[pyo3(signature = (path, schema=None, file_extension=".arrow", table_partition_cols=vec![]))]
1251+
pub fn read_arrow(
1252+
&self,
1253+
path: &str,
1254+
schema: Option<PyArrowType<Schema>>,
1255+
file_extension: &str,
1256+
table_partition_cols: Vec<(String, PyArrowType<DataType>)>,
1257+
py: Python,
1258+
) -> PyDataFusionResult<PyDataFrame> {
1259+
let mut options = ArrowReadOptions::default().table_partition_cols(
1260+
table_partition_cols
1261+
.into_iter()
1262+
.map(|(name, ty)| (name, ty.0))
1263+
.collect::<Vec<(String, DataType)>>(),
1264+
);
1265+
options.file_extension = file_extension;
1266+
options.schema = schema.as_ref().map(|x| &x.0);
1267+
1268+
let result = self.ctx.read_arrow(path, options);
1269+
let df = wait_for_future(py, result)??;
1270+
Ok(PyDataFrame::new(df))
1271+
}
1272+
12171273
pub fn read_table(&self, table: Bound<'_, PyAny>) -> PyDataFusionResult<PyDataFrame> {
12181274
let session = self.clone().into_bound_py_any(table.py())?;
12191275
let table = PyTable::new(table, Some(session))?;

python/datafusion/context.py

Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -903,6 +903,27 @@ def register_udtf(self, func: TableFunction) -> None:
903903
"""Register a user defined table function."""
904904
self.ctx.register_udtf(func._udtf)
905905

906+
def register_batch(self, name: str, batch: pa.RecordBatch) -> None:
907+
"""Register a single :py:class:`pa.RecordBatch` as a table.
908+
909+
Args:
910+
name: Name of the resultant table.
911+
batch: Record batch to register as a table.
912+
913+
Examples:
914+
>>> ctx = dfn.SessionContext()
915+
>>> batch = pa.RecordBatch.from_pydict({"a": [1, 2, 3]})
916+
>>> ctx.register_batch("batch_tbl", batch)
917+
>>> ctx.sql("SELECT * FROM batch_tbl").collect()[0].column(0)
918+
<pyarrow.lib.Int64Array object at ...>
919+
[
920+
1,
921+
2,
922+
3
923+
]
924+
"""
925+
self.ctx.register_batch(name, batch)
926+
906927
def deregister_udtf(self, name: str) -> None:
907928
"""Remove a user-defined table function from the session.
908929
@@ -1109,6 +1130,86 @@ def register_avro(
11091130
name, str(path), schema, file_extension, table_partition_cols
11101131
)
11111132

1133+
def register_arrow(
1134+
self,
1135+
name: str,
1136+
path: str | pathlib.Path,
1137+
schema: pa.Schema | None = None,
1138+
file_extension: str = ".arrow",
1139+
table_partition_cols: list[tuple[str, str | pa.DataType]] | None = None,
1140+
) -> None:
1141+
"""Register an Arrow IPC file as a table.
1142+
1143+
The registered table can be referenced from SQL statements executed
1144+
against this context.
1145+
1146+
Args:
1147+
name: Name of the table to register.
1148+
path: Path to the Arrow IPC file.
1149+
schema: The data source schema.
1150+
file_extension: File extension to select.
1151+
table_partition_cols: Partition columns.
1152+
1153+
Examples:
1154+
>>> import tempfile, os
1155+
>>> ctx = dfn.SessionContext()
1156+
>>> table = pa.table({"x": [10, 20, 30]})
1157+
>>> with tempfile.TemporaryDirectory() as tmpdir:
1158+
... path = os.path.join(tmpdir, "data.arrow")
1159+
... with pa.ipc.new_file(path, table.schema) as writer:
1160+
... writer.write_table(table)
1161+
... ctx.register_arrow("arrow_tbl", path)
1162+
... ctx.sql("SELECT * FROM arrow_tbl").collect()[0].column(0)
1163+
<pyarrow.lib.Int64Array object at ...>
1164+
[
1165+
10,
1166+
20,
1167+
30
1168+
]
1169+
1170+
Provide an explicit ``schema`` to override schema inference:
1171+
1172+
>>> with tempfile.TemporaryDirectory() as tmpdir:
1173+
... path = os.path.join(tmpdir, "data.arrow")
1174+
... with pa.ipc.new_file(path, table.schema) as writer:
1175+
... writer.write_table(table)
1176+
... ctx.register_arrow(
1177+
... "arrow_schema",
1178+
... path,
1179+
... schema=pa.schema([("x", pa.int64())]),
1180+
... )
1181+
... ctx.sql("SELECT * FROM arrow_schema").collect()[0].column(0)
1182+
<pyarrow.lib.Int64Array object at ...>
1183+
[
1184+
10,
1185+
20,
1186+
30
1187+
]
1188+
1189+
Use ``file_extension`` to read files with a non-default extension:
1190+
1191+
>>> with tempfile.TemporaryDirectory() as tmpdir:
1192+
... path = os.path.join(tmpdir, "data.ipc")
1193+
... with pa.ipc.new_file(path, table.schema) as writer:
1194+
... writer.write_table(table)
1195+
... ctx.register_arrow(
1196+
... "arrow_ipc", path, file_extension=".ipc"
1197+
... )
1198+
... ctx.sql("SELECT * FROM arrow_ipc").collect()[0].column(0)
1199+
<pyarrow.lib.Int64Array object at ...>
1200+
[
1201+
10,
1202+
20,
1203+
30
1204+
]
1205+
"""
1206+
if table_partition_cols is None:
1207+
table_partition_cols = []
1208+
table_partition_cols = _convert_table_partition_cols(table_partition_cols)
1209+
self.ctx.register_arrow(
1210+
name, str(path), schema, file_extension, table_partition_cols
1211+
)
1212+
11121213
def register_dataset(self, name: str, dataset: pa.dataset.Dataset) -> None:
11131214
"""Register a :py:class:`pa.dataset.Dataset` as a table.
11141215
@@ -1369,6 +1470,86 @@ def read_avro(
13691470
self.ctx.read_avro(str(path), schema, file_partition_cols, file_extension)
13701471
)
13711472

1473+
def read_arrow(
1474+
self,
1475+
path: str | pathlib.Path,
1476+
schema: pa.Schema | None = None,
1477+
file_extension: str = ".arrow",
1478+
file_partition_cols: list[tuple[str, str | pa.DataType]] | None = None,
1479+
) -> DataFrame:
1480+
"""Create a :py:class:`DataFrame` for reading an Arrow IPC data source.
1481+
1482+
Args:
1483+
path: Path to the Arrow IPC file.
1484+
schema: The data source schema.
1485+
file_extension: File extension to select.
1486+
file_partition_cols: Partition columns.
1487+
1488+
Returns:
1489+
DataFrame representation of the read Arrow IPC file.
1490+
1491+
Examples:
1492+
>>> import tempfile, os
1493+
>>> ctx = dfn.SessionContext()
1494+
>>> table = pa.table({"a": [1, 2, 3]})
1495+
>>> with tempfile.TemporaryDirectory() as tmpdir:
1496+
... path = os.path.join(tmpdir, "data.arrow")
1497+
... with pa.ipc.new_file(path, table.schema) as writer:
1498+
... writer.write_table(table)
1499+
... df = ctx.read_arrow(path)
1500+
... df.collect()[0].column(0)
1501+
<pyarrow.lib.Int64Array object at ...>
1502+
[
1503+
1,
1504+
2,
1505+
3
1506+
]
1507+
1508+
Provide an explicit ``schema`` to override schema inference:
1509+
1510+
>>> with tempfile.TemporaryDirectory() as tmpdir:
1511+
... path = os.path.join(tmpdir, "data.arrow")
1512+
... with pa.ipc.new_file(path, table.schema) as writer:
1513+
... writer.write_table(table)
1514+
... df = ctx.read_arrow(path, schema=pa.schema([("a", pa.int64())]))
1515+
... df.collect()[0].column(0)
1516+
<pyarrow.lib.Int64Array object at ...>
1517+
[
1518+
1,
1519+
2,
1520+
3
1521+
]
1522+
1523+
Use ``file_extension`` to read files with a non-default extension:
1524+
1525+
>>> with tempfile.TemporaryDirectory() as tmpdir:
1526+
... path = os.path.join(tmpdir, "data.ipc")
1527+
... with pa.ipc.new_file(path, table.schema) as writer:
1528+
... writer.write_table(table)
1529+
... df = ctx.read_arrow(path, file_extension=".ipc")
1530+
... df.collect()[0].column(0)
1531+
<pyarrow.lib.Int64Array object at ...>
1532+
[
1533+
1,
1534+
2,
1535+
3
1536+
]
1537+
"""
1538+
if file_partition_cols is None:
1539+
file_partition_cols = []
1540+
file_partition_cols = _convert_table_partition_cols(file_partition_cols)
1541+
return DataFrame(
1542+
self.ctx.read_arrow(str(path), schema, file_extension, file_partition_cols)
1543+
)
1544+
1545+
def read_empty(self) -> DataFrame:
1546+
"""Create an empty :py:class:`DataFrame` with no columns or rows.
1547+
1548+
See Also:
1549+
This is an alias for :meth:`empty_table`.
1550+
"""
1551+
return self.empty_table()
1552+
13721553
def read_table(
13731554
self, table: Table | TableProviderExportable | DataFrame | pa.dataset.Dataset
13741555
) -> DataFrame:

python/datafusion/user_defined.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,6 @@ def udf(*args: Any, **kwargs: Any): # noqa: D417
213213
Examples:
214214
Using ``udf`` as a function:
215215
216-
>>> import pyarrow as pa
217216
>>> import pyarrow.compute as pc
218217
>>> from datafusion.user_defined import ScalarUDF
219218
>>> def double_func(x):
@@ -480,7 +479,6 @@ def udaf(*args: Any, **kwargs: Any): # noqa: D417, C901
480479
instance in which this UDAF is used.
481480
482481
Examples:
483-
>>> import pyarrow as pa
484482
>>> import pyarrow.compute as pc
485483
>>> from datafusion.user_defined import AggregateUDF, Accumulator, udaf
486484
>>> class Summarize(Accumulator):
@@ -874,7 +872,6 @@ def udwf(*args: Any, **kwargs: Any): # noqa: D417
874872
When using ``udwf`` as a decorator, do not pass ``func`` explicitly.
875873
876874
Examples:
877-
>>> import pyarrow as pa
878875
>>> from datafusion.user_defined import WindowUDF, WindowEvaluator, udwf
879876
>>> class BiasedNumbers(WindowEvaluator):
880877
... def __init__(self, start: int = 0):

python/tests/test_context.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -788,6 +788,68 @@ def test_read_avro(ctx):
788788
assert avro_df is not None
789789

790790

791+
def test_read_arrow(ctx, tmp_path):
792+
# Write an Arrow IPC file, then read it back
793+
table = pa.table({"a": [1, 2, 3], "b": ["x", "y", "z"]})
794+
arrow_path = tmp_path / "test.arrow"
795+
with pa.ipc.new_file(str(arrow_path), table.schema) as writer:
796+
writer.write_table(table)
797+
798+
df = ctx.read_arrow(str(arrow_path))
799+
result = df.collect()
800+
assert result[0].column(0) == pa.array([1, 2, 3])
801+
assert result[0].column(1) == pa.array(["x", "y", "z"])
802+
803+
# Also verify pathlib.Path works
804+
df = ctx.read_arrow(arrow_path)
805+
result = df.collect()
806+
assert result[0].column(0) == pa.array([1, 2, 3])
807+
808+
809+
def test_read_empty(ctx):
810+
df = ctx.read_empty()
811+
result = df.collect()
812+
assert len(result) == 1
813+
assert result[0].num_columns == 0
814+
815+
df = ctx.empty_table()
816+
result = df.collect()
817+
assert len(result) == 1
818+
assert result[0].num_columns == 0
819+
820+
821+
def test_register_arrow(ctx, tmp_path):
822+
# Write an Arrow IPC file, then register and query it
823+
table = pa.table({"x": [10, 20, 30]})
824+
arrow_path = tmp_path / "test.arrow"
825+
with pa.ipc.new_file(str(arrow_path), table.schema) as writer:
826+
writer.write_table(table)
827+
828+
ctx.register_arrow("arrow_tbl", str(arrow_path))
829+
result = ctx.sql("SELECT * FROM arrow_tbl").collect()
830+
assert result[0].column(0) == pa.array([10, 20, 30])
831+
832+
# Also verify pathlib.Path works
833+
ctx.register_arrow("arrow_tbl_path", arrow_path)
834+
result = ctx.sql("SELECT * FROM arrow_tbl_path").collect()
835+
assert result[0].column(0) == pa.array([10, 20, 30])
836+
837+
838+
def test_register_batch(ctx):
839+
batch = pa.RecordBatch.from_pydict({"a": [1, 2, 3], "b": [4, 5, 6]})
840+
ctx.register_batch("batch_tbl", batch)
841+
result = ctx.sql("SELECT * FROM batch_tbl").collect()
842+
assert result[0].column(0) == pa.array([1, 2, 3])
843+
assert result[0].column(1) == pa.array([4, 5, 6])
844+
845+
846+
def test_register_batch_empty(ctx):
847+
batch = pa.RecordBatch.from_pydict({"a": pa.array([], type=pa.int64())})
848+
ctx.register_batch("empty_batch_tbl", batch)
849+
result = ctx.sql("SELECT * FROM empty_batch_tbl").collect()
850+
assert result[0].num_rows == 0
851+
852+
791853
def test_create_sql_options():
792854
SQLOptions()
793855

0 commit comments

Comments
 (0)