diff --git a/encodings/sequence/src/array.rs b/encodings/sequence/src/array.rs index ae280882ed5..09c3fe27183 100644 --- a/encodings/sequence/src/array.rs +++ b/encodings/sequence/src/array.rs @@ -324,16 +324,14 @@ impl VTable for SequenceVTable { _buffers: &[BufferHandle], _session: &VortexSession, ) -> VortexResult { - let prost = ProstMetadata( - as DeserializeMetadata>::deserialize(bytes)?, - ); + let prost = + as DeserializeMetadata>::deserialize(bytes)?; let ptype = dtype.as_ptype(); // We go via Scalar to validate that the value is valid for the ptype. let base = Scalar::from_proto_value( prost - .0 .base .as_ref() .ok_or_else(|| vortex_err!("base required"))?, @@ -345,7 +343,6 @@ impl VTable for SequenceVTable { let multiplier = Scalar::from_proto_value( prost - .0 .multiplier .as_ref() .ok_or_else(|| vortex_err!("multiplier required"))?, diff --git a/encodings/sparse/public-api.lock b/encodings/sparse/public-api.lock index f4456984bb1..ed3682bdffc 100644 --- a/encodings/sparse/public-api.lock +++ b/encodings/sparse/public-api.lock @@ -1,5 +1,25 @@ pub mod vortex_sparse +#[repr(C)] pub struct vortex_sparse::ProstPatchesMetadata + +impl core::clone::Clone for vortex_sparse::ProstPatchesMetadata + +pub fn vortex_sparse::ProstPatchesMetadata::clone(&self) -> vortex_sparse::ProstPatchesMetadata + +impl core::default::Default for vortex_sparse::ProstPatchesMetadata + +pub fn vortex_sparse::ProstPatchesMetadata::default() -> Self + +impl core::fmt::Debug for vortex_sparse::ProstPatchesMetadata + +pub fn vortex_sparse::ProstPatchesMetadata::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result + +impl prost::message::Message for vortex_sparse::ProstPatchesMetadata + +pub fn vortex_sparse::ProstPatchesMetadata::clear(&mut self) + +pub fn vortex_sparse::ProstPatchesMetadata::encoded_len(&self) -> usize + pub struct vortex_sparse::SparseArray impl vortex_sparse::SparseArray @@ -42,26 +62,12 @@ impl vortex_array::array::IntoArray for vortex_sparse::SparseArray pub fn vortex_sparse::SparseArray::into_array(self) -> vortex_array::array::ArrayRef -#[repr(C)] pub struct vortex_sparse::SparseMetadata - -impl core::clone::Clone for vortex_sparse::SparseMetadata - -pub fn vortex_sparse::SparseMetadata::clone(&self) -> vortex_sparse::SparseMetadata - -impl core::default::Default for vortex_sparse::SparseMetadata - -pub fn vortex_sparse::SparseMetadata::default() -> Self +pub struct vortex_sparse::SparseMetadata impl core::fmt::Debug for vortex_sparse::SparseMetadata pub fn vortex_sparse::SparseMetadata::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result -impl prost::message::Message for vortex_sparse::SparseMetadata - -pub fn vortex_sparse::SparseMetadata::clear(&mut self) - -pub fn vortex_sparse::SparseMetadata::encoded_len(&self) -> usize - pub struct vortex_sparse::SparseVTable impl vortex_sparse::SparseVTable @@ -96,7 +102,7 @@ impl vortex_array::vtable::VTable for vortex_sparse::SparseVTable pub type vortex_sparse::SparseVTable::Array = vortex_sparse::SparseArray -pub type vortex_sparse::SparseVTable::Metadata = vortex_array::metadata::ProstMetadata +pub type vortex_sparse::SparseVTable::Metadata = vortex_sparse::SparseMetadata pub type vortex_sparse::SparseVTable::OperationsVTable = vortex_sparse::SparseVTable @@ -110,13 +116,13 @@ pub fn vortex_sparse::SparseVTable::buffer(array: &vortex_sparse::SparseArray, i pub fn vortex_sparse::SparseVTable::buffer_name(_array: &vortex_sparse::SparseArray, idx: usize) -> core::option::Option -pub fn vortex_sparse::SparseVTable::build(dtype: &vortex_array::dtype::DType, len: usize, metadata: &Self::Metadata, buffers: &[vortex_array::buffer::BufferHandle], children: &dyn vortex_array::serde::ArrayChildren) -> vortex_error::VortexResult +pub fn vortex_sparse::SparseVTable::build(dtype: &vortex_array::dtype::DType, len: usize, metadata: &Self::Metadata, _buffers: &[vortex_array::buffer::BufferHandle], children: &dyn vortex_array::serde::ArrayChildren) -> vortex_error::VortexResult pub fn vortex_sparse::SparseVTable::child(array: &vortex_sparse::SparseArray, idx: usize) -> vortex_array::array::ArrayRef pub fn vortex_sparse::SparseVTable::child_name(_array: &vortex_sparse::SparseArray, idx: usize) -> alloc::string::String -pub fn vortex_sparse::SparseVTable::deserialize(bytes: &[u8], _dtype: &vortex_array::dtype::DType, _len: usize, _buffers: &[vortex_array::buffer::BufferHandle], _session: &vortex_session::VortexSession) -> vortex_error::VortexResult +pub fn vortex_sparse::SparseVTable::deserialize(bytes: &[u8], dtype: &vortex_array::dtype::DType, _len: usize, buffers: &[vortex_array::buffer::BufferHandle], _session: &vortex_session::VortexSession) -> vortex_error::VortexResult pub fn vortex_sparse::SparseVTable::dtype(array: &vortex_sparse::SparseArray) -> &vortex_array::dtype::DType diff --git a/encodings/sparse/src/lib.rs b/encodings/sparse/src/lib.rs index df41d26a49f..29c8d5c190e 100644 --- a/encodings/sparse/src/lib.rs +++ b/encodings/sparse/src/lib.rs @@ -10,6 +10,7 @@ use vortex_array::Array; use vortex_array::ArrayEq; use vortex_array::ArrayHash; use vortex_array::ArrayRef; +use vortex_array::DeserializeMetadata; use vortex_array::ExecutionCtx; use vortex_array::IntoArray; use vortex_array::Precision; @@ -43,6 +44,7 @@ use vortex_error::VortexExpect as _; use vortex_error::VortexResult; use vortex_error::vortex_bail; use vortex_error::vortex_ensure; +use vortex_error::vortex_ensure_eq; use vortex_error::vortex_panic; use vortex_mask::AllOr; use vortex_mask::Mask; @@ -60,9 +62,15 @@ mod slice; vtable!(Sparse); +#[derive(Debug)] +pub struct SparseMetadata { + patches: PatchesMetadata, + fill_value: Scalar, +} + #[derive(Clone, prost::Message)] #[repr(C)] -pub struct SparseMetadata { +pub struct ProstPatchesMetadata { #[prost(message, required, tag = "1")] patches: PatchesMetadata, } @@ -70,7 +78,7 @@ pub struct SparseMetadata { impl VTable for SparseVTable { type Array = SparseArray; - type Metadata = ProstMetadata; + type Metadata = SparseMetadata; type OperationsVTable = Self; type ValidityVTable = Self; @@ -134,65 +142,87 @@ impl VTable for SparseVTable { } fn metadata(array: &SparseArray) -> VortexResult { - Ok(ProstMetadata(SparseMetadata { - patches: array.patches().to_metadata(array.len(), array.dtype())?, - })) + let patches = array.patches().to_metadata(array.len(), array.dtype())?; + + Ok(SparseMetadata { + patches, + fill_value: array.fill_value.clone(), + }) } fn serialize(metadata: Self::Metadata) -> VortexResult>> { - Ok(Some(metadata.0.encode_to_vec())) + let prost_patches = ProstPatchesMetadata { + patches: metadata.patches, + }; + + // Note that we DO NOT serialize the fill value since that is stored in the buffers. + Ok(Some(prost_patches.encode_to_vec())) } fn deserialize( bytes: &[u8], - _dtype: &DType, + dtype: &DType, _len: usize, - _buffers: &[BufferHandle], + buffers: &[BufferHandle], _session: &VortexSession, ) -> VortexResult { - Ok(ProstMetadata(SparseMetadata::decode(bytes)?)) + let prost_patches = + as DeserializeMetadata>::deserialize(bytes)?; + + // Once we have the patches metadata, we need to get the fill value from the buffers. + + if buffers.len() != 1 { + vortex_bail!("Expected 1 buffer, got {}", buffers.len()); + } + let scalar_bytes: &[u8] = &buffers[0].clone().try_to_host_sync()?; + + let scalar_value = ScalarValue::from_proto_bytes(scalar_bytes, dtype)?; + let fill_value = Scalar::try_new(dtype.clone(), scalar_value)?; + + Ok(SparseMetadata { + patches: prost_patches.patches, + fill_value, + }) } fn build( dtype: &DType, len: usize, metadata: &Self::Metadata, - buffers: &[BufferHandle], + _buffers: &[BufferHandle], children: &dyn ArrayChildren, ) -> VortexResult { - if children.len() != 2 { - vortex_bail!( - "Expected 2 children for sparse encoding, found {}", - children.len() - ) - } - vortex_ensure!( - metadata.0.patches.offset()? == 0, + vortex_ensure_eq!( + children.len(), + 2, + "SparseArray expects 2 children for sparse encoding, found {}", + children.len() + ); + vortex_ensure_eq!( + metadata.patches.offset()?, + 0, "Patches must start at offset 0" ); let patch_indices = children.get( 0, - &metadata.0.patches.indices_dtype()?, - metadata.0.patches.len()?, + &metadata.patches.indices_dtype()?, + metadata.patches.len()?, )?; - let patch_values = children.get(1, dtype, metadata.0.patches.len()?)?; + let patch_values = children.get(1, dtype, metadata.patches.len()?)?; - if buffers.len() != 1 { - vortex_bail!("Expected 1 buffer, got {}", buffers.len()); - } - - let bytes: &[u8] = &buffers[0].clone().try_to_host_sync()?; - let scalar_value = ScalarValue::from_proto_bytes(bytes, dtype)?; - - let fill_value = Scalar::try_new(dtype.clone(), scalar_value)?; - - SparseArray::try_new(patch_indices, patch_values, len, fill_value) + SparseArray::try_new( + patch_indices, + patch_values, + len, + metadata.fill_value.clone(), + ) } fn with_children(array: &mut Self::Array, children: Vec) -> VortexResult<()> { - vortex_ensure!( - children.len() == 2, + vortex_ensure_eq!( + children.len(), + 2, "SparseArray expects 2 children, got {}", children.len() );