diff --git a/vortex-array/public-api.lock b/vortex-array/public-api.lock index 2abdf599602..fb2f9b9f20e 100644 --- a/vortex-array/public-api.lock +++ b/vortex-array/public-api.lock @@ -6366,23 +6366,33 @@ pub trait vortex_array::dtype::extension::ExtVTable: 'static + core::marker::Siz pub type vortex_array::dtype::extension::ExtVTable::Metadata: 'static + core::marker::Send + core::marker::Sync + core::clone::Clone + core::fmt::Debug + core::fmt::Display + core::cmp::Eq + core::hash::Hash -pub fn vortex_array::dtype::extension::ExtVTable::deserialize(&self, metadata: &[u8]) -> vortex_error::VortexResult +pub type vortex_array::dtype::extension::ExtVTable::NativeValue<'a>: core::fmt::Display + +pub fn vortex_array::dtype::extension::ExtVTable::deserialize_metadata(&self, metadata: &[u8]) -> vortex_error::VortexResult pub fn vortex_array::dtype::extension::ExtVTable::id(&self) -> vortex_array::dtype::extension::ExtId -pub fn vortex_array::dtype::extension::ExtVTable::serialize(&self, metadata: &Self::Metadata) -> vortex_error::VortexResult> +pub fn vortex_array::dtype::extension::ExtVTable::serialize_metadata(&self, metadata: &Self::Metadata) -> vortex_error::VortexResult> + +pub fn vortex_array::dtype::extension::ExtVTable::unpack_native<'a>(&self, metadata: &'a Self::Metadata, storage_dtype: &'a vortex_array::dtype::DType, storage_value: &'a vortex_array::scalar::ScalarValue) -> vortex_error::VortexResult pub fn vortex_array::dtype::extension::ExtVTable::validate_dtype(&self, metadata: &Self::Metadata, storage_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult<()> +pub fn vortex_array::dtype::extension::ExtVTable::validate_scalar_value(&self, metadata: &Self::Metadata, storage_dtype: &vortex_array::dtype::DType, storage_value: &vortex_array::scalar::ScalarValue) -> vortex_error::VortexResult<()> + impl vortex_array::dtype::extension::ExtVTable for vortex_array::extension::datetime::Date pub type vortex_array::extension::datetime::Date::Metadata = vortex_array::extension::datetime::TimeUnit -pub fn vortex_array::extension::datetime::Date::deserialize(&self, metadata: &[u8]) -> vortex_error::VortexResult +pub type vortex_array::extension::datetime::Date::NativeValue<'a> = vortex_array::extension::datetime::DateValue + +pub fn vortex_array::extension::datetime::Date::deserialize_metadata(&self, metadata: &[u8]) -> vortex_error::VortexResult pub fn vortex_array::extension::datetime::Date::id(&self) -> vortex_array::dtype::extension::ExtId -pub fn vortex_array::extension::datetime::Date::serialize(&self, metadata: &Self::Metadata) -> vortex_error::VortexResult> +pub fn vortex_array::extension::datetime::Date::serialize_metadata(&self, metadata: &Self::Metadata) -> vortex_error::VortexResult> + +pub fn vortex_array::extension::datetime::Date::unpack_native(&self, metadata: &Self::Metadata, _storage_dtype: &vortex_array::dtype::DType, storage_value: &vortex_array::scalar::ScalarValue) -> vortex_error::VortexResult pub fn vortex_array::extension::datetime::Date::validate_dtype(&self, metadata: &Self::Metadata, storage_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult<()> @@ -6390,11 +6400,15 @@ impl vortex_array::dtype::extension::ExtVTable for vortex_array::extension::date pub type vortex_array::extension::datetime::Time::Metadata = vortex_array::extension::datetime::TimeUnit -pub fn vortex_array::extension::datetime::Time::deserialize(&self, data: &[u8]) -> vortex_error::VortexResult +pub type vortex_array::extension::datetime::Time::NativeValue<'a> = vortex_array::extension::datetime::TimeValue + +pub fn vortex_array::extension::datetime::Time::deserialize_metadata(&self, data: &[u8]) -> vortex_error::VortexResult pub fn vortex_array::extension::datetime::Time::id(&self) -> vortex_array::dtype::extension::ExtId -pub fn vortex_array::extension::datetime::Time::serialize(&self, metadata: &Self::Metadata) -> vortex_error::VortexResult> +pub fn vortex_array::extension::datetime::Time::serialize_metadata(&self, metadata: &Self::Metadata) -> vortex_error::VortexResult> + +pub fn vortex_array::extension::datetime::Time::unpack_native(&self, metadata: &Self::Metadata, _storage_dtype: &vortex_array::dtype::DType, storage_value: &vortex_array::scalar::ScalarValue) -> vortex_error::VortexResult pub fn vortex_array::extension::datetime::Time::validate_dtype(&self, metadata: &Self::Metadata, storage_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult<()> @@ -6402,11 +6416,15 @@ impl vortex_array::dtype::extension::ExtVTable for vortex_array::extension::date pub type vortex_array::extension::datetime::Timestamp::Metadata = vortex_array::extension::datetime::TimestampOptions -pub fn vortex_array::extension::datetime::Timestamp::deserialize(&self, data: &[u8]) -> vortex_error::VortexResult +pub type vortex_array::extension::datetime::Timestamp::NativeValue<'a> = vortex_array::extension::datetime::TimestampValue<'a> + +pub fn vortex_array::extension::datetime::Timestamp::deserialize_metadata(&self, data: &[u8]) -> vortex_error::VortexResult pub fn vortex_array::extension::datetime::Timestamp::id(&self) -> vortex_array::dtype::extension::ExtId -pub fn vortex_array::extension::datetime::Timestamp::serialize(&self, metadata: &Self::Metadata) -> vortex_error::VortexResult> +pub fn vortex_array::extension::datetime::Timestamp::serialize_metadata(&self, metadata: &Self::Metadata) -> vortex_error::VortexResult> + +pub fn vortex_array::extension::datetime::Timestamp::unpack_native<'a>(&self, metadata: &'a Self::Metadata, _storage_dtype: &'a vortex_array::dtype::DType, storage_value: &'a vortex_array::scalar::ScalarValue) -> vortex_error::VortexResult pub fn vortex_array::extension::datetime::Timestamp::validate_dtype(&self, _metadata: &Self::Metadata, storage_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult<()> @@ -12192,6 +12210,16 @@ pub mod vortex_array::extension pub mod vortex_array::extension::datetime +pub enum vortex_array::extension::datetime::DateValue + +pub vortex_array::extension::datetime::DateValue::Days(i32) + +pub vortex_array::extension::datetime::DateValue::Milliseconds(i64) + +impl core::fmt::Display for vortex_array::extension::datetime::DateValue + +pub fn vortex_array::extension::datetime::DateValue::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result + pub enum vortex_array::extension::datetime::TemporalJiff pub vortex_array::extension::datetime::TemporalJiff::Date(jiff::civil::date::Date) @@ -12306,6 +12334,34 @@ impl core::marker::Copy for vortex_array::extension::datetime::TimeUnit impl core::marker::StructuralPartialEq for vortex_array::extension::datetime::TimeUnit +pub enum vortex_array::extension::datetime::TimeValue + +pub vortex_array::extension::datetime::TimeValue::Microseconds(i64) + +pub vortex_array::extension::datetime::TimeValue::Milliseconds(i32) + +pub vortex_array::extension::datetime::TimeValue::Nanoseconds(i64) + +pub vortex_array::extension::datetime::TimeValue::Seconds(i32) + +impl core::fmt::Display for vortex_array::extension::datetime::TimeValue + +pub fn vortex_array::extension::datetime::TimeValue::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result + +pub enum vortex_array::extension::datetime::TimestampValue<'a> + +pub vortex_array::extension::datetime::TimestampValue::Microseconds(i64, core::option::Option<&'a alloc::sync::Arc>) + +pub vortex_array::extension::datetime::TimestampValue::Milliseconds(i64, core::option::Option<&'a alloc::sync::Arc>) + +pub vortex_array::extension::datetime::TimestampValue::Nanoseconds(i64, core::option::Option<&'a alloc::sync::Arc>) + +pub vortex_array::extension::datetime::TimestampValue::Seconds(i64, core::option::Option<&'a alloc::sync::Arc>) + +impl core::fmt::Display for vortex_array::extension::datetime::TimestampValue<'_> + +pub fn vortex_array::extension::datetime::TimestampValue<'_>::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result + pub struct vortex_array::extension::datetime::AnyTemporal impl vortex_array::dtype::extension::Matcher for vortex_array::extension::datetime::AnyTemporal @@ -12350,11 +12406,15 @@ impl vortex_array::dtype::extension::ExtVTable for vortex_array::extension::date pub type vortex_array::extension::datetime::Date::Metadata = vortex_array::extension::datetime::TimeUnit -pub fn vortex_array::extension::datetime::Date::deserialize(&self, metadata: &[u8]) -> vortex_error::VortexResult +pub type vortex_array::extension::datetime::Date::NativeValue<'a> = vortex_array::extension::datetime::DateValue + +pub fn vortex_array::extension::datetime::Date::deserialize_metadata(&self, metadata: &[u8]) -> vortex_error::VortexResult pub fn vortex_array::extension::datetime::Date::id(&self) -> vortex_array::dtype::extension::ExtId -pub fn vortex_array::extension::datetime::Date::serialize(&self, metadata: &Self::Metadata) -> vortex_error::VortexResult> +pub fn vortex_array::extension::datetime::Date::serialize_metadata(&self, metadata: &Self::Metadata) -> vortex_error::VortexResult> + +pub fn vortex_array::extension::datetime::Date::unpack_native(&self, metadata: &Self::Metadata, _storage_dtype: &vortex_array::dtype::DType, storage_value: &vortex_array::scalar::ScalarValue) -> vortex_error::VortexResult pub fn vortex_array::extension::datetime::Date::validate_dtype(&self, metadata: &Self::Metadata, storage_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult<()> @@ -12394,11 +12454,15 @@ impl vortex_array::dtype::extension::ExtVTable for vortex_array::extension::date pub type vortex_array::extension::datetime::Time::Metadata = vortex_array::extension::datetime::TimeUnit -pub fn vortex_array::extension::datetime::Time::deserialize(&self, data: &[u8]) -> vortex_error::VortexResult +pub type vortex_array::extension::datetime::Time::NativeValue<'a> = vortex_array::extension::datetime::TimeValue + +pub fn vortex_array::extension::datetime::Time::deserialize_metadata(&self, data: &[u8]) -> vortex_error::VortexResult pub fn vortex_array::extension::datetime::Time::id(&self) -> vortex_array::dtype::extension::ExtId -pub fn vortex_array::extension::datetime::Time::serialize(&self, metadata: &Self::Metadata) -> vortex_error::VortexResult> +pub fn vortex_array::extension::datetime::Time::serialize_metadata(&self, metadata: &Self::Metadata) -> vortex_error::VortexResult> + +pub fn vortex_array::extension::datetime::Time::unpack_native(&self, metadata: &Self::Metadata, _storage_dtype: &vortex_array::dtype::DType, storage_value: &vortex_array::scalar::ScalarValue) -> vortex_error::VortexResult pub fn vortex_array::extension::datetime::Time::validate_dtype(&self, metadata: &Self::Metadata, storage_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult<()> @@ -12440,11 +12504,15 @@ impl vortex_array::dtype::extension::ExtVTable for vortex_array::extension::date pub type vortex_array::extension::datetime::Timestamp::Metadata = vortex_array::extension::datetime::TimestampOptions -pub fn vortex_array::extension::datetime::Timestamp::deserialize(&self, data: &[u8]) -> vortex_error::VortexResult +pub type vortex_array::extension::datetime::Timestamp::NativeValue<'a> = vortex_array::extension::datetime::TimestampValue<'a> + +pub fn vortex_array::extension::datetime::Timestamp::deserialize_metadata(&self, data: &[u8]) -> vortex_error::VortexResult pub fn vortex_array::extension::datetime::Timestamp::id(&self) -> vortex_array::dtype::extension::ExtId -pub fn vortex_array::extension::datetime::Timestamp::serialize(&self, metadata: &Self::Metadata) -> vortex_error::VortexResult> +pub fn vortex_array::extension::datetime::Timestamp::serialize_metadata(&self, metadata: &Self::Metadata) -> vortex_error::VortexResult> + +pub fn vortex_array::extension::datetime::Timestamp::unpack_native<'a>(&self, metadata: &'a Self::Metadata, _storage_dtype: &'a vortex_array::dtype::DType, storage_value: &'a vortex_array::scalar::ScalarValue) -> vortex_error::VortexResult pub fn vortex_array::extension::datetime::Timestamp::validate_dtype(&self, _metadata: &Self::Metadata, storage_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult<()> @@ -13036,6 +13104,80 @@ pub const vortex_array::patches::PATCH_CHUNK_SIZE: usize pub mod vortex_array::scalar +pub mod vortex_array::scalar::extension + +pub struct vortex_array::scalar::extension::ExtScalarValue(_) + +impl vortex_array::scalar::extension::ExtScalarValue + +pub fn vortex_array::scalar::extension::ExtScalarValue::erased(self) -> vortex_array::scalar::extension::ExtScalarValueRef + +pub fn vortex_array::scalar::extension::ExtScalarValue::id(&self) -> vortex_array::dtype::extension::ExtId + +pub fn vortex_array::scalar::extension::ExtScalarValue::storage_value(&self) -> &vortex_array::scalar::ScalarValue + +pub fn vortex_array::scalar::extension::ExtScalarValue::try_new(ext_dtype: &vortex_array::dtype::extension::ExtDType, storage: vortex_array::scalar::ScalarValue) -> vortex_error::VortexResult + +pub fn vortex_array::scalar::extension::ExtScalarValue::vtable(&self) -> &V + +impl core::clone::Clone for vortex_array::scalar::extension::ExtScalarValue + +pub fn vortex_array::scalar::extension::ExtScalarValue::clone(&self) -> vortex_array::scalar::extension::ExtScalarValue + +impl core::cmp::Eq for vortex_array::scalar::extension::ExtScalarValue + +impl core::cmp::PartialEq for vortex_array::scalar::extension::ExtScalarValue + +pub fn vortex_array::scalar::extension::ExtScalarValue::eq(&self, other: &vortex_array::scalar::extension::ExtScalarValue) -> bool + +impl core::fmt::Debug for vortex_array::scalar::extension::ExtScalarValue + +pub fn vortex_array::scalar::extension::ExtScalarValue::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result + +impl core::hash::Hash for vortex_array::scalar::extension::ExtScalarValue + +pub fn vortex_array::scalar::extension::ExtScalarValue::hash<__H: core::hash::Hasher>(&self, state: &mut __H) + +impl core::marker::StructuralPartialEq for vortex_array::scalar::extension::ExtScalarValue + +pub struct vortex_array::scalar::extension::ExtScalarValueRef(_) + +impl vortex_array::scalar::extension::ExtScalarValueRef + +pub fn vortex_array::scalar::extension::ExtScalarValueRef::downcast(self) -> vortex_array::scalar::extension::ExtScalarValue + +pub fn vortex_array::scalar::extension::ExtScalarValueRef::id(&self) -> vortex_array::dtype::extension::ExtId + +pub fn vortex_array::scalar::extension::ExtScalarValueRef::storage_value(&self) -> &vortex_array::scalar::ScalarValue + +pub fn vortex_array::scalar::extension::ExtScalarValueRef::try_downcast(self) -> core::result::Result, vortex_array::scalar::extension::ExtScalarValueRef> + +impl core::clone::Clone for vortex_array::scalar::extension::ExtScalarValueRef + +pub fn vortex_array::scalar::extension::ExtScalarValueRef::clone(&self) -> vortex_array::scalar::extension::ExtScalarValueRef + +impl core::cmp::Eq for vortex_array::scalar::extension::ExtScalarValueRef + +impl core::cmp::PartialEq for vortex_array::scalar::extension::ExtScalarValueRef + +pub fn vortex_array::scalar::extension::ExtScalarValueRef::eq(&self, other: &Self) -> bool + +impl core::cmp::PartialOrd for vortex_array::scalar::extension::ExtScalarValueRef + +pub fn vortex_array::scalar::extension::ExtScalarValueRef::partial_cmp(&self, other: &Self) -> core::option::Option + +impl core::fmt::Debug for vortex_array::scalar::extension::ExtScalarValueRef + +pub fn vortex_array::scalar::extension::ExtScalarValueRef::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result + +impl core::fmt::Display for vortex_array::scalar::extension::ExtScalarValueRef + +pub fn vortex_array::scalar::extension::ExtScalarValueRef::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result + +impl core::hash::Hash for vortex_array::scalar::extension::ExtScalarValueRef + +pub fn vortex_array::scalar::extension::ExtScalarValueRef::hash(&self, state: &mut H) + pub enum vortex_array::scalar::DecimalValue pub vortex_array::scalar::DecimalValue::I128(i128) diff --git a/vortex-array/src/arrays/extension/compute/rules.rs b/vortex-array/src/arrays/extension/compute/rules.rs index a6d890dae4e..74aab6ad4fd 100644 --- a/vortex-array/src/arrays/extension/compute/rules.rs +++ b/vortex-array/src/arrays/extension/compute/rules.rs @@ -75,16 +75,26 @@ mod tests { use crate::extension::EmptyMetadata; use crate::optimizer::ArrayOptimizer; use crate::scalar::Scalar; + use crate::scalar::ScalarValue; #[derive(Clone, Debug, Default, PartialEq, Eq, Hash)] struct TestExt; impl ExtVTable for TestExt { type Metadata = EmptyMetadata; + type NativeValue<'a> = &'a str; fn id(&self) -> ExtId { ExtId::new_ref("test_ext") } + fn serialize_metadata(&self, _metadata: &Self::Metadata) -> VortexResult> { + Ok(vec![]) + } + + fn deserialize_metadata(&self, _data: &[u8]) -> VortexResult { + Ok(EmptyMetadata) + } + fn validate_dtype( &self, _options: &Self::Metadata, @@ -92,6 +102,15 @@ mod tests { ) -> VortexResult<()> { Ok(()) } + + fn unpack_native<'a>( + &self, + _metadata: &'a Self::Metadata, + _storage_dtype: &'a DType, + _storage_value: &'a ScalarValue, + ) -> VortexResult> { + Ok("") + } } fn test_ext_dtype() -> ExtDTypeRef { @@ -164,11 +183,20 @@ mod tests { struct TestExt2; impl ExtVTable for TestExt2 { type Metadata = EmptyMetadata; + type NativeValue<'a> = &'a str; fn id(&self) -> ExtId { ExtId::new_ref("test_ext_2") } + fn serialize_metadata(&self, _metadata: &Self::Metadata) -> VortexResult> { + Ok(vec![]) + } + + fn deserialize_metadata(&self, _data: &[u8]) -> VortexResult { + Ok(EmptyMetadata) + } + fn validate_dtype( &self, _options: &Self::Metadata, @@ -176,6 +204,15 @@ mod tests { ) -> VortexResult<()> { Ok(()) } + + fn unpack_native<'a>( + &self, + _metadata: &'a Self::Metadata, + _storage_dtype: &'a DType, + _storage_value: &'a ScalarValue, + ) -> VortexResult> { + Ok("") + } } let ext_dtype1 = ExtDType::::try_new( diff --git a/vortex-array/src/dtype/extension/mod.rs b/vortex-array/src/dtype/extension/mod.rs index ed5d5b76b0b..65eb0ea13a6 100644 --- a/vortex-array/src/dtype/extension/mod.rs +++ b/vortex-array/src/dtype/extension/mod.rs @@ -1,7 +1,7 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors -//! Extension DTypes, and interfaces for working with extension types (dtypes, scalars, and arrays). +//! Extension DTypes, and interfaces for working with extension types. //! //! ## File layout convention //! diff --git a/vortex-array/src/dtype/extension/plugin.rs b/vortex-array/src/dtype/extension/plugin.rs index 2b45b140849..abde9c45dac 100644 --- a/vortex-array/src/dtype/extension/plugin.rs +++ b/vortex-array/src/dtype/extension/plugin.rs @@ -32,7 +32,7 @@ impl ExtDTypePlugin for V { } fn deserialize(&self, data: &[u8], storage_dtype: DType) -> VortexResult { - let metadata = ExtVTable::deserialize(self, data)?; + let metadata = ExtVTable::deserialize_metadata(self, data)?; Ok(ExtDType::try_with_vtable(self.clone(), metadata, storage_dtype)?.erased()) } } diff --git a/vortex-array/src/dtype/extension/typed.rs b/vortex-array/src/dtype/extension/typed.rs index 8e69df584d4..ab61309e0dd 100644 --- a/vortex-array/src/dtype/extension/typed.rs +++ b/vortex-array/src/dtype/extension/typed.rs @@ -168,7 +168,7 @@ impl DynExtDType for ExtDTypeInner { } fn metadata_serialize(&self) -> VortexResult> { - V::serialize(&self.vtable, &self.metadata) + V::serialize_metadata(&self.vtable, &self.metadata) } fn with_nullability(&self, nullability: Nullability) -> ExtDTypeRef { diff --git a/vortex-array/src/dtype/extension/vtable.rs b/vortex-array/src/dtype/extension/vtable.rs index e37a7be400b..84f530ea06f 100644 --- a/vortex-array/src/dtype/extension/vtable.rs +++ b/vortex-array/src/dtype/extension/vtable.rs @@ -6,40 +6,70 @@ use std::fmt::Display; use std::hash::Hash; use vortex_error::VortexResult; -use vortex_error::vortex_bail; use crate::dtype::DType; use crate::dtype::extension::ExtId; +use crate::scalar::ScalarValue; /// The public API for defining new extension types. /// /// This is the non-object-safe trait that plugin authors implement to define a new extension /// type. It specifies the type's identity, metadata, serialization, and validation. pub trait ExtVTable: 'static + Sized + Send + Sync + Clone + Debug + Eq + Hash { - /// Associated type containing the deserialized metadata for this extension type + /// Associated type containing the deserialized metadata for this extension type. type Metadata: 'static + Send + Sync + Clone + Debug + Display + Eq + Hash; + /// A native Rust value that represents a scalar of the extension type. + /// + /// The value only represents non-null values. We denote nullable values as `Option`. + type NativeValue<'a>: Display; + /// Returns the ID for this extension type. fn id(&self) -> ExtId; + // Methods related to the extension `DType`. + /// Serialize the metadata into a byte vector. - fn serialize(&self, metadata: &Self::Metadata) -> VortexResult> { - _ = metadata; - vortex_bail!( - "Serialization not implemented for extension type {}", - self.id() - ); - } + fn serialize_metadata(&self, metadata: &Self::Metadata) -> VortexResult>; /// Deserialize the metadata from a byte slice. - fn deserialize(&self, metadata: &[u8]) -> VortexResult { - _ = metadata; - vortex_bail!( - "Deserialization not implemented for extension type {}", - self.id() - ); - } + fn deserialize_metadata(&self, metadata: &[u8]) -> VortexResult; /// Validate that the given storage type is compatible with this extension type. fn validate_dtype(&self, metadata: &Self::Metadata, storage_dtype: &DType) -> VortexResult<()>; + + // Methods related to the extension scalar values. + + /// Validate the given storage value is compatible with the extension type. + /// + /// By default, this calls [`unpack_native()`](ExtVTable::unpack_native) and discards the result. + /// + /// # Errors + /// + /// Returns an error if the storage [`ScalarValue`] is not compatible with the extension type. + fn validate_scalar_value( + &self, + metadata: &Self::Metadata, + storage_dtype: &DType, + storage_value: &ScalarValue, + ) -> VortexResult<()> { + self.unpack_native(metadata, storage_dtype, storage_value) + .map(|_| ()) + } + + /// Validate and unpack a native value from the storage [`ScalarValue`]. + /// + /// Note that [`ExtVTable::validate_dtype()`] is always called first to validate the storage + /// [`DType`], and the [`Scalar`](crate::scalar::Scalar) implementation will verify that the + /// storage value is compatible with the storage dtype on construction. + /// + /// # Errors + /// + /// Returns an error if the storage [`ScalarValue`] is not compatible with the extension type. + fn unpack_native<'a>( + &self, + metadata: &'a Self::Metadata, + storage_dtype: &'a DType, + storage_value: &'a ScalarValue, + ) -> VortexResult>; } diff --git a/vortex-array/src/extension/datetime/date.rs b/vortex-array/src/extension/datetime/date.rs index ed29a0ae330..c31a2b2f058 100644 --- a/vortex-array/src/extension/datetime/date.rs +++ b/vortex-array/src/extension/datetime/date.rs @@ -1,8 +1,12 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors +use std::fmt; + +use jiff::Span; use vortex_error::VortexExpect; use vortex_error::VortexResult; +use vortex_error::vortex_bail; use vortex_error::vortex_ensure; use vortex_error::vortex_err; @@ -13,11 +17,25 @@ use crate::dtype::extension::ExtDType; use crate::dtype::extension::ExtId; use crate::dtype::extension::ExtVTable; use crate::extension::datetime::TimeUnit; +use crate::scalar::ScalarValue; + +/// The Unix epoch date (1970-01-01). +const EPOCH: jiff::civil::Date = jiff::civil::Date::constant(1970, 1, 1); /// Date DType. #[derive(Clone, Debug, Default, PartialEq, Eq, Hash)] pub struct Date; +fn date_ptype(time_unit: &TimeUnit) -> Option { + match time_unit { + TimeUnit::Nanoseconds => None, + TimeUnit::Microseconds => None, + TimeUnit::Milliseconds => Some(PType::I64), + TimeUnit::Seconds => None, + TimeUnit::Days => Some(PType::I32), + } +} + impl Date { /// Creates a new Date extension dtype with the given time unit and nullability. /// @@ -38,18 +56,37 @@ impl Date { } } +/// Unpacked value of a [`Date`] extension scalar. +pub enum DateValue { + /// Days since the Unix epoch. + Days(i32), + /// Milliseconds since the Unix epoch. + Milliseconds(i64), +} + +impl fmt::Display for DateValue { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let date = match self { + DateValue::Days(days) => EPOCH + Span::new().days(*days), + DateValue::Milliseconds(ms) => EPOCH + Span::new().milliseconds(*ms), + }; + write!(f, "{}", date) + } +} + impl ExtVTable for Date { type Metadata = TimeUnit; + type NativeValue<'a> = DateValue; fn id(&self) -> ExtId { ExtId::new_ref("vortex.date") } - fn serialize(&self, metadata: &Self::Metadata) -> VortexResult> { + fn serialize_metadata(&self, metadata: &Self::Metadata) -> VortexResult> { Ok(vec![u8::from(*metadata)]) } - fn deserialize(&self, metadata: &[u8]) -> VortexResult { + fn deserialize_metadata(&self, metadata: &[u8]) -> VortexResult { let tag = metadata[0]; TimeUnit::try_from(tag) } @@ -67,14 +104,19 @@ impl ExtVTable for Date { Ok(()) } -} -fn date_ptype(time_unit: &TimeUnit) -> Option { - match time_unit { - TimeUnit::Nanoseconds => None, - TimeUnit::Microseconds => None, - TimeUnit::Milliseconds => Some(PType::I64), - TimeUnit::Seconds => None, - TimeUnit::Days => Some(PType::I32), + fn unpack_native( + &self, + metadata: &Self::Metadata, + _storage_dtype: &DType, + storage_value: &ScalarValue, + ) -> VortexResult> { + match metadata { + TimeUnit::Milliseconds => Ok(DateValue::Milliseconds( + storage_value.as_primitive().cast::()?, + )), + TimeUnit::Days => Ok(DateValue::Days(storage_value.as_primitive().cast::()?)), + _ => vortex_bail!("Date type does not support time unit {}", metadata), + } } } diff --git a/vortex-array/src/extension/datetime/time.rs b/vortex-array/src/extension/datetime/time.rs index e1e630f237b..a13352afe92 100644 --- a/vortex-array/src/extension/datetime/time.rs +++ b/vortex-array/src/extension/datetime/time.rs @@ -1,8 +1,12 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors +use std::fmt; + +use jiff::Span; use vortex_error::VortexExpect; use vortex_error::VortexResult; +use vortex_error::vortex_bail; use vortex_error::vortex_ensure; use vortex_error::vortex_err; @@ -13,15 +17,24 @@ use crate::dtype::extension::ExtDType; use crate::dtype::extension::ExtId; use crate::dtype::extension::ExtVTable; use crate::extension::datetime::TimeUnit; +use crate::scalar::ScalarValue; /// Time DType. #[derive(Clone, Debug, Default, PartialEq, Eq, Hash)] pub struct Time; +fn time_ptype(time_unit: &TimeUnit) -> Option { + Some(match time_unit { + TimeUnit::Nanoseconds | TimeUnit::Microseconds => PType::I64, + TimeUnit::Milliseconds | TimeUnit::Seconds => PType::I32, + TimeUnit::Days => return None, + }) +} + impl Time { /// Creates a new Time extension dtype with the given time unit and nullability. /// - /// Note that only Milliseconds and Days time units are supported for Time. + /// Note that Days units are not supported for Time. pub fn try_new(time_unit: TimeUnit, nullability: Nullability) -> VortexResult> { let ptype = time_ptype(&time_unit) .ok_or_else(|| vortex_err!("Time type does not support time unit {}", time_unit))?; @@ -34,18 +47,47 @@ impl Time { } } +/// Unpacked value of a [`Time`] extension scalar. +pub enum TimeValue { + /// Seconds since midnight. + Seconds(i32), + /// Milliseconds since midnight. + Milliseconds(i32), + /// Microseconds since midnight. + Microseconds(i64), + /// Nanoseconds since midnight. + Nanoseconds(i64), +} + +impl fmt::Display for TimeValue { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let min = jiff::civil::Time::MIN; + + let time = match self { + TimeValue::Seconds(s) => min + Span::new().seconds(*s), + TimeValue::Milliseconds(ms) => min + Span::new().milliseconds(*ms), + TimeValue::Microseconds(us) => min + Span::new().microseconds(*us), + TimeValue::Nanoseconds(ns) => min + Span::new().nanoseconds(*ns), + }; + + write!(f, "{}", time) + } +} + impl ExtVTable for Time { type Metadata = TimeUnit; + type NativeValue<'a> = TimeValue; + fn id(&self) -> ExtId { ExtId::new_ref("vortex.time") } - fn serialize(&self, metadata: &Self::Metadata) -> VortexResult> { + fn serialize_metadata(&self, metadata: &Self::Metadata) -> VortexResult> { Ok(vec![u8::from(*metadata)]) } - fn deserialize(&self, data: &[u8]) -> VortexResult { + fn deserialize_metadata(&self, data: &[u8]) -> VortexResult { let tag = data[0]; TimeUnit::try_from(tag) } @@ -63,12 +105,42 @@ impl ExtVTable for Time { Ok(()) } -} -fn time_ptype(time_unit: &TimeUnit) -> Option { - Some(match time_unit { - TimeUnit::Nanoseconds | TimeUnit::Microseconds => PType::I64, - TimeUnit::Milliseconds | TimeUnit::Seconds => PType::I32, - TimeUnit::Days => return None, - }) + fn unpack_native( + &self, + metadata: &Self::Metadata, + _storage_dtype: &DType, + storage_value: &ScalarValue, + ) -> VortexResult> { + let length_of_time = storage_value.as_primitive().cast::()?; + + let (span, value) = match *metadata { + TimeUnit::Seconds => { + let v = i32::try_from(length_of_time) + .map_err(|e| vortex_err!("Time seconds value out of i32 range: {e}"))?; + (Span::new().seconds(v), TimeValue::Seconds(v)) + } + TimeUnit::Milliseconds => { + let v = i32::try_from(length_of_time) + .map_err(|e| vortex_err!("Time milliseconds value out of i32 range: {e}"))?; + (Span::new().milliseconds(v), TimeValue::Milliseconds(v)) + } + TimeUnit::Microseconds => ( + Span::new().microseconds(length_of_time), + TimeValue::Microseconds(length_of_time), + ), + TimeUnit::Nanoseconds => ( + Span::new().nanoseconds(length_of_time), + TimeValue::Nanoseconds(length_of_time), + ), + d @ TimeUnit::Days => vortex_bail!("Time type does not support time unit {d}"), + }; + + // Validate the storage value is within the valid range for Time. + jiff::civil::Time::MIN + .checked_add(span) + .map_err(|e| vortex_err!("Invalid time scalar: {}", e))?; + + Ok(value) + } } diff --git a/vortex-array/src/extension/datetime/timestamp.rs b/vortex-array/src/extension/datetime/timestamp.rs index 5b49f4b0ff9..fd3636d9066 100644 --- a/vortex-array/src/extension/datetime/timestamp.rs +++ b/vortex-array/src/extension/datetime/timestamp.rs @@ -3,12 +3,13 @@ //! Temporal extension data types. -use std::fmt::Display; -use std::fmt::Formatter; +use std::fmt; use std::sync::Arc; +use jiff::Span; use vortex_error::VortexExpect; use vortex_error::VortexResult; +use vortex_error::vortex_bail; use vortex_error::vortex_ensure; use vortex_error::vortex_err; use vortex_error::vortex_panic; @@ -20,6 +21,7 @@ use crate::dtype::extension::ExtDType; use crate::dtype::extension::ExtId; use crate::dtype::extension::ExtVTable; use crate::extension::datetime::TimeUnit; +use crate::scalar::ScalarValue; /// Timestamp DType. #[derive(Clone, Debug, Default, PartialEq, Eq, Hash)] @@ -63,8 +65,8 @@ pub struct TimestampOptions { pub tz: Option>, } -impl Display for TimestampOptions { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { +impl fmt::Display for TimestampOptions { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match &self.tz { Some(tz) => write!(f, "{}, tz={}", self.unit, tz), None => write!(f, "{}", self.unit), @@ -72,16 +74,52 @@ impl Display for TimestampOptions { } } +/// Unpacked value of a [`Timestamp`] extension scalar. +/// +/// Each variant carries the raw storage value and an optional timezone. +pub enum TimestampValue<'a> { + /// Seconds since the Unix epoch. + Seconds(i64, Option<&'a Arc>), + /// Milliseconds since the Unix epoch. + Milliseconds(i64, Option<&'a Arc>), + /// Microseconds since the Unix epoch. + Microseconds(i64, Option<&'a Arc>), + /// Nanoseconds since the Unix epoch. + Nanoseconds(i64, Option<&'a Arc>), +} + +impl fmt::Display for TimestampValue<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let (span, tz) = match self { + TimestampValue::Seconds(v, tz) => (Span::new().seconds(*v), *tz), + TimestampValue::Milliseconds(v, tz) => (Span::new().milliseconds(*v), *tz), + TimestampValue::Microseconds(v, tz) => (Span::new().microseconds(*v), *tz), + TimestampValue::Nanoseconds(v, tz) => (Span::new().nanoseconds(*v), *tz), + }; + let ts = jiff::Timestamp::UNIX_EPOCH + span; + + match tz { + None => write!(f, "{ts}"), + Some(tz) => { + let adjusted_ts = ts.in_tz(tz.as_ref()).vortex_expect("unknown timezone"); + write!(f, "{adjusted_ts}",) + } + } + } +} + impl ExtVTable for Timestamp { type Metadata = TimestampOptions; + type NativeValue<'a> = TimestampValue<'a>; + fn id(&self) -> ExtId { ExtId::new_ref("vortex.timestamp") } // NOTE(ngates): unfortunately we're stuck with this hand-rolled serialization format for // backwards compatibility. - fn serialize(&self, metadata: &Self::Metadata) -> VortexResult> { + fn serialize_metadata(&self, metadata: &Self::Metadata) -> VortexResult> { let mut bytes = Vec::with_capacity(4); let unit_tag: u8 = metadata.unit.into(); @@ -102,7 +140,7 @@ impl ExtVTable for Timestamp { Ok(bytes) } - fn deserialize(&self, data: &[u8]) -> VortexResult { + fn deserialize_metadata(&self, data: &[u8]) -> VortexResult { vortex_ensure!(data.len() >= 3); let tag = data[0]; @@ -142,4 +180,46 @@ impl ExtVTable for Timestamp { ); Ok(()) } + + fn unpack_native<'a>( + &self, + metadata: &'a Self::Metadata, + _storage_dtype: &'a DType, + storage_value: &'a ScalarValue, + ) -> VortexResult> { + let ts_value = storage_value.as_primitive().cast::()?; + let tz = metadata.tz.as_ref(); + + let (span, value) = match metadata.unit { + TimeUnit::Nanoseconds => ( + Span::new().nanoseconds(ts_value), + TimestampValue::Nanoseconds(ts_value, tz), + ), + TimeUnit::Microseconds => ( + Span::new().microseconds(ts_value), + TimestampValue::Microseconds(ts_value, tz), + ), + TimeUnit::Milliseconds => ( + Span::new().milliseconds(ts_value), + TimestampValue::Milliseconds(ts_value, tz), + ), + TimeUnit::Seconds => ( + Span::new().seconds(ts_value), + TimestampValue::Seconds(ts_value, tz), + ), + TimeUnit::Days => vortex_bail!("Timestamp does not support Days time unit"), + }; + + // Validate the storage value is within the valid range for Timestamp. + let ts = jiff::Timestamp::UNIX_EPOCH + .checked_add(span) + .map_err(|e| vortex_err!("Invalid timestamp scalar: {}", e))?; + + if let Some(tz) = tz { + ts.in_tz(tz.as_ref()) + .map_err(|e| vortex_err!("Invalid timezone for timestamp scalar: {}", e))?; + } + + Ok(value) + } } diff --git a/vortex-array/src/extension/mod.rs b/vortex-array/src/extension/mod.rs index 1d08bc73e8b..5c29154ded0 100644 --- a/vortex-array/src/extension/mod.rs +++ b/vortex-array/src/extension/mod.rs @@ -7,6 +7,9 @@ use std::fmt; pub mod datetime; +#[cfg(test)] +mod tests; + /// An empty metadata struct for extension dtypes that do not require any metadata. #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct EmptyMetadata; diff --git a/vortex-array/src/extension/tests/divisible_int.rs b/vortex-array/src/extension/tests/divisible_int.rs new file mode 100644 index 00000000000..e946dbdda51 --- /dev/null +++ b/vortex-array/src/extension/tests/divisible_int.rs @@ -0,0 +1,187 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! A test extension type representing unsigned integers divisible by a given divisor. + +use std::fmt; + +use vortex_error::VortexExpect; +use vortex_error::VortexResult; +use vortex_error::vortex_bail; +use vortex_error::vortex_ensure; + +use crate::dtype::DType; +use crate::dtype::Nullability; +use crate::dtype::PType; +use crate::dtype::extension::ExtDType; +use crate::dtype::extension::ExtId; +use crate::dtype::extension::ExtVTable; +use crate::scalar::ScalarValue; + +/// The divisor stored as extension metadata. +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +pub struct Divisor(pub u64); + +impl fmt::Display for Divisor { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "divisible by {}", self.0) + } +} + +/// Extension type for unsigned integers that must be divisible by the metadata divisor. +#[derive(Clone, Debug, Default, PartialEq, Eq, Hash)] +pub struct DivisibleInt; + +impl DivisibleInt { + /// Creates a new divisible integer extension dtype. + pub fn new(divisor: u64, nullability: Nullability) -> ExtDType { + ExtDType::try_new(Divisor(divisor), DType::Primitive(PType::U64, nullability)) + .vortex_expect("valid divisible int dtype") + } +} + +impl ExtVTable for DivisibleInt { + type Metadata = Divisor; + type NativeValue<'a> = u64; + + fn id(&self) -> ExtId { + ExtId::new_ref("test.divisible_int") + } + + fn serialize_metadata(&self, metadata: &Self::Metadata) -> VortexResult> { + Ok(metadata.0.to_le_bytes().to_vec()) + } + + fn deserialize_metadata(&self, data: &[u8]) -> VortexResult { + vortex_ensure!(data.len() == 8, "divisible int metadata must be 8 bytes"); + let bytes: [u8; 8] = data + .try_into() + .map_err(|_| vortex_error::vortex_err!("divisible int metadata must be 8 bytes"))?; + let n = u64::from_le_bytes(bytes); + vortex_ensure!(n > 0, "divisor must be greater than 0"); + Ok(Divisor(n)) + } + + fn validate_dtype( + &self, + _metadata: &Self::Metadata, + storage_dtype: &DType, + ) -> VortexResult<()> { + vortex_ensure!( + matches!(storage_dtype, DType::Primitive(PType::U64, _)), + "divisible int storage dtype must be u64" + ); + Ok(()) + } + + fn unpack_native( + &self, + metadata: &Self::Metadata, + _storage_dtype: &DType, + storage_value: &ScalarValue, + ) -> VortexResult> { + let value = storage_value.as_primitive().cast::()?; + if value % metadata.0 != 0 { + vortex_bail!("{} is not divisible by {}", value, metadata.0); + } + Ok(value) + } +} + +#[cfg(test)] +mod tests { + use vortex_error::VortexResult; + + use super::DivisibleInt; + use super::Divisor; + use crate::dtype::DType; + use crate::dtype::Nullability; + use crate::dtype::PType; + use crate::dtype::extension::ExtVTable; + use crate::scalar::PValue; + use crate::scalar::ScalarValue; + use crate::scalar::extension::ExtScalarValue; + + #[test] + fn accepts_divisible_values() -> VortexResult<()> { + let div7 = DivisibleInt::new(7, Nullability::NonNullable); + + for multiple in [0, 7, 14, 21, 7000] { + let sv = ExtScalarValue::::try_new( + &div7, + ScalarValue::Primitive(PValue::U64(multiple)), + )?; + assert_eq!( + sv.storage_value(), + &ScalarValue::Primitive(PValue::U64(multiple)) + ); + } + + Ok(()) + } + + #[test] + fn rejects_non_divisible_values() -> VortexResult<()> { + let div7 = DivisibleInt::new(7, Nullability::NonNullable); + + for bad in [1, 2, 6, 8, 13, 15] { + assert!( + ExtScalarValue::::try_new( + &div7, + ScalarValue::Primitive(PValue::U64(bad)), + ) + .is_err(), + "{bad} should not be accepted as divisible by 7" + ); + } + + Ok(()) + } + + #[test] + fn metadata_roundtrip() -> VortexResult<()> { + let vtable = DivisibleInt; + let divisor = Divisor(42); + + let bytes = vtable.serialize_metadata(&divisor)?; + let decoded = vtable.deserialize_metadata(&bytes)?; + + assert_eq!(decoded, divisor); + Ok(()) + } + + #[test] + fn rejects_zero_divisor() { + let vtable = DivisibleInt; + let bytes = 0u64.to_le_bytes(); + assert!(vtable.deserialize_metadata(&bytes).is_err()); + } + + #[test] + fn rejects_wrong_storage_dtype() { + let vtable = DivisibleInt; + let divisor = Divisor(10); + + assert!( + vtable + .validate_dtype( + &divisor, + &DType::Primitive(PType::I32, Nullability::NonNullable) + ) + .is_err() + ); + assert!( + vtable + .validate_dtype(&divisor, &DType::Utf8(Nullability::NonNullable)) + .is_err() + ); + assert!( + vtable + .validate_dtype( + &divisor, + &DType::Primitive(PType::U64, Nullability::NonNullable) + ) + .is_ok() + ); + } +} diff --git a/vortex-array/src/extension/tests/mod.rs b/vortex-array/src/extension/tests/mod.rs new file mode 100644 index 00000000000..31df677e61d --- /dev/null +++ b/vortex-array/src/extension/tests/mod.rs @@ -0,0 +1,6 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! Test extension types for exercising the [`ExtVTable`] contract. + +mod divisible_int; diff --git a/vortex-array/src/scalar/arrow.rs b/vortex-array/src/scalar/arrow.rs index 2b15c9e7ef7..e94dac46551 100644 --- a/vortex-array/src/scalar/arrow.rs +++ b/vortex-array/src/scalar/arrow.rs @@ -212,6 +212,7 @@ mod tests { use crate::extension::datetime::TimestampOptions; use crate::scalar::DecimalValue; use crate::scalar::Scalar; + use crate::scalar::ScalarValue; #[test] fn test_null_scalar_to_arrow() { @@ -449,16 +450,17 @@ mod tests { struct SomeExt; impl ExtVTable for SomeExt { type Metadata = String; + type NativeValue<'a> = &'a str; fn id(&self) -> ExtId { ExtId::new_ref("some_ext") } - fn serialize(&self, _options: &Self::Metadata) -> VortexResult> { + fn serialize_metadata(&self, _options: &Self::Metadata) -> VortexResult> { vortex_bail!("not implemented") } - fn deserialize(&self, _data: &[u8]) -> VortexResult { + fn deserialize_metadata(&self, _data: &[u8]) -> VortexResult { vortex_bail!("not implemented") } @@ -469,6 +471,15 @@ mod tests { ) -> VortexResult<()> { Ok(()) } + + fn unpack_native<'a>( + &self, + _metadata: &'a Self::Metadata, + _storage_dtype: &'a DType, + _storage_value: &'a ScalarValue, + ) -> VortexResult> { + Ok("") + } } let scalar = Scalar::extension::( diff --git a/vortex-array/src/scalar/extension/erased.rs b/vortex-array/src/scalar/extension/erased.rs new file mode 100644 index 00000000000..69bad510789 --- /dev/null +++ b/vortex-array/src/scalar/extension/erased.rs @@ -0,0 +1,124 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use std::any::type_name; +use std::cmp::Ordering; +use std::fmt; +use std::hash::Hash; +use std::hash::Hasher; +use std::sync::Arc; + +use vortex_error::VortexExpect; +use vortex_error::vortex_err; + +use crate::dtype::extension::ExtId; +use crate::dtype::extension::ExtVTable; +use crate::scalar::ScalarValue; +use crate::scalar::extension::ExtScalarValue; +use crate::scalar::extension::typed::DynExtScalarValue; +use crate::scalar::extension::typed::ExtScalarValueInner; + +/// A type-erased extension scalar value. +/// +/// This is the extension scalar analog of [`ExtDTypeRef`]: it stores an [`ExtVTable`] +/// and a storage [`ScalarValue`] behind a trait object, allowing heterogeneous storage inside +/// `ScalarValue::Extension` (so that we do not need a generic parameter). +/// +/// You can use [`try_downcast()`] or [`downcast()`] to recover the concrete vtable type as an +/// [`ExtScalarValue`]. +/// +/// [`ExtDTypeRef`]: crate::dtype::extension::ExtDTypeRef +/// [`try_downcast()`]: ExtScalarValueRef::try_downcast +/// [`downcast()`]: ExtScalarValueRef::downcast +#[derive(Clone)] +pub struct ExtScalarValueRef(pub(super) Arc); + +// NB: If you need access to the vtable, you probably want to add a method and implementation to +// `ExtScalarValueInnerImpl` and `ExtScalarValueInner`. +/// Methods for downcasting type-erased extension scalars. +impl ExtScalarValueRef { + /// Returns the [`ExtId`] identifying this extension scalar's type. + pub fn id(&self) -> ExtId { + self.0.id() + } + + /// Returns a reference to the underlying storage [`ScalarValue`]. + pub fn storage_value(&self) -> &ScalarValue { + self.0.storage_value() + } + + /// Attempts to downcast to a concrete [`ExtScalarValue`]. + /// + /// # Errors + /// + /// Returns `Err(self)` if the underlying vtable type does not match `V`. + pub fn try_downcast(self) -> Result, ExtScalarValueRef> { + // `ExtScalarValueInner` is the only implementor of `ExtScalarValueInnerImpl` (due to + // the sealed implementation below), so if the vtable is correct, we know the type can be + // downcasted and reinterpreted safely. + if !self.0.as_any().is::>() { + return Err(self); + } + + let ptr = Arc::into_raw(self.0) as *const ExtScalarValueInner; + // SAFETY: We verified the type matches above, so the size and alignment are correct. + let inner = unsafe { Arc::from_raw(ptr) }; + + Ok(ExtScalarValue(inner)) + } + + /// Downcasts to a concrete [`ExtScalarValue`]. + /// + /// # Panics + /// + /// Panics if the underlying vtable type does not match `V`. + pub fn downcast(self) -> ExtScalarValue { + self.try_downcast::() + .map_err(|this| { + vortex_err!( + "Failed to downcast ExtScalar {} to {}", + this.0.id(), + type_name::(), + ) + }) + .vortex_expect("Failed to downcast ExtScalar") + } +} + +impl fmt::Display for ExtScalarValueRef { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}({})", self.0.id(), self.0.storage_value()) + } +} + +impl fmt::Debug for ExtScalarValueRef { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("ExtScalar") + .field("id", &self.0.id()) + .field("storage_value", self.0.storage_value()) + .finish() + } +} + +// TODO(connor): In the future we may want to allow implementors to customize this behavior. + +impl PartialEq for ExtScalarValueRef { + fn eq(&self, other: &Self) -> bool { + self.0.id() == other.0.id() && self.0.storage_value() == other.0.storage_value() + } +} +impl Eq for ExtScalarValueRef {} + +impl PartialOrd for ExtScalarValueRef { + fn partial_cmp(&self, other: &Self) -> Option { + // TODO(connor): Should this check if the IDs are equal before ordering? + self.0.storage_value().partial_cmp(other.0.storage_value()) + } +} + +impl Hash for ExtScalarValueRef { + fn hash(&self, state: &mut H) { + self.0.id().hash(state); + self.0.storage_value().hash(state); + } +} diff --git a/vortex-array/src/scalar/extension/mod.rs b/vortex-array/src/scalar/extension/mod.rs new file mode 100644 index 00000000000..c3e2608b37a --- /dev/null +++ b/vortex-array/src/scalar/extension/mod.rs @@ -0,0 +1,35 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! Extension Scalar Values, and interfaces for working with them. +//! +//! We define normal [`Scalar`]s as the combination of a [`ScalarValue`] and a [`DType`]. +//! +//! Similarly, we define an extension [`Scalar`] as the combination of an [`ExtScalarValueRef`] and +//! an [`ExtDTypeRef`]. +//! +//! [`Scalar`]: crate::scalar::Scalar +//! [`ScalarValue`]: crate::scalar::ScalarValue +//! [`DType`]: crate::dtype::DType +//! [`ExtDTypeRef`]: crate::dtype::extension::ExtDTypeRef + +mod typed; +pub use typed::ExtScalarValue; + +mod erased; +pub use erased::ExtScalarValueRef; + +/// Private module to seal [`DynExtScalarValue`]. +mod sealed { + use crate::dtype::extension::ExtVTable; + use crate::scalar::extension::typed::ExtScalarValueInner; + + /// Marker trait to prevent external implementations of [`DynExtScalarValue`]. + pub(super) trait Sealed {} + + /// This can be the **only** implementor for [`super::typed::DynExtScalarValue`]. + impl Sealed for ExtScalarValueInner {} +} + +#[cfg(test)] +mod tests; diff --git a/vortex-array/src/scalar/extension/tests.rs b/vortex-array/src/scalar/extension/tests.rs new file mode 100644 index 00000000000..feadd2f319d --- /dev/null +++ b/vortex-array/src/scalar/extension/tests.rs @@ -0,0 +1,142 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use std::sync::Arc; + +use vortex_error::VortexResult; + +use crate::dtype::Nullability; +use crate::extension::datetime::Date; +use crate::extension::datetime::Time; +use crate::extension::datetime::TimeUnit; +use crate::extension::datetime::Timestamp; +use crate::scalar::PValue; +use crate::scalar::ScalarValue; +use crate::scalar::extension::ExtScalarValue; +use crate::scalar::extension::ExtScalarValueRef; + +#[test] +fn try_new_date_valid() -> VortexResult<()> { + let ext_dtype = Date::new(TimeUnit::Days, Nullability::NonNullable); + let storage = ScalarValue::Primitive(PValue::I32(100)); + + let sv = ExtScalarValue::::try_new(&ext_dtype, storage.clone())?; + + assert_eq!(sv.id().as_ref(), "vortex.date"); + assert_eq!(sv.storage_value(), &storage); + assert_eq!(sv.vtable(), &Date); + Ok(()) +} + +#[test] +fn try_new_time_rejects_out_of_range() -> VortexResult<()> { + let ext_dtype = Time::new(TimeUnit::Seconds, Nullability::NonNullable); + + // 25 hours in seconds exceeds valid time-of-day range. + let too_large = ScalarValue::Primitive(PValue::I32(90_000)); + assert!(ExtScalarValue::