Skip to content

Commit

Permalink
chore: prefer downcast_array_ref (#1786)
Browse files Browse the repository at this point in the history
  • Loading branch information
danking authored Jan 3, 2025
1 parent 60af509 commit 4eda890
Show file tree
Hide file tree
Showing 19 changed files with 64 additions and 190 deletions.
4 changes: 2 additions & 2 deletions vortex-array/src/compute/binary_numeric.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use vortex_scalar::{BinaryNumericOperator, Scalar};

use crate::array::ConstantArray;
use crate::arrow::{Datum, FromArrowArray};
use crate::encoding::{downcast_array_ref, Encoding};
use crate::encoding::Encoding;
use crate::{ArrayDType, ArrayData, IntoArrayData as _};

pub trait BinaryNumericFn<Array> {
Expand All @@ -30,7 +30,7 @@ where
rhs: &ArrayData,
op: BinaryNumericOperator,
) -> VortexResult<Option<ArrayData>> {
let (array_ref, encoding) = downcast_array_ref::<E>(lhs)?;
let (array_ref, encoding) = lhs.downcast_array_ref::<E>()?;
BinaryNumericFn::binary_numeric(encoding, array_ref, rhs, op)
}
}
Expand Down
9 changes: 2 additions & 7 deletions vortex-array/src/compute/boolean.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::sync::Arc;
use arrow_array::cast::AsArray;
use arrow_array::ArrayRef;
use vortex_dtype::DType;
use vortex_error::{vortex_bail, vortex_err, VortexError, VortexResult};
use vortex_error::{vortex_bail, VortexError, VortexResult};

use crate::arrow::FromArrowArray;
use crate::encoding::Encoding;
Expand Down Expand Up @@ -40,12 +40,7 @@ where
rhs: &ArrayData,
op: BinaryOperator,
) -> VortexResult<Option<ArrayData>> {
let array_ref = <&E::Array>::try_from(lhs)?;
let encoding = lhs
.encoding()
.as_any()
.downcast_ref::<E>()
.ok_or_else(|| vortex_err!("Mismatched encoding"))?;
let (array_ref, encoding) = lhs.downcast_array_ref::<E>()?;
BinaryBooleanFn::binary_boolean(encoding, array_ref, rhs, op)
}
}
Expand Down
9 changes: 2 additions & 7 deletions vortex-array/src/compute/cast.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use vortex_dtype::DType;
use vortex_error::{vortex_bail, vortex_err, VortexError, VortexResult};
use vortex_error::{vortex_bail, VortexError, VortexResult};

use crate::encoding::Encoding;
use crate::{ArrayDType, ArrayData, IntoArrayData, IntoCanonical};
Expand All @@ -14,12 +14,7 @@ where
for<'a> &'a E::Array: TryFrom<&'a ArrayData, Error = VortexError>,
{
fn cast(&self, array: &ArrayData, dtype: &DType) -> VortexResult<ArrayData> {
let array_ref = <&E::Array>::try_from(array)?;
let encoding = array
.encoding()
.as_any()
.downcast_ref::<E>()
.ok_or_else(|| vortex_err!("Mismatched encoding"))?;
let (array_ref, encoding) = array.downcast_array_ref::<E>()?;
CastFn::cast(encoding, array_ref, dtype)
}
}
Expand Down
9 changes: 2 additions & 7 deletions vortex-array/src/compute/compare.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::fmt::{Display, Formatter};

use arrow_ord::cmp;
use vortex_dtype::{DType, Nullability};
use vortex_error::{vortex_bail, vortex_err, VortexError, VortexResult};
use vortex_error::{vortex_bail, VortexError, VortexResult};
use vortex_scalar::Scalar;

