Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions encodings/sequence/src/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -324,16 +324,14 @@ impl VTable for SequenceVTable {
_buffers: &[BufferHandle],
_session: &VortexSession,
) -> VortexResult<Self::Metadata> {
let prost = ProstMetadata(
<ProstMetadata<ProstSequenceMetadata> as DeserializeMetadata>::deserialize(bytes)?,
);
let prost =
<ProstMetadata<ProstSequenceMetadata> as DeserializeMetadata>::deserialize(bytes)?;

let ptype = dtype.as_ptype();

// We go via Scalar to validate that the value is valid for the ptype.
let base = Scalar::from_proto_value(
prost
.0
.base
.as_ref()
.ok_or_else(|| vortex_err!("base required"))?,
Expand All @@ -345,7 +343,6 @@ impl VTable for SequenceVTable {

let multiplier = Scalar::from_proto_value(
prost
.0
.multiplier
.as_ref()
.ok_or_else(|| vortex_err!("multiplier required"))?,
Expand Down
42 changes: 24 additions & 18 deletions encodings/sparse/public-api.lock
Original file line number Diff line number Diff line change
@@ -1,5 +1,25 @@
pub mod vortex_sparse

#[repr(C)] pub struct vortex_sparse::ProstPatchesMetadata

impl core::clone::Clone for vortex_sparse::ProstPatchesMetadata

pub fn vortex_sparse::ProstPatchesMetadata::clone(&self) -> vortex_sparse::ProstPatchesMetadata

impl core::default::Default for vortex_sparse::ProstPatchesMetadata

pub fn vortex_sparse::ProstPatchesMetadata::default() -> Self

impl core::fmt::Debug for vortex_sparse::ProstPatchesMetadata

pub fn vortex_sparse::ProstPatchesMetadata::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result

impl prost::message::Message for vortex_sparse::ProstPatchesMetadata

pub fn vortex_sparse::ProstPatchesMetadata::clear(&mut self)

pub fn vortex_sparse::ProstPatchesMetadata::encoded_len(&self) -> usize

pub struct vortex_sparse::SparseArray

impl vortex_sparse::SparseArray
Expand Down Expand Up @@ -42,26 +62,12 @@ impl vortex_array::array::IntoArray for vortex_sparse::SparseArray

pub fn vortex_sparse::SparseArray::into_array(self) -> vortex_array::array::ArrayRef

#[repr(C)] pub struct vortex_sparse::SparseMetadata

impl core::clone::Clone for vortex_sparse::SparseMetadata

pub fn vortex_sparse::SparseMetadata::clone(&self) -> vortex_sparse::SparseMetadata

impl core::default::Default for vortex_sparse::SparseMetadata

pub fn vortex_sparse::SparseMetadata::default() -> Self
pub struct vortex_sparse::SparseMetadata

impl core::fmt::Debug for vortex_sparse::SparseMetadata

pub fn vortex_sparse::SparseMetadata::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result

impl prost::message::Message for vortex_sparse::SparseMetadata

pub fn vortex_sparse::SparseMetadata::clear(&mut self)

pub fn vortex_sparse::SparseMetadata::encoded_len(&self) -> usize

pub struct vortex_sparse::SparseVTable

impl vortex_sparse::SparseVTable
Expand Down Expand Up @@ -96,7 +102,7 @@ impl vortex_array::vtable::VTable for vortex_sparse::SparseVTable

pub type vortex_sparse::SparseVTable::Array = vortex_sparse::SparseArray

pub type vortex_sparse::SparseVTable::Metadata = vortex_array::metadata::ProstMetadata<vortex_sparse::SparseMetadata>
pub type vortex_sparse::SparseVTable::Metadata = vortex_sparse::SparseMetadata

pub type vortex_sparse::SparseVTable::OperationsVTable = vortex_sparse::SparseVTable

Expand All @@ -110,13 +116,13 @@ pub fn vortex_sparse::SparseVTable::buffer(array: &vortex_sparse::SparseArray, i

pub fn vortex_sparse::SparseVTable::buffer_name(_array: &vortex_sparse::SparseArray, idx: usize) -> core::option::Option<alloc::string::String>

pub fn vortex_sparse::SparseVTable::build(dtype: &vortex_array::dtype::DType, len: usize, metadata: &Self::Metadata, buffers: &[vortex_array::buffer::BufferHandle], children: &dyn vortex_array::serde::ArrayChildren) -> vortex_error::VortexResult<vortex_sparse::SparseArray>
pub fn vortex_sparse::SparseVTable::build(dtype: &vortex_array::dtype::DType, len: usize, metadata: &Self::Metadata, _buffers: &[vortex_array::buffer::BufferHandle], children: &dyn vortex_array::serde::ArrayChildren) -> vortex_error::VortexResult<vortex_sparse::SparseArray>

pub fn vortex_sparse::SparseVTable::child(array: &vortex_sparse::SparseArray, idx: usize) -> vortex_array::array::ArrayRef

pub fn vortex_sparse::SparseVTable::child_name(_array: &vortex_sparse::SparseArray, idx: usize) -> alloc::string::String

pub fn vortex_sparse::SparseVTable::deserialize(bytes: &[u8], _dtype: &vortex_array::dtype::DType, _len: usize, _buffers: &[vortex_array::buffer::BufferHandle], _session: &vortex_session::VortexSession) -> vortex_error::VortexResult<Self::Metadata>
pub fn vortex_sparse::SparseVTable::deserialize(bytes: &[u8], dtype: &vortex_array::dtype::DType, _len: usize, buffers: &[vortex_array::buffer::BufferHandle], _session: &vortex_session::VortexSession) -> vortex_error::VortexResult<Self::Metadata>

pub fn vortex_sparse::SparseVTable::dtype(array: &vortex_sparse::SparseArray) -> &vortex_array::dtype::DType

Expand Down
96 changes: 63 additions & 33 deletions encodings/sparse/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use vortex_array::Array;
use vortex_array::ArrayEq;
use vortex_array::ArrayHash;
use vortex_array::ArrayRef;
use vortex_array::DeserializeMetadata;
use vortex_array::ExecutionCtx;
use vortex_array::IntoArray;
use vortex_array::Precision;
Expand Down Expand Up @@ -43,6 +44,7 @@ use vortex_error::VortexExpect as _;
use vortex_error::VortexResult;
use vortex_error::vortex_bail;
use vortex_error::vortex_ensure;
use vortex_error::vortex_ensure_eq;
use vortex_error::vortex_panic;
use vortex_mask::AllOr;
use vortex_mask::Mask;
Expand All @@ -60,17 +62,23 @@ mod slice;

vtable!(Sparse);

#[derive(Debug)]
pub struct SparseMetadata {
patches: PatchesMetadata,
fill_value: Scalar,
}

#[derive(Clone, prost::Message)]
#[repr(C)]
pub struct SparseMetadata {
pub struct ProstPatchesMetadata {
#[prost(message, required, tag = "1")]
patches: PatchesMetadata,
}

impl VTable for SparseVTable {
type Array = SparseArray;

type Metadata = ProstMetadata<SparseMetadata>;
type Metadata = SparseMetadata;
type OperationsVTable = Self;
type ValidityVTable = Self;

Expand Down Expand Up @@ -134,65 +142,87 @@ impl VTable for SparseVTable {
}

fn metadata(array: &SparseArray) -> VortexResult<Self::Metadata> {
Ok(ProstMetadata(SparseMetadata {
patches: array.patches().to_metadata(array.len(), array.dtype())?,
}))
let patches = array.patches().to_metadata(array.len(), array.dtype())?;

Ok(SparseMetadata {
patches,
fill_value: array.fill_value.clone(),
})
}

fn serialize(metadata: Self::Metadata) -> VortexResult<Option<Vec<u8>>> {
Ok(Some(metadata.0.encode_to_vec()))
let prost_patches = ProstPatchesMetadata {
patches: metadata.patches,
};

// Note that we DO NOT serialize the fill value since that is stored in the buffers.
Ok(Some(prost_patches.encode_to_vec()))
}

fn deserialize(
bytes: &[u8],
_dtype: &DType,
dtype: &DType,
_len: usize,
_buffers: &[BufferHandle],
buffers: &[BufferHandle],
_session: &VortexSession,
) -> VortexResult<Self::Metadata> {
Ok(ProstMetadata(SparseMetadata::decode(bytes)?))
let prost_patches =
<ProstMetadata<ProstPatchesMetadata> as DeserializeMetadata>::deserialize(bytes)?;

// Once we have the patches metadata, we need to get the fill value from the buffers.

if buffers.len() != 1 {
vortex_bail!("Expected 1 buffer, got {}", buffers.len());
}
let scalar_bytes: &[u8] = &buffers[0].clone().try_to_host_sync()?;

let scalar_value = ScalarValue::from_proto_bytes(scalar_bytes, dtype)?;
let fill_value = Scalar::try_new(dtype.clone(), scalar_value)?;

Ok(SparseMetadata {
patches: prost_patches.patches,
fill_value,
})
}

fn build(
dtype: &DType,
len: usize,
metadata: &Self::Metadata,
buffers: &[BufferHandle],
_buffers: &[BufferHandle],
children: &dyn ArrayChildren,
) -> VortexResult<SparseArray> {
if children.len() != 2 {
vortex_bail!(
"Expected 2 children for sparse encoding, found {}",
children.len()
)
}
vortex_ensure!(
metadata.0.patches.offset()? == 0,
vortex_ensure_eq!(
children.len(),
2,
"SparseArray expects 2 children for sparse encoding, found {}",
children.len()
);
vortex_ensure_eq!(
metadata.patches.offset()?,
0,
"Patches must start at offset 0"
);

let patch_indices = children.get(
0,
&metadata.0.patches.indices_dtype()?,
metadata.0.patches.len()?,
&metadata.patches.indices_dtype()?,
metadata.patches.len()?,
)?;
let patch_values = children.get(1, dtype, metadata.0.patches.len()?)?;
let patch_values = children.get(1, dtype, metadata.patches.len()?)?;

if buffers.len() != 1 {
vortex_bail!("Expected 1 buffer, got {}", buffers.len());
}

let bytes: &[u8] = &buffers[0].clone().try_to_host_sync()?;
let scalar_value = ScalarValue::from_proto_bytes(bytes, dtype)?;

let fill_value = Scalar::try_new(dtype.clone(), scalar_value)?;

SparseArray::try_new(patch_indices, patch_values, len, fill_value)
SparseArray::try_new(
patch_indices,
patch_values,
len,
metadata.fill_value.clone(),
)
}

fn with_children(array: &mut Self::Array, children: Vec<ArrayRef>) -> VortexResult<()> {
vortex_ensure!(
children.len() == 2,
vortex_ensure_eq!(
children.len(),
2,
"SparseArray expects 2 children, got {}",
children.len()
);
Expand Down
Loading