Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 128 additions & 2 deletions vortex-array/src/arrays/struct_/compute/rules.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,9 @@ impl ArrayParentReduceRule<Struct> for StructCastPushDownRule {
parent: ScalarFnArrayView<Cast>,
_child_idx: usize,
) -> VortexResult<Option<ArrayRef>> {
let target_fields = parent.options.as_struct_fields();
let Some(target_fields) = parent.options.as_struct_fields_opt() else {
return Ok(None);
};
let mut new_fields = Vec::with_capacity(target_fields.nfields());

for (target_name, target_dtype) in target_fields.names().iter().zip(target_fields.fields())
Expand Down Expand Up @@ -136,6 +138,8 @@ impl ArrayParentReduceRule<Struct> for StructGetItemRule {

#[cfg(test)]
mod tests {
use vortex_buffer::buffer;

use crate::IntoArray;
use crate::arrays::StructArray;
use crate::arrays::VarBinViewArray;
Expand All @@ -146,6 +150,7 @@ mod tests {
use crate::dtype::DType;
use crate::dtype::FieldNames;
use crate::dtype::Nullability;
use crate::dtype::PType;
use crate::dtype::StructFields;
use crate::scalar::Scalar;
use crate::validity::Validity;
Expand Down Expand Up @@ -173,7 +178,8 @@ mod tests {
Nullability::NonNullable,
);

// Use ArrayBuiltins::cast which goes through the optimizer and applies StructCastPushDownRule
// Use `ArrayBuiltins::cast` which goes through the optimizer and applies
// `StructCastPushDownRule`.
let result = source.into_array().cast(target).unwrap().to_struct();
assert_arrays_eq!(
result.unmasked_field_by_name("a").unwrap(),
Expand All @@ -188,4 +194,124 @@ mod tests {
ConstantArray::new(Scalar::null(utf8_null), 1)
);
}

/// Regression test: casting a struct to a non-struct DType must not panic. Previously,
/// `StructCastPushDownRule` called `as_struct_fields()` which panics on non-struct types.
#[test]
fn cast_struct_to_non_struct_does_not_panic() {
let source = StructArray::try_new(
FieldNames::from(["x"]),
vec![buffer![1i32, 2, 3].into_array()],
3,
Validity::NonNullable,
)
.unwrap();

// Casting a struct to a primitive type should not panic. Before the fix,
// `StructCastPushDownRule` would panic via `as_struct_fields()` on the non-struct target.
let result = source
.into_array()
.cast(DType::Primitive(PType::I32, Nullability::NonNullable));
// Whether this errors or succeeds depends on execution, but the key invariant is that the
// optimizer rule does not panic.
if let Ok(arr) = &result {
assert_eq!(
arr.dtype(),
&DType::Primitive(PType::I32, Nullability::NonNullable)
);
}
}

#[test]
fn cast_struct_drop_field() {
// Casting to a struct with a subset of fields should succeed.
let source = StructArray::try_new(
FieldNames::from(["a", "b", "c"]),
vec![
buffer![1i32, 2, 3].into_array(),
buffer![10i64, 20, 30].into_array(),
buffer![100u8, 200, 255].into_array(),
],
3,
Validity::NonNullable,
)
.unwrap();

let target = DType::Struct(
StructFields::new(
FieldNames::from(["a", "c"]),
vec![
DType::Primitive(PType::I32, Nullability::NonNullable),
DType::Primitive(PType::U8, Nullability::NonNullable),
],
),
Nullability::NonNullable,
);

let result = source.into_array().cast(target).unwrap().to_struct();
assert_eq!(result.unmasked_fields().len(), 2);
assert_arrays_eq!(
result.unmasked_field_by_name("a").unwrap(),
buffer![1i32, 2, 3].into_array()
);
assert_arrays_eq!(
result.unmasked_field_by_name("c").unwrap(),
buffer![100u8, 200, 255].into_array()
);
}

#[test]
fn cast_struct_field_type_widening() {
// Casting struct fields to wider types (i32 -> i64).
let source = StructArray::try_new(
FieldNames::from(["val"]),
vec![buffer![1i32, 2, 3].into_array()],
3,
Validity::NonNullable,
)
.unwrap();

let target = DType::Struct(
StructFields::new(
FieldNames::from(["val"]),
vec![DType::Primitive(PType::I64, Nullability::NonNullable)],
),
Nullability::NonNullable,
);

let result = source.into_array().cast(target).unwrap().to_struct();
assert_eq!(
result.unmasked_field_by_name("val").unwrap().dtype(),
&DType::Primitive(PType::I64, Nullability::NonNullable)
);
assert_arrays_eq!(
result.unmasked_field_by_name("val").unwrap(),
buffer![1i64, 2, 3].into_array()
);
}

#[test]
fn cast_struct_add_non_nullable_field_fails() {
// Adding a non-nullable field via cast should fail.
let source = StructArray::try_new(
FieldNames::from(["a"]),
vec![buffer![1i32].into_array()],
1,
Validity::NonNullable,
)
.unwrap();

let target = DType::Struct(
StructFields::new(
FieldNames::from(["a", "b"]),
vec![
DType::Primitive(PType::I32, Nullability::NonNullable),
DType::Primitive(PType::I32, Nullability::NonNullable),
],
),
Nullability::NonNullable,
);

assert!(source.into_array().cast(target).is_err());
}
}
Loading