diff --git a/vortex-array/public-api.lock b/vortex-array/public-api.lock index d1a798eae06..b1ec105ab49 100644 --- a/vortex-array/public-api.lock +++ b/vortex-array/public-api.lock @@ -12296,7 +12296,9 @@ pub fn vortex_array::scalar::Scalar::from_proto_value(value: &vortex_proto::scal impl vortex_array::scalar::Scalar -pub fn vortex_array::scalar::Scalar::struct_(dtype: vortex_array::dtype::DType, children: alloc::vec::Vec) -> Self +pub fn vortex_array::scalar::Scalar::struct_(dtype: vortex_array::dtype::DType, children: impl core::iter::traits::collect::IntoIterator) -> Self + +pub unsafe fn vortex_array::scalar::Scalar::struct_unchecked(dtype: vortex_array::dtype::DType, children: impl core::iter::traits::collect::IntoIterator) -> Self impl vortex_array::scalar::Scalar diff --git a/vortex-array/src/arrays/struct_/vtable/operations.rs b/vortex-array/src/arrays/struct_/vtable/operations.rs index dba7a594ed1..111832a20b3 100644 --- a/vortex-array/src/arrays/struct_/vtable/operations.rs +++ b/vortex-array/src/arrays/struct_/vtable/operations.rs @@ -11,11 +11,13 @@ use crate::vtable::OperationsVTable; impl OperationsVTable for StructVTable { fn scalar_at(array: &StructArray, index: usize) -> VortexResult { - let field_scalars: VortexResult> = array + let field_scalars: VortexResult> = array .unmasked_fields() .iter() .map(|field| field.scalar_at(index)) .collect(); - Ok(Scalar::struct_(array.dtype().clone(), field_scalars?)) + // SAFETY: The vtable guarantees index is in-bounds and non-null before this is called. + // Each field's scalar_at returns a scalar with the field's own dtype. + Ok(unsafe { Scalar::struct_unchecked(array.dtype().clone(), field_scalars?) }) } } diff --git a/vortex-array/src/scalar/typed_view/struct_.rs b/vortex-array/src/scalar/typed_view/struct_.rs index acc56e195c4..ce571350a5b 100644 --- a/vortex-array/src/scalar/typed_view/struct_.rs +++ b/vortex-array/src/scalar/typed_view/struct_.rs @@ -279,12 +279,13 @@ impl<'a> StructScalar<'a> { } impl Scalar { - /// Creates a new struct scalar with the given fields. - pub fn struct_(dtype: DType, children: Vec) -> Self { + /// Creates a new struct scalar with the given fields, checking dtypes at runtime. + pub fn struct_(dtype: DType, children: impl IntoIterator) -> Self { let DType::Struct(struct_fields, _) = &dtype else { vortex_panic!("Expected struct dtype, found {}", dtype); }; + let children: Vec = children.into_iter().collect(); let field_dtypes = struct_fields.fields(); if children.len() != field_dtypes.len() { vortex_panic!( @@ -305,9 +306,24 @@ impl Scalar { } } - let mut value_children = Vec::with_capacity(children.len()); - value_children.extend(children.into_iter().map(|x| x.into_value())); + let value_children: Vec<_> = children.into_iter().map(|x| x.into_value()).collect(); + Self::try_new(dtype, Some(ScalarValue::List(value_children))) + .vortex_expect("unable to construct a struct `Scalar`") + } + /// Creates a new struct scalar from an iterator of field scalars, skipping dtype checks. + /// + /// # Safety + /// + /// Caller must ensure: + /// - `dtype` is `DType::Struct` + /// - The iterator yields exactly as many scalars as `dtype` has fields + /// - Each scalar's dtype matches the corresponding field dtype in `dtype` + pub unsafe fn struct_unchecked( + dtype: DType, + children: impl IntoIterator, + ) -> Self { + let value_children: Vec<_> = children.into_iter().map(|s| s.into_value()).collect(); Self::try_new(dtype, Some(ScalarValue::List(value_children))) .vortex_expect("unable to construct a struct `Scalar`") } diff --git a/vortex-python/src/scalar/factory.rs b/vortex-python/src/scalar/factory.rs index 81286574a17..d4a0e000cf1 100644 --- a/vortex-python/src/scalar/factory.rs +++ b/vortex-python/src/scalar/factory.rs @@ -151,12 +151,14 @@ fn scalar_helper_inner(value: &Bound<'_, PyAny>, dtype: Option<&DType>) -> PyRes ))); } + let children: Vec = dict + .values() + .into_iter() + .map(|item| scalar_helper_inner(&item, None)) + .try_collect()?; return Ok(Scalar::struct_( DType::Struct(dtype.clone(), *nullability), - dict.values() - .into_iter() - .map(|item| scalar_helper_inner(&item, None)) - .try_collect()?, + children, )); } else { let values: Vec = dict