diff --git a/encodings/fastlanes/src/for/vtable/mod.rs b/encodings/fastlanes/src/for/vtable/mod.rs index 37373456746..cef31f8250c 100644 --- a/encodings/fastlanes/src/for/vtable/mod.rs +++ b/encodings/fastlanes/src/for/vtable/mod.rs @@ -132,9 +132,9 @@ impl VTable for FoRVTable { dtype: &DType, _len: usize, _buffers: &[BufferHandle], - _session: &VortexSession, + session: &VortexSession, ) -> VortexResult { - let scalar_value = ScalarValue::from_proto_bytes(bytes, dtype)?; + let scalar_value = ScalarValue::from_proto_bytes(bytes, dtype, session)?; Scalar::try_new(dtype.clone(), scalar_value) } diff --git a/encodings/sequence/src/array.rs b/encodings/sequence/src/array.rs index 09c3fe27183..8f2258f025e 100644 --- a/encodings/sequence/src/array.rs +++ b/encodings/sequence/src/array.rs @@ -322,7 +322,7 @@ impl VTable for SequenceVTable { dtype: &DType, _len: usize, _buffers: &[BufferHandle], - _session: &VortexSession, + session: &VortexSession, ) -> VortexResult { let prost = as DeserializeMetadata>::deserialize(bytes)?; @@ -336,6 +336,7 @@ impl VTable for SequenceVTable { .as_ref() .ok_or_else(|| vortex_err!("base required"))?, &DType::Primitive(ptype, NonNullable), + session, )? .as_primitive() .pvalue() @@ -347,6 +348,7 @@ impl VTable for SequenceVTable { .as_ref() .ok_or_else(|| vortex_err!("multiplier required"))?, &DType::Primitive(ptype, NonNullable), + session, )? .as_primitive() .pvalue() diff --git a/encodings/sparse/src/lib.rs b/encodings/sparse/src/lib.rs index 29c8d5c190e..292c5054682 100644 --- a/encodings/sparse/src/lib.rs +++ b/encodings/sparse/src/lib.rs @@ -164,7 +164,7 @@ impl VTable for SparseVTable { dtype: &DType, _len: usize, buffers: &[BufferHandle], - _session: &VortexSession, + session: &VortexSession, ) -> VortexResult { let prost_patches = as DeserializeMetadata>::deserialize(bytes)?; @@ -176,7 +176,7 @@ impl VTable for SparseVTable { } let scalar_bytes: &[u8] = &buffers[0].clone().try_to_host_sync()?; - let scalar_value = ScalarValue::from_proto_bytes(scalar_bytes, dtype)?; + let scalar_value = ScalarValue::from_proto_bytes(scalar_bytes, dtype, session)?; let fill_value = Scalar::try_new(dtype.clone(), scalar_value)?; Ok(SparseMetadata { diff --git a/vortex-array/src/arrays/constant/vtable/mod.rs b/vortex-array/src/arrays/constant/vtable/mod.rs index ec7405a3700..879eb7005bc 100644 --- a/vortex-array/src/arrays/constant/vtable/mod.rs +++ b/vortex-array/src/arrays/constant/vtable/mod.rs @@ -122,7 +122,7 @@ impl VTable for ConstantVTable { dtype: &DType, _len: usize, buffers: &[BufferHandle], - _session: &VortexSession, + session: &VortexSession, ) -> VortexResult { vortex_ensure!( buffers.len() == 1, @@ -133,7 +133,7 @@ impl VTable for ConstantVTable { let buffer = buffers[0].clone().try_to_host_sync()?; let bytes: &[u8] = buffer.as_ref(); - let scalar_value = ScalarValue::from_proto_bytes(bytes, dtype)?; + let scalar_value = ScalarValue::from_proto_bytes(bytes, dtype, session)?; let scalar = Scalar::try_new(dtype.clone(), scalar_value)?; Ok(scalar) diff --git a/vortex-array/src/arrays/fixed_size_list/vtable/operations.rs b/vortex-array/src/arrays/fixed_size_list/vtable/operations.rs index 3c57bfaac11..baf29d2b592 100644 --- a/vortex-array/src/arrays/fixed_size_list/vtable/operations.rs +++ b/vortex-array/src/arrays/fixed_size_list/vtable/operations.rs @@ -6,22 +6,14 @@ use vortex_error::VortexResult; use crate::arrays::FixedSizeListArray; use crate::arrays::FixedSizeListVTable; use crate::scalar::Scalar; +use crate::scalar::ScalarValue; use crate::vtable::OperationsVTable; impl OperationsVTable for FixedSizeListVTable { fn scalar_at(array: &FixedSizeListArray, index: usize) -> VortexResult { // By the preconditions we know that the list scalar is not null. let list = array.fixed_size_list_elements_at(index)?; - let children_elements: Vec = (0..list.len()) - .map(|i| list.scalar_at(i)) - .collect::>()?; - - debug_assert_eq!(children_elements.len(), array.list_size() as usize); - - Ok(Scalar::fixed_size_list( - list.dtype().clone(), - children_elements, - array.dtype().nullability(), - )) + let scalar_value = ScalarValue::Array(list); + Scalar::try_new(array.dtype().clone(), Some(scalar_value)) } } diff --git a/vortex-array/src/arrays/list/vtable/operations.rs b/vortex-array/src/arrays/list/vtable/operations.rs index c326786a510..3459557d97a 100644 --- a/vortex-array/src/arrays/list/vtable/operations.rs +++ b/vortex-array/src/arrays/list/vtable/operations.rs @@ -1,27 +1,19 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors -use std::sync::Arc; - use vortex_error::VortexResult; use crate::arrays::ListArray; use crate::arrays::ListVTable; use crate::scalar::Scalar; +use crate::scalar::ScalarValue; use crate::vtable::OperationsVTable; impl OperationsVTable for ListVTable { fn scalar_at(array: &ListArray, index: usize) -> VortexResult { // By the preconditions we know that the list scalar is not null. - let elems = array.list_elements_at(index)?; - let scalars: Vec = (0..elems.len()) - .map(|i| elems.scalar_at(i)) - .collect::>()?; - - Ok(Scalar::list( - Arc::new(elems.dtype().clone()), - scalars, - array.dtype().nullability(), - )) + let list = array.list_elements_at(index)?; + let scalar_value = ScalarValue::Array(list); + Scalar::try_new(array.dtype().clone(), Some(scalar_value)) } } diff --git a/vortex-array/src/arrays/listview/vtable/operations.rs b/vortex-array/src/arrays/listview/vtable/operations.rs index 60bfc873548..f953b2769da 100644 --- a/vortex-array/src/arrays/listview/vtable/operations.rs +++ b/vortex-array/src/arrays/listview/vtable/operations.rs @@ -1,27 +1,19 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors -use std::sync::Arc; - use vortex_error::VortexResult; use crate::arrays::ListViewArray; use crate::arrays::ListViewVTable; use crate::scalar::Scalar; +use crate::scalar::ScalarValue; use crate::vtable::OperationsVTable; impl OperationsVTable for ListViewVTable { fn scalar_at(array: &ListViewArray, index: usize) -> VortexResult { // By the preconditions we know that the list scalar is not null. let list = array.list_elements_at(index)?; - let children: Vec = (0..list.len()) - .map(|i| list.scalar_at(i)) - .collect::>()?; - - Ok(Scalar::list( - Arc::new(list.dtype().clone()), - children, - array.dtype.nullability(), - )) + let scalar_value = ScalarValue::Array(list); + Scalar::try_new(array.dtype().clone(), Some(scalar_value)) } } diff --git a/vortex-array/src/scalar/proto.rs b/vortex-array/src/scalar/proto.rs index 5d0166c933c..8b9fe912c0d 100644 --- a/vortex-array/src/scalar/proto.rs +++ b/vortex-array/src/scalar/proto.rs @@ -8,17 +8,22 @@ use num_traits::ToPrimitive; use prost::Message; use vortex_buffer::BufferString; use vortex_buffer::ByteBuffer; +use vortex_buffer::ByteBufferMut; use vortex_error::VortexExpect; use vortex_error::VortexResult; use vortex_error::vortex_bail; use vortex_error::vortex_ensure; +use vortex_error::vortex_ensure_eq; use vortex_error::vortex_err; use vortex_proto::scalar as pb; use vortex_proto::scalar::ListValue; use vortex_proto::scalar::scalar_value::Kind; use vortex_session::VortexSession; +use crate::ArrayContext; +use crate::ArrayRef; use crate::dtype::DType; +use crate::dtype::DecimalDType; use crate::dtype::PType; use crate::dtype::half::f16; use crate::dtype::i256; @@ -26,6 +31,8 @@ use crate::scalar::DecimalValue; use crate::scalar::PValue; use crate::scalar::Scalar; use crate::scalar::ScalarValue; +use crate::serde::ArrayParts; +use crate::serde::SerializeOptions; //////////////////////////////////////////////////////////////////////////////////////////////////// // Serialize INTO proto. @@ -110,6 +117,31 @@ impl From<&ScalarValue> for pb::ScalarValue { kind: Some(Kind::ListValue(ListValue { values })), } } + ScalarValue::Array(array) => { + let ctx = ArrayContext::empty(); + + let serialized = array + .to_array() + .serialize(&ctx, &SerializeOptions::default()) + .vortex_expect("somehow unable to serialize value as array"); + + let mut concat = ByteBufferMut::empty(); + + // We reserve the first 8 bytes for the length of the array. + let array_len = array.to_array().len() as u64; + concat.extend_from_slice(&array_len.to_le_bytes()); + debug_assert_eq!(concat.len(), 8); + + // Then we serialize the rest of the array buffers. + for buf in serialized { + concat.extend_from_slice(buf.as_ref()); + } + let concat = concat.freeze(); + + pb::ScalarValue { + kind: Some(Kind::BytesValue(concat.to_vec())), + } + } } } } @@ -167,8 +199,12 @@ impl Scalar { /// # Errors /// /// Returns an error if type validation fails. - pub fn from_proto_value(value: &pb::ScalarValue, dtype: &DType) -> VortexResult { - let scalar_value = ScalarValue::from_proto(value, dtype)?; + pub fn from_proto_value( + value: &pb::ScalarValue, + dtype: &DType, + session: &VortexSession, + ) -> VortexResult { + let scalar_value = ScalarValue::from_proto(value, dtype, session)?; Scalar::try_new(dtype.clone(), scalar_value) } @@ -192,7 +228,7 @@ impl Scalar { .as_ref() .ok_or_else(|| vortex_err!(Serde: "Scalar missing value"))?; - let value: Option = ScalarValue::from_proto(pb_scalar_value, &dtype)?; + let value: Option = ScalarValue::from_proto(pb_scalar_value, &dtype, session)?; Scalar::try_new(dtype, value) } @@ -207,9 +243,13 @@ impl ScalarValue { /// # Errors /// /// Returns an error if decoding or type validation fails. - pub fn from_proto_bytes(bytes: &[u8], dtype: &DType) -> VortexResult> { + pub fn from_proto_bytes( + bytes: &[u8], + dtype: &DType, + session: &VortexSession, + ) -> VortexResult> { let proto = pb::ScalarValue::decode(bytes)?; - Self::from_proto(&proto, dtype) + Self::from_proto(&proto, dtype, session) } /// Creates a [`ScalarValue`] from its [protobuf](pb::ScalarValue) representation. @@ -220,7 +260,11 @@ impl ScalarValue { /// # Errors /// /// Returns an error if the protobuf value cannot be converted to the given [`DType`]. - pub fn from_proto(value: &pb::ScalarValue, dtype: &DType) -> VortexResult> { + pub fn from_proto( + value: &pb::ScalarValue, + dtype: &DType, + session: &VortexSession, + ) -> VortexResult> { let kind = value .kind .as_ref() @@ -241,8 +285,8 @@ impl ScalarValue { Kind::F32Value(v) => f32_from_proto(*v, dtype)?, Kind::F64Value(v) => f64_from_proto(*v, dtype)?, Kind::StringValue(s) => string_from_proto(s, dtype)?, - Kind::BytesValue(b) => bytes_from_proto(b, dtype)?, - Kind::ListValue(v) => list_from_proto(v, dtype)?, + Kind::BytesValue(b) => bytes_from_proto(b, dtype, session)?, + Kind::ListValue(v) => list_from_proto(v, dtype, session)?, })) } } @@ -365,64 +409,139 @@ fn string_from_proto(s: &str, dtype: &DType) -> VortexResult { } } -/// Deserialize a [`ScalarValue`] from a protobuf bytes and a `DType`. +/// Deserialize a [`ScalarValue`] from a protobuf bytes and a [`DType`]. +/// +/// Handles all variable-size scalars, including: /// -/// Handles [`Utf8`](ScalarValue::Utf8), [`Binary`](ScalarValue::Binary), and -/// [`Decimal`](ScalarValue::Decimal) dtypes. -fn bytes_from_proto(bytes: &[u8], dtype: &DType) -> VortexResult { +/// - `Utf8` -> `ScalarValue::Utf8` +/// - `Binary` -> `ScalarValue::Binary` +/// - `Decimal` -> `ScalarValue::Decimal` (Since decimal has different width representations) +fn bytes_from_proto( + bytes: &[u8], + dtype: &DType, + session: &VortexSession, +) -> VortexResult { match dtype { DType::Utf8(_) => Ok(ScalarValue::Utf8(BufferString::try_from(bytes)?)), DType::Binary(_) => Ok(ScalarValue::Binary(ByteBuffer::copy_from(bytes))), - // TODO(connor): This is incorrect, we need to verify this matches the inner decimal_dtype. - DType::Decimal(..) => Ok(ScalarValue::Decimal(match bytes.len() { - 1 => DecimalValue::I8(bytes[0] as i8), - 2 => DecimalValue::I16(i16::from_le_bytes( - bytes - .try_into() - .ok() - .vortex_expect("Buffer has invalid number of bytes"), - )), - 4 => DecimalValue::I32(i32::from_le_bytes( - bytes - .try_into() - .ok() - .vortex_expect("Buffer has invalid number of bytes"), - )), - 8 => DecimalValue::I64(i64::from_le_bytes( - bytes - .try_into() - .ok() - .vortex_expect("Buffer has invalid number of bytes"), - )), - 16 => DecimalValue::I128(i128::from_le_bytes( - bytes - .try_into() - .ok() - .vortex_expect("Buffer has invalid number of bytes"), - )), - 32 => DecimalValue::I256(i256::from_le_bytes( - bytes - .try_into() - .ok() - .vortex_expect("Buffer has invalid number of bytes"), - )), - l => vortex_bail!(Serde: "invalid decimal byte length: {l}"), - })), + DType::Decimal(decimal_dtype, _) => decimal_from_proto(bytes, decimal_dtype), + DType::List(elem_dtype, _) => { + let array = array_from_proto(bytes, elem_dtype, session)?; + Ok(ScalarValue::Array(array)) + } + DType::FixedSizeList(elem_dtype, list_size, _) => { + let array = array_from_proto(bytes, elem_dtype, session)?; + vortex_ensure_eq!(array.len(), *list_size as usize); + Ok(ScalarValue::Array(array)) + } _ => vortex_bail!( - Serde: "expected Utf8, Binary, or Decimal dtype for BytesValue, got {dtype}" + Serde: "expected Utf8, Binary, List, FSL, or Decimal dtype for BytesValue, got {dtype}" ), } } +/// Deserialize a [`ScalarValue::Decimal`] from a protobuf bytes. +fn decimal_from_proto(bytes: &[u8], _decimal_dtype: &DecimalDType) -> VortexResult { + let nbytes = bytes.len(); + // TODO(connor): Figure out if this makes any sense. + // let max_width = decimal_dtype.required_bit_width(); + // vortex_ensure!( + // max_width <= nbytes * 8, + // Serde: "invalid decimal byte length {nbytes} for decimal dtype {decimal_dtype} \ + // which requires a width of {max_width}" + // ); + + let value = match nbytes { + 1 => DecimalValue::I8(bytes[0] as i8), + 2 => { + DecimalValue::I16(i16::from_le_bytes(bytes.try_into().ok().vortex_expect( + "we just checked that there was the correct number of bytes", + ))) + } + 4 => { + DecimalValue::I32(i32::from_le_bytes(bytes.try_into().ok().vortex_expect( + "we just checked that there was the correct number of bytes", + ))) + } + 8 => { + DecimalValue::I64(i64::from_le_bytes(bytes.try_into().ok().vortex_expect( + "we just checked that there was the correct number of bytes", + ))) + } + 16 => { + DecimalValue::I128(i128::from_le_bytes(bytes.try_into().ok().vortex_expect( + "we just checked that there was the correct number of bytes", + ))) + } + 32 => { + DecimalValue::I256(i256::from_le_bytes(bytes.try_into().ok().vortex_expect( + "we just checked that there was the correct number of bytes", + ))) + } + l => vortex_bail!(Serde: "invalid decimal byte length: {l}"), + }; + + Ok(ScalarValue::Decimal(value)) +} + +// TODO(connor): Maybe this function should live somewhere else? +/// Deserialize a [`ScalarValue::Array`] from a protobuf bytes representation. +/// +/// The byte layout is: +/// - First 8 bytes: the array length as a little-endian `u64`. +/// - Remaining bytes: the serialized [`ArrayParts`]. +/// +/// # Errors +/// +/// Returns an error if the byte slice is too short, the array length exceeds `usize::MAX`, or +/// deserialization of the [`ArrayParts`] fails. +fn array_from_proto( + bytes: &[u8], + elem_dtype: &DType, + session: &VortexSession, +) -> VortexResult { + let nbytes = bytes.len(); + vortex_ensure!( + nbytes >= 8, + Serde: "expected at least 8 bytes for array length prefix, got {nbytes}", + ); + + // Retrieve the array length first, which should be in the first 8 bytes. + let len = u64::from_le_bytes( + bytes[..8] + .try_into() + .ok() + .vortex_expect("we just checked that there are at least 8 bytes"), + ); + let len = usize::try_from(len) + .map_err(|_| vortex_err!(Serde: "array length {len} exceeds usize::MAX"))?; + + let parts = ArrayParts::try_from(&bytes[8..])?; + + // Deserialize the entire array. + let ctx = ArrayContext::empty(); + let decoded = parts.decode(elem_dtype, len, &ctx, session)?; + + Ok(decoded) +} + /// Deserialize a [`ScalarValue::List`] from a protobuf `ListValue`. -fn list_from_proto(v: &ListValue, dtype: &DType) -> VortexResult { +fn list_from_proto( + v: &ListValue, + dtype: &DType, + session: &VortexSession, +) -> VortexResult { let element_dtype = dtype .as_list_element_opt() .ok_or_else(|| vortex_err!(Serde: "expected List dtype for ListValue, got {dtype}"))?; let mut values = Vec::with_capacity(v.values.len()); for elem in v.values.iter() { - values.push(ScalarValue::from_proto(elem, element_dtype.as_ref())?); + values.push(ScalarValue::from_proto( + elem, + element_dtype.as_ref(), + session, + )?); } Ok(ScalarValue::List(values)) @@ -604,6 +723,7 @@ mod tests { let scalar_value = ScalarValue::from_proto( &pb_scalar_value, &DType::Primitive(PType::U64, Nullability::NonNullable), + &session(), ) .unwrap(); assert_eq!( @@ -615,6 +735,7 @@ mod tests { let scalar_value_f16 = ScalarValue::from_proto( &pb_scalar_value, &DType::Primitive(PType::F16, Nullability::Nullable), + &session(), ) .unwrap(); @@ -651,6 +772,7 @@ mod tests { let read_back = ScalarValue::from_proto( &pb_value, &DType::Primitive(PType::F16, Nullability::NonNullable), + &session(), ) .unwrap(); @@ -729,7 +851,7 @@ mod tests { for (name, value, dtype) in exact_roundtrip_cases { let pb_value = ScalarValue::to_proto(value.as_ref()); - let read_back = ScalarValue::from_proto(&pb_value, &dtype).unwrap(); + let read_back = ScalarValue::from_proto(&pb_value, &dtype, &session()).unwrap(); let original_debug = format!("{value:?}"); let roundtrip_debug = format!("{read_back:?}"); @@ -764,7 +886,7 @@ mod tests { for (name, value, dtype, expected) in unsigned_cases { let pb_value = ScalarValue::to_proto(Some(&value)); - let read_back = ScalarValue::from_proto(&pb_value, &dtype).unwrap(); + let read_back = ScalarValue::from_proto(&pb_value, &dtype, &session()).unwrap(); match read_back.as_ref() { Some(ScalarValue::Primitive(pv)) => { @@ -808,7 +930,7 @@ mod tests { for (name, value, dtype, expected) in signed_cases { let pb_value = ScalarValue::to_proto(Some(&value)); - let read_back = ScalarValue::from_proto(&pb_value, &dtype).unwrap(); + let read_back = ScalarValue::from_proto(&pb_value, &dtype, &session()).unwrap(); match read_back.as_ref() { Some(ScalarValue::Primitive(pv)) => { @@ -837,7 +959,8 @@ mod tests { assert_eq!( Scalar::from_proto_value( &pb::ScalarValue::from(&v), - &DType::Primitive(PType::U64, Nullability::Nullable) + &DType::Primitive(PType::U64, Nullability::Nullable), + &session() ) .unwrap(), Scalar::primitive(0u64, Nullability::Nullable) @@ -852,7 +975,8 @@ mod tests { assert_eq!( Scalar::from_proto_value( &pb::ScalarValue::from(&v), - &DType::Primitive(PType::I64, Nullability::Nullable) + &DType::Primitive(PType::I64, Nullability::Nullable), + &session() ) .unwrap(), Scalar::primitive(0i64, Nullability::Nullable) diff --git a/vortex-array/src/scalar/scalar_value.rs b/vortex-array/src/scalar/scalar_value.rs index 71bf4fcfc77..c3e0cbb73a8 100644 --- a/vortex-array/src/scalar/scalar_value.rs +++ b/vortex-array/src/scalar/scalar_value.rs @@ -4,36 +4,63 @@ //! Core [`ScalarValue`] type definition. use std::cmp::Ordering; -use std::fmt::Display; -use std::fmt::Formatter; +use std::fmt; +use std::hash::Hash; +use std::hash::Hasher; use itertools::Itertools; use vortex_buffer::BufferString; use vortex_buffer::ByteBuffer; +use vortex_error::VortexExpect; +use vortex_error::VortexResult; +use vortex_error::vortex_bail; use vortex_error::vortex_panic; +use crate::ArrayRef; use crate::dtype::DType; use crate::scalar::DecimalValue; use crate::scalar::PValue; -/// The value stored in a [`Scalar`][crate::scalar::Scalar]. +/// The value stored in a [`Scalar`](crate::scalar::Scalar). /// /// This enum represents the possible non-null values that can be stored in a scalar. When the /// scalar is null, the value is represented as `None` in the `Option` field. -#[derive(Debug, Clone, PartialEq, Eq, Hash)] +/// +/// We can think of this type loosely (heavy emphasis on "loosely") as the serialization / physical +/// type of scalars. +#[derive(Debug, Clone)] pub enum ScalarValue { /// A boolean value. Bool(bool), + /// A primitive numeric value. Primitive(PValue), + /// A decimal value. Decimal(DecimalValue), + /// A UTF-8 encoded string value. Utf8(BufferString), + /// A binary (byte array) value. Binary(ByteBuffer), + /// A list of potentially null scalar values. + /// + /// Previously, we used this to represent all of struct, list, and fixed-size list scalars, but + /// with the recent addition of the `ScalarValue::Array` variant below, this only stores struct + /// scalars (and any other list scalars written to files in the past for backcompat). List(Vec>), + + /// An Array. + /// + /// We serialize this by using [`Array::serialize()`] into protobuf bytes. Note that because we + /// require passing the length of the array to deserialize with [`ArrayParts::decode()`], we + /// store the array length in the first 8 bytes. + /// + /// [`Array::serialize()`]: crate::Array::serialize + /// [`ArrayParts::decode()`]: crate::serde::ArrayParts::decode + Array(ArrayRef), } impl ScalarValue { @@ -102,25 +129,25 @@ impl ScalarValue { } }) } -} -impl PartialOrd for ScalarValue { - fn partial_cmp(&self, other: &Self) -> Option { - match (self, other) { - (ScalarValue::Bool(a), ScalarValue::Bool(b)) => a.partial_cmp(b), - (ScalarValue::Primitive(a), ScalarValue::Primitive(b)) => a.partial_cmp(b), - (ScalarValue::Decimal(a), ScalarValue::Decimal(b)) => a.partial_cmp(b), - (ScalarValue::Utf8(a), ScalarValue::Utf8(b)) => a.partial_cmp(b), - (ScalarValue::Binary(a), ScalarValue::Binary(b)) => a.partial_cmp(b), - (ScalarValue::List(a), ScalarValue::List(b)) => a.partial_cmp(b), - // (ScalarValue::Extension(a), ScalarValue::Extension(b)) => a.partial_cmp(b), - _ => None, - } + /// Turns a scalar value `Array` variant into a `List` variant. + /// + /// TODO + pub fn expand_array(&self) -> VortexResult { + let Self::Array(array) = self else { + vortex_bail!("tried to expand a scalar value that was not an array"); + }; + + let values = (0..array.len()) + .map(|i| array.scalar_at(i).map(|scalar| scalar.value().cloned())) + .collect::>>()?; + + Ok(Self::List(values)) } } -impl Display for ScalarValue { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { +impl fmt::Display for ScalarValue { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { ScalarValue::Bool(b) => write!(f, "{b}"), ScalarValue::Primitive(p) => write!(f, "{p}"), @@ -163,6 +190,83 @@ impl Display for ScalarValue { } write!(f, "]") } + ScalarValue::Array(_) => { + // We simply expand the value out into a list to display it. + let expanded = self.expand_array().vortex_expect( + "something went wrong when expanding scalar value for displaying", + ); + + expanded.fmt(f) + } + } + } +} + +impl PartialEq for ScalarValue { + fn eq(&self, other: &Self) -> bool { + match (self, other) { + (ScalarValue::Bool(a), ScalarValue::Bool(b)) => a == b, + (ScalarValue::Primitive(a), ScalarValue::Primitive(b)) => a == b, + (ScalarValue::Decimal(a), ScalarValue::Decimal(b)) => a == b, + (ScalarValue::Utf8(a), ScalarValue::Utf8(b)) => a == b, + (ScalarValue::Binary(a), ScalarValue::Binary(b)) => a == b, + (ScalarValue::List(a), ScalarValue::List(b)) => a == b, + (ScalarValue::Array(_), ScalarValue::Array(_)) => { + // We simply expand the value out into a list before doing any comparison. + let lhs = self.expand_array().vortex_expect( + "something went wrong when expanding lhs scalar value for comparison", + ); + let rhs = other.expand_array().vortex_expect( + "something went wrong when expanding rhs scalar value for comparison", + ); + + lhs == rhs + } + _ => false, + } + } +} + +impl Eq for ScalarValue {} + +impl PartialOrd for ScalarValue { + fn partial_cmp(&self, other: &Self) -> Option { + match (self, other) { + (ScalarValue::Bool(a), ScalarValue::Bool(b)) => a.partial_cmp(b), + (ScalarValue::Primitive(a), ScalarValue::Primitive(b)) => a.partial_cmp(b), + (ScalarValue::Decimal(a), ScalarValue::Decimal(b)) => a.partial_cmp(b), + (ScalarValue::Utf8(a), ScalarValue::Utf8(b)) => a.partial_cmp(b), + (ScalarValue::Binary(a), ScalarValue::Binary(b)) => a.partial_cmp(b), + (ScalarValue::List(a), ScalarValue::List(b)) => a.partial_cmp(b), + (ScalarValue::Array(_), ScalarValue::Array(_)) => { + // We simply expand the value out into a list before doing any comparison. + let lhs = self.expand_array().vortex_expect( + "something went wrong when expanding lhs scalar value for comparison", + ); + let rhs = other.expand_array().vortex_expect( + "something went wrong when expanding rhs scalar value for comparison", + ); + + lhs.partial_cmp(&rhs) + } + _ => None, + } + } +} + +impl Hash for ScalarValue { + fn hash(&self, state: &mut H) { + match self { + ScalarValue::Bool(b) => b.hash(state), + ScalarValue::Primitive(p) => p.hash(state), + ScalarValue::Decimal(d) => d.hash(state), + ScalarValue::Utf8(s) => s.hash(state), + ScalarValue::Binary(b) => b.hash(state), + ScalarValue::List(l) => l.hash(state), + ScalarValue::Array(_) => self + .expand_array() + .vortex_expect("something went wrong when expanding scalar value for hashing") + .hash(state), } } } diff --git a/vortex-array/src/scalar/validate.rs b/vortex-array/src/scalar/validate.rs index 3fed201ed93..22d3a4ae0dd 100644 --- a/vortex-array/src/scalar/validate.rs +++ b/vortex-array/src/scalar/validate.rs @@ -74,6 +74,13 @@ impl Scalar { ); } DType::List(elem_dtype, _) => { + if let ScalarValue::Array(array) = value { + // If the scalar value is an array, we just need to check that the array + // elements match the elem_dtype, so just check that the dtype is the same. + vortex_ensure_eq!(elem_dtype.as_ref(), array.dtype()); + return Ok(()); + } + let ScalarValue::List(elements) = value else { vortex_bail!("list dtype expected List value, got {value}"); }; @@ -84,6 +91,13 @@ impl Scalar { } } DType::FixedSizeList(elem_dtype, size, _) => { + if let ScalarValue::Array(array) = value { + // If the scalar value is an array, we just need to check that the array + // elements match the elem_dtype, so just check that the dtype is the same. + vortex_ensure_eq!(elem_dtype.as_ref(), array.dtype()); + return Ok(()); + } + let ScalarValue::List(elements) = value else { vortex_bail!("fixed-size list dtype expected List value, got {value}",); }; diff --git a/vortex-array/src/serde.rs b/vortex-array/src/serde.rs index d360a7bc35c..6caa46dd39c 100644 --- a/vortex-array/src/serde.rs +++ b/vortex-array/src/serde.rs @@ -377,7 +377,7 @@ impl ArrayParts { if let Some(stats) = self.flatbuffer().stats() { decoded .statistics() - .set_iter(StatsSet::from_flatbuffer(&stats, dtype)?.into_iter()); + .set_iter(StatsSet::from_flatbuffer(&stats, dtype, session)?.into_iter()); } Ok(decoded) @@ -639,6 +639,30 @@ impl ArrayChildren for ArrayPartsChildren<'_> { } } +// TODO(connor): There is probably performance that is left on the table here. + +impl TryFrom<&[u8]> for ArrayParts { + type Error = VortexError; + + fn try_from(value: &[u8]) -> Result { + // The final 4 bytes contain the length of the flatbuffer. + if value.len() < 4 { + vortex_bail!("ArrayParts buffer is too short"); + } + + let fb_length = u32::try_from_le_bytes(&value[value.len() - 4..])? as usize; + if value.len() < 4 + fb_length { + vortex_bail!("ArrayParts buffer is too short for flatbuffer"); + } + + let fb_offset = value.len() - 4 - fb_length; + let array_tree = ByteBuffer::copy_from(&value[fb_offset..fb_offset + fb_length]); + let segment = BufferHandle::new_host(ByteBuffer::copy_from(&value[0..fb_offset])); + + Self::from_flatbuffer_and_segment(array_tree, segment) + } +} + impl TryFrom for ArrayParts { type Error = VortexError; diff --git a/vortex-array/src/stats/flatbuffers.rs b/vortex-array/src/stats/flatbuffers.rs index fafa6ae32f1..2fbe7774a40 100644 --- a/vortex-array/src/stats/flatbuffers.rs +++ b/vortex-array/src/stats/flatbuffers.rs @@ -7,6 +7,7 @@ use vortex_error::VortexResult; use vortex_error::vortex_bail; use vortex_flatbuffers::WriteFlatBuffer; use vortex_flatbuffers::array as fba; +use vortex_session::VortexSession; use crate::dtype::DType; use crate::dtype::Nullability; @@ -113,6 +114,7 @@ impl StatsSet { pub fn from_flatbuffer<'a>( fb: &fba::ArrayStats<'a>, array_dtype: &DType, + session: &VortexSession, ) -> VortexResult { let mut stats_set = StatsSet::default(); @@ -142,7 +144,8 @@ impl StatsSet { if let Some(max) = fb.max() && let Some(stat_dtype) = stat_dtype { - let value = ScalarValue::from_proto_bytes(max.bytes(), &stat_dtype)?; + let value = + ScalarValue::from_proto_bytes(max.bytes(), &stat_dtype, session)?; let Some(value) = value else { continue; }; @@ -161,7 +164,8 @@ impl StatsSet { if let Some(min) = fb.min() && let Some(stat_dtype) = stat_dtype { - let value = ScalarValue::from_proto_bytes(min.bytes(), &stat_dtype)?; + let value = + ScalarValue::from_proto_bytes(min.bytes(), &stat_dtype, session)?; let Some(value) = value else { continue; }; @@ -193,7 +197,8 @@ impl StatsSet { if let Some(sum) = fb.sum() && let Some(stat_dtype) = stat_dtype { - let value = ScalarValue::from_proto_bytes(sum.bytes(), &stat_dtype)?; + let value = + ScalarValue::from_proto_bytes(sum.bytes(), &stat_dtype, session)?; let Some(value) = value else { continue; }; diff --git a/vortex-file/src/footer/deserializer.rs b/vortex-file/src/footer/deserializer.rs index fb644812ea7..9b594e31fd6 100644 --- a/vortex-file/src/footer/deserializer.rs +++ b/vortex-file/src/footer/deserializer.rs @@ -142,7 +142,13 @@ impl FooterDeserializer { .statistics .as_ref() .map(|segment| { - self.parse_file_statistics(initial_offset, &self.buffer, segment, &dtype) + self.parse_file_statistics( + initial_offset, + &self.buffer, + segment, + &dtype, + &self.session, + ) }) .transpose()?; @@ -222,13 +228,14 @@ impl FooterDeserializer { initial_read: &[u8], segment: &PostscriptSegment, dtype: &DType, + session: &VortexSession, ) -> VortexResult { let offset = usize::try_from(segment.offset - initial_offset)?; let sliced_buffer = FlatBuffer::copy_from(&initial_read[offset..offset + (segment.length as usize)]); let fb = root::(&sliced_buffer)?; - FileStatistics::from_flatbuffer(&fb, dtype) + FileStatistics::from_flatbuffer(&fb, dtype, session) } /// Parse the rest of the footer from the initial read. diff --git a/vortex-file/src/footer/file_statistics.rs b/vortex-file/src/footer/file_statistics.rs index 704ae6a44c0..4fac3ad8482 100644 --- a/vortex-file/src/footer/file_statistics.rs +++ b/vortex-file/src/footer/file_statistics.rs @@ -20,6 +20,7 @@ use vortex_flatbuffers::FlatBufferRoot; use vortex_flatbuffers::WriteFlatBuffer; use vortex_flatbuffers::array::ArrayStats; use vortex_flatbuffers::footer as fb; +use vortex_session::VortexSession; /// Contains statistical information about the data in a Vortex file. /// @@ -90,6 +91,7 @@ impl FileStatistics { pub fn from_flatbuffer<'a>( fb: &fb::FileStatistics<'a>, file_dtype: &DType, + session: &VortexSession, ) -> VortexResult { let field_stats = fb.field_stats().unwrap_or_default(); let mut array_stats: Vec = field_stats.iter().collect(); @@ -101,7 +103,7 @@ impl FileStatistics { .into_iter() .zip(struct_fields.fields()) .map(|(array_stat, field_dtype)| { - StatsSet::from_flatbuffer(&array_stat, &field_dtype) + StatsSet::from_flatbuffer(&array_stat, &field_dtype, session) }) .try_collect()?; @@ -117,7 +119,7 @@ impl FileStatistics { let array_stat = array_stats .pop() .vortex_expect("we just checked that there was 1 field"); - let stats_set = StatsSet::from_flatbuffer(&array_stat, file_dtype)?; + let stats_set = StatsSet::from_flatbuffer(&array_stat, file_dtype, session)?; Ok(Self { stats: Arc::new([stats_set]),