use crate::arrow::{Datum, FromArrowArray};
Expand Down Expand Up @@ -92,12 +92,7 @@ where
rhs: &ArrayData,
operator: Operator,
) -> VortexResult<Option<ArrayData>> {
let lhs_ref = <&E::Array>::try_from(lhs)?;
let encoding = lhs
.encoding()
.as_any()
.downcast_ref::<E>()
.ok_or_else(|| vortex_err!("Mismatched encoding"))?;
let (lhs_ref, encoding) = lhs.downcast_array_ref::<E>()?;
CompareFn::compare(encoding, lhs_ref, rhs, operator)
}
}
Expand Down
7 changes: 1 addition & 6 deletions vortex-array/src/compute/fill_forward.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,7 @@ where
for<'a> &'a E::Array: TryFrom<&'a ArrayData, Error = VortexError>,
{
fn fill_forward(&self, array: &ArrayData) -> VortexResult<ArrayData> {
let array_ref = <&E::Array>::try_from(array)?;
let encoding = array
.encoding()
.as_any()
.downcast_ref::<E>()
.ok_or_else(|| vortex_err!("Mismatched encoding"))?;
let (array_ref, encoding) = array.downcast_array_ref::<E>()?;
FillForwardFn::fill_forward(encoding, array_ref)
}
}
Expand Down
9 changes: 2 additions & 7 deletions vortex-array/src/compute/fill_null.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use vortex_error::{vortex_bail, vortex_err, VortexError, VortexResult};
use vortex_error::{vortex_bail, VortexError, VortexResult};
use vortex_scalar::Scalar;

use crate::encoding::Encoding;
Expand All @@ -17,12 +17,7 @@ where
for<'a> &'a E::Array: TryFrom<&'a ArrayData, Error = VortexError>,
{
fn fill_null(&self, array: &ArrayData, fill_value: Scalar) -> VortexResult<ArrayData> {
let array_ref = <&E::Array>::try_from(array)?;
let encoding = array
.encoding()
.as_any()
.downcast_ref::<E>()
.ok_or_else(|| vortex_err!("Mismatched encoding"))?;
let (array_ref, encoding) = array.downcast_array_ref::<E>()?;
FillNullFn::fill_null(encoding, array_ref, fill_value)
}
}
Expand Down
7 changes: 1 addition & 6 deletions vortex-array/src/compute/filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,7 @@ where
for<'a> &'a E::Array: TryFrom<&'a ArrayData, Error = VortexError>,
{
fn filter(&self, array: &ArrayData, mask: FilterMask) -> VortexResult<ArrayData> {
let array_ref = <&E::Array>::try_from(array)?;
let encoding = array
.encoding()
.as_any()
.downcast_ref::<E>()
.ok_or_else(|| vortex_err!("Mismatched encoding"))?;
let (array_ref, encoding) = array.downcast_array_ref::<E>()?;
FilterFn::filter(encoding, array_ref, mask)
}
}
Expand Down
9 changes: 2 additions & 7 deletions vortex-array/src/compute/invert.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use vortex_dtype::DType;
use vortex_error::{vortex_bail, vortex_err, VortexError, VortexResult};
use vortex_error::{vortex_bail, VortexError, VortexResult};

use crate::encoding::Encoding;
use crate::{ArrayDType, ArrayData, IntoArrayData, IntoArrayVariant};
Expand All @@ -15,12 +15,7 @@ where
for<'a> &'a E::Array: TryFrom<&'a ArrayData, Error = VortexError>,
{
fn invert(&self, array: &ArrayData) -> VortexResult<ArrayData> {
let array_ref = <&E::Array>::try_from(array)?;
let encoding = array
.encoding()
.as_any()
.downcast_ref::<E>()
.ok_or_else(|| vortex_err!("Mismatched encoding"))?;
let (array_ref, encoding) = array.downcast_array_ref::<E>()?;
InvertFn::invert(encoding, array_ref)
}
}
Expand Down
9 changes: 2 additions & 7 deletions vortex-array/src/compute/like.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use vortex_dtype::DType;
use vortex_error::{vortex_bail, vortex_err, VortexError, VortexResult};
use vortex_error::{vortex_bail, VortexError, VortexResult};

use crate::arrow::{Datum, FromArrowArray};
use crate::encoding::Encoding;
Expand All @@ -25,12 +25,7 @@ where
pattern: &ArrayData,
options: LikeOptions,
) -> VortexResult<ArrayData> {
let array_ref = <&E::Array>::try_from(array)?;
let encoding = array
.encoding()
.as_any()
.downcast_ref::<E>()
.ok_or_else(|| vortex_err!("Mismatched encoding"))?;
let (array_ref, encoding) = array.downcast_array_ref::<E>()?;
LikeFn::like(encoding, array_ref, pattern, options)
}
}
Expand Down
7 changes: 1 addition & 6 deletions vortex-array/src/compute/scalar_at.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,7 @@ where
for<'a> &'a E::Array: TryFrom<&'a ArrayData, Error = VortexError>,
{
fn scalar_at(&self, array: &ArrayData, index: usize) -> VortexResult<Scalar> {
let array_ref = <&E::Array>::try_from(array)?;
let encoding = array
.encoding()
.as_any()
.downcast_ref::<E>()
.ok_or_else(|| vortex_err!("Mismatched encoding"))?;
let (array_ref, encoding) = array.downcast_array_ref::<E>()?;
ScalarAtFn::scalar_at(encoding, array_ref, index)
}
}
Expand Down
30 changes: 5 additions & 25 deletions vortex-array/src/compute/search_sorted.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::fmt::{Debug, Display, Formatter};
use std::hint;

