diff --git a/encodings/runend/src/compute/take.rs b/encodings/runend/src/compute/take.rs index 7100faf9eac..46c8c2a35fc 100644 --- a/encodings/runend/src/compute/take.rs +++ b/encodings/runend/src/compute/take.rs @@ -7,87 +7,481 @@ use vortex_array::ArrayRef; use vortex_array::ArrayView; use vortex_array::ExecutionCtx; use vortex_array::IntoArray; +use vortex_array::arrays::ConstantArray; use vortex_array::arrays::PrimitiveArray; use vortex_array::arrays::dict::TakeExecute; +use vortex_array::dtype::UnsignedPType; use vortex_array::match_each_integer_ptype; -use vortex_array::search_sorted::SearchResult; -use vortex_array::search_sorted::SearchSorted; -use vortex_array::search_sorted::SearchSortedSide; +use vortex_array::match_each_unsigned_integer_ptype; +use vortex_array::scalar::Scalar; use vortex_array::validity::Validity; use vortex_buffer::Buffer; +use vortex_buffer::BufferMut; use vortex_error::VortexResult; use vortex_error::vortex_bail; +use vortex_mask::AllOr; +use vortex_mask::Mask; use crate::RunEnd; use crate::array::RunEndArrayExt; +use crate::iter::trimmed_ends_iter; + +const SORTED_LINEAR_RUNS_PER_INDEX_THRESHOLD: usize = 16; +const UNSORTED_LINEAR_RUNS_PER_INDEX_THRESHOLD: usize = 4; +/// Sorting the indices and merging only beats per-index binary search once the run ends are too +/// large to stay cache-resident; below this run count binary search wins. +const UNSORTED_LINEAR_MIN_RUNS: usize = 1 << 19; +/// Use a dense logical-position-to-run-index table when the array length is at most this many +/// times the number of valid indices: building the table is O(array_len) and each index then +/// resolves with a single unconditional gather. +const TABLE_LEN_PER_INDEX_THRESHOLD: usize = 8; impl TakeExecute for RunEnd { - #[expect( - clippy::cast_possible_truncation, - reason = "index cast to usize inside macro" - )] fn take( array: ArrayView<'_, Self>, indices: &ArrayRef, ctx: &mut ExecutionCtx, ) -> VortexResult> { let primitive_indices = indices.clone().execute::(ctx)?; + let indices_validity = primitive_indices.validity()?; + let indices_mask = indices_validity.execute_mask(primitive_indices.len(), ctx)?; - let checked_indices = match_each_integer_ptype!(primitive_indices.ptype(), |P| { - primitive_indices - .as_slice::

