Skip to content

Commit 492d34a

Browse files
committed
refined first_true_2d, added tests
1 parent c9f95ca commit 492d34a

2 files changed

Lines changed: 184 additions & 28 deletions

File tree

src/lib.rs

Lines changed: 79 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ fn first_true_1d_e(array: PyReadonlyArray1<bool>) -> isize {
201201
}
202202

203203
#[pyfunction]
204-
#[pyo3(signature = (array, forward=true))]
204+
#[pyo3(signature = (array, *, forward=true))]
205205
fn first_true_1d(py: Python, array: PyReadonlyArray1<bool>, forward: bool) -> isize {
206206
if let Ok(slice) = array.as_slice() {
207207
const LANES: usize = 32;
@@ -241,8 +241,7 @@ fn first_true_1d(py: Python, array: PyReadonlyArray1<bool>, forward: bool) -> is
241241
i -= LANES;
242242
let bytes = &*(ptr.add(i) as *const [u8; LANES]);
243243
let chunk = u8x32::from(*bytes);
244-
let equal_one = chunk.cmp_eq(ones);
245-
if equal_one.any() {
244+
if chunk.cmp_eq(ones).any() {
246245
// Found a true in this chunk, search backwards within it
247246
for j in (i..i + LANES).rev() {
248247
if *ptr.add(j) != 0 {
@@ -350,31 +349,53 @@ fn first_true_1d(py: Python, array: PyReadonlyArray1<bool>, forward: bool) -> is
350349
fn prepare_array_for_axis<'py>(
351350
py: Python<'py>,
352351
array: PyReadonlyArray2<'py, bool>,
353-
axis: usize,
352+
axis: isize,
354353
) -> PyResult<Bound<'py, PyArray2<bool>>> {
355354
if axis != 0 && axis != 1 {
356355
return Err(PyValueError::new_err("axis must be 0 or 1"));
357356
}
358357

359358
let is_c = array.is_c_contiguous();
360359
let is_f = array.is_fortran_contiguous();
360+
let array_view = array.as_array();
361361

362362
match (is_c, is_f, axis) {
363-
(true, _, 0) => Ok(array.as_array().reversed_axes().to_pyarray(py)),
364-
(true, _, 1) => Ok(array.as_array().to_pyarray(py)),
365-
(_, true, 0) => Ok(array.as_array().reversed_axes().to_pyarray(py)),
366-
(_, true, 1) => Ok(array.as_array().to_pyarray(py)),
367-
(_, _, 0) => Ok(array.as_array().reversed_axes().to_pyarray(py)),
368-
(_, _, 1) => Ok(array.as_array().to_pyarray(py)),
363+
(true, _, 1) => {
364+
// Already C-contiguous, no copy needed
365+
Ok(array_view.to_pyarray(py).to_owned())
366+
}
367+
(_, true, 0) => {
368+
// F-contiguous original -> transposed will be C-contiguous, no copy needed
369+
Ok(array_view.reversed_axes().to_pyarray(py).to_owned())
370+
}
371+
(_, true, 1) => {
372+
// F-contiguous, need to copy to C-contiguous
373+
let contiguous = array_view.as_standard_layout();
374+
Ok(contiguous.to_pyarray(py).to_owned())
375+
}
376+
(_, _, 1) => {
377+
// Neither C nor F contiguous, need to copy
378+
let contiguous = array_view.as_standard_layout();
379+
Ok(contiguous.to_pyarray(py).to_owned())
380+
}
381+
382+
(true, _, 0) | (_, _, 0) => {
383+
// C-contiguous or neither -> transposed won't be C-contiguous, need copy
384+
let transposed = array_view.reversed_axes();
385+
let contiguous = transposed.as_standard_layout();
386+
Ok(contiguous.to_pyarray(py).to_owned())
387+
}
369388
_ => unreachable!(),
370389
}
371390
}
372391

373392
#[pyfunction]
393+
#[pyo3(signature = (array, *, forward=true, axis))]
374394
pub fn first_true_2d<'py>(
375395
py: Python<'py>,
376396
array: PyReadonlyArray2<'py, bool>,
377-
axis: usize,
397+
forward: bool,
398+
axis: isize,
378399
) -> PyResult<Bound<'py, PyArray1<isize>>> {
379400
let prepped = prepare_array_for_axis(py, array, axis)?;
380401
let view = unsafe { prepped.as_array() };
@@ -387,28 +408,62 @@ pub fn first_true_2d<'py>(
387408
const LANES: usize = 32;
388409
let ones = u8x32::splat(1);
389410

411+
390412
for row in 0..rows {
391413
let mut found = -1;
392414
let row_slice = &view.row(row);
393415
let ptr = row_slice.as_ptr() as *const u8;
394416
let len = row_slice.len();
395-
let mut i = 0;
396417

397-
unsafe {
398-
while i + LANES <= len {
399-
let chunk = &*(ptr.add(i) as *const [u8; LANES]);
400-
let vec = u8x32::from(*chunk);
401-
if vec.cmp_eq(ones).any() {
402-
break;
418+
if forward {
419+
// Forward search
420+
let mut i = 0;
421+
unsafe {
422+
while i + LANES <= len {
423+
let chunk = &*(ptr.add(i) as *const [u8; LANES]);
424+
let vec = u8x32::from(*chunk);
425+
if vec.cmp_eq(ones).any() {
426+
break;
427+
}
428+
i += LANES;
429+
}
430+
while i < len {
431+
if *ptr.add(i) != 0 {
432+
found = i as isize;
433+
break;
434+
}
435+
i += 1;
403436
}
404-
i += LANES;
405437
}
406-
while i < len {
407-
if *ptr.add(i) != 0 {
408-
found = i as isize;
409-
break;
438+
} else {
439+
// Backward search
440+
let mut i = len;
441+
unsafe {
442+
// Process LANES bytes at a time with SIMD (backwards)
443+
while i >= LANES {
444+
i -= LANES;
445+
let chunk = &*(ptr.add(i) as *const [u8; LANES]);
446+
let vec = u8x32::from(*chunk);
447+
if vec.cmp_eq(ones).any() {
448+
// Found a true in this chunk, search backwards within it
449+
for j in (i..i + LANES).rev() {
450+
if *ptr.add(j) != 0 {
451+
found = j as isize;
452+
break;
453+
}
454+
}
455+
break;
456+
}
457+
}
458+
// Handle remaining bytes at the beginning
459+
if found == -1 && i > 0 {
460+
for j in (0..i).rev() {
461+
if *ptr.add(j) != 0 {
462+
found = j as isize;
463+
break;
464+
}
465+
}
410466
}
411-
i += 1;
412467
}
413468
}
414469
result.push(found);

tests/test_first_true_2d.py

Lines changed: 105 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,119 @@
1+
from socket import IP_ADD_SOURCE_MEMBERSHIP
12
from arrayredox import first_true_2d
23
import numpy as np
34

45
import pytest
56

67
#--------------------------------------------------------------------------
7-
def test_first_true_2d_a() -> None:
8+
def test_first_true_2d_a1() -> None:
89
a1 = np.arange(20).reshape(4, 5) % 3 == 0
9-
pos2 = first_true_2d(a1, 1)
10+
pos2 = first_true_2d(a1, axis=1)
1011
assert pos2.tolist() == [0, 1, 2, 0]
1112

12-
pos1 = first_true_2d(a1, 0)
13-
# import ipdb; ipdb.set_trace()
13+
pos1 = first_true_2d(a1, axis=0, forward=True)
1414
assert pos1.tolist() == [0, 1, 2, 0, 1]
1515

16+
pos2 = first_true_2d(a1, axis=0, forward=False)
17+
assert pos2.tolist() == [3, 1, 2, 3, 1]
1618

1719

20+
def test_first_true_2d_a2() -> None:
21+
a1 = np.isin(np.arange(100), (9, 19, 38, 68, 96)).reshape(5, 20)
22+
23+
post1 = first_true_2d(a1, axis=1, forward=True)
24+
# NOTE: this is an axis 1 result by argmax
25+
assert post1.tolist() == [9, 18, -1, 8, 16]
26+
27+
post2 = first_true_2d(a1, axis=1, forward=False)
28+
assert post2.tolist() == [19, 18, -1, 8, 16]
29+
30+
def test_first_true_2d_b() -> None:
31+
a1 = np.isin(np.arange(20), (3, 7, 10, 15, 18)).reshape(5, 4)
32+
33+
post1 = first_true_2d(a1, axis=1, forward=False)
34+
assert post1.tolist() == [3, 3, 2, 3, 2]
35+
36+
post2 = first_true_2d(a1, axis=1, forward=True)
37+
assert post2.tolist() == [3, 3, 2, 3, 2]
38+
39+
post3 = first_true_2d(a1, axis=0, forward=False)
40+
assert post3.tolist() == [-1, -1, 4, 3]
41+
42+
post4 = first_true_2d(a1, axis=0, forward=True)
43+
assert post4.tolist() == [-1, -1, 2, 0]
44+
45+
def test_first_true_2d_c() -> None:
46+
a1 = np.isin(np.arange(20), ()).reshape(5, 4)
47+
48+
post1 = first_true_2d(a1, axis=1, forward=False)
49+
assert post1.tolist() == [-1, -1, -1, -1, -1]
50+
51+
post2 = first_true_2d(a1, axis=1, forward=True)
52+
assert post2.tolist() == [-1, -1, -1, -1, -1]
53+
54+
post3 = first_true_2d(a1, axis=0, forward=False)
55+
assert post3.tolist() == [-1, -1, -1, -1]
56+
57+
post4 = first_true_2d(a1, axis=0, forward=True)
58+
assert post4.tolist() == [-1, -1, -1, -1]
59+
60+
61+
def test_first_true_2d_d() -> None:
62+
a1 = np.isin(np.arange(20), (0, 3, 4, 7, 8, 11, 12, 15, 16, 19)).reshape(5, 4)
63+
64+
post1 = first_true_2d(a1, axis=1, forward=False)
65+
assert post1.tolist() == [3, 3, 3, 3, 3]
66+
67+
post2 = first_true_2d(a1, axis=1, forward=True)
68+
assert post2.tolist() == [0, 0, 0, 0, 0]
69+
70+
post3 = first_true_2d(a1, axis=0, forward=True)
71+
assert post3.tolist() == [0, -1, -1, 0]
72+
73+
post4 = first_true_2d(a1, axis=0, forward=False)
74+
assert post4.tolist() == [4, -1, -1, 4]
75+
76+
def test_first_true_2d_e() -> None:
77+
a1 = np.isin(np.arange(15), (2, 7, 12)).reshape(3, 5)
78+
79+
post1 = first_true_2d(a1, axis=1, forward=False)
80+
assert post1.tolist() == [2, 2, 2]
81+
82+
post2 = first_true_2d(a1, axis=1, forward=True)
83+
assert post2.tolist() == [2, 2, 2]
84+
85+
def test_first_true_2d_f() -> None:
86+
a1 = np.isin(np.arange(15), (2, 7, 12)).reshape(3, 5)
87+
88+
with pytest.raises(ValueError):
89+
post1 = first_true_2d(a1, axis=-1)
90+
91+
with pytest.raises(ValueError):
92+
post1 = first_true_2d(a1, axis=2)
93+
94+
95+
def test_first_true_2d_g() -> None:
96+
a1 = np.isin(np.arange(15), (1, 7, 14)).reshape(3, 5)
97+
post1 = first_true_2d(a1, axis=0, forward=True)
98+
assert post1.tolist() == [-1, 0, 1, -1, 2]
99+
100+
post2 = first_true_2d(a1, axis=0, forward=False)
101+
assert post2.tolist() == [-1, 0, 1, -1, 2]
102+
103+
104+
def test_first_true_2d_h() -> None:
105+
a1 = np.isin(np.arange(15), (1, 7, 14)).reshape(3, 5).T # force fortran ordering
106+
assert first_true_2d(a1, axis=0, forward=True).tolist() == [1, 2, 4]
107+
assert first_true_2d(a1, axis=0, forward=False).tolist() == [1, 2, 4]
108+
assert first_true_2d(a1, axis=1, forward=True).tolist() == [-1, 0, 1, -1, 2]
109+
assert first_true_2d(a1, axis=1, forward=False).tolist() == [-1, 0, 1, -1, 2]
110+
111+
112+
def test_first_true_2d_i() -> None:
113+
# force fortran ordering, non-contiguous, non-owned
114+
a1 = np.isin(np.arange(15), (1, 4, 5, 7, 8, 12, 15)).reshape(3, 5).T[:4]
115+
assert first_true_2d(a1, axis=0, forward=True).tolist() == [1, 0, 2]
116+
assert first_true_2d(a1, axis=0, forward=False).tolist() == [1, 3, 2]
117+
assert first_true_2d(a1, axis=1, forward=True).tolist() == [1, 0, 1, 1]
118+
assert first_true_2d(a1, axis=1, forward=False).tolist() == [1, 0, 2, 1]
18119

0 commit comments

Comments
 (0)