Skip to content

Commit 9d0ac50

Browse files
committed
Refactor common code into helper function so we do not duplicate it.
1 parent 85ee4f7 commit 9d0ac50

File tree

1 file changed

+15
-21
lines changed

1 file changed

+15
-21
lines changed

src/pyarrow_util.rs

Lines changed: 15 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
2020
use std::sync::Arc;
2121

22-
use arrow::array::{make_array, Array, ArrayData, ListArray};
22+
use arrow::array::{make_array, Array, ArrayData, ArrayRef, ListArray};
2323
use arrow::buffer::OffsetBuffer;
2424
use arrow::datatypes::Field;
2525
use arrow::pyarrow::{FromPyArrow, ToPyArrow};
@@ -31,13 +31,7 @@ use pyo3::{Bound, FromPyObject, PyAny, PyResult, Python};
3131
use crate::common::data_type::PyScalarValue;
3232
use crate::errors::PyDataFusionError;
3333

34-
fn pyobj_extract_scalar_via_capsule(
35-
value: &Bound<'_, PyAny>,
36-
as_list_array: bool,
37-
) -> PyResult<PyScalarValue> {
38-
let array_data = ArrayData::from_pyarrow_bound(value)?;
39-
let array = make_array(array_data);
40-
34+
fn array_to_scalar_value(array: ArrayRef, as_list_array: bool) -> PyResult<PyScalarValue> {
4135
if as_list_array {
4236
let field = Arc::new(Field::new_list_field(
4337
array.data_type().clone(),
@@ -52,6 +46,16 @@ fn pyobj_extract_scalar_via_capsule(
5246
}
5347
}
5448

49+
fn pyobj_extract_scalar_via_capsule(
50+
value: &Bound<'_, PyAny>,
51+
as_list_array: bool,
52+
) -> PyResult<PyScalarValue> {
53+
let array_data = ArrayData::from_pyarrow_bound(value)?;
54+
let array = make_array(array_data);
55+
56+
array_to_scalar_value(array, as_list_array)
57+
}
58+
5559
impl FromPyArrow for PyScalarValue {
5660
fn from_pyarrow_bound(value: &Bound<'_, PyAny>) -> PyResult<Self> {
5761
let py = value.py();
@@ -115,19 +119,9 @@ impl FromPyArrow for PyScalarValue {
115119

116120
let array_data = ArrayData::from_pyarrow_bound(value)?;
117121
let array = make_array(array_data);
118-
if array.len() == 1 {
119-
let scalar =
120-
ScalarValue::try_from_array(&array, 0).map_err(PyDataFusionError::from)?;
121-
return Ok(PyScalarValue(scalar));
122-
} else {
123-
let field = Arc::new(Field::new_list_field(
124-
array.data_type().clone(),
125-
array.nulls().is_some(),
126-
));
127-
let offsets = OffsetBuffer::from_lengths(vec![array.len()]);
128-
let list_array = ListArray::new(field, offsets, array, None);
129-
return Ok(PyScalarValue(ScalarValue::List(Arc::new(list_array))));
130-
}
122+
123+
let as_array_list = array.len() != 1;
124+
return array_to_scalar_value(array, as_array_list);
131125
}
132126

133127
// Last attempt - try to create a PyArrow scalar from a plain Python object

0 commit comments

Comments
 (0)