use itertools::Itertools;
use vortex_error::{vortex_bail, vortex_err, VortexError, VortexResult};
use vortex_error::{vortex_bail, VortexError, VortexResult};
use vortex_scalar::Scalar;

use crate::compute::scalar_at;
Expand Down Expand Up @@ -164,12 +164,7 @@ where
value: &Scalar,
side: SearchSortedSide,
) -> VortexResult<SearchResult> {
let array_ref = <&E::Array>::try_from(array)?;
let encoding = array
.encoding()
.as_any()
.downcast_ref::<E>()
.ok_or_else(|| vortex_err!("Mismatched encoding"))?;
let (array_ref, encoding) = array.downcast_array_ref::<E>()?;
SearchSortedFn::search_sorted(encoding, array_ref, value, side)
}

Expand All @@ -179,12 +174,7 @@ where
values: &[Scalar],
side: SearchSortedSide,
) -> VortexResult<Vec<SearchResult>> {
let array_ref = <&E::Array>::try_from(array)?;
let encoding = array
.encoding()
.as_any()
.downcast_ref::<E>()
.ok_or_else(|| vortex_err!("Mismatched encoding"))?;
let (array_ref, encoding) = array.downcast_array_ref::<E>()?;
SearchSortedFn::search_sorted_many(encoding, array_ref, values, side)
}
}
Expand All @@ -200,12 +190,7 @@ where
value: usize,
side: SearchSortedSide,
) -> VortexResult<SearchResult> {
let array_ref = <&E::Array>::try_from(array)?;
let encoding = array
.encoding()
.as_any()
.downcast_ref::<E>()
.ok_or_else(|| vortex_err!("Mismatched encoding"))?;
let (array_ref, encoding) = array.downcast_array_ref::<E>()?;
SearchSortedUsizeFn::search_sorted_usize(encoding, array_ref, value, side)
}

Expand All @@ -215,12 +200,7 @@ where
values: &[usize],
side: SearchSortedSide,
) -> VortexResult<Vec<SearchResult>> {
let array_ref = <&E::Array>::try_from(array)?;
let encoding = array
.encoding()
.as_any()
.downcast_ref::<E>()
.ok_or_else(|| vortex_err!("Mismatched encoding"))?;
let (array_ref, encoding) = array.downcast_array_ref::<E>()?;
SearchSortedUsizeFn::search_sorted_usize_many(encoding, array_ref, values, side)
}
}
Expand Down
7 changes: 1 addition & 6 deletions vortex-array/src/compute/slice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,7 @@ where
for<'a> &'a E::Array: TryFrom<&'a ArrayData, Error = VortexError>,
{
fn slice(&self, array: &ArrayData, start: usize, stop: usize) -> VortexResult<ArrayData> {
let array_ref = <&E::Array>::try_from(array)?;
let encoding = array
.encoding()
.as_any()
.downcast_ref::<E>()
.ok_or_else(|| vortex_err!("Mismatched encoding"))?;
let (array_ref, encoding) = array.downcast_array_ref::<E>()?;
SliceFn::slice(encoding, array_ref, start, stop)
}
}
Expand Down
7 changes: 1 addition & 6 deletions vortex-array/src/compute/take.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,7 @@ where
for<'a> &'a E::Array: TryFrom<&'a ArrayData, Error = VortexError>,
{
fn take(&self, array: &ArrayData, indices: &ArrayData) -> VortexResult<ArrayData> {
let array_ref = <&E::Array>::try_from(array)?;
let encoding = array
.encoding()
.as_any()
.downcast_ref::<E>()
.ok_or_else(|| vortex_err!("Mismatched encoding"))?;
let (array_ref, encoding) = array.downcast_array_ref::<E>()?;
TakeFn::take(encoding, array_ref, indices)
}
}
Expand Down
17 changes: 15 additions & 2 deletions vortex-array/src/data/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,15 @@ use owned::OwnedArrayData;
use viewed::ViewedArrayData;
use vortex_buffer::ByteBuffer;
use vortex_dtype::DType;
use vortex_error::{vortex_err, VortexExpect, VortexResult};
use vortex_error::{vortex_err, VortexError, VortexExpect, VortexResult};
use vortex_scalar::Scalar;