() - .iter() - .copied() - .map(|idx| { - let usize_idx = idx as usize; - if usize_idx >= array.len() { - vortex_bail!(OutOfBounds: usize_idx, 0, array.len()); - } - Ok(usize_idx) - }) - .collect::>>()? + let taken = match_each_integer_ptype!(primitive_indices.ptype(), |P| { + take_indices( + array, + primitive_indices.as_slice::

(), + &indices_validity, + &indices_mask, + true, + ctx, + )? }); - let indices_validity = primitive_indices.validity()?; - take_indices_unchecked(array, &checked_indices, &indices_validity, ctx).map(Some) + Ok(Some(taken)) } } -/// Perform a take operation on a RunEndArray by binary searching for each of the indices. +/// Perform a take operation on a RunEndArray without bounds-checking the indices. +/// +/// The caller must guarantee that all valid indices are in bounds for the array. pub fn take_indices_unchecked>( array: ArrayView<'_, RunEnd>, indices: &[T], validity: &Validity, ctx: &mut ExecutionCtx, ) -> VortexResult { + let validity_mask = validity.execute_mask(indices.len(), ctx)?; + take_indices(array, indices, validity, &validity_mask, false, ctx) +} + +fn take_indices>( + array: ArrayView<'_, RunEnd>, + indices: &[T], + validity: &Validity, + validity_mask: &Mask, + check_bounds: bool, + ctx: &mut ExecutionCtx, +) -> VortexResult { + if validity_mask.all_false() { + return Ok( + ConstantArray::new(Scalar::null(array.dtype().as_nullable()), indices.len()) + .into_array(), + ); + } + + let stats = valid_indices_stats(indices, validity_mask, array.len(), check_bounds)?; let ends = array.ends().clone().execute::(ctx)?; - let ends_len = ends.len(); - - // TODO(joe): use the validity mask to skip search sorted. - let physical_indices = match_each_integer_ptype!(ends.ptype(), |I| { - let end_slices = ends.as_slice::(); - let physical_indices_vec: Vec = indices - .iter() - .map(|idx| idx.as_() + array.offset()) - .map(|idx| { - match ::from(idx) { - Some(idx) => end_slices.search_sorted(&idx, SearchSortedSide::Right), - None => { - // The idx is too large for I, therefore it's out of bounds. - Ok(SearchResult::NotFound(ends_len)) - } - } - }) - .map(|result| result.map(|r| r.to_ends_index(ends_len) as u64)) - .collect::>>()?; - let buffer = Buffer::from(physical_indices_vec); - PrimitiveArray::new(buffer, validity.clone()) + let physical_indices = match_each_unsigned_integer_ptype!(ends.ptype(), |I| { + let ends = ends.as_slice::(); + // Run indices fit in u32 for any realistic array; the narrower physical indices halve + // the memory traffic of the downstream take on the values. + if ends.len() <= u32::MAX as usize { + PrimitiveArray::new( + physical_indices_with_stats::<_, _, u32>( + ends, + array.offset(), + array.len(), + indices, + validity_mask, + stats, + ), + validity.clone(), + ) + } else { + PrimitiveArray::new( + physical_indices_with_stats::<_, _, u64>( + ends, + array.offset(), + array.len(), + indices, + validity_mask, + stats, + ), + validity.clone(), + ) + } }); array.values().take(physical_indices.into_array()) } +#[derive(Clone, Copy)] +struct ValidIndicesStats { + count: usize, + sorted: bool, +} + +fn physical_indices_with_stats( + ends: &[I], + offset: usize, + array_len: usize, + indices: &[T], + validity_mask: &Mask, + stats: ValidIndicesStats, +) -> Buffer +where + I: UnsignedPType, + T: AsPrimitive, + O: UnsignedPType, + usize: AsPrimitive, +{ + if stats.count == 0 { + return Buffer::zeroed(indices.len()); + } + + if stats.sorted + && prefer_linear_scan( + ends.len(), + stats.count, + SORTED_LINEAR_RUNS_PER_INDEX_THRESHOLD, + ) + { + return physical_indices_linear_sorted(ends, offset, indices, validity_mask); + } + + // A dense take resolves fastest through the position table regardless of index ordering. + // Sorted indices reach here only when there are too many runs for the sorted linear scan, + // for example a narrow slice of a heavily run-encoded array where runs far exceed array_len. + if array_len <= stats.count.saturating_mul(TABLE_LEN_PER_INDEX_THRESHOLD) { + return physical_indices_table(ends, offset, array_len, indices, validity_mask); + } + + if ends.len() >= UNSORTED_LINEAR_MIN_RUNS + && prefer_linear_scan( + ends.len(), + stats.count, + UNSORTED_LINEAR_RUNS_PER_INDEX_THRESHOLD, + ) + { + return physical_indices_linear_unsorted(ends, offset, indices, validity_mask, stats.count); + } + + physical_indices_binary(ends, offset, indices, validity_mask) +} + +/// Count the valid indices and determine whether they are sorted, bounds-checking each valid +/// index against `array_len` when `check_bounds` is set. +fn valid_indices_stats>( + indices: &[T], + validity_mask: &Mask, + array_len: usize, + check_bounds: bool, +) -> VortexResult { + debug_assert_eq!(indices.len(), validity_mask.len()); + + let count = validity_mask.true_count(); + if count == 0 { + return Ok(ValidIndicesStats { + count, + sorted: true, + }); + } + + let sorted = match validity_mask.bit_buffer() { + AllOr::All => valid_indices_sorted_all(indices, array_len, check_bounds)?, + AllOr::None => true, + AllOr::Some(validity) => { + valid_indices_sorted_masked(indices, validity.iter(), array_len, check_bounds)? + } + }; + + Ok(ValidIndicesStats { count, sorted }) +} + +fn valid_indices_sorted_all>( + indices: &[T], + array_len: usize, + check_bounds: bool, +) -> VortexResult { + // Seed the comparison with the first index; an empty or single-element slice is trivially + // sorted, so the loop below starts from the second element. + let Some((first, rest)) = indices.split_first() else { + return Ok(true); + }; + + let mut previous_idx = first.as_(); + if check_bounds { + check_index(previous_idx, array_len)?; + } + + let mut sorted = true; + for idx in rest { + let idx = idx.as_(); + if check_bounds { + check_index(idx, array_len)?; + } + if previous_idx > idx { + sorted = false; + if !check_bounds { + break; + } + } + previous_idx = idx; + } + + Ok(sorted) +} + +fn valid_indices_sorted_masked>( + indices: &[T], + is_valid: impl Iterator, + array_len: usize, + check_bounds: bool, +) -> VortexResult { + // Invalid positions are skipped without a bounds check, matching the take path that never + // dereferences them. + let mut valid = is_valid + .zip(indices.iter()) + .filter(|(is_valid, _)| *is_valid) + .map(|(_, idx)| idx.as_()); + + // Seed the comparison with the first valid index; zero or one valid index is trivially + // sorted, so the loop below starts from the second valid index. + let Some(mut previous_idx) = valid.next() else { + return Ok(true); + }; + if check_bounds { + check_index(previous_idx, array_len)?; + } + + let mut sorted = true; + for idx in valid { + if check_bounds { + check_index(idx, array_len)?; + } + if previous_idx > idx { + sorted = false; + if !check_bounds { + break; + } + } + previous_idx = idx; + } + + Ok(sorted) +} + +fn prefer_linear_scan( + ends_len: usize, + valid_count: usize, + runs_per_index_threshold: usize, +) -> bool { + ends_len <= valid_count.saturating_mul(runs_per_index_threshold) +} + +fn check_index(index: usize, array_len: usize) -> VortexResult<()> { + if index >= array_len { + vortex_bail!(OutOfBounds: index, 0, array_len); + } + Ok(()) +} + +fn physical_indices_linear_sorted( + ends: &[I], + offset: usize, + indices: &[T], + validity_mask: &Mask, +) -> Buffer +where + I: UnsignedPType, + T: AsPrimitive, + O: UnsignedPType, + usize: AsPrimitive, +{ + let mut run_idx = 0; + + match validity_mask.bit_buffer() { + AllOr::All => Buffer::from_trusted_len_iter(indices.iter().map(|idx| { + advance_run(ends, &mut run_idx, idx.as_() + offset); + run_idx.as_() + })), + AllOr::None => unreachable!("AllInvalid indices have been handled earlier"), + AllOr::Some(validity) => { + // Invalid positions keep physical index zero, which is always in-bounds for the + // values and masked out by the result validity. + let mut physical_indices = BufferMut::zeroed(indices.len()); + for (idx_pos, (is_valid, idx)) in validity.iter().zip(indices.iter()).enumerate() { + if !is_valid { + continue; + } + + advance_run(ends, &mut run_idx, idx.as_() + offset); + physical_indices[idx_pos] = run_idx.as_(); + } + physical_indices.freeze() + } + } +} + +/// Resolve indices through a dense logical-position-to-run-index table. +/// +/// Building the table costs O(array_len), but every index then resolves with an unconditional +/// gather, which beats per-index binary search and sort-then-merge for dense takes. Invalid +/// indices may hold arbitrary values (even out of bounds), so they are redirected to position +/// zero instead of branching; the result validity masks whatever they resolve to. +fn physical_indices_table( + ends: &[I], + offset: usize, + array_len: usize, + indices: &[T], + validity_mask: &Mask, +) -> Buffer +where + I: UnsignedPType, + T: AsPrimitive, + O: UnsignedPType, + usize: AsPrimitive, +{ + let table = run_index_table::(ends, offset, array_len); + let table = table.as_slice(); + + match validity_mask.bit_buffer() { + AllOr::All => Buffer::from_trusted_len_iter(indices.iter().map(|idx| table[idx.as_()])), + AllOr::None => unreachable!("AllInvalid indices have been handled earlier"), + AllOr::Some(validity) => Buffer::from_trusted_len_iter( + validity + .iter() + .zip(indices.iter()) + .map(|(is_valid, idx)| table[if is_valid { idx.as_() } else { 0 }]), + ), + } +} + +/// Materialize the run index of every logical position in `[0, len)`. +fn run_index_table(ends: &[I], offset: usize, len: usize) -> Buffer +where + I: UnsignedPType, + O: UnsignedPType, + usize: AsPrimitive, +{ + let mut table = BufferMut::with_capacity(len); + let mut run_start = 0; + for (run_idx, run_end) in trimmed_ends_iter(ends, offset, len).enumerate() { + table.push_n(run_idx.as_(), run_end - run_start); + run_start = run_end; + } + table.freeze() +} + +fn physical_indices_linear_unsorted( + ends: &[I], + offset: usize, + indices: &[T], + validity_mask: &Mask, + valid_count: usize, +) -> Buffer +where + I: UnsignedPType, + T: AsPrimitive, + O: UnsignedPType, + usize: AsPrimitive, +{ + let mut pairs = Vec::with_capacity(valid_count); + match validity_mask.bit_buffer() { + AllOr::All => { + pairs.extend( + indices + .iter() + .enumerate() + .map(|(idx_pos, idx)| (idx.as_(), idx_pos)), + ); + } + AllOr::None => unreachable!("AllInvalid indices have been handled earlier"), + AllOr::Some(validity) => { + for (idx_pos, (is_valid, idx)) in validity.iter().zip(indices.iter()).enumerate() { + if is_valid { + pairs.push((idx.as_(), idx_pos)); + } + } + } + } + pairs.sort_unstable(); + + let mut physical_indices = BufferMut::zeroed(indices.len()); + let mut run_idx = 0; + + for (idx, idx_pos) in pairs { + advance_run(ends, &mut run_idx, idx + offset); + physical_indices[idx_pos] = run_idx.as_(); + } + + physical_indices.freeze() +} + +fn physical_indices_binary( + ends: &[I], + offset: usize, + indices: &[T], + validity_mask: &Mask, +) -> Buffer +where + I: UnsignedPType, + T: AsPrimitive, + O: UnsignedPType, + usize: AsPrimitive, +{ + match validity_mask.bit_buffer() { + AllOr::All => Buffer::from_trusted_len_iter( + indices + .iter() + .map(|idx| physical_index_binary(ends, idx.as_() + offset).as_()), + ), + AllOr::None => Buffer::zeroed(indices.len()), + AllOr::Some(validity) => { + let mut physical_indices = BufferMut::zeroed(indices.len()); + for (idx_pos, (is_valid, idx)) in validity.iter().zip(indices.iter()).enumerate() { + if !is_valid { + continue; + } + + physical_indices[idx_pos] = physical_index_binary(ends, idx.as_() + offset).as_(); + } + physical_indices.freeze() + } + } +} + +fn physical_index_binary(ends: &[I], logical_idx: usize) -> usize { + let index = match ::from(logical_idx) { + Some(logical_idx) => ends.partition_point(|end| *end <= logical_idx), + None => ends.len(), + }; + index.min(ends.len() - 1) +} + +fn advance_run(ends: &[I], run_idx: &mut usize, logical_idx: usize) { + // A logical index that overflows the run-end type sits past every run, so it lands in the + // final run; otherwise advance while the current run ends at or before it. + let Some(logical_idx) = I::from(logical_idx) else { + *run_idx = ends.len().saturating_sub(1); + return; + }; + while *run_idx + 1 < ends.len() && ends[*run_idx] <= logical_idx { + *run_idx += 1; + } +} + #[cfg(test)] mod tests { use rstest::rstest; @@ -96,11 +490,18 @@ mod tests { use vortex_array::IntoArray; use vortex_array::LEGACY_SESSION; use vortex_array::VortexSessionExecute; + use vortex_array::arrays::BoolArray; use vortex_array::arrays::PrimitiveArray; use vortex_array::assert_arrays_eq; use vortex_array::compute::conformance::take::test_take_conformance; + use vortex_array::validity::Validity; use vortex_buffer::buffer; + use vortex_mask::Mask; + use super::physical_indices_binary; + use super::physical_indices_linear_sorted; + use super::physical_indices_linear_unsorted; + use super::physical_indices_table; use crate::RunEnd; use crate::RunEndArray; @@ -126,6 +527,15 @@ mod tests { assert_arrays_eq!(taken, expected); } + #[test] + fn ree_take_sorted_boundaries() { + let taken = ree_array() + .take(buffer![0, 2, 3, 6, 8, 11].into_array()) + .unwrap(); + let expected = PrimitiveArray::from_iter(vec![1i32, 1, 4, 2, 5, 5]).into_array(); + assert_arrays_eq!(taken, expected); + } + #[test] #[should_panic] fn ree_take_out_of_bounds() { @@ -145,6 +555,15 @@ mod tests { assert_arrays_eq!(taken, expected); } + #[test] + fn sliced_take_unsorted_dense() { + let sliced = ree_array().slice(4..9).unwrap(); + let taken = sliced.take(buffer![4, 0, 2, 1].into_array()).unwrap(); + + let expected = PrimitiveArray::from_iter(vec![5i32, 4, 2, 4]).into_array(); + assert_arrays_eq!(taken, expected); + } + #[test] fn ree_take_nullable() { let taken = ree_array() @@ -155,6 +574,115 @@ mod tests { assert_arrays_eq!(taken, expected.into_array()); } + #[test] + fn ree_take_all_null_indices() { + let taken = ree_array() + .take(PrimitiveArray::from_option_iter([None::, None]).into_array()) + .unwrap(); + + let expected = PrimitiveArray::from_option_iter([None::, None]); + assert_arrays_eq!(taken, expected.into_array()); + } + + #[test] + fn ree_take_null_index_skips_out_of_bounds_value() { + let indices = PrimitiveArray::new( + buffer![1u64, 12], + Validity::Array(BoolArray::from_iter([true, false]).into_array()), + ); + let taken = ree_array().take(indices.into_array()).unwrap(); + + let expected = PrimitiveArray::from_option_iter([Some(1i32), None]); + assert_arrays_eq!(taken, expected.into_array()); + } + + #[test] + fn ree_take_unsorted_null_index_skips_out_of_bounds_value() { + let indices = PrimitiveArray::new( + buffer![3u64, 12, 1], + Validity::Array(BoolArray::from_iter([true, false, true]).into_array()), + ); + let taken = ree_array().take(indices.into_array()).unwrap(); + + let expected = PrimitiveArray::from_option_iter([Some(4i32), None, Some(1)]); + assert_arrays_eq!(taken, expected.into_array()); + } + + #[test] + fn ree_take_dense_null_index_skips_out_of_bounds_value() { + let indices = PrimitiveArray::new( + buffer![0u64, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 12], + Validity::Array( + BoolArray::from_iter([ + true, true, true, true, true, true, true, true, true, true, true, false, + ]) + .into_array(), + ), + ); + let taken = ree_array().take(indices.into_array()).unwrap(); + + let expected = PrimitiveArray::from_option_iter([ + Some(1i32), + Some(1), + Some(1), + Some(4), + Some(4), + Some(4), + Some(2), + Some(2), + Some(5), + Some(5), + Some(5), + None, + ]); + assert_arrays_eq!(taken, expected.into_array()); + } + + #[rstest] + #[case(vec![3u32, 6, 8, 12], 0, 12, vec![0u64, 11, 3, 3, 7, 2, 9], Mask::new_true(7))] + #[case(vec![3u32, 6, 8, 12], 0, 12, vec![5u64, 100, 2, 11, 0], Mask::from_indices(5, [0, 2, 3, 4]))] + #[case(vec![6u32, 8, 12], 4, 5, vec![4u64, 0, 2, 1, 3], Mask::new_true(5))] + fn unsorted_strategies_agree( + #[case] ends: Vec, + #[case] offset: usize, + #[case] len: usize, + #[case] indices: Vec, + #[case] mask: Mask, + ) { + let binary = physical_indices_binary::(&ends, offset, &indices, &mask); + let table = physical_indices_table::(&ends, offset, len, &indices, &mask); + let sort_merge = physical_indices_linear_unsorted::( + &ends, + offset, + &indices, + &mask, + mask.true_count(), + ); + + assert_eq!(binary.as_slice(), table.as_slice()); + assert_eq!(binary.as_slice(), sort_merge.as_slice()); + } + + #[rstest] + #[case(vec![3u32, 6, 8, 12], 0, 12, vec![0u64, 2, 3, 6, 8, 11], Mask::new_true(6))] + #[case(vec![3u32, 6, 8, 12], 0, 12, vec![1u64, 100, 5, 9], Mask::from_indices(4, [0, 2, 3]))] + #[case(vec![6u32, 8, 12], 4, 5, vec![0u64, 1, 3, 4], Mask::new_true(4))] + fn sorted_strategies_agree( + #[case] ends: Vec, + #[case] offset: usize, + #[case] len: usize, + #[case] indices: Vec, + #[case] mask: Mask, + ) { + let binary = physical_indices_binary::(&ends, offset, &indices, &mask); + let table = physical_indices_table::(&ends, offset, len, &indices, &mask); + let sorted = + physical_indices_linear_sorted::(&ends, offset, &indices, &mask); + + assert_eq!(binary.as_slice(), table.as_slice()); + assert_eq!(binary.as_slice(), sorted.as_slice()); + } + #[rstest] #[case(ree_array())] #[case(RunEnd::encode(