From 02a52594617f3d970e98cd7ec5b4f3e1e47fc470 Mon Sep 17 00:00:00 2001 From: Arun Sharma Date: Tue, 12 May 2026 09:07:21 -0700 Subject: [PATCH 1/2] Add regression test for ndarray parameters --- test/test_issue.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/test/test_issue.py b/test/test_issue.py index df1fccb..5cea6b0 100644 --- a/test/test_issue.py +++ b/test/test_issue.py @@ -1,5 +1,7 @@ from __future__ import annotations +import pytest + try: from tools.python_api.test.type_aliases import ConnDB except ImportError: @@ -129,6 +131,27 @@ def test_int8_type_sniffing(conn_db_readwrite: ConnDB) -> None: result.close() +def test_issue_483_numpy_ndarray_parameter(conn_db_readwrite: ConnDB) -> None: + np = pytest.importorskip("numpy") + + conn, _ = conn_db_readwrite + conn.execute("CREATE NODE TABLE T(id INT64, v FLOAT[3], PRIMARY KEY(id))") + conn.execute("CREATE (:T {id: 1})") + + arr = np.array([0.1, 0.2, 0.3], dtype=np.float32) + result = conn.execute( + "MATCH (n:T {id: 1}) SET n.v = $v RETURN n.v", + {"v": arr}, + ) + + assert result.has_next() + assert result.get_next() == [ + [pytest.approx(0.1), pytest.approx(0.2), pytest.approx(0.3)] + ] + assert not result.has_next() + result.close() + + # TODO(Maxwell): check if we should change getCastCost() for the following test # def test_issue_3248(conn_db_readwrite: ConnDB) -> None: # conn, _ = conn_db_readwrite From d4c0f404db6a7eccc4d23912d7fa88b34de1b891 Mon Sep 17 00:00:00 2001 From: Arun Sharma Date: Tue, 12 May 2026 16:05:51 -0700 Subject: [PATCH 2/2] Handle NumPy ndarray query parameters Fix ndarray parameter conversion by importing importlib.util directly, inferring NumPy parameter types from dtype and shape, reading ndarray buffers through Python's buffer metadata instead of materializing tolist(), and using stable homogeneous-list inference for Python bool, int, and float lists. --- .../include/cached_import/py_cached_modules.h | 5 +- src_cpp/py_connection.cpp | 146 ++++++++++++++++++ src_py/_lbug_capi.py | 90 +++++++++++ test/test_scan_pandas_pyarrow.py | 38 ++--- test/test_scan_pyarrow.py | 2 +- 5 files changed, 258 insertions(+), 23 deletions(-) diff --git a/src_cpp/include/cached_import/py_cached_modules.h b/src_cpp/include/cached_import/py_cached_modules.h index dad381a..0d18bde 100644 --- a/src_cpp/include/cached_import/py_cached_modules.h +++ b/src_cpp/include/cached_import/py_cached_modules.h @@ -25,14 +25,13 @@ class DecimalCachedItem : public PythonCachedItem { class ImportLibCachedItem : public PythonCachedItem { class UtilCachedItem : public PythonCachedItem { public: - explicit UtilCachedItem(PythonCachedItem* parent) - : PythonCachedItem{"util", parent}, find_spec{"find_spec", this} {} + UtilCachedItem() : PythonCachedItem{"importlib.util"}, find_spec{"find_spec", this} {} PythonCachedItem find_spec; }; public: - ImportLibCachedItem() : PythonCachedItem("importlib"), util(this) {} + ImportLibCachedItem() : PythonCachedItem("importlib"), util() {} UtilCachedItem util; }; diff --git a/src_cpp/py_connection.cpp b/src_cpp/py_connection.cpp index 8c5bd4d..56a7c20 100644 --- a/src_cpp/py_connection.cpp +++ b/src_cpp/py_connection.cpp @@ -15,6 +15,7 @@ #include "include/py_udf.h" #include "main/connection.h" #include "main/query_result/materialized_query_result.h" +#include "numpy/numpy_type.h" #include "pandas/pandas_scan.h" #include "processor/result/factorized_table.h" #include "pyarrow/pyarrow_scan.h" @@ -388,6 +389,47 @@ bool integerFitsIn(int64_t val) { return val >= 0 && val <= UINT8_MAX; } +static LogicalType pyHomogeneousListType(const py::list& lst) { + py::handle firstNonNull; + for (auto child : lst) { + if (!child.is_none()) { + firstNonNull = child; + break; + } + } + if (!firstNonNull) { + return LogicalType::LIST(LogicalType::ANY()); + } + if (!py::isinstance(firstNonNull) && !py::isinstance(firstNonNull) && + !py::isinstance(firstNonNull)) { + return LogicalType::ANY(); + } + for (auto child : lst) { + if (child.is_none()) { + continue; + } + if (child.get_type().ptr() != firstNonNull.get_type().ptr()) { + return LogicalType::ANY(); + } + } + if (py::isinstance(firstNonNull)) { + return LogicalType::LIST(LogicalType::BOOL()); + } + if (py::isinstance(firstNonNull)) { + return LogicalType::LIST(LogicalType::INT64()); + } + return LogicalType::LIST(LogicalType::DOUBLE()); +} + +static LogicalType pyNumpyArrayLogicalType(const py::array& arr) { + auto npType = NumpyTypeUtils::convertNumpyType(arr.attr("dtype")); + auto type = NumpyTypeUtils::numpyToLogicalType(npType); + for (auto i = 0; i < arr.ndim(); ++i) { + type = LogicalType::LIST(std::move(type)); + } + return type; +} + static LogicalType pyLogicalType(const py::handle& val) { auto datetime_datetime = importCache->datetime.datetime(); auto time_delta = importCache->datetime.timedelta(); @@ -468,8 +510,14 @@ static LogicalType pyLogicalType(const py::handle& val) { childValueType = std::move(resultValue); } return LogicalType::MAP(std::move(childKeyType), std::move(childValueType)); + } else if (py::isinstance(val)) { + return pyNumpyArrayLogicalType(py::reinterpret_borrow(val)); } else if (py::isinstance(val)) { py::list lst = py::reinterpret_borrow(val); + auto homogeneousType = pyHomogeneousListType(lst); + if (homogeneousType.getLogicalTypeID() != LogicalTypeID::ANY) { + return homogeneousType; + } auto childType = LogicalType::ANY(); for (auto child : lst) { auto curChildType = pyLogicalType(child); @@ -568,8 +616,14 @@ static LogicalType pyLogicalTypeFromParameter(const py::handle& val) { structFields.emplace_back(std::move(keyName), std::move(keyType)); } return LogicalType::STRUCT(std::move(structFields)); + } else if (py::isinstance(val)) { + return pyNumpyArrayLogicalType(py::reinterpret_borrow(val)); } else if (py::isinstance(val)) { py::list lst = py::reinterpret_borrow(val); + auto homogeneousType = pyHomogeneousListType(lst); + if (homogeneousType.getLogicalTypeID() != LogicalTypeID::ANY) { + return homogeneousType; + } auto childType = LogicalType::ANY(); for (auto child : lst) { auto curChildType = pyLogicalTypeFromParameter(child); @@ -603,6 +657,90 @@ static std::string pythonObjectToJsonString(const py::handle& val) { return py::cast(jsonStr); } +template +static Value transformNumpyScalarAs(const void* ptr, const LogicalType& type) { + auto value = *reinterpret_cast(ptr); + switch (type.getLogicalTypeID()) { + case LogicalTypeID::BOOL: + return Value::createValue(static_cast(value)); + case LogicalTypeID::INT64: + return Value::createValue(static_cast(value)); + case LogicalTypeID::UINT32: + return Value::createValue(static_cast(value)); + case LogicalTypeID::INT32: + return Value::createValue(static_cast(value)); + case LogicalTypeID::UINT16: + return Value::createValue(static_cast(value)); + case LogicalTypeID::INT16: + return Value::createValue(static_cast(value)); + case LogicalTypeID::UINT8: + return Value::createValue(static_cast(value)); + case LogicalTypeID::INT8: + return Value::createValue(static_cast(value)); + case LogicalTypeID::FLOAT: + return Value(static_cast(value)); + case LogicalTypeID::DOUBLE: + return Value::createValue(static_cast(value)); + default: + throw RuntimeException("Unsupported numpy ndarray parameter child type " + type.toString()); + } +} + +static Value transformNumpyScalarAs(const void* ptr, NumpyNullableType npType, + const LogicalType& type) { + switch (npType) { + case NumpyNullableType::BOOL: + return transformNumpyScalarAs(ptr, type); + case NumpyNullableType::INT_8: + return transformNumpyScalarAs(ptr, type); + case NumpyNullableType::UINT_8: + return transformNumpyScalarAs(ptr, type); + case NumpyNullableType::INT_16: + return transformNumpyScalarAs(ptr, type); + case NumpyNullableType::UINT_16: + return transformNumpyScalarAs(ptr, type); + case NumpyNullableType::INT_32: + return transformNumpyScalarAs(ptr, type); + case NumpyNullableType::UINT_32: + return transformNumpyScalarAs(ptr, type); + case NumpyNullableType::INT_64: + return transformNumpyScalarAs(ptr, type); + case NumpyNullableType::UINT_64: + return transformNumpyScalarAs(ptr, type); + case NumpyNullableType::FLOAT_32: + return transformNumpyScalarAs(ptr, type); + case NumpyNullableType::FLOAT_64: + return transformNumpyScalarAs(ptr, type); + default: + throw RuntimeException("Unsupported numpy ndarray parameter dtype"); + } +} + +static Value transformNumpyArrayAs(const LogicalType& type, uint64_t dimension, const uint8_t* ptr, + const py::buffer_info& info, NumpyNullableType npType) { + if (dimension == static_cast(info.ndim)) { + return transformNumpyScalarAs(ptr, npType, type); + } + if (type.getLogicalTypeID() != LogicalTypeID::LIST) { + throw RuntimeException("Cannot convert numpy ndarray parameter to " + type.toString()); + } + std::vector> children; + children.reserve(info.shape[dimension]); + const auto& childType = ListType::getChildType(type); + for (auto i = 0; i < info.shape[dimension]; ++i) { + auto childPtr = ptr + i * info.strides[dimension]; + children.push_back(std::make_unique( + transformNumpyArrayAs(childType, dimension + 1, childPtr, info, npType))); + } + return Value(type.copy(), std::move(children)); +} + +static Value transformNumpyArrayAs(const py::array& arr, const LogicalType& type) { + auto info = arr.request(); + auto npType = NumpyTypeUtils::convertNumpyType(arr.attr("dtype")).type; + return transformNumpyArrayAs(type, 0, reinterpret_cast(info.ptr), info, npType); +} + Value PyConnection::transformPythonValueAs(const py::handle& val, const LogicalType& type) { // ignore the type of the actual python object, just directly cast auto datetime_datetime = importCache->datetime.datetime(); @@ -632,6 +770,8 @@ Value PyConnection::transformPythonValueAs(const py::handle& val, const LogicalT return Value::createValue(py::cast(val).cast()); case LogicalTypeID::DOUBLE: return Value::createValue(py::cast(val).cast()); + case LogicalTypeID::FLOAT: + return Value(py::cast(val).cast()); case LogicalTypeID::DECIMAL: { auto str = py::cast(py::str(val)); int128_t result = 0; @@ -708,6 +848,9 @@ Value PyConnection::transformPythonValueAs(const py::handle& val, const LogicalT return Value{uuidToAppend}; } case LogicalTypeID::LIST: { + if (py::isinstance(val)) { + return transformNumpyArrayAs(py::reinterpret_borrow(val), type); + } py::list lst = py::reinterpret_borrow(val); std::vector> children; for (auto child : lst) { @@ -763,6 +906,9 @@ Value PyConnection::transformPythonValueFromParameterAs(const py::handle& val, auto jsonStr = pythonObjectToJsonString(val); return Value::createValue(jsonStr); } + if (py::isinstance(val)) { + return transformNumpyArrayAs(py::reinterpret_borrow(val), type); + } py::list lst = py::reinterpret_borrow(val); std::vector> children; for (auto child : lst) { diff --git a/src_py/_lbug_capi.py b/src_py/_lbug_capi.py index 1fbacfa..eb464e2 100644 --- a/src_py/_lbug_capi.py +++ b/src_py/_lbug_capi.py @@ -244,6 +244,8 @@ def _ensure_arrow_atexit_cleanup() -> None: _LBUG_MAP = 55 _LBUG_UNION = 56 _LBUG_UUID = 59 +_NUMPY_MODULE: Any | None = None +_NUMPY_IMPORT_ATTEMPTED = False def _setup_signatures() -> None: @@ -392,6 +394,16 @@ def _setup_signatures() -> None: _LIB.lbug_value_create_int32.restype = ctypes.POINTER(_LbugValue) _LIB.lbug_value_create_int64.argtypes = [ctypes.c_int64] _LIB.lbug_value_create_int64.restype = ctypes.POINTER(_LbugValue) + _LIB.lbug_value_create_uint8.argtypes = [ctypes.c_uint8] + _LIB.lbug_value_create_uint8.restype = ctypes.POINTER(_LbugValue) + _LIB.lbug_value_create_uint16.argtypes = [ctypes.c_uint16] + _LIB.lbug_value_create_uint16.restype = ctypes.POINTER(_LbugValue) + _LIB.lbug_value_create_uint32.argtypes = [ctypes.c_uint32] + _LIB.lbug_value_create_uint32.restype = ctypes.POINTER(_LbugValue) + _LIB.lbug_value_create_uint64.argtypes = [ctypes.c_uint64] + _LIB.lbug_value_create_uint64.restype = ctypes.POINTER(_LbugValue) + _LIB.lbug_value_create_float.argtypes = [ctypes.c_float] + _LIB.lbug_value_create_float.restype = ctypes.POINTER(_LbugValue) _LIB.lbug_value_create_double.argtypes = [ctypes.c_double] _LIB.lbug_value_create_double.restype = ctypes.POINTER(_LbugValue) _LIB.lbug_value_create_string.argtypes = [ctypes.c_char_p] @@ -930,11 +942,89 @@ def _parse_rendered_value(value: str) -> Any: return value +def _numpy_module() -> Any | None: + global _NUMPY_IMPORT_ATTEMPTED, _NUMPY_MODULE + if _NUMPY_IMPORT_ATTEMPTED: + return _NUMPY_MODULE + _NUMPY_IMPORT_ATTEMPTED = True + try: + import numpy as np + except ModuleNotFoundError: + return None + _NUMPY_MODULE = np + return np + + +def _is_numpy_scalar(value: Any) -> bool: + np = _numpy_module() + return bool(np is not None and isinstance(value, np.generic)) + + +def _is_numpy_array(value: Any) -> bool: + np = _numpy_module() + return bool(np is not None and isinstance(value, np.ndarray)) + + +def _numpy_scalar_value_from_python(value: Any) -> ctypes.POINTER(_LbugValue): + dtype = value.dtype + kind = dtype.kind + item = value.item() + if kind == "b": + return _LIB.lbug_value_create_bool(bool(item)) + if kind == "i": + if dtype.itemsize == 1: + return _LIB.lbug_value_create_int8(item) + if dtype.itemsize == 2: + return _LIB.lbug_value_create_int16(item) + if dtype.itemsize == 4: + return _LIB.lbug_value_create_int32(item) + return _LIB.lbug_value_create_int64(item) + if kind == "u": + if dtype.itemsize == 1: + return _LIB.lbug_value_create_uint8(item) + if dtype.itemsize == 2: + return _LIB.lbug_value_create_uint16(item) + if dtype.itemsize == 4: + return _LIB.lbug_value_create_uint32(item) + return _LIB.lbug_value_create_uint64(item) + if kind == "f": + if dtype.itemsize == 4: + return _LIB.lbug_value_create_float(item) + return _LIB.lbug_value_create_double(item) + + return _value_from_python(item) + + +def _numpy_array_value_from_python(value: Any) -> ctypes.POINTER(_LbugValue): + if value.ndim == 0: + return _numpy_scalar_value_from_python(value[()]) + + child_ptrs: list[ctypes.POINTER(_LbugValue)] = [] + try: + for item in value: + child_ptrs.append(_value_from_python(item)) + out = ctypes.POINTER(_LbugValue)() + arr_type = ctypes.POINTER(_LbugValue) * len(child_ptrs) + arr = arr_type(*child_ptrs) if child_ptrs else arr_type() + _check_state( + _LIB.lbug_value_create_list(len(child_ptrs), arr, ctypes.byref(out)), + "Failed to create numpy ndarray list value", + ) + return out + finally: + for ptr in child_ptrs: + _LIB.lbug_value_destroy(ptr) + + def _value_from_python(value: Any) -> ctypes.POINTER(_LbugValue): if value is None: return _LIB.lbug_value_create_null() if isinstance(value, CAPIJsonParameter): return _LIB.lbug_value_create_json(value.value.encode()) + if _is_numpy_array(value): + return _numpy_array_value_from_python(value) + if _is_numpy_scalar(value): + return _numpy_scalar_value_from_python(value) if isinstance(value, bool): return _LIB.lbug_value_create_bool(value) if isinstance(value, int) and not isinstance(value, bool): diff --git a/test/test_scan_pandas_pyarrow.py b/test/test_scan_pandas_pyarrow.py index 438a62f..7cfbab4 100644 --- a/test/test_scan_pandas_pyarrow.py +++ b/test/test_scan_pandas_pyarrow.py @@ -228,7 +228,7 @@ def test_pyarrow_blob(conn_db_readonly: ConnDB) -> None: "col4": arrowtopd(col4), } ).sort_values(by=["index"]) - result = conn.execute("LOAD FROM df RETURN * ORDER BY index").get_as_df() + result = conn.execute("LOAD FROM df RETURN * ORDER BY `index`").get_as_df() for colname in ["col1", "col2", "col3", "col4"]: for expected, actual in zip(df[colname], result[colname], strict=False): if is_null(expected) or is_null(actual): @@ -277,7 +277,7 @@ def test_pyarrow_string(conn_db_readonly: ConnDB) -> None: "col3": arrowtopd(col3), } ).sort_values(by=["index"]) - result = conn.execute("LOAD FROM df RETURN * ORDER BY index").get_as_df() + result = conn.execute("LOAD FROM df RETURN * ORDER BY `index`").get_as_df() for colname in ["col1", "col2", "col3"]: for expected, actual in zip(df[colname], result[colname], strict=False): if is_null(expected) or is_null(actual): @@ -305,7 +305,7 @@ def test_pyarrow_dict(conn_db_readonly: ConnDB) -> None: df = pd.DataFrame( {"index": arrowtopd(index), "col1": arrowtopd(col1), "col2": arrowtopd(col2)} ) - result = conn.execute("LOAD FROM df RETURN * ORDER BY index").get_as_df() + result = conn.execute("LOAD FROM df RETURN * ORDER BY `index`").get_as_df() for colname in ["col1", "col2"]: for expected, actual in zip(df[colname], result[colname], strict=False): assert expected == actual @@ -320,7 +320,7 @@ def test_pyarrow_dict_offset(conn_db_readonly: ConnDB) -> None: dictionary = pa.array([1, 2, 3, 4]) col1 = pa.DictionaryArray.from_arrays(indices, dictionary.slice(1, 3)) df = pd.DataFrame({"index": arrowtopd(index), "col1": arrowtopd(col1)}) - result = conn.execute("LOAD FROM df RETURN * ORDER BY index") + result = conn.execute("LOAD FROM df RETURN * ORDER BY `index`") idx = 0 while result.has_next(): assert idx < len(index) @@ -370,7 +370,7 @@ def test_pyarrow_list(conn_db_readonly: ConnDB) -> None: df = pd.DataFrame( {"index": arrowtopd(index), "col1": arrowtopd(col1), "col2": arrowtopd(col2)} ) - result = conn.execute("LOAD FROM df RETURN * ORDER BY index") + result = conn.execute("LOAD FROM df RETURN * ORDER BY `index`") idx = 0 while result.has_next(): assert idx < len(index) @@ -412,7 +412,7 @@ def test_pyarrow_list_offset(conn_db_readonly: ConnDB) -> None: "col1": arrowtopd(col1), } ) - result = conn.execute("LOAD FROM df RETURN * ORDER BY index") + result = conn.execute("LOAD FROM df RETURN * ORDER BY `index`") idx = 0 while result.has_next(): assert idx < len(index) @@ -561,7 +561,7 @@ def test_pyarrow_fixed_list(conn_db_readonly: ConnDB) -> None: "map_col": arrowtopd(map_col), } ) - result = conn.execute("LOAD FROM df RETURN * ORDER BY index") + result = conn.execute("LOAD FROM df RETURN * ORDER BY `index`") idx = 0 while result.has_next(): @@ -621,7 +621,7 @@ def test_pyarrow_fixed_list_offset(conn_db_readonly: ConnDB) -> None: "col2": arrowtopd(col2), } ) - result = conn.execute("LOAD FROM df RETURN * ORDER BY index") + result = conn.execute("LOAD FROM df RETURN * ORDER BY `index`") idx = 0 while result.has_next(): assert idx < len(index) @@ -654,7 +654,7 @@ def test_pyarrow_struct(conn_db_readonly: ConnDB) -> None: pa.struct([("a", pa.int32()), ("b", pa.struct([("c", pa.string())]))]), ) df = pd.DataFrame({"index": arrowtopd(index), "col1": arrowtopd(col1)}) - result = conn.execute("LOAD FROM df RETURN * ORDER BY index") + result = conn.execute("LOAD FROM df RETURN * ORDER BY `index`") idx = 0 while result.has_next(): assert idx < len(index) @@ -691,7 +691,7 @@ def test_pyarrow_struct_offset(conn_db_readonly: ConnDB) -> None: mask=mask, ) df = pd.DataFrame({"index": arrowtopd(index), "col1": arrowtopd(col1)}) - result = conn.execute("LOAD FROM df RETURN * ORDER BY index") + result = conn.execute("LOAD FROM df RETURN * ORDER BY `index`") idx = 0 while result.has_next(): assert idx < len(index) @@ -730,7 +730,7 @@ def test_pyarrow_union_sparse(conn_db_readonly: ConnDB) -> None: ], ) df = pd.DataFrame({"index": arrowtopd(index), "col1": arrowtopd(col1)}) - result = conn.execute("LOAD FROM df RETURN * ORDER BY index") + result = conn.execute("LOAD FROM df RETURN * ORDER BY `index`") idx = 0 while result.has_next(): assert idx < len(index) @@ -774,7 +774,7 @@ def test_pyarrow_union_dense(conn_db_readonly: ConnDB) -> None: ], ) df = pd.DataFrame({"index": arrowtopd(index), "col1": arrowtopd(col1)}) - result = conn.execute("LOAD FROM df RETURN * ORDER BY index") + result = conn.execute("LOAD FROM df RETURN * ORDER BY `index`") idx = 0 while result.has_next(): assert idx < len(index) @@ -810,7 +810,7 @@ def test_pyarrow_map(conn_db_readonly: ConnDB) -> None: type=pa.map_(pa.string(), pa.string()), ) df = pd.DataFrame({"index": arrowtopd(index), "col1": arrowtopd(col1)}) - result = conn.execute("LOAD FROM df RETURN * ORDER BY index") + result = conn.execute("LOAD FROM df RETURN * ORDER BY `index`") idx = 0 while result.has_next(): assert idx < len(index) @@ -854,7 +854,7 @@ def test_pyarrow_map_offset(conn_db_readonly: ConnDB) -> None: "col1": arrowtopd(col1), } ) - result = conn.execute("LOAD FROM df RETURN * ORDER BY index") + result = conn.execute("LOAD FROM df RETURN * ORDER BY `index`") idx = 0 while result.has_next(): assert idx < len(index) @@ -887,9 +887,9 @@ def test_pyarrow_decimal(conn_db_readwrite: ConnDB) -> None: conn.execute( "CREATE NODE TABLE tab(id INT64, col1 DECIMAL(7, 2), col2 DECIMAL(38, 0), primary key(id))" ) - conn.execute("LOAD FROM df CREATE (t:tab {id: index, col1: col1, col2: col2})") + conn.execute("LOAD FROM df CREATE (t:tab {id: `index`, col1: col1, col2: col2})") result = conn.execute( - "MATCH (t:tab) RETURN t.id as index, t.col1 as col1, t.col2 as col2" + "MATCH (t:tab) RETURN t.id as `index`, t.col1 as col1, t.col2 as col2" ).get_as_arrow() expected = pa.Table.from_arrays( [index, decimal52, decimal380], names=["index", "col1", "col2"] @@ -923,7 +923,7 @@ def test_pyarrow_skip_limit(conn_db_readonly: ConnDB) -> None: } ) result = conn.execute( - "LOAD FROM df (SKIP=5000, LIMIT=5000) RETURN * ORDER BY index" + "LOAD FROM df (SKIP=5000, LIMIT=5000) RETURN * ORDER BY `index`" ).get_as_arrow() expected = pa.Table.from_pandas(df).slice(5000, 5000) assert result["index"].to_pylist() == expected["index"].to_pylist() @@ -933,13 +933,13 @@ def test_pyarrow_skip_limit(conn_db_readonly: ConnDB) -> None: # skip bounds check result = conn.execute( - "LOAD FROM df (SKIP=500000, LIMIT=5000) RETURN * ORDER BY index" + "LOAD FROM df (SKIP=500000, LIMIT=5000) RETURN * ORDER BY `index`" ).get_as_arrow() assert len(result) == 0 # limit bounds check result = conn.execute( - "LOAD FROM df (SKIP=0, LIMIT=500000) RETURN * ORDER BY index" + "LOAD FROM df (SKIP=0, LIMIT=500000) RETURN * ORDER BY `index`" ).get_as_arrow() expected = pa.Table.from_pandas(df) assert result["index"].to_pylist() == expected["index"].to_pylist() diff --git a/test/test_scan_pyarrow.py b/test/test_scan_pyarrow.py index 48917b2..c3d0b92 100644 --- a/test/test_scan_pyarrow.py +++ b/test/test_scan_pyarrow.py @@ -137,7 +137,7 @@ def test_pyarrow_copy_from_invalid_source(conn_db_readwrite: ConnDB) -> None: ) with pytest.raises( RuntimeError, - match=r"Binder exception: Trying to scan from unsupported data type INT8\[\]. The only parameter types that can be scanned from are pandas/polars dataframes and pyarrow tables.", + match=r"Binder exception: Trying to scan from unsupported data type INT(8|64)\[\]. The only parameter types that can be scanned from are pandas/polars dataframes and pyarrow tables.", ): conn.execute("COPY pyarrowtab FROM $tab", {"tab": [1, 2, 3]})