use crate::array::{
BoolEncoding, ExtensionEncoding, NullEncoding, PrimitiveEncoding, StructEncoding,
VarBinEncoding, VarBinViewEncoding,
};
use crate::compute::scalar_at;
use crate::encoding::{EncodingId, EncodingRef, EncodingVTable};
use crate::encoding::{Encoding, EncodingId, EncodingRef, EncodingVTable};
use crate::iter::{ArrayIterator, ArrayIteratorAdapter};
use crate::stats::{ArrayStatistics, Stat, Statistics, StatsSet};
use crate::stream::{ArrayStream, ArrayStreamAdapter};
Expand Down Expand Up @@ -371,6 +371,19 @@ impl ArrayData {
log::warn!("{:?}", bt);
}
}

pub fn downcast_array_ref<E: Encoding>(self: &ArrayData) -> VortexResult<(&E::Array, &E)>
where
for<'a> &'a E::Array: TryFrom<&'a ArrayData, Error = VortexError>,
{
let array_ref = <&E::Array>::try_from(self)?;
let encoding = self
.encoding()
.as_any()
.downcast_ref::<E>()
.ok_or_else(|| vortex_err!("Mismatched encoding"))?;
Ok((array_ref, encoding))
}
}

impl Display for ArrayData {
Expand Down
15 changes: 0 additions & 15 deletions vortex-array/src/encoding/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@ use std::any::Any;
use std::fmt::{Debug, Display, Formatter};
use std::hash::{Hash, Hasher};

use vortex_error::{vortex_err, VortexError, VortexResult};

use crate::compute::ComputeVTable;
use crate::stats::StatisticsVTable;
use crate::validity::ValidityVTable;
Expand Down Expand Up @@ -69,19 +67,6 @@ pub trait Encoding: 'static {
type Metadata: ArrayMetadata;
}

pub fn downcast_array_ref<E: Encoding>(array: &ArrayData) -> VortexResult<(&E::Array, &E)>
where
for<'a> &'a E::Array: TryFrom<&'a ArrayData, Error = VortexError>,
{
let array_ref = <&E::Array>::try_from(array)?;
let encoding = array
.encoding()
.as_any()
.downcast_ref::<E>()
.ok_or_else(|| vortex_err!("Mismatched encoding"))?;
Ok((array_ref, encoding))
}

pub type EncodingRef = &'static dyn EncodingVTable;

/// Object-safe encoding trait for an array.
Expand Down
9 changes: 2 additions & 7 deletions vortex-array/src/stats/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use num_enum::{IntoPrimitive, TryFromPrimitive};
pub use statsset::*;
use vortex_dtype::Nullability::NonNullable;
use vortex_dtype::{DType, NativePType, PType};
use vortex_error::{vortex_err, vortex_panic, VortexError, VortexExpect, VortexResult};
use vortex_error::{vortex_panic, VortexError, VortexExpect, VortexResult};
use vortex_scalar::Scalar;

use crate::encoding::Encoding;
Expand Down Expand Up @@ -205,12 +205,7 @@ where
for<'a> &'a E::Array: TryFrom<&'a ArrayData, Error = VortexError>,
{
fn compute_statistics(&self, array: &ArrayData, stat: Stat) -> VortexResult<StatsSet> {
let array_ref = <&E::Array>::try_from(array)?;
let encoding = array
.encoding()
.as_any()
.downcast_ref::<E>()
.ok_or_else(|| vortex_err!("Mismatched encoding"))?;
let (array_ref, encoding) = array.downcast_array_ref::<E>()?;
StatisticsVTable::compute_statistics(encoding, array_ref, stat)
}
}
Expand Down
Loading

0 comments on commit 4eda890

Please sign in to comment.