diff --git a/vortex-array/src/arrays/struct_/compute/rules.rs b/vortex-array/src/arrays/struct_/compute/rules.rs index e450374a3b2..910d5ebe481 100644 --- a/vortex-array/src/arrays/struct_/compute/rules.rs +++ b/vortex-array/src/arrays/struct_/compute/rules.rs @@ -51,7 +51,9 @@ impl ArrayParentReduceRule for StructCastPushDownRule { parent: ScalarFnArrayView, _child_idx: usize, ) -> VortexResult> { - let target_fields = parent.options.as_struct_fields(); + let Some(target_fields) = parent.options.as_struct_fields_opt() else { + return Ok(None); + }; let mut new_fields = Vec::with_capacity(target_fields.nfields()); for (target_name, target_dtype) in target_fields.names().iter().zip(target_fields.fields()) @@ -136,6 +138,8 @@ impl ArrayParentReduceRule for StructGetItemRule { #[cfg(test)] mod tests { + use vortex_buffer::buffer; + use crate::IntoArray; use crate::arrays::StructArray; use crate::arrays::VarBinViewArray; @@ -146,6 +150,7 @@ mod tests { use crate::dtype::DType; use crate::dtype::FieldNames; use crate::dtype::Nullability; + use crate::dtype::PType; use crate::dtype::StructFields; use crate::scalar::Scalar; use crate::validity::Validity; @@ -173,7 +178,8 @@ mod tests { Nullability::NonNullable, ); - // Use ArrayBuiltins::cast which goes through the optimizer and applies StructCastPushDownRule + // Use `ArrayBuiltins::cast` which goes through the optimizer and applies + // `StructCastPushDownRule`. let result = source.into_array().cast(target).unwrap().to_struct(); assert_arrays_eq!( result.unmasked_field_by_name("a").unwrap(), @@ -188,4 +194,124 @@ mod tests { ConstantArray::new(Scalar::null(utf8_null), 1) ); } + + /// Regression test: casting a struct to a non-struct DType must not panic. Previously, + /// `StructCastPushDownRule` called `as_struct_fields()` which panics on non-struct types. + #[test] + fn cast_struct_to_non_struct_does_not_panic() { + let source = StructArray::try_new( + FieldNames::from(["x"]), + vec![buffer![1i32, 2, 3].into_array()], + 3, + Validity::NonNullable, + ) + .unwrap(); + + // Casting a struct to a primitive type should not panic. Before the fix, + // `StructCastPushDownRule` would panic via `as_struct_fields()` on the non-struct target. + let result = source + .into_array() + .cast(DType::Primitive(PType::I32, Nullability::NonNullable)); + // Whether this errors or succeeds depends on execution, but the key invariant is that the + // optimizer rule does not panic. + if let Ok(arr) = &result { + assert_eq!( + arr.dtype(), + &DType::Primitive(PType::I32, Nullability::NonNullable) + ); + } + } + + #[test] + fn cast_struct_drop_field() { + // Casting to a struct with a subset of fields should succeed. + let source = StructArray::try_new( + FieldNames::from(["a", "b", "c"]), + vec![ + buffer![1i32, 2, 3].into_array(), + buffer![10i64, 20, 30].into_array(), + buffer![100u8, 200, 255].into_array(), + ], + 3, + Validity::NonNullable, + ) + .unwrap(); + + let target = DType::Struct( + StructFields::new( + FieldNames::from(["a", "c"]), + vec![ + DType::Primitive(PType::I32, Nullability::NonNullable), + DType::Primitive(PType::U8, Nullability::NonNullable), + ], + ), + Nullability::NonNullable, + ); + + let result = source.into_array().cast(target).unwrap().to_struct(); + assert_eq!(result.unmasked_fields().len(), 2); + assert_arrays_eq!( + result.unmasked_field_by_name("a").unwrap(), + buffer![1i32, 2, 3].into_array() + ); + assert_arrays_eq!( + result.unmasked_field_by_name("c").unwrap(), + buffer![100u8, 200, 255].into_array() + ); + } + + #[test] + fn cast_struct_field_type_widening() { + // Casting struct fields to wider types (i32 -> i64). + let source = StructArray::try_new( + FieldNames::from(["val"]), + vec![buffer![1i32, 2, 3].into_array()], + 3, + Validity::NonNullable, + ) + .unwrap(); + + let target = DType::Struct( + StructFields::new( + FieldNames::from(["val"]), + vec![DType::Primitive(PType::I64, Nullability::NonNullable)], + ), + Nullability::NonNullable, + ); + + let result = source.into_array().cast(target).unwrap().to_struct(); + assert_eq!( + result.unmasked_field_by_name("val").unwrap().dtype(), + &DType::Primitive(PType::I64, Nullability::NonNullable) + ); + assert_arrays_eq!( + result.unmasked_field_by_name("val").unwrap(), + buffer![1i64, 2, 3].into_array() + ); + } + + #[test] + fn cast_struct_add_non_nullable_field_fails() { + // Adding a non-nullable field via cast should fail. + let source = StructArray::try_new( + FieldNames::from(["a"]), + vec![buffer![1i32].into_array()], + 1, + Validity::NonNullable, + ) + .unwrap(); + + let target = DType::Struct( + StructFields::new( + FieldNames::from(["a", "b"]), + vec![ + DType::Primitive(PType::I32, Nullability::NonNullable), + DType::Primitive(PType::I32, Nullability::NonNullable), + ], + ), + Nullability::NonNullable, + ); + + assert!(source.into_array().cast(target).is_err()); + } }