Skip to content

Commit f181d1a

Browse files
buraksenncoderfender
authored andcommitted
fix(spark): mod/pmod returns NULL instead of NaN for float division by zero (apache#21557)
## Which issue does this PR close? - Closes apache#21514. ## Rationale for this change please see issue ## What changes are included in this PR? adjust pmod and mod to return NULL as expected ## Are these changes tested? added and adjusted tests ## Are there any user-facing changes? This methods will return NULL instead of Nan on this specific case.
1 parent 901062a commit f181d1a

2 files changed

Lines changed: 44 additions & 17 deletions

File tree

datafusion/spark/src/function/math/modulus.rs

Lines changed: 43 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -28,28 +28,27 @@ use datafusion_expr::{
2828
ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
2929
};
3030

31-
/// Attempts `rem(left, right)` with per-element divide-by-zero handling.
31+
/// Computes `rem(left, right)` with divide-by-zero handling.
3232
/// In ANSI mode, any zero divisor causes an error.
33-
/// In legacy mode (ANSI off), positions where the divisor is zero return NULL
34-
/// while other positions compute normally.
33+
/// In legacy mode (ANSI off), zero divisors are replaced with NULL before
34+
/// computing the remainder, so those positions return NULL while others
35+
/// compute normally.
3536
fn try_rem(
3637
left: &arrow::array::ArrayRef,
3738
right: &arrow::array::ArrayRef,
3839
enable_ansi_mode: bool,
3940
) -> Result<arrow::array::ArrayRef> {
40-
match rem(left, right) {
41-
Ok(result) => Ok(result),
42-
Err(arrow::error::ArrowError::DivideByZero) if !enable_ansi_mode => {
43-
// Integer rem fails when ANY divisor element is zero.
44-
// Handle per-element: null out zero divisors
45-
let zero = ScalarValue::new_zero(right.data_type())?.to_array()?;
46-
let zero = Scalar::new(zero);
47-
let null = Scalar::new(new_null_array(right.data_type(), 1));
48-
let is_zero = eq(right, &zero)?;
49-
let safe_right = zip(&is_zero, &null, right)?;
50-
Ok(rem(left, &safe_right)?)
51-
}
52-
Err(e) => Err(e.into()),
41+
if enable_ansi_mode {
42+
Ok(rem(left, right)?)
43+
} else {
44+
// In legacy mode, null out zero divisors so that division by zero
45+
// returns NULL instead of erroring (integers) or returning NaN (floats).
46+
let zero = ScalarValue::new_zero(right.data_type())?.to_array()?;
47+
let zero = Scalar::new(zero);
48+
let null = Scalar::new(new_null_array(right.data_type(), 1));
49+
let is_zero = eq(right, &zero)?;
50+
let safe_right = zip(&is_zero, &null, right)?;
51+
Ok(rem(left, &safe_right)?)
5352
}
5453
}
5554

@@ -241,6 +240,8 @@ mod test {
241240
Some(5.0),
242241
Some(f64::NAN),
243242
Some(f64::INFINITY),
243+
Some(10.5),
244+
Some(15.8),
244245
]);
245246
let right = Float64Array::from(vec![
246247
Some(3.0),
@@ -252,6 +253,8 @@ mod test {
252253
Some(f64::INFINITY),
253254
Some(f64::INFINITY),
254255
Some(f64::NAN),
256+
Some(0.0),
257+
Some(0.0),
255258
]);
256259

257260
let left_value = ColumnarValue::Array(Arc::new(left));
@@ -280,6 +283,9 @@ mod test {
280283
assert!(result_float64.value(7).is_nan());
281284
// inf % nan = nan
282285
assert!(result_float64.value(8).is_nan());
286+
// Division by zero returns NULL
287+
assert!(result_float64.is_null(9)); // 10.5 % 0.0 = NULL
288+
assert!(result_float64.is_null(10)); // 15.8 % 0.0 = NULL
283289
} else {
284290
panic!("Expected array result");
285291
}
@@ -297,6 +303,8 @@ mod test {
297303
Some(5.0),
298304
Some(f32::NAN),
299305
Some(f32::INFINITY),
306+
Some(10.5),
307+
Some(15.8),
300308
]);
301309
let right = Float32Array::from(vec![
302310
Some(3.0),
@@ -308,6 +316,8 @@ mod test {
308316
Some(f32::INFINITY),
309317
Some(f32::INFINITY),
310318
Some(f32::NAN),
319+
Some(0.0),
320+
Some(0.0),
311321
]);
312322

313323
let left_value = ColumnarValue::Array(Arc::new(left));
@@ -336,6 +346,9 @@ mod test {
336346
assert!(result_float32.value(7).is_nan());
337347
// inf % nan = nan
338348
assert!(result_float32.value(8).is_nan());
349+
// Division by zero returns NULL
350+
assert!(result_float32.is_null(9)); // 10.5 % 0.0 = NULL
351+
assert!(result_float32.is_null(10)); // 15.8 % 0.0 = NULL
339352
} else {
340353
panic!("Expected array result");
341354
}
@@ -462,6 +475,8 @@ mod test {
462475
Some(f64::INFINITY),
463476
Some(5.0),
464477
Some(-5.0),
478+
Some(10.5),
479+
Some(-7.2),
465480
]);
466481
let right = Float64Array::from(vec![
467482
Some(3.0),
@@ -472,6 +487,8 @@ mod test {
472487
Some(2.0),
473488
Some(f64::INFINITY),
474489
Some(f64::INFINITY),
490+
Some(0.0),
491+
Some(0.0),
475492
]);
476493

477494
let left_value = ColumnarValue::Array(Arc::new(left));
@@ -497,6 +514,9 @@ mod test {
497514
assert!((result_float64.value(6) - 5.0).abs() < f64::EPSILON);
498515
// -5.0 pmod inf = NaN
499516
assert!(result_float64.value(7).is_nan());
517+
// Division by zero returns NULL
518+
assert!(result_float64.is_null(8)); // 10.5 pmod 0.0 = NULL
519+
assert!(result_float64.is_null(9)); // -7.2 pmod 0.0 = NULL
500520
} else {
501521
panic!("Expected array result");
502522
}
@@ -513,6 +533,8 @@ mod test {
513533
Some(f32::INFINITY),
514534
Some(5.0),
515535
Some(-5.0),
536+
Some(10.5),
537+
Some(-7.2),
516538
]);
517539
let right = Float32Array::from(vec![
518540
Some(3.0),
@@ -523,6 +545,8 @@ mod test {
523545
Some(2.0),
524546
Some(f32::INFINITY),
525547
Some(f32::INFINITY),
548+
Some(0.0),
549+
Some(0.0),
526550
]);
527551

528552
let left_value = ColumnarValue::Array(Arc::new(left));
@@ -548,6 +572,9 @@ mod test {
548572
assert!((result_float32.value(6) - 5.0).abs() < f32::EPSILON * 10.0);
549573
// -5.0 pmod inf = NaN
550574
assert!(result_float32.value(7).is_nan());
575+
// Division by zero returns NULL
576+
assert!(result_float32.is_null(8)); // 10.5 pmod 0.0 = NULL
577+
assert!(result_float32.is_null(9)); // -7.2 pmod 0.0 = NULL
551578
} else {
552579
panic!("Expected array result");
553580
}

datafusion/sqllogictest/test_files/spark/math/mod.slt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ NULL
158158
query R
159159
SELECT MOD(10.5::float8, 0.0::float8) as mod_div_zero_float;
160160
----
161-
NaN
161+
NULL
162162

163163
# Division by zero errors in ANSI mode
164164
statement ok

0 commit comments

Comments
 (0)