@@ -55,14 +55,16 @@ use url::Url;
5555use uuid:: Uuid ;
5656
5757use crate :: catalog:: { PyCatalog , RustWrappedPyCatalogProvider } ;
58+ use crate :: common:: data_type:: PyScalarValue ;
5859use crate :: dataframe:: PyDataFrame ;
5960use crate :: dataset:: Dataset ;
60- use crate :: errors:: { py_datafusion_err, PyDataFusionResult } ;
61+ use crate :: errors:: { py_datafusion_err, PyDataFusionError , PyDataFusionResult } ;
6162use crate :: expr:: sort_expr:: PySortExpr ;
6263use crate :: physical_plan:: PyExecutionPlan ;
6364use crate :: record_batch:: PyRecordBatchStream ;
6465use crate :: sql:: exceptions:: py_value_err;
6566use crate :: sql:: logical:: PyLogicalPlan ;
67+ use crate :: sql:: util:: replace_placeholders_with_strings;
6668use crate :: store:: StorageContexts ;
6769use crate :: table:: PyTable ;
6870use crate :: udaf:: PyAggregateUDF ;
@@ -422,27 +424,41 @@ impl PySessionContext {
422424 self . ctx . register_udtf ( & name, func) ;
423425 }
424426
425- /// Returns a PyDataFrame whose plan corresponds to the SQL statement.
426- pub fn sql ( & self , query : & str , py : Python ) -> PyDataFusionResult < PyDataFrame > {
427- let result = self . ctx . sql ( query) ;
428- let df = wait_for_future ( py, result) ??;
429- Ok ( PyDataFrame :: new ( df) )
430- }
431-
432- #[ pyo3( signature = ( query, options=None ) ) ]
427+ #[ pyo3( signature = ( query, options=None , param_values=HashMap :: default ( ) , param_strings=HashMap :: default ( ) ) ) ]
433428 pub fn sql_with_options (
434429 & self ,
435- query : & str ,
436- options : Option < PySQLOptions > ,
437430 py : Python ,
431+ mut query : String ,
432+ options : Option < PySQLOptions > ,
433+ param_values : HashMap < String , PyScalarValue > ,
434+ param_strings : HashMap < String , String > ,
438435 ) -> PyDataFusionResult < PyDataFrame > {
439436 let options = if let Some ( options) = options {
440437 options. options
441438 } else {
442439 SQLOptions :: new ( )
443440 } ;
444- let result = self . ctx . sql_with_options ( query, options) ;
445- let df = wait_for_future ( py, result) ??;
441+
442+ let param_values = param_values
443+ . into_iter ( )
444+ . map ( |( name, value) | ( name, ScalarValue :: from ( value) ) )
445+ . collect :: < HashMap < _ , _ > > ( ) ;
446+
447+ let state = self . ctx . state ( ) ;
448+ let dialect = state. config ( ) . options ( ) . sql_parser . dialect . as_str ( ) ;
449+
450+ if !param_strings. is_empty ( ) {
451+ query = replace_placeholders_with_strings ( & query, dialect, param_strings) ?;
452+ }
453+
454+ let mut df = wait_for_future ( py, async {
455+ self . ctx . sql_with_options ( & query, options) . await
456+ } ) ??;
457+
458+ if !param_values. is_empty ( ) {
459+ df = df. with_param_values ( param_values) ?;
460+ }
461+
446462 Ok ( PyDataFrame :: new ( df) )
447463 }
448464
@@ -550,7 +566,7 @@ impl PySessionContext {
550566
551567 ( array. schema ( ) . as_ref ( ) . to_owned ( ) , vec ! [ array] )
552568 } else {
553- return Err ( crate :: errors :: PyDataFusionError :: Common (
569+ return Err ( PyDataFusionError :: Common (
554570 "Expected either a Arrow Array or Arrow Stream in from_arrow()." . to_string ( ) ,
555571 ) ) ;
556572 } ;
@@ -714,7 +730,7 @@ impl PySessionContext {
714730 ) -> PyDataFusionResult < ( ) > {
715731 let delimiter = delimiter. as_bytes ( ) ;
716732 if delimiter. len ( ) != 1 {
717- return Err ( crate :: errors :: PyDataFusionError :: PythonError ( py_value_err (
733+ return Err ( PyDataFusionError :: PythonError ( py_value_err (
718734 "Delimiter must be a single character" ,
719735 ) ) ) ;
720736 }
@@ -968,7 +984,7 @@ impl PySessionContext {
968984 ) -> PyDataFusionResult < PyDataFrame > {
969985 let delimiter = delimiter. as_bytes ( ) ;
970986 if delimiter. len ( ) != 1 {
971- return Err ( crate :: errors :: PyDataFusionError :: PythonError ( py_value_err (
987+ return Err ( PyDataFusionError :: PythonError ( py_value_err (
972988 "Delimiter must be a single character" ,
973989 ) ) ) ;
974990 } ;
0 commit comments