Skip to content

Commit f9d9cdf

Browse files
committed
implementation of first_true_2s
1 parent 0fb8b12 commit f9d9cdf

1 file changed

Lines changed: 72 additions & 34 deletions

File tree

src/lib.rs

Lines changed: 72 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,17 @@
11
use numpy::PyReadonlyArray1;
2-
use pyo3::prelude::*;
32
use pyo3::exceptions::PyValueError;
3+
use pyo3::prelude::*;
44
use pyo3::Bound;
55
// use pyo3::types::{PyBool, PyAny};
66
use wide::*;
77
// use std::simd::Simd;
88
// use std::simd::cmp::SimdPartialEq;
99

10-
use numpy::{IntoPyArray, PyArray2, PyReadonlyArray2};
11-
use numpy::ToPyArray;
1210
use numpy::PyArrayMethods;
1311
use numpy::PyUntypedArrayMethods;
14-
12+
use numpy::ToPyArray;
13+
use numpy::{PyArray2, PyReadonlyArray2};
14+
use numpy::PyArray1;
1515

1616
#[pyfunction]
1717
fn first_true_1d_a(array: PyReadonlyArray1<bool>) -> isize {
@@ -200,13 +200,9 @@ fn first_true_1d_e(array: PyReadonlyArray1<bool>) -> isize {
200200
}
201201
}
202202

203-
204203
#[pyfunction]
205204
#[pyo3(signature = (array, forward=true))]
206-
fn first_true_1d(py: Python,
207-
array: PyReadonlyArray1<bool>,
208-
forward: bool,
209-
) -> isize {
205+
fn first_true_1d(py: Python, array: PyReadonlyArray1<bool>, forward: bool) -> isize {
210206
if let Ok(slice) = array.as_slice() {
211207
const LANES: usize = 32;
212208

@@ -287,8 +283,6 @@ fn first_true_1d(py: Python,
287283
}
288284
}
289285

290-
291-
292286
// #[pyfunction]
293287
// fn first_true_1d_g(py: Python, array: PyReadonlyArray1<bool>) -> isize {
294288
// if let Ok(slice) = array.as_slice() {
@@ -338,10 +332,22 @@ fn first_true_1d(py: Python,
338332
// }
339333
// }
340334

341-
342335
//------------------------------------------------------------------------------
343336

344337

338+
// NOTE: we copy the entire array into contiguous memory when necessary.
339+
// axis = 0 returns the pos per col
340+
// axis = 1 returns the pos per row (as contiguous bytes)
341+
// if c contiguous:
342+
// axis == 0: transpose, copy to C
343+
// axis == 1: keep
344+
// if f contiguous:
345+
// axis == 0: transpose, keep
346+
// axis == 1: copy to C
347+
// else
348+
// axis == 0: transpose, copy to C
349+
// axis == 1: copy to C
350+
345351
fn prepare_array_for_axis<'py>(
346352
py: Python<'py>,
347353
array: PyReadonlyArray2<'py, bool>,
@@ -355,34 +361,66 @@ fn prepare_array_for_axis<'py>(
355361
let is_f = array.is_fortran_contiguous();
356362

357363
match (is_c, is_f, axis) {
358-
(true, _, 0) => {
359-
let transposed = array.as_array().reversed_axes().to_owned();
360-
Ok(transposed.into_pyarray(py))
361-
}
362-
(true, _, 1) => Ok(array.as_array().to_owned().into_pyarray(py)), // copy to get full ownership
363-
(_, true, 0) => {
364-
let transposed = array.as_array().reversed_axes();
365-
Ok(transposed.to_owned().into_pyarray(py))
366-
}
367-
(_, true, 1) => {
368-
let owned = array.as_array().to_owned();
369-
Ok(owned.into_pyarray(py))
370-
}
371-
(false, false, 0) => {
372-
let transposed = array.as_array().reversed_axes().to_owned();
373-
Ok(transposed.into_pyarray(py))
374-
}
375-
(false, false, 1) => {
376-
let owned = array.as_array().to_owned();
377-
Ok(owned.into_pyarray(py))
378-
}
364+
(true, _, 0) => Ok(array.as_array().reversed_axes().to_pyarray(py)),
365+
(true, _, 1) => Ok(array.as_array().to_pyarray(py)),
366+
(_, true, 0) => Ok(array.as_array().reversed_axes().to_pyarray(py)),
367+
(_, true, 1) => Ok(array.as_array().to_pyarray(py)),
368+
(_, _, 0) => Ok(array.as_array().reversed_axes().to_pyarray(py)),
369+
(_, _, 1) => Ok(array.as_array().to_pyarray(py)),
379370
_ => unreachable!(),
380371
}
381372
}
382373

374+
#[pyfunction]
375+
pub fn first_true_2d<'py>(
376+
py: Python<'py>,
377+
array: PyReadonlyArray2<'py, bool>,
378+
axis: usize,
379+
) -> PyResult<Bound<'py, PyArray1<isize>>> {
380+
let prepped = prepare_array_for_axis(py, array, axis)?;
381+
let view = unsafe { prepped.as_array() };
383382

384-
//------------------------------------------------------------------------------
383+
// NOTE: these are rows in the view, not always the same as rows
384+
let rows = view.nrows();
385+
let mut result = Vec::with_capacity(rows);
385386

387+
py.allow_threads(|| {
388+
const LANES: usize = 32;
389+
let ones = u8x32::splat(1);
390+
391+
for row in 0..rows {
392+
let mut found = -1;
393+
let row_slice = &view.row(row);
394+
let ptr = row_slice.as_ptr() as *const u8;
395+
let len = row_slice.len();
396+
let mut i = 0;
397+
398+
unsafe {
399+
while i + LANES <= len {
400+
let chunk = &*(ptr.add(i) as *const [u8; LANES]);
401+
let vec = u8x32::from(*chunk);
402+
if vec.cmp_eq(ones).any() {
403+
break;
404+
}
405+
i += LANES;
406+
}
407+
while i < len {
408+
if *ptr.add(i) != 0 {
409+
found = i as isize;
410+
break;
411+
}
412+
i += 1;
413+
}
414+
}
415+
result.push(found);
416+
}
417+
});
418+
419+
Ok(PyArray1::from_vec(py, result).to_owned())
420+
}
421+
422+
423+
//------------------------------------------------------------------------------
386424

387425
#[pymodule]
388426
fn arrayredox(m: &Bound<'_, PyModule>) -> PyResult<()> {

0 commit comments

Comments
 (0)