From dbf2aa5ad9b61bec17c6f6010359383f8707b5ba Mon Sep 17 00:00:00 2001 From: gstvg <28798827+gstvg@users.noreply.github.com> Date: Sat, 22 Nov 2025 15:29:54 -0300 Subject: [PATCH 01/47] add lambda support --- Cargo.lock | 1 + .../examples/custom_file_casts.rs | 11 +- .../examples/default_column_values.rs | 14 +- datafusion-examples/examples/expr_api.rs | 8 +- .../examples/json_shredding.rs | 14 +- datafusion/catalog-listing/src/helpers.rs | 9 +- datafusion/common/src/column.rs | 6 + datafusion/common/src/cse.rs | 22 +- datafusion/common/src/dfschema.rs | 16 +- datafusion/common/src/lib.rs | 2 + datafusion/common/src/utils/mod.rs | 127 +++- .../core/src/execution/session_state.rs | 5 +- datafusion/core/tests/parquet/mod.rs | 2 +- .../core/tests/parquet/schema_adapter.rs | 8 +- .../datasource-parquet/src/row_filter.rs | 16 +- datafusion/expr/src/expr.rs | 69 +- datafusion/expr/src/expr_rewriter/mod.rs | 49 +- datafusion/expr/src/expr_rewriter/order_by.rs | 4 + datafusion/expr/src/expr_schema.rs | 60 +- datafusion/expr/src/lib.rs | 7 +- datafusion/expr/src/tree_node.rs | 702 +++++++++++++++++- datafusion/expr/src/udf.rs | 564 +++++++++++++- datafusion/expr/src/utils.rs | 41 +- datafusion/ffi/src/udf/mod.rs | 8 +- datafusion/ffi/src/udf/return_type_args.rs | 9 +- .../functions-nested/src/array_transform.rs | 266 +++++++ .../src/analyzer/function_rewrite.rs | 21 +- .../optimizer/src/analyzer/type_coercion.rs | 98 +-- .../optimizer/src/common_subexpr_eliminate.rs | 23 +- datafusion/optimizer/src/decorrelate.rs | 20 +- .../optimizer/src/optimize_projections/mod.rs | 4 +- datafusion/optimizer/src/push_down_filter.rs | 37 +- .../optimizer/src/scalar_subquery_to_join.rs | 67 +- .../simplify_expressions/expr_simplifier.rs | 105 ++- datafusion/optimizer/src/utils.rs | 4 +- .../src/schema_rewriter.rs | 29 +- datafusion/physical-expr/Cargo.toml | 4 + .../src/async_scalar_function.rs | 2 + .../physical-expr/src/expressions/column.rs | 21 +- .../physical-expr/src/expressions/lambda.rs | 139 ++++ .../physical-expr/src/expressions/mod.rs | 2 + datafusion/physical-expr/src/lib.rs | 2 + datafusion/physical-expr/src/physical_expr.rs | 10 +- datafusion/physical-expr/src/planner.rs | 29 +- datafusion/physical-expr/src/projection.rs | 53 +- .../physical-expr/src/scalar_function.rs | 701 ++++++++++++++++- .../physical-expr/src/simplifier/mod.rs | 20 +- .../src/simplifier/unwrap_cast.rs | 12 +- datafusion/physical-expr/src/utils/mod.rs | 21 +- .../src/enforce_sorting/sort_pushdown.rs | 60 +- .../src/projection_pushdown.rs | 55 +- datafusion/physical-plan/src/async_func.rs | 6 +- .../src/joins/stream_join_utils.rs | 29 +- datafusion/physical-plan/src/projection.rs | 61 +- datafusion/proto/src/logical_plan/to_proto.rs | 5 + datafusion/pruning/src/pruning_predicate.rs | 11 +- datafusion/sql/src/expr/function.rs | 28 +- datafusion/sql/src/expr/identifier.rs | 13 + datafusion/sql/src/planner.rs | 30 +- datafusion/sql/src/select.rs | 6 +- datafusion/sql/src/unparser/expr.rs | 17 +- datafusion/sql/src/unparser/plan.rs | 6 +- datafusion/sql/src/unparser/rewrite.rs | 14 +- datafusion/sql/src/unparser/utils.rs | 44 +- datafusion/sql/src/utils.rs | 32 +- datafusion/sqllogictest/test_files/array.slt | 8 +- datafusion/sqllogictest/test_files/lambda.slt | 166 +++++ .../src/logical_plan/producer/expr/mod.rs | 1 + 68 files changed, 3573 insertions(+), 483 deletions(-) create mode 100644 datafusion/functions-nested/src/array_transform.rs create mode 100644 datafusion/physical-expr/src/expressions/lambda.rs create mode 100644 datafusion/sqllogictest/test_files/lambda.slt diff --git a/Cargo.lock b/Cargo.lock index f500265108ff5..4a315ff38f2aa 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2489,6 +2489,7 @@ dependencies = [ "paste", "petgraph 0.8.3", "rand 0.9.2", + "recursive", "rstest", ] diff --git a/datafusion-examples/examples/custom_file_casts.rs b/datafusion-examples/examples/custom_file_casts.rs index 4d97ecd91dc64..d8db97d1e0440 100644 --- a/datafusion-examples/examples/custom_file_casts.rs +++ b/datafusion-examples/examples/custom_file_casts.rs @@ -22,7 +22,7 @@ use arrow::datatypes::{DataType, Field, FieldRef, Schema, SchemaRef}; use datafusion::assert_batches_eq; use datafusion::common::not_impl_err; -use datafusion::common::tree_node::{Transformed, TransformedResult, TreeNode}; +use datafusion::common::tree_node::{Transformed, TransformedResult}; use datafusion::common::{Result, ScalarValue}; use datafusion::datasource::listing::{ ListingTable, ListingTableConfig, ListingTableConfigExt, ListingTableUrl, @@ -31,7 +31,7 @@ use datafusion::execution::context::SessionContext; use datafusion::execution::object_store::ObjectStoreUrl; use datafusion::parquet::arrow::ArrowWriter; use datafusion::physical_expr::expressions::CastExpr; -use datafusion::physical_expr::PhysicalExpr; +use datafusion::physical_expr::{PhysicalExpr, PhysicalExprExt}; use datafusion::prelude::SessionConfig; use datafusion_physical_expr_adapter::{ DefaultPhysicalExprAdapterFactory, PhysicalExprAdapter, PhysicalExprAdapterFactory, @@ -181,11 +181,10 @@ impl PhysicalExprAdapter for CustomCastsPhysicalExprAdapter { expr = self.inner.rewrite(expr)?; // Now we can apply custom casting rules or even swap out all CastExprs for a custom cast kernel / expression // For example, [DataFusion Comet](https://github.com/apache/datafusion-comet) has a [custom cast kernel](https://github.com/apache/datafusion-comet/blob/b4ac876ab420ed403ac7fc8e1b29f42f1f442566/native/spark-expr/src/conversion_funcs/cast.rs#L133-L138). - expr.transform(|expr| { + expr.transform_with_schema(&self.physical_file_schema, |expr, schema| { if let Some(cast) = expr.as_any().downcast_ref::() { - let input_data_type = - cast.expr().data_type(&self.physical_file_schema)?; - let output_data_type = cast.data_type(&self.physical_file_schema)?; + let input_data_type = cast.expr().data_type(schema)?; + let output_data_type = cast.data_type(schema)?; if !cast.is_bigger_cast(&input_data_type) { return not_impl_err!( "Unsupported CAST from {input_data_type} to {output_data_type}" diff --git a/datafusion-examples/examples/default_column_values.rs b/datafusion-examples/examples/default_column_values.rs index d3a7d2ec67f3c..0d00d2c3af827 100644 --- a/datafusion-examples/examples/default_column_values.rs +++ b/datafusion-examples/examples/default_column_values.rs @@ -26,8 +26,8 @@ use async_trait::async_trait; use datafusion::assert_batches_eq; use datafusion::catalog::memory::DataSourceExec; use datafusion::catalog::{Session, TableProvider}; -use datafusion::common::tree_node::{Transformed, TransformedResult, TreeNode}; -use datafusion::common::DFSchema; +use datafusion::common::tree_node::{Transformed, TransformedResult}; +use datafusion::common::{DFSchema, HashSet}; use datafusion::common::{Result, ScalarValue}; use datafusion::datasource::listing::PartitionedFile; use datafusion::datasource::physical_plan::{FileScanConfigBuilder, ParquetSource}; @@ -38,7 +38,7 @@ use datafusion::logical_expr::{Expr, TableProviderFilterPushDown, TableType}; use datafusion::parquet::arrow::ArrowWriter; use datafusion::parquet::file::properties::WriterProperties; use datafusion::physical_expr::expressions::{CastExpr, Column, Literal}; -use datafusion::physical_expr::PhysicalExpr; +use datafusion::physical_expr::{PhysicalExpr, PhysicalExprExt}; use datafusion::physical_plan::ExecutionPlan; use datafusion::prelude::{lit, SessionConfig}; use datafusion_physical_expr_adapter::{ @@ -308,11 +308,12 @@ impl PhysicalExprAdapter for DefaultValuePhysicalExprAdapter { fn rewrite(&self, expr: Arc) -> Result> { // First try our custom default value injection for missing columns let rewritten = expr - .transform(|expr| { + .transform_with_lambdas_params(|expr, lambdas_params| { self.inject_default_values( expr, &self.logical_file_schema, &self.physical_file_schema, + lambdas_params, ) }) .data()?; @@ -348,12 +349,15 @@ impl DefaultValuePhysicalExprAdapter { expr: Arc, logical_file_schema: &Schema, physical_file_schema: &Schema, + lambdas_params: &HashSet, ) -> Result>> { if let Some(column) = expr.as_any().downcast_ref::() { let column_name = column.name(); // Check if this column exists in the physical schema - if physical_file_schema.index_of(column_name).is_err() { + if !lambdas_params.contains(column_name) + && physical_file_schema.index_of(column_name).is_err() + { // Column is missing from physical schema, check if logical schema has a default if let Ok(logical_field) = logical_file_schema.field_with_name(column_name) diff --git a/datafusion-examples/examples/expr_api.rs b/datafusion-examples/examples/expr_api.rs index 56f960870e58a..29f074e2b400c 100644 --- a/datafusion-examples/examples/expr_api.rs +++ b/datafusion-examples/examples/expr_api.rs @@ -23,7 +23,7 @@ use arrow::record_batch::RecordBatch; use datafusion::arrow::datatypes::{DataType, Field, Schema, TimeUnit}; use datafusion::common::stats::Precision; -use datafusion::common::tree_node::{Transformed, TreeNode}; +use datafusion::common::tree_node::Transformed; use datafusion::common::{ColumnStatistics, DFSchema}; use datafusion::common::{ScalarValue, ToDFSchema}; use datafusion::error::Result; @@ -556,7 +556,7 @@ fn type_coercion_demo() -> Result<()> { // 3. Type coercion with `TypeCoercionRewriter`. let coerced_expr = expr .clone() - .rewrite(&mut TypeCoercionRewriter::new(&df_schema))? + .rewrite_with_schema(&df_schema, &mut TypeCoercionRewriter::new(&df_schema))? .data; let physical_expr = datafusion::physical_expr::create_physical_expr( &coerced_expr, @@ -567,7 +567,7 @@ fn type_coercion_demo() -> Result<()> { // 4. Apply explicit type coercion by manually rewriting the expression let coerced_expr = expr - .transform(|e| { + .transform_with_schema(&df_schema, |e, df_schema| { // Only type coerces binary expressions. let Expr::BinaryExpr(e) = e else { return Ok(Transformed::no(e)); @@ -575,7 +575,7 @@ fn type_coercion_demo() -> Result<()> { if let Expr::Column(ref col_expr) = *e.left { let field = df_schema.field_with_name(None, col_expr.name())?; let cast_to_type = field.data_type(); - let coerced_right = e.right.cast_to(cast_to_type, &df_schema)?; + let coerced_right = e.right.cast_to(cast_to_type, df_schema)?; Ok(Transformed::yes(Expr::BinaryExpr(BinaryExpr::new( e.left, e.op, diff --git a/datafusion-examples/examples/json_shredding.rs b/datafusion-examples/examples/json_shredding.rs index 5ef8b59b64200..e97f27b818d8d 100644 --- a/datafusion-examples/examples/json_shredding.rs +++ b/datafusion-examples/examples/json_shredding.rs @@ -22,10 +22,8 @@ use arrow::array::{RecordBatch, StringArray}; use arrow::datatypes::{DataType, Field, FieldRef, Schema, SchemaRef}; use datafusion::assert_batches_eq; -use datafusion::common::tree_node::{ - Transformed, TransformedResult, TreeNode, TreeNodeRecursion, -}; -use datafusion::common::{assert_contains, exec_datafusion_err, Result}; +use datafusion::common::tree_node::{Transformed, TransformedResult, TreeNodeRecursion}; +use datafusion::common::{assert_contains, exec_datafusion_err, HashSet, Result}; use datafusion::datasource::listing::{ ListingTable, ListingTableConfig, ListingTableConfigExt, ListingTableUrl, }; @@ -36,8 +34,8 @@ use datafusion::logical_expr::{ }; use datafusion::parquet::arrow::ArrowWriter; use datafusion::parquet::file::properties::WriterProperties; -use datafusion::physical_expr::PhysicalExpr; use datafusion::physical_expr::{expressions, ScalarFunctionExpr}; +use datafusion::physical_expr::{PhysicalExpr, PhysicalExprExt}; use datafusion::prelude::SessionConfig; use datafusion::scalar::ScalarValue; use datafusion_physical_expr_adapter::{ @@ -302,7 +300,9 @@ impl PhysicalExprAdapter for ShreddedJsonRewriter { fn rewrite(&self, expr: Arc) -> Result> { // First try our custom JSON shredding rewrite let rewritten = expr - .transform(|expr| self.rewrite_impl(expr, &self.physical_file_schema)) + .transform_with_lambdas_params(|expr, lambdas_params| { + self.rewrite_impl(expr, &self.physical_file_schema, lambdas_params) + }) .data()?; // Then apply the default adapter as a fallback to handle standard schema differences @@ -335,6 +335,7 @@ impl ShreddedJsonRewriter { &self, expr: Arc, physical_file_schema: &Schema, + lambdas_params: &HashSet, ) -> Result>> { if let Some(func) = expr.as_any().downcast_ref::() { if func.name() == "json_get_str" && func.args().len() == 2 { @@ -348,6 +349,7 @@ impl ShreddedJsonRewriter { if let Some(column) = func.args()[1] .as_any() .downcast_ref::() + .filter(|col| !lambdas_params.contains(col.name())) { let column_name = column.name(); // Check if there's a flat column with underscore prefix diff --git a/datafusion/catalog-listing/src/helpers.rs b/datafusion/catalog-listing/src/helpers.rs index 82cc36867939e..444f505f4280b 100644 --- a/datafusion/catalog-listing/src/helpers.rs +++ b/datafusion/catalog-listing/src/helpers.rs @@ -52,9 +52,9 @@ use object_store::{ObjectMeta, ObjectStore}; /// was performed pub fn expr_applicable_for_cols(col_names: &[&str], expr: &Expr) -> bool { let mut is_applicable = true; - expr.apply(|expr| match expr { - Expr::Column(Column { ref name, .. }) => { - is_applicable &= col_names.contains(&name.as_str()); + expr.apply_with_lambdas_params(|expr, lambdas_params| match expr { + Expr::Column(col) => { + is_applicable &= col_names.contains(&col.name()) || col.is_lambda_parameter(lambdas_params); if is_applicable { Ok(TreeNodeRecursion::Jump) } else { @@ -86,7 +86,8 @@ pub fn expr_applicable_for_cols(col_names: &[&str], expr: &Expr) -> bool { | Expr::InSubquery(_) | Expr::ScalarSubquery(_) | Expr::GroupingSet(_) - | Expr::Case(_) => Ok(TreeNodeRecursion::Continue), + | Expr::Case(_) + | Expr::Lambda(_) => Ok(TreeNodeRecursion::Continue), Expr::ScalarFunction(scalar_function) => { match scalar_function.func.signature().volatility { diff --git a/datafusion/common/src/column.rs b/datafusion/common/src/column.rs index c7f0b5a4f4881..dd9b985e6485c 100644 --- a/datafusion/common/src/column.rs +++ b/datafusion/common/src/column.rs @@ -22,6 +22,7 @@ use crate::utils::parse_identifiers_normalized; use crate::utils::quote_identifier; use crate::{DFSchema, Diagnostic, Result, SchemaError, Spans, TableReference}; use arrow::datatypes::{Field, FieldRef}; +use std::borrow::Borrow; use std::collections::HashSet; use std::fmt; @@ -325,6 +326,11 @@ impl Column { ..self.clone() } } + + pub fn is_lambda_parameter(&self, lambdas_params: &crate::HashSet + Eq + std::hash::Hash>) -> bool { + // currently, references to lambda parameters are always unqualified + self.relation.is_none() && lambdas_params.contains(self.name()) + } } impl From<&str> for Column { diff --git a/datafusion/common/src/cse.rs b/datafusion/common/src/cse.rs index 674d3386171f8..a7ffde52c93b2 100644 --- a/datafusion/common/src/cse.rs +++ b/datafusion/common/src/cse.rs @@ -178,6 +178,14 @@ pub trait CSEController { /// if all are always evaluated. fn conditional_children(node: &Self::Node) -> Option>; + // A helper method called on each node before is_ignored, during top-down traversal during the first, + // visiting traversal of CSE. + fn visit_f_down(&mut self, _node: &Self::Node) {} + + // A helper method called on each node after is_ignored, during bottom-up traversal during the first, + // visiting traversal of CSE. + fn visit_f_up(&mut self, _node: &Self::Node) {} + // Returns true if a node is valid. If a node is invalid then it can't be eliminated. // Validity is propagated up which means no subtree can be eliminated that contains // an invalid node. @@ -274,7 +282,7 @@ where /// thus can not be extracted as a common [`TreeNode`]. conditional: bool, - controller: &'a C, + controller: &'a mut C, } /// Record item that used when traversing a [`TreeNode`] tree. @@ -352,6 +360,7 @@ where self.visit_stack .push(VisitRecord::EnterMark(self.down_index)); self.down_index += 1; + self.controller.visit_f_down(node); // If a node can short-circuit then some of its children might not be executed so // count the occurrence either normal or conditional. @@ -414,6 +423,7 @@ where self.visit_stack .push(VisitRecord::NodeItem(node_id, is_valid)); self.up_index += 1; + self.controller.visit_f_up(node); Ok(TreeNodeRecursion::Continue) } @@ -532,7 +542,7 @@ where /// Add an identifier to `id_array` for every [`TreeNode`] in this tree. fn node_to_id_array<'n>( - &self, + &mut self, node: &'n N, node_stats: &mut NodeStats<'n, N>, id_array: &mut IdArray<'n, N>, @@ -546,7 +556,7 @@ where random_state: &self.random_state, found_common: false, conditional: false, - controller: &self.controller, + controller: &mut self.controller, }; node.visit(&mut visitor)?; @@ -561,7 +571,7 @@ where /// Each element is itself the result of [`CSE::node_to_id_array`] for that node /// (e.g. the identifiers for each node in the tree) fn to_arrays<'n>( - &self, + &mut self, nodes: &'n [N], node_stats: &mut NodeStats<'n, N>, ) -> Result<(bool, Vec>)> { @@ -761,7 +771,7 @@ mod test { #[test] fn id_array_visitor() -> Result<()> { let alias_generator = AliasGenerator::new(); - let eliminator = CSE::new(TestTreeNodeCSEController::new( + let mut eliminator = CSE::new(TestTreeNodeCSEController::new( &alias_generator, TestTreeNodeMask::Normal, )); @@ -853,7 +863,7 @@ mod test { assert_eq!(expected, id_array); // include aggregates - let eliminator = CSE::new(TestTreeNodeCSEController::new( + let mut eliminator = CSE::new(TestTreeNodeCSEController::new( &alias_generator, TestTreeNodeMask::NormalAndAggregates, )); diff --git a/datafusion/common/src/dfschema.rs b/datafusion/common/src/dfschema.rs index 24d152a7dba8c..8a09d61292b27 100644 --- a/datafusion/common/src/dfschema.rs +++ b/datafusion/common/src/dfschema.rs @@ -314,8 +314,10 @@ impl DFSchema { return; } - let self_fields: HashSet<(Option<&TableReference>, &FieldRef)> = - self.iter().collect(); + let self_fields: HashSet<(Option<&TableReference>, &str)> = self + .iter() + .map(|(qualifier, field)| (qualifier, field.name().as_str())) + .collect(); let self_unqualified_names: HashSet<&str> = self .inner .fields @@ -328,7 +330,10 @@ impl DFSchema { for (qualifier, field) in other_schema.iter() { // skip duplicate columns let duplicated_field = match qualifier { - Some(q) => self_fields.contains(&(Some(q), field)), + Some(q) => { + self_fields.contains(&(Some(q), field.name().as_str())) + || self_fields.contains(&(None, field.name().as_str())) + } // for unqualified columns, check as unqualified name None => self_unqualified_names.contains(field.name().as_str()), }; @@ -867,6 +872,11 @@ impl DFSchema { &self.functional_dependencies } + /// Get functional dependencies + pub fn field_qualifiers(&self) -> &[Option] { + &self.field_qualifiers + } + /// Iterate over the qualifiers and fields in the DFSchema pub fn iter(&self) -> impl Iterator, &FieldRef)> { self.field_qualifiers diff --git a/datafusion/common/src/lib.rs b/datafusion/common/src/lib.rs index 76c7b46e32737..8923df683f899 100644 --- a/datafusion/common/src/lib.rs +++ b/datafusion/common/src/lib.rs @@ -117,6 +117,8 @@ pub mod hash_set { pub use hashbrown::hash_set::Entry; } +pub use hashbrown; + /// Downcast an Arrow Array to a concrete type, return an `DataFusionError::Internal` if the cast is /// not possible. In normal usage of DataFusion the downcast should always succeed. /// diff --git a/datafusion/common/src/utils/mod.rs b/datafusion/common/src/utils/mod.rs index 7b145ac3ae21d..ec2dad505a561 100644 --- a/datafusion/common/src/utils/mod.rs +++ b/datafusion/common/src/utils/mod.rs @@ -22,15 +22,20 @@ pub mod memory; pub mod proxy; pub mod string_utils; -use crate::error::{_exec_datafusion_err, _internal_datafusion_err, _internal_err}; +use crate::error::{ + _exec_datafusion_err, _exec_err, _internal_datafusion_err, _internal_err, +}; use crate::{Result, ScalarValue}; use arrow::array::{ cast::AsArray, Array, ArrayRef, FixedSizeListArray, LargeListArray, ListArray, OffsetSizeTrait, }; +use arrow::array::{ArrowPrimitiveType, PrimitiveArray}; use arrow::buffer::OffsetBuffer; use arrow::compute::{partition, SortColumn, SortOptions}; -use arrow::datatypes::{DataType, Field, SchemaRef}; +use arrow::datatypes::{ + ArrowNativeType, DataType, Field, Int32Type, Int64Type, SchemaRef, +}; #[cfg(feature = "sql")] use sqlparser::{ast::Ident, dialect::GenericDialect, parser::Parser}; use std::borrow::{Borrow, Cow}; @@ -939,6 +944,124 @@ pub fn take_function_args( }) } +/// [0, 2, 2, 5, 6] -> [0, 0, 2, 2, 2, 3] +pub fn make_list_array_indices( + offsets: &OffsetBuffer, +) -> PrimitiveArray { + let mut indices = Vec::with_capacity( + offsets.last().unwrap().as_usize() - offsets.first().unwrap().as_usize(), + ); + + for (i, (&start, &end)) in std::iter::zip(&offsets[..], &offsets[1..]).enumerate() { + indices.extend(std::iter::repeat_n( + T::Native::usize_as(i), + end.as_usize() - start.as_usize(), + )); + } + + PrimitiveArray::new(indices.into(), None) +} + +/// [0, 2, 2, 5, 6] -> [0, 1, 0, 1, 2, 0] +pub fn make_list_element_indices( + offsets: &OffsetBuffer, +) -> PrimitiveArray { + let mut indices = vec![ + T::default_value(); + offsets.last().unwrap().as_usize() + - offsets.first().unwrap().as_usize() + ]; + + for (&start, &end) in std::iter::zip(&offsets[..], &offsets[1..]) { + for i in 0..end.as_usize() - start.as_usize() { + indices[start.as_usize() + i] = T::Native::usize_as(i); + } + } + + PrimitiveArray::new(indices.into(), None) +} + +/// (3, 2) -> [0, 0, 1, 1, 2, 2] +pub fn make_fsl_array_indices( + list_size: i32, + array_len: usize, +) -> PrimitiveArray { + let mut indices = vec![0; list_size as usize * array_len]; + + for i in 0..array_len { + for j in 0..list_size as usize { + indices[i + j] = i as i32; + } + } + + PrimitiveArray::new(indices.into(), None) +} + +/// (3, 2) -> [0, 1, 0, 1, 0, 1] +pub fn make_fsl_element_indices( + list_size: i32, + array_len: usize, +) -> PrimitiveArray { + let mut indices = vec![0; list_size as usize * array_len]; + + for i in 0..array_len { + for j in 0..list_size as usize { + indices[i + j] = j as i32; + } + } + + PrimitiveArray::new(indices.into(), None) +} + +pub fn list_values(array: &dyn Array) -> Result<&ArrayRef> { + match array.data_type() { + DataType::List(_) => Ok(array.as_list::().values()), + DataType::LargeList(_) => Ok(array.as_list::().values()), + DataType::FixedSizeList(_, _) => Ok(array.as_fixed_size_list().values()), + other => _exec_err!("expected list, got {other}"), + } +} + +pub fn list_indices(array: &dyn Array) -> Result { + match array.data_type() { + DataType::List(_) => Ok(Arc::new(make_list_array_indices::( + array.as_list().offsets(), + ))), + DataType::LargeList(_) => Ok(Arc::new(make_list_array_indices::( + array.as_list().offsets(), + ))), + DataType::FixedSizeList(_, _) => { + let fixed_size_list = array.as_fixed_size_list(); + + Ok(Arc::new(make_fsl_array_indices( + fixed_size_list.value_length(), + fixed_size_list.len(), + ))) + } + other => _exec_err!("expected list, got {other}"), + } +} + +pub fn elements_indices(array: &dyn Array) -> Result { + match array.data_type() { + DataType::List(_) => Ok(Arc::new(make_list_element_indices::( + array.as_list::().offsets(), + ))), + DataType::LargeList(_) => Ok(Arc::new(make_list_element_indices::( + array.as_list::().offsets(), + ))), + DataType::FixedSizeList(_, _) => { + let fixed_size_list = array.as_fixed_size_list(); + + Ok(Arc::new(make_fsl_element_indices( + fixed_size_list.value_length(), + fixed_size_list.len(), + ))) + } + other => _exec_err!("expected list, got {other}"), + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/datafusion/core/src/execution/session_state.rs b/datafusion/core/src/execution/session_state.rs index c15b7eae08432..ad4ffb487ee1d 100644 --- a/datafusion/core/src/execution/session_state.rs +++ b/datafusion/core/src/execution/session_state.rs @@ -41,7 +41,6 @@ use datafusion_common::alias::AliasGenerator; use datafusion_common::config::Dialect; use datafusion_common::config::{ConfigExtension, ConfigOptions, TableOptions}; use datafusion_common::display::{PlanType, StringifiedPlan, ToStringifiedPlan}; -use datafusion_common::tree_node::TreeNode; use datafusion_common::{ config_err, exec_err, plan_datafusion_err, DFSchema, DataFusionError, ResolvedTableReference, TableReference, @@ -701,7 +700,9 @@ impl SessionState { let config_options = self.config_options(); for rewrite in self.analyzer.function_rewrites() { expr = expr - .transform_up(|expr| rewrite.rewrite(expr, df_schema, config_options))? + .transform_up_with_schema(df_schema, |expr, df_schema| { + rewrite.rewrite(expr, df_schema, config_options) + })? .data; } create_physical_expr(&expr, df_schema, self.execution_props()) diff --git a/datafusion/core/tests/parquet/mod.rs b/datafusion/core/tests/parquet/mod.rs index 097600e45eadd..eea6085c02b9f 100644 --- a/datafusion/core/tests/parquet/mod.rs +++ b/datafusion/core/tests/parquet/mod.rs @@ -516,7 +516,7 @@ fn make_uint_batches(start: u8, end: u8) -> RecordBatch { Field::new("u64", DataType::UInt64, true), ])); let v8: Vec = (start..end).collect(); - let v16: Vec = (start as _..end as _).collect(); + let v16: Vec = (start as u16..end as _).collect(); let v32: Vec = (start as _..end as _).collect(); let v64: Vec = (start as _..end as _).collect(); RecordBatch::try_new( diff --git a/datafusion/core/tests/parquet/schema_adapter.rs b/datafusion/core/tests/parquet/schema_adapter.rs index 40fc6176e212b..dfa4c91ba5dd8 100644 --- a/datafusion/core/tests/parquet/schema_adapter.rs +++ b/datafusion/core/tests/parquet/schema_adapter.rs @@ -27,7 +27,7 @@ use datafusion::datasource::listing::{ ListingTable, ListingTableConfig, ListingTableConfigExt, }; use datafusion::prelude::{SessionConfig, SessionContext}; -use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; +use datafusion_common::tree_node::{Transformed, TransformedResult}; use datafusion_common::DataFusionError; use datafusion_common::{ColumnStatistics, ScalarValue}; use datafusion_datasource::file::FileSource; @@ -39,7 +39,7 @@ use datafusion_datasource::ListingTableUrl; use datafusion_datasource_parquet::source::ParquetSource; use datafusion_execution::object_store::ObjectStoreUrl; use datafusion_physical_expr::expressions::{self, Column}; -use datafusion_physical_expr::PhysicalExpr; +use datafusion_physical_expr::{PhysicalExpr, PhysicalExprExt}; use datafusion_physical_expr_adapter::{ DefaultPhysicalExprAdapter, DefaultPhysicalExprAdapterFactory, PhysicalExprAdapter, PhysicalExprAdapterFactory, @@ -181,10 +181,10 @@ struct CustomPhysicalExprAdapter { impl PhysicalExprAdapter for CustomPhysicalExprAdapter { fn rewrite(&self, mut expr: Arc) -> Result> { expr = expr - .transform(|expr| { + .transform_with_lambdas_params(|expr, lambdas_params| { if let Some(column) = expr.as_any().downcast_ref::() { let field_name = column.name(); - if self + if !lambdas_params.contains(field_name) && self .physical_file_schema .field_with_name(field_name) .ok() diff --git a/datafusion/datasource-parquet/src/row_filter.rs b/datafusion/datasource-parquet/src/row_filter.rs index 660b32f486120..45441ad71086c 100644 --- a/datafusion/datasource-parquet/src/row_filter.rs +++ b/datafusion/datasource-parquet/src/row_filter.rs @@ -77,7 +77,7 @@ use datafusion_common::Result; use datafusion_datasource::schema_adapter::{SchemaAdapterFactory, SchemaMapper}; use datafusion_physical_expr::expressions::Column; use datafusion_physical_expr::utils::reassign_expr_columns; -use datafusion_physical_expr::{split_conjunction, PhysicalExpr}; +use datafusion_physical_expr::{split_conjunction, PhysicalExpr, PhysicalExprExt}; use datafusion_physical_plan::metrics; @@ -336,6 +336,20 @@ impl<'schema> PushdownChecker<'schema> { fn prevents_pushdown(&self) -> bool { self.non_primitive_columns || self.projected_columns } + + fn check(&mut self, node: Arc) -> Result { + node.apply_with_lambdas_params(|node, lamdas_params| { + if let Some(column) = node.as_any().downcast_ref::() { + if !lamdas_params.contains(column.name()) { + if let Some(recursion) = self.check_single_column(column.name()) { + return Ok(recursion); + } + } + } + + Ok(TreeNodeRecursion::Continue) + }) + } } impl TreeNodeVisitor<'_> for PushdownChecker<'_> { diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 13160d573ab4d..e2845ea5a7de8 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -398,6 +398,10 @@ pub enum Expr { OuterReferenceColumn(FieldRef, Column), /// Unnest expression Unnest(Unnest), + /// Lambda expression, valid only as a scalar function argument + /// Note that it has it's own scoped schema, different from the plan schema, + /// that can be constructed with ScalarUDF::arguments_schemas and variants + Lambda(Lambda), } impl Default for Expr { @@ -1211,6 +1215,23 @@ impl GroupingSet { } } +/// Lambda expression. +#[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)] +pub struct Lambda { + pub params: Vec, + pub body: Box, +} + +impl Lambda { + /// Create a new lambda expression + pub fn new(params: Vec, body: Expr) -> Self { + Self { + params, + body: Box::new(body), + } + } +} + #[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)] #[cfg(not(feature = "sql"))] pub struct IlikeSelectItem { @@ -1525,6 +1546,7 @@ impl Expr { #[expect(deprecated)] Expr::Wildcard { .. } => "Wildcard", Expr::Unnest { .. } => "Unnest", + Expr::Lambda { .. } => "Lambda", } } @@ -1908,9 +1930,11 @@ impl Expr { /// /// See [`Self::column_refs`] for details pub fn add_column_refs<'a>(&'a self, set: &mut HashSet<&'a Column>) { - self.apply(|expr| { + self.apply_with_lambdas_params(|expr, lambdas_params| { if let Expr::Column(col) = expr { - set.insert(col); + if col.relation.is_some() || !lambdas_params.contains(col.name()) { + set.insert(col); + } } Ok(TreeNodeRecursion::Continue) }) @@ -1943,9 +1967,11 @@ impl Expr { /// /// See [`Self::column_refs_counts`] for details pub fn add_column_ref_counts<'a>(&'a self, map: &mut HashMap<&'a Column, usize>) { - self.apply(|expr| { + self.apply_with_lambdas_params(|expr, lambdas_params| { if let Expr::Column(col) = expr { - *map.entry(col).or_default() += 1; + if !col.is_lambda_parameter(lambdas_params) { + *map.entry(col).or_default() += 1; + } } Ok(TreeNodeRecursion::Continue) }) @@ -1954,8 +1980,10 @@ impl Expr { /// Returns true if there are any column references in this Expr pub fn any_column_refs(&self) -> bool { - self.exists(|expr| Ok(matches!(expr, Expr::Column(_)))) - .expect("exists closure is infallible") + self.exists_with_lambdas_params(|expr, lambdas_params| { + Ok(matches!(expr, Expr::Column(c) if !c.is_lambda_parameter(lambdas_params))) + }) + .expect("exists closure is infallible") } /// Return true if the expression contains out reference(correlated) expressions. @@ -1995,7 +2023,7 @@ impl Expr { /// at least one placeholder. pub fn infer_placeholder_types(self, schema: &DFSchema) -> Result<(Expr, bool)> { let mut has_placeholder = false; - self.transform(|mut expr| { + self.transform_with_schema(schema, |mut expr, schema| { match &mut expr { // Default to assuming the arguments are the same type Expr::BinaryExpr(BinaryExpr { left, op: _, right }) => { @@ -2078,7 +2106,8 @@ impl Expr { | Expr::Wildcard { .. } | Expr::WindowFunction(..) | Expr::Literal(..) - | Expr::Placeholder(..) => false, + | Expr::Placeholder(..) + | Expr::Lambda { .. } => false, } } @@ -2674,6 +2703,12 @@ impl HashNode for Expr { column.hash(state); } Expr::Unnest(Unnest { expr: _expr }) => {} + Expr::Lambda(Lambda { + params, + body: _, + }) => { + params.hash(state); + } }; } } @@ -2987,6 +3022,12 @@ impl Display for SchemaDisplay<'_> { } } } + Expr::Lambda(Lambda { + params, + body, + }) => { + write!(f, "({}) -> {body}", display_comma_separated(params)) + } } } } @@ -3167,6 +3208,12 @@ impl Display for SqlDisplay<'_> { } } } + Expr::Lambda(Lambda { + params, + body, + }) => { + write!(f, "({}) -> {}", params.join(", "), SchemaDisplay(body)) + } _ => write!(f, "{}", self.0), } } @@ -3474,6 +3521,12 @@ impl Display for Expr { Expr::Unnest(Unnest { expr }) => { write!(f, "{UNNEST_COLUMN_PREFIX}({expr})") } + Expr::Lambda(Lambda { + params, + body, + }) => { + write!(f, "({}) -> {body}", params.join(", ")) + } } } } diff --git a/datafusion/expr/src/expr_rewriter/mod.rs b/datafusion/expr/src/expr_rewriter/mod.rs index 9c3c5df7007ff..81ec6e7acbe38 100644 --- a/datafusion/expr/src/expr_rewriter/mod.rs +++ b/datafusion/expr/src/expr_rewriter/mod.rs @@ -62,11 +62,15 @@ pub trait FunctionRewrite: Debug { /// Recursively call `LogicalPlanBuilder::normalize` on all [`Column`] expressions /// in the `expr` expression tree. pub fn normalize_col(expr: Expr, plan: &LogicalPlan) -> Result { - expr.transform(|expr| { + expr.transform_up_with_lambdas_params(|expr, lambdas_params| { Ok({ if let Expr::Column(c) = expr { - let col = LogicalPlanBuilder::normalize(plan, c)?; - Transformed::yes(Expr::Column(col)) + if c.relation.is_some() || !lambdas_params.contains(c.name()) { + let col = LogicalPlanBuilder::normalize(plan, c)?; + Transformed::yes(Expr::Column(col)) + } else { + Transformed::no(Expr::Column(c)) + } } else { Transformed::no(expr) } @@ -91,14 +95,21 @@ pub fn normalize_col_with_schemas_and_ambiguity_check( return Ok(Expr::Unnest(Unnest { expr: Box::new(e) })); } - expr.transform(|expr| { + expr.transform_up_with_lambdas_params(|expr, lambdas_params| { Ok({ - if let Expr::Column(c) = expr { - let col = - c.normalize_with_schemas_and_ambiguity_check(schemas, using_columns)?; - Transformed::yes(Expr::Column(col)) - } else { - Transformed::no(expr) + match expr { + Expr::Column(c) => { + if c.relation.is_none() && lambdas_params.contains(c.name()) { + Transformed::no(Expr::Column(c)) + } else { + let col = c.normalize_with_schemas_and_ambiguity_check( + schemas, + using_columns, + )?; + Transformed::yes(Expr::Column(col)) + } + } + _ => Transformed::no(expr), } }) }) @@ -133,15 +144,18 @@ pub fn normalize_sorts( /// Recursively replace all [`Column`] expressions in a given expression tree with /// `Column` expressions provided by the hash map argument. pub fn replace_col(expr: Expr, replace_map: &HashMap<&Column, &Column>) -> Result { - expr.transform(|expr| { + expr.transform_up_with_lambdas_params(|expr, lambdas_params| { Ok({ - if let Expr::Column(c) = &expr { - match replace_map.get(c) { - Some(new_c) => Transformed::yes(Expr::Column((*new_c).to_owned())), - None => Transformed::no(expr), + match &expr { + Expr::Column(c) if !c.is_lambda_parameter(lambdas_params) => { + match replace_map.get(c) { + Some(new_c) => { + Transformed::yes(Expr::Column((*new_c).to_owned())) + } + None => Transformed::no(expr), + } } - } else { - Transformed::no(expr) + _ => Transformed::no(expr), } }) }) @@ -201,6 +215,7 @@ pub fn strip_outer_reference(expr: Expr) -> Expr { expr.transform(|expr| { Ok({ if let Expr::OuterReferenceColumn(_, col) = expr { + //todo: what if this col collides with a lambda parameter? Transformed::yes(Expr::Column(col)) } else { Transformed::no(expr) diff --git a/datafusion/expr/src/expr_rewriter/order_by.rs b/datafusion/expr/src/expr_rewriter/order_by.rs index 6db95555502da..b94c632ce74b3 100644 --- a/datafusion/expr/src/expr_rewriter/order_by.rs +++ b/datafusion/expr/src/expr_rewriter/order_by.rs @@ -77,6 +77,10 @@ fn rewrite_in_terms_of_projection( // assumption is that each item in exprs, such as "b + c" is // available as an output column named "b + c" expr.transform(|expr| { + if matches!(expr, Expr::Lambda(_)) { + return Ok(Transformed::new(expr, false, TreeNodeRecursion::Jump)) + } + // search for unnormalized names first such as "c1" (such as aliases) if let Some(found) = proj_exprs.iter().find(|a| (**a) == expr) { let (qualifier, field_name) = found.qualified_name(); diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index 9e8d6080b82c8..4a1efadccd0ec 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -16,19 +16,24 @@ // under the License. use super::{Between, Expr, Like}; +use crate::expr::FieldMetadata; use crate::expr::{ AggregateFunction, AggregateFunctionParams, Alias, BinaryExpr, Cast, InList, - InSubquery, Placeholder, ScalarFunction, TryCast, Unnest, WindowFunction, + InSubquery, Lambda, Placeholder, ScalarFunction, TryCast, Unnest, WindowFunction, WindowFunctionParams, }; use crate::type_coercion::functions::{ - data_types_with_scalar_udf, fields_with_aggregate_udf, fields_with_window_udf, + fields_with_aggregate_udf, fields_with_window_udf, +}; +use crate::{ + type_coercion::functions::data_types_with_scalar_udf, udf::ReturnFieldArgs, utils, + LogicalPlan, Projection, Subquery, WindowFunctionDefinition, +}; +use arrow::datatypes::FieldRef; +use arrow::{ + compute::can_cast_types, + datatypes::{DataType, Field}, }; -use crate::udf::ReturnFieldArgs; -use crate::{utils, LogicalPlan, Projection, Subquery, WindowFunctionDefinition}; -use arrow::compute::can_cast_types; -use arrow::datatypes::{DataType, Field, FieldRef}; -use datafusion_common::metadata::FieldMetadata; use datafusion_common::{ not_impl_err, plan_datafusion_err, plan_err, Column, DataFusionError, ExprSchema, Result, Spans, TableReference, @@ -229,6 +234,7 @@ impl ExprSchemable for Expr { // Grouping sets do not really have a type and do not appear in projections Ok(DataType::Null) } + Expr::Lambda { .. } => Ok(DataType::Null), } } @@ -347,6 +353,7 @@ impl ExprSchemable for Expr { // in projections Ok(true) } + Expr::Lambda { .. } => Ok(false), } } @@ -535,14 +542,31 @@ impl ExprSchemable for Expr { func.return_field(&new_fields) } + // Expr::Lambda(Lambda { params, body}) => body.to_field(schema), Expr::ScalarFunction(ScalarFunction { func, args }) => { - let (arg_types, fields): (Vec, Vec>) = args + let fields = if args.iter().any(|arg| matches!(arg, Expr::Lambda(_))) { + let lambdas_schemas = func.arguments_expr_schema(args, schema)?; + + std::iter::zip(args, lambdas_schemas) + // .map(|(e, schema)| e.to_field(schema).map(|(_, f)| f)) + .map(|(e, schema)| match e { + Expr::Lambda(Lambda { params: _, body }) => { + body.to_field(&schema).map(|(_, f)| f) + } + _ => e.to_field(&schema).map(|(_, f)| f), + }) + .collect::>>()? + } else { + args.iter() + .map(|e| e.to_field(schema).map(|(_, f)| f)) + .collect::>>()? + }; + + let arg_types = fields .iter() - .map(|e| e.to_field(schema).map(|(_, f)| f)) - .collect::>>()? - .into_iter() - .map(|f| (f.data_type().clone(), f)) - .unzip(); + .map(|f| f.data_type().clone()) + .collect::>(); + // Verify that function is invoked with correct number and type of arguments as defined in `TypeSignature` let new_data_types = data_types_with_scalar_udf(&arg_types, func) .map_err(|err| { @@ -573,9 +597,16 @@ impl ExprSchemable for Expr { _ => None, }) .collect::>(); + + let lambdas = args + .iter() + .map(|e| matches!(e, Expr::Lambda { .. })) + .collect::>(); + let args = ReturnFieldArgs { arg_fields: &new_fields, scalar_arguments: &arguments, + lambdas: &lambdas, }; func.return_field_from_args(args) @@ -600,7 +631,8 @@ impl ExprSchemable for Expr { | Expr::Wildcard { .. } | Expr::GroupingSet(_) | Expr::Placeholder(_) - | Expr::Unnest(_) => Ok(Arc::new(Field::new( + | Expr::Unnest(_) + | Expr::Lambda(_) => Ok(Arc::new(Field::new( &schema_name, self.get_type(schema)?, self.nullable(schema)?, diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index 2b7cc9d46ad34..46c7422814ace 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -117,7 +117,12 @@ pub use udaf::{ udaf_default_window_function_schema_name, AggregateUDF, AggregateUDFImpl, ReversedUDAF, SetMonotonicity, StatisticsArgs, }; -pub use udf::{ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl}; +pub use udf::{ + merge_captures_with_args, merge_captures_with_boxed_lazy_args, + merge_captures_with_lazy_args, ReturnFieldArgs, ScalarFunctionArgs, + ScalarFunctionLambdaArg, ScalarUDF, ScalarUDFImpl, ValueOrLambda, ValueOrLambdaField, + ValueOrLambdaParameter, +}; pub use udwf::{LimitEffect, ReversedUDWF, WindowUDF, WindowUDFImpl}; pub use window_frame::{WindowFrame, WindowFrameBound, WindowFrameUnits}; diff --git a/datafusion/expr/src/tree_node.rs b/datafusion/expr/src/tree_node.rs index 81846b4f80608..63c535b43ee8b 100644 --- a/datafusion/expr/src/tree_node.rs +++ b/datafusion/expr/src/tree_node.rs @@ -17,17 +17,20 @@ //! Tree node implementation for Logical Expressions -use crate::expr::{ - AggregateFunction, AggregateFunctionParams, Alias, Between, BinaryExpr, Case, Cast, - GroupingSet, InList, InSubquery, Like, Placeholder, ScalarFunction, TryCast, Unnest, - WindowFunction, WindowFunctionParams, +use crate::{ + expr::{ + AggregateFunction, AggregateFunctionParams, Alias, Between, BinaryExpr, Case, + Cast, GroupingSet, InList, InSubquery, Lambda, Like, Placeholder, ScalarFunction, + TryCast, Unnest, WindowFunction, WindowFunctionParams, + }, + Expr, }; -use crate::Expr; - -use datafusion_common::tree_node::{ - Transformed, TreeNode, TreeNodeContainer, TreeNodeRecursion, TreeNodeRefContainer, +use datafusion_common::{ + tree_node::{ + Transformed, TreeNode, TreeNodeContainer, TreeNodeRecursion, TreeNodeRefContainer, + }, + DFSchema, HashSet, Result, }; -use datafusion_common::Result; /// Implementation of the [`TreeNode`] trait /// @@ -106,6 +109,7 @@ impl TreeNode for Expr { Expr::InList(InList { expr, list, .. }) => { (expr, list).apply_ref_elements(f) } + Expr::Lambda (Lambda{ params: _, body}) => body.apply_elements(f) } } @@ -311,6 +315,686 @@ impl TreeNode for Expr { .update_data(|(new_expr, new_list)| { Expr::InList(InList::new(new_expr, new_list, negated)) }), + Expr::Lambda(Lambda { params, body }) => body + .map_elements(f)? + .update_data(|body| Expr::Lambda(Lambda { params, body })), + }) + } +} + +impl Expr { + /// Similarly to [`Self::rewrite`], rewrites this expr and its inputs using `f`, + /// including lambdas that may appear in expressions such as `array_transform([1, 2], v -> v*2)`. + #[cfg_attr(feature = "recursive_protection", recursive::recursive)] + pub fn rewrite_with_schema< + R: for<'a> TreeNodeRewriterWithPayload = &'a DFSchema>, + >( + self, + schema: &DFSchema, + rewriter: &mut R, + ) -> Result> { + rewriter + .f_down(self, schema)? + .transform_children(|n| match &n { + Expr::ScalarFunction(ScalarFunction { func, args }) + if args.iter().any(|arg| matches!(arg, Expr::Lambda(_))) => + { + let mut lambdas_schemas = func + .arguments_schema_from_logical_args(args, schema)? + .into_iter(); + + n.map_children(|n| { + n.rewrite_with_schema(&lambdas_schemas.next().unwrap(), rewriter) + }) + } + _ => n.map_children(|n| n.rewrite_with_schema(schema, rewriter)), + })? + .transform_parent(|n| rewriter.f_up(n, schema)) + } + + /// Similarly to [`Self::rewrite`], rewrites this expr and its inputs using `f`, + /// including lambdas that may appear in expressions such as `array_transform([1, 2], v -> v*2)`. + pub fn rewrite_with_lambdas_params< + R: for<'a> TreeNodeRewriterWithPayload< + Node = Expr, + Payload<'a> = &'a HashSet, + >, + >( + self, + rewriter: &mut R, + ) -> Result> { + self.rewrite_with_lambdas_params_impl(&HashSet::new(), rewriter) + } + + #[cfg_attr(feature = "recursive_protection", recursive::recursive)] + fn rewrite_with_lambdas_params_impl< + R: for<'a> TreeNodeRewriterWithPayload< + Node = Expr, + Payload<'a> = &'a HashSet, + >, + >( + self, + args: &HashSet, + rewriter: &mut R, + ) -> Result> { + rewriter + .f_down(self, args)? + .transform_children(|n| match n { + Expr::Lambda(Lambda { + ref params, + body: _, + }) => { + let mut args = args.clone(); + + args.extend(params.iter().cloned()); + + n.map_children(|n| { + n.rewrite_with_lambdas_params_impl(&args, rewriter) + }) + } + _ => { + n.map_children(|n| n.rewrite_with_lambdas_params_impl(args, rewriter)) + } + })? + .transform_parent(|n| rewriter.f_up(n, args)) + } + + /// Similarly to [`Self::map_children`], rewrites all lambdas that may + /// appear in expressions such as `array_transform([1, 2], v -> v*2)`. + /// + /// Returns the current node. + pub fn map_children_with_lambdas_params< + F: FnMut(Self, &HashSet) -> Result>, + >( + self, + args: &HashSet, + mut f: F, + ) -> Result> { + match &self { + Expr::Lambda(Lambda { params, body: _ }) => { + let mut args = args.clone(); + + args.extend(params.iter().cloned()); + + self.map_children(|expr| f(expr, &args)) + } + _ => self.map_children(|expr| f(expr, args)), + } + } + + /// Similarly to [`Self::transform_up`], rewrites this expr and its inputs using `f`, + /// including lambdas that may appear in expressions such as `array_transform([1, 2], v -> v*2)`. + pub fn transform_up_with_lambdas_params< + F: FnMut(Self, &HashSet) -> Result>, + >( + self, + mut f: F, + ) -> Result> { + #[cfg_attr(feature = "recursive_protection", recursive::recursive)] + fn transform_up_with_lambdas_params_impl< + F: FnMut(Expr, &HashSet) -> Result>, + >( + node: Expr, + args: &HashSet, + f: &mut F, + ) -> Result> { + node.map_children_with_lambdas_params(args, |node, args| { + transform_up_with_lambdas_params_impl(node, args, f) + })? + .transform_parent(|node| f(node, args)) + /*match &node { + Expr::Lambda(Lambda { params, body: _ }) => { + let mut args = args.clone(); + + args.extend(params.iter().cloned()); + + node.map_children(|n| { + transform_up_with_lambdas_params_impl(n, &args, f) + })? + .transform_parent(|n| f(n, &args)) + } + _ => node + .map_children(|n| transform_up_with_lambdas_params_impl(n, args, f))? + .transform_parent(|n| f(n, args)), + }*/ + } + + transform_up_with_lambdas_params_impl(self, &HashSet::new(), &mut f) + } + + /// Similarly to [`Self::transform_down`], rewrites this expr and its inputs using `f`, + /// including lambdas that may appear in expressions such as `array_transform([1, 2], v -> v*2)`. + pub fn transform_down_with_lambdas_params< + F: FnMut(Self, &HashSet) -> Result>, + >( + self, + mut f: F, + ) -> Result> { + #[cfg_attr(feature = "recursive_protection", recursive::recursive)] + fn transform_down_with_lambdas_params_impl< + F: FnMut(Expr, &HashSet) -> Result>, + >( + node: Expr, + args: &HashSet, + f: &mut F, + ) -> Result> { + f(node, args)?.transform_children(|node| { + node.map_children_with_lambdas_params(args, |node, args| { + transform_down_with_lambdas_params_impl(node, args, f) + }) + }) + } + + transform_down_with_lambdas_params_impl(self, &HashSet::new(), &mut f) + } + + pub fn apply_with_lambdas_params< + 'n, + F: FnMut(&'n Self, &HashSet<&'n str>) -> Result, + >( + &'n self, + mut f: F, + ) -> Result { + #[cfg_attr(feature = "recursive_protection", recursive::recursive)] + fn apply_with_lambdas_params_impl< + 'n, + F: FnMut(&'n Expr, &HashSet<&'n str>) -> Result, + >( + node: &'n Expr, + args: &HashSet<&'n str>, + f: &mut F, + ) -> Result { + match node { + Expr::Lambda(Lambda { params, body: _ }) => { + let mut args = args.clone(); + + args.extend(params.iter().map(|v| v.as_str())); + + f(node, &args)?.visit_children(|| { + node.apply_children(|c| { + apply_with_lambdas_params_impl(c, &args, f) + }) + }) + } + _ => f(node, args)?.visit_children(|| { + node.apply_children(|c| apply_with_lambdas_params_impl(c, args, f)) + }), + } + } + + apply_with_lambdas_params_impl(self, &HashSet::new(), &mut f) + } + + /// Similarly to [`Self::transform`], rewrites this expr and its inputs using `f`, + /// including lambdas that may appear in expressions such as `array_transform([1, 2], v -> v*2)`. + pub fn transform_with_schema< + F: FnMut(Self, &DFSchema) -> Result>, + >( + self, + schema: &DFSchema, + f: F, + ) -> Result> { + self.transform_up_with_schema(schema, f) + } + + /// Similarly to [`Self::transform_up`], rewrites this expr and its inputs using `f`, + /// including lambdas that may appear in expressions such as `array_transform([1, 2], v -> v*2)`. + pub fn transform_up_with_schema< + F: FnMut(Self, &DFSchema) -> Result>, + >( + self, + schema: &DFSchema, + mut f: F, + ) -> Result> { + #[cfg_attr(feature = "recursive_protection", recursive::recursive)] + fn transform_up_with_schema_impl< + F: FnMut(Expr, &DFSchema) -> Result>, + >( + node: Expr, + schema: &DFSchema, + f: &mut F, + ) -> Result> { + node.map_children_with_schema(schema, |n, schema| { + transform_up_with_schema_impl(n, schema, f) + })? + .transform_parent(|n| f(n, schema)) + } + + transform_up_with_schema_impl(self, schema, &mut f) + } + + pub fn map_children_with_schema< + F: FnMut(Self, &DFSchema) -> Result>, + >( + self, + schema: &DFSchema, + mut f: F, + ) -> Result> { + match self { + Expr::ScalarFunction(ref fun) + if fun.args.iter().any(|arg| matches!(arg, Expr::Lambda(_))) => + { + let mut args_schemas = fun + .func + .arguments_schema_from_logical_args(&fun.args, schema)? + .into_iter(); + + self.map_children(|expr| f(expr, &args_schemas.next().unwrap())) + } + _ => self.map_children(|expr| f(expr, schema)), + } + } + + pub fn exists_with_lambdas_params) -> Result>( + &self, + mut f: F, + ) -> Result { + let mut found = false; + + self.apply_with_lambdas_params(|n, lambdas_params| { + if f(n, lambdas_params)? { + found = true; + Ok(TreeNodeRecursion::Stop) + } else { + Ok(TreeNodeRecursion::Continue) + } + })?; + + Ok(found) + } +} + +pub trait ExprWithLambdasRewriter2: Sized { + /// Invoked while traversing down the tree before any children are rewritten. + /// Default implementation returns the node as is and continues recursion. + fn f_down(&mut self, node: Expr, _schema: &DFSchema) -> Result> { + Ok(Transformed::no(node)) + } + + /// Invoked while traversing up the tree after all children have been rewritten. + /// Default implementation returns the node as is and continues recursion. + fn f_up(&mut self, node: Expr, _schema: &DFSchema) -> Result> { + Ok(Transformed::no(node)) + } +} +pub trait TreeNodeRewriterWithPayload: Sized { + type Node; + type Payload<'a>; + + /// Invoked while traversing down the tree before any children are rewritten. + /// Default implementation returns the node as is and continues recursion. + fn f_down<'a>( + &mut self, + node: Self::Node, + _payload: Self::Payload<'a>, + ) -> Result> { + Ok(Transformed::no(node)) + } + + /// Invoked while traversing up the tree after all children have been rewritten. + /// Default implementation returns the node as is and continues recursion. + fn f_up<'a>( + &mut self, + node: Self::Node, + _payload: Self::Payload<'a>, + ) -> Result> { + Ok(Transformed::no(node)) + } +} + +/* +struct LambdaColumnNormalizer<'a> { + existing_qualifiers: HashSet<&'a str>, + alias_generator: AliasGenerator, + lambdas_columns: HashMap>, +} + +impl<'a> LambdaColumnNormalizer<'a> { + fn new(dfschema: &'a DFSchema, expr: &'a Expr) -> Self { + let mut existing_qualifiers: HashSet<&'a str> = dfschema + .field_qualifiers() + .iter() + .flatten() + .map(|tbl| tbl.table()) + .filter(|table| table.starts_with("lambda_")) + .collect(); + + expr.apply(|node| { + if let Expr::Lambda(lambda) = node { + if let Some(qualifier) = &lambda.qualifier { + existing_qualifiers.insert(qualifier); + } + } + + Ok(TreeNodeRecursion::Continue) }) + .unwrap(); + + Self { + existing_qualifiers, + alias_generator: AliasGenerator::new(), + lambdas_columns: HashMap::new(), + } + } +} + +impl TreeNodeRewriter for LambdaColumnNormalizer<'_> { + type Node = Expr; + + fn f_down(&mut self, node: Self::Node) -> Result> { + match node { + Expr::Lambda(mut lambda) => { + let tbl = lambda.qualifier.as_ref().map_or_else( + || loop { + let table = self.alias_generator.next("lambda"); + + if !self.existing_qualifiers.contains(table.as_str()) { + break TableReference::bare(table); + } + }, + |qualifier| TableReference::bare(qualifier.as_str()), + ); + + for param in &lambda.params { + self.lambdas_columns + .entry_ref(param) + .or_default() + .push(tbl.clone()); + } + + if lambda.qualifier.is_none() { + lambda.qualifier = Some(tbl.table().to_owned()); + + Ok(Transformed::yes(Expr::Lambda(lambda))) + } else { + Ok(Transformed::no(Expr::Lambda(lambda))) + } + } + Expr::Column(c) if c.relation.is_none() => { + if let Some(lambda_qualifier) = self.lambdas_columns.get(c.name()) { + Ok(Transformed::yes(Expr::Column( + c.with_relation(lambda_qualifier.last().unwrap().clone()), + ))) + } else { + Ok(Transformed::no(Expr::Column(c))) + } + } + _ => Ok(Transformed::no(node)) + } + } + + fn f_up(&mut self, node: Self::Node) -> Result> { + if let Expr::Lambda(lambda) = &node { + for param in &lambda.params { + match self.lambdas_columns.entry_ref(param) { + EntryRef::Occupied(mut entry) => { + let chain = entry.get_mut(); + + chain.pop(); + + if chain.is_empty() { + entry.remove(); + } + } + EntryRef::Vacant(_) => unreachable!(), + } + } + } + + Ok(Transformed::no(node)) + } +} +*/ + +// helpers used in udf.rs +#[cfg(test)] +pub(crate) mod tests { + use super::TreeNodeRewriterWithPayload; + use crate::{ + col, expr::Lambda, Expr, ScalarUDF, ScalarUDFImpl, ValueOrLambdaParameter, + }; + use arrow::datatypes::{DataType, Field, Schema}; + use datafusion_common::{ + tree_node::{Transformed, TreeNodeRecursion}, + DFSchema, HashSet, Result, + }; + use datafusion_expr_common::signature::{Signature, Volatility}; + + pub(crate) fn list_list_int() -> DFSchema { + DFSchema::try_from(Schema::new(vec![Field::new( + "v", + DataType::new_list(DataType::new_list(DataType::Int32, false), false), + false, + )])) + .unwrap() + } + + pub(crate) fn list_int() -> DFSchema { + DFSchema::try_from(Schema::new(vec![Field::new( + "v", + DataType::new_list(DataType::Int32, false), + false, + )])) + .unwrap() + } + + fn int() -> DFSchema { + DFSchema::try_from(Schema::new(vec![Field::new("v", DataType::Int32, false)])) + .unwrap() + } + + pub(crate) fn array_transform_udf() -> ScalarUDF { + ScalarUDF::new_from_impl(ArrayTransformFunc::new()) + } + + pub(crate) fn args() -> Vec { + vec![ + col("v"), + Expr::Lambda(Lambda::new( + vec!["v".into()], + array_transform_udf().call(vec![ + col("v"), + Expr::Lambda(Lambda::new(vec!["v".into()], -col("v"))), + ]), + )), + ] + } + + // array_transform(v, |v| -> array_transform(v, |v| -> -v)) + fn array_transform() -> Expr { + array_transform_udf().call(args()) + } + + #[derive(Debug, PartialEq, Eq, Hash)] + pub(crate) struct ArrayTransformFunc { + signature: Signature, + } + + impl ArrayTransformFunc { + pub fn new() -> Self { + Self { + signature: Signature::any(2, Volatility::Immutable), + } + } + } + + impl ScalarUDFImpl for ArrayTransformFunc { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + "array_transform" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + Ok(arg_types[0].clone()) + } + + fn lambdas_parameters( + &self, + args: &[ValueOrLambdaParameter], + ) -> Result>>> { + let ValueOrLambdaParameter::Value(value_field) = &args[0] else { + unreachable!() + }; + + let DataType::List(field) = value_field.data_type() else { + unreachable!() + }; + + Ok(vec![ + None, + Some(vec![Field::new( + "", + field.data_type().clone(), + field.is_nullable(), + )]), + ]) + } + + fn invoke_with_args( + &self, + _args: crate::ScalarFunctionArgs, + ) -> Result { + unimplemented!() + } + } + + #[test] + fn test_rewrite_with_schema() { + let schema = list_list_int(); + let array_transform = array_transform(); + + let mut rewriter = OkRewriter::default(); + + array_transform + .rewrite_with_schema(&schema, &mut rewriter) + .unwrap(); + + let expected = [ + ( + "f_down array_transform(v, (v) -> array_transform(v, (v) -> (- v)))", + list_list_int(), + ), + ("f_down v", list_list_int()), + ("f_up v", list_list_int()), + ("f_down (v) -> array_transform(v, (v) -> (- v))", list_int()), + ("f_down array_transform(v, (v) -> (- v))", list_int()), + ("f_down v", list_int()), + ("f_up v", list_int()), + ("f_down (v) -> (- v)", int()), + ("f_down (- v)", int()), + ("f_down v", int()), + ("f_up v", int()), + ("f_up (- v)", int()), + ("f_up (v) -> (- v)", int()), + ("f_up array_transform(v, (v) -> (- v))", list_int()), + ("f_up (v) -> array_transform(v, (v) -> (- v))", list_int()), + ( + "f_up array_transform(v, (v) -> array_transform(v, (v) -> (- v)))", + list_list_int(), + ), + ] + .map(|(a, b)| (String::from(a), b)); + + assert_eq!(rewriter.steps, expected) + } + + #[derive(Default)] + struct OkRewriter { + steps: Vec<(String, DFSchema)>, + } + + impl TreeNodeRewriterWithPayload for OkRewriter { + type Node = Expr; + type Payload<'a> = &'a DFSchema; + + fn f_down( + &mut self, + node: Expr, + schema: &DFSchema, + ) -> Result> { + self.steps.push((format!("f_down {node}"), schema.clone())); + + Ok(Transformed::no(node)) + } + + fn f_up( + &mut self, + node: Expr, + schema: &DFSchema, + ) -> Result> { + self.steps.push((format!("f_up {node}"), schema.clone())); + + Ok(Transformed::no(node)) + } + } + + #[test] + fn test_transform_up_with_lambdas_params() { + let mut steps = vec![]; + + array_transform() + .transform_up_with_lambdas_params(|node, params| { + steps.push((node.to_string(), params.clone())); + + Ok(Transformed::no(node)) + }) + .unwrap(); + + let lambdas_params = &HashSet::from([String::from("v")]); + + let expected = [ + ("v", lambdas_params), + ("v", lambdas_params), + ("v", lambdas_params), + ("(- v)", lambdas_params), + ("(v) -> (- v)", lambdas_params), + ("array_transform(v, (v) -> (- v))", lambdas_params), + ("(v) -> array_transform(v, (v) -> (- v))", lambdas_params), + ( + "array_transform(v, (v) -> array_transform(v, (v) -> (- v)))", + lambdas_params, + ), + ] + .map(|(a, b)| (String::from(a), b.clone())); + + assert_eq!(steps, expected); + } + + #[test] + fn test_apply_with_lambdas_params() { + let array_transform = array_transform(); + let mut steps = vec![]; + + array_transform + .apply_with_lambdas_params(|node, params| { + steps.push((node.to_string(), params.clone())); + + Ok(TreeNodeRecursion::Continue) + }) + .unwrap(); + + let expected = [ + ("v", HashSet::from(["v"])), + ("v", HashSet::from(["v"])), + ("v", HashSet::from(["v"])), + ("(- v)", HashSet::from(["v"])), + ("(v) -> (- v)", HashSet::from(["v"])), + ("array_transform(v, (v) -> (- v))", HashSet::from(["v"])), + ("(v) -> array_transform(v, (v) -> (- v))", HashSet::from(["v"])), + ( + "array_transform(v, (v) -> array_transform(v, (v) -> (- v)))", + HashSet::from(["v"]), + ), + ] + .map(|(a, b)| (String::from(a), b)); + + assert_eq!(steps, expected); } } diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index fd54bb13a62f3..74ac1b456ff04 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -18,21 +18,30 @@ //! [`ScalarUDF`]: Scalar User Defined Functions use crate::async_udf::AsyncScalarUDF; -use crate::expr::schema_name_from_exprs_comma_separated_without_space; +use crate::expr::{schema_name_from_exprs_comma_separated_without_space, Lambda}; use crate::simplify::{ExprSimplifyResult, SimplifyInfo}; use crate::sort_properties::{ExprProperties, SortProperties}; use crate::udf_eq::UdfEq; -use crate::{ColumnarValue, Documentation, Expr, Signature}; -use arrow::datatypes::{DataType, Field, FieldRef}; +use crate::{ColumnarValue, Documentation, Expr, ExprSchemable, Signature}; +use arrow::array::{ArrayRef, RecordBatch}; +use arrow::datatypes::{DataType, Field, FieldRef, Fields, Schema}; +use datafusion_common::alias::AliasGenerator; use datafusion_common::config::ConfigOptions; -use datafusion_common::{not_impl_err, ExprSchema, Result, ScalarValue}; +use datafusion_common::tree_node::TreeNodeRecursion; +use datafusion_common::{ + exec_err, not_impl_err, DFSchema, ExprSchema, Result, ScalarValue, +}; use datafusion_expr_common::dyn_eq::{DynEq, DynHash}; use datafusion_expr_common::interval_arithmetic::Interval; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use indexmap::IndexMap; use std::any::Any; +use std::borrow::Cow; use std::cmp::Ordering; +use std::collections::HashMap; use std::fmt::Debug; use std::hash::{Hash, Hasher}; -use std::sync::Arc; +use std::sync::{Arc, LazyLock}; /// Logical representation of a Scalar User Defined Function. /// @@ -343,6 +352,272 @@ impl ScalarUDF { pub fn as_async(&self) -> Option<&AsyncScalarUDF> { self.inner().as_any().downcast_ref::() } + + /// Variation of arguments_from_logical_args that works with arrow Schema's and ScalarFunctionArgMetadata instead + pub(crate) fn arguments_expr_schema<'a>( + &self, + args: &[Expr], + schema: &'a dyn ExprSchema, + ) -> Result> { + self.arguments_scope_with( + &lambda_parameters(args, schema)?, + ExtendableExprSchema::new(schema), + ) + } + + /// Variation of arguments_from_logical_args that works with arrow Schema's and ScalarFunctionArgMetadata instead, + pub fn arguments_arrow_schema<'a>( + &self, + args: &[ValueOrLambdaParameter], + schema: &'a Schema, + ) -> Result>> { + self.arguments_scope_with(args, Cow::Borrowed(schema)) + } + + pub fn arguments_schema_from_logical_args<'a>( + &self, + args: &[Expr], + schema: &'a DFSchema, + ) -> Result>> { + self.arguments_scope_with( + &lambda_parameters(args, schema)?, + Cow::Borrowed(schema), + ) + } + + /// Scalar function supports lambdas as arguments, which will be evaluated with + /// a different schema that of the function itself. This functions returns a vec + /// with the correspoding schema that each argument will run + /// + /// Return a vec with a value for each argument in args that, if it's a value, it's a clone of base_scope, + /// if it's a lambda, it's the return of merge called with the index and the fields from lambdas_parameters + /// updated with names from metadata + fn arguments_scope_with( + &self, + args: &[ValueOrLambdaParameter], + schema: T, + ) -> Result> { + let parameters = self.inner().lambdas_parameters(args)?; + + if parameters.len() != args.len() { + return exec_err!( + "lambdas_schemas: {} lambdas_parameters returned {} values instead of {}", + self.name(), + args.len(), + parameters.len() + ); + } + + std::iter::zip(args, parameters) + .enumerate() + .map(|(i, (arg, parameters))| match (arg, parameters) { + (ValueOrLambdaParameter::Value(_), None) => Ok(schema.clone()), + (ValueOrLambdaParameter::Value(_), Some(_)) => exec_err!("lambdas_schemas: {} argument {} (0-indexed) is a value but lambdas_parameters result treat it as a lambda", self.name(), i), + (ValueOrLambdaParameter::Lambda(_, _), None) => exec_err!("lambdas_schemas: {} argument {} (0-indexed) is a lambda but lambdas_parameters result treat it as a value", self.name(), i), + (ValueOrLambdaParameter::Lambda(names, captures), Some(args)) => { + if names.len() > args.len() { + return exec_err!("lambdas_schemas: {} argument {} (0-indexed), a lambda, supports up to {} arguments, but got {}", self.name(), i, args.len(), names.len()) + } + + let fields = std::iter::zip(*names, args) + .map(|(name, arg)| arg.with_name(name)) + .collect::(); + + if *captures { + schema.extend(fields) + } else { + T::from_fields(fields) + } + } + }) + .collect() + } +} + +pub trait ExtendSchema: Sized { + fn from_fields(params: Fields) -> Result; + fn extend(&self, params: Fields) -> Result; +} + +impl ExtendSchema for DFSchema { + fn from_fields(params: Fields) -> Result { + DFSchema::from_unqualified_fields(params, Default::default()) + } + + fn extend(&self, params: Fields) -> Result { + let qualified_fields = self + .iter() + .map(|(qualifier, field)| { + if params.find(field.name().as_str()).is_none() { + return (qualifier.cloned(), Arc::clone(field)); + } + + let alias_gen = AliasGenerator::new(); + + loop { + let alias = alias_gen.next(field.name().as_str()); + + if params.find(&alias).is_none() + && !self.has_column_with_unqualified_name(&alias) + { + return ( + qualifier.cloned(), + Arc::new(Field::new( + alias, + field.data_type().clone(), + field.is_nullable(), + )), + ); + } + } + }) + .collect(); + + let mut schema = DFSchema::new_with_metadata(qualified_fields, HashMap::new())?; + let fields_schema = DFSchema::from_unqualified_fields(params, HashMap::new())?; + + schema.merge(&fields_schema); + + assert_eq!( + schema.fields().len(), + self.fields().len() + fields_schema.fields().len() + ); + + Ok(schema) + } +} + +impl ExtendSchema for Schema { + fn from_fields(params: Fields) -> Result { + Ok(Schema::new(params)) + } + + fn extend(&self, params: Fields) -> Result { + let mut params2 = params.iter() + .map(|f| (f.name().as_str(), Some(Arc::clone(f)))) + .collect::>(); + + let mut fields = self.fields() + .iter() + .map(|field| { + match params2.get_mut(field.name().as_str()).and_then(|p| p.take()) { + Some(param) => param, + None => Arc::clone(field), + } + }) + .collect::>(); + + fields.extend(params2.into_values().flatten()); + + let fields = self + .fields() + .iter() + .map(|field| { + if params.find(field.name().as_str()).is_none() { + return Arc::clone(field); + } + + let alias_gen = AliasGenerator::new(); + + loop { + let alias = alias_gen.next(field.name().as_str()); + + if params.find(&alias).is_none() + && self.column_with_name(&alias).is_none() + { + return Arc::new(Field::new( + alias, + field.data_type().clone(), + field.is_nullable(), + )); + } + } + }) + .chain(params.iter().cloned()) + .collect::(); + + assert_eq!(fields.len(), self.fields().len() + params.len()); + + Ok(Schema::new_with_metadata(fields, self.metadata.clone())) + } +} + +impl ExtendSchema for Cow<'_, T> { + fn from_fields(params: Fields) -> Result { + Ok(Cow::Owned(T::from_fields(params)?)) + } + + fn extend(&self, params: Fields) -> Result { + Ok(Cow::Owned(self.as_ref().extend(params)?)) + } +} + +impl ExtendSchema for Arc { + fn from_fields(params: Fields) -> Result { + Ok(Arc::new(T::from_fields(params)?)) + } + + fn extend(&self, params: Fields) -> Result { + Ok(Arc::new(self.as_ref().extend(params)?)) + } +} + +impl ExtendSchema for ExtendableExprSchema<'_> { + fn from_fields(params: Fields) -> Result { + static EMPTY_DFSCHEMA: LazyLock = LazyLock::new(DFSchema::empty); + + Ok(ExtendableExprSchema { + fields_chain: vec![params], + outer_schema: &*EMPTY_DFSCHEMA, + }) + } + + fn extend(&self, params: Fields) -> Result { + Ok(ExtendableExprSchema { + fields_chain: std::iter::once(params) + .chain(self.fields_chain.iter().cloned()) + .collect(), + outer_schema: self.outer_schema, + }) + } +} + +/// A `&dyn ExprSchema` wrapper that supports adding the parameters of a lambda +#[derive(Clone, Debug)] +struct ExtendableExprSchema<'a> { + fields_chain: Vec, + outer_schema: &'a dyn ExprSchema, +} + +impl<'a> ExtendableExprSchema<'a> { + fn new(schema: &'a dyn ExprSchema) -> Self { + Self { + fields_chain: vec![], + outer_schema: schema, + } + } +} + +impl ExprSchema for ExtendableExprSchema<'_> { + fn field_from_column(&self, col: &datafusion_common::Column) -> Result<&Field> { + if col.relation.is_none() { + for fields in &self.fields_chain { + if let Some((_index, lambda_param)) = fields.find(&col.name) { + return Ok(lambda_param); + } + } + } + + self.outer_schema.field_from_column(col) + } +} + +#[derive(Clone, Debug)] +pub enum ValueOrLambdaParameter<'a> { + /// A columnar value with the given field + Value(FieldRef), + /// A lambda with the given parameters names and a flag indicating wheter it captures any columns + Lambda(&'a [String], bool), } impl From for ScalarUDF @@ -359,6 +634,7 @@ where #[derive(Debug, Clone)] pub struct ScalarFunctionArgs { /// The evaluated arguments to the function + /// If it's a lambda, will be `ColumnarValue::Scalar(ScalarValue::Null)` pub args: Vec, /// Field associated with each arg, if it exists pub arg_fields: Vec, @@ -370,6 +646,30 @@ pub struct ScalarFunctionArgs { pub return_field: FieldRef, /// The config options at execution time pub config_options: Arc, + /// The lambdas passed to the function + /// If it's not a lambda it will be `None` + pub lambdas: Option>>, +} + +/// A lambda argument to a ScalarFunction +#[derive(Clone, Debug)] +pub struct ScalarFunctionLambdaArg { + /// The parameters defined in this lambda + /// + /// For example, for `array_transform([2], v -> -v)`, + /// this will be `vec![Field::new("v", DataType::Int32, true)]` + pub params: Vec, + /// The body of the lambda + /// + /// For example, for `array_transform([2], v -> -v)`, + /// this will be the physical expression of `-v` + pub body: Arc, + /// A RecordBatch containing at least the captured columns inside this lambda body, if any + /// Note that it may contain additional, non-specified columns, but that's implementation detail + /// + /// For example, for `array_transform([2], v -> v + a + b)`, + /// this will be a `RecordBatch` with two columns, `a` and `b` + pub captures: Option, } impl ScalarFunctionArgs { @@ -378,6 +678,25 @@ impl ScalarFunctionArgs { pub fn return_type(&self) -> &DataType { self.return_field.data_type() } + + pub fn to_lambda_args(&self) -> Vec> { + match &self.lambdas { + Some(lambdas) => std::iter::zip(&self.args, lambdas) + .map(|(arg, lambda)| match lambda { + Some(lambda) => ValueOrLambda::Lambda(lambda), + None => ValueOrLambda::Value(arg), + }) + .collect(), + None => self.args.iter().map(ValueOrLambda::Value).collect(), + } + } +} + +// An argument to a ScalarUDF that supports lambdas +#[derive(Debug)] +pub enum ValueOrLambda<'a> { + Value(&'a ColumnarValue), + Lambda(&'a ScalarFunctionLambdaArg), } /// Information about arguments passed to the function @@ -390,6 +709,12 @@ impl ScalarFunctionArgs { #[derive(Debug)] pub struct ReturnFieldArgs<'a> { /// The data types of the arguments to the function + /// + /// If argument `i` to the function is a lambda, it will be the field returned by the + /// lambda when executed with the arguments returned from `ScalarUDFImpl::lambdas_parameters` + /// + /// For example, with `array_transform([1], v -> v == 5)` + /// this field will be `[Field::new("", DataType::List(DataType::Int32), false), Field::new("", DataType::Boolean, false)]` pub arg_fields: &'a [FieldRef], /// Is argument `i` to the function a scalar (constant)? /// @@ -398,6 +723,36 @@ pub struct ReturnFieldArgs<'a> { /// For example, if a function is called like `my_function(column_a, 5)` /// this field will be `[None, Some(ScalarValue::Int32(Some(5)))]` pub scalar_arguments: &'a [Option<&'a ScalarValue>], + /// Is argument `i` to the function a lambda? + /// + /// For example, with `array_transform([1], v -> v == 5)` + /// this field will be `[false, true]` + pub lambdas: &'a [bool], +} + +/// A tagged Field indicating whether it correspond to a value or a lambda argument +#[derive(Debug)] +pub enum ValueOrLambdaField<'a> { + /// The Field of a ColumnarValue argument + Value(&'a FieldRef), + /// The Field of the return of the lambda body when evaluated with the parameters from ScalarUDF::lambda_parameters + Lambda(&'a FieldRef), +} + +impl<'a> ReturnFieldArgs<'a> { + /// Based on self.lambdas, encodes self.arg_fields to tagged enums + /// indicating whether it correspond to a value or a lambda argument + pub fn to_lambda_args(&self) -> Vec> { + std::iter::zip(self.arg_fields, self.lambdas) + .map(|(field, is_lambda)| { + if *is_lambda { + ValueOrLambdaField::Lambda(field) + } else { + ValueOrLambdaField::Value(field) + } + }) + .collect() + } } /// Trait for implementing user defined scalar functions. @@ -841,6 +1196,14 @@ pub trait ScalarUDFImpl: Debug + DynEq + DynHash + Send + Sync { fn documentation(&self) -> Option<&Documentation> { None } + + /// Returns the parameters that any lambda supports + fn lambdas_parameters( + &self, + args: &[ValueOrLambdaParameter], + ) -> Result>>> { + Ok(vec![None; args.len()]) + } } /// ScalarUDF that adds an alias to the underlying function. It is better to @@ -959,6 +1322,118 @@ impl ScalarUDFImpl for AliasedScalarUDFImpl { fn documentation(&self) -> Option<&Documentation> { self.inner.documentation() } + + fn lambdas_parameters( + &self, + args: &[ValueOrLambdaParameter], + ) -> Result>>> { + self.inner.lambdas_parameters(args) + } +} + +fn lambda_parameters<'a>( + args: &'a [Expr], + schema: &dyn ExprSchema, +) -> Result>> { + args.iter() + .map(|e| match e { + Expr::Lambda(Lambda { params, body: _ }) => { + let mut captures = false; + + e.apply_with_lambdas_params(|expr, lambdas_params| match expr { + Expr::Column(c) if !c.is_lambda_parameter(lambdas_params) => { + captures = true; + + Ok(TreeNodeRecursion::Stop) + } + _ => Ok(TreeNodeRecursion::Continue), + }) + .unwrap(); + + Ok(ValueOrLambdaParameter::Lambda(params.as_slice(), captures)) + } + _ => Ok(ValueOrLambdaParameter::Value(e.to_field(schema)?.1)), + }) + .collect() +} + +/// Merge the lambda body captured columns with it's arguments +/// Datafusion relies on an unspecified field ordering implemented in this function +/// As such, this is the only correct way to merge the captured values with the arguments +/// The number of args should not be lower than the number of params +/// +/// See also merge_captures_with_lazy_args and merge_captures_with_boxed_lazy_args that lazily +/// computes only the necessary arguments to match the number of params +pub fn merge_captures_with_args( + captures: Option<&RecordBatch>, + params: &[FieldRef], + args: &[ArrayRef], +) -> Result { + if args.len() < params.len() { + return exec_err!( + "merge_captures_with_args called with {} params but with {} args", + params.len(), + args.len() + ); + } + + // the order of the merged batch must be kept in sync with ScalarFunction::lambdas_schemas variants + let (fields, columns) = match captures { + Some(captures) => { + let fields = captures + .schema() + .fields() + .iter() + .chain(params) + .cloned() + .collect::>(); + + let columns = [captures.columns(), args].concat(); + + (fields, columns) + } + None => (params.to_vec(), args.to_vec()), + }; + + Ok(RecordBatch::try_new( + Arc::new(Schema::new(fields)), + columns, + )?) +} + +/// Lazy version of merge_captures_with_args that receives closures to compute the arguments, +/// and calls only the necessary to match the number of params +pub fn merge_captures_with_lazy_args( + captures: Option<&RecordBatch>, + params: &[FieldRef], + args: &[&dyn Fn() -> Result], +) -> Result { + merge_captures_with_args( + captures, + params, + &args + .iter() + .take(params.len()) + .map(|arg| arg()) + .collect::>>()?, + ) +} + +/// Variation of merge_captures_with_lazy_args that take boxed closures +pub fn merge_captures_with_boxed_lazy_args( + captures: Option<&RecordBatch>, + params: &[FieldRef], + args: &[Box Result>], +) -> Result { + merge_captures_with_args( + captures, + params, + &args + .iter() + .take(params.len()) + .map(|arg| arg()) + .collect::>>()?, + ) } #[cfg(test)] @@ -1039,4 +1514,83 @@ mod tests { value.hash(hasher); hasher.finish() } + + use std::borrow::Cow; + + use arrow::datatypes::Fields; + + use crate::{ + tree_node::tests::{args, list_int, list_list_int, array_transform_udf}, + udf::{lambda_parameters, ExtendableExprSchema}, + }; + + #[test] + fn test_arguments_expr_schema() { + let args = args(); + let schema = list_list_int(); + + let schemas = array_transform_udf() + .arguments_expr_schema(&args, &schema) + .unwrap() + .into_iter() + .map(|s| format!("{s:?}")) + .collect::>(); + + let mut lambdas_parameters = array_transform_udf() + .inner() + .lambdas_parameters(&lambda_parameters(&args, &schema).unwrap()) + .unwrap(); + + assert_eq!( + schemas, + &[ + format!("{}", &list_list_int()), + format!( + "{:?}", + ExtendableExprSchema { + fields_chain: vec![Fields::from( + lambdas_parameters[0].take().unwrap() + )], + outer_schema: &list_list_int() + } + ), + ] + ) + } + + #[test] + fn test_arguments_arrow_schema() { + let list_int = list_int(); + let list_list_int = list_list_int(); + + let schemas = array_transform_udf() + .arguments_arrow_schema( + &lambda_parameters(&args(), &list_list_int).unwrap(), + //&[HashSet::new(), HashSet::from([0])], + list_list_int.as_arrow(), + ) + .unwrap(); + + assert_eq!( + schemas, + &[ + Cow::Borrowed(list_list_int.as_arrow()), + Cow::Owned(list_int.as_arrow().clone()) + ] + ) + } + + #[test] + fn test_arguments_schema_from_logical_args() { + let list_list_int = list_list_int(); + + let schemas = array_transform_udf() + .arguments_schema_from_logical_args(&args(), &list_list_int) + .unwrap(); + + assert_eq!( + schemas, + &[Cow::Borrowed(&list_list_int), Cow::Owned(list_int())] + ) + } } diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index cd733e0a130a9..93fcfaef882ff 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -266,10 +266,12 @@ pub fn grouping_set_to_exprlist(group_expr: &[Expr]) -> Result> { /// Recursively walk an expression tree, collecting the unique set of columns /// referenced in the expression pub fn expr_to_columns(expr: &Expr, accum: &mut HashSet) -> Result<()> { - expr.apply(|expr| { + expr.apply_with_lambdas_params(|expr, lambdas_params| { match expr { Expr::Column(qc) => { - accum.insert(qc.clone()); + if qc.relation.is_some() || !lambdas_params.contains(qc.name()) { + accum.insert(qc.clone()); + } } // Use explicit pattern match instead of a default // implementation, so that in the future if someone adds @@ -307,7 +309,8 @@ pub fn expr_to_columns(expr: &Expr, accum: &mut HashSet) -> Result<()> { | Expr::ScalarSubquery(_) | Expr::Wildcard { .. } | Expr::Placeholder(_) - | Expr::OuterReferenceColumn { .. } => {} + | Expr::OuterReferenceColumn { .. } + | Expr::Lambda { .. } => {} } Ok(TreeNodeRecursion::Continue) }) @@ -650,6 +653,7 @@ where /// Search an `Expr`, and all of its nested `Expr`'s, for any that pass the /// provided test. The returned `Expr`'s are deduplicated and returned in order /// of appearance (depth first). +/// todo: document about that columns may refer to a lambda parameter? fn find_exprs_in_expr(expr: &Expr, test_fn: &F) -> Vec where F: Fn(&Expr) -> bool, @@ -672,6 +676,7 @@ where } /// Recursively inspect an [`Expr`] and all its children. +/// todo: document about that columns may refer to a lambda parameter? pub fn inspect_expr_pre(expr: &Expr, mut f: F) -> Result<(), E> where F: FnMut(&Expr) -> Result<(), E>, @@ -743,13 +748,19 @@ pub fn columnize_expr(e: Expr, input: &LogicalPlan) -> Result { _ => return Ok(e), }; let exprs_map: HashMap<&Expr, Column> = output_exprs.into_iter().collect(); - e.transform_down(|node: Expr| match exprs_map.get(&node) { - Some(column) => Ok(Transformed::new( - Expr::Column(column.clone()), - true, - TreeNodeRecursion::Jump, - )), - None => Ok(Transformed::no(node)), + e.transform_down_with_lambdas_params(|node: Expr, lambdas_params| { + if matches!(&node, Expr::Column(c) if c.is_lambda_parameter(lambdas_params)) { + return Ok(Transformed::no(node)); + } + + match exprs_map.get(&node) { + Some(column) => Ok(Transformed::new( + Expr::Column(column.clone()), + true, + TreeNodeRecursion::Jump, + )), + None => Ok(Transformed::no(node)), + } }) .data() } @@ -766,9 +777,11 @@ pub fn find_column_exprs(exprs: &[Expr]) -> Vec { pub(crate) fn find_columns_referenced_by_expr(e: &Expr) -> Vec { let mut exprs = vec![]; - e.apply(|expr| { + e.apply_with_lambdas_params(|expr, lambdas_params| { if let Expr::Column(c) = expr { - exprs.push(c.clone()) + if !c.is_lambda_parameter(lambdas_params) { + exprs.push(c.clone()) + } } Ok(TreeNodeRecursion::Continue) }) @@ -797,9 +810,9 @@ pub(crate) fn find_column_indexes_referenced_by_expr( schema: &DFSchemaRef, ) -> Vec { let mut indexes = vec![]; - e.apply(|expr| { + e.apply_with_lambdas_params(|expr, lambdas_params| { match expr { - Expr::Column(qc) => { + Expr::Column(qc) if !qc.is_lambda_parameter(lambdas_params) => { if let Ok(idx) = schema.index_of_column(qc) { indexes.push(idx); } diff --git a/datafusion/ffi/src/udf/mod.rs b/datafusion/ffi/src/udf/mod.rs index 5e59cfc5ecb07..400ad44696047 100644 --- a/datafusion/ffi/src/udf/mod.rs +++ b/datafusion/ffi/src/udf/mod.rs @@ -33,7 +33,7 @@ use arrow::{ }; use arrow_schema::FieldRef; use datafusion::config::ConfigOptions; -use datafusion::logical_expr::ReturnFieldArgs; +use datafusion::{common::exec_err, logical_expr::ReturnFieldArgs}; use datafusion::{ error::DataFusionError, logical_expr::type_coercion::functions::data_types_with_scalar_udf, @@ -210,6 +210,7 @@ unsafe extern "C" fn invoke_with_args_fn_wrapper( return_field, // TODO: pass config options: https://github.com/apache/datafusion/issues/17035 config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let result = rresult_return!(udf @@ -382,10 +383,15 @@ impl ScalarUDFImpl for ForeignScalarUDF { arg_fields, number_rows, return_field, + lambdas, // TODO: pass config options: https://github.com/apache/datafusion/issues/17035 config_options: _config_options, } = invoke_args; + if lambdas.is_some_and(|lambdas| lambdas.iter().any(|l| l.is_some())) { + return exec_err!("ForeignScalarUDF doesn't support lambdas"); + } + let args = args .into_iter() .map(|v| v.to_array(number_rows)) diff --git a/datafusion/ffi/src/udf/return_type_args.rs b/datafusion/ffi/src/udf/return_type_args.rs index c437c9537be6f..d5cbfff1d3a4b 100644 --- a/datafusion/ffi/src/udf/return_type_args.rs +++ b/datafusion/ffi/src/udf/return_type_args.rs @@ -21,7 +21,7 @@ use abi_stable::{ }; use arrow_schema::FieldRef; use datafusion::{ - common::exec_datafusion_err, error::DataFusionError, logical_expr::ReturnFieldArgs, + common::{exec_datafusion_err, exec_err}, error::DataFusionError, logical_expr::ReturnFieldArgs, scalar::ScalarValue, }; @@ -42,6 +42,10 @@ impl TryFrom> for FFI_ReturnFieldArgs { type Error = DataFusionError; fn try_from(value: ReturnFieldArgs) -> Result { + if value.lambdas.iter().any(|l| *l) { + return exec_err!("FFI_ReturnFieldArgs doesn't support lambdas") + } + let arg_fields = vec_fieldref_to_rvec_wrapped(value.arg_fields)?; let scalar_arguments: Result, Self::Error> = value .scalar_arguments @@ -77,6 +81,7 @@ pub struct ForeignReturnFieldArgsOwned { pub struct ForeignReturnFieldArgs<'a> { arg_fields: &'a [FieldRef], scalar_arguments: Vec>, + lambdas: Vec, // currently always false, used to return a reference in From<&Self> for ReturnFieldArgs } impl TryFrom<&FFI_ReturnFieldArgs> for ForeignReturnFieldArgsOwned { @@ -116,6 +121,7 @@ impl<'a> From<&'a ForeignReturnFieldArgsOwned> for ForeignReturnFieldArgs<'a> { .iter() .map(|opt| opt.as_ref()) .collect(), + lambdas: vec![false; value.arg_fields.len()] } } } @@ -125,6 +131,7 @@ impl<'a> From<&'a ForeignReturnFieldArgs<'a>> for ReturnFieldArgs<'a> { ReturnFieldArgs { arg_fields: value.arg_fields, scalar_arguments: &value.scalar_arguments, + lambdas: &value.lambdas, } } } diff --git a/datafusion/functions-nested/src/array_transform.rs b/datafusion/functions-nested/src/array_transform.rs new file mode 100644 index 0000000000000..700fed477b4cb --- /dev/null +++ b/datafusion/functions-nested/src/array_transform.rs @@ -0,0 +1,266 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`ScalarUDFImpl`] definitions for array_transform function. + +use arrow::{ + array::{Array, ArrayRef, AsArray, FixedSizeListArray, LargeListArray, ListArray}, + compute::take_record_batch, + datatypes::{DataType, Field}, +}; +use datafusion_common::{ + HashSet, Result, exec_err, internal_err, tree_node::{Transformed, TreeNode}, utils::{elements_indices, list_indices, list_values, take_function_args} +}; +use datafusion_expr::{ + ColumnarValue, Documentation, Expr, ScalarFunctionArgs, ScalarUDFImpl, Signature, ValueOrLambda, ValueOrLambdaField, ValueOrLambdaParameter, Volatility, expr::Lambda, merge_captures_with_lazy_args +}; +use datafusion_macros::user_doc; +use std::{any::Any, sync::Arc}; + +make_udf_expr_and_func!( + ArrayTransform, + array_transform, + array lambda, + "transforms the values of a array", + array_transform_udf +); + +#[user_doc( + doc_section(label = "Array Functions"), + description = "transforms the values of a array", + syntax_example = "array_transform(array, x -> x*2)", + sql_example = r#"```sql +> select array_transform([1, 2, 3, 4, 5], x -> x*2); ++-------------------------------------------+ +| array_transform([1, 2, 3, 4, 5], x -> x*2) | ++-------------------------------------------+ +| [2, 4, 6, 8, 10] | ++-------------------------------------------+ +```"#, + argument( + name = "array", + description = "Array expression. Can be a constant, column, or function, and any combination of array operators." + ), + argument(name = "lambda", description = "Lambda") +)] +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct ArrayTransform { + signature: Signature, + aliases: Vec, +} + +impl Default for ArrayTransform { + fn default() -> Self { + Self::new() + } +} + +impl ArrayTransform { + pub fn new() -> Self { + Self { + signature: Signature::any(2, Volatility::Immutable), + aliases: vec![String::from("list_transform")], + } + } +} + +impl ScalarUDFImpl for ArrayTransform { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "array_transform" + } + + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + internal_err!("return_type called instead of return_field_from_args") + } + + fn return_field_from_args( + &self, + args: datafusion_expr::ReturnFieldArgs, + ) -> Result> { + let args = args.to_lambda_args(); + + let [ValueOrLambdaField::Value(list), ValueOrLambdaField::Lambda(lambda)] = + take_function_args(self.name(), &args)? + else { + return exec_err!( + "{} expects a value follewed by a lambda, got {:?}", + self.name(), + args + ); + }; + + //TODO: should metadata be passed? If so, with the same keys or prefixed/suffixed? + + // lambda is the resulting field of executing the lambda body + // with the parameters returned in lambdas_parameters + let field = Arc::new(Field::new( + Field::LIST_FIELD_DEFAULT_NAME, + lambda.data_type().clone(), + lambda.is_nullable(), + )); + + let return_type = match list.data_type() { + DataType::List(_) => DataType::List(field), + DataType::LargeList(_) => DataType::LargeList(field), + DataType::FixedSizeList(_, size) => DataType::FixedSizeList(field, *size), + _ => unreachable!(), + }; + + Ok(Arc::new(Field::new("", return_type, list.is_nullable()))) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + // args.lambda_args allows the convenient match below, instead of inspecting both args.args and args.lambdas + let lambda_args = args.to_lambda_args(); + let [list_value, lambda] = take_function_args(self.name(), &lambda_args)?; + + let (ValueOrLambda::Value(list_value), ValueOrLambda::Lambda(lambda)) = + (list_value, lambda) + else { + return exec_err!( + "{} expects a value followed by a lambda, got {:?}", + self.name(), + &lambda_args + ); + }; + + let list_array = list_value.to_array(args.number_rows)?; + + // if any column got captured, we need to adjust it to the values arrays, + // duplicating values of list with mulitple values and removing values of empty lists + // list_indices is not cheap so is important to avoid it when no column is captured + let adjusted_captures = lambda + .captures + .as_ref() + .map(|captures| take_record_batch(captures, &list_indices(&list_array)?)) + .transpose()?; + + // use closures and merge_captures_with_lazy_args so that it calls only the needed ones based on the number of arguments + // avoiding unnecessary computations + let values_param = || Ok(Arc::clone(list_values(&list_array)?)); + let indices_param = || elements_indices(&list_array); + + // the order of the merged schema is an unspecified implementation detail that may change in the future, + // using this function is the correct way to merge as it return the correct ordering and will change in sync + // the implementation without the need for fixes. It also computes only the parameters requested + let lambda_batch = merge_captures_with_lazy_args( + adjusted_captures.as_ref(), + &lambda.params, // ScalarUDF already merged the fields returned in lambdas_parameters with the parameters names definied in the lambda, so we don't need to + &[&values_param, &indices_param], + )?; + + // call the transforming expression with the record batch composed of the list values merged with captured columns + let transformed_values = lambda + .body + .evaluate(&lambda_batch)? + .into_array(lambda_batch.num_rows())?; + + let field = match args.return_field.data_type() { + DataType::List(field) + | DataType::LargeList(field) + | DataType::FixedSizeList(field, _) => Arc::clone(field), + _ => { + return exec_err!( + "{} expected ScalarFunctionArgs.return_field to be a list, got {}", + self.name(), + args.return_field + ) + } + }; + + let transformed_list = match list_array.data_type() { + DataType::List(_) => { + let list = list_array.as_list(); + + Arc::new(ListArray::new( + field, + list.offsets().clone(), + transformed_values, + list.nulls().cloned(), + )) as ArrayRef + } + DataType::LargeList(_) => { + let large_list = list_array.as_list(); + + Arc::new(LargeListArray::new( + field, + large_list.offsets().clone(), + transformed_values, + large_list.nulls().cloned(), + )) + } + DataType::FixedSizeList(_, value_length) => { + Arc::new(FixedSizeListArray::new( + field, + *value_length, + transformed_values, + list_array.as_fixed_size_list().nulls().cloned(), + )) + } + other => exec_err!("expected list, got {other}")?, + }; + + Ok(ColumnarValue::Array(transformed_list)) + } + + fn lambdas_parameters( + &self, + args: &[ValueOrLambdaParameter], + ) -> Result>>> { + let [ValueOrLambdaParameter::Value(list), ValueOrLambdaParameter::Lambda(_, _)] = + args + else { + return exec_err!( + "{} expects a value follewed by a lambda, got {:?}", + self.name(), + args + ); + }; + + let (field, index_type) = match list.data_type() { + DataType::List(field) => (field, DataType::Int32), + DataType::LargeList(field) => (field, DataType::Int64), + DataType::FixedSizeList(field, _) => (field, DataType::Int32), + _ => return exec_err!("expected list, got {list}"), + }; + + // we don't need to omit the index in the case the lambda don't specify, e.g. array_transform([], v -> v*2), + // nor check whether the lambda contains more than two parameters, e.g. array_transform([], (v, i, j) -> v+i+j), + // as datafusion will do that for us + let value = Field::new("value", field.data_type().clone(), field.is_nullable()) + .with_metadata(field.metadata().clone()); + let index = Field::new("index", index_type, false); + + Ok(vec![None, Some(vec![value, index])]) + } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } +} diff --git a/datafusion/optimizer/src/analyzer/function_rewrite.rs b/datafusion/optimizer/src/analyzer/function_rewrite.rs index c6bf14ebce2e3..0e5e602f8238e 100644 --- a/datafusion/optimizer/src/analyzer/function_rewrite.rs +++ b/datafusion/optimizer/src/analyzer/function_rewrite.rs @@ -19,7 +19,7 @@ use super::AnalyzerRule; use datafusion_common::config::ConfigOptions; -use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::tree_node::Transformed; use datafusion_common::{DFSchema, Result}; use crate::utils::NamePreserver; @@ -64,15 +64,16 @@ impl ApplyFunctionRewrites { let original_name = name_preserver.save(&expr); // recursively transform the expression, applying the rewrites at each step - let transformed_expr = expr.transform_up(|expr| { - let mut result = Transformed::no(expr); - for rewriter in self.function_rewrites.iter() { - result = result.transform_data(|expr| { - rewriter.rewrite(expr, &schema, options) - })?; - } - Ok(result) - })?; + let transformed_expr = + expr.transform_up_with_schema(&schema, |expr, schema| { + let mut result = Transformed::no(expr); + for rewriter in self.function_rewrites.iter() { + result = result.transform_data(|expr| { + rewriter.rewrite(expr, schema, options) + })?; + } + Ok(result) + })?; Ok(transformed_expr.update_data(|expr| original_name.restore(expr))) }) diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 4fb0f8553b4ba..1b82182e8600f 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -20,6 +20,7 @@ use std::sync::Arc; use datafusion_expr::binary::BinaryTypeCoercer; +use datafusion_expr::tree_node::TreeNodeRewriterWithPayload; use itertools::{izip, Itertools as _}; use arrow::datatypes::{DataType, Field, IntervalUnit, Schema}; @@ -27,7 +28,7 @@ use arrow::datatypes::{DataType, Field, IntervalUnit, Schema}; use crate::analyzer::AnalyzerRule; use crate::utils::NamePreserver; use datafusion_common::config::ConfigOptions; -use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter}; +use datafusion_common::tree_node::Transformed; use datafusion_common::{ exec_err, internal_datafusion_err, internal_err, not_impl_err, plan_datafusion_err, plan_err, Column, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, @@ -140,7 +141,7 @@ fn analyze_internal( // apply coercion rewrite all expressions in the plan individually plan.map_expressions(|expr| { let original_name = name_preserver.save(&expr); - expr.rewrite(&mut expr_rewrite) + expr.rewrite_with_schema(&schema, &mut expr_rewrite) .map(|transformed| transformed.update_data(|e| original_name.restore(e))) })? // some plans need extra coercion after their expressions are coerced @@ -304,10 +305,11 @@ impl<'a> TypeCoercionRewriter<'a> { } } -impl TreeNodeRewriter for TypeCoercionRewriter<'_> { +impl TreeNodeRewriterWithPayload for TypeCoercionRewriter<'_> { type Node = Expr; + type Payload<'a> = &'a DFSchema; - fn f_up(&mut self, expr: Expr) -> Result> { + fn f_up(&mut self, expr: Expr, schema: &DFSchema) -> Result> { match expr { Expr::Unnest(_) => not_impl_err!( "Unnest should be rewritten to LogicalPlan::Unnest before type coercion" @@ -318,7 +320,7 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> { spans, }) => { let new_plan = - analyze_internal(self.schema, Arc::unwrap_or_clone(subquery))?.data; + analyze_internal(schema, Arc::unwrap_or_clone(subquery))?.data; Ok(Transformed::yes(Expr::ScalarSubquery(Subquery { subquery: Arc::new(new_plan), outer_ref_columns, @@ -327,7 +329,7 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> { } Expr::Exists(Exists { subquery, negated }) => { let new_plan = analyze_internal( - self.schema, + schema, Arc::unwrap_or_clone(subquery.subquery), )? .data; @@ -346,11 +348,11 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> { negated, }) => { let new_plan = analyze_internal( - self.schema, + schema, Arc::unwrap_or_clone(subquery.subquery), )? .data; - let expr_type = expr.get_type(self.schema)?; + let expr_type = expr.get_type(schema)?; let subquery_type = new_plan.schema().field(0).data_type(); let common_type = comparison_coercion(&expr_type, subquery_type).ok_or( plan_datafusion_err!( @@ -363,32 +365,32 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> { spans: subquery.spans, }; Ok(Transformed::yes(Expr::InSubquery(InSubquery::new( - Box::new(expr.cast_to(&common_type, self.schema)?), + Box::new(expr.cast_to(&common_type, schema)?), cast_subquery(new_subquery, &common_type)?, negated, )))) } Expr::Not(expr) => Ok(Transformed::yes(not(get_casted_expr_for_bool_op( *expr, - self.schema, + schema, )?))), Expr::IsTrue(expr) => Ok(Transformed::yes(is_true( - get_casted_expr_for_bool_op(*expr, self.schema)?, + get_casted_expr_for_bool_op(*expr, schema)?, ))), Expr::IsNotTrue(expr) => Ok(Transformed::yes(is_not_true( - get_casted_expr_for_bool_op(*expr, self.schema)?, + get_casted_expr_for_bool_op(*expr, schema)?, ))), Expr::IsFalse(expr) => Ok(Transformed::yes(is_false( - get_casted_expr_for_bool_op(*expr, self.schema)?, + get_casted_expr_for_bool_op(*expr, schema)?, ))), Expr::IsNotFalse(expr) => Ok(Transformed::yes(is_not_false( - get_casted_expr_for_bool_op(*expr, self.schema)?, + get_casted_expr_for_bool_op(*expr, schema)?, ))), Expr::IsUnknown(expr) => Ok(Transformed::yes(is_unknown( - get_casted_expr_for_bool_op(*expr, self.schema)?, + get_casted_expr_for_bool_op(*expr, schema)?, ))), Expr::IsNotUnknown(expr) => Ok(Transformed::yes(is_not_unknown( - get_casted_expr_for_bool_op(*expr, self.schema)?, + get_casted_expr_for_bool_op(*expr, schema)?, ))), Expr::Like(Like { negated, @@ -397,8 +399,8 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> { escape_char, case_insensitive, }) => { - let left_type = expr.get_type(self.schema)?; - let right_type = pattern.get_type(self.schema)?; + let left_type = expr.get_type(schema)?; + let right_type = pattern.get_type(schema)?; let coerced_type = like_coercion(&left_type, &right_type).ok_or_else(|| { let op_name = if case_insensitive { "ILIKE" @@ -411,9 +413,9 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> { })?; let expr = match left_type { DataType::Dictionary(_, inner) if *inner == DataType::Utf8 => expr, - _ => Box::new(expr.cast_to(&coerced_type, self.schema)?), + _ => Box::new(expr.cast_to(&coerced_type, schema)?), }; - let pattern = Box::new(pattern.cast_to(&coerced_type, self.schema)?); + let pattern = Box::new(pattern.cast_to(&coerced_type, schema)?); Ok(Transformed::yes(Expr::Like(Like::new( negated, expr, @@ -424,7 +426,7 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> { } Expr::BinaryExpr(BinaryExpr { left, op, right }) => { let (left, right) = - self.coerce_binary_op(*left, self.schema, op, *right, self.schema)?; + self.coerce_binary_op(*left, schema, op, *right, schema)?; Ok(Transformed::yes(Expr::BinaryExpr(BinaryExpr::new( Box::new(left), op, @@ -437,15 +439,15 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> { low, high, }) => { - let expr_type = expr.get_type(self.schema)?; - let low_type = low.get_type(self.schema)?; + let expr_type = expr.get_type(schema)?; + let low_type = low.get_type(schema)?; let low_coerced_type = comparison_coercion(&expr_type, &low_type) .ok_or_else(|| { internal_datafusion_err!( "Failed to coerce types {expr_type} and {low_type} in BETWEEN expression" ) })?; - let high_type = high.get_type(self.schema)?; + let high_type = high.get_type(schema)?; let high_coerced_type = comparison_coercion(&expr_type, &high_type) .ok_or_else(|| { internal_datafusion_err!( @@ -460,10 +462,10 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> { ) })?; Ok(Transformed::yes(Expr::Between(Between::new( - Box::new(expr.cast_to(&coercion_type, self.schema)?), + Box::new(expr.cast_to(&coercion_type, schema)?), negated, - Box::new(low.cast_to(&coercion_type, self.schema)?), - Box::new(high.cast_to(&coercion_type, self.schema)?), + Box::new(low.cast_to(&coercion_type, schema)?), + Box::new(high.cast_to(&coercion_type, schema)?), )))) } Expr::InList(InList { @@ -471,10 +473,10 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> { list, negated, }) => { - let expr_data_type = expr.get_type(self.schema)?; + let expr_data_type = expr.get_type(schema)?; let list_data_types = list .iter() - .map(|list_expr| list_expr.get_type(self.schema)) + .map(|list_expr| list_expr.get_type(schema)) .collect::>>()?; let result_type = get_coerce_type_for_list(&expr_data_type, &list_data_types); @@ -484,11 +486,11 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> { ), Some(coerced_type) => { // find the coerced type - let cast_expr = expr.cast_to(&coerced_type, self.schema)?; + let cast_expr = expr.cast_to(&coerced_type, schema)?; let cast_list_expr = list .into_iter() .map(|list_expr| { - list_expr.cast_to(&coerced_type, self.schema) + list_expr.cast_to(&coerced_type, schema) }) .collect::>>()?; Ok(Transformed::yes(Expr::InList(InList ::new( @@ -500,13 +502,13 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> { } } Expr::Case(case) => { - let case = coerce_case_expression(case, self.schema)?; + let case = coerce_case_expression(case, schema)?; Ok(Transformed::yes(Expr::Case(case))) } Expr::ScalarFunction(ScalarFunction { func, args }) => { let new_expr = coerce_arguments_for_signature_with_scalar_udf( args, - self.schema, + schema, &func, )?; Ok(Transformed::yes(Expr::ScalarFunction( @@ -526,7 +528,7 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> { }) => { let new_expr = coerce_arguments_for_signature_with_aggregate_udf( args, - self.schema, + schema, &func, )?; Ok(Transformed::yes(Expr::AggregateFunction( @@ -555,13 +557,13 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> { }, } = *window_fun; let window_frame = - coerce_window_frame(window_frame, self.schema, &order_by)?; + coerce_window_frame(window_frame, schema, &order_by)?; let args = match &fun { expr::WindowFunctionDefinition::AggregateUDF(udf) => { coerce_arguments_for_signature_with_aggregate_udf( args, - self.schema, + schema, udf, )? } @@ -597,7 +599,8 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> { | Expr::Wildcard { .. } | Expr::GroupingSet(_) | Expr::Placeholder(_) - | Expr::OuterReferenceColumn(_, _) => Ok(Transformed::no(expr)), + | Expr::OuterReferenceColumn(_, _) + | Expr::Lambda { .. } => Ok(Transformed::no(expr)), } } } @@ -793,9 +796,11 @@ fn coerce_arguments_for_signature_with_scalar_udf( return Ok(expressions); } - let current_types = expressions - .iter() - .map(|e| e.get_type(schema)) + let current_types = expressions.iter() + .map(|e| match e { + Expr::Lambda { .. } => Ok(DataType::Null), + _ => e.get_type(schema), + }) .collect::>>()?; let new_types = data_types_with_scalar_udf(¤t_types, func)?; @@ -803,7 +808,10 @@ fn coerce_arguments_for_signature_with_scalar_udf( expressions .into_iter() .enumerate() - .map(|(i, expr)| expr.cast_to(&new_types[i], schema)) + .map(|(i, expr)| match expr { + lambda @ Expr::Lambda { .. } => Ok(lambda), + _ => expr.cast_to(&new_types[i], schema), + }) .collect() } @@ -1125,7 +1133,7 @@ mod test { use crate::analyzer::Analyzer; use crate::assert_analyzed_plan_with_config_eq_snapshot; use datafusion_common::config::ConfigOptions; - use datafusion_common::tree_node::{TransformedResult, TreeNode}; + use datafusion_common::tree_node::{TransformedResult}; use datafusion_common::{DFSchema, DFSchemaRef, Result, ScalarValue, Spans}; use datafusion_expr::expr::{self, InSubquery, Like, ScalarFunction}; use datafusion_expr::logical_plan::{EmptyRelation, Projection, Sort}; @@ -2076,7 +2084,7 @@ mod test { let mut rewriter = TypeCoercionRewriter { schema: &schema }; let expr = is_true(lit(12i32).gt(lit(13i64))); let expected = is_true(cast(lit(12i32), DataType::Int64).gt(lit(13i64))); - let result = expr.rewrite(&mut rewriter).data()?; + let result = expr.rewrite_with_schema(&schema, &mut rewriter).data()?; assert_eq!(expected, result); // eq @@ -2087,7 +2095,7 @@ mod test { let mut rewriter = TypeCoercionRewriter { schema: &schema }; let expr = is_true(lit(12i32).eq(lit(13i64))); let expected = is_true(cast(lit(12i32), DataType::Int64).eq(lit(13i64))); - let result = expr.rewrite(&mut rewriter).data()?; + let result = expr.rewrite_with_schema(&schema, &mut rewriter).data()?; assert_eq!(expected, result); // lt @@ -2098,7 +2106,7 @@ mod test { let mut rewriter = TypeCoercionRewriter { schema: &schema }; let expr = is_true(lit(12i32).lt(lit(13i64))); let expected = is_true(cast(lit(12i32), DataType::Int64).lt(lit(13i64))); - let result = expr.rewrite(&mut rewriter).data()?; + let result = expr.rewrite_with_schema(&schema, &mut rewriter).data()?; assert_eq!(expected, result); Ok(()) diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index 2510068494591..e06ed6e547eb5 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -29,7 +29,7 @@ use datafusion_common::alias::AliasGenerator; use datafusion_common::cse::{CSEController, FoundCommonNodes, CSE}; use datafusion_common::tree_node::{Transformed, TreeNode}; -use datafusion_common::{qualified_name, Column, DFSchema, DFSchemaRef, Result}; +use datafusion_common::{qualified_name, Column, DFSchema, DFSchemaRef, HashSet, Result}; use datafusion_expr::expr::{Alias, ScalarFunction}; use datafusion_expr::logical_plan::{ Aggregate, Filter, LogicalPlan, Projection, Sort, Window, @@ -632,6 +632,7 @@ struct ExprCSEController<'a> { // how many aliases have we seen so far alias_counter: usize, + lambdas_params: HashSet, } impl<'a> ExprCSEController<'a> { @@ -640,6 +641,7 @@ impl<'a> ExprCSEController<'a> { alias_generator, mask, alias_counter: 0, + lambdas_params: HashSet::new(), } } } @@ -693,11 +695,30 @@ impl CSEController for ExprCSEController<'_> { } } + fn visit_f_down(&mut self, node: &Expr) { + if let Expr::Lambda(lambda) = node { + self.lambdas_params + .extend(lambda.params.iter().cloned()); + } + } + + fn visit_f_up(&mut self, node: &Expr) { + if let Expr::Lambda(lambda) = node { + for param in &lambda.params { + self.lambdas_params.remove(param); + } + } + } + fn is_valid(node: &Expr) -> bool { !node.is_volatile_node() } fn is_ignored(&self, node: &Expr) -> bool { + if matches!(node, Expr::Column(c) if c.is_lambda_parameter(&self.lambdas_params)) { + return true + } + // TODO: remove the next line after `Expr::Wildcard` is removed #[expect(deprecated)] let is_normal_minus_aggregates = matches!( diff --git a/datafusion/optimizer/src/decorrelate.rs b/datafusion/optimizer/src/decorrelate.rs index 63236787743a4..0f43741834009 100644 --- a/datafusion/optimizer/src/decorrelate.rs +++ b/datafusion/optimizer/src/decorrelate.rs @@ -527,18 +527,17 @@ fn proj_exprs_evaluation_result_on_empty_batch( for expr in proj_expr.iter() { let result_expr = expr .clone() - .transform_up(|expr| { - if let Expr::Column(Column { name, .. }) = &expr { + .transform_up_with_lambdas_params(|expr, lambdas_params| match &expr { + Expr::Column(col) if !col.is_lambda_parameter(lambdas_params) => { if let Some(result_expr) = - input_expr_result_map_for_count_bug.get(name) + input_expr_result_map_for_count_bug.get(col.name()) { Ok(Transformed::yes(result_expr.clone())) } else { Ok(Transformed::no(expr)) } - } else { - Ok(Transformed::no(expr)) } + _ => Ok(Transformed::no(expr)), }) .data()?; @@ -570,16 +569,17 @@ fn filter_exprs_evaluation_result_on_empty_batch( ) -> Result> { let result_expr = filter_expr .clone() - .transform_up(|expr| { - if let Expr::Column(Column { name, .. }) = &expr { - if let Some(result_expr) = input_expr_result_map_for_count_bug.get(name) { + .transform_up_with_lambdas_params(|expr, lambdas_params| match &expr { + Expr::Column(col) if !col.is_lambda_parameter(lambdas_params) => { + if let Some(result_expr) = + input_expr_result_map_for_count_bug.get(col.name()) + { Ok(Transformed::yes(result_expr.clone())) } else { Ok(Transformed::no(expr)) } - } else { - Ok(Transformed::no(expr)) } + _ => Ok(Transformed::no(expr)), }) .data()?; diff --git a/datafusion/optimizer/src/optimize_projections/mod.rs b/datafusion/optimizer/src/optimize_projections/mod.rs index 5db71417bc8fd..f0187b618ccc0 100644 --- a/datafusion/optimizer/src/optimize_projections/mod.rs +++ b/datafusion/optimizer/src/optimize_projections/mod.rs @@ -639,7 +639,7 @@ fn is_expr_trivial(expr: &Expr) -> bool { /// --Source(a, b) /// ``` fn rewrite_expr(expr: Expr, input: &Projection) -> Result> { - expr.transform_up(|expr| { + expr.transform_up_with_lambdas_params(|expr, lambdas_params| { match expr { // remove any intermediate aliases if they do not carry metadata Expr::Alias(alias) => { @@ -653,7 +653,7 @@ fn rewrite_expr(expr: Expr, input: &Projection) -> Result> { false => Ok(Transformed::no(Expr::Alias(alias))), } } - Expr::Column(col) => { + Expr::Column(col) if !col.is_lambda_parameter(lambdas_params) => { // Find index of column: let idx = input.schema.index_of_column(&col)?; // get the corresponding unaliased input expression diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index 1c0790b3e3acd..54cb026543270 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -293,7 +293,8 @@ fn can_evaluate_as_join_condition(predicate: &Expr) -> Result { Expr::AggregateFunction(_) | Expr::WindowFunction(_) | Expr::Wildcard { .. } - | Expr::GroupingSet(_) => internal_err!("Unsupported predicate type"), + | Expr::GroupingSet(_) + | Expr::Lambda { .. } => internal_err!("Unsupported predicate type"), })?; Ok(is_evaluate) } @@ -1389,14 +1390,15 @@ pub fn replace_cols_by_name( e: Expr, replace_map: &HashMap, ) -> Result { - e.transform_up(|expr| { - Ok(if let Expr::Column(c) = &expr { - match replace_map.get(&c.flat_name()) { - Some(new_c) => Transformed::yes(new_c.clone()), - None => Transformed::no(expr), + e.transform_up_with_lambdas_params(|expr, lambdas_params| { + Ok(match &expr { + Expr::Column(c) if !c.is_lambda_parameter(lambdas_params) => { + match replace_map.get(&c.flat_name()) { + Some(new_c) => Transformed::yes(new_c.clone()), + None => Transformed::no(expr), + } } - } else { - Transformed::no(expr) + _ => Transformed::no(expr), }) }) .data() @@ -1405,17 +1407,18 @@ pub fn replace_cols_by_name( /// check whether the expression uses the columns in `check_map`. fn contain(e: &Expr, check_map: &HashMap) -> bool { let mut is_contain = false; - e.apply(|expr| { - Ok(if let Expr::Column(c) = &expr { - match check_map.get(&c.flat_name()) { - Some(_) => { - is_contain = true; - TreeNodeRecursion::Stop + e.apply_with_lambdas_params(|expr, lambdas_params| { + Ok(match &expr { + Expr::Column(c) if !c.is_lambda_parameter(lambdas_params) => { + match check_map.get(&c.flat_name()) { + Some(_) => { + is_contain = true; + TreeNodeRecursion::Stop + } + None => TreeNodeRecursion::Continue, } - None => TreeNodeRecursion::Continue, } - } else { - TreeNodeRecursion::Continue + _ => TreeNodeRecursion::Continue, }) }) .unwrap(); diff --git a/datafusion/optimizer/src/scalar_subquery_to_join.rs b/datafusion/optimizer/src/scalar_subquery_to_join.rs index 48d1182527013..f1e619750f9c8 100644 --- a/datafusion/optimizer/src/scalar_subquery_to_join.rs +++ b/datafusion/optimizer/src/scalar_subquery_to_join.rs @@ -106,17 +106,22 @@ impl OptimizerRule for ScalarSubqueryToJoin { { if !expr_check_map.is_empty() { rewrite_expr = rewrite_expr - .transform_up(|expr| { - // replace column references with entry in map, if it exists - if let Some(map_expr) = expr - .try_as_col() - .and_then(|col| expr_check_map.get(&col.name)) - { - Ok(Transformed::yes(map_expr.clone())) - } else { - Ok(Transformed::no(expr)) - } - }) + .transform_up_with_lambdas_params( + |expr, lambdas_params| { + // replace column references with entry in map, if it exists + if let Some(map_expr) = expr + .try_as_col() + .filter(|c| { + !c.is_lambda_parameter(lambdas_params) + }) + .and_then(|col| expr_check_map.get(&col.name)) + { + Ok(Transformed::yes(map_expr.clone())) + } else { + Ok(Transformed::no(expr)) + } + }, + ) .data()?; } cur_input = optimized_subquery; @@ -171,18 +176,26 @@ impl OptimizerRule for ScalarSubqueryToJoin { { let new_expr = rewrite_expr .clone() - .transform_up(|expr| { - // replace column references with entry in map, if it exists - if let Some(map_expr) = - expr.try_as_col().and_then(|col| { - expr_check_map.get(&col.name) - }) - { - Ok(Transformed::yes(map_expr.clone())) - } else { - Ok(Transformed::no(expr)) - } - }) + .transform_up_with_lambdas_params( + |expr, lambdas_params| { + // replace column references with entry in map, if it exists + if let Some(map_expr) = expr + .try_as_col() + .filter(|c| { + !c.is_lambda_parameter( + lambdas_params, + ) + }) + .and_then(|col| { + expr_check_map.get(&col.name) + }) + { + Ok(Transformed::yes(map_expr.clone())) + } else { + Ok(Transformed::no(expr)) + } + }, + ) .data()?; expr_to_rewrite_expr_map.insert(expr, new_expr); } @@ -396,8 +409,12 @@ fn build_join( let mut expr_rewrite = TypeCoercionRewriter { schema: new_plan.schema(), }; - computation_project_expr - .insert(name, computer_expr.rewrite(&mut expr_rewrite).data()?); + computation_project_expr.insert( + name, + computer_expr + .rewrite_with_schema(new_plan.schema(), &mut expr_rewrite) + .data()?, + ); } } diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 05b8c28fadd6c..a824f6b7be49f 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -17,27 +17,30 @@ //! Expression simplification API +use std::collections::HashSet; +use std::ops::Not; +use std::{borrow::Cow, sync::Arc}; + use arrow::{ array::{new_null_array, AsArray}, datatypes::{DataType, Field, Schema}, record_batch::RecordBatch, }; -use std::borrow::Cow; -use std::collections::HashSet; -use std::ops::Not; -use std::sync::Arc; +use datafusion_common::tree_node::TreeNodeContainer; use datafusion_common::{ cast::{as_large_list_array, as_list_array}, metadata::FieldMetadata, - tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeRewriter}, + tree_node::{ + Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeRewriter, + }, }; use datafusion_common::{ exec_datafusion_err, internal_err, DFSchema, DataFusionError, Result, ScalarValue, }; use datafusion_expr::{ - and, binary::BinaryTypeCoercer, lit, or, BinaryExpr, Case, ColumnarValue, Expr, Like, - Operator, Volatility, + and, binary::BinaryTypeCoercer, lit, or, simplify::SimplifyContext, BinaryExpr, Case, + ColumnarValue, Expr, Like, Operator, Volatility, }; use datafusion_expr::{expr::ScalarFunction, interval_arithmetic::NullableInterval}; use datafusion_expr::{ @@ -267,7 +270,7 @@ impl ExprSimplifier { /// documentation for more details on type coercion pub fn coerce(&self, expr: Expr, schema: &DFSchema) -> Result { let mut expr_rewrite = TypeCoercionRewriter { schema }; - expr.rewrite(&mut expr_rewrite).data() + expr.rewrite_with_schema(schema, &mut expr_rewrite).data() } /// Input guarantees about the values of columns. @@ -649,7 +652,8 @@ impl<'a> ConstEvaluator<'a> { | Expr::WindowFunction { .. } | Expr::GroupingSet(_) | Expr::Wildcard { .. } - | Expr::Placeholder(_) => false, + | Expr::Placeholder(_) + | Expr::Lambda { .. } => false, Expr::ScalarFunction(ScalarFunction { func, .. }) => { Self::volatility_ok(func.signature().volatility) } @@ -754,6 +758,89 @@ impl<'a, S> Simplifier<'a, S> { impl TreeNodeRewriter for Simplifier<'_, S> { type Node = Expr; + fn f_down(&mut self, expr: Self::Node) -> Result> { + match expr { + Expr::ScalarFunction(ScalarFunction { func, args }) + if args.iter().any(|arg| matches!(arg, Expr::Lambda(_))) => + { + // there's currently no way to adapt a generic SimplifyInfo with lambda parameters, + // so, if the scalar function has any lambda, we materialize a DFSchema using all the + // columns references in every arguments. Than we can call lambdas_schemas_from_args, + // and for each argument, we create a new SimplifyContext with the scoped schema, and + // simplify the argument using this 'sub-context'. Finally, we set Transformed.tnr to + // Jump so the parent context doesn't try to simplify the argument again, without the + // parameters info + + // get all columns references + let mut columns_refs = HashSet::new(); + + for arg in &args { + arg.add_column_refs(&mut columns_refs); + } + + // materialize columns references into qualified fields + let qualified_fields = columns_refs + .into_iter() + .map(|captured_column| { + let expr = Expr::Column(captured_column.clone()); + + Ok(( + captured_column.relation.clone(), + Arc::new(Field::new( + captured_column.name(), + self.info.get_data_type(&expr)?, + self.info.nullable(&expr)?, + )), + )) + }) + .collect::>()?; + + // create a schema using the materialized fields + let dfschema = + DFSchema::new_with_metadata(qualified_fields, Default::default())?; + + let mut scoped_schemas = func + .arguments_schema_from_logical_args(&args, &dfschema)? + .into_iter(); + + let transformed_args = args + .map_elements(|arg| { + let scoped_schema = scoped_schemas.next().unwrap(); + + // create a sub-context, using the scoped schema, that includes information about the lambda parameters + let simplify_context = + SimplifyContext::new(self.info.execution_props()) + .with_schema(Arc::new(scoped_schema.into_owned())); + + let mut simplifier = Simplifier::new(&simplify_context); + + // simplify the argument using it's context + arg.rewrite(&mut simplifier) + })? + .update_data(|args| { + Expr::ScalarFunction(ScalarFunction { func, args }) + }); + + Ok(Transformed::new( + transformed_args.data, + transformed_args.transformed, + // return at least Jump so the parent contex doesn't try again to simplify the arguments + // (and fail because it doesn't contain info about lambdas paramters) + match transformed_args.tnr { + TreeNodeRecursion::Continue | TreeNodeRecursion::Jump => { + TreeNodeRecursion::Jump + } + TreeNodeRecursion::Stop => TreeNodeRecursion::Stop, + }, + )) + + // Ok(transformed_args.update_data(|args| Expr::ScalarFunction(ScalarFunction { func, args}))) + } + // Expr::Lambda(_) => Ok(Transformed::new(expr, false, TreeNodeRecursion::Jump)), + _ => Ok(Transformed::no(expr)), + } + } + /// rewrite the expression simplifying any constant expressions fn f_up(&mut self, expr: Expr) -> Result> { use datafusion_expr::Operator::{ diff --git a/datafusion/optimizer/src/utils.rs b/datafusion/optimizer/src/utils.rs index 81763fa0552fb..d0ae4932628f3 100644 --- a/datafusion/optimizer/src/utils.rs +++ b/datafusion/optimizer/src/utils.rs @@ -23,7 +23,7 @@ use crate::analyzer::type_coercion::TypeCoercionRewriter; use arrow::array::{new_null_array, Array, RecordBatch}; use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::cast::as_boolean_array; -use datafusion_common::tree_node::{TransformedResult, TreeNode}; +use datafusion_common::tree_node::TransformedResult; use datafusion_common::{Column, DFSchema, Result, ScalarValue}; use datafusion_expr::execution_props::ExecutionProps; use datafusion_expr::expr_rewriter::replace_col; @@ -148,7 +148,7 @@ fn evaluate_expr_with_null_column<'a>( fn coerce(expr: Expr, schema: &DFSchema) -> Result { let mut expr_rewrite = TypeCoercionRewriter { schema }; - expr.rewrite(&mut expr_rewrite).data() + expr.rewrite_with_schema(schema, &mut expr_rewrite).data() } #[cfg(test)] diff --git a/datafusion/physical-expr-adapter/src/schema_rewriter.rs b/datafusion/physical-expr-adapter/src/schema_rewriter.rs index 61cc97dae300e..4a81a5c99ac75 100644 --- a/datafusion/physical-expr-adapter/src/schema_rewriter.rs +++ b/datafusion/physical-expr-adapter/src/schema_rewriter.rs @@ -21,12 +21,14 @@ use std::sync::Arc; use arrow::compute::can_cast_types; use arrow::datatypes::{DataType, FieldRef, Schema, SchemaRef}; +use datafusion_common::HashSet; use datafusion_common::{ exec_err, - tree_node::{Transformed, TransformedResult, TreeNode}, + tree_node::{Transformed, TransformedResult}, Result, ScalarValue, }; use datafusion_functions::core::getfield::GetFieldFunc; +use datafusion_physical_expr::PhysicalExprExt; use datafusion_physical_expr::{ expressions::{self, CastExpr, Column}, ScalarFunctionExpr, @@ -217,8 +219,10 @@ impl PhysicalExprAdapter for DefaultPhysicalExprAdapter { physical_file_schema: &self.physical_file_schema, partition_fields: &self.partition_values, }; - expr.transform(|expr| rewriter.rewrite_expr(Arc::clone(&expr))) - .data() + expr.transform_with_lambdas_params(|expr, lambdas_params| { + rewriter.rewrite_expr(Arc::clone(&expr), lambdas_params) + }) + .data() } fn with_partition_values( @@ -242,13 +246,18 @@ impl<'a> DefaultPhysicalExprAdapterRewriter<'a> { fn rewrite_expr( &self, expr: Arc, + lambdas_params: &HashSet, ) -> Result>> { - if let Some(transformed) = self.try_rewrite_struct_field_access(&expr)? { + if let Some(transformed) = + self.try_rewrite_struct_field_access(&expr, lambdas_params)? + { return Ok(Transformed::yes(transformed)); } if let Some(column) = expr.as_any().downcast_ref::() { - return self.rewrite_column(Arc::clone(&expr), column); + if !lambdas_params.contains(column.name()) { + return self.rewrite_column(Arc::clone(&expr), column); + } } Ok(Transformed::no(expr)) @@ -260,6 +269,7 @@ impl<'a> DefaultPhysicalExprAdapterRewriter<'a> { fn try_rewrite_struct_field_access( &self, expr: &Arc, + lambdas_params: &HashSet, ) -> Result>> { let get_field_expr = match ScalarFunctionExpr::try_downcast_func::(expr.as_ref()) { @@ -291,8 +301,8 @@ impl<'a> DefaultPhysicalExprAdapterRewriter<'a> { }; let column = match source_expr.as_any().downcast_ref::() { - Some(column) => column, - None => return Ok(None), + Some(column) if !lambdas_params.contains(column.name()) => column, + _ => return Ok(None), }; let physical_field = @@ -446,6 +456,7 @@ mod tests { use super::*; use arrow::array::{RecordBatch, RecordBatchOptions}; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; + use datafusion_common::hashbrown::HashSet; use datafusion_common::{assert_contains, record_batch, Result, ScalarValue}; use datafusion_expr::Operator; use datafusion_physical_expr::expressions::{col, lit, CastExpr, Column, Literal}; @@ -852,7 +863,9 @@ mod tests { // Test that when a field exists in physical schema, it returns None let column = Arc::new(Column::new("struct_col", 0)) as Arc; - let result = rewriter.try_rewrite_struct_field_access(&column).unwrap(); + let result = rewriter + .try_rewrite_struct_field_access(&column, &HashSet::new()) + .unwrap(); assert!(result.is_none()); // The actual test for the get_field expression would require creating a proper ScalarFunctionExpr diff --git a/datafusion/physical-expr/Cargo.toml b/datafusion/physical-expr/Cargo.toml index b7654a0f6f603..d4c0e1cbe6eb7 100644 --- a/datafusion/physical-expr/Cargo.toml +++ b/datafusion/physical-expr/Cargo.toml @@ -37,6 +37,9 @@ workspace = true [lib] name = "datafusion_physical_expr" +[features] +recursive_protection = ["dep:recursive"] + [dependencies] ahash = { workspace = true } arrow = { workspace = true } @@ -52,6 +55,7 @@ itertools = { workspace = true, features = ["use_std"] } parking_lot = { workspace = true } paste = "^1.0" petgraph = "0.8.3" +recursive = { workspace = true, optional = true } [dev-dependencies] arrow = { workspace = true, features = ["test_utils"] } diff --git a/datafusion/physical-expr/src/async_scalar_function.rs b/datafusion/physical-expr/src/async_scalar_function.rs index b434694a20cc8..a34d3cda47682 100644 --- a/datafusion/physical-expr/src/async_scalar_function.rs +++ b/datafusion/physical-expr/src/async_scalar_function.rs @@ -168,6 +168,7 @@ impl AsyncFuncExpr { number_rows: current_batch.num_rows(), return_field: Arc::clone(&self.return_field), config_options: Arc::clone(&config_options), + lambdas: None, }) .await?, ); @@ -187,6 +188,7 @@ impl AsyncFuncExpr { number_rows: batch.num_rows(), return_field: Arc::clone(&self.return_field), config_options: Arc::clone(&config_options), + lambdas: None, }) .await?, ); diff --git a/datafusion/physical-expr/src/expressions/column.rs b/datafusion/physical-expr/src/expressions/column.rs index 9ca464b304306..c55f42ae333bc 100644 --- a/datafusion/physical-expr/src/expressions/column.rs +++ b/datafusion/physical-expr/src/expressions/column.rs @@ -22,12 +22,13 @@ use std::hash::Hash; use std::sync::Arc; use crate::physical_expr::PhysicalExpr; +use crate::PhysicalExprExt; use arrow::datatypes::FieldRef; use arrow::{ datatypes::{DataType, Schema, SchemaRef}, record_batch::RecordBatch, }; -use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::tree_node::{Transformed, TransformedResult}; use datafusion_common::{internal_err, plan_err, Result}; use datafusion_expr::ColumnarValue; @@ -67,7 +68,8 @@ use datafusion_expr::ColumnarValue; pub struct Column { /// The name of the column (used for debugging and display purposes) name: String, - /// The index of the column in its schema + /// The index of the column in its schema. + /// Within a lambda body, this refer to the lambda scoped schema, not the plan schema. index: usize, } @@ -178,9 +180,9 @@ pub fn with_new_schema( expr: Arc, schema: &SchemaRef, ) -> Result> { - Ok(expr - .transform_up(|expr| { - if let Some(col) = expr.as_any().downcast_ref::() { + expr.transform_up_with_lambdas_params(|expr, lambdas_params| { + match expr.as_any().downcast_ref::() { + Some(col) if !lambdas_params.contains(col.name()) => { let idx = col.index(); let Some(field) = schema.fields().get(idx) else { return plan_err!( @@ -188,12 +190,13 @@ pub fn with_new_schema( ); }; let new_col = Column::new(field.name(), idx); + Ok(Transformed::yes(Arc::new(new_col) as _)) - } else { - Ok(Transformed::no(expr)) } - })? - .data) + _ => Ok(Transformed::no(expr)), + } + }) + .data() } #[cfg(test)] diff --git a/datafusion/physical-expr/src/expressions/lambda.rs b/datafusion/physical-expr/src/expressions/lambda.rs new file mode 100644 index 0000000000000..55110fdf5bf6b --- /dev/null +++ b/datafusion/physical-expr/src/expressions/lambda.rs @@ -0,0 +1,139 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Physical column reference: [`Column`] + +use std::hash::Hash; +use std::sync::Arc; +use std::{any::Any, sync::OnceLock}; + +use crate::expressions::Column; +use crate::physical_expr::PhysicalExpr; +use crate::PhysicalExprExt; +use arrow::{ + datatypes::{DataType, Schema}, + record_batch::RecordBatch, +}; +use datafusion_common::tree_node::TreeNodeRecursion; +use datafusion_common::{internal_err, HashSet, Result}; +use datafusion_expr::ColumnarValue; + +/// Represents a lambda with the given parameters name and body +#[derive(Debug, Eq, Clone)] +pub struct LambdaExpr { + params: Vec, + body: Arc, + captures: OnceLock>, +} + +// Manually derive PartialEq and Hash to work around https://github.com/rust-lang/rust/issues/78808 [https://github.com/apache/datafusion/issues/13196] +impl PartialEq for LambdaExpr { + fn eq(&self, other: &Self) -> bool { + self.params.eq(&other.params) && self.body.eq(&other.body) + } +} + +impl Hash for LambdaExpr { + fn hash(&self, state: &mut H) { + self.params.hash(state); + self.body.hash(state); + } +} + +impl LambdaExpr { + /// Create a new lambda expression with the given parameters and body + pub fn new(params: Vec, body: Arc) -> Self { + Self { + params, + body, + captures: OnceLock::new(), + } + } + + /// Get the lambda's params names + pub fn params(&self) -> &[String] { + &self.params + } + + /// Get the lambda's body + pub fn body(&self) -> &Arc { + &self.body + } + + pub fn captures(&self) -> &HashSet { + self.captures.get_or_init(|| { + let mut indices = HashSet::new(); + + self.body + .apply_with_lambdas_params(|expr, lambdas_params| { + if let Some(column) = expr.as_any().downcast_ref::() { + if !lambdas_params.contains(column.name()) { + indices.insert(column.index()); + } + } + + Ok(TreeNodeRecursion::Continue) + }) + .unwrap(); + + indices + }) + } +} + +impl std::fmt::Display for LambdaExpr { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "({}) -> {}", self.params.join(", "), self.body) + } +} + +impl PhysicalExpr for LambdaExpr { + fn as_any(&self) -> &dyn Any { + self + } + + fn data_type(&self, _input_schema: &Schema) -> Result { + Ok(DataType::Null) + } + + fn nullable(&self, _input_schema: &Schema) -> Result { + Ok(true) + } + + fn evaluate(&self, _batch: &RecordBatch) -> Result { + internal_err!("Lambda::evaluate() should not be called") + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.body] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + Ok(Arc::new(Self { + params: self.params.clone(), + body: Arc::clone(&children[0]), + captures: OnceLock::new(), + })) + } + + fn fmt_sql(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "({}) -> {}", self.params.join(", "), self.body) + } +} diff --git a/datafusion/physical-expr/src/expressions/mod.rs b/datafusion/physical-expr/src/expressions/mod.rs index 59d675753d985..e87941da5ef4c 100644 --- a/datafusion/physical-expr/src/expressions/mod.rs +++ b/datafusion/physical-expr/src/expressions/mod.rs @@ -27,6 +27,7 @@ mod dynamic_filters; mod in_list; mod is_not_null; mod is_null; +mod lambda; mod like; mod literal; mod negative; @@ -49,6 +50,7 @@ pub use dynamic_filters::DynamicFilterPhysicalExpr; pub use in_list::{in_list, InListExpr}; pub use is_not_null::{is_not_null, IsNotNullExpr}; pub use is_null::{is_null, IsNullExpr}; +pub use lambda::LambdaExpr; pub use like::{like, LikeExpr}; pub use literal::{lit, Literal}; pub use negative::{negative, NegativeExpr}; diff --git a/datafusion/physical-expr/src/lib.rs b/datafusion/physical-expr/src/lib.rs index aa8c9e50fd71e..873205f28bef4 100644 --- a/datafusion/physical-expr/src/lib.rs +++ b/datafusion/physical-expr/src/lib.rs @@ -70,6 +70,8 @@ pub use datafusion_physical_expr_common::sort_expr::{ pub use planner::{create_physical_expr, create_physical_exprs}; pub use scalar_function::ScalarFunctionExpr; +pub use scalar_function::PhysicalExprExt; + pub use simplifier::PhysicalExprSimplifier; pub use utils::{conjunction, conjunction_opt, split_conjunction}; diff --git a/datafusion/physical-expr/src/physical_expr.rs b/datafusion/physical-expr/src/physical_expr.rs index c658a8eddc233..2584fc22885c2 100644 --- a/datafusion/physical-expr/src/physical_expr.rs +++ b/datafusion/physical-expr/src/physical_expr.rs @@ -18,11 +18,11 @@ use std::sync::Arc; use crate::expressions::{self, Column}; -use crate::{create_physical_expr, LexOrdering, PhysicalSortExpr}; +use crate::{create_physical_expr, LexOrdering, PhysicalExprExt, PhysicalSortExpr}; use arrow::compute::SortOptions; use arrow::datatypes::{Schema, SchemaRef}; -use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; +use datafusion_common::tree_node::{Transformed, TransformedResult}; use datafusion_common::{plan_err, Result}; use datafusion_common::{DFSchema, HashMap}; use datafusion_expr::execution_props::ExecutionProps; @@ -38,14 +38,14 @@ pub fn add_offset_to_expr( expr: Arc, offset: isize, ) -> Result> { - expr.transform_down(|e| match e.as_any().downcast_ref::() { - Some(col) => { + expr.transform_down_with_lambdas_params(|e, lambdas_params| match e.as_any().downcast_ref::() { + Some(col) if !lambdas_params.contains(col.name()) => { let Some(idx) = col.index().checked_add_signed(offset) else { return plan_err!("Column index overflow"); }; Ok(Transformed::yes(Arc::new(Column::new(col.name(), idx)))) } - None => Ok(Transformed::no(e)), + _ => Ok(Transformed::no(e)), }) .data() } diff --git a/datafusion/physical-expr/src/planner.rs b/datafusion/physical-expr/src/planner.rs index 7790380dffd56..0119c81b8ed94 100644 --- a/datafusion/physical-expr/src/planner.rs +++ b/datafusion/physical-expr/src/planner.rs @@ -17,6 +17,7 @@ use std::sync::Arc; +use crate::expressions::LambdaExpr; use crate::ScalarFunctionExpr; use crate::{ expressions::{self, binary, like, similar_to, Column, Literal}, @@ -30,7 +31,7 @@ use datafusion_common::{ exec_err, not_impl_err, plan_err, DFSchema, Result, ScalarValue, ToDFSchema, }; use datafusion_expr::execution_props::ExecutionProps; -use datafusion_expr::expr::{Alias, Cast, InList, Placeholder, ScalarFunction}; +use datafusion_expr::expr::{Alias, Cast, InList, Lambda, Placeholder, ScalarFunction}; use datafusion_expr::var_provider::is_system_variables; use datafusion_expr::var_provider::VarType; use datafusion_expr::{ @@ -104,7 +105,8 @@ use datafusion_expr::{ /// /// * `e` - The logical expression /// * `input_dfschema` - The DataFusion schema for the input, used to resolve `Column` references -/// to qualified or unqualified fields by name. +/// to qualified or unqualified fields by name. Note that for creating a lambda, this must be +/// scoped lambda schema, and not the outer schema pub fn create_physical_expr( e: &Expr, input_dfschema: &DFSchema, @@ -314,9 +316,28 @@ pub fn create_physical_expr( input_dfschema, execution_props, )?), + Expr::Lambda { .. } => { + exec_err!("Expr::Lambda should be handled by Expr::ScalarFunction, as it can only exist within it") + } Expr::ScalarFunction(ScalarFunction { func, args }) => { - let physical_args = - create_physical_exprs(args, input_dfschema, execution_props)?; + let lambdas_schemas = + func.arguments_schema_from_logical_args(args, input_dfschema)?; + + let physical_args = std::iter::zip(args, lambdas_schemas) + .map(|(expr, schema)| match expr { + Expr::Lambda(Lambda { params, body }) => { + Ok(Arc::new(LambdaExpr::new( + params.clone(), + create_physical_expr(body, &schema, execution_props)?, + )) as Arc) + } + expr => create_physical_expr(expr, &schema, execution_props), + }) + .collect::>>()?; + + //let physical_args = + // create_physical_exprs(args, input_dfschema, execution_props)?; + let config_options = match execution_props.config_options.as_ref() { Some(config_options) => Arc::clone(config_options), None => Arc::new(ConfigOptions::default()), diff --git a/datafusion/physical-expr/src/projection.rs b/datafusion/physical-expr/src/projection.rs index a120ab427e1de..70be717a8436c 100644 --- a/datafusion/physical-expr/src/projection.rs +++ b/datafusion/physical-expr/src/projection.rs @@ -20,11 +20,11 @@ use std::sync::Arc; use crate::expressions::Column; use crate::utils::collect_columns; -use crate::PhysicalExpr; +use crate::{PhysicalExpr, PhysicalExprExt}; use arrow::datatypes::{Field, Schema, SchemaRef}; use datafusion_common::stats::{ColumnStatistics, Precision}; -use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; +use datafusion_common::tree_node::{Transformed, TransformedResult}; use datafusion_common::{internal_datafusion_err, internal_err, plan_err, Result}; use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; @@ -499,13 +499,16 @@ pub fn update_expr( let mut state = RewriteState::Unchanged; let new_expr = Arc::clone(expr) - .transform_up(|expr| { + .transform_up_with_lambdas_params(|expr, lambdas_params| { if state == RewriteState::RewrittenInvalid { return Ok(Transformed::no(expr)); } - let Some(column) = expr.as_any().downcast_ref::() else { - return Ok(Transformed::no(expr)); + let column = match expr.as_any().downcast_ref::() { + Some(column) if !lambdas_params.contains(column.name()) => column, + _ => { + return Ok(Transformed::no(expr)); + } }; if sync_with_child { state = RewriteState::RewrittenValid; @@ -616,14 +619,14 @@ impl ProjectionMapping { let mut map = IndexMap::<_, ProjectionTargets>::new(); for (expr_idx, (expr, name)) in expr.into_iter().enumerate() { let target_expr = Arc::new(Column::new(&name, expr_idx)) as _; - let source_expr = expr.transform_down(|e| match e.as_any().downcast_ref::() { + let source_expr = expr.transform_down_with_schema(input_schema, |e, schema| match e.as_any().downcast_ref::() { Some(col) => { - // Sometimes, an expression and its name in the input_schema + // Sometimes, an expression and its name in the schema // doesn't match. This can cause problems, so we make sure - // that the expression name matches with the name in `input_schema`. + // that the expression name matches with the name in `schema`. // Conceptually, `source_expr` and `expression` should be the same. let idx = col.index(); - let matching_field = input_schema.field(idx); + let matching_field = schema.field(idx); let matching_name = matching_field.name(); if col.name() != matching_name { return internal_err!( @@ -737,21 +740,25 @@ pub fn project_ordering( ) -> Option { let mut projected_exprs = vec![]; for PhysicalSortExpr { expr, options } in ordering.iter() { - let transformed = Arc::clone(expr).transform_up(|expr| { - let Some(col) = expr.as_any().downcast_ref::() else { - return Ok(Transformed::no(expr)); - }; + let transformed = + Arc::clone(expr).transform_up_with_lambdas_params(|expr, lambdas_params| { + let col = match expr.as_any().downcast_ref::() { + Some(col) if !lambdas_params.contains(col.name()) => col, + _ => { + return Ok(Transformed::no(expr)); + } + }; - let name = col.name(); - if let Some((idx, _)) = schema.column_with_name(name) { - // Compute the new column expression (with correct index) after projection: - Ok(Transformed::yes(Arc::new(Column::new(name, idx)))) - } else { - // Cannot find expression in the projected_schema, - // signal this using an Err result - plan_err!("") - } - }); + let name = col.name(); + if let Some((idx, _)) = schema.column_with_name(name) { + // Compute the new column expression (with correct index) after projection: + Ok(Transformed::yes(Arc::new(Column::new(name, idx)))) + } else { + // Cannot find expression in the projected_schema, + // signal this using an Err result + plan_err!("") + } + }); match transformed { Ok(transformed) => { diff --git a/datafusion/physical-expr/src/scalar_function.rs b/datafusion/physical-expr/src/scalar_function.rs index 743d5b99cde95..22fa300f05df4 100644 --- a/datafusion/physical-expr/src/scalar_function.rs +++ b/datafusion/physical-expr/src/scalar_function.rs @@ -30,23 +30,25 @@ //! to a function that supports f64, it is coerced to f64. use std::any::Any; +use std::borrow::Cow; use std::fmt::{self, Debug, Formatter}; use std::hash::{Hash, Hasher}; use std::sync::Arc; -use crate::expressions::Literal; +use crate::expressions::{Column, LambdaExpr, Literal}; use crate::PhysicalExpr; -use arrow::array::{Array, RecordBatch}; -use arrow::datatypes::{DataType, FieldRef, Schema}; +use arrow::array::{Array, NullArray, RecordBatch}; +use arrow::datatypes::{DataType, Field, FieldRef, Schema}; use datafusion_common::config::{ConfigEntry, ConfigOptions}; -use datafusion_common::{internal_err, Result, ScalarValue}; +use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion}; +use datafusion_common::{internal_err, HashSet, Result, ScalarValue}; use datafusion_expr::interval_arithmetic::Interval; use datafusion_expr::sort_properties::ExprProperties; use datafusion_expr::type_coercion::functions::data_types_with_scalar_udf; use datafusion_expr::{ - expr_vec_fmt, ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF, - Volatility, + expr_vec_fmt, ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, + ScalarFunctionLambdaArg, ScalarUDF, ValueOrLambdaParameter, Volatility, }; /// Physical expression of a scalar function @@ -94,10 +96,16 @@ impl ScalarFunctionExpr { schema: &Schema, config_options: Arc, ) -> Result { - let name = fun.name().to_string(); - let arg_fields = args - .iter() - .map(|e| e.return_field(schema)) + let lambdas_schemas = lambdas_schemas_from_args(&fun, &args, schema)?; + + let arg_fields = std::iter::zip(&args, lambdas_schemas) + .map(|(e, schema)| { + if let Some(lambda) = e.as_any().downcast_ref::() { + lambda.body().return_field(&schema) + } else { + e.return_field(&schema) + } + }) .collect::>>()?; // verify that input data types is consistent with function's `TypeSignature` @@ -105,6 +113,7 @@ impl ScalarFunctionExpr { .iter() .map(|f| f.data_type().clone()) .collect::>(); + data_types_with_scalar_udf(&arg_types, &fun)?; let arguments = args @@ -115,11 +124,21 @@ impl ScalarFunctionExpr { .map(|literal| literal.value()) }) .collect::>(); + + let lambdas = args + .iter() + .map(|e| e.as_any().is::()) + .collect::>(); + let ret_args = ReturnFieldArgs { arg_fields: &arg_fields, scalar_arguments: &arguments, + lambdas: &lambdas, }; + let return_field = fun.return_field_from_args(ret_args)?; + let name = fun.name().to_string(); + Ok(Self { fun, name, @@ -260,7 +279,10 @@ impl PhysicalExpr for ScalarFunctionExpr { let args = self .args .iter() - .map(|e| e.evaluate(batch)) + .map(|e| match e.as_any().downcast_ref::() { + Some(_) => Ok(ColumnarValue::Scalar(ScalarValue::Null)), + None => Ok(e.evaluate(batch)?), + }) .collect::>>()?; let arg_fields = self @@ -274,6 +296,111 @@ impl PhysicalExpr for ScalarFunctionExpr { .iter() .all(|arg| matches!(arg, ColumnarValue::Scalar(_))); + let lambdas = if self.args.iter().any(|arg| arg.as_any().is::()) { + let args_metadata = std::iter::zip(&self.args, &arg_fields) + .map( + |(expr, field)| match expr.as_any().downcast_ref::() { + Some(lambda) => { + let mut captures = false; + + expr.apply_with_lambdas_params(|expr, lambdas_params| { + match expr.as_any().downcast_ref::() { + Some(col) if !lambdas_params.contains(col.name()) => { + captures = true; + + Ok(TreeNodeRecursion::Stop) + } + _ => Ok(TreeNodeRecursion::Continue), + } + }) + .unwrap(); + + ValueOrLambdaParameter::Lambda(lambda.params(), captures) + } + None => ValueOrLambdaParameter::Value(Arc::clone(field)), + }, + ) + .collect::>(); + + let params = self.fun().inner().lambdas_parameters(&args_metadata)?; + + let lambdas = std::iter::zip(&self.args, params) + .map(|(arg, lambda_params)| { + arg.as_any() + .downcast_ref::() + .map(|lambda| { + let mut indices = HashSet::new(); + + arg.apply_with_lambdas_params(|expr, lambdas_params| { + if let Some(column) = + expr.as_any().downcast_ref::() + { + if !lambdas_params.contains(column.name()) { + indices.insert( + column.index(), //batch + // .schema_ref() + // .index_of(column.name())?, + ); + } + } + + Ok(TreeNodeRecursion::Continue) + })?; + + //let mut indices = indices.into_iter().collect::>(); + + //indices.sort_unstable(); + + let params = + std::iter::zip(lambda.params(), lambda_params.unwrap()) + .map(|(name, param)| Arc::new(param.with_name(name))) + .collect(); + + let captures = if !indices.is_empty() { + let (fields, columns): (Vec<_>, _) = std::iter::zip( + batch.schema_ref().fields(), + batch.columns(), + ) + .enumerate() + .map(|(column_index, (field, column))| { + if indices.contains(&column_index) { + (Arc::clone(field), Arc::clone(column)) + } else { + ( + Arc::new(Field::new( + field.name(), + DataType::Null, + false, + )), + Arc::new(NullArray::new(column.len())) as _, + ) + } + }) + .unzip(); + + let schema = Arc::new(Schema::new(fields)); + + Some(RecordBatch::try_new(schema, columns)?) + //Some(batch.project(&indices)?) + } else { + None + }; + + Ok(ScalarFunctionLambdaArg { + params, + body: Arc::clone(lambda.body()), + captures, + }) + }) + .transpose() + }) + .collect::>>()?; + + Some(lambdas) + } else { + None + }; + // evaluate the function let output = self.fun.invoke_with_args(ScalarFunctionArgs { args, @@ -281,6 +408,7 @@ impl PhysicalExpr for ScalarFunctionExpr { number_rows: batch.num_rows(), return_field: Arc::clone(&self.return_field), config_options: Arc::clone(&self.config_options), + lambdas, })?; if let ColumnarValue::Array(array) = &output { @@ -365,14 +493,377 @@ impl PhysicalExpr for ScalarFunctionExpr { } } +pub fn lambdas_schemas_from_args<'a>( + fun: &ScalarUDF, + args: &[Arc], + schema: &'a Schema, +) -> Result>> { + let args_metadata = args + .iter() + .map(|e| match e.as_any().downcast_ref::() { + Some(lambda) => { + let mut captures = false; + + e.apply_with_lambdas_params(|expr, lambdas_params| { + match expr.as_any().downcast_ref::() { + Some(col) if !lambdas_params.contains(col.name()) => { + captures = true; + + Ok(TreeNodeRecursion::Stop) + } + _ => Ok(TreeNodeRecursion::Continue), + } + }) + .unwrap(); + + Ok(ValueOrLambdaParameter::Lambda(lambda.params(), captures)) + } + None => Ok(ValueOrLambdaParameter::Value(e.return_field(schema)?)), + }) + .collect::>>()?; + + /*let captures = args + .iter() + .map(|arg| { + if arg.as_any().is::() { + let mut columns = HashSet::new(); + + arg.apply_with_lambdas_params(|n, lambdas_params| { + if let Some(column) = n.as_any().downcast_ref::() { + if !lambdas_params.contains(column.name()) { + columns.insert(schema.index_of(column.name())?); + } + // columns.insert(column.index()); + } + + Ok(TreeNodeRecursion::Continue) + })?; + + Ok(columns) + } else { + Ok(HashSet::new()) + } + }) + .collect::>>()?; */ + + fun.arguments_arrow_schema(&args_metadata, schema) +} + +pub trait PhysicalExprExt: Sized { + fn apply_with_lambdas_params< + 'n, + F: FnMut(&'n Self, &HashSet<&'n str>) -> Result, + >( + &'n self, + f: F, + ) -> Result; + + fn apply_with_schema<'n, F: FnMut(&'n Self, &Schema) -> Result>( + &'n self, + schema: &Schema, + f: F, + ) -> Result; + + fn apply_children_with_schema< + 'n, + F: FnMut(&'n Self, &Schema) -> Result, + >( + &'n self, + schema: &Schema, + f: F, + ) -> Result; + + fn transform_down_with_schema Result>>( + self, + schema: &Schema, + f: F, + ) -> Result>; + + fn transform_up_with_schema Result>>( + self, + schema: &Schema, + f: F, + ) -> Result>; + + fn transform_with_schema Result>>( + self, + schema: &Schema, + f: F, + ) -> Result> { + self.transform_up_with_schema(schema, f) + } + + fn transform_down_with_lambdas_params( + self, + f: impl FnMut(Self, &HashSet) -> Result>, + ) -> Result>; + + fn transform_up_with_lambdas_params( + self, + f: impl FnMut(Self, &HashSet) -> Result>, + ) -> Result>; + + fn transform_with_lambdas_params( + self, + f: impl FnMut(Self, &HashSet) -> Result>, + ) -> Result> { + self.transform_up_with_lambdas_params(f) + } +} + +impl PhysicalExprExt for Arc { + fn apply_with_lambdas_params< + 'n, + F: FnMut(&'n Self, &HashSet<&'n str>) -> Result, + >( + &'n self, + mut f: F, + ) -> Result { + #[cfg_attr(feature = "recursive_protection", recursive::recursive)] + fn apply_with_lambdas_params_impl< + 'n, + F: FnMut( + &'n Arc, + &HashSet<&'n str>, + ) -> Result, + >( + node: &'n Arc, + args: &HashSet<&'n str>, + f: &mut F, + ) -> Result { + match node.as_any().downcast_ref::() { + Some(lambda) => { + let mut args = args.clone(); + + args.extend(lambda.params().iter().map(|v| v.as_str())); + + f(node, &args)?.visit_children(|| { + node.apply_children(|c| { + apply_with_lambdas_params_impl(c, &args, f) + }) + }) + } + _ => f(node, args)?.visit_children(|| { + node.apply_children(|c| apply_with_lambdas_params_impl(c, args, f)) + }), + } + } + + apply_with_lambdas_params_impl(self, &HashSet::new(), &mut f) + } + + fn apply_with_schema<'n, F: FnMut(&'n Self, &Schema) -> Result>( + &'n self, + schema: &Schema, + mut f: F, + ) -> Result { + #[cfg_attr(feature = "recursive_protection", recursive::recursive)] + fn apply_with_lambdas_impl< + 'n, + F: FnMut(&'n Arc, &Schema) -> Result, + >( + node: &'n Arc, + schema: &Schema, + f: &mut F, + ) -> Result { + f(node, schema)?.visit_children(|| { + node.apply_children_with_schema(schema, |c, schema| { + apply_with_lambdas_impl(c, schema, f) + }) + }) + } + + apply_with_lambdas_impl(self, schema, &mut f) + } + + fn apply_children_with_schema< + 'n, + F: FnMut(&'n Self, &Schema) -> Result, + >( + &'n self, + schema: &Schema, + mut f: F, + ) -> Result { + match self.as_any().downcast_ref::() { + Some(scalar_function) + if scalar_function + .args() + .iter() + .any(|arg| arg.as_any().is::()) => + { + let mut lambdas_schemas = lambdas_schemas_from_args( + scalar_function.fun(), + scalar_function.args(), + schema, + )? + .into_iter(); + + self.apply_children(|expr| f(expr, &lambdas_schemas.next().unwrap())) + } + _ => self.apply_children(|e| f(e, schema)), + } + } + + fn transform_down_with_schema< + F: FnMut(Self, &Schema) -> Result>, + >( + self, + schema: &Schema, + mut f: F, + ) -> Result> { + #[cfg_attr(feature = "recursive_protection", recursive::recursive)] + fn transform_down_with_schema_impl< + F: FnMut( + Arc, + &Schema, + ) -> Result>>, + >( + node: Arc, + schema: &Schema, + f: &mut F, + ) -> Result>> { + f(node, schema)?.transform_children(|node| { + map_children_with_schema(node, schema, |n, schema| { + transform_down_with_schema_impl(n, schema, f) + }) + }) + } + + transform_down_with_schema_impl(self, schema, &mut f) + } + + fn transform_up_with_schema Result>>( + self, + schema: &Schema, + mut f: F, + ) -> Result> { + #[cfg_attr(feature = "recursive_protection", recursive::recursive)] + fn transform_up_with_schema_impl< + F: FnMut( + Arc, + &Schema, + ) -> Result>>, + >( + node: Arc, + schema: &Schema, + f: &mut F, + ) -> Result>> { + map_children_with_schema(node, schema, |n, schema| { + transform_up_with_schema_impl(n, schema, f) + })? + .transform_parent(|n| f(n, schema)) + } + + transform_up_with_schema_impl(self, schema, &mut f) + } + + fn transform_up_with_lambdas_params( + self, + mut f: impl FnMut(Self, &HashSet) -> Result>, + ) -> Result> { + #[cfg_attr(feature = "recursive_protection", recursive::recursive)] + fn transform_up_with_lambdas_params_impl< + F: FnMut( + Arc, + &HashSet, + ) -> Result>>, + >( + node: Arc, + params: &HashSet, + f: &mut F, + ) -> Result>> { + map_children_with_lambdas_params(node, params, |n, params| { + transform_up_with_lambdas_params_impl(n, params, f) + })? + .transform_parent(|n| f(n, params)) + } + + transform_up_with_lambdas_params_impl(self, &HashSet::new(), &mut f) + } + + fn transform_down_with_lambdas_params( + self, + mut f: impl FnMut(Self, &HashSet) -> Result>, + ) -> Result> { + #[cfg_attr(feature = "recursive_protection", recursive::recursive)] + fn transform_down_with_lambdas_params_impl< + F: FnMut( + Arc, + &HashSet, + ) -> Result>>, + >( + node: Arc, + params: &HashSet, + f: &mut F, + ) -> Result>> { + f(node, params)?.transform_children(|node| { + map_children_with_lambdas_params(node, params, |node, args| { + transform_down_with_lambdas_params_impl(node, args, f) + }) + }) + } + + transform_down_with_lambdas_params_impl(self, &HashSet::new(), &mut f) + } +} + +fn map_children_with_schema( + node: Arc, + schema: &Schema, + mut f: impl FnMut( + Arc, + &Schema, + ) -> Result>>, +) -> Result>> { + match node.as_any().downcast_ref::() { + Some(fun) if fun.args().iter().any(|arg| arg.as_any().is::()) => { + let mut args_schemas = + lambdas_schemas_from_args(fun.fun(), fun.args(), schema)?.into_iter(); + + node.map_children(|node| f(node, &args_schemas.next().unwrap())) + } + _ => node.map_children(|node| f(node, schema)), + } +} + +fn map_children_with_lambdas_params( + node: Arc, + params: &HashSet, + mut f: impl FnMut( + Arc, + &HashSet, + ) -> Result>>, +) -> Result>> { + match node.as_any().downcast_ref::() { + Some(lambda) => { + let mut params = params.clone(); + + params.extend(lambda.params().iter().cloned()); + + node.map_children(|node| f(node, ¶ms)) + } + None => node.map_children(|node| f(node, params)), + } +} + #[cfg(test)] mod tests { + use std::any::Any; + use std::{borrow::Cow, sync::Arc}; + use super::*; + use super::{lambdas_schemas_from_args, PhysicalExprExt}; use crate::expressions::Column; + use crate::{create_physical_expr, ScalarFunctionExpr}; use arrow::datatypes::{DataType, Field, Schema}; + use datafusion_common::{tree_node::TreeNodeRecursion, DFSchema, HashSet, Result}; + use datafusion_expr::{ + col, expr::Lambda, Expr, ScalarFunctionArgs, ValueOrLambdaParameter, Volatility, + }; use datafusion_expr::{ScalarUDF, ScalarUDFImpl, Signature}; + use datafusion_expr_common::columnar_value::ColumnarValue; use datafusion_physical_expr_common::physical_expr::is_volatile; - use std::any::Any; + use datafusion_physical_expr_common::physical_expr::PhysicalExpr; /// Test helper to create a mock UDF with a specific volatility #[derive(Debug, PartialEq, Eq, Hash)] @@ -444,4 +935,190 @@ mod tests { let stable_arc: Arc = Arc::new(stable_expr); assert!(!is_volatile(&stable_arc)); } + + fn list_list_int() -> Schema { + Schema::new(vec![Field::new( + "v", + DataType::new_list(DataType::new_list(DataType::Int32, false), false), + false, + )]) + } + + fn list_int() -> Schema { + Schema::new(vec![Field::new( + "v", + DataType::new_list(DataType::Int32, false), + false, + )]) + } + + fn int() -> Schema { + Schema::new(vec![Field::new("v", DataType::Int32, false)]) + } + + fn array_transform_udf() -> ScalarUDF { + ScalarUDF::new_from_impl(ArrayTransformFunc::new()) + } + + fn args() -> Vec { + vec![ + col("v"), + Expr::Lambda(Lambda::new( + vec!["v".into()], + array_transform_udf().call(vec![ + col("v"), + Expr::Lambda(Lambda::new(vec!["v".into()], -col("v"))), + ]), + )), + ] + } + + // array_transform(v, |v| -> array_transform(v, |v| -> -v)) + fn array_transform() -> Arc { + let e = array_transform_udf().call(args()); + + create_physical_expr( + &e, + &DFSchema::try_from(list_list_int()).unwrap(), + &Default::default(), + ) + .unwrap() + } + + #[derive(Debug, PartialEq, Eq, Hash)] + struct ArrayTransformFunc { + signature: Signature, + } + + impl ArrayTransformFunc { + pub fn new() -> Self { + Self { + signature: Signature::any(2, Volatility::Immutable), + } + } + } + + impl ScalarUDFImpl for ArrayTransformFunc { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + "array_transform" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + Ok(arg_types[0].clone()) + } + + fn lambdas_parameters( + &self, + args: &[ValueOrLambdaParameter], + ) -> Result>>> { + let ValueOrLambdaParameter::Value(value_field) = &args[0] else { + unimplemented!() + }; + let DataType::List(field) = value_field.data_type() else { + unimplemented!() + }; + + Ok(vec![ + None, + Some(vec![Field::new( + "", + field.data_type().clone(), + field.is_nullable(), + )]), + ]) + } + + fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result { + unimplemented!() + } + } + + #[test] + fn test_lambdas_schemas_from_args() { + let schema = list_list_int(); + let expr = array_transform(); + + let args = expr + .as_any() + .downcast_ref::() + .unwrap() + .args(); + + let schemas = + lambdas_schemas_from_args(&array_transform_udf(), args, &schema).unwrap(); + + assert_eq!(schemas, &[Cow::Borrowed(&schema), Cow::Owned(list_int())]); + } + + #[test] + fn test_apply_with_schema() { + let mut steps = vec![]; + + array_transform() + .apply_with_schema(&list_list_int(), |node, schema| { + steps.push((node.to_string(), schema.clone())); + + Ok(TreeNodeRecursion::Continue) + }) + .unwrap(); + + let expected = [ + ( + "array_transform(v@0, (v) -> array_transform(v@0, (v) -> (- v@0)))", + list_list_int(), + ), + ("(v) -> array_transform(v@0, (v) -> (- v@0))", list_int()), + ("array_transform(v@0, (v) -> (- v@0))", list_int()), + ("(v) -> (- v@0)", int()), + ("(- v@0)", int()), + ("v@0", int()), + ("v@0", int()), + ("v@0", int()), + ] + .map(|(a, b)| (String::from(a), b)); + + assert_eq!(steps, expected); + } + + #[test] + fn test_apply_with_lambdas_params() { + let array_transform = array_transform(); + let mut steps = vec![]; + + array_transform + .apply_with_lambdas_params(|node, params| { + steps.push((node.to_string(), params.clone())); + + Ok(TreeNodeRecursion::Continue) + }) + .unwrap(); + + let expected = [ + ( + "array_transform(v@0, (v) -> array_transform(v@0, (v) -> (- v@0)))", + HashSet::from(["v"]), + ), + ( + "(v) -> array_transform(v@0, (v) -> (- v@0))", + HashSet::from(["v"]), + ), + ("array_transform(v@0, (v) -> (- v@0))", HashSet::from(["v"])), + ("(v) -> (- v@0)", HashSet::from(["v"])), + ("(- v@0)", HashSet::from(["v"])), + ("v@0", HashSet::from(["v"])), + ("v@0", HashSet::from(["v"])), + ("v@0", HashSet::from(["v"])), + ] + .map(|(a, b)| (String::from(a), b)); + + assert_eq!(steps, expected); + } } diff --git a/datafusion/physical-expr/src/simplifier/mod.rs b/datafusion/physical-expr/src/simplifier/mod.rs index 80d6ee0a7b914..dd7e6e314672f 100644 --- a/datafusion/physical-expr/src/simplifier/mod.rs +++ b/datafusion/physical-expr/src/simplifier/mod.rs @@ -19,12 +19,12 @@ use arrow::datatypes::Schema; use datafusion_common::{ - tree_node::{Transformed, TreeNode, TreeNodeRewriter}, + tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeRewriter}, Result, }; use std::sync::Arc; -use crate::PhysicalExpr; +use crate::{PhysicalExpr, PhysicalExprExt}; pub mod unwrap_cast; @@ -48,6 +48,22 @@ impl<'a> PhysicalExprSimplifier<'a> { &mut self, expr: Arc, ) -> Result> { + return expr + .transform_up_with_schema(self.schema, |node, schema| { + // Apply unwrap cast optimization + #[cfg(test)] + let original_type = node.data_type(schema).unwrap(); + let unwrapped = unwrap_cast::unwrap_cast_in_comparison(node, schema)?; + #[cfg(test)] + assert_eq!( + unwrapped.data.data_type(schema).unwrap(), + original_type, + "Simplified expression should have the same data type as the original" + ); + Ok(unwrapped) + }) + .data(); + Ok(expr.rewrite(self)?.data) } } diff --git a/datafusion/physical-expr/src/simplifier/unwrap_cast.rs b/datafusion/physical-expr/src/simplifier/unwrap_cast.rs index d409ce9cb5bf2..1ccfc1cfe84d8 100644 --- a/datafusion/physical-expr/src/simplifier/unwrap_cast.rs +++ b/datafusion/physical-expr/src/simplifier/unwrap_cast.rs @@ -34,22 +34,22 @@ use std::sync::Arc; use arrow::datatypes::{DataType, Schema}; -use datafusion_common::{ - tree_node::{Transformed, TreeNode}, - Result, ScalarValue, -}; +use datafusion_common::{tree_node::Transformed, Result, ScalarValue}; use datafusion_expr::Operator; use datafusion_expr_common::casts::try_cast_literal_to_type; -use crate::expressions::{lit, BinaryExpr, CastExpr, Literal, TryCastExpr}; use crate::PhysicalExpr; +use crate::{ + expressions::{lit, BinaryExpr, CastExpr, Literal, TryCastExpr}, + PhysicalExprExt, +}; /// Attempts to unwrap casts in comparison expressions. pub(crate) fn unwrap_cast_in_comparison( expr: Arc, schema: &Schema, ) -> Result>> { - expr.transform_down(|e| { + expr.transform_down_with_schema(schema, |e, schema| { if let Some(binary) = e.as_any().downcast_ref::() { if let Some(unwrapped) = try_unwrap_cast_binary(binary, schema)? { return Ok(Transformed::yes(unwrapped)); diff --git a/datafusion/physical-expr/src/utils/mod.rs b/datafusion/physical-expr/src/utils/mod.rs index 745ae855efee2..92ecbb7176dc9 100644 --- a/datafusion/physical-expr/src/utils/mod.rs +++ b/datafusion/physical-expr/src/utils/mod.rs @@ -22,6 +22,7 @@ use std::borrow::Borrow; use std::sync::Arc; use crate::expressions::{BinaryExpr, Column}; +use crate::scalar_function::PhysicalExprExt; use crate::tree_node::ExprContext; use crate::PhysicalExpr; use crate::PhysicalSortExpr; @@ -227,9 +228,11 @@ where /// Recursively extract referenced [`Column`]s within a [`PhysicalExpr`]. pub fn collect_columns(expr: &Arc) -> HashSet { let mut columns = HashSet::::new(); - expr.apply(|expr| { + expr.apply_with_lambdas_params(|expr, lambdas_params| { if let Some(column) = expr.as_any().downcast_ref::() { - columns.get_or_insert_owned(column); + if !lambdas_params.contains(column.name()) { + columns.get_or_insert_owned(column); + } } Ok(TreeNodeRecursion::Continue) }) @@ -251,14 +254,16 @@ pub fn reassign_expr_columns( expr: Arc, schema: &Schema, ) -> Result> { - expr.transform_down(|expr| { + expr.transform_down_with_lambdas_params(|expr, lambdas_params| { if let Some(column) = expr.as_any().downcast_ref::() { - let index = schema.index_of(column.name())?; + if !lambdas_params.contains(column.name()) { + let index = schema.index_of(column.name())?; - return Ok(Transformed::yes(Arc::new(Column::new( - column.name(), - index, - )))); + return Ok(Transformed::yes(Arc::new(Column::new( + column.name(), + index, + )))); + } } Ok(Transformed::no(expr)) }) diff --git a/datafusion/physical-optimizer/src/enforce_sorting/sort_pushdown.rs b/datafusion/physical-optimizer/src/enforce_sorting/sort_pushdown.rs index 6e4e784866129..d87e001946414 100644 --- a/datafusion/physical-optimizer/src/enforce_sorting/sort_pushdown.rs +++ b/datafusion/physical-optimizer/src/enforce_sorting/sort_pushdown.rs @@ -29,7 +29,7 @@ use datafusion_expr::JoinType; use datafusion_physical_expr::expressions::Column; use datafusion_physical_expr::utils::collect_columns; use datafusion_physical_expr::{ - add_offset_to_physical_sort_exprs, EquivalenceProperties, + add_offset_to_physical_sort_exprs, EquivalenceProperties, PhysicalExprExt, }; use datafusion_physical_expr_common::sort_expr::{ LexOrdering, LexRequirement, OrderingRequirements, PhysicalSortExpr, @@ -661,20 +661,21 @@ fn handle_custom_pushdown( .into_iter() .map(|req| { let child_schema = plan.children()[maintained_child_idx].schema(); - let updated_columns = req - .expr - .transform_up(|expr| { - if let Some(col) = expr.as_any().downcast_ref::() { - let new_index = col.index() - sub_offset; - Ok(Transformed::yes(Arc::new(Column::new( - child_schema.field(new_index).name(), - new_index, - )))) - } else { - Ok(Transformed::no(expr)) - } - })? - .data; + let updated_columns = + req.expr + .transform_up_with_lambdas_params(|expr, lambdas_params| { + match expr.as_any().downcast_ref::() { + Some(col) if !lambdas_params.contains(col.name()) => { + let new_index = col.index() - sub_offset; + Ok(Transformed::yes(Arc::new(Column::new( + child_schema.field(new_index).name(), + new_index, + )))) + } + _ => Ok(Transformed::no(expr)), + } + })? + .data; Ok(PhysicalSortRequirement::new(updated_columns, req.options)) }) .collect::>>()?; @@ -742,20 +743,21 @@ fn handle_hash_join( .into_iter() .map(|req| { let child_schema = plan.children()[1].schema(); - let updated_columns = req - .expr - .transform_up(|expr| { - if let Some(col) = expr.as_any().downcast_ref::() { - let index = projected_indices[col.index()].index; - Ok(Transformed::yes(Arc::new(Column::new( - child_schema.field(index).name(), - index, - )))) - } else { - Ok(Transformed::no(expr)) - } - })? - .data; + let updated_columns = + req.expr + .transform_up_with_lambdas_params(|expr, lambdas_params| { + match expr.as_any().downcast_ref::() { + Some(col) if !lambdas_params.contains(col.name()) => { + let index = projected_indices[col.index()].index; + Ok(Transformed::yes(Arc::new(Column::new( + child_schema.field(index).name(), + index, + )))) + } + _ => Ok(Transformed::no(expr)), + } + })? + .data; Ok(PhysicalSortRequirement::new(updated_columns, req.options)) }) .collect::>>()?; diff --git a/datafusion/physical-optimizer/src/projection_pushdown.rs b/datafusion/physical-optimizer/src/projection_pushdown.rs index 987e3cb6f713e..8ed81d3874d64 100644 --- a/datafusion/physical-optimizer/src/projection_pushdown.rs +++ b/datafusion/physical-optimizer/src/projection_pushdown.rs @@ -23,6 +23,7 @@ use crate::PhysicalOptimizerRule; use arrow::datatypes::{Fields, Schema, SchemaRef}; use datafusion_common::alias::AliasGenerator; +use datafusion_physical_expr::PhysicalExprExt; use std::collections::HashSet; use std::sync::Arc; @@ -243,9 +244,11 @@ fn minimize_join_filter( rhs_schema: &Schema, ) -> JoinFilter { let mut used_columns = HashSet::new(); - expr.apply(|expr| { + expr.apply_with_lambdas_params(|expr, lambdas_params| { if let Some(col) = expr.as_any().downcast_ref::() { - used_columns.insert(col.index()); + if !lambdas_params.contains(col.name()) { + used_columns.insert(col.index()); + } } Ok(TreeNodeRecursion::Continue) }) @@ -267,17 +270,19 @@ fn minimize_join_filter( .collect::(); let final_expr = expr - .transform_up(|expr| match expr.as_any().downcast_ref::() { - None => Ok(Transformed::no(expr)), - Some(column) => { - let new_idx = used_columns - .iter() - .filter(|idx| **idx < column.index()) - .count(); - let new_column = Column::new(column.name(), new_idx); - Ok(Transformed::yes( - Arc::new(new_column) as Arc - )) + .transform_up_with_lambdas_params(|expr, lambdas_params| { + match expr.as_any().downcast_ref::() { + Some(column) if !lambdas_params.contains(column.name()) => { + let new_idx = used_columns + .iter() + .filter(|idx| **idx < column.index()) + .count(); + let new_column = Column::new(column.name(), new_idx); + Ok(Transformed::yes( + Arc::new(new_column) as Arc + )) + } + _ => Ok(Transformed::no(expr)), } }) .expect("Closure cannot fail"); @@ -380,10 +385,9 @@ impl<'a> JoinFilterRewriter<'a> { // First, add a new projection. The expression must be rewritten, as it is no longer // executed against the filter schema. let new_idx = self.join_side_projections.len(); - let rewritten_expr = expr.transform_up(|expr| { + let rewritten_expr = expr.transform_up_with_lambdas_params(|expr, lambdas_params| { Ok(match expr.as_any().downcast_ref::() { - None => Transformed::no(expr), - Some(column) => { + Some(column) if !lambdas_params.contains(column.name()) => { let intermediate_column = &self.intermediate_column_indices[column.index()]; assert_eq!(intermediate_column.side, self.join_side); @@ -393,6 +397,7 @@ impl<'a> JoinFilterRewriter<'a> { let new_column = Column::new(field.name(), join_side_index); Transformed::yes(Arc::new(new_column) as Arc) } + _ => Transformed::no(expr), }) })?; self.join_side_projections.push((rewritten_expr.data, name)); @@ -415,15 +420,17 @@ impl<'a> JoinFilterRewriter<'a> { join_side: JoinSide, ) -> Result { let mut result = false; - expr.apply(|expr| match expr.as_any().downcast_ref::() { - None => Ok(TreeNodeRecursion::Continue), - Some(c) => { - let column_index = &self.intermediate_column_indices[c.index()]; - if column_index.side == join_side { - result = true; - return Ok(TreeNodeRecursion::Stop); + expr.apply_with_lambdas_params(|expr, lambdas_params| { + match expr.as_any().downcast_ref::() { + Some(c) if !lambdas_params.contains(c.name()) => { + let column_index = &self.intermediate_column_indices[c.index()]; + if column_index.side == join_side { + result = true; + return Ok(TreeNodeRecursion::Stop); + } + Ok(TreeNodeRecursion::Continue) } - Ok(TreeNodeRecursion::Continue) + _ => Ok(TreeNodeRecursion::Continue), } })?; diff --git a/datafusion/physical-plan/src/async_func.rs b/datafusion/physical-plan/src/async_func.rs index 54a76e0ebb971..be72a6af2b509 100644 --- a/datafusion/physical-plan/src/async_func.rs +++ b/datafusion/physical-plan/src/async_func.rs @@ -22,13 +22,13 @@ use crate::{ }; use arrow::array::RecordBatch; use arrow_schema::{Fields, Schema, SchemaRef}; -use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion}; +use datafusion_common::tree_node::{Transformed, TreeNodeRecursion}; use datafusion_common::{internal_err, Result}; use datafusion_execution::{SendableRecordBatchStream, TaskContext}; use datafusion_physical_expr::async_scalar_function::AsyncFuncExpr; use datafusion_physical_expr::equivalence::ProjectionMapping; use datafusion_physical_expr::expressions::Column; -use datafusion_physical_expr::ScalarFunctionExpr; +use datafusion_physical_expr::{PhysicalExprExt, ScalarFunctionExpr}; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use futures::stream::StreamExt; use log::trace; @@ -249,7 +249,7 @@ impl AsyncMapper { schema: &Schema, ) -> Result<()> { // recursively look for references to async functions - physical_expr.apply(|expr| { + physical_expr.apply_with_schema(schema, |expr, schema| { if let Some(scalar_func_expr) = expr.as_any().downcast_ref::() { diff --git a/datafusion/physical-plan/src/joins/stream_join_utils.rs b/datafusion/physical-plan/src/joins/stream_join_utils.rs index 80221a77992ce..b70a8f60508a5 100644 --- a/datafusion/physical-plan/src/joins/stream_join_utils.rs +++ b/datafusion/physical-plan/src/joins/stream_join_utils.rs @@ -35,7 +35,7 @@ use arrow::array::{ }; use arrow::compute::concat_batches; use arrow::datatypes::{ArrowNativeType, Schema, SchemaRef}; -use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; +use datafusion_common::tree_node::{Transformed, TransformedResult}; use datafusion_common::utils::memory::estimate_memory_size; use datafusion_common::{ arrow_datafusion_err, DataFusionError, HashSet, JoinSide, Result, ScalarValue, @@ -44,7 +44,7 @@ use datafusion_expr::interval_arithmetic::Interval; use datafusion_physical_expr::expressions::Column; use datafusion_physical_expr::intervals::cp_solver::ExprIntervalGraph; use datafusion_physical_expr::utils::collect_columns; -use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; +use datafusion_physical_expr::{PhysicalExpr, PhysicalExprExt, PhysicalSortExpr}; use datafusion_physical_expr_common::sort_expr::LexOrdering; use hashbrown::HashTable; @@ -312,13 +312,13 @@ pub fn convert_sort_expr_with_filter_schema( // Since we are sure that one to one column mapping includes all columns, we convert // the sort expression into a filter expression. let converted_filter_expr = expr - .transform_up(|p| { - convert_filter_columns(p.as_ref(), &column_map).map(|transformed| { - match transformed { + .transform_up_with_lambdas_params(|p, lambdas_params| { + convert_filter_columns(p.as_ref(), &column_map, lambdas_params).map( + |transformed| match transformed { Some(transformed) => Transformed::yes(transformed), None => Transformed::no(p), - } - }) + }, + ) }) .data()?; // Search the converted `PhysicalExpr` in filter expression; if an exact @@ -361,14 +361,17 @@ pub fn build_filter_input_order( fn convert_filter_columns( input: &dyn PhysicalExpr, column_map: &HashMap, + lambdas_params: &HashSet, ) -> Result>> { // Attempt to downcast the input expression to a Column type. - Ok(if let Some(col) = input.as_any().downcast_ref::() { - // If the downcast is successful, retrieve the corresponding filter column. - column_map.get(col).map(|c| Arc::new(c.clone()) as _) - } else { - // If the downcast fails, return the input expression as is. - None + Ok(match input.as_any().downcast_ref::() { + Some(col) if !lambdas_params.contains(col.name()) => { + column_map.get(col).map(|c| Arc::new(c.clone()) as _) + } + _ => { + // If the downcast fails, return the input expression as is. + None + } }) } diff --git a/datafusion/physical-plan/src/projection.rs b/datafusion/physical-plan/src/projection.rs index ead2196860cde..ab654e4eee1df 100644 --- a/datafusion/physical-plan/src/projection.rs +++ b/datafusion/physical-plan/src/projection.rs @@ -42,14 +42,13 @@ use std::task::{Context, Poll}; use arrow::datatypes::SchemaRef; use arrow::record_batch::{RecordBatch, RecordBatchOptions}; use datafusion_common::config::ConfigOptions; -use datafusion_common::tree_node::{ - Transformed, TransformedResult, TreeNode, TreeNodeRecursion, -}; +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNodeRecursion}; use datafusion_common::{internal_err, JoinSide, Result}; use datafusion_execution::TaskContext; use datafusion_physical_expr::equivalence::ProjectionMapping; use datafusion_physical_expr::utils::collect_columns; -use datafusion_physical_expr_common::physical_expr::{fmt_sql, PhysicalExprRef}; +use datafusion_physical_expr::{PhysicalExprExt, PhysicalExprRef}; +use datafusion_physical_expr_common::physical_expr::fmt_sql; use datafusion_physical_expr_common::sort_expr::{LexOrdering, LexRequirement}; // Re-exported from datafusion-physical-expr for backwards compatibility // We recommend updating your imports to use datafusion-physical-expr directly @@ -866,10 +865,12 @@ fn try_unifying_projections( projection.expr().iter().for_each(|proj_expr| { proj_expr .expr - .apply(|expr| { + .apply_with_lambdas_params(|expr, lambdas_params| { Ok({ if let Some(column) = expr.as_any().downcast_ref::() { - *column_ref_map.entry(column.clone()).or_default() += 1; + if !lambdas_params.contains(column.name()) { + *column_ref_map.entry(column.clone()).or_default() += 1; + } } TreeNodeRecursion::Continue }) @@ -957,31 +958,31 @@ fn new_columns_for_join_on( .filter_map(|on| { // Rewrite all columns in `on` Arc::clone(*on) - .transform(|expr| { - if let Some(column) = expr.as_any().downcast_ref::() { - // Find the column in the projection expressions - let new_column = projection_exprs - .iter() - .enumerate() - .find(|(_, (proj_column, _))| { - column.name() == proj_column.name() - && column.index() + column_index_offset - == proj_column.index() - }) - .map(|(index, (_, alias))| Column::new(alias, index)); - if let Some(new_column) = new_column { - Ok(Transformed::yes(Arc::new(new_column))) - } else { - // If the column is not found in the projection expressions, - // it means that the column is not projected. In this case, - // we cannot push the projection down. - internal_err!( - "Column {:?} not found in projection expressions", - column - ) + .transform_with_lambdas_params(|expr, lambdas_params| { + match expr.as_any().downcast_ref::() { + Some(column) if !lambdas_params.contains(column.name()) => { + let new_column = projection_exprs + .iter() + .enumerate() + .find(|(_, (proj_column, _))| { + column.name() == proj_column.name() + && column.index() + column_index_offset + == proj_column.index() + }) + .map(|(index, (_, alias))| Column::new(alias, index)); + if let Some(new_column) = new_column { + Ok(Transformed::yes(Arc::new(new_column))) + } else { + // If the column is not found in the projection expressions, + // it means that the column is not projected. In this case, + // we cannot push the projection down. + internal_err!( + "Column {:?} not found in projection expressions", + column + ) + } } - } else { - Ok(Transformed::no(expr)) + _ => Ok(Transformed::no(expr)), } }) .data() diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 2774b5b6ba7c3..b87a50b3f5281 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -622,6 +622,11 @@ pub fn serialize_expr( .unwrap_or(HashMap::new()), })), }, + Expr::Lambda { .. } => { + return Err(Error::General( + "Proto serialization error: Lambda not supported".to_string(), + )) + } }; Ok(expr_node) diff --git a/datafusion/pruning/src/pruning_predicate.rs b/datafusion/pruning/src/pruning_predicate.rs index 380ada10df6e1..c9df93f8b693c 100644 --- a/datafusion/pruning/src/pruning_predicate.rs +++ b/datafusion/pruning/src/pruning_predicate.rs @@ -38,13 +38,14 @@ use datafusion_common::error::Result; use datafusion_common::tree_node::TransformedResult; use datafusion_common::{ internal_datafusion_err, internal_err, plan_datafusion_err, plan_err, - tree_node::{Transformed, TreeNode}, - ScalarValue, + tree_node::Transformed, ScalarValue, }; use datafusion_common::{Column, DFSchema}; use datafusion_expr_common::operator::Operator; use datafusion_physical_expr::utils::{collect_columns, Guarantee, LiteralGuarantee}; -use datafusion_physical_expr::{expressions as phys_expr, PhysicalExprRef}; +use datafusion_physical_expr::{ + expressions as phys_expr, PhysicalExprExt, PhysicalExprRef, +}; use datafusion_physical_expr_common::physical_expr::snapshot_physical_expr; use datafusion_physical_plan::{ColumnarValue, PhysicalExpr}; @@ -1204,9 +1205,9 @@ fn rewrite_column_expr( column_old: &phys_expr::Column, column_new: &phys_expr::Column, ) -> Result> { - e.transform(|expr| { + e.transform_with_lambdas_params(|expr, lambdas_params| { if let Some(column) = expr.as_any().downcast_ref::() { - if column == column_old { + if !lambdas_params.contains(column.name()) && column == column_old { return Ok(Transformed::yes(Arc::new(column_new.clone()))); } } diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index 50e479af36204..c13fd33104eb0 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -22,10 +22,10 @@ use datafusion_common::{ internal_datafusion_err, internal_err, not_impl_err, plan_datafusion_err, plan_err, DFSchema, Dependency, Diagnostic, Result, Span, }; -use datafusion_expr::expr::{ - NullTreatment, ScalarFunction, Unnest, WildcardOptions, WindowFunction, -}; -use datafusion_expr::planner::{PlannerResult, RawAggregateExpr, RawWindowExpr}; +use datafusion_expr::expr::{Lambda, ScalarFunction, Unnest}; +use datafusion_expr::expr::{NullTreatment, WildcardOptions, WindowFunction}; +use datafusion_expr::planner::PlannerResult; +use datafusion_expr::planner::{RawAggregateExpr, RawWindowExpr}; use datafusion_expr::{expr, Expr, ExprSchemable, WindowFrame, WindowFunctionDefinition}; use sqlparser::ast::{ DuplicateTreatment, Expr as SQLExpr, Function as SQLFunction, FunctionArg, @@ -724,6 +724,26 @@ impl SqlToRel<'_, S> { let arg_name = crate::utils::normalize_ident(name); Ok((expr, Some(arg_name))) } + FunctionArg::Unnamed(FunctionArgExpr::Expr(SQLExpr::Lambda( + sqlparser::ast::LambdaFunction { params, body }, + ))) => { + let params = params + .into_iter() + .map(|v| v.to_string()) + .collect::>(); + + Ok(( + Expr::Lambda(Lambda { + params: params.clone(), + body: Box::new(self.sql_expr_to_logical_expr( + *body, + schema, + &mut planner_context.clone().with_lambda_parameters(params), + )?), + }), + None, + )) + } FunctionArg::Unnamed(FunctionArgExpr::Expr(arg)) => { let expr = self.sql_expr_to_logical_expr(arg, schema, planner_context)?; Ok((expr, None)) diff --git a/datafusion/sql/src/expr/identifier.rs b/datafusion/sql/src/expr/identifier.rs index 3c57d195ade67..dc39cb4de055d 100644 --- a/datafusion/sql/src/expr/identifier.rs +++ b/datafusion/sql/src/expr/identifier.rs @@ -53,6 +53,19 @@ impl SqlToRel<'_, S> { // identifier. (e.g. it is "foo.bar" not foo.bar) let normalize_ident = self.ident_normalizer.normalize(id); + if planner_context + .lambdas_parameters() + .contains(&normalize_ident) + { + let mut column = Column::new_unqualified(normalize_ident); + if self.options.collect_spans { + if let Some(span) = Span::try_from_sqlparser_span(id_span) { + column.spans_mut().add_span(span); + } + } + return Ok(Expr::Column(column)); + } + // Check for qualified field with unqualified name if let Ok((qualifier, _)) = schema.qualified_field_with_unqualified_name(normalize_ident.as_str()) diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index 7bac0337672dc..2992378fd1d6c 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -28,13 +28,11 @@ use datafusion_common::datatype::{DataTypeExt, FieldExt}; use datafusion_common::error::add_possible_columns_to_diag; use datafusion_common::TableReference; use datafusion_common::{ - field_not_found, internal_err, plan_datafusion_err, DFSchemaRef, Diagnostic, - SchemaError, + field_not_found, plan_datafusion_err, DFSchemaRef, Diagnostic, HashSet, SchemaError, }; use datafusion_common::{not_impl_err, plan_err, DFSchema, DataFusionError, Result}; use datafusion_expr::logical_plan::{LogicalPlan, LogicalPlanBuilder}; pub use datafusion_expr::planner::ContextProvider; -use datafusion_expr::utils::find_column_exprs; use datafusion_expr::{col, Expr}; use sqlparser::ast::{ArrayElemTypeDef, ExactNumberInfo, TimezoneInfo}; use sqlparser::ast::{ColumnDef as SQLColumnDef, ColumnOption}; @@ -267,6 +265,8 @@ pub struct PlannerContext { outer_from_schema: Option, /// The query schema defined by the table create_table_schema: Option, + /// The lambda introduced columns names + lambdas_parameters: HashSet, } impl Default for PlannerContext { @@ -284,6 +284,7 @@ impl PlannerContext { outer_query_schema: None, outer_from_schema: None, create_table_schema: None, + lambdas_parameters: HashSet::new(), } } @@ -370,6 +371,19 @@ impl PlannerContext { self.ctes.get(cte_name).map(|cte| cte.as_ref()) } + pub fn lambdas_parameters(&self) -> &HashSet { + &self.lambdas_parameters + } + + pub fn with_lambda_parameters( + mut self, + arguments: impl IntoIterator, + ) -> Self { + self.lambdas_parameters.extend(arguments); + + self + } + /// Remove the plan of CTE / Subquery for the specified name pub(super) fn remove_cte(&mut self, cte_name: &str) { self.ctes.remove(cte_name); @@ -531,10 +545,11 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { schema: &DFSchema, exprs: &[Expr], ) -> Result<()> { - find_column_exprs(exprs) + exprs .iter() - .try_for_each(|col| match col { - Expr::Column(col) => match &col.relation { + .flat_map(|expr| expr.column_refs()) + .try_for_each(|col| { + match &col.relation { Some(r) => schema.field_with_qualified_name(r, &col.name).map(|_| ()), None => { if !schema.fields_with_unqualified_name(&col.name).is_empty() { @@ -584,8 +599,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { err.with_diagnostic(diagnostic) } _ => err, - }), - _ => internal_err!("Not a column"), + }) }) } diff --git a/datafusion/sql/src/select.rs b/datafusion/sql/src/select.rs index 42013a76a8657..0e7490d2c780b 100644 --- a/datafusion/sql/src/select.rs +++ b/datafusion/sql/src/select.rs @@ -540,9 +540,11 @@ impl SqlToRel<'_, S> { None => { let mut columns = HashSet::new(); for expr in &aggr_expr { - expr.apply(|expr| { + expr.apply_with_lambdas_params(|expr, lambdas_params| { if let Expr::Column(c) = expr { - columns.insert(Expr::Column(c.clone())); + if !c.is_lambda_parameter(lambdas_params) { + columns.insert(Expr::Column(c.clone())); + } } Ok(TreeNodeRecursion::Continue) }) diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index 97f2b58bf8402..67ca92bb1c1f1 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -15,13 +15,14 @@ // specific language governing permissions and limitations // under the License. -use datafusion_expr::expr::{AggregateFunctionParams, Unnest, WindowFunctionParams}; +use datafusion_expr::expr::{AggregateFunctionParams, WindowFunctionParams}; +use datafusion_expr::expr::{Lambda, Unnest}; use sqlparser::ast::Value::SingleQuotedString; use sqlparser::ast::{ - self, Array, BinaryOperator, CaseWhen, DuplicateTreatment, Expr as AstExpr, Function, - Ident, Interval, ObjectName, OrderByOptions, Subscript, TimezoneInfo, UnaryOperator, - ValueWithSpan, + self, Array, BinaryOperator, Expr as AstExpr, Function, Ident, Interval, + LambdaFunction, ObjectName, Subscript, TimezoneInfo, UnaryOperator, }; +use sqlparser::ast::{CaseWhen, DuplicateTreatment, OrderByOptions, ValueWithSpan}; use std::sync::Arc; use std::vec; @@ -527,6 +528,14 @@ impl Unparser<'_> { } Expr::OuterReferenceColumn(_, col) => self.col_to_sql(col), Expr::Unnest(unnest) => self.unnest_to_sql(unnest), + Expr::Lambda(Lambda { params, body }) => { + Ok(ast::Expr::Lambda(LambdaFunction { + params: ast::OneOrManyWithParens::Many( + params.iter().map(|param| param.as_str().into()).collect(), + ), + body: Box::new(self.expr_to_sql_inner(body)?), + })) + } } } diff --git a/datafusion/sql/src/unparser/plan.rs b/datafusion/sql/src/unparser/plan.rs index e7535338b7677..c218ce547b312 100644 --- a/datafusion/sql/src/unparser/plan.rs +++ b/datafusion/sql/src/unparser/plan.rs @@ -40,7 +40,7 @@ use crate::unparser::{ast::UnnestRelationBuilder, rewrite::rewrite_qualify}; use crate::utils::UNNEST_PLACEHOLDER; use datafusion_common::{ internal_err, not_impl_err, - tree_node::{TransformedResult, TreeNode}, + tree_node::TransformedResult, Column, DataFusionError, Result, ScalarValue, TableReference, }; use datafusion_expr::expr::OUTER_REFERENCE_COLUMN_PREFIX; @@ -1131,7 +1131,7 @@ impl Unparser<'_> { .cloned() .map(|expr| { if let Some(ref mut rewriter) = filter_alias_rewriter { - expr.rewrite(rewriter).data() + expr.rewrite_with_lambdas_params(rewriter).data() } else { Ok(expr) } @@ -1197,7 +1197,7 @@ impl Unparser<'_> { .cloned() .map(|expr| { if let Some(ref mut rewriter) = alias_rewriter { - expr.rewrite(rewriter).data() + expr.rewrite_with_lambdas_params(rewriter).data() } else { Ok(expr) } diff --git a/datafusion/sql/src/unparser/rewrite.rs b/datafusion/sql/src/unparser/rewrite.rs index c961f1d6f1f0c..58f4435095517 100644 --- a/datafusion/sql/src/unparser/rewrite.rs +++ b/datafusion/sql/src/unparser/rewrite.rs @@ -20,10 +20,11 @@ use std::{collections::HashSet, sync::Arc}; use arrow::datatypes::Schema; use datafusion_common::tree_node::TreeNodeContainer; use datafusion_common::{ - tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeRewriter}, + tree_node::{Transformed, TransformedResult, TreeNode}, Column, HashMap, Result, TableReference, }; use datafusion_expr::expr::{Alias, UNNEST_COLUMN_PREFIX}; +use datafusion_expr::tree_node::TreeNodeRewriterWithPayload; use datafusion_expr::{Expr, LogicalPlan, Projection, Sort, SortExpr}; use sqlparser::ast::Ident; @@ -466,12 +467,17 @@ pub struct TableAliasRewriter<'a> { pub alias_name: TableReference, } -impl TreeNodeRewriter for TableAliasRewriter<'_> { +impl TreeNodeRewriterWithPayload for TableAliasRewriter<'_> { type Node = Expr; + type Payload<'a> = &'a datafusion_common::HashSet; - fn f_down(&mut self, expr: Expr) -> Result> { + fn f_down( + &mut self, + expr: Expr, + lambdas_params: &datafusion_common::HashSet, + ) -> Result> { match expr { - Expr::Column(column) => { + Expr::Column(column) if !column.is_lambda_parameter(lambdas_params) => { if let Ok(field) = self.table_schema.field_with_name(&column.name) { let new_column = Column::new(Some(self.alias_name.clone()), field.name().clone()); diff --git a/datafusion/sql/src/unparser/utils.rs b/datafusion/sql/src/unparser/utils.rs index 8b3791017a8af..f785f640dbcee 100644 --- a/datafusion/sql/src/unparser/utils.rs +++ b/datafusion/sql/src/unparser/utils.rs @@ -161,11 +161,11 @@ pub(crate) fn find_window_nodes_within_select<'a>( /// For example, if expr contains the column expr "__unnest_placeholder(make_array(Int64(1),Int64(2),Int64(2),Int64(5),NULL),depth=1)" /// it will be transformed into an actual unnest expression UNNEST([1, 2, 2, 5, NULL]) pub(crate) fn unproject_unnest_expr(expr: Expr, unnest: &Unnest) -> Result { - expr.transform(|sub_expr| { + expr.transform_up_with_lambdas_params(|sub_expr, lambdas_params| { if let Expr::Column(col_ref) = &sub_expr { // Check if the column is among the columns to run unnest on. // Currently, only List/Array columns (defined in `list_type_columns`) are supported for unnesting. - if unnest.list_type_columns.iter().any(|e| e.1.output_column.name == col_ref.name) { + if !col_ref.is_lambda_parameter(lambdas_params) && unnest.list_type_columns.iter().any(|e| e.1.output_column.name == col_ref.name) { if let Ok(idx) = unnest.schema.index_of_column(col_ref) { if let LogicalPlan::Projection(Projection { expr, .. }) = unnest.input.as_ref() { if let Some(unprojected_expr) = expr.get(idx) { @@ -195,22 +195,21 @@ pub(crate) fn unproject_agg_exprs( agg: &Aggregate, windows: Option<&[&Window]>, ) -> Result { - expr.transform(|sub_expr| { - if let Expr::Column(c) = sub_expr { - if let Some(unprojected_expr) = find_agg_expr(agg, &c)? { - Ok(Transformed::yes(unprojected_expr.clone())) - } else if let Some(unprojected_expr) = - windows.and_then(|w| find_window_expr(w, &c.name).cloned()) - { - // Window function can contain an aggregation columns, e.g., 'avg(sum(ss_sales_price)) over ...' that needs to be unprojected - Ok(Transformed::yes(unproject_agg_exprs(unprojected_expr, agg, None)?)) - } else { - internal_err!( - "Tried to unproject agg expr for column '{}' that was not found in the provided Aggregate!", &c.name - ) - } - } else { - Ok(Transformed::no(sub_expr)) + expr.transform_up_with_lambdas_params(|sub_expr, lambdas_params| { + match sub_expr { + Expr::Column(c) if !c.is_lambda_parameter(lambdas_params) => if let Some(unprojected_expr) = find_agg_expr(agg, &c)? { + Ok(Transformed::yes(unprojected_expr.clone())) + } else if let Some(unprojected_expr) = + windows.and_then(|w| find_window_expr(w, &c.name).cloned()) + { + // Window function can contain an aggregation columns, e.g., 'avg(sum(ss_sales_price)) over ...' that needs to be unprojected + Ok(Transformed::yes(unproject_agg_exprs(unprojected_expr, agg, None)?)) + } else { + internal_err!( + "Tried to unproject agg expr for column '{}' that was not found in the provided Aggregate!", &c.name + ) + }, + _ => Ok(Transformed::no(sub_expr)), } }) .map(|e| e.data) @@ -222,16 +221,15 @@ pub(crate) fn unproject_agg_exprs( /// For example, if expr contains the column expr "COUNT(*) PARTITION BY id" it will be transformed /// into an actual window expression as identified in the window node. pub(crate) fn unproject_window_exprs(expr: Expr, windows: &[&Window]) -> Result { - expr.transform(|sub_expr| { - if let Expr::Column(c) = sub_expr { + expr.transform_up_with_lambdas_params(|sub_expr, lambdas_params| match sub_expr { + Expr::Column(c) if !c.is_lambda_parameter(lambdas_params) => { if let Some(unproj) = find_window_expr(windows, &c.name) { Ok(Transformed::yes(unproj.clone())) } else { Ok(Transformed::no(Expr::Column(c))) } - } else { - Ok(Transformed::no(sub_expr)) } + _ => Ok(Transformed::no(sub_expr)), }) .map(|e| e.data) } @@ -376,7 +374,7 @@ pub(crate) fn try_transform_to_simple_table_scan_with_filters( .cloned() .map(|expr| { if let Some(ref mut rewriter) = filter_alias_rewriter { - expr.rewrite(rewriter).data() + expr.rewrite_with_lambdas_params(rewriter).data() } else { Ok(expr) } diff --git a/datafusion/sql/src/utils.rs b/datafusion/sql/src/utils.rs index 3c86d2d04905f..6380412e3b5ee 100644 --- a/datafusion/sql/src/utils.rs +++ b/datafusion/sql/src/utils.rs @@ -23,16 +23,16 @@ use arrow::datatypes::{ DataType, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION, DECIMAL_DEFAULT_SCALE, }; use datafusion_common::tree_node::{ - Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeRewriter, + Transformed, TransformedResult, TreeNode, TreeNodeRecursion, }; use datafusion_common::{ - exec_datafusion_err, exec_err, internal_err, plan_err, Column, DFSchemaRef, - Diagnostic, HashMap, Result, ScalarValue, + exec_datafusion_err, exec_err, internal_err, plan_err, Column, DFSchema, Diagnostic, HashMap, Result, ScalarValue }; use datafusion_expr::builder::get_struct_unnested_columns; use datafusion_expr::expr::{ Alias, GroupingSet, Unnest, WindowFunction, WindowFunctionParams, }; +use datafusion_expr::tree_node::TreeNodeRewriterWithPayload; use datafusion_expr::utils::{expr_as_column_expr, find_column_exprs}; use datafusion_expr::{ col, expr_vec_fmt, ColumnUnnestList, Expr, ExprSchemable, LogicalPlan, @@ -44,9 +44,9 @@ use sqlparser::ast::{Ident, Value}; /// Make a best-effort attempt at resolving all columns in the expression tree pub(crate) fn resolve_columns(expr: &Expr, plan: &LogicalPlan) -> Result { expr.clone() - .transform_up(|nested_expr| { + .transform_up_with_lambdas_params(|nested_expr, lambdas_params| { match nested_expr { - Expr::Column(col) => { + Expr::Column(col) if !col.is_lambda_parameter(lambdas_params) => { let (qualifier, field) = plan.schema().qualified_field_from_column(&col)?; Ok(Transformed::yes(Expr::Column(Column::from(( @@ -81,6 +81,7 @@ pub(crate) fn rebase_expr( base_exprs: &[Expr], plan: &LogicalPlan, ) -> Result { + //todo user transform_down_with_lambdas_params expr.clone() .transform_down(|nested_expr| { if base_exprs.contains(&nested_expr) { @@ -231,8 +232,8 @@ pub(crate) fn resolve_aliases_to_exprs( expr: Expr, aliases: &HashMap, ) -> Result { - expr.transform_up(|nested_expr| match nested_expr { - Expr::Column(c) if c.relation.is_none() => { + expr.transform_up_with_lambdas_params(|nested_expr, lambdas_params| match nested_expr { + Expr::Column(c) if c.relation.is_none() && !c.is_lambda_parameter(lambdas_params) => { if let Some(aliased_expr) = aliases.get(&c.name) { Ok(Transformed::yes(aliased_expr.clone())) } else { @@ -371,7 +372,6 @@ This is only usedful when used with transform down up A full example of how the transformation works: */ struct RecursiveUnnestRewriter<'a> { - input_schema: &'a DFSchemaRef, root_expr: &'a Expr, // Useful to detect which child expr is a part of/ not a part of unnest operation top_most_unnest: Option, @@ -405,6 +405,7 @@ impl RecursiveUnnestRewriter<'_> { alias_name: String, expr_in_unnest: &Expr, struct_allowed: bool, + input_schema: &DFSchema, ) -> Result> { let inner_expr_name = expr_in_unnest.schema_name().to_string(); @@ -418,7 +419,7 @@ impl RecursiveUnnestRewriter<'_> { // column name as is, to comply with group by and order by let placeholder_column = Column::from_name(placeholder_name.clone()); - let (data_type, _) = expr_in_unnest.data_type_and_nullable(self.input_schema)?; + let (data_type, _) = expr_in_unnest.data_type_and_nullable(input_schema)?; match data_type { DataType::Struct(inner_fields) => { @@ -468,17 +469,18 @@ impl RecursiveUnnestRewriter<'_> { } } -impl TreeNodeRewriter for RecursiveUnnestRewriter<'_> { +impl TreeNodeRewriterWithPayload for RecursiveUnnestRewriter<'_> { type Node = Expr; + type Payload<'a> = &'a DFSchema; /// This downward traversal needs to keep track of: /// - Whether or not some unnest expr has been visited from the top util the current node /// - If some unnest expr has been visited, maintain a stack of such information, this /// is used to detect if some recursive unnest expr exists (e.g **unnest(unnest(unnest(3d column))))** - fn f_down(&mut self, expr: Expr) -> Result> { + fn f_down(&mut self, expr: Expr, input_schema: &DFSchema) -> Result> { if let Expr::Unnest(ref unnest_expr) = expr { let (data_type, _) = - unnest_expr.expr.data_type_and_nullable(self.input_schema)?; + unnest_expr.expr.data_type_and_nullable(input_schema)?; self.consecutive_unnest.push(Some(unnest_expr.clone())); // if expr inside unnest is a struct, do not consider // the next unnest as consecutive unnest (if any) @@ -532,7 +534,7 @@ impl TreeNodeRewriter for RecursiveUnnestRewriter<'_> { /// column2 /// ``` /// - fn f_up(&mut self, expr: Expr) -> Result> { + fn f_up(&mut self, expr: Expr, input_schema: &DFSchema) -> Result> { if let Expr::Unnest(ref traversing_unnest) = expr { if traversing_unnest == self.top_most_unnest.as_ref().unwrap() { self.top_most_unnest = None; @@ -568,6 +570,7 @@ impl TreeNodeRewriter for RecursiveUnnestRewriter<'_> { expr.schema_name().to_string(), inner_expr, struct_allowed, + input_schema, )?; if struct_allowed { self.transformed_root_exprs = Some(transformed_exprs.clone()); @@ -619,7 +622,6 @@ pub(crate) fn rewrite_recursive_unnest_bottom_up( original_expr: &Expr, ) -> Result> { let mut rewriter = RecursiveUnnestRewriter { - input_schema: input.schema(), root_expr: original_expr, top_most_unnest: None, consecutive_unnest: vec![], @@ -641,7 +643,7 @@ pub(crate) fn rewrite_recursive_unnest_bottom_up( data: transformed_expr, transformed, tnr: _, - } = original_expr.clone().rewrite(&mut rewriter)?; + } = original_expr.clone().rewrite_with_schema(input.schema(), &mut rewriter)?; if !transformed { // TODO: remove the next line after `Expr::Wildcard` is removed diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index 00629c392df48..29ea8cb786072 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -5866,10 +5866,10 @@ select array_ndims(arrow_cast([null], 'List(List(List(Int64)))')); 3 # array_ndims scalar function #2 -query II -select array_ndims(array_repeat(array_repeat(array_repeat(1, 3), 2), 1)), array_ndims([[[[[[[[[[[[[[[[[[[[[1]]]]]]]]]]]]]]]]]]]]]); ----- -3 21 +#query II +#select array_ndims(array_repeat(array_repeat(array_repeat(1, 3), 2), 1)), array_ndims([[[[[[[[[[[[[[[[[[[[[1]]]]]]]]]]]]]]]]]]]]]); +#---- +#3 21 # array_ndims scalar function #3 query II diff --git a/datafusion/sqllogictest/test_files/lambda.slt b/datafusion/sqllogictest/test_files/lambda.slt new file mode 100644 index 0000000000000..0043eae17a60c --- /dev/null +++ b/datafusion/sqllogictest/test_files/lambda.slt @@ -0,0 +1,166 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +############# +## Array Expressions Tests +############# + +statement ok +set datafusion.sql_parser.dialect = databricks; + +statement ok +CREATE TABLE tt +AS VALUES +([1, 50], 10), +([4, 50], 40); + +statement ok +CREATE TABLE t AS SELECT 1 as f, [ [ [2, 3], [2] ], [ [1] ], [ [] ] ] as v, 1 as n; + +query I? +SELECT t.n, array_transform([], e1 -> t.n) from t; +---- +1 [] + +query ? +SELECT array_transform([1], e1 -> (select n from t)); +---- +[1] + +query ? +SELECT array_transform(v, (v1, i) -> array_transform(v1, (v2, j) -> array_transform(v2, v3 -> j)) ) from t; +---- +[[[0, 0], [1]], [[0]], [[]]] + +query I? +SELECT t.n, array_transform([1, 2], (e) -> n) from t; +---- +1 [1, 1] + +# selection pushdown not working yet +query ? +SELECT array_transform([1, 2], (e) -> n) from t; +---- +[1, 1] + +query ? +SELECT array_transform([1, 2], (e, i) -> i) from t; +---- +[0, 1] + +# type coercion +query ? +SELECT array_transform([1, 2], (e, i) -> e+i) from t; +---- +[1, 3] + +query TT +EXPLAIN SELECT array_transform([1, 2], (e, i) -> e+i); +---- +logical_plan +01)Projection: array_transform(List([1, 2]), (e, i) -> e + CAST(i AS Int64)) AS array_transform(make_array(Int64(1),Int64(2)),(e, i) -> e + i) +02)--EmptyRelation: rows=1 +physical_plan +01)ProjectionExec: expr=[array_transform([1, 2], (e, i) -> e@0 + CAST(i@1 AS Int64)) as array_transform(make_array(Int64(1),Int64(2)),(e, i) -> e + i)] +02)--PlaceholderRowExec + +#cse +query TT +explain select n + 1, array_transform([1], v -> v + n + 1) from t; +---- +logical_plan +01)Projection: t.n + Int64(1), array_transform(List([1]), (v) -> v + t.n + Int64(1)) AS array_transform(make_array(Int64(1)),(v) -> v + t.n + Int64(1)) +02)--TableScan: t projection=[n] +physical_plan +01)ProjectionExec: expr=[n@0 + 1 as t.n + Int64(1), array_transform([1], (v) -> v@1 + n@0 + 1) as array_transform(make_array(Int64(1)),(v) -> v + t.n + Int64(1))] +02)--DataSourceExec: partitions=1, partition_sizes=[1] + +query ? +SELECT array_transform([1,2,3,4,5], v -> v*2); +---- +[2, 4, 6, 8, 10] + +query ? +SELECT array_transform([1,2,3,4,5], v -> 2); +---- +[2, 2, 2, 2, 2] + +query ? +SELECT array_transform([[1,2],[3,4,5]], v -> array_transform(v, v -> v*2)); +---- +[[2, 4], [6, 8, 10]] + +query ? +SELECT array_transform([1,2,3,4,5], v -> repeat("a", v)); +---- +[a, aa, aaa, aaaa, aaaaa] + +query ? +SELECT array_transform([1,2,3,4,5], v -> list_repeat("a", v)); +---- +[[a], [a, a], [a, a, a], [a, a, a, a], [a, a, a, a, a]] + +query TT +EXPLAIN SELECT array_transform([1,2,3,4,5], v -> v*2); +---- +logical_plan +01)Projection: array_transform(List([1, 2, 3, 4, 5]), (v) -> v * Int64(2)) AS array_transform(make_array(Int64(1),Int64(2),Int64(3),Int64(4),Int64(5)),(v) -> v * Int64(2)) +02)--EmptyRelation: rows=1 +physical_plan +01)ProjectionExec: expr=[array_transform([1, 2, 3, 4, 5], (v) -> v@0 * 2) as array_transform(make_array(Int64(1),Int64(2),Int64(3),Int64(4),Int64(5)),(v) -> v * Int64(2))] +02)--PlaceholderRowExec + + +query I?? +SELECT t.n, t.v, array_transform(t.v, (v, i) -> array_transform(v, (v, j) -> n) ) from t; +---- +1 [[[2, 3], [2]], [[1]], [[]]] [[1, 1], [1], [1]] + +query ? +SELECT array_transform([1,2,3,4,5], v -> v*2); +---- +[2, 4, 6, 8, 10] + + +# expr simplifier +query TT +EXPLAIN SELECT v = v, array_transform([1], v -> v = v) from t; +---- +logical_plan +01)Projection: Boolean(true) AS t.v = t.v, array_transform(List([1]), (v) -> v IS NOT NULL OR Boolean(NULL)) AS array_transform(make_array(Int64(1)),(v) -> v = v) +02)--TableScan: t projection=[] +physical_plan +01)ProjectionExec: expr=[true as t.v = t.v, array_transform([1], (v) -> v@0 IS NOT NULL OR NULL) as array_transform(make_array(Int64(1)),(v) -> v = v)] +02)--DataSourceExec: partitions=1, partition_sizes=[1] + + +query error +select array_transform(); +---- +DataFusion error: Error during planning: 'array_transform' does not support zero arguments No function matches the given name and argument types 'array_transform()'. You might need to add explicit type casts. + Candidate functions: + array_transform(Any, Any) + + +query error DataFusion error: Execution error: expected list, got Field \{ name: "Int64\(1\)", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: \{\} \} +select array_transform(1, v -> v*2); + +query error DataFusion error: Execution error: array_transform expects a value follewed by a lambda, got \[Lambda\(\["v"\], false\), Value\(Field \{ name: "make_array\(Int64\(1\),Int64\(2\)\)", data_type: List\(Field \{ name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: \{\} \}\), nullable: true, dict_id: 0, dict_is_ordered: false, metadata: \{\} \}\)\] +select array_transform(v -> v*2, [1, 2]); + +query error DataFusion error: Execution error: lambdas_schemas: array_transform argument 1 \(0\-indexed\), a lambda, supports up to 2 arguments, but got 3 +SELECT array_transform([1, 2], (e, i, j) -> i) from t; diff --git a/datafusion/substrait/src/logical_plan/producer/expr/mod.rs b/datafusion/substrait/src/logical_plan/producer/expr/mod.rs index f4e43fd586773..103d593cafbc0 100644 --- a/datafusion/substrait/src/logical_plan/producer/expr/mod.rs +++ b/datafusion/substrait/src/logical_plan/producer/expr/mod.rs @@ -152,6 +152,7 @@ pub fn to_substrait_rex( not_impl_err!("Cannot convert {expr:?} to Substrait") } Expr::Unnest(expr) => not_impl_err!("Cannot convert {expr:?} to Substrait"), + Expr::Lambda(expr) => not_impl_err!("Cannot convert {expr:?} to Substrait"), // https://github.com/substrait-io/substrait/issues/349 } } From fa4a8fbebe21207225077f41183c2e9016a24fbd Mon Sep 17 00:00:00 2001 From: gstvg <28798827+gstvg@users.noreply.github.com> Date: Sat, 22 Nov 2025 15:34:04 -0300 Subject: [PATCH 02/47] add lambdas: None to existing ScalarFunctionArgs in tests/benches --- datafusion/functions-nested/benches/map.rs | 1 + datafusion/functions-nested/src/array_has.rs | 2 ++ datafusion/functions-nested/src/lib.rs | 3 +++ datafusion/functions-nested/src/map_values.rs | 1 + datafusion/functions-nested/src/set_ops.rs | 1 + datafusion/functions/benches/ascii.rs | 4 ++++ .../functions/benches/character_length.rs | 4 ++++ datafusion/functions/benches/chr.rs | 1 + datafusion/functions/benches/concat.rs | 1 + datafusion/functions/benches/cot.rs | 2 ++ datafusion/functions/benches/date_bin.rs | 1 + datafusion/functions/benches/date_trunc.rs | 1 + datafusion/functions/benches/encoding.rs | 4 ++++ datafusion/functions/benches/find_in_set.rs | 4 ++++ datafusion/functions/benches/gcd.rs | 3 +++ datafusion/functions/benches/initcap.rs | 3 +++ datafusion/functions/benches/isnan.rs | 2 ++ datafusion/functions/benches/iszero.rs | 2 ++ datafusion/functions/benches/lower.rs | 6 ++++++ datafusion/functions/benches/ltrim.rs | 1 + datafusion/functions/benches/make_date.rs | 4 ++++ datafusion/functions/benches/nullif.rs | 1 + datafusion/functions/benches/pad.rs | 1 + datafusion/functions/benches/random.rs | 2 ++ datafusion/functions/benches/repeat.rs | 1 + datafusion/functions/benches/reverse.rs | 4 ++++ datafusion/functions/benches/signum.rs | 2 ++ datafusion/functions/benches/strpos.rs | 4 ++++ datafusion/functions/benches/substr.rs | 1 + datafusion/functions/benches/substr_index.rs | 1 + datafusion/functions/benches/to_char.rs | 6 ++++++ datafusion/functions/benches/to_hex.rs | 2 ++ datafusion/functions/benches/to_timestamp.rs | 6 ++++++ datafusion/functions/benches/trunc.rs | 2 ++ datafusion/functions/benches/upper.rs | 1 + datafusion/functions/benches/uuid.rs | 1 + datafusion/functions/src/core/union_extract.rs | 4 ++++ datafusion/functions/src/core/union_tag.rs | 8 ++++++-- datafusion/functions/src/core/version.rs | 1 + datafusion/functions/src/datetime/date_bin.rs | 1 + .../functions/src/datetime/date_trunc.rs | 2 ++ .../functions/src/datetime/from_unixtime.rs | 2 ++ datafusion/functions/src/datetime/make_date.rs | 1 + datafusion/functions/src/datetime/now.rs | 2 ++ datafusion/functions/src/datetime/to_char.rs | 7 +++++++ datafusion/functions/src/datetime/to_date.rs | 1 + .../functions/src/datetime/to_local_time.rs | 2 ++ .../functions/src/datetime/to_timestamp.rs | 2 ++ datafusion/functions/src/math/log.rs | 18 ++++++++++++++++++ datafusion/functions/src/math/power.rs | 2 ++ datafusion/functions/src/math/signum.rs | 2 ++ datafusion/functions/src/regex/regexpcount.rs | 1 + datafusion/functions/src/regex/regexpinstr.rs | 1 + datafusion/functions/src/string/concat.rs | 1 + datafusion/functions/src/string/concat_ws.rs | 2 ++ datafusion/functions/src/string/contains.rs | 1 + datafusion/functions/src/string/lower.rs | 1 + datafusion/functions/src/string/upper.rs | 1 + .../functions/src/unicode/find_in_set.rs | 1 + datafusion/functions/src/unicode/strpos.rs | 1 + datafusion/functions/src/utils.rs | 3 +++ datafusion/spark/benches/char.rs | 1 + .../spark/src/function/bitmap/bitmap_count.rs | 1 + .../src/function/datetime/make_dt_interval.rs | 1 + .../src/function/datetime/make_interval.rs | 1 + datafusion/spark/src/function/string/concat.rs | 2 ++ datafusion/spark/src/function/utils.rs | 5 ++++- 67 files changed, 162 insertions(+), 3 deletions(-) diff --git a/datafusion/functions-nested/benches/map.rs b/datafusion/functions-nested/benches/map.rs index 3197cc55cc957..3075d2e573e4a 100644 --- a/datafusion/functions-nested/benches/map.rs +++ b/datafusion/functions-nested/benches/map.rs @@ -117,6 +117,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: 1, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), + lambdas: None, }) .expect("map should work on valid values"), ); diff --git a/datafusion/functions-nested/src/array_has.rs b/datafusion/functions-nested/src/array_has.rs index 080b2f16d92f3..d6a333c0a0ef3 100644 --- a/datafusion/functions-nested/src/array_has.rs +++ b/datafusion/functions-nested/src/array_has.rs @@ -819,6 +819,7 @@ mod tests { number_rows: 1, return_field, config_options: Arc::new(ConfigOptions::default()), + lambdas: None, })?; let output = result.into_array(1)?; @@ -847,6 +848,7 @@ mod tests { number_rows: 1, return_field, config_options: Arc::new(ConfigOptions::default()), + lambdas: None, })?; let output = result.into_array(1)?; diff --git a/datafusion/functions-nested/src/lib.rs b/datafusion/functions-nested/src/lib.rs index 3a66e65694768..55acf24ba4657 100644 --- a/datafusion/functions-nested/src/lib.rs +++ b/datafusion/functions-nested/src/lib.rs @@ -37,6 +37,7 @@ pub mod macros; pub mod array_has; +pub mod array_transform; pub mod cardinality; pub mod concat; pub mod dimension; @@ -78,6 +79,7 @@ pub mod expr_fn { pub use super::array_has::array_has; pub use super::array_has::array_has_all; pub use super::array_has::array_has_any; + pub use super::array_transform::array_transform; pub use super::cardinality::cardinality; pub use super::concat::array_append; pub use super::concat::array_concat; @@ -145,6 +147,7 @@ pub fn all_default_nested_functions() -> Vec> { array_has::array_has_udf(), array_has::array_has_all_udf(), array_has::array_has_any_udf(), + array_transform::array_transform_udf(), empty::array_empty_udf(), length::array_length_udf(), distance::array_distance_udf(), diff --git a/datafusion/functions-nested/src/map_values.rs b/datafusion/functions-nested/src/map_values.rs index 6ae8a278063da..ac21ff8acd3f9 100644 --- a/datafusion/functions-nested/src/map_values.rs +++ b/datafusion/functions-nested/src/map_values.rs @@ -204,6 +204,7 @@ mod tests { let args = datafusion_expr::ReturnFieldArgs { arg_fields: &[field], scalar_arguments: &[None::<&ScalarValue>], + lambdas: &[false], }; func.return_field_from_args(args).unwrap() diff --git a/datafusion/functions-nested/src/set_ops.rs b/datafusion/functions-nested/src/set_ops.rs index 53642bf1622b0..f26fc173d8a9f 100644 --- a/datafusion/functions-nested/src/set_ops.rs +++ b/datafusion/functions-nested/src/set_ops.rs @@ -596,6 +596,7 @@ mod tests { number_rows: 1, return_field: input_field.clone().into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, })?; assert_eq!( diff --git a/datafusion/functions/benches/ascii.rs b/datafusion/functions/benches/ascii.rs index 03d25e9c3d4fe..97e6ab20ed458 100644 --- a/datafusion/functions/benches/ascii.rs +++ b/datafusion/functions/benches/ascii.rs @@ -60,6 +60,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: N_ROWS, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), + lambdas: None, })) }) }, @@ -81,6 +82,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: N_ROWS, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), + lambdas: None, })) }) }, @@ -108,6 +110,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: N_ROWS, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), + lambdas: None, })) }) }, @@ -129,6 +132,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: N_ROWS, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), + lambdas: None, })) }) }, diff --git a/datafusion/functions/benches/character_length.rs b/datafusion/functions/benches/character_length.rs index 4a1a63d62765f..f98e8a8b1a68b 100644 --- a/datafusion/functions/benches/character_length.rs +++ b/datafusion/functions/benches/character_length.rs @@ -55,6 +55,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: n_rows, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), + lambdas: None, })) }) }, @@ -79,6 +80,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: n_rows, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), + lambdas: None, })) }) }, @@ -103,6 +105,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: n_rows, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), + lambdas: None, })) }) }, @@ -127,6 +130,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: n_rows, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), + lambdas: None, })) }) }, diff --git a/datafusion/functions/benches/chr.rs b/datafusion/functions/benches/chr.rs index 8356cf7c31726..d51cda4566d64 100644 --- a/datafusion/functions/benches/chr.rs +++ b/datafusion/functions/benches/chr.rs @@ -69,6 +69,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: size, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), + lambdas: None, }) .unwrap(), ) diff --git a/datafusion/functions/benches/concat.rs b/datafusion/functions/benches/concat.rs index 09200139a244b..6378328537827 100644 --- a/datafusion/functions/benches/concat.rs +++ b/datafusion/functions/benches/concat.rs @@ -60,6 +60,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: size, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), + lambdas: None, }) .unwrap(), ) diff --git a/datafusion/functions/benches/cot.rs b/datafusion/functions/benches/cot.rs index 97f21ccd6d55e..56f50522acc5d 100644 --- a/datafusion/functions/benches/cot.rs +++ b/datafusion/functions/benches/cot.rs @@ -54,6 +54,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: size, return_field: Field::new("f", DataType::Float32, true).into(), config_options: Arc::clone(&config_options), + lambdas: None, }) .unwrap(), ) @@ -80,6 +81,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: size, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), + lambdas: None, }) .unwrap(), ) diff --git a/datafusion/functions/benches/date_bin.rs b/datafusion/functions/benches/date_bin.rs index 74390491d538c..1c3713723738a 100644 --- a/datafusion/functions/benches/date_bin.rs +++ b/datafusion/functions/benches/date_bin.rs @@ -66,6 +66,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: batch_len, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), + lambdas: None, }) .expect("date_bin should work on valid values"), ) diff --git a/datafusion/functions/benches/date_trunc.rs b/datafusion/functions/benches/date_trunc.rs index 498a3e63ef290..b757535fb03c5 100644 --- a/datafusion/functions/benches/date_trunc.rs +++ b/datafusion/functions/benches/date_trunc.rs @@ -71,6 +71,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: batch_len, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), + lambdas: None, }) .expect("date_trunc should work on valid values"), ) diff --git a/datafusion/functions/benches/encoding.rs b/datafusion/functions/benches/encoding.rs index 98faee91e1911..72b033cf5d9ed 100644 --- a/datafusion/functions/benches/encoding.rs +++ b/datafusion/functions/benches/encoding.rs @@ -45,6 +45,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: size, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), + lambdas: None, }) .unwrap(); @@ -63,6 +64,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: size, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), + lambdas: None, }) .unwrap(), ) @@ -82,6 +84,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: size, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), + lambdas: None, }) .unwrap(); @@ -101,6 +104,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: size, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), + lambdas: None, }) .unwrap(), ) diff --git a/datafusion/functions/benches/find_in_set.rs b/datafusion/functions/benches/find_in_set.rs index a928f5655806c..6fe498a58d84b 100644 --- a/datafusion/functions/benches/find_in_set.rs +++ b/datafusion/functions/benches/find_in_set.rs @@ -168,6 +168,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: n_rows, return_field: Arc::clone(&return_field), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, })) }) }); @@ -186,6 +187,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: n_rows, return_field: Arc::clone(&return_field), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, })) }) }); @@ -208,6 +210,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: n_rows, return_field: Arc::clone(&return_field), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, })) }) }); @@ -228,6 +231,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: n_rows, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), + lambdas: None, })) }) }); diff --git a/datafusion/functions/benches/gcd.rs b/datafusion/functions/benches/gcd.rs index 19e196d9a3eab..2bfec91e290dd 100644 --- a/datafusion/functions/benches/gcd.rs +++ b/datafusion/functions/benches/gcd.rs @@ -58,6 +58,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: 0, return_field: Field::new("f", DataType::Int64, true).into(), config_options: Arc::clone(&config_options), + lambdas: None, }) .expect("date_bin should work on valid values"), ) @@ -79,6 +80,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: 0, return_field: Field::new("f", DataType::Int64, true).into(), config_options: Arc::clone(&config_options), + lambdas: None, }) .expect("date_bin should work on valid values"), ) @@ -100,6 +102,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: 0, return_field: Field::new("f", DataType::Int64, true).into(), config_options: Arc::clone(&config_options), + lambdas: None, }) .expect("date_bin should work on valid values"), ) diff --git a/datafusion/functions/benches/initcap.rs b/datafusion/functions/benches/initcap.rs index 50aee8dbb9161..37d98596deb82 100644 --- a/datafusion/functions/benches/initcap.rs +++ b/datafusion/functions/benches/initcap.rs @@ -70,6 +70,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: size, return_field: Field::new("f", DataType::Utf8View, true).into(), config_options: Arc::clone(&config_options), + lambdas: None, })) }) }, @@ -86,6 +87,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: size, return_field: Field::new("f", DataType::Utf8View, true).into(), config_options: Arc::clone(&config_options), + lambdas: None, })) }) }, @@ -100,6 +102,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: size, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), + lambdas: None, })) }) }); diff --git a/datafusion/functions/benches/isnan.rs b/datafusion/functions/benches/isnan.rs index 4a90d45d66223..dcce59e46ce41 100644 --- a/datafusion/functions/benches/isnan.rs +++ b/datafusion/functions/benches/isnan.rs @@ -53,6 +53,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: size, return_field: Field::new("f", DataType::Boolean, true).into(), config_options: Arc::clone(&config_options), + lambdas: None, }) .unwrap(), ) @@ -77,6 +78,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: size, return_field: Field::new("f", DataType::Boolean, true).into(), config_options: Arc::clone(&config_options), + lambdas: None, }) .unwrap(), ) diff --git a/datafusion/functions/benches/iszero.rs b/datafusion/functions/benches/iszero.rs index 961cba7200ce0..574539fbb6427 100644 --- a/datafusion/functions/benches/iszero.rs +++ b/datafusion/functions/benches/iszero.rs @@ -55,6 +55,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: batch_len, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), + lambdas: None, }) .unwrap(), ) @@ -82,6 +83,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: batch_len, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), + lambdas: None, }) .unwrap(), ) diff --git a/datafusion/functions/benches/lower.rs b/datafusion/functions/benches/lower.rs index 6a5178b87fdce..e741afd0d8e01 100644 --- a/datafusion/functions/benches/lower.rs +++ b/datafusion/functions/benches/lower.rs @@ -145,6 +145,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: size, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), + lambdas: None, })) }) }); @@ -167,6 +168,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: size, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), + lambdas: None, })) }) }); @@ -191,6 +193,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: size, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), + lambdas: None, })) }) }, @@ -225,6 +228,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: size, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), + lambdas: None, })) }), ); @@ -240,6 +244,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: size, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), + lambdas: None, })) }), ); @@ -256,6 +261,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: size, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), + lambdas: None, })) }), ); diff --git a/datafusion/functions/benches/ltrim.rs b/datafusion/functions/benches/ltrim.rs index 4458af614396d..9b344cc6b143a 100644 --- a/datafusion/functions/benches/ltrim.rs +++ b/datafusion/functions/benches/ltrim.rs @@ -153,6 +153,7 @@ fn run_with_string_type( number_rows: size, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), + lambdas: None, })) }) }, diff --git a/datafusion/functions/benches/make_date.rs b/datafusion/functions/benches/make_date.rs index 15a895468db93..2a681ddedcbe8 100644 --- a/datafusion/functions/benches/make_date.rs +++ b/datafusion/functions/benches/make_date.rs @@ -81,6 +81,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: batch_len, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), + lambdas: None, }) .expect("make_date should work on valid values"), ) @@ -111,6 +112,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: batch_len, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), + lambdas: None, }) .expect("make_date should work on valid values"), ) @@ -141,6 +143,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: batch_len, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), + lambdas: None, }) .expect("make_date should work on valid values"), ) @@ -168,6 +171,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: 1, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), + lambdas: None, }) .expect("make_date should work on valid values"), ) diff --git a/datafusion/functions/benches/nullif.rs b/datafusion/functions/benches/nullif.rs index d649697cc5188..15914cd7ee6c5 100644 --- a/datafusion/functions/benches/nullif.rs +++ b/datafusion/functions/benches/nullif.rs @@ -54,6 +54,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: size, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), + lambdas: None, }) .unwrap(), ) diff --git a/datafusion/functions/benches/pad.rs b/datafusion/functions/benches/pad.rs index f92a69bbf4f92..c7d46da3d26c6 100644 --- a/datafusion/functions/benches/pad.rs +++ b/datafusion/functions/benches/pad.rs @@ -116,6 +116,7 @@ fn invoke_pad_with_args( number_rows, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), + lambdas: None, }; if left_pad { diff --git a/datafusion/functions/benches/random.rs b/datafusion/functions/benches/random.rs index 88efb2d1b5b93..2935876685800 100644 --- a/datafusion/functions/benches/random.rs +++ b/datafusion/functions/benches/random.rs @@ -43,6 +43,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: 8192, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), + lambdas: None, }) .unwrap(), ); @@ -64,6 +65,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: 128, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), + lambdas: None, }) .unwrap(), ); diff --git a/datafusion/functions/benches/repeat.rs b/datafusion/functions/benches/repeat.rs index 80ffa8ee38f1a..9a7c63ed4f304 100644 --- a/datafusion/functions/benches/repeat.rs +++ b/datafusion/functions/benches/repeat.rs @@ -76,6 +76,7 @@ fn invoke_repeat_with_args( number_rows: repeat_times as usize, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), + lambdas: None, }) } diff --git a/datafusion/functions/benches/reverse.rs b/datafusion/functions/benches/reverse.rs index b1eca654fb254..a8af40cd8cc19 100644 --- a/datafusion/functions/benches/reverse.rs +++ b/datafusion/functions/benches/reverse.rs @@ -58,6 +58,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: N_ROWS, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), + lambdas: None, })) }) }, @@ -80,6 +81,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: N_ROWS, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), + lambdas: None, })) }) }, @@ -107,6 +109,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: N_ROWS, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), + lambdas: None, })) }) }, @@ -131,6 +134,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: N_ROWS, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), + lambdas: None, })) }) }, diff --git a/datafusion/functions/benches/signum.rs b/datafusion/functions/benches/signum.rs index 24b8861e4d28c..805b62c83da6d 100644 --- a/datafusion/functions/benches/signum.rs +++ b/datafusion/functions/benches/signum.rs @@ -55,6 +55,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: batch_len, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), + lambdas: None, }) .unwrap(), ) @@ -83,6 +84,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: batch_len, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), + lambdas: None, }) .unwrap(), ) diff --git a/datafusion/functions/benches/strpos.rs b/datafusion/functions/benches/strpos.rs index 18a99e44bf487..708ebb5518727 100644 --- a/datafusion/functions/benches/strpos.rs +++ b/datafusion/functions/benches/strpos.rs @@ -128,6 +128,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: n_rows, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), + lambdas: None, })) }) }, @@ -146,6 +147,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: n_rows, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), + lambdas: None, })) }) }); @@ -165,6 +167,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: n_rows, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), + lambdas: None, })) }) }, @@ -185,6 +188,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: n_rows, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), + lambdas: None, })) }) }, diff --git a/datafusion/functions/benches/substr.rs b/datafusion/functions/benches/substr.rs index 771413458c1fb..58fda73defd25 100644 --- a/datafusion/functions/benches/substr.rs +++ b/datafusion/functions/benches/substr.rs @@ -116,6 +116,7 @@ fn invoke_substr_with_args( number_rows, return_field: Field::new("f", DataType::Utf8View, true).into(), config_options: Arc::clone(&config_options), + lambdas: None, }) } diff --git a/datafusion/functions/benches/substr_index.rs b/datafusion/functions/benches/substr_index.rs index d0941d9baedda..a77b961657c5f 100644 --- a/datafusion/functions/benches/substr_index.rs +++ b/datafusion/functions/benches/substr_index.rs @@ -110,6 +110,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: batch_len, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), + lambdas: None, }) .expect("substr_index should work on valid values"), ) diff --git a/datafusion/functions/benches/to_char.rs b/datafusion/functions/benches/to_char.rs index 945508aec7405..61990b4cb8b95 100644 --- a/datafusion/functions/benches/to_char.rs +++ b/datafusion/functions/benches/to_char.rs @@ -149,6 +149,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: batch_len, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), + lambdas: None, }) .expect("to_char should work on valid values"), ) @@ -176,6 +177,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: batch_len, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), + lambdas: None, }) .expect("to_char should work on valid values"), ) @@ -203,6 +205,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: batch_len, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), + lambdas: None, }) .expect("to_char should work on valid values"), ) @@ -229,6 +232,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: batch_len, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), + lambdas: None, }) .expect("to_char should work on valid values"), ) @@ -256,6 +260,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: batch_len, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), + lambdas: None, }) .expect("to_char should work on valid values"), ) @@ -288,6 +293,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: 1, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), + lambdas: None, }) .expect("to_char should work on valid values"), ) diff --git a/datafusion/functions/benches/to_hex.rs b/datafusion/functions/benches/to_hex.rs index a75ed9258791e..baa2de80c466f 100644 --- a/datafusion/functions/benches/to_hex.rs +++ b/datafusion/functions/benches/to_hex.rs @@ -44,6 +44,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: batch_len, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), + lambdas: None, }) .unwrap(), ) @@ -62,6 +63,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: batch_len, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), + lambdas: None, }) .unwrap(), ) diff --git a/datafusion/functions/benches/to_timestamp.rs b/datafusion/functions/benches/to_timestamp.rs index a8f5c5816d4da..e510a7c3fad41 100644 --- a/datafusion/functions/benches/to_timestamp.rs +++ b/datafusion/functions/benches/to_timestamp.rs @@ -130,6 +130,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: batch_len, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), + lambdas: None, }) .expect("to_timestamp should work on valid values"), ) @@ -150,6 +151,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: batch_len, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), + lambdas: None, }) .expect("to_timestamp should work on valid values"), ) @@ -170,6 +172,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: batch_len, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), + lambdas: None, }) .expect("to_timestamp should work on valid values"), ) @@ -203,6 +206,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: batch_len, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), + lambdas: None, }) .expect("to_timestamp should work on valid values"), ) @@ -244,6 +248,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: batch_len, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), + lambdas: None, }) .expect("to_timestamp should work on valid values"), ) @@ -286,6 +291,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: batch_len, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), + lambdas: None, }) .expect("to_timestamp should work on valid values"), ) diff --git a/datafusion/functions/benches/trunc.rs b/datafusion/functions/benches/trunc.rs index 6e225e0e7038b..0b08791f9ae50 100644 --- a/datafusion/functions/benches/trunc.rs +++ b/datafusion/functions/benches/trunc.rs @@ -49,6 +49,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: size, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), + lambdas: None, }) .unwrap(), ) @@ -68,6 +69,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: size, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), + lambdas: None, }) .unwrap(), ) diff --git a/datafusion/functions/benches/upper.rs b/datafusion/functions/benches/upper.rs index 7328b32574a4a..e9f0941032d8a 100644 --- a/datafusion/functions/benches/upper.rs +++ b/datafusion/functions/benches/upper.rs @@ -50,6 +50,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: size, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), + lambdas: None, })) }) }); diff --git a/datafusion/functions/benches/uuid.rs b/datafusion/functions/benches/uuid.rs index 1368e2f2af5d1..8ad79b2866eaf 100644 --- a/datafusion/functions/benches/uuid.rs +++ b/datafusion/functions/benches/uuid.rs @@ -37,6 +37,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: 1024, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), + lambdas: None, })) }) }); diff --git a/datafusion/functions/src/core/union_extract.rs b/datafusion/functions/src/core/union_extract.rs index a71e2e87388d5..ac542866f7e43 100644 --- a/datafusion/functions/src/core/union_extract.rs +++ b/datafusion/functions/src/core/union_extract.rs @@ -209,6 +209,7 @@ mod tests { number_rows: 1, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, })?; assert_scalar(result, ScalarValue::Utf8(None)); @@ -232,6 +233,7 @@ mod tests { number_rows: 1, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, })?; assert_scalar(result, ScalarValue::Utf8(None)); @@ -248,12 +250,14 @@ mod tests { .iter() .map(|arg| Field::new("a", arg.data_type().clone(), true).into()) .collect::>(); + let result = fun.invoke_with_args(ScalarFunctionArgs { args, arg_fields, number_rows: 1, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, })?; assert_scalar(result, ScalarValue::new_utf8("42")); diff --git a/datafusion/functions/src/core/union_tag.rs b/datafusion/functions/src/core/union_tag.rs index aeadb8292ba1e..ecdebf66e0043 100644 --- a/datafusion/functions/src/core/union_tag.rs +++ b/datafusion/functions/src/core/union_tag.rs @@ -173,6 +173,7 @@ mod tests { fields, UnionMode::Dense, ); + let arg_fields = vec![Field::new("a", scalar.data_type(), false).into()]; let return_type = DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Utf8)); @@ -180,10 +181,11 @@ mod tests { let result = UnionTagFunc::new() .invoke_with_args(ScalarFunctionArgs { args: vec![ColumnarValue::Scalar(scalar)], + arg_fields, number_rows: 1, return_field: Field::new("res", return_type, true).into(), - arg_fields: vec![], config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }) .unwrap(); @@ -196,6 +198,7 @@ mod tests { #[test] fn union_scalar_empty() { let scalar = ScalarValue::Union(None, UnionFields::empty(), UnionMode::Dense); + let arg_fields = vec![Field::new("a", scalar.data_type(), false).into()]; let return_type = DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Utf8)); @@ -203,10 +206,11 @@ mod tests { let result = UnionTagFunc::new() .invoke_with_args(ScalarFunctionArgs { args: vec![ColumnarValue::Scalar(scalar)], + arg_fields, number_rows: 1, return_field: Field::new("res", return_type, true).into(), - arg_fields: vec![], config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }) .unwrap(); diff --git a/datafusion/functions/src/core/version.rs b/datafusion/functions/src/core/version.rs index ef3c5aafa4801..390111028c8f2 100644 --- a/datafusion/functions/src/core/version.rs +++ b/datafusion/functions/src/core/version.rs @@ -112,6 +112,7 @@ mod test { number_rows: 0, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }) .unwrap(); diff --git a/datafusion/functions/src/datetime/date_bin.rs b/datafusion/functions/src/datetime/date_bin.rs index 92af123dbafac..5466129314640 100644 --- a/datafusion/functions/src/datetime/date_bin.rs +++ b/datafusion/functions/src/datetime/date_bin.rs @@ -530,6 +530,7 @@ mod tests { number_rows, return_field: Arc::clone(return_field), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; DateBinFunc::new().invoke_with_args(args) } diff --git a/datafusion/functions/src/datetime/date_trunc.rs b/datafusion/functions/src/datetime/date_trunc.rs index 913e6217af82d..5736c221cae84 100644 --- a/datafusion/functions/src/datetime/date_trunc.rs +++ b/datafusion/functions/src/datetime/date_trunc.rs @@ -892,6 +892,7 @@ mod tests { ) .into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let result = DateTruncFunc::new().invoke_with_args(args).unwrap(); if let ColumnarValue::Array(result) = result { @@ -1080,6 +1081,7 @@ mod tests { ) .into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let result = DateTruncFunc::new().invoke_with_args(args).unwrap(); if let ColumnarValue::Array(result) = result { diff --git a/datafusion/functions/src/datetime/from_unixtime.rs b/datafusion/functions/src/datetime/from_unixtime.rs index 5d6adfb6f119a..be44be094e5b7 100644 --- a/datafusion/functions/src/datetime/from_unixtime.rs +++ b/datafusion/functions/src/datetime/from_unixtime.rs @@ -179,6 +179,7 @@ mod test { number_rows: 1, return_field: Field::new("f", DataType::Timestamp(Second, None), true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let result = FromUnixtimeFunc::new().invoke_with_args(args).unwrap(); @@ -212,6 +213,7 @@ mod test { ) .into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let result = FromUnixtimeFunc::new().invoke_with_args(args).unwrap(); diff --git a/datafusion/functions/src/datetime/make_date.rs b/datafusion/functions/src/datetime/make_date.rs index 0fe5d156a8383..afa4ef132147a 100644 --- a/datafusion/functions/src/datetime/make_date.rs +++ b/datafusion/functions/src/datetime/make_date.rs @@ -250,6 +250,7 @@ mod tests { number_rows, return_field: Field::new("f", DataType::Date32, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; MakeDateFunc::new().invoke_with_args(args) } diff --git a/datafusion/functions/src/datetime/now.rs b/datafusion/functions/src/datetime/now.rs index 4723548a45584..f18e72a107e28 100644 --- a/datafusion/functions/src/datetime/now.rs +++ b/datafusion/functions/src/datetime/now.rs @@ -163,6 +163,7 @@ mod tests { .return_field_from_args(ReturnFieldArgs { arg_fields: &empty_fields, scalar_arguments: &empty_scalars, + lambdas: &[], }) .expect("legacy now() return field"); @@ -170,6 +171,7 @@ mod tests { .return_field_from_args(ReturnFieldArgs { arg_fields: &empty_fields, scalar_arguments: &empty_scalars, + lambdas: &[], }) .expect("configured now() return field"); diff --git a/datafusion/functions/src/datetime/to_char.rs b/datafusion/functions/src/datetime/to_char.rs index 7d9b2bc241e1a..5d69ce233f643 100644 --- a/datafusion/functions/src/datetime/to_char.rs +++ b/datafusion/functions/src/datetime/to_char.rs @@ -375,6 +375,7 @@ mod tests { number_rows: batch_len, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&Arc::new(ConfigOptions::default())), + lambdas: None, }; let result = ToCharFunc::new() .invoke_with_args(args) @@ -480,6 +481,7 @@ mod tests { number_rows: 1, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let result = ToCharFunc::new() .invoke_with_args(args) @@ -574,6 +576,7 @@ mod tests { number_rows: batch_len, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let result = ToCharFunc::new() .invoke_with_args(args) @@ -738,6 +741,7 @@ mod tests { number_rows: batch_len, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let result = ToCharFunc::new() .invoke_with_args(args) @@ -766,6 +770,7 @@ mod tests { number_rows: batch_len, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let result = ToCharFunc::new() .invoke_with_args(args) @@ -791,6 +796,7 @@ mod tests { number_rows: 1, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let result = ToCharFunc::new().invoke_with_args(args); assert_eq!( @@ -812,6 +818,7 @@ mod tests { number_rows: 1, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let result = ToCharFunc::new().invoke_with_args(args); assert_eq!( diff --git a/datafusion/functions/src/datetime/to_date.rs b/datafusion/functions/src/datetime/to_date.rs index 3840c8d8bbb94..f6b313e6a28bb 100644 --- a/datafusion/functions/src/datetime/to_date.rs +++ b/datafusion/functions/src/datetime/to_date.rs @@ -186,6 +186,7 @@ mod tests { number_rows, return_field: Field::new("f", DataType::Date32, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; ToDateFunc::new().invoke_with_args(args) } diff --git a/datafusion/functions/src/datetime/to_local_time.rs b/datafusion/functions/src/datetime/to_local_time.rs index 6e0a150b0a35f..4d50a70d37236 100644 --- a/datafusion/functions/src/datetime/to_local_time.rs +++ b/datafusion/functions/src/datetime/to_local_time.rs @@ -549,6 +549,7 @@ mod tests { number_rows: 1, return_field: Field::new("f", expected.data_type(), true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }) .unwrap(); match res { @@ -620,6 +621,7 @@ mod tests { ) .into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let result = ToLocalTimeFunc::new().invoke_with_args(args).unwrap(); if let ColumnarValue::Array(result) = result { diff --git a/datafusion/functions/src/datetime/to_timestamp.rs b/datafusion/functions/src/datetime/to_timestamp.rs index 0a0700097770f..f35e170073030 100644 --- a/datafusion/functions/src/datetime/to_timestamp.rs +++ b/datafusion/functions/src/datetime/to_timestamp.rs @@ -1033,6 +1033,7 @@ mod tests { number_rows: 4, return_field: Field::new("f", rt, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let res = udf .invoke_with_args(args) @@ -1083,6 +1084,7 @@ mod tests { number_rows: 5, return_field: Field::new("f", rt, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let res = udf .invoke_with_args(args) diff --git a/datafusion/functions/src/math/log.rs b/datafusion/functions/src/math/log.rs index f66f6fcfc1f88..1a73ed8436a68 100644 --- a/datafusion/functions/src/math/log.rs +++ b/datafusion/functions/src/math/log.rs @@ -370,6 +370,7 @@ mod tests { number_rows: 4, return_field: Field::new("f", DataType::Float64, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let result = LogFunc::new().invoke_with_args(args); assert!(result.is_err()); @@ -390,6 +391,7 @@ mod tests { number_rows: 1, return_field: Field::new("f", DataType::Float64, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let result = LogFunc::new().invoke_with_args(args); @@ -407,6 +409,7 @@ mod tests { number_rows: 1, return_field: Field::new("f", DataType::Float32, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let result = LogFunc::new() .invoke_with_args(args) @@ -437,6 +440,7 @@ mod tests { number_rows: 1, return_field: Field::new("f", DataType::Float64, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let result = LogFunc::new() .invoke_with_args(args) @@ -471,6 +475,7 @@ mod tests { number_rows: 1, return_field: Field::new("f", DataType::Float32, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let result = LogFunc::new() .invoke_with_args(args) @@ -505,6 +510,7 @@ mod tests { number_rows: 1, return_field: Field::new("f", DataType::Float64, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let result = LogFunc::new() .invoke_with_args(args) @@ -537,6 +543,7 @@ mod tests { number_rows: 4, return_field: Field::new("f", DataType::Float64, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let result = LogFunc::new() .invoke_with_args(args) @@ -572,6 +579,7 @@ mod tests { number_rows: 4, return_field: Field::new("f", DataType::Float32, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let result = LogFunc::new() .invoke_with_args(args) @@ -613,6 +621,7 @@ mod tests { number_rows: 5, return_field: Field::new("f", DataType::Float64, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let result = LogFunc::new() .invoke_with_args(args) @@ -655,6 +664,7 @@ mod tests { number_rows: 4, return_field: Field::new("f", DataType::Float32, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let result = LogFunc::new() .invoke_with_args(args) @@ -836,6 +846,7 @@ mod tests { number_rows: 1, return_field: Field::new("f", DataType::Decimal128(38, 0), true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let result = LogFunc::new() .invoke_with_args(args) @@ -869,6 +880,7 @@ mod tests { number_rows: 1, return_field: Field::new("f", DataType::Float64, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let result = LogFunc::new() .invoke_with_args(args) @@ -903,6 +915,7 @@ mod tests { number_rows: 6, return_field: Field::new("f", DataType::Float64, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let result = LogFunc::new() .invoke_with_args(args) @@ -947,6 +960,7 @@ mod tests { number_rows: 1, return_field: Field::new("f", DataType::Float64, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let result = LogFunc::new() .invoke_with_args(args) @@ -987,6 +1001,7 @@ mod tests { number_rows: 1, return_field: Field::new("f", DataType::Float64, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let result = LogFunc::new() .invoke_with_args(args) @@ -1037,6 +1052,7 @@ mod tests { number_rows: 7, return_field: Field::new("f", DataType::Float64, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let result = LogFunc::new() .invoke_with_args(args) @@ -1078,6 +1094,7 @@ mod tests { number_rows: 1, return_field: Field::new("f", DataType::Float64, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let result = LogFunc::new().invoke_with_args(args); assert!(result.is_err()); @@ -1101,6 +1118,7 @@ mod tests { number_rows: 1, return_field: Field::new("f", DataType::Float64, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let result = LogFunc::new().invoke_with_args(args); assert!(result.is_err()); diff --git a/datafusion/functions/src/math/power.rs b/datafusion/functions/src/math/power.rs index ad2e795d086e9..21a777abb3295 100644 --- a/datafusion/functions/src/math/power.rs +++ b/datafusion/functions/src/math/power.rs @@ -222,6 +222,7 @@ mod tests { number_rows: 4, return_field: Field::new("f", DataType::Float64, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let result = PowerFunc::new() .invoke_with_args(args) @@ -258,6 +259,7 @@ mod tests { number_rows: 4, return_field: Field::new("f", DataType::Int64, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let result = PowerFunc::new() .invoke_with_args(args) diff --git a/datafusion/functions/src/math/signum.rs b/datafusion/functions/src/math/signum.rs index bbe6178f39b79..d1d49b1bf6f90 100644 --- a/datafusion/functions/src/math/signum.rs +++ b/datafusion/functions/src/math/signum.rs @@ -173,6 +173,7 @@ mod test { number_rows: array.len(), return_field: Field::new("f", DataType::Float32, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let result = SignumFunc::new() .invoke_with_args(args) @@ -220,6 +221,7 @@ mod test { number_rows: array.len(), return_field: Field::new("f", DataType::Float64, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let result = SignumFunc::new() .invoke_with_args(args) diff --git a/datafusion/functions/src/regex/regexpcount.rs b/datafusion/functions/src/regex/regexpcount.rs index 8bad506217aa5..ee6f412bb9a16 100644 --- a/datafusion/functions/src/regex/regexpcount.rs +++ b/datafusion/functions/src/regex/regexpcount.rs @@ -628,6 +628,7 @@ mod tests { number_rows: args.len(), return_field: Field::new("f", Int64, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }) } diff --git a/datafusion/functions/src/regex/regexpinstr.rs b/datafusion/functions/src/regex/regexpinstr.rs index 851c182a90dd0..1e64f7087ea74 100644 --- a/datafusion/functions/src/regex/regexpinstr.rs +++ b/datafusion/functions/src/regex/regexpinstr.rs @@ -494,6 +494,7 @@ mod tests { number_rows: args.len(), return_field: Arc::new(Field::new("f", Int64, true)), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }) } diff --git a/datafusion/functions/src/string/concat.rs b/datafusion/functions/src/string/concat.rs index a93e70e714e8b..661bcfe4e0fd8 100644 --- a/datafusion/functions/src/string/concat.rs +++ b/datafusion/functions/src/string/concat.rs @@ -487,6 +487,7 @@ mod tests { number_rows: 3, return_field: Field::new("f", Utf8, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let result = ConcatFunc::new().invoke_with_args(args)?; diff --git a/datafusion/functions/src/string/concat_ws.rs b/datafusion/functions/src/string/concat_ws.rs index cdd30ac8755ab..85704d6b2f468 100644 --- a/datafusion/functions/src/string/concat_ws.rs +++ b/datafusion/functions/src/string/concat_ws.rs @@ -495,6 +495,7 @@ mod tests { number_rows: 3, return_field: Field::new("f", Utf8, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let result = ConcatWsFunc::new().invoke_with_args(args)?; @@ -532,6 +533,7 @@ mod tests { number_rows: 3, return_field: Field::new("f", Utf8, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let result = ConcatWsFunc::new().invoke_with_args(args)?; diff --git a/datafusion/functions/src/string/contains.rs b/datafusion/functions/src/string/contains.rs index 7e50676933c8d..1edab4c6bf334 100644 --- a/datafusion/functions/src/string/contains.rs +++ b/datafusion/functions/src/string/contains.rs @@ -177,6 +177,7 @@ mod test { number_rows: 2, return_field: Field::new("f", DataType::Boolean, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let actual = udf.invoke_with_args(args).unwrap(); diff --git a/datafusion/functions/src/string/lower.rs b/datafusion/functions/src/string/lower.rs index ee56a6a549857..099a3ffd44cc4 100644 --- a/datafusion/functions/src/string/lower.rs +++ b/datafusion/functions/src/string/lower.rs @@ -113,6 +113,7 @@ mod tests { arg_fields, return_field: Field::new("f", Utf8, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let result = match func.invoke_with_args(args)? { diff --git a/datafusion/functions/src/string/upper.rs b/datafusion/functions/src/string/upper.rs index 8bb2ec1d511cd..d7d2bde94b0a3 100644 --- a/datafusion/functions/src/string/upper.rs +++ b/datafusion/functions/src/string/upper.rs @@ -112,6 +112,7 @@ mod tests { arg_fields: vec![arg_field], return_field: Field::new("f", Utf8, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let result = match func.invoke_with_args(args)? { diff --git a/datafusion/functions/src/unicode/find_in_set.rs b/datafusion/functions/src/unicode/find_in_set.rs index fa68e539600b0..219bd6eaa762c 100644 --- a/datafusion/functions/src/unicode/find_in_set.rs +++ b/datafusion/functions/src/unicode/find_in_set.rs @@ -485,6 +485,7 @@ mod tests { number_rows: cardinality, return_field: Field::new("f", return_type, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }); assert!(result.is_ok()); diff --git a/datafusion/functions/src/unicode/strpos.rs b/datafusion/functions/src/unicode/strpos.rs index 4f238b2644bdf..a3734b0c0de4f 100644 --- a/datafusion/functions/src/unicode/strpos.rs +++ b/datafusion/functions/src/unicode/strpos.rs @@ -336,6 +336,7 @@ mod tests { Field::new("f2", DataType::Utf8, substring_nullable).into(), ], scalar_arguments: &[None::<&ScalarValue>, None::<&ScalarValue>], + lambdas: &[false; 2], }; strpos.return_field_from_args(args).unwrap().is_nullable() diff --git a/datafusion/functions/src/utils.rs b/datafusion/functions/src/utils.rs index 932d61e8007cd..d6d56b32722de 100644 --- a/datafusion/functions/src/utils.rs +++ b/datafusion/functions/src/utils.rs @@ -234,6 +234,7 @@ pub mod test { let return_field = func.return_field_from_args(datafusion_expr::ReturnFieldArgs { arg_fields: &field_array, scalar_arguments: &scalar_arguments_refs, + lambdas: &vec![false; scalar_arguments_refs.len()], }); let arg_fields = $ARGS.iter() .enumerate() @@ -252,6 +253,7 @@ pub mod test { arg_fields, number_rows: cardinality, return_field, + lambdas: None, config_options: $CONFIG_OPTIONS }); assert_eq!(result.is_ok(), true, "function returned an error: {}", result.unwrap_err()); @@ -274,6 +276,7 @@ pub mod test { arg_fields, number_rows: cardinality, return_field, + lambdas: None, config_options: $CONFIG_OPTIONS, }) { Ok(_) => assert!(false, "expected error"), diff --git a/datafusion/spark/benches/char.rs b/datafusion/spark/benches/char.rs index 02eab7630d070..501bfd2a0186d 100644 --- a/datafusion/spark/benches/char.rs +++ b/datafusion/spark/benches/char.rs @@ -68,6 +68,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: size, return_field: Arc::new(Field::new("f", DataType::Utf8, true)), config_options: Arc::clone(&config_options), + lambdas: None, }) .unwrap(), ) diff --git a/datafusion/spark/src/function/bitmap/bitmap_count.rs b/datafusion/spark/src/function/bitmap/bitmap_count.rs index 56a9c5edb812c..e4c12ebe19665 100644 --- a/datafusion/spark/src/function/bitmap/bitmap_count.rs +++ b/datafusion/spark/src/function/bitmap/bitmap_count.rs @@ -217,6 +217,7 @@ mod tests { number_rows: 1, return_field: Field::new("f", Int64, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let udf = BitmapCount::new(); let actual = udf.invoke_with_args(args)?; diff --git a/datafusion/spark/src/function/datetime/make_dt_interval.rs b/datafusion/spark/src/function/datetime/make_dt_interval.rs index bbfba44861344..aaff5400d0c00 100644 --- a/datafusion/spark/src/function/datetime/make_dt_interval.rs +++ b/datafusion/spark/src/function/datetime/make_dt_interval.rs @@ -317,6 +317,7 @@ mod tests { number_rows, return_field: Field::new("f", Duration(Microsecond), true).into(), config_options: Arc::new(Default::default()), + lambdas: None, }; SparkMakeDtInterval::new().invoke_with_args(args) } diff --git a/datafusion/spark/src/function/datetime/make_interval.rs b/datafusion/spark/src/function/datetime/make_interval.rs index 8e3169556b95b..9f98c4b5ce9fb 100644 --- a/datafusion/spark/src/function/datetime/make_interval.rs +++ b/datafusion/spark/src/function/datetime/make_interval.rs @@ -516,6 +516,7 @@ mod tests { number_rows, return_field: Field::new("f", Interval(MonthDayNano), true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; SparkMakeInterval::new().invoke_with_args(args) } diff --git a/datafusion/spark/src/function/string/concat.rs b/datafusion/spark/src/function/string/concat.rs index 0dcc58d5bb8ed..e2cd8d977fe29 100644 --- a/datafusion/spark/src/function/string/concat.rs +++ b/datafusion/spark/src/function/string/concat.rs @@ -105,6 +105,7 @@ fn spark_concat(args: ScalarFunctionArgs) -> Result { number_rows, return_field, config_options, + lambdas, } = args; // Handle zero-argument case: return empty string @@ -130,6 +131,7 @@ fn spark_concat(args: ScalarFunctionArgs) -> Result { number_rows, return_field, config_options, + lambdas, }; let result = concat_func.invoke_with_args(func_args)?; diff --git a/datafusion/spark/src/function/utils.rs b/datafusion/spark/src/function/utils.rs index e272d91d8a70e..1064acc342916 100644 --- a/datafusion/spark/src/function/utils.rs +++ b/datafusion/spark/src/function/utils.rs @@ -60,7 +60,8 @@ pub mod test { let return_field = func.return_field_from_args(datafusion_expr::ReturnFieldArgs { arg_fields: &arg_fields, - scalar_arguments: &scalar_arguments_refs + scalar_arguments: &scalar_arguments_refs, + lambdas: &vec![false; arg_fields.len()], }); match expected { @@ -74,6 +75,7 @@ pub mod test { return_field, arg_fields: arg_fields.clone(), config_options: $CONFIG_OPTIONS, + lambdas: None, }) { Ok(col_value) => { match col_value.to_array(cardinality) { @@ -117,6 +119,7 @@ pub mod test { return_field: value, arg_fields, config_options: $CONFIG_OPTIONS, + lambdas: None, }) { Ok(_) => assert!(false, "expected error"), Err(error) => { From b18d2145163fb5933dfed55eb6305412743b6cac Mon Sep 17 00:00:00 2001 From: gstvg <28798827+gstvg@users.noreply.github.com> Date: Mon, 15 Dec 2025 07:27:02 -0300 Subject: [PATCH 03/47] simplify lambda support --- Cargo.lock | 2 +- .../examples/custom_file_casts.rs | 11 +- .../examples/default_column_values.rs | 14 +- datafusion-examples/examples/expr_api.rs | 8 +- .../examples/json_shredding.rs | 14 +- datafusion/catalog-listing/src/helpers.rs | 9 +- datafusion/common/src/column.rs | 6 - datafusion/common/src/cse.rs | 22 +- datafusion/common/src/dfschema.rs | 16 +- datafusion/common/src/lib.rs | 2 - datafusion/common/src/utils/mod.rs | 32 +- .../core/src/execution/session_state.rs | 5 +- datafusion/core/tests/parquet/mod.rs | 2 +- .../core/tests/parquet/schema_adapter.rs | 8 +- .../datasource-parquet/src/row_filter.rs | 16 +- datafusion/expr/src/expr.rs | 84 ++- datafusion/expr/src/expr_rewriter/mod.rs | 49 +- datafusion/expr/src/expr_rewriter/order_by.rs | 4 - datafusion/expr/src/expr_schema.rs | 39 +- datafusion/expr/src/lib.rs | 6 +- datafusion/expr/src/tree_node.rs | 685 +----------------- datafusion/expr/src/udf.rs | 467 +----------- datafusion/expr/src/utils.rs | 41 +- datafusion/functions-nested/Cargo.toml | 1 + .../functions-nested/src/array_transform.rs | 110 ++- datafusion/functions/src/core/union_tag.rs | 6 +- .../src/analyzer/function_rewrite.rs | 21 +- .../optimizer/src/analyzer/type_coercion.rs | 85 ++- .../optimizer/src/common_subexpr_eliminate.rs | 26 +- datafusion/optimizer/src/decorrelate.rs | 20 +- .../optimizer/src/optimize_projections/mod.rs | 4 +- datafusion/optimizer/src/push_down_filter.rs | 41 +- .../optimizer/src/scalar_subquery_to_join.rs | 67 +- .../simplify_expressions/expr_simplifier.rs | 178 ++--- datafusion/optimizer/src/utils.rs | 4 +- .../src/schema_rewriter.rs | 28 +- datafusion/physical-expr/Cargo.toml | 4 - .../physical-expr/src/expressions/column.rs | 20 +- .../physical-expr/src/expressions/lambda.rs | 23 +- .../src/expressions/lambda_column.rs | 136 ++++ .../physical-expr/src/expressions/mod.rs | 2 + datafusion/physical-expr/src/lib.rs | 2 - datafusion/physical-expr/src/physical_expr.rs | 10 +- datafusion/physical-expr/src/planner.rs | 46 +- datafusion/physical-expr/src/projection.rs | 53 +- .../physical-expr/src/scalar_function.rs | 657 ++--------------- .../physical-expr/src/simplifier/mod.rs | 20 +- .../src/simplifier/unwrap_cast.rs | 12 +- datafusion/physical-expr/src/utils/mod.rs | 21 +- .../src/enforce_sorting/sort_pushdown.rs | 60 +- .../src/projection_pushdown.rs | 55 +- datafusion/physical-plan/src/async_func.rs | 6 +- .../src/joins/stream_join_utils.rs | 29 +- datafusion/physical-plan/src/projection.rs | 61 +- datafusion/proto/src/logical_plan/to_proto.rs | 4 +- datafusion/pruning/src/pruning_predicate.rs | 11 +- datafusion/sql/src/expr/function.rs | 117 ++- datafusion/sql/src/expr/identifier.rs | 11 +- datafusion/sql/src/planner.rs | 28 +- datafusion/sql/src/select.rs | 6 +- datafusion/sql/src/unparser/expr.rs | 3 + datafusion/sql/src/unparser/plan.rs | 6 +- datafusion/sql/src/unparser/rewrite.rs | 14 +- datafusion/sql/src/unparser/utils.rs | 44 +- datafusion/sql/src/utils.rs | 32 +- datafusion/sqllogictest/test_files/array.slt | 8 +- datafusion/sqllogictest/test_files/lambda.slt | 38 +- .../src/logical_plan/producer/expr/mod.rs | 1 + 68 files changed, 1007 insertions(+), 2666 deletions(-) create mode 100644 datafusion/physical-expr/src/expressions/lambda_column.rs diff --git a/Cargo.lock b/Cargo.lock index 4a315ff38f2aa..8377a263cd0cf 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2386,6 +2386,7 @@ dependencies = [ "datafusion-functions-aggregate", "datafusion-functions-aggregate-common", "datafusion-macros", + "datafusion-physical-expr", "datafusion-physical-expr-common", "itertools 0.14.0", "log", @@ -2489,7 +2490,6 @@ dependencies = [ "paste", "petgraph 0.8.3", "rand 0.9.2", - "recursive", "rstest", ] diff --git a/datafusion-examples/examples/custom_file_casts.rs b/datafusion-examples/examples/custom_file_casts.rs index d8db97d1e0440..4d97ecd91dc64 100644 --- a/datafusion-examples/examples/custom_file_casts.rs +++ b/datafusion-examples/examples/custom_file_casts.rs @@ -22,7 +22,7 @@ use arrow::datatypes::{DataType, Field, FieldRef, Schema, SchemaRef}; use datafusion::assert_batches_eq; use datafusion::common::not_impl_err; -use datafusion::common::tree_node::{Transformed, TransformedResult}; +use datafusion::common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion::common::{Result, ScalarValue}; use datafusion::datasource::listing::{ ListingTable, ListingTableConfig, ListingTableConfigExt, ListingTableUrl, @@ -31,7 +31,7 @@ use datafusion::execution::context::SessionContext; use datafusion::execution::object_store::ObjectStoreUrl; use datafusion::parquet::arrow::ArrowWriter; use datafusion::physical_expr::expressions::CastExpr; -use datafusion::physical_expr::{PhysicalExpr, PhysicalExprExt}; +use datafusion::physical_expr::PhysicalExpr; use datafusion::prelude::SessionConfig; use datafusion_physical_expr_adapter::{ DefaultPhysicalExprAdapterFactory, PhysicalExprAdapter, PhysicalExprAdapterFactory, @@ -181,10 +181,11 @@ impl PhysicalExprAdapter for CustomCastsPhysicalExprAdapter { expr = self.inner.rewrite(expr)?; // Now we can apply custom casting rules or even swap out all CastExprs for a custom cast kernel / expression // For example, [DataFusion Comet](https://github.com/apache/datafusion-comet) has a [custom cast kernel](https://github.com/apache/datafusion-comet/blob/b4ac876ab420ed403ac7fc8e1b29f42f1f442566/native/spark-expr/src/conversion_funcs/cast.rs#L133-L138). - expr.transform_with_schema(&self.physical_file_schema, |expr, schema| { + expr.transform(|expr| { if let Some(cast) = expr.as_any().downcast_ref::() { - let input_data_type = cast.expr().data_type(schema)?; - let output_data_type = cast.data_type(schema)?; + let input_data_type = + cast.expr().data_type(&self.physical_file_schema)?; + let output_data_type = cast.data_type(&self.physical_file_schema)?; if !cast.is_bigger_cast(&input_data_type) { return not_impl_err!( "Unsupported CAST from {input_data_type} to {output_data_type}" diff --git a/datafusion-examples/examples/default_column_values.rs b/datafusion-examples/examples/default_column_values.rs index 0d00d2c3af827..d3a7d2ec67f3c 100644 --- a/datafusion-examples/examples/default_column_values.rs +++ b/datafusion-examples/examples/default_column_values.rs @@ -26,8 +26,8 @@ use async_trait::async_trait; use datafusion::assert_batches_eq; use datafusion::catalog::memory::DataSourceExec; use datafusion::catalog::{Session, TableProvider}; -use datafusion::common::tree_node::{Transformed, TransformedResult}; -use datafusion::common::{DFSchema, HashSet}; +use datafusion::common::tree_node::{Transformed, TransformedResult, TreeNode}; +use datafusion::common::DFSchema; use datafusion::common::{Result, ScalarValue}; use datafusion::datasource::listing::PartitionedFile; use datafusion::datasource::physical_plan::{FileScanConfigBuilder, ParquetSource}; @@ -38,7 +38,7 @@ use datafusion::logical_expr::{Expr, TableProviderFilterPushDown, TableType}; use datafusion::parquet::arrow::ArrowWriter; use datafusion::parquet::file::properties::WriterProperties; use datafusion::physical_expr::expressions::{CastExpr, Column, Literal}; -use datafusion::physical_expr::{PhysicalExpr, PhysicalExprExt}; +use datafusion::physical_expr::PhysicalExpr; use datafusion::physical_plan::ExecutionPlan; use datafusion::prelude::{lit, SessionConfig}; use datafusion_physical_expr_adapter::{ @@ -308,12 +308,11 @@ impl PhysicalExprAdapter for DefaultValuePhysicalExprAdapter { fn rewrite(&self, expr: Arc) -> Result> { // First try our custom default value injection for missing columns let rewritten = expr - .transform_with_lambdas_params(|expr, lambdas_params| { + .transform(|expr| { self.inject_default_values( expr, &self.logical_file_schema, &self.physical_file_schema, - lambdas_params, ) }) .data()?; @@ -349,15 +348,12 @@ impl DefaultValuePhysicalExprAdapter { expr: Arc, logical_file_schema: &Schema, physical_file_schema: &Schema, - lambdas_params: &HashSet, ) -> Result>> { if let Some(column) = expr.as_any().downcast_ref::() { let column_name = column.name(); // Check if this column exists in the physical schema - if !lambdas_params.contains(column_name) - && physical_file_schema.index_of(column_name).is_err() - { + if physical_file_schema.index_of(column_name).is_err() { // Column is missing from physical schema, check if logical schema has a default if let Ok(logical_field) = logical_file_schema.field_with_name(column_name) diff --git a/datafusion-examples/examples/expr_api.rs b/datafusion-examples/examples/expr_api.rs index 29f074e2b400c..56f960870e58a 100644 --- a/datafusion-examples/examples/expr_api.rs +++ b/datafusion-examples/examples/expr_api.rs @@ -23,7 +23,7 @@ use arrow::record_batch::RecordBatch; use datafusion::arrow::datatypes::{DataType, Field, Schema, TimeUnit}; use datafusion::common::stats::Precision; -use datafusion::common::tree_node::Transformed; +use datafusion::common::tree_node::{Transformed, TreeNode}; use datafusion::common::{ColumnStatistics, DFSchema}; use datafusion::common::{ScalarValue, ToDFSchema}; use datafusion::error::Result; @@ -556,7 +556,7 @@ fn type_coercion_demo() -> Result<()> { // 3. Type coercion with `TypeCoercionRewriter`. let coerced_expr = expr .clone() - .rewrite_with_schema(&df_schema, &mut TypeCoercionRewriter::new(&df_schema))? + .rewrite(&mut TypeCoercionRewriter::new(&df_schema))? .data; let physical_expr = datafusion::physical_expr::create_physical_expr( &coerced_expr, @@ -567,7 +567,7 @@ fn type_coercion_demo() -> Result<()> { // 4. Apply explicit type coercion by manually rewriting the expression let coerced_expr = expr - .transform_with_schema(&df_schema, |e, df_schema| { + .transform(|e| { // Only type coerces binary expressions. let Expr::BinaryExpr(e) = e else { return Ok(Transformed::no(e)); @@ -575,7 +575,7 @@ fn type_coercion_demo() -> Result<()> { if let Expr::Column(ref col_expr) = *e.left { let field = df_schema.field_with_name(None, col_expr.name())?; let cast_to_type = field.data_type(); - let coerced_right = e.right.cast_to(cast_to_type, df_schema)?; + let coerced_right = e.right.cast_to(cast_to_type, &df_schema)?; Ok(Transformed::yes(Expr::BinaryExpr(BinaryExpr::new( e.left, e.op, diff --git a/datafusion-examples/examples/json_shredding.rs b/datafusion-examples/examples/json_shredding.rs index e97f27b818d8d..5ef8b59b64200 100644 --- a/datafusion-examples/examples/json_shredding.rs +++ b/datafusion-examples/examples/json_shredding.rs @@ -22,8 +22,10 @@ use arrow::array::{RecordBatch, StringArray}; use arrow::datatypes::{DataType, Field, FieldRef, Schema, SchemaRef}; use datafusion::assert_batches_eq; -use datafusion::common::tree_node::{Transformed, TransformedResult, TreeNodeRecursion}; -use datafusion::common::{assert_contains, exec_datafusion_err, HashSet, Result}; +use datafusion::common::tree_node::{ + Transformed, TransformedResult, TreeNode, TreeNodeRecursion, +}; +use datafusion::common::{assert_contains, exec_datafusion_err, Result}; use datafusion::datasource::listing::{ ListingTable, ListingTableConfig, ListingTableConfigExt, ListingTableUrl, }; @@ -34,8 +36,8 @@ use datafusion::logical_expr::{ }; use datafusion::parquet::arrow::ArrowWriter; use datafusion::parquet::file::properties::WriterProperties; +use datafusion::physical_expr::PhysicalExpr; use datafusion::physical_expr::{expressions, ScalarFunctionExpr}; -use datafusion::physical_expr::{PhysicalExpr, PhysicalExprExt}; use datafusion::prelude::SessionConfig; use datafusion::scalar::ScalarValue; use datafusion_physical_expr_adapter::{ @@ -300,9 +302,7 @@ impl PhysicalExprAdapter for ShreddedJsonRewriter { fn rewrite(&self, expr: Arc) -> Result> { // First try our custom JSON shredding rewrite let rewritten = expr - .transform_with_lambdas_params(|expr, lambdas_params| { - self.rewrite_impl(expr, &self.physical_file_schema, lambdas_params) - }) + .transform(|expr| self.rewrite_impl(expr, &self.physical_file_schema)) .data()?; // Then apply the default adapter as a fallback to handle standard schema differences @@ -335,7 +335,6 @@ impl ShreddedJsonRewriter { &self, expr: Arc, physical_file_schema: &Schema, - lambdas_params: &HashSet, ) -> Result>> { if let Some(func) = expr.as_any().downcast_ref::() { if func.name() == "json_get_str" && func.args().len() == 2 { @@ -349,7 +348,6 @@ impl ShreddedJsonRewriter { if let Some(column) = func.args()[1] .as_any() .downcast_ref::() - .filter(|col| !lambdas_params.contains(col.name())) { let column_name = column.name(); // Check if there's a flat column with underscore prefix diff --git a/datafusion/catalog-listing/src/helpers.rs b/datafusion/catalog-listing/src/helpers.rs index 444f505f4280b..78b46171006a7 100644 --- a/datafusion/catalog-listing/src/helpers.rs +++ b/datafusion/catalog-listing/src/helpers.rs @@ -52,9 +52,9 @@ use object_store::{ObjectMeta, ObjectStore}; /// was performed pub fn expr_applicable_for_cols(col_names: &[&str], expr: &Expr) -> bool { let mut is_applicable = true; - expr.apply_with_lambdas_params(|expr, lambdas_params| match expr { - Expr::Column(col) => { - is_applicable &= col_names.contains(&col.name()) || col.is_lambda_parameter(lambdas_params); + expr.apply(|expr| match expr { + Expr::Column(Column { ref name, .. }) => { + is_applicable &= col_names.contains(&name.as_str()); if is_applicable { Ok(TreeNodeRecursion::Jump) } else { @@ -87,7 +87,8 @@ pub fn expr_applicable_for_cols(col_names: &[&str], expr: &Expr) -> bool { | Expr::ScalarSubquery(_) | Expr::GroupingSet(_) | Expr::Case(_) - | Expr::Lambda(_) => Ok(TreeNodeRecursion::Continue), + | Expr::Lambda(_) + | Expr::LambdaColumn(_) => Ok(TreeNodeRecursion::Continue), Expr::ScalarFunction(scalar_function) => { match scalar_function.func.signature().volatility { diff --git a/datafusion/common/src/column.rs b/datafusion/common/src/column.rs index dd9b985e6485c..c7f0b5a4f4881 100644 --- a/datafusion/common/src/column.rs +++ b/datafusion/common/src/column.rs @@ -22,7 +22,6 @@ use crate::utils::parse_identifiers_normalized; use crate::utils::quote_identifier; use crate::{DFSchema, Diagnostic, Result, SchemaError, Spans, TableReference}; use arrow::datatypes::{Field, FieldRef}; -use std::borrow::Borrow; use std::collections::HashSet; use std::fmt; @@ -326,11 +325,6 @@ impl Column { ..self.clone() } } - - pub fn is_lambda_parameter(&self, lambdas_params: &crate::HashSet + Eq + std::hash::Hash>) -> bool { - // currently, references to lambda parameters are always unqualified - self.relation.is_none() && lambdas_params.contains(self.name()) - } } impl From<&str> for Column { diff --git a/datafusion/common/src/cse.rs b/datafusion/common/src/cse.rs index a7ffde52c93b2..674d3386171f8 100644 --- a/datafusion/common/src/cse.rs +++ b/datafusion/common/src/cse.rs @@ -178,14 +178,6 @@ pub trait CSEController { /// if all are always evaluated. fn conditional_children(node: &Self::Node) -> Option>; - // A helper method called on each node before is_ignored, during top-down traversal during the first, - // visiting traversal of CSE. - fn visit_f_down(&mut self, _node: &Self::Node) {} - - // A helper method called on each node after is_ignored, during bottom-up traversal during the first, - // visiting traversal of CSE. - fn visit_f_up(&mut self, _node: &Self::Node) {} - // Returns true if a node is valid. If a node is invalid then it can't be eliminated. // Validity is propagated up which means no subtree can be eliminated that contains // an invalid node. @@ -282,7 +274,7 @@ where /// thus can not be extracted as a common [`TreeNode`]. conditional: bool, - controller: &'a mut C, + controller: &'a C, } /// Record item that used when traversing a [`TreeNode`] tree. @@ -360,7 +352,6 @@ where self.visit_stack .push(VisitRecord::EnterMark(self.down_index)); self.down_index += 1; - self.controller.visit_f_down(node); // If a node can short-circuit then some of its children might not be executed so // count the occurrence either normal or conditional. @@ -423,7 +414,6 @@ where self.visit_stack .push(VisitRecord::NodeItem(node_id, is_valid)); self.up_index += 1; - self.controller.visit_f_up(node); Ok(TreeNodeRecursion::Continue) } @@ -542,7 +532,7 @@ where /// Add an identifier to `id_array` for every [`TreeNode`] in this tree. fn node_to_id_array<'n>( - &mut self, + &self, node: &'n N, node_stats: &mut NodeStats<'n, N>, id_array: &mut IdArray<'n, N>, @@ -556,7 +546,7 @@ where random_state: &self.random_state, found_common: false, conditional: false, - controller: &mut self.controller, + controller: &self.controller, }; node.visit(&mut visitor)?; @@ -571,7 +561,7 @@ where /// Each element is itself the result of [`CSE::node_to_id_array`] for that node /// (e.g. the identifiers for each node in the tree) fn to_arrays<'n>( - &mut self, + &self, nodes: &'n [N], node_stats: &mut NodeStats<'n, N>, ) -> Result<(bool, Vec>)> { @@ -771,7 +761,7 @@ mod test { #[test] fn id_array_visitor() -> Result<()> { let alias_generator = AliasGenerator::new(); - let mut eliminator = CSE::new(TestTreeNodeCSEController::new( + let eliminator = CSE::new(TestTreeNodeCSEController::new( &alias_generator, TestTreeNodeMask::Normal, )); @@ -863,7 +853,7 @@ mod test { assert_eq!(expected, id_array); // include aggregates - let mut eliminator = CSE::new(TestTreeNodeCSEController::new( + let eliminator = CSE::new(TestTreeNodeCSEController::new( &alias_generator, TestTreeNodeMask::NormalAndAggregates, )); diff --git a/datafusion/common/src/dfschema.rs b/datafusion/common/src/dfschema.rs index 8a09d61292b27..24d152a7dba8c 100644 --- a/datafusion/common/src/dfschema.rs +++ b/datafusion/common/src/dfschema.rs @@ -314,10 +314,8 @@ impl DFSchema { return; } - let self_fields: HashSet<(Option<&TableReference>, &str)> = self - .iter() - .map(|(qualifier, field)| (qualifier, field.name().as_str())) - .collect(); + let self_fields: HashSet<(Option<&TableReference>, &FieldRef)> = + self.iter().collect(); let self_unqualified_names: HashSet<&str> = self .inner .fields @@ -330,10 +328,7 @@ impl DFSchema { for (qualifier, field) in other_schema.iter() { // skip duplicate columns let duplicated_field = match qualifier { - Some(q) => { - self_fields.contains(&(Some(q), field.name().as_str())) - || self_fields.contains(&(None, field.name().as_str())) - } + Some(q) => self_fields.contains(&(Some(q), field)), // for unqualified columns, check as unqualified name None => self_unqualified_names.contains(field.name().as_str()), }; @@ -872,11 +867,6 @@ impl DFSchema { &self.functional_dependencies } - /// Get functional dependencies - pub fn field_qualifiers(&self) -> &[Option] { - &self.field_qualifiers - } - /// Iterate over the qualifiers and fields in the DFSchema pub fn iter(&self) -> impl Iterator, &FieldRef)> { self.field_qualifiers diff --git a/datafusion/common/src/lib.rs b/datafusion/common/src/lib.rs index 8923df683f899..76c7b46e32737 100644 --- a/datafusion/common/src/lib.rs +++ b/datafusion/common/src/lib.rs @@ -117,8 +117,6 @@ pub mod hash_set { pub use hashbrown::hash_set::Entry; } -pub use hashbrown; - /// Downcast an Arrow Array to a concrete type, return an `DataFusionError::Internal` if the cast is /// not possible. In normal usage of DataFusion the downcast should always succeed. /// diff --git a/datafusion/common/src/utils/mod.rs b/datafusion/common/src/utils/mod.rs index ec2dad505a561..3fd0683659caf 100644 --- a/datafusion/common/src/utils/mod.rs +++ b/datafusion/common/src/utils/mod.rs @@ -41,6 +41,7 @@ use sqlparser::{ast::Ident, dialect::GenericDialect, parser::Parser}; use std::borrow::{Borrow, Cow}; use std::cmp::{min, Ordering}; use std::collections::HashSet; +use std::iter::repeat_n; use std::num::NonZero; use std::ops::Range; use std::sync::Arc; @@ -953,7 +954,7 @@ pub fn make_list_array_indices( ); for (i, (&start, &end)) in std::iter::zip(&offsets[..], &offsets[1..]).enumerate() { - indices.extend(std::iter::repeat_n( + indices.extend(repeat_n( T::Native::usize_as(i), end.as_usize() - start.as_usize(), )); @@ -966,16 +967,13 @@ pub fn make_list_array_indices( pub fn make_list_element_indices( offsets: &OffsetBuffer, ) -> PrimitiveArray { - let mut indices = vec![ - T::default_value(); - offsets.last().unwrap().as_usize() - - offsets.first().unwrap().as_usize() - ]; + let mut indices = + Vec::with_capacity(offsets.last().unwrap().as_usize() - offsets[0].as_usize()); for (&start, &end) in std::iter::zip(&offsets[..], &offsets[1..]) { - for i in 0..end.as_usize() - start.as_usize() { - indices[start.as_usize() + i] = T::Native::usize_as(i); - } + indices.extend( + (0..end.as_usize() - start.as_usize()).map(|i| T::Native::usize_as(i)), + ); } PrimitiveArray::new(indices.into(), None) @@ -986,12 +984,10 @@ pub fn make_fsl_array_indices( list_size: i32, array_len: usize, ) -> PrimitiveArray { - let mut indices = vec![0; list_size as usize * array_len]; + let mut indices = Vec::with_capacity(list_size as usize * array_len); for i in 0..array_len { - for j in 0..list_size as usize { - indices[i + j] = i as i32; - } + indices.extend(repeat_n(i as i32, list_size as usize)); } PrimitiveArray::new(indices.into(), None) @@ -1002,11 +998,13 @@ pub fn make_fsl_element_indices( list_size: i32, array_len: usize, ) -> PrimitiveArray { - let mut indices = vec![0; list_size as usize * array_len]; + let mut indices = Vec::with_capacity(list_size as usize * array_len); - for i in 0..array_len { - for j in 0..list_size as usize { - indices[i + j] = j as i32; + if array_len > 0 { + indices.extend((0..list_size as usize).map(|j| j as i32)); + + for _ in 1..array_len { + indices.extend_from_within(0..list_size as usize); } } diff --git a/datafusion/core/src/execution/session_state.rs b/datafusion/core/src/execution/session_state.rs index ad4ffb487ee1d..c15b7eae08432 100644 --- a/datafusion/core/src/execution/session_state.rs +++ b/datafusion/core/src/execution/session_state.rs @@ -41,6 +41,7 @@ use datafusion_common::alias::AliasGenerator; use datafusion_common::config::Dialect; use datafusion_common::config::{ConfigExtension, ConfigOptions, TableOptions}; use datafusion_common::display::{PlanType, StringifiedPlan, ToStringifiedPlan}; +use datafusion_common::tree_node::TreeNode; use datafusion_common::{ config_err, exec_err, plan_datafusion_err, DFSchema, DataFusionError, ResolvedTableReference, TableReference, @@ -700,9 +701,7 @@ impl SessionState { let config_options = self.config_options(); for rewrite in self.analyzer.function_rewrites() { expr = expr - .transform_up_with_schema(df_schema, |expr, df_schema| { - rewrite.rewrite(expr, df_schema, config_options) - })? + .transform_up(|expr| rewrite.rewrite(expr, df_schema, config_options))? .data; } create_physical_expr(&expr, df_schema, self.execution_props()) diff --git a/datafusion/core/tests/parquet/mod.rs b/datafusion/core/tests/parquet/mod.rs index eea6085c02b9f..097600e45eadd 100644 --- a/datafusion/core/tests/parquet/mod.rs +++ b/datafusion/core/tests/parquet/mod.rs @@ -516,7 +516,7 @@ fn make_uint_batches(start: u8, end: u8) -> RecordBatch { Field::new("u64", DataType::UInt64, true), ])); let v8: Vec = (start..end).collect(); - let v16: Vec = (start as u16..end as _).collect(); + let v16: Vec = (start as _..end as _).collect(); let v32: Vec = (start as _..end as _).collect(); let v64: Vec = (start as _..end as _).collect(); RecordBatch::try_new( diff --git a/datafusion/core/tests/parquet/schema_adapter.rs b/datafusion/core/tests/parquet/schema_adapter.rs index dfa4c91ba5dd8..40fc6176e212b 100644 --- a/datafusion/core/tests/parquet/schema_adapter.rs +++ b/datafusion/core/tests/parquet/schema_adapter.rs @@ -27,7 +27,7 @@ use datafusion::datasource::listing::{ ListingTable, ListingTableConfig, ListingTableConfigExt, }; use datafusion::prelude::{SessionConfig, SessionContext}; -use datafusion_common::tree_node::{Transformed, TransformedResult}; +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::DataFusionError; use datafusion_common::{ColumnStatistics, ScalarValue}; use datafusion_datasource::file::FileSource; @@ -39,7 +39,7 @@ use datafusion_datasource::ListingTableUrl; use datafusion_datasource_parquet::source::ParquetSource; use datafusion_execution::object_store::ObjectStoreUrl; use datafusion_physical_expr::expressions::{self, Column}; -use datafusion_physical_expr::{PhysicalExpr, PhysicalExprExt}; +use datafusion_physical_expr::PhysicalExpr; use datafusion_physical_expr_adapter::{ DefaultPhysicalExprAdapter, DefaultPhysicalExprAdapterFactory, PhysicalExprAdapter, PhysicalExprAdapterFactory, @@ -181,10 +181,10 @@ struct CustomPhysicalExprAdapter { impl PhysicalExprAdapter for CustomPhysicalExprAdapter { fn rewrite(&self, mut expr: Arc) -> Result> { expr = expr - .transform_with_lambdas_params(|expr, lambdas_params| { + .transform(|expr| { if let Some(column) = expr.as_any().downcast_ref::() { let field_name = column.name(); - if !lambdas_params.contains(field_name) && self + if self .physical_file_schema .field_with_name(field_name) .ok() diff --git a/datafusion/datasource-parquet/src/row_filter.rs b/datafusion/datasource-parquet/src/row_filter.rs index 45441ad71086c..660b32f486120 100644 --- a/datafusion/datasource-parquet/src/row_filter.rs +++ b/datafusion/datasource-parquet/src/row_filter.rs @@ -77,7 +77,7 @@ use datafusion_common::Result; use datafusion_datasource::schema_adapter::{SchemaAdapterFactory, SchemaMapper}; use datafusion_physical_expr::expressions::Column; use datafusion_physical_expr::utils::reassign_expr_columns; -use datafusion_physical_expr::{split_conjunction, PhysicalExpr, PhysicalExprExt}; +use datafusion_physical_expr::{split_conjunction, PhysicalExpr}; use datafusion_physical_plan::metrics; @@ -336,20 +336,6 @@ impl<'schema> PushdownChecker<'schema> { fn prevents_pushdown(&self) -> bool { self.non_primitive_columns || self.projected_columns } - - fn check(&mut self, node: Arc) -> Result { - node.apply_with_lambdas_params(|node, lamdas_params| { - if let Some(column) = node.as_any().downcast_ref::() { - if !lamdas_params.contains(column.name()) { - if let Some(recursion) = self.check_single_column(column.name()) { - return Ok(recursion); - } - } - } - - Ok(TreeNodeRecursion::Continue) - }) - } } impl TreeNodeVisitor<'_> for PushdownChecker<'_> { diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index e2845ea5a7de8..6387fc4a44f38 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -398,10 +398,30 @@ pub enum Expr { OuterReferenceColumn(FieldRef, Column), /// Unnest expression Unnest(Unnest), - /// Lambda expression, valid only as a scalar function argument - /// Note that it has it's own scoped schema, different from the plan schema, - /// that can be constructed with ScalarUDF::arguments_schemas and variants + /// Lambda expression Lambda(Lambda), + LambdaColumn(LambdaColumn), +} + +#[derive(Clone, PartialEq, PartialOrd, Eq, Debug, Hash)] +pub struct LambdaColumn { + pub name: String, + pub field: FieldRef, + pub spans: Spans, +} + +impl LambdaColumn { + pub fn new(name: String, field: FieldRef) -> Self { + Self { + name, + field, + spans: Spans::new(), + } + } + + pub fn spans_mut(&mut self) -> &mut Spans { + &mut self.spans + } } impl Default for Expr { @@ -1547,6 +1567,7 @@ impl Expr { Expr::Wildcard { .. } => "Wildcard", Expr::Unnest { .. } => "Unnest", Expr::Lambda { .. } => "Lambda", + Expr::LambdaColumn { .. } => "LambdaColumn", } } @@ -1930,11 +1951,9 @@ impl Expr { /// /// See [`Self::column_refs`] for details pub fn add_column_refs<'a>(&'a self, set: &mut HashSet<&'a Column>) { - self.apply_with_lambdas_params(|expr, lambdas_params| { + self.apply(|expr| { if let Expr::Column(col) = expr { - if col.relation.is_some() || !lambdas_params.contains(col.name()) { - set.insert(col); - } + set.insert(col); } Ok(TreeNodeRecursion::Continue) }) @@ -1967,11 +1986,9 @@ impl Expr { /// /// See [`Self::column_refs_counts`] for details pub fn add_column_ref_counts<'a>(&'a self, map: &mut HashMap<&'a Column, usize>) { - self.apply_with_lambdas_params(|expr, lambdas_params| { + self.apply(|expr| { if let Expr::Column(col) = expr { - if !col.is_lambda_parameter(lambdas_params) { - *map.entry(col).or_default() += 1; - } + *map.entry(col).or_default() += 1; } Ok(TreeNodeRecursion::Continue) }) @@ -1980,10 +1997,8 @@ impl Expr { /// Returns true if there are any column references in this Expr pub fn any_column_refs(&self) -> bool { - self.exists_with_lambdas_params(|expr, lambdas_params| { - Ok(matches!(expr, Expr::Column(c) if !c.is_lambda_parameter(lambdas_params))) - }) - .expect("exists closure is infallible") + self.exists(|expr| Ok(matches!(expr, Expr::Column(_)))) + .expect("exists closure is infallible") } /// Return true if the expression contains out reference(correlated) expressions. @@ -2023,7 +2038,7 @@ impl Expr { /// at least one placeholder. pub fn infer_placeholder_types(self, schema: &DFSchema) -> Result<(Expr, bool)> { let mut has_placeholder = false; - self.transform_with_schema(schema, |mut expr, schema| { + self.transform(|mut expr| { match &mut expr { // Default to assuming the arguments are the same type Expr::BinaryExpr(BinaryExpr { left, op: _, right }) => { @@ -2107,7 +2122,8 @@ impl Expr { | Expr::WindowFunction(..) | Expr::Literal(..) | Expr::Placeholder(..) - | Expr::Lambda { .. } => false, + | Expr::Lambda(..) + | Expr::LambdaColumn(..) => false, } } @@ -2703,12 +2719,17 @@ impl HashNode for Expr { column.hash(state); } Expr::Unnest(Unnest { expr: _expr }) => {} - Expr::Lambda(Lambda { - params, - body: _, - }) => { + Expr::Lambda(Lambda { params, body: _ }) => { params.hash(state); } + Expr::LambdaColumn(LambdaColumn { + name, + field, + spans: _, + }) => { + name.hash(state); + field.hash(state); + } }; } } @@ -3022,12 +3043,12 @@ impl Display for SchemaDisplay<'_> { } } } - Expr::Lambda(Lambda { - params, - body, - }) => { + Expr::Lambda(Lambda { params, body }) => { write!(f, "({}) -> {body}", display_comma_separated(params)) } + Expr::LambdaColumn(c) => { + write!(f, "{}", c.name) + } } } } @@ -3208,10 +3229,7 @@ impl Display for SqlDisplay<'_> { } } } - Expr::Lambda(Lambda { - params, - body, - }) => { + Expr::Lambda(Lambda { params, body }) => { write!(f, "({}) -> {}", params.join(", "), SchemaDisplay(body)) } _ => write!(f, "{}", self.0), @@ -3521,12 +3539,12 @@ impl Display for Expr { Expr::Unnest(Unnest { expr }) => { write!(f, "{UNNEST_COLUMN_PREFIX}({expr})") } - Expr::Lambda(Lambda { - params, - body, - }) => { + Expr::Lambda(Lambda { params, body }) => { write!(f, "({}) -> {body}", params.join(", ")) } + Expr::LambdaColumn(c) => { + write!(f, "{}", c.name) + } } } } diff --git a/datafusion/expr/src/expr_rewriter/mod.rs b/datafusion/expr/src/expr_rewriter/mod.rs index 81ec6e7acbe38..9c3c5df7007ff 100644 --- a/datafusion/expr/src/expr_rewriter/mod.rs +++ b/datafusion/expr/src/expr_rewriter/mod.rs @@ -62,15 +62,11 @@ pub trait FunctionRewrite: Debug { /// Recursively call `LogicalPlanBuilder::normalize` on all [`Column`] expressions /// in the `expr` expression tree. pub fn normalize_col(expr: Expr, plan: &LogicalPlan) -> Result { - expr.transform_up_with_lambdas_params(|expr, lambdas_params| { + expr.transform(|expr| { Ok({ if let Expr::Column(c) = expr { - if c.relation.is_some() || !lambdas_params.contains(c.name()) { - let col = LogicalPlanBuilder::normalize(plan, c)?; - Transformed::yes(Expr::Column(col)) - } else { - Transformed::no(Expr::Column(c)) - } + let col = LogicalPlanBuilder::normalize(plan, c)?; + Transformed::yes(Expr::Column(col)) } else { Transformed::no(expr) } @@ -95,21 +91,14 @@ pub fn normalize_col_with_schemas_and_ambiguity_check( return Ok(Expr::Unnest(Unnest { expr: Box::new(e) })); } - expr.transform_up_with_lambdas_params(|expr, lambdas_params| { + expr.transform(|expr| { Ok({ - match expr { - Expr::Column(c) => { - if c.relation.is_none() && lambdas_params.contains(c.name()) { - Transformed::no(Expr::Column(c)) - } else { - let col = c.normalize_with_schemas_and_ambiguity_check( - schemas, - using_columns, - )?; - Transformed::yes(Expr::Column(col)) - } - } - _ => Transformed::no(expr), + if let Expr::Column(c) = expr { + let col = + c.normalize_with_schemas_and_ambiguity_check(schemas, using_columns)?; + Transformed::yes(Expr::Column(col)) + } else { + Transformed::no(expr) } }) }) @@ -144,18 +133,15 @@ pub fn normalize_sorts( /// Recursively replace all [`Column`] expressions in a given expression tree with /// `Column` expressions provided by the hash map argument. pub fn replace_col(expr: Expr, replace_map: &HashMap<&Column, &Column>) -> Result { - expr.transform_up_with_lambdas_params(|expr, lambdas_params| { + expr.transform(|expr| { Ok({ - match &expr { - Expr::Column(c) if !c.is_lambda_parameter(lambdas_params) => { - match replace_map.get(c) { - Some(new_c) => { - Transformed::yes(Expr::Column((*new_c).to_owned())) - } - None => Transformed::no(expr), - } + if let Expr::Column(c) = &expr { + match replace_map.get(c) { + Some(new_c) => Transformed::yes(Expr::Column((*new_c).to_owned())), + None => Transformed::no(expr), } - _ => Transformed::no(expr), + } else { + Transformed::no(expr) } }) }) @@ -215,7 +201,6 @@ pub fn strip_outer_reference(expr: Expr) -> Expr { expr.transform(|expr| { Ok({ if let Expr::OuterReferenceColumn(_, col) = expr { - //todo: what if this col collides with a lambda parameter? Transformed::yes(Expr::Column(col)) } else { Transformed::no(expr) diff --git a/datafusion/expr/src/expr_rewriter/order_by.rs b/datafusion/expr/src/expr_rewriter/order_by.rs index b94c632ce74b3..6db95555502da 100644 --- a/datafusion/expr/src/expr_rewriter/order_by.rs +++ b/datafusion/expr/src/expr_rewriter/order_by.rs @@ -77,10 +77,6 @@ fn rewrite_in_terms_of_projection( // assumption is that each item in exprs, such as "b + c" is // available as an output column named "b + c" expr.transform(|expr| { - if matches!(expr, Expr::Lambda(_)) { - return Ok(Transformed::new(expr, false, TreeNodeRecursion::Jump)) - } - // search for unnormalized names first such as "c1" (such as aliases) if let Some(found) = proj_exprs.iter().find(|a| (**a) == expr) { let (qualifier, field_name) = found.qualified_name(); diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index 4a1efadccd0ec..1b7f5f0212c6b 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -16,12 +16,12 @@ // under the License. use super::{Between, Expr, Like}; -use crate::expr::FieldMetadata; use crate::expr::{ AggregateFunction, AggregateFunctionParams, Alias, BinaryExpr, Cast, InList, InSubquery, Lambda, Placeholder, ScalarFunction, TryCast, Unnest, WindowFunction, WindowFunctionParams, }; +use crate::expr::{FieldMetadata, LambdaColumn}; use crate::type_coercion::functions::{ fields_with_aggregate_udf, fields_with_window_udf, }; @@ -234,7 +234,10 @@ impl ExprSchemable for Expr { // Grouping sets do not really have a type and do not appear in projections Ok(DataType::Null) } - Expr::Lambda { .. } => Ok(DataType::Null), + Expr::Lambda(Lambda { params: _, body }) => body.get_type(schema), + Expr::LambdaColumn(LambdaColumn { name: _, field, .. }) => { + Ok(field.data_type().clone()) + } } } @@ -353,7 +356,8 @@ impl ExprSchemable for Expr { // in projections Ok(true) } - Expr::Lambda { .. } => Ok(false), + Expr::Lambda(l) => l.body.nullable(input_schema), + Expr::LambdaColumn(c) => Ok(c.field.is_nullable()), } } @@ -542,30 +546,14 @@ impl ExprSchemable for Expr { func.return_field(&new_fields) } - // Expr::Lambda(Lambda { params, body}) => body.to_field(schema), Expr::ScalarFunction(ScalarFunction { func, args }) => { - let fields = if args.iter().any(|arg| matches!(arg, Expr::Lambda(_))) { - let lambdas_schemas = func.arguments_expr_schema(args, schema)?; - - std::iter::zip(args, lambdas_schemas) - // .map(|(e, schema)| e.to_field(schema).map(|(_, f)| f)) - .map(|(e, schema)| match e { - Expr::Lambda(Lambda { params: _, body }) => { - body.to_field(&schema).map(|(_, f)| f) - } - _ => e.to_field(&schema).map(|(_, f)| f), - }) - .collect::>>()? - } else { - args.iter() - .map(|e| e.to_field(schema).map(|(_, f)| f)) - .collect::>>()? - }; - - let arg_types = fields + let (arg_types, fields): (Vec, Vec>) = args .iter() - .map(|f| f.data_type().clone()) - .collect::>(); + .map(|e| e.to_field(schema).map(|(_, f)| f)) + .collect::>>()? + .into_iter() + .map(|f| (f.data_type().clone(), f)) + .unzip(); // Verify that function is invoked with correct number and type of arguments as defined in `TypeSignature` let new_data_types = data_types_with_scalar_udf(&arg_types, func) @@ -637,6 +625,7 @@ impl ExprSchemable for Expr { self.get_type(schema)?, self.nullable(schema)?, ))), + Expr::LambdaColumn(c) => Ok(Arc::clone(&c.field)), }?; Ok(( diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index 46c7422814ace..0f26218e74779 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -118,10 +118,8 @@ pub use udaf::{ ReversedUDAF, SetMonotonicity, StatisticsArgs, }; pub use udf::{ - merge_captures_with_args, merge_captures_with_boxed_lazy_args, - merge_captures_with_lazy_args, ReturnFieldArgs, ScalarFunctionArgs, - ScalarFunctionLambdaArg, ScalarUDF, ScalarUDFImpl, ValueOrLambda, ValueOrLambdaField, - ValueOrLambdaParameter, + ReturnFieldArgs, ScalarFunctionArgs, ScalarFunctionLambdaArg, ScalarUDF, + ScalarUDFImpl, ValueOrLambda, ValueOrLambdaField, ValueOrLambdaParameter, }; pub use udwf::{LimitEffect, ReversedUDWF, WindowUDF, WindowUDFImpl}; pub use window_frame::{WindowFrame, WindowFrameBound, WindowFrameUnits}; diff --git a/datafusion/expr/src/tree_node.rs b/datafusion/expr/src/tree_node.rs index 63c535b43ee8b..df98f720a0f08 100644 --- a/datafusion/expr/src/tree_node.rs +++ b/datafusion/expr/src/tree_node.rs @@ -29,7 +29,7 @@ use datafusion_common::{ tree_node::{ Transformed, TreeNode, TreeNodeContainer, TreeNodeRecursion, TreeNodeRefContainer, }, - DFSchema, HashSet, Result, + Result, }; /// Implementation of the [`TreeNode`] trait @@ -80,7 +80,8 @@ impl TreeNode for Expr { | Expr::Exists { .. } | Expr::ScalarSubquery(_) | Expr::Wildcard { .. } - | Expr::Placeholder(_) => Ok(TreeNodeRecursion::Continue), + | Expr::Placeholder(_) + | Expr::LambdaColumn(_) => Ok(TreeNodeRecursion::Continue), Expr::BinaryExpr(BinaryExpr { left, right, .. }) => { (left, right).apply_ref_elements(f) } @@ -131,7 +132,8 @@ impl TreeNode for Expr { | Expr::Exists { .. } | Expr::ScalarSubquery(_) | Expr::ScalarVariable(_, _) - | Expr::Literal(_, _) => Transformed::no(self), + | Expr::Literal(_, _) + | Expr::LambdaColumn(_) => Transformed::no(self), Expr::Unnest(Unnest { expr, .. }) => expr .map_elements(f)? .update_data(|expr| Expr::Unnest(Unnest { expr })), @@ -321,680 +323,3 @@ impl TreeNode for Expr { }) } } - -impl Expr { - /// Similarly to [`Self::rewrite`], rewrites this expr and its inputs using `f`, - /// including lambdas that may appear in expressions such as `array_transform([1, 2], v -> v*2)`. - #[cfg_attr(feature = "recursive_protection", recursive::recursive)] - pub fn rewrite_with_schema< - R: for<'a> TreeNodeRewriterWithPayload = &'a DFSchema>, - >( - self, - schema: &DFSchema, - rewriter: &mut R, - ) -> Result> { - rewriter - .f_down(self, schema)? - .transform_children(|n| match &n { - Expr::ScalarFunction(ScalarFunction { func, args }) - if args.iter().any(|arg| matches!(arg, Expr::Lambda(_))) => - { - let mut lambdas_schemas = func - .arguments_schema_from_logical_args(args, schema)? - .into_iter(); - - n.map_children(|n| { - n.rewrite_with_schema(&lambdas_schemas.next().unwrap(), rewriter) - }) - } - _ => n.map_children(|n| n.rewrite_with_schema(schema, rewriter)), - })? - .transform_parent(|n| rewriter.f_up(n, schema)) - } - - /// Similarly to [`Self::rewrite`], rewrites this expr and its inputs using `f`, - /// including lambdas that may appear in expressions such as `array_transform([1, 2], v -> v*2)`. - pub fn rewrite_with_lambdas_params< - R: for<'a> TreeNodeRewriterWithPayload< - Node = Expr, - Payload<'a> = &'a HashSet, - >, - >( - self, - rewriter: &mut R, - ) -> Result> { - self.rewrite_with_lambdas_params_impl(&HashSet::new(), rewriter) - } - - #[cfg_attr(feature = "recursive_protection", recursive::recursive)] - fn rewrite_with_lambdas_params_impl< - R: for<'a> TreeNodeRewriterWithPayload< - Node = Expr, - Payload<'a> = &'a HashSet, - >, - >( - self, - args: &HashSet, - rewriter: &mut R, - ) -> Result> { - rewriter - .f_down(self, args)? - .transform_children(|n| match n { - Expr::Lambda(Lambda { - ref params, - body: _, - }) => { - let mut args = args.clone(); - - args.extend(params.iter().cloned()); - - n.map_children(|n| { - n.rewrite_with_lambdas_params_impl(&args, rewriter) - }) - } - _ => { - n.map_children(|n| n.rewrite_with_lambdas_params_impl(args, rewriter)) - } - })? - .transform_parent(|n| rewriter.f_up(n, args)) - } - - /// Similarly to [`Self::map_children`], rewrites all lambdas that may - /// appear in expressions such as `array_transform([1, 2], v -> v*2)`. - /// - /// Returns the current node. - pub fn map_children_with_lambdas_params< - F: FnMut(Self, &HashSet) -> Result>, - >( - self, - args: &HashSet, - mut f: F, - ) -> Result> { - match &self { - Expr::Lambda(Lambda { params, body: _ }) => { - let mut args = args.clone(); - - args.extend(params.iter().cloned()); - - self.map_children(|expr| f(expr, &args)) - } - _ => self.map_children(|expr| f(expr, args)), - } - } - - /// Similarly to [`Self::transform_up`], rewrites this expr and its inputs using `f`, - /// including lambdas that may appear in expressions such as `array_transform([1, 2], v -> v*2)`. - pub fn transform_up_with_lambdas_params< - F: FnMut(Self, &HashSet) -> Result>, - >( - self, - mut f: F, - ) -> Result> { - #[cfg_attr(feature = "recursive_protection", recursive::recursive)] - fn transform_up_with_lambdas_params_impl< - F: FnMut(Expr, &HashSet) -> Result>, - >( - node: Expr, - args: &HashSet, - f: &mut F, - ) -> Result> { - node.map_children_with_lambdas_params(args, |node, args| { - transform_up_with_lambdas_params_impl(node, args, f) - })? - .transform_parent(|node| f(node, args)) - /*match &node { - Expr::Lambda(Lambda { params, body: _ }) => { - let mut args = args.clone(); - - args.extend(params.iter().cloned()); - - node.map_children(|n| { - transform_up_with_lambdas_params_impl(n, &args, f) - })? - .transform_parent(|n| f(n, &args)) - } - _ => node - .map_children(|n| transform_up_with_lambdas_params_impl(n, args, f))? - .transform_parent(|n| f(n, args)), - }*/ - } - - transform_up_with_lambdas_params_impl(self, &HashSet::new(), &mut f) - } - - /// Similarly to [`Self::transform_down`], rewrites this expr and its inputs using `f`, - /// including lambdas that may appear in expressions such as `array_transform([1, 2], v -> v*2)`. - pub fn transform_down_with_lambdas_params< - F: FnMut(Self, &HashSet) -> Result>, - >( - self, - mut f: F, - ) -> Result> { - #[cfg_attr(feature = "recursive_protection", recursive::recursive)] - fn transform_down_with_lambdas_params_impl< - F: FnMut(Expr, &HashSet) -> Result>, - >( - node: Expr, - args: &HashSet, - f: &mut F, - ) -> Result> { - f(node, args)?.transform_children(|node| { - node.map_children_with_lambdas_params(args, |node, args| { - transform_down_with_lambdas_params_impl(node, args, f) - }) - }) - } - - transform_down_with_lambdas_params_impl(self, &HashSet::new(), &mut f) - } - - pub fn apply_with_lambdas_params< - 'n, - F: FnMut(&'n Self, &HashSet<&'n str>) -> Result, - >( - &'n self, - mut f: F, - ) -> Result { - #[cfg_attr(feature = "recursive_protection", recursive::recursive)] - fn apply_with_lambdas_params_impl< - 'n, - F: FnMut(&'n Expr, &HashSet<&'n str>) -> Result, - >( - node: &'n Expr, - args: &HashSet<&'n str>, - f: &mut F, - ) -> Result { - match node { - Expr::Lambda(Lambda { params, body: _ }) => { - let mut args = args.clone(); - - args.extend(params.iter().map(|v| v.as_str())); - - f(node, &args)?.visit_children(|| { - node.apply_children(|c| { - apply_with_lambdas_params_impl(c, &args, f) - }) - }) - } - _ => f(node, args)?.visit_children(|| { - node.apply_children(|c| apply_with_lambdas_params_impl(c, args, f)) - }), - } - } - - apply_with_lambdas_params_impl(self, &HashSet::new(), &mut f) - } - - /// Similarly to [`Self::transform`], rewrites this expr and its inputs using `f`, - /// including lambdas that may appear in expressions such as `array_transform([1, 2], v -> v*2)`. - pub fn transform_with_schema< - F: FnMut(Self, &DFSchema) -> Result>, - >( - self, - schema: &DFSchema, - f: F, - ) -> Result> { - self.transform_up_with_schema(schema, f) - } - - /// Similarly to [`Self::transform_up`], rewrites this expr and its inputs using `f`, - /// including lambdas that may appear in expressions such as `array_transform([1, 2], v -> v*2)`. - pub fn transform_up_with_schema< - F: FnMut(Self, &DFSchema) -> Result>, - >( - self, - schema: &DFSchema, - mut f: F, - ) -> Result> { - #[cfg_attr(feature = "recursive_protection", recursive::recursive)] - fn transform_up_with_schema_impl< - F: FnMut(Expr, &DFSchema) -> Result>, - >( - node: Expr, - schema: &DFSchema, - f: &mut F, - ) -> Result> { - node.map_children_with_schema(schema, |n, schema| { - transform_up_with_schema_impl(n, schema, f) - })? - .transform_parent(|n| f(n, schema)) - } - - transform_up_with_schema_impl(self, schema, &mut f) - } - - pub fn map_children_with_schema< - F: FnMut(Self, &DFSchema) -> Result>, - >( - self, - schema: &DFSchema, - mut f: F, - ) -> Result> { - match self { - Expr::ScalarFunction(ref fun) - if fun.args.iter().any(|arg| matches!(arg, Expr::Lambda(_))) => - { - let mut args_schemas = fun - .func - .arguments_schema_from_logical_args(&fun.args, schema)? - .into_iter(); - - self.map_children(|expr| f(expr, &args_schemas.next().unwrap())) - } - _ => self.map_children(|expr| f(expr, schema)), - } - } - - pub fn exists_with_lambdas_params) -> Result>( - &self, - mut f: F, - ) -> Result { - let mut found = false; - - self.apply_with_lambdas_params(|n, lambdas_params| { - if f(n, lambdas_params)? { - found = true; - Ok(TreeNodeRecursion::Stop) - } else { - Ok(TreeNodeRecursion::Continue) - } - })?; - - Ok(found) - } -} - -pub trait ExprWithLambdasRewriter2: Sized { - /// Invoked while traversing down the tree before any children are rewritten. - /// Default implementation returns the node as is and continues recursion. - fn f_down(&mut self, node: Expr, _schema: &DFSchema) -> Result> { - Ok(Transformed::no(node)) - } - - /// Invoked while traversing up the tree after all children have been rewritten. - /// Default implementation returns the node as is and continues recursion. - fn f_up(&mut self, node: Expr, _schema: &DFSchema) -> Result> { - Ok(Transformed::no(node)) - } -} -pub trait TreeNodeRewriterWithPayload: Sized { - type Node; - type Payload<'a>; - - /// Invoked while traversing down the tree before any children are rewritten. - /// Default implementation returns the node as is and continues recursion. - fn f_down<'a>( - &mut self, - node: Self::Node, - _payload: Self::Payload<'a>, - ) -> Result> { - Ok(Transformed::no(node)) - } - - /// Invoked while traversing up the tree after all children have been rewritten. - /// Default implementation returns the node as is and continues recursion. - fn f_up<'a>( - &mut self, - node: Self::Node, - _payload: Self::Payload<'a>, - ) -> Result> { - Ok(Transformed::no(node)) - } -} - -/* -struct LambdaColumnNormalizer<'a> { - existing_qualifiers: HashSet<&'a str>, - alias_generator: AliasGenerator, - lambdas_columns: HashMap>, -} - -impl<'a> LambdaColumnNormalizer<'a> { - fn new(dfschema: &'a DFSchema, expr: &'a Expr) -> Self { - let mut existing_qualifiers: HashSet<&'a str> = dfschema - .field_qualifiers() - .iter() - .flatten() - .map(|tbl| tbl.table()) - .filter(|table| table.starts_with("lambda_")) - .collect(); - - expr.apply(|node| { - if let Expr::Lambda(lambda) = node { - if let Some(qualifier) = &lambda.qualifier { - existing_qualifiers.insert(qualifier); - } - } - - Ok(TreeNodeRecursion::Continue) - }) - .unwrap(); - - Self { - existing_qualifiers, - alias_generator: AliasGenerator::new(), - lambdas_columns: HashMap::new(), - } - } -} - -impl TreeNodeRewriter for LambdaColumnNormalizer<'_> { - type Node = Expr; - - fn f_down(&mut self, node: Self::Node) -> Result> { - match node { - Expr::Lambda(mut lambda) => { - let tbl = lambda.qualifier.as_ref().map_or_else( - || loop { - let table = self.alias_generator.next("lambda"); - - if !self.existing_qualifiers.contains(table.as_str()) { - break TableReference::bare(table); - } - }, - |qualifier| TableReference::bare(qualifier.as_str()), - ); - - for param in &lambda.params { - self.lambdas_columns - .entry_ref(param) - .or_default() - .push(tbl.clone()); - } - - if lambda.qualifier.is_none() { - lambda.qualifier = Some(tbl.table().to_owned()); - - Ok(Transformed::yes(Expr::Lambda(lambda))) - } else { - Ok(Transformed::no(Expr::Lambda(lambda))) - } - } - Expr::Column(c) if c.relation.is_none() => { - if let Some(lambda_qualifier) = self.lambdas_columns.get(c.name()) { - Ok(Transformed::yes(Expr::Column( - c.with_relation(lambda_qualifier.last().unwrap().clone()), - ))) - } else { - Ok(Transformed::no(Expr::Column(c))) - } - } - _ => Ok(Transformed::no(node)) - } - } - - fn f_up(&mut self, node: Self::Node) -> Result> { - if let Expr::Lambda(lambda) = &node { - for param in &lambda.params { - match self.lambdas_columns.entry_ref(param) { - EntryRef::Occupied(mut entry) => { - let chain = entry.get_mut(); - - chain.pop(); - - if chain.is_empty() { - entry.remove(); - } - } - EntryRef::Vacant(_) => unreachable!(), - } - } - } - - Ok(Transformed::no(node)) - } -} -*/ - -// helpers used in udf.rs -#[cfg(test)] -pub(crate) mod tests { - use super::TreeNodeRewriterWithPayload; - use crate::{ - col, expr::Lambda, Expr, ScalarUDF, ScalarUDFImpl, ValueOrLambdaParameter, - }; - use arrow::datatypes::{DataType, Field, Schema}; - use datafusion_common::{ - tree_node::{Transformed, TreeNodeRecursion}, - DFSchema, HashSet, Result, - }; - use datafusion_expr_common::signature::{Signature, Volatility}; - - pub(crate) fn list_list_int() -> DFSchema { - DFSchema::try_from(Schema::new(vec![Field::new( - "v", - DataType::new_list(DataType::new_list(DataType::Int32, false), false), - false, - )])) - .unwrap() - } - - pub(crate) fn list_int() -> DFSchema { - DFSchema::try_from(Schema::new(vec![Field::new( - "v", - DataType::new_list(DataType::Int32, false), - false, - )])) - .unwrap() - } - - fn int() -> DFSchema { - DFSchema::try_from(Schema::new(vec![Field::new("v", DataType::Int32, false)])) - .unwrap() - } - - pub(crate) fn array_transform_udf() -> ScalarUDF { - ScalarUDF::new_from_impl(ArrayTransformFunc::new()) - } - - pub(crate) fn args() -> Vec { - vec![ - col("v"), - Expr::Lambda(Lambda::new( - vec!["v".into()], - array_transform_udf().call(vec![ - col("v"), - Expr::Lambda(Lambda::new(vec!["v".into()], -col("v"))), - ]), - )), - ] - } - - // array_transform(v, |v| -> array_transform(v, |v| -> -v)) - fn array_transform() -> Expr { - array_transform_udf().call(args()) - } - - #[derive(Debug, PartialEq, Eq, Hash)] - pub(crate) struct ArrayTransformFunc { - signature: Signature, - } - - impl ArrayTransformFunc { - pub fn new() -> Self { - Self { - signature: Signature::any(2, Volatility::Immutable), - } - } - } - - impl ScalarUDFImpl for ArrayTransformFunc { - fn as_any(&self) -> &dyn std::any::Any { - self - } - - fn name(&self) -> &str { - "array_transform" - } - - fn signature(&self) -> &Signature { - &self.signature - } - - fn return_type(&self, arg_types: &[DataType]) -> Result { - Ok(arg_types[0].clone()) - } - - fn lambdas_parameters( - &self, - args: &[ValueOrLambdaParameter], - ) -> Result>>> { - let ValueOrLambdaParameter::Value(value_field) = &args[0] else { - unreachable!() - }; - - let DataType::List(field) = value_field.data_type() else { - unreachable!() - }; - - Ok(vec![ - None, - Some(vec![Field::new( - "", - field.data_type().clone(), - field.is_nullable(), - )]), - ]) - } - - fn invoke_with_args( - &self, - _args: crate::ScalarFunctionArgs, - ) -> Result { - unimplemented!() - } - } - - #[test] - fn test_rewrite_with_schema() { - let schema = list_list_int(); - let array_transform = array_transform(); - - let mut rewriter = OkRewriter::default(); - - array_transform - .rewrite_with_schema(&schema, &mut rewriter) - .unwrap(); - - let expected = [ - ( - "f_down array_transform(v, (v) -> array_transform(v, (v) -> (- v)))", - list_list_int(), - ), - ("f_down v", list_list_int()), - ("f_up v", list_list_int()), - ("f_down (v) -> array_transform(v, (v) -> (- v))", list_int()), - ("f_down array_transform(v, (v) -> (- v))", list_int()), - ("f_down v", list_int()), - ("f_up v", list_int()), - ("f_down (v) -> (- v)", int()), - ("f_down (- v)", int()), - ("f_down v", int()), - ("f_up v", int()), - ("f_up (- v)", int()), - ("f_up (v) -> (- v)", int()), - ("f_up array_transform(v, (v) -> (- v))", list_int()), - ("f_up (v) -> array_transform(v, (v) -> (- v))", list_int()), - ( - "f_up array_transform(v, (v) -> array_transform(v, (v) -> (- v)))", - list_list_int(), - ), - ] - .map(|(a, b)| (String::from(a), b)); - - assert_eq!(rewriter.steps, expected) - } - - #[derive(Default)] - struct OkRewriter { - steps: Vec<(String, DFSchema)>, - } - - impl TreeNodeRewriterWithPayload for OkRewriter { - type Node = Expr; - type Payload<'a> = &'a DFSchema; - - fn f_down( - &mut self, - node: Expr, - schema: &DFSchema, - ) -> Result> { - self.steps.push((format!("f_down {node}"), schema.clone())); - - Ok(Transformed::no(node)) - } - - fn f_up( - &mut self, - node: Expr, - schema: &DFSchema, - ) -> Result> { - self.steps.push((format!("f_up {node}"), schema.clone())); - - Ok(Transformed::no(node)) - } - } - - #[test] - fn test_transform_up_with_lambdas_params() { - let mut steps = vec![]; - - array_transform() - .transform_up_with_lambdas_params(|node, params| { - steps.push((node.to_string(), params.clone())); - - Ok(Transformed::no(node)) - }) - .unwrap(); - - let lambdas_params = &HashSet::from([String::from("v")]); - - let expected = [ - ("v", lambdas_params), - ("v", lambdas_params), - ("v", lambdas_params), - ("(- v)", lambdas_params), - ("(v) -> (- v)", lambdas_params), - ("array_transform(v, (v) -> (- v))", lambdas_params), - ("(v) -> array_transform(v, (v) -> (- v))", lambdas_params), - ( - "array_transform(v, (v) -> array_transform(v, (v) -> (- v)))", - lambdas_params, - ), - ] - .map(|(a, b)| (String::from(a), b.clone())); - - assert_eq!(steps, expected); - } - - #[test] - fn test_apply_with_lambdas_params() { - let array_transform = array_transform(); - let mut steps = vec![]; - - array_transform - .apply_with_lambdas_params(|node, params| { - steps.push((node.to_string(), params.clone())); - - Ok(TreeNodeRecursion::Continue) - }) - .unwrap(); - - let expected = [ - ("v", HashSet::from(["v"])), - ("v", HashSet::from(["v"])), - ("v", HashSet::from(["v"])), - ("(- v)", HashSet::from(["v"])), - ("(v) -> (- v)", HashSet::from(["v"])), - ("array_transform(v, (v) -> (- v))", HashSet::from(["v"])), - ("(v) -> array_transform(v, (v) -> (- v))", HashSet::from(["v"])), - ( - "array_transform(v, (v) -> array_transform(v, (v) -> (- v)))", - HashSet::from(["v"]), - ), - ] - .map(|(a, b)| (String::from(a), b)); - - assert_eq!(steps, expected); - } -} diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index 74ac1b456ff04..911fc890e2bc5 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -18,30 +18,23 @@ //! [`ScalarUDF`]: Scalar User Defined Functions use crate::async_udf::AsyncScalarUDF; -use crate::expr::{schema_name_from_exprs_comma_separated_without_space, Lambda}; +use crate::expr::schema_name_from_exprs_comma_separated_without_space; use crate::simplify::{ExprSimplifyResult, SimplifyInfo}; use crate::sort_properties::{ExprProperties, SortProperties}; use crate::udf_eq::UdfEq; -use crate::{ColumnarValue, Documentation, Expr, ExprSchemable, Signature}; -use arrow::array::{ArrayRef, RecordBatch}; -use arrow::datatypes::{DataType, Field, FieldRef, Fields, Schema}; -use datafusion_common::alias::AliasGenerator; +use crate::{ColumnarValue, Documentation, Expr, Signature}; +use arrow::array::RecordBatch; +use arrow::datatypes::{DataType, Field, FieldRef}; use datafusion_common::config::ConfigOptions; -use datafusion_common::tree_node::TreeNodeRecursion; -use datafusion_common::{ - exec_err, not_impl_err, DFSchema, ExprSchema, Result, ScalarValue, -}; +use datafusion_common::{not_impl_err, ExprSchema, Result, ScalarValue}; use datafusion_expr_common::dyn_eq::{DynEq, DynHash}; use datafusion_expr_common::interval_arithmetic::Interval; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; -use indexmap::IndexMap; use std::any::Any; -use std::borrow::Cow; use std::cmp::Ordering; -use std::collections::HashMap; use std::fmt::Debug; use std::hash::{Hash, Hasher}; -use std::sync::{Arc, LazyLock}; +use std::sync::Arc; /// Logical representation of a Scalar User Defined Function. /// @@ -352,272 +345,14 @@ impl ScalarUDF { pub fn as_async(&self) -> Option<&AsyncScalarUDF> { self.inner().as_any().downcast_ref::() } - - /// Variation of arguments_from_logical_args that works with arrow Schema's and ScalarFunctionArgMetadata instead - pub(crate) fn arguments_expr_schema<'a>( - &self, - args: &[Expr], - schema: &'a dyn ExprSchema, - ) -> Result> { - self.arguments_scope_with( - &lambda_parameters(args, schema)?, - ExtendableExprSchema::new(schema), - ) - } - - /// Variation of arguments_from_logical_args that works with arrow Schema's and ScalarFunctionArgMetadata instead, - pub fn arguments_arrow_schema<'a>( - &self, - args: &[ValueOrLambdaParameter], - schema: &'a Schema, - ) -> Result>> { - self.arguments_scope_with(args, Cow::Borrowed(schema)) - } - - pub fn arguments_schema_from_logical_args<'a>( - &self, - args: &[Expr], - schema: &'a DFSchema, - ) -> Result>> { - self.arguments_scope_with( - &lambda_parameters(args, schema)?, - Cow::Borrowed(schema), - ) - } - - /// Scalar function supports lambdas as arguments, which will be evaluated with - /// a different schema that of the function itself. This functions returns a vec - /// with the correspoding schema that each argument will run - /// - /// Return a vec with a value for each argument in args that, if it's a value, it's a clone of base_scope, - /// if it's a lambda, it's the return of merge called with the index and the fields from lambdas_parameters - /// updated with names from metadata - fn arguments_scope_with( - &self, - args: &[ValueOrLambdaParameter], - schema: T, - ) -> Result> { - let parameters = self.inner().lambdas_parameters(args)?; - - if parameters.len() != args.len() { - return exec_err!( - "lambdas_schemas: {} lambdas_parameters returned {} values instead of {}", - self.name(), - args.len(), - parameters.len() - ); - } - - std::iter::zip(args, parameters) - .enumerate() - .map(|(i, (arg, parameters))| match (arg, parameters) { - (ValueOrLambdaParameter::Value(_), None) => Ok(schema.clone()), - (ValueOrLambdaParameter::Value(_), Some(_)) => exec_err!("lambdas_schemas: {} argument {} (0-indexed) is a value but lambdas_parameters result treat it as a lambda", self.name(), i), - (ValueOrLambdaParameter::Lambda(_, _), None) => exec_err!("lambdas_schemas: {} argument {} (0-indexed) is a lambda but lambdas_parameters result treat it as a value", self.name(), i), - (ValueOrLambdaParameter::Lambda(names, captures), Some(args)) => { - if names.len() > args.len() { - return exec_err!("lambdas_schemas: {} argument {} (0-indexed), a lambda, supports up to {} arguments, but got {}", self.name(), i, args.len(), names.len()) - } - - let fields = std::iter::zip(*names, args) - .map(|(name, arg)| arg.with_name(name)) - .collect::(); - - if *captures { - schema.extend(fields) - } else { - T::from_fields(fields) - } - } - }) - .collect() - } -} - -pub trait ExtendSchema: Sized { - fn from_fields(params: Fields) -> Result; - fn extend(&self, params: Fields) -> Result; -} - -impl ExtendSchema for DFSchema { - fn from_fields(params: Fields) -> Result { - DFSchema::from_unqualified_fields(params, Default::default()) - } - - fn extend(&self, params: Fields) -> Result { - let qualified_fields = self - .iter() - .map(|(qualifier, field)| { - if params.find(field.name().as_str()).is_none() { - return (qualifier.cloned(), Arc::clone(field)); - } - - let alias_gen = AliasGenerator::new(); - - loop { - let alias = alias_gen.next(field.name().as_str()); - - if params.find(&alias).is_none() - && !self.has_column_with_unqualified_name(&alias) - { - return ( - qualifier.cloned(), - Arc::new(Field::new( - alias, - field.data_type().clone(), - field.is_nullable(), - )), - ); - } - } - }) - .collect(); - - let mut schema = DFSchema::new_with_metadata(qualified_fields, HashMap::new())?; - let fields_schema = DFSchema::from_unqualified_fields(params, HashMap::new())?; - - schema.merge(&fields_schema); - - assert_eq!( - schema.fields().len(), - self.fields().len() + fields_schema.fields().len() - ); - - Ok(schema) - } -} - -impl ExtendSchema for Schema { - fn from_fields(params: Fields) -> Result { - Ok(Schema::new(params)) - } - - fn extend(&self, params: Fields) -> Result { - let mut params2 = params.iter() - .map(|f| (f.name().as_str(), Some(Arc::clone(f)))) - .collect::>(); - - let mut fields = self.fields() - .iter() - .map(|field| { - match params2.get_mut(field.name().as_str()).and_then(|p| p.take()) { - Some(param) => param, - None => Arc::clone(field), - } - }) - .collect::>(); - - fields.extend(params2.into_values().flatten()); - - let fields = self - .fields() - .iter() - .map(|field| { - if params.find(field.name().as_str()).is_none() { - return Arc::clone(field); - } - - let alias_gen = AliasGenerator::new(); - - loop { - let alias = alias_gen.next(field.name().as_str()); - - if params.find(&alias).is_none() - && self.column_with_name(&alias).is_none() - { - return Arc::new(Field::new( - alias, - field.data_type().clone(), - field.is_nullable(), - )); - } - } - }) - .chain(params.iter().cloned()) - .collect::(); - - assert_eq!(fields.len(), self.fields().len() + params.len()); - - Ok(Schema::new_with_metadata(fields, self.metadata.clone())) - } -} - -impl ExtendSchema for Cow<'_, T> { - fn from_fields(params: Fields) -> Result { - Ok(Cow::Owned(T::from_fields(params)?)) - } - - fn extend(&self, params: Fields) -> Result { - Ok(Cow::Owned(self.as_ref().extend(params)?)) - } -} - -impl ExtendSchema for Arc { - fn from_fields(params: Fields) -> Result { - Ok(Arc::new(T::from_fields(params)?)) - } - - fn extend(&self, params: Fields) -> Result { - Ok(Arc::new(self.as_ref().extend(params)?)) - } -} - -impl ExtendSchema for ExtendableExprSchema<'_> { - fn from_fields(params: Fields) -> Result { - static EMPTY_DFSCHEMA: LazyLock = LazyLock::new(DFSchema::empty); - - Ok(ExtendableExprSchema { - fields_chain: vec![params], - outer_schema: &*EMPTY_DFSCHEMA, - }) - } - - fn extend(&self, params: Fields) -> Result { - Ok(ExtendableExprSchema { - fields_chain: std::iter::once(params) - .chain(self.fields_chain.iter().cloned()) - .collect(), - outer_schema: self.outer_schema, - }) - } -} - -/// A `&dyn ExprSchema` wrapper that supports adding the parameters of a lambda -#[derive(Clone, Debug)] -struct ExtendableExprSchema<'a> { - fields_chain: Vec, - outer_schema: &'a dyn ExprSchema, -} - -impl<'a> ExtendableExprSchema<'a> { - fn new(schema: &'a dyn ExprSchema) -> Self { - Self { - fields_chain: vec![], - outer_schema: schema, - } - } -} - -impl ExprSchema for ExtendableExprSchema<'_> { - fn field_from_column(&self, col: &datafusion_common::Column) -> Result<&Field> { - if col.relation.is_none() { - for fields in &self.fields_chain { - if let Some((_index, lambda_param)) = fields.find(&col.name) { - return Ok(lambda_param); - } - } - } - - self.outer_schema.field_from_column(col) - } } #[derive(Clone, Debug)] -pub enum ValueOrLambdaParameter<'a> { +pub enum ValueOrLambdaParameter { /// A columnar value with the given field Value(FieldRef), - /// A lambda with the given parameters names and a flag indicating wheter it captures any columns - Lambda(&'a [String], bool), + /// A lambda + Lambda, } impl From for ScalarUDF @@ -1331,111 +1066,6 @@ impl ScalarUDFImpl for AliasedScalarUDFImpl { } } -fn lambda_parameters<'a>( - args: &'a [Expr], - schema: &dyn ExprSchema, -) -> Result>> { - args.iter() - .map(|e| match e { - Expr::Lambda(Lambda { params, body: _ }) => { - let mut captures = false; - - e.apply_with_lambdas_params(|expr, lambdas_params| match expr { - Expr::Column(c) if !c.is_lambda_parameter(lambdas_params) => { - captures = true; - - Ok(TreeNodeRecursion::Stop) - } - _ => Ok(TreeNodeRecursion::Continue), - }) - .unwrap(); - - Ok(ValueOrLambdaParameter::Lambda(params.as_slice(), captures)) - } - _ => Ok(ValueOrLambdaParameter::Value(e.to_field(schema)?.1)), - }) - .collect() -} - -/// Merge the lambda body captured columns with it's arguments -/// Datafusion relies on an unspecified field ordering implemented in this function -/// As such, this is the only correct way to merge the captured values with the arguments -/// The number of args should not be lower than the number of params -/// -/// See also merge_captures_with_lazy_args and merge_captures_with_boxed_lazy_args that lazily -/// computes only the necessary arguments to match the number of params -pub fn merge_captures_with_args( - captures: Option<&RecordBatch>, - params: &[FieldRef], - args: &[ArrayRef], -) -> Result { - if args.len() < params.len() { - return exec_err!( - "merge_captures_with_args called with {} params but with {} args", - params.len(), - args.len() - ); - } - - // the order of the merged batch must be kept in sync with ScalarFunction::lambdas_schemas variants - let (fields, columns) = match captures { - Some(captures) => { - let fields = captures - .schema() - .fields() - .iter() - .chain(params) - .cloned() - .collect::>(); - - let columns = [captures.columns(), args].concat(); - - (fields, columns) - } - None => (params.to_vec(), args.to_vec()), - }; - - Ok(RecordBatch::try_new( - Arc::new(Schema::new(fields)), - columns, - )?) -} - -/// Lazy version of merge_captures_with_args that receives closures to compute the arguments, -/// and calls only the necessary to match the number of params -pub fn merge_captures_with_lazy_args( - captures: Option<&RecordBatch>, - params: &[FieldRef], - args: &[&dyn Fn() -> Result], -) -> Result { - merge_captures_with_args( - captures, - params, - &args - .iter() - .take(params.len()) - .map(|arg| arg()) - .collect::>>()?, - ) -} - -/// Variation of merge_captures_with_lazy_args that take boxed closures -pub fn merge_captures_with_boxed_lazy_args( - captures: Option<&RecordBatch>, - params: &[FieldRef], - args: &[Box Result>], -) -> Result { - merge_captures_with_args( - captures, - params, - &args - .iter() - .take(params.len()) - .map(|arg| arg()) - .collect::>>()?, - ) -} - #[cfg(test)] mod tests { use super::*; @@ -1514,83 +1144,4 @@ mod tests { value.hash(hasher); hasher.finish() } - - use std::borrow::Cow; - - use arrow::datatypes::Fields; - - use crate::{ - tree_node::tests::{args, list_int, list_list_int, array_transform_udf}, - udf::{lambda_parameters, ExtendableExprSchema}, - }; - - #[test] - fn test_arguments_expr_schema() { - let args = args(); - let schema = list_list_int(); - - let schemas = array_transform_udf() - .arguments_expr_schema(&args, &schema) - .unwrap() - .into_iter() - .map(|s| format!("{s:?}")) - .collect::>(); - - let mut lambdas_parameters = array_transform_udf() - .inner() - .lambdas_parameters(&lambda_parameters(&args, &schema).unwrap()) - .unwrap(); - - assert_eq!( - schemas, - &[ - format!("{}", &list_list_int()), - format!( - "{:?}", - ExtendableExprSchema { - fields_chain: vec![Fields::from( - lambdas_parameters[0].take().unwrap() - )], - outer_schema: &list_list_int() - } - ), - ] - ) - } - - #[test] - fn test_arguments_arrow_schema() { - let list_int = list_int(); - let list_list_int = list_list_int(); - - let schemas = array_transform_udf() - .arguments_arrow_schema( - &lambda_parameters(&args(), &list_list_int).unwrap(), - //&[HashSet::new(), HashSet::from([0])], - list_list_int.as_arrow(), - ) - .unwrap(); - - assert_eq!( - schemas, - &[ - Cow::Borrowed(list_list_int.as_arrow()), - Cow::Owned(list_int.as_arrow().clone()) - ] - ) - } - - #[test] - fn test_arguments_schema_from_logical_args() { - let list_list_int = list_list_int(); - - let schemas = array_transform_udf() - .arguments_schema_from_logical_args(&args(), &list_list_int) - .unwrap(); - - assert_eq!( - schemas, - &[Cow::Borrowed(&list_list_int), Cow::Owned(list_int())] - ) - } } diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 93fcfaef882ff..1e5e4c9f5b99a 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -266,12 +266,10 @@ pub fn grouping_set_to_exprlist(group_expr: &[Expr]) -> Result> { /// Recursively walk an expression tree, collecting the unique set of columns /// referenced in the expression pub fn expr_to_columns(expr: &Expr, accum: &mut HashSet) -> Result<()> { - expr.apply_with_lambdas_params(|expr, lambdas_params| { + expr.apply(|expr| { match expr { Expr::Column(qc) => { - if qc.relation.is_some() || !lambdas_params.contains(qc.name()) { - accum.insert(qc.clone()); - } + accum.insert(qc.clone()); } // Use explicit pattern match instead of a default // implementation, so that in the future if someone adds @@ -310,7 +308,8 @@ pub fn expr_to_columns(expr: &Expr, accum: &mut HashSet) -> Result<()> { | Expr::Wildcard { .. } | Expr::Placeholder(_) | Expr::OuterReferenceColumn { .. } - | Expr::Lambda { .. } => {} + | Expr::Lambda(_) + | Expr::LambdaColumn(_) => {} } Ok(TreeNodeRecursion::Continue) }) @@ -653,7 +652,6 @@ where /// Search an `Expr`, and all of its nested `Expr`'s, for any that pass the /// provided test. The returned `Expr`'s are deduplicated and returned in order /// of appearance (depth first). -/// todo: document about that columns may refer to a lambda parameter? fn find_exprs_in_expr(expr: &Expr, test_fn: &F) -> Vec where F: Fn(&Expr) -> bool, @@ -676,7 +674,6 @@ where } /// Recursively inspect an [`Expr`] and all its children. -/// todo: document about that columns may refer to a lambda parameter? pub fn inspect_expr_pre(expr: &Expr, mut f: F) -> Result<(), E> where F: FnMut(&Expr) -> Result<(), E>, @@ -748,19 +745,13 @@ pub fn columnize_expr(e: Expr, input: &LogicalPlan) -> Result { _ => return Ok(e), }; let exprs_map: HashMap<&Expr, Column> = output_exprs.into_iter().collect(); - e.transform_down_with_lambdas_params(|node: Expr, lambdas_params| { - if matches!(&node, Expr::Column(c) if c.is_lambda_parameter(lambdas_params)) { - return Ok(Transformed::no(node)); - } - - match exprs_map.get(&node) { - Some(column) => Ok(Transformed::new( - Expr::Column(column.clone()), - true, - TreeNodeRecursion::Jump, - )), - None => Ok(Transformed::no(node)), - } + e.transform_down(|node: Expr| match exprs_map.get(&node) { + Some(column) => Ok(Transformed::new( + Expr::Column(column.clone()), + true, + TreeNodeRecursion::Jump, + )), + None => Ok(Transformed::no(node)), }) .data() } @@ -777,11 +768,9 @@ pub fn find_column_exprs(exprs: &[Expr]) -> Vec { pub(crate) fn find_columns_referenced_by_expr(e: &Expr) -> Vec { let mut exprs = vec![]; - e.apply_with_lambdas_params(|expr, lambdas_params| { + e.apply(|expr| { if let Expr::Column(c) = expr { - if !c.is_lambda_parameter(lambdas_params) { - exprs.push(c.clone()) - } + exprs.push(c.clone()) } Ok(TreeNodeRecursion::Continue) }) @@ -810,9 +799,9 @@ pub(crate) fn find_column_indexes_referenced_by_expr( schema: &DFSchemaRef, ) -> Vec { let mut indexes = vec![]; - e.apply_with_lambdas_params(|expr, lambdas_params| { + e.apply(|expr| { match expr { - Expr::Column(qc) if !qc.is_lambda_parameter(lambdas_params) => { + Expr::Column(qc) => { if let Ok(idx) = schema.index_of_column(qc) { indexes.push(idx); } diff --git a/datafusion/functions-nested/Cargo.toml b/datafusion/functions-nested/Cargo.toml index 6e0d1048f9697..0299aebdcac47 100644 --- a/datafusion/functions-nested/Cargo.toml +++ b/datafusion/functions-nested/Cargo.toml @@ -53,6 +53,7 @@ datafusion-functions = { workspace = true } datafusion-functions-aggregate = { workspace = true } datafusion-functions-aggregate-common = { workspace = true } datafusion-macros = { workspace = true } +datafusion-physical-expr = { workspace = true } datafusion-physical-expr-common = { workspace = true } itertools = { workspace = true, features = ["use_std"] } log = { workspace = true } diff --git a/datafusion/functions-nested/src/array_transform.rs b/datafusion/functions-nested/src/array_transform.rs index 700fed477b4cb..dbc08473c246b 100644 --- a/datafusion/functions-nested/src/array_transform.rs +++ b/datafusion/functions-nested/src/array_transform.rs @@ -18,17 +18,23 @@ //! [`ScalarUDFImpl`] definitions for array_transform function. use arrow::{ - array::{Array, ArrayRef, AsArray, FixedSizeListArray, LargeListArray, ListArray}, + array::{ + Array, ArrayRef, AsArray, FixedSizeListArray, LargeListArray, ListArray, + RecordBatch, RecordBatchOptions, + }, compute::take_record_batch, - datatypes::{DataType, Field}, + datatypes::{DataType, Field, FieldRef, Schema}, }; use datafusion_common::{ - HashSet, Result, exec_err, internal_err, tree_node::{Transformed, TreeNode}, utils::{elements_indices, list_indices, list_values, take_function_args} + HashMap, Result, exec_err, internal_err, tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeRewriter}, utils::{elements_indices, list_indices, list_values, take_function_args} }; use datafusion_expr::{ - ColumnarValue, Documentation, Expr, ScalarFunctionArgs, ScalarUDFImpl, Signature, ValueOrLambda, ValueOrLambdaField, ValueOrLambdaParameter, Volatility, expr::Lambda, merge_captures_with_lazy_args + ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, + ValueOrLambda, ValueOrLambdaField, ValueOrLambdaParameter, Volatility, }; use datafusion_macros::user_doc; +use datafusion_physical_expr::expressions::{LambdaColumn, LambdaExpr}; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use std::{any::Any, sync::Arc}; make_udf_expr_and_func!( @@ -115,7 +121,7 @@ impl ScalarUDFImpl for ArrayTransform { ); }; - //TODO: should metadata be passed? If so, with the same keys or prefixed/suffixed? + //TODO: should metadata be copied into the transformed array? // lambda is the resulting field of executing the lambda body // with the parameters returned in lambdas_parameters @@ -151,6 +157,7 @@ impl ScalarUDFImpl for ArrayTransform { }; let list_array = list_value.to_array(args.number_rows)?; + let list_values = list_values(&list_array)?; // if any column got captured, we need to adjust it to the values arrays, // duplicating values of list with mulitple values and removing values of empty lists @@ -163,23 +170,26 @@ impl ScalarUDFImpl for ArrayTransform { // use closures and merge_captures_with_lazy_args so that it calls only the needed ones based on the number of arguments // avoiding unnecessary computations - let values_param = || Ok(Arc::clone(list_values(&list_array)?)); + let values_param = || Ok(Arc::clone(list_values)); let indices_param = || elements_indices(&list_array); - // the order of the merged schema is an unspecified implementation detail that may change in the future, - // using this function is the correct way to merge as it return the correct ordering and will change in sync - // the implementation without the need for fixes. It also computes only the parameters requested - let lambda_batch = merge_captures_with_lazy_args( - adjusted_captures.as_ref(), - &lambda.params, // ScalarUDF already merged the fields returned in lambdas_parameters with the parameters names definied in the lambda, so we don't need to + let binded_body = bind_lambda_columns( + Arc::clone(&lambda.body), + &lambda.params, &[&values_param, &indices_param], )?; // call the transforming expression with the record batch composed of the list values merged with captured columns - let transformed_values = lambda - .body - .evaluate(&lambda_batch)? - .into_array(lambda_batch.num_rows())?; + let transformed_values = binded_body + .evaluate(&adjusted_captures.unwrap_or_else(|| { + RecordBatch::try_new_with_options( + Arc::new(Schema::empty()), + vec![], + &RecordBatchOptions::new().with_row_count(Some(list_values.len())), + ) + .unwrap() + }))? + .into_array(list_values.len())?; let field = match args.return_field.data_type() { DataType::List(field) @@ -233,7 +243,7 @@ impl ScalarUDFImpl for ArrayTransform { &self, args: &[ValueOrLambdaParameter], ) -> Result>>> { - let [ValueOrLambdaParameter::Value(list), ValueOrLambdaParameter::Lambda(_, _)] = + let [ValueOrLambdaParameter::Value(list), ValueOrLambdaParameter::Lambda] = args else { return exec_err!( @@ -253,9 +263,9 @@ impl ScalarUDFImpl for ArrayTransform { // we don't need to omit the index in the case the lambda don't specify, e.g. array_transform([], v -> v*2), // nor check whether the lambda contains more than two parameters, e.g. array_transform([], (v, i, j) -> v+i+j), // as datafusion will do that for us - let value = Field::new("value", field.data_type().clone(), field.is_nullable()) + let value = Field::new("", field.data_type().clone(), field.is_nullable()) .with_metadata(field.metadata().clone()); - let index = Field::new("index", index_type, false); + let index = Field::new("", index_type, false); Ok(vec![None, Some(vec![value, index])]) } @@ -264,3 +274,65 @@ impl ScalarUDFImpl for ArrayTransform { self.doc() } } + +fn bind_lambda_columns( + expr: Arc, + params: &[FieldRef], + args: &[&dyn Fn() -> Result], +) -> Result> { + let columns = std::iter::zip(params, args) + .map(|(param, arg)| Ok((param.name().as_str(), (arg()?, 0)))) + .collect::>>()?; + + expr.rewrite(&mut BindLambdaColumn::new(columns)).data() +} + +struct BindLambdaColumn<'a> { + columns: HashMap<&'a str, (ArrayRef, usize)>, +} + +impl<'a> BindLambdaColumn<'a> { + fn new(columns: HashMap<&'a str, (ArrayRef, usize)>) -> Self { + Self { columns } + } +} + +impl TreeNodeRewriter for BindLambdaColumn<'_> { + type Node = Arc; + + fn f_down(&mut self, node: Self::Node) -> Result> { + if let Some(lambda_column) = node.as_any().downcast_ref::() { + if let Some((value, shadows)) = self.columns.get(lambda_column.name()) { + if *shadows == 0 { + return Ok(Transformed::yes(Arc::new( + lambda_column.clone().with_value(value.clone()), + ))); + } + } + } else if let Some(inner_lambda) = node.as_any().downcast_ref::() { + for param in inner_lambda.params() { + if let Some((_value, shadows)) = self.columns.get_mut(param.as_str()) { + *shadows += 1; + } + } + + if self.columns.values().all(|(_value, shadows)| *shadows > 0) { + return Ok(Transformed::new(node, false, TreeNodeRecursion::Jump)) + } + } + + Ok(Transformed::no(node)) + } + + fn f_up(&mut self, node: Self::Node) -> Result> { + if let Some(inner_lambda) = node.as_any().downcast_ref::() { + for param in inner_lambda.params() { + if let Some((_value, shadows)) = self.columns.get_mut(param.as_str()) { + *shadows -= 1; + } + } + } + + Ok(Transformed::no(node)) + } +} diff --git a/datafusion/functions/src/core/union_tag.rs b/datafusion/functions/src/core/union_tag.rs index ecdebf66e0043..4832c368872bf 100644 --- a/datafusion/functions/src/core/union_tag.rs +++ b/datafusion/functions/src/core/union_tag.rs @@ -173,7 +173,6 @@ mod tests { fields, UnionMode::Dense, ); - let arg_fields = vec![Field::new("a", scalar.data_type(), false).into()]; let return_type = DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Utf8)); @@ -181,9 +180,9 @@ mod tests { let result = UnionTagFunc::new() .invoke_with_args(ScalarFunctionArgs { args: vec![ColumnarValue::Scalar(scalar)], - arg_fields, number_rows: 1, return_field: Field::new("res", return_type, true).into(), + arg_fields: vec![], config_options: Arc::new(ConfigOptions::default()), lambdas: None, }) @@ -198,7 +197,6 @@ mod tests { #[test] fn union_scalar_empty() { let scalar = ScalarValue::Union(None, UnionFields::empty(), UnionMode::Dense); - let arg_fields = vec![Field::new("a", scalar.data_type(), false).into()]; let return_type = DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Utf8)); @@ -206,9 +204,9 @@ mod tests { let result = UnionTagFunc::new() .invoke_with_args(ScalarFunctionArgs { args: vec![ColumnarValue::Scalar(scalar)], - arg_fields, number_rows: 1, return_field: Field::new("res", return_type, true).into(), + arg_fields: vec![], config_options: Arc::new(ConfigOptions::default()), lambdas: None, }) diff --git a/datafusion/optimizer/src/analyzer/function_rewrite.rs b/datafusion/optimizer/src/analyzer/function_rewrite.rs index 0e5e602f8238e..c6bf14ebce2e3 100644 --- a/datafusion/optimizer/src/analyzer/function_rewrite.rs +++ b/datafusion/optimizer/src/analyzer/function_rewrite.rs @@ -19,7 +19,7 @@ use super::AnalyzerRule; use datafusion_common::config::ConfigOptions; -use datafusion_common::tree_node::Transformed; +use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::{DFSchema, Result}; use crate::utils::NamePreserver; @@ -64,16 +64,15 @@ impl ApplyFunctionRewrites { let original_name = name_preserver.save(&expr); // recursively transform the expression, applying the rewrites at each step - let transformed_expr = - expr.transform_up_with_schema(&schema, |expr, schema| { - let mut result = Transformed::no(expr); - for rewriter in self.function_rewrites.iter() { - result = result.transform_data(|expr| { - rewriter.rewrite(expr, schema, options) - })?; - } - Ok(result) - })?; + let transformed_expr = expr.transform_up(|expr| { + let mut result = Transformed::no(expr); + for rewriter in self.function_rewrites.iter() { + result = result.transform_data(|expr| { + rewriter.rewrite(expr, &schema, options) + })?; + } + Ok(result) + })?; Ok(transformed_expr.update_data(|expr| original_name.restore(expr))) }) diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 1b82182e8600f..763f693f2f607 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -20,7 +20,6 @@ use std::sync::Arc; use datafusion_expr::binary::BinaryTypeCoercer; -use datafusion_expr::tree_node::TreeNodeRewriterWithPayload; use itertools::{izip, Itertools as _}; use arrow::datatypes::{DataType, Field, IntervalUnit, Schema}; @@ -28,7 +27,7 @@ use arrow::datatypes::{DataType, Field, IntervalUnit, Schema}; use crate::analyzer::AnalyzerRule; use crate::utils::NamePreserver; use datafusion_common::config::ConfigOptions; -use datafusion_common::tree_node::Transformed; +use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter}; use datafusion_common::{ exec_err, internal_datafusion_err, internal_err, not_impl_err, plan_datafusion_err, plan_err, Column, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, @@ -141,7 +140,7 @@ fn analyze_internal( // apply coercion rewrite all expressions in the plan individually plan.map_expressions(|expr| { let original_name = name_preserver.save(&expr); - expr.rewrite_with_schema(&schema, &mut expr_rewrite) + expr.rewrite(&mut expr_rewrite) .map(|transformed| transformed.update_data(|e| original_name.restore(e))) })? // some plans need extra coercion after their expressions are coerced @@ -305,11 +304,10 @@ impl<'a> TypeCoercionRewriter<'a> { } } -impl TreeNodeRewriterWithPayload for TypeCoercionRewriter<'_> { +impl TreeNodeRewriter for TypeCoercionRewriter<'_> { type Node = Expr; - type Payload<'a> = &'a DFSchema; - fn f_up(&mut self, expr: Expr, schema: &DFSchema) -> Result> { + fn f_up(&mut self, expr: Expr) -> Result> { match expr { Expr::Unnest(_) => not_impl_err!( "Unnest should be rewritten to LogicalPlan::Unnest before type coercion" @@ -320,7 +318,7 @@ impl TreeNodeRewriterWithPayload for TypeCoercionRewriter<'_> { spans, }) => { let new_plan = - analyze_internal(schema, Arc::unwrap_or_clone(subquery))?.data; + analyze_internal(self.schema, Arc::unwrap_or_clone(subquery))?.data; Ok(Transformed::yes(Expr::ScalarSubquery(Subquery { subquery: Arc::new(new_plan), outer_ref_columns, @@ -329,7 +327,7 @@ impl TreeNodeRewriterWithPayload for TypeCoercionRewriter<'_> { } Expr::Exists(Exists { subquery, negated }) => { let new_plan = analyze_internal( - schema, + self.schema, Arc::unwrap_or_clone(subquery.subquery), )? .data; @@ -348,11 +346,11 @@ impl TreeNodeRewriterWithPayload for TypeCoercionRewriter<'_> { negated, }) => { let new_plan = analyze_internal( - schema, + self.schema, Arc::unwrap_or_clone(subquery.subquery), )? .data; - let expr_type = expr.get_type(schema)?; + let expr_type = expr.get_type(self.schema)?; let subquery_type = new_plan.schema().field(0).data_type(); let common_type = comparison_coercion(&expr_type, subquery_type).ok_or( plan_datafusion_err!( @@ -365,32 +363,32 @@ impl TreeNodeRewriterWithPayload for TypeCoercionRewriter<'_> { spans: subquery.spans, }; Ok(Transformed::yes(Expr::InSubquery(InSubquery::new( - Box::new(expr.cast_to(&common_type, schema)?), + Box::new(expr.cast_to(&common_type, self.schema)?), cast_subquery(new_subquery, &common_type)?, negated, )))) } Expr::Not(expr) => Ok(Transformed::yes(not(get_casted_expr_for_bool_op( *expr, - schema, + self.schema, )?))), Expr::IsTrue(expr) => Ok(Transformed::yes(is_true( - get_casted_expr_for_bool_op(*expr, schema)?, + get_casted_expr_for_bool_op(*expr, self.schema)?, ))), Expr::IsNotTrue(expr) => Ok(Transformed::yes(is_not_true( - get_casted_expr_for_bool_op(*expr, schema)?, + get_casted_expr_for_bool_op(*expr, self.schema)?, ))), Expr::IsFalse(expr) => Ok(Transformed::yes(is_false( - get_casted_expr_for_bool_op(*expr, schema)?, + get_casted_expr_for_bool_op(*expr, self.schema)?, ))), Expr::IsNotFalse(expr) => Ok(Transformed::yes(is_not_false( - get_casted_expr_for_bool_op(*expr, schema)?, + get_casted_expr_for_bool_op(*expr, self.schema)?, ))), Expr::IsUnknown(expr) => Ok(Transformed::yes(is_unknown( - get_casted_expr_for_bool_op(*expr, schema)?, + get_casted_expr_for_bool_op(*expr, self.schema)?, ))), Expr::IsNotUnknown(expr) => Ok(Transformed::yes(is_not_unknown( - get_casted_expr_for_bool_op(*expr, schema)?, + get_casted_expr_for_bool_op(*expr, self.schema)?, ))), Expr::Like(Like { negated, @@ -399,8 +397,8 @@ impl TreeNodeRewriterWithPayload for TypeCoercionRewriter<'_> { escape_char, case_insensitive, }) => { - let left_type = expr.get_type(schema)?; - let right_type = pattern.get_type(schema)?; + let left_type = expr.get_type(self.schema)?; + let right_type = pattern.get_type(self.schema)?; let coerced_type = like_coercion(&left_type, &right_type).ok_or_else(|| { let op_name = if case_insensitive { "ILIKE" @@ -413,9 +411,9 @@ impl TreeNodeRewriterWithPayload for TypeCoercionRewriter<'_> { })?; let expr = match left_type { DataType::Dictionary(_, inner) if *inner == DataType::Utf8 => expr, - _ => Box::new(expr.cast_to(&coerced_type, schema)?), + _ => Box::new(expr.cast_to(&coerced_type, self.schema)?), }; - let pattern = Box::new(pattern.cast_to(&coerced_type, schema)?); + let pattern = Box::new(pattern.cast_to(&coerced_type, self.schema)?); Ok(Transformed::yes(Expr::Like(Like::new( negated, expr, @@ -426,7 +424,7 @@ impl TreeNodeRewriterWithPayload for TypeCoercionRewriter<'_> { } Expr::BinaryExpr(BinaryExpr { left, op, right }) => { let (left, right) = - self.coerce_binary_op(*left, schema, op, *right, schema)?; + self.coerce_binary_op(*left, self.schema, op, *right, self.schema)?; Ok(Transformed::yes(Expr::BinaryExpr(BinaryExpr::new( Box::new(left), op, @@ -439,15 +437,15 @@ impl TreeNodeRewriterWithPayload for TypeCoercionRewriter<'_> { low, high, }) => { - let expr_type = expr.get_type(schema)?; - let low_type = low.get_type(schema)?; + let expr_type = expr.get_type(self.schema)?; + let low_type = low.get_type(self.schema)?; let low_coerced_type = comparison_coercion(&expr_type, &low_type) .ok_or_else(|| { internal_datafusion_err!( "Failed to coerce types {expr_type} and {low_type} in BETWEEN expression" ) })?; - let high_type = high.get_type(schema)?; + let high_type = high.get_type(self.schema)?; let high_coerced_type = comparison_coercion(&expr_type, &high_type) .ok_or_else(|| { internal_datafusion_err!( @@ -462,10 +460,10 @@ impl TreeNodeRewriterWithPayload for TypeCoercionRewriter<'_> { ) })?; Ok(Transformed::yes(Expr::Between(Between::new( - Box::new(expr.cast_to(&coercion_type, schema)?), + Box::new(expr.cast_to(&coercion_type, self.schema)?), negated, - Box::new(low.cast_to(&coercion_type, schema)?), - Box::new(high.cast_to(&coercion_type, schema)?), + Box::new(low.cast_to(&coercion_type, self.schema)?), + Box::new(high.cast_to(&coercion_type, self.schema)?), )))) } Expr::InList(InList { @@ -473,10 +471,10 @@ impl TreeNodeRewriterWithPayload for TypeCoercionRewriter<'_> { list, negated, }) => { - let expr_data_type = expr.get_type(schema)?; + let expr_data_type = expr.get_type(self.schema)?; let list_data_types = list .iter() - .map(|list_expr| list_expr.get_type(schema)) + .map(|list_expr| list_expr.get_type(self.schema)) .collect::>>()?; let result_type = get_coerce_type_for_list(&expr_data_type, &list_data_types); @@ -486,11 +484,11 @@ impl TreeNodeRewriterWithPayload for TypeCoercionRewriter<'_> { ), Some(coerced_type) => { // find the coerced type - let cast_expr = expr.cast_to(&coerced_type, schema)?; + let cast_expr = expr.cast_to(&coerced_type, self.schema)?; let cast_list_expr = list .into_iter() .map(|list_expr| { - list_expr.cast_to(&coerced_type, schema) + list_expr.cast_to(&coerced_type, self.schema) }) .collect::>>()?; Ok(Transformed::yes(Expr::InList(InList ::new( @@ -502,13 +500,13 @@ impl TreeNodeRewriterWithPayload for TypeCoercionRewriter<'_> { } } Expr::Case(case) => { - let case = coerce_case_expression(case, schema)?; + let case = coerce_case_expression(case, self.schema)?; Ok(Transformed::yes(Expr::Case(case))) } Expr::ScalarFunction(ScalarFunction { func, args }) => { let new_expr = coerce_arguments_for_signature_with_scalar_udf( args, - schema, + self.schema, &func, )?; Ok(Transformed::yes(Expr::ScalarFunction( @@ -528,7 +526,7 @@ impl TreeNodeRewriterWithPayload for TypeCoercionRewriter<'_> { }) => { let new_expr = coerce_arguments_for_signature_with_aggregate_udf( args, - schema, + self.schema, &func, )?; Ok(Transformed::yes(Expr::AggregateFunction( @@ -557,13 +555,13 @@ impl TreeNodeRewriterWithPayload for TypeCoercionRewriter<'_> { }, } = *window_fun; let window_frame = - coerce_window_frame(window_frame, schema, &order_by)?; + coerce_window_frame(window_frame, self.schema, &order_by)?; let args = match &fun { expr::WindowFunctionDefinition::AggregateUDF(udf) => { coerce_arguments_for_signature_with_aggregate_udf( args, - schema, + self.schema, udf, )? } @@ -600,7 +598,8 @@ impl TreeNodeRewriterWithPayload for TypeCoercionRewriter<'_> { | Expr::GroupingSet(_) | Expr::Placeholder(_) | Expr::OuterReferenceColumn(_, _) - | Expr::Lambda { .. } => Ok(Transformed::no(expr)), + | Expr::Lambda(_) + | Expr::LambdaColumn(_) => Ok(Transformed::no(expr)), } } } @@ -1133,7 +1132,7 @@ mod test { use crate::analyzer::Analyzer; use crate::assert_analyzed_plan_with_config_eq_snapshot; use datafusion_common::config::ConfigOptions; - use datafusion_common::tree_node::{TransformedResult}; + use datafusion_common::tree_node::{TransformedResult, TreeNode}; use datafusion_common::{DFSchema, DFSchemaRef, Result, ScalarValue, Spans}; use datafusion_expr::expr::{self, InSubquery, Like, ScalarFunction}; use datafusion_expr::logical_plan::{EmptyRelation, Projection, Sort}; @@ -2084,7 +2083,7 @@ mod test { let mut rewriter = TypeCoercionRewriter { schema: &schema }; let expr = is_true(lit(12i32).gt(lit(13i64))); let expected = is_true(cast(lit(12i32), DataType::Int64).gt(lit(13i64))); - let result = expr.rewrite_with_schema(&schema, &mut rewriter).data()?; + let result = expr.rewrite(&mut rewriter).data()?; assert_eq!(expected, result); // eq @@ -2095,7 +2094,7 @@ mod test { let mut rewriter = TypeCoercionRewriter { schema: &schema }; let expr = is_true(lit(12i32).eq(lit(13i64))); let expected = is_true(cast(lit(12i32), DataType::Int64).eq(lit(13i64))); - let result = expr.rewrite_with_schema(&schema, &mut rewriter).data()?; + let result = expr.rewrite(&mut rewriter).data()?; assert_eq!(expected, result); // lt @@ -2106,7 +2105,7 @@ mod test { let mut rewriter = TypeCoercionRewriter { schema: &schema }; let expr = is_true(lit(12i32).lt(lit(13i64))); let expected = is_true(cast(lit(12i32), DataType::Int64).lt(lit(13i64))); - let result = expr.rewrite_with_schema(&schema, &mut rewriter).data()?; + let result = expr.rewrite(&mut rewriter).data()?; assert_eq!(expected, result); Ok(()) diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index e06ed6e547eb5..f95a0f908b813 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -29,7 +29,7 @@ use datafusion_common::alias::AliasGenerator; use datafusion_common::cse::{CSEController, FoundCommonNodes, CSE}; use datafusion_common::tree_node::{Transformed, TreeNode}; -use datafusion_common::{qualified_name, Column, DFSchema, DFSchemaRef, HashSet, Result}; +use datafusion_common::{qualified_name, Column, DFSchema, DFSchemaRef, Result}; use datafusion_expr::expr::{Alias, ScalarFunction}; use datafusion_expr::logical_plan::{ Aggregate, Filter, LogicalPlan, Projection, Sort, Window, @@ -632,7 +632,6 @@ struct ExprCSEController<'a> { // how many aliases have we seen so far alias_counter: usize, - lambdas_params: HashSet, } impl<'a> ExprCSEController<'a> { @@ -641,7 +640,6 @@ impl<'a> ExprCSEController<'a> { alias_generator, mask, alias_counter: 0, - lambdas_params: HashSet::new(), } } } @@ -695,30 +693,11 @@ impl CSEController for ExprCSEController<'_> { } } - fn visit_f_down(&mut self, node: &Expr) { - if let Expr::Lambda(lambda) = node { - self.lambdas_params - .extend(lambda.params.iter().cloned()); - } - } - - fn visit_f_up(&mut self, node: &Expr) { - if let Expr::Lambda(lambda) = node { - for param in &lambda.params { - self.lambdas_params.remove(param); - } - } - } - fn is_valid(node: &Expr) -> bool { - !node.is_volatile_node() + !node.is_volatile_node() && !matches!(node, Expr::LambdaColumn(_)) } fn is_ignored(&self, node: &Expr) -> bool { - if matches!(node, Expr::Column(c) if c.is_lambda_parameter(&self.lambdas_params)) { - return true - } - // TODO: remove the next line after `Expr::Wildcard` is removed #[expect(deprecated)] let is_normal_minus_aggregates = matches!( @@ -728,6 +707,7 @@ impl CSEController for ExprCSEController<'_> { | Expr::ScalarVariable(..) | Expr::Alias(..) | Expr::Wildcard { .. } + | Expr::LambdaColumn(_) ); let is_aggr = matches!(node, Expr::AggregateFunction(..)); diff --git a/datafusion/optimizer/src/decorrelate.rs b/datafusion/optimizer/src/decorrelate.rs index 0f43741834009..63236787743a4 100644 --- a/datafusion/optimizer/src/decorrelate.rs +++ b/datafusion/optimizer/src/decorrelate.rs @@ -527,17 +527,18 @@ fn proj_exprs_evaluation_result_on_empty_batch( for expr in proj_expr.iter() { let result_expr = expr .clone() - .transform_up_with_lambdas_params(|expr, lambdas_params| match &expr { - Expr::Column(col) if !col.is_lambda_parameter(lambdas_params) => { + .transform_up(|expr| { + if let Expr::Column(Column { name, .. }) = &expr { if let Some(result_expr) = - input_expr_result_map_for_count_bug.get(col.name()) + input_expr_result_map_for_count_bug.get(name) { Ok(Transformed::yes(result_expr.clone())) } else { Ok(Transformed::no(expr)) } + } else { + Ok(Transformed::no(expr)) } - _ => Ok(Transformed::no(expr)), }) .data()?; @@ -569,17 +570,16 @@ fn filter_exprs_evaluation_result_on_empty_batch( ) -> Result> { let result_expr = filter_expr .clone() - .transform_up_with_lambdas_params(|expr, lambdas_params| match &expr { - Expr::Column(col) if !col.is_lambda_parameter(lambdas_params) => { - if let Some(result_expr) = - input_expr_result_map_for_count_bug.get(col.name()) - { + .transform_up(|expr| { + if let Expr::Column(Column { name, .. }) = &expr { + if let Some(result_expr) = input_expr_result_map_for_count_bug.get(name) { Ok(Transformed::yes(result_expr.clone())) } else { Ok(Transformed::no(expr)) } + } else { + Ok(Transformed::no(expr)) } - _ => Ok(Transformed::no(expr)), }) .data()?; diff --git a/datafusion/optimizer/src/optimize_projections/mod.rs b/datafusion/optimizer/src/optimize_projections/mod.rs index f0187b618ccc0..5db71417bc8fd 100644 --- a/datafusion/optimizer/src/optimize_projections/mod.rs +++ b/datafusion/optimizer/src/optimize_projections/mod.rs @@ -639,7 +639,7 @@ fn is_expr_trivial(expr: &Expr) -> bool { /// --Source(a, b) /// ``` fn rewrite_expr(expr: Expr, input: &Projection) -> Result> { - expr.transform_up_with_lambdas_params(|expr, lambdas_params| { + expr.transform_up(|expr| { match expr { // remove any intermediate aliases if they do not carry metadata Expr::Alias(alias) => { @@ -653,7 +653,7 @@ fn rewrite_expr(expr: Expr, input: &Projection) -> Result> { false => Ok(Transformed::no(Expr::Alias(alias))), } } - Expr::Column(col) if !col.is_lambda_parameter(lambdas_params) => { + Expr::Column(col) => { // Find index of column: let idx = input.schema.index_of_column(&col)?; // get the corresponding unaliased input expression diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index 54cb026543270..77c533ce2f01e 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -287,14 +287,15 @@ fn can_evaluate_as_join_condition(predicate: &Expr) -> Result { | Expr::Cast(_) | Expr::TryCast(_) | Expr::InList { .. } - | Expr::ScalarFunction(_) => Ok(TreeNodeRecursion::Continue), + | Expr::ScalarFunction(_) + | Expr::Lambda(_) + | Expr::LambdaColumn(_) => Ok(TreeNodeRecursion::Continue), // TODO: remove the next line after `Expr::Wildcard` is removed #[expect(deprecated)] Expr::AggregateFunction(_) | Expr::WindowFunction(_) | Expr::Wildcard { .. } - | Expr::GroupingSet(_) - | Expr::Lambda { .. } => internal_err!("Unsupported predicate type"), + | Expr::GroupingSet(_) => internal_err!("Unsupported predicate type"), })?; Ok(is_evaluate) } @@ -1390,15 +1391,14 @@ pub fn replace_cols_by_name( e: Expr, replace_map: &HashMap, ) -> Result { - e.transform_up_with_lambdas_params(|expr, lambdas_params| { - Ok(match &expr { - Expr::Column(c) if !c.is_lambda_parameter(lambdas_params) => { - match replace_map.get(&c.flat_name()) { - Some(new_c) => Transformed::yes(new_c.clone()), - None => Transformed::no(expr), - } + e.transform_up(|expr| { + Ok(if let Expr::Column(c) = &expr { + match replace_map.get(&c.flat_name()) { + Some(new_c) => Transformed::yes(new_c.clone()), + None => Transformed::no(expr), } - _ => Transformed::no(expr), + } else { + Transformed::no(expr) }) }) .data() @@ -1407,18 +1407,17 @@ pub fn replace_cols_by_name( /// check whether the expression uses the columns in `check_map`. fn contain(e: &Expr, check_map: &HashMap) -> bool { let mut is_contain = false; - e.apply_with_lambdas_params(|expr, lambdas_params| { - Ok(match &expr { - Expr::Column(c) if !c.is_lambda_parameter(lambdas_params) => { - match check_map.get(&c.flat_name()) { - Some(_) => { - is_contain = true; - TreeNodeRecursion::Stop - } - None => TreeNodeRecursion::Continue, + e.apply(|expr| { + Ok(if let Expr::Column(c) = &expr { + match check_map.get(&c.flat_name()) { + Some(_) => { + is_contain = true; + TreeNodeRecursion::Stop } + None => TreeNodeRecursion::Continue, } - _ => TreeNodeRecursion::Continue, + } else { + TreeNodeRecursion::Continue }) }) .unwrap(); diff --git a/datafusion/optimizer/src/scalar_subquery_to_join.rs b/datafusion/optimizer/src/scalar_subquery_to_join.rs index f1e619750f9c8..48d1182527013 100644 --- a/datafusion/optimizer/src/scalar_subquery_to_join.rs +++ b/datafusion/optimizer/src/scalar_subquery_to_join.rs @@ -106,22 +106,17 @@ impl OptimizerRule for ScalarSubqueryToJoin { { if !expr_check_map.is_empty() { rewrite_expr = rewrite_expr - .transform_up_with_lambdas_params( - |expr, lambdas_params| { - // replace column references with entry in map, if it exists - if let Some(map_expr) = expr - .try_as_col() - .filter(|c| { - !c.is_lambda_parameter(lambdas_params) - }) - .and_then(|col| expr_check_map.get(&col.name)) - { - Ok(Transformed::yes(map_expr.clone())) - } else { - Ok(Transformed::no(expr)) - } - }, - ) + .transform_up(|expr| { + // replace column references with entry in map, if it exists + if let Some(map_expr) = expr + .try_as_col() + .and_then(|col| expr_check_map.get(&col.name)) + { + Ok(Transformed::yes(map_expr.clone())) + } else { + Ok(Transformed::no(expr)) + } + }) .data()?; } cur_input = optimized_subquery; @@ -176,26 +171,18 @@ impl OptimizerRule for ScalarSubqueryToJoin { { let new_expr = rewrite_expr .clone() - .transform_up_with_lambdas_params( - |expr, lambdas_params| { - // replace column references with entry in map, if it exists - if let Some(map_expr) = expr - .try_as_col() - .filter(|c| { - !c.is_lambda_parameter( - lambdas_params, - ) - }) - .and_then(|col| { - expr_check_map.get(&col.name) - }) - { - Ok(Transformed::yes(map_expr.clone())) - } else { - Ok(Transformed::no(expr)) - } - }, - ) + .transform_up(|expr| { + // replace column references with entry in map, if it exists + if let Some(map_expr) = + expr.try_as_col().and_then(|col| { + expr_check_map.get(&col.name) + }) + { + Ok(Transformed::yes(map_expr.clone())) + } else { + Ok(Transformed::no(expr)) + } + }) .data()?; expr_to_rewrite_expr_map.insert(expr, new_expr); } @@ -409,12 +396,8 @@ fn build_join( let mut expr_rewrite = TypeCoercionRewriter { schema: new_plan.schema(), }; - computation_project_expr.insert( - name, - computer_expr - .rewrite_with_schema(new_plan.schema(), &mut expr_rewrite) - .data()?, - ); + computation_project_expr + .insert(name, computer_expr.rewrite(&mut expr_rewrite).data()?); } } diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index a824f6b7be49f..779c0acea9963 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -27,20 +27,18 @@ use arrow::{ record_batch::RecordBatch, }; -use datafusion_common::tree_node::TreeNodeContainer; use datafusion_common::{ cast::{as_large_list_array, as_list_array}, metadata::FieldMetadata, - tree_node::{ - Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeRewriter, - }, + tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeRewriter}, }; use datafusion_common::{ - exec_datafusion_err, internal_err, DFSchema, DataFusionError, Result, ScalarValue, + exec_datafusion_err, internal_err, DFSchema, DataFusionError, Result, + ScalarValue, }; use datafusion_expr::{ - and, binary::BinaryTypeCoercer, lit, or, simplify::SimplifyContext, BinaryExpr, Case, - ColumnarValue, Expr, Like, Operator, Volatility, + and, binary::BinaryTypeCoercer, lit, or, BinaryExpr, Case, ColumnarValue, Expr, Like, + Operator, Volatility, }; use datafusion_expr::{expr::ScalarFunction, interval_arithmetic::NullableInterval}; use datafusion_expr::{ @@ -270,7 +268,7 @@ impl ExprSimplifier { /// documentation for more details on type coercion pub fn coerce(&self, expr: Expr, schema: &DFSchema) -> Result { let mut expr_rewrite = TypeCoercionRewriter { schema }; - expr.rewrite_with_schema(schema, &mut expr_rewrite).data() + expr.rewrite(&mut expr_rewrite).data() } /// Input guarantees about the values of columns. @@ -469,9 +467,11 @@ impl TreeNodeRewriter for Canonicalizer { }; match (left.as_ref(), right.as_ref(), op.swap()) { // - (Expr::Column(left_col), Expr::Column(right_col), Some(swapped_op)) - if right_col > left_col => - { + ( + left_col @ (Expr::Column(_) | Expr::LambdaColumn(_)), + right_col @ (Expr::Column(_) | Expr::LambdaColumn(_)), + Some(swapped_op), + ) if right_col > left_col => { Ok(Transformed::yes(Expr::BinaryExpr(BinaryExpr { left: right, op: swapped_op, @@ -479,13 +479,15 @@ impl TreeNodeRewriter for Canonicalizer { }))) } // - (Expr::Literal(_a, _), Expr::Column(_b), Some(swapped_op)) => { - Ok(Transformed::yes(Expr::BinaryExpr(BinaryExpr { - left: right, - op: swapped_op, - right: left, - }))) - } + ( + Expr::Literal(_, _), + Expr::Column(_) | Expr::LambdaColumn(_), + Some(swapped_op), + ) => Ok(Transformed::yes(Expr::BinaryExpr(BinaryExpr { + left: right, + op: swapped_op, + right: left, + }))), _ => Ok(Transformed::no(Expr::BinaryExpr(BinaryExpr { left, op, @@ -653,7 +655,7 @@ impl<'a> ConstEvaluator<'a> { | Expr::GroupingSet(_) | Expr::Wildcard { .. } | Expr::Placeholder(_) - | Expr::Lambda { .. } => false, + | Expr::LambdaColumn(_) => false, Expr::ScalarFunction(ScalarFunction { func, .. }) => { Self::volatility_ok(func.signature().volatility) } @@ -677,7 +679,8 @@ impl<'a> ConstEvaluator<'a> { | Expr::Case(_) | Expr::Cast { .. } | Expr::TryCast { .. } - | Expr::InList { .. } => true, + | Expr::InList { .. } + | Expr::Lambda(_) => true, } } @@ -758,89 +761,6 @@ impl<'a, S> Simplifier<'a, S> { impl TreeNodeRewriter for Simplifier<'_, S> { type Node = Expr; - fn f_down(&mut self, expr: Self::Node) -> Result> { - match expr { - Expr::ScalarFunction(ScalarFunction { func, args }) - if args.iter().any(|arg| matches!(arg, Expr::Lambda(_))) => - { - // there's currently no way to adapt a generic SimplifyInfo with lambda parameters, - // so, if the scalar function has any lambda, we materialize a DFSchema using all the - // columns references in every arguments. Than we can call lambdas_schemas_from_args, - // and for each argument, we create a new SimplifyContext with the scoped schema, and - // simplify the argument using this 'sub-context'. Finally, we set Transformed.tnr to - // Jump so the parent context doesn't try to simplify the argument again, without the - // parameters info - - // get all columns references - let mut columns_refs = HashSet::new(); - - for arg in &args { - arg.add_column_refs(&mut columns_refs); - } - - // materialize columns references into qualified fields - let qualified_fields = columns_refs - .into_iter() - .map(|captured_column| { - let expr = Expr::Column(captured_column.clone()); - - Ok(( - captured_column.relation.clone(), - Arc::new(Field::new( - captured_column.name(), - self.info.get_data_type(&expr)?, - self.info.nullable(&expr)?, - )), - )) - }) - .collect::>()?; - - // create a schema using the materialized fields - let dfschema = - DFSchema::new_with_metadata(qualified_fields, Default::default())?; - - let mut scoped_schemas = func - .arguments_schema_from_logical_args(&args, &dfschema)? - .into_iter(); - - let transformed_args = args - .map_elements(|arg| { - let scoped_schema = scoped_schemas.next().unwrap(); - - // create a sub-context, using the scoped schema, that includes information about the lambda parameters - let simplify_context = - SimplifyContext::new(self.info.execution_props()) - .with_schema(Arc::new(scoped_schema.into_owned())); - - let mut simplifier = Simplifier::new(&simplify_context); - - // simplify the argument using it's context - arg.rewrite(&mut simplifier) - })? - .update_data(|args| { - Expr::ScalarFunction(ScalarFunction { func, args }) - }); - - Ok(Transformed::new( - transformed_args.data, - transformed_args.transformed, - // return at least Jump so the parent contex doesn't try again to simplify the arguments - // (and fail because it doesn't contain info about lambdas paramters) - match transformed_args.tnr { - TreeNodeRecursion::Continue | TreeNodeRecursion::Jump => { - TreeNodeRecursion::Jump - } - TreeNodeRecursion::Stop => TreeNodeRecursion::Stop, - }, - )) - - // Ok(transformed_args.update_data(|args| Expr::ScalarFunction(ScalarFunction { func, args}))) - } - // Expr::Lambda(_) => Ok(Transformed::new(expr, false, TreeNodeRecursion::Jump)), - _ => Ok(Transformed::no(expr)), - } - } - /// rewrite the expression simplifying any constant expressions fn f_up(&mut self, expr: Expr) -> Result> { use datafusion_expr::Operator::{ @@ -2092,8 +2012,8 @@ fn are_inlist_and_eq(left: &Expr, right: &Expr) -> bool { let left = as_inlist(left); let right = as_inlist(right); if let (Some(lhs), Some(rhs)) = (left, right) { - matches!(lhs.expr.as_ref(), Expr::Column(_)) - && matches!(rhs.expr.as_ref(), Expr::Column(_)) + matches!(lhs.expr.as_ref(), Expr::Column(_) | Expr::LambdaColumn(_)) + && matches!(rhs.expr.as_ref(), Expr::Column(_) | Expr::LambdaColumn(_)) && lhs.expr == rhs.expr && !lhs.negated && !rhs.negated @@ -2108,16 +2028,20 @@ fn as_inlist(expr: &'_ Expr) -> Option> { Expr::InList(inlist) => Some(Cow::Borrowed(inlist)), Expr::BinaryExpr(BinaryExpr { left, op, right }) if *op == Operator::Eq => { match (left.as_ref(), right.as_ref()) { - (Expr::Column(_), Expr::Literal(_, _)) => Some(Cow::Owned(InList { - expr: left.clone(), - list: vec![*right.clone()], - negated: false, - })), - (Expr::Literal(_, _), Expr::Column(_)) => Some(Cow::Owned(InList { - expr: right.clone(), - list: vec![*left.clone()], - negated: false, - })), + (Expr::Column(_) | Expr::LambdaColumn(_), Expr::Literal(_, _)) => { + Some(Cow::Owned(InList { + expr: left.clone(), + list: vec![*right.clone()], + negated: false, + })) + } + (Expr::Literal(_, _), Expr::Column(_) | Expr::LambdaColumn(_)) => { + Some(Cow::Owned(InList { + expr: right.clone(), + list: vec![*left.clone()], + negated: false, + })) + } _ => None, } } @@ -2133,16 +2057,20 @@ fn to_inlist(expr: Expr) -> Option { op: Operator::Eq, right, }) => match (left.as_ref(), right.as_ref()) { - (Expr::Column(_), Expr::Literal(_, _)) => Some(InList { - expr: left, - list: vec![*right], - negated: false, - }), - (Expr::Literal(_, _), Expr::Column(_)) => Some(InList { - expr: right, - list: vec![*left], - negated: false, - }), + (Expr::Column(_) | Expr::LambdaColumn(_), Expr::Literal(_, _)) => { + Some(InList { + expr: left, + list: vec![*right], + negated: false, + }) + } + (Expr::Literal(_, _), Expr::Column(_) | Expr::LambdaColumn(_)) => { + Some(InList { + expr: right, + list: vec![*left], + negated: false, + }) + } _ => None, }, _ => None, diff --git a/datafusion/optimizer/src/utils.rs b/datafusion/optimizer/src/utils.rs index d0ae4932628f3..81763fa0552fb 100644 --- a/datafusion/optimizer/src/utils.rs +++ b/datafusion/optimizer/src/utils.rs @@ -23,7 +23,7 @@ use crate::analyzer::type_coercion::TypeCoercionRewriter; use arrow::array::{new_null_array, Array, RecordBatch}; use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::cast::as_boolean_array; -use datafusion_common::tree_node::TransformedResult; +use datafusion_common::tree_node::{TransformedResult, TreeNode}; use datafusion_common::{Column, DFSchema, Result, ScalarValue}; use datafusion_expr::execution_props::ExecutionProps; use datafusion_expr::expr_rewriter::replace_col; @@ -148,7 +148,7 @@ fn evaluate_expr_with_null_column<'a>( fn coerce(expr: Expr, schema: &DFSchema) -> Result { let mut expr_rewrite = TypeCoercionRewriter { schema }; - expr.rewrite_with_schema(schema, &mut expr_rewrite).data() + expr.rewrite(&mut expr_rewrite).data() } #[cfg(test)] diff --git a/datafusion/physical-expr-adapter/src/schema_rewriter.rs b/datafusion/physical-expr-adapter/src/schema_rewriter.rs index 4a81a5c99ac75..d1957ae1892ea 100644 --- a/datafusion/physical-expr-adapter/src/schema_rewriter.rs +++ b/datafusion/physical-expr-adapter/src/schema_rewriter.rs @@ -21,14 +21,13 @@ use std::sync::Arc; use arrow::compute::can_cast_types; use arrow::datatypes::{DataType, FieldRef, Schema, SchemaRef}; -use datafusion_common::HashSet; +use datafusion_common::tree_node::TreeNode; use datafusion_common::{ exec_err, tree_node::{Transformed, TransformedResult}, Result, ScalarValue, }; use datafusion_functions::core::getfield::GetFieldFunc; -use datafusion_physical_expr::PhysicalExprExt; use datafusion_physical_expr::{ expressions::{self, CastExpr, Column}, ScalarFunctionExpr, @@ -219,10 +218,8 @@ impl PhysicalExprAdapter for DefaultPhysicalExprAdapter { physical_file_schema: &self.physical_file_schema, partition_fields: &self.partition_values, }; - expr.transform_with_lambdas_params(|expr, lambdas_params| { - rewriter.rewrite_expr(Arc::clone(&expr), lambdas_params) - }) - .data() + expr.transform(|expr| rewriter.rewrite_expr(Arc::clone(&expr))) + .data() } fn with_partition_values( @@ -246,18 +243,13 @@ impl<'a> DefaultPhysicalExprAdapterRewriter<'a> { fn rewrite_expr( &self, expr: Arc, - lambdas_params: &HashSet, ) -> Result>> { - if let Some(transformed) = - self.try_rewrite_struct_field_access(&expr, lambdas_params)? - { + if let Some(transformed) = self.try_rewrite_struct_field_access(&expr)? { return Ok(Transformed::yes(transformed)); } if let Some(column) = expr.as_any().downcast_ref::() { - if !lambdas_params.contains(column.name()) { - return self.rewrite_column(Arc::clone(&expr), column); - } + return self.rewrite_column(Arc::clone(&expr), column); } Ok(Transformed::no(expr)) @@ -269,7 +261,6 @@ impl<'a> DefaultPhysicalExprAdapterRewriter<'a> { fn try_rewrite_struct_field_access( &self, expr: &Arc, - lambdas_params: &HashSet, ) -> Result>> { let get_field_expr = match ScalarFunctionExpr::try_downcast_func::(expr.as_ref()) { @@ -301,8 +292,8 @@ impl<'a> DefaultPhysicalExprAdapterRewriter<'a> { }; let column = match source_expr.as_any().downcast_ref::() { - Some(column) if !lambdas_params.contains(column.name()) => column, - _ => return Ok(None), + Some(column) => column, + None => return Ok(None), }; let physical_field = @@ -456,7 +447,6 @@ mod tests { use super::*; use arrow::array::{RecordBatch, RecordBatchOptions}; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; - use datafusion_common::hashbrown::HashSet; use datafusion_common::{assert_contains, record_batch, Result, ScalarValue}; use datafusion_expr::Operator; use datafusion_physical_expr::expressions::{col, lit, CastExpr, Column, Literal}; @@ -863,9 +853,7 @@ mod tests { // Test that when a field exists in physical schema, it returns None let column = Arc::new(Column::new("struct_col", 0)) as Arc; - let result = rewriter - .try_rewrite_struct_field_access(&column, &HashSet::new()) - .unwrap(); + let result = rewriter.try_rewrite_struct_field_access(&column).unwrap(); assert!(result.is_none()); // The actual test for the get_field expression would require creating a proper ScalarFunctionExpr diff --git a/datafusion/physical-expr/Cargo.toml b/datafusion/physical-expr/Cargo.toml index d4c0e1cbe6eb7..b7654a0f6f603 100644 --- a/datafusion/physical-expr/Cargo.toml +++ b/datafusion/physical-expr/Cargo.toml @@ -37,9 +37,6 @@ workspace = true [lib] name = "datafusion_physical_expr" -[features] -recursive_protection = ["dep:recursive"] - [dependencies] ahash = { workspace = true } arrow = { workspace = true } @@ -55,7 +52,6 @@ itertools = { workspace = true, features = ["use_std"] } parking_lot = { workspace = true } paste = "^1.0" petgraph = "0.8.3" -recursive = { workspace = true, optional = true } [dev-dependencies] arrow = { workspace = true, features = ["test_utils"] } diff --git a/datafusion/physical-expr/src/expressions/column.rs b/datafusion/physical-expr/src/expressions/column.rs index c55f42ae333bc..7040fa2bfc9b4 100644 --- a/datafusion/physical-expr/src/expressions/column.rs +++ b/datafusion/physical-expr/src/expressions/column.rs @@ -22,13 +22,12 @@ use std::hash::Hash; use std::sync::Arc; use crate::physical_expr::PhysicalExpr; -use crate::PhysicalExprExt; use arrow::datatypes::FieldRef; use arrow::{ datatypes::{DataType, Schema, SchemaRef}, record_batch::RecordBatch, }; -use datafusion_common::tree_node::{Transformed, TransformedResult}; +use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::{internal_err, plan_err, Result}; use datafusion_expr::ColumnarValue; @@ -68,8 +67,7 @@ use datafusion_expr::ColumnarValue; pub struct Column { /// The name of the column (used for debugging and display purposes) name: String, - /// The index of the column in its schema. - /// Within a lambda body, this refer to the lambda scoped schema, not the plan schema. + /// The index of the column in its schema index: usize, } @@ -180,9 +178,9 @@ pub fn with_new_schema( expr: Arc, schema: &SchemaRef, ) -> Result> { - expr.transform_up_with_lambdas_params(|expr, lambdas_params| { - match expr.as_any().downcast_ref::() { - Some(col) if !lambdas_params.contains(col.name()) => { + Ok(expr + .transform_up(|expr| { + if let Some(col) = expr.as_any().downcast_ref::() { let idx = col.index(); let Some(field) = schema.fields().get(idx) else { return plan_err!( @@ -192,11 +190,11 @@ pub fn with_new_schema( let new_col = Column::new(field.name(), idx); Ok(Transformed::yes(Arc::new(new_col) as _)) + } else { + Ok(Transformed::no(expr)) } - _ => Ok(Transformed::no(expr)), - } - }) - .data() + })? + .data) } #[cfg(test)] diff --git a/datafusion/physical-expr/src/expressions/lambda.rs b/datafusion/physical-expr/src/expressions/lambda.rs index 55110fdf5bf6b..38b64e3c7f3e1 100644 --- a/datafusion/physical-expr/src/expressions/lambda.rs +++ b/datafusion/physical-expr/src/expressions/lambda.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -//! Physical column reference: [`Column`] +//! Physical lambda expression: [`LambdaExpr`] use std::hash::Hash; use std::sync::Arc; @@ -23,16 +23,15 @@ use std::{any::Any, sync::OnceLock}; use crate::expressions::Column; use crate::physical_expr::PhysicalExpr; -use crate::PhysicalExprExt; use arrow::{ datatypes::{DataType, Schema}, record_batch::RecordBatch, }; -use datafusion_common::tree_node::TreeNodeRecursion; +use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; use datafusion_common::{internal_err, HashSet, Result}; use datafusion_expr::ColumnarValue; -/// Represents a lambda with the given parameters name and body +/// Represents a lambda with the given parameters names and body #[derive(Debug, Eq, Clone)] pub struct LambdaExpr { params: Vec, @@ -79,16 +78,14 @@ impl LambdaExpr { let mut indices = HashSet::new(); self.body - .apply_with_lambdas_params(|expr, lambdas_params| { + .apply(|expr| { if let Some(column) = expr.as_any().downcast_ref::() { - if !lambdas_params.contains(column.name()) { - indices.insert(column.index()); - } + indices.insert(column.index()); } Ok(TreeNodeRecursion::Continue) }) - .unwrap(); + .expect("closure should be infallibe"); indices }) @@ -106,12 +103,12 @@ impl PhysicalExpr for LambdaExpr { self } - fn data_type(&self, _input_schema: &Schema) -> Result { - Ok(DataType::Null) + fn data_type(&self, input_schema: &Schema) -> Result { + self.body.data_type(input_schema) } - fn nullable(&self, _input_schema: &Schema) -> Result { - Ok(true) + fn nullable(&self, input_schema: &Schema) -> Result { + self.body.nullable(input_schema) } fn evaluate(&self, _batch: &RecordBatch) -> Result { diff --git a/datafusion/physical-expr/src/expressions/lambda_column.rs b/datafusion/physical-expr/src/expressions/lambda_column.rs new file mode 100644 index 0000000000000..4aed16186ba6f --- /dev/null +++ b/datafusion/physical-expr/src/expressions/lambda_column.rs @@ -0,0 +1,136 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Physical lambda column reference: [`LambdaColumn`] + +use std::any::Any; +use std::hash::Hash; +use std::sync::Arc; + +use crate::physical_expr::PhysicalExpr; +use arrow::array::ArrayRef; +use arrow::datatypes::FieldRef; +use arrow::{ + datatypes::{DataType, Schema}, + record_batch::RecordBatch, +}; +use datafusion_common::{Result, exec_datafusion_err}; +use datafusion_expr::ColumnarValue; + +/// Represents the lambda column with a given name and field +#[derive(Debug, Clone)] +pub struct LambdaColumn { + name: String, + field: FieldRef, + value: Option, +} + +impl Eq for LambdaColumn {} + +impl PartialEq for LambdaColumn { + fn eq(&self, other: &Self) -> bool { + self.name == other.name && self.field == other.field + } +} + +impl Hash for LambdaColumn { + fn hash(&self, state: &mut H) { + self.name.hash(state); + self.field.hash(state); + } +} + +impl LambdaColumn { + /// Create a new lambda column expression + pub fn new(name: &str, field: FieldRef) -> Self { + Self { + name: name.to_owned(), + field, + value: None, + } + } + + /// Get the column's name + pub fn name(&self) -> &str { + &self.name + } + + /// Get the column's field + pub fn field(&self) -> &FieldRef { + &self.field + } + + pub fn with_value(self, value: ArrayRef) -> Self { + Self { + name: self.name, + field: self.field, + value: Some(ColumnarValue::Array(value)), + } + } +} + +impl std::fmt::Display for LambdaColumn { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "{}@-1", self.name) + } +} + +impl PhysicalExpr for LambdaColumn { + /// Return a reference to Any that can be used for downcasting + fn as_any(&self) -> &dyn Any { + self + } + + /// Get the data type of this expression, given the schema of the input + fn data_type(&self, _input_schema: &Schema) -> Result { + Ok(self.field.data_type().clone()) + } + + /// Decide whether this expression is nullable, given the schema of the input + fn nullable(&self, _input_schema: &Schema) -> Result { + Ok(self.field.is_nullable()) + } + + /// Evaluate the expression + fn evaluate(&self, _batch: &RecordBatch) -> Result { + self.value.clone().ok_or_else(|| exec_datafusion_err!("Physical LambdaColumn {} missing value", self.name)) + } + + fn return_field(&self, _input_schema: &Schema) -> Result { + Ok(Arc::clone(&self.field)) + } + + fn children(&self) -> Vec<&Arc> { + vec![] + } + + fn with_new_children( + self: Arc, + _children: Vec>, + ) -> Result> { + Ok(self) + } + + fn fmt_sql(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.name) + } +} + +/// Create a column expression +pub fn lambda_col(name: &str, field: FieldRef) -> Result> { + Ok(Arc::new(LambdaColumn::new(name, field))) +} diff --git a/datafusion/physical-expr/src/expressions/mod.rs b/datafusion/physical-expr/src/expressions/mod.rs index e87941da5ef4c..5d044ab848550 100644 --- a/datafusion/physical-expr/src/expressions/mod.rs +++ b/datafusion/physical-expr/src/expressions/mod.rs @@ -23,6 +23,7 @@ mod case; mod cast; mod cast_column; mod column; +mod lambda_column; mod dynamic_filters; mod in_list; mod is_not_null; @@ -45,6 +46,7 @@ pub use case::{case, CaseExpr}; pub use cast::{cast, CastExpr}; pub use cast_column::CastColumnExpr; pub use column::{col, with_new_schema, Column}; +pub use lambda_column::{lambda_col, LambdaColumn}; pub use datafusion_expr::utils::format_state_name; pub use dynamic_filters::DynamicFilterPhysicalExpr; pub use in_list::{in_list, InListExpr}; diff --git a/datafusion/physical-expr/src/lib.rs b/datafusion/physical-expr/src/lib.rs index 873205f28bef4..aa8c9e50fd71e 100644 --- a/datafusion/physical-expr/src/lib.rs +++ b/datafusion/physical-expr/src/lib.rs @@ -70,8 +70,6 @@ pub use datafusion_physical_expr_common::sort_expr::{ pub use planner::{create_physical_expr, create_physical_exprs}; pub use scalar_function::ScalarFunctionExpr; -pub use scalar_function::PhysicalExprExt; - pub use simplifier::PhysicalExprSimplifier; pub use utils::{conjunction, conjunction_opt, split_conjunction}; diff --git a/datafusion/physical-expr/src/physical_expr.rs b/datafusion/physical-expr/src/physical_expr.rs index 2584fc22885c2..c658a8eddc233 100644 --- a/datafusion/physical-expr/src/physical_expr.rs +++ b/datafusion/physical-expr/src/physical_expr.rs @@ -18,11 +18,11 @@ use std::sync::Arc; use crate::expressions::{self, Column}; -use crate::{create_physical_expr, LexOrdering, PhysicalExprExt, PhysicalSortExpr}; +use crate::{create_physical_expr, LexOrdering, PhysicalSortExpr}; use arrow::compute::SortOptions; use arrow::datatypes::{Schema, SchemaRef}; -use datafusion_common::tree_node::{Transformed, TransformedResult}; +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::{plan_err, Result}; use datafusion_common::{DFSchema, HashMap}; use datafusion_expr::execution_props::ExecutionProps; @@ -38,14 +38,14 @@ pub fn add_offset_to_expr( expr: Arc, offset: isize, ) -> Result> { - expr.transform_down_with_lambdas_params(|e, lambdas_params| match e.as_any().downcast_ref::() { - Some(col) if !lambdas_params.contains(col.name()) => { + expr.transform_down(|e| match e.as_any().downcast_ref::() { + Some(col) => { let Some(idx) = col.index().checked_add_signed(offset) else { return plan_err!("Column index overflow"); }; Ok(Transformed::yes(Arc::new(Column::new(col.name(), idx)))) } - _ => Ok(Transformed::no(e)), + None => Ok(Transformed::no(e)), }) .data() } diff --git a/datafusion/physical-expr/src/planner.rs b/datafusion/physical-expr/src/planner.rs index 0119c81b8ed94..4c3d1352cce0f 100644 --- a/datafusion/physical-expr/src/planner.rs +++ b/datafusion/physical-expr/src/planner.rs @@ -17,7 +17,7 @@ use std::sync::Arc; -use crate::expressions::LambdaExpr; +use crate::expressions::{lambda_col, LambdaExpr}; use crate::ScalarFunctionExpr; use crate::{ expressions::{self, binary, like, similar_to, Column, Literal}, @@ -28,10 +28,13 @@ use arrow::datatypes::Schema; use datafusion_common::config::ConfigOptions; use datafusion_common::metadata::FieldMetadata; use datafusion_common::{ - exec_err, not_impl_err, plan_err, DFSchema, Result, ScalarValue, ToDFSchema, + exec_err, not_impl_err, plan_err, DFSchema, Result, ScalarValue, + ToDFSchema, }; use datafusion_expr::execution_props::ExecutionProps; -use datafusion_expr::expr::{Alias, Cast, InList, Lambda, Placeholder, ScalarFunction}; +use datafusion_expr::expr::{ + Alias, Cast, InList, Lambda, LambdaColumn, Placeholder, ScalarFunction, +}; use datafusion_expr::var_provider::is_system_variables; use datafusion_expr::var_provider::VarType; use datafusion_expr::{ @@ -105,8 +108,7 @@ use datafusion_expr::{ /// /// * `e` - The logical expression /// * `input_dfschema` - The DataFusion schema for the input, used to resolve `Column` references -/// to qualified or unqualified fields by name. Note that for creating a lambda, this must be -/// scoped lambda schema, and not the outer schema +/// to qualified or unqualified fields by name. pub fn create_physical_expr( e: &Expr, input_dfschema: &DFSchema, @@ -316,27 +318,13 @@ pub fn create_physical_expr( input_dfschema, execution_props, )?), - Expr::Lambda { .. } => { - exec_err!("Expr::Lambda should be handled by Expr::ScalarFunction, as it can only exist within it") - } + Expr::Lambda(Lambda { params, body }) => Ok(Arc::new(LambdaExpr::new( + params.clone(), + create_physical_expr(body, input_dfschema, execution_props)?, + ))), Expr::ScalarFunction(ScalarFunction { func, args }) => { - let lambdas_schemas = - func.arguments_schema_from_logical_args(args, input_dfschema)?; - - let physical_args = std::iter::zip(args, lambdas_schemas) - .map(|(expr, schema)| match expr { - Expr::Lambda(Lambda { params, body }) => { - Ok(Arc::new(LambdaExpr::new( - params.clone(), - create_physical_expr(body, &schema, execution_props)?, - )) as Arc) - } - expr => create_physical_expr(expr, &schema, execution_props), - }) - .collect::>>()?; - - //let physical_args = - // create_physical_exprs(args, input_dfschema, execution_props)?; + let physical_args = + create_physical_exprs(args, input_dfschema, execution_props)?; let config_options = match execution_props.config_options.as_ref() { Some(config_options) => Arc::clone(config_options), @@ -404,6 +392,14 @@ pub fn create_physical_expr( Expr::Placeholder(Placeholder { id, .. }) => { exec_err!("Placeholder '{id}' was not provided a value for execution.") } + Expr::LambdaColumn(LambdaColumn { + name, + field, + spans: _, + }) => lambda_col( + name, + Arc::clone(field), + ), other => { not_impl_err!("Physical plan does not support logical expression {other:?}") } diff --git a/datafusion/physical-expr/src/projection.rs b/datafusion/physical-expr/src/projection.rs index 70be717a8436c..a120ab427e1de 100644 --- a/datafusion/physical-expr/src/projection.rs +++ b/datafusion/physical-expr/src/projection.rs @@ -20,11 +20,11 @@ use std::sync::Arc; use crate::expressions::Column; use crate::utils::collect_columns; -use crate::{PhysicalExpr, PhysicalExprExt}; +use crate::PhysicalExpr; use arrow::datatypes::{Field, Schema, SchemaRef}; use datafusion_common::stats::{ColumnStatistics, Precision}; -use datafusion_common::tree_node::{Transformed, TransformedResult}; +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::{internal_datafusion_err, internal_err, plan_err, Result}; use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; @@ -499,16 +499,13 @@ pub fn update_expr( let mut state = RewriteState::Unchanged; let new_expr = Arc::clone(expr) - .transform_up_with_lambdas_params(|expr, lambdas_params| { + .transform_up(|expr| { if state == RewriteState::RewrittenInvalid { return Ok(Transformed::no(expr)); } - let column = match expr.as_any().downcast_ref::() { - Some(column) if !lambdas_params.contains(column.name()) => column, - _ => { - return Ok(Transformed::no(expr)); - } + let Some(column) = expr.as_any().downcast_ref::() else { + return Ok(Transformed::no(expr)); }; if sync_with_child { state = RewriteState::RewrittenValid; @@ -619,14 +616,14 @@ impl ProjectionMapping { let mut map = IndexMap::<_, ProjectionTargets>::new(); for (expr_idx, (expr, name)) in expr.into_iter().enumerate() { let target_expr = Arc::new(Column::new(&name, expr_idx)) as _; - let source_expr = expr.transform_down_with_schema(input_schema, |e, schema| match e.as_any().downcast_ref::() { + let source_expr = expr.transform_down(|e| match e.as_any().downcast_ref::() { Some(col) => { - // Sometimes, an expression and its name in the schema + // Sometimes, an expression and its name in the input_schema // doesn't match. This can cause problems, so we make sure - // that the expression name matches with the name in `schema`. + // that the expression name matches with the name in `input_schema`. // Conceptually, `source_expr` and `expression` should be the same. let idx = col.index(); - let matching_field = schema.field(idx); + let matching_field = input_schema.field(idx); let matching_name = matching_field.name(); if col.name() != matching_name { return internal_err!( @@ -740,25 +737,21 @@ pub fn project_ordering( ) -> Option { let mut projected_exprs = vec![]; for PhysicalSortExpr { expr, options } in ordering.iter() { - let transformed = - Arc::clone(expr).transform_up_with_lambdas_params(|expr, lambdas_params| { - let col = match expr.as_any().downcast_ref::() { - Some(col) if !lambdas_params.contains(col.name()) => col, - _ => { - return Ok(Transformed::no(expr)); - } - }; + let transformed = Arc::clone(expr).transform_up(|expr| { + let Some(col) = expr.as_any().downcast_ref::() else { + return Ok(Transformed::no(expr)); + }; - let name = col.name(); - if let Some((idx, _)) = schema.column_with_name(name) { - // Compute the new column expression (with correct index) after projection: - Ok(Transformed::yes(Arc::new(Column::new(name, idx)))) - } else { - // Cannot find expression in the projected_schema, - // signal this using an Err result - plan_err!("") - } - }); + let name = col.name(); + if let Some((idx, _)) = schema.column_with_name(name) { + // Compute the new column expression (with correct index) after projection: + Ok(Transformed::yes(Arc::new(Column::new(name, idx)))) + } else { + // Cannot find expression in the projected_schema, + // signal this using an Err result + plan_err!("") + } + }); match transformed { Ok(transformed) => { diff --git a/datafusion/physical-expr/src/scalar_function.rs b/datafusion/physical-expr/src/scalar_function.rs index 22fa300f05df4..2527e84241fe3 100644 --- a/datafusion/physical-expr/src/scalar_function.rs +++ b/datafusion/physical-expr/src/scalar_function.rs @@ -30,19 +30,17 @@ //! to a function that supports f64, it is coerced to f64. use std::any::Any; -use std::borrow::Cow; use std::fmt::{self, Debug, Formatter}; use std::hash::{Hash, Hasher}; use std::sync::Arc; -use crate::expressions::{Column, LambdaExpr, Literal}; +use crate::expressions::{LambdaExpr, Literal}; use crate::PhysicalExpr; use arrow::array::{Array, NullArray, RecordBatch}; use arrow::datatypes::{DataType, Field, FieldRef, Schema}; use datafusion_common::config::{ConfigEntry, ConfigOptions}; -use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion}; -use datafusion_common::{internal_err, HashSet, Result, ScalarValue}; +use datafusion_common::{exec_err, internal_err, Result, ScalarValue}; use datafusion_expr::interval_arithmetic::Interval; use datafusion_expr::sort_properties::ExprProperties; use datafusion_expr::type_coercion::functions::data_types_with_scalar_udf; @@ -96,16 +94,10 @@ impl ScalarFunctionExpr { schema: &Schema, config_options: Arc, ) -> Result { - let lambdas_schemas = lambdas_schemas_from_args(&fun, &args, schema)?; - - let arg_fields = std::iter::zip(&args, lambdas_schemas) - .map(|(e, schema)| { - if let Some(lambda) = e.as_any().downcast_ref::() { - lambda.body().return_field(&schema) - } else { - e.return_field(&schema) - } - }) + let name = fun.name().to_string(); + let arg_fields = args + .iter() + .map(|e| e.return_field(schema)) .collect::>>()?; // verify that input data types is consistent with function's `TypeSignature` @@ -137,7 +129,6 @@ impl ScalarFunctionExpr { }; let return_field = fun.return_field_from_args(ret_args)?; - let name = fun.name().to_string(); Ok(Self { fun, @@ -300,23 +291,7 @@ impl PhysicalExpr for ScalarFunctionExpr { let args_metadata = std::iter::zip(&self.args, &arg_fields) .map( |(expr, field)| match expr.as_any().downcast_ref::() { - Some(lambda) => { - let mut captures = false; - - expr.apply_with_lambdas_params(|expr, lambdas_params| { - match expr.as_any().downcast_ref::() { - Some(col) if !lambdas_params.contains(col.name()) => { - captures = true; - - Ok(TreeNodeRecursion::Stop) - } - _ => Ok(TreeNodeRecursion::Continue), - } - }) - .unwrap(); - - ValueOrLambdaParameter::Lambda(lambda.params(), captures) - } + Some(_lambda) => ValueOrLambdaParameter::Lambda, None => ValueOrLambdaParameter::Value(Arc::clone(field)), }, ) @@ -326,44 +301,30 @@ impl PhysicalExpr for ScalarFunctionExpr { let lambdas = std::iter::zip(&self.args, params) .map(|(arg, lambda_params)| { - arg.as_any() - .downcast_ref::() - .map(|lambda| { - let mut indices = HashSet::new(); - - arg.apply_with_lambdas_params(|expr, lambdas_params| { - if let Some(column) = - expr.as_any().downcast_ref::() - { - if !lambdas_params.contains(column.name()) { - indices.insert( - column.index(), //batch - // .schema_ref() - // .index_of(column.name())?, - ); - } - } - - Ok(TreeNodeRecursion::Continue) - })?; - - //let mut indices = indices.into_iter().collect::>(); - - //indices.sort_unstable(); - - let params = - std::iter::zip(lambda.params(), lambda_params.unwrap()) - .map(|(name, param)| Arc::new(param.with_name(name))) - .collect(); - - let captures = if !indices.is_empty() { + match (arg.as_any().downcast_ref::(), lambda_params) { + (Some(lambda), Some(lambda_params)) => { + if lambda.params().len() > lambda_params.len() { + return exec_err!( + "lambda defined {} params but UDF support only {}", + lambda.params().len(), + lambda_params.len() + ); + } + + let captures = lambda.captures(); + + let params = std::iter::zip(lambda.params(), lambda_params) + .map(|(name, param)| Arc::new(param.with_name(name))) + .collect(); + + let captures = if !captures.is_empty() { let (fields, columns): (Vec<_>, _) = std::iter::zip( batch.schema_ref().fields(), batch.columns(), ) .enumerate() .map(|(column_index, (field, column))| { - if indices.contains(&column_index) { + if captures.contains(&column_index) { (Arc::clone(field), Arc::clone(column)) } else { ( @@ -381,18 +342,26 @@ impl PhysicalExpr for ScalarFunctionExpr { let schema = Arc::new(Schema::new(fields)); Some(RecordBatch::try_new(schema, columns)?) - //Some(batch.project(&indices)?) } else { None }; - Ok(ScalarFunctionLambdaArg { + Ok(Some(ScalarFunctionLambdaArg { params, body: Arc::clone(lambda.body()), captures, - }) - }) - .transpose() + })) + } + (Some(_lambda), None) => exec_err!( + "{} don't reported the parameters of one of it's lambdas", + self.fun.name() + ), + (None, Some(_lambda_params)) => exec_err!( + "{} reported parameters for an argument that is not a lambda", + self.fun.name() + ), + _ => Ok(None), + } }) .collect::>>()?; @@ -493,373 +462,17 @@ impl PhysicalExpr for ScalarFunctionExpr { } } -pub fn lambdas_schemas_from_args<'a>( - fun: &ScalarUDF, - args: &[Arc], - schema: &'a Schema, -) -> Result>> { - let args_metadata = args - .iter() - .map(|e| match e.as_any().downcast_ref::() { - Some(lambda) => { - let mut captures = false; - - e.apply_with_lambdas_params(|expr, lambdas_params| { - match expr.as_any().downcast_ref::() { - Some(col) if !lambdas_params.contains(col.name()) => { - captures = true; - - Ok(TreeNodeRecursion::Stop) - } - _ => Ok(TreeNodeRecursion::Continue), - } - }) - .unwrap(); - - Ok(ValueOrLambdaParameter::Lambda(lambda.params(), captures)) - } - None => Ok(ValueOrLambdaParameter::Value(e.return_field(schema)?)), - }) - .collect::>>()?; - - /*let captures = args - .iter() - .map(|arg| { - if arg.as_any().is::() { - let mut columns = HashSet::new(); - - arg.apply_with_lambdas_params(|n, lambdas_params| { - if let Some(column) = n.as_any().downcast_ref::() { - if !lambdas_params.contains(column.name()) { - columns.insert(schema.index_of(column.name())?); - } - // columns.insert(column.index()); - } - - Ok(TreeNodeRecursion::Continue) - })?; - - Ok(columns) - } else { - Ok(HashSet::new()) - } - }) - .collect::>>()?; */ - - fun.arguments_arrow_schema(&args_metadata, schema) -} - -pub trait PhysicalExprExt: Sized { - fn apply_with_lambdas_params< - 'n, - F: FnMut(&'n Self, &HashSet<&'n str>) -> Result, - >( - &'n self, - f: F, - ) -> Result; - - fn apply_with_schema<'n, F: FnMut(&'n Self, &Schema) -> Result>( - &'n self, - schema: &Schema, - f: F, - ) -> Result; - - fn apply_children_with_schema< - 'n, - F: FnMut(&'n Self, &Schema) -> Result, - >( - &'n self, - schema: &Schema, - f: F, - ) -> Result; - - fn transform_down_with_schema Result>>( - self, - schema: &Schema, - f: F, - ) -> Result>; - - fn transform_up_with_schema Result>>( - self, - schema: &Schema, - f: F, - ) -> Result>; - - fn transform_with_schema Result>>( - self, - schema: &Schema, - f: F, - ) -> Result> { - self.transform_up_with_schema(schema, f) - } - - fn transform_down_with_lambdas_params( - self, - f: impl FnMut(Self, &HashSet) -> Result>, - ) -> Result>; - - fn transform_up_with_lambdas_params( - self, - f: impl FnMut(Self, &HashSet) -> Result>, - ) -> Result>; - - fn transform_with_lambdas_params( - self, - f: impl FnMut(Self, &HashSet) -> Result>, - ) -> Result> { - self.transform_up_with_lambdas_params(f) - } -} - -impl PhysicalExprExt for Arc { - fn apply_with_lambdas_params< - 'n, - F: FnMut(&'n Self, &HashSet<&'n str>) -> Result, - >( - &'n self, - mut f: F, - ) -> Result { - #[cfg_attr(feature = "recursive_protection", recursive::recursive)] - fn apply_with_lambdas_params_impl< - 'n, - F: FnMut( - &'n Arc, - &HashSet<&'n str>, - ) -> Result, - >( - node: &'n Arc, - args: &HashSet<&'n str>, - f: &mut F, - ) -> Result { - match node.as_any().downcast_ref::() { - Some(lambda) => { - let mut args = args.clone(); - - args.extend(lambda.params().iter().map(|v| v.as_str())); - - f(node, &args)?.visit_children(|| { - node.apply_children(|c| { - apply_with_lambdas_params_impl(c, &args, f) - }) - }) - } - _ => f(node, args)?.visit_children(|| { - node.apply_children(|c| apply_with_lambdas_params_impl(c, args, f)) - }), - } - } - - apply_with_lambdas_params_impl(self, &HashSet::new(), &mut f) - } - - fn apply_with_schema<'n, F: FnMut(&'n Self, &Schema) -> Result>( - &'n self, - schema: &Schema, - mut f: F, - ) -> Result { - #[cfg_attr(feature = "recursive_protection", recursive::recursive)] - fn apply_with_lambdas_impl< - 'n, - F: FnMut(&'n Arc, &Schema) -> Result, - >( - node: &'n Arc, - schema: &Schema, - f: &mut F, - ) -> Result { - f(node, schema)?.visit_children(|| { - node.apply_children_with_schema(schema, |c, schema| { - apply_with_lambdas_impl(c, schema, f) - }) - }) - } - - apply_with_lambdas_impl(self, schema, &mut f) - } - - fn apply_children_with_schema< - 'n, - F: FnMut(&'n Self, &Schema) -> Result, - >( - &'n self, - schema: &Schema, - mut f: F, - ) -> Result { - match self.as_any().downcast_ref::() { - Some(scalar_function) - if scalar_function - .args() - .iter() - .any(|arg| arg.as_any().is::()) => - { - let mut lambdas_schemas = lambdas_schemas_from_args( - scalar_function.fun(), - scalar_function.args(), - schema, - )? - .into_iter(); - - self.apply_children(|expr| f(expr, &lambdas_schemas.next().unwrap())) - } - _ => self.apply_children(|e| f(e, schema)), - } - } - - fn transform_down_with_schema< - F: FnMut(Self, &Schema) -> Result>, - >( - self, - schema: &Schema, - mut f: F, - ) -> Result> { - #[cfg_attr(feature = "recursive_protection", recursive::recursive)] - fn transform_down_with_schema_impl< - F: FnMut( - Arc, - &Schema, - ) -> Result>>, - >( - node: Arc, - schema: &Schema, - f: &mut F, - ) -> Result>> { - f(node, schema)?.transform_children(|node| { - map_children_with_schema(node, schema, |n, schema| { - transform_down_with_schema_impl(n, schema, f) - }) - }) - } - - transform_down_with_schema_impl(self, schema, &mut f) - } - - fn transform_up_with_schema Result>>( - self, - schema: &Schema, - mut f: F, - ) -> Result> { - #[cfg_attr(feature = "recursive_protection", recursive::recursive)] - fn transform_up_with_schema_impl< - F: FnMut( - Arc, - &Schema, - ) -> Result>>, - >( - node: Arc, - schema: &Schema, - f: &mut F, - ) -> Result>> { - map_children_with_schema(node, schema, |n, schema| { - transform_up_with_schema_impl(n, schema, f) - })? - .transform_parent(|n| f(n, schema)) - } - - transform_up_with_schema_impl(self, schema, &mut f) - } - - fn transform_up_with_lambdas_params( - self, - mut f: impl FnMut(Self, &HashSet) -> Result>, - ) -> Result> { - #[cfg_attr(feature = "recursive_protection", recursive::recursive)] - fn transform_up_with_lambdas_params_impl< - F: FnMut( - Arc, - &HashSet, - ) -> Result>>, - >( - node: Arc, - params: &HashSet, - f: &mut F, - ) -> Result>> { - map_children_with_lambdas_params(node, params, |n, params| { - transform_up_with_lambdas_params_impl(n, params, f) - })? - .transform_parent(|n| f(n, params)) - } - - transform_up_with_lambdas_params_impl(self, &HashSet::new(), &mut f) - } - - fn transform_down_with_lambdas_params( - self, - mut f: impl FnMut(Self, &HashSet) -> Result>, - ) -> Result> { - #[cfg_attr(feature = "recursive_protection", recursive::recursive)] - fn transform_down_with_lambdas_params_impl< - F: FnMut( - Arc, - &HashSet, - ) -> Result>>, - >( - node: Arc, - params: &HashSet, - f: &mut F, - ) -> Result>> { - f(node, params)?.transform_children(|node| { - map_children_with_lambdas_params(node, params, |node, args| { - transform_down_with_lambdas_params_impl(node, args, f) - }) - }) - } - - transform_down_with_lambdas_params_impl(self, &HashSet::new(), &mut f) - } -} - -fn map_children_with_schema( - node: Arc, - schema: &Schema, - mut f: impl FnMut( - Arc, - &Schema, - ) -> Result>>, -) -> Result>> { - match node.as_any().downcast_ref::() { - Some(fun) if fun.args().iter().any(|arg| arg.as_any().is::()) => { - let mut args_schemas = - lambdas_schemas_from_args(fun.fun(), fun.args(), schema)?.into_iter(); - - node.map_children(|node| f(node, &args_schemas.next().unwrap())) - } - _ => node.map_children(|node| f(node, schema)), - } -} - -fn map_children_with_lambdas_params( - node: Arc, - params: &HashSet, - mut f: impl FnMut( - Arc, - &HashSet, - ) -> Result>>, -) -> Result>> { - match node.as_any().downcast_ref::() { - Some(lambda) => { - let mut params = params.clone(); - - params.extend(lambda.params().iter().cloned()); - - node.map_children(|node| f(node, ¶ms)) - } - None => node.map_children(|node| f(node, params)), - } -} - #[cfg(test)] mod tests { use std::any::Any; - use std::{borrow::Cow, sync::Arc}; + use std::sync::Arc; use super::*; - use super::{lambdas_schemas_from_args, PhysicalExprExt}; use crate::expressions::Column; - use crate::{create_physical_expr, ScalarFunctionExpr}; + use crate::ScalarFunctionExpr; use arrow::datatypes::{DataType, Field, Schema}; - use datafusion_common::{tree_node::TreeNodeRecursion, DFSchema, HashSet, Result}; - use datafusion_expr::{ - col, expr::Lambda, Expr, ScalarFunctionArgs, ValueOrLambdaParameter, Volatility, - }; + use datafusion_common::Result; + use datafusion_expr::{ScalarFunctionArgs, Volatility}; use datafusion_expr::{ScalarUDF, ScalarUDFImpl, Signature}; use datafusion_expr_common::columnar_value::ColumnarValue; use datafusion_physical_expr_common::physical_expr::is_volatile; @@ -935,190 +548,4 @@ mod tests { let stable_arc: Arc = Arc::new(stable_expr); assert!(!is_volatile(&stable_arc)); } - - fn list_list_int() -> Schema { - Schema::new(vec![Field::new( - "v", - DataType::new_list(DataType::new_list(DataType::Int32, false), false), - false, - )]) - } - - fn list_int() -> Schema { - Schema::new(vec![Field::new( - "v", - DataType::new_list(DataType::Int32, false), - false, - )]) - } - - fn int() -> Schema { - Schema::new(vec![Field::new("v", DataType::Int32, false)]) - } - - fn array_transform_udf() -> ScalarUDF { - ScalarUDF::new_from_impl(ArrayTransformFunc::new()) - } - - fn args() -> Vec { - vec![ - col("v"), - Expr::Lambda(Lambda::new( - vec!["v".into()], - array_transform_udf().call(vec![ - col("v"), - Expr::Lambda(Lambda::new(vec!["v".into()], -col("v"))), - ]), - )), - ] - } - - // array_transform(v, |v| -> array_transform(v, |v| -> -v)) - fn array_transform() -> Arc { - let e = array_transform_udf().call(args()); - - create_physical_expr( - &e, - &DFSchema::try_from(list_list_int()).unwrap(), - &Default::default(), - ) - .unwrap() - } - - #[derive(Debug, PartialEq, Eq, Hash)] - struct ArrayTransformFunc { - signature: Signature, - } - - impl ArrayTransformFunc { - pub fn new() -> Self { - Self { - signature: Signature::any(2, Volatility::Immutable), - } - } - } - - impl ScalarUDFImpl for ArrayTransformFunc { - fn as_any(&self) -> &dyn std::any::Any { - self - } - - fn name(&self) -> &str { - "array_transform" - } - - fn signature(&self) -> &Signature { - &self.signature - } - - fn return_type(&self, arg_types: &[DataType]) -> Result { - Ok(arg_types[0].clone()) - } - - fn lambdas_parameters( - &self, - args: &[ValueOrLambdaParameter], - ) -> Result>>> { - let ValueOrLambdaParameter::Value(value_field) = &args[0] else { - unimplemented!() - }; - let DataType::List(field) = value_field.data_type() else { - unimplemented!() - }; - - Ok(vec![ - None, - Some(vec![Field::new( - "", - field.data_type().clone(), - field.is_nullable(), - )]), - ]) - } - - fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result { - unimplemented!() - } - } - - #[test] - fn test_lambdas_schemas_from_args() { - let schema = list_list_int(); - let expr = array_transform(); - - let args = expr - .as_any() - .downcast_ref::() - .unwrap() - .args(); - - let schemas = - lambdas_schemas_from_args(&array_transform_udf(), args, &schema).unwrap(); - - assert_eq!(schemas, &[Cow::Borrowed(&schema), Cow::Owned(list_int())]); - } - - #[test] - fn test_apply_with_schema() { - let mut steps = vec![]; - - array_transform() - .apply_with_schema(&list_list_int(), |node, schema| { - steps.push((node.to_string(), schema.clone())); - - Ok(TreeNodeRecursion::Continue) - }) - .unwrap(); - - let expected = [ - ( - "array_transform(v@0, (v) -> array_transform(v@0, (v) -> (- v@0)))", - list_list_int(), - ), - ("(v) -> array_transform(v@0, (v) -> (- v@0))", list_int()), - ("array_transform(v@0, (v) -> (- v@0))", list_int()), - ("(v) -> (- v@0)", int()), - ("(- v@0)", int()), - ("v@0", int()), - ("v@0", int()), - ("v@0", int()), - ] - .map(|(a, b)| (String::from(a), b)); - - assert_eq!(steps, expected); - } - - #[test] - fn test_apply_with_lambdas_params() { - let array_transform = array_transform(); - let mut steps = vec![]; - - array_transform - .apply_with_lambdas_params(|node, params| { - steps.push((node.to_string(), params.clone())); - - Ok(TreeNodeRecursion::Continue) - }) - .unwrap(); - - let expected = [ - ( - "array_transform(v@0, (v) -> array_transform(v@0, (v) -> (- v@0)))", - HashSet::from(["v"]), - ), - ( - "(v) -> array_transform(v@0, (v) -> (- v@0))", - HashSet::from(["v"]), - ), - ("array_transform(v@0, (v) -> (- v@0))", HashSet::from(["v"])), - ("(v) -> (- v@0)", HashSet::from(["v"])), - ("(- v@0)", HashSet::from(["v"])), - ("v@0", HashSet::from(["v"])), - ("v@0", HashSet::from(["v"])), - ("v@0", HashSet::from(["v"])), - ] - .map(|(a, b)| (String::from(a), b)); - - assert_eq!(steps, expected); - } } diff --git a/datafusion/physical-expr/src/simplifier/mod.rs b/datafusion/physical-expr/src/simplifier/mod.rs index dd7e6e314672f..80d6ee0a7b914 100644 --- a/datafusion/physical-expr/src/simplifier/mod.rs +++ b/datafusion/physical-expr/src/simplifier/mod.rs @@ -19,12 +19,12 @@ use arrow::datatypes::Schema; use datafusion_common::{ - tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeRewriter}, + tree_node::{Transformed, TreeNode, TreeNodeRewriter}, Result, }; use std::sync::Arc; -use crate::{PhysicalExpr, PhysicalExprExt}; +use crate::PhysicalExpr; pub mod unwrap_cast; @@ -48,22 +48,6 @@ impl<'a> PhysicalExprSimplifier<'a> { &mut self, expr: Arc, ) -> Result> { - return expr - .transform_up_with_schema(self.schema, |node, schema| { - // Apply unwrap cast optimization - #[cfg(test)] - let original_type = node.data_type(schema).unwrap(); - let unwrapped = unwrap_cast::unwrap_cast_in_comparison(node, schema)?; - #[cfg(test)] - assert_eq!( - unwrapped.data.data_type(schema).unwrap(), - original_type, - "Simplified expression should have the same data type as the original" - ); - Ok(unwrapped) - }) - .data(); - Ok(expr.rewrite(self)?.data) } } diff --git a/datafusion/physical-expr/src/simplifier/unwrap_cast.rs b/datafusion/physical-expr/src/simplifier/unwrap_cast.rs index 1ccfc1cfe84d8..d409ce9cb5bf2 100644 --- a/datafusion/physical-expr/src/simplifier/unwrap_cast.rs +++ b/datafusion/physical-expr/src/simplifier/unwrap_cast.rs @@ -34,22 +34,22 @@ use std::sync::Arc; use arrow::datatypes::{DataType, Schema}; -use datafusion_common::{tree_node::Transformed, Result, ScalarValue}; +use datafusion_common::{ + tree_node::{Transformed, TreeNode}, + Result, ScalarValue, +}; use datafusion_expr::Operator; use datafusion_expr_common::casts::try_cast_literal_to_type; +use crate::expressions::{lit, BinaryExpr, CastExpr, Literal, TryCastExpr}; use crate::PhysicalExpr; -use crate::{ - expressions::{lit, BinaryExpr, CastExpr, Literal, TryCastExpr}, - PhysicalExprExt, -}; /// Attempts to unwrap casts in comparison expressions. pub(crate) fn unwrap_cast_in_comparison( expr: Arc, schema: &Schema, ) -> Result>> { - expr.transform_down_with_schema(schema, |e, schema| { + expr.transform_down(|e| { if let Some(binary) = e.as_any().downcast_ref::() { if let Some(unwrapped) = try_unwrap_cast_binary(binary, schema)? { return Ok(Transformed::yes(unwrapped)); diff --git a/datafusion/physical-expr/src/utils/mod.rs b/datafusion/physical-expr/src/utils/mod.rs index 92ecbb7176dc9..745ae855efee2 100644 --- a/datafusion/physical-expr/src/utils/mod.rs +++ b/datafusion/physical-expr/src/utils/mod.rs @@ -22,7 +22,6 @@ use std::borrow::Borrow; use std::sync::Arc; use crate::expressions::{BinaryExpr, Column}; -use crate::scalar_function::PhysicalExprExt; use crate::tree_node::ExprContext; use crate::PhysicalExpr; use crate::PhysicalSortExpr; @@ -228,11 +227,9 @@ where /// Recursively extract referenced [`Column`]s within a [`PhysicalExpr`]. pub fn collect_columns(expr: &Arc) -> HashSet { let mut columns = HashSet::::new(); - expr.apply_with_lambdas_params(|expr, lambdas_params| { + expr.apply(|expr| { if let Some(column) = expr.as_any().downcast_ref::() { - if !lambdas_params.contains(column.name()) { - columns.get_or_insert_owned(column); - } + columns.get_or_insert_owned(column); } Ok(TreeNodeRecursion::Continue) }) @@ -254,16 +251,14 @@ pub fn reassign_expr_columns( expr: Arc, schema: &Schema, ) -> Result> { - expr.transform_down_with_lambdas_params(|expr, lambdas_params| { + expr.transform_down(|expr| { if let Some(column) = expr.as_any().downcast_ref::() { - if !lambdas_params.contains(column.name()) { - let index = schema.index_of(column.name())?; + let index = schema.index_of(column.name())?; - return Ok(Transformed::yes(Arc::new(Column::new( - column.name(), - index, - )))); - } + return Ok(Transformed::yes(Arc::new(Column::new( + column.name(), + index, + )))); } Ok(Transformed::no(expr)) }) diff --git a/datafusion/physical-optimizer/src/enforce_sorting/sort_pushdown.rs b/datafusion/physical-optimizer/src/enforce_sorting/sort_pushdown.rs index d87e001946414..6e4e784866129 100644 --- a/datafusion/physical-optimizer/src/enforce_sorting/sort_pushdown.rs +++ b/datafusion/physical-optimizer/src/enforce_sorting/sort_pushdown.rs @@ -29,7 +29,7 @@ use datafusion_expr::JoinType; use datafusion_physical_expr::expressions::Column; use datafusion_physical_expr::utils::collect_columns; use datafusion_physical_expr::{ - add_offset_to_physical_sort_exprs, EquivalenceProperties, PhysicalExprExt, + add_offset_to_physical_sort_exprs, EquivalenceProperties, }; use datafusion_physical_expr_common::sort_expr::{ LexOrdering, LexRequirement, OrderingRequirements, PhysicalSortExpr, @@ -661,21 +661,20 @@ fn handle_custom_pushdown( .into_iter() .map(|req| { let child_schema = plan.children()[maintained_child_idx].schema(); - let updated_columns = - req.expr - .transform_up_with_lambdas_params(|expr, lambdas_params| { - match expr.as_any().downcast_ref::() { - Some(col) if !lambdas_params.contains(col.name()) => { - let new_index = col.index() - sub_offset; - Ok(Transformed::yes(Arc::new(Column::new( - child_schema.field(new_index).name(), - new_index, - )))) - } - _ => Ok(Transformed::no(expr)), - } - })? - .data; + let updated_columns = req + .expr + .transform_up(|expr| { + if let Some(col) = expr.as_any().downcast_ref::() { + let new_index = col.index() - sub_offset; + Ok(Transformed::yes(Arc::new(Column::new( + child_schema.field(new_index).name(), + new_index, + )))) + } else { + Ok(Transformed::no(expr)) + } + })? + .data; Ok(PhysicalSortRequirement::new(updated_columns, req.options)) }) .collect::>>()?; @@ -743,21 +742,20 @@ fn handle_hash_join( .into_iter() .map(|req| { let child_schema = plan.children()[1].schema(); - let updated_columns = - req.expr - .transform_up_with_lambdas_params(|expr, lambdas_params| { - match expr.as_any().downcast_ref::() { - Some(col) if !lambdas_params.contains(col.name()) => { - let index = projected_indices[col.index()].index; - Ok(Transformed::yes(Arc::new(Column::new( - child_schema.field(index).name(), - index, - )))) - } - _ => Ok(Transformed::no(expr)), - } - })? - .data; + let updated_columns = req + .expr + .transform_up(|expr| { + if let Some(col) = expr.as_any().downcast_ref::() { + let index = projected_indices[col.index()].index; + Ok(Transformed::yes(Arc::new(Column::new( + child_schema.field(index).name(), + index, + )))) + } else { + Ok(Transformed::no(expr)) + } + })? + .data; Ok(PhysicalSortRequirement::new(updated_columns, req.options)) }) .collect::>>()?; diff --git a/datafusion/physical-optimizer/src/projection_pushdown.rs b/datafusion/physical-optimizer/src/projection_pushdown.rs index 8ed81d3874d64..987e3cb6f713e 100644 --- a/datafusion/physical-optimizer/src/projection_pushdown.rs +++ b/datafusion/physical-optimizer/src/projection_pushdown.rs @@ -23,7 +23,6 @@ use crate::PhysicalOptimizerRule; use arrow::datatypes::{Fields, Schema, SchemaRef}; use datafusion_common::alias::AliasGenerator; -use datafusion_physical_expr::PhysicalExprExt; use std::collections::HashSet; use std::sync::Arc; @@ -244,11 +243,9 @@ fn minimize_join_filter( rhs_schema: &Schema, ) -> JoinFilter { let mut used_columns = HashSet::new(); - expr.apply_with_lambdas_params(|expr, lambdas_params| { + expr.apply(|expr| { if let Some(col) = expr.as_any().downcast_ref::() { - if !lambdas_params.contains(col.name()) { - used_columns.insert(col.index()); - } + used_columns.insert(col.index()); } Ok(TreeNodeRecursion::Continue) }) @@ -270,19 +267,17 @@ fn minimize_join_filter( .collect::(); let final_expr = expr - .transform_up_with_lambdas_params(|expr, lambdas_params| { - match expr.as_any().downcast_ref::() { - Some(column) if !lambdas_params.contains(column.name()) => { - let new_idx = used_columns - .iter() - .filter(|idx| **idx < column.index()) - .count(); - let new_column = Column::new(column.name(), new_idx); - Ok(Transformed::yes( - Arc::new(new_column) as Arc - )) - } - _ => Ok(Transformed::no(expr)), + .transform_up(|expr| match expr.as_any().downcast_ref::() { + None => Ok(Transformed::no(expr)), + Some(column) => { + let new_idx = used_columns + .iter() + .filter(|idx| **idx < column.index()) + .count(); + let new_column = Column::new(column.name(), new_idx); + Ok(Transformed::yes( + Arc::new(new_column) as Arc + )) } }) .expect("Closure cannot fail"); @@ -385,9 +380,10 @@ impl<'a> JoinFilterRewriter<'a> { // First, add a new projection. The expression must be rewritten, as it is no longer // executed against the filter schema. let new_idx = self.join_side_projections.len(); - let rewritten_expr = expr.transform_up_with_lambdas_params(|expr, lambdas_params| { + let rewritten_expr = expr.transform_up(|expr| { Ok(match expr.as_any().downcast_ref::() { - Some(column) if !lambdas_params.contains(column.name()) => { + None => Transformed::no(expr), + Some(column) => { let intermediate_column = &self.intermediate_column_indices[column.index()]; assert_eq!(intermediate_column.side, self.join_side); @@ -397,7 +393,6 @@ impl<'a> JoinFilterRewriter<'a> { let new_column = Column::new(field.name(), join_side_index); Transformed::yes(Arc::new(new_column) as Arc) } - _ => Transformed::no(expr), }) })?; self.join_side_projections.push((rewritten_expr.data, name)); @@ -420,17 +415,15 @@ impl<'a> JoinFilterRewriter<'a> { join_side: JoinSide, ) -> Result { let mut result = false; - expr.apply_with_lambdas_params(|expr, lambdas_params| { - match expr.as_any().downcast_ref::() { - Some(c) if !lambdas_params.contains(c.name()) => { - let column_index = &self.intermediate_column_indices[c.index()]; - if column_index.side == join_side { - result = true; - return Ok(TreeNodeRecursion::Stop); - } - Ok(TreeNodeRecursion::Continue) + expr.apply(|expr| match expr.as_any().downcast_ref::() { + None => Ok(TreeNodeRecursion::Continue), + Some(c) => { + let column_index = &self.intermediate_column_indices[c.index()]; + if column_index.side == join_side { + result = true; + return Ok(TreeNodeRecursion::Stop); } - _ => Ok(TreeNodeRecursion::Continue), + Ok(TreeNodeRecursion::Continue) } })?; diff --git a/datafusion/physical-plan/src/async_func.rs b/datafusion/physical-plan/src/async_func.rs index be72a6af2b509..54a76e0ebb971 100644 --- a/datafusion/physical-plan/src/async_func.rs +++ b/datafusion/physical-plan/src/async_func.rs @@ -22,13 +22,13 @@ use crate::{ }; use arrow::array::RecordBatch; use arrow_schema::{Fields, Schema, SchemaRef}; -use datafusion_common::tree_node::{Transformed, TreeNodeRecursion}; +use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion}; use datafusion_common::{internal_err, Result}; use datafusion_execution::{SendableRecordBatchStream, TaskContext}; use datafusion_physical_expr::async_scalar_function::AsyncFuncExpr; use datafusion_physical_expr::equivalence::ProjectionMapping; use datafusion_physical_expr::expressions::Column; -use datafusion_physical_expr::{PhysicalExprExt, ScalarFunctionExpr}; +use datafusion_physical_expr::ScalarFunctionExpr; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use futures::stream::StreamExt; use log::trace; @@ -249,7 +249,7 @@ impl AsyncMapper { schema: &Schema, ) -> Result<()> { // recursively look for references to async functions - physical_expr.apply_with_schema(schema, |expr, schema| { + physical_expr.apply(|expr| { if let Some(scalar_func_expr) = expr.as_any().downcast_ref::() { diff --git a/datafusion/physical-plan/src/joins/stream_join_utils.rs b/datafusion/physical-plan/src/joins/stream_join_utils.rs index b70a8f60508a5..80221a77992ce 100644 --- a/datafusion/physical-plan/src/joins/stream_join_utils.rs +++ b/datafusion/physical-plan/src/joins/stream_join_utils.rs @@ -35,7 +35,7 @@ use arrow::array::{ }; use arrow::compute::concat_batches; use arrow::datatypes::{ArrowNativeType, Schema, SchemaRef}; -use datafusion_common::tree_node::{Transformed, TransformedResult}; +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::utils::memory::estimate_memory_size; use datafusion_common::{ arrow_datafusion_err, DataFusionError, HashSet, JoinSide, Result, ScalarValue, @@ -44,7 +44,7 @@ use datafusion_expr::interval_arithmetic::Interval; use datafusion_physical_expr::expressions::Column; use datafusion_physical_expr::intervals::cp_solver::ExprIntervalGraph; use datafusion_physical_expr::utils::collect_columns; -use datafusion_physical_expr::{PhysicalExpr, PhysicalExprExt, PhysicalSortExpr}; +use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; use datafusion_physical_expr_common::sort_expr::LexOrdering; use hashbrown::HashTable; @@ -312,13 +312,13 @@ pub fn convert_sort_expr_with_filter_schema( // Since we are sure that one to one column mapping includes all columns, we convert // the sort expression into a filter expression. let converted_filter_expr = expr - .transform_up_with_lambdas_params(|p, lambdas_params| { - convert_filter_columns(p.as_ref(), &column_map, lambdas_params).map( - |transformed| match transformed { + .transform_up(|p| { + convert_filter_columns(p.as_ref(), &column_map).map(|transformed| { + match transformed { Some(transformed) => Transformed::yes(transformed), None => Transformed::no(p), - }, - ) + } + }) }) .data()?; // Search the converted `PhysicalExpr` in filter expression; if an exact @@ -361,17 +361,14 @@ pub fn build_filter_input_order( fn convert_filter_columns( input: &dyn PhysicalExpr, column_map: &HashMap, - lambdas_params: &HashSet, ) -> Result>> { // Attempt to downcast the input expression to a Column type. - Ok(match input.as_any().downcast_ref::() { - Some(col) if !lambdas_params.contains(col.name()) => { - column_map.get(col).map(|c| Arc::new(c.clone()) as _) - } - _ => { - // If the downcast fails, return the input expression as is. - None - } + Ok(if let Some(col) = input.as_any().downcast_ref::() { + // If the downcast is successful, retrieve the corresponding filter column. + column_map.get(col).map(|c| Arc::new(c.clone()) as _) + } else { + // If the downcast fails, return the input expression as is. + None }) } diff --git a/datafusion/physical-plan/src/projection.rs b/datafusion/physical-plan/src/projection.rs index ab654e4eee1df..ead2196860cde 100644 --- a/datafusion/physical-plan/src/projection.rs +++ b/datafusion/physical-plan/src/projection.rs @@ -42,13 +42,14 @@ use std::task::{Context, Poll}; use arrow::datatypes::SchemaRef; use arrow::record_batch::{RecordBatch, RecordBatchOptions}; use datafusion_common::config::ConfigOptions; -use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNodeRecursion}; +use datafusion_common::tree_node::{ + Transformed, TransformedResult, TreeNode, TreeNodeRecursion, +}; use datafusion_common::{internal_err, JoinSide, Result}; use datafusion_execution::TaskContext; use datafusion_physical_expr::equivalence::ProjectionMapping; use datafusion_physical_expr::utils::collect_columns; -use datafusion_physical_expr::{PhysicalExprExt, PhysicalExprRef}; -use datafusion_physical_expr_common::physical_expr::fmt_sql; +use datafusion_physical_expr_common::physical_expr::{fmt_sql, PhysicalExprRef}; use datafusion_physical_expr_common::sort_expr::{LexOrdering, LexRequirement}; // Re-exported from datafusion-physical-expr for backwards compatibility // We recommend updating your imports to use datafusion-physical-expr directly @@ -865,12 +866,10 @@ fn try_unifying_projections( projection.expr().iter().for_each(|proj_expr| { proj_expr .expr - .apply_with_lambdas_params(|expr, lambdas_params| { + .apply(|expr| { Ok({ if let Some(column) = expr.as_any().downcast_ref::() { - if !lambdas_params.contains(column.name()) { - *column_ref_map.entry(column.clone()).or_default() += 1; - } + *column_ref_map.entry(column.clone()).or_default() += 1; } TreeNodeRecursion::Continue }) @@ -958,31 +957,31 @@ fn new_columns_for_join_on( .filter_map(|on| { // Rewrite all columns in `on` Arc::clone(*on) - .transform_with_lambdas_params(|expr, lambdas_params| { - match expr.as_any().downcast_ref::() { - Some(column) if !lambdas_params.contains(column.name()) => { - let new_column = projection_exprs - .iter() - .enumerate() - .find(|(_, (proj_column, _))| { - column.name() == proj_column.name() - && column.index() + column_index_offset - == proj_column.index() - }) - .map(|(index, (_, alias))| Column::new(alias, index)); - if let Some(new_column) = new_column { - Ok(Transformed::yes(Arc::new(new_column))) - } else { - // If the column is not found in the projection expressions, - // it means that the column is not projected. In this case, - // we cannot push the projection down. - internal_err!( - "Column {:?} not found in projection expressions", - column - ) - } + .transform(|expr| { + if let Some(column) = expr.as_any().downcast_ref::() { + // Find the column in the projection expressions + let new_column = projection_exprs + .iter() + .enumerate() + .find(|(_, (proj_column, _))| { + column.name() == proj_column.name() + && column.index() + column_index_offset + == proj_column.index() + }) + .map(|(index, (_, alias))| Column::new(alias, index)); + if let Some(new_column) = new_column { + Ok(Transformed::yes(Arc::new(new_column))) + } else { + // If the column is not found in the projection expressions, + // it means that the column is not projected. In this case, + // we cannot push the projection down. + internal_err!( + "Column {:?} not found in projection expressions", + column + ) } - _ => Ok(Transformed::no(expr)), + } else { + Ok(Transformed::no(expr)) } }) .data() diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index b87a50b3f5281..c6e33159aa89f 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -622,9 +622,9 @@ pub fn serialize_expr( .unwrap_or(HashMap::new()), })), }, - Expr::Lambda { .. } => { + Expr::Lambda(_) | Expr::LambdaColumn(_) => { return Err(Error::General( - "Proto serialization error: Lambda not supported".to_string(), + "Proto serialization error: Lambda not implemented".to_string(), )) } }; diff --git a/datafusion/pruning/src/pruning_predicate.rs b/datafusion/pruning/src/pruning_predicate.rs index c9df93f8b693c..380ada10df6e1 100644 --- a/datafusion/pruning/src/pruning_predicate.rs +++ b/datafusion/pruning/src/pruning_predicate.rs @@ -38,14 +38,13 @@ use datafusion_common::error::Result; use datafusion_common::tree_node::TransformedResult; use datafusion_common::{ internal_datafusion_err, internal_err, plan_datafusion_err, plan_err, - tree_node::Transformed, ScalarValue, + tree_node::{Transformed, TreeNode}, + ScalarValue, }; use datafusion_common::{Column, DFSchema}; use datafusion_expr_common::operator::Operator; use datafusion_physical_expr::utils::{collect_columns, Guarantee, LiteralGuarantee}; -use datafusion_physical_expr::{ - expressions as phys_expr, PhysicalExprExt, PhysicalExprRef, -}; +use datafusion_physical_expr::{expressions as phys_expr, PhysicalExprRef}; use datafusion_physical_expr_common::physical_expr::snapshot_physical_expr; use datafusion_physical_plan::{ColumnarValue, PhysicalExpr}; @@ -1205,9 +1204,9 @@ fn rewrite_column_expr( column_old: &phys_expr::Column, column_new: &phys_expr::Column, ) -> Result> { - e.transform_with_lambdas_params(|expr, lambdas_params| { + e.transform(|expr| { if let Some(column) = expr.as_any().downcast_ref::() { - if !lambdas_params.contains(column.name()) && column == column_old { + if column == column_old { return Ok(Transformed::yes(Arc::new(column_new.clone()))); } } diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index c13fd33104eb0..439e65e8f7e47 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +use std::sync::Arc; + use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; use arrow::datatypes::DataType; @@ -26,7 +28,10 @@ use datafusion_expr::expr::{Lambda, ScalarFunction, Unnest}; use datafusion_expr::expr::{NullTreatment, WildcardOptions, WindowFunction}; use datafusion_expr::planner::PlannerResult; use datafusion_expr::planner::{RawAggregateExpr, RawWindowExpr}; -use datafusion_expr::{expr, Expr, ExprSchemable, WindowFrame, WindowFunctionDefinition}; +use datafusion_expr::{ + expr, Expr, ExprSchemable, ValueOrLambdaParameter, WindowFrame, + WindowFunctionDefinition, +}; use sqlparser::ast::{ DuplicateTreatment, Expr as SQLExpr, Function as SQLFunction, FunctionArg, FunctionArgExpr, FunctionArgumentClause, FunctionArgumentList, FunctionArguments, @@ -274,8 +279,94 @@ impl SqlToRel<'_, S> { } // User-defined function (UDF) should have precedence if let Some(fm) = self.context_provider.get_function_meta(&name) { - let (args, arg_names) = - self.function_args_to_expr_with_names(args, schema, planner_context)?; + enum ExprOrLambda { + ExprWithName((Expr, Option)), + Lambda(sqlparser::ast::LambdaFunction), + } + + let pairs = args + .into_iter() + .map(|a| match a { + FunctionArg::Unnamed(FunctionArgExpr::Expr(SQLExpr::Lambda( + lambda, + ))) => Ok(ExprOrLambda::Lambda(lambda)), + _ => Ok(ExprOrLambda::ExprWithName( + self.sql_fn_arg_to_logical_expr_with_name( + a, + schema, + planner_context, + )?, + )), + }) + .collect::>>()?; + + let metadata = pairs + .iter() + .map(|e| match e { + ExprOrLambda::ExprWithName((expr, _name)) => { + Ok(ValueOrLambdaParameter::Value(expr.to_field(schema)?.1)) + } + ExprOrLambda::Lambda(_lambda_function) => { + Ok(ValueOrLambdaParameter::Lambda) + } + }) + .collect::>>()?; + + let lambdas_parameters = fm.inner().lambdas_parameters(&metadata)?; + + let pairs = pairs + .into_iter() + .zip(lambdas_parameters) + .map(|(e, lambda_parameters)| match (e, lambda_parameters) { + (ExprOrLambda::ExprWithName(expr_with_name), None) => { + Ok(expr_with_name) + } + (ExprOrLambda::Lambda(lambda), Some(lambda_params)) => { + if lambda.params.len() > lambda_params.len() { + return plan_err!( + "lambda defined {} params but UDF support only {}", + lambda.params.len(), + lambda_params.len() + ); + } + + let params = + lambda.params.iter().map(|p| p.value.clone()).collect(); + + let lambda_parameters = lambda_params + .into_iter() + .zip(¶ms) + .map(|(f, n)| Arc::new(f.with_name(n))); + + let mut planner_context = planner_context + .clone() + .with_lambda_parameters(lambda_parameters); + + Ok(( + Expr::Lambda(Lambda { + params, + body: Box::new(self.sql_expr_to_logical_expr( + *lambda.body, + schema, + &mut planner_context, + )?), + }), + None, + )) + } + (ExprOrLambda::ExprWithName(_), Some(_)) => plan_err!( + "{} reported parameters for an argument that is not a lambda", + fm.name() + ), + (ExprOrLambda::Lambda(_), None) => plan_err!( + "{} don't reported the parameters of one of it's lambdas", + fm.name() + ), + }) + .collect::>>()?; + + let (args, arg_names): (Vec, Vec>) = + pairs.into_iter().unzip(); let resolved_args = if arg_names.iter().any(|name| name.is_some()) { if let Some(param_names) = &fm.signature().parameter_names { @@ -724,26 +815,6 @@ impl SqlToRel<'_, S> { let arg_name = crate::utils::normalize_ident(name); Ok((expr, Some(arg_name))) } - FunctionArg::Unnamed(FunctionArgExpr::Expr(SQLExpr::Lambda( - sqlparser::ast::LambdaFunction { params, body }, - ))) => { - let params = params - .into_iter() - .map(|v| v.to_string()) - .collect::>(); - - Ok(( - Expr::Lambda(Lambda { - params: params.clone(), - body: Box::new(self.sql_expr_to_logical_expr( - *body, - schema, - &mut planner_context.clone().with_lambda_parameters(params), - )?), - }), - None, - )) - } FunctionArg::Unnamed(FunctionArgExpr::Expr(arg)) => { let expr = self.sql_expr_to_logical_expr(arg, schema, planner_context)?; Ok((expr, None)) diff --git a/datafusion/sql/src/expr/identifier.rs b/datafusion/sql/src/expr/identifier.rs index dc39cb4de055d..73be980d686d0 100644 --- a/datafusion/sql/src/expr/identifier.rs +++ b/datafusion/sql/src/expr/identifier.rs @@ -20,6 +20,7 @@ use datafusion_common::{ exec_datafusion_err, internal_err, not_impl_err, plan_datafusion_err, plan_err, Column, DFSchema, Result, Span, TableReference, }; +use datafusion_expr::expr::LambdaColumn; use datafusion_expr::planner::PlannerResult; use datafusion_expr::{Case, Expr}; use sqlparser::ast::{CaseWhen, Expr as SQLExpr, Ident}; @@ -53,17 +54,17 @@ impl SqlToRel<'_, S> { // identifier. (e.g. it is "foo.bar" not foo.bar) let normalize_ident = self.ident_normalizer.normalize(id); - if planner_context + if let Some(field) = planner_context .lambdas_parameters() - .contains(&normalize_ident) + .get(&normalize_ident) { - let mut column = Column::new_unqualified(normalize_ident); + let mut lambda_column = LambdaColumn::new(normalize_ident, Arc::clone(field)); if self.options.collect_spans { if let Some(span) = Span::try_from_sqlparser_span(id_span) { - column.spans_mut().add_span(span); + lambda_column.spans_mut().add_span(span); } } - return Ok(Expr::Column(column)); + return Ok(Expr::LambdaColumn(lambda_column)); } // Check for qualified field with unqualified name diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index 2992378fd1d6c..8cc7747ffe16b 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -26,13 +26,14 @@ use arrow::datatypes::*; use datafusion_common::config::SqlParserOptions; use datafusion_common::datatype::{DataTypeExt, FieldExt}; use datafusion_common::error::add_possible_columns_to_diag; -use datafusion_common::TableReference; use datafusion_common::{ - field_not_found, plan_datafusion_err, DFSchemaRef, Diagnostic, HashSet, SchemaError, + field_not_found, plan_datafusion_err, DFSchemaRef, Diagnostic, SchemaError, }; +use datafusion_common::{internal_err, TableReference}; use datafusion_common::{not_impl_err, plan_err, DFSchema, DataFusionError, Result}; use datafusion_expr::logical_plan::{LogicalPlan, LogicalPlanBuilder}; pub use datafusion_expr::planner::ContextProvider; +use datafusion_expr::utils::find_column_exprs; use datafusion_expr::{col, Expr}; use sqlparser::ast::{ArrayElemTypeDef, ExactNumberInfo, TimezoneInfo}; use sqlparser::ast::{ColumnDef as SQLColumnDef, ColumnOption}; @@ -265,8 +266,8 @@ pub struct PlannerContext { outer_from_schema: Option, /// The query schema defined by the table create_table_schema: Option, - /// The lambda introduced columns names - lambdas_parameters: HashSet, + /// The parameters of all lambdas seen so far + lambdas_parameters: HashMap, } impl Default for PlannerContext { @@ -284,7 +285,7 @@ impl PlannerContext { outer_query_schema: None, outer_from_schema: None, create_table_schema: None, - lambdas_parameters: HashSet::new(), + lambdas_parameters: HashMap::new(), } } @@ -371,15 +372,16 @@ impl PlannerContext { self.ctes.get(cte_name).map(|cte| cte.as_ref()) } - pub fn lambdas_parameters(&self) -> &HashSet { + pub fn lambdas_parameters(&self) -> &HashMap { &self.lambdas_parameters } pub fn with_lambda_parameters( mut self, - arguments: impl IntoIterator, + arguments: impl IntoIterator, ) -> Self { - self.lambdas_parameters.extend(arguments); + self.lambdas_parameters + .extend(arguments.into_iter().map(|f| (f.name().clone(), f))); self } @@ -545,11 +547,10 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { schema: &DFSchema, exprs: &[Expr], ) -> Result<()> { - exprs + find_column_exprs(exprs) .iter() - .flat_map(|expr| expr.column_refs()) - .try_for_each(|col| { - match &col.relation { + .try_for_each(|col| match col { + Expr::Column(col) => match &col.relation { Some(r) => schema.field_with_qualified_name(r, &col.name).map(|_| ()), None => { if !schema.fields_with_unqualified_name(&col.name).is_empty() { @@ -599,7 +600,8 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { err.with_diagnostic(diagnostic) } _ => err, - }) + }), + _ => internal_err!("Not a column"), }) } diff --git a/datafusion/sql/src/select.rs b/datafusion/sql/src/select.rs index 0e7490d2c780b..42013a76a8657 100644 --- a/datafusion/sql/src/select.rs +++ b/datafusion/sql/src/select.rs @@ -540,11 +540,9 @@ impl SqlToRel<'_, S> { None => { let mut columns = HashSet::new(); for expr in &aggr_expr { - expr.apply_with_lambdas_params(|expr, lambdas_params| { + expr.apply(|expr| { if let Expr::Column(c) = expr { - if !c.is_lambda_parameter(lambdas_params) { - columns.insert(Expr::Column(c.clone())); - } + columns.insert(Expr::Column(c.clone())); } Ok(TreeNodeRecursion::Continue) }) diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index 67ca92bb1c1f1..71a1a342a9c5e 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -536,6 +536,9 @@ impl Unparser<'_> { body: Box::new(self.expr_to_sql_inner(body)?), })) } + Expr::LambdaColumn(l) => Ok(ast::Expr::Identifier( + self.new_ident_quoted_if_needs(l.name.clone()), + )), } } diff --git a/datafusion/sql/src/unparser/plan.rs b/datafusion/sql/src/unparser/plan.rs index c218ce547b312..e7535338b7677 100644 --- a/datafusion/sql/src/unparser/plan.rs +++ b/datafusion/sql/src/unparser/plan.rs @@ -40,7 +40,7 @@ use crate::unparser::{ast::UnnestRelationBuilder, rewrite::rewrite_qualify}; use crate::utils::UNNEST_PLACEHOLDER; use datafusion_common::{ internal_err, not_impl_err, - tree_node::TransformedResult, + tree_node::{TransformedResult, TreeNode}, Column, DataFusionError, Result, ScalarValue, TableReference, }; use datafusion_expr::expr::OUTER_REFERENCE_COLUMN_PREFIX; @@ -1131,7 +1131,7 @@ impl Unparser<'_> { .cloned() .map(|expr| { if let Some(ref mut rewriter) = filter_alias_rewriter { - expr.rewrite_with_lambdas_params(rewriter).data() + expr.rewrite(rewriter).data() } else { Ok(expr) } @@ -1197,7 +1197,7 @@ impl Unparser<'_> { .cloned() .map(|expr| { if let Some(ref mut rewriter) = alias_rewriter { - expr.rewrite_with_lambdas_params(rewriter).data() + expr.rewrite(rewriter).data() } else { Ok(expr) } diff --git a/datafusion/sql/src/unparser/rewrite.rs b/datafusion/sql/src/unparser/rewrite.rs index 58f4435095517..c961f1d6f1f0c 100644 --- a/datafusion/sql/src/unparser/rewrite.rs +++ b/datafusion/sql/src/unparser/rewrite.rs @@ -20,11 +20,10 @@ use std::{collections::HashSet, sync::Arc}; use arrow::datatypes::Schema; use datafusion_common::tree_node::TreeNodeContainer; use datafusion_common::{ - tree_node::{Transformed, TransformedResult, TreeNode}, + tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeRewriter}, Column, HashMap, Result, TableReference, }; use datafusion_expr::expr::{Alias, UNNEST_COLUMN_PREFIX}; -use datafusion_expr::tree_node::TreeNodeRewriterWithPayload; use datafusion_expr::{Expr, LogicalPlan, Projection, Sort, SortExpr}; use sqlparser::ast::Ident; @@ -467,17 +466,12 @@ pub struct TableAliasRewriter<'a> { pub alias_name: TableReference, } -impl TreeNodeRewriterWithPayload for TableAliasRewriter<'_> { +impl TreeNodeRewriter for TableAliasRewriter<'_> { type Node = Expr; - type Payload<'a> = &'a datafusion_common::HashSet; - fn f_down( - &mut self, - expr: Expr, - lambdas_params: &datafusion_common::HashSet, - ) -> Result> { + fn f_down(&mut self, expr: Expr) -> Result> { match expr { - Expr::Column(column) if !column.is_lambda_parameter(lambdas_params) => { + Expr::Column(column) => { if let Ok(field) = self.table_schema.field_with_name(&column.name) { let new_column = Column::new(Some(self.alias_name.clone()), field.name().clone()); diff --git a/datafusion/sql/src/unparser/utils.rs b/datafusion/sql/src/unparser/utils.rs index f785f640dbcee..8b3791017a8af 100644 --- a/datafusion/sql/src/unparser/utils.rs +++ b/datafusion/sql/src/unparser/utils.rs @@ -161,11 +161,11 @@ pub(crate) fn find_window_nodes_within_select<'a>( /// For example, if expr contains the column expr "__unnest_placeholder(make_array(Int64(1),Int64(2),Int64(2),Int64(5),NULL),depth=1)" /// it will be transformed into an actual unnest expression UNNEST([1, 2, 2, 5, NULL]) pub(crate) fn unproject_unnest_expr(expr: Expr, unnest: &Unnest) -> Result { - expr.transform_up_with_lambdas_params(|sub_expr, lambdas_params| { + expr.transform(|sub_expr| { if let Expr::Column(col_ref) = &sub_expr { // Check if the column is among the columns to run unnest on. // Currently, only List/Array columns (defined in `list_type_columns`) are supported for unnesting. - if !col_ref.is_lambda_parameter(lambdas_params) && unnest.list_type_columns.iter().any(|e| e.1.output_column.name == col_ref.name) { + if unnest.list_type_columns.iter().any(|e| e.1.output_column.name == col_ref.name) { if let Ok(idx) = unnest.schema.index_of_column(col_ref) { if let LogicalPlan::Projection(Projection { expr, .. }) = unnest.input.as_ref() { if let Some(unprojected_expr) = expr.get(idx) { @@ -195,21 +195,22 @@ pub(crate) fn unproject_agg_exprs( agg: &Aggregate, windows: Option<&[&Window]>, ) -> Result { - expr.transform_up_with_lambdas_params(|sub_expr, lambdas_params| { - match sub_expr { - Expr::Column(c) if !c.is_lambda_parameter(lambdas_params) => if let Some(unprojected_expr) = find_agg_expr(agg, &c)? { - Ok(Transformed::yes(unprojected_expr.clone())) - } else if let Some(unprojected_expr) = - windows.and_then(|w| find_window_expr(w, &c.name).cloned()) - { - // Window function can contain an aggregation columns, e.g., 'avg(sum(ss_sales_price)) over ...' that needs to be unprojected - Ok(Transformed::yes(unproject_agg_exprs(unprojected_expr, agg, None)?)) - } else { - internal_err!( - "Tried to unproject agg expr for column '{}' that was not found in the provided Aggregate!", &c.name - ) - }, - _ => Ok(Transformed::no(sub_expr)), + expr.transform(|sub_expr| { + if let Expr::Column(c) = sub_expr { + if let Some(unprojected_expr) = find_agg_expr(agg, &c)? { + Ok(Transformed::yes(unprojected_expr.clone())) + } else if let Some(unprojected_expr) = + windows.and_then(|w| find_window_expr(w, &c.name).cloned()) + { + // Window function can contain an aggregation columns, e.g., 'avg(sum(ss_sales_price)) over ...' that needs to be unprojected + Ok(Transformed::yes(unproject_agg_exprs(unprojected_expr, agg, None)?)) + } else { + internal_err!( + "Tried to unproject agg expr for column '{}' that was not found in the provided Aggregate!", &c.name + ) + } + } else { + Ok(Transformed::no(sub_expr)) } }) .map(|e| e.data) @@ -221,15 +222,16 @@ pub(crate) fn unproject_agg_exprs( /// For example, if expr contains the column expr "COUNT(*) PARTITION BY id" it will be transformed /// into an actual window expression as identified in the window node. pub(crate) fn unproject_window_exprs(expr: Expr, windows: &[&Window]) -> Result { - expr.transform_up_with_lambdas_params(|sub_expr, lambdas_params| match sub_expr { - Expr::Column(c) if !c.is_lambda_parameter(lambdas_params) => { + expr.transform(|sub_expr| { + if let Expr::Column(c) = sub_expr { if let Some(unproj) = find_window_expr(windows, &c.name) { Ok(Transformed::yes(unproj.clone())) } else { Ok(Transformed::no(Expr::Column(c))) } + } else { + Ok(Transformed::no(sub_expr)) } - _ => Ok(Transformed::no(sub_expr)), }) .map(|e| e.data) } @@ -374,7 +376,7 @@ pub(crate) fn try_transform_to_simple_table_scan_with_filters( .cloned() .map(|expr| { if let Some(ref mut rewriter) = filter_alias_rewriter { - expr.rewrite_with_lambdas_params(rewriter).data() + expr.rewrite(rewriter).data() } else { Ok(expr) } diff --git a/datafusion/sql/src/utils.rs b/datafusion/sql/src/utils.rs index 6380412e3b5ee..3c86d2d04905f 100644 --- a/datafusion/sql/src/utils.rs +++ b/datafusion/sql/src/utils.rs @@ -23,16 +23,16 @@ use arrow::datatypes::{ DataType, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION, DECIMAL_DEFAULT_SCALE, }; use datafusion_common::tree_node::{ - Transformed, TransformedResult, TreeNode, TreeNodeRecursion, + Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeRewriter, }; use datafusion_common::{ - exec_datafusion_err, exec_err, internal_err, plan_err, Column, DFSchema, Diagnostic, HashMap, Result, ScalarValue + exec_datafusion_err, exec_err, internal_err, plan_err, Column, DFSchemaRef, + Diagnostic, HashMap, Result, ScalarValue, }; use datafusion_expr::builder::get_struct_unnested_columns; use datafusion_expr::expr::{ Alias, GroupingSet, Unnest, WindowFunction, WindowFunctionParams, }; -use datafusion_expr::tree_node::TreeNodeRewriterWithPayload; use datafusion_expr::utils::{expr_as_column_expr, find_column_exprs}; use datafusion_expr::{ col, expr_vec_fmt, ColumnUnnestList, Expr, ExprSchemable, LogicalPlan, @@ -44,9 +44,9 @@ use sqlparser::ast::{Ident, Value}; /// Make a best-effort attempt at resolving all columns in the expression tree pub(crate) fn resolve_columns(expr: &Expr, plan: &LogicalPlan) -> Result { expr.clone() - .transform_up_with_lambdas_params(|nested_expr, lambdas_params| { + .transform_up(|nested_expr| { match nested_expr { - Expr::Column(col) if !col.is_lambda_parameter(lambdas_params) => { + Expr::Column(col) => { let (qualifier, field) = plan.schema().qualified_field_from_column(&col)?; Ok(Transformed::yes(Expr::Column(Column::from(( @@ -81,7 +81,6 @@ pub(crate) fn rebase_expr( base_exprs: &[Expr], plan: &LogicalPlan, ) -> Result { - //todo user transform_down_with_lambdas_params expr.clone() .transform_down(|nested_expr| { if base_exprs.contains(&nested_expr) { @@ -232,8 +231,8 @@ pub(crate) fn resolve_aliases_to_exprs( expr: Expr, aliases: &HashMap, ) -> Result { - expr.transform_up_with_lambdas_params(|nested_expr, lambdas_params| match nested_expr { - Expr::Column(c) if c.relation.is_none() && !c.is_lambda_parameter(lambdas_params) => { + expr.transform_up(|nested_expr| match nested_expr { + Expr::Column(c) if c.relation.is_none() => { if let Some(aliased_expr) = aliases.get(&c.name) { Ok(Transformed::yes(aliased_expr.clone())) } else { @@ -372,6 +371,7 @@ This is only usedful when used with transform down up A full example of how the transformation works: */ struct RecursiveUnnestRewriter<'a> { + input_schema: &'a DFSchemaRef, root_expr: &'a Expr, // Useful to detect which child expr is a part of/ not a part of unnest operation top_most_unnest: Option, @@ -405,7 +405,6 @@ impl RecursiveUnnestRewriter<'_> { alias_name: String, expr_in_unnest: &Expr, struct_allowed: bool, - input_schema: &DFSchema, ) -> Result> { let inner_expr_name = expr_in_unnest.schema_name().to_string(); @@ -419,7 +418,7 @@ impl RecursiveUnnestRewriter<'_> { // column name as is, to comply with group by and order by let placeholder_column = Column::from_name(placeholder_name.clone()); - let (data_type, _) = expr_in_unnest.data_type_and_nullable(input_schema)?; + let (data_type, _) = expr_in_unnest.data_type_and_nullable(self.input_schema)?; match data_type { DataType::Struct(inner_fields) => { @@ -469,18 +468,17 @@ impl RecursiveUnnestRewriter<'_> { } } -impl TreeNodeRewriterWithPayload for RecursiveUnnestRewriter<'_> { +impl TreeNodeRewriter for RecursiveUnnestRewriter<'_> { type Node = Expr; - type Payload<'a> = &'a DFSchema; /// This downward traversal needs to keep track of: /// - Whether or not some unnest expr has been visited from the top util the current node /// - If some unnest expr has been visited, maintain a stack of such information, this /// is used to detect if some recursive unnest expr exists (e.g **unnest(unnest(unnest(3d column))))** - fn f_down(&mut self, expr: Expr, input_schema: &DFSchema) -> Result> { + fn f_down(&mut self, expr: Expr) -> Result> { if let Expr::Unnest(ref unnest_expr) = expr { let (data_type, _) = - unnest_expr.expr.data_type_and_nullable(input_schema)?; + unnest_expr.expr.data_type_and_nullable(self.input_schema)?; self.consecutive_unnest.push(Some(unnest_expr.clone())); // if expr inside unnest is a struct, do not consider // the next unnest as consecutive unnest (if any) @@ -534,7 +532,7 @@ impl TreeNodeRewriterWithPayload for RecursiveUnnestRewriter<'_> { /// column2 /// ``` /// - fn f_up(&mut self, expr: Expr, input_schema: &DFSchema) -> Result> { + fn f_up(&mut self, expr: Expr) -> Result> { if let Expr::Unnest(ref traversing_unnest) = expr { if traversing_unnest == self.top_most_unnest.as_ref().unwrap() { self.top_most_unnest = None; @@ -570,7 +568,6 @@ impl TreeNodeRewriterWithPayload for RecursiveUnnestRewriter<'_> { expr.schema_name().to_string(), inner_expr, struct_allowed, - input_schema, )?; if struct_allowed { self.transformed_root_exprs = Some(transformed_exprs.clone()); @@ -622,6 +619,7 @@ pub(crate) fn rewrite_recursive_unnest_bottom_up( original_expr: &Expr, ) -> Result> { let mut rewriter = RecursiveUnnestRewriter { + input_schema: input.schema(), root_expr: original_expr, top_most_unnest: None, consecutive_unnest: vec![], @@ -643,7 +641,7 @@ pub(crate) fn rewrite_recursive_unnest_bottom_up( data: transformed_expr, transformed, tnr: _, - } = original_expr.clone().rewrite_with_schema(input.schema(), &mut rewriter)?; + } = original_expr.clone().rewrite(&mut rewriter)?; if !transformed { // TODO: remove the next line after `Expr::Wildcard` is removed diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index 29ea8cb786072..00629c392df48 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -5866,10 +5866,10 @@ select array_ndims(arrow_cast([null], 'List(List(List(Int64)))')); 3 # array_ndims scalar function #2 -#query II -#select array_ndims(array_repeat(array_repeat(array_repeat(1, 3), 2), 1)), array_ndims([[[[[[[[[[[[[[[[[[[[[1]]]]]]]]]]]]]]]]]]]]]); -#---- -#3 21 +query II +select array_ndims(array_repeat(array_repeat(array_repeat(1, 3), 2), 1)), array_ndims([[[[[[[[[[[[[[[[[[[[[1]]]]]]]]]]]]]]]]]]]]]); +---- +3 21 # array_ndims scalar function #3 query II diff --git a/datafusion/sqllogictest/test_files/lambda.slt b/datafusion/sqllogictest/test_files/lambda.slt index 0043eae17a60c..af5334a644421 100644 --- a/datafusion/sqllogictest/test_files/lambda.slt +++ b/datafusion/sqllogictest/test_files/lambda.slt @@ -42,9 +42,9 @@ SELECT array_transform([1], e1 -> (select n from t)); [1] query ? -SELECT array_transform(v, (v1, i) -> array_transform(v1, (v2, j) -> array_transform(v2, v3 -> j)) ) from t; +SELECT array_transform(t.v, (v1, i) -> array_transform(v1, (v2, j) -> array_transform(v2, v3 -> j)) ) from t; ---- -[[[0, 0], [1]], [[0]], [[]]] +[[[0, 1], [0]], [[0]], [[]]] query I? SELECT t.n, array_transform([1, 2], (e) -> n) from t; @@ -75,7 +75,7 @@ logical_plan 01)Projection: array_transform(List([1, 2]), (e, i) -> e + CAST(i AS Int64)) AS array_transform(make_array(Int64(1),Int64(2)),(e, i) -> e + i) 02)--EmptyRelation: rows=1 physical_plan -01)ProjectionExec: expr=[array_transform([1, 2], (e, i) -> e@0 + CAST(i@1 AS Int64)) as array_transform(make_array(Int64(1),Int64(2)),(e, i) -> e + i)] +01)ProjectionExec: expr=[array_transform([1, 2], (e, i) -> e@-1 + CAST(i@-1 AS Int64)) as array_transform(make_array(Int64(1),Int64(2)),(e, i) -> e + i)] 02)--PlaceholderRowExec #cse @@ -86,7 +86,7 @@ logical_plan 01)Projection: t.n + Int64(1), array_transform(List([1]), (v) -> v + t.n + Int64(1)) AS array_transform(make_array(Int64(1)),(v) -> v + t.n + Int64(1)) 02)--TableScan: t projection=[n] physical_plan -01)ProjectionExec: expr=[n@0 + 1 as t.n + Int64(1), array_transform([1], (v) -> v@1 + n@0 + 1) as array_transform(make_array(Int64(1)),(v) -> v + t.n + Int64(1))] +01)ProjectionExec: expr=[n@0 + 1 as t.n + Int64(1), array_transform([1], (v) -> v@-1 + n@0 + 1) as array_transform(make_array(Int64(1)),(v) -> v + t.n + Int64(1))] 02)--DataSourceExec: partitions=1, partition_sizes=[1] query ? @@ -121,9 +121,19 @@ logical_plan 01)Projection: array_transform(List([1, 2, 3, 4, 5]), (v) -> v * Int64(2)) AS array_transform(make_array(Int64(1),Int64(2),Int64(3),Int64(4),Int64(5)),(v) -> v * Int64(2)) 02)--EmptyRelation: rows=1 physical_plan -01)ProjectionExec: expr=[array_transform([1, 2, 3, 4, 5], (v) -> v@0 * 2) as array_transform(make_array(Int64(1),Int64(2),Int64(3),Int64(4),Int64(5)),(v) -> v * Int64(2))] +01)ProjectionExec: expr=[array_transform([1, 2, 3, 4, 5], (v) -> v@-1 * 2) as array_transform(make_array(Int64(1),Int64(2),Int64(3),Int64(4),Int64(5)),(v) -> v * Int64(2))] 02)--PlaceholderRowExec +query ? +SELECT array_transform( + [[1]], + v -> array_concat( + array_transform(v, v -> v), + array_transform(v, v1 -> v1 + v[0]) + ) +); +---- +[[1, NULL]] query I?? SELECT t.n, t.v, array_transform(t.v, (v, i) -> array_transform(v, (v, j) -> n) ) from t; @@ -144,23 +154,27 @@ logical_plan 01)Projection: Boolean(true) AS t.v = t.v, array_transform(List([1]), (v) -> v IS NOT NULL OR Boolean(NULL)) AS array_transform(make_array(Int64(1)),(v) -> v = v) 02)--TableScan: t projection=[] physical_plan -01)ProjectionExec: expr=[true as t.v = t.v, array_transform([1], (v) -> v@0 IS NOT NULL OR NULL) as array_transform(make_array(Int64(1)),(v) -> v = v)] +01)ProjectionExec: expr=[true as t.v = t.v, array_transform([1], (v) -> v@-1 IS NOT NULL OR NULL) as array_transform(make_array(Int64(1)),(v) -> v = v)] 02)--DataSourceExec: partitions=1, partition_sizes=[1] query error select array_transform(); ---- -DataFusion error: Error during planning: 'array_transform' does not support zero arguments No function matches the given name and argument types 'array_transform()'. You might need to add explicit type casts. - Candidate functions: - array_transform(Any, Any) +DataFusion error: Execution error: array_transform expects a value follewed by a lambda, got [] -query error DataFusion error: Execution error: expected list, got Field \{ name: "Int64\(1\)", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: \{\} \} +query error DataFusion error: Execution error: expected list, got Field \{ "Int64\(1\)": Int64 \} select array_transform(1, v -> v*2); -query error DataFusion error: Execution error: array_transform expects a value follewed by a lambda, got \[Lambda\(\["v"\], false\), Value\(Field \{ name: "make_array\(Int64\(1\),Int64\(2\)\)", data_type: List\(Field \{ name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: \{\} \}\), nullable: true, dict_id: 0, dict_is_ordered: false, metadata: \{\} \}\)\] +query error DataFusion error: Execution error: array_transform expects a value follewed by a lambda, got \[Lambda, Value\(Field \{ name: "make_array\(Int64\(1\),Int64\(2\)\)", data_type: List\(Field \{ data_type: Int64, nullable: true \}\), nullable: true \}\)\] select array_transform(v -> v*2, [1, 2]); -query error DataFusion error: Execution error: lambdas_schemas: array_transform argument 1 \(0\-indexed\), a lambda, supports up to 2 arguments, but got 3 +query error DataFusion error: Error during planning: lambda defined 3 params but UDF support only 2 SELECT array_transform([1, 2], (e, i, j) -> i) from t; + +#todo: this should error due to duplicate names +query ? +SELECT array_transform([1], (v, v) -> v*2); +---- +[0] diff --git a/datafusion/substrait/src/logical_plan/producer/expr/mod.rs b/datafusion/substrait/src/logical_plan/producer/expr/mod.rs index 103d593cafbc0..b16fd8032877f 100644 --- a/datafusion/substrait/src/logical_plan/producer/expr/mod.rs +++ b/datafusion/substrait/src/logical_plan/producer/expr/mod.rs @@ -153,6 +153,7 @@ pub fn to_substrait_rex( } Expr::Unnest(expr) => not_impl_err!("Cannot convert {expr:?} to Substrait"), Expr::Lambda(expr) => not_impl_err!("Cannot convert {expr:?} to Substrait"), // https://github.com/substrait-io/substrait/issues/349 + Expr::LambdaColumn(expr) => not_impl_err!("Cannot convert {expr:?} to Substrait"), // https://github.com/substrait-io/substrait/issues/349 } } From d844b2d81d9ab5c4bc95671e9134ea6475af79bb Mon Sep 17 00:00:00 2001 From: gstvg <28798827+gstvg@users.noreply.github.com> Date: Fri, 19 Dec 2025 05:07:33 -0300 Subject: [PATCH 04/47] rename LambdaColumn to LambdaVariable --- datafusion/catalog-listing/src/helpers.rs | 2 +- datafusion/expr/src/expr.rs | 16 ++++++------- datafusion/expr/src/expr_schema.rs | 8 +++---- datafusion/expr/src/tree_node.rs | 4 ++-- datafusion/expr/src/utils.rs | 2 +- .../functions-nested/src/array_transform.rs | 20 ++++++++-------- .../optimizer/src/analyzer/type_coercion.rs | 2 +- .../optimizer/src/common_subexpr_eliminate.rs | 4 ++-- datafusion/optimizer/src/push_down_filter.rs | 2 +- .../simplify_expressions/expr_simplifier.rs | 20 ++++++++-------- .../{lambda_column.rs => lambda_variable.rs} | 24 +++++++++---------- .../physical-expr/src/expressions/mod.rs | 4 ++-- datafusion/physical-expr/src/planner.rs | 8 +++---- datafusion/proto/src/logical_plan/to_proto.rs | 2 +- datafusion/sql/src/expr/identifier.rs | 8 +++---- datafusion/sql/src/unparser/expr.rs | 2 +- .../src/logical_plan/producer/expr/mod.rs | 2 +- 17 files changed, 65 insertions(+), 65 deletions(-) rename datafusion/physical-expr/src/expressions/{lambda_column.rs => lambda_variable.rs} (86%) diff --git a/datafusion/catalog-listing/src/helpers.rs b/datafusion/catalog-listing/src/helpers.rs index 78b46171006a7..eca681e3c604c 100644 --- a/datafusion/catalog-listing/src/helpers.rs +++ b/datafusion/catalog-listing/src/helpers.rs @@ -88,7 +88,7 @@ pub fn expr_applicable_for_cols(col_names: &[&str], expr: &Expr) -> bool { | Expr::GroupingSet(_) | Expr::Case(_) | Expr::Lambda(_) - | Expr::LambdaColumn(_) => Ok(TreeNodeRecursion::Continue), + | Expr::LambdaVariable(_) => Ok(TreeNodeRecursion::Continue), Expr::ScalarFunction(scalar_function) => { match scalar_function.func.signature().volatility { diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 6387fc4a44f38..07f1bc129c597 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -400,17 +400,17 @@ pub enum Expr { Unnest(Unnest), /// Lambda expression Lambda(Lambda), - LambdaColumn(LambdaColumn), + LambdaVariable(LambdaVariable), } #[derive(Clone, PartialEq, PartialOrd, Eq, Debug, Hash)] -pub struct LambdaColumn { +pub struct LambdaVariable { pub name: String, pub field: FieldRef, pub spans: Spans, } -impl LambdaColumn { +impl LambdaVariable { pub fn new(name: String, field: FieldRef) -> Self { Self { name, @@ -1567,7 +1567,7 @@ impl Expr { Expr::Wildcard { .. } => "Wildcard", Expr::Unnest { .. } => "Unnest", Expr::Lambda { .. } => "Lambda", - Expr::LambdaColumn { .. } => "LambdaColumn", + Expr::LambdaVariable { .. } => "LambdaVariable", } } @@ -2123,7 +2123,7 @@ impl Expr { | Expr::Literal(..) | Expr::Placeholder(..) | Expr::Lambda(..) - | Expr::LambdaColumn(..) => false, + | Expr::LambdaVariable(..) => false, } } @@ -2722,7 +2722,7 @@ impl HashNode for Expr { Expr::Lambda(Lambda { params, body: _ }) => { params.hash(state); } - Expr::LambdaColumn(LambdaColumn { + Expr::LambdaVariable(LambdaVariable { name, field, spans: _, @@ -3046,7 +3046,7 @@ impl Display for SchemaDisplay<'_> { Expr::Lambda(Lambda { params, body }) => { write!(f, "({}) -> {body}", display_comma_separated(params)) } - Expr::LambdaColumn(c) => { + Expr::LambdaVariable(c) => { write!(f, "{}", c.name) } } @@ -3542,7 +3542,7 @@ impl Display for Expr { Expr::Lambda(Lambda { params, body }) => { write!(f, "({}) -> {body}", params.join(", ")) } - Expr::LambdaColumn(c) => { + Expr::LambdaVariable(c) => { write!(f, "{}", c.name) } } diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index 1b7f5f0212c6b..3e3ff7dacb9d8 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -21,7 +21,7 @@ use crate::expr::{ InSubquery, Lambda, Placeholder, ScalarFunction, TryCast, Unnest, WindowFunction, WindowFunctionParams, }; -use crate::expr::{FieldMetadata, LambdaColumn}; +use crate::expr::{FieldMetadata, LambdaVariable}; use crate::type_coercion::functions::{ fields_with_aggregate_udf, fields_with_window_udf, }; @@ -235,7 +235,7 @@ impl ExprSchemable for Expr { Ok(DataType::Null) } Expr::Lambda(Lambda { params: _, body }) => body.get_type(schema), - Expr::LambdaColumn(LambdaColumn { name: _, field, .. }) => { + Expr::LambdaVariable(LambdaVariable { name: _, field, .. }) => { Ok(field.data_type().clone()) } } @@ -357,7 +357,7 @@ impl ExprSchemable for Expr { Ok(true) } Expr::Lambda(l) => l.body.nullable(input_schema), - Expr::LambdaColumn(c) => Ok(c.field.is_nullable()), + Expr::LambdaVariable(c) => Ok(c.field.is_nullable()), } } @@ -625,7 +625,7 @@ impl ExprSchemable for Expr { self.get_type(schema)?, self.nullable(schema)?, ))), - Expr::LambdaColumn(c) => Ok(Arc::clone(&c.field)), + Expr::LambdaVariable(c) => Ok(Arc::clone(&c.field)), }?; Ok(( diff --git a/datafusion/expr/src/tree_node.rs b/datafusion/expr/src/tree_node.rs index df98f720a0f08..a818c32948d09 100644 --- a/datafusion/expr/src/tree_node.rs +++ b/datafusion/expr/src/tree_node.rs @@ -81,7 +81,7 @@ impl TreeNode for Expr { | Expr::ScalarSubquery(_) | Expr::Wildcard { .. } | Expr::Placeholder(_) - | Expr::LambdaColumn(_) => Ok(TreeNodeRecursion::Continue), + | Expr::LambdaVariable(_) => Ok(TreeNodeRecursion::Continue), Expr::BinaryExpr(BinaryExpr { left, right, .. }) => { (left, right).apply_ref_elements(f) } @@ -133,7 +133,7 @@ impl TreeNode for Expr { | Expr::ScalarSubquery(_) | Expr::ScalarVariable(_, _) | Expr::Literal(_, _) - | Expr::LambdaColumn(_) => Transformed::no(self), + | Expr::LambdaVariable(_) => Transformed::no(self), Expr::Unnest(Unnest { expr, .. }) => expr .map_elements(f)? .update_data(|expr| Expr::Unnest(Unnest { expr })), diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 1e5e4c9f5b99a..ab58b1c3f835f 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -309,7 +309,7 @@ pub fn expr_to_columns(expr: &Expr, accum: &mut HashSet) -> Result<()> { | Expr::Placeholder(_) | Expr::OuterReferenceColumn { .. } | Expr::Lambda(_) - | Expr::LambdaColumn(_) => {} + | Expr::LambdaVariable(_) => {} } Ok(TreeNodeRecursion::Continue) }) diff --git a/datafusion/functions-nested/src/array_transform.rs b/datafusion/functions-nested/src/array_transform.rs index dbc08473c246b..123df27b339be 100644 --- a/datafusion/functions-nested/src/array_transform.rs +++ b/datafusion/functions-nested/src/array_transform.rs @@ -33,7 +33,7 @@ use datafusion_expr::{ ValueOrLambda, ValueOrLambdaField, ValueOrLambdaParameter, Volatility, }; use datafusion_macros::user_doc; -use datafusion_physical_expr::expressions::{LambdaColumn, LambdaExpr}; +use datafusion_physical_expr::expressions::{LambdaVariable, LambdaExpr}; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use std::{any::Any, sync::Arc}; @@ -173,7 +173,7 @@ impl ScalarUDFImpl for ArrayTransform { let values_param = || Ok(Arc::clone(list_values)); let indices_param = || elements_indices(&list_array); - let binded_body = bind_lambda_columns( + let binded_body = bind_lambda_variables( Arc::clone(&lambda.body), &lambda.params, &[&values_param, &indices_param], @@ -275,7 +275,7 @@ impl ScalarUDFImpl for ArrayTransform { } } -fn bind_lambda_columns( +fn bind_lambda_variables( expr: Arc, params: &[FieldRef], args: &[&dyn Fn() -> Result], @@ -284,28 +284,28 @@ fn bind_lambda_columns( .map(|(param, arg)| Ok((param.name().as_str(), (arg()?, 0)))) .collect::>>()?; - expr.rewrite(&mut BindLambdaColumn::new(columns)).data() + expr.rewrite(&mut BindLambdaVariable::new(columns)).data() } -struct BindLambdaColumn<'a> { +struct BindLambdaVariable<'a> { columns: HashMap<&'a str, (ArrayRef, usize)>, } -impl<'a> BindLambdaColumn<'a> { +impl<'a> BindLambdaVariable<'a> { fn new(columns: HashMap<&'a str, (ArrayRef, usize)>) -> Self { Self { columns } } } -impl TreeNodeRewriter for BindLambdaColumn<'_> { +impl TreeNodeRewriter for BindLambdaVariable<'_> { type Node = Arc; fn f_down(&mut self, node: Self::Node) -> Result> { - if let Some(lambda_column) = node.as_any().downcast_ref::() { - if let Some((value, shadows)) = self.columns.get(lambda_column.name()) { + if let Some(lambda_variable) = node.as_any().downcast_ref::() { + if let Some((value, shadows)) = self.columns.get(lambda_variable.name()) { if *shadows == 0 { return Ok(Transformed::yes(Arc::new( - lambda_column.clone().with_value(value.clone()), + lambda_variable.clone().with_value(value.clone()), ))); } } diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 763f693f2f607..626c2ba550594 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -599,7 +599,7 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> { | Expr::Placeholder(_) | Expr::OuterReferenceColumn(_, _) | Expr::Lambda(_) - | Expr::LambdaColumn(_) => Ok(Transformed::no(expr)), + | Expr::LambdaVariable(_) => Ok(Transformed::no(expr)), } } } diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index f95a0f908b813..74e77011b71ed 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -694,7 +694,7 @@ impl CSEController for ExprCSEController<'_> { } fn is_valid(node: &Expr) -> bool { - !node.is_volatile_node() && !matches!(node, Expr::LambdaColumn(_)) + !node.is_volatile_node() && !matches!(node, Expr::LambdaVariable(_)) } fn is_ignored(&self, node: &Expr) -> bool { @@ -707,7 +707,7 @@ impl CSEController for ExprCSEController<'_> { | Expr::ScalarVariable(..) | Expr::Alias(..) | Expr::Wildcard { .. } - | Expr::LambdaColumn(_) + | Expr::LambdaVariable(_) ); let is_aggr = matches!(node, Expr::AggregateFunction(..)); diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index 77c533ce2f01e..b7e8626aa7bd5 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -289,7 +289,7 @@ fn can_evaluate_as_join_condition(predicate: &Expr) -> Result { | Expr::InList { .. } | Expr::ScalarFunction(_) | Expr::Lambda(_) - | Expr::LambdaColumn(_) => Ok(TreeNodeRecursion::Continue), + | Expr::LambdaVariable(_) => Ok(TreeNodeRecursion::Continue), // TODO: remove the next line after `Expr::Wildcard` is removed #[expect(deprecated)] Expr::AggregateFunction(_) diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 779c0acea9963..74794115755ba 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -468,8 +468,8 @@ impl TreeNodeRewriter for Canonicalizer { match (left.as_ref(), right.as_ref(), op.swap()) { // ( - left_col @ (Expr::Column(_) | Expr::LambdaColumn(_)), - right_col @ (Expr::Column(_) | Expr::LambdaColumn(_)), + left_col @ (Expr::Column(_) | Expr::LambdaVariable(_)), + right_col @ (Expr::Column(_) | Expr::LambdaVariable(_)), Some(swapped_op), ) if right_col > left_col => { Ok(Transformed::yes(Expr::BinaryExpr(BinaryExpr { @@ -481,7 +481,7 @@ impl TreeNodeRewriter for Canonicalizer { // ( Expr::Literal(_, _), - Expr::Column(_) | Expr::LambdaColumn(_), + Expr::Column(_) | Expr::LambdaVariable(_), Some(swapped_op), ) => Ok(Transformed::yes(Expr::BinaryExpr(BinaryExpr { left: right, @@ -655,7 +655,7 @@ impl<'a> ConstEvaluator<'a> { | Expr::GroupingSet(_) | Expr::Wildcard { .. } | Expr::Placeholder(_) - | Expr::LambdaColumn(_) => false, + | Expr::LambdaVariable(_) => false, Expr::ScalarFunction(ScalarFunction { func, .. }) => { Self::volatility_ok(func.signature().volatility) } @@ -2012,8 +2012,8 @@ fn are_inlist_and_eq(left: &Expr, right: &Expr) -> bool { let left = as_inlist(left); let right = as_inlist(right); if let (Some(lhs), Some(rhs)) = (left, right) { - matches!(lhs.expr.as_ref(), Expr::Column(_) | Expr::LambdaColumn(_)) - && matches!(rhs.expr.as_ref(), Expr::Column(_) | Expr::LambdaColumn(_)) + matches!(lhs.expr.as_ref(), Expr::Column(_) | Expr::LambdaVariable(_)) + && matches!(rhs.expr.as_ref(), Expr::Column(_) | Expr::LambdaVariable(_)) && lhs.expr == rhs.expr && !lhs.negated && !rhs.negated @@ -2028,14 +2028,14 @@ fn as_inlist(expr: &'_ Expr) -> Option> { Expr::InList(inlist) => Some(Cow::Borrowed(inlist)), Expr::BinaryExpr(BinaryExpr { left, op, right }) if *op == Operator::Eq => { match (left.as_ref(), right.as_ref()) { - (Expr::Column(_) | Expr::LambdaColumn(_), Expr::Literal(_, _)) => { + (Expr::Column(_) | Expr::LambdaVariable(_), Expr::Literal(_, _)) => { Some(Cow::Owned(InList { expr: left.clone(), list: vec![*right.clone()], negated: false, })) } - (Expr::Literal(_, _), Expr::Column(_) | Expr::LambdaColumn(_)) => { + (Expr::Literal(_, _), Expr::Column(_) | Expr::LambdaVariable(_)) => { Some(Cow::Owned(InList { expr: right.clone(), list: vec![*left.clone()], @@ -2057,14 +2057,14 @@ fn to_inlist(expr: Expr) -> Option { op: Operator::Eq, right, }) => match (left.as_ref(), right.as_ref()) { - (Expr::Column(_) | Expr::LambdaColumn(_), Expr::Literal(_, _)) => { + (Expr::Column(_) | Expr::LambdaVariable(_), Expr::Literal(_, _)) => { Some(InList { expr: left, list: vec![*right], negated: false, }) } - (Expr::Literal(_, _), Expr::Column(_) | Expr::LambdaColumn(_)) => { + (Expr::Literal(_, _), Expr::Column(_) | Expr::LambdaVariable(_)) => { Some(InList { expr: right, list: vec![*left], diff --git a/datafusion/physical-expr/src/expressions/lambda_column.rs b/datafusion/physical-expr/src/expressions/lambda_variable.rs similarity index 86% rename from datafusion/physical-expr/src/expressions/lambda_column.rs rename to datafusion/physical-expr/src/expressions/lambda_variable.rs index 4aed16186ba6f..305774c3c02da 100644 --- a/datafusion/physical-expr/src/expressions/lambda_column.rs +++ b/datafusion/physical-expr/src/expressions/lambda_variable.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -//! Physical lambda column reference: [`LambdaColumn`] +//! Physical lambda column reference: [`LambdaVariable`] use std::any::Any; use std::hash::Hash; @@ -33,28 +33,28 @@ use datafusion_expr::ColumnarValue; /// Represents the lambda column with a given name and field #[derive(Debug, Clone)] -pub struct LambdaColumn { +pub struct LambdaVariable { name: String, field: FieldRef, value: Option, } -impl Eq for LambdaColumn {} +impl Eq for LambdaVariable {} -impl PartialEq for LambdaColumn { +impl PartialEq for LambdaVariable { fn eq(&self, other: &Self) -> bool { self.name == other.name && self.field == other.field } } -impl Hash for LambdaColumn { +impl Hash for LambdaVariable { fn hash(&self, state: &mut H) { self.name.hash(state); self.field.hash(state); } } -impl LambdaColumn { +impl LambdaVariable { /// Create a new lambda column expression pub fn new(name: &str, field: FieldRef) -> Self { Self { @@ -83,13 +83,13 @@ impl LambdaColumn { } } -impl std::fmt::Display for LambdaColumn { +impl std::fmt::Display for LambdaVariable { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { write!(f, "{}@-1", self.name) } } -impl PhysicalExpr for LambdaColumn { +impl PhysicalExpr for LambdaVariable { /// Return a reference to Any that can be used for downcasting fn as_any(&self) -> &dyn Any { self @@ -107,7 +107,7 @@ impl PhysicalExpr for LambdaColumn { /// Evaluate the expression fn evaluate(&self, _batch: &RecordBatch) -> Result { - self.value.clone().ok_or_else(|| exec_datafusion_err!("Physical LambdaColumn {} missing value", self.name)) + self.value.clone().ok_or_else(|| exec_datafusion_err!("Physical LambdaVariable {} missing value", self.name)) } fn return_field(&self, _input_schema: &Schema) -> Result { @@ -130,7 +130,7 @@ impl PhysicalExpr for LambdaColumn { } } -/// Create a column expression -pub fn lambda_col(name: &str, field: FieldRef) -> Result> { - Ok(Arc::new(LambdaColumn::new(name, field))) +/// Create a lambda variable expression +pub fn lambda_variable(name: &str, field: FieldRef) -> Result> { + Ok(Arc::new(LambdaVariable::new(name, field))) } diff --git a/datafusion/physical-expr/src/expressions/mod.rs b/datafusion/physical-expr/src/expressions/mod.rs index 5d044ab848550..990e53fa23b2c 100644 --- a/datafusion/physical-expr/src/expressions/mod.rs +++ b/datafusion/physical-expr/src/expressions/mod.rs @@ -23,7 +23,7 @@ mod case; mod cast; mod cast_column; mod column; -mod lambda_column; +mod lambda_variable; mod dynamic_filters; mod in_list; mod is_not_null; @@ -46,7 +46,7 @@ pub use case::{case, CaseExpr}; pub use cast::{cast, CastExpr}; pub use cast_column::CastColumnExpr; pub use column::{col, with_new_schema, Column}; -pub use lambda_column::{lambda_col, LambdaColumn}; +pub use lambda_variable::{lambda_variable, LambdaVariable}; pub use datafusion_expr::utils::format_state_name; pub use dynamic_filters::DynamicFilterPhysicalExpr; pub use in_list::{in_list, InListExpr}; diff --git a/datafusion/physical-expr/src/planner.rs b/datafusion/physical-expr/src/planner.rs index 4c3d1352cce0f..8a53aa81da8fa 100644 --- a/datafusion/physical-expr/src/planner.rs +++ b/datafusion/physical-expr/src/planner.rs @@ -17,7 +17,7 @@ use std::sync::Arc; -use crate::expressions::{lambda_col, LambdaExpr}; +use crate::expressions::{lambda_variable, LambdaExpr}; use crate::ScalarFunctionExpr; use crate::{ expressions::{self, binary, like, similar_to, Column, Literal}, @@ -33,7 +33,7 @@ use datafusion_common::{ }; use datafusion_expr::execution_props::ExecutionProps; use datafusion_expr::expr::{ - Alias, Cast, InList, Lambda, LambdaColumn, Placeholder, ScalarFunction, + Alias, Cast, InList, Lambda, LambdaVariable, Placeholder, ScalarFunction, }; use datafusion_expr::var_provider::is_system_variables; use datafusion_expr::var_provider::VarType; @@ -392,11 +392,11 @@ pub fn create_physical_expr( Expr::Placeholder(Placeholder { id, .. }) => { exec_err!("Placeholder '{id}' was not provided a value for execution.") } - Expr::LambdaColumn(LambdaColumn { + Expr::LambdaVariable(LambdaVariable { name, field, spans: _, - }) => lambda_col( + }) => lambda_variable( name, Arc::clone(field), ), diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index c6e33159aa89f..41a9172fff276 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -622,7 +622,7 @@ pub fn serialize_expr( .unwrap_or(HashMap::new()), })), }, - Expr::Lambda(_) | Expr::LambdaColumn(_) => { + Expr::Lambda(_) | Expr::LambdaVariable(_) => { return Err(Error::General( "Proto serialization error: Lambda not implemented".to_string(), )) diff --git a/datafusion/sql/src/expr/identifier.rs b/datafusion/sql/src/expr/identifier.rs index 73be980d686d0..14433e9cf7eba 100644 --- a/datafusion/sql/src/expr/identifier.rs +++ b/datafusion/sql/src/expr/identifier.rs @@ -20,7 +20,7 @@ use datafusion_common::{ exec_datafusion_err, internal_err, not_impl_err, plan_datafusion_err, plan_err, Column, DFSchema, Result, Span, TableReference, }; -use datafusion_expr::expr::LambdaColumn; +use datafusion_expr::expr::LambdaVariable; use datafusion_expr::planner::PlannerResult; use datafusion_expr::{Case, Expr}; use sqlparser::ast::{CaseWhen, Expr as SQLExpr, Ident}; @@ -58,13 +58,13 @@ impl SqlToRel<'_, S> { .lambdas_parameters() .get(&normalize_ident) { - let mut lambda_column = LambdaColumn::new(normalize_ident, Arc::clone(field)); + let mut lambda_var = LambdaVariable::new(normalize_ident, Arc::clone(field)); if self.options.collect_spans { if let Some(span) = Span::try_from_sqlparser_span(id_span) { - lambda_column.spans_mut().add_span(span); + lambda_var.spans_mut().add_span(span); } } - return Ok(Expr::LambdaColumn(lambda_column)); + return Ok(Expr::LambdaVariable(lambda_var)); } // Check for qualified field with unqualified name diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index 71a1a342a9c5e..013a6f1128957 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -536,7 +536,7 @@ impl Unparser<'_> { body: Box::new(self.expr_to_sql_inner(body)?), })) } - Expr::LambdaColumn(l) => Ok(ast::Expr::Identifier( + Expr::LambdaVariable(l) => Ok(ast::Expr::Identifier( self.new_ident_quoted_if_needs(l.name.clone()), )), } diff --git a/datafusion/substrait/src/logical_plan/producer/expr/mod.rs b/datafusion/substrait/src/logical_plan/producer/expr/mod.rs index b16fd8032877f..b3ca88a690291 100644 --- a/datafusion/substrait/src/logical_plan/producer/expr/mod.rs +++ b/datafusion/substrait/src/logical_plan/producer/expr/mod.rs @@ -153,7 +153,7 @@ pub fn to_substrait_rex( } Expr::Unnest(expr) => not_impl_err!("Cannot convert {expr:?} to Substrait"), Expr::Lambda(expr) => not_impl_err!("Cannot convert {expr:?} to Substrait"), // https://github.com/substrait-io/substrait/issues/349 - Expr::LambdaColumn(expr) => not_impl_err!("Cannot convert {expr:?} to Substrait"), // https://github.com/substrait-io/substrait/issues/349 + Expr::LambdaVariable(expr) => not_impl_err!("Cannot convert {expr:?} to Substrait"), // https://github.com/substrait-io/substrait/issues/349 } } From e1921eb377a56766b2dd4729311def1c3515c8c4 Mon Sep 17 00:00:00 2001 From: gstvg <28798827+gstvg@users.noreply.github.com> Date: Mon, 23 Feb 2026 14:18:01 -0300 Subject: [PATCH 05/47] feat: add LambdaUDF --- datafusion-examples/examples/sql_frontend.rs | 8 +- datafusion/catalog-listing/src/helpers.rs | 18 +- .../core/src/bin/print_functions_docs.rs | 15 +- datafusion/core/src/execution/context/mod.rs | 16 + .../core/src/execution/session_state.rs | 87 ++- .../src/execution/session_state_defaults.rs | 8 +- datafusion/core/tests/optimizer/mod.rs | 7 +- .../datasource-arrow/src/file_format.rs | 6 +- datafusion/datasource/src/url.rs | 6 +- datafusion/execution/src/task.rs | 41 +- datafusion/expr/src/expr.rs | 49 ++ datafusion/expr/src/expr_schema.rs | 39 ++ datafusion/expr/src/lib.rs | 8 +- datafusion/expr/src/planner.rs | 6 +- datafusion/expr/src/registry.rs | 48 ++ datafusion/expr/src/tree_node.rs | 8 +- datafusion/expr/src/udf_eq.rs | 14 +- datafusion/expr/src/udlf.rs | 649 ++++++++++++++++++ datafusion/expr/src/utils.rs | 1 + .../functions-nested/src/array_transform.rs | 52 +- datafusion/functions-nested/src/lib.rs | 4 +- .../optimizer/src/analyzer/type_coercion.rs | 1 + datafusion/optimizer/src/push_down_filter.rs | 1 + .../simplify_expressions/expr_simplifier.rs | 4 + .../optimizer/tests/optimizer_integration.rs | 4 + .../physical-expr/src/lambda_function.rs | 530 ++++++++++++++ datafusion/physical-expr/src/lib.rs | 2 + datafusion/physical-expr/src/planner.rs | 28 +- datafusion/proto/src/bytes/mod.rs | 60 +- datafusion/proto/src/bytes/registry.rs | 10 + datafusion/proto/src/logical_plan/mod.rs | 10 +- datafusion/proto/src/logical_plan/to_proto.rs | 11 + datafusion/session/src/session.rs | 6 +- datafusion/sql/examples/sql.rs | 6 +- datafusion/sql/src/expr/function.rs | 57 +- datafusion/sql/src/expr/mod.rs | 6 +- datafusion/sql/src/unparser/expr.rs | 18 +- datafusion/sql/tests/common/mod.rs | 7 +- .../src/logical_plan/producer/expr/mod.rs | 5 +- .../producer/expr/scalar_function.rs | 23 +- .../producer/substrait_producer.rs | 15 +- 41 files changed, 1812 insertions(+), 82 deletions(-) create mode 100644 datafusion/expr/src/udlf.rs create mode 100644 datafusion/physical-expr/src/lambda_function.rs diff --git a/datafusion-examples/examples/sql_frontend.rs b/datafusion-examples/examples/sql_frontend.rs index 1fc9ce24ecbb5..e6341080a9e11 100644 --- a/datafusion-examples/examples/sql_frontend.rs +++ b/datafusion-examples/examples/sql_frontend.rs @@ -20,8 +20,8 @@ use datafusion::common::{plan_err, TableReference}; use datafusion::config::ConfigOptions; use datafusion::error::Result; use datafusion::logical_expr::{ - AggregateUDF, Expr, LogicalPlan, ScalarUDF, TableProviderFilterPushDown, TableSource, - WindowUDF, + AggregateUDF, Expr, LambdaUDF, LogicalPlan, ScalarUDF, TableProviderFilterPushDown, + TableSource, WindowUDF, }; use datafusion::optimizer::{ Analyzer, AnalyzerRule, Optimizer, OptimizerConfig, OptimizerContext, OptimizerRule, @@ -153,6 +153,10 @@ impl ContextProvider for MyContextProvider { None } + fn get_lambda_meta(&self, _name: &str) -> Option> { + None + } + fn get_aggregate_meta(&self, _name: &str) -> Option> { None } diff --git a/datafusion/catalog-listing/src/helpers.rs b/datafusion/catalog-listing/src/helpers.rs index eca681e3c604c..34fee4eb6bd41 100644 --- a/datafusion/catalog-listing/src/helpers.rs +++ b/datafusion/catalog-listing/src/helpers.rs @@ -100,6 +100,16 @@ pub fn expr_applicable_for_cols(col_names: &[&str], expr: &Expr) -> bool { } } } + Expr::LambdaFunction(lambda_function) => { + match lambda_function.func.signature().volatility { + Volatility::Immutable => Ok(TreeNodeRecursion::Continue), + // TODO: Stable functions could be `applicable`, but that would require access to the context + Volatility::Stable | Volatility::Volatile => { + is_applicable = false; + Ok(TreeNodeRecursion::Stop) + } + } + } // TODO other expressions are not handled yet: // - AGGREGATE and WINDOW should not end up in filter conditions, except maybe in some edge cases @@ -555,7 +565,7 @@ mod tests { use super::*; use datafusion_expr::{ - case, col, lit, AggregateUDF, Expr, LogicalPlan, ScalarUDF, WindowUDF, + case, col, lit, AggregateUDF, Expr, LambdaUDF, LogicalPlan, ScalarUDF, WindowUDF, }; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use datafusion_physical_plan::ExecutionPlan; @@ -1066,6 +1076,12 @@ mod tests { unimplemented!() } + fn lambda_functions( + &self, + ) -> &std::collections::HashMap> { + unimplemented!() + } + fn aggregate_functions( &self, ) -> &std::collections::HashMap> { diff --git a/datafusion/core/src/bin/print_functions_docs.rs b/datafusion/core/src/bin/print_functions_docs.rs index 63387c023b11a..97282edf49381 100644 --- a/datafusion/core/src/bin/print_functions_docs.rs +++ b/datafusion/core/src/bin/print_functions_docs.rs @@ -18,8 +18,7 @@ use datafusion::execution::SessionStateDefaults; use datafusion_common::{not_impl_err, HashSet, Result}; use datafusion_expr::{ - aggregate_doc_sections, scalar_doc_sections, window_doc_sections, AggregateUDF, - DocSection, Documentation, ScalarUDF, WindowUDF, + AggregateUDF, DocSection, Documentation, LambdaUDF, ScalarUDF, WindowUDF, aggregate_doc_sections, scalar_doc_sections, window_doc_sections }; use itertools::Itertools; use std::env::args; @@ -303,6 +302,18 @@ impl DocProvider for WindowUDF { } } +impl DocProvider for dyn LambdaUDF { + fn get_name(&self) -> String { + self.name().to_string() + } + fn get_aliases(&self) -> Vec { + self.aliases().iter().map(|a| a.to_string()).collect() + } + fn get_documentation(&self) -> Option<&Documentation> { + self.documentation() + } +} + #[allow(clippy::borrowed_box)] #[allow(clippy::ptr_arg)] fn get_names_and_aliases(functions: &Vec<&Box>) -> Vec { diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index 687779787ab50..083ecdaf575af 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -75,6 +75,7 @@ use datafusion_common::{ pub use datafusion_execution::config::SessionConfig; use datafusion_execution::registry::SerializerRegistry; pub use datafusion_execution::TaskContext; +use datafusion_expr::LambdaUDF; pub use datafusion_expr::execution_props::ExecutionProps; use datafusion_expr::{ expr_rewriter::FunctionRewrite, @@ -1786,6 +1787,21 @@ impl FunctionRegistry for SessionContext { fn udwfs(&self) -> HashSet { self.state.read().udwfs() } + + fn udlfs(&self) -> HashSet { + self.state.read().udlfs() + } + + fn udlf(&self, name: &str) -> Result> { + self.state.read().udlf(name) + } + + fn register_udlf( + &mut self, + udlf: Arc, + ) -> Result>> { + self.state.write().register_udlf(udlf) + } } /// Create a new task context instance from SessionContext diff --git a/datafusion/core/src/execution/session_state.rs b/datafusion/core/src/execution/session_state.rs index c15b7eae08432..f33f3d3412f4d 100644 --- a/datafusion/core/src/execution/session_state.rs +++ b/datafusion/core/src/execution/session_state.rs @@ -59,7 +59,7 @@ use datafusion_expr::simplify::SimplifyInfo; #[cfg(feature = "sql")] use datafusion_expr::TableSource; use datafusion_expr::{ - AggregateUDF, Explain, Expr, ExprSchemable, LogicalPlan, ScalarUDF, WindowUDF, + AggregateUDF, Explain, Expr, ExprSchemable, LambdaUDF, LogicalPlan, ScalarUDF, WindowUDF }; use datafusion_optimizer::simplify_expressions::ExprSimplifier; use datafusion_optimizer::{ @@ -154,6 +154,8 @@ pub struct SessionState { table_functions: HashMap>, /// Scalar functions that are registered with the context scalar_functions: HashMap>, + /// Lambda functions that are registered with the context + lambda_functions: HashMap>, /// Aggregate functions registered in the context aggregate_functions: HashMap>, /// Window functions registered in the context @@ -252,6 +254,10 @@ impl Session for SessionState { fn scalar_functions(&self) -> &HashMap> { &self.scalar_functions } + + fn lambda_functions(&self) -> &HashMap> { + &self.lambda_functions + } fn aggregate_functions(&self) -> &HashMap> { &self.aggregate_functions @@ -921,6 +927,7 @@ pub struct SessionStateBuilder { catalog_list: Option>, table_functions: Option>>, scalar_functions: Option>>, + lambda_functions: Option>>, aggregate_functions: Option>>, window_functions: Option>>, serializer_registry: Option>, @@ -958,6 +965,7 @@ impl SessionStateBuilder { catalog_list: None, table_functions: None, scalar_functions: None, + lambda_functions: None, aggregate_functions: None, window_functions: None, serializer_registry: None, @@ -1008,6 +1016,7 @@ impl SessionStateBuilder { catalog_list: Some(existing.catalog_list), table_functions: Some(existing.table_functions), scalar_functions: Some(existing.scalar_functions.into_values().collect_vec()), + lambda_functions: Some(existing.lambda_functions.into_values().collect_vec()), aggregate_functions: Some( existing.aggregate_functions.into_values().collect_vec(), ), @@ -1048,6 +1057,10 @@ impl SessionStateBuilder { self.scalar_functions .get_or_insert_with(Vec::new) .extend(SessionStateDefaults::default_scalar_functions()); + + self.lambda_functions + .get_or_insert_with(Vec::new) + .extend(SessionStateDefaults::default_lambda_functions()); self.aggregate_functions .get_or_insert_with(Vec::new) @@ -1362,6 +1375,7 @@ impl SessionStateBuilder { catalog_list, table_functions, scalar_functions, + lambda_functions, aggregate_functions, window_functions, serializer_registry, @@ -1395,6 +1409,7 @@ impl SessionStateBuilder { }), table_functions: table_functions.unwrap_or_default(), scalar_functions: HashMap::new(), + lambda_functions: HashMap::new(), aggregate_functions: HashMap::new(), window_functions: HashMap::new(), serializer_registry: serializer_registry @@ -1446,6 +1461,34 @@ impl SessionStateBuilder { } } } + + if let Some(lambda_functions) = lambda_functions { + for udlf in lambda_functions { + let config_options = state.config().options(); + match udlf.with_updated_config(config_options) { + Some(new_udf) => { + if let Err(err) = state.register_udlf(new_udf) { + debug!( + "Failed to re-register updated UDLF '{}': {}", + udlf.name(), + err + ); + } + } + None => match state.register_udlf(Arc::clone(&udlf)) { + Ok(Some(existing)) => { + debug!("Overwrote existing UDLF '{}'", existing.name()); + } + Ok(None) => { + debug!("Registered UDLF '{}'", udlf.name()); + } + Err(err) => { + debug!("Failed to register UDLF '{}': {}", udlf.name(), err); + } + }, + } + } + } if let Some(aggregate_functions) = aggregate_functions { aggregate_functions.into_iter().for_each(|udaf| { @@ -1661,6 +1704,7 @@ impl Debug for SessionStateBuilder { .field("physical_optimizers", &self.physical_optimizers) .field("table_functions", &self.table_functions) .field("scalar_functions", &self.scalar_functions) + .field("lambda_functions", &self.lambda_functions) .field("aggregate_functions", &self.aggregate_functions) .field("window_functions", &self.window_functions) .finish() @@ -1755,6 +1799,10 @@ impl ContextProvider for SessionContextProvider<'_> { self.state.scalar_functions().get(name).cloned() } + fn get_lambda_meta(&self, name: &str) -> Option> { + self.state.lambda_functions().get(name).cloned() + } + fn get_aggregate_meta(&self, name: &str) -> Option> { self.state.aggregate_functions().get(name).cloned() } @@ -1918,6 +1966,37 @@ impl FunctionRegistry for SessionState { Ok(udwf) } + fn udlfs(&self) -> HashSet { + self.lambda_functions.keys().cloned().collect() + } + + fn udlf(&self, name: &str) -> datafusion_common::Result> { + self.lambda_functions + .get(name) + .cloned() + .ok_or_else(|| plan_datafusion_err!("Lambda Function {name} not found")) + } + + fn register_udlf( + &mut self, + udlf: Arc, + ) -> datafusion_common::Result>> { + Ok(self.lambda_functions.insert(udlf.name().into(), udlf)) + } + + fn deregister_udlf( + &mut self, + name: &str, + ) -> datafusion_common::Result>> { + let udlf = self.lambda_functions.remove(name); + if let Some(udlf) = &udlf { + for alias in udlf.aliases() { + self.lambda_functions.remove(alias); + } + } + Ok(udlf) + } + fn register_function_rewrite( &mut self, rewrite: Arc, @@ -1974,6 +2053,7 @@ impl From<&SessionState> for TaskContext { state.session_id.clone(), state.config.clone(), state.scalar_functions.clone(), + state.lambda_functions.clone(), state.aggregate_functions.clone(), state.window_functions.clone(), Arc::clone(&state.runtime_env), @@ -2062,6 +2142,7 @@ mod tests { use datafusion_optimizer::optimizer::OptimizerRule; use datafusion_optimizer::Optimizer; use datafusion_physical_plan::display::DisplayableExecutionPlan; + use datafusion_session::Session; use datafusion_sql::planner::{PlannerContext, SqlToRel}; use std::collections::HashMap; use std::sync::Arc; @@ -2338,6 +2419,10 @@ mod tests { self.state.scalar_functions().get(name).cloned() } + fn get_lambda_meta(&self, name: &str) -> Option> { + self.state.lambda_functions().get(name).cloned() + } + fn get_aggregate_meta(&self, name: &str) -> Option> { self.state.aggregate_functions().get(name).cloned() } diff --git a/datafusion/core/src/execution/session_state_defaults.rs b/datafusion/core/src/execution/session_state_defaults.rs index 62a575541a5d8..54037c0a96f9c 100644 --- a/datafusion/core/src/execution/session_state_defaults.rs +++ b/datafusion/core/src/execution/session_state_defaults.rs @@ -36,7 +36,8 @@ use datafusion_execution::config::SessionConfig; use datafusion_execution::object_store::ObjectStoreUrl; use datafusion_execution::runtime_env::RuntimeEnv; use datafusion_expr::planner::ExprPlanner; -use datafusion_expr::{AggregateUDF, ScalarUDF, WindowUDF}; +use datafusion_expr::{AggregateUDF, LambdaUDF, ScalarUDF, WindowUDF}; +use datafusion_functions_nested::array_transform::ArrayTransform; use std::collections::HashMap; use std::sync::Arc; use url::Url; @@ -112,6 +113,11 @@ impl SessionStateDefaults { functions } + /// returns the list of default [`LambdaUDF`]s + pub fn default_lambda_functions() -> Vec> { + vec![Arc::new(ArrayTransform::new())] + } + /// returns the list of default [`AggregateUDF`]s pub fn default_aggregate_functions() -> Vec> { functions_aggregate::all_default_aggregate_functions() diff --git a/datafusion/core/tests/optimizer/mod.rs b/datafusion/core/tests/optimizer/mod.rs index 9b2a5596827d0..44e40143fe171 100644 --- a/datafusion/core/tests/optimizer/mod.rs +++ b/datafusion/core/tests/optimizer/mod.rs @@ -31,8 +31,7 @@ use datafusion_common::tree_node::{TransformedResult, TreeNode}; use datafusion_common::{plan_err, DFSchema, Result, ScalarValue, TableReference}; use datafusion_expr::interval_arithmetic::{Interval, NullableInterval}; use datafusion_expr::{ - col, lit, AggregateUDF, BinaryExpr, Expr, ExprSchemable, LogicalPlan, Operator, - ScalarUDF, TableSource, WindowUDF, + AggregateUDF, BinaryExpr, Expr, ExprSchemable, LambdaUDF, LogicalPlan, Operator, ScalarUDF, TableSource, WindowUDF, col, lit }; use datafusion_functions::core::expr_ext::FieldAccessor; use datafusion_optimizer::analyzer::Analyzer; @@ -217,6 +216,10 @@ impl ContextProvider for MyContextProvider { self.udfs.get(name).cloned() } + fn get_lambda_meta(&self, _name: &str) -> Option> { + None + } + fn get_aggregate_meta(&self, _name: &str) -> Option> { None } diff --git a/datafusion/datasource-arrow/src/file_format.rs b/datafusion/datasource-arrow/src/file_format.rs index 3b85640804219..31a880e688aeb 100644 --- a/datafusion/datasource-arrow/src/file_format.rs +++ b/datafusion/datasource-arrow/src/file_format.rs @@ -442,7 +442,7 @@ mod tests { use datafusion_execution::config::SessionConfig; use datafusion_execution::runtime_env::RuntimeEnv; use datafusion_expr::execution_props::ExecutionProps; - use datafusion_expr::{AggregateUDF, Expr, LogicalPlan, ScalarUDF, WindowUDF}; + use datafusion_expr::{AggregateUDF, Expr, LambdaUDF, LogicalPlan, ScalarUDF, WindowUDF}; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use object_store::{chunked::ChunkedStore, memory::InMemory, path::Path}; @@ -488,6 +488,10 @@ mod tests { fn scalar_functions(&self) -> &HashMap> { unimplemented!() } + + fn lambda_functions(&self) -> &HashMap> { + unimplemented!() + } fn aggregate_functions(&self) -> &HashMap> { unimplemented!() diff --git a/datafusion/datasource/src/url.rs b/datafusion/datasource/src/url.rs index 08e5b6a5df83a..b0b84bd3cc943 100644 --- a/datafusion/datasource/src/url.rs +++ b/datafusion/datasource/src/url.rs @@ -415,7 +415,7 @@ mod tests { use datafusion_execution::runtime_env::RuntimeEnv; use datafusion_execution::TaskContext; use datafusion_expr::execution_props::ExecutionProps; - use datafusion_expr::{AggregateUDF, Expr, LogicalPlan, ScalarUDF, WindowUDF}; + use datafusion_expr::{AggregateUDF, Expr, LambdaUDF, LogicalPlan, ScalarUDF, WindowUDF}; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use datafusion_physical_plan::ExecutionPlan; use object_store::{ @@ -874,6 +874,10 @@ mod tests { unimplemented!() } + fn lambda_functions(&self) -> &HashMap> { + unimplemented!() + } + fn aggregate_functions(&self) -> &HashMap> { unimplemented!() } diff --git a/datafusion/execution/src/task.rs b/datafusion/execution/src/task.rs index c2a6cfe2c833f..70c59b6375943 100644 --- a/datafusion/execution/src/task.rs +++ b/datafusion/execution/src/task.rs @@ -21,7 +21,7 @@ use crate::{ }; use datafusion_common::{internal_datafusion_err, plan_datafusion_err, Result}; use datafusion_expr::planner::ExprPlanner; -use datafusion_expr::{AggregateUDF, ScalarUDF, WindowUDF}; +use datafusion_expr::{AggregateUDF, ScalarUDF, LambdaUDF, WindowUDF}; use std::collections::HashSet; use std::{collections::HashMap, sync::Arc}; @@ -42,6 +42,8 @@ pub struct TaskContext { session_config: SessionConfig, /// Scalar functions associated with this task context scalar_functions: HashMap>, + /// Lambda functions associated with this task context + lambda_functions: HashMap>, /// Aggregate functions associated with this task context aggregate_functions: HashMap>, /// Window functions associated with this task context @@ -60,6 +62,7 @@ impl Default for TaskContext { task_id: None, session_config: SessionConfig::new(), scalar_functions: HashMap::new(), + lambda_functions: HashMap::new(), aggregate_functions: HashMap::new(), window_functions: HashMap::new(), runtime, @@ -73,11 +76,13 @@ impl TaskContext { /// Most users will use [`SessionContext::task_ctx`] to create [`TaskContext`]s /// /// [`SessionContext::task_ctx`]: https://docs.rs/datafusion/latest/datafusion/execution/context/struct.SessionContext.html#method.task_ctx + #[allow(clippy::too_many_arguments)] pub fn new( task_id: Option, session_id: String, session_config: SessionConfig, scalar_functions: HashMap>, + lambda_functions: HashMap>, aggregate_functions: HashMap>, window_functions: HashMap>, runtime: Arc, @@ -87,6 +92,7 @@ impl TaskContext { session_id, session_config, scalar_functions, + lambda_functions, aggregate_functions, window_functions, runtime, @@ -198,6 +204,37 @@ impl FunctionRegistry for TaskContext { Ok(self.scalar_functions.insert(udf.name().into(), udf)) } + fn udlfs(&self) -> HashSet { + self.lambda_functions.keys().cloned().collect() + } + + fn udlf(&self, name: &str) -> Result> { + self.lambda_functions + .get(name) + .cloned() + .ok_or_else(|| plan_datafusion_err!("Lambda Function {name} not found")) + } + + fn register_udlf( + &mut self, + udlf: Arc, + ) -> Result>> { + Ok(self.lambda_functions.insert(udlf.name().into(), udlf)) + } + + fn deregister_udlf( + &mut self, + name: &str, + ) -> Result>> { + let udlf = self.lambda_functions.remove(name); + if let Some(udlf) = &udlf { + for alias in udlf.aliases() { + self.lambda_functions.remove(alias); + } + } + Ok(udlf) + } + fn expr_planners(&self) -> Vec> { vec![] } @@ -248,6 +285,7 @@ mod tests { HashMap::default(), HashMap::default(), HashMap::default(), + HashMap::default(), runtime, ); @@ -280,6 +318,7 @@ mod tests { HashMap::default(), HashMap::default(), HashMap::default(), + HashMap::default(), runtime, ); diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 07f1bc129c597..ef4ba4ef5cdfd 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -27,6 +27,7 @@ use std::sync::Arc; use crate::expr_fn::binary_expr; use crate::function::WindowFunctionSimplification; use crate::logical_plan::Subquery; +use crate::udlf::LambdaUDF; use crate::{AggregateUDF, Volatility}; use crate::{ExprSchemable, Operator, Signature, WindowFrame, WindowUDF}; @@ -398,11 +399,41 @@ pub enum Expr { OuterReferenceColumn(FieldRef, Column), /// Unnest expression Unnest(Unnest), + LambdaFunction(LambdaFunction), /// Lambda expression Lambda(Lambda), LambdaVariable(LambdaVariable), } +#[derive(Clone, Eq, PartialOrd, Debug)] +pub struct LambdaFunction { + pub func: Arc, + pub args: Vec, +} + +impl LambdaFunction { + pub fn new(func: Arc, args: Vec) -> Self { + Self { func, args } + } + + pub fn name(&self) -> &str { + self.func.name() + } +} + +impl Hash for LambdaFunction { + fn hash(&self, state: &mut H) { + self.func.hash(state); + self.args.hash(state); + } +} + +impl PartialEq for LambdaFunction { + fn eq(&self, other: &Self) -> bool { + self.func.as_ref() == other.func.as_ref() && self.args == other.args + } +} + #[derive(Clone, PartialEq, PartialOrd, Eq, Debug, Hash)] pub struct LambdaVariable { pub name: String, @@ -1566,6 +1597,7 @@ impl Expr { #[expect(deprecated)] Expr::Wildcard { .. } => "Wildcard", Expr::Unnest { .. } => "Unnest", + Expr::LambdaFunction { .. } => "LambdaFunction", Expr::Lambda { .. } => "Lambda", Expr::LambdaVariable { .. } => "LambdaVariable", } @@ -2083,6 +2115,7 @@ impl Expr { pub fn short_circuits(&self) -> bool { match self { Expr::ScalarFunction(ScalarFunction { func, .. }) => func.short_circuits(), + Expr::LambdaFunction(LambdaFunction { func, .. }) => func.short_circuits(), Expr::BinaryExpr(BinaryExpr { op, .. }) => { matches!(op, Operator::And | Operator::Or) } @@ -2719,6 +2752,9 @@ impl HashNode for Expr { column.hash(state); } Expr::Unnest(Unnest { expr: _expr }) => {} + Expr::LambdaFunction(LambdaFunction { func, args: _args }) => { + func.hash(state); + } Expr::Lambda(Lambda { params, body: _ }) => { params.hash(state); } @@ -3043,6 +3079,16 @@ impl Display for SchemaDisplay<'_> { } } } + Expr::LambdaFunction(LambdaFunction { func, args }) => { + match func.schema_name(args) { + Ok(name) => { + write!(f, "{name}") + } + Err(e) => { + write!(f, "got error from schema_name {e}") + } + } + } Expr::Lambda(Lambda { params, body }) => { write!(f, "({}) -> {body}", display_comma_separated(params)) } @@ -3539,6 +3585,9 @@ impl Display for Expr { Expr::Unnest(Unnest { expr }) => { write!(f, "{UNNEST_COLUMN_PREFIX}({expr})") } + Expr::LambdaFunction(fun) => { + fmt_function(f, fun.name(), false, &fun.args, true) + } Expr::Lambda(Lambda { params, body }) => { write!(f, "({}) -> {body}", params.join(", ")) } diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index 3e3ff7dacb9d8..6a3a7cbc85e76 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -25,6 +25,7 @@ use crate::expr::{FieldMetadata, LambdaVariable}; use crate::type_coercion::functions::{ fields_with_aggregate_udf, fields_with_window_udf, }; +use crate::udlf::{LambdaReturnFieldArgs, ValueOrLambdaField}; use crate::{ type_coercion::functions::data_types_with_scalar_udf, udf::ReturnFieldArgs, utils, LogicalPlan, Projection, Subquery, WindowFunctionDefinition, @@ -234,6 +235,10 @@ impl ExprSchemable for Expr { // Grouping sets do not really have a type and do not appear in projections Ok(DataType::Null) } + Expr::LambdaFunction(_func) => { + let (return_type, _) = self.data_type_and_nullable(schema)?; + Ok(return_type) + } Expr::Lambda(Lambda { params: _, body }) => body.get_type(schema), Expr::LambdaVariable(LambdaVariable { name: _, field, .. }) => { Ok(field.data_type().clone()) @@ -356,6 +361,10 @@ impl ExprSchemable for Expr { // in projections Ok(true) } + Expr::LambdaFunction(_func) => { + let (_, nullable) = self.data_type_and_nullable(input_schema)?; + Ok(nullable) + } Expr::Lambda(l) => l.body.nullable(input_schema), Expr::LambdaVariable(c) => Ok(c.field.is_nullable()), } @@ -625,6 +634,36 @@ impl ExprSchemable for Expr { self.get_type(schema)?, self.nullable(schema)?, ))), + Expr::LambdaFunction(func) => { + let arg_fields = func + .args + .iter() + .map(|arg| { + let field = arg.to_field(schema)?.1; + match arg { + Expr::Lambda(_lambda) => { + Ok(ValueOrLambdaField::Lambda(field)) + } + _ => Ok(ValueOrLambdaField::Value(field)), + } + }) + .collect::>>()?; + + let arguments = func.args + .iter() + .map(|e| match e { + Expr::Literal(sv, _) => Some(sv), + _ => None, + }) + .collect::>(); + + let args = LambdaReturnFieldArgs { + arg_fields: &arg_fields, + scalar_arguments: &arguments, + }; + + func.func.return_field_from_args(args) + } Expr::LambdaVariable(c) => Ok(Arc::clone(&c.field)), }?; diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index 0f26218e74779..e6dfd9fc1483b 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -42,6 +42,7 @@ mod partition_evaluator; mod table_source; mod udaf; mod udf; +mod udlf; mod udwf; pub mod arguments; @@ -117,9 +118,10 @@ pub use udaf::{ udaf_default_window_function_schema_name, AggregateUDF, AggregateUDFImpl, ReversedUDAF, SetMonotonicity, StatisticsArgs, }; -pub use udf::{ - ReturnFieldArgs, ScalarFunctionArgs, ScalarFunctionLambdaArg, ScalarUDF, - ScalarUDFImpl, ValueOrLambda, ValueOrLambdaField, ValueOrLambdaParameter, +pub use udf::{ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl}; +pub use udlf::{ + LambdaFunctionArgs, LambdaFunctionLambdaArg, LambdaReturnFieldArgs, LambdaUDF, + ValueOrLambda, ValueOrLambdaField, ValueOrLambdaParameter, }; pub use udwf::{LimitEffect, ReversedUDWF, WindowUDF, WindowUDFImpl}; pub use window_frame::{WindowFrame, WindowFrameBound, WindowFrameUnits}; diff --git a/datafusion/expr/src/planner.rs b/datafusion/expr/src/planner.rs index 25a0f83947eee..7696faca0922a 100644 --- a/datafusion/expr/src/planner.rs +++ b/datafusion/expr/src/planner.rs @@ -22,8 +22,7 @@ use std::sync::Arc; use crate::expr::NullTreatment; use crate::{ - AggregateUDF, Expr, GetFieldAccess, ScalarUDF, SortExpr, TableSource, WindowFrame, - WindowFunctionDefinition, WindowUDF, + AggregateUDF, Expr, GetFieldAccess, LambdaUDF, ScalarUDF, SortExpr, TableSource, WindowFrame, WindowFunctionDefinition, WindowUDF }; use arrow::datatypes::{DataType, Field, SchemaRef}; use datafusion_common::{ @@ -91,6 +90,9 @@ pub trait ContextProvider { /// Return the scalar function with a given name, if any fn get_function_meta(&self, name: &str) -> Option>; + + /// Return the lambda function with a given name, if any + fn get_lambda_meta(&self, name: &str) -> Option>; /// Return the aggregate function with a given name, if any fn get_aggregate_meta(&self, name: &str) -> Option>; diff --git a/datafusion/expr/src/registry.rs b/datafusion/expr/src/registry.rs index 9554dd68e1758..92aa39d64c98d 100644 --- a/datafusion/expr/src/registry.rs +++ b/datafusion/expr/src/registry.rs @@ -19,6 +19,7 @@ use crate::expr_rewriter::FunctionRewrite; use crate::planner::ExprPlanner; +use crate::udlf::LambdaUDF; use crate::{AggregateUDF, ScalarUDF, UserDefinedLogicalNode, WindowUDF}; use datafusion_common::{not_impl_err, plan_datafusion_err, HashMap, Result}; use std::collections::HashSet; @@ -30,6 +31,9 @@ pub trait FunctionRegistry { /// Returns names of all available scalar user defined functions. fn udfs(&self) -> HashSet; + /// Returns names of all available lambda user defined functions. + fn udlfs(&self) -> HashSet; + /// Returns names of all available aggregate user defined functions. fn udafs(&self) -> HashSet; @@ -40,6 +44,10 @@ pub trait FunctionRegistry { /// `name`. fn udf(&self, name: &str) -> Result>; + /// Returns a reference to the user defined lambda function (udf) named + /// `name`. + fn udlf(&self, name: &str) -> Result>; + /// Returns a reference to the user defined aggregate function (udaf) named /// `name`. fn udaf(&self, name: &str) -> Result>; @@ -56,6 +64,17 @@ pub trait FunctionRegistry { fn register_udf(&mut self, _udf: Arc) -> Result>> { not_impl_err!("Registering ScalarUDF") } + /// Registers a new [`LambdaUDF`], returning any previously registered + /// implementation. + /// + /// Returns an error (the default) if the function can not be registered, + /// for example if the registry is read only. + fn register_udlf( + &mut self, + _udlf: Arc, + ) -> Result>> { + not_impl_err!("Registering LambdaUDF") + } /// Registers a new [`AggregateUDF`], returning any previously registered /// implementation. /// @@ -85,6 +104,15 @@ pub trait FunctionRegistry { not_impl_err!("Deregistering ScalarUDF") } + /// Deregisters a [`LambdaUDF`], returning the implementation that was + /// deregistered. + /// + /// Returns an error (the default) if the function can not be deregistered, + /// for example if the registry is read only. + fn deregister_udlf(&mut self, _name: &str) -> Result>> { + not_impl_err!("Deregistering LambdaUDF") + } + /// Deregisters a [`AggregateUDF`], returning the implementation that was /// deregistered. /// @@ -152,6 +180,8 @@ pub trait SerializerRegistry: Debug + Send + Sync { pub struct MemoryFunctionRegistry { /// Scalar Functions udfs: HashMap>, + /// Lambda Functions + udlfs: HashMap>, /// Aggregate Functions udafs: HashMap>, /// Window Functions @@ -214,4 +244,22 @@ impl FunctionRegistry for MemoryFunctionRegistry { fn udwfs(&self) -> HashSet { self.udwfs.keys().cloned().collect() } + + fn udlfs(&self) -> HashSet { + self.udlfs.keys().cloned().collect() + } + + fn udlf(&self, name: &str) -> Result> { + self.udlfs + .get(name) + .cloned() + .ok_or_else(|| plan_datafusion_err!("Lambda Function {name} not found")) + } + + fn register_udlf( + &mut self, + udlf: Arc, + ) -> Result>> { + Ok(self.udlfs.insert(udlf.name().into(), udlf)) + } } diff --git a/datafusion/expr/src/tree_node.rs b/datafusion/expr/src/tree_node.rs index a818c32948d09..82179f095937b 100644 --- a/datafusion/expr/src/tree_node.rs +++ b/datafusion/expr/src/tree_node.rs @@ -20,8 +20,8 @@ use crate::{ expr::{ AggregateFunction, AggregateFunctionParams, Alias, Between, BinaryExpr, Case, - Cast, GroupingSet, InList, InSubquery, Lambda, Like, Placeholder, ScalarFunction, - TryCast, Unnest, WindowFunction, WindowFunctionParams, + Cast, GroupingSet, InList, InSubquery, Lambda, LambdaFunction, Like, Placeholder, + ScalarFunction, TryCast, Unnest, WindowFunction, WindowFunctionParams, }, Expr, }; @@ -110,6 +110,7 @@ impl TreeNode for Expr { Expr::InList(InList { expr, list, .. }) => { (expr, list).apply_ref_elements(f) } + Expr::LambdaFunction(LambdaFunction { func: _, args}) => args.apply_elements(f), Expr::Lambda (Lambda{ params: _, body}) => body.apply_elements(f) } } @@ -317,6 +318,9 @@ impl TreeNode for Expr { .update_data(|(new_expr, new_list)| { Expr::InList(InList::new(new_expr, new_list, negated)) }), + Expr::LambdaFunction(LambdaFunction { func, args }) => args + .map_elements(f)? + .update_data(|args| Expr::LambdaFunction(LambdaFunction { func, args })), Expr::Lambda(Lambda { params, body }) => body .map_elements(f)? .update_data(|body| Expr::Lambda(Lambda { params, body })), diff --git a/datafusion/expr/src/udf_eq.rs b/datafusion/expr/src/udf_eq.rs index 6664495267129..a003f05c15b5e 100644 --- a/datafusion/expr/src/udf_eq.rs +++ b/datafusion/expr/src/udf_eq.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use crate::{AggregateUDFImpl, ScalarUDFImpl, WindowUDFImpl}; +use crate::{AggregateUDFImpl, LambdaUDF, ScalarUDFImpl, WindowUDFImpl}; use std::fmt::Debug; use std::hash::{DefaultHasher, Hash, Hasher}; use std::ops::Deref; @@ -93,6 +93,18 @@ impl UdfPointer for Arc { } } +impl UdfPointer for Arc { + fn equals(&self, other: &Self::Target) -> bool { + self.as_ref().dyn_eq(other.as_any()) + } + + fn hash_value(&self) -> u64 { + let hasher = &mut DefaultHasher::new(); + self.as_ref().dyn_hash(hasher); + hasher.finish() + } +} + impl UdfPointer for Arc { fn equals(&self, other: &(dyn AggregateUDFImpl + '_)) -> bool { self.as_ref().dyn_eq(other.as_any()) diff --git a/datafusion/expr/src/udlf.rs b/datafusion/expr/src/udlf.rs new file mode 100644 index 0000000000000..84f9494d2edd8 --- /dev/null +++ b/datafusion/expr/src/udlf.rs @@ -0,0 +1,649 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`LambdaUDF`]: Lambda User Defined Functions + +use crate::expr::schema_name_from_exprs_comma_separated_without_space; +use crate::simplify::{ExprSimplifyResult, SimplifyInfo}; +use crate::sort_properties::{ExprProperties, SortProperties}; +use crate::{ColumnarValue, Documentation, Expr}; +use arrow::array::RecordBatch; +use arrow::datatypes::{DataType, Field, FieldRef}; +use datafusion_common::config::ConfigOptions; +use datafusion_common::{Result, ScalarValue, not_impl_err}; +use datafusion_expr_common::dyn_eq::{DynEq, DynHash}; +use datafusion_expr_common::interval_arithmetic::Interval; +use datafusion_expr_common::signature::Signature; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use std::any::Any; +use std::cmp::Ordering; +use std::fmt::Debug; +use std::hash::{Hash, Hasher}; +use std::sync::Arc; + +impl PartialEq for dyn LambdaUDF { + fn eq(&self, other: &Self) -> bool { + self.dyn_eq(other.as_any()) + } +} + +impl PartialOrd for dyn LambdaUDF { + fn partial_cmp(&self, other: &Self) -> Option { + let mut cmp = self.name().cmp(other.name()); + if cmp == Ordering::Equal { + cmp = self.signature().partial_cmp(other.signature())?; + } + if cmp == Ordering::Equal { + cmp = self.aliases().partial_cmp(other.aliases())?; + } + // Contract for PartialOrd and PartialEq consistency requires that + // a == b if and only if partial_cmp(a, b) == Some(Equal). + if cmp == Ordering::Equal && self != other { + // Functions may have other properties besides name and signature + // that differentiate two instances (e.g. type, or arbitrary parameters). + // We cannot return Some(Equal) in such case. + return None; + } + debug_assert!( + cmp == Ordering::Equal || self != other, + "Detected incorrect implementation of PartialEq when comparing functions: '{}' and '{}'. \ + The functions compare as equal, but they are not equal based on general properties that \ + the PartialOrd implementation observes,", + self.name(), other.name() + ); + Some(cmp) + } +} + +impl Eq for dyn LambdaUDF {} + +impl Hash for dyn LambdaUDF { + fn hash(&self, state: &mut H) { + self.dyn_hash(state) + } +} + +#[derive(Clone, Debug)] +pub enum ValueOrLambdaParameter { + /// A columnar value with the given field + Value(FieldRef), + /// A lambda + Lambda, +} + +/// Arguments passed to [`LambdaUDF::invoke_with_args`] when invoking a +/// lambda function. +#[derive(Debug, Clone)] +pub struct LambdaFunctionArgs { + /// The evaluated arguments to the function + pub args: Vec, + /// Field associated with each arg, if it exists + pub arg_fields: Vec, + /// The number of rows in record batch being evaluated + pub number_rows: usize, + /// The return field of the lambda function returned (from `return_type` + /// or `return_field_from_args`) when creating the physical expression + /// from the logical expression + pub return_field: FieldRef, + /// The config options at execution time + pub config_options: Arc, +} + +/// A lambda argument to a LambdaFunction +#[derive(Clone, Debug)] +pub struct LambdaFunctionLambdaArg { + /// The parameters defined in this lambda + /// + /// For example, for `array_transform([2], v -> -v)`, + /// this will be `vec![Field::new("v", DataType::Int32, true)]` + pub params: Vec, + /// The body of the lambda + /// + /// For example, for `array_transform([2], v -> -v)`, + /// this will be the physical expression of `-v` + pub body: Arc, + /// A RecordBatch containing at least the captured columns inside this lambda body, if any + /// Note that it may contain additional, non-specified columns, but that's implementation detail + /// + /// For example, for `array_transform([2], v -> v + a + b)`, + /// this will be a `RecordBatch` with two columns, `a` and `b` + pub captures: Option, +} + +impl LambdaFunctionArgs { + /// The return type of the function. See [`Self::return_field`] for more + /// details. + pub fn return_type(&self) -> &DataType { + self.return_field.data_type() + } +} + +// An argument to a LambdaUDF that supports lambdas +#[derive(Clone, Debug)] +pub enum ValueOrLambda { + Value(ColumnarValue), + Lambda(LambdaFunctionLambdaArg), +} + +/// Information about arguments passed to the function +/// +/// This structure contains metadata about how the function was called +/// such as the type of the arguments, any scalar arguments and if the +/// arguments can (ever) be null +/// +/// See [`LambdaUDF::return_field_from_args`] for more information +#[derive(Clone, Debug)] +pub struct LambdaReturnFieldArgs<'a> { + /// The data types of the arguments to the function + /// + /// If argument `i` to the function is a lambda, it will be the field returned by the + /// lambda when executed with the arguments returned from `LambdaUDF::lambdas_parameters` + /// + /// For example, with `array_transform([1], v -> v == 5)` + /// this field will be `[Field::new("", DataType::List(DataType::Int32), false), Field::new("", DataType::Boolean, false)]` + pub arg_fields: &'a [ValueOrLambdaField], + /// Is argument `i` to the function a scalar (constant)? + /// + /// If the argument `i` is not a scalar, it will be None + /// + /// For example, if a function is called like `my_function(column_a, 5)` + /// this field will be `[None, Some(ScalarValue::Int32(Some(5)))]` + pub scalar_arguments: &'a [Option<&'a ScalarValue>], +} + +/// A tagged Field indicating whether it correspond to a value or a lambda argument +#[derive(Clone, Debug)] +pub enum ValueOrLambdaField { + /// The Field of a ColumnarValue argument + Value(FieldRef), + /// The Field of the return of the lambda body when evaluated with the parameters from LambdaUDF::lambda_parameters + Lambda(FieldRef), +} + +/// Trait for implementing user defined lambda functions. +/// +/// This trait exposes the full API for implementing user defined functions and +/// can be used to implement any function. +/// +/// See [`advanced_udf.rs`] for a full example with complete implementation and +/// [`LambdaUDF`] for other available options. +/// +/// [`advanced_udf.rs`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/advanced_udf.rs +/// +/// # Basic Example +/// ``` +/// # use std::any::Any; +/// # use std::sync::LazyLock; +/// # use arrow::datatypes::DataType; +/// # use datafusion_common::{DataFusionError, plan_err, Result}; +/// # use datafusion_expr::{col, ColumnarValue, Documentation, LambdaFunctionArgs, Signature, Volatility}; +/// # use datafusion_expr::LambdaUDF; +/// # use datafusion_expr::lambda_doc_sections::DOC_SECTION_MATH; +/// /// This struct for a simple UDF that adds one to an int32 +/// #[derive(Debug, PartialEq, Eq, Hash)] +/// struct AddOne { +/// signature: Signature, +/// } +/// +/// impl AddOne { +/// fn new() -> Self { +/// Self { +/// signature: Signature::uniform(1, vec![DataType::Int32], Volatility::Immutable), +/// } +/// } +/// } +/// +/// static DOCUMENTATION: LazyLock = LazyLock::new(|| { +/// Documentation::builder(DOC_SECTION_MATH, "Add one to an int32", "add_one(2)") +/// .with_argument("arg1", "The int32 number to add one to") +/// .build() +/// }); +/// +/// fn get_doc() -> &'static Documentation { +/// &DOCUMENTATION +/// } +/// +/// /// Implement the LambdaUDF trait for AddOne +/// impl LambdaUDF for AddOne { +/// fn as_any(&self) -> &dyn Any { self } +/// fn name(&self) -> &str { "add_one" } +/// fn signature(&self) -> &Signature { &self.signature } +/// fn return_type(&self, args: &[DataType]) -> Result { +/// if !matches!(args.get(0), Some(&DataType::Int32)) { +/// return plan_err!("add_one only accepts Int32 arguments"); +/// } +/// Ok(DataType::Int32) +/// } +/// // The actual implementation would add one to the argument +/// fn invoke_with_args(&self, args: LambdaFunctionArgs) -> Result { +/// unimplemented!() +/// } +/// fn documentation(&self) -> Option<&Documentation> { +/// Some(get_doc()) +/// } +/// } +/// +/// // Create a new LambdaUDF from the implementation +/// let add_one = LambdaUDF::from(AddOne::new()); +/// +/// // Call the function `add_one(col)` +/// let expr = add_one.call(vec![col("a")]); +/// ``` +pub trait LambdaUDF: Debug + DynEq + DynHash + Send + Sync { + /// Returns this object as an [`Any`] trait object + fn as_any(&self) -> &dyn Any; + + /// Returns this function's name + fn name(&self) -> &str; + + /// Returns any aliases (alternate names) for this function. + /// + /// Aliases can be used to invoke the same function using different names. + /// For example in some databases `now()` and `current_timestamp()` are + /// aliases for the same function. This behavior can be obtained by + /// returning `current_timestamp` as an alias for the `now` function. + /// + /// Note: `aliases` should only include names other than [`Self::name`]. + /// Defaults to `[]` (no aliases) + fn aliases(&self) -> &[String] { + &[] + } + + /// Returns the name of the column this expression would create + /// + /// See [`Expr::schema_name`] for details + fn schema_name(&self, args: &[Expr]) -> Result { + Ok(format!( + "{}({})", + self.name(), + schema_name_from_exprs_comma_separated_without_space(args)? + )) + } + + /// Returns a [`Signature`] describing the argument types for which this + /// function has an implementation, and the function's [`Volatility`]. + /// + /// See [`Signature`] for more details on argument type handling + /// and [`Self::return_type`] for computing the return type. + /// + /// [`Volatility`]: datafusion_expr_common::signature::Volatility + fn signature(&self) -> &Signature; + + /// Create a new instance of this function with updated configuration. + /// + /// This method is called when configuration options change at runtime + /// (e.g., via `SET` statements) to allow functions that depend on + /// configuration to update themselves accordingly. + /// + /// Note the current [`ConfigOptions`] are also passed to [`Self::invoke_with_args`] so + /// this API is not needed for functions where the values may + /// depend on the current options. + /// + /// This API is useful for functions where the return + /// **type** depends on the configuration options, such as the `now()` function + /// which depends on the current timezone. + /// + /// # Arguments + /// + /// * `config` - The updated configuration options + /// + /// # Returns + /// + /// * `Some(LambdaUDF)` - A new instance of this function configured with the new settings + /// * `None` - If this function does not change with new configuration settings (the default) + fn with_updated_config(&self, _config: &ConfigOptions) -> Option> { + None + } + + /// What type will be returned by this function, given the arguments? + /// + /// By default, this function calls [`Self::return_type`] with the + /// types of each argument. + /// + /// # Notes + /// + /// For the majority of UDFs, implementing [`Self::return_type`] is sufficient, + /// as the result type is typically a deterministic function of the input types + /// (e.g., `sqrt(f32)` consistently yields `f32`). Implementing this method directly + /// is generally unnecessary unless the return type depends on runtime values. + /// + /// This function can be used for more advanced cases such as: + /// + /// 1. specifying nullability + /// 2. return types based on the **values** of the arguments (rather than + /// their **types**. + /// + /// # Example creating `Field` + /// + /// Note the name of the [`Field`] is ignored, except for structured types such as + /// `DataType::Struct`. + /// + /// ```rust + /// # use std::sync::Arc; + /// # use arrow::datatypes::{DataType, Field, FieldRef}; + /// # use datafusion_common::Result; + /// # use datafusion_expr::LambdaReturnFieldArgs; + /// # struct Example{} + /// # impl Example { + /// fn return_field_from_args(&self, args: LambdaReturnFieldArgs) -> Result { + /// // report output is only nullable if any one of the arguments are nullable + /// let nullable = args.arg_fields.iter().any(|f| f.is_nullable()); + /// let field = Arc::new(Field::new("ignored_name", DataType::Int32, true)); + /// Ok(field) + /// } + /// # } + /// ``` + /// + /// # Output Type based on Values + /// + /// For example, the following two function calls get the same argument + /// types (something and a `Utf8` string) but return different types based + /// on the value of the second argument: + /// + /// * `arrow_cast(x, 'Int16')` --> `Int16` + /// * `arrow_cast(x, 'Float32')` --> `Float32` + /// + /// # Requirements + /// + /// This function **must** consistently return the same type for the same + /// logical input even if the input is simplified (e.g. it must return the same + /// value for `('foo' | 'bar')` as it does for ('foobar'). + fn return_field_from_args(&self, args: LambdaReturnFieldArgs) -> Result; + + /// Invoke the function returning the appropriate result. + /// + /// # Performance + /// + /// For the best performance, the implementations should handle the common case + /// when one or more of their arguments are constant values (aka + /// [`ColumnarValue::Scalar`]). + /// + /// [`ColumnarValue::values_to_arrays`] can be used to convert the arguments + /// to arrays, which will likely be simpler code, but be slower. + fn invoke_with_args(&self, args: LambdaFunctionArgs) -> Result; + + /// Optionally apply per-UDF simplification / rewrite rules. + /// + /// This can be used to apply function specific simplification rules during + /// optimization (e.g. `arrow_cast` --> `Expr::Cast`). The default + /// implementation does nothing. + /// + /// Note that DataFusion handles simplifying arguments and "constant + /// folding" (replacing a function call with constant arguments such as + /// `my_add(1,2) --> 3` ). Thus, there is no need to implement such + /// optimizations manually for specific UDFs. + /// + /// # Arguments + /// * `args`: The arguments of the function + /// * `info`: The necessary information for simplification + /// + /// # Returns + /// [`ExprSimplifyResult`] indicating the result of the simplification NOTE + /// if the function cannot be simplified, the arguments *MUST* be returned + /// unmodified + /// + /// # Notes + /// + /// The returned expression must have the same schema as the original + /// expression, including both the data type and nullability. For example, + /// if the original expression is nullable, the returned expression must + /// also be nullable, otherwise it may lead to schema verification errors + /// later in query planning. + fn simplify( + &self, + args: Vec, + _info: &dyn SimplifyInfo, + ) -> Result { + Ok(ExprSimplifyResult::Original(args)) + } + + /// Returns true if some of this `exprs` subexpressions may not be evaluated + /// and thus any side effects (like divide by zero) may not be encountered. + /// + /// Setting this to true prevents certain optimizations such as common + /// subexpression elimination + /// + /// When overriding this function to return `true`, [LambdaUDF::conditional_arguments] can also be + /// overridden to report more accurately which arguments are eagerly evaluated and which ones + /// lazily. + fn short_circuits(&self) -> bool { + false + } + + /// Determines which of the arguments passed to this function are evaluated eagerly + /// and which may be evaluated lazily. + /// + /// If this function returns `None`, all arguments are eagerly evaluated. + /// Returning `None` is a micro optimization that saves a needless `Vec` + /// allocation. + /// + /// If the function returns `Some`, returns (`eager`, `lazy`) where `eager` + /// are the arguments that are always evaluated, and `lazy` are the + /// arguments that may be evaluated lazily (i.e. may not be evaluated at all + /// in some cases). + /// + /// Implementations must ensure that the two returned `Vec`s are disjunct, + /// and that each argument from `args` is present in one the two `Vec`s. + /// + /// When overriding this function, [LambdaUDF::short_circuits] must + /// be overridden to return `true`. + fn conditional_arguments<'a>( + &self, + args: &'a [Expr], + ) -> Option<(Vec<&'a Expr>, Vec<&'a Expr>)> { + if self.short_circuits() { + Some((vec![], args.iter().collect())) + } else { + None + } + } + + /// Computes the output [`Interval`] for a [`LambdaUDF`], given the input + /// intervals. + /// + /// # Parameters + /// + /// * `children` are the intervals for the children (inputs) of this function. + /// + /// # Example + /// + /// If the function is `ABS(a)`, and the input interval is `a: [-3, 2]`, + /// then the output interval would be `[0, 3]`. + fn evaluate_bounds(&self, _input: &[&Interval]) -> Result { + // We cannot assume the input datatype is the same of output type. + Interval::make_unbounded(&DataType::Null) + } + + /// Updates bounds for child expressions, given a known [`Interval`]s for this + /// function. + /// + /// This function is used to propagate constraints down through an + /// expression tree. + /// + /// # Parameters + /// + /// * `interval` is the currently known interval for this function. + /// * `inputs` are the current intervals for the inputs (children) of this function. + /// + /// # Returns + /// + /// A `Vec` of new intervals for the children, in order. + /// + /// If constraint propagation reveals an infeasibility for any child, returns + /// [`None`]. If none of the children intervals change as a result of + /// propagation, may return an empty vector instead of cloning `children`. + /// This is the default (and conservative) return value. + /// + /// # Example + /// + /// If the function is `ABS(a)`, the current `interval` is `[4, 5]` and the + /// input `a` is given as `[-7, 3]`, then propagation would return `[-5, 3]`. + fn propagate_constraints( + &self, + _interval: &Interval, + _inputs: &[&Interval], + ) -> Result>> { + Ok(Some(vec![])) + } + + /// Calculates the [`SortProperties`] of this function based on its children's properties. + fn output_ordering(&self, inputs: &[ExprProperties]) -> Result { + if !self.preserves_lex_ordering(inputs)? { + return Ok(SortProperties::Unordered); + } + + let Some(first_order) = inputs.first().map(|p| &p.sort_properties) else { + return Ok(SortProperties::Singleton); + }; + + if inputs + .iter() + .skip(1) + .all(|input| &input.sort_properties == first_order) + { + Ok(*first_order) + } else { + Ok(SortProperties::Unordered) + } + } + + /// Returns true if the function preserves lexicographical ordering based on + /// the input ordering. + /// + /// For example, `concat(a || b)` preserves lexicographical ordering, but `abs(a)` does not. + fn preserves_lex_ordering(&self, _inputs: &[ExprProperties]) -> Result { + Ok(false) + } + + /// Coerce arguments of a function call to types that the function can evaluate. + /// + /// See the [type coercion module](crate::type_coercion) + /// documentation for more details on type coercion + /// + /// [`TypeSignature`]: crate::TypeSignature + /// + /// For example, if your function requires a floating point arguments, but the user calls + /// it like `my_func(1::int)` (i.e. with `1` as an integer), coerce_types can return `[DataType::Float64]` + /// to ensure the argument is converted to `1::double` + /// + /// # Parameters + /// * `arg_types`: The argument types of the arguments this function with + /// + /// # Return value + /// A Vec the same length as `arg_types`. DataFusion will `CAST` the function call + /// arguments to these specific types. + fn coerce_types(&self, _arg_types: &[DataType]) -> Result> { + not_impl_err!("Function {} does not implement coerce_types", self.name()) + } + + /// Returns the documentation for this Lambda UDF. + /// + /// Documentation can be accessed programmatically as well as generating + /// publicly facing documentation. + fn documentation(&self) -> Option<&Documentation> { + None + } + + /// Returns the parameters that any lambda supports + fn lambdas_parameters( + &self, + args: &[ValueOrLambdaParameter], + ) -> Result>>> { + Ok(vec![None; args.len()]) + } +} + +#[cfg(test)] +mod tests { + use datafusion_expr_common::signature::Volatility; + + use super::*; + use std::hash::DefaultHasher; + + #[derive(Debug, PartialEq, Eq, Hash)] + struct TestLambdaUDF { + name: &'static str, + field: &'static str, + signature: Signature, + } + impl LambdaUDF for TestLambdaUDF { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + self.name + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_field_from_args(&self, _args: LambdaReturnFieldArgs) -> Result { + unimplemented!() + } + + fn invoke_with_args(&self, _args: LambdaFunctionArgs) -> Result { + unimplemented!() + } + } + + // PartialEq and Hash must be consistent, and also PartialEq and PartialOrd + // must be consistent, so they are tested together. + #[test] + fn test_partial_eq_hash_and_partial_ord() { + // A parameterized function + let f = test_func("foo", "a"); + + // Same like `f`, different instance + let f2 = test_func("foo", "a"); + assert_eq!(&f, &f2); + assert_eq!(hash(&f), hash(&f2)); + assert_eq!(f.partial_cmp(&f2), Some(Ordering::Equal)); + + // Different parameter + let b = test_func("foo", "b"); + assert_ne!(&f, &b); + assert_ne!(hash(&f), hash(&b)); // hash can collide for different values but does not collide in this test + assert_eq!(f.partial_cmp(&b), None); + + // Different name + let o = test_func("other", "a"); + assert_ne!(&f, &o); + assert_ne!(hash(&f), hash(&o)); // hash can collide for different values but does not collide in this test + assert_eq!(f.partial_cmp(&o), Some(Ordering::Less)); + + // Different name and parameter + assert_ne!(&b, &o); + assert_ne!(hash(&b), hash(&o)); // hash can collide for different values but does not collide in this test + assert_eq!(b.partial_cmp(&o), Some(Ordering::Less)); + } + + fn test_func(name: &'static str, parameter: &'static str) -> Arc { + Arc::new(TestLambdaUDF { + name, + field: parameter, + signature: Signature::any(1, Volatility::Immutable), + }) + } + + fn hash(value: &T) -> u64 { + let hasher = &mut DefaultHasher::new(); + value.hash(hasher); + hasher.finish() + } +} diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index ab58b1c3f835f..e7beba8c4b090 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -308,6 +308,7 @@ pub fn expr_to_columns(expr: &Expr, accum: &mut HashSet) -> Result<()> { | Expr::Wildcard { .. } | Expr::Placeholder(_) | Expr::OuterReferenceColumn { .. } + | Expr::LambdaFunction(_) | Expr::Lambda(_) | Expr::LambdaVariable(_) => {} } diff --git a/datafusion/functions-nested/src/array_transform.rs b/datafusion/functions-nested/src/array_transform.rs index 123df27b339be..746cfe6b62be5 100644 --- a/datafusion/functions-nested/src/array_transform.rs +++ b/datafusion/functions-nested/src/array_transform.rs @@ -26,24 +26,29 @@ use arrow::{ datatypes::{DataType, Field, FieldRef, Schema}, }; use datafusion_common::{ - HashMap, Result, exec_err, internal_err, tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeRewriter}, utils::{elements_indices, list_indices, list_values, take_function_args} + exec_err, + tree_node::{ + Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeRewriter, + }, + utils::{elements_indices, list_indices, list_values, take_function_args}, + HashMap, Result, }; use datafusion_expr::{ - ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, + ColumnarValue, Documentation, LambdaFunctionArgs, LambdaUDF, Signature, ValueOrLambda, ValueOrLambdaField, ValueOrLambdaParameter, Volatility, }; use datafusion_macros::user_doc; -use datafusion_physical_expr::expressions::{LambdaVariable, LambdaExpr}; +use datafusion_physical_expr::expressions::{LambdaExpr, LambdaVariable}; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use std::{any::Any, sync::Arc}; -make_udf_expr_and_func!( - ArrayTransform, - array_transform, - array lambda, - "transforms the values of a array", - array_transform_udf -); +//make_udf_expr_and_func!( +// ArrayTransform, +// array_transform, +// array lambda, +// "transforms the values of a array", +// array_transform_udf +//); #[user_doc( doc_section(label = "Array Functions"), @@ -84,7 +89,7 @@ impl ArrayTransform { } } -impl ScalarUDFImpl for ArrayTransform { +impl LambdaUDF for ArrayTransform { fn as_any(&self) -> &dyn Any { self } @@ -101,18 +106,12 @@ impl ScalarUDFImpl for ArrayTransform { &self.signature } - fn return_type(&self, _arg_types: &[DataType]) -> Result { - internal_err!("return_type called instead of return_field_from_args") - } - fn return_field_from_args( &self, - args: datafusion_expr::ReturnFieldArgs, + args: datafusion_expr::LambdaReturnFieldArgs, ) -> Result> { - let args = args.to_lambda_args(); - let [ValueOrLambdaField::Value(list), ValueOrLambdaField::Lambda(lambda)] = - take_function_args(self.name(), &args)? + take_function_args(self.name(), args.arg_fields)? else { return exec_err!( "{} expects a value follewed by a lambda, got {:?}", @@ -141,10 +140,8 @@ impl ScalarUDFImpl for ArrayTransform { Ok(Arc::new(Field::new("", return_type, list.is_nullable()))) } - fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - // args.lambda_args allows the convenient match below, instead of inspecting both args.args and args.lambdas - let lambda_args = args.to_lambda_args(); - let [list_value, lambda] = take_function_args(self.name(), &lambda_args)?; + fn invoke_with_args(&self, args: LambdaFunctionArgs) -> Result { + let [list_value, lambda] = take_function_args(self.name(), &args.args)?; let (ValueOrLambda::Value(list_value), ValueOrLambda::Lambda(lambda)) = (list_value, lambda) @@ -152,7 +149,7 @@ impl ScalarUDFImpl for ArrayTransform { return exec_err!( "{} expects a value followed by a lambda, got {:?}", self.name(), - &lambda_args + &args.args ); }; @@ -243,8 +240,7 @@ impl ScalarUDFImpl for ArrayTransform { &self, args: &[ValueOrLambdaParameter], ) -> Result>>> { - let [ValueOrLambdaParameter::Value(list), ValueOrLambdaParameter::Lambda] = - args + let [ValueOrLambdaParameter::Value(list), ValueOrLambdaParameter::Lambda] = args else { return exec_err!( "{} expects a value follewed by a lambda, got {:?}", @@ -305,7 +301,7 @@ impl TreeNodeRewriter for BindLambdaVariable<'_> { if let Some((value, shadows)) = self.columns.get(lambda_variable.name()) { if *shadows == 0 { return Ok(Transformed::yes(Arc::new( - lambda_variable.clone().with_value(value.clone()), + lambda_variable.clone().with_value(Arc::clone(value)), ))); } } @@ -317,7 +313,7 @@ impl TreeNodeRewriter for BindLambdaVariable<'_> { } if self.columns.values().all(|(_value, shadows)| *shadows > 0) { - return Ok(Transformed::new(node, false, TreeNodeRecursion::Jump)) + return Ok(Transformed::new(node, false, TreeNodeRecursion::Jump)); } } diff --git a/datafusion/functions-nested/src/lib.rs b/datafusion/functions-nested/src/lib.rs index 55acf24ba4657..c93a55cce1a4f 100644 --- a/datafusion/functions-nested/src/lib.rs +++ b/datafusion/functions-nested/src/lib.rs @@ -79,7 +79,7 @@ pub mod expr_fn { pub use super::array_has::array_has; pub use super::array_has::array_has_all; pub use super::array_has::array_has_any; - pub use super::array_transform::array_transform; + //pub use super::array_transform::array_transform; pub use super::cardinality::cardinality; pub use super::concat::array_append; pub use super::concat::array_concat; @@ -147,7 +147,7 @@ pub fn all_default_nested_functions() -> Vec> { array_has::array_has_udf(), array_has::array_has_all_udf(), array_has::array_has_any_udf(), - array_transform::array_transform_udf(), + //array_transform::array_transform_udf(), empty::array_empty_udf(), length::array_length_udf(), distance::array_distance_udf(), diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 626c2ba550594..e0b1d9096b415 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -598,6 +598,7 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> { | Expr::GroupingSet(_) | Expr::Placeholder(_) | Expr::OuterReferenceColumn(_, _) + | Expr::LambdaFunction(_) | Expr::Lambda(_) | Expr::LambdaVariable(_) => Ok(Transformed::no(expr)), } diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index b7e8626aa7bd5..63314b48facfd 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -288,6 +288,7 @@ fn can_evaluate_as_join_condition(predicate: &Expr) -> Result { | Expr::TryCast(_) | Expr::InList { .. } | Expr::ScalarFunction(_) + | Expr::LambdaFunction(_) | Expr::Lambda(_) | Expr::LambdaVariable(_) => Ok(TreeNodeRecursion::Continue), // TODO: remove the next line after `Expr::Wildcard` is removed diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 74794115755ba..8e09fbbae48d8 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -36,6 +36,7 @@ use datafusion_common::{ exec_datafusion_err, internal_err, DFSchema, DataFusionError, Result, ScalarValue, }; +use datafusion_expr::expr::LambdaFunction; use datafusion_expr::{ and, binary::BinaryTypeCoercer, lit, or, BinaryExpr, Case, ColumnarValue, Expr, Like, Operator, Volatility, @@ -659,6 +660,9 @@ impl<'a> ConstEvaluator<'a> { Expr::ScalarFunction(ScalarFunction { func, .. }) => { Self::volatility_ok(func.signature().volatility) } + Expr::LambdaFunction(LambdaFunction { func, .. }) => { + Self::volatility_ok(func.signature().volatility) + } Expr::Literal(_, _) | Expr::Alias(..) | Expr::Unnest(_) diff --git a/datafusion/optimizer/tests/optimizer_integration.rs b/datafusion/optimizer/tests/optimizer_integration.rs index c0f48b8ebfc40..11a656f2abb4c 100644 --- a/datafusion/optimizer/tests/optimizer_integration.rs +++ b/datafusion/optimizer/tests/optimizer_integration.rs @@ -728,6 +728,10 @@ impl ContextProvider for MyContextProvider { None } + fn get_lambda_meta(&self, _name: &str) -> Option> { + None + } + fn get_aggregate_meta(&self, name: &str) -> Option> { self.udafs.get(name).cloned() } diff --git a/datafusion/physical-expr/src/lambda_function.rs b/datafusion/physical-expr/src/lambda_function.rs new file mode 100644 index 0000000000000..0b0c33cdd6be9 --- /dev/null +++ b/datafusion/physical-expr/src/lambda_function.rs @@ -0,0 +1,530 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Declaration of built-in (lambda) functions. +//! This module contains built-in functions' enumeration and metadata. +//! +//! Generally, a function has: +//! * a signature +//! * a return type, that is a function of the incoming argument's types +//! * the computation, that must accept each valid signature +//! +//! * Signature: see `Signature` +//! * Return type: a function `(arg_types) -> return_type`. E.g. for sqrt, ([f32]) -> f32, ([f64]) -> f64. +//! +//! This module also has a set of coercion rules to improve user experience: if an argument i32 is passed +//! to a function that supports f64, it is coerced to f64. + +use std::any::Any; +use std::fmt::{self, Debug, Formatter}; +use std::hash::{Hash, Hasher}; +use std::sync::Arc; + +use crate::expressions::{LambdaExpr, Literal}; +use crate::PhysicalExpr; + +use arrow::array::{Array, NullArray, RecordBatch}; +use arrow::datatypes::{DataType, Field, FieldRef, Schema}; +use datafusion_common::config::{ConfigEntry, ConfigOptions}; +use datafusion_common::{exec_err, internal_err, Result, ScalarValue}; +use datafusion_expr::interval_arithmetic::Interval; +use datafusion_expr::sort_properties::ExprProperties; +use datafusion_expr::{ + expr_vec_fmt, ColumnarValue, LambdaFunctionArgs, LambdaFunctionLambdaArg, + LambdaReturnFieldArgs, LambdaUDF, ValueOrLambda, ValueOrLambdaField, + ValueOrLambdaParameter, Volatility, +}; + +/// Physical expression of a lambda function +pub struct LambdaFunctionExpr { + fun: Arc, + name: String, + args: Vec>, + return_field: FieldRef, + config_options: Arc, +} + +impl Debug for LambdaFunctionExpr { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + f.debug_struct("LambdaFunctionExpr") + .field("fun", &"") + .field("name", &self.name) + .field("args", &self.args) + .field("return_field", &self.return_field) + .finish() + } +} + +impl LambdaFunctionExpr { + /// Create a new Lambda function + pub fn new( + name: &str, + fun: Arc, + args: Vec>, + return_field: FieldRef, + config_options: Arc, + ) -> Self { + Self { + fun, + name: name.to_owned(), + args, + return_field, + config_options, + } + } + + /// Create a new Lambda function + pub fn try_new( + fun: Arc, + args: Vec>, + schema: &Schema, + config_options: Arc, + ) -> Result { + let name = fun.name().to_string(); + let arg_fields = args + .iter() + .map(|e| { + let field = e.return_field(schema)?; + match e.as_any().downcast_ref::() { + Some(_lambda) => Ok(ValueOrLambdaField::Lambda(field)), + None => Ok(ValueOrLambdaField::Value(field)), + } + }) + .collect::>>()?; + + // TODO: verify that input data types is consistent with function's `TypeSignature` + + let arguments = args + .iter() + .map(|e| { + e.as_any() + .downcast_ref::() + .map(|literal| literal.value()) + }) + .collect::>(); + + let ret_args = LambdaReturnFieldArgs { + arg_fields: &arg_fields, + scalar_arguments: &arguments, + }; + + let return_field = fun.return_field_from_args(ret_args)?; + + Ok(Self { + fun, + name, + args, + return_field, + config_options, + }) + } + + /// Get the lambda function implementation + pub fn fun(&self) -> &dyn LambdaUDF { + self.fun.as_ref() + } + + /// The name for this expression + pub fn name(&self) -> &str { + &self.name + } + + /// Input arguments + pub fn args(&self) -> &[Arc] { + &self.args + } + + /// Data type produced by this expression + pub fn return_type(&self) -> &DataType { + self.return_field.data_type() + } + + pub fn with_nullable(mut self, nullable: bool) -> Self { + self.return_field = self + .return_field + .as_ref() + .clone() + .with_nullable(nullable) + .into(); + self + } + + pub fn nullable(&self) -> bool { + self.return_field.is_nullable() + } + + pub fn config_options(&self) -> &ConfigOptions { + &self.config_options + } + + /// Given an arbitrary PhysicalExpr attempt to downcast it to a LambdaFunctionExpr + /// and verify that its inner function is of type T. + /// If the downcast fails, or the function is not of type T, returns `None`. + /// Otherwise returns `Some(LambdaFunctionExpr)`. + pub fn try_downcast_func(expr: &dyn PhysicalExpr) -> Option<&LambdaFunctionExpr> + where + T: 'static, + { + match expr.as_any().downcast_ref::() { + Some(lambda_expr) + if lambda_expr.fun().as_any().downcast_ref::().is_some() => + { + Some(lambda_expr) + } + _ => None, + } + } +} + +impl fmt::Display for LambdaFunctionExpr { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + write!(f, "{}({})", self.name, expr_vec_fmt!(self.args)) + } +} + +impl PartialEq for LambdaFunctionExpr { + fn eq(&self, o: &Self) -> bool { + if std::ptr::eq(self, o) { + // The equality implementation is somewhat expensive, so let's short-circuit when possible. + return true; + } + let Self { + fun, + name, + args, + return_field, + config_options, + } = self; + fun.eq(&o.fun) + && name.eq(&o.name) + && args.eq(&o.args) + && return_field.eq(&o.return_field) + && (Arc::ptr_eq(config_options, &o.config_options) + || sorted_config_entries(config_options) + == sorted_config_entries(&o.config_options)) + } +} +impl Eq for LambdaFunctionExpr {} +impl Hash for LambdaFunctionExpr { + fn hash(&self, state: &mut H) { + let Self { + fun, + name, + args, + return_field, + config_options: _, // expensive to hash, and often equal + } = self; + fun.hash(state); + name.hash(state); + args.hash(state); + return_field.hash(state); + } +} + +fn sorted_config_entries(config_options: &ConfigOptions) -> Vec { + let mut entries = config_options.entries(); + entries.sort_by(|l, r| l.key.cmp(&r.key)); + entries +} + +impl PhysicalExpr for LambdaFunctionExpr { + /// Return a reference to Any that can be used for downcasting + fn as_any(&self) -> &dyn Any { + self + } + + fn data_type(&self, _input_schema: &Schema) -> Result { + Ok(self.return_field.data_type().clone()) + } + + fn nullable(&self, _input_schema: &Schema) -> Result { + Ok(self.return_field.is_nullable()) + } + + fn evaluate(&self, batch: &RecordBatch) -> Result { + let arg_fields = self + .args + .iter() + .map(|e| { + let field = e.return_field(batch.schema_ref())?; + match e.as_any().downcast_ref::() { + Some(_lambda) => Ok(ValueOrLambdaField::Lambda(field)), + None => Ok(ValueOrLambdaField::Value(field)), + } + }) + .collect::>>()?; + + let args_metadata = arg_fields.iter() + .map(|field| match field { + ValueOrLambdaField::Value(field) => ValueOrLambdaParameter::Value(Arc::clone(field)), + ValueOrLambdaField::Lambda(_field) => ValueOrLambdaParameter::Lambda, + }) + .collect::>(); + + let params = self.fun().lambdas_parameters(&args_metadata)?; + + let args = std::iter::zip(&self.args, params) + .map(|(arg, lambda_params)| { + match (arg.as_any().downcast_ref::(), lambda_params) { + (Some(lambda), Some(lambda_params)) => { + if lambda.params().len() > lambda_params.len() { + return exec_err!( + "lambda defined {} params but UDF support only {}", + lambda.params().len(), + lambda_params.len() + ); + } + + let captures = lambda.captures(); + + let params = std::iter::zip(lambda.params(), lambda_params) + .map(|(name, param)| Arc::new(param.with_name(name))) + .collect(); + + let captures = if !captures.is_empty() { + let (fields, columns): (Vec<_>, _) = std::iter::zip( + batch.schema_ref().fields(), + batch.columns(), + ) + .enumerate() + .map(|(column_index, (field, column))| { + if captures.contains(&column_index) { + (Arc::clone(field), Arc::clone(column)) + } else { + ( + Arc::new(Field::new( + field.name(), + DataType::Null, + false, + )), + Arc::new(NullArray::new(column.len())) as _, + ) + } + }) + .unzip(); + + let schema = Arc::new(Schema::new(fields)); + + Some(RecordBatch::try_new(schema, columns)?) + } else { + None + }; + + Ok(ValueOrLambda::Lambda(LambdaFunctionLambdaArg { + params, + body: Arc::clone(lambda.body()), + captures, + })) + } + (Some(_lambda), None) => exec_err!( + "{} don't reported the parameters of one of it's lambdas", + self.fun.name() + ), + (None, Some(_lambda_params)) => exec_err!( + "{} reported parameters for an argument that is not a lambda", + self.fun.name() + ), + (None, None) => Ok(ValueOrLambda::Value(arg.evaluate(batch)?)), + } + }) + .collect::>>()?; + + let input_empty = args.is_empty(); + let input_all_scalar = args + .iter() + .all(|arg| matches!(arg, ValueOrLambda::Value(ColumnarValue::Scalar(_)))); + + // evaluate the function + let output = self.fun.invoke_with_args(LambdaFunctionArgs { + args, + arg_fields, + number_rows: batch.num_rows(), + return_field: Arc::clone(&self.return_field), + config_options: Arc::clone(&self.config_options), + })?; + + if let ColumnarValue::Array(array) = &output { + if array.len() != batch.num_rows() { + // If the arguments are a non-empty slice of scalar values, we can assume that + // returning a one-element array is equivalent to returning a scalar. + let preserve_scalar = + array.len() == 1 && !input_empty && input_all_scalar; + return if preserve_scalar { + ScalarValue::try_from_array(array, 0).map(ColumnarValue::Scalar) + } else { + internal_err!("UDF {} returned a different number of rows than expected. Expected: {}, Got: {}", + self.name, batch.num_rows(), array.len()) + }; + } + } + Ok(output) + } + + fn return_field(&self, _input_schema: &Schema) -> Result { + Ok(Arc::clone(&self.return_field)) + } + + fn children(&self) -> Vec<&Arc> { + self.args.iter().collect() + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + Ok(Arc::new(LambdaFunctionExpr::new( + &self.name, + Arc::clone(&self.fun), + children, + Arc::clone(&self.return_field), + Arc::clone(&self.config_options), + ))) + } + + fn evaluate_bounds(&self, children: &[&Interval]) -> Result { + self.fun.evaluate_bounds(children) + } + + fn propagate_constraints( + &self, + interval: &Interval, + children: &[&Interval], + ) -> Result>> { + self.fun.propagate_constraints(interval, children) + } + + fn get_properties(&self, children: &[ExprProperties]) -> Result { + let sort_properties = self.fun.output_ordering(children)?; + let preserves_lex_ordering = self.fun.preserves_lex_ordering(children)?; + let children_range = children + .iter() + .map(|props| &props.range) + .collect::>(); + let range = self.fun().evaluate_bounds(&children_range)?; + + Ok(ExprProperties { + sort_properties, + range, + preserves_lex_ordering, + }) + } + + fn fmt_sql(&self, f: &mut Formatter<'_>) -> fmt::Result { + write!(f, "{}(", self.name)?; + for (i, expr) in self.args.iter().enumerate() { + if i > 0 { + write!(f, ", ")?; + } + expr.fmt_sql(f)?; + } + write!(f, ")") + } + + fn is_volatile_node(&self) -> bool { + self.fun.signature().volatility == Volatility::Volatile + } +} + +#[cfg(test)] +mod tests { + use std::any::Any; + use std::sync::Arc; + + use super::*; + use crate::expressions::Column; + use crate::LambdaFunctionExpr; + use arrow::datatypes::{DataType, Field, Schema}; + use datafusion_common::Result; + use datafusion_expr::{LambdaFunctionArgs, LambdaUDF, Signature}; + use datafusion_expr_common::columnar_value::ColumnarValue; + use datafusion_physical_expr_common::physical_expr::is_volatile; + use datafusion_physical_expr_common::physical_expr::PhysicalExpr; + + /// Test helper to create a mock UDF with a specific volatility + #[derive(Debug, PartialEq, Eq, Hash)] + struct MockLambdaUDF { + signature: Signature, + } + + impl LambdaUDF for MockLambdaUDF { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "mock_function" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_field_from_args( + &self, + _args: LambdaReturnFieldArgs, + ) -> Result { + Ok(Arc::new(Field::new("", DataType::Int32, false))) + } + + fn invoke_with_args(&self, _args: LambdaFunctionArgs) -> Result { + Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some(42)))) + } + } + + #[test] + fn test_lambda_function_volatile_node() { + // Create a volatile UDF + let volatile_udf = Arc::new(MockLambdaUDF { + signature: Signature::uniform( + 1, + vec![DataType::Float32], + Volatility::Volatile, + ), + }); + + // Create a non-volatile UDF + let stable_udf = Arc::new(MockLambdaUDF { + signature: Signature::uniform(1, vec![DataType::Float32], Volatility::Stable), + }); + + let schema = Schema::new(vec![Field::new("a", DataType::Float32, false)]); + let args = vec![Arc::new(Column::new("a", 0)) as Arc]; + let config_options = Arc::new(ConfigOptions::new()); + + // Test volatile function + let volatile_expr = LambdaFunctionExpr::try_new( + volatile_udf, + args.clone(), + &schema, + Arc::clone(&config_options), + ) + .unwrap(); + + assert!(volatile_expr.is_volatile_node()); + let volatile_arc: Arc = Arc::new(volatile_expr); + assert!(is_volatile(&volatile_arc)); + + // Test non-volatile function + let stable_expr = + LambdaFunctionExpr::try_new(stable_udf, args, &schema, config_options) + .unwrap(); + + assert!(!stable_expr.is_volatile_node()); + let stable_arc: Arc = Arc::new(stable_expr); + assert!(!is_volatile(&stable_arc)); + } +} diff --git a/datafusion/physical-expr/src/lib.rs b/datafusion/physical-expr/src/lib.rs index aa8c9e50fd71e..a05d24d2ba2c5 100644 --- a/datafusion/physical-expr/src/lib.rs +++ b/datafusion/physical-expr/src/lib.rs @@ -31,6 +31,7 @@ pub mod binary_map { pub use datafusion_physical_expr_common::binary_map::{ArrowBytesSet, OutputType}; } pub mod async_scalar_function; +pub mod lambda_function; pub mod equivalence; pub mod expressions; pub mod intervals; @@ -70,6 +71,7 @@ pub use datafusion_physical_expr_common::sort_expr::{ pub use planner::{create_physical_expr, create_physical_exprs}; pub use scalar_function::ScalarFunctionExpr; +pub use lambda_function::LambdaFunctionExpr; pub use simplifier::PhysicalExprSimplifier; pub use utils::{conjunction, conjunction_opt, split_conjunction}; diff --git a/datafusion/physical-expr/src/planner.rs b/datafusion/physical-expr/src/planner.rs index 8a53aa81da8fa..f7be4aedf555e 100644 --- a/datafusion/physical-expr/src/planner.rs +++ b/datafusion/physical-expr/src/planner.rs @@ -18,7 +18,7 @@ use std::sync::Arc; use crate::expressions::{lambda_variable, LambdaExpr}; -use crate::ScalarFunctionExpr; +use crate::{LambdaFunctionExpr, ScalarFunctionExpr}; use crate::{ expressions::{self, binary, like, similar_to, Column, Literal}, PhysicalExpr, @@ -33,7 +33,7 @@ use datafusion_common::{ }; use datafusion_expr::execution_props::ExecutionProps; use datafusion_expr::expr::{ - Alias, Cast, InList, Lambda, LambdaVariable, Placeholder, ScalarFunction, + Alias, Cast, InList, Lambda, LambdaFunction, LambdaVariable, Placeholder, ScalarFunction }; use datafusion_expr::var_provider::is_system_variables; use datafusion_expr::var_provider::VarType; @@ -318,10 +318,6 @@ pub fn create_physical_expr( input_dfschema, execution_props, )?), - Expr::Lambda(Lambda { params, body }) => Ok(Arc::new(LambdaExpr::new( - params.clone(), - create_physical_expr(body, input_dfschema, execution_props)?, - ))), Expr::ScalarFunction(ScalarFunction { func, args }) => { let physical_args = create_physical_exprs(args, input_dfschema, execution_props)?; @@ -392,6 +388,26 @@ pub fn create_physical_expr( Expr::Placeholder(Placeholder { id, .. }) => { exec_err!("Placeholder '{id}' was not provided a value for execution.") } + Expr::LambdaFunction(LambdaFunction { func, args}) => { + let physical_args = + create_physical_exprs(args, input_dfschema, execution_props)?; + + let config_options = match execution_props.config_options.as_ref() { + Some(config_options) => Arc::clone(config_options), + None => Arc::new(ConfigOptions::default()), + }; + + Ok(Arc::new(LambdaFunctionExpr::try_new( + Arc::clone(func), + physical_args, + input_schema, + config_options, + )?)) + } + Expr::Lambda(Lambda { params, body }) => Ok(Arc::new(LambdaExpr::new( + params.clone(), + create_physical_expr(body, input_dfschema, execution_props)?, + ))), Expr::LambdaVariable(LambdaVariable { name, field, diff --git a/datafusion/proto/src/bytes/mod.rs b/datafusion/proto/src/bytes/mod.rs index 6eab2239015a7..e421ea11c4b2d 100644 --- a/datafusion/proto/src/bytes/mod.rs +++ b/datafusion/proto/src/bytes/mod.rs @@ -24,11 +24,11 @@ use crate::physical_plan::{ AsExecutionPlan, DefaultPhysicalExtensionCodec, PhysicalExtensionCodec, }; use crate::protobuf; -use datafusion_common::{plan_datafusion_err, Result}; +use datafusion_common::{not_impl_err, plan_datafusion_err, Result}; use datafusion_execution::TaskContext; use datafusion_expr::{ - create_udaf, create_udf, create_udwf, AggregateUDF, Expr, LogicalPlan, Volatility, - WindowUDF, + create_udaf, create_udf, create_udwf, AggregateUDF, Expr, LambdaUDF, LogicalPlan, + Signature, Volatility, WindowUDF, }; use prost::{ bytes::{Bytes, BytesMut}, @@ -167,6 +167,15 @@ impl Serializeable for Expr { ) } + fn register_udlf( + &mut self, + _udlf: Arc, + ) -> Result>> { + datafusion_common::internal_err!( + "register_udlf called in Placeholder Registry!" + ) + } + fn expr_planners(&self) -> Vec> { vec![] } @@ -178,6 +187,51 @@ impl Serializeable for Expr { fn udwfs(&self) -> std::collections::HashSet { std::collections::HashSet::default() } + + fn udlfs(&self) -> std::collections::HashSet { + std::collections::HashSet::default() + } + + fn udlf(&self, name: &str) -> Result> { + #[derive(Debug, PartialEq, Eq, Hash)] + struct MockLambdaUDF { + name: String, + signature: Signature, + } + + impl LambdaUDF for MockLambdaUDF { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + &self.name + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_field_from_args( + &self, + _args: datafusion_expr::LambdaReturnFieldArgs, + ) -> Result { + not_impl_err!("mock LambdaUDF") + } + + fn invoke_with_args( + &self, + _args: datafusion_expr::LambdaFunctionArgs, + ) -> Result { + not_impl_err!("mock LambdaUDF") + } + } + + Ok(Arc::new(MockLambdaUDF { + name: name.to_string(), + signature: Signature::variadic_any(Volatility::Immutable), + })) + } } Expr::from_bytes_with_registry(&bytes, &PlaceHolderRegistry)?; diff --git a/datafusion/proto/src/bytes/registry.rs b/datafusion/proto/src/bytes/registry.rs index 087e073db21af..98f4928457679 100644 --- a/datafusion/proto/src/bytes/registry.rs +++ b/datafusion/proto/src/bytes/registry.rs @@ -67,4 +67,14 @@ impl FunctionRegistry for NoRegistry { fn udwfs(&self) -> HashSet { HashSet::new() } + + fn udlfs(&self) -> HashSet { + HashSet::new() + } + + fn udlf(&self, name: &str) -> Result> { + plan_err!("No function registry provided to deserialize, so can not deserialize User Defined Lambda Function '{name}'") + } + + } diff --git a/datafusion/proto/src/logical_plan/mod.rs b/datafusion/proto/src/logical_plan/mod.rs index 9644c9f69feae..1122952771c79 100644 --- a/datafusion/proto/src/logical_plan/mod.rs +++ b/datafusion/proto/src/logical_plan/mod.rs @@ -63,7 +63,7 @@ use datafusion_expr::{ Statement, WindowUDF, }; use datafusion_expr::{ - AggregateUDF, DmlStatement, FetchType, RecursiveQuery, SkipType, TableSource, Unnest, + AggregateUDF, DmlStatement, FetchType, LambdaUDF, RecursiveQuery, SkipType, TableSource, Unnest }; use self::to_proto::{serialize_expr, serialize_exprs}; @@ -153,6 +153,14 @@ pub trait LogicalExtensionCodec: Debug + Send + Sync { fn try_encode_udf(&self, _node: &ScalarUDF, _buf: &mut Vec) -> Result<()> { Ok(()) } + + fn try_decode_udlf(&self, name: &str, _buf: &[u8]) -> Result> { + not_impl_err!("LogicalExtensionCodec is not provided for lambda function {name}") + } + + fn try_encode_udlf(&self, _node: &dyn LambdaUDF, _buf: &mut Vec) -> Result<()> { + Ok(()) + } fn try_decode_udaf(&self, name: &str, _buf: &[u8]) -> Result> { not_impl_err!( diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 41a9172fff276..e080411b49e95 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -622,6 +622,17 @@ pub fn serialize_expr( .unwrap_or(HashMap::new()), })), }, + Expr::LambdaFunction(func) => { + let mut buf = Vec::new(); + let _ = codec.try_encode_udlf(func.func.as_ref(), &mut buf); + protobuf::LogicalExprNode { + expr_type: Some(ExprType::ScalarUdfExpr(protobuf::ScalarUdfExprNode { + fun_name: func.name().to_string(), + fun_definition: (!buf.is_empty()).then_some(buf), + args: serialize_exprs(&func.args, codec)?, + })), + } + } Expr::Lambda(_) | Expr::LambdaVariable(_) => { return Err(Error::General( "Proto serialization error: Lambda not implemented".to_string(), diff --git a/datafusion/session/src/session.rs b/datafusion/session/src/session.rs index fd033172f224f..625b3fb77a4d8 100644 --- a/datafusion/session/src/session.rs +++ b/datafusion/session/src/session.rs @@ -22,7 +22,7 @@ use datafusion_execution::config::SessionConfig; use datafusion_execution::runtime_env::RuntimeEnv; use datafusion_execution::TaskContext; use datafusion_expr::execution_props::ExecutionProps; -use datafusion_expr::{AggregateUDF, Expr, LogicalPlan, ScalarUDF, WindowUDF}; +use datafusion_expr::{AggregateUDF, Expr, LambdaUDF, LogicalPlan, ScalarUDF, WindowUDF}; use datafusion_physical_plan::{ExecutionPlan, PhysicalExpr}; use parking_lot::{Mutex, RwLock}; use std::any::Any; @@ -109,6 +109,9 @@ pub trait Session: Send + Sync { /// Return reference to scalar_functions fn scalar_functions(&self) -> &HashMap>; + + /// Return reference to lambda_functions + fn lambda_functions(&self) -> &HashMap>; /// Return reference to aggregate_functions fn aggregate_functions(&self) -> &HashMap>; @@ -149,6 +152,7 @@ impl From<&dyn Session> for TaskContext { state.session_id().to_string(), state.config().clone(), state.scalar_functions().clone(), + state.lambda_functions().clone(), state.aggregate_functions().clone(), state.window_functions().clone(), Arc::clone(state.runtime_env()), diff --git a/datafusion/sql/examples/sql.rs b/datafusion/sql/examples/sql.rs index 2c0bb86cd8087..3d2ff0528081c 100644 --- a/datafusion/sql/examples/sql.rs +++ b/datafusion/sql/examples/sql.rs @@ -22,7 +22,7 @@ use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::config::ConfigOptions; use datafusion_common::{plan_err, Result, TableReference}; use datafusion_expr::planner::ExprPlanner; -use datafusion_expr::WindowUDF; +use datafusion_expr::{LambdaUDF, WindowUDF}; use datafusion_expr::{ logical_plan::builder::LogicalTableSource, AggregateUDF, ScalarUDF, TableSource, }; @@ -138,6 +138,10 @@ impl ContextProvider for MyContextProvider { None } + fn get_lambda_meta(&self, _name: &str) -> Option> { + None + } + fn get_aggregate_meta(&self, name: &str) -> Option> { self.udafs.get(name).cloned() } diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index 439e65e8f7e47..47f132d065980 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -24,7 +24,7 @@ use datafusion_common::{ internal_datafusion_err, internal_err, not_impl_err, plan_datafusion_err, plan_err, DFSchema, Dependency, Diagnostic, Result, Span, }; -use datafusion_expr::expr::{Lambda, ScalarFunction, Unnest}; +use datafusion_expr::expr::{Lambda, LambdaFunction, ScalarFunction, Unnest}; use datafusion_expr::expr::{NullTreatment, WildcardOptions, WindowFunction}; use datafusion_expr::planner::PlannerResult; use datafusion_expr::planner::{RawAggregateExpr, RawWindowExpr}; @@ -277,8 +277,8 @@ impl SqlToRel<'_, S> { } } } - // User-defined function (UDF) should have precedence - if let Some(fm) = self.context_provider.get_function_meta(&name) { + + if let Some(fm) = self.context_provider.get_lambda_meta(&name) { enum ExprOrLambda { ExprWithName((Expr, Option)), Lambda(sqlparser::ast::LambdaFunction), @@ -312,7 +312,7 @@ impl SqlToRel<'_, S> { }) .collect::>>()?; - let lambdas_parameters = fm.inner().lambdas_parameters(&metadata)?; + let lambdas_parameters = fm.lambdas_parameters(&metadata)?; let pairs = pairs .into_iter() @@ -385,6 +385,55 @@ impl SqlToRel<'_, S> { args }; + // After resolution, all arguments are positional + let inner = LambdaFunction::new(fm, resolved_args); + + if name.eq_ignore_ascii_case(inner.name()) { + return Ok(Expr::LambdaFunction(inner)); + } else { + // If the function is called by an alias, a verbose string representation is created + // (e.g., "my_alias(arg1, arg2)") and the expression is wrapped in an `Alias` + // to ensure the output column name matches the user's query. + let arg_names = inner + .args + .iter() + .map(|arg| arg.to_string()) + .collect::>() + .join(","); + let verbose_alias = format!("{name}({arg_names})"); + + return Ok(Expr::LambdaFunction(inner).alias(verbose_alias)); + } + } + + // User-defined function (UDF) should have precedence + if let Some(fm) = self.context_provider.get_function_meta(&name) { + let (args, arg_names): (Vec, Vec>) = args + .into_iter() + .map(|a| { + self.sql_fn_arg_to_logical_expr_with_name(a, schema, planner_context) + }) + .collect::>>()? + .into_iter() + .unzip(); + + let resolved_args = if arg_names.iter().any(|name| name.is_some()) { + if let Some(param_names) = &fm.signature().parameter_names { + datafusion_expr::arguments::resolve_function_arguments( + param_names, + args, + arg_names, + )? + } else { + return plan_err!( + "Function '{}' does not support named arguments", + fm.name() + ); + } + } else { + args + }; + // After resolution, all arguments are positional let inner = ScalarFunction::new_udf(fm, resolved_args); diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index 715a02db8b027..e51f1c04cf157 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -1207,7 +1207,7 @@ mod tests { use datafusion_common::config::ConfigOptions; use datafusion_common::TableReference; use datafusion_expr::logical_plan::builder::LogicalTableSource; - use datafusion_expr::{AggregateUDF, ScalarUDF, TableSource, WindowUDF}; + use datafusion_expr::{AggregateUDF, LambdaUDF, ScalarUDF, TableSource, WindowUDF}; use super::*; @@ -1247,6 +1247,10 @@ mod tests { None } + fn get_lambda_meta(&self, _name: &str) -> Option> { + None + } + fn get_aggregate_meta(&self, name: &str) -> Option> { match name { "sum" => Some(datafusion_functions_aggregate::sum::sum_udaf()), diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index 013a6f1128957..3bd669dbec071 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -15,12 +15,12 @@ // specific language governing permissions and limitations // under the License. -use datafusion_expr::expr::{AggregateFunctionParams, WindowFunctionParams}; +use datafusion_expr::expr::{AggregateFunctionParams, LambdaFunction, WindowFunctionParams}; use datafusion_expr::expr::{Lambda, Unnest}; use sqlparser::ast::Value::SingleQuotedString; use sqlparser::ast::{ self, Array, BinaryOperator, Expr as AstExpr, Function, Ident, Interval, - LambdaFunction, ObjectName, Subscript, TimezoneInfo, UnaryOperator, + ObjectName, Subscript, TimezoneInfo, UnaryOperator, }; use sqlparser::ast::{CaseWhen, DuplicateTreatment, OrderByOptions, ValueWithSpan}; use std::sync::Arc; @@ -528,8 +528,20 @@ impl Unparser<'_> { } Expr::OuterReferenceColumn(_, col) => self.col_to_sql(col), Expr::Unnest(unnest) => self.unnest_to_sql(unnest), + Expr::LambdaFunction(LambdaFunction { func, args }) => { + let func_name = func.name(); + + if let Some(expr) = self + .dialect + .scalar_function_to_sql_overrides(self, func_name, args)? + { + return Ok(expr); + } + + self.scalar_function_to_sql(func_name, args) + } Expr::Lambda(Lambda { params, body }) => { - Ok(ast::Expr::Lambda(LambdaFunction { + Ok(ast::Expr::Lambda(ast::LambdaFunction { params: ast::OneOrManyWithParens::Many( params.iter().map(|param| param.as_str().into()).collect(), ), diff --git a/datafusion/sql/tests/common/mod.rs b/datafusion/sql/tests/common/mod.rs index 5d9fd9f2c3740..6c9ac4bf70046 100644 --- a/datafusion/sql/tests/common/mod.rs +++ b/datafusion/sql/tests/common/mod.rs @@ -26,7 +26,7 @@ use datafusion_common::config::ConfigOptions; use datafusion_common::file_options::file_type::FileType; use datafusion_common::{plan_err, DFSchema, GetExt, Result, TableReference}; use datafusion_expr::planner::{ExprPlanner, PlannerResult, TypePlanner}; -use datafusion_expr::{AggregateUDF, Expr, ScalarUDF, TableSource, WindowUDF}; +use datafusion_expr::{AggregateUDF, Expr, LambdaUDF, ScalarUDF, TableSource, WindowUDF}; use datafusion_functions_nested::expr_fn::make_array; use datafusion_sql::planner::ContextProvider; @@ -53,6 +53,7 @@ impl Display for MockCsvType { #[derive(Default)] pub(crate) struct MockSessionState { scalar_functions: HashMap>, + lambda_functions: HashMap>, aggregate_functions: HashMap>, expr_planners: Vec>, type_planner: Option>, @@ -240,6 +241,10 @@ impl ContextProvider for MockContextProvider { self.state.scalar_functions.get(name).cloned() } + fn get_lambda_meta(&self, name: &str) -> Option> { + self.state.lambda_functions.get(name).cloned() + } + fn get_aggregate_meta(&self, name: &str) -> Option> { self.state.aggregate_functions.get(name).cloned() } diff --git a/datafusion/substrait/src/logical_plan/producer/expr/mod.rs b/datafusion/substrait/src/logical_plan/producer/expr/mod.rs index b3ca88a690291..d1112b99536d9 100644 --- a/datafusion/substrait/src/logical_plan/producer/expr/mod.rs +++ b/datafusion/substrait/src/logical_plan/producer/expr/mod.rs @@ -152,8 +152,9 @@ pub fn to_substrait_rex( not_impl_err!("Cannot convert {expr:?} to Substrait") } Expr::Unnest(expr) => not_impl_err!("Cannot convert {expr:?} to Substrait"), - Expr::Lambda(expr) => not_impl_err!("Cannot convert {expr:?} to Substrait"), // https://github.com/substrait-io/substrait/issues/349 - Expr::LambdaVariable(expr) => not_impl_err!("Cannot convert {expr:?} to Substrait"), // https://github.com/substrait-io/substrait/issues/349 + Expr::LambdaFunction(expr) => producer.handle_lambda_function(expr, schema), + Expr::Lambda(expr) => not_impl_err!("Cannot convert {expr:?} to Substrait"), // not yet implemented in substrait-rs + Expr::LambdaVariable(expr) => not_impl_err!("Cannot convert {expr:?} to Substrait"), // not yet implemented in substrait-rs } } diff --git a/datafusion/substrait/src/logical_plan/producer/expr/scalar_function.rs b/datafusion/substrait/src/logical_plan/producer/expr/scalar_function.rs index abb26f6f66822..b2057a9d914f8 100644 --- a/datafusion/substrait/src/logical_plan/producer/expr/scalar_function.rs +++ b/datafusion/substrait/src/logical_plan/producer/expr/scalar_function.rs @@ -26,17 +26,34 @@ pub fn from_scalar_function( producer: &mut impl SubstraitProducer, fun: &expr::ScalarFunction, schema: &DFSchemaRef, +) -> datafusion::common::Result { + from_function(producer, fun.name(), &fun.args, schema) +} + +pub fn from_lambda_function( + producer: &mut impl SubstraitProducer, + fun: &expr::LambdaFunction, + schema: &DFSchemaRef, +) -> datafusion::common::Result { + from_function(producer, fun.name(), &fun.args, schema) +} + +fn from_function( + producer: &mut impl SubstraitProducer, + name: &str, + args: &[Expr], + schema: &DFSchemaRef, ) -> datafusion::common::Result { let mut arguments: Vec = vec![]; - for arg in &fun.args { + for arg in args { arguments.push(FunctionArgument { arg_type: Some(ArgType::Value(producer.handle_expr(arg, schema)?)), }); } - let arguments = custom_argument_handler(fun.name(), arguments); + let arguments = custom_argument_handler(name, arguments); - let function_anchor = producer.register_function(fun.name().to_string()); + let function_anchor = producer.register_function(name.to_string()); #[allow(deprecated)] Ok(Expression { rex_type: Some(RexType::ScalarFunction(ScalarFunction { diff --git a/datafusion/substrait/src/logical_plan/producer/substrait_producer.rs b/datafusion/substrait/src/logical_plan/producer/substrait_producer.rs index db08e0f7bfd0c..d065bcf41586a 100644 --- a/datafusion/substrait/src/logical_plan/producer/substrait_producer.rs +++ b/datafusion/substrait/src/logical_plan/producer/substrait_producer.rs @@ -17,12 +17,7 @@ use crate::extensions::Extensions; use crate::logical_plan::producer::{ - from_aggregate, from_aggregate_function, from_alias, from_between, from_binary_expr, - from_case, from_cast, from_column, from_distinct, from_empty_relation, from_filter, - from_in_list, from_in_subquery, from_join, from_like, from_limit, from_literal, - from_projection, from_repartition, from_scalar_function, from_sort, - from_subquery_alias, from_table_scan, from_try_cast, from_unary_expr, from_union, - from_values, from_window, from_window_function, to_substrait_rel, to_substrait_rex, + from_aggregate, from_aggregate_function, from_alias, from_between, from_binary_expr, from_case, from_cast, from_column, from_distinct, from_empty_relation, from_filter, from_in_list, from_in_subquery, from_join, from_lambda_function, from_like, from_limit, from_literal, from_projection, from_repartition, from_scalar_function, from_sort, from_subquery_alias, from_table_scan, from_try_cast, from_unary_expr, from_union, from_values, from_window, from_window_function, to_substrait_rel, to_substrait_rex }; use datafusion::common::{substrait_err, Column, DFSchemaRef, ScalarValue}; use datafusion::execution::registry::SerializerRegistry; @@ -327,6 +322,14 @@ pub trait SubstraitProducer: Send + Sync + Sized { ) -> datafusion::common::Result { from_scalar_function(self, scalar_fn, schema) } + + fn handle_lambda_function( + &mut self, + scalar_fn: &expr::LambdaFunction, + schema: &DFSchemaRef, + ) -> datafusion::common::Result { + from_lambda_function(self, scalar_fn, schema) + } fn handle_aggregate_function( &mut self, From 1f19c6406b97e05ffbd5088bfeaf5c0e7320a622 Mon Sep 17 00:00:00 2001 From: gstvg <28798827+gstvg@users.noreply.github.com> Date: Mon, 23 Feb 2026 14:40:03 -0300 Subject: [PATCH 06/47] feat: remove lambda support for ScalarUDF --- datafusion/expr/src/expr_schema.rs | 6 - datafusion/expr/src/udf.rs | 105 ----------------- datafusion/ffi/src/udf/mod.rs | 8 +- datafusion/ffi/src/udf/return_type_args.rs | 9 +- datafusion/functions-nested/benches/map.rs | 1 - datafusion/functions-nested/src/array_has.rs | 2 - datafusion/functions-nested/src/map_values.rs | 1 - datafusion/functions-nested/src/set_ops.rs | 1 - datafusion/functions/benches/ascii.rs | 4 - .../functions/benches/character_length.rs | 4 - datafusion/functions/benches/chr.rs | 1 - datafusion/functions/benches/concat.rs | 1 - datafusion/functions/benches/cot.rs | 2 - datafusion/functions/benches/date_bin.rs | 1 - datafusion/functions/benches/date_trunc.rs | 1 - datafusion/functions/benches/encoding.rs | 4 - datafusion/functions/benches/find_in_set.rs | 4 - datafusion/functions/benches/gcd.rs | 3 - datafusion/functions/benches/initcap.rs | 3 - datafusion/functions/benches/isnan.rs | 2 - datafusion/functions/benches/iszero.rs | 2 - datafusion/functions/benches/lower.rs | 6 - datafusion/functions/benches/ltrim.rs | 1 - datafusion/functions/benches/make_date.rs | 4 - datafusion/functions/benches/nullif.rs | 1 - datafusion/functions/benches/pad.rs | 1 - datafusion/functions/benches/random.rs | 2 - datafusion/functions/benches/repeat.rs | 1 - datafusion/functions/benches/reverse.rs | 4 - datafusion/functions/benches/signum.rs | 2 - datafusion/functions/benches/strpos.rs | 4 - datafusion/functions/benches/substr.rs | 1 - datafusion/functions/benches/substr_index.rs | 1 - datafusion/functions/benches/to_char.rs | 6 - datafusion/functions/benches/to_hex.rs | 2 - datafusion/functions/benches/to_timestamp.rs | 6 - datafusion/functions/benches/trunc.rs | 2 - datafusion/functions/benches/upper.rs | 1 - datafusion/functions/benches/uuid.rs | 1 - .../functions/src/core/union_extract.rs | 3 - datafusion/functions/src/core/union_tag.rs | 2 - datafusion/functions/src/core/version.rs | 1 - datafusion/functions/src/datetime/date_bin.rs | 1 - .../functions/src/datetime/date_trunc.rs | 2 - .../functions/src/datetime/from_unixtime.rs | 2 - .../functions/src/datetime/make_date.rs | 1 - datafusion/functions/src/datetime/now.rs | 2 - datafusion/functions/src/datetime/to_char.rs | 7 -- datafusion/functions/src/datetime/to_date.rs | 1 - .../functions/src/datetime/to_local_time.rs | 2 - .../functions/src/datetime/to_timestamp.rs | 2 - datafusion/functions/src/math/log.rs | 18 --- datafusion/functions/src/math/power.rs | 2 - datafusion/functions/src/math/signum.rs | 2 - datafusion/functions/src/regex/regexpcount.rs | 1 - datafusion/functions/src/regex/regexpinstr.rs | 1 - datafusion/functions/src/string/concat.rs | 1 - datafusion/functions/src/string/concat_ws.rs | 2 - datafusion/functions/src/string/contains.rs | 1 - datafusion/functions/src/string/lower.rs | 1 - datafusion/functions/src/string/upper.rs | 1 - .../functions/src/unicode/find_in_set.rs | 1 - datafusion/functions/src/unicode/strpos.rs | 1 - datafusion/functions/src/utils.rs | 3 - .../src/async_scalar_function.rs | 2 - .../physical-expr/src/scalar_function.rs | 108 ++---------------- datafusion/spark/benches/char.rs | 1 - .../spark/src/function/bitmap/bitmap_count.rs | 1 - .../src/function/datetime/make_dt_interval.rs | 1 - .../src/function/datetime/make_interval.rs | 1 - .../spark/src/function/string/concat.rs | 2 - datafusion/spark/src/function/utils.rs | 3 - 72 files changed, 10 insertions(+), 381 deletions(-) diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index 6a3a7cbc85e76..f3789ca9fd115 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -595,15 +595,9 @@ impl ExprSchemable for Expr { }) .collect::>(); - let lambdas = args - .iter() - .map(|e| matches!(e, Expr::Lambda { .. })) - .collect::>(); - let args = ReturnFieldArgs { arg_fields: &new_fields, scalar_arguments: &arguments, - lambdas: &lambdas, }; func.return_field_from_args(args) diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index 911fc890e2bc5..fd54bb13a62f3 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -23,13 +23,11 @@ use crate::simplify::{ExprSimplifyResult, SimplifyInfo}; use crate::sort_properties::{ExprProperties, SortProperties}; use crate::udf_eq::UdfEq; use crate::{ColumnarValue, Documentation, Expr, Signature}; -use arrow::array::RecordBatch; use arrow::datatypes::{DataType, Field, FieldRef}; use datafusion_common::config::ConfigOptions; use datafusion_common::{not_impl_err, ExprSchema, Result, ScalarValue}; use datafusion_expr_common::dyn_eq::{DynEq, DynHash}; use datafusion_expr_common::interval_arithmetic::Interval; -use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use std::any::Any; use std::cmp::Ordering; use std::fmt::Debug; @@ -347,14 +345,6 @@ impl ScalarUDF { } } -#[derive(Clone, Debug)] -pub enum ValueOrLambdaParameter { - /// A columnar value with the given field - Value(FieldRef), - /// A lambda - Lambda, -} - impl From for ScalarUDF where F: ScalarUDFImpl + 'static, @@ -369,7 +359,6 @@ where #[derive(Debug, Clone)] pub struct ScalarFunctionArgs { /// The evaluated arguments to the function - /// If it's a lambda, will be `ColumnarValue::Scalar(ScalarValue::Null)` pub args: Vec, /// Field associated with each arg, if it exists pub arg_fields: Vec, @@ -381,30 +370,6 @@ pub struct ScalarFunctionArgs { pub return_field: FieldRef, /// The config options at execution time pub config_options: Arc, - /// The lambdas passed to the function - /// If it's not a lambda it will be `None` - pub lambdas: Option>>, -} - -/// A lambda argument to a ScalarFunction -#[derive(Clone, Debug)] -pub struct ScalarFunctionLambdaArg { - /// The parameters defined in this lambda - /// - /// For example, for `array_transform([2], v -> -v)`, - /// this will be `vec![Field::new("v", DataType::Int32, true)]` - pub params: Vec, - /// The body of the lambda - /// - /// For example, for `array_transform([2], v -> -v)`, - /// this will be the physical expression of `-v` - pub body: Arc, - /// A RecordBatch containing at least the captured columns inside this lambda body, if any - /// Note that it may contain additional, non-specified columns, but that's implementation detail - /// - /// For example, for `array_transform([2], v -> v + a + b)`, - /// this will be a `RecordBatch` with two columns, `a` and `b` - pub captures: Option, } impl ScalarFunctionArgs { @@ -413,25 +378,6 @@ impl ScalarFunctionArgs { pub fn return_type(&self) -> &DataType { self.return_field.data_type() } - - pub fn to_lambda_args(&self) -> Vec> { - match &self.lambdas { - Some(lambdas) => std::iter::zip(&self.args, lambdas) - .map(|(arg, lambda)| match lambda { - Some(lambda) => ValueOrLambda::Lambda(lambda), - None => ValueOrLambda::Value(arg), - }) - .collect(), - None => self.args.iter().map(ValueOrLambda::Value).collect(), - } - } -} - -// An argument to a ScalarUDF that supports lambdas -#[derive(Debug)] -pub enum ValueOrLambda<'a> { - Value(&'a ColumnarValue), - Lambda(&'a ScalarFunctionLambdaArg), } /// Information about arguments passed to the function @@ -444,12 +390,6 @@ pub enum ValueOrLambda<'a> { #[derive(Debug)] pub struct ReturnFieldArgs<'a> { /// The data types of the arguments to the function - /// - /// If argument `i` to the function is a lambda, it will be the field returned by the - /// lambda when executed with the arguments returned from `ScalarUDFImpl::lambdas_parameters` - /// - /// For example, with `array_transform([1], v -> v == 5)` - /// this field will be `[Field::new("", DataType::List(DataType::Int32), false), Field::new("", DataType::Boolean, false)]` pub arg_fields: &'a [FieldRef], /// Is argument `i` to the function a scalar (constant)? /// @@ -458,36 +398,6 @@ pub struct ReturnFieldArgs<'a> { /// For example, if a function is called like `my_function(column_a, 5)` /// this field will be `[None, Some(ScalarValue::Int32(Some(5)))]` pub scalar_arguments: &'a [Option<&'a ScalarValue>], - /// Is argument `i` to the function a lambda? - /// - /// For example, with `array_transform([1], v -> v == 5)` - /// this field will be `[false, true]` - pub lambdas: &'a [bool], -} - -/// A tagged Field indicating whether it correspond to a value or a lambda argument -#[derive(Debug)] -pub enum ValueOrLambdaField<'a> { - /// The Field of a ColumnarValue argument - Value(&'a FieldRef), - /// The Field of the return of the lambda body when evaluated with the parameters from ScalarUDF::lambda_parameters - Lambda(&'a FieldRef), -} - -impl<'a> ReturnFieldArgs<'a> { - /// Based on self.lambdas, encodes self.arg_fields to tagged enums - /// indicating whether it correspond to a value or a lambda argument - pub fn to_lambda_args(&self) -> Vec> { - std::iter::zip(self.arg_fields, self.lambdas) - .map(|(field, is_lambda)| { - if *is_lambda { - ValueOrLambdaField::Lambda(field) - } else { - ValueOrLambdaField::Value(field) - } - }) - .collect() - } } /// Trait for implementing user defined scalar functions. @@ -931,14 +841,6 @@ pub trait ScalarUDFImpl: Debug + DynEq + DynHash + Send + Sync { fn documentation(&self) -> Option<&Documentation> { None } - - /// Returns the parameters that any lambda supports - fn lambdas_parameters( - &self, - args: &[ValueOrLambdaParameter], - ) -> Result>>> { - Ok(vec![None; args.len()]) - } } /// ScalarUDF that adds an alias to the underlying function. It is better to @@ -1057,13 +959,6 @@ impl ScalarUDFImpl for AliasedScalarUDFImpl { fn documentation(&self) -> Option<&Documentation> { self.inner.documentation() } - - fn lambdas_parameters( - &self, - args: &[ValueOrLambdaParameter], - ) -> Result>>> { - self.inner.lambdas_parameters(args) - } } #[cfg(test)] diff --git a/datafusion/ffi/src/udf/mod.rs b/datafusion/ffi/src/udf/mod.rs index 400ad44696047..5e59cfc5ecb07 100644 --- a/datafusion/ffi/src/udf/mod.rs +++ b/datafusion/ffi/src/udf/mod.rs @@ -33,7 +33,7 @@ use arrow::{ }; use arrow_schema::FieldRef; use datafusion::config::ConfigOptions; -use datafusion::{common::exec_err, logical_expr::ReturnFieldArgs}; +use datafusion::logical_expr::ReturnFieldArgs; use datafusion::{ error::DataFusionError, logical_expr::type_coercion::functions::data_types_with_scalar_udf, @@ -210,7 +210,6 @@ unsafe extern "C" fn invoke_with_args_fn_wrapper( return_field, // TODO: pass config options: https://github.com/apache/datafusion/issues/17035 config_options: Arc::new(ConfigOptions::default()), - lambdas: None, }; let result = rresult_return!(udf @@ -383,15 +382,10 @@ impl ScalarUDFImpl for ForeignScalarUDF { arg_fields, number_rows, return_field, - lambdas, // TODO: pass config options: https://github.com/apache/datafusion/issues/17035 config_options: _config_options, } = invoke_args; - if lambdas.is_some_and(|lambdas| lambdas.iter().any(|l| l.is_some())) { - return exec_err!("ForeignScalarUDF doesn't support lambdas"); - } - let args = args .into_iter() .map(|v| v.to_array(number_rows)) diff --git a/datafusion/ffi/src/udf/return_type_args.rs b/datafusion/ffi/src/udf/return_type_args.rs index d5cbfff1d3a4b..c437c9537be6f 100644 --- a/datafusion/ffi/src/udf/return_type_args.rs +++ b/datafusion/ffi/src/udf/return_type_args.rs @@ -21,7 +21,7 @@ use abi_stable::{ }; use arrow_schema::FieldRef; use datafusion::{ - common::{exec_datafusion_err, exec_err}, error::DataFusionError, logical_expr::ReturnFieldArgs, + common::exec_datafusion_err, error::DataFusionError, logical_expr::ReturnFieldArgs, scalar::ScalarValue, }; @@ -42,10 +42,6 @@ impl TryFrom> for FFI_ReturnFieldArgs { type Error = DataFusionError; fn try_from(value: ReturnFieldArgs) -> Result { - if value.lambdas.iter().any(|l| *l) { - return exec_err!("FFI_ReturnFieldArgs doesn't support lambdas") - } - let arg_fields = vec_fieldref_to_rvec_wrapped(value.arg_fields)?; let scalar_arguments: Result, Self::Error> = value .scalar_arguments @@ -81,7 +77,6 @@ pub struct ForeignReturnFieldArgsOwned { pub struct ForeignReturnFieldArgs<'a> { arg_fields: &'a [FieldRef], scalar_arguments: Vec>, - lambdas: Vec, // currently always false, used to return a reference in From<&Self> for ReturnFieldArgs } impl TryFrom<&FFI_ReturnFieldArgs> for ForeignReturnFieldArgsOwned { @@ -121,7 +116,6 @@ impl<'a> From<&'a ForeignReturnFieldArgsOwned> for ForeignReturnFieldArgs<'a> { .iter() .map(|opt| opt.as_ref()) .collect(), - lambdas: vec![false; value.arg_fields.len()] } } } @@ -131,7 +125,6 @@ impl<'a> From<&'a ForeignReturnFieldArgs<'a>> for ReturnFieldArgs<'a> { ReturnFieldArgs { arg_fields: value.arg_fields, scalar_arguments: &value.scalar_arguments, - lambdas: &value.lambdas, } } } diff --git a/datafusion/functions-nested/benches/map.rs b/datafusion/functions-nested/benches/map.rs index 3075d2e573e4a..3197cc55cc957 100644 --- a/datafusion/functions-nested/benches/map.rs +++ b/datafusion/functions-nested/benches/map.rs @@ -117,7 +117,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: 1, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), - lambdas: None, }) .expect("map should work on valid values"), ); diff --git a/datafusion/functions-nested/src/array_has.rs b/datafusion/functions-nested/src/array_has.rs index d6a333c0a0ef3..080b2f16d92f3 100644 --- a/datafusion/functions-nested/src/array_has.rs +++ b/datafusion/functions-nested/src/array_has.rs @@ -819,7 +819,6 @@ mod tests { number_rows: 1, return_field, config_options: Arc::new(ConfigOptions::default()), - lambdas: None, })?; let output = result.into_array(1)?; @@ -848,7 +847,6 @@ mod tests { number_rows: 1, return_field, config_options: Arc::new(ConfigOptions::default()), - lambdas: None, })?; let output = result.into_array(1)?; diff --git a/datafusion/functions-nested/src/map_values.rs b/datafusion/functions-nested/src/map_values.rs index ac21ff8acd3f9..6ae8a278063da 100644 --- a/datafusion/functions-nested/src/map_values.rs +++ b/datafusion/functions-nested/src/map_values.rs @@ -204,7 +204,6 @@ mod tests { let args = datafusion_expr::ReturnFieldArgs { arg_fields: &[field], scalar_arguments: &[None::<&ScalarValue>], - lambdas: &[false], }; func.return_field_from_args(args).unwrap() diff --git a/datafusion/functions-nested/src/set_ops.rs b/datafusion/functions-nested/src/set_ops.rs index f26fc173d8a9f..53642bf1622b0 100644 --- a/datafusion/functions-nested/src/set_ops.rs +++ b/datafusion/functions-nested/src/set_ops.rs @@ -596,7 +596,6 @@ mod tests { number_rows: 1, return_field: input_field.clone().into(), config_options: Arc::new(ConfigOptions::default()), - lambdas: None, })?; assert_eq!( diff --git a/datafusion/functions/benches/ascii.rs b/datafusion/functions/benches/ascii.rs index 97e6ab20ed458..03d25e9c3d4fe 100644 --- a/datafusion/functions/benches/ascii.rs +++ b/datafusion/functions/benches/ascii.rs @@ -60,7 +60,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: N_ROWS, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), - lambdas: None, })) }) }, @@ -82,7 +81,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: N_ROWS, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), - lambdas: None, })) }) }, @@ -110,7 +108,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: N_ROWS, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), - lambdas: None, })) }) }, @@ -132,7 +129,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: N_ROWS, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), - lambdas: None, })) }) }, diff --git a/datafusion/functions/benches/character_length.rs b/datafusion/functions/benches/character_length.rs index f98e8a8b1a68b..4a1a63d62765f 100644 --- a/datafusion/functions/benches/character_length.rs +++ b/datafusion/functions/benches/character_length.rs @@ -55,7 +55,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: n_rows, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), - lambdas: None, })) }) }, @@ -80,7 +79,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: n_rows, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), - lambdas: None, })) }) }, @@ -105,7 +103,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: n_rows, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), - lambdas: None, })) }) }, @@ -130,7 +127,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: n_rows, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), - lambdas: None, })) }) }, diff --git a/datafusion/functions/benches/chr.rs b/datafusion/functions/benches/chr.rs index d51cda4566d64..8356cf7c31726 100644 --- a/datafusion/functions/benches/chr.rs +++ b/datafusion/functions/benches/chr.rs @@ -69,7 +69,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: size, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), - lambdas: None, }) .unwrap(), ) diff --git a/datafusion/functions/benches/concat.rs b/datafusion/functions/benches/concat.rs index 6378328537827..09200139a244b 100644 --- a/datafusion/functions/benches/concat.rs +++ b/datafusion/functions/benches/concat.rs @@ -60,7 +60,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: size, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), - lambdas: None, }) .unwrap(), ) diff --git a/datafusion/functions/benches/cot.rs b/datafusion/functions/benches/cot.rs index 56f50522acc5d..97f21ccd6d55e 100644 --- a/datafusion/functions/benches/cot.rs +++ b/datafusion/functions/benches/cot.rs @@ -54,7 +54,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: size, return_field: Field::new("f", DataType::Float32, true).into(), config_options: Arc::clone(&config_options), - lambdas: None, }) .unwrap(), ) @@ -81,7 +80,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: size, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), - lambdas: None, }) .unwrap(), ) diff --git a/datafusion/functions/benches/date_bin.rs b/datafusion/functions/benches/date_bin.rs index 1c3713723738a..74390491d538c 100644 --- a/datafusion/functions/benches/date_bin.rs +++ b/datafusion/functions/benches/date_bin.rs @@ -66,7 +66,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: batch_len, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), - lambdas: None, }) .expect("date_bin should work on valid values"), ) diff --git a/datafusion/functions/benches/date_trunc.rs b/datafusion/functions/benches/date_trunc.rs index b757535fb03c5..498a3e63ef290 100644 --- a/datafusion/functions/benches/date_trunc.rs +++ b/datafusion/functions/benches/date_trunc.rs @@ -71,7 +71,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: batch_len, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), - lambdas: None, }) .expect("date_trunc should work on valid values"), ) diff --git a/datafusion/functions/benches/encoding.rs b/datafusion/functions/benches/encoding.rs index 72b033cf5d9ed..98faee91e1911 100644 --- a/datafusion/functions/benches/encoding.rs +++ b/datafusion/functions/benches/encoding.rs @@ -45,7 +45,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: size, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), - lambdas: None, }) .unwrap(); @@ -64,7 +63,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: size, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), - lambdas: None, }) .unwrap(), ) @@ -84,7 +82,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: size, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), - lambdas: None, }) .unwrap(); @@ -104,7 +101,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: size, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), - lambdas: None, }) .unwrap(), ) diff --git a/datafusion/functions/benches/find_in_set.rs b/datafusion/functions/benches/find_in_set.rs index 6fe498a58d84b..a928f5655806c 100644 --- a/datafusion/functions/benches/find_in_set.rs +++ b/datafusion/functions/benches/find_in_set.rs @@ -168,7 +168,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: n_rows, return_field: Arc::clone(&return_field), config_options: Arc::new(ConfigOptions::default()), - lambdas: None, })) }) }); @@ -187,7 +186,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: n_rows, return_field: Arc::clone(&return_field), config_options: Arc::new(ConfigOptions::default()), - lambdas: None, })) }) }); @@ -210,7 +208,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: n_rows, return_field: Arc::clone(&return_field), config_options: Arc::new(ConfigOptions::default()), - lambdas: None, })) }) }); @@ -231,7 +228,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: n_rows, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), - lambdas: None, })) }) }); diff --git a/datafusion/functions/benches/gcd.rs b/datafusion/functions/benches/gcd.rs index 2bfec91e290dd..19e196d9a3eab 100644 --- a/datafusion/functions/benches/gcd.rs +++ b/datafusion/functions/benches/gcd.rs @@ -58,7 +58,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: 0, return_field: Field::new("f", DataType::Int64, true).into(), config_options: Arc::clone(&config_options), - lambdas: None, }) .expect("date_bin should work on valid values"), ) @@ -80,7 +79,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: 0, return_field: Field::new("f", DataType::Int64, true).into(), config_options: Arc::clone(&config_options), - lambdas: None, }) .expect("date_bin should work on valid values"), ) @@ -102,7 +100,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: 0, return_field: Field::new("f", DataType::Int64, true).into(), config_options: Arc::clone(&config_options), - lambdas: None, }) .expect("date_bin should work on valid values"), ) diff --git a/datafusion/functions/benches/initcap.rs b/datafusion/functions/benches/initcap.rs index 37d98596deb82..50aee8dbb9161 100644 --- a/datafusion/functions/benches/initcap.rs +++ b/datafusion/functions/benches/initcap.rs @@ -70,7 +70,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: size, return_field: Field::new("f", DataType::Utf8View, true).into(), config_options: Arc::clone(&config_options), - lambdas: None, })) }) }, @@ -87,7 +86,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: size, return_field: Field::new("f", DataType::Utf8View, true).into(), config_options: Arc::clone(&config_options), - lambdas: None, })) }) }, @@ -102,7 +100,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: size, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), - lambdas: None, })) }) }); diff --git a/datafusion/functions/benches/isnan.rs b/datafusion/functions/benches/isnan.rs index dcce59e46ce41..4a90d45d66223 100644 --- a/datafusion/functions/benches/isnan.rs +++ b/datafusion/functions/benches/isnan.rs @@ -53,7 +53,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: size, return_field: Field::new("f", DataType::Boolean, true).into(), config_options: Arc::clone(&config_options), - lambdas: None, }) .unwrap(), ) @@ -78,7 +77,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: size, return_field: Field::new("f", DataType::Boolean, true).into(), config_options: Arc::clone(&config_options), - lambdas: None, }) .unwrap(), ) diff --git a/datafusion/functions/benches/iszero.rs b/datafusion/functions/benches/iszero.rs index 574539fbb6427..961cba7200ce0 100644 --- a/datafusion/functions/benches/iszero.rs +++ b/datafusion/functions/benches/iszero.rs @@ -55,7 +55,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: batch_len, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), - lambdas: None, }) .unwrap(), ) @@ -83,7 +82,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: batch_len, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), - lambdas: None, }) .unwrap(), ) diff --git a/datafusion/functions/benches/lower.rs b/datafusion/functions/benches/lower.rs index e741afd0d8e01..6a5178b87fdce 100644 --- a/datafusion/functions/benches/lower.rs +++ b/datafusion/functions/benches/lower.rs @@ -145,7 +145,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: size, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), - lambdas: None, })) }) }); @@ -168,7 +167,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: size, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), - lambdas: None, })) }) }); @@ -193,7 +191,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: size, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), - lambdas: None, })) }) }, @@ -228,7 +225,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: size, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), - lambdas: None, })) }), ); @@ -244,7 +240,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: size, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), - lambdas: None, })) }), ); @@ -261,7 +256,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: size, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), - lambdas: None, })) }), ); diff --git a/datafusion/functions/benches/ltrim.rs b/datafusion/functions/benches/ltrim.rs index 9b344cc6b143a..4458af614396d 100644 --- a/datafusion/functions/benches/ltrim.rs +++ b/datafusion/functions/benches/ltrim.rs @@ -153,7 +153,6 @@ fn run_with_string_type( number_rows: size, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), - lambdas: None, })) }) }, diff --git a/datafusion/functions/benches/make_date.rs b/datafusion/functions/benches/make_date.rs index 2a681ddedcbe8..15a895468db93 100644 --- a/datafusion/functions/benches/make_date.rs +++ b/datafusion/functions/benches/make_date.rs @@ -81,7 +81,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: batch_len, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), - lambdas: None, }) .expect("make_date should work on valid values"), ) @@ -112,7 +111,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: batch_len, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), - lambdas: None, }) .expect("make_date should work on valid values"), ) @@ -143,7 +141,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: batch_len, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), - lambdas: None, }) .expect("make_date should work on valid values"), ) @@ -171,7 +168,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: 1, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), - lambdas: None, }) .expect("make_date should work on valid values"), ) diff --git a/datafusion/functions/benches/nullif.rs b/datafusion/functions/benches/nullif.rs index 15914cd7ee6c5..d649697cc5188 100644 --- a/datafusion/functions/benches/nullif.rs +++ b/datafusion/functions/benches/nullif.rs @@ -54,7 +54,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: size, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), - lambdas: None, }) .unwrap(), ) diff --git a/datafusion/functions/benches/pad.rs b/datafusion/functions/benches/pad.rs index c7d46da3d26c6..f92a69bbf4f92 100644 --- a/datafusion/functions/benches/pad.rs +++ b/datafusion/functions/benches/pad.rs @@ -116,7 +116,6 @@ fn invoke_pad_with_args( number_rows, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), - lambdas: None, }; if left_pad { diff --git a/datafusion/functions/benches/random.rs b/datafusion/functions/benches/random.rs index 2935876685800..88efb2d1b5b93 100644 --- a/datafusion/functions/benches/random.rs +++ b/datafusion/functions/benches/random.rs @@ -43,7 +43,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: 8192, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), - lambdas: None, }) .unwrap(), ); @@ -65,7 +64,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: 128, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), - lambdas: None, }) .unwrap(), ); diff --git a/datafusion/functions/benches/repeat.rs b/datafusion/functions/benches/repeat.rs index 9a7c63ed4f304..80ffa8ee38f1a 100644 --- a/datafusion/functions/benches/repeat.rs +++ b/datafusion/functions/benches/repeat.rs @@ -76,7 +76,6 @@ fn invoke_repeat_with_args( number_rows: repeat_times as usize, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), - lambdas: None, }) } diff --git a/datafusion/functions/benches/reverse.rs b/datafusion/functions/benches/reverse.rs index a8af40cd8cc19..b1eca654fb254 100644 --- a/datafusion/functions/benches/reverse.rs +++ b/datafusion/functions/benches/reverse.rs @@ -58,7 +58,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: N_ROWS, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), - lambdas: None, })) }) }, @@ -81,7 +80,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: N_ROWS, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), - lambdas: None, })) }) }, @@ -109,7 +107,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: N_ROWS, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), - lambdas: None, })) }) }, @@ -134,7 +131,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: N_ROWS, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), - lambdas: None, })) }) }, diff --git a/datafusion/functions/benches/signum.rs b/datafusion/functions/benches/signum.rs index 805b62c83da6d..24b8861e4d28c 100644 --- a/datafusion/functions/benches/signum.rs +++ b/datafusion/functions/benches/signum.rs @@ -55,7 +55,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: batch_len, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), - lambdas: None, }) .unwrap(), ) @@ -84,7 +83,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: batch_len, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), - lambdas: None, }) .unwrap(), ) diff --git a/datafusion/functions/benches/strpos.rs b/datafusion/functions/benches/strpos.rs index 708ebb5518727..18a99e44bf487 100644 --- a/datafusion/functions/benches/strpos.rs +++ b/datafusion/functions/benches/strpos.rs @@ -128,7 +128,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: n_rows, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), - lambdas: None, })) }) }, @@ -147,7 +146,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: n_rows, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), - lambdas: None, })) }) }); @@ -167,7 +165,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: n_rows, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), - lambdas: None, })) }) }, @@ -188,7 +185,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: n_rows, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), - lambdas: None, })) }) }, diff --git a/datafusion/functions/benches/substr.rs b/datafusion/functions/benches/substr.rs index 58fda73defd25..771413458c1fb 100644 --- a/datafusion/functions/benches/substr.rs +++ b/datafusion/functions/benches/substr.rs @@ -116,7 +116,6 @@ fn invoke_substr_with_args( number_rows, return_field: Field::new("f", DataType::Utf8View, true).into(), config_options: Arc::clone(&config_options), - lambdas: None, }) } diff --git a/datafusion/functions/benches/substr_index.rs b/datafusion/functions/benches/substr_index.rs index a77b961657c5f..d0941d9baedda 100644 --- a/datafusion/functions/benches/substr_index.rs +++ b/datafusion/functions/benches/substr_index.rs @@ -110,7 +110,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: batch_len, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), - lambdas: None, }) .expect("substr_index should work on valid values"), ) diff --git a/datafusion/functions/benches/to_char.rs b/datafusion/functions/benches/to_char.rs index 61990b4cb8b95..945508aec7405 100644 --- a/datafusion/functions/benches/to_char.rs +++ b/datafusion/functions/benches/to_char.rs @@ -149,7 +149,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: batch_len, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), - lambdas: None, }) .expect("to_char should work on valid values"), ) @@ -177,7 +176,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: batch_len, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), - lambdas: None, }) .expect("to_char should work on valid values"), ) @@ -205,7 +203,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: batch_len, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), - lambdas: None, }) .expect("to_char should work on valid values"), ) @@ -232,7 +229,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: batch_len, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), - lambdas: None, }) .expect("to_char should work on valid values"), ) @@ -260,7 +256,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: batch_len, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), - lambdas: None, }) .expect("to_char should work on valid values"), ) @@ -293,7 +288,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: 1, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), - lambdas: None, }) .expect("to_char should work on valid values"), ) diff --git a/datafusion/functions/benches/to_hex.rs b/datafusion/functions/benches/to_hex.rs index baa2de80c466f..a75ed9258791e 100644 --- a/datafusion/functions/benches/to_hex.rs +++ b/datafusion/functions/benches/to_hex.rs @@ -44,7 +44,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: batch_len, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), - lambdas: None, }) .unwrap(), ) @@ -63,7 +62,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: batch_len, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), - lambdas: None, }) .unwrap(), ) diff --git a/datafusion/functions/benches/to_timestamp.rs b/datafusion/functions/benches/to_timestamp.rs index e510a7c3fad41..a8f5c5816d4da 100644 --- a/datafusion/functions/benches/to_timestamp.rs +++ b/datafusion/functions/benches/to_timestamp.rs @@ -130,7 +130,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: batch_len, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), - lambdas: None, }) .expect("to_timestamp should work on valid values"), ) @@ -151,7 +150,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: batch_len, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), - lambdas: None, }) .expect("to_timestamp should work on valid values"), ) @@ -172,7 +170,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: batch_len, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), - lambdas: None, }) .expect("to_timestamp should work on valid values"), ) @@ -206,7 +203,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: batch_len, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), - lambdas: None, }) .expect("to_timestamp should work on valid values"), ) @@ -248,7 +244,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: batch_len, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), - lambdas: None, }) .expect("to_timestamp should work on valid values"), ) @@ -291,7 +286,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: batch_len, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), - lambdas: None, }) .expect("to_timestamp should work on valid values"), ) diff --git a/datafusion/functions/benches/trunc.rs b/datafusion/functions/benches/trunc.rs index 0b08791f9ae50..6e225e0e7038b 100644 --- a/datafusion/functions/benches/trunc.rs +++ b/datafusion/functions/benches/trunc.rs @@ -49,7 +49,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: size, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), - lambdas: None, }) .unwrap(), ) @@ -69,7 +68,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: size, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), - lambdas: None, }) .unwrap(), ) diff --git a/datafusion/functions/benches/upper.rs b/datafusion/functions/benches/upper.rs index e9f0941032d8a..7328b32574a4a 100644 --- a/datafusion/functions/benches/upper.rs +++ b/datafusion/functions/benches/upper.rs @@ -50,7 +50,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: size, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), - lambdas: None, })) }) }); diff --git a/datafusion/functions/benches/uuid.rs b/datafusion/functions/benches/uuid.rs index 8ad79b2866eaf..1368e2f2af5d1 100644 --- a/datafusion/functions/benches/uuid.rs +++ b/datafusion/functions/benches/uuid.rs @@ -37,7 +37,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: 1024, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), - lambdas: None, })) }) }); diff --git a/datafusion/functions/src/core/union_extract.rs b/datafusion/functions/src/core/union_extract.rs index ac542866f7e43..7f93500b9cfb9 100644 --- a/datafusion/functions/src/core/union_extract.rs +++ b/datafusion/functions/src/core/union_extract.rs @@ -209,7 +209,6 @@ mod tests { number_rows: 1, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::new(ConfigOptions::default()), - lambdas: None, })?; assert_scalar(result, ScalarValue::Utf8(None)); @@ -233,7 +232,6 @@ mod tests { number_rows: 1, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::new(ConfigOptions::default()), - lambdas: None, })?; assert_scalar(result, ScalarValue::Utf8(None)); @@ -257,7 +255,6 @@ mod tests { number_rows: 1, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::new(ConfigOptions::default()), - lambdas: None, })?; assert_scalar(result, ScalarValue::new_utf8("42")); diff --git a/datafusion/functions/src/core/union_tag.rs b/datafusion/functions/src/core/union_tag.rs index 4832c368872bf..aeadb8292ba1e 100644 --- a/datafusion/functions/src/core/union_tag.rs +++ b/datafusion/functions/src/core/union_tag.rs @@ -184,7 +184,6 @@ mod tests { return_field: Field::new("res", return_type, true).into(), arg_fields: vec![], config_options: Arc::new(ConfigOptions::default()), - lambdas: None, }) .unwrap(); @@ -208,7 +207,6 @@ mod tests { return_field: Field::new("res", return_type, true).into(), arg_fields: vec![], config_options: Arc::new(ConfigOptions::default()), - lambdas: None, }) .unwrap(); diff --git a/datafusion/functions/src/core/version.rs b/datafusion/functions/src/core/version.rs index 390111028c8f2..ef3c5aafa4801 100644 --- a/datafusion/functions/src/core/version.rs +++ b/datafusion/functions/src/core/version.rs @@ -112,7 +112,6 @@ mod test { number_rows: 0, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::new(ConfigOptions::default()), - lambdas: None, }) .unwrap(); diff --git a/datafusion/functions/src/datetime/date_bin.rs b/datafusion/functions/src/datetime/date_bin.rs index 5466129314640..92af123dbafac 100644 --- a/datafusion/functions/src/datetime/date_bin.rs +++ b/datafusion/functions/src/datetime/date_bin.rs @@ -530,7 +530,6 @@ mod tests { number_rows, return_field: Arc::clone(return_field), config_options: Arc::new(ConfigOptions::default()), - lambdas: None, }; DateBinFunc::new().invoke_with_args(args) } diff --git a/datafusion/functions/src/datetime/date_trunc.rs b/datafusion/functions/src/datetime/date_trunc.rs index 5736c221cae84..913e6217af82d 100644 --- a/datafusion/functions/src/datetime/date_trunc.rs +++ b/datafusion/functions/src/datetime/date_trunc.rs @@ -892,7 +892,6 @@ mod tests { ) .into(), config_options: Arc::new(ConfigOptions::default()), - lambdas: None, }; let result = DateTruncFunc::new().invoke_with_args(args).unwrap(); if let ColumnarValue::Array(result) = result { @@ -1081,7 +1080,6 @@ mod tests { ) .into(), config_options: Arc::new(ConfigOptions::default()), - lambdas: None, }; let result = DateTruncFunc::new().invoke_with_args(args).unwrap(); if let ColumnarValue::Array(result) = result { diff --git a/datafusion/functions/src/datetime/from_unixtime.rs b/datafusion/functions/src/datetime/from_unixtime.rs index be44be094e5b7..5d6adfb6f119a 100644 --- a/datafusion/functions/src/datetime/from_unixtime.rs +++ b/datafusion/functions/src/datetime/from_unixtime.rs @@ -179,7 +179,6 @@ mod test { number_rows: 1, return_field: Field::new("f", DataType::Timestamp(Second, None), true).into(), config_options: Arc::new(ConfigOptions::default()), - lambdas: None, }; let result = FromUnixtimeFunc::new().invoke_with_args(args).unwrap(); @@ -213,7 +212,6 @@ mod test { ) .into(), config_options: Arc::new(ConfigOptions::default()), - lambdas: None, }; let result = FromUnixtimeFunc::new().invoke_with_args(args).unwrap(); diff --git a/datafusion/functions/src/datetime/make_date.rs b/datafusion/functions/src/datetime/make_date.rs index afa4ef132147a..0fe5d156a8383 100644 --- a/datafusion/functions/src/datetime/make_date.rs +++ b/datafusion/functions/src/datetime/make_date.rs @@ -250,7 +250,6 @@ mod tests { number_rows, return_field: Field::new("f", DataType::Date32, true).into(), config_options: Arc::new(ConfigOptions::default()), - lambdas: None, }; MakeDateFunc::new().invoke_with_args(args) } diff --git a/datafusion/functions/src/datetime/now.rs b/datafusion/functions/src/datetime/now.rs index f18e72a107e28..4723548a45584 100644 --- a/datafusion/functions/src/datetime/now.rs +++ b/datafusion/functions/src/datetime/now.rs @@ -163,7 +163,6 @@ mod tests { .return_field_from_args(ReturnFieldArgs { arg_fields: &empty_fields, scalar_arguments: &empty_scalars, - lambdas: &[], }) .expect("legacy now() return field"); @@ -171,7 +170,6 @@ mod tests { .return_field_from_args(ReturnFieldArgs { arg_fields: &empty_fields, scalar_arguments: &empty_scalars, - lambdas: &[], }) .expect("configured now() return field"); diff --git a/datafusion/functions/src/datetime/to_char.rs b/datafusion/functions/src/datetime/to_char.rs index 5d69ce233f643..7d9b2bc241e1a 100644 --- a/datafusion/functions/src/datetime/to_char.rs +++ b/datafusion/functions/src/datetime/to_char.rs @@ -375,7 +375,6 @@ mod tests { number_rows: batch_len, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&Arc::new(ConfigOptions::default())), - lambdas: None, }; let result = ToCharFunc::new() .invoke_with_args(args) @@ -481,7 +480,6 @@ mod tests { number_rows: 1, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::new(ConfigOptions::default()), - lambdas: None, }; let result = ToCharFunc::new() .invoke_with_args(args) @@ -576,7 +574,6 @@ mod tests { number_rows: batch_len, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::new(ConfigOptions::default()), - lambdas: None, }; let result = ToCharFunc::new() .invoke_with_args(args) @@ -741,7 +738,6 @@ mod tests { number_rows: batch_len, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::new(ConfigOptions::default()), - lambdas: None, }; let result = ToCharFunc::new() .invoke_with_args(args) @@ -770,7 +766,6 @@ mod tests { number_rows: batch_len, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::new(ConfigOptions::default()), - lambdas: None, }; let result = ToCharFunc::new() .invoke_with_args(args) @@ -796,7 +791,6 @@ mod tests { number_rows: 1, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::new(ConfigOptions::default()), - lambdas: None, }; let result = ToCharFunc::new().invoke_with_args(args); assert_eq!( @@ -818,7 +812,6 @@ mod tests { number_rows: 1, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::new(ConfigOptions::default()), - lambdas: None, }; let result = ToCharFunc::new().invoke_with_args(args); assert_eq!( diff --git a/datafusion/functions/src/datetime/to_date.rs b/datafusion/functions/src/datetime/to_date.rs index f6b313e6a28bb..3840c8d8bbb94 100644 --- a/datafusion/functions/src/datetime/to_date.rs +++ b/datafusion/functions/src/datetime/to_date.rs @@ -186,7 +186,6 @@ mod tests { number_rows, return_field: Field::new("f", DataType::Date32, true).into(), config_options: Arc::new(ConfigOptions::default()), - lambdas: None, }; ToDateFunc::new().invoke_with_args(args) } diff --git a/datafusion/functions/src/datetime/to_local_time.rs b/datafusion/functions/src/datetime/to_local_time.rs index 4d50a70d37236..6e0a150b0a35f 100644 --- a/datafusion/functions/src/datetime/to_local_time.rs +++ b/datafusion/functions/src/datetime/to_local_time.rs @@ -549,7 +549,6 @@ mod tests { number_rows: 1, return_field: Field::new("f", expected.data_type(), true).into(), config_options: Arc::new(ConfigOptions::default()), - lambdas: None, }) .unwrap(); match res { @@ -621,7 +620,6 @@ mod tests { ) .into(), config_options: Arc::new(ConfigOptions::default()), - lambdas: None, }; let result = ToLocalTimeFunc::new().invoke_with_args(args).unwrap(); if let ColumnarValue::Array(result) = result { diff --git a/datafusion/functions/src/datetime/to_timestamp.rs b/datafusion/functions/src/datetime/to_timestamp.rs index f35e170073030..0a0700097770f 100644 --- a/datafusion/functions/src/datetime/to_timestamp.rs +++ b/datafusion/functions/src/datetime/to_timestamp.rs @@ -1033,7 +1033,6 @@ mod tests { number_rows: 4, return_field: Field::new("f", rt, true).into(), config_options: Arc::new(ConfigOptions::default()), - lambdas: None, }; let res = udf .invoke_with_args(args) @@ -1084,7 +1083,6 @@ mod tests { number_rows: 5, return_field: Field::new("f", rt, true).into(), config_options: Arc::new(ConfigOptions::default()), - lambdas: None, }; let res = udf .invoke_with_args(args) diff --git a/datafusion/functions/src/math/log.rs b/datafusion/functions/src/math/log.rs index 1a73ed8436a68..f66f6fcfc1f88 100644 --- a/datafusion/functions/src/math/log.rs +++ b/datafusion/functions/src/math/log.rs @@ -370,7 +370,6 @@ mod tests { number_rows: 4, return_field: Field::new("f", DataType::Float64, true).into(), config_options: Arc::new(ConfigOptions::default()), - lambdas: None, }; let result = LogFunc::new().invoke_with_args(args); assert!(result.is_err()); @@ -391,7 +390,6 @@ mod tests { number_rows: 1, return_field: Field::new("f", DataType::Float64, true).into(), config_options: Arc::new(ConfigOptions::default()), - lambdas: None, }; let result = LogFunc::new().invoke_with_args(args); @@ -409,7 +407,6 @@ mod tests { number_rows: 1, return_field: Field::new("f", DataType::Float32, true).into(), config_options: Arc::new(ConfigOptions::default()), - lambdas: None, }; let result = LogFunc::new() .invoke_with_args(args) @@ -440,7 +437,6 @@ mod tests { number_rows: 1, return_field: Field::new("f", DataType::Float64, true).into(), config_options: Arc::new(ConfigOptions::default()), - lambdas: None, }; let result = LogFunc::new() .invoke_with_args(args) @@ -475,7 +471,6 @@ mod tests { number_rows: 1, return_field: Field::new("f", DataType::Float32, true).into(), config_options: Arc::new(ConfigOptions::default()), - lambdas: None, }; let result = LogFunc::new() .invoke_with_args(args) @@ -510,7 +505,6 @@ mod tests { number_rows: 1, return_field: Field::new("f", DataType::Float64, true).into(), config_options: Arc::new(ConfigOptions::default()), - lambdas: None, }; let result = LogFunc::new() .invoke_with_args(args) @@ -543,7 +537,6 @@ mod tests { number_rows: 4, return_field: Field::new("f", DataType::Float64, true).into(), config_options: Arc::new(ConfigOptions::default()), - lambdas: None, }; let result = LogFunc::new() .invoke_with_args(args) @@ -579,7 +572,6 @@ mod tests { number_rows: 4, return_field: Field::new("f", DataType::Float32, true).into(), config_options: Arc::new(ConfigOptions::default()), - lambdas: None, }; let result = LogFunc::new() .invoke_with_args(args) @@ -621,7 +613,6 @@ mod tests { number_rows: 5, return_field: Field::new("f", DataType::Float64, true).into(), config_options: Arc::new(ConfigOptions::default()), - lambdas: None, }; let result = LogFunc::new() .invoke_with_args(args) @@ -664,7 +655,6 @@ mod tests { number_rows: 4, return_field: Field::new("f", DataType::Float32, true).into(), config_options: Arc::new(ConfigOptions::default()), - lambdas: None, }; let result = LogFunc::new() .invoke_with_args(args) @@ -846,7 +836,6 @@ mod tests { number_rows: 1, return_field: Field::new("f", DataType::Decimal128(38, 0), true).into(), config_options: Arc::new(ConfigOptions::default()), - lambdas: None, }; let result = LogFunc::new() .invoke_with_args(args) @@ -880,7 +869,6 @@ mod tests { number_rows: 1, return_field: Field::new("f", DataType::Float64, true).into(), config_options: Arc::new(ConfigOptions::default()), - lambdas: None, }; let result = LogFunc::new() .invoke_with_args(args) @@ -915,7 +903,6 @@ mod tests { number_rows: 6, return_field: Field::new("f", DataType::Float64, true).into(), config_options: Arc::new(ConfigOptions::default()), - lambdas: None, }; let result = LogFunc::new() .invoke_with_args(args) @@ -960,7 +947,6 @@ mod tests { number_rows: 1, return_field: Field::new("f", DataType::Float64, true).into(), config_options: Arc::new(ConfigOptions::default()), - lambdas: None, }; let result = LogFunc::new() .invoke_with_args(args) @@ -1001,7 +987,6 @@ mod tests { number_rows: 1, return_field: Field::new("f", DataType::Float64, true).into(), config_options: Arc::new(ConfigOptions::default()), - lambdas: None, }; let result = LogFunc::new() .invoke_with_args(args) @@ -1052,7 +1037,6 @@ mod tests { number_rows: 7, return_field: Field::new("f", DataType::Float64, true).into(), config_options: Arc::new(ConfigOptions::default()), - lambdas: None, }; let result = LogFunc::new() .invoke_with_args(args) @@ -1094,7 +1078,6 @@ mod tests { number_rows: 1, return_field: Field::new("f", DataType::Float64, true).into(), config_options: Arc::new(ConfigOptions::default()), - lambdas: None, }; let result = LogFunc::new().invoke_with_args(args); assert!(result.is_err()); @@ -1118,7 +1101,6 @@ mod tests { number_rows: 1, return_field: Field::new("f", DataType::Float64, true).into(), config_options: Arc::new(ConfigOptions::default()), - lambdas: None, }; let result = LogFunc::new().invoke_with_args(args); assert!(result.is_err()); diff --git a/datafusion/functions/src/math/power.rs b/datafusion/functions/src/math/power.rs index 21a777abb3295..ad2e795d086e9 100644 --- a/datafusion/functions/src/math/power.rs +++ b/datafusion/functions/src/math/power.rs @@ -222,7 +222,6 @@ mod tests { number_rows: 4, return_field: Field::new("f", DataType::Float64, true).into(), config_options: Arc::new(ConfigOptions::default()), - lambdas: None, }; let result = PowerFunc::new() .invoke_with_args(args) @@ -259,7 +258,6 @@ mod tests { number_rows: 4, return_field: Field::new("f", DataType::Int64, true).into(), config_options: Arc::new(ConfigOptions::default()), - lambdas: None, }; let result = PowerFunc::new() .invoke_with_args(args) diff --git a/datafusion/functions/src/math/signum.rs b/datafusion/functions/src/math/signum.rs index d1d49b1bf6f90..bbe6178f39b79 100644 --- a/datafusion/functions/src/math/signum.rs +++ b/datafusion/functions/src/math/signum.rs @@ -173,7 +173,6 @@ mod test { number_rows: array.len(), return_field: Field::new("f", DataType::Float32, true).into(), config_options: Arc::new(ConfigOptions::default()), - lambdas: None, }; let result = SignumFunc::new() .invoke_with_args(args) @@ -221,7 +220,6 @@ mod test { number_rows: array.len(), return_field: Field::new("f", DataType::Float64, true).into(), config_options: Arc::new(ConfigOptions::default()), - lambdas: None, }; let result = SignumFunc::new() .invoke_with_args(args) diff --git a/datafusion/functions/src/regex/regexpcount.rs b/datafusion/functions/src/regex/regexpcount.rs index ee6f412bb9a16..8bad506217aa5 100644 --- a/datafusion/functions/src/regex/regexpcount.rs +++ b/datafusion/functions/src/regex/regexpcount.rs @@ -628,7 +628,6 @@ mod tests { number_rows: args.len(), return_field: Field::new("f", Int64, true).into(), config_options: Arc::new(ConfigOptions::default()), - lambdas: None, }) } diff --git a/datafusion/functions/src/regex/regexpinstr.rs b/datafusion/functions/src/regex/regexpinstr.rs index 1e64f7087ea74..851c182a90dd0 100644 --- a/datafusion/functions/src/regex/regexpinstr.rs +++ b/datafusion/functions/src/regex/regexpinstr.rs @@ -494,7 +494,6 @@ mod tests { number_rows: args.len(), return_field: Arc::new(Field::new("f", Int64, true)), config_options: Arc::new(ConfigOptions::default()), - lambdas: None, }) } diff --git a/datafusion/functions/src/string/concat.rs b/datafusion/functions/src/string/concat.rs index 661bcfe4e0fd8..a93e70e714e8b 100644 --- a/datafusion/functions/src/string/concat.rs +++ b/datafusion/functions/src/string/concat.rs @@ -487,7 +487,6 @@ mod tests { number_rows: 3, return_field: Field::new("f", Utf8, true).into(), config_options: Arc::new(ConfigOptions::default()), - lambdas: None, }; let result = ConcatFunc::new().invoke_with_args(args)?; diff --git a/datafusion/functions/src/string/concat_ws.rs b/datafusion/functions/src/string/concat_ws.rs index 85704d6b2f468..cdd30ac8755ab 100644 --- a/datafusion/functions/src/string/concat_ws.rs +++ b/datafusion/functions/src/string/concat_ws.rs @@ -495,7 +495,6 @@ mod tests { number_rows: 3, return_field: Field::new("f", Utf8, true).into(), config_options: Arc::new(ConfigOptions::default()), - lambdas: None, }; let result = ConcatWsFunc::new().invoke_with_args(args)?; @@ -533,7 +532,6 @@ mod tests { number_rows: 3, return_field: Field::new("f", Utf8, true).into(), config_options: Arc::new(ConfigOptions::default()), - lambdas: None, }; let result = ConcatWsFunc::new().invoke_with_args(args)?; diff --git a/datafusion/functions/src/string/contains.rs b/datafusion/functions/src/string/contains.rs index 1edab4c6bf334..7e50676933c8d 100644 --- a/datafusion/functions/src/string/contains.rs +++ b/datafusion/functions/src/string/contains.rs @@ -177,7 +177,6 @@ mod test { number_rows: 2, return_field: Field::new("f", DataType::Boolean, true).into(), config_options: Arc::new(ConfigOptions::default()), - lambdas: None, }; let actual = udf.invoke_with_args(args).unwrap(); diff --git a/datafusion/functions/src/string/lower.rs b/datafusion/functions/src/string/lower.rs index 099a3ffd44cc4..ee56a6a549857 100644 --- a/datafusion/functions/src/string/lower.rs +++ b/datafusion/functions/src/string/lower.rs @@ -113,7 +113,6 @@ mod tests { arg_fields, return_field: Field::new("f", Utf8, true).into(), config_options: Arc::new(ConfigOptions::default()), - lambdas: None, }; let result = match func.invoke_with_args(args)? { diff --git a/datafusion/functions/src/string/upper.rs b/datafusion/functions/src/string/upper.rs index d7d2bde94b0a3..8bb2ec1d511cd 100644 --- a/datafusion/functions/src/string/upper.rs +++ b/datafusion/functions/src/string/upper.rs @@ -112,7 +112,6 @@ mod tests { arg_fields: vec![arg_field], return_field: Field::new("f", Utf8, true).into(), config_options: Arc::new(ConfigOptions::default()), - lambdas: None, }; let result = match func.invoke_with_args(args)? { diff --git a/datafusion/functions/src/unicode/find_in_set.rs b/datafusion/functions/src/unicode/find_in_set.rs index 219bd6eaa762c..fa68e539600b0 100644 --- a/datafusion/functions/src/unicode/find_in_set.rs +++ b/datafusion/functions/src/unicode/find_in_set.rs @@ -485,7 +485,6 @@ mod tests { number_rows: cardinality, return_field: Field::new("f", return_type, true).into(), config_options: Arc::new(ConfigOptions::default()), - lambdas: None, }); assert!(result.is_ok()); diff --git a/datafusion/functions/src/unicode/strpos.rs b/datafusion/functions/src/unicode/strpos.rs index a3734b0c0de4f..4f238b2644bdf 100644 --- a/datafusion/functions/src/unicode/strpos.rs +++ b/datafusion/functions/src/unicode/strpos.rs @@ -336,7 +336,6 @@ mod tests { Field::new("f2", DataType::Utf8, substring_nullable).into(), ], scalar_arguments: &[None::<&ScalarValue>, None::<&ScalarValue>], - lambdas: &[false; 2], }; strpos.return_field_from_args(args).unwrap().is_nullable() diff --git a/datafusion/functions/src/utils.rs b/datafusion/functions/src/utils.rs index d6d56b32722de..932d61e8007cd 100644 --- a/datafusion/functions/src/utils.rs +++ b/datafusion/functions/src/utils.rs @@ -234,7 +234,6 @@ pub mod test { let return_field = func.return_field_from_args(datafusion_expr::ReturnFieldArgs { arg_fields: &field_array, scalar_arguments: &scalar_arguments_refs, - lambdas: &vec![false; scalar_arguments_refs.len()], }); let arg_fields = $ARGS.iter() .enumerate() @@ -253,7 +252,6 @@ pub mod test { arg_fields, number_rows: cardinality, return_field, - lambdas: None, config_options: $CONFIG_OPTIONS }); assert_eq!(result.is_ok(), true, "function returned an error: {}", result.unwrap_err()); @@ -276,7 +274,6 @@ pub mod test { arg_fields, number_rows: cardinality, return_field, - lambdas: None, config_options: $CONFIG_OPTIONS, }) { Ok(_) => assert!(false, "expected error"), diff --git a/datafusion/physical-expr/src/async_scalar_function.rs b/datafusion/physical-expr/src/async_scalar_function.rs index a34d3cda47682..b434694a20cc8 100644 --- a/datafusion/physical-expr/src/async_scalar_function.rs +++ b/datafusion/physical-expr/src/async_scalar_function.rs @@ -168,7 +168,6 @@ impl AsyncFuncExpr { number_rows: current_batch.num_rows(), return_field: Arc::clone(&self.return_field), config_options: Arc::clone(&config_options), - lambdas: None, }) .await?, ); @@ -188,7 +187,6 @@ impl AsyncFuncExpr { number_rows: batch.num_rows(), return_field: Arc::clone(&self.return_field), config_options: Arc::clone(&config_options), - lambdas: None, }) .await?, ); diff --git a/datafusion/physical-expr/src/scalar_function.rs b/datafusion/physical-expr/src/scalar_function.rs index 2527e84241fe3..6ad22671ba847 100644 --- a/datafusion/physical-expr/src/scalar_function.rs +++ b/datafusion/physical-expr/src/scalar_function.rs @@ -34,19 +34,19 @@ use std::fmt::{self, Debug, Formatter}; use std::hash::{Hash, Hasher}; use std::sync::Arc; -use crate::expressions::{LambdaExpr, Literal}; +use crate::expressions::Literal; use crate::PhysicalExpr; -use arrow::array::{Array, NullArray, RecordBatch}; -use arrow::datatypes::{DataType, Field, FieldRef, Schema}; +use arrow::array::{Array, RecordBatch}; +use arrow::datatypes::{DataType, FieldRef, Schema}; use datafusion_common::config::{ConfigEntry, ConfigOptions}; -use datafusion_common::{exec_err, internal_err, Result, ScalarValue}; +use datafusion_common::{internal_err, Result, ScalarValue}; use datafusion_expr::interval_arithmetic::Interval; use datafusion_expr::sort_properties::ExprProperties; use datafusion_expr::type_coercion::functions::data_types_with_scalar_udf; use datafusion_expr::{ - expr_vec_fmt, ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, - ScalarFunctionLambdaArg, ScalarUDF, ValueOrLambdaParameter, Volatility, + expr_vec_fmt, ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF, + Volatility, }; /// Physical expression of a scalar function @@ -117,15 +117,9 @@ impl ScalarFunctionExpr { }) .collect::>(); - let lambdas = args - .iter() - .map(|e| e.as_any().is::()) - .collect::>(); - let ret_args = ReturnFieldArgs { arg_fields: &arg_fields, scalar_arguments: &arguments, - lambdas: &lambdas, }; let return_field = fun.return_field_from_args(ret_args)?; @@ -270,10 +264,7 @@ impl PhysicalExpr for ScalarFunctionExpr { let args = self .args .iter() - .map(|e| match e.as_any().downcast_ref::() { - Some(_) => Ok(ColumnarValue::Scalar(ScalarValue::Null)), - None => Ok(e.evaluate(batch)?), - }) + .map(|e| e.evaluate(batch)) .collect::>>()?; let arg_fields = self @@ -287,89 +278,6 @@ impl PhysicalExpr for ScalarFunctionExpr { .iter() .all(|arg| matches!(arg, ColumnarValue::Scalar(_))); - let lambdas = if self.args.iter().any(|arg| arg.as_any().is::()) { - let args_metadata = std::iter::zip(&self.args, &arg_fields) - .map( - |(expr, field)| match expr.as_any().downcast_ref::() { - Some(_lambda) => ValueOrLambdaParameter::Lambda, - None => ValueOrLambdaParameter::Value(Arc::clone(field)), - }, - ) - .collect::>(); - - let params = self.fun().inner().lambdas_parameters(&args_metadata)?; - - let lambdas = std::iter::zip(&self.args, params) - .map(|(arg, lambda_params)| { - match (arg.as_any().downcast_ref::(), lambda_params) { - (Some(lambda), Some(lambda_params)) => { - if lambda.params().len() > lambda_params.len() { - return exec_err!( - "lambda defined {} params but UDF support only {}", - lambda.params().len(), - lambda_params.len() - ); - } - - let captures = lambda.captures(); - - let params = std::iter::zip(lambda.params(), lambda_params) - .map(|(name, param)| Arc::new(param.with_name(name))) - .collect(); - - let captures = if !captures.is_empty() { - let (fields, columns): (Vec<_>, _) = std::iter::zip( - batch.schema_ref().fields(), - batch.columns(), - ) - .enumerate() - .map(|(column_index, (field, column))| { - if captures.contains(&column_index) { - (Arc::clone(field), Arc::clone(column)) - } else { - ( - Arc::new(Field::new( - field.name(), - DataType::Null, - false, - )), - Arc::new(NullArray::new(column.len())) as _, - ) - } - }) - .unzip(); - - let schema = Arc::new(Schema::new(fields)); - - Some(RecordBatch::try_new(schema, columns)?) - } else { - None - }; - - Ok(Some(ScalarFunctionLambdaArg { - params, - body: Arc::clone(lambda.body()), - captures, - })) - } - (Some(_lambda), None) => exec_err!( - "{} don't reported the parameters of one of it's lambdas", - self.fun.name() - ), - (None, Some(_lambda_params)) => exec_err!( - "{} reported parameters for an argument that is not a lambda", - self.fun.name() - ), - _ => Ok(None), - } - }) - .collect::>>()?; - - Some(lambdas) - } else { - None - }; - // evaluate the function let output = self.fun.invoke_with_args(ScalarFunctionArgs { args, @@ -377,9 +285,9 @@ impl PhysicalExpr for ScalarFunctionExpr { number_rows: batch.num_rows(), return_field: Arc::clone(&self.return_field), config_options: Arc::clone(&self.config_options), - lambdas, })?; + if let ColumnarValue::Array(array) = &output { if array.len() != batch.num_rows() { // If the arguments are a non-empty slice of scalar values, we can assume that diff --git a/datafusion/spark/benches/char.rs b/datafusion/spark/benches/char.rs index 501bfd2a0186d..02eab7630d070 100644 --- a/datafusion/spark/benches/char.rs +++ b/datafusion/spark/benches/char.rs @@ -68,7 +68,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: size, return_field: Arc::new(Field::new("f", DataType::Utf8, true)), config_options: Arc::clone(&config_options), - lambdas: None, }) .unwrap(), ) diff --git a/datafusion/spark/src/function/bitmap/bitmap_count.rs b/datafusion/spark/src/function/bitmap/bitmap_count.rs index e4c12ebe19665..56a9c5edb812c 100644 --- a/datafusion/spark/src/function/bitmap/bitmap_count.rs +++ b/datafusion/spark/src/function/bitmap/bitmap_count.rs @@ -217,7 +217,6 @@ mod tests { number_rows: 1, return_field: Field::new("f", Int64, true).into(), config_options: Arc::new(ConfigOptions::default()), - lambdas: None, }; let udf = BitmapCount::new(); let actual = udf.invoke_with_args(args)?; diff --git a/datafusion/spark/src/function/datetime/make_dt_interval.rs b/datafusion/spark/src/function/datetime/make_dt_interval.rs index aaff5400d0c00..bbfba44861344 100644 --- a/datafusion/spark/src/function/datetime/make_dt_interval.rs +++ b/datafusion/spark/src/function/datetime/make_dt_interval.rs @@ -317,7 +317,6 @@ mod tests { number_rows, return_field: Field::new("f", Duration(Microsecond), true).into(), config_options: Arc::new(Default::default()), - lambdas: None, }; SparkMakeDtInterval::new().invoke_with_args(args) } diff --git a/datafusion/spark/src/function/datetime/make_interval.rs b/datafusion/spark/src/function/datetime/make_interval.rs index 9f98c4b5ce9fb..8e3169556b95b 100644 --- a/datafusion/spark/src/function/datetime/make_interval.rs +++ b/datafusion/spark/src/function/datetime/make_interval.rs @@ -516,7 +516,6 @@ mod tests { number_rows, return_field: Field::new("f", Interval(MonthDayNano), true).into(), config_options: Arc::new(ConfigOptions::default()), - lambdas: None, }; SparkMakeInterval::new().invoke_with_args(args) } diff --git a/datafusion/spark/src/function/string/concat.rs b/datafusion/spark/src/function/string/concat.rs index e2cd8d977fe29..0dcc58d5bb8ed 100644 --- a/datafusion/spark/src/function/string/concat.rs +++ b/datafusion/spark/src/function/string/concat.rs @@ -105,7 +105,6 @@ fn spark_concat(args: ScalarFunctionArgs) -> Result { number_rows, return_field, config_options, - lambdas, } = args; // Handle zero-argument case: return empty string @@ -131,7 +130,6 @@ fn spark_concat(args: ScalarFunctionArgs) -> Result { number_rows, return_field, config_options, - lambdas, }; let result = concat_func.invoke_with_args(func_args)?; diff --git a/datafusion/spark/src/function/utils.rs b/datafusion/spark/src/function/utils.rs index 1064acc342916..b939dabda388d 100644 --- a/datafusion/spark/src/function/utils.rs +++ b/datafusion/spark/src/function/utils.rs @@ -61,7 +61,6 @@ pub mod test { let return_field = func.return_field_from_args(datafusion_expr::ReturnFieldArgs { arg_fields: &arg_fields, scalar_arguments: &scalar_arguments_refs, - lambdas: &vec![false; arg_fields.len()], }); match expected { @@ -75,7 +74,6 @@ pub mod test { return_field, arg_fields: arg_fields.clone(), config_options: $CONFIG_OPTIONS, - lambdas: None, }) { Ok(col_value) => { match col_value.to_array(cardinality) { @@ -119,7 +117,6 @@ pub mod test { return_field: value, arg_fields, config_options: $CONFIG_OPTIONS, - lambdas: None, }) { Ok(_) => assert!(false, "expected error"), Err(error) => { From 570cc53367e9d4a4697c0dd4d462641f87670c4e Mon Sep 17 00:00:00 2001 From: gstvg <28798827+gstvg@users.noreply.github.com> Date: Sun, 1 Mar 2026 02:13:40 -0300 Subject: [PATCH 07/47] temporarily add pr description as DOC.md --- DOC.md | 1166 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 1166 insertions(+) create mode 100644 DOC.md diff --git a/DOC.md b/DOC.md new file mode 100644 index 0000000000000..10a0ab4e19407 --- /dev/null +++ b/DOC.md @@ -0,0 +1,1166 @@ +This PR adds support for lambdas with column capture and the `array_transform` function used to test the lambda implementation. Example usage: + +```sql +CREATE TABLE t as SELECT 2 as n; + +SELECT array_transform([2, 3], v -> v != t.n) from t; + +[false, true] + +-- arbitrally nested lambdas are also supported +SELECT array_transform([[[2, 3]]], m -> array_transform(m, l -> array_transform(l, v -> v*2))); + +[[[4, 6]]] +``` + +Some comments on code snippets of this doc show what value each struct, variant or field would hold after planning the first example above. Some literals are simplified pseudo code + +3 new `Expr` variants are added, `LambdaFunction`, owing a new trait `LambdaUDF`, which is like a `ScalarFunction`/`ScalarUDFImpl` with support for lambdas, `Lambda`, for the lambda body and it's parameters names, and `LambdaVariable`, which is like `Column` but for lambdas parameters. The reasoning why not using `Column` instead is later on this doc. + +Their logical representations: + +```rust +enum Expr { + LambdaFunction(LambdaFunction), // array_transform([2, 3], v -> v != t.n) + Lambda(Lambda), // v -> v != t.n + LambdaVariable(LambdaVariable), // v, of the lambda body: v != t.n + ... +} + +// array_transform([2, 3], v -> v != t.n) +struct LambdaFunction { + pub func: Arc, // global instance of array_transform + pub args: Vec, // [Expr::ScalarValue([2, 3]), Expr::Lambda(v -> v != n)] +} + +// v -> v != t.n +struct Lambda { + pub params: Vec, // ["v"] + pub body: Box, // v != n +} + +// v, of the lambda body: v != t.n +struct LambdaVariable { + pub name: String, // "v" + pub field: Option, // Some(Field::new("", DataType::Int32, false)) + pub spans: Spans, +} + +``` + +The example would be planned into a tree like this: + +``` +LambdaFunctionExpression + name: array_transform + children: + 1. ListExpression [2,3] + 2. LambdaExpression + parameters: ["v"] + body: + ComparisonExpression (!=) + left: + LambdaVariableExpression("v", Some(Field::new("", Int32, false))) + right: + ColumnExpression("t.n") +``` + +The physical counterparts definition: + +```rust + +struct LambdaFunctionExpr { + fun: Arc, // global instance of array_transform + name: String, // "array_transform" + args: Vec>, // [LiteralExpr([2, 3], LambdaExpr("v -> v != t.n"))] + return_field: FieldRef, // Field::new("", DataType::new_list(DataType::Boolean, false), false) + config_options: Arc, +} + + +struct LambdaExpr { + params: Vec, // ["v"] + body: Arc, // v -> v != t.n +} + +struct LambdaVariable { + name: String, // "v", of the lambda body: v != t.n + field: FieldRef, // Field::new("", DataType::Int32, false) + value: Option, // reasoning later on +} +``` + +Note: For those who primarly wants to check if this lambda implementation supports their usecase and don't want to spend much time here, it's okay to skip most collapsed blocks, as those serve mostly to help code reviewers, with the exception of `LambdaUDF` and the `array_transform` implementation of `LambdaUDF` relevant methods, collapsed due to their size + +
Physical planning implementation is trivial: + +```rust +fn create_physical_expr( + e: &Expr, + input_dfschema: &DFSchema, + execution_props: &ExecutionProps, +) -> Result> { + let input_schema = input_dfschema.as_arrow(); + + match e { + ... + Expr::LambdaFunction(LambdaFunction { func, args}) => { + let physical_args = + create_physical_exprs(args, input_dfschema, execution_props)?; + + Ok(Arc::new(LambdaFunctionExpr::try_new( + Arc::clone(func), + physical_args, + input_schema, + config_options: ... // irrelevant + )?)) + } + Expr::Lambda(Lambda { params, body }) => Ok(Arc::new(LambdaExpr::new( + params.clone(), + create_physical_expr(body, input_dfschema, execution_props)?, + ))), + Expr::LambdaVariable(LambdaVariable { + name, + field, + spans: _, + }) => lambda_variable( + name, + Arc::clone(field), + ), + } +} +``` + +
+
+ +The added `LambdaUDF` trait is almost a clone of `ScalarUDFImpl`, with the exception of: +1. `return_field_from_args` and `invoke_with_args`, where now `args.args` is a list of enums with two variants: `Value` or `Lambda` instead of a list of values +2. the addition of `lambdas_parameters`, which return a `Field` for each parameter supported for every lambda argument based on the `Field` of the non lambda arguments +3. the removal of `return_field` and the deprecated ones `is_nullable` and `display_name`. + +
LambdaUDF + +```rust + +trait LambdaUDF { + /// Returns a list of the same size as args where each value is the logic below applied to value at the correspondent position in args: + /// + /// If it's a value, return None + /// If it's a lambda, return the list of all parameters that that lambda supports + /// based on the Field of the non-lambda arguments + /// + /// Example for array_transform: + /// + /// `array_transform([2, 8], v -> v > 4)` + /// + /// let lambdas_parameters = array_transform.lambdas_parameters(&[ + /// ValueOrLambdaParameter::Value(Field::new("", DataType::new_list(DataType::Int32, false)))]), // the Field associated with the literal `[2, 8]` + /// ValueOrLambdaParameter::Lambda, // A lambda + /// ]?; + /// + /// assert_eq!( + /// lambdas_parameters, + /// vec![ + /// None, // it's a value, return None + /// // it's a lambda, return it's supported parameters, regardless of how many are actually used + /// Some(vec![ + /// Field::new("", DataType::Int32, false), // the value being transformed, + /// Field::new("", DataType::Int32, false), // the 1-based index being transformed, not used on the example above, but implementations doesn't need to care about it + /// ]) + /// ] + /// ) + fn lambdas_parameters( + &self, + args: &[ValueOrLambdaParameter], + ) -> Result>>>; + fn return_field_from_args(&self, args: LambdaReturnFieldArgs) -> Result; + fn invoke_with_args(&self, args: LambdaFunctionArgs) -> Result; + // ... omitted methods that are similar in ScalarUDFImpl +} + +pub enum ValueOrLambdaParameter { + /// A columnar value with the given field + Value(FieldRef), + /// A lambda + Lambda, +} + +/// Information about arguments passed to the function +/// +/// This structure contains metadata about how the function was called +/// such as the type of the arguments, any scalar arguments and if the +/// arguments can (ever) be null +/// +/// See [`LambdaUDF::return_field_from_args`] for more information +pub struct LambdaReturnFieldArgs<'a> { + /// The data types of the arguments to the function + /// + /// If argument `i` to the function is a lambda, it will be the field returned by the + /// lambda when executed with the arguments returned from `LambdaUDF::lambdas_parameters` + /// + /// For example, with `array_transform([1], v -> v == 5)` + /// this field will be `[ + // ValueOrLambdaField::Value(Field::new("", DataType::new_list(DataType::Int32, false), false)), + // ValueOrLambdaField::Lambda(Field::new("", DataType::Boolean, false)) + // ]` + pub arg_fields: &'a [ValueOrLambdaField], + /// Is argument `i` to the function a scalar (constant)? + /// + /// If the argument `i` is not a scalar, it will be None + /// + /// For example, if a function is called like `my_function(column_a, 5)` + /// this field will be `[None, Some(ScalarValue::Int32(Some(5)))]` + pub scalar_arguments: &'a [Option<&'a ScalarValue>], +} + +/// A tagged FieldRef indicating whether it correspond the field of a value or the field of the output of a lambda argument +pub enum ValueOrLambdaField { + /// The FieldRef of a ColumnarValue argument + Value(FieldRef), + /// The return FieldRef of the lambda body when evaluated with the parameters from LambdaUDF::lambda_parameters + Lambda(FieldRef), +} + +/// Arguments passed to [`LambdaUDF::invoke_with_args`] when invoking a +/// lambda function. +pub struct LambdaFunctionArgs { + /// The evaluated arguments to the function + pub args: Vec, + /// Field associated with each arg, if it exists + pub arg_fields: Vec, + /// The number of rows in record batch being evaluated + pub number_rows: usize, + /// The return field of the lambda function returned (from `return_type` + /// or `return_field_from_args`) when creating the physical expression + /// from the logical expression + pub return_field: FieldRef, + /// The config options at execution time + pub config_options: Arc, +} + +/// A lambda argument to a LambdaFunction +pub struct LambdaFunctionLambdaArg { + /// The parameters defined in this lambda + /// + /// For example, for `array_transform([2], v -> -v)`, + /// this will be `vec![Field::new("v", DataType::Int32, true)]` + pub params: Vec, + /// The body of the lambda + /// + /// For example, for `array_transform([2], v -> -v)`, + /// this will be the physical expression of `-v` + pub body: Arc, + /// A RecordBatch containing at least the captured columns inside this lambda body, if any + /// Note that it may contain additional, non-specified columns, + /// but that's implementation detail and should not be relied upon + /// + /// For example, for `array_transform([2], v -> v + t.a + t.b)`, + /// this will be a `RecordBatch` with at least two columns, `t.a` and `t.b` + pub captures: Option, +} + +// An argument to a LambdaUDF +pub enum ValueOrLambda { + Value(ColumnarValue), + Lambda(LambdaFunctionLambdaArg), +} +``` + + +
+ +
array_transform lambdas_parameters implementation + +```rust +impl LambdaUDF for ArrayTransform { + fn lambdas_parameters( + &self, + args: &[ValueOrLambdaParameter], + ) -> Result>>> { + // list is the field of [2, 3]: Field::new("", DataType::new_list(DataType::Int32, false), false) + let [ValueOrLambdaParameter::Value(list), ValueOrLambdaParameter::Lambda] = args + else { + return exec_err!( + "{} expects a value follewed by a lambda, got {:?}", + self.name(), + args + ); + }; + + // the field of [2, 3] inner values: Field::new("", DataType::Int32, false) + let (field, index_type) = match list.data_type() { + DataType::List(field) => (field, DataType::Int32), + DataType::LargeList(field) => (field, DataType::Int64), + DataType::FixedSizeList(field, _) => (field, DataType::Int32), + _ => return exec_err!("expected list, got {list}"), + }; + + // we don't need to omit the index in the case the lambda don't specify, e.g. array_transform([], v -> v*2), + // nor check whether the lambda contains more than two parameters, e.g. array_transform([], (v, i, j) -> v+i+j), + // as datafusion will do that for us + let value = Field::new("", field.data_type().clone(), field.is_nullable()) + .with_metadata(field.metadata().clone()); + let index = Field::new("", index_type, false); + + Ok(vec![None, Some(vec![value, index])]) + } +} +``` + +
+ +
array_transform return_field_from_args implementation + +```rust +impl LambdaUDF for ArrayTransform { + fn return_field_from_args( + &self, + args: datafusion_expr::LambdaReturnFieldArgs, + ) -> Result> { + // [ + // Field::new("", DataType::new_list(DataType::Int32, false), false), + // Field::new("", DataType::Boolean, false), + // ] + let [ValueOrLambdaField::Value(list), ValueOrLambdaField::Lambda(lambda)] = + take_function_args(self.name(), args.arg_fields)? + else { + return exec_err!( + "{} expects a value follewed by a lambda, got {:?}", + self.name(), + args + ); + }; + + // lambda is the return_field of the lambda body + // when evaluated with the parameters from lambdas_parameters + let field = Arc::new(Field::new( + Field::LIST_FIELD_DEFAULT_NAME, + lambda.data_type().clone(), + lambda.is_nullable(), + )); + + let return_type = match list.data_type() { + DataType::List(_) => DataType::List(field), + DataType::LargeList(_) => DataType::LargeList(field), + DataType::FixedSizeList(_, size) => DataType::FixedSizeList(field, *size), + other => plan_err!("expected list, got {other}"), + }; + + Ok(Arc::new(Field::new("", return_type, list.is_nullable()))) + } +} +``` + +
+ +
array_transform invoke_with_args implementation + + +```rust +impl LambdaUDF for ArrayTransform { + fn invoke_with_args(&self, args: LambdaFunctionArgs) -> Result { + let [list_value, lambda] = take_function_args(self.name(), &args.args)?; + + // list = [2, 3] + // lambda = LambdaFunctionLambdaArg { + // params: vec![Field::new("v", DataType::Int32, false)], + // body: PhysicalExpr("v != t.n"),// the physical expression of the lambda *body*, and not the lambda itself: this is not a LambdaExpr. + // captures: Some(record_batch!("t.n", Int32, [2])) + // } + let (ValueOrLambda::Value(list_value), ValueOrLambda::Lambda(lambda)) = + (list_value, lambda) + else { + return exec_err!( + "{} expects a value followed by a lambda, got {} and {}", + self.name(), + list_value, + lambda, + ); + }; + + let list_array = list_value.to_array(args.number_rows)?; + let list_values = match list_array.data_type() { + DataType::List(_) => list_array.as_list::().values(), + DataType::LargeList(_) => list_array.as_list::().values(), + DataType::FixedSizeList(_, _) => list_array.as_fixed_size_list().values(), + other => exec_err!("expected list, got {other}") + } + + // if any column got captured, we need to adjust it to the values arrays, + // duplicating values of list with mulitple values and removing values of empty lists + // list_indices is not cheap so is important to avoid it when no column is captured + let adjusted_captures = lambda + .captures + .as_ref() + //list_indices return the row_number for each sublist element: [[1, 2], [3], [4]] => [0,0,1,2], not included here + .map(|captures| take_record_batch(captures, &list_indices(&list_array)?)) + .transpose()? + .unwrap_or_else(|| { + RecordBatch::try_new_with_options( + Arc::new(Schema::empty()), + vec![], + &RecordBatchOptions::new().with_row_count(Some(list_values.len())), + ) + .unwrap() + }); + + // by using closures, bind_lambda_variables can evaluate only the needed ones avoiding unnecessary computations + let values_param = || Ok(Arc::clone(list_values)); + //elements_indices return the index of each element within its sublist: [[5, 3], [7, 1, 1]] => [1, 2, 1, 2, 3], not included here + let indices_param = || elements_indices(&list_array); + + let binded_body = bind_lambda_variables( + Arc::clone(&lambda.body), + &lambda.params, + &[&values_param, &indices_param], + )?; + + // call the transforming expression with the record batch + let transformed_values = binded_body + .evaluate(&adjusted_captures)? + .into_array(list_values.len())?; + + let field = match args.return_field.data_type() { + DataType::List(field) + | DataType::LargeList(field) + | DataType::FixedSizeList(field, _) => Arc::clone(field), + _ => { + return exec_err!( + "{} expected ScalarFunctionArgs.return_field to be a list, got {}", + self.name(), + args.return_field + ) + } + }; + + let transformed_list = match list_array.data_type() { + DataType::List(_) => { + let list = list_array.as_list(); + + Arc::new(ListArray::new( + field, + list.offsets().clone(), + transformed_values, + list.nulls().cloned(), + )) as ArrayRef + } + DataType::LargeList(_) => { + let large_list = list_array.as_list(); + + Arc::new(LargeListArray::new( + field, + large_list.offsets().clone(), + transformed_values, + large_list.nulls().cloned(), + )) + } + DataType::FixedSizeList(_, value_length) => { + Arc::new(FixedSizeListArray::new( + field, + *value_length, + transformed_values, + list_array.as_fixed_size_list().nulls().cloned(), + )) + } + other => exec_err!("expected list, got {other}")?, + }; + + Ok(ColumnarValue::Array(transformed_list)) + } +} +``` + +
+ +
How relevant LambdaUDF methods would be called and what they would return during planning and evaluation of the example + + +```rust +// this is called at sql planning +let lambdas_parameters = lambda_udf.lambdas_parameters(&[ + ValueOrLambdaParameter::Value(Field::new("", DataType::new_list(DataType::Int32, false), false)), // the Field of the [2, 3] literal + ValueOrLambdaParameter::Lambda, // A unspecified lambda. On the example, v -> v != t.n +])?; + +assert_eq!( + lambdas_parameters, + vec![ + // the [2, 3] argument, not a lambda so no parameters + None, + // the parameters that *can* be declared on the lambda, and not only + // those actually declared: the implementation doesn't need to care + // about it + Some(vec![ + Field::new("", DataType::Int32, false), // the list inner value + Field::new("", DataType::Int32, false), // the 1-based index of the element being transformed + ])] +); + + + +// this is called every time ExprSchemable is called on a LambdaFunction +let return_field = array_transform.return_field_from_args(&LambdaReturnFieldArgs { + arg_fields: &[ + ValueOrLambdaField::Value(Field::new("", DataType::new_list(DataType::Int32, false), false)), + ValueOrLambdaField::Lambda(Field::new("", DataType::Boolean, false)), // the return_field of the expression "v != t.n" when "v" is of the type returned in lambdas_parameters + ], + scalar_arguments // irrelevant +})?; + +assert_eq!(return_field, Field::new("", DataType::new_list(DataType::Boolean, false), false)); + + + +let value = array_transform.evaluate(&LambdaFunctionArgs { + args: vec![ + ValueOrLambda::Value(List([2, 3])), + ValueOrLambda::Lambda(LambdaFunctionLambdaArg { + params: vec![Field::new("v", DataType::Int32, false)], + body: PhysicalExpr("v != t.n"),// the physical expression of the lambda *body*, and not the lambda itself: this is not a LambdaExpr. + captures: Some(record_batch!("t.n", Int32, [2])) + }), + ], + arg_fields, // same as above + number_rows: 1, + return_field, // same as above + config_options, // irrelevant +})?; + +assert_eq!(value, BooleanArray::from([false, true])) +``` + +
+
+
+ +A pair LambdaUDF/LambdaUDFImpl like ScalarFunction was not used because those exist only [to maintain backwards compatibility with the older API](https://docs.rs/datafusion/latest/datafusion/logical_expr/struct.ScalarUDF.html#api-note) #8045 + +LambdaFunction invocation: + +Instead of evaluating all it's arguments as ScalarFunction, LambdaFunction does the following: + +1. If it's a non lambda argument, evaluate as usual, and provide the resulting `ColumnarValue` to `LambdaUDF::evaluate` as a `ValueOrLambda::Value` +2. If it's a lambda, construct a `LambdaFunctionLambdaArg` containing the lambda body physical expression and a record batch containing any captured columns as a `ValueOrLambda::Lambda` and provide it to `LambdaUDF::evaluate`. To avoid costly copies of uncaptured columns, we swap them with a `NullArray` while keeping the number of columns on the batch the same so captured columns indices are kept stable across the whole tree. The recent #18329 instead projects-out uncaptured columns and rewrites the expr adjusting columns indexes. If that is preferrable we can generalize that implementation and use it here too. + +
LambdaFunction evalution + +```rust + +impl PhysicalExpr for LambdaFunctionExpr { + fn evaluate(&self, batch: &RecordBatch) -> Result { + let args = self.args + .map(|arg| { + match arg.as_any().downcast_ref::() { + Some(lambda) => { + // helper method that returns the indices of the captured columns. In the example, the only column available (index 0) is captured, so this would be HashSet(0) + let captures = lambda.captures(); + + let captures = if !captures.is_empty() { + let (fields, columns): (Vec<_>, _) = std::iter::zip( + batch.schema_ref().fields(), + batch.columns(), + ) + .enumerate() + .map(|(column_index, (field, column))| { + if captures.contains(&column_index) { + (Arc::clone(field), Arc::clone(column)) + } else { + ( + Arc::new(Field::new( + field.name(), + DataType::Null, + false, + )), + Arc::new(NullArray::new(column.len())) as _, + ) + } + }) + .unzip(); + + let schema = Arc::new(Schema::new(fields)); + + Some(RecordBatch::try_new(schema, columns)?) + } else { + None + }; + + Ok(ValueOrLambda::Lambda(LambdaFunctionLambdaArg { + params, // irrelevant, + body: Arc::clone(lambda.body()), // use the lambda body and not the lambda itself + captures, + })) + } + None => Ok(ValueOrLambda::Value(arg.evaluate(batch)?)), + } + }) + .collect::>>()?; + + // evaluate the function + let output = self.fun.invoke_with_args(LambdaFunctionArgs { + args, + arg_fields, // irrelevant + number_rows: batch.num_rows(), + return_field: Arc::clone(&self.return_field), + config_options: Arc::clone(&self.config_options), + })?; + + Ok(output) + } +} + +``` + +
+
+ +Why `LambdaVariable` and not `Column`: + +Existing tree traversals that operate on columns would break if some column nodes referenced to a lambda parameter and not a real column. In the example query, projection pushdown would try to push the lambda parameter "v", which won't exist in table "t". + +Example of code of another traversal that would break: + +```rust +fn minimize_join_filter(expr: Arc, ...) -> JoinFilter { + let mut used_columns = HashSet::new(); + expr.apply(|expr| { + if let Some(col) = expr.as_any().downcast_ref::() { + // if this is a lambda column, this function will break + used_columns.insert(col.index()); + } + Ok(TreeNodeRecursion::Continue) + }); + ... +} +``` + +Furthermore, the implemention of `ExprSchemable` and `PhysicalExpr::return_field` for `Column` expects that the schema it receives as a argument contains an entry for its name, which is not the case for lambda parameters. + +By including a `FieldRef` on `LambdaVariable` that should be resolved either during construction time, as in the sql planner, or later by the an `AnalyzerRule`, `ExprSchemable` and `PhysicalExpr::return_field` simply return it's own Field: + +
LambdaVariable ExprSchemable and PhysicalExpr::return_field implementation + +```rust +impl ExprSchemable for Expr { + fn to_field( + &self, + schema: &dyn ExprSchema, + ) -> Result<(Option, Arc)> { + let (relation, schema_name) = self.qualified_name(); + let field = match self { + Expr::LambdaVariable(l) => Ok(Arc::clone(&l.field.ok_or_else(|| plan_err!("Unresolved LambdaVariable {}", l.name)))), + ... + }?; + + Ok(( + relation, + Arc::new(field.as_ref().clone().with_name(schema_name)), + )) + } + ... +} + +impl PhysicalExpr for LambdaVariable { + fn return_field(&self, _input_schema: &Schema) -> Result { + Ok(Arc::clone(&self.field)) + } + ... +} +``` + +
+
+ +For reference, [Spark](https://github.com/apache/spark/blob/8b68a172d34d2ed9bd0a2deefcae1840a78143b6/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala#L77) and [Substrait](https://substrait.io/expressions/lambda_expressions/#parameter-references) also use a specialized node instead of a regular column + +There's also discussions on making every expr own it's type: #18845, #12604 + +
Possible fixes discarded due to complexity, requiring downstream changes and implementation size: + +1. Add a new set of TreeNode methods that provides the set of lambdas parameters names seen during the traversal, so column nodes can be tested if they refer to a regular column or to a lambda parameter. Any downstream user that wants to support lambdas would need use those methods instead of the existing ones. This also would add 1k+ lines to the PR. + +```rust +impl Expr { + pub fn transform_with_lambdas_params< + F: FnMut(Self, &HashSet) -> Result>, + >( + self, + mut f: F, + ) -> Result> {} +} +``` + +How minimize_join_filter would looks like: + + +```rust +fn minimize_join_filter(expr: Arc, ...) -> JoinFilter { + let mut used_columns = HashSet::new(); + expr.apply_with_lambdas_params(|expr, lambdas_params| { + if let Some(col) = expr.as_any().downcast_ref::() { + // dont include lambdas parameters + if !lambdas_params.contains(col.name()) { + used_columns.insert(col.index()); + } + } + Ok(TreeNodeRecursion::Continue) + }) + ... +} +``` + +2. Add a flag to the Column node indicating if it refers to a lambda parameter. Still requires checking for it on existing tree traversals that works on Columns (30+) and also downstream. + +```rust +//logical +struct Column { + pub relation: Option, + pub name: String, + pub spans: Spans, + pub is_lambda_parameter: bool, +} + +//physical +struct Column { + name: String, + index: usize, + is_lambda_parameter: bool, +} +``` + + +How minimize_join_filter would look like: + +```rust +fn minimize_join_filter(expr: Arc, ...) -> JoinFilter { + let mut used_columns = HashSet::new(); + expr.apply(|expr| { + if let Some(col) = expr.as_any().downcast_ref::() { + // dont include lambdas parameters + if !col.is_lambda_parameter { + used_columns.insert(col.index()); + } + } + Ok(TreeNodeRecursion::Continue) + }) + ... +} +``` + + +1. Add a new set of TreeNode methods that provides a schema that includes the lambdas parameters for the scope of the node being visited/transformed: + +```rust +impl Expr { + pub fn transform_with_schema< + F: FnMut(Self, &DFSchema) -> Result>, + >( + self, + schema: &DFSchema, + f: F, + ) -> Result> { ... } + ... other methods +} +``` + +For any given LambdaFunction found during the traversal, a new schema is created for each lambda argument that contains it's parameter, returned from LambdaUDF::lambdas_parameters +How it would look like: + +```rust + +pub fn infer_placeholder_types(self, schema: &DFSchema) -> Result<(Expr, bool)> { + let mut has_placeholder = false; + // Provide the schema as the first argument. + // Transforming closure receive an adjusted_schema as argument + self.transform_with_schema(schema, |mut expr, adjusted_schema| { + match &mut expr { + // Default to assuming the arguments are the same type + Expr::BinaryExpr(BinaryExpr { left, op: _, right }) => { + // use adjusted_schema and not schema. Those expressions may contain + // columns referring to a lambda parameter, which Field would only be + // available in adjusted_schema and not in schema + rewrite_placeholder(left.as_mut(), right.as_ref(), adjusted_schema)?; + rewrite_placeholder(right.as_mut(), left.as_ref(), adjusted_schema)?; + } + .... + +``` + +2. Make available trought LogicalPlan and ExecutionPlan nodes a schema that includes all lambdas parameters from all expressions owned by the node, and use this schema for tree traversals. For nodes which won't own any expression, the regular schema can be returned + + +```rust +impl LogicalPlan { + fn lambda_extended_schema(&self) -> &DFSchema; +} + +trait ExecutionPlan { + fn lambda_extended_schema(&self) -> &DFSchema; +} + +//usage +impl LogicalPlan { + pub fn replace_params_with_values( + self, + param_values: &ParamValues, + ) -> Result { + self.transform_up_with_subqueries(|plan| { + // use plan.lambda_extended_schema() containing lambdas parameters + // instead of plan.schema() which wont + let lambda_extended_schema = Arc::clone(plan.lambda_extended_schema()); + let name_preserver = NamePreserver::new(&plan); + plan.map_expressions(|e| { + // if this expression is child of lambda and contain columns referring it's parameters + // the lambda_extended_schema already contain them + let (e, has_placeholder) = e.infer_placeholder_types(&lambda_extended_schema)?; + .... + +``` +
+
+ +`LambdaVariable` evaluation, current implementation: + +The physical `LambdaVariable` contains an optional `ColumnarValue` that must be binded for each batch before evaluation with the helper function `bind_lambda_variables`, which rewrites the whole lambda body, binding any variable of the tree. + +
LambdaVariable::evaluate + +```rust +impl PhysicalExpr for LambdaVariable { + fn evaluate(&self, _batch: &RecordBatch) -> Result { + self.value.clone().ok_or_else(|| exec_datafusion_err!("Physical LambdaVariable {} unbinded value", self.name)) + } +} +``` + +
+
+ +Unbinded: +``` +LambdaExpression + parameters: ["v"] + body: + ComparisonExpression(!=) + left: + LambdaVariableExpression("v", Field::new("", Int32, false), None) + right: + ColumnExpression("n") +``` + +After binding: + +``` +LambdaExpression + parameters: ["v"] + body: + ComparisonExpression(!=) + left: + LambdaVariableExpression("v", Field::new("", Int32, false), Some([2, 3])) + right: + ColumnExpression("n") +``` + +Alternative: + +Make the `LambdaVariable` evaluate it's value from the batch passed to `PhysicalExpr::evaluate` as a regular column. For that, instead of binding the body, the `LambdaUDF` implementation would merge the captured batch of a lambda with the values of it's parameters. So that it happen via an index as a regular column, the schema used plan to physical `LambdaVariable` must contain the lambda parameters. This would be the only place during planning that a schema would contain those parameters. Otherwise it only can get the value from the batch via name instead of index + +1. Add a index to LambdaVariable, similar to Column, and remove the optional value. + +```rust +struct LambdaVariable { + name: String, // "v", of the lambda body: v != t.n + field: FieldRef, // Field::new("", DataType::Int32, false) + index: usize, // 1 +} +``` + +2. Insert the lambda parameters only at the Schema used to do the physical planning, to compute the index of a LambdaVariable + +
how physical planning would look like + +```rust +fn create_physical_expr( + e: &Expr, + input_dfschema: &DFSchema, + execution_props: &ExecutionProps, +) -> Result> { + let input_schema = input_dfschema.as_arrow(); + + match e { + ... + Expr::LambdaFunction(LambdaFunction { func, args}) => { + let args_metadata = args.iter() + .map(|arg| if arg.is::() { + Ok(ValueOrLambdaParameter::Lambda) + } else { + Ok(ValueOrLambdaParameter::Value(arg.to_field(input_dfschema)?)) + }) + .collect()?; + + let lambdas_parameters = func.lambdas_parameters(&args_metadata)?; + + let physical_args = std::iter::zip(args, lambdas_parameters) + .map(|(arg, lambda_parameters)| { + match (arg.downcast_ref::(), lambda_parameters) { + (Some(lambda), Some(lambda_parameters)) => { + let extended_dfschema = merge_schema_and_parameters(input_dfschame, lambda_parameters)?; + + create_physical_expr(body, extended_dfschema, execution_props) + } + (None, None) => create_physical_expr(arg, input_dfschema, execution_props), + (Some(_), None) => plan_err!("lambdas_parameters returned None for a lambda") + (None, Some(_)) => plan_err!("lambdas_parameters returned Some for a non lambda") + } + }) + .collect()?; + + Ok(Arc::new(LambdaFunctionExpr::try_new( + Arc::clone(func), + physical_args, + input_schema, + config_options: ... // irrelevant + )?)) + } + } +} +``` + +
+
+ +3. Insert the lambda parameters values into the RecordBatch during the evaluation phase: the LambdaUDF, instead of binding the lambda body variables, inserts it's parameters on the captured RecordBatch it receives on LambdaFunctionLambdaArg. + +How ArrayTransform::invoke_with_args would look like: + +```rust + ... + let values_param = || Ok(Arc::clone(list_values)); + let indices_param = || elements_indices(&list_array); + + let merged_batch = merge_captures_with_params( + adjusted_captures, + &lambda.params, + &[&values_param, &indices_param], + )?; + + // call the transforming expression with the record batch + let transformed_values = lambda.body + .evaluate(&merged_batch)? + .into_array(list_values.len())?; + + ... +``` + +
+ +Why is `LambdaVariable` `Field` is an `Option`? + +So expr_api users can construct a LambdaVariable just by using it's name, without having to set it's field. An `AnalyzerRule` will then set the `LambdaVariable` field based on the returned values from `LambdaUDF::lambdas_parameters` of any `LambdaFunction` it finds while traversing down a expr tree. We may include that rule on the default rules list for when the plan/expression tree is transformed by another rule in a way that changes the types of non lambda arguments of a lambda function, as it may change the types of it's lambda parameters, which would render `LambdaVariable` field's out of sync, as the rule would fix it. Or to not increase planning time we don't include it by default and instruct `expr_api` users to add it manually if needed + + + +```rust +array_transform( + col("my_array"), + lambda( + vec!["current_value"], + 2 * lambda_variable("current_value") + ) +) + +//instead of + +array_transform( + col("my_array"), + lambda( + vec!["current_value"], + 2 * lambda_variable("current_value", Field::new("", DataType::Int32, false)) + ) +) +``` + + +Why set `LambdaVariable` field during sql planning if it's optional and can be set later via an `AnalyzerRule`? + +Some parts of sql planning checks the type/nullability of the already planned children expression of the expr it's planning, and would error if doing so on a unresolved `LambdaVariable` +Take as example this expression: `array_transform([[0, 1]], v -> v[1])`. `FieldAccess` `v[1]` planning is handled by the `ExprPlanner` `FieldAccessPlanner`, which checks the datatype of `v`, a lambda variable, which `ExprSchemable` implementation depends on it's field being resolved, and not on the `PlannerContext` schema, requiring sql planner to plan `LambdaVariables` with a resolved field + + +
FieldAccessPlanner + +```rust +pub struct FieldAccessPlanner; + +impl ExprPlanner for FieldAccessPlanner { + fn plan_field_access( + &self, + expr: RawFieldAccessExpr, // "v[1]" + schema: &DFSchema, + ) -> Result> { + // { "v", "[1]" } + let RawFieldAccessExpr { expr, field_access } = expr; + + match field_access { + ... + // expr[idx] ==> array_element(expr, idx) + GetFieldAccess::ListIndex { key: index } => { + match expr { + ... + // ExprSchemable::get_type called + _ if matches!(expr.get_type(schema)?, DataType::Map(_, _)) => { + Ok(PlannerResult::Planned(Expr::ScalarFunction( + ScalarFunction::new_udf( + get_field_inner(), + vec![expr, *index], + ), + ))) + } + } + } + } + } +} +``` + +
+
+ + Therefore we can't plan all arguments on a single pass, and must first plan the non-lambda arguments, collect their types and nullability, pass them to `LambdaUDF::lambdas_parameters`, which will derive the type of it's lambda parameters based on the type of it's non-lambda argument, and return it to the planner, which, for each unplanned lambda argument, will create a new `PlannerContext` via `with_lambda_parameters`, which contains a mapping of lambdas parameters names to it's type. Then, when planning a `ast::Identifier`, it first check whether a lambda parameter with the given name exists, and if so, plans it into a `Expr::LambdaVariable` with a resolved field, otherwise plan it into a regular `Expr::Column`. + + + +
sql planning + + +```rust +struct PlannerContext { + /// The parameters of all lambdas seen so far + lambdas_parameters: HashMap, + // ... omitted fields +} + +impl PlannerContext { + pub fn with_lambda_parameters( + mut self, + arguments: impl IntoIterator, + ) -> Self { + self.lambdas_parameters + .extend(arguments.into_iter().map(|f| (f.name().clone(), f))); + + self + } +} + +// copied from sqlparser +struct LambdaFunction { + pub params: OneOrManyWithParens, // One("v") + pub body: Box, // v != t.n +} + +// copied from sqlparser +enum OneOrManyWithParens { + One(T), // "v" + Many(Vec), +} + +/// the planning would happens as the following: + +enum ExprOrLambda { + Expr(Expr), // planned [2, 3] + Lambda(ast::LambdaFunction), // unplanned v -> v != t.n +} + +impl SqlToRel { + // example function, won't exist + fn plan_array_transform(&self, array_transform: Arc, args: Vec, schema: &DFSchema, planner_context: &mut PlannerContext) -> Result { + let args = args.into_iter() + .map(|arg| match arg { + ast::Expr::LambdaFunction(l) => Ok(ExprOrLambda::Lambda(l)),//skip planning until we plan non lambda args + arg => Ok(ExprOrLambda::Expr( + self.sql_fn_arg_to_logical_expr_with_name( + arg, + schema, + planner_context, + )?, + )) + }) + .collect::>>()?; + + let args_metadata = args.iter() + .map(|arg| match arg { + Expr(expr) => Ok(ValueOrLambda::Value(expr.to_field(schema)?)), + Lambda(_) => Ok(ValueOrLambda::Lambda), + }) + .collect::>>()?; + + let lambdas_parameters = array_transform.lambdas_parameters(&args_metadata)?; + + let args = std::iter::zip(args, lambdas_parameters) + .map(|(arg, lambdas_parameters)| match (arg, lambdas_parameters) { + (ExprOrLambda::Expr(planned_expr), None) => Ok(planned_expr), + (ExprOrLambda::Lambda(unplanned_lambda), Some(lambda_parameters)) => { + let params = + unplanned_lambda.params + .iter() + .map(|p| p.value.clone()) + .collect(); + + let lambda_parameters = lambda_params + .into_iter() + .zip(¶ms) + .map(|(field, name)| Arc::new(field.with_name(name))); + + let mut planner_context = planner_context + .clone() + .with_lambda_parameters(lambda_parameters); + + Ok(( + Expr::Lambda(Lambda { + params, + body: Box::new(self.sql_expr_to_logical_expr( + *lambda.body, + schema, + &mut planner_context, + )?), + }), + None, + )) + } + (ExprOrLambda::Expr(planned_expr), Some(lambda_parameters)) => plan_err!("lambdas_parameters returned Some for a value"), + (ExprOrLambda::Lambda(unplanned_lambda), None) => plan_err!("lambdas_parameters returned None for a lambda"), + }) + .collect::>>()?; + + Ok(Expr::LambdaFunction(LambdaFunction { + func: array_transform, + args, + })) + } + + fn sql_identifier_to_expr( + &self, + id: ast::Ident, + schema: &DFSchema, + planner_context: &mut PlannerContext, + ) -> Result { + // simplified implementation + if let Some(field) = planner_context.lambdas_parameters.get(id) { + Ok(Expr::LambdaVariable(LambdaVariable { + name: id, // "v" + field, // Field::new("", DataType::Int32, false) + })) + } else { + Ok(Expr::Column(Column::new(id))) + } + } +} + +``` + +
+
+ +`LambdaFunction` `Signature` is non functional + +Currenty, `LambdaUDF::signature` returns the same `Signature` as `ScalarUDF`, but it's `type_signature` field is never used, as most variants of the `TypeSignature` enum aren't applicable to a lambda, and no type coercion is applied on it's arguments, being currently a implementation responsability. We should either add lambda compatible variants to the `TypeSignature` enum, create a new `LambdaTypeSignature` and `LambdaSignature`, or support no automatic type coercion at all on lambda functions. From 83dfbdd13e98be8a00f90becc26f5b9f277eebff Mon Sep 17 00:00:00 2001 From: gstvg <28798827+gstvg@users.noreply.github.com> Date: Sun, 8 Mar 2026 16:56:46 -0300 Subject: [PATCH 08/47] add lambda note in substrait consumer --- .../substrait/src/logical_plan/consumer/expr/scalar_function.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/datafusion/substrait/src/logical_plan/consumer/expr/scalar_function.rs b/datafusion/substrait/src/logical_plan/consumer/expr/scalar_function.rs index f80cf43eb81eb..062c1ac03110c 100644 --- a/datafusion/substrait/src/logical_plan/consumer/expr/scalar_function.rs +++ b/datafusion/substrait/src/logical_plan/consumer/expr/scalar_function.rs @@ -30,6 +30,7 @@ pub async fn from_scalar_function( f: &ScalarFunction, input_schema: &DFSchema, ) -> Result { + //TODO: handle lambda functions, as they are also encoded as scalar functions let Some(fn_signature) = consumer .get_extensions() .functions From 34137e15ca87bd71a9a0b2bd386a823f81fafad8 Mon Sep 17 00:00:00 2001 From: gstvg <28798827+gstvg@users.noreply.github.com> Date: Sun, 8 Mar 2026 17:32:41 -0300 Subject: [PATCH 09/47] add LambdaSignature --- datafusion/expr/src/lib.rs | 4 +- datafusion/expr/src/udlf.rs | 46 +++++++++++++------ .../functions-nested/src/array_transform.rs | 8 ++-- .../physical-expr/src/lambda_function.rs | 14 ++---- datafusion/proto/src/bytes/mod.rs | 10 ++-- 5 files changed, 49 insertions(+), 33 deletions(-) diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index e6dfd9fc1483b..4fc6c738ea06b 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -120,8 +120,8 @@ pub use udaf::{ }; pub use udf::{ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl}; pub use udlf::{ - LambdaFunctionArgs, LambdaFunctionLambdaArg, LambdaReturnFieldArgs, LambdaUDF, - ValueOrLambda, ValueOrLambdaField, ValueOrLambdaParameter, + LambdaFunctionArgs, LambdaFunctionLambdaArg, LambdaReturnFieldArgs, LambdaSignature, + LambdaUDF, ValueOrLambda, ValueOrLambdaField, ValueOrLambdaParameter, }; pub use udwf::{LimitEffect, ReversedUDWF, WindowUDF, WindowUDFImpl}; pub use window_frame::{WindowFrame, WindowFrameBound, WindowFrameUnits}; diff --git a/datafusion/expr/src/udlf.rs b/datafusion/expr/src/udlf.rs index 84f9494d2edd8..883a083ca9c9b 100644 --- a/datafusion/expr/src/udlf.rs +++ b/datafusion/expr/src/udlf.rs @@ -27,7 +27,7 @@ use datafusion_common::config::ConfigOptions; use datafusion_common::{Result, ScalarValue, not_impl_err}; use datafusion_expr_common::dyn_eq::{DynEq, DynHash}; use datafusion_expr_common::interval_arithmetic::Interval; -use datafusion_expr_common::signature::Signature; +use datafusion_expr_common::signature::{Volatility}; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use std::any::Any; use std::cmp::Ordering; @@ -35,6 +35,28 @@ use std::fmt::Debug; use std::hash::{Hash, Hasher}; use std::sync::Arc; +/// Provides information necessary for calling a lambda function. +/// +/// - [`Volatility`] defines how the output of the function changes with the input. +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] +pub struct LambdaSignature { + /// The volatility of the function. See [Volatility] for more information. + pub volatility: Volatility, + /// Optional parameter names for the function arguments. + /// + /// If provided, enables named argument notation for function calls (e.g., `func(a => 1, b => 2)`). + /// + /// Defaults to `None`, meaning only positional arguments are supported. + pub parameter_names: Option>, +} + +impl LambdaSignature { + /// Creates a new Signature from a given volatility. + pub fn new(volatility: Volatility) -> LambdaSignature { + LambdaSignature { volatility, parameter_names: None } + } +} + impl PartialEq for dyn LambdaUDF { fn eq(&self, other: &Self) -> bool { self.dyn_eq(other.as_any()) @@ -190,19 +212,19 @@ pub enum ValueOrLambdaField { /// # use std::sync::LazyLock; /// # use arrow::datatypes::DataType; /// # use datafusion_common::{DataFusionError, plan_err, Result}; -/// # use datafusion_expr::{col, ColumnarValue, Documentation, LambdaFunctionArgs, Signature, Volatility}; +/// # use datafusion_expr::{col, ColumnarValue, Documentation, LambdaFunctionArgs, LambdaSignature, Volatility}; /// # use datafusion_expr::LambdaUDF; /// # use datafusion_expr::lambda_doc_sections::DOC_SECTION_MATH; /// /// This struct for a simple UDF that adds one to an int32 /// #[derive(Debug, PartialEq, Eq, Hash)] /// struct AddOne { -/// signature: Signature, +/// signature: LambdaSignature, /// } /// /// impl AddOne { /// fn new() -> Self { /// Self { -/// signature: Signature::uniform(1, vec![DataType::Int32], Volatility::Immutable), +/// signature: LambdaSignature::new(Volatility::Immutable), /// } /// } /// } @@ -221,7 +243,7 @@ pub enum ValueOrLambdaField { /// impl LambdaUDF for AddOne { /// fn as_any(&self) -> &dyn Any { self } /// fn name(&self) -> &str { "add_one" } -/// fn signature(&self) -> &Signature { &self.signature } +/// fn signature(&self) -> &LambdaSignature { &self.signature } /// fn return_type(&self, args: &[DataType]) -> Result { /// if !matches!(args.get(0), Some(&DataType::Int32)) { /// return plan_err!("add_one only accepts Int32 arguments"); @@ -274,14 +296,14 @@ pub trait LambdaUDF: Debug + DynEq + DynHash + Send + Sync { )) } - /// Returns a [`Signature`] describing the argument types for which this + /// Returns a [`LambdaSignature`] describing the argument types for which this /// function has an implementation, and the function's [`Volatility`]. /// - /// See [`Signature`] for more details on argument type handling + /// See [`LambdaSignature`] for more details on argument type handling /// and [`Self::return_type`] for computing the return type. /// /// [`Volatility`]: datafusion_expr_common::signature::Volatility - fn signature(&self) -> &Signature; + fn signature(&self) -> &LambdaSignature; /// Create a new instance of this function with updated configuration. /// @@ -534,8 +556,6 @@ pub trait LambdaUDF: Debug + DynEq + DynHash + Send + Sync { /// See the [type coercion module](crate::type_coercion) /// documentation for more details on type coercion /// - /// [`TypeSignature`]: crate::TypeSignature - /// /// For example, if your function requires a floating point arguments, but the user calls /// it like `my_func(1::int)` (i.e. with `1` as an integer), coerce_types can return `[DataType::Float64]` /// to ensure the argument is converted to `1::double` @@ -578,7 +598,7 @@ mod tests { struct TestLambdaUDF { name: &'static str, field: &'static str, - signature: Signature, + signature: LambdaSignature, } impl LambdaUDF for TestLambdaUDF { fn as_any(&self) -> &dyn Any { @@ -589,7 +609,7 @@ mod tests { self.name } - fn signature(&self) -> &Signature { + fn signature(&self) -> &LambdaSignature { &self.signature } @@ -637,7 +657,7 @@ mod tests { Arc::new(TestLambdaUDF { name, field: parameter, - signature: Signature::any(1, Volatility::Immutable), + signature: LambdaSignature::new(Volatility::Immutable), }) } diff --git a/datafusion/functions-nested/src/array_transform.rs b/datafusion/functions-nested/src/array_transform.rs index 746cfe6b62be5..b660f83b3ba2b 100644 --- a/datafusion/functions-nested/src/array_transform.rs +++ b/datafusion/functions-nested/src/array_transform.rs @@ -34,7 +34,7 @@ use datafusion_common::{ HashMap, Result, }; use datafusion_expr::{ - ColumnarValue, Documentation, LambdaFunctionArgs, LambdaUDF, Signature, + ColumnarValue, Documentation, LambdaFunctionArgs, LambdaSignature, LambdaUDF, ValueOrLambda, ValueOrLambdaField, ValueOrLambdaParameter, Volatility, }; use datafusion_macros::user_doc; @@ -70,7 +70,7 @@ use std::{any::Any, sync::Arc}; )] #[derive(Debug, PartialEq, Eq, Hash)] pub struct ArrayTransform { - signature: Signature, + signature: LambdaSignature, aliases: Vec, } @@ -83,7 +83,7 @@ impl Default for ArrayTransform { impl ArrayTransform { pub fn new() -> Self { Self { - signature: Signature::any(2, Volatility::Immutable), + signature: LambdaSignature::new(Volatility::Immutable), aliases: vec![String::from("list_transform")], } } @@ -102,7 +102,7 @@ impl LambdaUDF for ArrayTransform { &self.aliases } - fn signature(&self) -> &Signature { + fn signature(&self) -> &LambdaSignature { &self.signature } diff --git a/datafusion/physical-expr/src/lambda_function.rs b/datafusion/physical-expr/src/lambda_function.rs index 0b0c33cdd6be9..5ba354f4e8a4f 100644 --- a/datafusion/physical-expr/src/lambda_function.rs +++ b/datafusion/physical-expr/src/lambda_function.rs @@ -449,7 +449,7 @@ mod tests { use crate::LambdaFunctionExpr; use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::Result; - use datafusion_expr::{LambdaFunctionArgs, LambdaUDF, Signature}; + use datafusion_expr::{LambdaFunctionArgs, LambdaUDF, LambdaSignature}; use datafusion_expr_common::columnar_value::ColumnarValue; use datafusion_physical_expr_common::physical_expr::is_volatile; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; @@ -457,7 +457,7 @@ mod tests { /// Test helper to create a mock UDF with a specific volatility #[derive(Debug, PartialEq, Eq, Hash)] struct MockLambdaUDF { - signature: Signature, + signature: LambdaSignature, } impl LambdaUDF for MockLambdaUDF { @@ -469,7 +469,7 @@ mod tests { "mock_function" } - fn signature(&self) -> &Signature { + fn signature(&self) -> &LambdaSignature { &self.signature } @@ -489,16 +489,12 @@ mod tests { fn test_lambda_function_volatile_node() { // Create a volatile UDF let volatile_udf = Arc::new(MockLambdaUDF { - signature: Signature::uniform( - 1, - vec![DataType::Float32], - Volatility::Volatile, - ), + signature: LambdaSignature::new(Volatility::Volatile), }); // Create a non-volatile UDF let stable_udf = Arc::new(MockLambdaUDF { - signature: Signature::uniform(1, vec![DataType::Float32], Volatility::Stable), + signature: LambdaSignature::new(Volatility::Stable), }); let schema = Schema::new(vec![Field::new("a", DataType::Float32, false)]); diff --git a/datafusion/proto/src/bytes/mod.rs b/datafusion/proto/src/bytes/mod.rs index e421ea11c4b2d..732831b9464c7 100644 --- a/datafusion/proto/src/bytes/mod.rs +++ b/datafusion/proto/src/bytes/mod.rs @@ -27,8 +27,8 @@ use crate::protobuf; use datafusion_common::{not_impl_err, plan_datafusion_err, Result}; use datafusion_execution::TaskContext; use datafusion_expr::{ - create_udaf, create_udf, create_udwf, AggregateUDF, Expr, LambdaUDF, LogicalPlan, - Signature, Volatility, WindowUDF, + create_udaf, create_udf, create_udwf, AggregateUDF, Expr, LambdaUDF, LambdaSignature, LogicalPlan, + Volatility, WindowUDF, }; use prost::{ bytes::{Bytes, BytesMut}, @@ -196,7 +196,7 @@ impl Serializeable for Expr { #[derive(Debug, PartialEq, Eq, Hash)] struct MockLambdaUDF { name: String, - signature: Signature, + signature: LambdaSignature, } impl LambdaUDF for MockLambdaUDF { @@ -208,7 +208,7 @@ impl Serializeable for Expr { &self.name } - fn signature(&self) -> &Signature { + fn signature(&self) -> &LambdaSignature { &self.signature } @@ -229,7 +229,7 @@ impl Serializeable for Expr { Ok(Arc::new(MockLambdaUDF { name: name.to_string(), - signature: Signature::variadic_any(Volatility::Immutable), + signature: LambdaSignature::new(Volatility::Immutable), })) } } From 3ded1154a4aa7adc167ec033622308c2baa8d524 Mon Sep 17 00:00:00 2001 From: gstvg <28798827+gstvg@users.noreply.github.com> Date: Sun, 8 Mar 2026 20:05:16 -0300 Subject: [PATCH 10/47] improve lambda type coercion --- datafusion/expr/src/lib.rs | 2 +- datafusion/expr/src/udlf.rs | 87 ++++++++++++++++--- .../functions-nested/src/array_transform.rs | 4 +- .../optimizer/src/analyzer/type_coercion.rs | 64 +++++++++++++- .../physical-expr/src/lambda_function.rs | 4 +- datafusion/proto/src/bytes/mod.rs | 2 +- 6 files changed, 143 insertions(+), 20 deletions(-) diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index 4fc6c738ea06b..b03bab622a357 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -121,7 +121,7 @@ pub use udaf::{ pub use udf::{ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl}; pub use udlf::{ LambdaFunctionArgs, LambdaFunctionLambdaArg, LambdaReturnFieldArgs, LambdaSignature, - LambdaUDF, ValueOrLambda, ValueOrLambdaField, ValueOrLambdaParameter, + LambdaTypeSignature, LambdaUDF, ValueOrLambda, ValueOrLambdaField, ValueOrLambdaParameter, }; pub use udwf::{LimitEffect, ReversedUDWF, WindowUDF, WindowUDFImpl}; pub use window_frame::{WindowFrame, WindowFrameBound, WindowFrameUnits}; diff --git a/datafusion/expr/src/udlf.rs b/datafusion/expr/src/udlf.rs index 883a083ca9c9b..b2191e0883d30 100644 --- a/datafusion/expr/src/udlf.rs +++ b/datafusion/expr/src/udlf.rs @@ -35,11 +35,44 @@ use std::fmt::Debug; use std::hash::{Hash, Hasher}; use std::sync::Arc; +/// The types of arguments for which a function has implementations. +/// +/// [`LambdaTypeSignature`] **DOES NOT** define the types that a user query could call the +/// function with. DataFusion will automatically coerce (cast) argument types to +/// one of the supported function signatures, if possible. +/// +/// # Overview +/// Functions typically provide implementations for a small number of different +/// argument [`DataType`]s, rather than all possible combinations. If a user +/// calls a function with arguments that do not match any of the declared types, +/// DataFusion will attempt to automatically coerce (add casts to) function +/// arguments so they match the [`LambdaTypeSignature`]. See the [`type_coercion`] module +/// for more details +/// +/// [`type_coercion`]: crate::type_coercion +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] +pub enum LambdaTypeSignature { + /// The acceptable signature and coercions rules are special for this + /// function. + /// + /// If this signature is specified, + /// DataFusion will call [`LambdaUDF::coerce_value_types`] to prepare argument types. + /// + /// [`LambdaUDF::coerce_value_types`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/trait.LambdaUDF.html#method.coerce_value_types + UserDefined, + /// One or more lambdas or arguments with arbitrary types + VariadicAny, + /// The specified number of lambdas or arguments with arbitrary types. + Any(usize), +} + /// Provides information necessary for calling a lambda function. /// /// - [`Volatility`] defines how the output of the function changes with the input. #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] pub struct LambdaSignature { + /// The data types that the function accepts. See [LambdaTypeSignature] for more information. + pub type_signature: LambdaTypeSignature, /// The volatility of the function. See [Volatility] for more information. pub volatility: Volatility, /// Optional parameter names for the function arguments. @@ -51,9 +84,40 @@ pub struct LambdaSignature { } impl LambdaSignature { - /// Creates a new Signature from a given volatility. - pub fn new(volatility: Volatility) -> LambdaSignature { - LambdaSignature { volatility, parameter_names: None } + /// Creates a new LambdaSignature from a given type signature and volatility. + pub fn new(type_signature: LambdaTypeSignature, volatility: Volatility) -> Self { + LambdaSignature { + type_signature, + volatility, + parameter_names: None, + } + } + + /// User-defined coercion rules for the function. + pub fn user_defined(volatility: Volatility) -> Self { + Self { + type_signature: LambdaTypeSignature::UserDefined, + volatility, + parameter_names: None, + } + } + + /// An arbitrary number of lambdas or arguments of any type. + pub fn variadic_any(volatility: Volatility) -> Self { + Self { + type_signature: LambdaTypeSignature::VariadicAny, + volatility, + parameter_names: None, + } + } + + /// A specified number of arguments of any type + pub fn any(arg_count: usize, volatility: Volatility) -> Self { + Self { + type_signature: LambdaTypeSignature::Any(arg_count), + volatility, + parameter_names: None, + } } } @@ -99,10 +163,10 @@ impl Hash for dyn LambdaUDF { } } -#[derive(Clone, Debug)] -pub enum ValueOrLambdaParameter { - /// A columnar value with the given field - Value(FieldRef), +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum ValueOrLambdaParameter { + /// A value with the given associated data + Value(T), /// A lambda Lambda, } @@ -566,7 +630,10 @@ pub trait LambdaUDF: Debug + DynEq + DynHash + Send + Sync { /// # Return value /// A Vec the same length as `arg_types`. DataFusion will `CAST` the function call /// arguments to these specific types. - fn coerce_types(&self, _arg_types: &[DataType]) -> Result> { + fn coerce_value_types( + &self, + _arg_types: &[ValueOrLambdaParameter], + ) -> Result>> { not_impl_err!("Function {} does not implement coerce_types", self.name()) } @@ -581,7 +648,7 @@ pub trait LambdaUDF: Debug + DynEq + DynHash + Send + Sync { /// Returns the parameters that any lambda supports fn lambdas_parameters( &self, - args: &[ValueOrLambdaParameter], + args: &[ValueOrLambdaParameter], ) -> Result>>> { Ok(vec![None; args.len()]) } @@ -657,7 +724,7 @@ mod tests { Arc::new(TestLambdaUDF { name, field: parameter, - signature: LambdaSignature::new(Volatility::Immutable), + signature: LambdaSignature::variadic_any(Volatility::Immutable), }) } diff --git a/datafusion/functions-nested/src/array_transform.rs b/datafusion/functions-nested/src/array_transform.rs index b660f83b3ba2b..e0c4ab28c1fef 100644 --- a/datafusion/functions-nested/src/array_transform.rs +++ b/datafusion/functions-nested/src/array_transform.rs @@ -83,7 +83,7 @@ impl Default for ArrayTransform { impl ArrayTransform { pub fn new() -> Self { Self { - signature: LambdaSignature::new(Volatility::Immutable), + signature: LambdaSignature::any(2, Volatility::Immutable), aliases: vec![String::from("list_transform")], } } @@ -238,7 +238,7 @@ impl LambdaUDF for ArrayTransform { fn lambdas_parameters( &self, - args: &[ValueOrLambdaParameter], + args: &[ValueOrLambdaParameter], ) -> Result>>> { let [ValueOrLambdaParameter::Value(list), ValueOrLambdaParameter::Lambda] = args else { diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index e0b1d9096b415..8c6c42e7e630c 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -35,7 +35,7 @@ use datafusion_common::{ }; use datafusion_expr::expr::{ self, AggregateFunctionParams, Alias, Between, BinaryExpr, Case, Exists, InList, - InSubquery, Like, ScalarFunction, Sort, WindowFunction, + InSubquery, LambdaFunction, Like, ScalarFunction, Sort, WindowFunction, }; use datafusion_expr::expr_rewriter::coerce_plan_expr_for_schema; use datafusion_expr::expr_schema::cast_subquery; @@ -51,8 +51,9 @@ use datafusion_expr::type_coercion::{is_datetime, is_utf8_or_utf8view_or_large_u use datafusion_expr::utils::merge_schema; use datafusion_expr::{ is_false, is_not_false, is_not_true, is_not_unknown, is_true, is_unknown, not, - AggregateUDF, Expr, ExprSchemable, Join, Limit, LogicalPlan, Operator, Projection, - ScalarUDF, Union, WindowFrame, WindowFrameBound, WindowFrameUnits, + AggregateUDF, Expr, ExprSchemable, Join, LambdaTypeSignature, Limit, LogicalPlan, + Operator, Projection, ScalarUDF, Union, ValueOrLambdaParameter, WindowFrame, + WindowFrameBound, WindowFrameUnits, }; /// Performs type coercion by determining the schema @@ -582,6 +583,62 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> { }); Ok(Transformed::yes(new_expr)) } + Expr::LambdaFunction(LambdaFunction { func, args }) => { + match func.signature().type_signature { + LambdaTypeSignature::UserDefined => { + let args_types = args + .iter() + .map(|arg| match arg { + Expr::Lambda(_) => Ok(ValueOrLambdaParameter::Lambda), + _ => Ok(ValueOrLambdaParameter::Value( + arg.get_type(self.schema)?, + )), + }) + .collect::>>()?; + + let value_types = func.coerce_value_types(&args_types)?; + + if args_types.iter().eq_by(&value_types, |a, b| match (a, b) { + (ValueOrLambdaParameter::Value(_type), None) => false, + (ValueOrLambdaParameter::Value(from), Some(to)) => from == to, + (ValueOrLambdaParameter::Lambda, None) => true, + (ValueOrLambdaParameter::Lambda, Some(_ty)) => false, + }) { + return Ok(Transformed::no(Expr::LambdaFunction( + LambdaFunction::new(func, args), + ))); + } + + let args = std::iter::zip(args, value_types) + .map(|(arg, ty)| match (&arg, ty) { + (Expr::Lambda(_), None) => Ok(arg), + (Expr::Lambda(_), Some(_ty)) => plan_err!("{} coerce_value_types returned Some for a lambda argument", func.name()), + (_, Some(ty)) => arg.cast_to(&ty, self.schema), + (_, None) => plan_err!("{} coerce_value_types returned None for a value argument", func.name()), + }) + .collect::>>()?; + + Ok(Transformed::yes(Expr::LambdaFunction(LambdaFunction::new( + func, args, + )))) + } + LambdaTypeSignature::VariadicAny => Ok(Transformed::no( + Expr::LambdaFunction(LambdaFunction::new(func, args)), + )), + LambdaTypeSignature::Any(number) => { + if args.len() != number { + return plan_err!( + "The function '{}' expected {number} arguments but received {}", + func.name(), args.len() + ); + } + + Ok(Transformed::no(Expr::LambdaFunction(LambdaFunction::new( + func, args, + )))) + } + } + } // TODO: remove the next line after `Expr::Wildcard` is removed #[expect(deprecated)] Expr::Alias(_) @@ -598,7 +655,6 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> { | Expr::GroupingSet(_) | Expr::Placeholder(_) | Expr::OuterReferenceColumn(_, _) - | Expr::LambdaFunction(_) | Expr::Lambda(_) | Expr::LambdaVariable(_) => Ok(Transformed::no(expr)), } diff --git a/datafusion/physical-expr/src/lambda_function.rs b/datafusion/physical-expr/src/lambda_function.rs index 5ba354f4e8a4f..97af1f9b13891 100644 --- a/datafusion/physical-expr/src/lambda_function.rs +++ b/datafusion/physical-expr/src/lambda_function.rs @@ -489,12 +489,12 @@ mod tests { fn test_lambda_function_volatile_node() { // Create a volatile UDF let volatile_udf = Arc::new(MockLambdaUDF { - signature: LambdaSignature::new(Volatility::Volatile), + signature: LambdaSignature::variadic_any(Volatility::Volatile), }); // Create a non-volatile UDF let stable_udf = Arc::new(MockLambdaUDF { - signature: LambdaSignature::new(Volatility::Stable), + signature: LambdaSignature::variadic_any(Volatility::Stable), }); let schema = Schema::new(vec![Field::new("a", DataType::Float32, false)]); diff --git a/datafusion/proto/src/bytes/mod.rs b/datafusion/proto/src/bytes/mod.rs index 732831b9464c7..e9060c0f2c986 100644 --- a/datafusion/proto/src/bytes/mod.rs +++ b/datafusion/proto/src/bytes/mod.rs @@ -229,7 +229,7 @@ impl Serializeable for Expr { Ok(Arc::new(MockLambdaUDF { name: name.to_string(), - signature: LambdaSignature::new(Volatility::Immutable), + signature: LambdaSignature::variadic_any(Volatility::Immutable), })) } } From 82930ec92df0ec3a129a2cfcd34c897019e38a7f Mon Sep 17 00:00:00 2001 From: gstvg <28798827+gstvg@users.noreply.github.com> Date: Sun, 8 Mar 2026 23:58:54 -0300 Subject: [PATCH 11/47] lambda function type coercion: stop using unstable Iterator::eq_by --- datafusion/optimizer/src/analyzer/type_coercion.rs | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 8c6c42e7e630c..e97d79e98e977 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -598,12 +598,14 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> { let value_types = func.coerce_value_types(&args_types)?; - if args_types.iter().eq_by(&value_types, |a, b| match (a, b) { - (ValueOrLambdaParameter::Value(_type), None) => false, - (ValueOrLambdaParameter::Value(from), Some(to)) => from == to, - (ValueOrLambdaParameter::Lambda, None) => true, - (ValueOrLambdaParameter::Lambda, Some(_ty)) => false, - }) { + if args_types + .iter() + .map(|a| match a { + ValueOrLambdaParameter::Value(ty) => Some(ty), + ValueOrLambdaParameter::Lambda => None, + }) + .eq(value_types.iter().map(|v| v.as_ref())) + { return Ok(Transformed::no(Expr::LambdaFunction( LambdaFunction::new(func, args), ))); From 86d5999056b49b4b597c1103ee9f57f65cacc882 Mon Sep 17 00:00:00 2001 From: gstvg <28798827+gstvg@users.noreply.github.com> Date: Mon, 9 Mar 2026 01:30:48 -0300 Subject: [PATCH 12/47] remove signature section from DOC.md --- DOC.md | 4 ---- 1 file changed, 4 deletions(-) diff --git a/DOC.md b/DOC.md index 10a0ab4e19407..a88e69a5689e3 100644 --- a/DOC.md +++ b/DOC.md @@ -1160,7 +1160,3 @@ impl SqlToRel {
- -`LambdaFunction` `Signature` is non functional - -Currenty, `LambdaUDF::signature` returns the same `Signature` as `ScalarUDF`, but it's `type_signature` field is never used, as most variants of the `TypeSignature` enum aren't applicable to a lambda, and no type coercion is applied on it's arguments, being currently a implementation responsability. We should either add lambda compatible variants to the `TypeSignature` enum, create a new `LambdaTypeSignature` and `LambdaSignature`, or support no automatic type coercion at all on lambda functions. From 60cabc02087bcd2a04d78abba8da176fa44637f4 Mon Sep 17 00:00:00 2001 From: gstvg <28798827+gstvg@users.noreply.github.com> Date: Sat, 14 Mar 2026 22:02:04 -0300 Subject: [PATCH 13/47] polish lambda impl --- Cargo.lock | 1 - datafusion-examples/examples/sql_frontend.rs | 4 + datafusion/common/src/lib.rs | 1 + datafusion/common/src/utils/mod.rs | 378 +++++++++++++++--- .../core/src/execution/session_state.rs | 36 +- datafusion/core/tests/optimizer/mod.rs | 4 + datafusion/core/tests/parquet/mod.rs | 2 +- datafusion/expr/src/expr.rs | 362 ++++++++++++++++- datafusion/expr/src/expr_fn.rs | 19 +- datafusion/expr/src/expr_schema.rs | 42 +- datafusion/expr/src/lib.rs | 4 +- datafusion/expr/src/logical_plan/plan.rs | 11 +- datafusion/expr/src/planner.rs | 3 + .../expr/src/type_coercion/functions.rs | 69 +++- datafusion/expr/src/udlf.rs | 219 +++++++--- datafusion/functions-nested/Cargo.toml | 1 - .../functions-nested/src/array_transform.rs | 275 ++++++------- datafusion/functions-nested/src/lib.rs | 39 +- .../functions-nested/src/macros_lambda.rs | 113 ++++++ .../functions/src/core/union_extract.rs | 1 - .../optimizer/src/analyzer/type_coercion.rs | 98 ++--- .../simplify_expressions/expr_simplifier.rs | 12 +- .../optimizer/tests/optimizer_integration.rs | 4 + .../src/schema_rewriter.rs | 3 +- .../physical-expr/src/expressions/column.rs | 1 - .../physical-expr/src/expressions/lambda.rs | 150 +++++-- .../src/expressions/lambda_variable.rs | 43 +- .../physical-expr/src/expressions/mod.rs | 2 +- .../physical-expr/src/lambda_function.rs | 49 ++- datafusion/physical-expr/src/planner.rs | 21 +- .../physical-expr/src/scalar_function.rs | 14 +- datafusion/proto/src/bytes/mod.rs | 12 +- datafusion/proto/src/bytes/registry.rs | 10 +- datafusion/proto/src/logical_plan/to_proto.rs | 13 +- datafusion/spark/src/function/utils.rs | 2 +- datafusion/sql/examples/sql.rs | 4 + datafusion/sql/src/expr/function.rs | 82 +++- datafusion/sql/src/expr/identifier.rs | 9 +- datafusion/sql/src/expr/mod.rs | 4 + datafusion/sql/src/unparser/dialect.rs | 12 + datafusion/sql/src/unparser/expr.rs | 8 +- datafusion/sql/tests/common/mod.rs | 4 + datafusion/sqllogictest/src/test_context.rs | 36 +- datafusion/sqllogictest/test_files/lambda.slt | 76 +++- 44 files changed, 1693 insertions(+), 560 deletions(-) create mode 100644 datafusion/functions-nested/src/macros_lambda.rs diff --git a/Cargo.lock b/Cargo.lock index 8377a263cd0cf..f500265108ff5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2386,7 +2386,6 @@ dependencies = [ "datafusion-functions-aggregate", "datafusion-functions-aggregate-common", "datafusion-macros", - "datafusion-physical-expr", "datafusion-physical-expr-common", "itertools 0.14.0", "log", diff --git a/datafusion-examples/examples/sql_frontend.rs b/datafusion-examples/examples/sql_frontend.rs index e6341080a9e11..66cf686711e2a 100644 --- a/datafusion-examples/examples/sql_frontend.rs +++ b/datafusion-examples/examples/sql_frontend.rs @@ -177,6 +177,10 @@ impl ContextProvider for MyContextProvider { Vec::new() } + fn udlf_names(&self) -> Vec { + Vec::new() + } + fn udaf_names(&self) -> Vec { Vec::new() } diff --git a/datafusion/common/src/lib.rs b/datafusion/common/src/lib.rs index 76c7b46e32737..09fefb4eaa749 100644 --- a/datafusion/common/src/lib.rs +++ b/datafusion/common/src/lib.rs @@ -112,6 +112,7 @@ pub type HashMap = hashbrown::HashMap; pub type HashSet = hashbrown::HashSet; pub mod hash_map { pub use hashbrown::hash_map::Entry; + pub use hashbrown::hash_map::EntryRef; } pub mod hash_set { pub use hashbrown::hash_set::Entry; diff --git a/datafusion/common/src/utils/mod.rs b/datafusion/common/src/utils/mod.rs index 3fd0683659caf..1c2ec91568a3d 100644 --- a/datafusion/common/src/utils/mod.rs +++ b/datafusion/common/src/utils/mod.rs @@ -30,7 +30,7 @@ use arrow::array::{ cast::AsArray, Array, ArrayRef, FixedSizeListArray, LargeListArray, ListArray, OffsetSizeTrait, }; -use arrow::array::{ArrowPrimitiveType, PrimitiveArray}; +use arrow::array::{ArrowPrimitiveType, GenericListArray, Int32Array, PrimitiveArray}; use arrow::buffer::OffsetBuffer; use arrow::compute::{partition, SortColumn, SortOptions}; use arrow::datatypes::{ @@ -946,115 +946,166 @@ pub fn take_function_args( } /// [0, 2, 2, 5, 6] -> [0, 0, 2, 2, 2, 3] -pub fn make_list_array_indices( +fn list_array_values_row_number( offsets: &OffsetBuffer, ) -> PrimitiveArray { - let mut indices = Vec::with_capacity( - offsets.last().unwrap().as_usize() - offsets.first().unwrap().as_usize(), + let mut rows_number = Vec::with_capacity( + offsets[offsets.len() - 1].to_usize().unwrap() - offsets[0].to_usize().unwrap(), ); for (i, (&start, &end)) in std::iter::zip(&offsets[..], &offsets[1..]).enumerate() { - indices.extend(repeat_n( + rows_number.extend(repeat_n( T::Native::usize_as(i), end.as_usize() - start.as_usize(), )); } - PrimitiveArray::new(indices.into(), None) + PrimitiveArray::new(rows_number.into(), None) } -/// [0, 2, 2, 5, 6] -> [0, 1, 0, 1, 2, 0] -pub fn make_list_element_indices( +/// [0, 2, 2, 5, 6] -> [1, 2, 1, 2, 3, 1] +fn list_array_values_index( offsets: &OffsetBuffer, ) -> PrimitiveArray { - let mut indices = - Vec::with_capacity(offsets.last().unwrap().as_usize() - offsets[0].as_usize()); + let mut indices = Vec::with_capacity( + offsets[offsets.len() - 1].to_usize().unwrap() - offsets[0].to_usize().unwrap(), + ); for (&start, &end) in std::iter::zip(&offsets[..], &offsets[1..]) { - indices.extend( - (0..end.as_usize() - start.as_usize()).map(|i| T::Native::usize_as(i)), - ); + indices + .extend((1..1 + end.as_usize() - start.as_usize()).map(T::Native::usize_as)); } PrimitiveArray::new(indices.into(), None) } -/// (3, 2) -> [0, 0, 1, 1, 2, 2] -pub fn make_fsl_array_indices( - list_size: i32, - array_len: usize, -) -> PrimitiveArray { - let mut indices = Vec::with_capacity(list_size as usize * array_len); +/// (2, 3) -> [0, 0, 1, 1, 2, 2] +fn fsl_values_row_number(list_size: i32, array_len: usize) -> Result { + let list_size = list_size.to_usize().ok_or_else(|| { + _exec_datafusion_err!("fsl_values_index: invalid list_size {list_size}") + })?; + + let mut rows_number = Vec::with_capacity(list_size * array_len); for i in 0..array_len { - indices.extend(repeat_n(i as i32, list_size as usize)); + rows_number.extend(repeat_n(i as i32, list_size)); } - PrimitiveArray::new(indices.into(), None) + Ok(PrimitiveArray::new(rows_number.into(), None)) } -/// (3, 2) -> [0, 1, 0, 1, 0, 1] -pub fn make_fsl_element_indices( - list_size: i32, - array_len: usize, -) -> PrimitiveArray { - let mut indices = Vec::with_capacity(list_size as usize * array_len); +/// (2, 3) -> [1, 2, 1, 2, 1, 2] +fn fsl_values_index(list_size: i32, array_len: usize) -> Result { + let list_size = list_size.to_usize().ok_or_else(|| { + _exec_datafusion_err!("fsl_values_index: invalid list_size {list_size}") + })?; - if array_len > 0 { - indices.extend((0..list_size as usize).map(|j| j as i32)); + let mut indices = Vec::with_capacity(list_size * array_len); - for _ in 1..array_len { - indices.extend_from_within(0..list_size as usize); - } + for _ in 0..array_len { + indices.extend((1..1 + list_size).map(|j| j as i32)); } - PrimitiveArray::new(indices.into(), None) + Ok(PrimitiveArray::new(indices.into(), None)) } -pub fn list_values(array: &dyn Array) -> Result<&ArrayRef> { +/// Returns the inner values of a list, or an error otherwise +/// For [`ListArray`] and [`LargeListArray`], if it's sliced, it returns a +/// sliced array too. Therefore, too reconstruct a list using it, +/// you must adjust the offsets using [`adjust_offsets_for_slice`] +pub fn list_values(array: &dyn Array) -> Result { match array.data_type() { - DataType::List(_) => Ok(array.as_list::().values()), - DataType::LargeList(_) => Ok(array.as_list::().values()), - DataType::FixedSizeList(_, _) => Ok(array.as_fixed_size_list().values()), + DataType::List(_) => Ok(sliced_list_values(array.as_list::())), + DataType::LargeList(_) => Ok(sliced_list_values(array.as_list::())), + DataType::FixedSizeList(_, _) => { + Ok(Arc::clone(array.as_fixed_size_list().values())) + } other => _exec_err!("expected list, got {other}"), } } -pub fn list_indices(array: &dyn Array) -> Result { +fn sliced_list_values(list: &GenericListArray) -> ArrayRef { + let values = list.values(); + let offsets = list.offsets(); + + if let (Some(first), Some(last)) = (offsets.first(), offsets.last()) { + let first = first.to_usize().unwrap(); + let last = last.to_usize().unwrap(); + + if first != 0 || last != values.len() { + return values.slice(first, last - first); + } + } + + Arc::clone(values) +} + +/// If `list` is sliced, returns an adjusted offset buffer so that +/// it points to the sliced portion of the list values, and not the whole list values +pub fn adjust_offsets_for_slice( + list: &GenericListArray, +) -> OffsetBuffer { + let offsets = list.offsets(); + + if let (Some(first), Some(last)) = (offsets.first(), offsets.last()) { + if !first.is_zero() || last.to_usize().unwrap() != list.values().len() { + let offsets = offsets.iter().map(|offset| *offset - *first).collect(); + + //todo: use unsafe Offset::new_unchecked? + return OffsetBuffer::new(offsets); + } + } + + offsets.clone() +} + +/// If `array` is a contiguos list, returns a new array of the same length as it's inner values +/// where each value is the 1-based index of the sublist it's contained. Example: +/// +/// `[[1], [2, 3], [4, 5, 6]] => [1, 2, 2, 3, 3, 3]` +/// +/// If it's not a contiguos list, return an error +pub fn list_values_row_number(array: &dyn Array) -> Result { match array.data_type() { - DataType::List(_) => Ok(Arc::new(make_list_array_indices::( - array.as_list().offsets(), - ))), - DataType::LargeList(_) => Ok(Arc::new(make_list_array_indices::( + DataType::List(_) => Ok(Arc::new(list_array_values_row_number::( array.as_list().offsets(), ))), + DataType::LargeList(_) => Ok(Arc::new( + list_array_values_row_number::(array.as_list().offsets()), + )), DataType::FixedSizeList(_, _) => { let fixed_size_list = array.as_fixed_size_list(); - Ok(Arc::new(make_fsl_array_indices( + Ok(Arc::new(fsl_values_row_number( fixed_size_list.value_length(), fixed_size_list.len(), - ))) + )?)) } other => _exec_err!("expected list, got {other}"), } } -pub fn elements_indices(array: &dyn Array) -> Result { +/// If `array` is a contiguos list, returns a new array of the same length as it's inner values +/// where each value is the 1-based index within the sublist it's contained. Example: +/// +/// `[[1], [2, 3], [4, 5, 6]] => [1, 1, 2, 1, 2, 3]` +/// +/// If it's not a contiguos list, return an error +pub fn list_values_index(array: &dyn Array) -> Result { match array.data_type() { - DataType::List(_) => Ok(Arc::new(make_list_element_indices::( + DataType::List(_) => Ok(Arc::new(list_array_values_index::( array.as_list::().offsets(), ))), - DataType::LargeList(_) => Ok(Arc::new(make_list_element_indices::( + DataType::LargeList(_) => Ok(Arc::new(list_array_values_index::( array.as_list::().offsets(), ))), DataType::FixedSizeList(_, _) => { let fixed_size_list = array.as_fixed_size_list(); - Ok(Arc::new(make_fsl_element_indices( + Ok(Arc::new(fsl_values_index( fixed_size_list.value_length(), fixed_size_list.len(), - ))) + )?)) } other => _exec_err!("expected list, got {other}"), } @@ -1064,7 +1115,7 @@ pub fn elements_indices(array: &dyn Array) -> Result { mod tests { use super::*; use crate::ScalarValue::Null; - use arrow::array::Float64Array; + use arrow::array::{Float64Array, Int32Array}; use sqlparser::ast::Ident; use sqlparser::tokenizer::Span; @@ -1366,4 +1417,231 @@ mod tests { assert_eq!(expected, transposed); Ok(()) } + + #[test] + fn test_list_array_values_row_number() { + assert_eq!( + list_array_values_row_number::(&OffsetBuffer::from_lengths([ + 1, 3, 0, 2, + ])), + Int32Array::from(vec![0, 1, 1, 1, 3, 3]) + ); + + assert_eq!( + list_array_values_row_number::(&OffsetBuffer::from_lengths([])), + Int32Array::new_null(0) + ); + + assert_eq!( + list_array_values_row_number::(&OffsetBuffer::from_lengths([0])), + Int32Array::new_null(0) + ); + + assert_eq!( + list_array_values_row_number::(&OffsetBuffer::from_lengths([ + 0, 0 + ])), + Int32Array::new_null(0) + ); + + assert_eq!( + list_array_values_row_number::(&OffsetBuffer::from_lengths([1])), + Int32Array::from(vec![0]) + ); + + assert_eq!( + list_array_values_row_number::(&OffsetBuffer::from_lengths([2])), + Int32Array::from(vec![0, 0]) + ); + } + + #[test] + fn test_list_array_values_index() { + assert_eq!( + list_array_values_index::(&OffsetBuffer::from_lengths([ + 1, 3, 0, 2, + ])), + Int32Array::from(vec![1, 1, 2, 3, 1, 2]) + ); + + assert_eq!( + list_array_values_index::(&OffsetBuffer::from_lengths([])), + Int32Array::new_null(0) + ); + + assert_eq!( + list_array_values_index::(&OffsetBuffer::from_lengths([0])), + Int32Array::new_null(0) + ); + + assert_eq!( + list_array_values_index::(&OffsetBuffer::from_lengths([0, 0])), + Int32Array::new_null(0) + ); + + assert_eq!( + list_array_values_index::(&OffsetBuffer::from_lengths([1])), + Int32Array::from(vec![1]) + ); + + assert_eq!( + list_array_values_index::(&OffsetBuffer::from_lengths([2])), + Int32Array::from(vec![1, 2]) + ); + } + + #[test] + fn test_fsl_values_row_number() { + assert_eq!( + fsl_values_row_number(2, 3).unwrap(), + Int32Array::from(vec![0, 0, 1, 1, 2, 2]) + ); + + assert_eq!( + fsl_values_row_number(1, 3).unwrap(), + Int32Array::from(vec![0, 1, 2]) + ); + + assert_eq!( + fsl_values_row_number(2, 1).unwrap(), + Int32Array::from(vec![0, 0]) + ); + + assert_eq!( + fsl_values_row_number(2, 0).unwrap(), + Int32Array::new_null(0), + ); + + assert_eq!( + fsl_values_row_number(0, 2).unwrap(), + Int32Array::new_null(0), + ); + + assert_eq!( + fsl_values_row_number(0, 0).unwrap(), + Int32Array::new_null(0), + ); + + fsl_values_row_number(-1, 2).unwrap_err(); + fsl_values_row_number(-1, 0).unwrap_err(); + } + + #[test] + fn test_fsl_values_index() { + assert_eq!( + fsl_values_index(2, 3).unwrap(), + Int32Array::from(vec![1, 2, 1, 2, 1, 2]) + ); + + assert_eq!( + fsl_values_index(1, 3).unwrap(), + Int32Array::from(vec![1, 1, 1]) + ); + + assert_eq!( + fsl_values_index(2, 1).unwrap(), + Int32Array::from(vec![1, 2]) + ); + + assert_eq!(fsl_values_index(2, 0).unwrap(), Int32Array::new_null(0)); + assert_eq!(fsl_values_index(0, 2).unwrap(), Int32Array::new_null(0)); + assert_eq!(fsl_values_index(0, 0).unwrap(), Int32Array::new_null(0)); + + fsl_values_index(-1, 2).unwrap_err(); + fsl_values_index(-1, 0).unwrap_err(); + } + + fn list() -> ListArray { + let data = vec![ + Some(vec![Some(0), Some(1), Some(2)]), + None, + Some(vec![Some(3), None, Some(5)]), + Some(vec![Some(6), Some(7)]), + ]; + ListArray::from_iter_primitive::(data) + } + + #[test] + fn test_sliced_list_values() { + let list = list(); + + assert_eq!( + sliced_list_values(&list).as_primitive(), + &Int32Array::from(vec![ + Some(0), + Some(1), + Some(2), + Some(3), + None, + Some(5), + Some(6), + Some(7) + ]) + ); + + assert_eq!( + sliced_list_values(&list.slice(0, 1)).as_primitive(), + &Int32Array::from(vec![Some(0), Some(1), Some(2)]) + ); + + assert_eq!( + sliced_list_values(&list.slice(2, 1)).as_primitive(), + &Int32Array::from(vec![Some(3), None, Some(5)]) + ); + + assert_eq!( + sliced_list_values(&list.slice(3, 1)).as_primitive(), + &Int32Array::from(vec![Some(6), Some(7)]) + ); + + assert!(sliced_list_values(&list.slice(0, 0)).is_empty()); + assert!(sliced_list_values(&list.slice(1, 0)).is_empty()); + assert!(sliced_list_values(&list.slice(3, 0)).is_empty()); + } + + #[test] + fn test_adjust_offsets() { + let data = vec![ + Some(vec![Some(0), Some(1), Some(2)]), + None, + Some(vec![Some(3), None, Some(5)]), + Some(vec![Some(6), Some(7)]), + ]; + let list = ListArray::from_iter_primitive::(data); + + assert_eq!( + adjust_offsets_for_slice(&list), + OffsetBuffer::from_lengths([3, 0, 3, 2]) + ); + + assert_eq!( + adjust_offsets_for_slice(&list.slice(0, 1)), + OffsetBuffer::from_lengths([3]) + ); + + assert_eq!( + adjust_offsets_for_slice(&list.slice(1, 2)), + OffsetBuffer::from_lengths([0, 3]) + ); + + assert_eq!( + adjust_offsets_for_slice(&list.slice(1, 3)), + OffsetBuffer::from_lengths([0, 3, 2]) + ); + + assert_eq!( + adjust_offsets_for_slice(&list.slice(0, 0)), + OffsetBuffer::from_lengths([]) + ); + + assert_eq!( + adjust_offsets_for_slice(&list.slice(1, 0)), + OffsetBuffer::from_lengths([]) + ); + + assert_eq!( + adjust_offsets_for_slice(&list.slice(3, 0)), + OffsetBuffer::from_lengths([]) + ); + } } diff --git a/datafusion/core/src/execution/session_state.rs b/datafusion/core/src/execution/session_state.rs index f33f3d3412f4d..805b61b4c6a30 100644 --- a/datafusion/core/src/execution/session_state.rs +++ b/datafusion/core/src/execution/session_state.rs @@ -219,6 +219,7 @@ impl Debug for SessionState { .field("physical_optimizers", &self.physical_optimizers) .field("table_functions", &self.table_functions) .field("scalar_functions", &self.scalar_functions) + .field("lambda_functions", &self.lambda_functions) .field("aggregate_functions", &self.aggregate_functions) .field("window_functions", &self.window_functions) .field("prepared_plans", &self.prepared_plans) @@ -833,6 +834,11 @@ impl SessionState { pub fn scalar_functions(&self) -> &HashMap> { &self.scalar_functions } + + /// Return reference to lambda_functions + pub fn lambda_functions(&self) -> &HashMap> { + &self.lambda_functions + } /// Return reference to aggregate_functions pub fn aggregate_functions(&self) -> &HashMap> { @@ -1231,6 +1237,15 @@ impl SessionStateBuilder { self.scalar_functions = Some(scalar_functions); self } + + /// Set the map of [`LambdaUDF`]s + pub fn with_lambda_functions( + mut self, + lambda_functions: Vec>, + ) -> Self { + self.lambda_functions = Some(lambda_functions); + self + } /// Set the map of [`AggregateUDF`]s pub fn with_aggregate_functions( @@ -1601,6 +1616,11 @@ impl SessionStateBuilder { pub fn scalar_functions(&mut self) -> &mut Option>> { &mut self.scalar_functions } + + /// Returns the current scalar_functions value + pub fn lambda_functions(&mut self) -> &mut Option>> { + &mut self.lambda_functions + } /// Returns the current aggregate_functions value pub fn aggregate_functions(&mut self) -> &mut Option>> { @@ -1838,6 +1858,10 @@ impl ContextProvider for SessionContextProvider<'_> { fn udf_names(&self) -> Vec { self.state.scalar_functions().keys().cloned().collect() } + + fn udlf_names(&self) -> Vec { + self.state.lambda_functions().keys().cloned().collect() + } fn udaf_names(&self) -> Vec { self.state.aggregate_functions().keys().cloned().collect() @@ -1981,6 +2005,10 @@ impl FunctionRegistry for SessionState { &mut self, udlf: Arc, ) -> datafusion_common::Result>> { + udlf.aliases().iter().for_each(|alias| { + self.lambda_functions + .insert(alias.clone(), Arc::clone(&udlf)); + }); Ok(self.lambda_functions.insert(udlf.name().into(), udlf)) } @@ -2139,10 +2167,10 @@ mod tests { use datafusion_common::Result; use datafusion_execution::config::SessionConfig; use datafusion_expr::Expr; + use datafusion_expr::LambdaUDF; use datafusion_optimizer::optimizer::OptimizerRule; use datafusion_optimizer::Optimizer; use datafusion_physical_plan::display::DisplayableExecutionPlan; - use datafusion_session::Session; use datafusion_sql::planner::{PlannerContext, SqlToRel}; use std::collections::HashMap; use std::sync::Arc; @@ -2419,7 +2447,7 @@ mod tests { self.state.scalar_functions().get(name).cloned() } - fn get_lambda_meta(&self, name: &str) -> Option> { + fn get_lambda_meta(&self, name: &str) -> Option> { self.state.lambda_functions().get(name).cloned() } @@ -2442,6 +2470,10 @@ mod tests { fn udf_names(&self) -> Vec { self.state.scalar_functions().keys().cloned().collect() } + + fn udlf_names(&self) -> Vec { + self.state.lambda_functions().keys().cloned().collect() + } fn udaf_names(&self) -> Vec { self.state.aggregate_functions().keys().cloned().collect() diff --git a/datafusion/core/tests/optimizer/mod.rs b/datafusion/core/tests/optimizer/mod.rs index 44e40143fe171..389f72d96eaa9 100644 --- a/datafusion/core/tests/optimizer/mod.rs +++ b/datafusion/core/tests/optimizer/mod.rs @@ -240,6 +240,10 @@ impl ContextProvider for MyContextProvider { Vec::new() } + fn udlf_names(&self) -> Vec { + Vec::new() + } + fn udaf_names(&self) -> Vec { Vec::new() } diff --git a/datafusion/core/tests/parquet/mod.rs b/datafusion/core/tests/parquet/mod.rs index 097600e45eadd..27b8c18596476 100644 --- a/datafusion/core/tests/parquet/mod.rs +++ b/datafusion/core/tests/parquet/mod.rs @@ -516,7 +516,7 @@ fn make_uint_batches(start: u8, end: u8) -> RecordBatch { Field::new("u64", DataType::UInt64, true), ])); let v8: Vec = (start..end).collect(); - let v16: Vec = (start as _..end as _).collect(); + let v16: Vec = (start as _..end as u16).collect(); let v32: Vec = (start as _..end as _).collect(); let v64: Vec = (start as _..end as _).collect(); RecordBatch::try_new( diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index ef4ba4ef5cdfd..f73e6080d9dd0 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -27,17 +27,20 @@ use std::sync::Arc; use crate::expr_fn::binary_expr; use crate::function::WindowFunctionSimplification; use crate::logical_plan::Subquery; +use crate::type_coercion::functions::value_fields_with_lambda_udf; use crate::udlf::LambdaUDF; -use crate::{AggregateUDF, Volatility}; +use crate::{AggregateUDF, ValueOrLambda, Volatility}; use crate::{ExprSchemable, Operator, Signature, WindowFrame, WindowUDF}; use arrow::datatypes::{DataType, Field, FieldRef}; use datafusion_common::cse::{HashNode, NormalizeEq, Normalizeable}; +use datafusion_common::hash_map::EntryRef; use datafusion_common::tree_node::{ Transformed, TransformedResult, TreeNode, TreeNodeContainer, TreeNodeRecursion, }; use datafusion_common::{ - Column, DFSchema, HashMap, Result, ScalarValue, Spans, TableReference, + plan_datafusion_err, plan_err, Column, DFSchema, ExprSchema, HashMap, Result, + ScalarValue, Spans, TableReference, }; use datafusion_functions_window_common::field::WindowUDFFieldArgs; #[cfg(feature = "sql")] @@ -399,12 +402,15 @@ pub enum Expr { OuterReferenceColumn(FieldRef, Column), /// Unnest expression Unnest(Unnest), + /// Call a lambda function with a set of arguments. LambdaFunction(LambdaFunction), - /// Lambda expression + /// A Lambda expression with a set of parameters names and a body Lambda(Lambda), + /// A named reference to a lambda parameter LambdaVariable(LambdaVariable), } +/// Invoke a [`LambdaUDF`] with a set of arguments #[derive(Clone, Eq, PartialOrd, Debug)] pub struct LambdaFunction { pub func: Arc, @@ -412,6 +418,7 @@ pub struct LambdaFunction { } impl LambdaFunction { + /// Create a new `LambdaFunction` from a [`LambdaUDF`] pub fn new(func: Arc, args: Vec) -> Self { Self { func, args } } @@ -419,6 +426,26 @@ impl LambdaFunction { pub fn name(&self) -> &str { self.func.name() } + + /// Invokes the inner function [`LambdaUDF::lambdas_parameters`] + /// using the arguments of this invocation + pub fn lambdas_parameters( + &self, + schema: &dyn ExprSchema, + ) -> Result>>> { + let args = self + .args + .iter() + .map(|e| match e { + Expr::Lambda(_lambda) => Ok(ValueOrLambda::Lambda(())), + _ => Ok(ValueOrLambda::Value(e.to_field(schema)?.1)), + }) + .collect::>>()?; + + let coerced = value_fields_with_lambda_udf(&args, self.func.as_ref())?; + + self.func.lambdas_parameters(&coerced) + } } impl Hash for LambdaFunction { @@ -434,15 +461,38 @@ impl PartialEq for LambdaFunction { } } +/// A named reference to a lambda parameter which includes it's own [`FieldRef`], +/// which is used to implement [`ExprSchemable`], for example. It is an option only to make +/// easier for `expr_api` users to construct lambda variables, but any expression +/// tree or [`LogicalPlan`] containing unresolved variables must be resolved before +/// usage with either [`Expr::resolve_lambdas_variables`] or +/// [`LogicalPlan::resolve_lambdas_variables`]. The default SQL planner produces +/// already resolved variables and no further resolving is required. +/// +/// After resolving, if any non-lambda argument from the lambda function +/// which this variables originates from have it's type, nullability or +/// metadata changed, the resolved field may became outdated and must be +/// resolved again. +/// +/// [`LogicalPlan`]: crate::LogicalPlan +/// [`LogicalPlan::resolve_lambdas_variables`]: LogicalPlan::resolve_lambdas_variables +/// +// todo: if substrait come to produce resolved variables, cite it above too #[derive(Clone, PartialEq, PartialOrd, Eq, Debug, Hash)] pub struct LambdaVariable { pub name: String, - pub field: FieldRef, + pub field: Option, pub spans: Spans, } impl LambdaVariable { - pub fn new(name: String, field: FieldRef) -> Self { + /// Create a lambda variable from a name and an optional Field. + /// If the field is none, the expression tree or LogicalPlan which + /// owns this variable must be resolved before usage with either + /// [`Expr::resolve_lambdas_variables`] or [`LogicalPlan::resolve_lambdas_variables`]. + /// + /// [`LogicalPlan::resolve_lambdas_variables`]: crate::LogicalPlan::resolve_lambdas_variables + pub fn new(name: String, field: Option) -> Self { Self { name, field, @@ -1266,7 +1316,7 @@ impl GroupingSet { } } -/// Lambda expression. +/// A Lambda expression with a set of parameters names and a body #[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)] pub struct Lambda { pub params: Vec, @@ -2178,6 +2228,169 @@ impl Expr { None } } + + /// Return a `Expr` with all [`LambdaVariable`] resolved only if all of them + /// are contained in the subtree of the [`LambdaFunction`] it originates from, + /// otherwise returns an error + pub fn resolve_lambdas_variables( + self, + schema: &DFSchema, + ) -> Result> { + resolve_lambdas_variables(self, schema, &mut HashMap::new()) + } +} + +fn resolve_lambdas_variables( + expr: Expr, + schema: &DFSchema, + vars: &mut HashMap>, +) -> Result> { + expr.transform_down(|expr| match expr { + Expr::LambdaFunction(LambdaFunction { func, args }) => { + let args = if !vars.is_empty() { + /* if this is a nested lambda, we must resolve non-lambda args before invoking + lambdas_parameters because it will invoke ExprSchemable::to_field for every + non-lambda parameter, and if one them contains a lambda variable, it will fail + due to it being unresolved. Example query: + + array_transform([[1, 2]], a -> array_transform(a, b -> b+1)) + + the nested array_transform's lambdas_parameters will call Lambdavariable::to_field + on it's first argument, the variable `a`, which must be resolved + */ + args.map_elements(|arg| match arg { + Expr::Lambda(_) => Ok(Transformed::no(arg)), + _ => resolve_lambdas_variables(arg, schema, vars), + })? + } else { + Transformed::no(args) + }; + + let transformed = args.transformed; + let func = LambdaFunction::new(func, args.data); + + let mut lambdas_params = func.lambdas_parameters(schema)?.into_iter(); + + let num_args = func.args.len(); + let num_lambdas_params = lambdas_params.len(); + + let args = func.args.map_elements(|arg| { + let lambda_params = lambdas_params.next().ok_or_else(|| { + plan_datafusion_err!( + "{} lambdas_parameters returned {num_lambdas_params} values for {num_args} args", + func.func.name() + ) + })?; + + match (arg, lambda_params) { + (Expr::Lambda(mut lambda), Some(lambda_params)) => { + if lambda.params.len() > lambda_params.len() { + return plan_err!( + "{} lambda defined {} params ({}), but only {} supported", + func.func.name(), + lambda.params.len(), + display_comma_separated(&lambda.params), + lambda_params.len() + ); + } + + if !all_unique(&lambda.params) { + return plan_err!( + "lambda params must be unique, got ({})", + lambda.params.join(", ") + ); + } + + for (param, field) in + std::iter::zip(&lambda.params, lambda_params) + { + vars.entry_ref(param) + .or_default() + .push(Arc::new(field)); + } + + let transformed = resolve_lambdas_variables(mem::take(lambda.body.as_mut()), schema, vars)?; + + *lambda.body = transformed.data; + + for param in &lambda.params { + match vars.entry_ref(param) { + EntryRef::Occupied(mut v) => { + if v.get().len() == 1 { + v.remove(); + } else { + v.get_mut() + .pop() + .expect("every entry should have at least one field"); + } + }, + EntryRef::Vacant(_v) => { + unreachable!("the loop above should have inserted a value for every param") + }, + } + } + + Ok(Transformed::new(Expr::Lambda(lambda), transformed.transformed, TreeNodeRecursion::Jump)) + } + (Expr::Lambda(_), None) => { + plan_err!( + "{} lambdas_parameters retured None for a lambda argument", + func.func.name() + ) + } + (_, Some(_)) => { + plan_err!( + "{} lambdas_parameters retured Some for a non-lambda argument", + func.func.name() + ) + } + (arg, None) => Ok(Transformed::no(arg)) // resolved above + } + })?; + + Ok(Transformed::new( + Expr::LambdaFunction(LambdaFunction::new(func.func, args.data)), + transformed || args.transformed, + TreeNodeRecursion::Jump, + )) + } + Expr::LambdaVariable(mut var) => { + let fields_chain = vars.get(&var.name).ok_or_else(|| { + plan_datafusion_err!( + "missing field of lambda variable {} while resolving", + var.name + ) + })?; + + let field = fields_chain + .last() + .expect("every entry should have at least one field"); + + let transformed = var.field.as_ref().is_none_or(|old| old != field); + + if transformed { + var.field = Some(Arc::clone(field)); + } + + Ok(Transformed::new_transformed( + Expr::LambdaVariable(var), + transformed, + )) + } + _ => Ok(Transformed::no(expr)), + }) +} + +fn all_unique(params: &[String]) -> bool { + match params.len() { + 0 | 1 => true, + 2 => params[0] != params[1], + _ => { + let mut set = HashSet::with_capacity(params.len()); + + params.iter().all(|p| set.insert(p.as_str())) + } + } } impl Normalizeable for Expr { @@ -3090,7 +3303,12 @@ impl Display for SchemaDisplay<'_> { } } Expr::Lambda(Lambda { params, body }) => { - write!(f, "({}) -> {body}", display_comma_separated(params)) + write!( + f, + "({}) -> {}", + display_comma_separated(params), + SchemaDisplay(body) + ) } Expr::LambdaVariable(c) => { write!(f, "{}", c.name) @@ -3631,13 +3849,15 @@ pub fn physical_name(expr: &Expr) -> Result { mod test { use crate::expr_fn::col; use crate::{ - case, lit, placeholder, qualified_wildcard, wildcard, wildcard_with_options, - ColumnarValue, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Volatility, + case, lambda, lambda_var, lit, placeholder, qualified_wildcard, wildcard, + wildcard_with_options, ColumnarValue, LambdaSignature, ScalarFunctionArgs, + ScalarUDF, ScalarUDFImpl, Volatility, }; use arrow::datatypes::{Field, Schema}; use sqlparser::ast; use sqlparser::ast::{Ident, IdentWithAlias}; use std::any::Any; + use std::sync::Arc; #[test] fn infer_placeholder_in_clause() { @@ -4118,4 +4338,128 @@ mod test { } } } + + #[test] + fn test_resolve_lambda_variables() { + let schema = DFSchema::try_from(Schema::new(vec![Field::new( + "c", + DataType::new_list(DataType::new_list(DataType::Int32, true), true), + true, + )])) + .unwrap(); + + #[derive(Debug, Hash, PartialEq, Eq)] + struct MockLambdaUDF { + signature: LambdaSignature, + } + + impl LambdaUDF for MockLambdaUDF { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "array_transform" + } + + fn signature(&self) -> &LambdaSignature { + &self.signature + } + + fn lambdas_parameters( + &self, + args: &[ValueOrLambda], + ) -> Result>>> { + let (ValueOrLambda::Value(list), ValueOrLambda::Lambda(_lambda)) = + (&args[0], &args[1]) + else { + unreachable!() + }; + + let (field, index_type) = match list.data_type() { + DataType::List(field) => (field, DataType::Int32), + _ => unreachable!(), + }; + + let value = + Field::new("", field.data_type().clone(), field.is_nullable()) + .with_metadata(field.metadata().clone()); + let index = Field::new("", index_type, false); + + Ok(vec![None, Some(vec![value, index])]) + } + + fn return_field_from_args( + &self, + _args: crate::LambdaReturnFieldArgs, + ) -> Result { + unimplemented!() + } + + fn invoke_with_args( + &self, + _args: crate::LambdaFunctionArgs, + ) -> Result { + unimplemented!() + } + } + + let func = Arc::new(MockLambdaUDF { + signature: LambdaSignature::variadic_any(Volatility::Immutable), + }) as _; + + // array_transform(c, v -> array_transform(v, (v, i) -> v+i)) + let expr = Expr::LambdaFunction(LambdaFunction::new( + Arc::clone(&func), + vec![ + col("c"), + lambda( + ["v"], + Expr::LambdaFunction(LambdaFunction::new( + Arc::clone(&func), + vec![ + lambda_var("v"), + lambda(["v", "i"], lambda_var("v") + lambda_var("i")), + ], + )), + ), + ], + )); + + let resolved_expr = expr.resolve_lambdas_variables(&schema).unwrap().data; + + let expected = Expr::LambdaFunction(LambdaFunction::new( + Arc::clone(&func), + vec![ + col("c"), + lambda( + ["v"], + Expr::LambdaFunction(LambdaFunction::new( + func, + vec![ + resolved_lambda_var( + "v", + DataType::new_list(DataType::Int32, true), + true, + ), + lambda( + ["v", "i"], + resolved_lambda_var("v", DataType::Int32, true) + + resolved_lambda_var("i", DataType::Int32, false), + ), + ], + )), + ), + ], + )); + + assert_eq!(resolved_expr, expected); + } + + fn resolved_lambda_var(name: &str, dt: DataType, nullable: bool) -> Expr { + Expr::LambdaVariable(LambdaVariable::new( + name.into(), + Some(Arc::new(Field::new("", dt, nullable))), + )) + } } diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index c777c4978f99a..ec8331ecbed73 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -18,8 +18,7 @@ //! Functions for creating logical expressions use crate::expr::{ - AggregateFunction, BinaryExpr, Cast, Exists, GroupingSet, InList, InSubquery, - NullTreatment, Placeholder, TryCast, Unnest, WildcardOptions, WindowFunction, + AggregateFunction, BinaryExpr, Cast, Exists, GroupingSet, InList, InSubquery, Lambda, LambdaVariable, NullTreatment, Placeholder, TryCast, Unnest, WildcardOptions, WindowFunction }; use crate::function::{ AccumulatorArgs, AccumulatorFactoryFunction, PartitionEvaluatorFactory, @@ -727,6 +726,22 @@ pub fn interval_month_day_nano_lit(value: &str) -> Expr { Expr::Literal(ScalarValue::IntervalMonthDayNano(interval), None) } +/// Create a lambda expression +pub fn lambda(params: impl IntoIterator>, body: Expr) -> Expr { + Expr::Lambda(Lambda::new(params.into_iter().map(Into::into).collect(), body)) +} + +/// Create an unresolved lambda variable expression +/// +/// The expression tree or LogicalPlan which +/// owns this variable must be resolved before usage with either +/// [`Expr::resolve_lambdas_variables`] or [`LogicalPlan::resolve_lambdas_variables`]. +/// +/// [`LogicalPlan::resolve_lambdas_variables`]: crate::LogicalPlan::resolve_lambdas_variables +pub fn lambda_var(name: impl Into) -> Expr { + Expr::LambdaVariable(LambdaVariable::new(name.into(), None)) +} + /// Extensions for configuring [`Expr::AggregateFunction`] or [`Expr::WindowFunction`] /// /// Adds methods to [`Expr`] that make it easy to set optional options diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index f3789ca9fd115..6c496cf7e158e 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -23,9 +23,10 @@ use crate::expr::{ }; use crate::expr::{FieldMetadata, LambdaVariable}; use crate::type_coercion::functions::{ - fields_with_aggregate_udf, fields_with_window_udf, + fields_with_aggregate_udf, fields_with_window_udf, value_fields_with_lambda_udf, }; -use crate::udlf::{LambdaReturnFieldArgs, ValueOrLambdaField}; +use crate::udlf::LambdaReturnFieldArgs; +use crate::ValueOrLambda; use crate::{ type_coercion::functions::data_types_with_scalar_udf, udf::ReturnFieldArgs, utils, LogicalPlan, Projection, Subquery, WindowFunctionDefinition, @@ -240,9 +241,11 @@ impl ExprSchemable for Expr { Ok(return_type) } Expr::Lambda(Lambda { params: _, body }) => body.get_type(schema), - Expr::LambdaVariable(LambdaVariable { name: _, field, .. }) => { - Ok(field.data_type().clone()) - } + Expr::LambdaVariable(LambdaVariable { name, field, .. }) => Ok(field + .as_ref() + .ok_or_else(|| plan_datafusion_err!("unresolved LambdaVariable {name}"))? + .data_type() + .clone()), } } @@ -366,7 +369,13 @@ impl ExprSchemable for Expr { Ok(nullable) } Expr::Lambda(l) => l.body.nullable(input_schema), - Expr::LambdaVariable(c) => Ok(c.field.is_nullable()), + Expr::LambdaVariable(l) => Ok(l + .field + .as_ref() + .ok_or_else(|| { + plan_datafusion_err!("unresolved LambdaVariable {}", l.name) + })? + .is_nullable()), } } @@ -563,7 +572,6 @@ impl ExprSchemable for Expr { .into_iter() .map(|f| (f.data_type().clone(), f)) .unzip(); - // Verify that function is invoked with correct number and type of arguments as defined in `TypeSignature` let new_data_types = data_types_with_scalar_udf(&arg_types, func) .map_err(|err| { @@ -594,7 +602,6 @@ impl ExprSchemable for Expr { _ => None, }) .collect::>(); - let args = ReturnFieldArgs { arg_fields: &new_fields, scalar_arguments: &arguments, @@ -635,15 +642,16 @@ impl ExprSchemable for Expr { .map(|arg| { let field = arg.to_field(schema)?.1; match arg { - Expr::Lambda(_lambda) => { - Ok(ValueOrLambdaField::Lambda(field)) - } - _ => Ok(ValueOrLambdaField::Value(field)), + Expr::Lambda(_lambda) => Ok(ValueOrLambda::Lambda(field)), + _ => Ok(ValueOrLambda::Value(field)), } }) .collect::>>()?; - let arguments = func.args + let new_fields = value_fields_with_lambda_udf(&arg_fields, func.func.as_ref())?; + + let arguments = func + .args .iter() .map(|e| match e { Expr::Literal(sv, _) => Some(sv), @@ -652,13 +660,17 @@ impl ExprSchemable for Expr { .collect::>(); let args = LambdaReturnFieldArgs { - arg_fields: &arg_fields, + arg_fields: &new_fields, scalar_arguments: &arguments, }; func.func.return_field_from_args(args) } - Expr::LambdaVariable(c) => Ok(Arc::clone(&c.field)), + Expr::LambdaVariable(l) => { + Ok(Arc::clone(l.field.as_ref().ok_or_else(|| { + plan_datafusion_err!("unresolved LambdaVariable {}", l.name) + })?)) + } }?; Ok(( diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index b03bab622a357..4ba328ddaf450 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -120,8 +120,8 @@ pub use udaf::{ }; pub use udf::{ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl}; pub use udlf::{ - LambdaFunctionArgs, LambdaFunctionLambdaArg, LambdaReturnFieldArgs, LambdaSignature, - LambdaTypeSignature, LambdaUDF, ValueOrLambda, ValueOrLambdaField, ValueOrLambdaParameter, + LambdaFunctionArgs, LambdaArgument, LambdaReturnFieldArgs, LambdaSignature, + LambdaTypeSignature, LambdaUDF, ValueOrLambda, }; pub use udwf::{LimitEffect, ReversedUDWF, WindowUDF, WindowUDFImpl}; pub use window_frame::{WindowFrame, WindowFrameBound, WindowFrameUnits}; diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 0f0d81186d68f..1fa9e3aa6ab84 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -41,8 +41,7 @@ use crate::logical_plan::display::{GraphvizVisitor, IndentVisitor}; use crate::logical_plan::extension::UserDefinedLogicalNode; use crate::logical_plan::{DmlStatement, Statement}; use crate::utils::{ - enumerate_grouping_sets, exprlist_to_fields, find_out_reference_exprs, - grouping_set_expr_count, grouping_set_to_exprlist, split_conjunction, + enumerate_grouping_sets, exprlist_to_fields, find_out_reference_exprs, grouping_set_expr_count, grouping_set_to_exprlist, merge_schema, split_conjunction }; use crate::{ build_join_schema, expr_vec_fmt, requalify_sides_if_needed, BinaryExpr, @@ -2053,6 +2052,14 @@ impl LogicalPlan { } Wrapper(self) } + + pub fn resolve_lambdas_variables(self) -> Result> { + self.transform_with_subqueries(|plan| { + let schema = merge_schema(&plan.inputs()); + + plan.map_expressions(|expr| expr.resolve_lambdas_variables(&schema)) + }) + } } impl Display for LogicalPlan { diff --git a/datafusion/expr/src/planner.rs b/datafusion/expr/src/planner.rs index 7696faca0922a..765b083beabb5 100644 --- a/datafusion/expr/src/planner.rs +++ b/datafusion/expr/src/planner.rs @@ -110,6 +110,9 @@ pub trait ContextProvider { /// Return all scalar function names fn udf_names(&self) -> Vec; + + /// Return all lambda function names + fn udlf_names(&self) -> Vec; /// Return all aggregate function names fn udaf_names(&self) -> Vec; diff --git a/datafusion/expr/src/type_coercion/functions.rs b/datafusion/expr/src/type_coercion/functions.rs index bcaff11bcdb49..fdad0b174f476 100644 --- a/datafusion/expr/src/type_coercion/functions.rs +++ b/datafusion/expr/src/type_coercion/functions.rs @@ -16,7 +16,10 @@ // under the License. use super::binary::binary_numeric_coercion; -use crate::{AggregateUDF, ScalarUDF, Signature, TypeSignature, WindowUDF}; +use crate::{ + AggregateUDF, LambdaTypeSignature, LambdaUDF, ScalarUDF, Signature, TypeSignature, + ValueOrLambda, WindowUDF, +}; use arrow::datatypes::FieldRef; use arrow::{ compute::can_cast_types, @@ -77,6 +80,70 @@ pub fn data_types_with_scalar_udf( try_coerce_types(func.name(), valid_types, current_types, type_signature) } +/// Performs type coercion for lambda function arguments. +/// +/// For value arguments, returns the field to which each +/// argument must be coerced to match `signature`. +/// For lambda arguments, returns a clone of the associated data +/// +/// For more details on coercion in general, please see the +/// [`type_coercion`](crate::type_coercion) module. +pub fn value_fields_with_lambda_udf( + current_fields: &[ValueOrLambda], + func: &dyn LambdaUDF, +) -> Result>> { + match func.signature().type_signature { + LambdaTypeSignature::UserDefined => { + let arg_types = current_fields + .iter() + .map(|p| match p { + ValueOrLambda::Value(field) => { + ValueOrLambda::Value(field.data_type().clone()) + } + ValueOrLambda::Lambda(_) => ValueOrLambda::Lambda(()), + }) + .collect::>(); + + let coerced_types = func.coerce_value_types(&arg_types)?; + + std::iter::zip(current_fields, coerced_types) + .map(|(field, coerce_to)| match (field, coerce_to) { + (ValueOrLambda::Value(field), Some(coerce_to)) => { + Ok(ValueOrLambda::Value(Arc::new( + field.as_ref().clone().with_data_type(coerce_to), + ))) + } + (ValueOrLambda::Lambda(v), None) => { + Ok(ValueOrLambda::Lambda(v.clone())) + } + (ValueOrLambda::Value(_), None) => plan_err!( + "{} coerce_values_types returned None for a value", + func.name() + ), + (ValueOrLambda::Lambda(_), Some(_)) => plan_err!( + "{} coerce_values_types returned Some for a lambda", + func.name() + ), + }) + .collect() + } + LambdaTypeSignature::VariadicAny => { + Ok(current_fields.to_vec()) + } + LambdaTypeSignature::Any(number) => { + if current_fields.len() != number { + return plan_err!( + "The function '{}' expected {number} arguments but received {}", + func.name(), + current_fields.len() + ); + } + + Ok(current_fields.to_vec()) + } + } +} + /// Performs type coercion for aggregate function arguments. /// /// Returns the fields to which each argument must be coerced to diff --git a/datafusion/expr/src/udlf.rs b/datafusion/expr/src/udlf.rs index b2191e0883d30..b9d1f193db3dc 100644 --- a/datafusion/expr/src/udlf.rs +++ b/datafusion/expr/src/udlf.rs @@ -21,13 +21,13 @@ use crate::expr::schema_name_from_exprs_comma_separated_without_space; use crate::simplify::{ExprSimplifyResult, SimplifyInfo}; use crate::sort_properties::{ExprProperties, SortProperties}; use crate::{ColumnarValue, Documentation, Expr}; -use arrow::array::RecordBatch; -use arrow::datatypes::{DataType, Field, FieldRef}; +use arrow::array::{ArrayRef, RecordBatch}; +use arrow::datatypes::{DataType, Field, FieldRef, Schema}; use datafusion_common::config::ConfigOptions; -use datafusion_common::{Result, ScalarValue, not_impl_err}; +use datafusion_common::{exec_err, not_impl_err, Result, ScalarValue}; use datafusion_expr_common::dyn_eq::{DynEq, DynHash}; use datafusion_expr_common::interval_arithmetic::Interval; -use datafusion_expr_common::signature::{Volatility}; +use datafusion_expr_common::signature::Volatility; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use std::any::Any; use std::cmp::Ordering; @@ -68,6 +68,9 @@ pub enum LambdaTypeSignature { /// Provides information necessary for calling a lambda function. /// +/// - [`LambdaTypeSignature`] defines the argument types that a function has implementations +/// for. +/// /// - [`Volatility`] defines how the output of the function changes with the input. #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] pub struct LambdaSignature { @@ -84,7 +87,7 @@ pub struct LambdaSignature { } impl LambdaSignature { - /// Creates a new LambdaSignature from a given type signature and volatility. + /// Creates a new `LambdaSignature` from a given type signature and volatility. pub fn new(type_signature: LambdaTypeSignature, volatility: Volatility) -> Self { LambdaSignature { type_signature, @@ -163,35 +166,35 @@ impl Hash for dyn LambdaUDF { } } -#[derive(Clone, Debug, PartialEq, Eq)] -pub enum ValueOrLambdaParameter { - /// A value with the given associated data - Value(T), - /// A lambda - Lambda, -} - /// Arguments passed to [`LambdaUDF::invoke_with_args`] when invoking a /// lambda function. #[derive(Debug, Clone)] pub struct LambdaFunctionArgs { /// The evaluated arguments to the function - pub args: Vec, + pub args: Vec>, /// Field associated with each arg, if it exists - pub arg_fields: Vec, + pub arg_fields: Vec>, /// The number of rows in record batch being evaluated pub number_rows: usize, - /// The return field of the lambda function returned (from `return_type` - /// or `return_field_from_args`) when creating the physical expression - /// from the logical expression + /// The return field of the lambda function returned + /// (from `return_field_from_args`) when creating the + /// physical expression from the logical expression pub return_field: FieldRef, /// The config options at execution time pub config_options: Arc, } +impl LambdaFunctionArgs { + /// The return type of the function. See [`Self::return_field`] for more + /// details. + pub fn return_type(&self) -> &DataType { + self.return_field.data_type() + } +} + /// A lambda argument to a LambdaFunction #[derive(Clone, Debug)] -pub struct LambdaFunctionLambdaArg { +pub struct LambdaArgument { /// The parameters defined in this lambda /// /// For example, for `array_transform([2], v -> -v)`, @@ -203,26 +206,114 @@ pub struct LambdaFunctionLambdaArg { /// this will be the physical expression of `-v` pub body: Arc, /// A RecordBatch containing at least the captured columns inside this lambda body, if any - /// Note that it may contain additional, non-specified columns, but that's implementation detail + /// Note that it may contain additional, non-specified columns, but that's a implementation detail /// /// For example, for `array_transform([2], v -> v + a + b)`, - /// this will be a `RecordBatch` with two columns, `a` and `b` + /// this will be a `RecordBatch` with at least two columns, `a` and `b` pub captures: Option, } -impl LambdaFunctionArgs { - /// The return type of the function. See [`Self::return_field`] for more - /// details. - pub fn return_type(&self) -> &DataType { - self.return_field.data_type() +impl LambdaArgument { + /// For adjusting multiple arrays by indices, use [`take_arrays`] + /// + /// [`take_arrays`]: arrow::compute::take_arrays + pub fn evaluate( + &self, + args: &[&dyn Fn() -> Result], + mut adjust: impl FnMut(&[ArrayRef]) -> Result>, + ) -> Result { + let adjusted_captures = self + .captures + .as_ref() + .map(|captures| { + let adjusted_columns = adjust(captures.columns())?; + + RecordBatch::try_new(captures.schema(), adjusted_columns) + }) + .transpose()?; + + let merged = merge_captures_with_variables( + adjusted_captures.as_ref(), + &self.params, + args, + )?; + + self.body.evaluate(&merged) } } -// An argument to a LambdaUDF that supports lambdas -#[derive(Clone, Debug)] -pub enum ValueOrLambda { - Value(ColumnarValue), - Lambda(LambdaFunctionLambdaArg), +fn merge_captures_with_variables( + captures: Option<&RecordBatch>, + params: &[FieldRef], + variables: &[&dyn Fn() -> Result], +) -> Result { + if variables.len() < params.len() { + return exec_err!( + "expected at least {} lambda arguments to merge with captures, got {}", + params.len(), + variables.len() + ); + } + + match captures { + Some(captures) => { + let old_fields = captures.schema_ref().fields(); + + let mut new_fields = old_fields + .iter() + .map(|field| { + if !fields_contains(params, field.name()) { + return Arc::clone(field); + } + + let mut i = 0; + + loop { + let alias = format!("{}_shadowed_{i}", field.name()); + + if !fields_contains(params, &alias) + && old_fields.find(&alias).is_none() + { + break Arc::new(Field::new( + alias, + field.data_type().clone(), + field.is_nullable(), + )); + } + + i += 1; + } + }) + .collect::>(); + + new_fields.extend_from_slice(params); + + let mut columns = captures.columns().to_vec(); + + for arg in &variables[..params.len()] { + columns.push(arg()?); + } + + let new_schema = Arc::new(Schema::new(new_fields)); + + Ok(RecordBatch::try_new(new_schema, columns)?) + } + None => { + let columns = variables + .iter() + .take(params.len()) + .map(|arg| arg()) + .collect::>()?; + + let schema = Arc::new(Schema::new(params)); + + Ok(RecordBatch::try_new(schema, columns)?) + } + } +} + +fn fields_contains(fields: &[FieldRef], name: &str) -> bool { + fields.iter().any(|f| f.name().as_str() == name) } /// Information about arguments passed to the function @@ -236,28 +327,31 @@ pub enum ValueOrLambda { pub struct LambdaReturnFieldArgs<'a> { /// The data types of the arguments to the function /// - /// If argument `i` to the function is a lambda, it will be the field returned by the - /// lambda when executed with the arguments returned from `LambdaUDF::lambdas_parameters` + /// If argument `i` to the function is a lambda, it will be the field of the result of the + /// lambda if evaluated with the parameters returned from [`LambdaUDF::lambdas_parameters`] /// /// For example, with `array_transform([1], v -> v == 5)` - /// this field will be `[Field::new("", DataType::List(DataType::Int32), false), Field::new("", DataType::Boolean, false)]` - pub arg_fields: &'a [ValueOrLambdaField], + /// this field will be `[ + /// ValueOrLambda::Value(Field::new("", DataType::List(DataType::Int32), false)), + /// ValueOrLambda::Lambda(Field::new("", DataType::Boolean, false)) + /// ]` + pub arg_fields: &'a [ValueOrLambda], /// Is argument `i` to the function a scalar (constant)? /// /// If the argument `i` is not a scalar, it will be None /// - /// For example, if a function is called like `my_function(column_a, 5)` - /// this field will be `[None, Some(ScalarValue::Int32(Some(5)))]` + /// For example, if a function is called like `array_transform([1], v -> v == 5)` + /// this field will be `[Some(ScalarValue::List(...), None]` pub scalar_arguments: &'a [Option<&'a ScalarValue>], } -/// A tagged Field indicating whether it correspond to a value or a lambda argument -#[derive(Clone, Debug)] -pub enum ValueOrLambdaField { - /// The Field of a ColumnarValue argument - Value(FieldRef), - /// The Field of the return of the lambda body when evaluated with the parameters from LambdaUDF::lambda_parameters - Lambda(FieldRef), +/// An argument to a lambda function +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum ValueOrLambda { + /// A value with associated data + Value(V), + /// A lambda with associated data + Lambda(L), } /// Trait for implementing user defined lambda functions. @@ -265,10 +359,9 @@ pub enum ValueOrLambdaField { /// This trait exposes the full API for implementing user defined functions and /// can be used to implement any function. /// -/// See [`advanced_udf.rs`] for a full example with complete implementation and -/// [`LambdaUDF`] for other available options. +/// See [`array_transform.rs`] for a commented complete implementation /// -/// [`advanced_udf.rs`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/advanced_udf.rs +/// [`array_transform.rs`]: https://github.com/apache/datafusion/blob/main/datafusion/functions-nested/src/array_transform.rs /// /// # Basic Example /// ``` @@ -364,7 +457,7 @@ pub trait LambdaUDF: Debug + DynEq + DynHash + Send + Sync { /// function has an implementation, and the function's [`Volatility`]. /// /// See [`LambdaSignature`] for more details on argument type handling - /// and [`Self::return_type`] for computing the return type. + /// and [`Self::return_field_from_args`] for computing the return type. /// /// [`Volatility`]: datafusion_expr_common::signature::Volatility fn signature(&self) -> &LambdaSignature; @@ -391,10 +484,16 @@ pub trait LambdaUDF: Debug + DynEq + DynHash + Send + Sync { /// /// * `Some(LambdaUDF)` - A new instance of this function configured with the new settings /// * `None` - If this function does not change with new configuration settings (the default) - fn with_updated_config(&self, _config: &ConfigOptions) -> Option> { + fn with_updated_config(&self, _config: &ConfigOptions) -> Option> { None } + /// Returns the parameters that any lambda supports + fn lambdas_parameters( + &self, + args: &[ValueOrLambda], + ) -> Result>>>; + /// What type will be returned by this function, given the arguments? /// /// By default, this function calls [`Self::return_type`] with the @@ -632,7 +731,7 @@ pub trait LambdaUDF: Debug + DynEq + DynHash + Send + Sync { /// arguments to these specific types. fn coerce_value_types( &self, - _arg_types: &[ValueOrLambdaParameter], + _arg_types: &[ValueOrLambda], ) -> Result>> { not_impl_err!("Function {} does not implement coerce_types", self.name()) } @@ -644,14 +743,6 @@ pub trait LambdaUDF: Debug + DynEq + DynHash + Send + Sync { fn documentation(&self) -> Option<&Documentation> { None } - - /// Returns the parameters that any lambda supports - fn lambdas_parameters( - &self, - args: &[ValueOrLambdaParameter], - ) -> Result>>> { - Ok(vec![None; args.len()]) - } } #[cfg(test)] @@ -680,7 +771,17 @@ mod tests { &self.signature } - fn return_field_from_args(&self, _args: LambdaReturnFieldArgs) -> Result { + fn lambdas_parameters( + &self, + _args: &[ValueOrLambda], + ) -> Result>>> { + unimplemented!() + } + + fn return_field_from_args( + &self, + _args: LambdaReturnFieldArgs, + ) -> Result { unimplemented!() } diff --git a/datafusion/functions-nested/Cargo.toml b/datafusion/functions-nested/Cargo.toml index 0299aebdcac47..6e0d1048f9697 100644 --- a/datafusion/functions-nested/Cargo.toml +++ b/datafusion/functions-nested/Cargo.toml @@ -53,7 +53,6 @@ datafusion-functions = { workspace = true } datafusion-functions-aggregate = { workspace = true } datafusion-functions-aggregate-common = { workspace = true } datafusion-macros = { workspace = true } -datafusion-physical-expr = { workspace = true } datafusion-physical-expr-common = { workspace = true } itertools = { workspace = true, features = ["use_std"] } log = { workspace = true } diff --git a/datafusion/functions-nested/src/array_transform.rs b/datafusion/functions-nested/src/array_transform.rs index e0c4ab28c1fef..843029114fa5c 100644 --- a/datafusion/functions-nested/src/array_transform.rs +++ b/datafusion/functions-nested/src/array_transform.rs @@ -20,35 +20,30 @@ use arrow::{ array::{ Array, ArrayRef, AsArray, FixedSizeListArray, LargeListArray, ListArray, - RecordBatch, RecordBatchOptions, - }, - compute::take_record_batch, - datatypes::{DataType, Field, FieldRef, Schema}, + }, compute::take_arrays, datatypes::{DataType, Field, FieldRef} }; use datafusion_common::{ - exec_err, - tree_node::{ - Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeRewriter, + exec_err, plan_err, + utils::{ + adjust_offsets_for_slice, list_values, list_values_index, list_values_row_number, + take_function_args, }, - utils::{elements_indices, list_indices, list_values, take_function_args}, - HashMap, Result, + Result, }; use datafusion_expr::{ ColumnarValue, Documentation, LambdaFunctionArgs, LambdaSignature, LambdaUDF, - ValueOrLambda, ValueOrLambdaField, ValueOrLambdaParameter, Volatility, + ValueOrLambda, Volatility, }; use datafusion_macros::user_doc; -use datafusion_physical_expr::expressions::{LambdaExpr, LambdaVariable}; -use datafusion_physical_expr_common::physical_expr::PhysicalExpr; -use std::{any::Any, sync::Arc}; - -//make_udf_expr_and_func!( -// ArrayTransform, -// array_transform, -// array lambda, -// "transforms the values of a array", -// array_transform_udf -//); +use std::{any::Any, fmt::Debug, sync::Arc}; + +make_udlf_expr_and_func!( + ArrayTransform, + array_transform, + array lambda, + "transforms the values of a array", + array_transform_udlf +); #[user_doc( doc_section(label = "Array Functions"), @@ -83,7 +78,7 @@ impl Default for ArrayTransform { impl ArrayTransform { pub fn new() -> Self { Self { - signature: LambdaSignature::any(2, Volatility::Immutable), + signature: LambdaSignature::user_defined(Volatility::Immutable), aliases: vec![String::from("list_transform")], } } @@ -106,19 +101,58 @@ impl LambdaUDF for ArrayTransform { &self.signature } + fn coerce_value_types( + &self, + arg_types: &[ValueOrLambda], + ) -> Result>> { + let (list, _lambda) = value_lambda_pair(self.name(), arg_types)?; + + let coerced = match list { + DataType::List(_) + | DataType::LargeList(_) + | DataType::FixedSizeList(_, _) => list.clone(), + DataType::ListView(field) => DataType::List(Arc::clone(field)), + DataType::LargeListView(field) => DataType::LargeList(Arc::clone(field)), + _ => { + return plan_err!( + "{} expected a list as first argument, got {}", + self.name(), + list + ) + } + }; + + Ok(vec![Some(coerced), None]) + } + + fn lambdas_parameters( + &self, + args: &[ValueOrLambda], + ) -> Result>>> { + let (list, _lambda) = value_lambda_pair(self.name(), args)?; + + let (field, index_type) = match list.data_type() { + DataType::List(field) => (field, DataType::Int32), + DataType::LargeList(field) => (field, DataType::Int64), + DataType::FixedSizeList(field, _) => (field, DataType::Int32), + _ => return plan_err!("expected list, got {list}"), + }; + + // we don't need to omit the index in the case the lambda don't specify, e.g. array_transform([], v -> v*2), + // nor check whether the lambda contains more than two parameters, e.g. array_transform([], (v, i, j) -> v+i+j), + // as datafusion will do that for us + let value = Field::new("", field.data_type().clone(), field.is_nullable()) + .with_metadata(field.metadata().clone()); + let index = Field::new("", index_type, false); + + Ok(vec![None, Some(vec![value, index])]) + } + fn return_field_from_args( &self, args: datafusion_expr::LambdaReturnFieldArgs, ) -> Result> { - let [ValueOrLambdaField::Value(list), ValueOrLambdaField::Lambda(lambda)] = - take_function_args(self.name(), args.arg_fields)? - else { - return exec_err!( - "{} expects a value follewed by a lambda, got {:?}", - self.name(), - args - ); - }; + let (list, lambda) = value_lambda_pair(self.name(), args.arg_fields)?; //TODO: should metadata be copied into the transformed array? @@ -134,58 +168,45 @@ impl LambdaUDF for ArrayTransform { DataType::List(_) => DataType::List(field), DataType::LargeList(_) => DataType::LargeList(field), DataType::FixedSizeList(_, size) => DataType::FixedSizeList(field, *size), - _ => unreachable!(), + other => plan_err!("expected list, got {other}")?, }; Ok(Arc::new(Field::new("", return_type, list.is_nullable()))) } fn invoke_with_args(&self, args: LambdaFunctionArgs) -> Result { - let [list_value, lambda] = take_function_args(self.name(), &args.args)?; - - let (ValueOrLambda::Value(list_value), ValueOrLambda::Lambda(lambda)) = - (list_value, lambda) - else { - return exec_err!( - "{} expects a value followed by a lambda, got {:?}", - self.name(), - &args.args - ); - }; + let (list, lambda) = value_lambda_pair(self.name(), &args.args)?; - let list_array = list_value.to_array(args.number_rows)?; + let list_array = list.to_array(args.number_rows)?; + // as per list_values docs, if list_array is sliced, list_values will be sliced too, + // so before constructing the transformed array below, we must adjust the list offsets with + // adjust_offsets_for_slice let list_values = list_values(&list_array)?; // if any column got captured, we need to adjust it to the values arrays, // duplicating values of list with mulitple values and removing values of empty lists - // list_indices is not cheap so is important to avoid it when no column is captured - let adjusted_captures = lambda - .captures - .as_ref() - .map(|captures| take_record_batch(captures, &list_indices(&list_array)?)) - .transpose()?; - - // use closures and merge_captures_with_lazy_args so that it calls only the needed ones based on the number of arguments - // avoiding unnecessary computations - let values_param = || Ok(Arc::clone(list_values)); - let indices_param = || elements_indices(&list_array); - - let binded_body = bind_lambda_variables( - Arc::clone(&lambda.body), - &lambda.params, - &[&values_param, &indices_param], - )?; - - // call the transforming expression with the record batch composed of the list values merged with captured columns - let transformed_values = binded_body - .evaluate(&adjusted_captures.unwrap_or_else(|| { - RecordBatch::try_new_with_options( - Arc::new(Schema::empty()), - vec![], - &RecordBatchOptions::new().with_row_count(Some(list_values.len())), - ) - .unwrap() - }))? + // list_values_row_number is not cheap so is important to avoid it when no column is captured + let mut adjust_indices = None; + + // use closures so that lambda.evaluate calls only the needed ones + // based on the number of arguments avoiding unnecessary computations + let values_param = || Ok(Arc::clone(&list_values)); + let indices_param = || list_values_index(&list_array); + + // call the transforming lambda with the record batch composed of the list values merged with captured columns + let transformed_values = lambda + .evaluate( + &[&values_param, &indices_param], + |arrays| { + let indices = match &adjust_indices { + Some(v) => v, + None => { + adjust_indices.insert(list_values_row_number(&list_array)?) + } + }; + Ok(take_arrays(arrays, indices, None)?) + }, + )? .into_array(list_values.len())?; let field = match args.return_field.data_type() { @@ -204,20 +225,26 @@ impl LambdaUDF for ArrayTransform { let transformed_list = match list_array.data_type() { DataType::List(_) => { let list = list_array.as_list(); + // since we called list_values above which would return sliced values for + // a sliced list, we must adjust the offsets here as otherwise they would be invalid + let adjusted_offsets = adjust_offsets_for_slice(list); Arc::new(ListArray::new( field, - list.offsets().clone(), + adjusted_offsets, transformed_values, list.nulls().cloned(), )) as ArrayRef } DataType::LargeList(_) => { let large_list = list_array.as_list(); + // since we called list_values above which would return sliced values for + // a sliced list, we must adjust the offsets here as otherwise they would be invalid + let adjusted_offsets = adjust_offsets_for_slice(large_list); Arc::new(LargeListArray::new( field, - large_list.offsets().clone(), + adjusted_offsets, transformed_values, large_list.nulls().cloned(), )) @@ -236,99 +263,23 @@ impl LambdaUDF for ArrayTransform { Ok(ColumnarValue::Array(transformed_list)) } - fn lambdas_parameters( - &self, - args: &[ValueOrLambdaParameter], - ) -> Result>>> { - let [ValueOrLambdaParameter::Value(list), ValueOrLambdaParameter::Lambda] = args - else { - return exec_err!( - "{} expects a value follewed by a lambda, got {:?}", - self.name(), - args - ); - }; - - let (field, index_type) = match list.data_type() { - DataType::List(field) => (field, DataType::Int32), - DataType::LargeList(field) => (field, DataType::Int64), - DataType::FixedSizeList(field, _) => (field, DataType::Int32), - _ => return exec_err!("expected list, got {list}"), - }; - - // we don't need to omit the index in the case the lambda don't specify, e.g. array_transform([], v -> v*2), - // nor check whether the lambda contains more than two parameters, e.g. array_transform([], (v, i, j) -> v+i+j), - // as datafusion will do that for us - let value = Field::new("", field.data_type().clone(), field.is_nullable()) - .with_metadata(field.metadata().clone()); - let index = Field::new("", index_type, false); - - Ok(vec![None, Some(vec![value, index])]) - } - fn documentation(&self) -> Option<&Documentation> { self.doc() } } -fn bind_lambda_variables( - expr: Arc, - params: &[FieldRef], - args: &[&dyn Fn() -> Result], -) -> Result> { - let columns = std::iter::zip(params, args) - .map(|(param, arg)| Ok((param.name().as_str(), (arg()?, 0)))) - .collect::>>()?; +fn value_lambda_pair<'a, V: Debug, L: Debug>( + name: &str, + args: &'a [ValueOrLambda], +) -> Result<(&'a V, &'a L)> { + let [value, lambda] = take_function_args(name, args)?; - expr.rewrite(&mut BindLambdaVariable::new(columns)).data() -} - -struct BindLambdaVariable<'a> { - columns: HashMap<&'a str, (ArrayRef, usize)>, -} - -impl<'a> BindLambdaVariable<'a> { - fn new(columns: HashMap<&'a str, (ArrayRef, usize)>) -> Self { - Self { columns } - } -} - -impl TreeNodeRewriter for BindLambdaVariable<'_> { - type Node = Arc; - - fn f_down(&mut self, node: Self::Node) -> Result> { - if let Some(lambda_variable) = node.as_any().downcast_ref::() { - if let Some((value, shadows)) = self.columns.get(lambda_variable.name()) { - if *shadows == 0 { - return Ok(Transformed::yes(Arc::new( - lambda_variable.clone().with_value(Arc::clone(value)), - ))); - } - } - } else if let Some(inner_lambda) = node.as_any().downcast_ref::() { - for param in inner_lambda.params() { - if let Some((_value, shadows)) = self.columns.get_mut(param.as_str()) { - *shadows += 1; - } - } + let (ValueOrLambda::Value(value), ValueOrLambda::Lambda(lambda)) = (value, lambda) + else { + return plan_err!( + "{name} expects a value followed by a lambda, got {value:?} and {lambda:?}" + ); + }; - if self.columns.values().all(|(_value, shadows)| *shadows > 0) { - return Ok(Transformed::new(node, false, TreeNodeRecursion::Jump)); - } - } - - Ok(Transformed::no(node)) - } - - fn f_up(&mut self, node: Self::Node) -> Result> { - if let Some(inner_lambda) = node.as_any().downcast_ref::() { - for param in inner_lambda.params() { - if let Some((_value, shadows)) = self.columns.get_mut(param.as_str()) { - *shadows -= 1; - } - } - } - - Ok(Transformed::no(node)) - } + Ok((value, lambda)) } diff --git a/datafusion/functions-nested/src/lib.rs b/datafusion/functions-nested/src/lib.rs index c93a55cce1a4f..e8dc283b4cb88 100644 --- a/datafusion/functions-nested/src/lib.rs +++ b/datafusion/functions-nested/src/lib.rs @@ -36,6 +36,9 @@ #[macro_use] pub mod macros; +#[macro_use] +pub mod macros_lambda; + pub mod array_has; pub mod array_transform; pub mod cardinality; @@ -70,7 +73,7 @@ pub mod utils; use datafusion_common::Result; use datafusion_execution::FunctionRegistry; -use datafusion_expr::ScalarUDF; +use datafusion_expr::{LambdaUDF, ScalarUDF}; use log::debug; use std::sync::Arc; @@ -79,7 +82,7 @@ pub mod expr_fn { pub use super::array_has::array_has; pub use super::array_has::array_has_all; pub use super::array_has::array_has_any; - //pub use super::array_transform::array_transform; + pub use super::array_transform::array_transform; pub use super::cardinality::cardinality; pub use super::concat::array_append; pub use super::concat::array_concat; @@ -147,7 +150,6 @@ pub fn all_default_nested_functions() -> Vec> { array_has::array_has_udf(), array_has::array_has_all_udf(), array_has::array_has_any_udf(), - //array_transform::array_transform_udf(), empty::array_empty_udf(), length::array_length_udf(), distance::array_distance_udf(), @@ -177,6 +179,10 @@ pub fn all_default_nested_functions() -> Vec> { ] } +pub fn all_default_lambda_functions() -> Vec> { + vec![array_transform::array_transform_udlf()] +} + /// Registers all enabled packages with a [`FunctionRegistry`] pub fn register_all(registry: &mut dyn FunctionRegistry) -> Result<()> { let functions: Vec> = all_default_nested_functions(); @@ -188,25 +194,40 @@ pub fn register_all(registry: &mut dyn FunctionRegistry) -> Result<()> { Ok(()) as Result<()> })?; + let functions: Vec> = all_default_lambda_functions(); + functions.into_iter().try_for_each(|udlf| { + let existing_udlf = registry.register_udlf(udlf)?; + if let Some(existing_udlf) = existing_udlf { + debug!("Overwrite existing UDLF: {}", existing_udlf.name()); + } + Ok(()) as Result<()> + })?; + Ok(()) } #[cfg(test)] mod tests { - use crate::all_default_nested_functions; + use crate::{all_default_lambda_functions, all_default_nested_functions}; use datafusion_common::Result; use std::collections::HashSet; #[test] fn test_no_duplicate_name() -> Result<()> { + let scalars = all_default_nested_functions(); + let scalars = scalars.iter().map(|s| (s.name(), s.aliases())); + + let lambdas = all_default_lambda_functions(); + let lambdas = lambdas.iter().map(|l| (l.name(), l.aliases())); + let mut names = HashSet::new(); - for func in all_default_nested_functions() { + + for (name, aliases) in scalars.chain(lambdas) { assert!( - names.insert(func.name().to_string().to_lowercase()), - "duplicate function name: {}", - func.name() + names.insert(name.to_string().to_lowercase()), + "duplicate function name: {name}", ); - for alias in func.aliases() { + for alias in aliases { assert!( names.insert(alias.to_string().to_lowercase()), "duplicate function name: {alias}" diff --git a/datafusion/functions-nested/src/macros_lambda.rs b/datafusion/functions-nested/src/macros_lambda.rs new file mode 100644 index 0000000000000..2f426ba8ee9b3 --- /dev/null +++ b/datafusion/functions-nested/src/macros_lambda.rs @@ -0,0 +1,113 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +/// Creates external API functions for an array UDF. Specifically, creates +/// +/// 1. Single `LambdaUDF` instance +/// +/// Creates a singleton `LambdaUDF` of the `$UDF` function named `STATIC_$(UDF)` and a +/// function named `$LAMBDA_UDF_FUNC` which returns that function named `STATIC_$(UDF)`. +/// +/// This is used to ensure creating the list of `LambdaUDF` only happens once. +/// +/// # 2. `expr_fn` style function +/// +/// These are functions that create an `Expr` that invokes the UDF, used +/// primarily to programmatically create expressions. +/// +/// For example: +/// ```text +/// pub fn array_to_string(delimiter: Expr) -> Expr { +/// ... +/// } +/// ``` +/// # Arguments +/// * `UDF`: name of the [`LambdaUDF`] +/// * `EXPR_FN`: name of the expr_fn function to be created +/// * `arg`: 0 or more named arguments for the function +/// * `DOC`: documentation string for the function +/// * `LAMBDA_UDF_FUNC`: name of the function to create (just) the `LambdaUDF` +/// * (optional) `$CTOR`: Pass a custom constructor. When omitted it +/// automatically resolves to `$UDF::new()`. +/// +/// [`LambdaUDF`]: datafusion_expr::LambdaUDF +macro_rules! make_udlf_expr_and_func { + ($UDF:ident, $EXPR_FN:ident, $($arg:ident)*, $DOC:expr, $LAMBDA_UDF_FN:ident) => { + make_udlf_expr_and_func!($UDF, $EXPR_FN, $($arg)*, $DOC, $LAMBDA_UDF_FN, $UDF::new); + }; + ($UDF:ident, $EXPR_FN:ident, $($arg:ident)*, $DOC:expr, $LAMBDA_UDF_FN:ident, $CTOR:path) => { + paste::paste! { + // "fluent expr_fn" style function + #[doc = $DOC] + pub fn $EXPR_FN($($arg: datafusion_expr::Expr),*) -> datafusion_expr::Expr { + datafusion_expr::Expr::LambdaFunction(datafusion_expr::expr::LambdaFunction::new( + $LAMBDA_UDF_FN(), + vec![$($arg),*], + )) + } + create_lambda!($UDF, $LAMBDA_UDF_FN, $CTOR); + } + }; + ($UDF:ident, $EXPR_FN:ident, $DOC:expr, $LAMBDA_UDF_FN:ident) => { + make_udlf_expr_and_func!($UDF, $EXPR_FN, $DOC, $LAMBDA_UDF_FN, $UDF::new); + }; + ($UDF:ident, $EXPR_FN:ident, $DOC:expr, $LAMBDA_UDF_FN:ident, $CTOR:path) => { + paste::paste! { + // "fluent expr_fn" style function + #[doc = $DOC] + pub fn $EXPR_FN(arg: Vec) -> datafusion_expr::Expr { + datafusion_expr::Expr::LambdaFunction(datafusion_expr::expr::LambdaFunction::new( + $LAMBDA_UDF_FN(), + arg, + )) + } + create_lambda!($UDF, $LAMBDA_UDF_FN, $CTOR); + } + }; +} + +/// Creates a singleton `LambdaUDF` of the `$UDF` function named `STATIC_$(UDF)` and a +/// function named `$LAMBDA_UDF_FUNC` which returns that function named `STATIC_$(UDF)`. +/// +/// This is used to ensure creating the list of `LambdaUDF` only happens once. +/// +/// # Arguments +/// * `UDF`: name of the [`LambdaUDF`] +/// * `LAMBDA_UDF_FUNC`: name of the function to create (just) the `LambdaUDF` +/// * (optional) `$CTOR`: Pass a custom constructor. When omitted it +/// automatically resolves to `$UDF::new()`. +/// +/// [`LambdaUDF`]: datafusion_expr::LambdaUDF +macro_rules! create_lambda { + ($UDF:ident, $LAMBDA_UDF_FN:ident) => { + create_lambda!($UDF, $LAMBDA_UDF_FN, $UDF::new); + }; + ($UDF:ident, $LAMBDA_UDF_FN:ident, $CTOR:path) => { + paste::paste! { + #[doc = concat!("LambdaFunction that returns a [`LambdaUDF`](datafusion_expr::LambdaUDF) for ")] + #[doc = stringify!($UDF)] + pub fn $LAMBDA_UDF_FN() -> std::sync::Arc { + // Singleton instance of [`$UDF`], ensures the UDF is only created once + static INSTANCE: std::sync::LazyLock> = + std::sync::LazyLock::new(|| { + std::sync::Arc::new($CTOR()) + }); + std::sync::Arc::clone(&INSTANCE) + } + } + }; +} diff --git a/datafusion/functions/src/core/union_extract.rs b/datafusion/functions/src/core/union_extract.rs index 7f93500b9cfb9..a71e2e87388d5 100644 --- a/datafusion/functions/src/core/union_extract.rs +++ b/datafusion/functions/src/core/union_extract.rs @@ -248,7 +248,6 @@ mod tests { .iter() .map(|arg| Field::new("a", arg.data_type().clone(), true).into()) .collect::>(); - let result = fun.invoke_with_args(ScalarFunctionArgs { args, arg_fields, diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index e97d79e98e977..af1bd90aedf2a 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -42,7 +42,7 @@ use datafusion_expr::expr_schema::cast_subquery; use datafusion_expr::logical_plan::Subquery; use datafusion_expr::type_coercion::binary::{comparison_coercion, like_coercion}; use datafusion_expr::type_coercion::functions::{ - data_types_with_scalar_udf, fields_with_aggregate_udf, + data_types_with_scalar_udf, fields_with_aggregate_udf, value_fields_with_lambda_udf, }; use datafusion_expr::type_coercion::other::{ get_coerce_type_for_case_expression, get_coerce_type_for_list, @@ -51,9 +51,8 @@ use datafusion_expr::type_coercion::{is_datetime, is_utf8_or_utf8view_or_large_u use datafusion_expr::utils::merge_schema; use datafusion_expr::{ is_false, is_not_false, is_not_true, is_not_unknown, is_true, is_unknown, not, - AggregateUDF, Expr, ExprSchemable, Join, LambdaTypeSignature, Limit, LogicalPlan, - Operator, Projection, ScalarUDF, Union, ValueOrLambdaParameter, WindowFrame, - WindowFrameBound, WindowFrameUnits, + AggregateUDF, Expr, ExprSchemable, Join, Limit, LogicalPlan, Operator, Projection, + ScalarUDF, Union, ValueOrLambda, WindowFrame, WindowFrameBound, WindowFrameUnits, }; /// Performs type coercion by determining the schema @@ -584,62 +583,36 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> { Ok(Transformed::yes(new_expr)) } Expr::LambdaFunction(LambdaFunction { func, args }) => { - match func.signature().type_signature { - LambdaTypeSignature::UserDefined => { - let args_types = args - .iter() - .map(|arg| match arg { - Expr::Lambda(_) => Ok(ValueOrLambdaParameter::Lambda), - _ => Ok(ValueOrLambdaParameter::Value( - arg.get_type(self.schema)?, - )), - }) - .collect::>>()?; + let current_fields = args + .iter() + .map(|arg| match arg { + Expr::Lambda(_) => Ok(ValueOrLambda::Lambda(())), + _ => Ok(ValueOrLambda::Value(arg.to_field(self.schema)?.1)), + }) + .collect::>>()?; - let value_types = func.coerce_value_types(&args_types)?; + let new_fields = + value_fields_with_lambda_udf(¤t_fields, func.as_ref())?; - if args_types - .iter() - .map(|a| match a { - ValueOrLambdaParameter::Value(ty) => Some(ty), - ValueOrLambdaParameter::Lambda => None, - }) - .eq(value_types.iter().map(|v| v.as_ref())) - { - return Ok(Transformed::no(Expr::LambdaFunction( - LambdaFunction::new(func, args), - ))); - } - - let args = std::iter::zip(args, value_types) - .map(|(arg, ty)| match (&arg, ty) { - (Expr::Lambda(_), None) => Ok(arg), - (Expr::Lambda(_), Some(_ty)) => plan_err!("{} coerce_value_types returned Some for a lambda argument", func.name()), - (_, Some(ty)) => arg.cast_to(&ty, self.schema), - (_, None) => plan_err!("{} coerce_value_types returned None for a value argument", func.name()), + let transformed = current_fields != new_fields; + + let new_args = if transformed { + std::iter::zip(args, new_fields) + .map(|(arg, new_field)| match (&arg, new_field) { + (Expr::Lambda(_lambda), ValueOrLambda::Lambda(_)) => Ok(arg), + (Expr::Lambda(_lambda), ValueOrLambda::Value(_)) => plan_err!("value_fields_with_lambda_udf return a value for a lambda argument"), + (_, ValueOrLambda::Value(new_field)) => arg.cast_to(new_field.data_type(), self.schema), + (_, ValueOrLambda::Lambda(_)) => plan_err!("value_fields_with_lambda_udf return a lambda for a value argument"), }) - .collect::>>()?; + .collect::>()? + } else { + args + }; - Ok(Transformed::yes(Expr::LambdaFunction(LambdaFunction::new( - func, args, - )))) - } - LambdaTypeSignature::VariadicAny => Ok(Transformed::no( - Expr::LambdaFunction(LambdaFunction::new(func, args)), - )), - LambdaTypeSignature::Any(number) => { - if args.len() != number { - return plan_err!( - "The function '{}' expected {number} arguments but received {}", - func.name(), args.len() - ); - } - - Ok(Transformed::no(Expr::LambdaFunction(LambdaFunction::new( - func, args, - )))) - } - } + Ok(Transformed::new_transformed( + Expr::LambdaFunction(LambdaFunction::new(func, new_args)), + transformed, + )) } // TODO: remove the next line after `Expr::Wildcard` is removed #[expect(deprecated)] @@ -854,11 +827,9 @@ fn coerce_arguments_for_signature_with_scalar_udf( return Ok(expressions); } - let current_types = expressions.iter() - .map(|e| match e { - Expr::Lambda { .. } => Ok(DataType::Null), - _ => e.get_type(schema), - }) + let current_types = expressions + .iter() + .map(|e| e.get_type(schema)) .collect::>>()?; let new_types = data_types_with_scalar_udf(¤t_types, func)?; @@ -866,10 +837,7 @@ fn coerce_arguments_for_signature_with_scalar_udf( expressions .into_iter() .enumerate() - .map(|(i, expr)| match expr { - lambda @ Expr::Lambda { .. } => Ok(lambda), - _ => expr.cast_to(&new_types[i], schema), - }) + .map(|(i, expr)| expr.cast_to(&new_types[i], schema)) .collect() } diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 8e09fbbae48d8..0abf80a622ea5 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -469,10 +469,10 @@ impl TreeNodeRewriter for Canonicalizer { match (left.as_ref(), right.as_ref(), op.swap()) { // ( - left_col @ (Expr::Column(_) | Expr::LambdaVariable(_)), - right_col @ (Expr::Column(_) | Expr::LambdaVariable(_)), + left_ref @ (Expr::Column(_) | Expr::LambdaVariable(_)), + right_ref @ (Expr::Column(_) | Expr::LambdaVariable(_)), Some(swapped_op), - ) if right_col > left_col => { + ) if right_ref > left_ref => { Ok(Transformed::yes(Expr::BinaryExpr(BinaryExpr { left: right, op: swapped_op, @@ -655,8 +655,7 @@ impl<'a> ConstEvaluator<'a> { | Expr::WindowFunction { .. } | Expr::GroupingSet(_) | Expr::Wildcard { .. } - | Expr::Placeholder(_) - | Expr::LambdaVariable(_) => false, + | Expr::Placeholder(_) => false, Expr::ScalarFunction(ScalarFunction { func, .. }) => { Self::volatility_ok(func.signature().volatility) } @@ -684,7 +683,8 @@ impl<'a> ConstEvaluator<'a> { | Expr::Cast { .. } | Expr::TryCast { .. } | Expr::InList { .. } - | Expr::Lambda(_) => true, + | Expr::Lambda(_) + | Expr::LambdaVariable(_) => true, } } diff --git a/datafusion/optimizer/tests/optimizer_integration.rs b/datafusion/optimizer/tests/optimizer_integration.rs index 11a656f2abb4c..3c6400a1dad2f 100644 --- a/datafusion/optimizer/tests/optimizer_integration.rs +++ b/datafusion/optimizer/tests/optimizer_integration.rs @@ -760,6 +760,10 @@ impl ContextProvider for MyContextProvider { Vec::new() } + fn udlf_names(&self) -> Vec { + Vec::new() + } + fn udaf_names(&self) -> Vec { Vec::new() } diff --git a/datafusion/physical-expr-adapter/src/schema_rewriter.rs b/datafusion/physical-expr-adapter/src/schema_rewriter.rs index d1957ae1892ea..61cc97dae300e 100644 --- a/datafusion/physical-expr-adapter/src/schema_rewriter.rs +++ b/datafusion/physical-expr-adapter/src/schema_rewriter.rs @@ -21,10 +21,9 @@ use std::sync::Arc; use arrow::compute::can_cast_types; use arrow::datatypes::{DataType, FieldRef, Schema, SchemaRef}; -use datafusion_common::tree_node::TreeNode; use datafusion_common::{ exec_err, - tree_node::{Transformed, TransformedResult}, + tree_node::{Transformed, TransformedResult, TreeNode}, Result, ScalarValue, }; use datafusion_functions::core::getfield::GetFieldFunc; diff --git a/datafusion/physical-expr/src/expressions/column.rs b/datafusion/physical-expr/src/expressions/column.rs index 7040fa2bfc9b4..9ca464b304306 100644 --- a/datafusion/physical-expr/src/expressions/column.rs +++ b/datafusion/physical-expr/src/expressions/column.rs @@ -188,7 +188,6 @@ pub fn with_new_schema( ); }; let new_col = Column::new(field.name(), idx); - Ok(Transformed::yes(Arc::new(new_col) as _)) } else { Ok(Transformed::no(expr)) diff --git a/datafusion/physical-expr/src/expressions/lambda.rs b/datafusion/physical-expr/src/expressions/lambda.rs index 38b64e3c7f3e1..c29edf17fca41 100644 --- a/datafusion/physical-expr/src/expressions/lambda.rs +++ b/datafusion/physical-expr/src/expressions/lambda.rs @@ -17,26 +17,31 @@ //! Physical lambda expression: [`LambdaExpr`] +use std::any::Any; use std::hash::Hash; use std::sync::Arc; -use std::{any::Any, sync::OnceLock}; -use crate::expressions::Column; +use crate::expressions::{Column, LambdaVariable}; use crate::physical_expr::PhysicalExpr; use arrow::{ datatypes::{DataType, Schema}, record_batch::RecordBatch, }; -use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; -use datafusion_common::{internal_err, HashSet, Result}; +use datafusion_common::{internal_err, tree_node::TreeNodeVisitor, HashSet, Result}; +use datafusion_common::{ + plan_err, + tree_node::{TreeNode, TreeNodeRecursion}, +}; use datafusion_expr::ColumnarValue; +use hashbrown::{hash_map::EntryRef, HashMap}; /// Represents a lambda with the given parameters names and body #[derive(Debug, Eq, Clone)] pub struct LambdaExpr { params: Vec, body: Arc, - captures: OnceLock>, + captured_columns: HashSet, + captured_variables: HashSet, } // Manually derive PartialEq and Hash to work around https://github.com/rust-lang/rust/issues/78808 [https://github.com/apache/datafusion/issues/13196] @@ -55,11 +60,29 @@ impl Hash for LambdaExpr { impl LambdaExpr { /// Create a new lambda expression with the given parameters and body - pub fn new(params: Vec, body: Arc) -> Self { + pub fn try_new(params: Vec, body: Arc) -> Result { + if all_unique(¶ms) { + Ok(Self::new(params, body)) + } else { + plan_err!("lambda params must be unique, got ({})", params.join(", ")) + } + } + + fn new(params: Vec, body: Arc) -> Self { + let (captured_columns, captured_variables) = { + let mut captures = Captures::new(¶ms); + + body.visit(&mut captures) + .expect("visitor should be infallible"); + + (captures.columns, captures.variables) + }; + Self { params, body, - captures: OnceLock::new(), + captured_columns, + captured_variables, } } @@ -73,22 +96,12 @@ impl LambdaExpr { &self.body } - pub fn captures(&self) -> &HashSet { - self.captures.get_or_init(|| { - let mut indices = HashSet::new(); - - self.body - .apply(|expr| { - if let Some(column) = expr.as_any().downcast_ref::() { - indices.insert(column.index()); - } - - Ok(TreeNodeRecursion::Continue) - }) - .expect("closure should be infallibe"); + pub(crate) fn captured_columns(&self) -> &HashSet { + &self.captured_columns + } - indices - }) + pub(crate) fn captured_variables(&self) -> &HashSet { + &self.captured_variables } } @@ -123,14 +136,97 @@ impl PhysicalExpr for LambdaExpr { self: Arc, children: Vec>, ) -> Result> { - Ok(Arc::new(Self { - params: self.params.clone(), - body: Arc::clone(&children[0]), - captures: OnceLock::new(), - })) + Ok(Arc::new(Self::new( + self.params.clone(), + Arc::clone(&children[0]), + ))) } fn fmt_sql(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "({}) -> {}", self.params.join(", "), self.body) } } + +/// Create a lambda expression +pub fn lambda( + params: impl IntoIterator>, + body: Arc, +) -> Result> { + Ok(Arc::new(LambdaExpr::try_new( + params.into_iter().map(Into::into).collect(), + body, + )?)) +} + +fn all_unique(params: &[String]) -> bool { + match params.len() { + 0 | 1 => true, + 2 => params[0] != params[1], + _ => { + let mut set = HashSet::with_capacity(params.len()); + + params.iter().all(|p| set.insert(p.as_str())) + } + } +} + +struct Captures<'a> { + shadows: HashMap<&'a str, usize>, + columns: HashSet, + variables: HashSet, +} + +impl<'a> Captures<'a> { + fn new(params: &'a [String]) -> Self { + Self { + shadows: params.iter().map(|p| (p.as_str(), 1)).collect(), + columns: HashSet::new(), + variables: HashSet::new(), + } + } +} + +impl<'n> TreeNodeVisitor<'n> for Captures<'n> { + type Node = Arc; + + fn f_down(&mut self, node: &'n Self::Node) -> Result { + if let Some(lambda) = node.as_any().downcast_ref::() { + for param in &lambda.params { + *self.shadows.entry_ref(param.as_str()).or_default() += 1; + } + } else if let Some(lambda_variable) = + node.as_any().downcast_ref::() + { + if !self.shadows.contains_key(lambda_variable.name()) { + self.variables.insert(lambda_variable.name().to_owned()); + } + } else if let Some(col) = node.as_any().downcast_ref::() { + self.columns.insert(col.index()); + } + + Ok(TreeNodeRecursion::Continue) + } + + fn f_up(&mut self, node: &'n Self::Node) -> Result { + if let Some(lambda) = node.as_any().downcast_ref::() { + for param in &lambda.params { + match self.shadows.entry_ref(param.as_str()) { + EntryRef::Occupied(mut v) => { + if *v.get() > 1 { + *v.get_mut() -= 1; + } else { + v.remove(); + } + } + EntryRef::Vacant(_v) => { + unreachable!( + "f_down should have inserted a value for every param" + ) + } + } + } + } + + Ok(TreeNodeRecursion::Continue) + } +} diff --git a/datafusion/physical-expr/src/expressions/lambda_variable.rs b/datafusion/physical-expr/src/expressions/lambda_variable.rs index 305774c3c02da..820a88f4ce0f9 100644 --- a/datafusion/physical-expr/src/expressions/lambda_variable.rs +++ b/datafusion/physical-expr/src/expressions/lambda_variable.rs @@ -15,28 +15,27 @@ // specific language governing permissions and limitations // under the License. -//! Physical lambda column reference: [`LambdaVariable`] +//! Physical lambda variable reference: [`LambdaVariable`] use std::any::Any; use std::hash::Hash; use std::sync::Arc; use crate::physical_expr::PhysicalExpr; -use arrow::array::ArrayRef; use arrow::datatypes::FieldRef; use arrow::{ datatypes::{DataType, Schema}, record_batch::RecordBatch, }; -use datafusion_common::{Result, exec_datafusion_err}; + +use datafusion_common::{exec_err, Result}; use datafusion_expr::ColumnarValue; -/// Represents the lambda column with a given name and field +/// Represents the lambda variable with a given name and field #[derive(Debug, Clone)] pub struct LambdaVariable { name: String, field: FieldRef, - value: Option, } impl Eq for LambdaVariable {} @@ -55,32 +54,23 @@ impl Hash for LambdaVariable { } impl LambdaVariable { - /// Create a new lambda column expression - pub fn new(name: &str, field: FieldRef) -> Self { + /// Create a new lambda variable expression + pub fn new(name: String, field: FieldRef) -> Self { Self { - name: name.to_owned(), + name, field, - value: None, } } - /// Get the column's name + /// Get the variable's name pub fn name(&self) -> &str { &self.name } - /// Get the column's field + /// Get the variable's field pub fn field(&self) -> &FieldRef { &self.field } - - pub fn with_value(self, value: ArrayRef) -> Self { - Self { - name: self.name, - field: self.field, - value: Some(ColumnarValue::Array(value)), - } - } } impl std::fmt::Display for LambdaVariable { @@ -90,24 +80,23 @@ impl std::fmt::Display for LambdaVariable { } impl PhysicalExpr for LambdaVariable { - /// Return a reference to Any that can be used for downcasting fn as_any(&self) -> &dyn Any { self } - /// Get the data type of this expression, given the schema of the input fn data_type(&self, _input_schema: &Schema) -> Result { Ok(self.field.data_type().clone()) } - /// Decide whether this expression is nullable, given the schema of the input fn nullable(&self, _input_schema: &Schema) -> Result { Ok(self.field.is_nullable()) } - /// Evaluate the expression - fn evaluate(&self, _batch: &RecordBatch) -> Result { - self.value.clone().ok_or_else(|| exec_datafusion_err!("Physical LambdaVariable {} missing value", self.name)) + fn evaluate(&self, batch: &RecordBatch) -> Result { + match batch.column_by_name(&self.name) { + Some(array) => Ok(ColumnarValue::Array(Arc::clone(array))), + None => exec_err!("LambdaVariable {} not present in batch", self.name), + } } fn return_field(&self, _input_schema: &Schema) -> Result { @@ -131,6 +120,6 @@ impl PhysicalExpr for LambdaVariable { } /// Create a lambda variable expression -pub fn lambda_variable(name: &str, field: FieldRef) -> Result> { - Ok(Arc::new(LambdaVariable::new(name, field))) +pub fn lambda_variable(name: impl Into, field: FieldRef) -> Result> { + Ok(Arc::new(LambdaVariable::new(name.into(), field))) } diff --git a/datafusion/physical-expr/src/expressions/mod.rs b/datafusion/physical-expr/src/expressions/mod.rs index 990e53fa23b2c..4c231fbb27629 100644 --- a/datafusion/physical-expr/src/expressions/mod.rs +++ b/datafusion/physical-expr/src/expressions/mod.rs @@ -52,7 +52,7 @@ pub use dynamic_filters::DynamicFilterPhysicalExpr; pub use in_list::{in_list, InListExpr}; pub use is_not_null::{is_not_null, IsNotNullExpr}; pub use is_null::{is_null, IsNullExpr}; -pub use lambda::LambdaExpr; +pub use lambda::{LambdaExpr, lambda}; pub use like::{like, LikeExpr}; pub use literal::{lit, Literal}; pub use negative::{negative, NegativeExpr}; diff --git a/datafusion/physical-expr/src/lambda_function.rs b/datafusion/physical-expr/src/lambda_function.rs index 97af1f9b13891..c9301c84e7763 100644 --- a/datafusion/physical-expr/src/lambda_function.rs +++ b/datafusion/physical-expr/src/lambda_function.rs @@ -44,9 +44,8 @@ use datafusion_common::{exec_err, internal_err, Result, ScalarValue}; use datafusion_expr::interval_arithmetic::Interval; use datafusion_expr::sort_properties::ExprProperties; use datafusion_expr::{ - expr_vec_fmt, ColumnarValue, LambdaFunctionArgs, LambdaFunctionLambdaArg, - LambdaReturnFieldArgs, LambdaUDF, ValueOrLambda, ValueOrLambdaField, - ValueOrLambdaParameter, Volatility, + expr_vec_fmt, ColumnarValue, LambdaArgument, LambdaFunctionArgs, + LambdaReturnFieldArgs, LambdaUDF, ValueOrLambda, Volatility, }; /// Physical expression of a lambda function @@ -100,8 +99,8 @@ impl LambdaFunctionExpr { .map(|e| { let field = e.return_field(schema)?; match e.as_any().downcast_ref::() { - Some(_lambda) => Ok(ValueOrLambdaField::Lambda(field)), - None => Ok(ValueOrLambdaField::Value(field)), + Some(_lambda) => Ok(ValueOrLambda::Lambda(field)), + None => Ok(ValueOrLambda::Value(field)), } }) .collect::>>()?; @@ -261,17 +260,19 @@ impl PhysicalExpr for LambdaFunctionExpr { .iter() .map(|e| { let field = e.return_field(batch.schema_ref())?; + match e.as_any().downcast_ref::() { - Some(_lambda) => Ok(ValueOrLambdaField::Lambda(field)), - None => Ok(ValueOrLambdaField::Value(field)), + Some(_lambda) => Ok(ValueOrLambda::Lambda(field)), + None => Ok(ValueOrLambda::Value(field)), } }) .collect::>>()?; - let args_metadata = arg_fields.iter() + let args_metadata = arg_fields + .iter() .map(|field| match field { - ValueOrLambdaField::Value(field) => ValueOrLambdaParameter::Value(Arc::clone(field)), - ValueOrLambdaField::Lambda(_field) => ValueOrLambdaParameter::Lambda, + ValueOrLambda::Value(field) => ValueOrLambda::Value(Arc::clone(field)), + ValueOrLambda::Lambda(_field) => ValueOrLambda::Lambda(()), }) .collect::>(); @@ -289,20 +290,19 @@ impl PhysicalExpr for LambdaFunctionExpr { ); } - let captures = lambda.captures(); + let indices = lambda.captured_columns(); + let variables = lambda.captured_variables(); - let params = std::iter::zip(lambda.params(), lambda_params) - .map(|(name, param)| Arc::new(param.with_name(name))) - .collect(); - - let captures = if !captures.is_empty() { + let captures = if !indices.is_empty() || !variables.is_empty() { let (fields, columns): (Vec<_>, _) = std::iter::zip( batch.schema_ref().fields(), batch.columns(), ) .enumerate() .map(|(column_index, (field, column))| { - if captures.contains(&column_index) { + if indices.contains(&column_index) + || variables.contains(field.name()) + { (Arc::clone(field), Arc::clone(column)) } else { ( @@ -324,7 +324,11 @@ impl PhysicalExpr for LambdaFunctionExpr { None }; - Ok(ValueOrLambda::Lambda(LambdaFunctionLambdaArg { + let params = std::iter::zip(lambda.params(), lambda_params) + .map(|(name, param)| Arc::new(param.with_name(name))) + .collect(); + + Ok(ValueOrLambda::Lambda(LambdaArgument { params, body: Arc::clone(lambda.body()), captures, @@ -449,7 +453,7 @@ mod tests { use crate::LambdaFunctionExpr; use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::Result; - use datafusion_expr::{LambdaFunctionArgs, LambdaUDF, LambdaSignature}; + use datafusion_expr::{LambdaFunctionArgs, LambdaSignature, LambdaUDF}; use datafusion_expr_common::columnar_value::ColumnarValue; use datafusion_physical_expr_common::physical_expr::is_volatile; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; @@ -473,6 +477,13 @@ mod tests { &self.signature } + fn lambdas_parameters( + &self, + _args: &[ValueOrLambda], + ) -> Result>>> { + unimplemented!() + } + fn return_field_from_args( &self, _args: LambdaReturnFieldArgs, diff --git a/datafusion/physical-expr/src/planner.rs b/datafusion/physical-expr/src/planner.rs index f7be4aedf555e..e337d938b4a02 100644 --- a/datafusion/physical-expr/src/planner.rs +++ b/datafusion/physical-expr/src/planner.rs @@ -17,23 +17,23 @@ use std::sync::Arc; -use crate::expressions::{lambda_variable, LambdaExpr}; -use crate::{LambdaFunctionExpr, ScalarFunctionExpr}; use crate::{ expressions::{self, binary, like, similar_to, Column, Literal}, PhysicalExpr, }; +use crate::{LambdaFunctionExpr, ScalarFunctionExpr}; use arrow::datatypes::Schema; use datafusion_common::config::ConfigOptions; use datafusion_common::metadata::FieldMetadata; use datafusion_common::{ - exec_err, not_impl_err, plan_err, DFSchema, Result, ScalarValue, + exec_err, not_impl_err, plan_datafusion_err, plan_err, DFSchema, Result, ScalarValue, ToDFSchema, }; use datafusion_expr::execution_props::ExecutionProps; use datafusion_expr::expr::{ - Alias, Cast, InList, Lambda, LambdaFunction, LambdaVariable, Placeholder, ScalarFunction + Alias, Cast, InList, Lambda, LambdaFunction, LambdaVariable, Placeholder, + ScalarFunction, }; use datafusion_expr::var_provider::is_system_variables; use datafusion_expr::var_provider::VarType; @@ -321,7 +321,6 @@ pub fn create_physical_expr( Expr::ScalarFunction(ScalarFunction { func, args }) => { let physical_args = create_physical_exprs(args, input_dfschema, execution_props)?; - let config_options = match execution_props.config_options.as_ref() { Some(config_options) => Arc::clone(config_options), None => Arc::new(ConfigOptions::default()), @@ -388,7 +387,7 @@ pub fn create_physical_expr( Expr::Placeholder(Placeholder { id, .. }) => { exec_err!("Placeholder '{id}' was not provided a value for execution.") } - Expr::LambdaFunction(LambdaFunction { func, args}) => { + Expr::LambdaFunction(LambdaFunction { func, args }) => { let physical_args = create_physical_exprs(args, input_dfschema, execution_props)?; @@ -404,17 +403,19 @@ pub fn create_physical_expr( config_options, )?)) } - Expr::Lambda(Lambda { params, body }) => Ok(Arc::new(LambdaExpr::new( + Expr::Lambda(Lambda { params, body }) => expressions::lambda( params.clone(), create_physical_expr(body, input_dfschema, execution_props)?, - ))), + ), Expr::LambdaVariable(LambdaVariable { name, field, spans: _, - }) => lambda_variable( + }) => expressions::lambda_variable( name, - Arc::clone(field), + Arc::clone(field.as_ref().ok_or_else(|| { + plan_datafusion_err!("unresolved LambdaVariable {name}") + })?), ), other => { not_impl_err!("Physical plan does not support logical expression {other:?}") diff --git a/datafusion/physical-expr/src/scalar_function.rs b/datafusion/physical-expr/src/scalar_function.rs index 6ad22671ba847..743d5b99cde95 100644 --- a/datafusion/physical-expr/src/scalar_function.rs +++ b/datafusion/physical-expr/src/scalar_function.rs @@ -105,7 +105,6 @@ impl ScalarFunctionExpr { .iter() .map(|f| f.data_type().clone()) .collect::>(); - data_types_with_scalar_udf(&arg_types, &fun)?; let arguments = args @@ -116,14 +115,11 @@ impl ScalarFunctionExpr { .map(|literal| literal.value()) }) .collect::>(); - let ret_args = ReturnFieldArgs { arg_fields: &arg_fields, scalar_arguments: &arguments, }; - let return_field = fun.return_field_from_args(ret_args)?; - Ok(Self { fun, name, @@ -287,7 +283,6 @@ impl PhysicalExpr for ScalarFunctionExpr { config_options: Arc::clone(&self.config_options), })?; - if let ColumnarValue::Array(array) = &output { if array.len() != batch.num_rows() { // If the arguments are a non-empty slice of scalar values, we can assume that @@ -372,19 +367,12 @@ impl PhysicalExpr for ScalarFunctionExpr { #[cfg(test)] mod tests { - use std::any::Any; - use std::sync::Arc; - use super::*; use crate::expressions::Column; - use crate::ScalarFunctionExpr; use arrow::datatypes::{DataType, Field, Schema}; - use datafusion_common::Result; - use datafusion_expr::{ScalarFunctionArgs, Volatility}; use datafusion_expr::{ScalarUDF, ScalarUDFImpl, Signature}; - use datafusion_expr_common::columnar_value::ColumnarValue; use datafusion_physical_expr_common::physical_expr::is_volatile; - use datafusion_physical_expr_common::physical_expr::PhysicalExpr; + use std::any::Any; /// Test helper to create a mock UDF with a specific volatility #[derive(Debug, PartialEq, Eq, Hash)] diff --git a/datafusion/proto/src/bytes/mod.rs b/datafusion/proto/src/bytes/mod.rs index e9060c0f2c986..364f0c4f28115 100644 --- a/datafusion/proto/src/bytes/mod.rs +++ b/datafusion/proto/src/bytes/mod.rs @@ -27,8 +27,8 @@ use crate::protobuf; use datafusion_common::{not_impl_err, plan_datafusion_err, Result}; use datafusion_execution::TaskContext; use datafusion_expr::{ - create_udaf, create_udf, create_udwf, AggregateUDF, Expr, LambdaUDF, LambdaSignature, LogicalPlan, - Volatility, WindowUDF, + create_udaf, create_udf, create_udwf, AggregateUDF, Expr, LambdaSignature, LambdaUDF, + LogicalPlan, Volatility, WindowUDF, }; use prost::{ bytes::{Bytes, BytesMut}, @@ -193,6 +193,7 @@ impl Serializeable for Expr { } fn udlf(&self, name: &str) -> Result> { + // if a SimpleLambdaFunction get's added, use it instead of MockLambdaUDF #[derive(Debug, PartialEq, Eq, Hash)] struct MockLambdaUDF { name: String, @@ -212,6 +213,13 @@ impl Serializeable for Expr { &self.signature } + fn lambdas_parameters( + &self, + _args: &[datafusion_expr::ValueOrLambda], + ) -> Result>>> { + not_impl_err!("mock LambdaUDF") + } + fn return_field_from_args( &self, _args: datafusion_expr::LambdaReturnFieldArgs, diff --git a/datafusion/proto/src/bytes/registry.rs b/datafusion/proto/src/bytes/registry.rs index 98f4928457679..e9ff7f7c36ac3 100644 --- a/datafusion/proto/src/bytes/registry.rs +++ b/datafusion/proto/src/bytes/registry.rs @@ -36,6 +36,10 @@ impl FunctionRegistry for NoRegistry { plan_err!("No function registry provided to deserialize, so can not deserialize User Defined Function '{name}'") } + fn udlf(&self, name: &str) -> Result> { + plan_err!("No function registry provided to deserialize, so can not deserialize User Defined Lambda Function '{name}'") + } + fn udaf(&self, name: &str) -> Result> { plan_err!("No function registry provided to deserialize, so can not deserialize User Defined Aggregate Function '{name}'") } @@ -71,10 +75,4 @@ impl FunctionRegistry for NoRegistry { fn udlfs(&self) -> HashSet { HashSet::new() } - - fn udlf(&self, name: &str) -> Result> { - plan_err!("No function registry provided to deserialize, so can not deserialize User Defined Lambda Function '{name}'") - } - - } diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index e080411b49e95..acc085cbc285c 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -622,18 +622,7 @@ pub fn serialize_expr( .unwrap_or(HashMap::new()), })), }, - Expr::LambdaFunction(func) => { - let mut buf = Vec::new(); - let _ = codec.try_encode_udlf(func.func.as_ref(), &mut buf); - protobuf::LogicalExprNode { - expr_type: Some(ExprType::ScalarUdfExpr(protobuf::ScalarUdfExprNode { - fun_name: func.name().to_string(), - fun_definition: (!buf.is_empty()).then_some(buf), - args: serialize_exprs(&func.args, codec)?, - })), - } - } - Expr::Lambda(_) | Expr::LambdaVariable(_) => { + Expr::LambdaFunction(_) | Expr::Lambda(_) | Expr::LambdaVariable(_) => { return Err(Error::General( "Proto serialization error: Lambda not implemented".to_string(), )) diff --git a/datafusion/spark/src/function/utils.rs b/datafusion/spark/src/function/utils.rs index b939dabda388d..e272d91d8a70e 100644 --- a/datafusion/spark/src/function/utils.rs +++ b/datafusion/spark/src/function/utils.rs @@ -60,7 +60,7 @@ pub mod test { let return_field = func.return_field_from_args(datafusion_expr::ReturnFieldArgs { arg_fields: &arg_fields, - scalar_arguments: &scalar_arguments_refs, + scalar_arguments: &scalar_arguments_refs }); match expected { diff --git a/datafusion/sql/examples/sql.rs b/datafusion/sql/examples/sql.rs index 3d2ff0528081c..3a306d9bb1a52 100644 --- a/datafusion/sql/examples/sql.rs +++ b/datafusion/sql/examples/sql.rs @@ -162,6 +162,10 @@ impl ContextProvider for MyContextProvider { Vec::new() } + fn udlf_names(&self) -> Vec { + Vec::new() + } + fn udaf_names(&self) -> Vec { Vec::new() } diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index 47f132d065980..41d063b6d05a4 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -22,15 +22,15 @@ use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; use arrow::datatypes::DataType; use datafusion_common::{ internal_datafusion_err, internal_err, not_impl_err, plan_datafusion_err, plan_err, - DFSchema, Dependency, Diagnostic, Result, Span, + DFSchema, Dependency, Diagnostic, HashSet, Result, Span, }; use datafusion_expr::expr::{Lambda, LambdaFunction, ScalarFunction, Unnest}; use datafusion_expr::expr::{NullTreatment, WildcardOptions, WindowFunction}; use datafusion_expr::planner::PlannerResult; use datafusion_expr::planner::{RawAggregateExpr, RawWindowExpr}; +use datafusion_expr::type_coercion::functions::value_fields_with_lambda_udf; use datafusion_expr::{ - expr, Expr, ExprSchemable, ValueOrLambdaParameter, WindowFrame, - WindowFunctionDefinition, + expr, Expr, ExprSchemable, ValueOrLambda, WindowFrame, WindowFunctionDefinition, }; use sqlparser::ast::{ DuplicateTreatment, Expr as SQLExpr, Function as SQLFunction, FunctionArg, @@ -279,9 +279,12 @@ impl SqlToRel<'_, S> { } if let Some(fm) = self.context_provider.get_lambda_meta(&name) { + // plan non-lambda arguments first so we can get theirs datatype and call + // LambdaUDF::lambdas_parameters to then plan the lambda arguments with + // resolved lambda variables enum ExprOrLambda { - ExprWithName((Expr, Option)), - Lambda(sqlparser::ast::LambdaFunction), + Expr((Expr, Option)), + Lambda((sqlparser::ast::LambdaFunction, Option)), } let pairs = args @@ -289,8 +292,39 @@ impl SqlToRel<'_, S> { .map(|a| match a { FunctionArg::Unnamed(FunctionArgExpr::Expr(SQLExpr::Lambda( lambda, - ))) => Ok(ExprOrLambda::Lambda(lambda)), - _ => Ok(ExprOrLambda::ExprWithName( + ))) => { + if !all_unique(&lambda.params) { + return plan_err!( + "lambda parameters names must be unique, got {}", + lambda.params + ); + } + + Ok(ExprOrLambda::Lambda((lambda, None))) + } + FunctionArg::Named { + name, + arg: FunctionArgExpr::Expr(SQLExpr::Lambda(lambda)), + operator: _, + } + | FunctionArg::ExprNamed { + name: SQLExpr::Identifier(name), + arg: FunctionArgExpr::Expr(SQLExpr::Lambda(lambda)), + operator: _, + } => { + if !all_unique(&lambda.params) { + return plan_err!( + "lambda parameters names must be unique, got {}", + lambda.params + ); + } + + Ok(ExprOrLambda::Lambda(( + lambda, + Some(crate::utils::normalize_ident(name)), + ))) + } + _ => Ok(ExprOrLambda::Expr( self.sql_fn_arg_to_logical_expr_with_name( a, schema, @@ -300,28 +334,28 @@ impl SqlToRel<'_, S> { }) .collect::>>()?; - let metadata = pairs + let current_fields = pairs .iter() .map(|e| match e { - ExprOrLambda::ExprWithName((expr, _name)) => { - Ok(ValueOrLambdaParameter::Value(expr.to_field(schema)?.1)) + ExprOrLambda::Expr((expr, _name)) => { + Ok(ValueOrLambda::Value(expr.to_field(schema)?.1)) } ExprOrLambda::Lambda(_lambda_function) => { - Ok(ValueOrLambdaParameter::Lambda) + Ok(ValueOrLambda::Lambda(())) } }) .collect::>>()?; - let lambdas_parameters = fm.lambdas_parameters(&metadata)?; + let coerced = value_fields_with_lambda_udf(¤t_fields, fm.as_ref())?; + + let lambdas_parameters = fm.lambdas_parameters(&coerced)?; let pairs = pairs .into_iter() .zip(lambdas_parameters) .map(|(e, lambda_parameters)| match (e, lambda_parameters) { - (ExprOrLambda::ExprWithName(expr_with_name), None) => { - Ok(expr_with_name) - } - (ExprOrLambda::Lambda(lambda), Some(lambda_params)) => { + (ExprOrLambda::Expr(expr), None) => Ok(expr), + (ExprOrLambda::Lambda((lambda, name)), Some(lambda_params)) => { if lambda.params.len() > lambda_params.len() { return plan_err!( "lambda defined {} params but UDF support only {}", @@ -351,10 +385,10 @@ impl SqlToRel<'_, S> { &mut planner_context, )?), }), - None, + name, )) } - (ExprOrLambda::ExprWithName(_), Some(_)) => plan_err!( + (ExprOrLambda::Expr(_), Some(_)) => plan_err!( "{} reported parameters for an argument that is not a lambda", fm.name() ), @@ -963,3 +997,15 @@ impl SqlToRel<'_, S> { } } } + +fn all_unique(params: &[sqlparser::ast::Ident]) -> bool { + match params.len() { + 0 | 1 => true, + 2 => params[0].value != params[1].value, + _ => { + let mut set = HashSet::with_capacity(params.len()); + + params.iter().all(|p| set.insert(p.value.as_str())) + } + } +} diff --git a/datafusion/sql/src/expr/identifier.rs b/datafusion/sql/src/expr/identifier.rs index 14433e9cf7eba..7b7c9c195be49 100644 --- a/datafusion/sql/src/expr/identifier.rs +++ b/datafusion/sql/src/expr/identifier.rs @@ -54,11 +54,12 @@ impl SqlToRel<'_, S> { // identifier. (e.g. it is "foo.bar" not foo.bar) let normalize_ident = self.ident_normalizer.normalize(id); - if let Some(field) = planner_context - .lambdas_parameters() - .get(&normalize_ident) + // lambdas parameters have higher precedence + if let Some(field) = + planner_context.lambdas_parameters().get(&normalize_ident) { - let mut lambda_var = LambdaVariable::new(normalize_ident, Arc::clone(field)); + let mut lambda_var = + LambdaVariable::new(normalize_ident, Some(Arc::clone(field))); if self.options.collect_spans { if let Some(span) = Span::try_from_sqlparser_span(id_span) { lambda_var.spans_mut().add_span(span); diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index e51f1c04cf157..151cafa74f0ff 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -1274,6 +1274,10 @@ mod tests { Vec::new() } + fn udlf_names(&self) -> Vec { + Vec::new() + } + fn udaf_names(&self) -> Vec { vec!["sum".to_string()] } diff --git a/datafusion/sql/src/unparser/dialect.rs b/datafusion/sql/src/unparser/dialect.rs index 834b0a97a47b0..e09082a350771 100644 --- a/datafusion/sql/src/unparser/dialect.rs +++ b/datafusion/sql/src/unparser/dialect.rs @@ -158,6 +158,18 @@ pub trait Dialect: Send + Sync { ) -> Result> { Ok(None) } + + /// Allows the dialect to override lambda function unparsing if the dialect has specific rules. + /// Returns None if the default unparsing should be used, or Some(ast::Expr) if there is + /// a custom implementation for the function. + fn lambda_function_to_sql_overrides( + &self, + _unparser: &Unparser, + _func_name: &str, + _args: &[Expr], + ) -> Result> { + Ok(None) + } /// Allows the dialect to choose to omit window frame in unparsing /// based on function name and window frame bound diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index 3bd669dbec071..6775acceb001b 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -533,12 +533,12 @@ impl Unparser<'_> { if let Some(expr) = self .dialect - .scalar_function_to_sql_overrides(self, func_name, args)? + .lambda_function_to_sql_overrides(self, func_name, args)? { return Ok(expr); } - self.scalar_function_to_sql(func_name, args) + self.function_to_sql_internal(func_name, args) } Expr::Lambda(Lambda { params, body }) => { Ok(ast::Expr::Lambda(ast::LambdaFunction { @@ -566,11 +566,11 @@ impl Unparser<'_> { "get_field" => self.get_field_to_sql(args), "map" => self.map_to_sql(args), // TODO: support for the construct and access functions of the `map` type - _ => self.scalar_function_to_sql_internal(func_name, args), + _ => self.function_to_sql_internal(func_name, args), } } - fn scalar_function_to_sql_internal( + fn function_to_sql_internal( &self, func_name: &str, args: &[Expr], diff --git a/datafusion/sql/tests/common/mod.rs b/datafusion/sql/tests/common/mod.rs index 6c9ac4bf70046..66ab3acf0eec2 100644 --- a/datafusion/sql/tests/common/mod.rs +++ b/datafusion/sql/tests/common/mod.rs @@ -276,6 +276,10 @@ impl ContextProvider for MockContextProvider { fn udf_names(&self) -> Vec { self.state.scalar_functions.keys().cloned().collect() } + + fn udlf_names(&self) -> Vec { + self.state.lambda_functions.keys().cloned().collect() + } fn udaf_names(&self) -> Vec { self.state.aggregate_functions.keys().cloned().collect() diff --git a/datafusion/sqllogictest/src/test_context.rs b/datafusion/sqllogictest/src/test_context.rs index b499401e5589c..583b9ebabae4e 100644 --- a/datafusion/sqllogictest/src/test_context.rs +++ b/datafusion/sqllogictest/src/test_context.rs @@ -23,11 +23,14 @@ use std::path::Path; use std::sync::Arc; use arrow::array::{ - Array, ArrayRef, BinaryArray, Float64Array, Int32Array, LargeBinaryArray, - LargeStringArray, StringArray, TimestampNanosecondArray, UnionArray, + Array, ArrayRef, BinaryArray, FixedSizeListArray, Float64Array, Int32Array, + LargeBinaryArray, LargeStringArray, ListViewArray, StringArray, + TimestampNanosecondArray, UnionArray, }; use arrow::buffer::ScalarBuffer; -use arrow::datatypes::{DataType, Field, Schema, SchemaRef, TimeUnit, UnionFields}; +use arrow::datatypes::{ + DataType, Field, Int32Type, Schema, SchemaRef, TimeUnit, UnionFields, +}; use arrow::record_batch::RecordBatch; use datafusion::catalog::{ CatalogProvider, MemoryCatalogProvider, MemorySchemaProvider, Session, @@ -143,6 +146,10 @@ impl TestContext { info!("Registering dummy async udf"); register_async_abs_udf(test_ctx.session_ctx()) } + "lambda.slt" => { + info!("Registering table with ListView column"); + register_table_with_list_view(test_ctx.session_ctx()) + } _ => { info!("Using default SessionContext"); } @@ -513,3 +520,26 @@ fn register_async_abs_udf(ctx: &SessionContext) { let udf = AsyncScalarUDF::new(Arc::new(async_abs)); ctx.register_udf(udf.into_scalar_udf()); } + +fn register_table_with_list_view(session_ctx: &SessionContext) { + let data = vec![ + Some(vec![Some(0), Some(1)]), + Some(vec![Some(3), None]), + Some(vec![None, None]), + None, + ]; + let list_array = FixedSizeListArray::from_iter_primitive::(data, 2); + let list_view: ListViewArray = list_array.into(); + + let schema = Schema::new(vec![Field::new( + "list_view", + list_view.data_type().clone(), + true, + )]); + let batch = + RecordBatch::try_new(Arc::new(schema), vec![Arc::new(list_view)]).unwrap(); + + session_ctx + .register_batch("table_with_list_view", batch) + .unwrap(); +} diff --git a/datafusion/sqllogictest/test_files/lambda.slt b/datafusion/sqllogictest/test_files/lambda.slt index af5334a644421..cfd56298baa2a 100644 --- a/datafusion/sqllogictest/test_files/lambda.slt +++ b/datafusion/sqllogictest/test_files/lambda.slt @@ -16,7 +16,7 @@ # under the License. ############# -## Array Expressions Tests +## Lambda Expressions Tests ############# statement ok @@ -31,27 +31,30 @@ AS VALUES statement ok CREATE TABLE t AS SELECT 1 as f, [ [ [2, 3], [2] ], [ [1] ], [ [] ] ] as v, 1 as n; +#test potentialy sliced list +select array_transform(tt.column_1, v -> v*2) from tt order by column_2 limit 1; + query I? SELECT t.n, array_transform([], e1 -> t.n) from t; ---- 1 [] query ? -SELECT array_transform([1], e1 -> (select n from t)); +SELECT array_transform([1, 2], e1 -> (select n from t)); ---- -[1] +[1, 1] query ? SELECT array_transform(t.v, (v1, i) -> array_transform(v1, (v2, j) -> array_transform(v2, v3 -> j)) ) from t; ---- -[[[0, 1], [0]], [[0]], [[]]] +[[[1, 1], [2]], [[1]], [[]]] query I? SELECT t.n, array_transform([1, 2], (e) -> n) from t; ---- 1 [1, 1] -# selection pushdown not working yet +# projection pushdown query ? SELECT array_transform([1, 2], (e) -> n) from t; ---- @@ -60,22 +63,22 @@ SELECT array_transform([1, 2], (e) -> n) from t; query ? SELECT array_transform([1, 2], (e, i) -> i) from t; ---- -[0, 1] +[1, 2] # type coercion query ? -SELECT array_transform([1, 2], (e, i) -> e+i) from t; +SELECT array_transform([1.0, 2.0], v -> v + t.n) from t; ---- -[1, 3] +[2.0, 4.0] query TT -EXPLAIN SELECT array_transform([1, 2], (e, i) -> e+i); +EXPLAIN SELECT array_transform([1.0, 2.0], v -> v + t.n) from t; ---- logical_plan -01)Projection: array_transform(List([1, 2]), (e, i) -> e + CAST(i AS Int64)) AS array_transform(make_array(Int64(1),Int64(2)),(e, i) -> e + i) +01)Projection: List([2.0, 4.0]) AS array_transform(make_array(Float64(1),Float64(2)),(e, i) -> e + i) 02)--EmptyRelation: rows=1 physical_plan -01)ProjectionExec: expr=[array_transform([1, 2], (e, i) -> e@-1 + CAST(i@-1 AS Int64)) as array_transform(make_array(Int64(1),Int64(2)),(e, i) -> e + i)] +01)ProjectionExec: expr=[[2.0, 4.0] as array_transform(make_array(Float64(1),Float64(2)),(e, i) -> e + i)] 02)--PlaceholderRowExec #cse @@ -118,10 +121,10 @@ query TT EXPLAIN SELECT array_transform([1,2,3,4,5], v -> v*2); ---- logical_plan -01)Projection: array_transform(List([1, 2, 3, 4, 5]), (v) -> v * Int64(2)) AS array_transform(make_array(Int64(1),Int64(2),Int64(3),Int64(4),Int64(5)),(v) -> v * Int64(2)) +01)Projection: List([2, 4, 6, 8, 10]) AS array_transform(make_array(Int64(1),Int64(2),Int64(3),Int64(4),Int64(5)),(v) -> v * Int64(2)) 02)--EmptyRelation: rows=1 physical_plan -01)ProjectionExec: expr=[array_transform([1, 2, 3, 4, 5], (v) -> v@-1 * 2) as array_transform(make_array(Int64(1),Int64(2),Int64(3),Int64(4),Int64(5)),(v) -> v * Int64(2))] +01)ProjectionExec: expr=[[2, 4, 6, 8, 10] as array_transform(make_array(Int64(1),Int64(2),Int64(3),Int64(4),Int64(5)),(v) -> v * Int64(2))] 02)--PlaceholderRowExec query ? @@ -148,33 +151,62 @@ SELECT array_transform([1,2,3,4,5], v -> v*2); # expr simplifier query TT -EXPLAIN SELECT v = v, array_transform([1], v -> v = v) from t; +EXPLAIN SELECT v = v, array_transform([t.n], v -> v = v) from t; ---- logical_plan -01)Projection: Boolean(true) AS t.v = t.v, array_transform(List([1]), (v) -> v IS NOT NULL OR Boolean(NULL)) AS array_transform(make_array(Int64(1)),(v) -> v = v) +01)Projection: Boolean(true) AS t.v = t.v, List([true]) AS array_transform(make_array(Int64(1)),(v) -> v = v) 02)--TableScan: t projection=[] physical_plan -01)ProjectionExec: expr=[true as t.v = t.v, array_transform([1], (v) -> v@-1 IS NOT NULL OR NULL) as array_transform(make_array(Int64(1)),(v) -> v = v)] +01)ProjectionExec: expr=[true as t.v = t.v, [true] as array_transform(make_array(Int64(1)),(v) -> v = v)] 02)--DataSourceExec: partitions=1, partition_sizes=[1] query error select array_transform(); ---- -DataFusion error: Execution error: array_transform expects a value follewed by a lambda, got [] +DataFusion error: Execution error: array_transform function requires 2 arguments, got 0 -query error DataFusion error: Execution error: expected list, got Field \{ "Int64\(1\)": Int64 \} +query error DataFusion error: Error during planning: array_transform expected a list as first argument, got Int64 select array_transform(1, v -> v*2); -query error DataFusion error: Execution error: array_transform expects a value follewed by a lambda, got \[Lambda, Value\(Field \{ name: "make_array\(Int64\(1\),Int64\(2\)\)", data_type: List\(Field \{ data_type: Int64, nullable: true \}\), nullable: true \}\)\] +query error DataFusion error: Error during planning: array_transform expects a value followed by a lambda, got Lambda\(\(\)\) and Value\(List\(Field \{ data_type: Int64, nullable: true \}\)\) select array_transform(v -> v*2, [1, 2]); query error DataFusion error: Error during planning: lambda defined 3 params but UDF support only 2 SELECT array_transform([1, 2], (e, i, j) -> i) from t; -#todo: this should error due to duplicate names -query ? +query error DataFusion error: Error during planning: lambda parameters names must be unique, got \(v, v\) SELECT array_transform([1], (v, v) -> v*2); + +query error DataFusion error: This feature is not implemented: Unsupported ast node in sqltorel: Lambda\(LambdaFunction \{ params: One\(Ident \{ value: "v", quote_style: None, span: Span\(Location\(1,12\)\.\.Location\(1,13\)\) \}\), body: Identifier\(Ident \{ value: "v", quote_style: None, span: Span\(Location\(1,17\)\.\.Location\(1,18\)\) \}\) \}\) +SELECT abs(v -> v); + +query error DataFusion error: This feature is not implemented: Unsupported ast node in sqltorel: Lambda\(LambdaFunction \{ params: One\(Ident \{ value: "v", quote_style: None, span: Span\(Location\(1,8\)\.\.Location\(1,9\)\) \}\), body: Identifier\(Ident \{ value: "v", quote_style: None, span: Span\(Location\(1,13\)\.\.Location\(1,14\)\) \}\) \}\) +SELECT v -> v; + +query error DataFusion error: This feature is not implemented: Unsupported ast node in sqltorel: Lambda\(LambdaFunction \{ params: One\(Ident \{ value: "v", quote_style: None, span: Span\(Location\(1,34\)\.\.Location\(1,35\)\) \}\), body: BinaryOp \{ left: Identifier\(Ident \{ value: "v", quote_style: None, span: Span\(Location\(1,39\)\.\.Location\(1,40\)\) \}\), op: Plus, right: Value\(ValueWithSpan \{ value: Number\("1", false\), span: Span\(Location\(1,41\)\.\.Location\(1,42\)\) \}\) \} \}\) +SELECT array_transform([1], v -> v -> v+1); + + +query ? +SELECT array_transform([[2, 3]], v -> array_transform(v, j -> v)); +---- +[[[2, 3], [2, 3]]] + + +query TT +explain select array_transform(list_view, a -> a+1) from table_with_list_view; +---- +initial_logical_plan +01)Projection: array_transform(table_with_list_view.list_view, (a) -> a + Int64(1)) +02)--TableScan: table_with_list_view +logical_plan after resolve_grouping_function SAME TEXT AS ABOVE +logical_plan after type_coercion Error during planning: Cannot automatically convert ListView(nullable Int32) to List(nullable Int32) + +query error +select array_transform(list_view, a -> a+1) from table_with_list_view; ---- -[0] +DataFusion error: type_coercion +caused by +Error during planning: Cannot automatically convert ListView(nullable Int32) to List(nullable Int32) From 41152c35e154b1f12966b49e8e35a7add18f78b0 Mon Sep 17 00:00:00 2001 From: gstvg <28798827+gstvg@users.noreply.github.com> Date: Sun, 15 Mar 2026 05:17:54 -0300 Subject: [PATCH 14/47] minor improvoments --- datafusion/common/src/utils/mod.rs | 12 +- datafusion/core/src/execution/context/mod.rs | 30 ++-- .../core/src/execution/session_state.rs | 89 +++++----- .../src/execution/session_state_defaults.rs | 2 +- datafusion/core/tests/parquet/mod.rs | 2 +- datafusion/execution/src/task.rs | 40 ++--- datafusion/expr/src/expr.rs | 6 +- datafusion/expr/src/expr_fn.rs | 2 +- datafusion/expr/src/expr_schema.rs | 10 +- datafusion/expr/src/logical_plan/plan.rs | 3 + datafusion/expr/src/registry.rs | 37 ++-- datafusion/expr/src/udlf.rs | 165 +++++++----------- .../functions-nested/src/array_transform.rs | 48 +++-- .../simplify_expressions/expr_simplifier.rs | 11 +- .../src/expressions/lambda_variable.rs | 2 +- .../physical-expr/src/lambda_function.rs | 16 +- datafusion/physical-expr/src/planner.rs | 2 +- datafusion/proto/src/bytes/mod.rs | 125 ++++++------- datafusion/proto/src/bytes/registry.rs | 8 +- datafusion/sql/src/expr/function.rs | 99 ++++++----- .../src/logical_plan/producer/expr/mod.rs | 4 +- 21 files changed, 333 insertions(+), 380 deletions(-) diff --git a/datafusion/common/src/utils/mod.rs b/datafusion/common/src/utils/mod.rs index 1c2ec91568a3d..1ceaa46c1bce4 100644 --- a/datafusion/common/src/utils/mod.rs +++ b/datafusion/common/src/utils/mod.rs @@ -953,11 +953,8 @@ fn list_array_values_row_number( offsets[offsets.len() - 1].to_usize().unwrap() - offsets[0].to_usize().unwrap(), ); - for (i, (&start, &end)) in std::iter::zip(&offsets[..], &offsets[1..]).enumerate() { - rows_number.extend(repeat_n( - T::Native::usize_as(i), - end.as_usize() - start.as_usize(), - )); + for (i, len) in offsets.lengths().enumerate() { + rows_number.extend(repeat_n(T::Native::usize_as(i), len)); } PrimitiveArray::new(rows_number.into(), None) @@ -971,9 +968,8 @@ fn list_array_values_index( offsets[offsets.len() - 1].to_usize().unwrap() - offsets[0].to_usize().unwrap(), ); - for (&start, &end) in std::iter::zip(&offsets[..], &offsets[1..]) { - indices - .extend((1..1 + end.as_usize() - start.as_usize()).map(T::Native::usize_as)); + for len in offsets.lengths() { + indices.extend((1..1 + len).map(T::Native::usize_as)); } PrimitiveArray::new(indices.into(), None) diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index 083ecdaf575af..5a713b6f1486e 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -1739,6 +1739,10 @@ impl FunctionRegistry for SessionContext { self.state.read().udf(name) } + fn udlf(&self, name: &str) -> Result> { + self.state.read().udlf(name) + } + fn udaf(&self, name: &str) -> Result> { self.state.read().udaf(name) } @@ -1751,6 +1755,13 @@ impl FunctionRegistry for SessionContext { self.state.write().register_udf(udf) } + fn register_udlf( + &mut self, + udlf: Arc, + ) -> Result>> { + self.state.write().register_udlf(udlf) + } + fn register_udaf( &mut self, udaf: Arc, @@ -1780,27 +1791,16 @@ impl FunctionRegistry for SessionContext { self.state.write().register_expr_planner(expr_planner) } - fn udafs(&self) -> HashSet { - self.state.read().udafs() - } - - fn udwfs(&self) -> HashSet { - self.state.read().udwfs() - } - fn udlfs(&self) -> HashSet { self.state.read().udlfs() } - fn udlf(&self, name: &str) -> Result> { - self.state.read().udlf(name) + fn udafs(&self) -> HashSet { + self.state.read().udafs() } - fn register_udlf( - &mut self, - udlf: Arc, - ) -> Result>> { - self.state.write().register_udlf(udlf) + fn udwfs(&self) -> HashSet { + self.state.read().udwfs() } } diff --git a/datafusion/core/src/execution/session_state.rs b/datafusion/core/src/execution/session_state.rs index 805b61b4c6a30..f4269e8ad18cf 100644 --- a/datafusion/core/src/execution/session_state.rs +++ b/datafusion/core/src/execution/session_state.rs @@ -59,7 +59,8 @@ use datafusion_expr::simplify::SimplifyInfo; #[cfg(feature = "sql")] use datafusion_expr::TableSource; use datafusion_expr::{ - AggregateUDF, Explain, Expr, ExprSchemable, LambdaUDF, LogicalPlan, ScalarUDF, WindowUDF + AggregateUDF, Explain, Expr, ExprSchemable, LambdaUDF, LogicalPlan, ScalarUDF, + WindowUDF, }; use datafusion_optimizer::simplify_expressions::ExprSimplifier; use datafusion_optimizer::{ @@ -255,7 +256,7 @@ impl Session for SessionState { fn scalar_functions(&self) -> &HashMap> { &self.scalar_functions } - + fn lambda_functions(&self) -> &HashMap> { &self.lambda_functions } @@ -834,7 +835,7 @@ impl SessionState { pub fn scalar_functions(&self) -> &HashMap> { &self.scalar_functions } - + /// Return reference to lambda_functions pub fn lambda_functions(&self) -> &HashMap> { &self.lambda_functions @@ -1063,7 +1064,7 @@ impl SessionStateBuilder { self.scalar_functions .get_or_insert_with(Vec::new) .extend(SessionStateDefaults::default_scalar_functions()); - + self.lambda_functions .get_or_insert_with(Vec::new) .extend(SessionStateDefaults::default_lambda_functions()); @@ -1237,7 +1238,7 @@ impl SessionStateBuilder { self.scalar_functions = Some(scalar_functions); self } - + /// Set the map of [`LambdaUDF`]s pub fn with_lambda_functions( mut self, @@ -1476,7 +1477,7 @@ impl SessionStateBuilder { } } } - + if let Some(lambda_functions) = lambda_functions { for udlf in lambda_functions { let config_options = state.config().options(); @@ -1616,7 +1617,7 @@ impl SessionStateBuilder { pub fn scalar_functions(&mut self) -> &mut Option>> { &mut self.scalar_functions } - + /// Returns the current scalar_functions value pub fn lambda_functions(&mut self) -> &mut Option>> { &mut self.lambda_functions @@ -1858,7 +1859,7 @@ impl ContextProvider for SessionContextProvider<'_> { fn udf_names(&self) -> Vec { self.state.scalar_functions().keys().cloned().collect() } - + fn udlf_names(&self) -> Vec { self.state.lambda_functions().keys().cloned().collect() } @@ -1902,6 +1903,13 @@ impl FunctionRegistry for SessionState { }) } + fn udlf(&self, name: &str) -> datafusion_common::Result> { + self.lambda_functions + .get(name) + .cloned() + .ok_or_else(|| plan_datafusion_err!("Lambda Function {name} not found")) + } + fn udaf(&self, name: &str) -> datafusion_common::Result> { let result = self.aggregate_functions.get(name); @@ -1929,6 +1937,17 @@ impl FunctionRegistry for SessionState { Ok(self.scalar_functions.insert(udf.name().into(), udf)) } + fn register_udlf( + &mut self, + udlf: Arc, + ) -> datafusion_common::Result>> { + udlf.aliases().iter().for_each(|alias| { + self.lambda_functions + .insert(alias.clone(), Arc::clone(&udlf)); + }); + Ok(self.lambda_functions.insert(udlf.name().into(), udlf)) + } + fn register_udaf( &mut self, udaf: Arc, @@ -1964,6 +1983,19 @@ impl FunctionRegistry for SessionState { Ok(udf) } + fn deregister_udlf( + &mut self, + name: &str, + ) -> datafusion_common::Result>> { + let udlf = self.lambda_functions.remove(name); + if let Some(udlf) = &udlf { + for alias in udlf.aliases() { + self.lambda_functions.remove(alias); + } + } + Ok(udlf) + } + fn deregister_udaf( &mut self, name: &str, @@ -1990,41 +2022,6 @@ impl FunctionRegistry for SessionState { Ok(udwf) } - fn udlfs(&self) -> HashSet { - self.lambda_functions.keys().cloned().collect() - } - - fn udlf(&self, name: &str) -> datafusion_common::Result> { - self.lambda_functions - .get(name) - .cloned() - .ok_or_else(|| plan_datafusion_err!("Lambda Function {name} not found")) - } - - fn register_udlf( - &mut self, - udlf: Arc, - ) -> datafusion_common::Result>> { - udlf.aliases().iter().for_each(|alias| { - self.lambda_functions - .insert(alias.clone(), Arc::clone(&udlf)); - }); - Ok(self.lambda_functions.insert(udlf.name().into(), udlf)) - } - - fn deregister_udlf( - &mut self, - name: &str, - ) -> datafusion_common::Result>> { - let udlf = self.lambda_functions.remove(name); - if let Some(udlf) = &udlf { - for alias in udlf.aliases() { - self.lambda_functions.remove(alias); - } - } - Ok(udlf) - } - fn register_function_rewrite( &mut self, rewrite: Arc, @@ -2045,6 +2042,10 @@ impl FunctionRegistry for SessionState { Ok(()) } + fn udlfs(&self) -> HashSet { + self.lambda_functions.keys().cloned().collect() + } + fn udafs(&self) -> HashSet { self.aggregate_functions.keys().cloned().collect() } @@ -2470,7 +2471,7 @@ mod tests { fn udf_names(&self) -> Vec { self.state.scalar_functions().keys().cloned().collect() } - + fn udlf_names(&self) -> Vec { self.state.lambda_functions().keys().cloned().collect() } diff --git a/datafusion/core/src/execution/session_state_defaults.rs b/datafusion/core/src/execution/session_state_defaults.rs index 54037c0a96f9c..b37241debfb0f 100644 --- a/datafusion/core/src/execution/session_state_defaults.rs +++ b/datafusion/core/src/execution/session_state_defaults.rs @@ -115,7 +115,7 @@ impl SessionStateDefaults { /// returns the list of default [`LambdaUDF`]s pub fn default_lambda_functions() -> Vec> { - vec![Arc::new(ArrayTransform::new())] + function_nested::all_default_lambda_functions() } /// returns the list of default [`AggregateUDF`]s diff --git a/datafusion/core/tests/parquet/mod.rs b/datafusion/core/tests/parquet/mod.rs index 27b8c18596476..097600e45eadd 100644 --- a/datafusion/core/tests/parquet/mod.rs +++ b/datafusion/core/tests/parquet/mod.rs @@ -516,7 +516,7 @@ fn make_uint_batches(start: u8, end: u8) -> RecordBatch { Field::new("u64", DataType::UInt64, true), ])); let v8: Vec = (start..end).collect(); - let v16: Vec = (start as _..end as u16).collect(); + let v16: Vec = (start as _..end as _).collect(); let v32: Vec = (start as _..end as _).collect(); let v64: Vec = (start as _..end as _).collect(); RecordBatch::try_new( diff --git a/datafusion/execution/src/task.rs b/datafusion/execution/src/task.rs index 70c59b6375943..e8f66e1682a52 100644 --- a/datafusion/execution/src/task.rs +++ b/datafusion/execution/src/task.rs @@ -162,6 +162,14 @@ impl FunctionRegistry for TaskContext { }) } + fn udlf(&self, name: &str) -> Result> { + let result = self.lambda_functions.get(name); + + result.cloned().ok_or_else(|| { + plan_datafusion_err!("There is no UDLF named \"{name}\" in the TaskContext") + }) + } + fn udaf(&self, name: &str) -> Result> { let result = self.aggregate_functions.get(name); @@ -204,41 +212,25 @@ impl FunctionRegistry for TaskContext { Ok(self.scalar_functions.insert(udf.name().into(), udf)) } - fn udlfs(&self) -> HashSet { - self.lambda_functions.keys().cloned().collect() - } - - fn udlf(&self, name: &str) -> Result> { - self.lambda_functions - .get(name) - .cloned() - .ok_or_else(|| plan_datafusion_err!("Lambda Function {name} not found")) - } - fn register_udlf( &mut self, udlf: Arc, ) -> Result>> { + udlf.aliases().iter().for_each(|alias| { + self.lambda_functions + .insert(alias.clone(), Arc::clone(&udlf)); + }); Ok(self.lambda_functions.insert(udlf.name().into(), udlf)) } - fn deregister_udlf( - &mut self, - name: &str, - ) -> Result>> { - let udlf = self.lambda_functions.remove(name); - if let Some(udlf) = &udlf { - for alias in udlf.aliases() { - self.lambda_functions.remove(alias); - } - } - Ok(udlf) - } - fn expr_planners(&self) -> Vec> { vec![] } + fn udlfs(&self) -> HashSet { + self.lambda_functions.keys().cloned().collect() + } + fn udafs(&self) -> HashSet { self.aggregate_functions.keys().cloned().collect() } diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index f73e6080d9dd0..fe4bd4cace057 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -413,7 +413,9 @@ pub enum Expr { /// Invoke a [`LambdaUDF`] with a set of arguments #[derive(Clone, Eq, PartialOrd, Debug)] pub struct LambdaFunction { + /// The function pub func: Arc, + /// List of expressions to feed to the functions as arguments pub args: Vec, } @@ -1319,7 +1321,9 @@ impl GroupingSet { /// A Lambda expression with a set of parameters names and a body #[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)] pub struct Lambda { + /// The parameters names pub params: Vec, + /// The body expression pub body: Box, } @@ -3810,7 +3814,7 @@ impl Display for Expr { write!(f, "({}) -> {body}", params.join(", ")) } Expr::LambdaVariable(c) => { - write!(f, "{}", c.name) + f.write_str(&c.name) } } } diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index ec8331ecbed73..ab38cdd885a35 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -733,7 +733,7 @@ pub fn lambda(params: impl IntoIterator>, body: Expr) - /// Create an unresolved lambda variable expression /// -/// The expression tree or LogicalPlan which +/// The expression tree or [`LogicalPlan`] which /// owns this variable must be resolved before usage with either /// [`Expr::resolve_lambdas_variables`] or [`LogicalPlan::resolve_lambdas_variables`]. /// diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index 6c496cf7e158e..e5286573332bb 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -369,12 +369,9 @@ impl ExprSchemable for Expr { Ok(nullable) } Expr::Lambda(l) => l.body.nullable(input_schema), - Expr::LambdaVariable(l) => Ok(l - .field + Expr::LambdaVariable(LambdaVariable { name, field, .. }) => Ok(field .as_ref() - .ok_or_else(|| { - plan_datafusion_err!("unresolved LambdaVariable {}", l.name) - })? + .ok_or_else(|| plan_datafusion_err!("unresolved LambdaVariable {name}"))? .is_nullable()), } } @@ -648,7 +645,8 @@ impl ExprSchemable for Expr { }) .collect::>>()?; - let new_fields = value_fields_with_lambda_udf(&arg_fields, func.func.as_ref())?; + let new_fields = + value_fields_with_lambda_udf(&arg_fields, func.func.as_ref())?; let arguments = func .args diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 1fa9e3aa6ab84..8ae933c0c42bc 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -2053,6 +2053,9 @@ impl LogicalPlan { Wrapper(self) } + /// Return a `LogicalPLan` with all [`LambdaVariable`] resolved + /// + /// [`LambdaVariable`]: crate::expr::LambdaVariable pub fn resolve_lambdas_variables(self) -> Result> { self.transform_with_subqueries(|plan| { let schema = merge_schema(&plan.inputs()); diff --git a/datafusion/expr/src/registry.rs b/datafusion/expr/src/registry.rs index 92aa39d64c98d..069bb7b352994 100644 --- a/datafusion/expr/src/registry.rs +++ b/datafusion/expr/src/registry.rs @@ -44,7 +44,7 @@ pub trait FunctionRegistry { /// `name`. fn udf(&self, name: &str) -> Result>; - /// Returns a reference to the user defined lambda function (udf) named + /// Returns a reference to the user defined lambda function (udlf) named /// `name`. fn udlf(&self, name: &str) -> Result>; @@ -206,6 +206,13 @@ impl FunctionRegistry for MemoryFunctionRegistry { .ok_or_else(|| plan_datafusion_err!("Function {name} not found")) } + fn udlf(&self, name: &str) -> Result> { + self.udlfs + .get(name) + .cloned() + .ok_or_else(|| plan_datafusion_err!("Lambda Function {name} not found")) + } + fn udaf(&self, name: &str) -> Result> { self.udafs .get(name) @@ -223,6 +230,12 @@ impl FunctionRegistry for MemoryFunctionRegistry { fn register_udf(&mut self, udf: Arc) -> Result>> { Ok(self.udfs.insert(udf.name().to_string(), udf)) } + fn register_udlf( + &mut self, + udlf: Arc, + ) -> Result>> { + Ok(self.udlfs.insert(udlf.name().into(), udlf)) + } fn register_udaf( &mut self, udaf: Arc, @@ -237,29 +250,15 @@ impl FunctionRegistry for MemoryFunctionRegistry { vec![] } - fn udafs(&self) -> HashSet { - self.udafs.keys().cloned().collect() - } - - fn udwfs(&self) -> HashSet { - self.udwfs.keys().cloned().collect() - } - fn udlfs(&self) -> HashSet { self.udlfs.keys().cloned().collect() } - fn udlf(&self, name: &str) -> Result> { - self.udlfs - .get(name) - .cloned() - .ok_or_else(|| plan_datafusion_err!("Lambda Function {name} not found")) + fn udafs(&self) -> HashSet { + self.udafs.keys().cloned().collect() } - fn register_udlf( - &mut self, - udlf: Arc, - ) -> Result>> { - Ok(self.udlfs.insert(udlf.name().into(), udlf)) + fn udwfs(&self) -> HashSet { + self.udwfs.keys().cloned().collect() } } diff --git a/datafusion/expr/src/udlf.rs b/datafusion/expr/src/udlf.rs index b9d1f193db3dc..ce52ce2034e51 100644 --- a/datafusion/expr/src/udlf.rs +++ b/datafusion/expr/src/udlf.rs @@ -57,8 +57,6 @@ pub enum LambdaTypeSignature { /// /// If this signature is specified, /// DataFusion will call [`LambdaUDF::coerce_value_types`] to prepare argument types. - /// - /// [`LambdaUDF::coerce_value_types`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/trait.LambdaUDF.html#method.coerce_value_types UserDefined, /// One or more lambdas or arguments with arbitrary types VariadicAny, @@ -80,7 +78,7 @@ pub struct LambdaSignature { pub volatility: Volatility, /// Optional parameter names for the function arguments. /// - /// If provided, enables named argument notation for function calls (e.g., `func(a => 1, b => 2)`). + /// If provided, enables named argument notation for function calls (e.g., `func(a => 1, b => v -> v+1)`). /// /// Defaults to `None`, meaning only positional arguments are supported. pub parameter_names: Option>, @@ -170,9 +168,12 @@ impl Hash for dyn LambdaUDF { /// lambda function. #[derive(Debug, Clone)] pub struct LambdaFunctionArgs { - /// The evaluated arguments to the function + /// The evaluated arguments and lambdas to the function pub args: Vec>, /// Field associated with each arg, if it exists + /// For lambdas, it will be the field of the result of + /// the lambda if evaluated with the parameters + /// returned from [`LambdaUDF::lambdas_parameters`] pub arg_fields: Vec>, /// The number of rows in record batch being evaluated pub number_rows: usize, @@ -199,22 +200,41 @@ pub struct LambdaArgument { /// /// For example, for `array_transform([2], v -> -v)`, /// this will be `vec![Field::new("v", DataType::Int32, true)]` - pub params: Vec, + params: Vec, /// The body of the lambda /// /// For example, for `array_transform([2], v -> -v)`, /// this will be the physical expression of `-v` - pub body: Arc, + body: Arc, /// A RecordBatch containing at least the captured columns inside this lambda body, if any /// Note that it may contain additional, non-specified columns, but that's a implementation detail /// /// For example, for `array_transform([2], v -> v + a + b)`, /// this will be a `RecordBatch` with at least two columns, `a` and `b` - pub captures: Option, + captures: Option, } impl LambdaArgument { - /// For adjusting multiple arrays by indices, use [`take_arrays`] + pub fn new( + params: Vec, + body: Arc, + captures: Option, + ) -> Self { + Self { + params, + body, + captures, + } + } + + /// Evaluate this lambda + /// `args` should evalute to the value of each parameter + /// of the correspondent lambda returned in [LambdaUDF::lambdas_parameters]. + /// + /// `adjust` should adjust the captured columns of this + /// lambda, if any, relative to it's parameters + /// + /// Tip: For adjusting multiple arrays by indices, use [`take_arrays`] /// /// [`take_arrays`]: arrow::compute::take_arrays pub fn evaluate( @@ -362,66 +382,6 @@ pub enum ValueOrLambda { /// See [`array_transform.rs`] for a commented complete implementation /// /// [`array_transform.rs`]: https://github.com/apache/datafusion/blob/main/datafusion/functions-nested/src/array_transform.rs -/// -/// # Basic Example -/// ``` -/// # use std::any::Any; -/// # use std::sync::LazyLock; -/// # use arrow::datatypes::DataType; -/// # use datafusion_common::{DataFusionError, plan_err, Result}; -/// # use datafusion_expr::{col, ColumnarValue, Documentation, LambdaFunctionArgs, LambdaSignature, Volatility}; -/// # use datafusion_expr::LambdaUDF; -/// # use datafusion_expr::lambda_doc_sections::DOC_SECTION_MATH; -/// /// This struct for a simple UDF that adds one to an int32 -/// #[derive(Debug, PartialEq, Eq, Hash)] -/// struct AddOne { -/// signature: LambdaSignature, -/// } -/// -/// impl AddOne { -/// fn new() -> Self { -/// Self { -/// signature: LambdaSignature::new(Volatility::Immutable), -/// } -/// } -/// } -/// -/// static DOCUMENTATION: LazyLock = LazyLock::new(|| { -/// Documentation::builder(DOC_SECTION_MATH, "Add one to an int32", "add_one(2)") -/// .with_argument("arg1", "The int32 number to add one to") -/// .build() -/// }); -/// -/// fn get_doc() -> &'static Documentation { -/// &DOCUMENTATION -/// } -/// -/// /// Implement the LambdaUDF trait for AddOne -/// impl LambdaUDF for AddOne { -/// fn as_any(&self) -> &dyn Any { self } -/// fn name(&self) -> &str { "add_one" } -/// fn signature(&self) -> &LambdaSignature { &self.signature } -/// fn return_type(&self, args: &[DataType]) -> Result { -/// if !matches!(args.get(0), Some(&DataType::Int32)) { -/// return plan_err!("add_one only accepts Int32 arguments"); -/// } -/// Ok(DataType::Int32) -/// } -/// // The actual implementation would add one to the argument -/// fn invoke_with_args(&self, args: LambdaFunctionArgs) -> Result { -/// unimplemented!() -/// } -/// fn documentation(&self) -> Option<&Documentation> { -/// Some(get_doc()) -/// } -/// } -/// -/// // Create a new LambdaUDF from the implementation -/// let add_one = LambdaUDF::from(AddOne::new()); -/// -/// // Call the function `add_one(col)` -/// let expr = add_one.call(vec![col("a")]); -/// ``` pub trait LambdaUDF: Debug + DynEq + DynHash + Send + Sync { /// Returns this object as an [`Any`] trait object fn as_any(&self) -> &dyn Any; @@ -488,29 +448,49 @@ pub trait LambdaUDF: Debug + DynEq + DynHash + Send + Sync { None } - /// Returns the parameters that any lambda supports + /// Returns a list of the same size as args where each value is the logic below applied to value at the correspondent position in args: + /// + /// If it's a value, return None + /// If it's a lambda, return the list of all parameters that that lambda supports + /// + /// Example for array_transform: + /// + /// `array_transform([2.0, 8.0], v -> v > 4.0)` + /// + /// ```ignore + /// let lambdas_parameters = array_transform.lambdas_parameters(&[ + /// ValueOrLambdaParameter::Value(Field::new("", DataType::new_list(DataType::Float32, false)))]), // the Field of the literal `[2, 8]` + /// ValueOrLambdaParameter::Lambda, // A lambda + /// ]?; + /// + /// assert_eq!( + /// lambdas_parameters, + /// vec![ + /// // it's a value, return None + /// None, + /// // it's a lambda, return it's supported parameters, regardless of how many are actually used + /// Some(vec![ + /// // the value being transformed + /// Field::new("", DataType::Float32, false), + /// // the 1-based index being transformed, not used on the example above, + /// //but implementations doesn't need to care about it + /// Field::new("", DataType::Int32, false), + /// ]) + /// ] + /// ) + /// ``` + /// + /// The implementation can assume that some other part of the code has coerced + /// the actual argument types to match [`Self::signature`]. fn lambdas_parameters( &self, args: &[ValueOrLambda], ) -> Result>>>; /// What type will be returned by this function, given the arguments? - /// - /// By default, this function calls [`Self::return_type`] with the - /// types of each argument. - /// - /// # Notes - /// - /// For the majority of UDFs, implementing [`Self::return_type`] is sufficient, - /// as the result type is typically a deterministic function of the input types - /// (e.g., `sqrt(f32)` consistently yields `f32`). Implementing this method directly - /// is generally unnecessary unless the return type depends on runtime values. - /// - /// This function can be used for more advanced cases such as: - /// - /// 1. specifying nullability - /// 2. return types based on the **values** of the arguments (rather than - /// their **types**. + /// + /// The implementation can assume that some other part of the code has coerced + /// the actual argument types to match [`Self::signature`]. /// /// # Example creating `Field` /// @@ -532,21 +512,6 @@ pub trait LambdaUDF: Debug + DynEq + DynHash + Send + Sync { /// } /// # } /// ``` - /// - /// # Output Type based on Values - /// - /// For example, the following two function calls get the same argument - /// types (something and a `Utf8` string) but return different types based - /// on the value of the second argument: - /// - /// * `arrow_cast(x, 'Int16')` --> `Int16` - /// * `arrow_cast(x, 'Float32')` --> `Float32` - /// - /// # Requirements - /// - /// This function **must** consistently return the same type for the same - /// logical input even if the input is simplified (e.g. it must return the same - /// value for `('foo' | 'bar')` as it does for ('foobar'). fn return_field_from_args(&self, args: LambdaReturnFieldArgs) -> Result; /// Invoke the function returning the appropriate result. diff --git a/datafusion/functions-nested/src/array_transform.rs b/datafusion/functions-nested/src/array_transform.rs index 843029114fa5c..21551ddcf9881 100644 --- a/datafusion/functions-nested/src/array_transform.rs +++ b/datafusion/functions-nested/src/array_transform.rs @@ -15,12 +15,12 @@ // specific language governing permissions and limitations // under the License. -//! [`ScalarUDFImpl`] definitions for array_transform function. +//! [`LambdaUDF`] definitions for array_transform function. use arrow::{ - array::{ - Array, ArrayRef, AsArray, FixedSizeListArray, LargeListArray, ListArray, - }, compute::take_arrays, datatypes::{DataType, Field, FieldRef} + array::{Array, ArrayRef, AsArray, FixedSizeListArray, LargeListArray, ListArray}, + compute::take_arrays, + datatypes::{DataType, Field, FieldRef}, }; use datafusion_common::{ exec_err, plan_err, @@ -31,8 +31,8 @@ use datafusion_common::{ Result, }; use datafusion_expr::{ - ColumnarValue, Documentation, LambdaFunctionArgs, LambdaSignature, LambdaUDF, - ValueOrLambda, Volatility, + ColumnarValue, Documentation, LambdaFunctionArgs, LambdaReturnFieldArgs, + LambdaSignature, LambdaUDF, ValueOrLambda, Volatility, }; use datafusion_macros::user_doc; use std::{any::Any, fmt::Debug, sync::Arc}; @@ -148,10 +148,7 @@ impl LambdaUDF for ArrayTransform { Ok(vec![None, Some(vec![value, index])]) } - fn return_field_from_args( - &self, - args: datafusion_expr::LambdaReturnFieldArgs, - ) -> Result> { + fn return_field_from_args(&self, args: LambdaReturnFieldArgs) -> Result> { let (list, lambda) = value_lambda_pair(self.name(), args.arg_fields)?; //TODO: should metadata be copied into the transformed array? @@ -178,7 +175,8 @@ impl LambdaUDF for ArrayTransform { let (list, lambda) = value_lambda_pair(self.name(), &args.args)?; let list_array = list.to_array(args.number_rows)?; - // as per list_values docs, if list_array is sliced, list_values will be sliced too, + + // as per list_values docs, if list_array is sliced, list_values will be sliced too, // so before constructing the transformed array below, we must adjust the list offsets with // adjust_offsets_for_slice let list_values = list_values(&list_array)?; @@ -188,25 +186,19 @@ impl LambdaUDF for ArrayTransform { // list_values_row_number is not cheap so is important to avoid it when no column is captured let mut adjust_indices = None; - // use closures so that lambda.evaluate calls only the needed ones - // based on the number of arguments avoiding unnecessary computations + // by passing closures, lambda.evaluate can evaluate only those actually needed let values_param = || Ok(Arc::clone(&list_values)); let indices_param = || list_values_index(&list_array); - // call the transforming lambda with the record batch composed of the list values merged with captured columns + // call the transforming lambda let transformed_values = lambda - .evaluate( - &[&values_param, &indices_param], - |arrays| { - let indices = match &adjust_indices { - Some(v) => v, - None => { - adjust_indices.insert(list_values_row_number(&list_array)?) - } - }; - Ok(take_arrays(arrays, indices, None)?) - }, - )? + .evaluate(&[&values_param, &indices_param], |arrays| { + let indices = match &adjust_indices { + Some(v) => v, + None => adjust_indices.insert(list_values_row_number(&list_array)?), + }; + Ok(take_arrays(arrays, indices, None)?) + })? .into_array(list_values.len())?; let field = match args.return_field.data_type() { @@ -225,7 +217,7 @@ impl LambdaUDF for ArrayTransform { let transformed_list = match list_array.data_type() { DataType::List(_) => { let list = list_array.as_list(); - // since we called list_values above which would return sliced values for + // since we called list_values above which would return sliced values for // a sliced list, we must adjust the offsets here as otherwise they would be invalid let adjusted_offsets = adjust_offsets_for_slice(list); @@ -238,7 +230,7 @@ impl LambdaUDF for ArrayTransform { } DataType::LargeList(_) => { let large_list = list_array.as_list(); - // since we called list_values above which would return sliced values for + // since we called list_values above which would return sliced values for // a sliced list, we must adjust the offsets here as otherwise they would be invalid let adjusted_offsets = adjust_offsets_for_slice(large_list); diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 0abf80a622ea5..cc16fdbdc90f8 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -17,15 +17,15 @@ //! Expression simplification API -use std::collections::HashSet; -use std::ops::Not; -use std::{borrow::Cow, sync::Arc}; - use arrow::{ array::{new_null_array, AsArray}, datatypes::{DataType, Field, Schema}, record_batch::RecordBatch, }; +use std::borrow::Cow; +use std::collections::HashSet; +use std::ops::Not; +use std::sync::Arc; use datafusion_common::{ cast::{as_large_list_array, as_list_array}, @@ -33,8 +33,7 @@ use datafusion_common::{ tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeRewriter}, }; use datafusion_common::{ - exec_datafusion_err, internal_err, DFSchema, DataFusionError, Result, - ScalarValue, + exec_datafusion_err, internal_err, DFSchema, DataFusionError, Result, ScalarValue, }; use datafusion_expr::expr::LambdaFunction; use datafusion_expr::{ diff --git a/datafusion/physical-expr/src/expressions/lambda_variable.rs b/datafusion/physical-expr/src/expressions/lambda_variable.rs index 820a88f4ce0f9..e5fab5fdb2366 100644 --- a/datafusion/physical-expr/src/expressions/lambda_variable.rs +++ b/datafusion/physical-expr/src/expressions/lambda_variable.rs @@ -75,7 +75,7 @@ impl LambdaVariable { impl std::fmt::Display for LambdaVariable { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - write!(f, "{}@-1", self.name) + write!(f, "{}@", self.name) } } diff --git a/datafusion/physical-expr/src/lambda_function.rs b/datafusion/physical-expr/src/lambda_function.rs index c9301c84e7763..0bb4c66c79b69 100644 --- a/datafusion/physical-expr/src/lambda_function.rs +++ b/datafusion/physical-expr/src/lambda_function.rs @@ -24,7 +24,7 @@ //! * the computation, that must accept each valid signature //! //! * Signature: see `Signature` -//! * Return type: a function `(arg_types) -> return_type`. E.g. for sqrt, ([f32]) -> f32, ([f64]) -> f64. +//! * Return type: a function `(arg_types) -> return_type`. E.g. for array_transform, ([[f32]], v -> v*2) -> [f32], ([[f32]], v -> v > 3.0) -> [bool]. //! //! This module also has a set of coercion rules to improve user experience: if an argument i32 is passed //! to a function that supports f64, it is coerced to f64. @@ -43,6 +43,7 @@ use datafusion_common::config::{ConfigEntry, ConfigOptions}; use datafusion_common::{exec_err, internal_err, Result, ScalarValue}; use datafusion_expr::interval_arithmetic::Interval; use datafusion_expr::sort_properties::ExprProperties; +use datafusion_expr::type_coercion::functions::value_fields_with_lambda_udf; use datafusion_expr::{ expr_vec_fmt, ColumnarValue, LambdaArgument, LambdaFunctionArgs, LambdaReturnFieldArgs, LambdaUDF, ValueOrLambda, Volatility, @@ -71,7 +72,7 @@ impl Debug for LambdaFunctionExpr { impl LambdaFunctionExpr { /// Create a new Lambda function pub fn new( - name: &str, + name: impl Into, fun: Arc, args: Vec>, return_field: FieldRef, @@ -79,7 +80,7 @@ impl LambdaFunctionExpr { ) -> Self { Self { fun, - name: name.to_owned(), + name: name.into(), args, return_field, config_options, @@ -105,7 +106,8 @@ impl LambdaFunctionExpr { }) .collect::>>()?; - // TODO: verify that input data types is consistent with function's `TypeSignature` + // verify that input data types is consistent with function's `LambdaTypeSignature` + value_fields_with_lambda_udf(&arg_fields, func.as_ref())?; let arguments = args .iter() @@ -328,11 +330,11 @@ impl PhysicalExpr for LambdaFunctionExpr { .map(|(name, param)| Arc::new(param.with_name(name))) .collect(); - Ok(ValueOrLambda::Lambda(LambdaArgument { + Ok(ValueOrLambda::Lambda(LambdaArgument::new( params, - body: Arc::clone(lambda.body()), + Arc::clone(lambda.body()), captures, - })) + ))) } (Some(_lambda), None) => exec_err!( "{} don't reported the parameters of one of it's lambdas", diff --git a/datafusion/physical-expr/src/planner.rs b/datafusion/physical-expr/src/planner.rs index e337d938b4a02..8da8627c56024 100644 --- a/datafusion/physical-expr/src/planner.rs +++ b/datafusion/physical-expr/src/planner.rs @@ -404,7 +404,7 @@ pub fn create_physical_expr( )?)) } Expr::Lambda(Lambda { params, body }) => expressions::lambda( - params.clone(), + params, create_physical_expr(body, input_dfschema, execution_props)?, ), Expr::LambdaVariable(LambdaVariable { diff --git a/datafusion/proto/src/bytes/mod.rs b/datafusion/proto/src/bytes/mod.rs index 364f0c4f28115..a3ed54bc76cb0 100644 --- a/datafusion/proto/src/bytes/mod.rs +++ b/datafusion/proto/src/bytes/mod.rs @@ -122,6 +122,59 @@ impl Serializeable for Expr { ))) } + fn udlf(&self, name: &str) -> Result> { + // if a SimpleLambdaFunction get's added, use it instead of MockLambdaUDF + #[derive(Debug, PartialEq, Eq, Hash)] + struct MockLambdaUDF { + name: String, + signature: LambdaSignature, + } + + impl LambdaUDF for MockLambdaUDF { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + &self.name + } + + fn signature(&self) -> &LambdaSignature { + &self.signature + } + + fn lambdas_parameters( + &self, + _args: &[datafusion_expr::ValueOrLambda< + arrow::datatypes::FieldRef, + (), + >], + ) -> Result>>> + { + not_impl_err!("mock LambdaUDF") + } + + fn return_field_from_args( + &self, + _args: datafusion_expr::LambdaReturnFieldArgs, + ) -> Result { + not_impl_err!("mock LambdaUDF") + } + + fn invoke_with_args( + &self, + _args: datafusion_expr::LambdaFunctionArgs, + ) -> Result { + not_impl_err!("mock LambdaUDF") + } + } + + Ok(Arc::new(MockLambdaUDF { + name: name.to_string(), + signature: LambdaSignature::variadic_any(Volatility::Immutable), + })) + } + fn udaf(&self, name: &str) -> Result> { Ok(Arc::new(create_udaf( name, @@ -158,15 +211,6 @@ impl Serializeable for Expr { "register_udf called in Placeholder Registry!" ) } - fn register_udwf( - &mut self, - _udaf: Arc, - ) -> Result>> { - datafusion_common::internal_err!( - "register_udwf called in Placeholder Registry!" - ) - } - fn register_udlf( &mut self, _udlf: Arc, @@ -175,71 +219,30 @@ impl Serializeable for Expr { "register_udlf called in Placeholder Registry!" ) } + fn register_udwf( + &mut self, + _udaf: Arc, + ) -> Result>> { + datafusion_common::internal_err!( + "register_udwf called in Placeholder Registry!" + ) + } fn expr_planners(&self) -> Vec> { vec![] } - fn udafs(&self) -> std::collections::HashSet { + fn udlfs(&self) -> std::collections::HashSet { std::collections::HashSet::default() } - fn udwfs(&self) -> std::collections::HashSet { + fn udafs(&self) -> std::collections::HashSet { std::collections::HashSet::default() } - fn udlfs(&self) -> std::collections::HashSet { + fn udwfs(&self) -> std::collections::HashSet { std::collections::HashSet::default() } - - fn udlf(&self, name: &str) -> Result> { - // if a SimpleLambdaFunction get's added, use it instead of MockLambdaUDF - #[derive(Debug, PartialEq, Eq, Hash)] - struct MockLambdaUDF { - name: String, - signature: LambdaSignature, - } - - impl LambdaUDF for MockLambdaUDF { - fn as_any(&self) -> &dyn std::any::Any { - self - } - - fn name(&self) -> &str { - &self.name - } - - fn signature(&self) -> &LambdaSignature { - &self.signature - } - - fn lambdas_parameters( - &self, - _args: &[datafusion_expr::ValueOrLambda], - ) -> Result>>> { - not_impl_err!("mock LambdaUDF") - } - - fn return_field_from_args( - &self, - _args: datafusion_expr::LambdaReturnFieldArgs, - ) -> Result { - not_impl_err!("mock LambdaUDF") - } - - fn invoke_with_args( - &self, - _args: datafusion_expr::LambdaFunctionArgs, - ) -> Result { - not_impl_err!("mock LambdaUDF") - } - } - - Ok(Arc::new(MockLambdaUDF { - name: name.to_string(), - signature: LambdaSignature::variadic_any(Volatility::Immutable), - })) - } } Expr::from_bytes_with_registry(&bytes, &PlaceHolderRegistry)?; diff --git a/datafusion/proto/src/bytes/registry.rs b/datafusion/proto/src/bytes/registry.rs index e9ff7f7c36ac3..df8281e799e0d 100644 --- a/datafusion/proto/src/bytes/registry.rs +++ b/datafusion/proto/src/bytes/registry.rs @@ -64,15 +64,15 @@ impl FunctionRegistry for NoRegistry { vec![] } - fn udafs(&self) -> HashSet { + fn udlfs(&self) -> HashSet { HashSet::new() } - fn udwfs(&self) -> HashSet { + fn udafs(&self) -> HashSet { HashSet::new() } - - fn udlfs(&self) -> HashSet { + + fn udwfs(&self) -> HashSet { HashSet::new() } } diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index 41d063b6d05a4..9b1c81dbad752 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -277,10 +277,58 @@ impl SqlToRel<'_, S> { } } } + // User-defined function (UDF) should have precedence + if let Some(fm) = self.context_provider.get_function_meta(&name) { + let (args, arg_names): (Vec, Vec>) = args + .into_iter() + .map(|a| { + self.sql_fn_arg_to_logical_expr_with_name(a, schema, planner_context) + }) + .collect::>>()? + .into_iter() + .unzip(); + + let resolved_args = if arg_names.iter().any(|name| name.is_some()) { + if let Some(param_names) = &fm.signature().parameter_names { + datafusion_expr::arguments::resolve_function_arguments( + param_names, + args, + arg_names, + )? + } else { + return plan_err!( + "Function '{}' does not support named arguments", + fm.name() + ); + } + } else { + args + }; + + // After resolution, all arguments are positional + let inner = ScalarFunction::new_udf(fm, resolved_args); + + if name.eq_ignore_ascii_case(inner.name()) { + return Ok(Expr::ScalarFunction(inner)); + } else { + // If the function is called by an alias, a verbose string representation is created + // (e.g., "my_alias(arg1, arg2)") and the expression is wrapped in an `Alias` + // to ensure the output column name matches the user's query. + let arg_names = inner + .args + .iter() + .map(|arg| arg.to_string()) + .collect::>() + .join(","); + let verbose_alias = format!("{name}({arg_names})"); + + return Ok(Expr::ScalarFunction(inner).alias(verbose_alias)); + } + } if let Some(fm) = self.context_provider.get_lambda_meta(&name) { // plan non-lambda arguments first so we can get theirs datatype and call - // LambdaUDF::lambdas_parameters to then plan the lambda arguments with + // LambdaUDF::lambdas_parameters to then plan the lambda arguments with // resolved lambda variables enum ExprOrLambda { Expr((Expr, Option)), @@ -440,55 +488,6 @@ impl SqlToRel<'_, S> { } } - // User-defined function (UDF) should have precedence - if let Some(fm) = self.context_provider.get_function_meta(&name) { - let (args, arg_names): (Vec, Vec>) = args - .into_iter() - .map(|a| { - self.sql_fn_arg_to_logical_expr_with_name(a, schema, planner_context) - }) - .collect::>>()? - .into_iter() - .unzip(); - - let resolved_args = if arg_names.iter().any(|name| name.is_some()) { - if let Some(param_names) = &fm.signature().parameter_names { - datafusion_expr::arguments::resolve_function_arguments( - param_names, - args, - arg_names, - )? - } else { - return plan_err!( - "Function '{}' does not support named arguments", - fm.name() - ); - } - } else { - args - }; - - // After resolution, all arguments are positional - let inner = ScalarFunction::new_udf(fm, resolved_args); - - if name.eq_ignore_ascii_case(inner.name()) { - return Ok(Expr::ScalarFunction(inner)); - } else { - // If the function is called by an alias, a verbose string representation is created - // (e.g., "my_alias(arg1, arg2)") and the expression is wrapped in an `Alias` - // to ensure the output column name matches the user's query. - let arg_names = inner - .args - .iter() - .map(|arg| arg.to_string()) - .collect::>() - .join(","); - let verbose_alias = format!("{name}({arg_names})"); - - return Ok(Expr::ScalarFunction(inner).alias(verbose_alias)); - } - } - // Build Unnest expression if name.eq("unnest") { let mut exprs = self.function_args_to_expr(args, schema, planner_context)?; diff --git a/datafusion/substrait/src/logical_plan/producer/expr/mod.rs b/datafusion/substrait/src/logical_plan/producer/expr/mod.rs index d1112b99536d9..3874a9a04f793 100644 --- a/datafusion/substrait/src/logical_plan/producer/expr/mod.rs +++ b/datafusion/substrait/src/logical_plan/producer/expr/mod.rs @@ -153,8 +153,8 @@ pub fn to_substrait_rex( } Expr::Unnest(expr) => not_impl_err!("Cannot convert {expr:?} to Substrait"), Expr::LambdaFunction(expr) => producer.handle_lambda_function(expr, schema), - Expr::Lambda(expr) => not_impl_err!("Cannot convert {expr:?} to Substrait"), // not yet implemented in substrait-rs - Expr::LambdaVariable(expr) => not_impl_err!("Cannot convert {expr:?} to Substrait"), // not yet implemented in substrait-rs + Expr::Lambda(expr) => not_impl_err!("Cannot convert {expr:?} to Substrait"), + Expr::LambdaVariable(expr) => not_impl_err!("Cannot convert {expr:?} to Substrait"), } } From 90eb08fa362643b74a396b3a4ed818243f89e5ca Mon Sep 17 00:00:00 2001 From: gstvg <28798827+gstvg@users.noreply.github.com> Date: Tue, 17 Mar 2026 17:36:22 -0300 Subject: [PATCH 15/47] improve lambdas --- datafusion/expr/src/expr.rs | 2 +- datafusion/expr/src/udlf.rs | 79 ------ .../optimizer/src/common_subexpr_eliminate.rs | 11 +- .../physical-expr/src/expressions/lambda.rs | 112 +++++++- .../physical-expr/src/expressions/mod.rs | 2 +- .../physical-expr/src/lambda_function.rs | 31 --- datafusion/sql/src/unparser/expr.rs | 64 ++++- datafusion/sqllogictest/src/test_context.rs | 36 +-- datafusion/sqllogictest/test_files/lambda.slt | 263 +++++++++++------- 9 files changed, 343 insertions(+), 257 deletions(-) diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 6556ee9dc2c6a..06b0e29efe6d3 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -3411,7 +3411,7 @@ impl Display for SchemaDisplay<'_> { ) } Expr::LambdaVariable(c) => { - write!(f, "{}", c.name) + f.write_str(&c.name) } } } diff --git a/datafusion/expr/src/udlf.rs b/datafusion/expr/src/udlf.rs index defb6221e1406..fac9ef5d8c6d7 100644 --- a/datafusion/expr/src/udlf.rs +++ b/datafusion/expr/src/udlf.rs @@ -19,14 +19,12 @@ use crate::expr::schema_name_from_exprs_comma_separated_without_space; use crate::simplify::{ExprSimplifyResult, SimplifyContext}; -use crate::sort_properties::{ExprProperties, SortProperties}; use crate::{ColumnarValue, Documentation, Expr}; use arrow::array::{ArrayRef, RecordBatch}; use arrow::datatypes::{DataType, Field, FieldRef, Schema}; use datafusion_common::config::ConfigOptions; use datafusion_common::{exec_err, not_impl_err, Result, ScalarValue}; use datafusion_expr_common::dyn_eq::{DynEq, DynHash}; -use datafusion_expr_common::interval_arithmetic::Interval; use datafusion_expr_common::signature::Volatility; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use std::any::Any; @@ -602,83 +600,6 @@ pub trait LambdaUDF: Debug + DynEq + DynHash + Send + Sync { } } - /// Computes the output [`Interval`] for a [`LambdaUDF`], given the input - /// intervals. - /// - /// # Parameters - /// - /// * `children` are the intervals for the children (inputs) of this function. - /// - /// # Example - /// - /// If the function is `ABS(a)`, and the input interval is `a: [-3, 2]`, - /// then the output interval would be `[0, 3]`. - fn evaluate_bounds(&self, _input: &[&Interval]) -> Result { - // We cannot assume the input datatype is the same of output type. - Interval::make_unbounded(&DataType::Null) - } - - /// Updates bounds for child expressions, given a known [`Interval`]s for this - /// function. - /// - /// This function is used to propagate constraints down through an - /// expression tree. - /// - /// # Parameters - /// - /// * `interval` is the currently known interval for this function. - /// * `inputs` are the current intervals for the inputs (children) of this function. - /// - /// # Returns - /// - /// A `Vec` of new intervals for the children, in order. - /// - /// If constraint propagation reveals an infeasibility for any child, returns - /// [`None`]. If none of the children intervals change as a result of - /// propagation, may return an empty vector instead of cloning `children`. - /// This is the default (and conservative) return value. - /// - /// # Example - /// - /// If the function is `ABS(a)`, the current `interval` is `[4, 5]` and the - /// input `a` is given as `[-7, 3]`, then propagation would return `[-5, 3]`. - fn propagate_constraints( - &self, - _interval: &Interval, - _inputs: &[&Interval], - ) -> Result>> { - Ok(Some(vec![])) - } - - /// Calculates the [`SortProperties`] of this function based on its children's properties. - fn output_ordering(&self, inputs: &[ExprProperties]) -> Result { - if !self.preserves_lex_ordering(inputs)? { - return Ok(SortProperties::Unordered); - } - - let Some(first_order) = inputs.first().map(|p| &p.sort_properties) else { - return Ok(SortProperties::Singleton); - }; - - if inputs - .iter() - .skip(1) - .all(|input| &input.sort_properties == first_order) - { - Ok(*first_order) - } else { - Ok(SortProperties::Unordered) - } - } - - /// Returns true if the function preserves lexicographical ordering based on - /// the input ordering. - /// - /// For example, `concat(a || b)` preserves lexicographical ordering, but `abs(a)` does not. - fn preserves_lex_ordering(&self, _inputs: &[ExprProperties]) -> Result { - Ok(false) - } - /// Coerce arguments of a function call to types that the function can evaluate. /// /// See the [type coercion module](crate::type_coercion) diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index dac0c2ba55909..eed5c46f080f6 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -30,7 +30,7 @@ use datafusion_common::alias::AliasGenerator; use datafusion_common::cse::{CSE, CSEController, FoundCommonNodes}; use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::{Column, DFSchema, DFSchemaRef, Result, qualified_name}; -use datafusion_expr::expr::{Alias, ScalarFunction}; +use datafusion_expr::expr::{Alias, LambdaFunction, ScalarFunction}; use datafusion_expr::logical_plan::{ Aggregate, Filter, LogicalPlan, Projection, Sort, Window, }; @@ -651,12 +651,15 @@ impl CSEController for ExprCSEController<'_> { fn conditional_children(node: &Expr) -> Option<(Vec<&Expr>, Vec<&Expr>)> { match node { - // In case of `ScalarFunction`s we don't know which children are surely + // In case of `ScalarFunction`s and `LambdaFunction`s we don't know which children are surely // executed so start visiting all children conditionally and stop the // recursion with `TreeNodeRecursion::Jump`. Expr::ScalarFunction(ScalarFunction { func, args }) => { func.conditional_arguments(args) } + Expr::LambdaFunction(LambdaFunction { func, args }) => { + func.conditional_arguments(args) + } // In case of `And` and `Or` the first child is surely executed, but we // account subexpressions as conditional in the second. @@ -696,7 +699,8 @@ impl CSEController for ExprCSEController<'_> { } fn is_valid(node: &Expr) -> bool { - !node.is_volatile_node() && !matches!(node, Expr::LambdaVariable(_)) + !node.is_volatile_node() + && !matches!(node, Expr::Lambda(_) | Expr::LambdaVariable(_)) } fn is_ignored(&self, node: &Expr) -> bool { @@ -726,6 +730,7 @@ impl CSEController for ExprCSEController<'_> { | Expr::ScalarVariable(..) | Expr::Alias(..) | Expr::Wildcard { .. } + | Expr::Lambda(_) | Expr::LambdaVariable(_) ); diff --git a/datafusion/physical-expr/src/expressions/lambda.rs b/datafusion/physical-expr/src/expressions/lambda.rs index c29edf17fca41..cb0b76483505f 100644 --- a/datafusion/physical-expr/src/expressions/lambda.rs +++ b/datafusion/physical-expr/src/expressions/lambda.rs @@ -27,13 +27,13 @@ use arrow::{ datatypes::{DataType, Schema}, record_batch::RecordBatch, }; -use datafusion_common::{internal_err, tree_node::TreeNodeVisitor, HashSet, Result}; +use datafusion_common::{HashSet, Result, internal_err, tree_node::TreeNodeVisitor}; use datafusion_common::{ plan_err, tree_node::{TreeNode, TreeNodeRecursion}, }; use datafusion_expr::ColumnarValue; -use hashbrown::{hash_map::EntryRef, HashMap}; +use hashbrown::{HashMap, hash_map::EntryRef}; /// Represents a lambda with the given parameters names and body #[derive(Debug, Eq, Clone)] @@ -100,6 +100,14 @@ impl LambdaExpr { &self.captured_columns } + /// Returns lambdas variables names that aren't of this lambda nor any other lambda down tree. + /// Example: + /// + /// `array_transform([[[1, 2, 3]]], a -> array_transform(a, b -> array_transform(b, c -> length(a) + length(b) + c)))` + /// + /// For the outermost lambda, this would return an empty hash set + /// For the middle one, `HashSet("a")` + /// And for the innermost, `HashSet("a", "b")` pub(crate) fn captured_variables(&self) -> &HashSet { &self.captured_variables } @@ -192,7 +200,7 @@ impl<'n> TreeNodeVisitor<'n> for Captures<'n> { fn f_down(&mut self, node: &'n Self::Node) -> Result { if let Some(lambda) = node.as_any().downcast_ref::() { for param in &lambda.params { - *self.shadows.entry_ref(param.as_str()).or_default() += 1; + *self.shadows.entry(param.as_str()).or_default() += 1; } } else if let Some(lambda_variable) = node.as_any().downcast_ref::() @@ -230,3 +238,101 @@ impl<'n> TreeNodeVisitor<'n> for Captures<'n> { Ok(TreeNodeRecursion::Continue) } } + +#[cfg(test)] +mod tests { + use crate::{ + LambdaFunctionExpr, + expressions::{Column, LambdaExpr, NoOp, lambda::lambda, lambda_variable}, + }; + use arrow::{ + array::RecordBatch, + datatypes::{DataType, Field, FieldRef, Schema}, + }; + use datafusion_common::{HashSet, Result}; + use datafusion_expr::{ColumnarValue, LambdaUDF}; + use std::sync::Arc; + + #[derive(Debug, Hash, Eq, PartialEq)] + struct DummyLambdaUDF; + + impl LambdaUDF for DummyLambdaUDF { + fn as_any(&self) -> &dyn std::any::Any { + unimplemented!() + } + + fn name(&self) -> &str { + "dummy_udlf" + } + + fn signature(&self) -> &datafusion_expr::LambdaSignature { + unimplemented!() + } + + fn lambdas_parameters( + &self, + _args: &[datafusion_expr::ValueOrLambda], + ) -> Result>>> { + unimplemented!() + } + + fn return_field_from_args( + &self, + _args: datafusion_expr::LambdaReturnFieldArgs, + ) -> Result { + unimplemented!() + } + + fn invoke_with_args( + &self, + _args: datafusion_expr::LambdaFunctionArgs, + ) -> Result { + unimplemented!() + } + } + + #[test] + fn test_lambda_captures() { + let null_field = Arc::new(Field::new("", DataType::Null, true)); + + //`var_b -> dummy_udlf(var_a, var_b, column@0, var_c -> var_c))` + let inner = LambdaExpr::try_new( + vec![String::from("var_b")], + Arc::new(LambdaFunctionExpr::new( + "dummy_udlf", + Arc::new(DummyLambdaUDF), + vec![ + lambda_variable("var_a", Arc::clone(&null_field)).unwrap(), + lambda_variable("var_b", Arc::clone(&null_field)).unwrap(), + Arc::new(Column::new("column", 0)), + lambda( + ["var_c"], + lambda_variable("var_c", Arc::clone(&null_field)).unwrap(), + ) + .unwrap(), + ], + Arc::clone(&null_field), + Arc::new(Default::default()), + )), + ) + .unwrap(); + + assert_eq!(inner.captured_columns(), &HashSet::from([0])); + assert_eq!( + inner.captured_variables(), + &HashSet::from([String::from("var_a")]) + ); + } + + #[test] + fn test_lambda_evaluate() { + let lambda = lambda(["a"], Arc::new(NoOp::new())).unwrap(); + let batch = RecordBatch::new_empty(Arc::new(Schema::empty())); + assert!(lambda.evaluate(&batch).is_err()); + } + + #[test] + fn test_lambda_duplicate_name() { + assert!(lambda(["a", "a"], Arc::new(NoOp::new())).is_err()); + } +} diff --git a/datafusion/physical-expr/src/expressions/mod.rs b/datafusion/physical-expr/src/expressions/mod.rs index e0958d2f351f3..27303e91e4285 100644 --- a/datafusion/physical-expr/src/expressions/mod.rs +++ b/datafusion/physical-expr/src/expressions/mod.rs @@ -23,11 +23,11 @@ mod case; mod cast; mod cast_column; mod column; -mod lambda_variable; mod dynamic_filters; mod in_list; mod is_not_null; mod is_null; +mod lambda_variable; mod lambda; mod like; mod literal; diff --git a/datafusion/physical-expr/src/lambda_function.rs b/datafusion/physical-expr/src/lambda_function.rs index 259aeaa53151f..197342a844cd3 100644 --- a/datafusion/physical-expr/src/lambda_function.rs +++ b/datafusion/physical-expr/src/lambda_function.rs @@ -41,8 +41,6 @@ use arrow::array::{Array, NullArray, RecordBatch}; use arrow::datatypes::{DataType, Field, FieldRef, Schema}; use datafusion_common::config::{ConfigEntry, ConfigOptions}; use datafusion_common::{Result, ScalarValue, exec_err, internal_err}; -use datafusion_expr::interval_arithmetic::Interval; -use datafusion_expr::sort_properties::ExprProperties; use datafusion_expr::type_coercion::functions::value_fields_with_lambda_udf; use datafusion_expr::{ ColumnarValue, LambdaArgument, LambdaFunctionArgs, LambdaReturnFieldArgs, LambdaUDF, @@ -243,7 +241,6 @@ fn sorted_config_entries(config_options: &ConfigOptions) -> Vec { } impl PhysicalExpr for LambdaFunctionExpr { - /// Return a reference to Any that can be used for downcasting fn as_any(&self) -> &dyn Any { self } @@ -404,34 +401,6 @@ impl PhysicalExpr for LambdaFunctionExpr { ))) } - fn evaluate_bounds(&self, children: &[&Interval]) -> Result { - self.fun.evaluate_bounds(children) - } - - fn propagate_constraints( - &self, - interval: &Interval, - children: &[&Interval], - ) -> Result>> { - self.fun.propagate_constraints(interval, children) - } - - fn get_properties(&self, children: &[ExprProperties]) -> Result { - let sort_properties = self.fun.output_ordering(children)?; - let preserves_lex_ordering = self.fun.preserves_lex_ordering(children)?; - let children_range = children - .iter() - .map(|props| &props.range) - .collect::>(); - let range = self.fun().evaluate_bounds(&children_range)?; - - Ok(ExprProperties { - sort_properties, - range, - preserves_lex_ordering, - }) - } - fn fmt_sql(&self, f: &mut Formatter<'_>) -> fmt::Result { write!(f, "{}(", self.name)?; for (i, expr) in self.args.iter().enumerate() { diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index 78d54cc75f699..83801fa0f306f 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -15,13 +15,15 @@ // specific language governing permissions and limitations // under the License. -use datafusion_expr::expr::{AggregateFunctionParams, LambdaFunction, WindowFunctionParams}; -use datafusion_expr::expr::{Lambda, Unnest}; use datafusion_common::datatype::DataTypeExt; +use datafusion_expr::expr::{ + AggregateFunctionParams, LambdaFunction, WindowFunctionParams, +}; +use datafusion_expr::expr::{Lambda, Unnest}; use sqlparser::ast::Value::SingleQuotedString; use sqlparser::ast::{ - self, Array, BinaryOperator, Expr as AstExpr, Function, Ident, Interval, - ObjectName, Subscript, TimezoneInfo, UnaryOperator, + self, Array, BinaryOperator, Expr as AstExpr, Function, Ident, Interval, ObjectName, + Subscript, TimezoneInfo, UnaryOperator, }; use sqlparser::ast::{CaseWhen, DuplicateTreatment, OrderByOptions, ValueWithSpan}; use std::sync::Arc; @@ -1869,10 +1871,11 @@ mod tests { use datafusion_common::{Spans, TableReference}; use datafusion_expr::expr::WildcardOptions; use datafusion_expr::{ - ColumnarValue, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, - Volatility, WindowFrame, WindowFunctionDefinition, case, cast, col, cube, exists, - grouping_set, interval_datetime_lit, interval_year_month_lit, lit, not, - not_exists, out_ref_col, placeholder, rollup, table_scan, try_cast, when, + ColumnarValue, LambdaUDF, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, + Signature, Volatility, WindowFrame, WindowFunctionDefinition, case, cast, col, + cube, exists, grouping_set, interval_datetime_lit, interval_year_month_lit, + lambda, lambda_var, lit, not, not_exists, out_ref_col, placeholder, rollup, + table_scan, try_cast, when, }; use datafusion_expr::{ExprFunctionExt, interval_month_day_nano_lit}; use datafusion_functions::datetime::from_unixtime::FromUnixtimeFunc; @@ -1929,6 +1932,44 @@ mod tests { } // See sql::tests for E2E tests. + #[derive(Debug, Hash, Eq, PartialEq)] + struct DummyLambdaUDF; + + impl LambdaUDF for DummyLambdaUDF { + fn as_any(&self) -> &dyn Any { + unimplemented!() + } + + fn name(&self) -> &str { + "dummy_udlf" + } + + fn signature(&self) -> &datafusion_expr::LambdaSignature { + unimplemented!() + } + + fn lambdas_parameters( + &self, + _args: &[datafusion_expr::ValueOrLambda], + ) -> Result>>> { + unimplemented!() + } + + fn return_field_from_args( + &self, + _args: datafusion_expr::LambdaReturnFieldArgs, + ) -> Result { + unimplemented!() + } + + fn invoke_with_args( + &self, + _args: datafusion_expr::LambdaFunctionArgs, + ) -> Result { + unimplemented!() + } + } + #[test] fn expr_to_sql_ok() -> Result<()> { let dummy_schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); @@ -2013,6 +2054,13 @@ mod tests { .is_not_null(), r#"dummy_udf(a, b) IS NOT NULL"#, ), + ( + Expr::LambdaFunction(LambdaFunction::new( + Arc::new(DummyLambdaUDF), + vec![col("a"), lambda(["v"], -lambda_var("v"))], + )), + r#"dummy_udlf(a, (v) -> -v)"#, + ), ( Expr::Like(Like { negated: true, diff --git a/datafusion/sqllogictest/src/test_context.rs b/datafusion/sqllogictest/src/test_context.rs index 11a6573b7ae8e..8bd0cabcb05b0 100644 --- a/datafusion/sqllogictest/src/test_context.rs +++ b/datafusion/sqllogictest/src/test_context.rs @@ -24,14 +24,11 @@ use std::sync::Arc; use std::vec; use arrow::array::{ - Array, ArrayRef, BinaryArray, FixedSizeListArray, Float64Array, Int32Array, - LargeBinaryArray, LargeStringArray, ListViewArray, StringArray, - TimestampNanosecondArray, UnionArray, + Array, ArrayRef, BinaryArray, Float64Array, Int32Array, LargeBinaryArray, + LargeStringArray, StringArray, TimestampNanosecondArray, UnionArray, }; use arrow::buffer::ScalarBuffer; -use arrow::datatypes::{ - DataType, Field, Int32Type, Schema, SchemaRef, TimeUnit, UnionFields, -}; +use arrow::datatypes::{DataType, Field, Schema, SchemaRef, TimeUnit, UnionFields}; use arrow::record_batch::RecordBatch; use datafusion::catalog::{ CatalogProvider, MemoryCatalogProvider, MemorySchemaProvider, SchemaProvider, Session, @@ -152,10 +149,6 @@ impl TestContext { info!("Registering dummy async udf"); register_async_abs_udf(test_ctx.session_ctx()) } - "lambda.slt" => { - info!("Registering table with ListView column"); - register_table_with_list_view(test_ctx.session_ctx()) - } _ => { info!("Using default SessionContext"); } @@ -625,26 +618,3 @@ fn register_async_abs_udf(ctx: &SessionContext) { let udf = AsyncScalarUDF::new(Arc::new(async_abs)); ctx.register_udf(udf.into_scalar_udf()); } - -fn register_table_with_list_view(session_ctx: &SessionContext) { - let data = vec![ - Some(vec![Some(0), Some(1)]), - Some(vec![Some(3), None]), - Some(vec![None, None]), - None, - ]; - let list_array = FixedSizeListArray::from_iter_primitive::(data, 2); - let list_view: ListViewArray = list_array.into(); - - let schema = Schema::new(vec![Field::new( - "list_view", - list_view.data_type().clone(), - true, - )]); - let batch = - RecordBatch::try_new(Arc::new(schema), vec![Arc::new(list_view)]).unwrap(); - - session_ctx - .register_batch("table_with_list_view", batch) - .unwrap(); -} diff --git a/datafusion/sqllogictest/test_files/lambda.slt b/datafusion/sqllogictest/test_files/lambda.slt index cfd56298baa2a..3c79976e850a7 100644 --- a/datafusion/sqllogictest/test_files/lambda.slt +++ b/datafusion/sqllogictest/test_files/lambda.slt @@ -23,144 +23,225 @@ statement ok set datafusion.sql_parser.dialect = databricks; statement ok -CREATE TABLE tt +CREATE TABLE t (list array, number int) AS VALUES ([1, 50], 10), -([4, 50], 40); +([4, 50], 40), +([7, 50], 60); -statement ok -CREATE TABLE t AS SELECT 1 as f, [ [ [2, 3], [2] ], [ [1] ], [ [] ] ] as v, 1 as n; +query ? +SELECT array_transform([1,2,3,4,5], v -> v*2); +---- +[2, 4, 6, 8, 10] -#test potentialy sliced list -select array_transform(tt.column_1, v -> v*2) from tt order by column_2 limit 1; +query ? +SELECT array_transform([1,2,3,4,5], v -> repeat("a", v)); +---- +[a, aa, aaa, aaaa, aaaaa] -query I? -SELECT t.n, array_transform([], e1 -> t.n) from t; +query ? +SELECT array_transform([1,2,3,4,5], v -> list_repeat("a", v)); +---- +[[a], [a, a], [a, a, a], [a, a, a, a], [a, a, a, a, a]] + + +# version without limit/offset of queries below +query ? +select array_transform(t.list, v -> v*2) from t order by t.number; ---- -1 [] +[2, 100] +[8, 100] +[14, 100] +# sliced lists query ? -SELECT array_transform([1, 2], e1 -> (select n from t)); +select array_transform(t.list, v -> v*2) from t order by t.number limit 1; ---- -[1, 1] +[2, 100] query ? -SELECT array_transform(t.v, (v1, i) -> array_transform(v1, (v2, j) -> array_transform(v2, v3 -> j)) ) from t; +select array_transform(t.list, v -> v*2) from t order by t.number offset 1; ---- -[[[1, 1], [2]], [[1]], [[]]] +[8, 100] +[14, 100] +query ? +select array_transform(t.list, v -> v*2) from t order by t.number limit 1 offset 1; +---- +[8, 100] + +# lambda that uses only captured column which also only appears within the lambda +query ? +SELECT array_transform([1, 2], v -> t.number) from t; +---- +[10, 10] +[40, 40] +[60, 60] + +# return scalar query I? -SELECT t.n, array_transform([1, 2], (e) -> n) from t; +SELECT t.number, array_transform([1, 2], e1 -> 24) from t; ---- -1 [1, 1] +10 [24, 24] +40 [24, 24] +60 [24, 24] -# projection pushdown +# uses only the first parameter query ? -SELECT array_transform([1, 2], (e) -> n) from t; +SELECT array_transform([1, 2], (v, i) -> v+1) from t; ---- -[1, 1] +[2, 3] +[2, 3] +[2, 3] +# uses only the second parameter query ? -SELECT array_transform([1, 2], (e, i) -> i) from t; +SELECT array_transform([10, 20], (v, i) -> i) from t; ---- [1, 2] +[1, 2] +[1, 2] -# type coercion +# use only capture of a parent lambda variable query ? -SELECT array_transform([1.0, 2.0], v -> v + t.n) from t; +SELECT array_transform([[1, 2]], (a, i) -> array_transform(a, b -> i)) ---- -[2.0, 4.0] +[[1, 1]] -query TT -EXPLAIN SELECT array_transform([1.0, 2.0], v -> v + t.n) from t; +# use only capture of a parent lambda variable and a column +query ? +SELECT array_transform([[1, 2]], (a, i) -> array_transform(a, b -> i + t.number)) from t ---- -logical_plan -01)Projection: List([2.0, 4.0]) AS array_transform(make_array(Float64(1),Float64(2)),(e, i) -> e + i) -02)--EmptyRelation: rows=1 -physical_plan -01)ProjectionExec: expr=[[2.0, 4.0] as array_transform(make_array(Float64(1),Float64(2)),(e, i) -> e + i)] -02)--PlaceholderRowExec +[[11, 11]] +[[41, 41]] +[[61, 61]] -#cse -query TT -explain select n + 1, array_transform([1], v -> v + n + 1) from t; +# use only capture of a parent lambda variable and own variable +query ? +SELECT array_transform([[1, 2]], (a, i) -> array_transform(a, b -> i + b)) ---- -logical_plan -01)Projection: t.n + Int64(1), array_transform(List([1]), (v) -> v + t.n + Int64(1)) AS array_transform(make_array(Int64(1)),(v) -> v + t.n + Int64(1)) -02)--TableScan: t projection=[n] -physical_plan -01)ProjectionExec: expr=[n@0 + 1 as t.n + Int64(1), array_transform([1], (v) -> v@-1 + n@0 + 1) as array_transform(make_array(Int64(1)),(v) -> v + t.n + Int64(1))] -02)--DataSourceExec: partitions=1, partition_sizes=[1] +[[2, 3]] +# use capture of a column, a parent lambda variable and own variable query ? -SELECT array_transform([1,2,3,4,5], v -> v*2); +SELECT array_transform([[1, 2]], (a, i) -> array_transform(a, b -> i + b + t.number)) from t ---- -[2, 4, 6, 8, 10] +[[12, 13]] +[[42, 43]] +[[62, 63]] +# shadows parent lambda variable query ? -SELECT array_transform([1,2,3,4,5], v -> 2); +SELECT array_transform([[1, 2]], a -> array_transform(a, a -> a+1)) ---- -[2, 2, 2, 2, 2] +[[2, 3]] +# multiple nesting query ? -SELECT array_transform([[1,2],[3,4,5]], v -> array_transform(v, v -> v*2)); +SELECT array_transform([[[1], [2], [3]]], (a, i) -> array_transform(a, (b, j) -> array_transform(b, (c, k) -> c + i + j + k))); ---- -[[2, 4], [6, 8, 10]] +[[[4], [6], [8]]] -query ? -SELECT array_transform([1,2,3,4,5], v -> repeat("a", v)); +# parameter shadows unqualified column +query I? +SELECT number, array_transform([1, 2], number -> number+1) from t; ---- -[a, aa, aaa, aaaa, aaaaa] +10 [2, 3] +40 [2, 3] +60 [2, 3] +# type coercion inside lambda body query ? -SELECT array_transform([1,2,3,4,5], v -> list_repeat("a", v)); +SELECT array_transform([1.0, 2.0], v -> v + t.number) from t; ---- -[[a], [a, a], [a, a, a], [a, a, a, a], [a, a, a, a, a]] +[11.0, 12.0] +[41.0, 42.0] +[61.0, 62.0] query TT -EXPLAIN SELECT array_transform([1,2,3,4,5], v -> v*2); +EXPLAIN SELECT array_transform([1.0, 2.0], v -> v + t.number) from t; ---- logical_plan -01)Projection: List([2, 4, 6, 8, 10]) AS array_transform(make_array(Int64(1),Int64(2),Int64(3),Int64(4),Int64(5)),(v) -> v * Int64(2)) -02)--EmptyRelation: rows=1 +01)Projection: array_transform(List([1.0, 2.0]), (v) -> v + CAST(t.number AS Float64)) AS array_transform(make_array(Float64(1),Float64(2)),(v) -> v + t.number) +02)--TableScan: t projection=[number] physical_plan -01)ProjectionExec: expr=[[2, 4, 6, 8, 10] as array_transform(make_array(Int64(1),Int64(2),Int64(3),Int64(4),Int64(5)),(v) -> v * Int64(2))] -02)--PlaceholderRowExec +01)ProjectionExec: expr=[array_transform([1.0, 2.0], (v) -> v@ + CAST(number@0 AS Float64)) as array_transform(make_array(Float64(1),Float64(2)),(v) -> v + t.number)] +02)--DataSourceExec: partitions=1, partition_sizes=[1] -query ? -SELECT array_transform( - [[1]], - v -> array_concat( - array_transform(v, v -> v), - array_transform(v, v1 -> v1 + v[0]) - ) -); +#cse +query TT +explain select t.number*2, array_transform([1], v -> v + t.number*2) from t; ---- -[[1, NULL]] +logical_plan +01)Projection: __common_expr_1 AS t.number * Int64(2), array_transform(List([1]), (v) -> v + __common_expr_1) AS array_transform(make_array(Int64(1)),(v) -> v + t.number * Int64(2)) +02)--Projection: CAST(t.number AS Int64) * Int64(2) AS __common_expr_1 +03)----TableScan: t projection=[number] +physical_plan +01)ProjectionExec: expr=[__common_expr_1@0 as t.number * Int64(2), array_transform([1], (v) -> v@ + __common_expr_1@0) as array_transform(make_array(Int64(1)),(v) -> v + t.number * Int64(2))] +02)--ProjectionExec: expr=[CAST(number@0 AS Int64) * 2 as __common_expr_1] +03)----DataSourceExec: partitions=1, partition_sizes=[1] -query I?? -SELECT t.n, t.v, array_transform(t.v, (v, i) -> array_transform(v, (v, j) -> n) ) from t; +#cse should not eliminate subtrees containing lambdas +query TT +explain select array_transform([t.number], v -> 5), array_transform([t.number+1], v -> 5) from t; ---- -1 [[[2, 3], [2]], [[1]], [[]]] [[1, 1], [1], [1]] +logical_plan +01)Projection: array_transform(make_array(t.number), (v) -> Int64(5)), array_transform(make_array(CAST(t.number AS Int64) + Int64(1)), (v) -> Int64(5)) +02)--TableScan: t projection=[number] +physical_plan +01)ProjectionExec: expr=[array_transform(make_array(number@0), (v) -> 5) as array_transform(make_array(t.number),(v) -> Int64(5)), array_transform(make_array(CAST(number@0 AS Int64) + 1), (v) -> 5) as array_transform(make_array(t.number + Int64(1)),(v) -> Int64(5))] +02)--DataSourceExec: partitions=1, partition_sizes=[1] + +#cse should not eliminate subtrees containing lambda variables +query TT +explain select array_transform([t.number], v -> v*2), array_transform([t.number+1], (v, i) -> v*2) from t; +---- +logical_plan +01)Projection: array_transform(make_array(t.number), (v) -> CAST(v AS Int64) * Int64(2)), array_transform(make_array(CAST(t.number AS Int64) + Int64(1)), (v, i) -> v * Int64(2)) +02)--TableScan: t projection=[number] +physical_plan +01)ProjectionExec: expr=[array_transform(make_array(number@0), (v) -> CAST(v@ AS Int64) * 2) as array_transform(make_array(t.number),(v) -> v * Int64(2)), array_transform(make_array(CAST(number@0 AS Int64) + 1), (v, i) -> v@ * 2) as array_transform(make_array(t.number + Int64(1)),(v, i) -> v * Int64(2))] +02)--DataSourceExec: partitions=1, partition_sizes=[1] + +# test that sql planner plans resolved lambda variables, as v[1] planning checks the datatype of lhs query ? -SELECT array_transform([1,2,3,4,5], v -> v*2); +SELECT array_transform([[10, 20]], v -> v[1]); ---- -[2, 4, 6, 8, 10] +[10] -# expr simplifier +# expr simplifier inside lambda body query TT -EXPLAIN SELECT v = v, array_transform([t.n], v -> v = v) from t; +EXPLAIN SELECT array_transform([t.number], v -> v = v) from t; ---- logical_plan -01)Projection: Boolean(true) AS t.v = t.v, List([true]) AS array_transform(make_array(Int64(1)),(v) -> v = v) -02)--TableScan: t projection=[] +01)Projection: array_transform(make_array(t.number), (v) -> v IS NOT NULL OR Boolean(NULL)) AS array_transform(make_array(t.number),(v) -> v = v) +02)--TableScan: t projection=[number] physical_plan -01)ProjectionExec: expr=[true as t.v = t.v, [true] as array_transform(make_array(Int64(1)),(v) -> v = v)] +01)ProjectionExec: expr=[array_transform(make_array(number@0), (v) -> v@ IS NOT NULL OR NULL) as array_transform(make_array(t.number),(v) -> v = v)] 02)--DataSourceExec: partitions=1, partition_sizes=[1] +# array_transform coercion rules +query TT +explain select array_transform(arrow_cast(t.list, 'ListView(Int32)'), a -> a+1) from t; +---- +logical_plan +01)Projection: array_transform(CAST(CAST(t.list AS ListView(Int32)) AS List(Int32)), (a) -> CAST(a AS Int64) + Int64(1)) AS array_transform(arrow_cast(t.list,Utf8("ListView(Int32)")),(a) -> a + Int64(1)) +02)--TableScan: t projection=[list] +physical_plan +01)ProjectionExec: expr=[array_transform(CAST(CAST(list@0 AS ListView(Int32)) AS List(Int32)), (a) -> CAST(a@ AS Int64) + 1) as array_transform(arrow_cast(t.list,Utf8("ListView(Int32)")),(a) -> a + Int64(1))] +02)--DataSourceExec: partitions=1, partition_sizes=[1] + +query ? +select array_transform(arrow_cast(t.list, 'ListView(Int32)'), a -> a+1) from t; +---- +[2, 51] +[5, 51] +[8, 51] + + query error select array_transform(); ---- @@ -174,39 +255,25 @@ query error DataFusion error: Error during planning: array_transform expects a v select array_transform(v -> v*2, [1, 2]); query error DataFusion error: Error during planning: lambda defined 3 params but UDF support only 2 -SELECT array_transform([1, 2], (e, i, j) -> i) from t; +SELECT array_transform([1, 2], (e, i, j) -> i); query error DataFusion error: Error during planning: lambda parameters names must be unique, got \(v, v\) SELECT array_transform([1], (v, v) -> v*2); -query error DataFusion error: This feature is not implemented: Unsupported ast node in sqltorel: Lambda\(LambdaFunction \{ params: One\(Ident \{ value: "v", quote_style: None, span: Span\(Location\(1,12\)\.\.Location\(1,13\)\) \}\), body: Identifier\(Ident \{ value: "v", quote_style: None, span: Span\(Location\(1,17\)\.\.Location\(1,18\)\) \}\) \}\) +query error DataFusion error: This feature is not implemented: Unsupported ast node in sqltorel: Lambda\(LambdaFunction \{ params: One\(Ident \{ value: "v", quote_style: None, span: Span\(Location\(1,12\)\.\.Location\(1,13\)\) \}\), body: Identifier\(Ident \{ value: "v", quote_style: None, span: Span\(Location\(1,17\)\.\.Location\(1,18\)\) \}\), syntax: Arrow \}\) SELECT abs(v -> v); -query error DataFusion error: This feature is not implemented: Unsupported ast node in sqltorel: Lambda\(LambdaFunction \{ params: One\(Ident \{ value: "v", quote_style: None, span: Span\(Location\(1,8\)\.\.Location\(1,9\)\) \}\), body: Identifier\(Ident \{ value: "v", quote_style: None, span: Span\(Location\(1,13\)\.\.Location\(1,14\)\) \}\) \}\) +query error DataFusion error: This feature is not implemented: Unsupported ast node in sqltorel: Lambda\(LambdaFunction \{ params: One\(Ident \{ value: "v", quote_style: None, span: Span\(Location\(1,8\)\.\.Location\(1,9\)\) \}\), body: Identifier\(Ident \{ value: "v", quote_style: None, span: Span\(Location\(1,13\)\.\.Location\(1,14\)\) \}\), syntax: Arrow \}\) SELECT v -> v; -query error DataFusion error: This feature is not implemented: Unsupported ast node in sqltorel: Lambda\(LambdaFunction \{ params: One\(Ident \{ value: "v", quote_style: None, span: Span\(Location\(1,34\)\.\.Location\(1,35\)\) \}\), body: BinaryOp \{ left: Identifier\(Ident \{ value: "v", quote_style: None, span: Span\(Location\(1,39\)\.\.Location\(1,40\)\) \}\), op: Plus, right: Value\(ValueWithSpan \{ value: Number\("1", false\), span: Span\(Location\(1,41\)\.\.Location\(1,42\)\) \}\) \} \}\) +query error DataFusion error: This feature is not implemented: Unsupported ast node in sqltorel: Lambda\(LambdaFunction \{ params: One\(Ident \{ value: "v", quote_style: None, span: Span\(Location\(1,34\)\.\.Location\(1,35\)\) \}\), body: BinaryOp \{ left: Identifier\(Ident \{ value: "v", quote_style: None, span: Span\(Location\(1,39\)\.\.Location\(1,40\)\) \}\), op: Plus, right: Value\(ValueWithSpan \{ value: Number\("1", false\), span: Span\(Location\(1,41\)\.\.Location\(1,42\)\) \}\) \}, syntax: Arrow \}\) SELECT array_transform([1], v -> v -> v+1); +query error DataFusion error: SQL error: ParserError\("Expected: an expression, found: \) at Line: 1, Column: 30"\) +SELECT array_transform([1], () -> 1); -query ? -SELECT array_transform([[2, 3]], v -> array_transform(v, j -> v)); ----- -[[[2, 3], [2, 3]]] - - -query TT -explain select array_transform(list_view, a -> a+1) from table_with_list_view; ----- -initial_logical_plan -01)Projection: array_transform(table_with_list_view.list_view, (a) -> a + Int64(1)) -02)--TableScan: table_with_list_view -logical_plan after resolve_grouping_function SAME TEXT AS ABOVE -logical_plan after type_coercion Error during planning: Cannot automatically convert ListView(nullable Int32) to List(nullable Int32) +statement ok +drop table t; -query error -select array_transform(list_view, a -> a+1) from table_with_list_view; ----- -DataFusion error: type_coercion -caused by -Error during planning: Cannot automatically convert ListView(nullable Int32) to List(nullable Int32) +statement ok +set datafusion.sql_parser.dialect = generic; From d874db7e2881befa6f9f10720647649b8ac28e10 Mon Sep 17 00:00:00 2001 From: gstvg <28798827+gstvg@users.noreply.github.com> Date: Tue, 17 Mar 2026 17:41:43 -0300 Subject: [PATCH 16/47] cargo fmt --- .../core/src/bin/print_functions_docs.rs | 3 ++- .../core/src/execution/session_state.rs | 5 ++-- .../datasource-arrow/src/file_format.rs | 6 +++-- datafusion/datasource/src/url.rs | 4 +++- datafusion/execution/src/task.rs | 2 +- datafusion/expr/src/expr.rs | 4 +--- datafusion/expr/src/expr_fn.rs | 11 ++++++--- datafusion/expr/src/lib.rs | 2 +- datafusion/expr/src/logical_plan/plan.rs | 5 ++-- datafusion/expr/src/planner.rs | 7 +++--- datafusion/expr/src/udlf.rs | 23 ++++++++++--------- .../functions-nested/src/array_transform.rs | 9 ++++---- .../optimizer/tests/optimizer_integration.rs | 5 +++- .../src/expressions/lambda_variable.rs | 12 +++++----- .../physical-expr/src/expressions/mod.rs | 4 ++-- datafusion/physical-expr/src/lib.rs | 4 ++-- datafusion/proto/src/bytes/registry.rs | 4 +++- datafusion/proto/src/logical_plan/to_proto.rs | 2 +- datafusion/session/src/session.rs | 2 +- datafusion/sql/src/expr/function.rs | 3 ++- datafusion/sql/src/unparser/dialect.rs | 2 +- datafusion/sql/tests/common/mod.rs | 2 +- .../consumer/expr/scalar_function.rs | 2 +- .../src/logical_plan/producer/expr/mod.rs | 4 +++- 24 files changed, 72 insertions(+), 55 deletions(-) diff --git a/datafusion/core/src/bin/print_functions_docs.rs b/datafusion/core/src/bin/print_functions_docs.rs index e086a59041212..2b681cdd74cb2 100644 --- a/datafusion/core/src/bin/print_functions_docs.rs +++ b/datafusion/core/src/bin/print_functions_docs.rs @@ -18,7 +18,8 @@ use datafusion::execution::SessionStateDefaults; use datafusion_common::{HashSet, Result, not_impl_err}; use datafusion_expr::{ - AggregateUDF, DocSection, Documentation, LambdaUDF, ScalarUDF, WindowUDF, aggregate_doc_sections, scalar_doc_sections, window_doc_sections + AggregateUDF, DocSection, Documentation, LambdaUDF, ScalarUDF, WindowUDF, + aggregate_doc_sections, scalar_doc_sections, window_doc_sections, }; use itertools::Itertools; use std::env::args; diff --git a/datafusion/core/src/execution/session_state.rs b/datafusion/core/src/execution/session_state.rs index 1bd0f27381d04..ed9ecc393cb0b 100644 --- a/datafusion/core/src/execution/session_state.rs +++ b/datafusion/core/src/execution/session_state.rs @@ -57,12 +57,11 @@ use datafusion_expr::planner::ExprPlanner; #[cfg(feature = "sql")] use datafusion_expr::planner::{RelationPlanner, TypePlanner}; use datafusion_expr::registry::{FunctionRegistry, SerializerRegistry}; +use datafusion_expr::simplify::SimplifyContext; #[cfg(feature = "sql")] use datafusion_expr::{ - AggregateUDF, Explain, Expr, LambdaUDF, LogicalPlan, ScalarUDF, - WindowUDF, + AggregateUDF, Explain, Expr, LambdaUDF, LogicalPlan, ScalarUDF, WindowUDF, }; -use datafusion_expr::simplify::SimplifyContext; use datafusion_optimizer::simplify_expressions::ExprSimplifier; use datafusion_optimizer::{ Analyzer, AnalyzerRule, Optimizer, OptimizerConfig, OptimizerRule, diff --git a/datafusion/datasource-arrow/src/file_format.rs b/datafusion/datasource-arrow/src/file_format.rs index ed8dbaac1311d..2f4f29fc94487 100644 --- a/datafusion/datasource-arrow/src/file_format.rs +++ b/datafusion/datasource-arrow/src/file_format.rs @@ -555,7 +555,9 @@ mod tests { use datafusion_execution::config::SessionConfig; use datafusion_execution::runtime_env::RuntimeEnv; use datafusion_expr::execution_props::ExecutionProps; - use datafusion_expr::{AggregateUDF, Expr, LambdaUDF, LogicalPlan, ScalarUDF, WindowUDF}; + use datafusion_expr::{ + AggregateUDF, Expr, LambdaUDF, LogicalPlan, ScalarUDF, WindowUDF, + }; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use object_store::{chunked::ChunkedStore, memory::InMemory, path::Path}; @@ -601,7 +603,7 @@ mod tests { fn scalar_functions(&self) -> &HashMap> { unimplemented!() } - + fn lambda_functions(&self) -> &HashMap> { unimplemented!() } diff --git a/datafusion/datasource/src/url.rs b/datafusion/datasource/src/url.rs index 3f484c3806a21..1f3b156e4b0b6 100644 --- a/datafusion/datasource/src/url.rs +++ b/datafusion/datasource/src/url.rs @@ -517,7 +517,9 @@ mod tests { use datafusion_execution::config::SessionConfig; use datafusion_execution::runtime_env::RuntimeEnv; use datafusion_expr::execution_props::ExecutionProps; - use datafusion_expr::{AggregateUDF, Expr, LambdaUDF, LogicalPlan, ScalarUDF, WindowUDF}; + use datafusion_expr::{ + AggregateUDF, Expr, LambdaUDF, LogicalPlan, ScalarUDF, WindowUDF, + }; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use datafusion_physical_plan::ExecutionPlan; use object_store::{ diff --git a/datafusion/execution/src/task.rs b/datafusion/execution/src/task.rs index 371a94e8298e0..9427f1179c09c 100644 --- a/datafusion/execution/src/task.rs +++ b/datafusion/execution/src/task.rs @@ -21,7 +21,7 @@ use crate::{ }; use datafusion_common::{Result, internal_datafusion_err, plan_datafusion_err}; use datafusion_expr::planner::ExprPlanner; -use datafusion_expr::{AggregateUDF, ScalarUDF, LambdaUDF, WindowUDF}; +use datafusion_expr::{AggregateUDF, LambdaUDF, ScalarUDF, WindowUDF}; use std::collections::HashSet; use std::{collections::HashMap, sync::Arc}; diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 06b0e29efe6d3..8803c6e2187f8 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -3410,9 +3410,7 @@ impl Display for SchemaDisplay<'_> { SchemaDisplay(body) ) } - Expr::LambdaVariable(c) => { - f.write_str(&c.name) - } + Expr::LambdaVariable(c) => f.write_str(&c.name), } } } diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index a49593ee6ce81..ce0fb09a3126b 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -18,7 +18,9 @@ //! Functions for creating logical expressions use crate::expr::{ - AggregateFunction, BinaryExpr, Cast, Exists, GroupingSet, InList, InSubquery, Lambda, LambdaVariable, NullTreatment, Placeholder, TryCast, Unnest, WildcardOptions, WindowFunction + AggregateFunction, BinaryExpr, Cast, Exists, GroupingSet, InList, InSubquery, Lambda, + LambdaVariable, NullTreatment, Placeholder, TryCast, Unnest, WildcardOptions, + WindowFunction, }; use crate::function::{ AccumulatorArgs, AccumulatorFactoryFunction, PartitionEvaluatorFactory, @@ -733,11 +735,14 @@ pub fn interval_month_day_nano_lit(value: &str) -> Expr { /// Create a lambda expression pub fn lambda(params: impl IntoIterator>, body: Expr) -> Expr { - Expr::Lambda(Lambda::new(params.into_iter().map(Into::into).collect(), body)) + Expr::Lambda(Lambda::new( + params.into_iter().map(Into::into).collect(), + body, + )) } /// Create an unresolved lambda variable expression -/// +/// /// The expression tree or [`LogicalPlan`] which /// owns this variable must be resolved before usage with either /// [`Expr::resolve_lambdas_variables`] or [`LogicalPlan::resolve_lambdas_variables`]. diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index 998b8415fee7a..6dd91b1012a61 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -128,7 +128,7 @@ pub use udaf::{ }; pub use udf::{ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl}; pub use udlf::{ - LambdaFunctionArgs, LambdaArgument, LambdaReturnFieldArgs, LambdaSignature, + LambdaArgument, LambdaFunctionArgs, LambdaReturnFieldArgs, LambdaSignature, LambdaTypeSignature, LambdaUDF, ValueOrLambda, }; pub use udwf::{LimitEffect, ReversedUDWF, WindowUDF, WindowUDFImpl}; diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 1b3cc94a9eab1..4d8466827fe9c 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -41,7 +41,8 @@ use crate::logical_plan::display::{GraphvizVisitor, IndentVisitor}; use crate::logical_plan::extension::UserDefinedLogicalNode; use crate::logical_plan::{DmlStatement, Statement}; use crate::utils::{ - enumerate_grouping_sets, exprlist_to_fields, find_out_reference_exprs, grouping_set_expr_count, grouping_set_to_exprlist, merge_schema, split_conjunction + enumerate_grouping_sets, exprlist_to_fields, find_out_reference_exprs, + grouping_set_expr_count, grouping_set_to_exprlist, merge_schema, split_conjunction, }; use crate::{ BinaryExpr, CreateMemoryTable, CreateView, Execute, Expr, ExprSchemable, @@ -2111,7 +2112,7 @@ impl LogicalPlan { } /// Return a `LogicalPLan` with all [`LambdaVariable`] resolved - /// + /// /// [`LambdaVariable`]: crate::expr::LambdaVariable pub fn resolve_lambdas_variables(self) -> Result> { self.transform_with_subqueries(|plan| { diff --git a/datafusion/expr/src/planner.rs b/datafusion/expr/src/planner.rs index 296a99d55fad8..fe5d71d338941 100644 --- a/datafusion/expr/src/planner.rs +++ b/datafusion/expr/src/planner.rs @@ -24,7 +24,8 @@ use crate::expr::NullTreatment; #[cfg(feature = "sql")] use crate::logical_plan::LogicalPlan; use crate::{ - AggregateUDF, Expr, GetFieldAccess, LambdaUDF, ScalarUDF, SortExpr, TableSource, WindowFrame, WindowFunctionDefinition, WindowUDF + AggregateUDF, Expr, GetFieldAccess, LambdaUDF, ScalarUDF, SortExpr, TableSource, + WindowFrame, WindowFunctionDefinition, WindowUDF, }; use arrow::datatypes::{DataType, Field, FieldRef, SchemaRef}; use datafusion_common::datatype::DataTypeExt; @@ -101,7 +102,7 @@ pub trait ContextProvider { /// Return the scalar function with a given name, if any fn get_function_meta(&self, name: &str) -> Option>; - + /// Return the lambda function with a given name, if any fn get_lambda_meta(&self, name: &str) -> Option>; @@ -132,7 +133,7 @@ pub trait ContextProvider { /// Return all scalar function names fn udf_names(&self) -> Vec; - + /// Return all lambda function names fn udlf_names(&self) -> Vec; diff --git a/datafusion/expr/src/udlf.rs b/datafusion/expr/src/udlf.rs index fac9ef5d8c6d7..add2c8a3ca39f 100644 --- a/datafusion/expr/src/udlf.rs +++ b/datafusion/expr/src/udlf.rs @@ -23,7 +23,7 @@ use crate::{ColumnarValue, Documentation, Expr}; use arrow::array::{ArrayRef, RecordBatch}; use arrow::datatypes::{DataType, Field, FieldRef, Schema}; use datafusion_common::config::ConfigOptions; -use datafusion_common::{exec_err, not_impl_err, Result, ScalarValue}; +use datafusion_common::{Result, ScalarValue, exec_err, not_impl_err}; use datafusion_expr_common::dyn_eq::{DynEq, DynHash}; use datafusion_expr_common::signature::Volatility; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; @@ -148,7 +148,8 @@ impl PartialOrd for dyn LambdaUDF { "Detected incorrect implementation of PartialEq when comparing functions: '{}' and '{}'. \ The functions compare as equal, but they are not equal based on general properties that \ the PartialOrd implementation observes,", - self.name(), other.name() + self.name(), + other.name() ); Some(cmp) } @@ -169,8 +170,8 @@ pub struct LambdaFunctionArgs { /// The evaluated arguments and lambdas to the function pub args: Vec>, /// Field associated with each arg, if it exists - /// For lambdas, it will be the field of the result of - /// the lambda if evaluated with the parameters + /// For lambdas, it will be the field of the result of + /// the lambda if evaluated with the parameters /// returned from [`LambdaUDF::lambdas_parameters`] pub arg_fields: Vec>, /// The number of rows in record batch being evaluated @@ -447,14 +448,14 @@ pub trait LambdaUDF: Debug + DynEq + DynHash + Send + Sync { } /// Returns a list of the same size as args where each value is the logic below applied to value at the correspondent position in args: - /// + /// /// If it's a value, return None /// If it's a lambda, return the list of all parameters that that lambda supports - /// + /// /// Example for array_transform: - /// + /// /// `array_transform([2.0, 8.0], v -> v > 4.0)` - /// + /// /// ```ignore /// let lambdas_parameters = array_transform.lambdas_parameters(&[ /// ValueOrLambdaParameter::Value(Field::new("", DataType::new_list(DataType::Float32, false)))]), // the Field of the literal `[2, 8]` @@ -470,14 +471,14 @@ pub trait LambdaUDF: Debug + DynEq + DynHash + Send + Sync { /// Some(vec![ /// // the value being transformed /// Field::new("", DataType::Float32, false), - /// // the 1-based index being transformed, not used on the example above, + /// // the 1-based index being transformed, not used on the example above, /// //but implementations doesn't need to care about it /// Field::new("", DataType::Int32, false), /// ]) /// ] /// ) /// ``` - /// + /// /// The implementation can assume that some other part of the code has coerced /// the actual argument types to match [`Self::signature`]. fn lambdas_parameters( @@ -486,7 +487,7 @@ pub trait LambdaUDF: Debug + DynEq + DynHash + Send + Sync { ) -> Result>>>; /// What type will be returned by this function, given the arguments? - /// + /// /// The implementation can assume that some other part of the code has coerced /// the actual argument types to match [`Self::signature`]. /// diff --git a/datafusion/functions-nested/src/array_transform.rs b/datafusion/functions-nested/src/array_transform.rs index 21551ddcf9881..ff0c45037289b 100644 --- a/datafusion/functions-nested/src/array_transform.rs +++ b/datafusion/functions-nested/src/array_transform.rs @@ -23,12 +23,11 @@ use arrow::{ datatypes::{DataType, Field, FieldRef}, }; use datafusion_common::{ - exec_err, plan_err, + Result, exec_err, plan_err, utils::{ adjust_offsets_for_slice, list_values, list_values_index, list_values_row_number, take_function_args, }, - Result, }; use datafusion_expr::{ ColumnarValue, Documentation, LambdaFunctionArgs, LambdaReturnFieldArgs, @@ -118,7 +117,7 @@ impl LambdaUDF for ArrayTransform { "{} expected a list as first argument, got {}", self.name(), list - ) + ); } }; @@ -175,7 +174,7 @@ impl LambdaUDF for ArrayTransform { let (list, lambda) = value_lambda_pair(self.name(), &args.args)?; let list_array = list.to_array(args.number_rows)?; - + // as per list_values docs, if list_array is sliced, list_values will be sliced too, // so before constructing the transformed array below, we must adjust the list offsets with // adjust_offsets_for_slice @@ -210,7 +209,7 @@ impl LambdaUDF for ArrayTransform { "{} expected ScalarFunctionArgs.return_field to be a list, got {}", self.name(), args.return_field - ) + ); } }; diff --git a/datafusion/optimizer/tests/optimizer_integration.rs b/datafusion/optimizer/tests/optimizer_integration.rs index 3e60b08a646fe..27e46a98bed76 100644 --- a/datafusion/optimizer/tests/optimizer_integration.rs +++ b/datafusion/optimizer/tests/optimizer_integration.rs @@ -735,7 +735,10 @@ impl ContextProvider for MyContextProvider { None } - fn get_lambda_meta(&self, _name: &str) -> Option> { + fn get_lambda_meta( + &self, + _name: &str, + ) -> Option> { None } diff --git a/datafusion/physical-expr/src/expressions/lambda_variable.rs b/datafusion/physical-expr/src/expressions/lambda_variable.rs index e5fab5fdb2366..2072bf9bb1c1e 100644 --- a/datafusion/physical-expr/src/expressions/lambda_variable.rs +++ b/datafusion/physical-expr/src/expressions/lambda_variable.rs @@ -28,7 +28,7 @@ use arrow::{ record_batch::RecordBatch, }; -use datafusion_common::{exec_err, Result}; +use datafusion_common::{Result, exec_err}; use datafusion_expr::ColumnarValue; /// Represents the lambda variable with a given name and field @@ -56,10 +56,7 @@ impl Hash for LambdaVariable { impl LambdaVariable { /// Create a new lambda variable expression pub fn new(name: String, field: FieldRef) -> Self { - Self { - name, - field, - } + Self { name, field } } /// Get the variable's name @@ -120,6 +117,9 @@ impl PhysicalExpr for LambdaVariable { } /// Create a lambda variable expression -pub fn lambda_variable(name: impl Into, field: FieldRef) -> Result> { +pub fn lambda_variable( + name: impl Into, + field: FieldRef, +) -> Result> { Ok(Arc::new(LambdaVariable::new(name.into(), field))) } diff --git a/datafusion/physical-expr/src/expressions/mod.rs b/datafusion/physical-expr/src/expressions/mod.rs index 27303e91e4285..0d49910a3554f 100644 --- a/datafusion/physical-expr/src/expressions/mod.rs +++ b/datafusion/physical-expr/src/expressions/mod.rs @@ -27,8 +27,8 @@ mod dynamic_filters; mod in_list; mod is_not_null; mod is_null; -mod lambda_variable; mod lambda; +mod lambda_variable; mod like; mod literal; mod negative; @@ -51,8 +51,8 @@ pub use dynamic_filters::DynamicFilterPhysicalExpr; pub use in_list::{InListExpr, in_list}; pub use is_not_null::{IsNotNullExpr, is_not_null}; pub use is_null::{IsNullExpr, is_null}; -pub use lambda_variable::{lambda_variable, LambdaVariable}; pub use lambda::{LambdaExpr, lambda}; +pub use lambda_variable::{LambdaVariable, lambda_variable}; pub use like::{LikeExpr, like}; pub use literal::{Literal, lit}; pub use negative::{NegativeExpr, negative}; diff --git a/datafusion/physical-expr/src/lib.rs b/datafusion/physical-expr/src/lib.rs index 0345d247001dd..3476f32d4e487 100644 --- a/datafusion/physical-expr/src/lib.rs +++ b/datafusion/physical-expr/src/lib.rs @@ -32,10 +32,10 @@ pub mod binary_map { pub use datafusion_physical_expr_common::binary_map::{ArrowBytesSet, OutputType}; } pub mod async_scalar_function; -pub mod lambda_function; pub mod equivalence; pub mod expressions; pub mod intervals; +pub mod lambda_function; mod partitioning; mod physical_expr; pub mod planner; @@ -70,9 +70,9 @@ pub use datafusion_physical_expr_common::sort_expr::{ PhysicalSortRequirement, }; +pub use lambda_function::LambdaFunctionExpr; pub use planner::{create_physical_expr, create_physical_exprs}; pub use scalar_function::ScalarFunctionExpr; -pub use lambda_function::LambdaFunctionExpr; pub use simplifier::PhysicalExprSimplifier; pub use utils::{conjunction, conjunction_opt, split_conjunction}; diff --git a/datafusion/proto/src/bytes/registry.rs b/datafusion/proto/src/bytes/registry.rs index 270ff88175ae4..880ddc03ecd1f 100644 --- a/datafusion/proto/src/bytes/registry.rs +++ b/datafusion/proto/src/bytes/registry.rs @@ -39,7 +39,9 @@ impl FunctionRegistry for NoRegistry { } fn udlf(&self, name: &str) -> Result> { - plan_err!("No function registry provided to deserialize, so can not deserialize User Defined Lambda Function '{name}'") + plan_err!( + "No function registry provided to deserialize, so can not deserialize User Defined Lambda Function '{name}'" + ) } fn udaf(&self, name: &str) -> Result> { diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 58b272aa709ab..9d417b09b3682 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -629,7 +629,7 @@ pub fn serialize_expr( Expr::LambdaFunction(_) | Expr::Lambda(_) | Expr::LambdaVariable(_) => { return Err(Error::General( "Proto serialization error: Lambda not implemented".to_string(), - )) + )); } }; diff --git a/datafusion/session/src/session.rs b/datafusion/session/src/session.rs index c4fb1867a6815..00ac5534debeb 100644 --- a/datafusion/session/src/session.rs +++ b/datafusion/session/src/session.rs @@ -109,7 +109,7 @@ pub trait Session: Send + Sync { /// Return reference to scalar_functions fn scalar_functions(&self) -> &HashMap>; - + /// Return reference to lambda_functions fn lambda_functions(&self) -> &HashMap>; diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index 1cddff6c4311a..36ee0f6beabd7 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -21,7 +21,8 @@ use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; use arrow::datatypes::DataType; use datafusion_common::{ - DFSchema, Dependency, Diagnostic, HashSet, Result, Span, internal_datafusion_err, internal_err, not_impl_err, plan_datafusion_err, plan_err + DFSchema, Dependency, Diagnostic, HashSet, Result, Span, internal_datafusion_err, + internal_err, not_impl_err, plan_datafusion_err, plan_err, }; use datafusion_expr::{ Expr, ExprSchemable, SortExpr, ValueOrLambda, WindowFrame, WindowFunctionDefinition, diff --git a/datafusion/sql/src/unparser/dialect.rs b/datafusion/sql/src/unparser/dialect.rs index 828024d7d61fe..42cd8b1202b32 100644 --- a/datafusion/sql/src/unparser/dialect.rs +++ b/datafusion/sql/src/unparser/dialect.rs @@ -158,7 +158,7 @@ pub trait Dialect: Send + Sync { ) -> Result> { Ok(None) } - + /// Allows the dialect to override lambda function unparsing if the dialect has specific rules. /// Returns None if the default unparsing should be used, or Some(ast::Expr) if there is /// a custom implementation for the function. diff --git a/datafusion/sql/tests/common/mod.rs b/datafusion/sql/tests/common/mod.rs index 902bb4af72f10..620ddeca5778e 100644 --- a/datafusion/sql/tests/common/mod.rs +++ b/datafusion/sql/tests/common/mod.rs @@ -301,7 +301,7 @@ impl ContextProvider for MockContextProvider { fn udf_names(&self) -> Vec { self.state.scalar_functions.keys().cloned().collect() } - + fn udlf_names(&self) -> Vec { self.state.lambda_functions.keys().cloned().collect() } diff --git a/datafusion/substrait/src/logical_plan/consumer/expr/scalar_function.rs b/datafusion/substrait/src/logical_plan/consumer/expr/scalar_function.rs index 5631df5d46079..9c7804624fec6 100644 --- a/datafusion/substrait/src/logical_plan/consumer/expr/scalar_function.rs +++ b/datafusion/substrait/src/logical_plan/consumer/expr/scalar_function.rs @@ -30,7 +30,7 @@ pub async fn from_scalar_function( f: &ScalarFunction, input_schema: &DFSchema, ) -> Result { - //TODO: handle lambda functions, as they are also encoded as scalar functions + //TODO: handle lambda functions, as they are also encoded as scalar functions let Some(fn_signature) = consumer .get_extensions() .functions diff --git a/datafusion/substrait/src/logical_plan/producer/expr/mod.rs b/datafusion/substrait/src/logical_plan/producer/expr/mod.rs index 37288b5f8c7aa..26eb106702367 100644 --- a/datafusion/substrait/src/logical_plan/producer/expr/mod.rs +++ b/datafusion/substrait/src/logical_plan/producer/expr/mod.rs @@ -151,7 +151,9 @@ pub fn to_substrait_rex( Expr::Unnest(expr) => not_impl_err!("Cannot convert {expr:?} to Substrait"), Expr::LambdaFunction(expr) => producer.handle_lambda_function(expr, schema), Expr::Lambda(expr) => not_impl_err!("Cannot convert {expr:?} to Substrait"), - Expr::LambdaVariable(expr) => not_impl_err!("Cannot convert {expr:?} to Substrait"), + Expr::LambdaVariable(expr) => { + not_impl_err!("Cannot convert {expr:?} to Substrait") + } } } From a59ffe8539d4e0c13c6280bf5e473fa8dc257e49 Mon Sep 17 00:00:00 2001 From: gstvg <28798827+gstvg@users.noreply.github.com> Date: Tue, 17 Mar 2026 18:25:59 -0300 Subject: [PATCH 17/47] simplify LambdaUDF coerce_value_types --- .../expr/src/type_coercion/functions.rs | 50 +++++++++++-------- datafusion/expr/src/udlf.rs | 18 +++---- .../functions-nested/src/array_transform.rs | 17 ++++--- datafusion/sqllogictest/test_files/lambda.slt | 4 +- 4 files changed, 50 insertions(+), 39 deletions(-) diff --git a/datafusion/expr/src/type_coercion/functions.rs b/datafusion/expr/src/type_coercion/functions.rs index d401b5166a0b8..a3d4a4a69c0d0 100644 --- a/datafusion/expr/src/type_coercion/functions.rs +++ b/datafusion/expr/src/type_coercion/functions.rs @@ -167,36 +167,42 @@ pub fn value_fields_with_lambda_udf( LambdaTypeSignature::UserDefined => { let arg_types = current_fields .iter() - .map(|p| match p { - ValueOrLambda::Value(field) => { - ValueOrLambda::Value(field.data_type().clone()) - } - ValueOrLambda::Lambda(_) => ValueOrLambda::Lambda(()), + .filter_map(|p| match p { + ValueOrLambda::Value(field) => Some(field.data_type().clone()), + ValueOrLambda::Lambda(_) => None, }) .collect::>(); let coerced_types = func.coerce_value_types(&arg_types)?; - std::iter::zip(current_fields, coerced_types) - .map(|(field, coerce_to)| match (field, coerce_to) { - (ValueOrLambda::Value(field), Some(coerce_to)) => { - Ok(ValueOrLambda::Value(Arc::new( - field.as_ref().clone().with_data_type(coerce_to), - ))) + if coerced_types.len() != arg_types.len() { + return plan_err!( + "{} coerce_value_types should have returned {} items but returned {}", + func.name(), + arg_types.len(), + coerced_types.len() + ); + } + + let mut coerced_types = coerced_types.into_iter(); + + Ok(current_fields + .iter() + .map(|current_field| match current_field { + ValueOrLambda::Value(field) => { + let data_type = coerced_types + .next() + .expect("coerced_types len should have been checked above"); + + ValueOrLambda::Value(Arc::new( + field.as_ref().clone().with_data_type(data_type), + )) } - (ValueOrLambda::Lambda(v), None) => { - Ok(ValueOrLambda::Lambda(v.clone())) + ValueOrLambda::Lambda(lambda) => { + ValueOrLambda::Lambda(lambda.clone()) } - (ValueOrLambda::Value(_), None) => plan_err!( - "{} coerce_values_types returned None for a value", - func.name() - ), - (ValueOrLambda::Lambda(_), Some(_)) => plan_err!( - "{} coerce_values_types returned Some for a lambda", - func.name() - ), }) - .collect() + .collect()) } LambdaTypeSignature::VariadicAny => Ok(current_fields.to_vec()), LambdaTypeSignature::Any(number) => { diff --git a/datafusion/expr/src/udlf.rs b/datafusion/expr/src/udlf.rs index add2c8a3ca39f..5140f14a99188 100644 --- a/datafusion/expr/src/udlf.rs +++ b/datafusion/expr/src/udlf.rs @@ -606,21 +606,21 @@ pub trait LambdaUDF: Debug + DynEq + DynHash + Send + Sync { /// See the [type coercion module](crate::type_coercion) /// documentation for more details on type coercion /// - /// For example, if your function requires a floating point arguments, but the user calls - /// it like `my_func(1::int)` (i.e. with `1` as an integer), coerce_types can return `[DataType::Float64]` - /// to ensure the argument is converted to `1::double` + /// For example, if your function requires a contiguous list argument, but the user calls + /// it like `my_func(c, v -> v+2)` (i.e. with `c` as a ListView), coerce_types can return `[DataType::List(..)]` + /// to ensure the argument is converted to a List /// /// # Parameters - /// * `arg_types`: The argument types of the arguments this function with + /// * `arg_types`: The argument types of the value arguments of this function, excluding lambdas /// /// # Return value /// A Vec the same length as `arg_types`. DataFusion will `CAST` the function call /// arguments to these specific types. - fn coerce_value_types( - &self, - _arg_types: &[ValueOrLambda], - ) -> Result>> { - not_impl_err!("Function {} does not implement coerce_types", self.name()) + fn coerce_value_types(&self, _arg_types: &[DataType]) -> Result> { + not_impl_err!( + "Function {} does not implement coerce_value_types", + self.name() + ) } /// Returns the documentation for this Lambda UDF. diff --git a/datafusion/functions-nested/src/array_transform.rs b/datafusion/functions-nested/src/array_transform.rs index ff0c45037289b..f7d299ceb6283 100644 --- a/datafusion/functions-nested/src/array_transform.rs +++ b/datafusion/functions-nested/src/array_transform.rs @@ -100,11 +100,16 @@ impl LambdaUDF for ArrayTransform { &self.signature } - fn coerce_value_types( - &self, - arg_types: &[ValueOrLambda], - ) -> Result>> { - let (list, _lambda) = value_lambda_pair(self.name(), arg_types)?; + fn coerce_value_types(&self, arg_types: &[DataType]) -> Result> { + let list = if arg_types.len() == 1 { + &arg_types[0] + } else { + return plan_err!( + "{} function requires 1 value arguments, got {}", + self.name(), + arg_types.len() + ); + }; let coerced = match list { DataType::List(_) @@ -121,7 +126,7 @@ impl LambdaUDF for ArrayTransform { } }; - Ok(vec![Some(coerced), None]) + Ok(vec![coerced]) } fn lambdas_parameters( diff --git a/datafusion/sqllogictest/test_files/lambda.slt b/datafusion/sqllogictest/test_files/lambda.slt index 3c79976e850a7..b44ab3203eb9e 100644 --- a/datafusion/sqllogictest/test_files/lambda.slt +++ b/datafusion/sqllogictest/test_files/lambda.slt @@ -245,13 +245,13 @@ select array_transform(arrow_cast(t.list, 'ListView(Int32)'), a -> a+1) from t; query error select array_transform(); ---- -DataFusion error: Execution error: array_transform function requires 2 arguments, got 0 +DataFusion error: Error during planning: array_transform function requires 1 value arguments, got 0 query error DataFusion error: Error during planning: array_transform expected a list as first argument, got Int64 select array_transform(1, v -> v*2); -query error DataFusion error: Error during planning: array_transform expects a value followed by a lambda, got Lambda\(\(\)\) and Value\(List\(Field \{ data_type: Int64, nullable: true \}\)\) +query error DataFusion error: Error during planning: array_transform expects a value followed by a lambda, got Lambda\(\(\)\) and Value\(Field \{ name: "make_array\(Int64\(1\),Int64\(2\)\)", data_type: List\(Field \{ data_type: Int64, nullable: true \}\), nullable: true \}\) select array_transform(v -> v*2, [1, 2]); query error DataFusion error: Error during planning: lambda defined 3 params but UDF support only 2 From cd22c04eb97c69168f11a9d1931340ef7f5ac34a Mon Sep 17 00:00:00 2001 From: gstvg <28798827+gstvg@users.noreply.github.com> Date: Tue, 17 Mar 2026 23:38:03 -0300 Subject: [PATCH 18/47] remove DOC.md --- DOC.md | 1162 -------------------------------------------------------- 1 file changed, 1162 deletions(-) delete mode 100644 DOC.md diff --git a/DOC.md b/DOC.md deleted file mode 100644 index a88e69a5689e3..0000000000000 --- a/DOC.md +++ /dev/null @@ -1,1162 +0,0 @@ -This PR adds support for lambdas with column capture and the `array_transform` function used to test the lambda implementation. Example usage: - -```sql -CREATE TABLE t as SELECT 2 as n; - -SELECT array_transform([2, 3], v -> v != t.n) from t; - -[false, true] - --- arbitrally nested lambdas are also supported -SELECT array_transform([[[2, 3]]], m -> array_transform(m, l -> array_transform(l, v -> v*2))); - -[[[4, 6]]] -``` - -Some comments on code snippets of this doc show what value each struct, variant or field would hold after planning the first example above. Some literals are simplified pseudo code - -3 new `Expr` variants are added, `LambdaFunction`, owing a new trait `LambdaUDF`, which is like a `ScalarFunction`/`ScalarUDFImpl` with support for lambdas, `Lambda`, for the lambda body and it's parameters names, and `LambdaVariable`, which is like `Column` but for lambdas parameters. The reasoning why not using `Column` instead is later on this doc. - -Their logical representations: - -```rust -enum Expr { - LambdaFunction(LambdaFunction), // array_transform([2, 3], v -> v != t.n) - Lambda(Lambda), // v -> v != t.n - LambdaVariable(LambdaVariable), // v, of the lambda body: v != t.n - ... -} - -// array_transform([2, 3], v -> v != t.n) -struct LambdaFunction { - pub func: Arc, // global instance of array_transform - pub args: Vec, // [Expr::ScalarValue([2, 3]), Expr::Lambda(v -> v != n)] -} - -// v -> v != t.n -struct Lambda { - pub params: Vec, // ["v"] - pub body: Box, // v != n -} - -// v, of the lambda body: v != t.n -struct LambdaVariable { - pub name: String, // "v" - pub field: Option, // Some(Field::new("", DataType::Int32, false)) - pub spans: Spans, -} - -``` - -The example would be planned into a tree like this: - -``` -LambdaFunctionExpression - name: array_transform - children: - 1. ListExpression [2,3] - 2. LambdaExpression - parameters: ["v"] - body: - ComparisonExpression (!=) - left: - LambdaVariableExpression("v", Some(Field::new("", Int32, false))) - right: - ColumnExpression("t.n") -``` - -The physical counterparts definition: - -```rust - -struct LambdaFunctionExpr { - fun: Arc, // global instance of array_transform - name: String, // "array_transform" - args: Vec>, // [LiteralExpr([2, 3], LambdaExpr("v -> v != t.n"))] - return_field: FieldRef, // Field::new("", DataType::new_list(DataType::Boolean, false), false) - config_options: Arc, -} - - -struct LambdaExpr { - params: Vec, // ["v"] - body: Arc, // v -> v != t.n -} - -struct LambdaVariable { - name: String, // "v", of the lambda body: v != t.n - field: FieldRef, // Field::new("", DataType::Int32, false) - value: Option, // reasoning later on -} -``` - -Note: For those who primarly wants to check if this lambda implementation supports their usecase and don't want to spend much time here, it's okay to skip most collapsed blocks, as those serve mostly to help code reviewers, with the exception of `LambdaUDF` and the `array_transform` implementation of `LambdaUDF` relevant methods, collapsed due to their size - -
Physical planning implementation is trivial: - -```rust -fn create_physical_expr( - e: &Expr, - input_dfschema: &DFSchema, - execution_props: &ExecutionProps, -) -> Result> { - let input_schema = input_dfschema.as_arrow(); - - match e { - ... - Expr::LambdaFunction(LambdaFunction { func, args}) => { - let physical_args = - create_physical_exprs(args, input_dfschema, execution_props)?; - - Ok(Arc::new(LambdaFunctionExpr::try_new( - Arc::clone(func), - physical_args, - input_schema, - config_options: ... // irrelevant - )?)) - } - Expr::Lambda(Lambda { params, body }) => Ok(Arc::new(LambdaExpr::new( - params.clone(), - create_physical_expr(body, input_dfschema, execution_props)?, - ))), - Expr::LambdaVariable(LambdaVariable { - name, - field, - spans: _, - }) => lambda_variable( - name, - Arc::clone(field), - ), - } -} -``` - -
-
- -The added `LambdaUDF` trait is almost a clone of `ScalarUDFImpl`, with the exception of: -1. `return_field_from_args` and `invoke_with_args`, where now `args.args` is a list of enums with two variants: `Value` or `Lambda` instead of a list of values -2. the addition of `lambdas_parameters`, which return a `Field` for each parameter supported for every lambda argument based on the `Field` of the non lambda arguments -3. the removal of `return_field` and the deprecated ones `is_nullable` and `display_name`. - -
LambdaUDF - -```rust - -trait LambdaUDF { - /// Returns a list of the same size as args where each value is the logic below applied to value at the correspondent position in args: - /// - /// If it's a value, return None - /// If it's a lambda, return the list of all parameters that that lambda supports - /// based on the Field of the non-lambda arguments - /// - /// Example for array_transform: - /// - /// `array_transform([2, 8], v -> v > 4)` - /// - /// let lambdas_parameters = array_transform.lambdas_parameters(&[ - /// ValueOrLambdaParameter::Value(Field::new("", DataType::new_list(DataType::Int32, false)))]), // the Field associated with the literal `[2, 8]` - /// ValueOrLambdaParameter::Lambda, // A lambda - /// ]?; - /// - /// assert_eq!( - /// lambdas_parameters, - /// vec![ - /// None, // it's a value, return None - /// // it's a lambda, return it's supported parameters, regardless of how many are actually used - /// Some(vec![ - /// Field::new("", DataType::Int32, false), // the value being transformed, - /// Field::new("", DataType::Int32, false), // the 1-based index being transformed, not used on the example above, but implementations doesn't need to care about it - /// ]) - /// ] - /// ) - fn lambdas_parameters( - &self, - args: &[ValueOrLambdaParameter], - ) -> Result>>>; - fn return_field_from_args(&self, args: LambdaReturnFieldArgs) -> Result; - fn invoke_with_args(&self, args: LambdaFunctionArgs) -> Result; - // ... omitted methods that are similar in ScalarUDFImpl -} - -pub enum ValueOrLambdaParameter { - /// A columnar value with the given field - Value(FieldRef), - /// A lambda - Lambda, -} - -/// Information about arguments passed to the function -/// -/// This structure contains metadata about how the function was called -/// such as the type of the arguments, any scalar arguments and if the -/// arguments can (ever) be null -/// -/// See [`LambdaUDF::return_field_from_args`] for more information -pub struct LambdaReturnFieldArgs<'a> { - /// The data types of the arguments to the function - /// - /// If argument `i` to the function is a lambda, it will be the field returned by the - /// lambda when executed with the arguments returned from `LambdaUDF::lambdas_parameters` - /// - /// For example, with `array_transform([1], v -> v == 5)` - /// this field will be `[ - // ValueOrLambdaField::Value(Field::new("", DataType::new_list(DataType::Int32, false), false)), - // ValueOrLambdaField::Lambda(Field::new("", DataType::Boolean, false)) - // ]` - pub arg_fields: &'a [ValueOrLambdaField], - /// Is argument `i` to the function a scalar (constant)? - /// - /// If the argument `i` is not a scalar, it will be None - /// - /// For example, if a function is called like `my_function(column_a, 5)` - /// this field will be `[None, Some(ScalarValue::Int32(Some(5)))]` - pub scalar_arguments: &'a [Option<&'a ScalarValue>], -} - -/// A tagged FieldRef indicating whether it correspond the field of a value or the field of the output of a lambda argument -pub enum ValueOrLambdaField { - /// The FieldRef of a ColumnarValue argument - Value(FieldRef), - /// The return FieldRef of the lambda body when evaluated with the parameters from LambdaUDF::lambda_parameters - Lambda(FieldRef), -} - -/// Arguments passed to [`LambdaUDF::invoke_with_args`] when invoking a -/// lambda function. -pub struct LambdaFunctionArgs { - /// The evaluated arguments to the function - pub args: Vec, - /// Field associated with each arg, if it exists - pub arg_fields: Vec, - /// The number of rows in record batch being evaluated - pub number_rows: usize, - /// The return field of the lambda function returned (from `return_type` - /// or `return_field_from_args`) when creating the physical expression - /// from the logical expression - pub return_field: FieldRef, - /// The config options at execution time - pub config_options: Arc, -} - -/// A lambda argument to a LambdaFunction -pub struct LambdaFunctionLambdaArg { - /// The parameters defined in this lambda - /// - /// For example, for `array_transform([2], v -> -v)`, - /// this will be `vec![Field::new("v", DataType::Int32, true)]` - pub params: Vec, - /// The body of the lambda - /// - /// For example, for `array_transform([2], v -> -v)`, - /// this will be the physical expression of `-v` - pub body: Arc, - /// A RecordBatch containing at least the captured columns inside this lambda body, if any - /// Note that it may contain additional, non-specified columns, - /// but that's implementation detail and should not be relied upon - /// - /// For example, for `array_transform([2], v -> v + t.a + t.b)`, - /// this will be a `RecordBatch` with at least two columns, `t.a` and `t.b` - pub captures: Option, -} - -// An argument to a LambdaUDF -pub enum ValueOrLambda { - Value(ColumnarValue), - Lambda(LambdaFunctionLambdaArg), -} -``` - - -
- -
array_transform lambdas_parameters implementation - -```rust -impl LambdaUDF for ArrayTransform { - fn lambdas_parameters( - &self, - args: &[ValueOrLambdaParameter], - ) -> Result>>> { - // list is the field of [2, 3]: Field::new("", DataType::new_list(DataType::Int32, false), false) - let [ValueOrLambdaParameter::Value(list), ValueOrLambdaParameter::Lambda] = args - else { - return exec_err!( - "{} expects a value follewed by a lambda, got {:?}", - self.name(), - args - ); - }; - - // the field of [2, 3] inner values: Field::new("", DataType::Int32, false) - let (field, index_type) = match list.data_type() { - DataType::List(field) => (field, DataType::Int32), - DataType::LargeList(field) => (field, DataType::Int64), - DataType::FixedSizeList(field, _) => (field, DataType::Int32), - _ => return exec_err!("expected list, got {list}"), - }; - - // we don't need to omit the index in the case the lambda don't specify, e.g. array_transform([], v -> v*2), - // nor check whether the lambda contains more than two parameters, e.g. array_transform([], (v, i, j) -> v+i+j), - // as datafusion will do that for us - let value = Field::new("", field.data_type().clone(), field.is_nullable()) - .with_metadata(field.metadata().clone()); - let index = Field::new("", index_type, false); - - Ok(vec![None, Some(vec![value, index])]) - } -} -``` - -
- -
array_transform return_field_from_args implementation - -```rust -impl LambdaUDF for ArrayTransform { - fn return_field_from_args( - &self, - args: datafusion_expr::LambdaReturnFieldArgs, - ) -> Result> { - // [ - // Field::new("", DataType::new_list(DataType::Int32, false), false), - // Field::new("", DataType::Boolean, false), - // ] - let [ValueOrLambdaField::Value(list), ValueOrLambdaField::Lambda(lambda)] = - take_function_args(self.name(), args.arg_fields)? - else { - return exec_err!( - "{} expects a value follewed by a lambda, got {:?}", - self.name(), - args - ); - }; - - // lambda is the return_field of the lambda body - // when evaluated with the parameters from lambdas_parameters - let field = Arc::new(Field::new( - Field::LIST_FIELD_DEFAULT_NAME, - lambda.data_type().clone(), - lambda.is_nullable(), - )); - - let return_type = match list.data_type() { - DataType::List(_) => DataType::List(field), - DataType::LargeList(_) => DataType::LargeList(field), - DataType::FixedSizeList(_, size) => DataType::FixedSizeList(field, *size), - other => plan_err!("expected list, got {other}"), - }; - - Ok(Arc::new(Field::new("", return_type, list.is_nullable()))) - } -} -``` - -
- -
array_transform invoke_with_args implementation - - -```rust -impl LambdaUDF for ArrayTransform { - fn invoke_with_args(&self, args: LambdaFunctionArgs) -> Result { - let [list_value, lambda] = take_function_args(self.name(), &args.args)?; - - // list = [2, 3] - // lambda = LambdaFunctionLambdaArg { - // params: vec![Field::new("v", DataType::Int32, false)], - // body: PhysicalExpr("v != t.n"),// the physical expression of the lambda *body*, and not the lambda itself: this is not a LambdaExpr. - // captures: Some(record_batch!("t.n", Int32, [2])) - // } - let (ValueOrLambda::Value(list_value), ValueOrLambda::Lambda(lambda)) = - (list_value, lambda) - else { - return exec_err!( - "{} expects a value followed by a lambda, got {} and {}", - self.name(), - list_value, - lambda, - ); - }; - - let list_array = list_value.to_array(args.number_rows)?; - let list_values = match list_array.data_type() { - DataType::List(_) => list_array.as_list::().values(), - DataType::LargeList(_) => list_array.as_list::().values(), - DataType::FixedSizeList(_, _) => list_array.as_fixed_size_list().values(), - other => exec_err!("expected list, got {other}") - } - - // if any column got captured, we need to adjust it to the values arrays, - // duplicating values of list with mulitple values and removing values of empty lists - // list_indices is not cheap so is important to avoid it when no column is captured - let adjusted_captures = lambda - .captures - .as_ref() - //list_indices return the row_number for each sublist element: [[1, 2], [3], [4]] => [0,0,1,2], not included here - .map(|captures| take_record_batch(captures, &list_indices(&list_array)?)) - .transpose()? - .unwrap_or_else(|| { - RecordBatch::try_new_with_options( - Arc::new(Schema::empty()), - vec![], - &RecordBatchOptions::new().with_row_count(Some(list_values.len())), - ) - .unwrap() - }); - - // by using closures, bind_lambda_variables can evaluate only the needed ones avoiding unnecessary computations - let values_param = || Ok(Arc::clone(list_values)); - //elements_indices return the index of each element within its sublist: [[5, 3], [7, 1, 1]] => [1, 2, 1, 2, 3], not included here - let indices_param = || elements_indices(&list_array); - - let binded_body = bind_lambda_variables( - Arc::clone(&lambda.body), - &lambda.params, - &[&values_param, &indices_param], - )?; - - // call the transforming expression with the record batch - let transformed_values = binded_body - .evaluate(&adjusted_captures)? - .into_array(list_values.len())?; - - let field = match args.return_field.data_type() { - DataType::List(field) - | DataType::LargeList(field) - | DataType::FixedSizeList(field, _) => Arc::clone(field), - _ => { - return exec_err!( - "{} expected ScalarFunctionArgs.return_field to be a list, got {}", - self.name(), - args.return_field - ) - } - }; - - let transformed_list = match list_array.data_type() { - DataType::List(_) => { - let list = list_array.as_list(); - - Arc::new(ListArray::new( - field, - list.offsets().clone(), - transformed_values, - list.nulls().cloned(), - )) as ArrayRef - } - DataType::LargeList(_) => { - let large_list = list_array.as_list(); - - Arc::new(LargeListArray::new( - field, - large_list.offsets().clone(), - transformed_values, - large_list.nulls().cloned(), - )) - } - DataType::FixedSizeList(_, value_length) => { - Arc::new(FixedSizeListArray::new( - field, - *value_length, - transformed_values, - list_array.as_fixed_size_list().nulls().cloned(), - )) - } - other => exec_err!("expected list, got {other}")?, - }; - - Ok(ColumnarValue::Array(transformed_list)) - } -} -``` - -
- -
How relevant LambdaUDF methods would be called and what they would return during planning and evaluation of the example - - -```rust -// this is called at sql planning -let lambdas_parameters = lambda_udf.lambdas_parameters(&[ - ValueOrLambdaParameter::Value(Field::new("", DataType::new_list(DataType::Int32, false), false)), // the Field of the [2, 3] literal - ValueOrLambdaParameter::Lambda, // A unspecified lambda. On the example, v -> v != t.n -])?; - -assert_eq!( - lambdas_parameters, - vec![ - // the [2, 3] argument, not a lambda so no parameters - None, - // the parameters that *can* be declared on the lambda, and not only - // those actually declared: the implementation doesn't need to care - // about it - Some(vec![ - Field::new("", DataType::Int32, false), // the list inner value - Field::new("", DataType::Int32, false), // the 1-based index of the element being transformed - ])] -); - - - -// this is called every time ExprSchemable is called on a LambdaFunction -let return_field = array_transform.return_field_from_args(&LambdaReturnFieldArgs { - arg_fields: &[ - ValueOrLambdaField::Value(Field::new("", DataType::new_list(DataType::Int32, false), false)), - ValueOrLambdaField::Lambda(Field::new("", DataType::Boolean, false)), // the return_field of the expression "v != t.n" when "v" is of the type returned in lambdas_parameters - ], - scalar_arguments // irrelevant -})?; - -assert_eq!(return_field, Field::new("", DataType::new_list(DataType::Boolean, false), false)); - - - -let value = array_transform.evaluate(&LambdaFunctionArgs { - args: vec![ - ValueOrLambda::Value(List([2, 3])), - ValueOrLambda::Lambda(LambdaFunctionLambdaArg { - params: vec![Field::new("v", DataType::Int32, false)], - body: PhysicalExpr("v != t.n"),// the physical expression of the lambda *body*, and not the lambda itself: this is not a LambdaExpr. - captures: Some(record_batch!("t.n", Int32, [2])) - }), - ], - arg_fields, // same as above - number_rows: 1, - return_field, // same as above - config_options, // irrelevant -})?; - -assert_eq!(value, BooleanArray::from([false, true])) -``` - -
-
-
- -A pair LambdaUDF/LambdaUDFImpl like ScalarFunction was not used because those exist only [to maintain backwards compatibility with the older API](https://docs.rs/datafusion/latest/datafusion/logical_expr/struct.ScalarUDF.html#api-note) #8045 - -LambdaFunction invocation: - -Instead of evaluating all it's arguments as ScalarFunction, LambdaFunction does the following: - -1. If it's a non lambda argument, evaluate as usual, and provide the resulting `ColumnarValue` to `LambdaUDF::evaluate` as a `ValueOrLambda::Value` -2. If it's a lambda, construct a `LambdaFunctionLambdaArg` containing the lambda body physical expression and a record batch containing any captured columns as a `ValueOrLambda::Lambda` and provide it to `LambdaUDF::evaluate`. To avoid costly copies of uncaptured columns, we swap them with a `NullArray` while keeping the number of columns on the batch the same so captured columns indices are kept stable across the whole tree. The recent #18329 instead projects-out uncaptured columns and rewrites the expr adjusting columns indexes. If that is preferrable we can generalize that implementation and use it here too. - -
LambdaFunction evalution - -```rust - -impl PhysicalExpr for LambdaFunctionExpr { - fn evaluate(&self, batch: &RecordBatch) -> Result { - let args = self.args - .map(|arg| { - match arg.as_any().downcast_ref::() { - Some(lambda) => { - // helper method that returns the indices of the captured columns. In the example, the only column available (index 0) is captured, so this would be HashSet(0) - let captures = lambda.captures(); - - let captures = if !captures.is_empty() { - let (fields, columns): (Vec<_>, _) = std::iter::zip( - batch.schema_ref().fields(), - batch.columns(), - ) - .enumerate() - .map(|(column_index, (field, column))| { - if captures.contains(&column_index) { - (Arc::clone(field), Arc::clone(column)) - } else { - ( - Arc::new(Field::new( - field.name(), - DataType::Null, - false, - )), - Arc::new(NullArray::new(column.len())) as _, - ) - } - }) - .unzip(); - - let schema = Arc::new(Schema::new(fields)); - - Some(RecordBatch::try_new(schema, columns)?) - } else { - None - }; - - Ok(ValueOrLambda::Lambda(LambdaFunctionLambdaArg { - params, // irrelevant, - body: Arc::clone(lambda.body()), // use the lambda body and not the lambda itself - captures, - })) - } - None => Ok(ValueOrLambda::Value(arg.evaluate(batch)?)), - } - }) - .collect::>>()?; - - // evaluate the function - let output = self.fun.invoke_with_args(LambdaFunctionArgs { - args, - arg_fields, // irrelevant - number_rows: batch.num_rows(), - return_field: Arc::clone(&self.return_field), - config_options: Arc::clone(&self.config_options), - })?; - - Ok(output) - } -} - -``` - -
-
- -Why `LambdaVariable` and not `Column`: - -Existing tree traversals that operate on columns would break if some column nodes referenced to a lambda parameter and not a real column. In the example query, projection pushdown would try to push the lambda parameter "v", which won't exist in table "t". - -Example of code of another traversal that would break: - -```rust -fn minimize_join_filter(expr: Arc, ...) -> JoinFilter { - let mut used_columns = HashSet::new(); - expr.apply(|expr| { - if let Some(col) = expr.as_any().downcast_ref::() { - // if this is a lambda column, this function will break - used_columns.insert(col.index()); - } - Ok(TreeNodeRecursion::Continue) - }); - ... -} -``` - -Furthermore, the implemention of `ExprSchemable` and `PhysicalExpr::return_field` for `Column` expects that the schema it receives as a argument contains an entry for its name, which is not the case for lambda parameters. - -By including a `FieldRef` on `LambdaVariable` that should be resolved either during construction time, as in the sql planner, or later by the an `AnalyzerRule`, `ExprSchemable` and `PhysicalExpr::return_field` simply return it's own Field: - -
LambdaVariable ExprSchemable and PhysicalExpr::return_field implementation - -```rust -impl ExprSchemable for Expr { - fn to_field( - &self, - schema: &dyn ExprSchema, - ) -> Result<(Option, Arc)> { - let (relation, schema_name) = self.qualified_name(); - let field = match self { - Expr::LambdaVariable(l) => Ok(Arc::clone(&l.field.ok_or_else(|| plan_err!("Unresolved LambdaVariable {}", l.name)))), - ... - }?; - - Ok(( - relation, - Arc::new(field.as_ref().clone().with_name(schema_name)), - )) - } - ... -} - -impl PhysicalExpr for LambdaVariable { - fn return_field(&self, _input_schema: &Schema) -> Result { - Ok(Arc::clone(&self.field)) - } - ... -} -``` - -
-
- -For reference, [Spark](https://github.com/apache/spark/blob/8b68a172d34d2ed9bd0a2deefcae1840a78143b6/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala#L77) and [Substrait](https://substrait.io/expressions/lambda_expressions/#parameter-references) also use a specialized node instead of a regular column - -There's also discussions on making every expr own it's type: #18845, #12604 - -
Possible fixes discarded due to complexity, requiring downstream changes and implementation size: - -1. Add a new set of TreeNode methods that provides the set of lambdas parameters names seen during the traversal, so column nodes can be tested if they refer to a regular column or to a lambda parameter. Any downstream user that wants to support lambdas would need use those methods instead of the existing ones. This also would add 1k+ lines to the PR. - -```rust -impl Expr { - pub fn transform_with_lambdas_params< - F: FnMut(Self, &HashSet) -> Result>, - >( - self, - mut f: F, - ) -> Result> {} -} -``` - -How minimize_join_filter would looks like: - - -```rust -fn minimize_join_filter(expr: Arc, ...) -> JoinFilter { - let mut used_columns = HashSet::new(); - expr.apply_with_lambdas_params(|expr, lambdas_params| { - if let Some(col) = expr.as_any().downcast_ref::() { - // dont include lambdas parameters - if !lambdas_params.contains(col.name()) { - used_columns.insert(col.index()); - } - } - Ok(TreeNodeRecursion::Continue) - }) - ... -} -``` - -2. Add a flag to the Column node indicating if it refers to a lambda parameter. Still requires checking for it on existing tree traversals that works on Columns (30+) and also downstream. - -```rust -//logical -struct Column { - pub relation: Option, - pub name: String, - pub spans: Spans, - pub is_lambda_parameter: bool, -} - -//physical -struct Column { - name: String, - index: usize, - is_lambda_parameter: bool, -} -``` - - -How minimize_join_filter would look like: - -```rust -fn minimize_join_filter(expr: Arc, ...) -> JoinFilter { - let mut used_columns = HashSet::new(); - expr.apply(|expr| { - if let Some(col) = expr.as_any().downcast_ref::() { - // dont include lambdas parameters - if !col.is_lambda_parameter { - used_columns.insert(col.index()); - } - } - Ok(TreeNodeRecursion::Continue) - }) - ... -} -``` - - -1. Add a new set of TreeNode methods that provides a schema that includes the lambdas parameters for the scope of the node being visited/transformed: - -```rust -impl Expr { - pub fn transform_with_schema< - F: FnMut(Self, &DFSchema) -> Result>, - >( - self, - schema: &DFSchema, - f: F, - ) -> Result> { ... } - ... other methods -} -``` - -For any given LambdaFunction found during the traversal, a new schema is created for each lambda argument that contains it's parameter, returned from LambdaUDF::lambdas_parameters -How it would look like: - -```rust - -pub fn infer_placeholder_types(self, schema: &DFSchema) -> Result<(Expr, bool)> { - let mut has_placeholder = false; - // Provide the schema as the first argument. - // Transforming closure receive an adjusted_schema as argument - self.transform_with_schema(schema, |mut expr, adjusted_schema| { - match &mut expr { - // Default to assuming the arguments are the same type - Expr::BinaryExpr(BinaryExpr { left, op: _, right }) => { - // use adjusted_schema and not schema. Those expressions may contain - // columns referring to a lambda parameter, which Field would only be - // available in adjusted_schema and not in schema - rewrite_placeholder(left.as_mut(), right.as_ref(), adjusted_schema)?; - rewrite_placeholder(right.as_mut(), left.as_ref(), adjusted_schema)?; - } - .... - -``` - -2. Make available trought LogicalPlan and ExecutionPlan nodes a schema that includes all lambdas parameters from all expressions owned by the node, and use this schema for tree traversals. For nodes which won't own any expression, the regular schema can be returned - - -```rust -impl LogicalPlan { - fn lambda_extended_schema(&self) -> &DFSchema; -} - -trait ExecutionPlan { - fn lambda_extended_schema(&self) -> &DFSchema; -} - -//usage -impl LogicalPlan { - pub fn replace_params_with_values( - self, - param_values: &ParamValues, - ) -> Result { - self.transform_up_with_subqueries(|plan| { - // use plan.lambda_extended_schema() containing lambdas parameters - // instead of plan.schema() which wont - let lambda_extended_schema = Arc::clone(plan.lambda_extended_schema()); - let name_preserver = NamePreserver::new(&plan); - plan.map_expressions(|e| { - // if this expression is child of lambda and contain columns referring it's parameters - // the lambda_extended_schema already contain them - let (e, has_placeholder) = e.infer_placeholder_types(&lambda_extended_schema)?; - .... - -``` -
-
- -`LambdaVariable` evaluation, current implementation: - -The physical `LambdaVariable` contains an optional `ColumnarValue` that must be binded for each batch before evaluation with the helper function `bind_lambda_variables`, which rewrites the whole lambda body, binding any variable of the tree. - -
LambdaVariable::evaluate - -```rust -impl PhysicalExpr for LambdaVariable { - fn evaluate(&self, _batch: &RecordBatch) -> Result { - self.value.clone().ok_or_else(|| exec_datafusion_err!("Physical LambdaVariable {} unbinded value", self.name)) - } -} -``` - -
-
- -Unbinded: -``` -LambdaExpression - parameters: ["v"] - body: - ComparisonExpression(!=) - left: - LambdaVariableExpression("v", Field::new("", Int32, false), None) - right: - ColumnExpression("n") -``` - -After binding: - -``` -LambdaExpression - parameters: ["v"] - body: - ComparisonExpression(!=) - left: - LambdaVariableExpression("v", Field::new("", Int32, false), Some([2, 3])) - right: - ColumnExpression("n") -``` - -Alternative: - -Make the `LambdaVariable` evaluate it's value from the batch passed to `PhysicalExpr::evaluate` as a regular column. For that, instead of binding the body, the `LambdaUDF` implementation would merge the captured batch of a lambda with the values of it's parameters. So that it happen via an index as a regular column, the schema used plan to physical `LambdaVariable` must contain the lambda parameters. This would be the only place during planning that a schema would contain those parameters. Otherwise it only can get the value from the batch via name instead of index - -1. Add a index to LambdaVariable, similar to Column, and remove the optional value. - -```rust -struct LambdaVariable { - name: String, // "v", of the lambda body: v != t.n - field: FieldRef, // Field::new("", DataType::Int32, false) - index: usize, // 1 -} -``` - -2. Insert the lambda parameters only at the Schema used to do the physical planning, to compute the index of a LambdaVariable - -
how physical planning would look like - -```rust -fn create_physical_expr( - e: &Expr, - input_dfschema: &DFSchema, - execution_props: &ExecutionProps, -) -> Result> { - let input_schema = input_dfschema.as_arrow(); - - match e { - ... - Expr::LambdaFunction(LambdaFunction { func, args}) => { - let args_metadata = args.iter() - .map(|arg| if arg.is::() { - Ok(ValueOrLambdaParameter::Lambda) - } else { - Ok(ValueOrLambdaParameter::Value(arg.to_field(input_dfschema)?)) - }) - .collect()?; - - let lambdas_parameters = func.lambdas_parameters(&args_metadata)?; - - let physical_args = std::iter::zip(args, lambdas_parameters) - .map(|(arg, lambda_parameters)| { - match (arg.downcast_ref::(), lambda_parameters) { - (Some(lambda), Some(lambda_parameters)) => { - let extended_dfschema = merge_schema_and_parameters(input_dfschame, lambda_parameters)?; - - create_physical_expr(body, extended_dfschema, execution_props) - } - (None, None) => create_physical_expr(arg, input_dfschema, execution_props), - (Some(_), None) => plan_err!("lambdas_parameters returned None for a lambda") - (None, Some(_)) => plan_err!("lambdas_parameters returned Some for a non lambda") - } - }) - .collect()?; - - Ok(Arc::new(LambdaFunctionExpr::try_new( - Arc::clone(func), - physical_args, - input_schema, - config_options: ... // irrelevant - )?)) - } - } -} -``` - -
-
- -3. Insert the lambda parameters values into the RecordBatch during the evaluation phase: the LambdaUDF, instead of binding the lambda body variables, inserts it's parameters on the captured RecordBatch it receives on LambdaFunctionLambdaArg. - -How ArrayTransform::invoke_with_args would look like: - -```rust - ... - let values_param = || Ok(Arc::clone(list_values)); - let indices_param = || elements_indices(&list_array); - - let merged_batch = merge_captures_with_params( - adjusted_captures, - &lambda.params, - &[&values_param, &indices_param], - )?; - - // call the transforming expression with the record batch - let transformed_values = lambda.body - .evaluate(&merged_batch)? - .into_array(list_values.len())?; - - ... -``` - -
- -Why is `LambdaVariable` `Field` is an `Option`? - -So expr_api users can construct a LambdaVariable just by using it's name, without having to set it's field. An `AnalyzerRule` will then set the `LambdaVariable` field based on the returned values from `LambdaUDF::lambdas_parameters` of any `LambdaFunction` it finds while traversing down a expr tree. We may include that rule on the default rules list for when the plan/expression tree is transformed by another rule in a way that changes the types of non lambda arguments of a lambda function, as it may change the types of it's lambda parameters, which would render `LambdaVariable` field's out of sync, as the rule would fix it. Or to not increase planning time we don't include it by default and instruct `expr_api` users to add it manually if needed - - - -```rust -array_transform( - col("my_array"), - lambda( - vec!["current_value"], - 2 * lambda_variable("current_value") - ) -) - -//instead of - -array_transform( - col("my_array"), - lambda( - vec!["current_value"], - 2 * lambda_variable("current_value", Field::new("", DataType::Int32, false)) - ) -) -``` - - -Why set `LambdaVariable` field during sql planning if it's optional and can be set later via an `AnalyzerRule`? - -Some parts of sql planning checks the type/nullability of the already planned children expression of the expr it's planning, and would error if doing so on a unresolved `LambdaVariable` -Take as example this expression: `array_transform([[0, 1]], v -> v[1])`. `FieldAccess` `v[1]` planning is handled by the `ExprPlanner` `FieldAccessPlanner`, which checks the datatype of `v`, a lambda variable, which `ExprSchemable` implementation depends on it's field being resolved, and not on the `PlannerContext` schema, requiring sql planner to plan `LambdaVariables` with a resolved field - - -
FieldAccessPlanner - -```rust -pub struct FieldAccessPlanner; - -impl ExprPlanner for FieldAccessPlanner { - fn plan_field_access( - &self, - expr: RawFieldAccessExpr, // "v[1]" - schema: &DFSchema, - ) -> Result> { - // { "v", "[1]" } - let RawFieldAccessExpr { expr, field_access } = expr; - - match field_access { - ... - // expr[idx] ==> array_element(expr, idx) - GetFieldAccess::ListIndex { key: index } => { - match expr { - ... - // ExprSchemable::get_type called - _ if matches!(expr.get_type(schema)?, DataType::Map(_, _)) => { - Ok(PlannerResult::Planned(Expr::ScalarFunction( - ScalarFunction::new_udf( - get_field_inner(), - vec![expr, *index], - ), - ))) - } - } - } - } - } -} -``` - -
-
- - Therefore we can't plan all arguments on a single pass, and must first plan the non-lambda arguments, collect their types and nullability, pass them to `LambdaUDF::lambdas_parameters`, which will derive the type of it's lambda parameters based on the type of it's non-lambda argument, and return it to the planner, which, for each unplanned lambda argument, will create a new `PlannerContext` via `with_lambda_parameters`, which contains a mapping of lambdas parameters names to it's type. Then, when planning a `ast::Identifier`, it first check whether a lambda parameter with the given name exists, and if so, plans it into a `Expr::LambdaVariable` with a resolved field, otherwise plan it into a regular `Expr::Column`. - - - -
sql planning - - -```rust -struct PlannerContext { - /// The parameters of all lambdas seen so far - lambdas_parameters: HashMap, - // ... omitted fields -} - -impl PlannerContext { - pub fn with_lambda_parameters( - mut self, - arguments: impl IntoIterator, - ) -> Self { - self.lambdas_parameters - .extend(arguments.into_iter().map(|f| (f.name().clone(), f))); - - self - } -} - -// copied from sqlparser -struct LambdaFunction { - pub params: OneOrManyWithParens, // One("v") - pub body: Box, // v != t.n -} - -// copied from sqlparser -enum OneOrManyWithParens { - One(T), // "v" - Many(Vec), -} - -/// the planning would happens as the following: - -enum ExprOrLambda { - Expr(Expr), // planned [2, 3] - Lambda(ast::LambdaFunction), // unplanned v -> v != t.n -} - -impl SqlToRel { - // example function, won't exist - fn plan_array_transform(&self, array_transform: Arc, args: Vec, schema: &DFSchema, planner_context: &mut PlannerContext) -> Result { - let args = args.into_iter() - .map(|arg| match arg { - ast::Expr::LambdaFunction(l) => Ok(ExprOrLambda::Lambda(l)),//skip planning until we plan non lambda args - arg => Ok(ExprOrLambda::Expr( - self.sql_fn_arg_to_logical_expr_with_name( - arg, - schema, - planner_context, - )?, - )) - }) - .collect::>>()?; - - let args_metadata = args.iter() - .map(|arg| match arg { - Expr(expr) => Ok(ValueOrLambda::Value(expr.to_field(schema)?)), - Lambda(_) => Ok(ValueOrLambda::Lambda), - }) - .collect::>>()?; - - let lambdas_parameters = array_transform.lambdas_parameters(&args_metadata)?; - - let args = std::iter::zip(args, lambdas_parameters) - .map(|(arg, lambdas_parameters)| match (arg, lambdas_parameters) { - (ExprOrLambda::Expr(planned_expr), None) => Ok(planned_expr), - (ExprOrLambda::Lambda(unplanned_lambda), Some(lambda_parameters)) => { - let params = - unplanned_lambda.params - .iter() - .map(|p| p.value.clone()) - .collect(); - - let lambda_parameters = lambda_params - .into_iter() - .zip(¶ms) - .map(|(field, name)| Arc::new(field.with_name(name))); - - let mut planner_context = planner_context - .clone() - .with_lambda_parameters(lambda_parameters); - - Ok(( - Expr::Lambda(Lambda { - params, - body: Box::new(self.sql_expr_to_logical_expr( - *lambda.body, - schema, - &mut planner_context, - )?), - }), - None, - )) - } - (ExprOrLambda::Expr(planned_expr), Some(lambda_parameters)) => plan_err!("lambdas_parameters returned Some for a value"), - (ExprOrLambda::Lambda(unplanned_lambda), None) => plan_err!("lambdas_parameters returned None for a lambda"), - }) - .collect::>>()?; - - Ok(Expr::LambdaFunction(LambdaFunction { - func: array_transform, - args, - })) - } - - fn sql_identifier_to_expr( - &self, - id: ast::Ident, - schema: &DFSchema, - planner_context: &mut PlannerContext, - ) -> Result { - // simplified implementation - if let Some(field) = planner_context.lambdas_parameters.get(id) { - Ok(Expr::LambdaVariable(LambdaVariable { - name: id, // "v" - field, // Field::new("", DataType::Int32, false) - })) - } else { - Ok(Expr::Column(Column::new(id))) - } - } -} - -``` - -
-
From 0188d4044a9f2e9a65d30df058713cc174658cc1 Mon Sep 17 00:00:00 2001 From: gstvg <28798827+gstvg@users.noreply.github.com> Date: Tue, 17 Mar 2026 23:38:33 -0300 Subject: [PATCH 19/47] add physical lambda function comments --- datafusion/physical-expr/src/lambda_function.rs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/datafusion/physical-expr/src/lambda_function.rs b/datafusion/physical-expr/src/lambda_function.rs index 197342a844cd3..b5c7cadf77e1c 100644 --- a/datafusion/physical-expr/src/lambda_function.rs +++ b/datafusion/physical-expr/src/lambda_function.rs @@ -304,6 +304,9 @@ impl PhysicalExpr for LambdaFunctionExpr { { (Arc::clone(field), Arc::clone(column)) } else { + // To avoid costly copies of uncaptured columns, we swap them with a NullArray + // while keeping the number of columns on the batch the same + // so captured columns indices are kept stable across the whole tree. ( Arc::new(Field::new( field.name(), From 6f2c92bd0858aa57cf5034a830a8df23829ac9b1 Mon Sep 17 00:00:00 2001 From: gstvg <28798827+gstvg@users.noreply.github.com> Date: Wed, 18 Mar 2026 01:11:25 -0300 Subject: [PATCH 20/47] remove secondary lambda features to be added later --- datafusion/common/src/lib.rs | 1 - datafusion/common/src/utils/mod.rs | 386 +----------------- datafusion/expr/src/expr.rs | 298 +------------- datafusion/expr/src/expr_fn.rs | 12 +- datafusion/expr/src/expr_schema.rs | 19 +- datafusion/expr/src/logical_plan/plan.rs | 13 +- datafusion/expr/src/udlf.rs | 195 +-------- .../functions-nested/src/array_transform.rs | 50 +-- .../simplify_expressions/expr_simplifier.rs | 76 ++-- .../physical-expr/src/expressions/lambda.rs | 187 +-------- .../physical-expr/src/lambda_function.rs | 60 +-- datafusion/physical-expr/src/planner.rs | 10 +- datafusion/sql/src/expr/function.rs | 103 ++--- datafusion/sql/src/expr/identifier.rs | 2 +- datafusion/sql/src/unparser/expr.rs | 11 +- datafusion/sqllogictest/test_files/lambda.slt | 97 +---- 16 files changed, 118 insertions(+), 1402 deletions(-) diff --git a/datafusion/common/src/lib.rs b/datafusion/common/src/lib.rs index 4dcb6d6cf2ec2..fdd04f752455e 100644 --- a/datafusion/common/src/lib.rs +++ b/datafusion/common/src/lib.rs @@ -114,7 +114,6 @@ pub type HashMap = hashbrown::HashMap; pub type HashSet = hashbrown::HashSet; pub mod hash_map { pub use hashbrown::hash_map::Entry; - pub use hashbrown::hash_map::EntryRef; } pub mod hash_set { pub use hashbrown::hash_set::Entry; diff --git a/datafusion/common/src/utils/mod.rs b/datafusion/common/src/utils/mod.rs index a5cf26fbff58e..ee08b70d1597c 100644 --- a/datafusion/common/src/utils/mod.rs +++ b/datafusion/common/src/utils/mod.rs @@ -30,18 +30,14 @@ use arrow::array::{ Array, ArrayRef, FixedSizeListArray, LargeListArray, ListArray, OffsetSizeTrait, cast::AsArray, }; -use arrow::array::{ArrowPrimitiveType, GenericListArray, Int32Array, PrimitiveArray}; use arrow::buffer::OffsetBuffer; use arrow::compute::{SortColumn, SortOptions, partition}; -use arrow::datatypes::{ - ArrowNativeType, DataType, Field, Int32Type, Int64Type, SchemaRef, -}; +use arrow::datatypes::{DataType, Field, SchemaRef}; #[cfg(feature = "sql")] use sqlparser::{ast::Ident, dialect::GenericDialect, parser::Parser}; use std::borrow::{Borrow, Cow}; use std::cmp::{Ordering, min}; use std::collections::HashSet; -use std::iter::repeat_n; use std::num::NonZero; use std::ops::Range; use std::sync::Arc; @@ -975,74 +971,14 @@ pub fn take_function_args( }) } -/// [0, 2, 2, 5, 6] -> [0, 0, 2, 2, 2, 3] -fn list_array_values_row_number( - offsets: &OffsetBuffer, -) -> PrimitiveArray { - let mut rows_number = Vec::with_capacity( - offsets[offsets.len() - 1].to_usize().unwrap() - offsets[0].to_usize().unwrap(), - ); - - for (i, len) in offsets.lengths().enumerate() { - rows_number.extend(repeat_n(T::Native::usize_as(i), len)); - } - - PrimitiveArray::new(rows_number.into(), None) -} - -/// [0, 2, 2, 5, 6] -> [1, 2, 1, 2, 3, 1] -fn list_array_values_index( - offsets: &OffsetBuffer, -) -> PrimitiveArray { - let mut indices = Vec::with_capacity( - offsets[offsets.len() - 1].to_usize().unwrap() - offsets[0].to_usize().unwrap(), - ); - - for len in offsets.lengths() { - indices.extend((1..1 + len).map(T::Native::usize_as)); - } - - PrimitiveArray::new(indices.into(), None) -} - -/// (2, 3) -> [0, 0, 1, 1, 2, 2] -fn fsl_values_row_number(list_size: i32, array_len: usize) -> Result { - let list_size = list_size.to_usize().ok_or_else(|| { - _exec_datafusion_err!("fsl_values_index: invalid list_size {list_size}") - })?; - - let mut rows_number = Vec::with_capacity(list_size * array_len); - - for i in 0..array_len { - rows_number.extend(repeat_n(i as i32, list_size)); - } - - Ok(PrimitiveArray::new(rows_number.into(), None)) -} - -/// (2, 3) -> [1, 2, 1, 2, 1, 2] -fn fsl_values_index(list_size: i32, array_len: usize) -> Result { - let list_size = list_size.to_usize().ok_or_else(|| { - _exec_datafusion_err!("fsl_values_index: invalid list_size {list_size}") - })?; - - let mut indices = Vec::with_capacity(list_size * array_len); - - for _ in 0..array_len { - indices.extend((1..1 + list_size).map(|j| j as i32)); - } - - Ok(PrimitiveArray::new(indices.into(), None)) -} - /// Returns the inner values of a list, or an error otherwise /// For [`ListArray`] and [`LargeListArray`], if it's sliced, it returns a /// sliced array too. Therefore, too reconstruct a list using it, /// you must adjust the offsets using [`adjust_offsets_for_slice`] pub fn list_values(array: &dyn Array) -> Result { match array.data_type() { - DataType::List(_) => Ok(sliced_list_values(array.as_list::())), - DataType::LargeList(_) => Ok(sliced_list_values(array.as_list::())), + DataType::List(_) => Ok(Arc::clone(array.as_list::().values())), + DataType::LargeList(_) => Ok(Arc::clone(array.as_list::().values())), DataType::FixedSizeList(_, _) => { Ok(Arc::clone(array.as_fixed_size_list().values())) } @@ -1050,98 +986,11 @@ pub fn list_values(array: &dyn Array) -> Result { } } -fn sliced_list_values(list: &GenericListArray) -> ArrayRef { - let values = list.values(); - let offsets = list.offsets(); - - if let (Some(first), Some(last)) = (offsets.first(), offsets.last()) { - let first = first.to_usize().unwrap(); - let last = last.to_usize().unwrap(); - - if first != 0 || last != values.len() { - return values.slice(first, last - first); - } - } - - Arc::clone(values) -} - -/// If `list` is sliced, returns an adjusted offset buffer so that -/// it points to the sliced portion of the list values, and not the whole list values -pub fn adjust_offsets_for_slice( - list: &GenericListArray, -) -> OffsetBuffer { - let offsets = list.offsets(); - - if let (Some(first), Some(last)) = (offsets.first(), offsets.last()) - && (!first.is_zero() || last.to_usize().unwrap() != list.values().len()) - { - let offsets = offsets.iter().map(|offset| *offset - *first).collect(); - - //todo: use unsafe Offset::new_unchecked? - return OffsetBuffer::new(offsets); - } - - offsets.clone() -} - -/// If `array` is a contiguos list, returns a new array of the same length as it's inner values -/// where each value is the 1-based index of the sublist it's contained. Example: -/// -/// `[[1], [2, 3], [4, 5, 6]] => [1, 2, 2, 3, 3, 3]` -/// -/// If it's not a contiguos list, return an error -pub fn list_values_row_number(array: &dyn Array) -> Result { - match array.data_type() { - DataType::List(_) => Ok(Arc::new(list_array_values_row_number::( - array.as_list().offsets(), - ))), - DataType::LargeList(_) => Ok(Arc::new( - list_array_values_row_number::(array.as_list().offsets()), - )), - DataType::FixedSizeList(_, _) => { - let fixed_size_list = array.as_fixed_size_list(); - - Ok(Arc::new(fsl_values_row_number( - fixed_size_list.value_length(), - fixed_size_list.len(), - )?)) - } - other => _exec_err!("expected list, got {other}"), - } -} - -/// If `array` is a contiguos list, returns a new array of the same length as it's inner values -/// where each value is the 1-based index within the sublist it's contained. Example: -/// -/// `[[1], [2, 3], [4, 5, 6]] => [1, 1, 2, 1, 2, 3]` -/// -/// If it's not a contiguos list, return an error -pub fn list_values_index(array: &dyn Array) -> Result { - match array.data_type() { - DataType::List(_) => Ok(Arc::new(list_array_values_index::( - array.as_list::().offsets(), - ))), - DataType::LargeList(_) => Ok(Arc::new(list_array_values_index::( - array.as_list::().offsets(), - ))), - DataType::FixedSizeList(_, _) => { - let fixed_size_list = array.as_fixed_size_list(); - - Ok(Arc::new(fsl_values_index( - fixed_size_list.value_length(), - fixed_size_list.len(), - )?)) - } - other => _exec_err!("expected list, got {other}"), - } -} - #[cfg(test)] mod tests { use super::*; use crate::ScalarValue::Null; - use arrow::array::{Float64Array, Int32Array}; + use arrow::array::Float64Array; use sqlparser::ast::Ident; #[test] @@ -1442,231 +1291,4 @@ mod tests { assert_eq!(expected, transposed); Ok(()) } - - #[test] - fn test_list_array_values_row_number() { - assert_eq!( - list_array_values_row_number::(&OffsetBuffer::from_lengths([ - 1, 3, 0, 2, - ])), - Int32Array::from(vec![0, 1, 1, 1, 3, 3]) - ); - - assert_eq!( - list_array_values_row_number::(&OffsetBuffer::from_lengths([])), - Int32Array::new_null(0) - ); - - assert_eq!( - list_array_values_row_number::(&OffsetBuffer::from_lengths([0])), - Int32Array::new_null(0) - ); - - assert_eq!( - list_array_values_row_number::(&OffsetBuffer::from_lengths([ - 0, 0 - ])), - Int32Array::new_null(0) - ); - - assert_eq!( - list_array_values_row_number::(&OffsetBuffer::from_lengths([1])), - Int32Array::from(vec![0]) - ); - - assert_eq!( - list_array_values_row_number::(&OffsetBuffer::from_lengths([2])), - Int32Array::from(vec![0, 0]) - ); - } - - #[test] - fn test_list_array_values_index() { - assert_eq!( - list_array_values_index::(&OffsetBuffer::from_lengths([ - 1, 3, 0, 2, - ])), - Int32Array::from(vec![1, 1, 2, 3, 1, 2]) - ); - - assert_eq!( - list_array_values_index::(&OffsetBuffer::from_lengths([])), - Int32Array::new_null(0) - ); - - assert_eq!( - list_array_values_index::(&OffsetBuffer::from_lengths([0])), - Int32Array::new_null(0) - ); - - assert_eq!( - list_array_values_index::(&OffsetBuffer::from_lengths([0, 0])), - Int32Array::new_null(0) - ); - - assert_eq!( - list_array_values_index::(&OffsetBuffer::from_lengths([1])), - Int32Array::from(vec![1]) - ); - - assert_eq!( - list_array_values_index::(&OffsetBuffer::from_lengths([2])), - Int32Array::from(vec![1, 2]) - ); - } - - #[test] - fn test_fsl_values_row_number() { - assert_eq!( - fsl_values_row_number(2, 3).unwrap(), - Int32Array::from(vec![0, 0, 1, 1, 2, 2]) - ); - - assert_eq!( - fsl_values_row_number(1, 3).unwrap(), - Int32Array::from(vec![0, 1, 2]) - ); - - assert_eq!( - fsl_values_row_number(2, 1).unwrap(), - Int32Array::from(vec![0, 0]) - ); - - assert_eq!( - fsl_values_row_number(2, 0).unwrap(), - Int32Array::new_null(0), - ); - - assert_eq!( - fsl_values_row_number(0, 2).unwrap(), - Int32Array::new_null(0), - ); - - assert_eq!( - fsl_values_row_number(0, 0).unwrap(), - Int32Array::new_null(0), - ); - - fsl_values_row_number(-1, 2).unwrap_err(); - fsl_values_row_number(-1, 0).unwrap_err(); - } - - #[test] - fn test_fsl_values_index() { - assert_eq!( - fsl_values_index(2, 3).unwrap(), - Int32Array::from(vec![1, 2, 1, 2, 1, 2]) - ); - - assert_eq!( - fsl_values_index(1, 3).unwrap(), - Int32Array::from(vec![1, 1, 1]) - ); - - assert_eq!( - fsl_values_index(2, 1).unwrap(), - Int32Array::from(vec![1, 2]) - ); - - assert_eq!(fsl_values_index(2, 0).unwrap(), Int32Array::new_null(0)); - assert_eq!(fsl_values_index(0, 2).unwrap(), Int32Array::new_null(0)); - assert_eq!(fsl_values_index(0, 0).unwrap(), Int32Array::new_null(0)); - - fsl_values_index(-1, 2).unwrap_err(); - fsl_values_index(-1, 0).unwrap_err(); - } - - fn list() -> ListArray { - let data = vec![ - Some(vec![Some(0), Some(1), Some(2)]), - None, - Some(vec![Some(3), None, Some(5)]), - Some(vec![Some(6), Some(7)]), - ]; - ListArray::from_iter_primitive::(data) - } - - #[test] - fn test_sliced_list_values() { - let list = list(); - - assert_eq!( - sliced_list_values(&list).as_primitive(), - &Int32Array::from(vec![ - Some(0), - Some(1), - Some(2), - Some(3), - None, - Some(5), - Some(6), - Some(7) - ]) - ); - - assert_eq!( - sliced_list_values(&list.slice(0, 1)).as_primitive(), - &Int32Array::from(vec![Some(0), Some(1), Some(2)]) - ); - - assert_eq!( - sliced_list_values(&list.slice(2, 1)).as_primitive(), - &Int32Array::from(vec![Some(3), None, Some(5)]) - ); - - assert_eq!( - sliced_list_values(&list.slice(3, 1)).as_primitive(), - &Int32Array::from(vec![Some(6), Some(7)]) - ); - - assert!(sliced_list_values(&list.slice(0, 0)).is_empty()); - assert!(sliced_list_values(&list.slice(1, 0)).is_empty()); - assert!(sliced_list_values(&list.slice(3, 0)).is_empty()); - } - - #[test] - fn test_adjust_offsets() { - let data = vec![ - Some(vec![Some(0), Some(1), Some(2)]), - None, - Some(vec![Some(3), None, Some(5)]), - Some(vec![Some(6), Some(7)]), - ]; - let list = ListArray::from_iter_primitive::(data); - - assert_eq!( - adjust_offsets_for_slice(&list), - OffsetBuffer::from_lengths([3, 0, 3, 2]) - ); - - assert_eq!( - adjust_offsets_for_slice(&list.slice(0, 1)), - OffsetBuffer::from_lengths([3]) - ); - - assert_eq!( - adjust_offsets_for_slice(&list.slice(1, 2)), - OffsetBuffer::from_lengths([0, 3]) - ); - - assert_eq!( - adjust_offsets_for_slice(&list.slice(1, 3)), - OffsetBuffer::from_lengths([0, 3, 2]) - ); - - assert_eq!( - adjust_offsets_for_slice(&list.slice(0, 0)), - OffsetBuffer::from_lengths([]) - ); - - assert_eq!( - adjust_offsets_for_slice(&list.slice(1, 0)), - OffsetBuffer::from_lengths([]) - ); - - assert_eq!( - adjust_offsets_for_slice(&list.slice(3, 0)), - OffsetBuffer::from_lengths([]) - ); - } } diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 8803c6e2187f8..e2b3643235933 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -35,14 +35,12 @@ use crate::{ExprSchemable, Operator, Signature, WindowFrame, WindowUDF}; use arrow::datatypes::{DataType, Field, FieldRef}; use datafusion_common::cse::{HashNode, NormalizeEq, Normalizeable}; use datafusion_common::datatype::DataTypeExt; -use datafusion_common::hash_map::EntryRef; use datafusion_common::metadata::format_type_and_metadata; use datafusion_common::tree_node::{ Transformed, TransformedResult, TreeNode, TreeNodeContainer, TreeNodeRecursion, }; use datafusion_common::{ Column, DFSchema, ExprSchema, HashMap, Result, ScalarValue, Spans, TableReference, - plan_datafusion_err, plan_err, }; use datafusion_expr_common::placement::ExpressionPlacement; use datafusion_functions_window_common::field::WindowUDFFieldArgs; @@ -489,7 +487,7 @@ impl PartialEq for LambdaFunction { #[derive(Clone, PartialEq, PartialOrd, Eq, Debug, Hash)] pub struct LambdaVariable { pub name: String, - pub field: Option, + pub field: FieldRef, pub spans: Spans, } @@ -500,7 +498,7 @@ impl LambdaVariable { /// [`Expr::resolve_lambdas_variables`] or [`LogicalPlan::resolve_lambdas_variables`]. /// /// [`LogicalPlan::resolve_lambdas_variables`]: crate::LogicalPlan::resolve_lambdas_variables - pub fn new(name: String, field: Option) -> Self { + pub fn new(name: String, field: FieldRef) -> Self { Self { name, field, @@ -2319,169 +2317,6 @@ impl Expr { None } } - - /// Return a `Expr` with all [`LambdaVariable`] resolved only if all of them - /// are contained in the subtree of the [`LambdaFunction`] it originates from, - /// otherwise returns an error - pub fn resolve_lambdas_variables( - self, - schema: &DFSchema, - ) -> Result> { - resolve_lambdas_variables(self, schema, &mut HashMap::new()) - } -} - -fn resolve_lambdas_variables( - expr: Expr, - schema: &DFSchema, - vars: &mut HashMap>, -) -> Result> { - expr.transform_down(|expr| match expr { - Expr::LambdaFunction(LambdaFunction { func, args }) => { - let args = if !vars.is_empty() { - /* if this is a nested lambda, we must resolve non-lambda args before invoking - lambdas_parameters because it will invoke ExprSchemable::to_field for every - non-lambda parameter, and if one them contains a lambda variable, it will fail - due to it being unresolved. Example query: - - array_transform([[1, 2]], a -> array_transform(a, b -> b+1)) - - the nested array_transform's lambdas_parameters will call Lambdavariable::to_field - on it's first argument, the variable `a`, which must be resolved - */ - args.map_elements(|arg| match arg { - Expr::Lambda(_) => Ok(Transformed::no(arg)), - _ => resolve_lambdas_variables(arg, schema, vars), - })? - } else { - Transformed::no(args) - }; - - let transformed = args.transformed; - let func = LambdaFunction::new(func, args.data); - - let mut lambdas_params = func.lambdas_parameters(schema)?.into_iter(); - - let num_args = func.args.len(); - let num_lambdas_params = lambdas_params.len(); - - let args = func.args.map_elements(|arg| { - let lambda_params = lambdas_params.next().ok_or_else(|| { - plan_datafusion_err!( - "{} lambdas_parameters returned {num_lambdas_params} values for {num_args} args", - func.func.name() - ) - })?; - - match (arg, lambda_params) { - (Expr::Lambda(mut lambda), Some(lambda_params)) => { - if lambda.params.len() > lambda_params.len() { - return plan_err!( - "{} lambda defined {} params ({}), but only {} supported", - func.func.name(), - lambda.params.len(), - display_comma_separated(&lambda.params), - lambda_params.len() - ); - } - - if !all_unique(&lambda.params) { - return plan_err!( - "lambda params must be unique, got ({})", - lambda.params.join(", ") - ); - } - - for (param, field) in - std::iter::zip(&lambda.params, lambda_params) - { - vars.entry_ref(param) - .or_default() - .push(Arc::new(field)); - } - - let transformed = resolve_lambdas_variables(mem::take(lambda.body.as_mut()), schema, vars)?; - - *lambda.body = transformed.data; - - for param in &lambda.params { - match vars.entry_ref(param) { - EntryRef::Occupied(mut v) => { - if v.get().len() == 1 { - v.remove(); - } else { - v.get_mut() - .pop() - .expect("every entry should have at least one field"); - } - }, - EntryRef::Vacant(_v) => { - unreachable!("the loop above should have inserted a value for every param") - }, - } - } - - Ok(Transformed::new(Expr::Lambda(lambda), transformed.transformed, TreeNodeRecursion::Jump)) - } - (Expr::Lambda(_), None) => { - plan_err!( - "{} lambdas_parameters retured None for a lambda argument", - func.func.name() - ) - } - (_, Some(_)) => { - plan_err!( - "{} lambdas_parameters retured Some for a non-lambda argument", - func.func.name() - ) - } - (arg, None) => Ok(Transformed::no(arg)) // resolved above - } - })?; - - Ok(Transformed::new( - Expr::LambdaFunction(LambdaFunction::new(func.func, args.data)), - transformed || args.transformed, - TreeNodeRecursion::Jump, - )) - } - Expr::LambdaVariable(mut var) => { - let fields_chain = vars.get(&var.name).ok_or_else(|| { - plan_datafusion_err!( - "missing field of lambda variable {} while resolving", - var.name - ) - })?; - - let field = fields_chain - .last() - .expect("every entry should have at least one field"); - - let transformed = var.field.as_ref().is_none_or(|old| old != field); - - if transformed { - var.field = Some(Arc::clone(field)); - } - - Ok(Transformed::new_transformed( - Expr::LambdaVariable(var), - transformed, - )) - } - _ => Ok(Transformed::no(expr)), - }) -} - -fn all_unique(params: &[String]) -> bool { - match params.len() { - 0 | 1 => true, - 2 => params[0] != params[1], - _ => { - let mut set = HashSet::with_capacity(params.len()); - - params.iter().all(|p| set.insert(p.as_str())) - } - } } impl Normalizeable for Expr { @@ -3955,9 +3790,8 @@ pub fn physical_name(expr: &Expr) -> Result { mod test { use crate::expr_fn::col; use crate::{ - ColumnarValue, LambdaSignature, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, - Volatility, case, lambda, lambda_var, lit, placeholder, qualified_wildcard, - wildcard, wildcard_with_options, + ColumnarValue, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Volatility, case, + lit, placeholder, qualified_wildcard, wildcard, wildcard_with_options, }; use arrow::datatypes::{Field, Schema}; use sqlparser::ast; @@ -4467,128 +4301,4 @@ mod test { } } } - - #[test] - fn test_resolve_lambda_variables() { - let schema = DFSchema::try_from(Schema::new(vec![Field::new( - "c", - DataType::new_list(DataType::new_list(DataType::Int32, true), true), - true, - )])) - .unwrap(); - - #[derive(Debug, Hash, PartialEq, Eq)] - struct MockLambdaUDF { - signature: LambdaSignature, - } - - impl LambdaUDF for MockLambdaUDF { - fn as_any(&self) -> &dyn Any { - self - } - - fn name(&self) -> &str { - "array_transform" - } - - fn signature(&self) -> &LambdaSignature { - &self.signature - } - - fn lambdas_parameters( - &self, - args: &[ValueOrLambda], - ) -> Result>>> { - let (ValueOrLambda::Value(list), ValueOrLambda::Lambda(_lambda)) = - (&args[0], &args[1]) - else { - unreachable!() - }; - - let (field, index_type) = match list.data_type() { - DataType::List(field) => (field, DataType::Int32), - _ => unreachable!(), - }; - - let value = - Field::new("", field.data_type().clone(), field.is_nullable()) - .with_metadata(field.metadata().clone()); - let index = Field::new("", index_type, false); - - Ok(vec![None, Some(vec![value, index])]) - } - - fn return_field_from_args( - &self, - _args: crate::LambdaReturnFieldArgs, - ) -> Result { - unimplemented!() - } - - fn invoke_with_args( - &self, - _args: crate::LambdaFunctionArgs, - ) -> Result { - unimplemented!() - } - } - - let func = Arc::new(MockLambdaUDF { - signature: LambdaSignature::variadic_any(Volatility::Immutable), - }) as _; - - // array_transform(c, v -> array_transform(v, (v, i) -> v+i)) - let expr = Expr::LambdaFunction(LambdaFunction::new( - Arc::clone(&func), - vec![ - col("c"), - lambda( - ["v"], - Expr::LambdaFunction(LambdaFunction::new( - Arc::clone(&func), - vec![ - lambda_var("v"), - lambda(["v", "i"], lambda_var("v") + lambda_var("i")), - ], - )), - ), - ], - )); - - let resolved_expr = expr.resolve_lambdas_variables(&schema).unwrap().data; - - let expected = Expr::LambdaFunction(LambdaFunction::new( - Arc::clone(&func), - vec![ - col("c"), - lambda( - ["v"], - Expr::LambdaFunction(LambdaFunction::new( - func, - vec![ - resolved_lambda_var( - "v", - DataType::new_list(DataType::Int32, true), - true, - ), - lambda( - ["v", "i"], - resolved_lambda_var("v", DataType::Int32, true) - + resolved_lambda_var("i", DataType::Int32, false), - ), - ], - )), - ), - ], - )); - - assert_eq!(resolved_expr, expected); - } - - fn resolved_lambda_var(name: &str, dt: DataType, nullable: bool) -> Expr { - Expr::LambdaVariable(LambdaVariable::new( - name.into(), - Some(Arc::new(Field::new("", dt, nullable))), - )) - } } diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index ce0fb09a3126b..79884e677a02d 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -741,15 +741,9 @@ pub fn lambda(params: impl IntoIterator>, body: Expr) - )) } -/// Create an unresolved lambda variable expression -/// -/// The expression tree or [`LogicalPlan`] which -/// owns this variable must be resolved before usage with either -/// [`Expr::resolve_lambdas_variables`] or [`LogicalPlan::resolve_lambdas_variables`]. -/// -/// [`LogicalPlan::resolve_lambdas_variables`]: crate::LogicalPlan::resolve_lambdas_variables -pub fn lambda_var(name: impl Into) -> Expr { - Expr::LambdaVariable(LambdaVariable::new(name.into(), None)) +/// Create an lambda variable expression +pub fn lambda_var(name: impl Into, field: FieldRef) -> Expr { + Expr::LambdaVariable(LambdaVariable::new(name.into(), field)) } /// Extensions for configuring [`Expr::AggregateFunction`] or [`Expr::WindowFunction`] diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index 9ae182aa20013..c40d483f25143 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -207,11 +207,9 @@ impl ExprSchemable for Expr { Ok(self.to_field(schema)?.1.data_type().clone()) } Expr::Lambda(Lambda { params: _, body }) => body.get_type(schema), - Expr::LambdaVariable(LambdaVariable { name, field, .. }) => Ok(field - .as_ref() - .ok_or_else(|| plan_datafusion_err!("unresolved LambdaVariable {name}"))? - .data_type() - .clone()), + Expr::LambdaVariable(LambdaVariable { field, .. }) => { + Ok(field.data_type().clone()) + } } } @@ -369,10 +367,7 @@ impl ExprSchemable for Expr { Ok(self.to_field(input_schema)?.1.is_nullable()) } Expr::Lambda(l) => l.body.nullable(input_schema), - Expr::LambdaVariable(LambdaVariable { name, field, .. }) => Ok(field - .as_ref() - .ok_or_else(|| plan_datafusion_err!("unresolved LambdaVariable {name}"))? - .is_nullable()), + Expr::LambdaVariable(LambdaVariable { field, .. }) => Ok(field.is_nullable()), } } @@ -645,11 +640,7 @@ impl ExprSchemable for Expr { func.func.return_field_from_args(args) } - Expr::LambdaVariable(l) => { - Ok(Arc::clone(l.field.as_ref().ok_or_else(|| { - plan_datafusion_err!("unresolved LambdaVariable {}", l.name) - })?)) - } + Expr::LambdaVariable(l) => Ok(Arc::clone(&l.field)), }?; Ok(( diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 4d8466827fe9c..b2a56971837f0 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -42,7 +42,7 @@ use crate::logical_plan::extension::UserDefinedLogicalNode; use crate::logical_plan::{DmlStatement, Statement}; use crate::utils::{ enumerate_grouping_sets, exprlist_to_fields, find_out_reference_exprs, - grouping_set_expr_count, grouping_set_to_exprlist, merge_schema, split_conjunction, + grouping_set_expr_count, grouping_set_to_exprlist, split_conjunction, }; use crate::{ BinaryExpr, CreateMemoryTable, CreateView, Execute, Expr, ExprSchemable, @@ -2110,17 +2110,6 @@ impl LogicalPlan { } Wrapper(self) } - - /// Return a `LogicalPLan` with all [`LambdaVariable`] resolved - /// - /// [`LambdaVariable`]: crate::expr::LambdaVariable - pub fn resolve_lambdas_variables(self) -> Result> { - self.transform_with_subqueries(|plan| { - let schema = merge_schema(&plan.inputs()); - - plan.map_expressions(|expr| expr.resolve_lambdas_variables(&schema)) - }) - } } impl Display for LogicalPlan { diff --git a/datafusion/expr/src/udlf.rs b/datafusion/expr/src/udlf.rs index 5140f14a99188..820195dbc9971 100644 --- a/datafusion/expr/src/udlf.rs +++ b/datafusion/expr/src/udlf.rs @@ -23,7 +23,7 @@ use crate::{ColumnarValue, Documentation, Expr}; use arrow::array::{ArrayRef, RecordBatch}; use arrow::datatypes::{DataType, Field, FieldRef, Schema}; use datafusion_common::config::ConfigOptions; -use datafusion_common::{Result, ScalarValue, exec_err, not_impl_err}; +use datafusion_common::{Result, ScalarValue, not_impl_err}; use datafusion_expr_common::dyn_eq::{DynEq, DynHash}; use datafusion_expr_common::signature::Volatility; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; @@ -74,12 +74,6 @@ pub struct LambdaSignature { pub type_signature: LambdaTypeSignature, /// The volatility of the function. See [Volatility] for more information. pub volatility: Volatility, - /// Optional parameter names for the function arguments. - /// - /// If provided, enables named argument notation for function calls (e.g., `func(a => 1, b => v -> v+1)`). - /// - /// Defaults to `None`, meaning only positional arguments are supported. - pub parameter_names: Option>, } impl LambdaSignature { @@ -88,7 +82,6 @@ impl LambdaSignature { LambdaSignature { type_signature, volatility, - parameter_names: None, } } @@ -97,7 +90,6 @@ impl LambdaSignature { Self { type_signature: LambdaTypeSignature::UserDefined, volatility, - parameter_names: None, } } @@ -106,7 +98,6 @@ impl LambdaSignature { Self { type_signature: LambdaTypeSignature::VariadicAny, volatility, - parameter_names: None, } } @@ -115,7 +106,6 @@ impl LambdaSignature { Self { type_signature: LambdaTypeSignature::Any(arg_count), volatility, - parameter_names: None, } } } @@ -205,136 +195,34 @@ pub struct LambdaArgument { /// For example, for `array_transform([2], v -> -v)`, /// this will be the physical expression of `-v` body: Arc, - /// A RecordBatch containing at least the captured columns inside this lambda body, if any - /// Note that it may contain additional, non-specified columns, but that's a implementation detail - /// - /// For example, for `array_transform([2], v -> v + a + b)`, - /// this will be a `RecordBatch` with at least two columns, `a` and `b` - captures: Option, } impl LambdaArgument { - pub fn new( - params: Vec, - body: Arc, - captures: Option, - ) -> Self { - Self { - params, - body, - captures, - } + pub fn new(params: Vec, body: Arc) -> Self { + Self { params, body } } /// Evaluate this lambda /// `args` should evalute to the value of each parameter /// of the correspondent lambda returned in [LambdaUDF::lambdas_parameters]. - /// - /// `adjust` should adjust the captured columns of this - /// lambda, if any, relative to it's parameters - /// - /// Tip: For adjusting multiple arrays by indices, use [`take_arrays`] - /// - /// [`take_arrays`]: arrow::compute::take_arrays pub fn evaluate( &self, args: &[&dyn Fn() -> Result], - mut adjust: impl FnMut(&[ArrayRef]) -> Result>, ) -> Result { - let adjusted_captures = self - .captures - .as_ref() - .map(|captures| { - let adjusted_columns = adjust(captures.columns())?; - - RecordBatch::try_new(captures.schema(), adjusted_columns) - }) - .transpose()?; - - let merged = merge_captures_with_variables( - adjusted_captures.as_ref(), - &self.params, - args, - )?; - - self.body.evaluate(&merged) - } -} - -fn merge_captures_with_variables( - captures: Option<&RecordBatch>, - params: &[FieldRef], - variables: &[&dyn Fn() -> Result], -) -> Result { - if variables.len() < params.len() { - return exec_err!( - "expected at least {} lambda arguments to merge with captures, got {}", - params.len(), - variables.len() - ); - } - - match captures { - Some(captures) => { - let old_fields = captures.schema_ref().fields(); - - let mut new_fields = old_fields - .iter() - .map(|field| { - if !fields_contains(params, field.name()) { - return Arc::clone(field); - } - - let mut i = 0; - - loop { - let alias = format!("{}_shadowed_{i}", field.name()); - - if !fields_contains(params, &alias) - && old_fields.find(&alias).is_none() - { - break Arc::new(Field::new( - alias, - field.data_type().clone(), - field.is_nullable(), - )); - } - - i += 1; - } - }) - .collect::>(); - - new_fields.extend_from_slice(params); - - let mut columns = captures.columns().to_vec(); - - for arg in &variables[..params.len()] { - columns.push(arg()?); - } - - let new_schema = Arc::new(Schema::new(new_fields)); + let columns = args + .iter() + .take(self.params.len()) + .map(|arg| arg()) + .collect::>()?; - Ok(RecordBatch::try_new(new_schema, columns)?) - } - None => { - let columns = variables - .iter() - .take(params.len()) - .map(|arg| arg()) - .collect::>()?; + let schema = Arc::new(Schema::new(self.params.clone())); - let schema = Arc::new(Schema::new(params)); + let batch = RecordBatch::try_new(schema, columns)?; - Ok(RecordBatch::try_new(schema, columns)?) - } + self.body.evaluate(&batch) } } -fn fields_contains(fields: &[FieldRef], name: &str) -> bool { - fields.iter().any(|f| f.name().as_str() == name) -} - /// Information about arguments passed to the function /// /// This structure contains metadata about how the function was called @@ -421,32 +309,6 @@ pub trait LambdaUDF: Debug + DynEq + DynHash + Send + Sync { /// [`Volatility`]: datafusion_expr_common::signature::Volatility fn signature(&self) -> &LambdaSignature; - /// Create a new instance of this function with updated configuration. - /// - /// This method is called when configuration options change at runtime - /// (e.g., via `SET` statements) to allow functions that depend on - /// configuration to update themselves accordingly. - /// - /// Note the current [`ConfigOptions`] are also passed to [`Self::invoke_with_args`] so - /// this API is not needed for functions where the values may - /// depend on the current options. - /// - /// This API is useful for functions where the return - /// **type** depends on the configuration options, such as the `now()` function - /// which depends on the current timezone. - /// - /// # Arguments - /// - /// * `config` - The updated configuration options - /// - /// # Returns - /// - /// * `Some(LambdaUDF)` - A new instance of this function configured with the new settings - /// * `None` - If this function does not change with new configuration settings (the default) - fn with_updated_config(&self, _config: &ConfigOptions) -> Option> { - None - } - /// Returns a list of the same size as args where each value is the logic below applied to value at the correspondent position in args: /// /// If it's a value, return None @@ -525,41 +387,6 @@ pub trait LambdaUDF: Debug + DynEq + DynHash + Send + Sync { /// to arrays, which will likely be simpler code, but be slower. fn invoke_with_args(&self, args: LambdaFunctionArgs) -> Result; - /// Optionally apply per-UDF simplification / rewrite rules. - /// - /// This can be used to apply function specific simplification rules during - /// optimization (e.g. `arrow_cast` --> `Expr::Cast`). The default - /// implementation does nothing. - /// - /// Note that DataFusion handles simplifying arguments and "constant - /// folding" (replacing a function call with constant arguments such as - /// `my_add(1,2) --> 3` ). Thus, there is no need to implement such - /// optimizations manually for specific UDFs. - /// - /// # Arguments - /// * `args`: The arguments of the function - /// * `info`: The necessary information for simplification - /// - /// # Returns - /// [`ExprSimplifyResult`] indicating the result of the simplification NOTE - /// if the function cannot be simplified, the arguments *MUST* be returned - /// unmodified - /// - /// # Notes - /// - /// The returned expression must have the same schema as the original - /// expression, including both the data type and nullability. For example, - /// if the original expression is nullable, the returned expression must - /// also be nullable, otherwise it may lead to schema verification errors - /// later in query planning. - fn simplify( - &self, - args: Vec, - _info: &SimplifyContext, - ) -> Result { - Ok(ExprSimplifyResult::Original(args)) - } - /// Returns true if some of this `exprs` subexpressions may not be evaluated /// and thus any side effects (like divide by zero) may not be encountered. /// diff --git a/datafusion/functions-nested/src/array_transform.rs b/datafusion/functions-nested/src/array_transform.rs index f7d299ceb6283..446c9b9a45f05 100644 --- a/datafusion/functions-nested/src/array_transform.rs +++ b/datafusion/functions-nested/src/array_transform.rs @@ -19,15 +19,11 @@ use arrow::{ array::{Array, ArrayRef, AsArray, FixedSizeListArray, LargeListArray, ListArray}, - compute::take_arrays, datatypes::{DataType, Field, FieldRef}, }; use datafusion_common::{ Result, exec_err, plan_err, - utils::{ - adjust_offsets_for_slice, list_values, list_values_index, list_values_row_number, - take_function_args, - }, + utils::{list_values, take_function_args}, }; use datafusion_expr::{ ColumnarValue, Documentation, LambdaFunctionArgs, LambdaReturnFieldArgs, @@ -135,21 +131,19 @@ impl LambdaUDF for ArrayTransform { ) -> Result>>> { let (list, _lambda) = value_lambda_pair(self.name(), args)?; - let (field, index_type) = match list.data_type() { - DataType::List(field) => (field, DataType::Int32), - DataType::LargeList(field) => (field, DataType::Int64), - DataType::FixedSizeList(field, _) => (field, DataType::Int32), + let field = match list.data_type() { + DataType::List(field) => field, + DataType::LargeList(field) => field, + DataType::FixedSizeList(field, _) => field, _ => return plan_err!("expected list, got {list}"), }; - // we don't need to omit the index in the case the lambda don't specify, e.g. array_transform([], v -> v*2), - // nor check whether the lambda contains more than two parameters, e.g. array_transform([], (v, i, j) -> v+i+j), - // as datafusion will do that for us + // we don't need to check whether the lambda contains more than two parameters, + // e.g. array_transform([], (v, i, j) -> v+i+j), as datafusion will do that for us let value = Field::new("", field.data_type().clone(), field.is_nullable()) .with_metadata(field.metadata().clone()); - let index = Field::new("", index_type, false); - Ok(vec![None, Some(vec![value, index])]) + Ok(vec![None, Some(vec![value])]) } fn return_field_from_args(&self, args: LambdaReturnFieldArgs) -> Result> { @@ -179,30 +173,14 @@ impl LambdaUDF for ArrayTransform { let (list, lambda) = value_lambda_pair(self.name(), &args.args)?; let list_array = list.to_array(args.number_rows)?; - - // as per list_values docs, if list_array is sliced, list_values will be sliced too, - // so before constructing the transformed array below, we must adjust the list offsets with - // adjust_offsets_for_slice let list_values = list_values(&list_array)?; - // if any column got captured, we need to adjust it to the values arrays, - // duplicating values of list with mulitple values and removing values of empty lists - // list_values_row_number is not cheap so is important to avoid it when no column is captured - let mut adjust_indices = None; - // by passing closures, lambda.evaluate can evaluate only those actually needed let values_param = || Ok(Arc::clone(&list_values)); - let indices_param = || list_values_index(&list_array); // call the transforming lambda let transformed_values = lambda - .evaluate(&[&values_param, &indices_param], |arrays| { - let indices = match &adjust_indices { - Some(v) => v, - None => adjust_indices.insert(list_values_row_number(&list_array)?), - }; - Ok(take_arrays(arrays, indices, None)?) - })? + .evaluate(&[&values_param])? .into_array(list_values.len())?; let field = match args.return_field.data_type() { @@ -221,26 +199,20 @@ impl LambdaUDF for ArrayTransform { let transformed_list = match list_array.data_type() { DataType::List(_) => { let list = list_array.as_list(); - // since we called list_values above which would return sliced values for - // a sliced list, we must adjust the offsets here as otherwise they would be invalid - let adjusted_offsets = adjust_offsets_for_slice(list); Arc::new(ListArray::new( field, - adjusted_offsets, + list.offsets().clone(), transformed_values, list.nulls().cloned(), )) as ArrayRef } DataType::LargeList(_) => { let large_list = list_array.as_list(); - // since we called list_values above which would return sliced values for - // a sliced list, we must adjust the offsets here as otherwise they would be invalid - let adjusted_offsets = adjust_offsets_for_slice(large_list); Arc::new(LargeListArray::new( field, - adjusted_offsets, + large_list.offsets().clone(), transformed_values, large_list.nulls().cloned(), )) diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 84f16dd5e51d5..0273e6ffd2eee 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -448,11 +448,9 @@ impl TreeNodeRewriter for Canonicalizer { }; match (left.as_ref(), right.as_ref(), op.swap()) { // - ( - left_ref @ (Expr::Column(_) | Expr::LambdaVariable(_)), - right_ref @ (Expr::Column(_) | Expr::LambdaVariable(_)), - Some(swapped_op), - ) if right_ref > left_ref => { + (Expr::Column(left_col), Expr::Column(right_col), Some(swapped_op)) + if right_col > left_col => + { Ok(Transformed::yes(Expr::BinaryExpr(BinaryExpr { left: right, op: swapped_op, @@ -460,15 +458,13 @@ impl TreeNodeRewriter for Canonicalizer { }))) } // - ( - Expr::Literal(_, _), - Expr::Column(_) | Expr::LambdaVariable(_), - Some(swapped_op), - ) => Ok(Transformed::yes(Expr::BinaryExpr(BinaryExpr { - left: right, - op: swapped_op, - right: left, - }))), + (Expr::Literal(_a, _), Expr::Column(_b), Some(swapped_op)) => { + Ok(Transformed::yes(Expr::BinaryExpr(BinaryExpr { + left: right, + op: swapped_op, + right: left, + }))) + } _ => Ok(Transformed::no(Expr::BinaryExpr(BinaryExpr { left, op, @@ -2212,8 +2208,8 @@ fn are_inlist_and_eq(left: &Expr, right: &Expr) -> bool { let left = as_inlist(left); let right = as_inlist(right); if let (Some(lhs), Some(rhs)) = (left, right) { - matches!(lhs.expr.as_ref(), Expr::Column(_) | Expr::LambdaVariable(_)) - && matches!(rhs.expr.as_ref(), Expr::Column(_) | Expr::LambdaVariable(_)) + matches!(lhs.expr.as_ref(), Expr::Column(_)) + && matches!(rhs.expr.as_ref(), Expr::Column(_)) && lhs.expr == rhs.expr && !lhs.negated && !rhs.negated @@ -2228,20 +2224,16 @@ fn as_inlist(expr: &'_ Expr) -> Option> { Expr::InList(inlist) => Some(Cow::Borrowed(inlist)), Expr::BinaryExpr(BinaryExpr { left, op, right }) if *op == Operator::Eq => { match (left.as_ref(), right.as_ref()) { - (Expr::Column(_) | Expr::LambdaVariable(_), Expr::Literal(_, _)) => { - Some(Cow::Owned(InList { - expr: left.clone(), - list: vec![*right.clone()], - negated: false, - })) - } - (Expr::Literal(_, _), Expr::Column(_) | Expr::LambdaVariable(_)) => { - Some(Cow::Owned(InList { - expr: right.clone(), - list: vec![*left.clone()], - negated: false, - })) - } + (Expr::Column(_), Expr::Literal(_, _)) => Some(Cow::Owned(InList { + expr: left.clone(), + list: vec![*right.clone()], + negated: false, + })), + (Expr::Literal(_, _), Expr::Column(_)) => Some(Cow::Owned(InList { + expr: right.clone(), + list: vec![*left.clone()], + negated: false, + })), _ => None, } } @@ -2257,20 +2249,16 @@ fn to_inlist(expr: Expr) -> Option { op: Operator::Eq, right, }) => match (left.as_ref(), right.as_ref()) { - (Expr::Column(_) | Expr::LambdaVariable(_), Expr::Literal(_, _)) => { - Some(InList { - expr: left, - list: vec![*right], - negated: false, - }) - } - (Expr::Literal(_, _), Expr::Column(_) | Expr::LambdaVariable(_)) => { - Some(InList { - expr: right, - list: vec![*left], - negated: false, - }) - } + (Expr::Column(_), Expr::Literal(_, _)) => Some(InList { + expr: left, + list: vec![*right], + negated: false, + }), + (Expr::Literal(_, _), Expr::Column(_)) => Some(InList { + expr: right, + list: vec![*left], + negated: false, + }), _ => None, }, _ => None, diff --git a/datafusion/physical-expr/src/expressions/lambda.rs b/datafusion/physical-expr/src/expressions/lambda.rs index cb0b76483505f..c06b48591b70e 100644 --- a/datafusion/physical-expr/src/expressions/lambda.rs +++ b/datafusion/physical-expr/src/expressions/lambda.rs @@ -21,27 +21,20 @@ use std::any::Any; use std::hash::Hash; use std::sync::Arc; -use crate::expressions::{Column, LambdaVariable}; use crate::physical_expr::PhysicalExpr; use arrow::{ datatypes::{DataType, Schema}, record_batch::RecordBatch, }; -use datafusion_common::{HashSet, Result, internal_err, tree_node::TreeNodeVisitor}; -use datafusion_common::{ - plan_err, - tree_node::{TreeNode, TreeNodeRecursion}, -}; +use datafusion_common::plan_err; +use datafusion_common::{HashSet, Result, internal_err}; use datafusion_expr::ColumnarValue; -use hashbrown::{HashMap, hash_map::EntryRef}; /// Represents a lambda with the given parameters names and body #[derive(Debug, Eq, Clone)] pub struct LambdaExpr { params: Vec, body: Arc, - captured_columns: HashSet, - captured_variables: HashSet, } // Manually derive PartialEq and Hash to work around https://github.com/rust-lang/rust/issues/78808 [https://github.com/apache/datafusion/issues/13196] @@ -69,21 +62,7 @@ impl LambdaExpr { } fn new(params: Vec, body: Arc) -> Self { - let (captured_columns, captured_variables) = { - let mut captures = Captures::new(¶ms); - - body.visit(&mut captures) - .expect("visitor should be infallible"); - - (captures.columns, captures.variables) - }; - - Self { - params, - body, - captured_columns, - captured_variables, - } + Self { params, body } } /// Get the lambda's params names @@ -95,22 +74,6 @@ impl LambdaExpr { pub fn body(&self) -> &Arc { &self.body } - - pub(crate) fn captured_columns(&self) -> &HashSet { - &self.captured_columns - } - - /// Returns lambdas variables names that aren't of this lambda nor any other lambda down tree. - /// Example: - /// - /// `array_transform([[[1, 2, 3]]], a -> array_transform(a, b -> array_transform(b, c -> length(a) + length(b) + c)))` - /// - /// For the outermost lambda, this would return an empty hash set - /// For the middle one, `HashSet("a")` - /// And for the innermost, `HashSet("a", "b")` - pub(crate) fn captured_variables(&self) -> &HashSet { - &self.captured_variables - } } impl std::fmt::Display for LambdaExpr { @@ -178,152 +141,12 @@ fn all_unique(params: &[String]) -> bool { } } -struct Captures<'a> { - shadows: HashMap<&'a str, usize>, - columns: HashSet, - variables: HashSet, -} - -impl<'a> Captures<'a> { - fn new(params: &'a [String]) -> Self { - Self { - shadows: params.iter().map(|p| (p.as_str(), 1)).collect(), - columns: HashSet::new(), - variables: HashSet::new(), - } - } -} - -impl<'n> TreeNodeVisitor<'n> for Captures<'n> { - type Node = Arc; - - fn f_down(&mut self, node: &'n Self::Node) -> Result { - if let Some(lambda) = node.as_any().downcast_ref::() { - for param in &lambda.params { - *self.shadows.entry(param.as_str()).or_default() += 1; - } - } else if let Some(lambda_variable) = - node.as_any().downcast_ref::() - { - if !self.shadows.contains_key(lambda_variable.name()) { - self.variables.insert(lambda_variable.name().to_owned()); - } - } else if let Some(col) = node.as_any().downcast_ref::() { - self.columns.insert(col.index()); - } - - Ok(TreeNodeRecursion::Continue) - } - - fn f_up(&mut self, node: &'n Self::Node) -> Result { - if let Some(lambda) = node.as_any().downcast_ref::() { - for param in &lambda.params { - match self.shadows.entry_ref(param.as_str()) { - EntryRef::Occupied(mut v) => { - if *v.get() > 1 { - *v.get_mut() -= 1; - } else { - v.remove(); - } - } - EntryRef::Vacant(_v) => { - unreachable!( - "f_down should have inserted a value for every param" - ) - } - } - } - } - - Ok(TreeNodeRecursion::Continue) - } -} - #[cfg(test)] mod tests { - use crate::{ - LambdaFunctionExpr, - expressions::{Column, LambdaExpr, NoOp, lambda::lambda, lambda_variable}, - }; - use arrow::{ - array::RecordBatch, - datatypes::{DataType, Field, FieldRef, Schema}, - }; - use datafusion_common::{HashSet, Result}; - use datafusion_expr::{ColumnarValue, LambdaUDF}; + use crate::expressions::{NoOp, lambda::lambda}; + use arrow::{array::RecordBatch, datatypes::Schema}; use std::sync::Arc; - #[derive(Debug, Hash, Eq, PartialEq)] - struct DummyLambdaUDF; - - impl LambdaUDF for DummyLambdaUDF { - fn as_any(&self) -> &dyn std::any::Any { - unimplemented!() - } - - fn name(&self) -> &str { - "dummy_udlf" - } - - fn signature(&self) -> &datafusion_expr::LambdaSignature { - unimplemented!() - } - - fn lambdas_parameters( - &self, - _args: &[datafusion_expr::ValueOrLambda], - ) -> Result>>> { - unimplemented!() - } - - fn return_field_from_args( - &self, - _args: datafusion_expr::LambdaReturnFieldArgs, - ) -> Result { - unimplemented!() - } - - fn invoke_with_args( - &self, - _args: datafusion_expr::LambdaFunctionArgs, - ) -> Result { - unimplemented!() - } - } - - #[test] - fn test_lambda_captures() { - let null_field = Arc::new(Field::new("", DataType::Null, true)); - - //`var_b -> dummy_udlf(var_a, var_b, column@0, var_c -> var_c))` - let inner = LambdaExpr::try_new( - vec![String::from("var_b")], - Arc::new(LambdaFunctionExpr::new( - "dummy_udlf", - Arc::new(DummyLambdaUDF), - vec![ - lambda_variable("var_a", Arc::clone(&null_field)).unwrap(), - lambda_variable("var_b", Arc::clone(&null_field)).unwrap(), - Arc::new(Column::new("column", 0)), - lambda( - ["var_c"], - lambda_variable("var_c", Arc::clone(&null_field)).unwrap(), - ) - .unwrap(), - ], - Arc::clone(&null_field), - Arc::new(Default::default()), - )), - ) - .unwrap(); - - assert_eq!(inner.captured_columns(), &HashSet::from([0])); - assert_eq!( - inner.captured_variables(), - &HashSet::from([String::from("var_a")]) - ); - } - #[test] fn test_lambda_evaluate() { let lambda = lambda(["a"], Arc::new(NoOp::new())).unwrap(); diff --git a/datafusion/physical-expr/src/lambda_function.rs b/datafusion/physical-expr/src/lambda_function.rs index b5c7cadf77e1c..756d21c748b21 100644 --- a/datafusion/physical-expr/src/lambda_function.rs +++ b/datafusion/physical-expr/src/lambda_function.rs @@ -37,8 +37,8 @@ use std::sync::Arc; use crate::PhysicalExpr; use crate::expressions::{LambdaExpr, Literal}; -use arrow::array::{Array, NullArray, RecordBatch}; -use arrow::datatypes::{DataType, Field, FieldRef, Schema}; +use arrow::array::{Array, RecordBatch}; +use arrow::datatypes::{DataType, FieldRef, Schema}; use datafusion_common::config::{ConfigEntry, ConfigOptions}; use datafusion_common::{Result, ScalarValue, exec_err, internal_err}; use datafusion_expr::type_coercion::functions::value_fields_with_lambda_udf; @@ -169,24 +169,6 @@ impl LambdaFunctionExpr { pub fn config_options(&self) -> &ConfigOptions { &self.config_options } - - /// Given an arbitrary PhysicalExpr attempt to downcast it to a LambdaFunctionExpr - /// and verify that its inner function is of type T. - /// If the downcast fails, or the function is not of type T, returns `None`. - /// Otherwise returns `Some(LambdaFunctionExpr)`. - pub fn try_downcast_func(expr: &dyn PhysicalExpr) -> Option<&LambdaFunctionExpr> - where - T: 'static, - { - match expr.as_any().downcast_ref::() { - Some(lambda_expr) - if lambda_expr.fun().as_any().downcast_ref::().is_some() => - { - Some(lambda_expr) - } - _ => None, - } - } } impl fmt::Display for LambdaFunctionExpr { @@ -289,43 +271,6 @@ impl PhysicalExpr for LambdaFunctionExpr { ); } - let indices = lambda.captured_columns(); - let variables = lambda.captured_variables(); - - let captures = if !indices.is_empty() || !variables.is_empty() { - let (fields, columns): (Vec<_>, _) = std::iter::zip( - batch.schema_ref().fields(), - batch.columns(), - ) - .enumerate() - .map(|(column_index, (field, column))| { - if indices.contains(&column_index) - || variables.contains(field.name()) - { - (Arc::clone(field), Arc::clone(column)) - } else { - // To avoid costly copies of uncaptured columns, we swap them with a NullArray - // while keeping the number of columns on the batch the same - // so captured columns indices are kept stable across the whole tree. - ( - Arc::new(Field::new( - field.name(), - DataType::Null, - false, - )), - Arc::new(NullArray::new(column.len())) as _, - ) - } - }) - .unzip(); - - let schema = Arc::new(Schema::new(fields)); - - Some(RecordBatch::try_new(schema, columns)?) - } else { - None - }; - let params = std::iter::zip(lambda.params(), lambda_params) .map(|(name, param)| Arc::new(param.with_name(name))) .collect(); @@ -333,7 +278,6 @@ impl PhysicalExpr for LambdaFunctionExpr { Ok(ValueOrLambda::Lambda(LambdaArgument::new( params, Arc::clone(lambda.body()), - captures, ))) } (Some(_lambda), None) => exec_err!( diff --git a/datafusion/physical-expr/src/planner.rs b/datafusion/physical-expr/src/planner.rs index 63683ae3b68e2..21f63273a1ed6 100644 --- a/datafusion/physical-expr/src/planner.rs +++ b/datafusion/physical-expr/src/planner.rs @@ -27,8 +27,7 @@ use arrow::datatypes::Schema; use datafusion_common::config::ConfigOptions; use datafusion_common::metadata::{FieldMetadata, format_type_and_metadata}; use datafusion_common::{ - DFSchema, Result, ScalarValue, ToDFSchema, exec_err, not_impl_err, - plan_datafusion_err, plan_err, + DFSchema, Result, ScalarValue, ToDFSchema, exec_err, not_impl_err, plan_err, }; use datafusion_expr::execution_props::ExecutionProps; use datafusion_expr::expr::{ @@ -440,12 +439,7 @@ pub fn create_physical_expr( name, field, spans: _, - }) => expressions::lambda_variable( - name, - Arc::clone(field.as_ref().ok_or_else(|| { - plan_datafusion_err!("unresolved LambdaVariable {name}") - })?), - ), + }) => expressions::lambda_variable(name, Arc::clone(field)), other => { not_impl_err!("Physical plan does not support logical expression {other:?}") } diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index 36ee0f6beabd7..d16f6189ba33a 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -328,14 +328,8 @@ impl SqlToRel<'_, S> { } // User-defined function (UDF) should have precedence if let Some(fm) = self.context_provider.get_function_meta(&name) { - let (args, arg_names): (Vec<_>, Vec<_>) = args - .into_iter() - .map(|a| { - self.sql_fn_arg_to_logical_expr_with_name(a, schema, planner_context) - }) - .collect::>>()? - .into_iter() - .unzip(); + let (args, arg_names) = + self.function_args_to_expr_with_names(args, schema, planner_context)?; let resolved_args = if arg_names.iter().any(|name| name.is_some()) { if let Some(param_names) = &fm.signature().parameter_names { @@ -380,11 +374,11 @@ impl SqlToRel<'_, S> { // LambdaUDF::lambdas_parameters to then plan the lambda arguments with // resolved lambda variables enum ExprOrLambda { - Expr((Expr, Option)), - Lambda((sqlparser::ast::LambdaFunction, Option)), + Expr(Expr), + Lambda(sqlparser::ast::LambdaFunction), } - let pairs = args + let partially_planned = args .into_iter() .map(|a| match a { FunctionArg::Unnamed(FunctionArgExpr::Expr(SQLExpr::Lambda( @@ -397,46 +391,20 @@ impl SqlToRel<'_, S> { ); } - Ok(ExprOrLambda::Lambda((lambda, None))) - } - FunctionArg::Named { - name, - arg: FunctionArgExpr::Expr(SQLExpr::Lambda(lambda)), - operator: _, - } - | FunctionArg::ExprNamed { - name: SQLExpr::Identifier(name), - arg: FunctionArgExpr::Expr(SQLExpr::Lambda(lambda)), - operator: _, - } => { - if !all_unique(&lambda.params) { - return plan_err!( - "lambda parameters names must be unique, got {}", - lambda.params - ); - } - - let arg_name = ArgumentName { - value: name.value, - is_quoted: name.quote_style.is_some(), - }; - - Ok(ExprOrLambda::Lambda((lambda, Some(arg_name)))) + Ok(ExprOrLambda::Lambda(lambda)) } - _ => Ok(ExprOrLambda::Expr( - self.sql_fn_arg_to_logical_expr_with_name( - a, - schema, - planner_context, - )?, - )), + _ => Ok(ExprOrLambda::Expr(self.sql_fn_arg_to_logical_expr( + a, + schema, + planner_context, + )?)), }) .collect::>>()?; - let current_fields = pairs + let current_fields = partially_planned .iter() .map(|e| match e { - ExprOrLambda::Expr((expr, _name)) => { + ExprOrLambda::Expr(expr) => { Ok(ValueOrLambda::Value(expr.to_field(schema)?.1)) } ExprOrLambda::Lambda(_lambda_function) => { @@ -449,12 +417,12 @@ impl SqlToRel<'_, S> { let lambdas_parameters = fm.lambdas_parameters(&coerced)?; - let pairs = pairs + let args = partially_planned .into_iter() .zip(lambdas_parameters) .map(|(e, lambda_parameters)| match (e, lambda_parameters) { (ExprOrLambda::Expr(expr), None) => Ok(expr), - (ExprOrLambda::Lambda((lambda, name)), Some(lambda_params)) => { + (ExprOrLambda::Lambda(lambda), Some(lambda_params)) => { if lambda.params.len() > lambda_params.len() { return plan_err!( "lambda defined {} params but UDF support only {}", @@ -475,17 +443,14 @@ impl SqlToRel<'_, S> { .clone() .with_lambda_parameters(lambda_parameters); - Ok(( - Expr::Lambda(Lambda { - params, - body: Box::new(self.sql_expr_to_logical_expr( - *lambda.body, - schema, - &mut planner_context, - )?), - }), - name, - )) + Ok(Expr::Lambda(Lambda { + params, + body: Box::new(self.sql_expr_to_logical_expr( + *lambda.body, + schema, + &mut planner_context, + )?), + })) } (ExprOrLambda::Expr(_), Some(_)) => plan_err!( "{} reported parameters for an argument that is not a lambda", @@ -498,27 +463,7 @@ impl SqlToRel<'_, S> { }) .collect::>>()?; - let (args, arg_names): (Vec<_>, Vec<_>) = pairs.into_iter().unzip(); - - let resolved_args = if arg_names.iter().any(|name| name.is_some()) { - if let Some(param_names) = &fm.signature().parameter_names { - datafusion_expr::arguments::resolve_function_arguments( - param_names, - args, - arg_names, - )? - } else { - return plan_err!( - "Function '{}' does not support named arguments", - fm.name() - ); - } - } else { - args - }; - - // After resolution, all arguments are positional - let inner = LambdaFunction::new(fm, resolved_args); + let inner = LambdaFunction::new(fm, args); if name.eq_ignore_ascii_case(inner.name()) { return Ok(Expr::LambdaFunction(inner)); diff --git a/datafusion/sql/src/expr/identifier.rs b/datafusion/sql/src/expr/identifier.rs index 4ec00c13719ff..4ae0f51e9f993 100644 --- a/datafusion/sql/src/expr/identifier.rs +++ b/datafusion/sql/src/expr/identifier.rs @@ -65,7 +65,7 @@ impl SqlToRel<'_, S> { planner_context.lambdas_parameters().get(&normalize_ident) { let mut lambda_var = - LambdaVariable::new(normalize_ident, Some(Arc::clone(field))); + LambdaVariable::new(normalize_ident, Arc::clone(field)); if self.options.collect_spans && let Some(span) = Span::try_from_sqlparser_span(id_span) { diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index 83801fa0f306f..a8da0234d3713 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -2057,7 +2057,16 @@ mod tests { ( Expr::LambdaFunction(LambdaFunction::new( Arc::new(DummyLambdaUDF), - vec![col("a"), lambda(["v"], -lambda_var("v"))], + vec![ + col("a"), + lambda( + ["v"], + -lambda_var( + "v", + Arc::new(Field::new("", DataType::Null, true)), + ), + ), + ], )), r#"dummy_udlf(a, (v) -> -v)"#, ), diff --git a/datafusion/sqllogictest/test_files/lambda.slt b/datafusion/sqllogictest/test_files/lambda.slt index b44ab3203eb9e..585a6a07f6b88 100644 --- a/datafusion/sqllogictest/test_files/lambda.slt +++ b/datafusion/sqllogictest/test_files/lambda.slt @@ -44,40 +44,6 @@ SELECT array_transform([1,2,3,4,5], v -> list_repeat("a", v)); ---- [[a], [a, a], [a, a, a], [a, a, a, a], [a, a, a, a, a]] - -# version without limit/offset of queries below -query ? -select array_transform(t.list, v -> v*2) from t order by t.number; ----- -[2, 100] -[8, 100] -[14, 100] - -# sliced lists -query ? -select array_transform(t.list, v -> v*2) from t order by t.number limit 1; ----- -[2, 100] - -query ? -select array_transform(t.list, v -> v*2) from t order by t.number offset 1; ----- -[8, 100] -[14, 100] - -query ? -select array_transform(t.list, v -> v*2) from t order by t.number limit 1 offset 1; ----- -[8, 100] - -# lambda that uses only captured column which also only appears within the lambda -query ? -SELECT array_transform([1, 2], v -> t.number) from t; ----- -[10, 10] -[40, 40] -[60, 60] - # return scalar query I? SELECT t.number, array_transform([1, 2], e1 -> 24) from t; @@ -86,50 +52,6 @@ SELECT t.number, array_transform([1, 2], e1 -> 24) from t; 40 [24, 24] 60 [24, 24] -# uses only the first parameter -query ? -SELECT array_transform([1, 2], (v, i) -> v+1) from t; ----- -[2, 3] -[2, 3] -[2, 3] - -# uses only the second parameter -query ? -SELECT array_transform([10, 20], (v, i) -> i) from t; ----- -[1, 2] -[1, 2] -[1, 2] - -# use only capture of a parent lambda variable -query ? -SELECT array_transform([[1, 2]], (a, i) -> array_transform(a, b -> i)) ----- -[[1, 1]] - -# use only capture of a parent lambda variable and a column -query ? -SELECT array_transform([[1, 2]], (a, i) -> array_transform(a, b -> i + t.number)) from t ----- -[[11, 11]] -[[41, 41]] -[[61, 61]] - -# use only capture of a parent lambda variable and own variable -query ? -SELECT array_transform([[1, 2]], (a, i) -> array_transform(a, b -> i + b)) ----- -[[2, 3]] - -# use capture of a column, a parent lambda variable and own variable -query ? -SELECT array_transform([[1, 2]], (a, i) -> array_transform(a, b -> i + b + t.number)) from t ----- -[[12, 13]] -[[42, 43]] -[[62, 63]] - # shadows parent lambda variable query ? SELECT array_transform([[1, 2]], a -> array_transform(a, a -> a+1)) @@ -138,7 +60,7 @@ SELECT array_transform([[1, 2]], a -> array_transform(a, a -> a+1)) # multiple nesting query ? -SELECT array_transform([[[1], [2], [3]]], (a, i) -> array_transform(a, (b, j) -> array_transform(b, (c, k) -> c + i + j + k))); +SELECT array_transform([[[1], [2], [3]]], a -> array_transform(a, b -> array_transform(b, c -> c*2))); ---- [[[4], [6], [8]]] @@ -152,14 +74,14 @@ SELECT number, array_transform([1, 2], number -> number+1) from t; # type coercion inside lambda body query ? -SELECT array_transform([1.0, 2.0], v -> v + t.number) from t; +SELECT array_transform([1.0, 2.0], v -> v + 3) from t; ---- [11.0, 12.0] [41.0, 42.0] [61.0, 62.0] query TT -EXPLAIN SELECT array_transform([1.0, 2.0], v -> v + t.number) from t; +EXPLAIN SELECT array_transform([1.0, 2.0], v -> v + 3) from t; ---- logical_plan 01)Projection: array_transform(List([1.0, 2.0]), (v) -> v + CAST(t.number AS Float64)) AS array_transform(make_array(Float64(1),Float64(2)),(v) -> v + t.number) @@ -168,19 +90,6 @@ physical_plan 01)ProjectionExec: expr=[array_transform([1.0, 2.0], (v) -> v@ + CAST(number@0 AS Float64)) as array_transform(make_array(Float64(1),Float64(2)),(v) -> v + t.number)] 02)--DataSourceExec: partitions=1, partition_sizes=[1] -#cse -query TT -explain select t.number*2, array_transform([1], v -> v + t.number*2) from t; ----- -logical_plan -01)Projection: __common_expr_1 AS t.number * Int64(2), array_transform(List([1]), (v) -> v + __common_expr_1) AS array_transform(make_array(Int64(1)),(v) -> v + t.number * Int64(2)) -02)--Projection: CAST(t.number AS Int64) * Int64(2) AS __common_expr_1 -03)----TableScan: t projection=[number] -physical_plan -01)ProjectionExec: expr=[__common_expr_1@0 as t.number * Int64(2), array_transform([1], (v) -> v@ + __common_expr_1@0) as array_transform(make_array(Int64(1)),(v) -> v + t.number * Int64(2))] -02)--ProjectionExec: expr=[CAST(number@0 AS Int64) * 2 as __common_expr_1] -03)----DataSourceExec: partitions=1, partition_sizes=[1] - #cse should not eliminate subtrees containing lambdas query TT explain select array_transform([t.number], v -> 5), array_transform([t.number+1], v -> 5) from t; From b3bdc48d554b9687a95f31021b825e6c3274e463 Mon Sep 17 00:00:00 2001 From: gstvg <28798827+gstvg@users.noreply.github.com> Date: Wed, 18 Mar 2026 09:30:06 -0300 Subject: [PATCH 21/47] fix removal of lambda features --- .../core/src/execution/session_state.rs | 30 ++++++------------- datafusion/expr/src/udlf.rs | 1 - 2 files changed, 9 insertions(+), 22 deletions(-) diff --git a/datafusion/core/src/execution/session_state.rs b/datafusion/core/src/execution/session_state.rs index ed9ecc393cb0b..7389f82611dd8 100644 --- a/datafusion/core/src/execution/session_state.rs +++ b/datafusion/core/src/execution/session_state.rs @@ -1576,28 +1576,16 @@ impl SessionStateBuilder { if let Some(lambda_functions) = lambda_functions { for udlf in lambda_functions { - let config_options = state.config().options(); - match udlf.with_updated_config(config_options) { - Some(new_udf) => { - if let Err(err) = state.register_udlf(new_udf) { - debug!( - "Failed to re-register updated UDLF '{}': {}", - udlf.name(), - err - ); - } + match state.register_udlf(Arc::clone(&udlf)) { + Ok(Some(existing)) => { + debug!("Overwrote existing UDLF '{}'", existing.name()); + } + Ok(None) => { + debug!("Registered UDLF '{}'", udlf.name()); + } + Err(err) => { + debug!("Failed to register UDLF '{}': {}", udlf.name(), err); } - None => match state.register_udlf(Arc::clone(&udlf)) { - Ok(Some(existing)) => { - debug!("Overwrote existing UDLF '{}'", existing.name()); - } - Ok(None) => { - debug!("Registered UDLF '{}'", udlf.name()); - } - Err(err) => { - debug!("Failed to register UDLF '{}': {}", udlf.name(), err); - } - }, } } } diff --git a/datafusion/expr/src/udlf.rs b/datafusion/expr/src/udlf.rs index 820195dbc9971..6cad63d297429 100644 --- a/datafusion/expr/src/udlf.rs +++ b/datafusion/expr/src/udlf.rs @@ -18,7 +18,6 @@ //! [`LambdaUDF`]: Lambda User Defined Functions use crate::expr::schema_name_from_exprs_comma_separated_without_space; -use crate::simplify::{ExprSimplifyResult, SimplifyContext}; use crate::{ColumnarValue, Documentation, Expr}; use arrow::array::{ArrayRef, RecordBatch}; use arrow::datatypes::{DataType, Field, FieldRef, Schema}; From 9728a2ec06d6a2ba6f55a17a2f8aff01a691e1a7 Mon Sep 17 00:00:00 2001 From: gstvg <28798827+gstvg@users.noreply.github.com> Date: Wed, 18 Mar 2026 19:19:22 -0300 Subject: [PATCH 22/47] fix typo --- datafusion/expr/src/udlf.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/expr/src/udlf.rs b/datafusion/expr/src/udlf.rs index 6cad63d297429..bcd7f837fe843 100644 --- a/datafusion/expr/src/udlf.rs +++ b/datafusion/expr/src/udlf.rs @@ -202,7 +202,7 @@ impl LambdaArgument { } /// Evaluate this lambda - /// `args` should evalute to the value of each parameter + /// `args` should evaluate to the value of each parameter /// of the correspondent lambda returned in [LambdaUDF::lambdas_parameters]. pub fn evaluate( &self, From f724ef501d020774545d29d621f98d93ebdd11eb Mon Sep 17 00:00:00 2001 From: gstvg <28798827+gstvg@users.noreply.github.com> Date: Wed, 18 Mar 2026 21:33:24 -0300 Subject: [PATCH 23/47] remove paste! from lambda macros --- .../functions-nested/src/macros_lambda.rs | 56 +++++++++---------- 1 file changed, 25 insertions(+), 31 deletions(-) diff --git a/datafusion/functions-nested/src/macros_lambda.rs b/datafusion/functions-nested/src/macros_lambda.rs index 2f426ba8ee9b3..cd8b4dbdfd263 100644 --- a/datafusion/functions-nested/src/macros_lambda.rs +++ b/datafusion/functions-nested/src/macros_lambda.rs @@ -50,33 +50,29 @@ macro_rules! make_udlf_expr_and_func { make_udlf_expr_and_func!($UDF, $EXPR_FN, $($arg)*, $DOC, $LAMBDA_UDF_FN, $UDF::new); }; ($UDF:ident, $EXPR_FN:ident, $($arg:ident)*, $DOC:expr, $LAMBDA_UDF_FN:ident, $CTOR:path) => { - paste::paste! { - // "fluent expr_fn" style function - #[doc = $DOC] - pub fn $EXPR_FN($($arg: datafusion_expr::Expr),*) -> datafusion_expr::Expr { - datafusion_expr::Expr::LambdaFunction(datafusion_expr::expr::LambdaFunction::new( - $LAMBDA_UDF_FN(), - vec![$($arg),*], - )) - } - create_lambda!($UDF, $LAMBDA_UDF_FN, $CTOR); + // "fluent expr_fn" style function + #[doc = $DOC] + pub fn $EXPR_FN($($arg: datafusion_expr::Expr),*) -> datafusion_expr::Expr { + datafusion_expr::Expr::LambdaFunction(datafusion_expr::expr::LambdaFunction::new( + $LAMBDA_UDF_FN(), + vec![$($arg),*], + )) } + create_lambda!($UDF, $LAMBDA_UDF_FN, $CTOR); }; ($UDF:ident, $EXPR_FN:ident, $DOC:expr, $LAMBDA_UDF_FN:ident) => { make_udlf_expr_and_func!($UDF, $EXPR_FN, $DOC, $LAMBDA_UDF_FN, $UDF::new); }; ($UDF:ident, $EXPR_FN:ident, $DOC:expr, $LAMBDA_UDF_FN:ident, $CTOR:path) => { - paste::paste! { - // "fluent expr_fn" style function - #[doc = $DOC] - pub fn $EXPR_FN(arg: Vec) -> datafusion_expr::Expr { - datafusion_expr::Expr::LambdaFunction(datafusion_expr::expr::LambdaFunction::new( - $LAMBDA_UDF_FN(), - arg, - )) - } - create_lambda!($UDF, $LAMBDA_UDF_FN, $CTOR); + // "fluent expr_fn" style function + #[doc = $DOC] + pub fn $EXPR_FN(arg: Vec) -> datafusion_expr::Expr { + datafusion_expr::Expr::LambdaFunction(datafusion_expr::expr::LambdaFunction::new( + $LAMBDA_UDF_FN(), + arg, + )) } + create_lambda!($UDF, $LAMBDA_UDF_FN, $CTOR); }; } @@ -97,17 +93,15 @@ macro_rules! create_lambda { create_lambda!($UDF, $LAMBDA_UDF_FN, $UDF::new); }; ($UDF:ident, $LAMBDA_UDF_FN:ident, $CTOR:path) => { - paste::paste! { - #[doc = concat!("LambdaFunction that returns a [`LambdaUDF`](datafusion_expr::LambdaUDF) for ")] - #[doc = stringify!($UDF)] - pub fn $LAMBDA_UDF_FN() -> std::sync::Arc { - // Singleton instance of [`$UDF`], ensures the UDF is only created once - static INSTANCE: std::sync::LazyLock> = - std::sync::LazyLock::new(|| { - std::sync::Arc::new($CTOR()) - }); - std::sync::Arc::clone(&INSTANCE) - } + #[doc = concat!("LambdaFunction that returns a [`LambdaUDF`](datafusion_expr::LambdaUDF) for ")] + #[doc = stringify!($UDF)] + pub fn $LAMBDA_UDF_FN() -> std::sync::Arc { + // Singleton instance of [`$UDF`], ensures the UDF is only created once + static INSTANCE: std::sync::LazyLock> = + std::sync::LazyLock::new(|| { + std::sync::Arc::new($CTOR()) + }); + std::sync::Arc::clone(&INSTANCE) } }; } From 5380884d42133a65fd943a928b2e05d2cdf54d0b Mon Sep 17 00:00:00 2001 From: gstvg <28798827+gstvg@users.noreply.github.com> Date: Wed, 18 Mar 2026 21:41:38 -0300 Subject: [PATCH 24/47] fix lambda sqllogictests --- datafusion/sqllogictest/test_files/lambda.slt | 30 ++++++++++--------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/datafusion/sqllogictest/test_files/lambda.slt b/datafusion/sqllogictest/test_files/lambda.slt index 585a6a07f6b88..08d34ae9bd391 100644 --- a/datafusion/sqllogictest/test_files/lambda.slt +++ b/datafusion/sqllogictest/test_files/lambda.slt @@ -62,7 +62,7 @@ SELECT array_transform([[1, 2]], a -> array_transform(a, a -> a+1)) query ? SELECT array_transform([[[1], [2], [3]]], a -> array_transform(a, b -> array_transform(b, c -> c*2))); ---- -[[[4], [6], [8]]] +[[[2], [4], [6]]] # parameter shadows unqualified column query I? @@ -74,20 +74,20 @@ SELECT number, array_transform([1, 2], number -> number+1) from t; # type coercion inside lambda body query ? -SELECT array_transform([1.0, 2.0], v -> v + 3) from t; +SELECT array_transform([t.number], v -> v + 3.0) from t; ---- -[11.0, 12.0] -[41.0, 42.0] -[61.0, 62.0] +[13.0] +[43.0] +[63.0] query TT -EXPLAIN SELECT array_transform([1.0, 2.0], v -> v + 3) from t; +EXPLAIN SELECT array_transform([t.number], v -> v + 3.0) from t; ---- logical_plan -01)Projection: array_transform(List([1.0, 2.0]), (v) -> v + CAST(t.number AS Float64)) AS array_transform(make_array(Float64(1),Float64(2)),(v) -> v + t.number) +01)Projection: array_transform(make_array(t.number), (v) -> CAST(v AS Float64) + Float64(3)) 02)--TableScan: t projection=[number] physical_plan -01)ProjectionExec: expr=[array_transform([1.0, 2.0], (v) -> v@ + CAST(number@0 AS Float64)) as array_transform(make_array(Float64(1),Float64(2)),(v) -> v + t.number)] +01)ProjectionExec: expr=[array_transform(make_array(number@0), (v) -> CAST(v@ AS Float64) + 3) as array_transform(make_array(t.number),(v) -> v + Float64(3))] 02)--DataSourceExec: partitions=1, partition_sizes=[1] #cse should not eliminate subtrees containing lambdas @@ -103,14 +103,16 @@ physical_plan #cse should not eliminate subtrees containing lambda variables query TT -explain select array_transform([t.number], v -> v*2), array_transform([t.number+1], (v, i) -> v*2) from t; +explain select array_transform([t.number], v -> v*2), array_transform([t.number], v -> v*2-1) from t; ---- logical_plan -01)Projection: array_transform(make_array(t.number), (v) -> CAST(v AS Int64) * Int64(2)), array_transform(make_array(CAST(t.number AS Int64) + Int64(1)), (v, i) -> v * Int64(2)) -02)--TableScan: t projection=[number] +01)Projection: array_transform(__common_expr_1 AS make_array(t.number), (v) -> CAST(v AS Int64) * Int64(2)), array_transform(__common_expr_1 AS make_array(t.number), (v) -> CAST(v AS Int64) * Int64(2) - Int64(1)) +02)--Projection: make_array(t.number) AS __common_expr_1 +03)----TableScan: t projection=[number] physical_plan -01)ProjectionExec: expr=[array_transform(make_array(number@0), (v) -> CAST(v@ AS Int64) * 2) as array_transform(make_array(t.number),(v) -> v * Int64(2)), array_transform(make_array(CAST(number@0 AS Int64) + 1), (v, i) -> v@ * 2) as array_transform(make_array(t.number + Int64(1)),(v, i) -> v * Int64(2))] -02)--DataSourceExec: partitions=1, partition_sizes=[1] +01)ProjectionExec: expr=[array_transform(__common_expr_1@0, (v) -> CAST(v@ AS Int64) * 2) as array_transform(make_array(t.number),(v) -> v * Int64(2)), array_transform(__common_expr_1@0, (v) -> CAST(v@ AS Int64) * 2 - 1) as array_transform(make_array(t.number),(v) -> v * Int64(2) - Int64(1))] +02)--ProjectionExec: expr=[make_array(number@0) as __common_expr_1] +03)----DataSourceExec: partitions=1, partition_sizes=[1] # test that sql planner plans resolved lambda variables, as v[1] planning checks the datatype of lhs @@ -163,7 +165,7 @@ select array_transform(1, v -> v*2); query error DataFusion error: Error during planning: array_transform expects a value followed by a lambda, got Lambda\(\(\)\) and Value\(Field \{ name: "make_array\(Int64\(1\),Int64\(2\)\)", data_type: List\(Field \{ data_type: Int64, nullable: true \}\), nullable: true \}\) select array_transform(v -> v*2, [1, 2]); -query error DataFusion error: Error during planning: lambda defined 3 params but UDF support only 2 +query error DataFusion error: Error during planning: lambda defined 3 params but UDF support only 1 SELECT array_transform([1, 2], (e, i, j) -> i); query error DataFusion error: Error during planning: lambda parameters names must be unique, got \(v, v\) From d75dfe3cad6560a1b105dbf013b33e19d72a469b Mon Sep 17 00:00:00 2001 From: gstvg <28798827+gstvg@users.noreply.github.com> Date: Mon, 23 Mar 2026 12:59:48 -0300 Subject: [PATCH 25/47] improve Expr::Lambda docs --- datafusion/expr/src/expr.rs | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index e2b3643235933..54146a3c37fb9 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -407,6 +407,18 @@ pub enum Expr { /// Unnest expression Unnest(Unnest), /// Call a lambda function with a set of arguments. + /// + /// For example, `array_transform([1,2,3], v -> v+1)` would be equivalent to: + /// + /// ```ignore + /// LambdaFunction(array_transform) + /// ├── args[0]: Literal([1,2,3]) + /// └── args[1]: Lambda + /// ├── params: ["v"] + /// └── body: BinaryExpr(+) + /// ├── LambdaVariable("v") + /// └── Literal(1) + /// ``` LambdaFunction(LambdaFunction), /// A Lambda expression with a set of parameters names and a body Lambda(Lambda), From a241a5148830ae7e374e3add89d206d8f48a08b0 Mon Sep 17 00:00:00 2001 From: gstvg <28798827+gstvg@users.noreply.github.com> Date: Mon, 23 Mar 2026 13:00:09 -0300 Subject: [PATCH 26/47] add clarifying comment on lambda type coercion --- datafusion/expr/src/type_coercion/functions.rs | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/datafusion/expr/src/type_coercion/functions.rs b/datafusion/expr/src/type_coercion/functions.rs index a3d4a4a69c0d0..c5161ce3a3638 100644 --- a/datafusion/expr/src/type_coercion/functions.rs +++ b/datafusion/expr/src/type_coercion/functions.rs @@ -184,6 +184,14 @@ pub fn value_fields_with_lambda_udf( ); } + + // coerced_types has been partitioned from current_fields + // and refers only to values and not to lambdas, so instead + // of zipping them, we iterate over current_fields and only + // consume from coerced_types when a given argument is a value + // to reconstruct the arguments list with the correct order + // this supports any value and lambda positioning including + // multiple lambdas interleaved with values let mut coerced_types = coerced_types.into_iter(); Ok(current_fields From 547c1488b37a929c55ed6adeb625a6fba35e341c Mon Sep 17 00:00:00 2001 From: gstvg <28798827+gstvg@users.noreply.github.com> Date: Mon, 23 Mar 2026 13:00:26 -0300 Subject: [PATCH 27/47] simplify lambda type coercion --- .../optimizer/src/analyzer/type_coercion.rs | 29 +++++++------------ 1 file changed, 11 insertions(+), 18 deletions(-) diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 38edcd3524391..8da23892e4744 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -758,25 +758,18 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> { let new_fields = value_fields_with_lambda_udf(¤t_fields, func.as_ref())?; - let transformed = current_fields != new_fields; - - let new_args = if transformed { - std::iter::zip(args, new_fields) - .map(|(arg, new_field)| match (&arg, new_field) { - (Expr::Lambda(_lambda), ValueOrLambda::Lambda(_)) => Ok(arg), - (Expr::Lambda(_lambda), ValueOrLambda::Value(_)) => plan_err!("value_fields_with_lambda_udf return a value for a lambda argument"), - (_, ValueOrLambda::Value(new_field)) => arg.cast_to(new_field.data_type(), self.schema), - (_, ValueOrLambda::Lambda(_)) => plan_err!("value_fields_with_lambda_udf return a lambda for a value argument"), - }) - .collect::>()? - } else { - args - }; + let new_args = std::iter::zip(args, new_fields) + .map(|(arg, new_field)| match (&arg, new_field) { + (Expr::Lambda(_lambda), ValueOrLambda::Lambda(_)) => Ok(arg), + (Expr::Lambda(_lambda), ValueOrLambda::Value(_)) => plan_err!("value_fields_with_lambda_udf return a value for a lambda argument"), + (_, ValueOrLambda::Value(new_field)) => arg.cast_to(new_field.data_type(), self.schema), + (_, ValueOrLambda::Lambda(_)) => plan_err!("value_fields_with_lambda_udf return a lambda for a value argument"), + }) + .collect::>()?; - Ok(Transformed::new_transformed( - Expr::LambdaFunction(LambdaFunction::new(func, new_args)), - transformed, - )) + Ok(Transformed::yes(Expr::LambdaFunction( + LambdaFunction::new(func, new_args) + ))) } // TODO: remove the next line after `Expr::Wildcard` is removed #[expect(deprecated)] From 6ae73cb4a569d6693d37b601cca38a4c689722df Mon Sep 17 00:00:00 2001 From: gstvg <28798827+gstvg@users.noreply.github.com> Date: Tue, 24 Mar 2026 12:32:26 -0300 Subject: [PATCH 28/47] handle null values in array_transform --- Cargo.lock | 1 + datafusion/common/src/utils/mod.rs | 136 +++++++- datafusion/functions-nested/Cargo.toml | 2 + .../functions-nested/src/array_transform.rs | 329 +++++++++++++++++- 4 files changed, 460 insertions(+), 8 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 8101e2fe22d92..bcce6b218f440 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2319,6 +2319,7 @@ dependencies = [ "datafusion-functions-aggregate", "datafusion-functions-aggregate-common", "datafusion-macros", + "datafusion-physical-expr", "datafusion-physical-expr-common", "hashbrown 0.16.1", "itertools 0.14.0", diff --git a/datafusion/common/src/utils/mod.rs b/datafusion/common/src/utils/mod.rs index ee08b70d1597c..f0bb6a18903d9 100644 --- a/datafusion/common/src/utils/mod.rs +++ b/datafusion/common/src/utils/mod.rs @@ -26,6 +26,7 @@ pub mod string_utils; use crate::assert_or_internal_err; use crate::error::{_exec_datafusion_err, _exec_err, _internal_datafusion_err}; use crate::{Result, ScalarValue}; +use arrow::array::GenericListArray; use arrow::array::{ Array, ArrayRef, FixedSizeListArray, LargeListArray, ListArray, OffsetSizeTrait, cast::AsArray, @@ -977,8 +978,8 @@ pub fn take_function_args( /// you must adjust the offsets using [`adjust_offsets_for_slice`] pub fn list_values(array: &dyn Array) -> Result { match array.data_type() { - DataType::List(_) => Ok(Arc::clone(array.as_list::().values())), - DataType::LargeList(_) => Ok(Arc::clone(array.as_list::().values())), + DataType::List(_) => Ok(sliced_list_values(array.as_list::())), + DataType::LargeList(_) => Ok(sliced_list_values(array.as_list::())), DataType::FixedSizeList(_, _) => { Ok(Arc::clone(array.as_fixed_size_list().values())) } @@ -986,11 +987,49 @@ pub fn list_values(array: &dyn Array) -> Result { } } +fn sliced_list_values(list: &GenericListArray) -> ArrayRef { + let values = list.values(); + let offsets = list.offsets(); + + if let (Some(first), Some(last)) = (offsets.first(), offsets.last()) { + let first = first.to_usize().unwrap(); + let last = last.to_usize().unwrap(); + + if first != 0 || last != values.len() { + return values.slice(first, last - first); + } + } + + Arc::clone(values) +} + +/// If `list` is sliced, returns an adjusted offset buffer so that +/// it points to the sliced portion of the list values, and not the whole list values +pub fn adjust_offsets_for_slice( + list: &GenericListArray, +) -> OffsetBuffer { + let offsets = list.offsets(); + + if let (Some(first), Some(last)) = (offsets.first(), offsets.last()) + && (!first.is_zero() || last.to_usize().unwrap() != list.values().len()) + { + let offsets = offsets.iter().map(|offset| *offset - *first).collect(); + + //todo: use unsafe Offset::new_unchecked? + return OffsetBuffer::new(offsets); + } + + offsets.clone() +} + #[cfg(test)] mod tests { use super::*; use crate::ScalarValue::Null; - use arrow::array::Float64Array; + use arrow::{ + array::{Float64Array, Int32Array}, + datatypes::Int32Type, + }; use sqlparser::ast::Ident; #[test] @@ -1291,4 +1330,95 @@ mod tests { assert_eq!(expected, transposed); Ok(()) } + + #[test] + fn test_sliced_list_values() { + let data = vec![ + Some(vec![Some(0), Some(1), Some(2)]), + None, + Some(vec![Some(3), None, Some(5)]), + Some(vec![Some(6), Some(7)]), + ]; + + let list = ListArray::from_iter_primitive::(data); + + assert_eq!( + sliced_list_values(&list).as_primitive(), + &Int32Array::from(vec![ + Some(0), + Some(1), + Some(2), + Some(3), + None, + Some(5), + Some(6), + Some(7) + ]) + ); + + assert_eq!( + sliced_list_values(&list.slice(0, 1)).as_primitive(), + &Int32Array::from(vec![Some(0), Some(1), Some(2)]) + ); + + assert_eq!( + sliced_list_values(&list.slice(2, 1)).as_primitive(), + &Int32Array::from(vec![Some(3), None, Some(5)]) + ); + + assert_eq!( + sliced_list_values(&list.slice(3, 1)).as_primitive(), + &Int32Array::from(vec![Some(6), Some(7)]) + ); + + assert!(sliced_list_values(&list.slice(0, 0)).is_empty()); + assert!(sliced_list_values(&list.slice(1, 0)).is_empty()); + assert!(sliced_list_values(&list.slice(3, 0)).is_empty()); + } + + #[test] + fn test_adjust_offsets() { + let data = vec![ + Some(vec![Some(0), Some(1), Some(2)]), + None, + Some(vec![Some(3), None, Some(5)]), + Some(vec![Some(6), Some(7)]), + ]; + let list = ListArray::from_iter_primitive::(data); + + assert_eq!( + adjust_offsets_for_slice(&list), + OffsetBuffer::from_lengths([3, 0, 3, 2]) + ); + + assert_eq!( + adjust_offsets_for_slice(&list.slice(0, 1)), + OffsetBuffer::from_lengths([3]) + ); + + assert_eq!( + adjust_offsets_for_slice(&list.slice(1, 2)), + OffsetBuffer::from_lengths([0, 3]) + ); + + assert_eq!( + adjust_offsets_for_slice(&list.slice(1, 3)), + OffsetBuffer::from_lengths([0, 3, 2]) + ); + + assert_eq!( + adjust_offsets_for_slice(&list.slice(0, 0)), + OffsetBuffer::from_lengths([]) + ); + + assert_eq!( + adjust_offsets_for_slice(&list.slice(1, 0)), + OffsetBuffer::from_lengths([]) + ); + + assert_eq!( + adjust_offsets_for_slice(&list.slice(3, 0)), + OffsetBuffer::from_lengths([]) + ); + } } diff --git a/datafusion/functions-nested/Cargo.toml b/datafusion/functions-nested/Cargo.toml index 5fce3e854eb33..9877aa03c34f1 100644 --- a/datafusion/functions-nested/Cargo.toml +++ b/datafusion/functions-nested/Cargo.toml @@ -65,6 +65,8 @@ log = { workspace = true } [dev-dependencies] criterion = { workspace = true, features = ["async_tokio"] } rand = { workspace = true } +# used to test array_transform +datafusion-physical-expr = { workspace = true } [[bench]] harness = false diff --git a/datafusion/functions-nested/src/array_transform.rs b/datafusion/functions-nested/src/array_transform.rs index 446c9b9a45f05..f336ed9aff543 100644 --- a/datafusion/functions-nested/src/array_transform.rs +++ b/datafusion/functions-nested/src/array_transform.rs @@ -18,12 +18,17 @@ //! [`LambdaUDF`] definitions for array_transform function. use arrow::{ - array::{Array, ArrayRef, AsArray, FixedSizeListArray, LargeListArray, ListArray}, + array::{ + Array, ArrayRef, AsArray, FixedSizeListArray, GenericListArray, LargeListArray, + ListArray, OffsetSizeTrait, UInt64Array, new_null_array, + }, + buffer::OffsetBuffer, + compute::take, datatypes::{DataType, Field, FieldRef}, }; use datafusion_common::{ - Result, exec_err, plan_err, - utils::{list_values, take_function_args}, + Result, exec_err, internal_datafusion_err, plan_err, + utils::{adjust_offsets_for_slice, list_values, take_function_args}, }; use datafusion_expr::{ ColumnarValue, Documentation, LambdaFunctionArgs, LambdaReturnFieldArgs, @@ -173,6 +178,20 @@ impl LambdaUDF for ArrayTransform { let (list, lambda) = value_lambda_pair(self.name(), &args.args)?; let list_array = list.to_array(args.number_rows)?; + + // Fast path for fully null input array and also the only way to safely work with + // a fully null fixed size list array as it can't be handled by remove_list_null_values below + if list_array.null_count() == list_array.len() { + return Ok(ColumnarValue::Array(new_null_array(args.return_type(), list_array.len()))) + } + + // null sublists may contain values that cause problems, like a 0 used on a division + // use remove_list_null_values to remove them + let list_array = remove_list_null_values(&list_array)?; + + // as per list_values docs, if list_array is sliced, list_values will be sliced too, + // so before constructing the transformed array below, we must adjust the list offsets with + // adjust_offsets_for_slice let list_values = list_values(&list_array)?; // by passing closures, lambda.evaluate can evaluate only those actually needed @@ -200,9 +219,13 @@ impl LambdaUDF for ArrayTransform { DataType::List(_) => { let list = list_array.as_list(); + // since we called list_values above which would return sliced values for + // a sliced list, we must adjust the offsets here as otherwise they would be invalid + let adjusted_offsets = adjust_offsets_for_slice(list); + Arc::new(ListArray::new( field, - list.offsets().clone(), + adjusted_offsets, transformed_values, list.nulls().cloned(), )) as ArrayRef @@ -210,9 +233,13 @@ impl LambdaUDF for ArrayTransform { DataType::LargeList(_) => { let large_list = list_array.as_list(); + // since we called list_values above which would return sliced values for + // a sliced list, we must adjust the offsets here as otherwise they would be invalid + let adjusted_offsets = adjust_offsets_for_slice(large_list); + Arc::new(LargeListArray::new( field, - large_list.offsets().clone(), + adjusted_offsets, transformed_values, large_list.nulls().cloned(), )) @@ -251,3 +278,295 @@ fn value_lambda_pair<'a, V: Debug, L: Debug>( Ok((value, lambda)) } + +//todo: make this function public and move to a more generic crate like datafusion-common +fn remove_list_null_values(list: &dyn Array) -> Result { + match list.data_type() { + DataType::List(_) => { + Ok(Arc::new(truncate_nulls(list.as_list::())?)) + } + DataType::LargeList(_) => { + Ok(Arc::new(truncate_nulls(list.as_list::())?)) + } + DataType::FixedSizeList(_, _) => { + Ok(Arc::new(replace_nulls_with_valid(list.as_fixed_size_list())?)) + } + dt => exec_err!("expected list, got {dt}"), + } +} + +fn replace_nulls_with_valid(list: &FixedSizeListArray) -> Result { + if let Some(nulls) = list.nulls() { + let null_count = list.null_count(); + + if null_count > 0 { + if null_count == list.len() { + return exec_err!("no valid value to use"); + } + + let first_valid = nulls + .inner() + .set_indices() + .next() + .ok_or_else(|| internal_datafusion_err!("fixed size list should have been checked to contain at least one valid value"))? as u64; + + let mut indices = Vec::with_capacity(list.values().len()); + + let size = list.value_length() as u64; + + for (i, is_valid) in nulls.iter().enumerate() { + let range = if is_valid { + let i = i as u64; + + i * size..(i + 1) * size + } else { + first_valid * size..(first_valid + 1) * size + }; + + indices.extend(range) + } + + let indices = UInt64Array::from(indices); + let values = take(list.values(), &indices, None)?; + + let DataType::FixedSizeList(field, size) = list.data_type() else { + unreachable!() + }; + + return Ok(FixedSizeListArray::try_new( + Arc::clone(field), + *size, + values, + list.nulls().cloned(), + )?); + } + } + + Ok(list.clone()) +} + +fn truncate_nulls( + list: &GenericListArray, +) -> Result> { + if let Some(nulls) = list.nulls() { + if list.null_count() > 0 { + let contains_null_and_non_empty = + std::iter::zip(list.offsets().lengths(), nulls) + .any(|(len, is_valid)| len > 0 && !is_valid); + + if contains_null_and_non_empty { + let mut indices = Vec::with_capacity(list.values().len()); + + let lengths = list.offsets().windows(2).enumerate().map(|(i, window)| { + let start = window[0].as_usize(); + let end = window[1].as_usize(); + + if list.is_valid(i) { + indices.extend((start..end).map(|i| i as u64)); + + end - start + } else { + 0 + } + }); + + let offsets = OffsetBuffer::from_lengths(lengths); + let indices = UInt64Array::from(indices); + let values = take(list.values(), &indices, None)?; + + let field = match list.data_type() { + DataType::List(field) => field, + DataType::LargeList(field) => field, + _ => unreachable!(), + }; + + return Ok(GenericListArray::try_new( + Arc::clone(field), + offsets, + values, + list.nulls().cloned(), + )?); + } + } + } + + Ok(list.clone()) +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use arrow::{ + array::{Array, ArrayRef, AsArray, FixedSizeListArray, Int32Array, ListArray}, + buffer::{NullBuffer, OffsetBuffer}, + datatypes::{DataType, Field, FieldRef}, + }; + use datafusion_common::{DFSchema, Result, config::ConfigOptions}; + use datafusion_expr::{ + LambdaArgument, LambdaFunctionArgs, ValueOrLambda, + execution_props::ExecutionProps, lambda_var, lit, + }; + use datafusion_physical_expr::create_physical_expr; + + use crate::array_transform::array_transform_udlf; + + fn create_i32_list( + values: impl Into, + offsets: OffsetBuffer, + nulls: Option, + ) -> ListArray { + let list_field = Arc::new(Field::new_list_field(DataType::Int32, true)); + + ListArray::new(list_field, offsets, Arc::new(values.into()), nulls) + } + + fn create_i32_fsl( + size: i32, + values: Vec, + nulls: Option, + ) -> FixedSizeListArray { + FixedSizeListArray::new( + Arc::new(Field::new_list_field(DataType::Int32, true)), + size, + Arc::new(Int32Array::from(values)), + nulls, + ) + } + + fn int32_field() -> FieldRef { + Arc::new(Field::new("", DataType::Int32, true)) + } + + fn divide_100_by(list: impl Array + Clone + 'static) -> Result { + let array_transform = array_transform_udlf(); + + let lambda = create_physical_expr( + &(lit(100i32) / lambda_var("v", int32_field())), + &DFSchema::empty(), + &ExecutionProps::new(), + )?; + + array_transform + .invoke_with_args(LambdaFunctionArgs { + args: vec![ + ValueOrLambda::Value(datafusion_expr::ColumnarValue::Array( + Arc::new(list.clone()), + )), + ValueOrLambda::Lambda(LambdaArgument::new( + vec![Arc::new(Field::new("v", DataType::Int32, true))], + lambda, + )), + ], + arg_fields: vec![ + ValueOrLambda::Value(Arc::new(Field::new( + "", + list.data_type().clone(), + list.is_nullable(), + ))), + ValueOrLambda::Lambda(int32_field()), + ], + number_rows: list.len(), + return_field: Arc::new(Field::new_list( + "", + Field::new_list_field(DataType::Int32, true), + list.is_nullable(), + )), + config_options: Arc::new(ConfigOptions::new()), + })? + .into_array(list.len()) + } + + #[test] + fn transform_on_sliced_list_should_not_evaluate_on_unreachable_values() { + let list = create_i32_list( + vec![ + // Have 0 here so if the expression is called on data that it will fail + 0, 4, 100, 25, 20, 5, 2, 1, 10, + ], + OffsetBuffer::::from_lengths(vec![1, 3, 4, 1]), + None, + ) + .slice(1, 3); + + let res = divide_100_by(list).unwrap(); + + let actual_list = res.as_list::(); + + let expected_list = create_i32_list( + vec![25, 1, 4, 5, 20, 50, 100, 10], + OffsetBuffer::::from_lengths(vec![3, 4, 1]), + None, + ); + + assert_eq!(actual_list, &expected_list); + } + + #[test] + fn transform_on_sliced_fsl_should_not_evaluate_on_unreachable_values() { + let list = create_i32_fsl( + 3, + vec![ + // Have 0 here so if the expression is called on data that it will fail + 0, 4, 100, 25, 20, 5, 2, 1, 10, + ], + None, + ) + .slice(1, 2); + + let res = divide_100_by(list).unwrap(); + + let actual_list = res.as_fixed_size_list(); + + let expected_list = create_i32_fsl(3, vec![4, 5, 20, 50, 100, 10], None); + + assert_eq!(actual_list, &expected_list); + } + + #[test] + fn transform_function_should_not_be_evaluated_on_values_underlying_null() { + let list = create_i32_list( + // 0 here for one of the values behind null, so if it will be evaluated + // it will fail due to divide by 0 + vec![100, 20, 10, 0, 1, 2, 0, 1, 50], + OffsetBuffer::::from_lengths(vec![3, 4, 2]), + Some(NullBuffer::from(vec![true, false, true])), + ); + + let res = divide_100_by(list).unwrap(); + + let actual_list = res.as_list::(); + + let expected_list = create_i32_list( + vec![1, 5, 10, 100, 2], + OffsetBuffer::::from_lengths(vec![3, 0, 2]), + Some(NullBuffer::from(vec![true, false, true])), + ); + + assert_eq!(actual_list.data_type(), expected_list.data_type()); + assert_eq!(actual_list, &expected_list); + } + + #[test] + fn transform_function_should_not_be_evaluated_on_values_underlying_null_fsl() { + let list = create_i32_fsl( + 3, + // 0 here for one of the values behind null, so if it will be evaluated + // it will fail due to divide by 0 + vec![100, 20, 10, 0, 1, 2, 0, 1, 50], + Some(NullBuffer::from(vec![true, false, false])), + ); + + let res = divide_100_by(list).unwrap(); + + let actual_list = res.as_fixed_size_list(); + + let expected_list = create_i32_fsl( + 3, + vec![1, 5, 10, 1, 5, 10, 1, 5, 10], + Some(NullBuffer::from(vec![true, false, false])), + ); + + assert_eq!(actual_list, &expected_list); + } +} From 39db62ba532d7a1349c1c275c14b8db71a291b6d Mon Sep 17 00:00:00 2001 From: gstvg <28798827+gstvg@users.noreply.github.com> Date: Tue, 24 Mar 2026 23:53:47 +0000 Subject: [PATCH 29/47] simplify LambdaUDF::lambdas_parameters --- datafusion/expr/src/expr.rs | 15 +-- datafusion/expr/src/udlf.rs | 30 +++--- .../functions-nested/src/array_transform.rs | 17 ++-- .../physical-expr/src/lambda_function.rs | 96 +++++++++++-------- datafusion/proto/src/bytes/mod.rs | 8 +- datafusion/sql/src/expr/function.rs | 56 ++++++++--- datafusion/sql/src/unparser/expr.rs | 4 +- datafusion/sqllogictest/test_files/lambda.slt | 2 +- 8 files changed, 136 insertions(+), 92 deletions(-) diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 54146a3c37fb9..08add22f7330f 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -447,10 +447,7 @@ impl LambdaFunction { /// Invokes the inner function [`LambdaUDF::lambdas_parameters`] /// using the arguments of this invocation - pub fn lambdas_parameters( - &self, - schema: &dyn ExprSchema, - ) -> Result>>> { + pub fn lambdas_parameters(&self, schema: &dyn ExprSchema) -> Result>> { let args = self .args .iter() @@ -460,9 +457,15 @@ impl LambdaFunction { }) .collect::>>()?; - let coerced = value_fields_with_lambda_udf(&args, self.func.as_ref())?; + let coerced_values = value_fields_with_lambda_udf(&args, self.func.as_ref())? + .into_iter() + .filter_map(|arg| match arg { + ValueOrLambda::Value(value) => Some(value), + ValueOrLambda::Lambda(_lambda) => None, + }) + .collect::>(); - self.func.lambdas_parameters(&coerced) + self.func.lambdas_parameters(&coerced_values) } } diff --git a/datafusion/expr/src/udlf.rs b/datafusion/expr/src/udlf.rs index bcd7f837fe843..26a902c83f140 100644 --- a/datafusion/expr/src/udlf.rs +++ b/datafusion/expr/src/udlf.rs @@ -308,10 +308,10 @@ pub trait LambdaUDF: Debug + DynEq + DynHash + Send + Sync { /// [`Volatility`]: datafusion_expr_common::signature::Volatility fn signature(&self) -> &LambdaSignature; - /// Returns a list of the same size as args where each value is the logic below applied to value at the correspondent position in args: - /// - /// If it's a value, return None - /// If it's a lambda, return the list of all parameters that that lambda supports + /// Return the field of all the parameters supported by all the supported lambdas of this function + /// based on the field of the value arguments. If a lambda support multiple parameters, or if multiple + /// lambdas are supported and some are optional, all should be returned, + /// regardless of whether they are used on a particular invocation /// /// Example for array_transform: /// @@ -319,33 +319,27 @@ pub trait LambdaUDF: Debug + DynEq + DynHash + Send + Sync { /// /// ```ignore /// let lambdas_parameters = array_transform.lambdas_parameters(&[ - /// ValueOrLambdaParameter::Value(Field::new("", DataType::new_list(DataType::Float32, false)))]), // the Field of the literal `[2, 8]` - /// ValueOrLambdaParameter::Lambda, // A lambda - /// ]?; + /// Arc::new(Field::new("", DataType::new_list(DataType::Float32, false))), // the Field of the literal `[2, 8]` + /// ])?; /// /// assert_eq!( /// lambdas_parameters, /// vec![ - /// // it's a value, return None - /// None, - /// // it's a lambda, return it's supported parameters, regardless of how many are actually used - /// Some(vec![ + /// // the lambda supported parameters, regardless of how many are actually used + /// vec![ /// // the value being transformed /// Field::new("", DataType::Float32, false), /// // the 1-based index being transformed, not used on the example above, /// //but implementations doesn't need to care about it /// Field::new("", DataType::Int32, false), - /// ]) + /// ] /// ] /// ) /// ``` /// /// The implementation can assume that some other part of the code has coerced /// the actual argument types to match [`Self::signature`]. - fn lambdas_parameters( - &self, - args: &[ValueOrLambda], - ) -> Result>>>; + fn lambdas_parameters(&self, value_fields: &[FieldRef]) -> Result>>; /// What type will be returned by this function, given the arguments? /// @@ -486,8 +480,8 @@ mod tests { fn lambdas_parameters( &self, - _args: &[ValueOrLambda], - ) -> Result>>> { + _value_fields: &[FieldRef], + ) -> Result>> { unimplemented!() } diff --git a/datafusion/functions-nested/src/array_transform.rs b/datafusion/functions-nested/src/array_transform.rs index f336ed9aff543..1f6a9fc47ca56 100644 --- a/datafusion/functions-nested/src/array_transform.rs +++ b/datafusion/functions-nested/src/array_transform.rs @@ -130,11 +130,16 @@ impl LambdaUDF for ArrayTransform { Ok(vec![coerced]) } - fn lambdas_parameters( - &self, - args: &[ValueOrLambda], - ) -> Result>>> { - let (list, _lambda) = value_lambda_pair(self.name(), args)?; + fn lambdas_parameters(&self, value_fields: &[FieldRef]) -> Result>> { + let list = if value_fields.len() == 1 { + &value_fields[0] + } else { + return plan_err!( + "{} function requires 1 value arguments, got {}", + self.name(), + value_fields.len() + ); + }; let field = match list.data_type() { DataType::List(field) => field, @@ -148,7 +153,7 @@ impl LambdaUDF for ArrayTransform { let value = Field::new("", field.data_type().clone(), field.is_nullable()) .with_metadata(field.metadata().clone()); - Ok(vec![None, Some(vec![value])]) + Ok(vec![vec![value]]) } fn return_field_from_args(&self, args: LambdaReturnFieldArgs) -> Result> { diff --git a/datafusion/physical-expr/src/lambda_function.rs b/datafusion/physical-expr/src/lambda_function.rs index 756d21c748b21..9d1adc7a459f7 100644 --- a/datafusion/physical-expr/src/lambda_function.rs +++ b/datafusion/physical-expr/src/lambda_function.rs @@ -40,7 +40,9 @@ use crate::expressions::{LambdaExpr, Literal}; use arrow::array::{Array, RecordBatch}; use arrow::datatypes::{DataType, FieldRef, Schema}; use datafusion_common::config::{ConfigEntry, ConfigOptions}; -use datafusion_common::{Result, ScalarValue, exec_err, internal_err}; +use datafusion_common::{ + Result, ScalarValue, exec_err, internal_datafusion_err, internal_err, +}; use datafusion_expr::type_coercion::functions::value_fields_with_lambda_udf; use datafusion_expr::{ ColumnarValue, LambdaArgument, LambdaFunctionArgs, LambdaReturnFieldArgs, LambdaUDF, @@ -249,47 +251,65 @@ impl PhysicalExpr for LambdaFunctionExpr { }) .collect::>>()?; - let args_metadata = arg_fields + let value_fields = arg_fields .iter() - .map(|field| match field { - ValueOrLambda::Value(field) => ValueOrLambda::Value(Arc::clone(field)), - ValueOrLambda::Lambda(_field) => ValueOrLambda::Lambda(()), + .filter_map(|field| match field { + ValueOrLambda::Value(field) => Some(Arc::clone(field)), + ValueOrLambda::Lambda(_field) => None, }) .collect::>(); - let params = self.fun().lambdas_parameters(&args_metadata)?; - - let args = std::iter::zip(&self.args, params) - .map(|(arg, lambda_params)| { - match (arg.as_any().downcast_ref::(), lambda_params) { - (Some(lambda), Some(lambda_params)) => { - if lambda.params().len() > lambda_params.len() { - return exec_err!( - "lambda defined {} params but UDF support only {}", - lambda.params().len(), - lambda_params.len() - ); - } - - let params = std::iter::zip(lambda.params(), lambda_params) - .map(|(name, param)| Arc::new(param.with_name(name))) - .collect(); - - Ok(ValueOrLambda::Lambda(LambdaArgument::new( - params, - Arc::clone(lambda.body()), - ))) + // lambdas_parameters refers only to lambdas and not to values, so instead + // of zipping it with self.args, we iterate over self.args and only + // consume from lambdas_parameters when a given argument is a lambda + // to reconstruct the arguments list with the correct order + // this supports any value and lambda positioning including + // multiple lambdas interleaved with values + let mut lambdas_parameters = + self.fun().lambdas_parameters(&value_fields)?.into_iter(); + let num_lambdas = self.args.len() - value_fields.len(); + + // functions can support multiple lambdas where some trailing ones are optional, + // but to simplify the implementor, lambdas_parameters returns the parameters of all of them, + // so we can't do equality check. one example is spark reduce: + // https://spark.apache.org/docs/latest/api/sql/index.html#reduce + if lambdas_parameters.len() < num_lambdas { + return exec_err!( + "{} invocation defined {num_lambdas} but lambdas_parameters returned only {}", + self.name(), + lambdas_parameters.len() + ); + } + + let args = self + .args + .iter() + .map(|arg| match arg.as_any().downcast_ref::() { + Some(lambda) => { + let lambda_params = lambdas_parameters.next().ok_or_else(|| { + internal_datafusion_err!( + "params len should have been checked above" + ) + })?; + + if lambda.params().len() > lambda_params.len() { + return exec_err!( + "lambda defined {} params but UDF support only {}", + lambda.params().len(), + lambda_params.len() + ); } - (Some(_lambda), None) => exec_err!( - "{} don't reported the parameters of one of it's lambdas", - self.fun.name() - ), - (None, Some(_lambda_params)) => exec_err!( - "{} reported parameters for an argument that is not a lambda", - self.fun.name() - ), - (None, None) => Ok(ValueOrLambda::Value(arg.evaluate(batch)?)), + + let params = std::iter::zip(lambda.params(), lambda_params) + .map(|(name, param)| Arc::new(param.with_name(name))) + .collect(); + + Ok(ValueOrLambda::Lambda(LambdaArgument::new( + params, + Arc::clone(lambda.body()), + ))) } + None => Ok(ValueOrLambda::Value(arg.evaluate(batch)?)), }) .collect::>>()?; @@ -400,8 +420,8 @@ mod tests { fn lambdas_parameters( &self, - _args: &[ValueOrLambda], - ) -> Result>>> { + _value_fields: &[FieldRef], + ) -> Result>> { unimplemented!() } diff --git a/datafusion/proto/src/bytes/mod.rs b/datafusion/proto/src/bytes/mod.rs index b0bd108ba4956..c314a6ae21244 100644 --- a/datafusion/proto/src/bytes/mod.rs +++ b/datafusion/proto/src/bytes/mod.rs @@ -146,12 +146,8 @@ impl Serializeable for Expr { fn lambdas_parameters( &self, - _args: &[datafusion_expr::ValueOrLambda< - arrow::datatypes::FieldRef, - (), - >], - ) -> Result>>> - { + _value_fields: &[arrow::datatypes::FieldRef], + ) -> Result>> { not_impl_err!("mock LambdaUDF") } diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index d16f6189ba33a..1f7574f5a8eaf 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -413,16 +413,50 @@ impl SqlToRel<'_, S> { }) .collect::>>()?; - let coerced = value_fields_with_lambda_udf(¤t_fields, fm.as_ref())?; - - let lambdas_parameters = fm.lambdas_parameters(&coerced)?; + let coerced_values = + value_fields_with_lambda_udf(¤t_fields, fm.as_ref())? + .into_iter() + .filter_map(|arg| match arg { + ValueOrLambda::Value(value) => Some(value), + ValueOrLambda::Lambda(_lambda) => None, + }) + .collect::>(); + + // lambdas_parameters refers only to lambdas and not to values, so instead + // of zipping it with partially_planned, we iterate over partially_planned and only + // consume from lambdas_parameters when a given argument is a lambda + // to reconstruct the arguments list with the correct order + // this supports any value and lambda positioning including + // multiple lambdas interleaved with values + let mut lambdas_parameters = + fm.lambdas_parameters(&coerced_values)?.into_iter(); + + let num_lambdas = partially_planned.len() - coerced_values.len(); + + // functions can support multiple lambdas where some trailing ones are optional, + // but to simplify the implementor, lambdas_parameters returns the parameters of all of them, + // so we can't do equality check. one example is spark reduce: + // https://spark.apache.org/docs/latest/api/sql/index.html#reduce + if lambdas_parameters.len() < num_lambdas { + return plan_err!( + "{} invocation defined {num_lambdas} but lambdas_parameters returned only {}", + fm.name(), + lambdas_parameters.len() + ); + } let args = partially_planned .into_iter() - .zip(lambdas_parameters) - .map(|(e, lambda_parameters)| match (e, lambda_parameters) { - (ExprOrLambda::Expr(expr), None) => Ok(expr), - (ExprOrLambda::Lambda(lambda), Some(lambda_params)) => { + .map(|arg| match arg { + ExprOrLambda::Expr(expr) => Ok(expr), + ExprOrLambda::Lambda(lambda) => { + let lambda_params = + lambdas_parameters.next().ok_or_else(|| { + internal_datafusion_err!( + "lambdas_parameters len should have been checked above" + ) + })?; + if lambda.params.len() > lambda_params.len() { return plan_err!( "lambda defined {} params but UDF support only {}", @@ -452,14 +486,6 @@ impl SqlToRel<'_, S> { )?), })) } - (ExprOrLambda::Expr(_), Some(_)) => plan_err!( - "{} reported parameters for an argument that is not a lambda", - fm.name() - ), - (ExprOrLambda::Lambda(_), None) => plan_err!( - "{} don't reported the parameters of one of it's lambdas", - fm.name() - ), }) .collect::>>()?; diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index 6c4d309dd8e16..1e5642646f0c0 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -1949,8 +1949,8 @@ mod tests { fn lambdas_parameters( &self, - _args: &[datafusion_expr::ValueOrLambda], - ) -> Result>>> { + _value_fields: &[FieldRef], + ) -> Result>> { unimplemented!() } diff --git a/datafusion/sqllogictest/test_files/lambda.slt b/datafusion/sqllogictest/test_files/lambda.slt index 08d34ae9bd391..fae77cc29a7a7 100644 --- a/datafusion/sqllogictest/test_files/lambda.slt +++ b/datafusion/sqllogictest/test_files/lambda.slt @@ -162,7 +162,7 @@ DataFusion error: Error during planning: array_transform function requires 1 val query error DataFusion error: Error during planning: array_transform expected a list as first argument, got Int64 select array_transform(1, v -> v*2); -query error DataFusion error: Error during planning: array_transform expects a value followed by a lambda, got Lambda\(\(\)\) and Value\(Field \{ name: "make_array\(Int64\(1\),Int64\(2\)\)", data_type: List\(Field \{ data_type: Int64, nullable: true \}\), nullable: true \}\) +query error DataFusion error: Error during planning: array_transform expects a value followed by a lambda, got Lambda\(Field \{ name: "\(v\) \-> v \* Int64\(2\)", data_type: Int64, nullable: true \}\) and Value\(Field \{ name: "make_array\(Int64\(1\),Int64\(2\)\)", data_type: List\(Field \{ data_type: Int64, nullable: true \}\), nullable: true \}\) select array_transform(v -> v*2, [1, 2]); query error DataFusion error: Error during planning: lambda defined 3 params but UDF support only 1 From 7a7d371ae2a58e23c696980c7273f28b6ae49494 Mon Sep 17 00:00:00 2001 From: gstvg <28798827+gstvg@users.noreply.github.com> Date: Tue, 24 Mar 2026 23:55:03 +0000 Subject: [PATCH 30/47] cargo fmt --- .../expr/src/type_coercion/functions.rs | 1 - .../functions-nested/src/array_transform.rs | 19 +++++++++---------- .../optimizer/src/analyzer/type_coercion.rs | 6 +++--- 3 files changed, 12 insertions(+), 14 deletions(-) diff --git a/datafusion/expr/src/type_coercion/functions.rs b/datafusion/expr/src/type_coercion/functions.rs index c5161ce3a3638..703e416ba280c 100644 --- a/datafusion/expr/src/type_coercion/functions.rs +++ b/datafusion/expr/src/type_coercion/functions.rs @@ -184,7 +184,6 @@ pub fn value_fields_with_lambda_udf( ); } - // coerced_types has been partitioned from current_fields // and refers only to values and not to lambdas, so instead // of zipping them, we iterate over current_fields and only diff --git a/datafusion/functions-nested/src/array_transform.rs b/datafusion/functions-nested/src/array_transform.rs index 1f6a9fc47ca56..35a3a42299658 100644 --- a/datafusion/functions-nested/src/array_transform.rs +++ b/datafusion/functions-nested/src/array_transform.rs @@ -187,7 +187,10 @@ impl LambdaUDF for ArrayTransform { // Fast path for fully null input array and also the only way to safely work with // a fully null fixed size list array as it can't be handled by remove_list_null_values below if list_array.null_count() == list_array.len() { - return Ok(ColumnarValue::Array(new_null_array(args.return_type(), list_array.len()))) + return Ok(ColumnarValue::Array(new_null_array( + args.return_type(), + list_array.len(), + ))); } // null sublists may contain values that cause problems, like a 0 used on a division @@ -287,15 +290,11 @@ fn value_lambda_pair<'a, V: Debug, L: Debug>( //todo: make this function public and move to a more generic crate like datafusion-common fn remove_list_null_values(list: &dyn Array) -> Result { match list.data_type() { - DataType::List(_) => { - Ok(Arc::new(truncate_nulls(list.as_list::())?)) - } - DataType::LargeList(_) => { - Ok(Arc::new(truncate_nulls(list.as_list::())?)) - } - DataType::FixedSizeList(_, _) => { - Ok(Arc::new(replace_nulls_with_valid(list.as_fixed_size_list())?)) - } + DataType::List(_) => Ok(Arc::new(truncate_nulls(list.as_list::())?)), + DataType::LargeList(_) => Ok(Arc::new(truncate_nulls(list.as_list::())?)), + DataType::FixedSizeList(_, _) => Ok(Arc::new(replace_nulls_with_valid( + list.as_fixed_size_list(), + )?)), dt => exec_err!("expected list, got {dt}"), } } diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 8da23892e4744..c77fcb090d290 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -767,9 +767,9 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> { }) .collect::>()?; - Ok(Transformed::yes(Expr::LambdaFunction( - LambdaFunction::new(func, new_args) - ))) + Ok(Transformed::yes(Expr::LambdaFunction(LambdaFunction::new( + func, new_args, + )))) } // TODO: remove the next line after `Expr::Wildcard` is removed #[expect(deprecated)] From 83fb18d6ccf286f977fb6d7a0d296cc6d4cef42b Mon Sep 17 00:00:00 2001 From: gstvg <28798827+gstvg@users.noreply.github.com> Date: Wed, 25 Mar 2026 00:06:29 +0000 Subject: [PATCH 31/47] add tip on LambdaUDF::lambdas_parameters docs to LambdaFunction helper --- datafusion/expr/src/udlf.rs | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/datafusion/expr/src/udlf.rs b/datafusion/expr/src/udlf.rs index 26a902c83f140..c6481a2e3dc1f 100644 --- a/datafusion/expr/src/udlf.rs +++ b/datafusion/expr/src/udlf.rs @@ -313,6 +313,12 @@ pub trait LambdaUDF: Debug + DynEq + DynHash + Send + Sync { /// lambdas are supported and some are optional, all should be returned, /// regardless of whether they are used on a particular invocation /// + /// Tip: If you have a [`LambdaFunction`] invocation, you can call the helper + /// [`LambdaFunction::lambdas_parameters`] instead of this method directly + /// + /// [`LambdaFunction`]: crate::expr::LambdaFunction + /// [`LambdaFunction::lambdas_parameters`]: crate::expr::LambdaFunction::lambdas_parameters + /// /// Example for array_transform: /// /// `array_transform([2.0, 8.0], v -> v > 4.0)` From e76ff254b1d0f1e93f9d0d843cf4aab25e76e035 Mon Sep 17 00:00:00 2001 From: gstvg <28798827+gstvg@users.noreply.github.com> Date: Wed, 25 Mar 2026 02:15:28 +0000 Subject: [PATCH 32/47] minor fixes --- .../src/execution/session_state_defaults.rs | 5 +- datafusion/expr/src/expr.rs | 26 +------ datafusion/expr/src/udlf.rs | 2 - .../functions-nested/src/array_transform.rs | 75 +++++++++---------- datafusion/sql/src/unparser/expr.rs | 2 +- 5 files changed, 45 insertions(+), 65 deletions(-) diff --git a/datafusion/core/src/execution/session_state_defaults.rs b/datafusion/core/src/execution/session_state_defaults.rs index 9235839c9ea12..487565ead3093 100644 --- a/datafusion/core/src/execution/session_state_defaults.rs +++ b/datafusion/core/src/execution/session_state_defaults.rs @@ -115,7 +115,10 @@ impl SessionStateDefaults { /// returns the list of default [`LambdaUDF`]s pub fn default_lambda_functions() -> Vec> { #[cfg(feature = "nested_expressions")] - functions_nested::all_default_lambda_functions() + return functions_nested::all_default_lambda_functions(); + + #[cfg(not(feature = "nested_expressions"))] + return Vec::new(); } /// returns the list of default [`AggregateUDF`]s diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 08add22f7330f..8214fa4248c7f 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -410,7 +410,7 @@ pub enum Expr { /// /// For example, `array_transform([1,2,3], v -> v+1)` would be equivalent to: /// - /// ```ignore + /// ```text /// LambdaFunction(array_transform) /// ├── args[0]: Literal([1,2,3]) /// └── args[1]: Lambda @@ -483,22 +483,7 @@ impl PartialEq for LambdaFunction { } /// A named reference to a lambda parameter which includes it's own [`FieldRef`], -/// which is used to implement [`ExprSchemable`], for example. It is an option only to make -/// easier for `expr_api` users to construct lambda variables, but any expression -/// tree or [`LogicalPlan`] containing unresolved variables must be resolved before -/// usage with either [`Expr::resolve_lambdas_variables`] or -/// [`LogicalPlan::resolve_lambdas_variables`]. The default SQL planner produces -/// already resolved variables and no further resolving is required. -/// -/// After resolving, if any non-lambda argument from the lambda function -/// which this variables originates from have it's type, nullability or -/// metadata changed, the resolved field may became outdated and must be -/// resolved again. -/// -/// [`LogicalPlan`]: crate::LogicalPlan -/// [`LogicalPlan::resolve_lambdas_variables`]: LogicalPlan::resolve_lambdas_variables -/// -// todo: if substrait come to produce resolved variables, cite it above too +/// which is used to implement [`ExprSchemable`], for example #[derive(Clone, PartialEq, PartialOrd, Eq, Debug, Hash)] pub struct LambdaVariable { pub name: String, @@ -507,12 +492,7 @@ pub struct LambdaVariable { } impl LambdaVariable { - /// Create a lambda variable from a name and an optional Field. - /// If the field is none, the expression tree or LogicalPlan which - /// owns this variable must be resolved before usage with either - /// [`Expr::resolve_lambdas_variables`] or [`LogicalPlan::resolve_lambdas_variables`]. - /// - /// [`LogicalPlan::resolve_lambdas_variables`]: crate::LogicalPlan::resolve_lambdas_variables + /// Create a lambda variable from a name and a Field. pub fn new(name: String, field: FieldRef) -> Self { Self { name, diff --git a/datafusion/expr/src/udlf.rs b/datafusion/expr/src/udlf.rs index c6481a2e3dc1f..27c49df7699ad 100644 --- a/datafusion/expr/src/udlf.rs +++ b/datafusion/expr/src/udlf.rs @@ -365,8 +365,6 @@ pub trait LambdaUDF: Debug + DynEq + DynHash + Send + Sync { /// # struct Example{} /// # impl Example { /// fn return_field_from_args(&self, args: LambdaReturnFieldArgs) -> Result { - /// // report output is only nullable if any one of the arguments are nullable - /// let nullable = args.arg_fields.iter().any(|f| f.is_nullable()); /// let field = Arc::new(Field::new("ignored_name", DataType::Int32, true)); /// Ok(field) /// } diff --git a/datafusion/functions-nested/src/array_transform.rs b/datafusion/functions-nested/src/array_transform.rs index 35a3a42299658..d35b3b03e8d8f 100644 --- a/datafusion/functions-nested/src/array_transform.rs +++ b/datafusion/functions-nested/src/array_transform.rs @@ -352,45 +352,44 @@ fn replace_nulls_with_valid(list: &FixedSizeListArray) -> Result( list: &GenericListArray, ) -> Result> { - if let Some(nulls) = list.nulls() { - if list.null_count() > 0 { - let contains_null_and_non_empty = - std::iter::zip(list.offsets().lengths(), nulls) - .any(|(len, is_valid)| len > 0 && !is_valid); - - if contains_null_and_non_empty { - let mut indices = Vec::with_capacity(list.values().len()); - - let lengths = list.offsets().windows(2).enumerate().map(|(i, window)| { - let start = window[0].as_usize(); - let end = window[1].as_usize(); - - if list.is_valid(i) { - indices.extend((start..end).map(|i| i as u64)); - - end - start - } else { - 0 - } - }); - - let offsets = OffsetBuffer::from_lengths(lengths); - let indices = UInt64Array::from(indices); - let values = take(list.values(), &indices, None)?; - - let field = match list.data_type() { - DataType::List(field) => field, - DataType::LargeList(field) => field, - _ => unreachable!(), - }; + if let Some(nulls) = list.nulls() + && nulls.null_count() > 0 + { + let contains_null_and_non_empty = std::iter::zip(list.offsets().lengths(), nulls) + .any(|(len, is_valid)| len > 0 && !is_valid); - return Ok(GenericListArray::try_new( - Arc::clone(field), - offsets, - values, - list.nulls().cloned(), - )?); - } + if contains_null_and_non_empty { + let mut indices = Vec::with_capacity(list.values().len()); + + let lengths = list.offsets().windows(2).enumerate().map(|(i, window)| { + let start = window[0].as_usize(); + let end = window[1].as_usize(); + + if list.is_valid(i) { + indices.extend((start..end).map(|i| i as u64)); + + end - start + } else { + 0 + } + }); + + let offsets = OffsetBuffer::from_lengths(lengths); + let indices = UInt64Array::from(indices); + let values = take(list.values(), &indices, None)?; + + let field = match list.data_type() { + DataType::List(field) => field, + DataType::LargeList(field) => field, + _ => unreachable!(), + }; + + return Ok(GenericListArray::try_new( + Arc::clone(field), + offsets, + values, + list.nulls().cloned(), + )?); } } diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index 1e5642646f0c0..5f3c97d7b1c2f 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -573,7 +573,7 @@ impl Unparser<'_> { params.iter().map(|param| param.as_str().into()).collect(), ), body: Box::new(self.expr_to_sql_inner(body)?), - syntax: ast::LambdaSyntax::LambdaKeyword, + syntax: ast::LambdaSyntax::Arrow, })) } Expr::LambdaVariable(l) => Ok(ast::Expr::Identifier( From 6c32ef8e1b1a7d8deef29b32792cc652d34629a8 Mon Sep 17 00:00:00 2001 From: gstvg <28798827+gstvg@users.noreply.github.com> Date: Wed, 25 Mar 2026 03:16:32 +0000 Subject: [PATCH 33/47] minor fixes --- datafusion/core/src/execution/session_state.rs | 1 - datafusion/spark/src/lib.rs | 4 +++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/datafusion/core/src/execution/session_state.rs b/datafusion/core/src/execution/session_state.rs index fa6f065065888..790086a37bd95 100644 --- a/datafusion/core/src/execution/session_state.rs +++ b/datafusion/core/src/execution/session_state.rs @@ -58,7 +58,6 @@ use datafusion_expr::planner::ExprPlanner; use datafusion_expr::planner::{RelationPlanner, TypePlanner}; use datafusion_expr::registry::{FunctionRegistry, SerializerRegistry}; use datafusion_expr::simplify::SimplifyContext; -#[cfg(feature = "sql")] use datafusion_expr::{ AggregateUDF, Explain, Expr, LambdaUDF, LogicalPlan, ScalarUDF, WindowUDF, }; diff --git a/datafusion/spark/src/lib.rs b/datafusion/spark/src/lib.rs index 9575f560b8d0e..71ee9c5f5d2e8 100644 --- a/datafusion/spark/src/lib.rs +++ b/datafusion/spark/src/lib.rs @@ -43,7 +43,7 @@ //! //! ``` //! # use datafusion_execution::FunctionRegistry; -//! # use datafusion_expr::{ScalarUDF, AggregateUDF, WindowUDF}; +//! # use datafusion_expr::{ScalarUDF, AggregateUDF, WindowUDF, LambdaUDF}; //! # use datafusion_expr::planner::ExprPlanner; //! # use datafusion_common::Result; //! # use std::collections::HashSet; @@ -55,9 +55,11 @@ //! # impl FunctionRegistry for SessionContext { //! # fn register_udf(&mut self, _udf: Arc) -> Result>> { Ok (None) } //! # fn udfs(&self) -> HashSet { unimplemented!() } +//! # fn udlfs(&self) -> HashSet { unimplemented!() } //! # fn udafs(&self) -> HashSet { unimplemented!() } //! # fn udwfs(&self) -> HashSet { unimplemented!() } //! # fn udf(&self, _name: &str) -> Result> { unimplemented!() } +//! # fn udlf(&self, name: &str) -> Result> { unimplemented!() } //! # fn udaf(&self, name: &str) -> Result> {unimplemented!() } //! # fn udwf(&self, name: &str) -> Result> { unimplemented!() } //! # fn expr_planners(&self) -> Vec> { unimplemented!() } From 27a3e2474671ef9be5c86097f9d4a88d2f28db74 Mon Sep 17 00:00:00 2001 From: gstvg <28798827+gstvg@users.noreply.github.com> Date: Fri, 27 Mar 2026 20:16:59 +0000 Subject: [PATCH 34/47] simplify array_transform tests --- .../functions-nested/src/array_transform.rs | 80 +++++++++---------- 1 file changed, 39 insertions(+), 41 deletions(-) diff --git a/datafusion/functions-nested/src/array_transform.rs b/datafusion/functions-nested/src/array_transform.rs index d35b3b03e8d8f..160fa8ba6c4df 100644 --- a/datafusion/functions-nested/src/array_transform.rs +++ b/datafusion/functions-nested/src/array_transform.rs @@ -398,17 +398,20 @@ fn truncate_nulls( #[cfg(test)] mod tests { - use std::sync::Arc; + use std::{collections::HashMap, sync::Arc}; use arrow::{ - array::{Array, ArrayRef, AsArray, FixedSizeListArray, Int32Array, ListArray}, + array::{ + Array, ArrayRef, AsArray, FixedSizeListArray, Int32Array, ListArray, + RecordBatch, + }, buffer::{NullBuffer, OffsetBuffer}, - datatypes::{DataType, Field, FieldRef}, + datatypes::{DataType, Field}, }; - use datafusion_common::{DFSchema, Result, config::ConfigOptions}; + use datafusion_common::{DFSchema, Result}; use datafusion_expr::{ - LambdaArgument, LambdaFunctionArgs, ValueOrLambda, - execution_props::ExecutionProps, lambda_var, lit, + Expr, col, execution_props::ExecutionProps, expr::LambdaFunction, lambda, + lambda_var, lit, }; use datafusion_physical_expr::create_physical_expr; @@ -437,47 +440,42 @@ mod tests { ) } - fn int32_field() -> FieldRef { - Arc::new(Field::new("", DataType::Int32, true)) - } - fn divide_100_by(list: impl Array + Clone + 'static) -> Result { let array_transform = array_transform_udlf(); - let lambda = create_physical_expr( - &(lit(100i32) / lambda_var("v", int32_field())), - &DFSchema::empty(), - &ExecutionProps::new(), + let schema = DFSchema::from_unqualified_fields( + vec![Field::new( + "list", + list.data_type().clone(), + list.is_nullable(), + )] + .into(), + HashMap::new(), )?; - array_transform - .invoke_with_args(LambdaFunctionArgs { - args: vec![ - ValueOrLambda::Value(datafusion_expr::ColumnarValue::Array( - Arc::new(list.clone()), - )), - ValueOrLambda::Lambda(LambdaArgument::new( - vec![Arc::new(Field::new("v", DataType::Int32, true))], - lambda, - )), + create_physical_expr( + &Expr::LambdaFunction(LambdaFunction::new( + array_transform, + vec![ + col("list"), + lambda( + ["v"], + lit(100i32) + / lambda_var( + "v", + Arc::new(Field::new("v", DataType::Int32, true)), + ), + ), ], - arg_fields: vec![ - ValueOrLambda::Value(Arc::new(Field::new( - "", - list.data_type().clone(), - list.is_nullable(), - ))), - ValueOrLambda::Lambda(int32_field()), - ], - number_rows: list.len(), - return_field: Arc::new(Field::new_list( - "", - Field::new_list_field(DataType::Int32, true), - list.is_nullable(), - )), - config_options: Arc::new(ConfigOptions::new()), - })? - .into_array(list.len()) + )), + &schema, + &ExecutionProps::new(), + )? + .evaluate(&RecordBatch::try_new( + Arc::clone(schema.inner()), + vec![Arc::new(list.clone())], + )?)? + .into_array(list.len()) } #[test] From 93e66f7c618cfab4acb8f85588d1727d6b4b7abf Mon Sep 17 00:00:00 2001 From: gstvg <28798827+gstvg@users.noreply.github.com> Date: Fri, 27 Mar 2026 20:23:11 +0000 Subject: [PATCH 35/47] evaluate LambdaVariable by index instead of name --- .../src/expressions/lambda_variable.rs | 39 +++++++++-- datafusion/physical-expr/src/planner.rs | 70 ++++++++++++++++--- 2 files changed, 93 insertions(+), 16 deletions(-) diff --git a/datafusion/physical-expr/src/expressions/lambda_variable.rs b/datafusion/physical-expr/src/expressions/lambda_variable.rs index 2072bf9bb1c1e..f4d4a294d491b 100644 --- a/datafusion/physical-expr/src/expressions/lambda_variable.rs +++ b/datafusion/physical-expr/src/expressions/lambda_variable.rs @@ -28,13 +28,14 @@ use arrow::{ record_batch::RecordBatch, }; -use datafusion_common::{Result, exec_err}; +use datafusion_common::{Result, internal_err, plan_err}; use datafusion_expr::ColumnarValue; /// Represents the lambda variable with a given name and field #[derive(Debug, Clone)] pub struct LambdaVariable { name: String, + index: usize, field: FieldRef, } @@ -55,8 +56,8 @@ impl Hash for LambdaVariable { impl LambdaVariable { /// Create a new lambda variable expression - pub fn new(name: String, field: FieldRef) -> Self { - Self { name, field } + pub fn new(name: String, index: usize, field: FieldRef) -> Self { + Self { name, index, field } } /// Get the variable's name @@ -90,10 +91,22 @@ impl PhysicalExpr for LambdaVariable { } fn evaluate(&self, batch: &RecordBatch) -> Result { - match batch.column_by_name(&self.name) { - Some(array) => Ok(ColumnarValue::Array(Arc::clone(array))), - None => exec_err!("LambdaVariable {} not present in batch", self.name), + if self.index >= batch.num_columns() { + return internal_err!( + "PhysicalExpr LambdaVariable references column '{}' at index {} (zero-based) but batch only has {} columns: {:?}", + self.name, + self.index, + batch.num_columns(), + batch + .schema_ref() + .fields() + .iter() + .map(|f| f.name()) + .collect::>() + ); } + + Ok(ColumnarValue::Array(Arc::clone(batch.column(self.index)))) } fn return_field(&self, _input_schema: &Schema) -> Result { @@ -120,6 +133,18 @@ impl PhysicalExpr for LambdaVariable { pub fn lambda_variable( name: impl Into, field: FieldRef, + schema: &Schema, ) -> Result> { - Ok(Arc::new(LambdaVariable::new(name.into(), field))) + let name = name.into(); + let index = schema.index_of(&name)?; + + let schema_field = schema.field(index); + + if field.as_ref() != schema_field { + return plan_err!( + "LambdaVariable owned field differ from schema field {field} != {schema_field}" + ); + } + + Ok(Arc::new(LambdaVariable::new(name, index, field))) } diff --git a/datafusion/physical-expr/src/planner.rs b/datafusion/physical-expr/src/planner.rs index 5e391f08f131c..2885fcf398fe2 100644 --- a/datafusion/physical-expr/src/planner.rs +++ b/datafusion/physical-expr/src/planner.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use std::collections::HashMap; use std::sync::Arc; use crate::{LambdaFunctionExpr, ScalarFunctionExpr}; @@ -27,7 +28,8 @@ use arrow::datatypes::Schema; use datafusion_common::config::ConfigOptions; use datafusion_common::metadata::{FieldMetadata, format_type_and_metadata}; use datafusion_common::{ - DFSchema, Result, ScalarValue, ToDFSchema, exec_err, not_impl_err, plan_err, + DFSchema, Result, ScalarValue, ToDFSchema, exec_err, internal_datafusion_err, + not_impl_err, plan_err, }; use datafusion_expr::execution_props::ExecutionProps; use datafusion_expr::expr::{ @@ -415,9 +417,49 @@ pub fn create_physical_expr( Expr::Placeholder(Placeholder { id, .. }) => { exec_err!("Placeholder '{id}' was not provided a value for execution.") } - Expr::LambdaFunction(LambdaFunction { func, args }) => { - let physical_args = - create_physical_exprs(args, input_dfschema, execution_props)?; + Expr::LambdaFunction(invocation @ LambdaFunction { func, args }) => { + let num_lambdas = args + .iter() + .filter(|arg| matches!(arg, Expr::Lambda(_))) + .count(); + + let mut lambdas_parameters = + invocation.lambdas_parameters(input_dfschema)?.into_iter(); + + if num_lambdas > lambdas_parameters.len() { + return plan_err!( + "{} lambdas_parameters returned only {} values for {num_lambdas} lambdas", + func.name(), + lambdas_parameters.len() + ); + } + + let physical_args = args + .iter() + .map(|arg| match arg { + Expr::Lambda(lambda) => { + let lambda_parameters = lambdas_parameters + .next() + .ok_or_else(|| { + internal_datafusion_err!( + "lambdas_parameters len should have been checked above" + ) + })? + .into_iter() + .zip(&lambda.params) + .map(|(field, name)| field.with_name(name)) + .collect(); + + let lambda_schema = DFSchema::from_unqualified_fields( + lambda_parameters, + HashMap::new(), + )?; + + create_physical_expr(arg, &lambda_schema, execution_props) + } + _ => create_physical_expr(arg, input_dfschema, execution_props), + }) + .collect::>()?; let config_options = match execution_props.config_options.as_ref() { Some(config_options) => Arc::clone(config_options), @@ -431,15 +473,25 @@ pub fn create_physical_expr( config_options, )?)) } - Expr::Lambda(Lambda { params, body }) => expressions::lambda( - params, - create_physical_expr(body, input_dfschema, execution_props)?, - ), + Expr::Lambda(Lambda { params, body }) => { + if body.any_column_refs() { + return plan_err!("lambda doesn't support column capture"); + } + + expressions::lambda( + params, + create_physical_expr(body, input_dfschema, execution_props)?, + ) + } Expr::LambdaVariable(LambdaVariable { name, field, spans: _, - }) => expressions::lambda_variable(name, Arc::clone(field)), + }) => expressions::lambda_variable( + name, + Arc::clone(field), + input_dfschema.as_arrow(), + ), other => { not_impl_err!("Physical plan does not support logical expression {other:?}") } From a9d0e6cbf4b7ddd0e3d70a4076193f93073af62a Mon Sep 17 00:00:00 2001 From: gstvg <28798827+gstvg@users.noreply.github.com> Date: Sun, 29 Mar 2026 04:04:39 -0300 Subject: [PATCH 36/47] add LambdaUDF::clean_null_values --- datafusion/common/src/utils/mod.rs | 190 +++++++++++++++++- datafusion/expr/src/udlf.rs | 13 ++ .../functions-nested/src/array_transform.rs | 121 +---------- .../physical-expr/src/lambda_function.rs | 16 +- 4 files changed, 220 insertions(+), 120 deletions(-) diff --git a/datafusion/common/src/utils/mod.rs b/datafusion/common/src/utils/mod.rs index 49ddf6d934aca..58e1eccc1d42d 100644 --- a/datafusion/common/src/utils/mod.rs +++ b/datafusion/common/src/utils/mod.rs @@ -26,12 +26,18 @@ pub mod string_utils; use crate::assert_or_internal_err; use crate::error::{_exec_datafusion_err, _exec_err, _internal_datafusion_err}; use crate::{Result, ScalarValue}; -use arrow::array::GenericListArray; use arrow::array::{ Array, ArrayRef, FixedSizeListArray, LargeListArray, ListArray, OffsetSizeTrait, cast::AsArray, }; +use arrow::array::{ + BooleanArray, Datum, GenericListArray, Int32Array, Int64Array, MutableArrayData, + Scalar, make_array, +}; use arrow::buffer::OffsetBuffer; +use arrow::compute::kernels::cmp::neq; +use arrow::compute::kernels::length::length; +use arrow::compute::kernels::zip::zip; use arrow::compute::{SortColumn, SortOptions, partition}; use arrow::datatypes::{DataType, Field, SchemaRef}; #[cfg(feature = "sql")] @@ -1027,12 +1033,123 @@ pub fn adjust_offsets_for_slice( offsets.clone() } +/// For lists and large lists, truncates the sublist of null values +/// +/// For fixed size lists, if there's any valid value, replace all null values with it, +/// otherwise return the array unchanged +pub fn remove_list_null_values(array: &ArrayRef) -> Result { + // todo: handle list view and map + match array.data_type() { + DataType::List(_) => Ok(Arc::new(truncate_list_nulls(array.as_list::())?)), + DataType::LargeList(_) => { + Ok(Arc::new(truncate_list_nulls(array.as_list::())?)) + } + DataType::FixedSizeList(_, _) => replace_nulls_with_first_valid(array), + dt => _exec_err!("expected list, got {dt}"), + } +} + +fn replace_nulls_with_first_valid(array: &ArrayRef) -> Result { + if let Some(nulls) = array.nulls() { + let null_count = array.null_count(); + + if null_count > 0 { + if null_count == array.len() { + return Ok(Arc::clone(array)); + } + + let first_valid = nulls + .inner() + .set_indices() + .next() + .ok_or_else(|| _internal_datafusion_err!("fixed size list should have been checked to contain at least one valid value"))?; + + let mask = BooleanArray::new(nulls.inner().clone(), None); + // perf: remove the null buffer so zip doesn't unnecessarly zip it too + let without_null_buffer = + make_array(array.to_data().into_builder().nulls(None).build()?); + let first_valid = array.slice(first_valid, 1); + let zipped = zip(&mask, &without_null_buffer, &Scalar::new(first_valid))?; + let zipped_with_null_buffer = make_array( + zipped + .to_data() + .into_builder() + .nulls(Some(nulls.clone())) + .build()?, + ); + + return Ok(zipped_with_null_buffer); + } + } + + Ok(Arc::clone(array)) +} + +fn truncate_list_nulls( + list: &GenericListArray, +) -> Result> { + if let Some(nulls) = list.nulls() + && nulls.null_count() > 0 + { + let lengths = length(list)?; + let zero: &dyn Datum = if lengths.data_type() == &DataType::Int32 { + &Int32Array::new_scalar(0) + } else { + &Int64Array::new_scalar(0) + }; + + let not_empty = neq(&lengths, zero)?; + let null_and_non_empty = &!nulls.inner() & not_empty.values(); + + if null_and_non_empty.count_set_bits() > 0 { + let array_data = list.values().to_data(); + let offsets = list.offsets(); + let capacity = offsets[offsets.len() - 1] - offsets[0]; + let mut mutable_array_data = + MutableArrayData::new(vec![&array_data], false, capacity.as_usize()); + + let valid_or_empty = nulls.inner() | &!not_empty.values(); + + for (start, end) in valid_or_empty.set_slices() { + mutable_array_data.extend( + 0, + offsets[start].as_usize(), + offsets[end].as_usize(), + ); + } + + let lengths = std::iter::zip(offsets.lengths(), nulls) + .map(|(length, is_valid)| if is_valid { length } else { 0 }); + + let offsets = OffsetBuffer::from_lengths(lengths); + let values = make_array(mutable_array_data.freeze()); + + let field = match list.data_type() { + DataType::List(field) => field, + DataType::LargeList(field) => field, + _ => unreachable!(), + }; + + return Ok(GenericListArray::try_new( + Arc::clone(field), + offsets, + values, + list.nulls().cloned(), + )?); + } + } + Ok(list.clone()) +} + #[cfg(test)] mod tests { + use std::sync::Arc; + use super::*; use crate::ScalarValue::Null; use arrow::{ array::{Float64Array, Int32Array}, + buffer::NullBuffer, datatypes::Int32Type, }; use sqlparser::ast::Ident; @@ -1426,4 +1543,75 @@ mod tests { OffsetBuffer::from_lengths([]) ); } + + fn create_i32_list( + values: impl Into, + offsets: OffsetBuffer, + nulls: Option, + ) -> ListArray { + let list_field = Arc::new(Field::new_list_field(DataType::Int32, true)); + + ListArray::new(list_field, offsets, Arc::new(values.into()), nulls) + } + + fn create_i32_fsl( + size: i32, + values: Vec, + nulls: Option, + ) -> FixedSizeListArray { + FixedSizeListArray::new( + Arc::new(Field::new_list_field(DataType::Int32, true)), + size, + Arc::new(Int32Array::from(values)), + nulls, + ) + } + + #[test] + fn test_remove_list_null_values_list() { + let list = Arc::new(create_i32_list( + vec![100, 20, 10, 0, 0, 0, 0, 1, 50], + OffsetBuffer::::from_lengths(vec![3, 4, 0, 2, 0]), + Some(NullBuffer::from(vec![true, false, false, true, false])), + )) as ArrayRef; + + let res = remove_list_null_values(&list).unwrap(); + let res = res.as_list::(); + + let expected = Arc::new(create_i32_list( + vec![100, 20, 10, 1, 50], + OffsetBuffer::::from_lengths(vec![3, 0, 0, 2, 0]), + Some(NullBuffer::from(vec![true, false, false, true, false])), + )) as ArrayRef; + let expected = expected.as_list::(); + + assert_eq!(res, expected); + // check above skips inner value of nulls + assert_eq!(res.values(), expected.values()); + assert_eq!(res.offsets(), expected.offsets()); + } + + #[test] + fn test_remove_list_null_values_fsl() { + let list = Arc::new(create_i32_fsl( + 3, + vec![100, 20, 10, 0, 0, 0, 0, 0, 0], + Some(NullBuffer::from(vec![true, false, false])), + )) as ArrayRef; + + let res = remove_list_null_values(&list).unwrap(); + + let expected = Arc::new(create_i32_fsl( + 3, + vec![100, 20, 10, 100, 20, 10, 100, 20, 10], + Some(NullBuffer::from(vec![true, false, false])), + )) as ArrayRef; + + assert_eq!(&res, &expected); + // check above skips inner value of nulls + assert_eq!( + res.as_fixed_size_list().values(), + expected.as_fixed_size_list().values() + ); + } } diff --git a/datafusion/expr/src/udlf.rs b/datafusion/expr/src/udlf.rs index 27c49df7699ad..f7f40fa615ded 100644 --- a/datafusion/expr/src/udlf.rs +++ b/datafusion/expr/src/udlf.rs @@ -372,6 +372,19 @@ pub trait LambdaUDF: Debug + DynEq + DynHash + Send + Sync { /// ``` fn return_field_from_args(&self, args: LambdaReturnFieldArgs) -> Result; + /// Whether List, LargeList and FixedSizeList arguments should have it's + /// non-empty null sublists cleaned by Datafusion before invoking this function + /// + /// The default implementation always returns true and should only be implemented + /// if you want to handle non-empty null sublists yourself + /// + /// fully null fixed size list arrays should always be handled regardless of + /// the return of this function + // todo: extend this to listview and maps when remove_list_null_values supports it + fn clear_null_values(&self) -> bool { + true + } + /// Invoke the function returning the appropriate result. /// /// # Performance diff --git a/datafusion/functions-nested/src/array_transform.rs b/datafusion/functions-nested/src/array_transform.rs index 160fa8ba6c4df..d148415db3206 100644 --- a/datafusion/functions-nested/src/array_transform.rs +++ b/datafusion/functions-nested/src/array_transform.rs @@ -19,15 +19,13 @@ use arrow::{ array::{ - Array, ArrayRef, AsArray, FixedSizeListArray, GenericListArray, LargeListArray, - ListArray, OffsetSizeTrait, UInt64Array, new_null_array, + Array, ArrayRef, AsArray, FixedSizeListArray, LargeListArray, ListArray, + new_null_array, }, - buffer::OffsetBuffer, - compute::take, datatypes::{DataType, Field, FieldRef}, }; use datafusion_common::{ - Result, exec_err, internal_datafusion_err, plan_err, + Result, exec_err, plan_err, utils::{adjust_offsets_for_slice, list_values, take_function_args}, }; use datafusion_expr::{ @@ -193,10 +191,6 @@ impl LambdaUDF for ArrayTransform { ))); } - // null sublists may contain values that cause problems, like a 0 used on a division - // use remove_list_null_values to remove them - let list_array = remove_list_null_values(&list_array)?; - // as per list_values docs, if list_array is sliced, list_values will be sliced too, // so before constructing the transformed array below, we must adjust the list offsets with // adjust_offsets_for_slice @@ -287,115 +281,6 @@ fn value_lambda_pair<'a, V: Debug, L: Debug>( Ok((value, lambda)) } -//todo: make this function public and move to a more generic crate like datafusion-common -fn remove_list_null_values(list: &dyn Array) -> Result { - match list.data_type() { - DataType::List(_) => Ok(Arc::new(truncate_nulls(list.as_list::())?)), - DataType::LargeList(_) => Ok(Arc::new(truncate_nulls(list.as_list::())?)), - DataType::FixedSizeList(_, _) => Ok(Arc::new(replace_nulls_with_valid( - list.as_fixed_size_list(), - )?)), - dt => exec_err!("expected list, got {dt}"), - } -} - -fn replace_nulls_with_valid(list: &FixedSizeListArray) -> Result { - if let Some(nulls) = list.nulls() { - let null_count = list.null_count(); - - if null_count > 0 { - if null_count == list.len() { - return exec_err!("no valid value to use"); - } - - let first_valid = nulls - .inner() - .set_indices() - .next() - .ok_or_else(|| internal_datafusion_err!("fixed size list should have been checked to contain at least one valid value"))? as u64; - - let mut indices = Vec::with_capacity(list.values().len()); - - let size = list.value_length() as u64; - - for (i, is_valid) in nulls.iter().enumerate() { - let range = if is_valid { - let i = i as u64; - - i * size..(i + 1) * size - } else { - first_valid * size..(first_valid + 1) * size - }; - - indices.extend(range) - } - - let indices = UInt64Array::from(indices); - let values = take(list.values(), &indices, None)?; - - let DataType::FixedSizeList(field, size) = list.data_type() else { - unreachable!() - }; - - return Ok(FixedSizeListArray::try_new( - Arc::clone(field), - *size, - values, - list.nulls().cloned(), - )?); - } - } - - Ok(list.clone()) -} - -fn truncate_nulls( - list: &GenericListArray, -) -> Result> { - if let Some(nulls) = list.nulls() - && nulls.null_count() > 0 - { - let contains_null_and_non_empty = std::iter::zip(list.offsets().lengths(), nulls) - .any(|(len, is_valid)| len > 0 && !is_valid); - - if contains_null_and_non_empty { - let mut indices = Vec::with_capacity(list.values().len()); - - let lengths = list.offsets().windows(2).enumerate().map(|(i, window)| { - let start = window[0].as_usize(); - let end = window[1].as_usize(); - - if list.is_valid(i) { - indices.extend((start..end).map(|i| i as u64)); - - end - start - } else { - 0 - } - }); - - let offsets = OffsetBuffer::from_lengths(lengths); - let indices = UInt64Array::from(indices); - let values = take(list.values(), &indices, None)?; - - let field = match list.data_type() { - DataType::List(field) => field, - DataType::LargeList(field) => field, - _ => unreachable!(), - }; - - return Ok(GenericListArray::try_new( - Arc::clone(field), - offsets, - values, - list.nulls().cloned(), - )?); - } - } - - Ok(list.clone()) -} - #[cfg(test)] mod tests { use std::{collections::HashMap, sync::Arc}; diff --git a/datafusion/physical-expr/src/lambda_function.rs b/datafusion/physical-expr/src/lambda_function.rs index 9d1adc7a459f7..b5b79285cbe6d 100644 --- a/datafusion/physical-expr/src/lambda_function.rs +++ b/datafusion/physical-expr/src/lambda_function.rs @@ -40,6 +40,7 @@ use crate::expressions::{LambdaExpr, Literal}; use arrow::array::{Array, RecordBatch}; use arrow::datatypes::{DataType, FieldRef, Schema}; use datafusion_common::config::{ConfigEntry, ConfigOptions}; +use datafusion_common::utils::remove_list_null_values; use datafusion_common::{ Result, ScalarValue, exec_err, internal_datafusion_err, internal_err, }; @@ -309,7 +310,20 @@ impl PhysicalExpr for LambdaFunctionExpr { Arc::clone(lambda.body()), ))) } - None => Ok(ValueOrLambda::Value(arg.evaluate(batch)?)), + None => { + let value = arg.evaluate(batch)?; + + let value = + if self.fun.clear_null_values() && value.data_type().is_list() { + ColumnarValue::Array(remove_list_null_values( + &value.into_array(batch.num_rows())?, + )?) + } else { + value + }; + + Ok(ValueOrLambda::Value(value)) + } }) .collect::>>()?; From 69c44fc334150a241f8bf53792fc3d9377883497 Mon Sep 17 00:00:00 2001 From: gstvg <28798827+gstvg@users.noreply.github.com> Date: Sun, 29 Mar 2026 04:22:43 -0300 Subject: [PATCH 37/47] rename LambdaUDF::lambdas_parameters to lambda_parameters --- datafusion/expr/src/expr.rs | 6 ++--- datafusion/expr/src/udlf.rs | 18 +++++++-------- .../functions-nested/src/array_transform.rs | 4 ++-- .../physical-expr/src/lambda_function.rs | 20 ++++++++--------- datafusion/physical-expr/src/planner.rs | 14 ++++++------ datafusion/proto/src/bytes/mod.rs | 2 +- datafusion/sql/src/expr/function.rs | 22 +++++++++---------- datafusion/sql/src/expr/identifier.rs | 3 +-- datafusion/sql/src/planner.rs | 14 ++++++------ datafusion/sql/src/unparser/expr.rs | 2 +- 10 files changed, 52 insertions(+), 53 deletions(-) diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 8214fa4248c7f..4ba918b7783ed 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -445,9 +445,9 @@ impl LambdaFunction { self.func.name() } - /// Invokes the inner function [`LambdaUDF::lambdas_parameters`] + /// Invokes the inner function [`LambdaUDF::lambda_parameters`] /// using the arguments of this invocation - pub fn lambdas_parameters(&self, schema: &dyn ExprSchema) -> Result>> { + pub fn lambda_parameters(&self, schema: &dyn ExprSchema) -> Result>> { let args = self .args .iter() @@ -465,7 +465,7 @@ impl LambdaFunction { }) .collect::>(); - self.func.lambdas_parameters(&coerced_values) + self.func.lambda_parameters(&coerced_values) } } diff --git a/datafusion/expr/src/udlf.rs b/datafusion/expr/src/udlf.rs index f7f40fa615ded..f82f281f70e5e 100644 --- a/datafusion/expr/src/udlf.rs +++ b/datafusion/expr/src/udlf.rs @@ -161,7 +161,7 @@ pub struct LambdaFunctionArgs { /// Field associated with each arg, if it exists /// For lambdas, it will be the field of the result of /// the lambda if evaluated with the parameters - /// returned from [`LambdaUDF::lambdas_parameters`] + /// returned from [`LambdaUDF::lambda_parameters`] pub arg_fields: Vec>, /// The number of rows in record batch being evaluated pub number_rows: usize, @@ -203,7 +203,7 @@ impl LambdaArgument { /// Evaluate this lambda /// `args` should evaluate to the value of each parameter - /// of the correspondent lambda returned in [LambdaUDF::lambdas_parameters]. + /// of the correspondent lambda returned in [LambdaUDF::lambda_parameters]. pub fn evaluate( &self, args: &[&dyn Fn() -> Result], @@ -234,7 +234,7 @@ pub struct LambdaReturnFieldArgs<'a> { /// The data types of the arguments to the function /// /// If argument `i` to the function is a lambda, it will be the field of the result of the - /// lambda if evaluated with the parameters returned from [`LambdaUDF::lambdas_parameters`] + /// lambda if evaluated with the parameters returned from [`LambdaUDF::lambda_parameters`] /// /// For example, with `array_transform([1], v -> v == 5)` /// this field will be `[ @@ -314,22 +314,22 @@ pub trait LambdaUDF: Debug + DynEq + DynHash + Send + Sync { /// regardless of whether they are used on a particular invocation /// /// Tip: If you have a [`LambdaFunction`] invocation, you can call the helper - /// [`LambdaFunction::lambdas_parameters`] instead of this method directly + /// [`LambdaFunction::lambda_parameters`] instead of this method directly /// /// [`LambdaFunction`]: crate::expr::LambdaFunction - /// [`LambdaFunction::lambdas_parameters`]: crate::expr::LambdaFunction::lambdas_parameters + /// [`LambdaFunction::lambda_parameters`]: crate::expr::LambdaFunction::lambda_parameters /// /// Example for array_transform: /// /// `array_transform([2.0, 8.0], v -> v > 4.0)` /// /// ```ignore - /// let lambdas_parameters = array_transform.lambdas_parameters(&[ + /// let lambda_parameters = array_transform.lambda_parameters(&[ /// Arc::new(Field::new("", DataType::new_list(DataType::Float32, false))), // the Field of the literal `[2, 8]` /// ])?; /// /// assert_eq!( - /// lambdas_parameters, + /// lambda_parameters, /// vec![ /// // the lambda supported parameters, regardless of how many are actually used /// vec![ @@ -345,7 +345,7 @@ pub trait LambdaUDF: Debug + DynEq + DynHash + Send + Sync { /// /// The implementation can assume that some other part of the code has coerced /// the actual argument types to match [`Self::signature`]. - fn lambdas_parameters(&self, value_fields: &[FieldRef]) -> Result>>; + fn lambda_parameters(&self, value_fields: &[FieldRef]) -> Result>>; /// What type will be returned by this function, given the arguments? /// @@ -495,7 +495,7 @@ mod tests { &self.signature } - fn lambdas_parameters( + fn lambda_parameters( &self, _value_fields: &[FieldRef], ) -> Result>> { diff --git a/datafusion/functions-nested/src/array_transform.rs b/datafusion/functions-nested/src/array_transform.rs index d148415db3206..e9ed199d55dae 100644 --- a/datafusion/functions-nested/src/array_transform.rs +++ b/datafusion/functions-nested/src/array_transform.rs @@ -128,7 +128,7 @@ impl LambdaUDF for ArrayTransform { Ok(vec![coerced]) } - fn lambdas_parameters(&self, value_fields: &[FieldRef]) -> Result>> { + fn lambda_parameters(&self, value_fields: &[FieldRef]) -> Result>> { let list = if value_fields.len() == 1 { &value_fields[0] } else { @@ -160,7 +160,7 @@ impl LambdaUDF for ArrayTransform { //TODO: should metadata be copied into the transformed array? // lambda is the resulting field of executing the lambda body - // with the parameters returned in lambdas_parameters + // with the parameters returned in lambda_parameters let field = Arc::new(Field::new( Field::LIST_FIELD_DEFAULT_NAME, lambda.data_type().clone(), diff --git a/datafusion/physical-expr/src/lambda_function.rs b/datafusion/physical-expr/src/lambda_function.rs index b5b79285cbe6d..6823ff4d4d4f4 100644 --- a/datafusion/physical-expr/src/lambda_function.rs +++ b/datafusion/physical-expr/src/lambda_function.rs @@ -260,25 +260,25 @@ impl PhysicalExpr for LambdaFunctionExpr { }) .collect::>(); - // lambdas_parameters refers only to lambdas and not to values, so instead + // lambda_parameters refers only to lambdas and not to values, so instead // of zipping it with self.args, we iterate over self.args and only - // consume from lambdas_parameters when a given argument is a lambda + // consume from lambda_parameters when a given argument is a lambda // to reconstruct the arguments list with the correct order // this supports any value and lambda positioning including // multiple lambdas interleaved with values - let mut lambdas_parameters = - self.fun().lambdas_parameters(&value_fields)?.into_iter(); + let mut lambda_parameters = + self.fun().lambda_parameters(&value_fields)?.into_iter(); let num_lambdas = self.args.len() - value_fields.len(); // functions can support multiple lambdas where some trailing ones are optional, - // but to simplify the implementor, lambdas_parameters returns the parameters of all of them, + // but to simplify the implementor, lambda_parameters returns the parameters of all of them, // so we can't do equality check. one example is spark reduce: // https://spark.apache.org/docs/latest/api/sql/index.html#reduce - if lambdas_parameters.len() < num_lambdas { + if lambda_parameters.len() < num_lambdas { return exec_err!( - "{} invocation defined {num_lambdas} but lambdas_parameters returned only {}", + "{} invocation defined {num_lambdas} but lambda_parameters returned only {}", self.name(), - lambdas_parameters.len() + lambda_parameters.len() ); } @@ -287,7 +287,7 @@ impl PhysicalExpr for LambdaFunctionExpr { .iter() .map(|arg| match arg.as_any().downcast_ref::() { Some(lambda) => { - let lambda_params = lambdas_parameters.next().ok_or_else(|| { + let lambda_params = lambda_parameters.next().ok_or_else(|| { internal_datafusion_err!( "params len should have been checked above" ) @@ -432,7 +432,7 @@ mod tests { &self.signature } - fn lambdas_parameters( + fn lambda_parameters( &self, _value_fields: &[FieldRef], ) -> Result>> { diff --git a/datafusion/physical-expr/src/planner.rs b/datafusion/physical-expr/src/planner.rs index 2885fcf398fe2..f530c3f9f6b8a 100644 --- a/datafusion/physical-expr/src/planner.rs +++ b/datafusion/physical-expr/src/planner.rs @@ -423,14 +423,14 @@ pub fn create_physical_expr( .filter(|arg| matches!(arg, Expr::Lambda(_))) .count(); - let mut lambdas_parameters = - invocation.lambdas_parameters(input_dfschema)?.into_iter(); + let mut lambda_parameters = + invocation.lambda_parameters(input_dfschema)?.into_iter(); - if num_lambdas > lambdas_parameters.len() { + if num_lambdas > lambda_parameters.len() { return plan_err!( - "{} lambdas_parameters returned only {} values for {num_lambdas} lambdas", + "{} lambda_parameters returned only {} values for {num_lambdas} lambdas", func.name(), - lambdas_parameters.len() + lambda_parameters.len() ); } @@ -438,11 +438,11 @@ pub fn create_physical_expr( .iter() .map(|arg| match arg { Expr::Lambda(lambda) => { - let lambda_parameters = lambdas_parameters + let lambda_parameters = lambda_parameters .next() .ok_or_else(|| { internal_datafusion_err!( - "lambdas_parameters len should have been checked above" + "lambda_parameters len should have been checked above" ) })? .into_iter() diff --git a/datafusion/proto/src/bytes/mod.rs b/datafusion/proto/src/bytes/mod.rs index c314a6ae21244..5343dc00648fb 100644 --- a/datafusion/proto/src/bytes/mod.rs +++ b/datafusion/proto/src/bytes/mod.rs @@ -144,7 +144,7 @@ impl Serializeable for Expr { &self.signature } - fn lambdas_parameters( + fn lambda_parameters( &self, _value_fields: &[arrow::datatypes::FieldRef], ) -> Result>> { diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index 1f7574f5a8eaf..a61653e472103 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -371,7 +371,7 @@ impl SqlToRel<'_, S> { if let Some(fm) = self.context_provider.get_lambda_meta(&name) { // plan non-lambda arguments first so we can get theirs datatype and call - // LambdaUDF::lambdas_parameters to then plan the lambda arguments with + // LambdaUDF::lambda_parameters to then plan the lambda arguments with // resolved lambda variables enum ExprOrLambda { Expr(Expr), @@ -422,26 +422,26 @@ impl SqlToRel<'_, S> { }) .collect::>(); - // lambdas_parameters refers only to lambdas and not to values, so instead + // lambda_parameters refers only to lambdas and not to values, so instead // of zipping it with partially_planned, we iterate over partially_planned and only - // consume from lambdas_parameters when a given argument is a lambda + // consume from lambda_parameters when a given argument is a lambda // to reconstruct the arguments list with the correct order // this supports any value and lambda positioning including // multiple lambdas interleaved with values - let mut lambdas_parameters = - fm.lambdas_parameters(&coerced_values)?.into_iter(); + let mut lambda_parameters = + fm.lambda_parameters(&coerced_values)?.into_iter(); let num_lambdas = partially_planned.len() - coerced_values.len(); // functions can support multiple lambdas where some trailing ones are optional, - // but to simplify the implementor, lambdas_parameters returns the parameters of all of them, + // but to simplify the implementor, lambda_parameters returns the parameters of all of them, // so we can't do equality check. one example is spark reduce: // https://spark.apache.org/docs/latest/api/sql/index.html#reduce - if lambdas_parameters.len() < num_lambdas { + if lambda_parameters.len() < num_lambdas { return plan_err!( - "{} invocation defined {num_lambdas} but lambdas_parameters returned only {}", + "{} invocation defined {num_lambdas} but lambda_parameters returned only {}", fm.name(), - lambdas_parameters.len() + lambda_parameters.len() ); } @@ -451,9 +451,9 @@ impl SqlToRel<'_, S> { ExprOrLambda::Expr(expr) => Ok(expr), ExprOrLambda::Lambda(lambda) => { let lambda_params = - lambdas_parameters.next().ok_or_else(|| { + lambda_parameters.next().ok_or_else(|| { internal_datafusion_err!( - "lambdas_parameters len should have been checked above" + "lambda_parameters len should have been checked above" ) })?; diff --git a/datafusion/sql/src/expr/identifier.rs b/datafusion/sql/src/expr/identifier.rs index 4ae0f51e9f993..0fd769a7b916d 100644 --- a/datafusion/sql/src/expr/identifier.rs +++ b/datafusion/sql/src/expr/identifier.rs @@ -61,8 +61,7 @@ impl SqlToRel<'_, S> { let normalize_ident = self.ident_normalizer.normalize(id); // lambdas parameters have higher precedence - if let Some(field) = - planner_context.lambdas_parameters().get(&normalize_ident) + if let Some(field) = planner_context.lambda_parameters().get(&normalize_ident) { let mut lambda_var = LambdaVariable::new(normalize_ident, Arc::clone(field)); diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index b43b1454a4d9d..b12152f12dbca 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -271,7 +271,7 @@ pub struct PlannerContext { /// The query schema defined by the table create_table_schema: Option, /// The parameters of all lambdas seen so far - lambdas_parameters: HashMap, + lambda_parameters: HashMap, } impl Default for PlannerContext { @@ -289,7 +289,7 @@ impl PlannerContext { outer_queries_schemas_stack: vec![], outer_from_schema: None, create_table_schema: None, - lambdas_parameters: HashMap::new(), + lambda_parameters: HashMap::new(), } } @@ -399,16 +399,16 @@ impl PlannerContext { self.ctes.get(cte_name).map(|cte| cte.as_ref()) } - pub fn lambdas_parameters(&self) -> &HashMap { - &self.lambdas_parameters + pub fn lambda_parameters(&self) -> &HashMap { + &self.lambda_parameters } pub fn with_lambda_parameters( mut self, - arguments: impl IntoIterator, + parameters: impl IntoIterator, ) -> Self { - self.lambdas_parameters - .extend(arguments.into_iter().map(|f| (f.name().clone(), f))); + self.lambda_parameters + .extend(parameters.into_iter().map(|f| (f.name().clone(), f))); self } diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index 5f3c97d7b1c2f..8f2850c807282 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -1947,7 +1947,7 @@ mod tests { unimplemented!() } - fn lambdas_parameters( + fn lambda_parameters( &self, _value_fields: &[FieldRef], ) -> Result>> { From 66839d3a76c8b6a3029c97b7472e2534ce543bd3 Mon Sep 17 00:00:00 2001 From: gstvg <28798827+gstvg@users.noreply.github.com> Date: Sun, 29 Mar 2026 06:30:06 -0300 Subject: [PATCH 38/47] rename LambdaFunction to HigherOrderFunction, LambdaUDF to HigherOrderUDF --- .../examples/sql_ops/frontend.rs | 8 +- datafusion/catalog-listing/src/helpers.rs | 4 +- .../core/src/bin/print_functions_docs.rs | 4 +- .../src/datasource/listing_table_factory.rs | 4 +- datafusion/core/src/execution/context/mod.rs | 18 +-- .../core/src/execution/session_state.rs | 130 ++++++++++-------- .../src/execution/session_state_defaults.rs | 8 +- datafusion/core/tests/optimizer/mod.rs | 6 +- .../datasource-arrow/src/file_format.rs | 4 +- datafusion/datasource/src/url.rs | 4 +- datafusion/execution/src/task.rs | 38 ++--- datafusion/expr/src/expr.rs | 55 ++++---- datafusion/expr/src/expr_schema.rs | 14 +- datafusion/expr/src/lib.rs | 8 +- datafusion/expr/src/planner.rs | 10 +- datafusion/expr/src/registry.rs | 51 +++---- datafusion/expr/src/tree_node.rs | 14 +- .../expr/src/type_coercion/functions.rs | 16 +-- datafusion/expr/src/udf_eq.rs | 4 +- datafusion/expr/src/{udlf.rs => udhof.rs} | 122 ++++++++-------- datafusion/expr/src/utils.rs | 2 +- datafusion/ffi/src/session/mod.rs | 10 +- .../functions-nested/src/array_transform.rs | 33 +++-- datafusion/functions-nested/src/lib.rs | 20 +-- .../functions-nested/src/macros_lambda.rs | 66 ++++----- .../optimizer/src/analyzer/type_coercion.rs | 20 +-- .../optimizer/src/common_subexpr_eliminate.rs | 6 +- datafusion/optimizer/src/push_down_filter.rs | 2 +- .../simplify_expressions/expr_simplifier.rs | 4 +- .../optimizer/tests/optimizer_integration.rs | 6 +- ...a_function.rs => higher_order_function.rs} | 87 ++++++------ datafusion/physical-expr/src/lib.rs | 4 +- datafusion/physical-expr/src/planner.rs | 8 +- datafusion/proto/src/bytes/mod.rs | 43 +++--- datafusion/proto/src/bytes/registry.rs | 6 +- datafusion/proto/src/logical_plan/mod.rs | 18 ++- datafusion/proto/src/logical_plan/to_proto.rs | 2 +- datafusion/session/src/session.rs | 10 +- datafusion/spark/src/lib.rs | 6 +- datafusion/sql/examples/sql.rs | 6 +- datafusion/sql/src/expr/function.rs | 18 +-- datafusion/sql/src/expr/mod.rs | 8 +- datafusion/sql/src/unparser/dialect.rs | 4 +- datafusion/sql/src/unparser/expr.rs | 26 ++-- datafusion/sql/tests/common/mod.rs | 14 +- .../{lambda.slt => higher_order.slt} | 2 +- .../consumer/expr/scalar_function.rs | 2 +- .../src/logical_plan/producer/expr/mod.rs | 4 +- .../producer/expr/scalar_function.rs | 4 +- .../producer/substrait_producer.rs | 8 +- 50 files changed, 515 insertions(+), 456 deletions(-) rename datafusion/expr/src/{udlf.rs => udhof.rs} (82%) rename datafusion/physical-expr/src/{lambda_function.rs => higher_order_function.rs} (86%) rename datafusion/sqllogictest/test_files/{lambda.slt => higher_order.slt} (99%) diff --git a/datafusion-examples/examples/sql_ops/frontend.rs b/datafusion-examples/examples/sql_ops/frontend.rs index fe8e4bd066ebc..d6995ce4fe611 100644 --- a/datafusion-examples/examples/sql_ops/frontend.rs +++ b/datafusion-examples/examples/sql_ops/frontend.rs @@ -22,8 +22,8 @@ use datafusion::common::{TableReference, plan_err}; use datafusion::config::ConfigOptions; use datafusion::error::Result; use datafusion::logical_expr::{ - AggregateUDF, Expr, LambdaUDF, LogicalPlan, ScalarUDF, TableProviderFilterPushDown, - TableSource, WindowUDF, + AggregateUDF, Expr, HigherOrderUDF, LogicalPlan, ScalarUDF, + TableProviderFilterPushDown, TableSource, WindowUDF, }; use datafusion::optimizer::{ Analyzer, AnalyzerRule, Optimizer, OptimizerConfig, OptimizerContext, OptimizerRule, @@ -155,7 +155,7 @@ impl ContextProvider for MyContextProvider { None } - fn get_lambda_meta(&self, _name: &str) -> Option> { + fn get_higher_order_meta(&self, _name: &str) -> Option> { None } @@ -179,7 +179,7 @@ impl ContextProvider for MyContextProvider { Vec::new() } - fn udlf_names(&self) -> Vec { + fn udhof_names(&self) -> Vec { Vec::new() } diff --git a/datafusion/catalog-listing/src/helpers.rs b/datafusion/catalog-listing/src/helpers.rs index faa218835000f..bd497b6fd162e 100644 --- a/datafusion/catalog-listing/src/helpers.rs +++ b/datafusion/catalog-listing/src/helpers.rs @@ -99,8 +99,8 @@ pub fn expr_applicable_for_cols(col_names: &[&str], expr: &Expr) -> bool { } } } - Expr::LambdaFunction(lambda_function) => { - match lambda_function.func.signature().volatility { + Expr::HigherOrderFunction(hof) => { + match hof.func.signature().volatility { Volatility::Immutable => Ok(TreeNodeRecursion::Continue), // TODO: Stable functions could be `applicable`, but that would require access to the context Volatility::Stable | Volatility::Volatile => { diff --git a/datafusion/core/src/bin/print_functions_docs.rs b/datafusion/core/src/bin/print_functions_docs.rs index 2b681cdd74cb2..68d47b3c01b9b 100644 --- a/datafusion/core/src/bin/print_functions_docs.rs +++ b/datafusion/core/src/bin/print_functions_docs.rs @@ -18,7 +18,7 @@ use datafusion::execution::SessionStateDefaults; use datafusion_common::{HashSet, Result, not_impl_err}; use datafusion_expr::{ - AggregateUDF, DocSection, Documentation, LambdaUDF, ScalarUDF, WindowUDF, + AggregateUDF, DocSection, Documentation, HigherOrderUDF, ScalarUDF, WindowUDF, aggregate_doc_sections, scalar_doc_sections, window_doc_sections, }; use itertools::Itertools; @@ -282,7 +282,7 @@ impl DocProvider for WindowUDF { } } -impl DocProvider for dyn LambdaUDF { +impl DocProvider for dyn HigherOrderUDF { fn get_name(&self) -> String { self.name().to_string() } diff --git a/datafusion/core/src/datasource/listing_table_factory.rs b/datafusion/core/src/datasource/listing_table_factory.rs index b7211d92d8370..db4e2a4440c38 100644 --- a/datafusion/core/src/datasource/listing_table_factory.rs +++ b/datafusion/core/src/datasource/listing_table_factory.rs @@ -595,9 +595,9 @@ mod tests { ) -> &HashMap> { unimplemented!() } - fn lambda_functions( + fn higher_order_functions( &self, - ) -> &HashMap> { + ) -> &HashMap> { unimplemented!() } fn aggregate_functions( diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index dde62be038143..f1deb9366ffbd 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -83,7 +83,7 @@ use datafusion_execution::disk_manager::{ DEFAULT_MAX_TEMP_DIRECTORY_SIZE, DiskManagerBuilder, }; use datafusion_execution::registry::SerializerRegistry; -use datafusion_expr::LambdaUDF; +use datafusion_expr::HigherOrderUDF; pub use datafusion_expr::execution_props::ExecutionProps; #[cfg(feature = "sql")] use datafusion_expr::planner::RelationPlanner; @@ -1978,8 +1978,8 @@ impl FunctionRegistry for SessionContext { self.state.read().udf(name) } - fn udlf(&self, name: &str) -> Result> { - self.state.read().udlf(name) + fn udhof(&self, name: &str) -> Result> { + self.state.read().udhof(name) } fn udaf(&self, name: &str) -> Result> { @@ -1994,11 +1994,11 @@ impl FunctionRegistry for SessionContext { self.state.write().register_udf(udf) } - fn register_udlf( + fn register_udhof( &mut self, - udlf: Arc, - ) -> Result>> { - self.state.write().register_udlf(udlf) + udhof: Arc, + ) -> Result>> { + self.state.write().register_udhof(udhof) } fn register_udaf( @@ -2030,8 +2030,8 @@ impl FunctionRegistry for SessionContext { self.state.write().register_expr_planner(expr_planner) } - fn udlfs(&self) -> HashSet { - self.state.read().udlfs() + fn udhofs(&self) -> HashSet { + self.state.read().udhofs() } fn udafs(&self) -> HashSet { diff --git a/datafusion/core/src/execution/session_state.rs b/datafusion/core/src/execution/session_state.rs index 790086a37bd95..16bc0195004e1 100644 --- a/datafusion/core/src/execution/session_state.rs +++ b/datafusion/core/src/execution/session_state.rs @@ -59,7 +59,7 @@ use datafusion_expr::planner::{RelationPlanner, TypePlanner}; use datafusion_expr::registry::{FunctionRegistry, SerializerRegistry}; use datafusion_expr::simplify::SimplifyContext; use datafusion_expr::{ - AggregateUDF, Explain, Expr, LambdaUDF, LogicalPlan, ScalarUDF, WindowUDF, + AggregateUDF, Explain, Expr, HigherOrderUDF, LogicalPlan, ScalarUDF, WindowUDF, }; use datafusion_optimizer::simplify_expressions::ExprSimplifier; use datafusion_optimizer::{ @@ -156,8 +156,8 @@ pub struct SessionState { table_functions: HashMap>, /// Scalar functions that are registered with the context scalar_functions: HashMap>, - /// Lambda functions that are registered with the context - lambda_functions: HashMap>, + /// Higher order functions that are registered with the context + higher_order_functions: HashMap>, /// Aggregate functions registered in the context aggregate_functions: HashMap>, /// Window functions registered in the context @@ -226,7 +226,7 @@ impl Debug for SessionState { .field("physical_optimizers", &self.physical_optimizers) .field("table_functions", &self.table_functions) .field("scalar_functions", &self.scalar_functions) - .field("lambda_functions", &self.lambda_functions) + .field("higher_order_functions", &self.higher_order_functions) .field("aggregate_functions", &self.aggregate_functions) .field("window_functions", &self.window_functions) .field("prepared_plans", &self.prepared_plans) @@ -263,8 +263,8 @@ impl Session for SessionState { &self.scalar_functions } - fn lambda_functions(&self) -> &HashMap> { - &self.lambda_functions + fn higher_order_functions(&self) -> &HashMap> { + &self.higher_order_functions } fn aggregate_functions(&self) -> &HashMap> { @@ -898,9 +898,9 @@ impl SessionState { &self.scalar_functions } - /// Return reference to lambda_functions - pub fn lambda_functions(&self) -> &HashMap> { - &self.lambda_functions + /// Return reference to higher_order_functions + pub fn higher_order_functions(&self) -> &HashMap> { + &self.higher_order_functions } /// Return reference to aggregate_functions @@ -999,7 +999,7 @@ pub struct SessionStateBuilder { catalog_list: Option>, table_functions: Option>>, scalar_functions: Option>>, - lambda_functions: Option>>, + higher_order_functions: Option>>, aggregate_functions: Option>>, window_functions: Option>>, serializer_registry: Option>, @@ -1040,7 +1040,7 @@ impl SessionStateBuilder { catalog_list: None, table_functions: None, scalar_functions: None, - lambda_functions: None, + higher_order_functions: None, aggregate_functions: None, window_functions: None, serializer_registry: None, @@ -1094,7 +1094,9 @@ impl SessionStateBuilder { catalog_list: Some(existing.catalog_list), table_functions: Some(existing.table_functions), scalar_functions: Some(existing.scalar_functions.into_values().collect_vec()), - lambda_functions: Some(existing.lambda_functions.into_values().collect_vec()), + higher_order_functions: Some( + existing.higher_order_functions.into_values().collect_vec(), + ), aggregate_functions: Some( existing.aggregate_functions.into_values().collect_vec(), ), @@ -1136,9 +1138,9 @@ impl SessionStateBuilder { .get_or_insert_with(Vec::new) .extend(SessionStateDefaults::default_scalar_functions()); - self.lambda_functions + self.higher_order_functions .get_or_insert_with(Vec::new) - .extend(SessionStateDefaults::default_lambda_functions()); + .extend(SessionStateDefaults::default_higher_order_functions()); self.aggregate_functions .get_or_insert_with(Vec::new) @@ -1320,12 +1322,12 @@ impl SessionStateBuilder { self } - /// Set the map of [`LambdaUDF`]s - pub fn with_lambda_functions( + /// Set the map of [`HigherOrderUDF`]s + pub fn with_higher_order_functions( mut self, - lambda_functions: Vec>, + higher_order_functions: Vec>, ) -> Self { - self.lambda_functions = Some(lambda_functions); + self.higher_order_functions = Some(higher_order_functions); self } @@ -1483,7 +1485,7 @@ impl SessionStateBuilder { catalog_list, table_functions, scalar_functions, - lambda_functions, + higher_order_functions, aggregate_functions, window_functions, serializer_registry, @@ -1520,7 +1522,7 @@ impl SessionStateBuilder { }), table_functions: table_functions.unwrap_or_default(), scalar_functions: HashMap::new(), - lambda_functions: HashMap::new(), + higher_order_functions: HashMap::new(), aggregate_functions: HashMap::new(), window_functions: HashMap::new(), serializer_registry: serializer_registry @@ -1574,17 +1576,17 @@ impl SessionStateBuilder { } } - if let Some(lambda_functions) = lambda_functions { - for udlf in lambda_functions { - match state.register_udlf(Arc::clone(&udlf)) { + if let Some(higher_order_functions) = higher_order_functions { + for udhof in higher_order_functions { + match state.register_udhof(Arc::clone(&udhof)) { Ok(Some(existing)) => { - debug!("Overwrote existing UDLF '{}'", existing.name()); + debug!("Overwrote existing UDHOF '{}'", existing.name()); } Ok(None) => { - debug!("Registered UDLF '{}'", udlf.name()); + debug!("Registered UDHOF '{}'", udhof.name()); } Err(err) => { - debug!("Failed to register UDLF '{}': {}", udlf.name(), err); + debug!("Failed to register UDHOF '{}': {}", udhof.name(), err); } } } @@ -1709,8 +1711,10 @@ impl SessionStateBuilder { } /// Returns the current scalar_functions value - pub fn lambda_functions(&mut self) -> &mut Option>> { - &mut self.lambda_functions + pub fn higher_order_functions( + &mut self, + ) -> &mut Option>> { + &mut self.higher_order_functions } /// Returns the current aggregate_functions value @@ -1821,7 +1825,7 @@ impl Debug for SessionStateBuilder { .field("physical_optimizers", &self.physical_optimizers) .field("table_functions", &self.table_functions) .field("scalar_functions", &self.scalar_functions) - .field("lambda_functions", &self.lambda_functions) + .field("higher_order_functions", &self.higher_order_functions) .field("aggregate_functions", &self.aggregate_functions) .field("window_functions", &self.window_functions) .finish() @@ -1929,8 +1933,8 @@ impl ContextProvider for SessionContextProvider<'_> { self.state.scalar_functions().get(name).cloned() } - fn get_lambda_meta(&self, name: &str) -> Option> { - self.state.lambda_functions().get(name).cloned() + fn get_higher_order_meta(&self, name: &str) -> Option> { + self.state.higher_order_functions().get(name).cloned() } fn get_aggregate_meta(&self, name: &str) -> Option> { @@ -1969,8 +1973,12 @@ impl ContextProvider for SessionContextProvider<'_> { self.state.scalar_functions().keys().cloned().collect() } - fn udlf_names(&self) -> Vec { - self.state.lambda_functions().keys().cloned().collect() + fn udhof_names(&self) -> Vec { + self.state + .higher_order_functions() + .keys() + .cloned() + .collect() } fn udaf_names(&self) -> Vec { @@ -2012,11 +2020,11 @@ impl FunctionRegistry for SessionState { }) } - fn udlf(&self, name: &str) -> datafusion_common::Result> { - self.lambda_functions + fn udhof(&self, name: &str) -> datafusion_common::Result> { + self.higher_order_functions .get(name) .cloned() - .ok_or_else(|| plan_datafusion_err!("Lambda Function {name} not found")) + .ok_or_else(|| plan_datafusion_err!("Higher Order Function {name} not found")) } fn udaf(&self, name: &str) -> datafusion_common::Result> { @@ -2046,15 +2054,17 @@ impl FunctionRegistry for SessionState { Ok(self.scalar_functions.insert(udf.name().into(), udf)) } - fn register_udlf( + fn register_udhof( &mut self, - udlf: Arc, - ) -> datafusion_common::Result>> { - udlf.aliases().iter().for_each(|alias| { - self.lambda_functions - .insert(alias.clone(), Arc::clone(&udlf)); + udhof: Arc, + ) -> datafusion_common::Result>> { + udhof.aliases().iter().for_each(|alias| { + self.higher_order_functions + .insert(alias.clone(), Arc::clone(&udhof)); }); - Ok(self.lambda_functions.insert(udlf.name().into(), udlf)) + Ok(self + .higher_order_functions + .insert(udhof.name().into(), udhof)) } fn register_udaf( @@ -2092,17 +2102,17 @@ impl FunctionRegistry for SessionState { Ok(udf) } - fn deregister_udlf( + fn deregister_udhof( &mut self, name: &str, - ) -> datafusion_common::Result>> { - let udlf = self.lambda_functions.remove(name); - if let Some(udlf) = &udlf { - for alias in udlf.aliases() { - self.lambda_functions.remove(alias); + ) -> datafusion_common::Result>> { + let udhof = self.higher_order_functions.remove(name); + if let Some(udhof) = &udhof { + for alias in udhof.aliases() { + self.higher_order_functions.remove(alias); } } - Ok(udlf) + Ok(udhof) } fn deregister_udaf( @@ -2151,8 +2161,8 @@ impl FunctionRegistry for SessionState { Ok(()) } - fn udlfs(&self) -> HashSet { - self.lambda_functions.keys().cloned().collect() + fn udhofs(&self) -> HashSet { + self.higher_order_functions.keys().cloned().collect() } fn udafs(&self) -> HashSet { @@ -2197,7 +2207,7 @@ impl From<&SessionState> for TaskContext { state.session_id.clone(), state.config.clone(), state.scalar_functions.clone(), - state.lambda_functions.clone(), + state.higher_order_functions.clone(), state.aggregate_functions.clone(), state.window_functions.clone(), Arc::clone(&state.runtime_env), @@ -2267,7 +2277,7 @@ mod tests { use datafusion_common::config::Dialect; use datafusion_execution::config::SessionConfig; use datafusion_expr::Expr; - use datafusion_expr::LambdaUDF; + use datafusion_expr::HigherOrderUDF; use datafusion_optimizer::Optimizer; use datafusion_optimizer::optimizer::OptimizerRule; use datafusion_physical_plan::display::DisplayableExecutionPlan; @@ -2581,8 +2591,8 @@ mod tests { self.state.scalar_functions().get(name).cloned() } - fn get_lambda_meta(&self, name: &str) -> Option> { - self.state.lambda_functions().get(name).cloned() + fn get_higher_order_meta(&self, name: &str) -> Option> { + self.state.higher_order_functions().get(name).cloned() } fn get_aggregate_meta(&self, name: &str) -> Option> { @@ -2605,8 +2615,12 @@ mod tests { self.state.scalar_functions().keys().cloned().collect() } - fn udlf_names(&self) -> Vec { - self.state.lambda_functions().keys().cloned().collect() + fn udhof_names(&self) -> Vec { + self.state + .higher_order_functions() + .keys() + .cloned() + .collect() } fn udaf_names(&self) -> Vec { diff --git a/datafusion/core/src/execution/session_state_defaults.rs b/datafusion/core/src/execution/session_state_defaults.rs index 487565ead3093..1bfd177462c53 100644 --- a/datafusion/core/src/execution/session_state_defaults.rs +++ b/datafusion/core/src/execution/session_state_defaults.rs @@ -36,7 +36,7 @@ use datafusion_execution::config::SessionConfig; use datafusion_execution::object_store::ObjectStoreUrl; use datafusion_execution::runtime_env::RuntimeEnv; use datafusion_expr::planner::ExprPlanner; -use datafusion_expr::{AggregateUDF, LambdaUDF, ScalarUDF, WindowUDF}; +use datafusion_expr::{AggregateUDF, HigherOrderUDF, ScalarUDF, WindowUDF}; use std::collections::HashMap; use std::sync::Arc; use url::Url; @@ -112,10 +112,10 @@ impl SessionStateDefaults { functions } - /// returns the list of default [`LambdaUDF`]s - pub fn default_lambda_functions() -> Vec> { + /// returns the list of default [`HigherOrderUDF`]s + pub fn default_higher_order_functions() -> Vec> { #[cfg(feature = "nested_expressions")] - return functions_nested::all_default_lambda_functions(); + return functions_nested::all_default_higher_order_functions(); #[cfg(not(feature = "nested_expressions"))] return Vec::new(); diff --git a/datafusion/core/tests/optimizer/mod.rs b/datafusion/core/tests/optimizer/mod.rs index 1dee923ee2356..3821f88187b4a 100644 --- a/datafusion/core/tests/optimizer/mod.rs +++ b/datafusion/core/tests/optimizer/mod.rs @@ -31,7 +31,7 @@ use datafusion_common::tree_node::TransformedResult; use datafusion_common::{DFSchema, Result, ScalarValue, TableReference, plan_err}; use datafusion_expr::interval_arithmetic::{Interval, NullableInterval}; use datafusion_expr::{ - AggregateUDF, BinaryExpr, Expr, ExprSchemable, LambdaUDF, LogicalPlan, Operator, + AggregateUDF, BinaryExpr, Expr, ExprSchemable, HigherOrderUDF, LogicalPlan, Operator, ScalarUDF, TableSource, WindowUDF, col, lit, }; use datafusion_functions::core::expr_ext::FieldAccessor; @@ -217,7 +217,7 @@ impl ContextProvider for MyContextProvider { self.udfs.get(name).cloned() } - fn get_lambda_meta(&self, _name: &str) -> Option> { + fn get_higher_order_meta(&self, _name: &str) -> Option> { None } @@ -241,7 +241,7 @@ impl ContextProvider for MyContextProvider { Vec::new() } - fn udlf_names(&self) -> Vec { + fn udhof_names(&self) -> Vec { Vec::new() } diff --git a/datafusion/datasource-arrow/src/file_format.rs b/datafusion/datasource-arrow/src/file_format.rs index 23a60bea6f884..a51875bd4b7c7 100644 --- a/datafusion/datasource-arrow/src/file_format.rs +++ b/datafusion/datasource-arrow/src/file_format.rs @@ -556,7 +556,7 @@ mod tests { use datafusion_execution::runtime_env::RuntimeEnv; use datafusion_expr::execution_props::ExecutionProps; use datafusion_expr::{ - AggregateUDF, Expr, LambdaUDF, LogicalPlan, ScalarUDF, WindowUDF, + AggregateUDF, Expr, HigherOrderUDF, LogicalPlan, ScalarUDF, WindowUDF, }; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use object_store::{chunked::ChunkedStore, memory::InMemory}; @@ -604,7 +604,7 @@ mod tests { unimplemented!() } - fn lambda_functions(&self) -> &HashMap> { + fn higher_order_functions(&self) -> &HashMap> { unimplemented!() } diff --git a/datafusion/datasource/src/url.rs b/datafusion/datasource/src/url.rs index c0190235524d1..5945b00c9fa21 100644 --- a/datafusion/datasource/src/url.rs +++ b/datafusion/datasource/src/url.rs @@ -518,7 +518,7 @@ mod tests { use datafusion_execution::runtime_env::RuntimeEnv; use datafusion_expr::execution_props::ExecutionProps; use datafusion_expr::{ - AggregateUDF, Expr, LambdaUDF, LogicalPlan, ScalarUDF, WindowUDF, + AggregateUDF, Expr, HigherOrderUDF, LogicalPlan, ScalarUDF, WindowUDF, }; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use datafusion_physical_plan::ExecutionPlan; @@ -1209,7 +1209,7 @@ mod tests { unimplemented!() } - fn lambda_functions(&self) -> &HashMap> { + fn higher_order_functions(&self) -> &HashMap> { unimplemented!() } diff --git a/datafusion/execution/src/task.rs b/datafusion/execution/src/task.rs index 9427f1179c09c..ed3de93ff0e6d 100644 --- a/datafusion/execution/src/task.rs +++ b/datafusion/execution/src/task.rs @@ -21,7 +21,7 @@ use crate::{ }; use datafusion_common::{Result, internal_datafusion_err, plan_datafusion_err}; use datafusion_expr::planner::ExprPlanner; -use datafusion_expr::{AggregateUDF, LambdaUDF, ScalarUDF, WindowUDF}; +use datafusion_expr::{AggregateUDF, HigherOrderUDF, ScalarUDF, WindowUDF}; use std::collections::HashSet; use std::{collections::HashMap, sync::Arc}; @@ -42,8 +42,8 @@ pub struct TaskContext { session_config: SessionConfig, /// Scalar functions associated with this task context scalar_functions: HashMap>, - /// Lambda functions associated with this task context - lambda_functions: HashMap>, + /// Higher order functions associated with this task context + higher_order_functions: HashMap>, /// Aggregate functions associated with this task context aggregate_functions: HashMap>, /// Window functions associated with this task context @@ -62,7 +62,7 @@ impl Default for TaskContext { task_id: None, session_config: SessionConfig::new(), scalar_functions: HashMap::new(), - lambda_functions: HashMap::new(), + higher_order_functions: HashMap::new(), aggregate_functions: HashMap::new(), window_functions: HashMap::new(), runtime, @@ -82,7 +82,7 @@ impl TaskContext { session_id: String, session_config: SessionConfig, scalar_functions: HashMap>, - lambda_functions: HashMap>, + higher_order_functions: HashMap>, aggregate_functions: HashMap>, window_functions: HashMap>, runtime: Arc, @@ -92,7 +92,7 @@ impl TaskContext { session_id, session_config, scalar_functions, - lambda_functions, + higher_order_functions, aggregate_functions, window_functions, runtime, @@ -162,11 +162,11 @@ impl FunctionRegistry for TaskContext { }) } - fn udlf(&self, name: &str) -> Result> { - let result = self.lambda_functions.get(name); + fn udhof(&self, name: &str) -> Result> { + let result = self.higher_order_functions.get(name); result.cloned().ok_or_else(|| { - plan_datafusion_err!("There is no UDLF named \"{name}\" in the TaskContext") + plan_datafusion_err!("There is no UDHOF named \"{name}\" in the TaskContext") }) } @@ -212,23 +212,25 @@ impl FunctionRegistry for TaskContext { Ok(self.scalar_functions.insert(udf.name().into(), udf)) } - fn register_udlf( + fn register_udhof( &mut self, - udlf: Arc, - ) -> Result>> { - udlf.aliases().iter().for_each(|alias| { - self.lambda_functions - .insert(alias.clone(), Arc::clone(&udlf)); + udhof: Arc, + ) -> Result>> { + udhof.aliases().iter().for_each(|alias| { + self.higher_order_functions + .insert(alias.clone(), Arc::clone(&udhof)); }); - Ok(self.lambda_functions.insert(udlf.name().into(), udlf)) + Ok(self + .higher_order_functions + .insert(udhof.name().into(), udhof)) } fn expr_planners(&self) -> Vec> { vec![] } - fn udlfs(&self) -> HashSet { - self.lambda_functions.keys().cloned().collect() + fn udhofs(&self) -> HashSet { + self.higher_order_functions.keys().cloned().collect() } fn udafs(&self) -> HashSet { diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 4ba918b7783ed..c4355478852a3 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -27,8 +27,8 @@ use std::sync::Arc; use crate::expr_fn::binary_expr; use crate::function::WindowFunctionSimplification; use crate::logical_plan::Subquery; -use crate::type_coercion::functions::value_fields_with_lambda_udf; -use crate::udlf::LambdaUDF; +use crate::type_coercion::functions::value_fields_with_higher_order_udf; +use crate::udhof::HigherOrderUDF; use crate::{AggregateUDF, ValueOrLambda, Volatility}; use crate::{ExprSchemable, Operator, Signature, WindowFrame, WindowUDF}; @@ -406,12 +406,12 @@ pub enum Expr { OuterReferenceColumn(FieldRef, Column), /// Unnest expression Unnest(Unnest), - /// Call a lambda function with a set of arguments. + /// Call a higher order function with a set of arguments. /// /// For example, `array_transform([1,2,3], v -> v+1)` would be equivalent to: /// /// ```text - /// LambdaFunction(array_transform) + /// HigherOrderFunction(array_transform) /// ├── args[0]: Literal([1,2,3]) /// └── args[1]: Lambda /// ├── params: ["v"] @@ -419,25 +419,25 @@ pub enum Expr { /// ├── LambdaVariable("v") /// └── Literal(1) /// ``` - LambdaFunction(LambdaFunction), + HigherOrderFunction(HigherOrderFunction), /// A Lambda expression with a set of parameters names and a body Lambda(Lambda), /// A named reference to a lambda parameter LambdaVariable(LambdaVariable), } -/// Invoke a [`LambdaUDF`] with a set of arguments +/// Invoke a [`HigherOrderUDF`] with a set of arguments #[derive(Clone, Eq, PartialOrd, Debug)] -pub struct LambdaFunction { +pub struct HigherOrderFunction { /// The function - pub func: Arc, + pub func: Arc, /// List of expressions to feed to the functions as arguments pub args: Vec, } -impl LambdaFunction { - /// Create a new `LambdaFunction` from a [`LambdaUDF`] - pub fn new(func: Arc, args: Vec) -> Self { +impl HigherOrderFunction { + /// Create a new `HigherOrderFunction` from a [`HigherOrderUDF`] + pub fn new(func: Arc, args: Vec) -> Self { Self { func, args } } @@ -445,7 +445,7 @@ impl LambdaFunction { self.func.name() } - /// Invokes the inner function [`LambdaUDF::lambda_parameters`] + /// Invokes the inner function [`HigherOrderUDF::lambda_parameters`] /// using the arguments of this invocation pub fn lambda_parameters(&self, schema: &dyn ExprSchema) -> Result>> { let args = self @@ -457,26 +457,27 @@ impl LambdaFunction { }) .collect::>>()?; - let coerced_values = value_fields_with_lambda_udf(&args, self.func.as_ref())? - .into_iter() - .filter_map(|arg| match arg { - ValueOrLambda::Value(value) => Some(value), - ValueOrLambda::Lambda(_lambda) => None, - }) - .collect::>(); + let coerced_values = + value_fields_with_higher_order_udf(&args, self.func.as_ref())? + .into_iter() + .filter_map(|arg| match arg { + ValueOrLambda::Value(value) => Some(value), + ValueOrLambda::Lambda(_lambda) => None, + }) + .collect::>(); self.func.lambda_parameters(&coerced_values) } } -impl Hash for LambdaFunction { +impl Hash for HigherOrderFunction { fn hash(&self, state: &mut H) { self.func.hash(state); self.args.hash(state); } } -impl PartialEq for LambdaFunction { +impl PartialEq for HigherOrderFunction { fn eq(&self, other: &Self) -> bool { self.func.as_ref() == other.func.as_ref() && self.args == other.args } @@ -1730,7 +1731,7 @@ impl Expr { #[expect(deprecated)] Expr::Wildcard { .. } => "Wildcard", Expr::Unnest { .. } => "Unnest", - Expr::LambdaFunction { .. } => "LambdaFunction", + Expr::HigherOrderFunction { .. } => "HigherOrderFunction", Expr::Lambda { .. } => "Lambda", Expr::LambdaVariable { .. } => "LambdaVariable", } @@ -2248,7 +2249,9 @@ impl Expr { pub fn short_circuits(&self) -> bool { match self { Expr::ScalarFunction(ScalarFunction { func, .. }) => func.short_circuits(), - Expr::LambdaFunction(LambdaFunction { func, .. }) => func.short_circuits(), + Expr::HigherOrderFunction(HigherOrderFunction { func, .. }) => { + func.short_circuits() + } Expr::BinaryExpr(BinaryExpr { op, .. }) => { matches!(op, Operator::And | Operator::Or) } @@ -2890,7 +2893,7 @@ impl HashNode for Expr { column.hash(state); } Expr::Unnest(Unnest { expr: _expr }) => {} - Expr::LambdaFunction(LambdaFunction { func, args: _args }) => { + Expr::HigherOrderFunction(HigherOrderFunction { func, args: _args }) => { func.hash(state); } Expr::Lambda(Lambda { params, body: _ }) => { @@ -3222,7 +3225,7 @@ impl Display for SchemaDisplay<'_> { } } } - Expr::LambdaFunction(LambdaFunction { func, args }) => { + Expr::HigherOrderFunction(HigherOrderFunction { func, args }) => { match func.schema_name(args) { Ok(name) => { write!(f, "{name}") @@ -3741,7 +3744,7 @@ impl Display for Expr { Expr::Unnest(Unnest { expr }) => { write!(f, "{UNNEST_COLUMN_PREFIX}({expr})") } - Expr::LambdaFunction(fun) => { + Expr::HigherOrderFunction(fun) => { fmt_function(f, fun.name(), false, &fun.args, true) } Expr::Lambda(Lambda { params, body }) => { diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index 396c00ef7970a..5fc1aa8f8420c 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -23,10 +23,10 @@ use crate::expr::{ WindowFunctionParams, }; use crate::expr::{FieldMetadata, LambdaVariable}; -use crate::type_coercion::functions::value_fields_with_lambda_udf; +use crate::type_coercion::functions::value_fields_with_higher_order_udf; use crate::type_coercion::functions::{UDFCoercionExt, fields_with_udf}; use crate::udf::ReturnFieldArgs; -use crate::udlf::LambdaReturnFieldArgs; +use crate::udhof::HigherOrderReturnFieldArgs; use crate::{LogicalPlan, Projection, Subquery, WindowFunctionDefinition, utils}; use arrow::compute::can_cast_types; use arrow::datatypes::FieldRef; @@ -203,7 +203,7 @@ impl ExprSchemable for Expr { // Grouping sets do not really have a type and do not appear in projections Ok(DataType::Null) } - Expr::LambdaFunction(_func) => { + Expr::HigherOrderFunction(_func) => { Ok(self.to_field(schema)?.1.data_type().clone()) } Expr::Lambda(Lambda { params: _, body }) => body.get_type(schema), @@ -363,7 +363,7 @@ impl ExprSchemable for Expr { // in projections Ok(true) } - Expr::LambdaFunction(_func) => { + Expr::HigherOrderFunction(_func) => { Ok(self.to_field(input_schema)?.1.is_nullable()) } Expr::Lambda(l) => l.body.nullable(input_schema), @@ -608,7 +608,7 @@ impl ExprSchemable for Expr { self.get_type(schema)?, self.nullable(schema)?, ))), - Expr::LambdaFunction(func) => { + Expr::HigherOrderFunction(func) => { let arg_fields = func .args .iter() @@ -622,7 +622,7 @@ impl ExprSchemable for Expr { .collect::>>()?; let new_fields = - value_fields_with_lambda_udf(&arg_fields, func.func.as_ref())?; + value_fields_with_higher_order_udf(&arg_fields, func.func.as_ref())?; let arguments = func .args @@ -633,7 +633,7 @@ impl ExprSchemable for Expr { }) .collect::>(); - let args = LambdaReturnFieldArgs { + let args = HigherOrderReturnFieldArgs { arg_fields: &new_fields, scalar_arguments: &arguments, }; diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index 6dd91b1012a61..fee13461f50ca 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -43,7 +43,7 @@ mod partition_evaluator; mod table_source; mod udaf; mod udf; -mod udlf; +mod udhof; mod udwf; pub mod arguments; @@ -127,9 +127,9 @@ pub use udaf::{ udaf_default_window_function_schema_name, }; pub use udf::{ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl}; -pub use udlf::{ - LambdaArgument, LambdaFunctionArgs, LambdaReturnFieldArgs, LambdaSignature, - LambdaTypeSignature, LambdaUDF, ValueOrLambda, +pub use udhof::{ + HigherOrderFunctionArgs, HigherOrderReturnFieldArgs, HigherOrderSignature, + HigherOrderTypeSignature, HigherOrderUDF, LambdaArgument, ValueOrLambda, }; pub use udwf::{LimitEffect, ReversedUDWF, WindowUDF, WindowUDFImpl}; pub use window_frame::{WindowFrame, WindowFrameBound, WindowFrameUnits}; diff --git a/datafusion/expr/src/planner.rs b/datafusion/expr/src/planner.rs index fe5d71d338941..36299abf7a096 100644 --- a/datafusion/expr/src/planner.rs +++ b/datafusion/expr/src/planner.rs @@ -24,7 +24,7 @@ use crate::expr::NullTreatment; #[cfg(feature = "sql")] use crate::logical_plan::LogicalPlan; use crate::{ - AggregateUDF, Expr, GetFieldAccess, LambdaUDF, ScalarUDF, SortExpr, TableSource, + AggregateUDF, Expr, GetFieldAccess, HigherOrderUDF, ScalarUDF, SortExpr, TableSource, WindowFrame, WindowFunctionDefinition, WindowUDF, }; use arrow::datatypes::{DataType, Field, FieldRef, SchemaRef}; @@ -103,8 +103,8 @@ pub trait ContextProvider { /// Return the scalar function with a given name, if any fn get_function_meta(&self, name: &str) -> Option>; - /// Return the lambda function with a given name, if any - fn get_lambda_meta(&self, name: &str) -> Option>; + /// Return the higher order function with a given name, if any + fn get_higher_order_meta(&self, name: &str) -> Option>; /// Return the aggregate function with a given name, if any fn get_aggregate_meta(&self, name: &str) -> Option>; @@ -134,8 +134,8 @@ pub trait ContextProvider { /// Return all scalar function names fn udf_names(&self) -> Vec; - /// Return all lambda function names - fn udlf_names(&self) -> Vec; + /// Return all higher order function names + fn udhof_names(&self) -> Vec; /// Return all aggregate function names fn udaf_names(&self) -> Vec; diff --git a/datafusion/expr/src/registry.rs b/datafusion/expr/src/registry.rs index d563b96b2551d..1356cdc8565de 100644 --- a/datafusion/expr/src/registry.rs +++ b/datafusion/expr/src/registry.rs @@ -19,7 +19,7 @@ use crate::expr_rewriter::FunctionRewrite; use crate::planner::ExprPlanner; -use crate::udlf::LambdaUDF; +use crate::udhof::HigherOrderUDF; use crate::{AggregateUDF, ScalarUDF, UserDefinedLogicalNode, WindowUDF}; use datafusion_common::{HashMap, Result, not_impl_err, plan_datafusion_err}; use std::collections::HashSet; @@ -31,8 +31,8 @@ pub trait FunctionRegistry { /// Returns names of all available scalar user defined functions. fn udfs(&self) -> HashSet; - /// Returns names of all available lambda user defined functions. - fn udlfs(&self) -> HashSet; + /// Returns names of all available higher order user defined functions. + fn udhofs(&self) -> HashSet; /// Returns names of all available aggregate user defined functions. fn udafs(&self) -> HashSet; @@ -44,9 +44,9 @@ pub trait FunctionRegistry { /// `name`. fn udf(&self, name: &str) -> Result>; - /// Returns a reference to the user defined lambda function (udlf) named + /// Returns a reference to the user defined higher order function (udhof) named /// `name`. - fn udlf(&self, name: &str) -> Result>; + fn udhof(&self, name: &str) -> Result>; /// Returns a reference to the user defined aggregate function (udaf) named /// `name`. @@ -64,16 +64,16 @@ pub trait FunctionRegistry { fn register_udf(&mut self, _udf: Arc) -> Result>> { not_impl_err!("Registering ScalarUDF") } - /// Registers a new [`LambdaUDF`], returning any previously registered + /// Registers a new [`HigherOrderUDF`], returning any previously registered /// implementation. /// /// Returns an error (the default) if the function can not be registered, /// for example if the registry is read only. - fn register_udlf( + fn register_udhof( &mut self, - _udlf: Arc, - ) -> Result>> { - not_impl_err!("Registering LambdaUDF") + _udhof: Arc, + ) -> Result>> { + not_impl_err!("Registering HigherOrderUDF") } /// Registers a new [`AggregateUDF`], returning any previously registered /// implementation. @@ -104,13 +104,16 @@ pub trait FunctionRegistry { not_impl_err!("Deregistering ScalarUDF") } - /// Deregisters a [`LambdaUDF`], returning the implementation that was + /// Deregisters a [`HigherOrderUDF`], returning the implementation that was /// deregistered. /// /// Returns an error (the default) if the function can not be deregistered, /// for example if the registry is read only. - fn deregister_udlf(&mut self, _name: &str) -> Result>> { - not_impl_err!("Deregistering LambdaUDF") + fn deregister_udhof( + &mut self, + _name: &str, + ) -> Result>> { + not_impl_err!("Deregistering HigherOrderUDF") } /// Deregisters a [`AggregateUDF`], returning the implementation that was @@ -180,8 +183,8 @@ pub trait SerializerRegistry: Debug + Send + Sync { pub struct MemoryFunctionRegistry { /// Scalar Functions udfs: HashMap>, - /// Lambda Functions - udlfs: HashMap>, + /// Higher Order Functions + udhof: HashMap>, /// Aggregate Functions udafs: HashMap>, /// Window Functions @@ -206,11 +209,11 @@ impl FunctionRegistry for MemoryFunctionRegistry { .ok_or_else(|| plan_datafusion_err!("Function {name} not found")) } - fn udlf(&self, name: &str) -> Result> { - self.udlfs + fn udhof(&self, name: &str) -> Result> { + self.udhof .get(name) .cloned() - .ok_or_else(|| plan_datafusion_err!("Lambda Function {name} not found")) + .ok_or_else(|| plan_datafusion_err!("Higher Order Function {name} not found")) } fn udaf(&self, name: &str) -> Result> { @@ -230,11 +233,11 @@ impl FunctionRegistry for MemoryFunctionRegistry { fn register_udf(&mut self, udf: Arc) -> Result>> { Ok(self.udfs.insert(udf.name().to_string(), udf)) } - fn register_udlf( + fn register_udhof( &mut self, - udlf: Arc, - ) -> Result>> { - Ok(self.udlfs.insert(udlf.name().into(), udlf)) + udhof: Arc, + ) -> Result>> { + Ok(self.udhof.insert(udhof.name().into(), udhof)) } fn register_udaf( &mut self, @@ -250,8 +253,8 @@ impl FunctionRegistry for MemoryFunctionRegistry { vec![] } - fn udlfs(&self) -> HashSet { - self.udlfs.keys().cloned().collect() + fn udhofs(&self) -> HashSet { + self.udhof.keys().cloned().collect() } fn udafs(&self) -> HashSet { diff --git a/datafusion/expr/src/tree_node.rs b/datafusion/expr/src/tree_node.rs index 0dff4e7a14328..2d86c5560a4ca 100644 --- a/datafusion/expr/src/tree_node.rs +++ b/datafusion/expr/src/tree_node.rs @@ -21,8 +21,8 @@ use crate::{ Expr, expr::{ AggregateFunction, AggregateFunctionParams, Alias, Between, BinaryExpr, Case, - Cast, GroupingSet, InList, InSubquery, Lambda, LambdaFunction, Like, Placeholder, - ScalarFunction, SetComparison, TryCast, Unnest, WindowFunction, + Cast, GroupingSet, HigherOrderFunction, InList, InSubquery, Lambda, Like, + Placeholder, ScalarFunction, SetComparison, TryCast, Unnest, WindowFunction, WindowFunctionParams, }, }; @@ -112,7 +112,7 @@ impl TreeNode for Expr { Expr::InList(InList { expr, list, .. }) => { (expr, list).apply_ref_elements(f) } - Expr::LambdaFunction(LambdaFunction { func: _, args}) => args.apply_elements(f), + Expr::HigherOrderFunction(HigherOrderFunction { func: _, args}) => args.apply_elements(f), Expr::Lambda (Lambda{ params: _, body}) => body.apply_elements(f) } } @@ -333,9 +333,11 @@ impl TreeNode for Expr { .update_data(|(new_expr, new_list)| { Expr::InList(InList::new(new_expr, new_list, negated)) }), - Expr::LambdaFunction(LambdaFunction { func, args }) => args - .map_elements(f)? - .update_data(|args| Expr::LambdaFunction(LambdaFunction { func, args })), + Expr::HigherOrderFunction(HigherOrderFunction { func, args }) => { + args.map_elements(f)?.update_data(|args| { + Expr::HigherOrderFunction(HigherOrderFunction { func, args }) + }) + } Expr::Lambda(Lambda { params, body }) => body .map_elements(f)? .update_data(|body| Expr::Lambda(Lambda { params, body })), diff --git a/datafusion/expr/src/type_coercion/functions.rs b/datafusion/expr/src/type_coercion/functions.rs index 374fb3d0e8a48..4378962e148df 100644 --- a/datafusion/expr/src/type_coercion/functions.rs +++ b/datafusion/expr/src/type_coercion/functions.rs @@ -17,8 +17,8 @@ use super::binary::binary_numeric_coercion; use crate::{ - AggregateUDF, LambdaTypeSignature, LambdaUDF, ScalarUDF, Signature, TypeSignature, - ValueOrLambda, WindowUDF, + AggregateUDF, HigherOrderTypeSignature, HigherOrderUDF, ScalarUDF, Signature, + TypeSignature, ValueOrLambda, WindowUDF, }; use arrow::datatypes::{Field, FieldRef}; use arrow::{ @@ -151,7 +151,7 @@ pub fn fields_with_udf( .collect()) } -/// Performs type coercion for lambda function arguments. +/// Performs type coercion for higher order function arguments. /// /// For value arguments, returns the field to which each /// argument must be coerced to match `signature`. @@ -159,12 +159,12 @@ pub fn fields_with_udf( /// /// For more details on coercion in general, please see the /// [`type_coercion`](crate::type_coercion) module. -pub fn value_fields_with_lambda_udf( +pub fn value_fields_with_higher_order_udf( current_fields: &[ValueOrLambda], - func: &dyn LambdaUDF, + func: &dyn HigherOrderUDF, ) -> Result>> { match func.signature().type_signature { - LambdaTypeSignature::UserDefined => { + HigherOrderTypeSignature::UserDefined => { let arg_types = current_fields .iter() .filter_map(|p| match p { @@ -211,8 +211,8 @@ pub fn value_fields_with_lambda_udf( }) .collect()) } - LambdaTypeSignature::VariadicAny => Ok(current_fields.to_vec()), - LambdaTypeSignature::Any(number) => { + HigherOrderTypeSignature::VariadicAny => Ok(current_fields.to_vec()), + HigherOrderTypeSignature::Any(number) => { if current_fields.len() != number { return plan_err!( "The function '{}' expected {number} arguments but received {}", diff --git a/datafusion/expr/src/udf_eq.rs b/datafusion/expr/src/udf_eq.rs index 3e59f2004f362..86ae9964492c3 100644 --- a/datafusion/expr/src/udf_eq.rs +++ b/datafusion/expr/src/udf_eq.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use crate::{AggregateUDFImpl, LambdaUDF, ScalarUDFImpl, WindowUDFImpl}; +use crate::{AggregateUDFImpl, HigherOrderUDF, ScalarUDFImpl, WindowUDFImpl}; use std::fmt::Debug; use std::hash::{DefaultHasher, Hash, Hasher}; use std::ops::Deref; @@ -93,7 +93,7 @@ impl UdfPointer for Arc { } } -impl UdfPointer for Arc { +impl UdfPointer for Arc { fn equals(&self, other: &Self::Target) -> bool { self.as_ref().dyn_eq(other.as_any()) } diff --git a/datafusion/expr/src/udlf.rs b/datafusion/expr/src/udhof.rs similarity index 82% rename from datafusion/expr/src/udlf.rs rename to datafusion/expr/src/udhof.rs index f82f281f70e5e..0092f098ad56c 100644 --- a/datafusion/expr/src/udlf.rs +++ b/datafusion/expr/src/udhof.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -//! [`LambdaUDF`]: Lambda User Defined Functions +//! [`HigherOrderUDF`]: User Defined Higher Order Functions use crate::expr::schema_name_from_exprs_comma_separated_without_space; use crate::{ColumnarValue, Documentation, Expr}; @@ -34,7 +34,7 @@ use std::sync::Arc; /// The types of arguments for which a function has implementations. /// -/// [`LambdaTypeSignature`] **DOES NOT** define the types that a user query could call the +/// [`HigherOrderTypeSignature`] **DOES NOT** define the types that a user query could call the /// function with. DataFusion will automatically coerce (cast) argument types to /// one of the supported function signatures, if possible. /// @@ -43,17 +43,17 @@ use std::sync::Arc; /// argument [`DataType`]s, rather than all possible combinations. If a user /// calls a function with arguments that do not match any of the declared types, /// DataFusion will attempt to automatically coerce (add casts to) function -/// arguments so they match the [`LambdaTypeSignature`]. See the [`type_coercion`] module +/// arguments so they match the [`HigherOrderTypeSignature`]. See the [`type_coercion`] module /// for more details /// /// [`type_coercion`]: crate::type_coercion #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] -pub enum LambdaTypeSignature { +pub enum HigherOrderTypeSignature { /// The acceptable signature and coercions rules are special for this /// function. /// /// If this signature is specified, - /// DataFusion will call [`LambdaUDF::coerce_value_types`] to prepare argument types. + /// DataFusion will call [`HigherOrderUDF::coerce_value_types`] to prepare argument types. UserDefined, /// One or more lambdas or arguments with arbitrary types VariadicAny, @@ -61,24 +61,24 @@ pub enum LambdaTypeSignature { Any(usize), } -/// Provides information necessary for calling a lambda function. +/// Provides information necessary for calling a higher order function. /// -/// - [`LambdaTypeSignature`] defines the argument types that a function has implementations +/// - [`HigherOrderTypeSignature`] defines the argument types that a function has implementations /// for. /// /// - [`Volatility`] defines how the output of the function changes with the input. #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] -pub struct LambdaSignature { - /// The data types that the function accepts. See [LambdaTypeSignature] for more information. - pub type_signature: LambdaTypeSignature, +pub struct HigherOrderSignature { + /// The data types that the function accepts. See [HigherOrderTypeSignature] for more information. + pub type_signature: HigherOrderTypeSignature, /// The volatility of the function. See [Volatility] for more information. pub volatility: Volatility, } -impl LambdaSignature { - /// Creates a new `LambdaSignature` from a given type signature and volatility. - pub fn new(type_signature: LambdaTypeSignature, volatility: Volatility) -> Self { - LambdaSignature { +impl HigherOrderSignature { + /// Creates a new `HigherOrderSignature` from a given type signature and volatility. + pub fn new(type_signature: HigherOrderTypeSignature, volatility: Volatility) -> Self { + HigherOrderSignature { type_signature, volatility, } @@ -87,7 +87,7 @@ impl LambdaSignature { /// User-defined coercion rules for the function. pub fn user_defined(volatility: Volatility) -> Self { Self { - type_signature: LambdaTypeSignature::UserDefined, + type_signature: HigherOrderTypeSignature::UserDefined, volatility, } } @@ -95,7 +95,7 @@ impl LambdaSignature { /// An arbitrary number of lambdas or arguments of any type. pub fn variadic_any(volatility: Volatility) -> Self { Self { - type_signature: LambdaTypeSignature::VariadicAny, + type_signature: HigherOrderTypeSignature::VariadicAny, volatility, } } @@ -103,19 +103,19 @@ impl LambdaSignature { /// A specified number of arguments of any type pub fn any(arg_count: usize, volatility: Volatility) -> Self { Self { - type_signature: LambdaTypeSignature::Any(arg_count), + type_signature: HigherOrderTypeSignature::Any(arg_count), volatility, } } } -impl PartialEq for dyn LambdaUDF { +impl PartialEq for dyn HigherOrderUDF { fn eq(&self, other: &Self) -> bool { self.dyn_eq(other.as_any()) } } -impl PartialOrd for dyn LambdaUDF { +impl PartialOrd for dyn HigherOrderUDF { fn partial_cmp(&self, other: &Self) -> Option { let mut cmp = self.name().cmp(other.name()); if cmp == Ordering::Equal { @@ -144,28 +144,28 @@ impl PartialOrd for dyn LambdaUDF { } } -impl Eq for dyn LambdaUDF {} +impl Eq for dyn HigherOrderUDF {} -impl Hash for dyn LambdaUDF { +impl Hash for dyn HigherOrderUDF { fn hash(&self, state: &mut H) { self.dyn_hash(state) } } -/// Arguments passed to [`LambdaUDF::invoke_with_args`] when invoking a -/// lambda function. +/// Arguments passed to [`HigherOrderUDF::invoke_with_args`] when invoking a +/// higher order function. #[derive(Debug, Clone)] -pub struct LambdaFunctionArgs { +pub struct HigherOrderFunctionArgs { /// The evaluated arguments and lambdas to the function pub args: Vec>, /// Field associated with each arg, if it exists /// For lambdas, it will be the field of the result of /// the lambda if evaluated with the parameters - /// returned from [`LambdaUDF::lambda_parameters`] + /// returned from [`HigherOrderUDF::lambda_parameters`] pub arg_fields: Vec>, /// The number of rows in record batch being evaluated pub number_rows: usize, - /// The return field of the lambda function returned + /// The return field of the higher order function returned /// (from `return_field_from_args`) when creating the /// physical expression from the logical expression pub return_field: FieldRef, @@ -173,7 +173,7 @@ pub struct LambdaFunctionArgs { pub config_options: Arc, } -impl LambdaFunctionArgs { +impl HigherOrderFunctionArgs { /// The return type of the function. See [`Self::return_field`] for more /// details. pub fn return_type(&self) -> &DataType { @@ -181,7 +181,7 @@ impl LambdaFunctionArgs { } } -/// A lambda argument to a LambdaFunction +/// A lambda argument to a HigherOrderFunction #[derive(Clone, Debug)] pub struct LambdaArgument { /// The parameters defined in this lambda @@ -203,7 +203,7 @@ impl LambdaArgument { /// Evaluate this lambda /// `args` should evaluate to the value of each parameter - /// of the correspondent lambda returned in [LambdaUDF::lambda_parameters]. + /// of the correspondent lambda returned in [HigherOrderUDF::lambda_parameters]. pub fn evaluate( &self, args: &[&dyn Fn() -> Result], @@ -228,13 +228,13 @@ impl LambdaArgument { /// such as the type of the arguments, any scalar arguments and if the /// arguments can (ever) be null /// -/// See [`LambdaUDF::return_field_from_args`] for more information +/// See [`HigherOrderUDF::return_field_from_args`] for more information #[derive(Clone, Debug)] -pub struct LambdaReturnFieldArgs<'a> { +pub struct HigherOrderReturnFieldArgs<'a> { /// The data types of the arguments to the function /// /// If argument `i` to the function is a lambda, it will be the field of the result of the - /// lambda if evaluated with the parameters returned from [`LambdaUDF::lambda_parameters`] + /// lambda if evaluated with the parameters returned from [`HigherOrderUDF::lambda_parameters`] /// /// For example, with `array_transform([1], v -> v == 5)` /// this field will be `[ @@ -251,7 +251,7 @@ pub struct LambdaReturnFieldArgs<'a> { pub scalar_arguments: &'a [Option<&'a ScalarValue>], } -/// An argument to a lambda function +/// An argument to a higher order function #[derive(Clone, Debug, PartialEq, Eq)] pub enum ValueOrLambda { /// A value with associated data @@ -260,7 +260,7 @@ pub enum ValueOrLambda { Lambda(L), } -/// Trait for implementing user defined lambda functions. +/// Trait for implementing user defined higher order functions. /// /// This trait exposes the full API for implementing user defined functions and /// can be used to implement any function. @@ -268,7 +268,7 @@ pub enum ValueOrLambda { /// See [`array_transform.rs`] for a commented complete implementation /// /// [`array_transform.rs`]: https://github.com/apache/datafusion/blob/main/datafusion/functions-nested/src/array_transform.rs -pub trait LambdaUDF: Debug + DynEq + DynHash + Send + Sync { +pub trait HigherOrderUDF: Debug + DynEq + DynHash + Send + Sync { /// Returns this object as an [`Any`] trait object fn as_any(&self) -> &dyn Any; @@ -299,25 +299,25 @@ pub trait LambdaUDF: Debug + DynEq + DynHash + Send + Sync { )) } - /// Returns a [`LambdaSignature`] describing the argument types for which this + /// Returns a [`HigherOrderSignature`] describing the argument types for which this /// function has an implementation, and the function's [`Volatility`]. /// - /// See [`LambdaSignature`] for more details on argument type handling + /// See [`HigherOrderSignature`] for more details on argument type handling /// and [`Self::return_field_from_args`] for computing the return type. /// /// [`Volatility`]: datafusion_expr_common::signature::Volatility - fn signature(&self) -> &LambdaSignature; + fn signature(&self) -> &HigherOrderSignature; /// Return the field of all the parameters supported by all the supported lambdas of this function /// based on the field of the value arguments. If a lambda support multiple parameters, or if multiple /// lambdas are supported and some are optional, all should be returned, /// regardless of whether they are used on a particular invocation /// - /// Tip: If you have a [`LambdaFunction`] invocation, you can call the helper - /// [`LambdaFunction::lambda_parameters`] instead of this method directly + /// Tip: If you have a [`HigherOrderFunction`] invocation, you can call the helper + /// [`HigherOrderFunction::lambda_parameters`] instead of this method directly /// - /// [`LambdaFunction`]: crate::expr::LambdaFunction - /// [`LambdaFunction::lambda_parameters`]: crate::expr::LambdaFunction::lambda_parameters + /// [`HigherOrderFunction`]: crate::expr::HigherOrderFunction + /// [`HigherOrderFunction::lambda_parameters`]: crate::expr::HigherOrderFunction::lambda_parameters /// /// Example for array_transform: /// @@ -361,16 +361,19 @@ pub trait LambdaUDF: Debug + DynEq + DynHash + Send + Sync { /// # use std::sync::Arc; /// # use arrow::datatypes::{DataType, Field, FieldRef}; /// # use datafusion_common::Result; - /// # use datafusion_expr::LambdaReturnFieldArgs; + /// # use datafusion_expr::HigherOrderReturnFieldArgs; /// # struct Example{} /// # impl Example { - /// fn return_field_from_args(&self, args: LambdaReturnFieldArgs) -> Result { + /// fn return_field_from_args(&self, args: HigherOrderReturnFieldArgs) -> Result { /// let field = Arc::new(Field::new("ignored_name", DataType::Int32, true)); /// Ok(field) /// } /// # } /// ``` - fn return_field_from_args(&self, args: LambdaReturnFieldArgs) -> Result; + fn return_field_from_args( + &self, + args: HigherOrderReturnFieldArgs, + ) -> Result; /// Whether List, LargeList and FixedSizeList arguments should have it's /// non-empty null sublists cleaned by Datafusion before invoking this function @@ -395,7 +398,7 @@ pub trait LambdaUDF: Debug + DynEq + DynHash + Send + Sync { /// /// [`ColumnarValue::values_to_arrays`] can be used to convert the arguments /// to arrays, which will likely be simpler code, but be slower. - fn invoke_with_args(&self, args: LambdaFunctionArgs) -> Result; + fn invoke_with_args(&self, args: HigherOrderFunctionArgs) -> Result; /// Returns true if some of this `exprs` subexpressions may not be evaluated /// and thus any side effects (like divide by zero) may not be encountered. @@ -403,7 +406,7 @@ pub trait LambdaUDF: Debug + DynEq + DynHash + Send + Sync { /// Setting this to true prevents certain optimizations such as common /// subexpression elimination /// - /// When overriding this function to return `true`, [LambdaUDF::conditional_arguments] can also be + /// When overriding this function to return `true`, [HigherOrderUDF::conditional_arguments] can also be /// overridden to report more accurately which arguments are eagerly evaluated and which ones /// lazily. fn short_circuits(&self) -> bool { @@ -425,7 +428,7 @@ pub trait LambdaUDF: Debug + DynEq + DynHash + Send + Sync { /// Implementations must ensure that the two returned `Vec`s are disjunct, /// and that each argument from `args` is present in one the two `Vec`s. /// - /// When overriding this function, [LambdaUDF::short_circuits] must + /// When overriding this function, [HigherOrderUDF::short_circuits] must /// be overridden to return `true`. fn conditional_arguments<'a>( &self, @@ -460,7 +463,7 @@ pub trait LambdaUDF: Debug + DynEq + DynHash + Send + Sync { ) } - /// Returns the documentation for this Lambda UDF. + /// Returns the documentation for this HigherOrderUDF. /// /// Documentation can be accessed programmatically as well as generating /// publicly facing documentation. @@ -477,12 +480,12 @@ mod tests { use std::hash::DefaultHasher; #[derive(Debug, PartialEq, Eq, Hash)] - struct TestLambdaUDF { + struct TestHigherOrderUDF { name: &'static str, field: &'static str, - signature: LambdaSignature, + signature: HigherOrderSignature, } - impl LambdaUDF for TestLambdaUDF { + impl HigherOrderUDF for TestHigherOrderUDF { fn as_any(&self) -> &dyn Any { self } @@ -491,7 +494,7 @@ mod tests { self.name } - fn signature(&self) -> &LambdaSignature { + fn signature(&self) -> &HigherOrderSignature { &self.signature } @@ -504,12 +507,15 @@ mod tests { fn return_field_from_args( &self, - _args: LambdaReturnFieldArgs, + _args: HigherOrderReturnFieldArgs, ) -> Result { unimplemented!() } - fn invoke_with_args(&self, _args: LambdaFunctionArgs) -> Result { + fn invoke_with_args( + &self, + _args: HigherOrderFunctionArgs, + ) -> Result { unimplemented!() } } @@ -545,11 +551,11 @@ mod tests { assert_eq!(b.partial_cmp(&o), Some(Ordering::Less)); } - fn test_func(name: &'static str, parameter: &'static str) -> Arc { - Arc::new(TestLambdaUDF { + fn test_func(name: &'static str, parameter: &'static str) -> Arc { + Arc::new(TestHigherOrderUDF { name, field: parameter, - signature: LambdaSignature::variadic_any(Volatility::Immutable), + signature: HigherOrderSignature::variadic_any(Volatility::Immutable), }) } diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 85cd5af37b7a8..0445816f81cb9 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -317,7 +317,7 @@ pub fn expr_to_columns(expr: &Expr, accum: &mut HashSet) -> Result<()> { | Expr::Wildcard { .. } | Expr::Placeholder(_) | Expr::OuterReferenceColumn { .. } - | Expr::LambdaFunction(_) + | Expr::HigherOrderFunction(_) | Expr::Lambda(_) | Expr::LambdaVariable(_) => {} } diff --git a/datafusion/ffi/src/session/mod.rs b/datafusion/ffi/src/session/mod.rs index e6f96450aee27..c2b1ade8265a5 100644 --- a/datafusion/ffi/src/session/mod.rs +++ b/datafusion/ffi/src/session/mod.rs @@ -33,7 +33,7 @@ use datafusion_execution::config::SessionConfig; use datafusion_execution::runtime_env::RuntimeEnv; use datafusion_expr::execution_props::ExecutionProps; use datafusion_expr::{ - AggregateUDF, AggregateUDFImpl, Expr, LambdaUDF, LogicalPlan, ScalarUDF, + AggregateUDF, AggregateUDFImpl, Expr, HigherOrderUDF, LogicalPlan, ScalarUDF, ScalarUDFImpl, WindowUDF, WindowUDFImpl, }; use datafusion_physical_expr::PhysicalExpr; @@ -376,7 +376,7 @@ pub struct ForeignSession { session: FFI_SessionRef, config: SessionConfig, scalar_functions: HashMap>, - lambda_functions: HashMap>, + higher_order_functions: HashMap>, aggregate_functions: HashMap>, window_functions: HashMap>, table_options: TableOptions, @@ -445,7 +445,7 @@ impl TryFrom<&FFI_SessionRef> for ForeignSession { config, table_options, scalar_functions, - lambda_functions: HashMap::new(), + higher_order_functions: HashMap::new(), aggregate_functions, window_functions, runtime_env: Default::default(), @@ -586,8 +586,8 @@ impl Session for ForeignSession { &self.scalar_functions } - fn lambda_functions(&self) -> &HashMap> { - &self.lambda_functions + fn higher_order_functions(&self) -> &HashMap> { + &self.higher_order_functions } fn aggregate_functions(&self) -> &HashMap> { diff --git a/datafusion/functions-nested/src/array_transform.rs b/datafusion/functions-nested/src/array_transform.rs index e9ed199d55dae..2653d874af200 100644 --- a/datafusion/functions-nested/src/array_transform.rs +++ b/datafusion/functions-nested/src/array_transform.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -//! [`LambdaUDF`] definitions for array_transform function. +//! [`HigherOrderUDF`] definitions for array_transform function. use arrow::{ array::{ @@ -29,18 +29,18 @@ use datafusion_common::{ utils::{adjust_offsets_for_slice, list_values, take_function_args}, }; use datafusion_expr::{ - ColumnarValue, Documentation, LambdaFunctionArgs, LambdaReturnFieldArgs, - LambdaSignature, LambdaUDF, ValueOrLambda, Volatility, + ColumnarValue, Documentation, HigherOrderFunctionArgs, HigherOrderReturnFieldArgs, + HigherOrderSignature, HigherOrderUDF, ValueOrLambda, Volatility, }; use datafusion_macros::user_doc; use std::{any::Any, fmt::Debug, sync::Arc}; -make_udlf_expr_and_func!( +make_udhof_expr_and_func!( ArrayTransform, array_transform, array lambda, "transforms the values of a array", - array_transform_udlf + array_transform_udhof ); #[user_doc( @@ -63,7 +63,7 @@ make_udlf_expr_and_func!( )] #[derive(Debug, PartialEq, Eq, Hash)] pub struct ArrayTransform { - signature: LambdaSignature, + signature: HigherOrderSignature, aliases: Vec, } @@ -76,13 +76,13 @@ impl Default for ArrayTransform { impl ArrayTransform { pub fn new() -> Self { Self { - signature: LambdaSignature::user_defined(Volatility::Immutable), + signature: HigherOrderSignature::user_defined(Volatility::Immutable), aliases: vec![String::from("list_transform")], } } } -impl LambdaUDF for ArrayTransform { +impl HigherOrderUDF for ArrayTransform { fn as_any(&self) -> &dyn Any { self } @@ -95,7 +95,7 @@ impl LambdaUDF for ArrayTransform { &self.aliases } - fn signature(&self) -> &LambdaSignature { + fn signature(&self) -> &HigherOrderSignature { &self.signature } @@ -154,7 +154,10 @@ impl LambdaUDF for ArrayTransform { Ok(vec![vec![value]]) } - fn return_field_from_args(&self, args: LambdaReturnFieldArgs) -> Result> { + fn return_field_from_args( + &self, + args: HigherOrderReturnFieldArgs, + ) -> Result> { let (list, lambda) = value_lambda_pair(self.name(), args.arg_fields)?; //TODO: should metadata be copied into the transformed array? @@ -177,7 +180,7 @@ impl LambdaUDF for ArrayTransform { Ok(Arc::new(Field::new("", return_type, list.is_nullable()))) } - fn invoke_with_args(&self, args: LambdaFunctionArgs) -> Result { + fn invoke_with_args(&self, args: HigherOrderFunctionArgs) -> Result { let (list, lambda) = value_lambda_pair(self.name(), &args.args)?; let list_array = list.to_array(args.number_rows)?; @@ -295,12 +298,12 @@ mod tests { }; use datafusion_common::{DFSchema, Result}; use datafusion_expr::{ - Expr, col, execution_props::ExecutionProps, expr::LambdaFunction, lambda, + Expr, col, execution_props::ExecutionProps, expr::HigherOrderFunction, lambda, lambda_var, lit, }; use datafusion_physical_expr::create_physical_expr; - use crate::array_transform::array_transform_udlf; + use crate::array_transform::array_transform_udhof; fn create_i32_list( values: impl Into, @@ -326,7 +329,7 @@ mod tests { } fn divide_100_by(list: impl Array + Clone + 'static) -> Result { - let array_transform = array_transform_udlf(); + let array_transform = array_transform_udhof(); let schema = DFSchema::from_unqualified_fields( vec![Field::new( @@ -339,7 +342,7 @@ mod tests { )?; create_physical_expr( - &Expr::LambdaFunction(LambdaFunction::new( + &Expr::HigherOrderFunction(HigherOrderFunction::new( array_transform, vec![ col("list"), diff --git a/datafusion/functions-nested/src/lib.rs b/datafusion/functions-nested/src/lib.rs index 71442d33ffb9b..d997461e2803e 100644 --- a/datafusion/functions-nested/src/lib.rs +++ b/datafusion/functions-nested/src/lib.rs @@ -75,7 +75,7 @@ pub mod utils; use datafusion_common::Result; use datafusion_execution::FunctionRegistry; -use datafusion_expr::{LambdaUDF, ScalarUDF}; +use datafusion_expr::{HigherOrderUDF, ScalarUDF}; use log::debug; use std::sync::Arc; @@ -183,8 +183,8 @@ pub fn all_default_nested_functions() -> Vec> { ] } -pub fn all_default_lambda_functions() -> Vec> { - vec![array_transform::array_transform_udlf()] +pub fn all_default_higher_order_functions() -> Vec> { + vec![array_transform::array_transform_udhof()] } /// Registers all enabled packages with a [`FunctionRegistry`] @@ -198,11 +198,11 @@ pub fn register_all(registry: &mut dyn FunctionRegistry) -> Result<()> { Ok(()) as Result<()> })?; - let functions: Vec> = all_default_lambda_functions(); - functions.into_iter().try_for_each(|udlf| { - let existing_udlf = registry.register_udlf(udlf)?; - if let Some(existing_udlf) = existing_udlf { - debug!("Overwrite existing UDLF: {}", existing_udlf.name()); + let functions: Vec> = all_default_higher_order_functions(); + functions.into_iter().try_for_each(|udhof| { + let existing_udhof = registry.register_udhof(udhof)?; + if let Some(existing_udhof) = existing_udhof { + debug!("Overwrite existing UDHOF: {}", existing_udhof.name()); } Ok(()) as Result<()> })?; @@ -212,7 +212,7 @@ pub fn register_all(registry: &mut dyn FunctionRegistry) -> Result<()> { #[cfg(test)] mod tests { - use crate::{all_default_lambda_functions, all_default_nested_functions}; + use crate::{all_default_higher_order_functions, all_default_nested_functions}; use datafusion_common::Result; use std::collections::HashSet; @@ -221,7 +221,7 @@ mod tests { let scalars = all_default_nested_functions(); let scalars = scalars.iter().map(|s| (s.name(), s.aliases())); - let lambdas = all_default_lambda_functions(); + let lambdas = all_default_higher_order_functions(); let lambdas = lambdas.iter().map(|l| (l.name(), l.aliases())); let mut names = HashSet::new(); diff --git a/datafusion/functions-nested/src/macros_lambda.rs b/datafusion/functions-nested/src/macros_lambda.rs index cd8b4dbdfd263..7d210ddb2a744 100644 --- a/datafusion/functions-nested/src/macros_lambda.rs +++ b/datafusion/functions-nested/src/macros_lambda.rs @@ -17,12 +17,12 @@ /// Creates external API functions for an array UDF. Specifically, creates /// -/// 1. Single `LambdaUDF` instance +/// 1. Single `HigherOrderUDF` instance /// -/// Creates a singleton `LambdaUDF` of the `$UDF` function named `STATIC_$(UDF)` and a -/// function named `$LAMBDA_UDF_FUNC` which returns that function named `STATIC_$(UDF)`. +/// Creates a singleton `HigherOrderUDF` of the `$UDF` function named `STATIC_$(UDF)` and a +/// function named `$HIGHER_ORDER_UDF_FUNC` which returns that function named `STATIC_$(UDF)`. /// -/// This is used to ensure creating the list of `LambdaUDF` only happens once. +/// This is used to ensure creating the list of `HigherOrderUDF` only happens once. /// /// # 2. `expr_fn` style function /// @@ -36,68 +36,68 @@ /// } /// ``` /// # Arguments -/// * `UDF`: name of the [`LambdaUDF`] +/// * `UDF`: name of the [`HigherOrderUDF`] /// * `EXPR_FN`: name of the expr_fn function to be created /// * `arg`: 0 or more named arguments for the function /// * `DOC`: documentation string for the function -/// * `LAMBDA_UDF_FUNC`: name of the function to create (just) the `LambdaUDF` +/// * `HIGHER_ORDER_UDF_FUNC`: name of the function to create (just) the `HigherOrderUDF` /// * (optional) `$CTOR`: Pass a custom constructor. When omitted it /// automatically resolves to `$UDF::new()`. /// -/// [`LambdaUDF`]: datafusion_expr::LambdaUDF -macro_rules! make_udlf_expr_and_func { - ($UDF:ident, $EXPR_FN:ident, $($arg:ident)*, $DOC:expr, $LAMBDA_UDF_FN:ident) => { - make_udlf_expr_and_func!($UDF, $EXPR_FN, $($arg)*, $DOC, $LAMBDA_UDF_FN, $UDF::new); +/// [`HigherOrderUDF`]: datafusion_expr::HigherOrderUDF +macro_rules! make_udhof_expr_and_func { + ($UDF:ident, $EXPR_FN:ident, $($arg:ident)*, $DOC:expr, $HIGHER_ORDER_UDF_FN:ident) => { + make_udhof_expr_and_func!($UDF, $EXPR_FN, $($arg)*, $DOC, $HIGHER_ORDER_UDF_FN, $UDF::new); }; - ($UDF:ident, $EXPR_FN:ident, $($arg:ident)*, $DOC:expr, $LAMBDA_UDF_FN:ident, $CTOR:path) => { + ($UDF:ident, $EXPR_FN:ident, $($arg:ident)*, $DOC:expr, $HIGHER_ORDER_UDF_FN:ident, $CTOR:path) => { // "fluent expr_fn" style function #[doc = $DOC] pub fn $EXPR_FN($($arg: datafusion_expr::Expr),*) -> datafusion_expr::Expr { - datafusion_expr::Expr::LambdaFunction(datafusion_expr::expr::LambdaFunction::new( - $LAMBDA_UDF_FN(), + datafusion_expr::Expr::HigherOrderFunction(datafusion_expr::expr::HigherOrderFunction::new( + $HIGHER_ORDER_UDF_FN(), vec![$($arg),*], )) } - create_lambda!($UDF, $LAMBDA_UDF_FN, $CTOR); + create_higher_order!($UDF, $HIGHER_ORDER_UDF_FN, $CTOR); }; - ($UDF:ident, $EXPR_FN:ident, $DOC:expr, $LAMBDA_UDF_FN:ident) => { - make_udlf_expr_and_func!($UDF, $EXPR_FN, $DOC, $LAMBDA_UDF_FN, $UDF::new); + ($UDF:ident, $EXPR_FN:ident, $DOC:expr, $HIGHER_ORDER_UDF_FN:ident) => { + make_udhof_expr_and_func!($UDF, $EXPR_FN, $DOC, $HIGHER_ORDER_UDF_FN, $UDF::new); }; - ($UDF:ident, $EXPR_FN:ident, $DOC:expr, $LAMBDA_UDF_FN:ident, $CTOR:path) => { + ($UDF:ident, $EXPR_FN:ident, $DOC:expr, $HIGHER_ORDER_UDF_FN:ident, $CTOR:path) => { // "fluent expr_fn" style function #[doc = $DOC] pub fn $EXPR_FN(arg: Vec) -> datafusion_expr::Expr { - datafusion_expr::Expr::LambdaFunction(datafusion_expr::expr::LambdaFunction::new( - $LAMBDA_UDF_FN(), + datafusion_expr::Expr::HigherOrderFunction(datafusion_expr::expr::HigherOrderFunction::new( + $HIGHER_ORDER_UDF_FN(), arg, )) } - create_lambda!($UDF, $LAMBDA_UDF_FN, $CTOR); + create_higher_order!($UDF, $HIGHER_ORDER_UDF_FN, $CTOR); }; } -/// Creates a singleton `LambdaUDF` of the `$UDF` function named `STATIC_$(UDF)` and a -/// function named `$LAMBDA_UDF_FUNC` which returns that function named `STATIC_$(UDF)`. +/// Creates a singleton `HigherOrderUDF` of the `$UDF` function named `STATIC_$(UDF)` and a +/// function named `$HIGHER_ORDER_UDF_FUNC` which returns that function named `STATIC_$(UDF)`. /// -/// This is used to ensure creating the list of `LambdaUDF` only happens once. +/// This is used to ensure creating the list of `HigherOrderUDF` only happens once. /// /// # Arguments -/// * `UDF`: name of the [`LambdaUDF`] -/// * `LAMBDA_UDF_FUNC`: name of the function to create (just) the `LambdaUDF` +/// * `UDF`: name of the [`HigherOrderUDF`] +/// * `HIGHER_ORDER_UDF_FUNC`: name of the function to create (just) the `HigherOrderUDF` /// * (optional) `$CTOR`: Pass a custom constructor. When omitted it /// automatically resolves to `$UDF::new()`. /// -/// [`LambdaUDF`]: datafusion_expr::LambdaUDF -macro_rules! create_lambda { - ($UDF:ident, $LAMBDA_UDF_FN:ident) => { - create_lambda!($UDF, $LAMBDA_UDF_FN, $UDF::new); +/// [`HigherOrderUDF`]: datafusion_expr::HigherOrderUDF +macro_rules! create_higher_order { + ($UDF:ident, $HIGHER_ORDER_UDF_FN:ident) => { + create_higher_order!($UDF, $HIGHER_ORDER_UDF_FN, $UDF::new); }; - ($UDF:ident, $LAMBDA_UDF_FN:ident, $CTOR:path) => { - #[doc = concat!("LambdaFunction that returns a [`LambdaUDF`](datafusion_expr::LambdaUDF) for ")] + ($UDF:ident, $HIGHER_ORDER_UDF_FN:ident, $CTOR:path) => { + #[doc = concat!("HigherOrderFunction that returns a [`HigherOrderUDF`](datafusion_expr::HigherOrderUDF) for ")] #[doc = stringify!($UDF)] - pub fn $LAMBDA_UDF_FN() -> std::sync::Arc { + pub fn $HIGHER_ORDER_UDF_FN() -> std::sync::Arc { // Singleton instance of [`$UDF`], ensures the UDF is only created once - static INSTANCE: std::sync::LazyLock> = + static INSTANCE: std::sync::LazyLock> = std::sync::LazyLock::new(|| { std::sync::Arc::new($CTOR()) }); diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 37e1aef0ac94a..18e4e91c7a5f7 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -35,15 +35,15 @@ use datafusion_common::{ plan_err, }; use datafusion_expr::expr::{ - self, AggregateFunctionParams, Alias, Between, BinaryExpr, Case, Exists, InList, - InSubquery, LambdaFunction, Like, ScalarFunction, SetComparison, Sort, + self, AggregateFunctionParams, Alias, Between, BinaryExpr, Case, Exists, + HigherOrderFunction, InList, InSubquery, Like, ScalarFunction, SetComparison, Sort, WindowFunction, }; use datafusion_expr::expr_rewriter::coerce_plan_expr_for_schema; use datafusion_expr::expr_schema::cast_subquery; use datafusion_expr::logical_plan::Subquery; use datafusion_expr::type_coercion::binary::{comparison_coercion, like_coercion}; -use datafusion_expr::type_coercion::functions::value_fields_with_lambda_udf; +use datafusion_expr::type_coercion::functions::value_fields_with_higher_order_udf; use datafusion_expr::type_coercion::functions::{UDFCoercionExt, fields_with_udf}; use datafusion_expr::type_coercion::other::{ get_coerce_type_for_case_expression, get_coerce_type_for_list, @@ -762,7 +762,7 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> { }); Ok(Transformed::yes(new_expr)) } - Expr::LambdaFunction(LambdaFunction { func, args }) => { + Expr::HigherOrderFunction(HigherOrderFunction { func, args }) => { let current_fields = args .iter() .map(|arg| match arg { @@ -772,20 +772,20 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> { .collect::>>()?; let new_fields = - value_fields_with_lambda_udf(¤t_fields, func.as_ref())?; + value_fields_with_higher_order_udf(¤t_fields, func.as_ref())?; let new_args = std::iter::zip(args, new_fields) .map(|(arg, new_field)| match (&arg, new_field) { (Expr::Lambda(_lambda), ValueOrLambda::Lambda(_)) => Ok(arg), - (Expr::Lambda(_lambda), ValueOrLambda::Value(_)) => plan_err!("value_fields_with_lambda_udf return a value for a lambda argument"), + (Expr::Lambda(_lambda), ValueOrLambda::Value(_)) => plan_err!("value_fields_with_higher_order_udf return a value for a lambda argument"), (_, ValueOrLambda::Value(new_field)) => arg.cast_to(new_field.data_type(), self.schema), - (_, ValueOrLambda::Lambda(_)) => plan_err!("value_fields_with_lambda_udf return a lambda for a value argument"), + (_, ValueOrLambda::Lambda(_)) => plan_err!("value_fields_with_higher_order_udf return a lambda for a value argument"), }) .collect::>()?; - Ok(Transformed::yes(Expr::LambdaFunction(LambdaFunction::new( - func, new_args, - )))) + Ok(Transformed::yes(Expr::HigherOrderFunction( + HigherOrderFunction::new(func, new_args), + ))) } // TODO: remove the next line after `Expr::Wildcard` is removed #[expect(deprecated)] diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index eed5c46f080f6..2a7bc2fc50e1b 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -30,7 +30,7 @@ use datafusion_common::alias::AliasGenerator; use datafusion_common::cse::{CSE, CSEController, FoundCommonNodes}; use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::{Column, DFSchema, DFSchemaRef, Result, qualified_name}; -use datafusion_expr::expr::{Alias, LambdaFunction, ScalarFunction}; +use datafusion_expr::expr::{Alias, HigherOrderFunction, ScalarFunction}; use datafusion_expr::logical_plan::{ Aggregate, Filter, LogicalPlan, Projection, Sort, Window, }; @@ -651,13 +651,13 @@ impl CSEController for ExprCSEController<'_> { fn conditional_children(node: &Expr) -> Option<(Vec<&Expr>, Vec<&Expr>)> { match node { - // In case of `ScalarFunction`s and `LambdaFunction`s we don't know which children are surely + // In case of `ScalarFunction`s and `HigherOrderFunction`s we don't know which children are surely // executed so start visiting all children conditionally and stop the // recursion with `TreeNodeRecursion::Jump`. Expr::ScalarFunction(ScalarFunction { func, args }) => { func.conditional_arguments(args) } - Expr::LambdaFunction(LambdaFunction { func, args }) => { + Expr::HigherOrderFunction(HigherOrderFunction { func, args }) => { func.conditional_arguments(args) } diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index 8597a452fe5f2..0ddfd072f9b88 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -275,7 +275,7 @@ fn can_evaluate_as_join_condition(predicate: &Expr) -> Result { | Expr::TryCast(_) | Expr::InList { .. } | Expr::ScalarFunction(_) - | Expr::LambdaFunction(_) + | Expr::HigherOrderFunction(_) | Expr::Lambda(_) | Expr::LambdaVariable(_) => Ok(TreeNodeRecursion::Continue), // TODO: remove the next line after `Expr::Wildcard` is removed diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index fda21cd2ad4d6..43bc3c0a7fda2 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -38,7 +38,7 @@ use datafusion_common::{ metadata::FieldMetadata, tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeRewriter}, }; -use datafusion_expr::expr::LambdaFunction; +use datafusion_expr::expr::HigherOrderFunction; use datafusion_expr::{ BinaryExpr, Case, ColumnarValue, Expr, ExprSchemable, Like, Operator, Volatility, and, binary::BinaryTypeCoercer, lit, or, preimage::PreimageResult, @@ -647,7 +647,7 @@ impl ConstEvaluator { Expr::ScalarFunction(ScalarFunction { func, .. }) => { Self::volatility_ok(func.signature().volatility) } - Expr::LambdaFunction(LambdaFunction { func, .. }) => { + Expr::HigherOrderFunction(HigherOrderFunction { func, .. }) => { Self::volatility_ok(func.signature().volatility) } Expr::Cast(Cast { expr, field }) | Expr::TryCast(TryCast { expr, field }) => { diff --git a/datafusion/optimizer/tests/optimizer_integration.rs b/datafusion/optimizer/tests/optimizer_integration.rs index 27e46a98bed76..76d193112ee01 100644 --- a/datafusion/optimizer/tests/optimizer_integration.rs +++ b/datafusion/optimizer/tests/optimizer_integration.rs @@ -735,10 +735,10 @@ impl ContextProvider for MyContextProvider { None } - fn get_lambda_meta( + fn get_higher_order_meta( &self, _name: &str, - ) -> Option> { + ) -> Option> { None } @@ -770,7 +770,7 @@ impl ContextProvider for MyContextProvider { Vec::new() } - fn udlf_names(&self) -> Vec { + fn udhof_names(&self) -> Vec { Vec::new() } diff --git a/datafusion/physical-expr/src/lambda_function.rs b/datafusion/physical-expr/src/higher_order_function.rs similarity index 86% rename from datafusion/physical-expr/src/lambda_function.rs rename to datafusion/physical-expr/src/higher_order_function.rs index 6823ff4d4d4f4..d77f521d65515 100644 --- a/datafusion/physical-expr/src/lambda_function.rs +++ b/datafusion/physical-expr/src/higher_order_function.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -//! Declaration of built-in (lambda) functions. +//! Declaration of built-in (higher order) functions. //! This module contains built-in functions' enumeration and metadata. //! //! Generally, a function has: @@ -44,24 +44,24 @@ use datafusion_common::utils::remove_list_null_values; use datafusion_common::{ Result, ScalarValue, exec_err, internal_datafusion_err, internal_err, }; -use datafusion_expr::type_coercion::functions::value_fields_with_lambda_udf; +use datafusion_expr::type_coercion::functions::value_fields_with_higher_order_udf; use datafusion_expr::{ - ColumnarValue, LambdaArgument, LambdaFunctionArgs, LambdaReturnFieldArgs, LambdaUDF, - ValueOrLambda, Volatility, expr_vec_fmt, + ColumnarValue, HigherOrderFunctionArgs, HigherOrderReturnFieldArgs, HigherOrderUDF, + LambdaArgument, ValueOrLambda, Volatility, expr_vec_fmt, }; -/// Physical expression of a lambda function -pub struct LambdaFunctionExpr { - fun: Arc, +/// Physical expression of a higher order function +pub struct HigherOrderFunctionExpr { + fun: Arc, name: String, args: Vec>, return_field: FieldRef, config_options: Arc, } -impl Debug for LambdaFunctionExpr { +impl Debug for HigherOrderFunctionExpr { fn fmt(&self, f: &mut Formatter) -> fmt::Result { - f.debug_struct("LambdaFunctionExpr") + f.debug_struct("HigherOrderFunctionExpr") .field("fun", &"") .field("name", &self.name) .field("args", &self.args) @@ -70,11 +70,11 @@ impl Debug for LambdaFunctionExpr { } } -impl LambdaFunctionExpr { - /// Create a new Lambda function +impl HigherOrderFunctionExpr { + /// Create a new Higher Order function pub fn new( name: impl Into, - fun: Arc, + fun: Arc, args: Vec>, return_field: FieldRef, config_options: Arc, @@ -88,9 +88,9 @@ impl LambdaFunctionExpr { } } - /// Create a new Lambda function + /// Create a new Higher Order function pub fn try_new( - fun: Arc, + fun: Arc, args: Vec>, schema: &Schema, config_options: Arc, @@ -107,8 +107,8 @@ impl LambdaFunctionExpr { }) .collect::>>()?; - // verify that input data types is consistent with function's `LambdaTypeSignature` - value_fields_with_lambda_udf(&arg_fields, fun.as_ref())?; + // verify that input data types is consistent with function's `HigherOrderTypeSignature` + value_fields_with_higher_order_udf(&arg_fields, fun.as_ref())?; let arguments = args .iter() @@ -119,7 +119,7 @@ impl LambdaFunctionExpr { }) .collect::>(); - let ret_args = LambdaReturnFieldArgs { + let ret_args = HigherOrderReturnFieldArgs { arg_fields: &arg_fields, scalar_arguments: &arguments, }; @@ -135,8 +135,8 @@ impl LambdaFunctionExpr { }) } - /// Get the lambda function implementation - pub fn fun(&self) -> &dyn LambdaUDF { + /// Get the higher order function implementation + pub fn fun(&self) -> &dyn HigherOrderUDF { self.fun.as_ref() } @@ -174,13 +174,13 @@ impl LambdaFunctionExpr { } } -impl fmt::Display for LambdaFunctionExpr { +impl fmt::Display for HigherOrderFunctionExpr { fn fmt(&self, f: &mut Formatter) -> fmt::Result { write!(f, "{}({})", self.name, expr_vec_fmt!(self.args)) } } -impl PartialEq for LambdaFunctionExpr { +impl PartialEq for HigherOrderFunctionExpr { fn eq(&self, o: &Self) -> bool { if std::ptr::eq(self, o) { // The equality implementation is somewhat expensive, so let's short-circuit when possible. @@ -202,8 +202,8 @@ impl PartialEq for LambdaFunctionExpr { == sorted_config_entries(&o.config_options)) } } -impl Eq for LambdaFunctionExpr {} -impl Hash for LambdaFunctionExpr { +impl Eq for HigherOrderFunctionExpr {} +impl Hash for HigherOrderFunctionExpr { fn hash(&self, state: &mut H) { let Self { fun, @@ -225,7 +225,7 @@ fn sorted_config_entries(config_options: &ConfigOptions) -> Vec { entries } -impl PhysicalExpr for LambdaFunctionExpr { +impl PhysicalExpr for HigherOrderFunctionExpr { fn as_any(&self) -> &dyn Any { self } @@ -333,7 +333,7 @@ impl PhysicalExpr for LambdaFunctionExpr { .all(|arg| matches!(arg, ValueOrLambda::Value(ColumnarValue::Scalar(_)))); // evaluate the function - let output = self.fun.invoke_with_args(LambdaFunctionArgs { + let output = self.fun.invoke_with_args(HigherOrderFunctionArgs { args, arg_fields, number_rows: batch.num_rows(), @@ -373,7 +373,7 @@ impl PhysicalExpr for LambdaFunctionExpr { self: Arc, children: Vec>, ) -> Result> { - Ok(Arc::new(LambdaFunctionExpr::new( + Ok(Arc::new(HigherOrderFunctionExpr::new( &self.name, Arc::clone(&self.fun), children, @@ -404,22 +404,24 @@ mod tests { use std::sync::Arc; use super::*; - use crate::LambdaFunctionExpr; + use crate::HigherOrderFunctionExpr; use crate::expressions::Column; use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::Result; - use datafusion_expr::{LambdaFunctionArgs, LambdaSignature, LambdaUDF}; + use datafusion_expr::{ + HigherOrderFunctionArgs, HigherOrderSignature, HigherOrderUDF, + }; use datafusion_expr_common::columnar_value::ColumnarValue; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use datafusion_physical_expr_common::physical_expr::is_volatile; /// Test helper to create a mock UDF with a specific volatility #[derive(Debug, PartialEq, Eq, Hash)] - struct MockLambdaUDF { - signature: LambdaSignature, + struct MockHigherOrderUDF { + signature: HigherOrderSignature, } - impl LambdaUDF for MockLambdaUDF { + impl HigherOrderUDF for MockHigherOrderUDF { fn as_any(&self) -> &dyn Any { self } @@ -428,7 +430,7 @@ mod tests { "mock_function" } - fn signature(&self) -> &LambdaSignature { + fn signature(&self) -> &HigherOrderSignature { &self.signature } @@ -441,26 +443,29 @@ mod tests { fn return_field_from_args( &self, - _args: LambdaReturnFieldArgs, + _args: HigherOrderReturnFieldArgs, ) -> Result { Ok(Arc::new(Field::new("", DataType::Int32, false))) } - fn invoke_with_args(&self, _args: LambdaFunctionArgs) -> Result { + fn invoke_with_args( + &self, + _args: HigherOrderFunctionArgs, + ) -> Result { Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some(42)))) } } #[test] - fn test_lambda_function_volatile_node() { + fn test_higher_order_function_volatile_node() { // Create a volatile UDF - let volatile_udf = Arc::new(MockLambdaUDF { - signature: LambdaSignature::variadic_any(Volatility::Volatile), + let volatile_udf = Arc::new(MockHigherOrderUDF { + signature: HigherOrderSignature::variadic_any(Volatility::Volatile), }); // Create a non-volatile UDF - let stable_udf = Arc::new(MockLambdaUDF { - signature: LambdaSignature::variadic_any(Volatility::Stable), + let stable_udf = Arc::new(MockHigherOrderUDF { + signature: HigherOrderSignature::variadic_any(Volatility::Stable), }); let schema = Schema::new(vec![Field::new("a", DataType::Float32, false)]); @@ -468,7 +473,7 @@ mod tests { let config_options = Arc::new(ConfigOptions::new()); // Test volatile function - let volatile_expr = LambdaFunctionExpr::try_new( + let volatile_expr = HigherOrderFunctionExpr::try_new( volatile_udf, args.clone(), &schema, @@ -482,7 +487,7 @@ mod tests { // Test non-volatile function let stable_expr = - LambdaFunctionExpr::try_new(stable_udf, args, &schema, config_options) + HigherOrderFunctionExpr::try_new(stable_udf, args, &schema, config_options) .unwrap(); assert!(!stable_expr.is_volatile_node()); diff --git a/datafusion/physical-expr/src/lib.rs b/datafusion/physical-expr/src/lib.rs index 3476f32d4e487..b1666ca95b2e2 100644 --- a/datafusion/physical-expr/src/lib.rs +++ b/datafusion/physical-expr/src/lib.rs @@ -34,8 +34,8 @@ pub mod binary_map { pub mod async_scalar_function; pub mod equivalence; pub mod expressions; +pub mod higher_order_function; pub mod intervals; -pub mod lambda_function; mod partitioning; mod physical_expr; pub mod planner; @@ -70,7 +70,7 @@ pub use datafusion_physical_expr_common::sort_expr::{ PhysicalSortRequirement, }; -pub use lambda_function::LambdaFunctionExpr; +pub use higher_order_function::HigherOrderFunctionExpr; pub use planner::{create_physical_expr, create_physical_exprs}; pub use scalar_function::ScalarFunctionExpr; pub use simplifier::PhysicalExprSimplifier; diff --git a/datafusion/physical-expr/src/planner.rs b/datafusion/physical-expr/src/planner.rs index f530c3f9f6b8a..30dee27c3014f 100644 --- a/datafusion/physical-expr/src/planner.rs +++ b/datafusion/physical-expr/src/planner.rs @@ -18,7 +18,7 @@ use std::collections::HashMap; use std::sync::Arc; -use crate::{LambdaFunctionExpr, ScalarFunctionExpr}; +use crate::{HigherOrderFunctionExpr, ScalarFunctionExpr}; use crate::{ PhysicalExpr, expressions::{self, Column, Literal, binary, like, similar_to}, @@ -33,7 +33,7 @@ use datafusion_common::{ }; use datafusion_expr::execution_props::ExecutionProps; use datafusion_expr::expr::{ - Alias, Cast, InList, Lambda, LambdaFunction, LambdaVariable, Placeholder, + Alias, Cast, HigherOrderFunction, InList, Lambda, LambdaVariable, Placeholder, ScalarFunction, }; use datafusion_expr::var_provider::VarType; @@ -417,7 +417,7 @@ pub fn create_physical_expr( Expr::Placeholder(Placeholder { id, .. }) => { exec_err!("Placeholder '{id}' was not provided a value for execution.") } - Expr::LambdaFunction(invocation @ LambdaFunction { func, args }) => { + Expr::HigherOrderFunction(invocation @ HigherOrderFunction { func, args }) => { let num_lambdas = args .iter() .filter(|arg| matches!(arg, Expr::Lambda(_))) @@ -466,7 +466,7 @@ pub fn create_physical_expr( None => Arc::new(ConfigOptions::default()), }; - Ok(Arc::new(LambdaFunctionExpr::try_new( + Ok(Arc::new(HigherOrderFunctionExpr::try_new( Arc::clone(func), physical_args, input_schema, diff --git a/datafusion/proto/src/bytes/mod.rs b/datafusion/proto/src/bytes/mod.rs index 5343dc00648fb..44c6ee3779778 100644 --- a/datafusion/proto/src/bytes/mod.rs +++ b/datafusion/proto/src/bytes/mod.rs @@ -28,8 +28,8 @@ use crate::protobuf; use datafusion_common::{Result, not_impl_err, plan_datafusion_err}; use datafusion_execution::TaskContext; use datafusion_expr::{ - AggregateUDF, Expr, LambdaSignature, LambdaUDF, LogicalPlan, Volatility, WindowUDF, - create_udaf, create_udf, create_udwf, + AggregateUDF, Expr, HigherOrderSignature, HigherOrderUDF, LogicalPlan, Volatility, + WindowUDF, create_udaf, create_udf, create_udwf, }; use prost::{ Message, @@ -123,15 +123,18 @@ impl Serializeable for Expr { ))) } - fn udlf(&self, name: &str) -> Result> { - // if a SimpleLambdaFunction get's added, use it instead of MockLambdaUDF + fn udhof( + &self, + name: &str, + ) -> Result> { + // if a SimpleHigherOrderFunction get's added, use it instead of MockHigherOrderUDF #[derive(Debug, PartialEq, Eq, Hash)] - struct MockLambdaUDF { + struct MockHigherOrderUDF { name: String, - signature: LambdaSignature, + signature: HigherOrderSignature, } - impl LambdaUDF for MockLambdaUDF { + impl HigherOrderUDF for MockHigherOrderUDF { fn as_any(&self) -> &dyn std::any::Any { self } @@ -140,7 +143,7 @@ impl Serializeable for Expr { &self.name } - fn signature(&self) -> &LambdaSignature { + fn signature(&self) -> &HigherOrderSignature { &self.signature } @@ -148,27 +151,27 @@ impl Serializeable for Expr { &self, _value_fields: &[arrow::datatypes::FieldRef], ) -> Result>> { - not_impl_err!("mock LambdaUDF") + not_impl_err!("mock HigherOrderUDF") } fn return_field_from_args( &self, - _args: datafusion_expr::LambdaReturnFieldArgs, + _args: datafusion_expr::HigherOrderReturnFieldArgs, ) -> Result { - not_impl_err!("mock LambdaUDF") + not_impl_err!("mock HigherOrderUDF") } fn invoke_with_args( &self, - _args: datafusion_expr::LambdaFunctionArgs, + _args: datafusion_expr::HigherOrderFunctionArgs, ) -> Result { - not_impl_err!("mock LambdaUDF") + not_impl_err!("mock HigherOrderUDF") } } - Ok(Arc::new(MockLambdaUDF { + Ok(Arc::new(MockHigherOrderUDF { name: name.to_string(), - signature: LambdaSignature::variadic_any(Volatility::Immutable), + signature: HigherOrderSignature::variadic_any(Volatility::Immutable), })) } @@ -208,12 +211,12 @@ impl Serializeable for Expr { "register_udf called in Placeholder Registry!" ) } - fn register_udlf( + fn register_udhof( &mut self, - _udlf: Arc, - ) -> Result>> { + _udhof: Arc, + ) -> Result>> { datafusion_common::internal_err!( - "register_udlf called in Placeholder Registry!" + "register_udhof called in Placeholder Registry!" ) } fn register_udwf( @@ -229,7 +232,7 @@ impl Serializeable for Expr { vec![] } - fn udlfs(&self) -> std::collections::HashSet { + fn udhofs(&self) -> std::collections::HashSet { std::collections::HashSet::default() } diff --git a/datafusion/proto/src/bytes/registry.rs b/datafusion/proto/src/bytes/registry.rs index 880ddc03ecd1f..10285382fdfae 100644 --- a/datafusion/proto/src/bytes/registry.rs +++ b/datafusion/proto/src/bytes/registry.rs @@ -38,9 +38,9 @@ impl FunctionRegistry for NoRegistry { ) } - fn udlf(&self, name: &str) -> Result> { + fn udhof(&self, name: &str) -> Result> { plan_err!( - "No function registry provided to deserialize, so can not deserialize User Defined Lambda Function '{name}'" + "No function registry provided to deserialize, so can not deserialize User Defined Higher Order Function '{name}'" ) } @@ -81,7 +81,7 @@ impl FunctionRegistry for NoRegistry { vec![] } - fn udlfs(&self) -> HashSet { + fn udhofs(&self) -> HashSet { HashSet::new() } diff --git a/datafusion/proto/src/logical_plan/mod.rs b/datafusion/proto/src/logical_plan/mod.rs index 69a56f1192f5e..9788b50ddcb13 100644 --- a/datafusion/proto/src/logical_plan/mod.rs +++ b/datafusion/proto/src/logical_plan/mod.rs @@ -54,7 +54,7 @@ use datafusion_datasource_json::file_format::{ #[cfg(feature = "parquet")] use datafusion_datasource_parquet::file_format::{ParquetFormat, ParquetFormatFactory}; use datafusion_expr::{ - AggregateUDF, DmlStatement, FetchType, LambdaUDF, RecursiveQuery, SkipType, + AggregateUDF, DmlStatement, FetchType, HigherOrderUDF, RecursiveQuery, SkipType, TableSource, Unnest, }; use datafusion_expr::{ @@ -156,11 +156,21 @@ pub trait LogicalExtensionCodec: Debug + Send + Sync + std::any::Any { Ok(()) } - fn try_decode_udlf(&self, name: &str, _buf: &[u8]) -> Result> { - not_impl_err!("LogicalExtensionCodec is not provided for lambda function {name}") + fn try_decode_udhof( + &self, + name: &str, + _buf: &[u8], + ) -> Result> { + not_impl_err!( + "LogicalExtensionCodec is not provided for higher order function {name}" + ) } - fn try_encode_udlf(&self, _node: &dyn LambdaUDF, _buf: &mut Vec) -> Result<()> { + fn try_encode_udhof( + &self, + _node: &dyn HigherOrderUDF, + _buf: &mut Vec, + ) -> Result<()> { Ok(()) } diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 9d417b09b3682..888727fbe0883 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -626,7 +626,7 @@ pub fn serialize_expr( .unwrap_or(HashMap::new()), })), }, - Expr::LambdaFunction(_) | Expr::Lambda(_) | Expr::LambdaVariable(_) => { + Expr::HigherOrderFunction(_) | Expr::Lambda(_) | Expr::LambdaVariable(_) => { return Err(Error::General( "Proto serialization error: Lambda not implemented".to_string(), )); diff --git a/datafusion/session/src/session.rs b/datafusion/session/src/session.rs index 00ac5534debeb..0e192788862ee 100644 --- a/datafusion/session/src/session.rs +++ b/datafusion/session/src/session.rs @@ -22,7 +22,9 @@ use datafusion_execution::TaskContext; use datafusion_execution::config::SessionConfig; use datafusion_execution::runtime_env::RuntimeEnv; use datafusion_expr::execution_props::ExecutionProps; -use datafusion_expr::{AggregateUDF, Expr, LambdaUDF, LogicalPlan, ScalarUDF, WindowUDF}; +use datafusion_expr::{ + AggregateUDF, Expr, HigherOrderUDF, LogicalPlan, ScalarUDF, WindowUDF, +}; use datafusion_physical_plan::{ExecutionPlan, PhysicalExpr}; use parking_lot::{Mutex, RwLock}; use std::any::Any; @@ -110,8 +112,8 @@ pub trait Session: Send + Sync { /// Return reference to scalar_functions fn scalar_functions(&self) -> &HashMap>; - /// Return reference to lambda_functions - fn lambda_functions(&self) -> &HashMap>; + /// Return reference to higher_order_functions + fn higher_order_functions(&self) -> &HashMap>; /// Return reference to aggregate_functions fn aggregate_functions(&self) -> &HashMap>; @@ -152,7 +154,7 @@ impl From<&dyn Session> for TaskContext { state.session_id().to_string(), state.config().clone(), state.scalar_functions().clone(), - state.lambda_functions().clone(), + state.higher_order_functions().clone(), state.aggregate_functions().clone(), state.window_functions().clone(), Arc::clone(state.runtime_env()), diff --git a/datafusion/spark/src/lib.rs b/datafusion/spark/src/lib.rs index 71ee9c5f5d2e8..0a8f617bc59de 100644 --- a/datafusion/spark/src/lib.rs +++ b/datafusion/spark/src/lib.rs @@ -43,7 +43,7 @@ //! //! ``` //! # use datafusion_execution::FunctionRegistry; -//! # use datafusion_expr::{ScalarUDF, AggregateUDF, WindowUDF, LambdaUDF}; +//! # use datafusion_expr::{ScalarUDF, AggregateUDF, WindowUDF, HigherOrderUDF}; //! # use datafusion_expr::planner::ExprPlanner; //! # use datafusion_common::Result; //! # use std::collections::HashSet; @@ -55,11 +55,11 @@ //! # impl FunctionRegistry for SessionContext { //! # fn register_udf(&mut self, _udf: Arc) -> Result>> { Ok (None) } //! # fn udfs(&self) -> HashSet { unimplemented!() } -//! # fn udlfs(&self) -> HashSet { unimplemented!() } +//! # fn udhofs(&self) -> HashSet { unimplemented!() } //! # fn udafs(&self) -> HashSet { unimplemented!() } //! # fn udwfs(&self) -> HashSet { unimplemented!() } //! # fn udf(&self, _name: &str) -> Result> { unimplemented!() } -//! # fn udlf(&self, name: &str) -> Result> { unimplemented!() } +//! # fn udhof(&self, name: &str) -> Result> { unimplemented!() } //! # fn udaf(&self, name: &str) -> Result> {unimplemented!() } //! # fn udwf(&self, name: &str) -> Result> { unimplemented!() } //! # fn expr_planners(&self) -> Vec> { unimplemented!() } diff --git a/datafusion/sql/examples/sql.rs b/datafusion/sql/examples/sql.rs index 09c2898aa1b2a..42a68e298e4ab 100644 --- a/datafusion/sql/examples/sql.rs +++ b/datafusion/sql/examples/sql.rs @@ -25,7 +25,7 @@ use datafusion_expr::planner::ExprPlanner; use datafusion_expr::{ AggregateUDF, ScalarUDF, TableSource, logical_plan::builder::LogicalTableSource, }; -use datafusion_expr::{LambdaUDF, WindowUDF}; +use datafusion_expr::{HigherOrderUDF, WindowUDF}; use datafusion_functions::core::planner::CoreFunctionPlanner; use datafusion_functions_aggregate::count::count_udaf; use datafusion_functions_aggregate::sum::sum_udaf; @@ -138,7 +138,7 @@ impl ContextProvider for MyContextProvider { None } - fn get_lambda_meta(&self, _name: &str) -> Option> { + fn get_higher_order_meta(&self, _name: &str) -> Option> { None } @@ -162,7 +162,7 @@ impl ContextProvider for MyContextProvider { Vec::new() } - fn udlf_names(&self) -> Vec { + fn udhof_names(&self) -> Vec { Vec::new() } diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index a61653e472103..749a5b44197dc 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -28,11 +28,11 @@ use datafusion_expr::{ Expr, ExprSchemable, SortExpr, ValueOrLambda, WindowFrame, WindowFunctionDefinition, arguments::ArgumentName, expr::{ - self, Lambda, LambdaFunction, NullTreatment, ScalarFunction, Unnest, + self, HigherOrderFunction, Lambda, NullTreatment, ScalarFunction, Unnest, WildcardOptions, WindowFunction, }, planner::{PlannerResult, RawAggregateExpr, RawWindowExpr}, - type_coercion::functions::value_fields_with_lambda_udf, + type_coercion::functions::value_fields_with_higher_order_udf, }; use sqlparser::ast::{ DuplicateTreatment, Expr as SQLExpr, Function as SQLFunction, FunctionArg, @@ -62,7 +62,7 @@ pub fn suggest_valid_function( let mut funcs = Vec::new(); funcs.extend(ctx.udf_names()); - funcs.extend(ctx.udlf_names()); + funcs.extend(ctx.udhof_names()); funcs.extend(ctx.udaf_names()); funcs @@ -369,9 +369,9 @@ impl SqlToRel<'_, S> { } } - if let Some(fm) = self.context_provider.get_lambda_meta(&name) { + if let Some(fm) = self.context_provider.get_higher_order_meta(&name) { // plan non-lambda arguments first so we can get theirs datatype and call - // LambdaUDF::lambda_parameters to then plan the lambda arguments with + // HigherOrderUDF::lambda_parameters to then plan the lambda arguments with // resolved lambda variables enum ExprOrLambda { Expr(Expr), @@ -414,7 +414,7 @@ impl SqlToRel<'_, S> { .collect::>>()?; let coerced_values = - value_fields_with_lambda_udf(¤t_fields, fm.as_ref())? + value_fields_with_higher_order_udf(¤t_fields, fm.as_ref())? .into_iter() .filter_map(|arg| match arg { ValueOrLambda::Value(value) => Some(value), @@ -489,10 +489,10 @@ impl SqlToRel<'_, S> { }) .collect::>>()?; - let inner = LambdaFunction::new(fm, args); + let inner = HigherOrderFunction::new(fm, args); if name.eq_ignore_ascii_case(inner.name()) { - return Ok(Expr::LambdaFunction(inner)); + return Ok(Expr::HigherOrderFunction(inner)); } else { // If the function is called by an alias, a verbose string representation is created // (e.g., "my_alias(arg1, arg2)") and the expression is wrapped in an `Alias` @@ -505,7 +505,7 @@ impl SqlToRel<'_, S> { .join(","); let verbose_alias = format!("{name}({arg_names})"); - return Ok(Expr::LambdaFunction(inner).alias(verbose_alias)); + return Ok(Expr::HigherOrderFunction(inner).alias(verbose_alias)); } } diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index b7f3cc4e1affa..362d6a46a5861 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -1300,7 +1300,9 @@ mod tests { use datafusion_common::TableReference; use datafusion_common::config::ConfigOptions; use datafusion_expr::logical_plan::builder::LogicalTableSource; - use datafusion_expr::{AggregateUDF, LambdaUDF, ScalarUDF, TableSource, WindowUDF}; + use datafusion_expr::{ + AggregateUDF, HigherOrderUDF, ScalarUDF, TableSource, WindowUDF, + }; use super::*; @@ -1340,7 +1342,7 @@ mod tests { None } - fn get_lambda_meta(&self, _name: &str) -> Option> { + fn get_higher_order_meta(&self, _name: &str) -> Option> { None } @@ -1367,7 +1369,7 @@ mod tests { Vec::new() } - fn udlf_names(&self) -> Vec { + fn udhof_names(&self) -> Vec { Vec::new() } diff --git a/datafusion/sql/src/unparser/dialect.rs b/datafusion/sql/src/unparser/dialect.rs index 1b12b6b8c7432..b1b1a7a485a62 100644 --- a/datafusion/sql/src/unparser/dialect.rs +++ b/datafusion/sql/src/unparser/dialect.rs @@ -159,10 +159,10 @@ pub trait Dialect: Send + Sync { Ok(None) } - /// Allows the dialect to override lambda function unparsing if the dialect has specific rules. + /// Allows the dialect to override higher order function unparsing if the dialect has specific rules. /// Returns None if the default unparsing should be used, or Some(ast::Expr) if there is /// a custom implementation for the function. - fn lambda_function_to_sql_overrides( + fn higher_order_function_to_sql_overrides( &self, _unparser: &Unparser, _func_name: &str, diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index 8f2850c807282..92ace03e34795 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -17,7 +17,7 @@ use datafusion_common::datatype::DataTypeExt; use datafusion_expr::expr::{ - AggregateFunctionParams, LambdaFunction, WindowFunctionParams, + AggregateFunctionParams, HigherOrderFunction, WindowFunctionParams, }; use datafusion_expr::expr::{Lambda, Unnest}; use sqlparser::ast::Value::SingleQuotedString; @@ -555,12 +555,12 @@ impl Unparser<'_> { } Expr::OuterReferenceColumn(_, col) => self.col_to_sql(col), Expr::Unnest(unnest) => self.unnest_to_sql(unnest), - Expr::LambdaFunction(LambdaFunction { func, args }) => { + Expr::HigherOrderFunction(HigherOrderFunction { func, args }) => { let func_name = func.name(); if let Some(expr) = self .dialect - .lambda_function_to_sql_overrides(self, func_name, args)? + .higher_order_function_to_sql_overrides(self, func_name, args)? { return Ok(expr); } @@ -1870,7 +1870,7 @@ mod tests { use datafusion_common::{Spans, TableReference}; use datafusion_expr::expr::WildcardOptions; use datafusion_expr::{ - ColumnarValue, LambdaUDF, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, + ColumnarValue, HigherOrderUDF, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, Volatility, WindowFrame, WindowFunctionDefinition, case, cast, col, cube, exists, grouping_set, interval_datetime_lit, interval_year_month_lit, lambda, lambda_var, lit, not, not_exists, out_ref_col, placeholder, rollup, @@ -1932,18 +1932,18 @@ mod tests { // See sql::tests for E2E tests. #[derive(Debug, Hash, Eq, PartialEq)] - struct DummyLambdaUDF; + struct DummyHigherOrderUDF; - impl LambdaUDF for DummyLambdaUDF { + impl HigherOrderUDF for DummyHigherOrderUDF { fn as_any(&self) -> &dyn Any { unimplemented!() } fn name(&self) -> &str { - "dummy_udlf" + "dummy_udhof" } - fn signature(&self) -> &datafusion_expr::LambdaSignature { + fn signature(&self) -> &datafusion_expr::HigherOrderSignature { unimplemented!() } @@ -1956,14 +1956,14 @@ mod tests { fn return_field_from_args( &self, - _args: datafusion_expr::LambdaReturnFieldArgs, + _args: datafusion_expr::HigherOrderReturnFieldArgs, ) -> Result { unimplemented!() } fn invoke_with_args( &self, - _args: datafusion_expr::LambdaFunctionArgs, + _args: datafusion_expr::HigherOrderFunctionArgs, ) -> Result { unimplemented!() } @@ -2054,8 +2054,8 @@ mod tests { r#"dummy_udf(a, b) IS NOT NULL"#, ), ( - Expr::LambdaFunction(LambdaFunction::new( - Arc::new(DummyLambdaUDF), + Expr::HigherOrderFunction(HigherOrderFunction::new( + Arc::new(DummyHigherOrderUDF), vec![ col("a"), lambda( @@ -2067,7 +2067,7 @@ mod tests { ), ], )), - r#"dummy_udlf(a, (v) -> -v)"#, + r#"dummy_udhof(a, (v) -> -v)"#, ), ( Expr::Like(Like { diff --git a/datafusion/sql/tests/common/mod.rs b/datafusion/sql/tests/common/mod.rs index 620ddeca5778e..aa02cb22d0068 100644 --- a/datafusion/sql/tests/common/mod.rs +++ b/datafusion/sql/tests/common/mod.rs @@ -27,7 +27,9 @@ use datafusion_common::datatype::DataTypeExt; use datafusion_common::file_options::file_type::FileType; use datafusion_common::{DFSchema, GetExt, Result, TableReference, plan_err}; use datafusion_expr::planner::{ExprPlanner, PlannerResult, TypePlanner}; -use datafusion_expr::{AggregateUDF, Expr, LambdaUDF, ScalarUDF, TableSource, WindowUDF}; +use datafusion_expr::{ + AggregateUDF, Expr, HigherOrderUDF, ScalarUDF, TableSource, WindowUDF, +}; use datafusion_functions_nested::expr_fn::make_array; use datafusion_sql::planner::ContextProvider; @@ -54,7 +56,7 @@ impl Display for MockCsvType { #[derive(Default)] pub(crate) struct MockSessionState { scalar_functions: HashMap>, - lambda_functions: HashMap>, + higher_order_functions: HashMap>, aggregate_functions: HashMap>, expr_planners: Vec>, type_planner: Option>, @@ -263,8 +265,8 @@ impl ContextProvider for MockContextProvider { self.state.scalar_functions.get(name).cloned() } - fn get_lambda_meta(&self, name: &str) -> Option> { - self.state.lambda_functions.get(name).cloned() + fn get_higher_order_meta(&self, name: &str) -> Option> { + self.state.higher_order_functions.get(name).cloned() } fn get_aggregate_meta(&self, name: &str) -> Option> { @@ -302,8 +304,8 @@ impl ContextProvider for MockContextProvider { self.state.scalar_functions.keys().cloned().collect() } - fn udlf_names(&self) -> Vec { - self.state.lambda_functions.keys().cloned().collect() + fn udhof_names(&self) -> Vec { + self.state.higher_order_functions.keys().cloned().collect() } fn udaf_names(&self) -> Vec { diff --git a/datafusion/sqllogictest/test_files/lambda.slt b/datafusion/sqllogictest/test_files/higher_order.slt similarity index 99% rename from datafusion/sqllogictest/test_files/lambda.slt rename to datafusion/sqllogictest/test_files/higher_order.slt index fae77cc29a7a7..546f1d39f3f2d 100644 --- a/datafusion/sqllogictest/test_files/lambda.slt +++ b/datafusion/sqllogictest/test_files/higher_order.slt @@ -16,7 +16,7 @@ # under the License. ############# -## Lambda Expressions Tests +## Higher Order Function Tests ############# statement ok diff --git a/datafusion/substrait/src/logical_plan/consumer/expr/scalar_function.rs b/datafusion/substrait/src/logical_plan/consumer/expr/scalar_function.rs index 9c7804624fec6..1a0fb3f55f609 100644 --- a/datafusion/substrait/src/logical_plan/consumer/expr/scalar_function.rs +++ b/datafusion/substrait/src/logical_plan/consumer/expr/scalar_function.rs @@ -30,7 +30,7 @@ pub async fn from_scalar_function( f: &ScalarFunction, input_schema: &DFSchema, ) -> Result { - //TODO: handle lambda functions, as they are also encoded as scalar functions + //TODO: handle higher order functions, as they are also encoded as scalar functions let Some(fn_signature) = consumer .get_extensions() .functions diff --git a/datafusion/substrait/src/logical_plan/producer/expr/mod.rs b/datafusion/substrait/src/logical_plan/producer/expr/mod.rs index 26eb106702367..c255c9ca84afa 100644 --- a/datafusion/substrait/src/logical_plan/producer/expr/mod.rs +++ b/datafusion/substrait/src/logical_plan/producer/expr/mod.rs @@ -149,7 +149,9 @@ pub fn to_substrait_rex( not_impl_err!("Cannot convert {expr:?} to Substrait") } Expr::Unnest(expr) => not_impl_err!("Cannot convert {expr:?} to Substrait"), - Expr::LambdaFunction(expr) => producer.handle_lambda_function(expr, schema), + Expr::HigherOrderFunction(expr) => { + producer.handle_higher_order_function(expr, schema) + } Expr::Lambda(expr) => not_impl_err!("Cannot convert {expr:?} to Substrait"), Expr::LambdaVariable(expr) => { not_impl_err!("Cannot convert {expr:?} to Substrait") diff --git a/datafusion/substrait/src/logical_plan/producer/expr/scalar_function.rs b/datafusion/substrait/src/logical_plan/producer/expr/scalar_function.rs index c3c7defa43bd9..e36d5128cd293 100644 --- a/datafusion/substrait/src/logical_plan/producer/expr/scalar_function.rs +++ b/datafusion/substrait/src/logical_plan/producer/expr/scalar_function.rs @@ -30,9 +30,9 @@ pub fn from_scalar_function( from_function(producer, fun.name(), &fun.args, schema) } -pub fn from_lambda_function( +pub fn from_higher_order_function( producer: &mut impl SubstraitProducer, - fun: &expr::LambdaFunction, + fun: &expr::HigherOrderFunction, schema: &DFSchemaRef, ) -> datafusion::common::Result { from_function(producer, fun.name(), &fun.args, schema) diff --git a/datafusion/substrait/src/logical_plan/producer/substrait_producer.rs b/datafusion/substrait/src/logical_plan/producer/substrait_producer.rs index 64bb50b173c32..a38c1abea0de6 100644 --- a/datafusion/substrait/src/logical_plan/producer/substrait_producer.rs +++ b/datafusion/substrait/src/logical_plan/producer/substrait_producer.rs @@ -19,7 +19,7 @@ use crate::extensions::Extensions; use crate::logical_plan::producer::{ from_aggregate, from_aggregate_function, from_alias, from_between, from_binary_expr, from_case, from_cast, from_column, from_distinct, from_empty_relation, from_exists, - from_filter, from_in_list, from_in_subquery, from_join, from_lambda_function, + from_filter, from_higher_order_function, from_in_list, from_in_subquery, from_join, from_like, from_limit, from_literal, from_projection, from_repartition, from_scalar_function, from_scalar_subquery, from_set_comparison, from_sort, from_subquery_alias, from_table_scan, from_try_cast, from_unary_expr, from_union, @@ -334,12 +334,12 @@ pub trait SubstraitProducer: Send + Sync + Sized { from_scalar_function(self, scalar_fn, schema) } - fn handle_lambda_function( + fn handle_higher_order_function( &mut self, - scalar_fn: &expr::LambdaFunction, + scalar_fn: &expr::HigherOrderFunction, schema: &DFSchemaRef, ) -> datafusion::common::Result { - from_lambda_function(self, scalar_fn, schema) + from_higher_order_function(self, scalar_fn, schema) } fn handle_aggregate_function( From f0cf8d7cf2484b305683184cc10e936eeceddb06 Mon Sep 17 00:00:00 2001 From: gstvg <28798827+gstvg@users.noreply.github.com> Date: Mon, 30 Mar 2026 02:49:55 -0300 Subject: [PATCH 39/47] fix typo --- datafusion/common/src/utils/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/common/src/utils/mod.rs b/datafusion/common/src/utils/mod.rs index 58e1eccc1d42d..66914ccaa6387 100644 --- a/datafusion/common/src/utils/mod.rs +++ b/datafusion/common/src/utils/mod.rs @@ -1065,7 +1065,7 @@ fn replace_nulls_with_first_valid(array: &ArrayRef) -> Result { .ok_or_else(|| _internal_datafusion_err!("fixed size list should have been checked to contain at least one valid value"))?; let mask = BooleanArray::new(nulls.inner().clone(), None); - // perf: remove the null buffer so zip doesn't unnecessarly zip it too + // perf: remove the null buffer so zip doesn't unnecessarily zip it too let without_null_buffer = make_array(array.to_data().into_builder().nulls(None).build()?); let first_valid = array.slice(first_valid, 1); From 1b7f4bfd2c3aa9d9e3c431775e199dc4113e27d3 Mon Sep 17 00:00:00 2001 From: gstvg <28798827+gstvg@users.noreply.github.com> Date: Mon, 30 Mar 2026 03:38:31 -0300 Subject: [PATCH 40/47] handle CaseWhen optimization --- .../physical-expr/src/expressions/case.rs | 24 ++++++++++++++- .../src/expressions/lambda_variable.rs | 5 ++++ .../sqllogictest/test_files/higher_order.slt | 29 +++++++++++++++++++ 3 files changed, 57 insertions(+), 1 deletion(-) diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index 35cfed228121e..3568f1e856dd9 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -19,7 +19,7 @@ mod literal_lookup_table; use super::{Column, Literal}; use crate::PhysicalExpr; -use crate::expressions::{lit, try_cast}; +use crate::expressions::{LambdaExpr, LambdaVariable, lit, try_cast}; use arrow::array::*; use arrow::compute::kernels::zip::zip; use arrow::compute::{ @@ -133,6 +133,13 @@ impl CaseBody { expr.apply(|expr| { if let Some(column) = expr.as_any().downcast_ref::() { used_column_indices.insert(column.index()); + } else if let Some(lambda_variable) = + expr.as_any().downcast_ref::() + { + used_column_indices.insert(lambda_variable.index()); + } else if expr.as_any().is::() { + //todo: remove this branch when lambda supports column capture + return Ok(TreeNodeRecursion::Jump); } Ok(TreeNodeRecursion::Continue) }) @@ -171,6 +178,21 @@ impl CaseBody { projected, )))); } + } else if let Some(lambda_variable) = + expr.as_any().downcast_ref::() + { + let original = lambda_variable.index(); + let projected = *column_index_map.get(&original).unwrap(); + if projected != original { + return Ok(Transformed::yes(Arc::new(LambdaVariable::new( + lambda_variable.name().to_owned(), + projected, + Arc::clone(lambda_variable.field()), + )))); + } + } else if expr.as_any().is::() { + //todo: remove this branch when lambda supports column capture + return Ok(Transformed::new(e, false, TreeNodeRecursion::Jump)); } Ok(Transformed::no(e)) }) diff --git a/datafusion/physical-expr/src/expressions/lambda_variable.rs b/datafusion/physical-expr/src/expressions/lambda_variable.rs index f4d4a294d491b..25c7d869a5801 100644 --- a/datafusion/physical-expr/src/expressions/lambda_variable.rs +++ b/datafusion/physical-expr/src/expressions/lambda_variable.rs @@ -65,6 +65,11 @@ impl LambdaVariable { &self.name } + /// Get the variable's index + pub fn index(&self) -> usize { + self.index + } + /// Get the variable's field pub fn field(&self) -> &FieldRef { &self.field diff --git a/datafusion/sqllogictest/test_files/higher_order.slt b/datafusion/sqllogictest/test_files/higher_order.slt index 546f1d39f3f2d..7c6a6b1045901 100644 --- a/datafusion/sqllogictest/test_files/higher_order.slt +++ b/datafusion/sqllogictest/test_files/higher_order.slt @@ -152,6 +152,35 @@ select array_transform(arrow_cast(t.list, 'ListView(Int32)'), a -> a+1) from t; [5, 51] [8, 51] +# higher order function with inner case using lambda variables only +query ? +select array_transform([3, 5, 0], v -> case when v > 1 then 2 when v > 4 then 6 else 8 end); +---- +[2, 2, 8] + +#case with inner higher order function +query I?? +select + t.number, + t.list, + case + when t.number > 30 then array_transform(t.list, v -> v+1) + else array_transform(t.list, v -> v*2) + end +from t +order by t.number; +---- +10 [1, 50] [2, 100] +40 [4, 50] [5, 51] +60 [7, 50] [8, 51] + +# higher order function with inner case using lambda variables and captured column(capture not supported yet) +query error DataFusion error: Error during planning: lambda doesn't support column capture +select array_transform([3, 5, 9], v -> case when v > 1 then t.number when v > 4 then 6 else 8 end) from t; + +# higher order function with inner case using captured column only(capture not supported yet) +query error DataFusion error: Error during planning: lambda doesn't support column capture +select array_transform([3, 5, 9], v -> case when t.number > 1 then 2 when t.number > 4 then 6 else 8 end) from t; query error select array_transform(); From 6d7c52aa3b132d1dc7e8dcebfccbad209a8c4a15 Mon Sep 17 00:00:00 2001 From: gstvg <28798827+gstvg@users.noreply.github.com> Date: Mon, 30 Mar 2026 03:39:25 -0300 Subject: [PATCH 41/47] remove HigherOrderUDF::as_any --- datafusion/expr/src/udf_eq.rs | 2 +- datafusion/expr/src/udhof.rs | 11 ++--------- datafusion/functions-nested/src/array_transform.rs | 6 +----- datafusion/physical-expr/src/higher_order_function.rs | 5 ----- datafusion/proto/src/bytes/mod.rs | 4 ---- datafusion/sql/src/unparser/expr.rs | 4 ---- 6 files changed, 4 insertions(+), 28 deletions(-) diff --git a/datafusion/expr/src/udf_eq.rs b/datafusion/expr/src/udf_eq.rs index 86ae9964492c3..68b69dd30cf86 100644 --- a/datafusion/expr/src/udf_eq.rs +++ b/datafusion/expr/src/udf_eq.rs @@ -95,7 +95,7 @@ impl UdfPointer for Arc { impl UdfPointer for Arc { fn equals(&self, other: &Self::Target) -> bool { - self.as_ref().dyn_eq(other.as_any()) + self.as_ref().dyn_eq(other) } fn hash_value(&self) -> u64 { diff --git a/datafusion/expr/src/udhof.rs b/datafusion/expr/src/udhof.rs index 0092f098ad56c..b2818f6101491 100644 --- a/datafusion/expr/src/udhof.rs +++ b/datafusion/expr/src/udhof.rs @@ -111,7 +111,7 @@ impl HigherOrderSignature { impl PartialEq for dyn HigherOrderUDF { fn eq(&self, other: &Self) -> bool { - self.dyn_eq(other.as_any()) + self.dyn_eq(other as _) } } @@ -268,10 +268,7 @@ pub enum ValueOrLambda { /// See [`array_transform.rs`] for a commented complete implementation /// /// [`array_transform.rs`]: https://github.com/apache/datafusion/blob/main/datafusion/functions-nested/src/array_transform.rs -pub trait HigherOrderUDF: Debug + DynEq + DynHash + Send + Sync { - /// Returns this object as an [`Any`] trait object - fn as_any(&self) -> &dyn Any; - +pub trait HigherOrderUDF: Debug + DynEq + DynHash + Send + Sync + Any { /// Returns this function's name fn name(&self) -> &str; @@ -486,10 +483,6 @@ mod tests { signature: HigherOrderSignature, } impl HigherOrderUDF for TestHigherOrderUDF { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { self.name } diff --git a/datafusion/functions-nested/src/array_transform.rs b/datafusion/functions-nested/src/array_transform.rs index 2653d874af200..86e54abe812ec 100644 --- a/datafusion/functions-nested/src/array_transform.rs +++ b/datafusion/functions-nested/src/array_transform.rs @@ -33,7 +33,7 @@ use datafusion_expr::{ HigherOrderSignature, HigherOrderUDF, ValueOrLambda, Volatility, }; use datafusion_macros::user_doc; -use std::{any::Any, fmt::Debug, sync::Arc}; +use std::{fmt::Debug, sync::Arc}; make_udhof_expr_and_func!( ArrayTransform, @@ -83,10 +83,6 @@ impl ArrayTransform { } impl HigherOrderUDF for ArrayTransform { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "array_transform" } diff --git a/datafusion/physical-expr/src/higher_order_function.rs b/datafusion/physical-expr/src/higher_order_function.rs index d77f521d65515..0afc2de9388df 100644 --- a/datafusion/physical-expr/src/higher_order_function.rs +++ b/datafusion/physical-expr/src/higher_order_function.rs @@ -400,7 +400,6 @@ impl PhysicalExpr for HigherOrderFunctionExpr { #[cfg(test)] mod tests { - use std::any::Any; use std::sync::Arc; use super::*; @@ -422,10 +421,6 @@ mod tests { } impl HigherOrderUDF for MockHigherOrderUDF { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "mock_function" } diff --git a/datafusion/proto/src/bytes/mod.rs b/datafusion/proto/src/bytes/mod.rs index 44c6ee3779778..ec22ce9d6f9ba 100644 --- a/datafusion/proto/src/bytes/mod.rs +++ b/datafusion/proto/src/bytes/mod.rs @@ -135,10 +135,6 @@ impl Serializeable for Expr { } impl HigherOrderUDF for MockHigherOrderUDF { - fn as_any(&self) -> &dyn std::any::Any { - self - } - fn name(&self) -> &str { &self.name } diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index 92ace03e34795..615337007565e 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -1935,10 +1935,6 @@ mod tests { struct DummyHigherOrderUDF; impl HigherOrderUDF for DummyHigherOrderUDF { - fn as_any(&self) -> &dyn Any { - unimplemented!() - } - fn name(&self) -> &str { "dummy_udhof" } From 72558201bbc00641d2a3b426417e0eb42cbd7169 Mon Sep 17 00:00:00 2001 From: gstvg <28798827+gstvg@users.noreply.github.com> Date: Mon, 30 Mar 2026 11:04:36 -0300 Subject: [PATCH 42/47] add TaskContext::higher_order_functions --- datafusion/execution/src/task.rs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/datafusion/execution/src/task.rs b/datafusion/execution/src/task.rs index ed3de93ff0e6d..093c66c416ccb 100644 --- a/datafusion/execution/src/task.rs +++ b/datafusion/execution/src/task.rs @@ -128,6 +128,10 @@ impl TaskContext { &self.scalar_functions } + pub fn higher_order_functions(&self) -> &HashMap> { + &self.higher_order_functions + } + pub fn aggregate_functions(&self) -> &HashMap> { &self.aggregate_functions } From 5c5ca195d1fc2c708a1e4964e8392337fc3b0b02 Mon Sep 17 00:00:00 2001 From: gstvg <28798827+gstvg@users.noreply.github.com> Date: Wed, 1 Apr 2026 10:41:58 -0300 Subject: [PATCH 43/47] support lambda column capture, casewhen aware --- datafusion/common/src/utils/mod.rs | 135 +++++++++++++++++- datafusion/expr/src/udhof.rs | 124 ++++++++++++++-- .../functions-nested/src/array_transform.rs | 12 +- .../physical-expr/src/expressions/case.rs | 55 ++++--- .../physical-expr/src/expressions/lambda.rs | 79 +++++++++- .../src/higher_order_function.rs | 15 +- datafusion/physical-expr/src/planner.rs | 59 +++++--- .../sqllogictest/test_files/higher_order.slt | 24 ++-- 8 files changed, 439 insertions(+), 64 deletions(-) diff --git a/datafusion/common/src/utils/mod.rs b/datafusion/common/src/utils/mod.rs index 66914ccaa6387..b48f83d7c1b9b 100644 --- a/datafusion/common/src/utils/mod.rs +++ b/datafusion/common/src/utils/mod.rs @@ -31,20 +31,20 @@ use arrow::array::{ cast::AsArray, }; use arrow::array::{ - BooleanArray, Datum, GenericListArray, Int32Array, Int64Array, MutableArrayData, - Scalar, make_array, + ArrowPrimitiveType, BooleanArray, Datum, GenericListArray, Int32Array, Int64Array, MutableArrayData, PrimitiveArray, Scalar, make_array }; use arrow::buffer::OffsetBuffer; use arrow::compute::kernels::cmp::neq; use arrow::compute::kernels::length::length; use arrow::compute::kernels::zip::zip; use arrow::compute::{SortColumn, SortOptions, partition}; -use arrow::datatypes::{DataType, Field, SchemaRef}; +use arrow::datatypes::{ArrowNativeType, DataType, Field, Int32Type, Int64Type, SchemaRef}; #[cfg(feature = "sql")] use sqlparser::{ast::Ident, dialect::GenericDialect, parser::Parser}; use std::borrow::{Borrow, Cow}; use std::cmp::{Ordering, min}; use std::collections::HashSet; +use std::iter::repeat_n; use std::num::NonZero; use std::ops::Range; use std::sync::{Arc, LazyLock}; @@ -1141,6 +1141,62 @@ fn truncate_list_nulls( Ok(list.clone()) } +/// If `array` is a contiguos list, returns a new array of the same length as it's inner values +/// where each value is the 1-based index of the sublist it's contained. Example: +/// +/// `[[1], [2, 3], [4, 5, 6]] => [1, 2, 2, 3, 3, 3]` +/// +/// If it's not a contiguos list, return an error +pub fn list_values_row_number(array: &dyn Array) -> Result { + match array.data_type() { + DataType::List(_) => Ok(Arc::new(list_array_values_row_number::( + array.as_list().offsets(), + ))), + DataType::LargeList(_) => Ok(Arc::new( + list_array_values_row_number::(array.as_list().offsets()), + )), + DataType::FixedSizeList(_, _) => { + let fixed_size_list = array.as_fixed_size_list(); + + Ok(Arc::new(fsl_values_row_number( + fixed_size_list.value_length(), + fixed_size_list.len(), + )?)) + } + other => _exec_err!("expected list, got {other}"), + } +} + +/// [0, 2, 2, 5, 6] -> [0, 0, 2, 2, 2, 3] +fn list_array_values_row_number( + offsets: &OffsetBuffer, +) -> PrimitiveArray { + let mut rows_number = Vec::with_capacity( + offsets[offsets.len() - 1].to_usize().unwrap() - offsets[0].to_usize().unwrap(), + ); + + for (i, len) in offsets.lengths().enumerate() { + rows_number.extend(repeat_n(T::Native::usize_as(i), len)); + } + + PrimitiveArray::new(rows_number.into(), None) +} + +/// (2, 3) -> [0, 0, 1, 1, 2, 2] +fn fsl_values_row_number(list_size: i32, array_len: usize) -> Result { + let list_size = list_size.to_usize().ok_or_else(|| { + _exec_datafusion_err!("fsl_values_index: invalid list_size {list_size}") + })?; + + let mut rows_number = Vec::with_capacity(list_size * array_len); + + for i in 0..array_len { + rows_number.extend(repeat_n(i as i32, list_size)); + } + + Ok(PrimitiveArray::new(rows_number.into(), None)) +} + #[cfg(test)] mod tests { use std::sync::Arc; @@ -1614,4 +1670,77 @@ mod tests { expected.as_fixed_size_list().values() ); } + + #[test] + fn test_list_array_values_row_number() { + assert_eq!( + list_array_values_row_number::(&OffsetBuffer::from_lengths([ + 1, 3, 0, 2, + ])), + Int32Array::from(vec![0, 1, 1, 1, 3, 3]) + ); + + assert_eq!( + list_array_values_row_number::(&OffsetBuffer::from_lengths([])), + Int32Array::new_null(0) + ); + + assert_eq!( + list_array_values_row_number::(&OffsetBuffer::from_lengths([0])), + Int32Array::new_null(0) + ); + + assert_eq!( + list_array_values_row_number::(&OffsetBuffer::from_lengths([ + 0, 0 + ])), + Int32Array::new_null(0) + ); + + assert_eq!( + list_array_values_row_number::(&OffsetBuffer::from_lengths([1])), + Int32Array::from(vec![0]) + ); + + assert_eq!( + list_array_values_row_number::(&OffsetBuffer::from_lengths([2])), + Int32Array::from(vec![0, 0]) + ); + } + + #[test] + fn test_fsl_values_row_number() { + assert_eq!( + fsl_values_row_number(2, 3).unwrap(), + Int32Array::from(vec![0, 0, 1, 1, 2, 2]) + ); + + assert_eq!( + fsl_values_row_number(1, 3).unwrap(), + Int32Array::from(vec![0, 1, 2]) + ); + + assert_eq!( + fsl_values_row_number(2, 1).unwrap(), + Int32Array::from(vec![0, 0]) + ); + + assert_eq!( + fsl_values_row_number(2, 0).unwrap(), + Int32Array::new_null(0), + ); + + assert_eq!( + fsl_values_row_number(0, 2).unwrap(), + Int32Array::new_null(0), + ); + + assert_eq!( + fsl_values_row_number(0, 0).unwrap(), + Int32Array::new_null(0), + ); + + fsl_values_row_number(-1, 2).unwrap_err(); + fsl_values_row_number(-1, 0).unwrap_err(); + } } diff --git a/datafusion/expr/src/udhof.rs b/datafusion/expr/src/udhof.rs index b2818f6101491..a54dcf34560d8 100644 --- a/datafusion/expr/src/udhof.rs +++ b/datafusion/expr/src/udhof.rs @@ -22,7 +22,7 @@ use crate::{ColumnarValue, Documentation, Expr}; use arrow::array::{ArrayRef, RecordBatch}; use arrow::datatypes::{DataType, Field, FieldRef, Schema}; use datafusion_common::config::ConfigOptions; -use datafusion_common::{Result, ScalarValue, not_impl_err}; +use datafusion_common::{Result, ScalarValue, exec_err, not_impl_err}; use datafusion_expr_common::dyn_eq::{DynEq, DynHash}; use datafusion_expr_common::signature::Volatility; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; @@ -194,34 +194,136 @@ pub struct LambdaArgument { /// For example, for `array_transform([2], v -> -v)`, /// this will be the physical expression of `-v` body: Arc, + /// A RecordBatch containing at least the captured columns inside this lambda body, if any + /// Note that it may contain additional, non-specified columns, but that's a implementation detail + /// + /// For example, for `array_transform([2], v -> v + a + b)`, + /// this will be a `RecordBatch` with at least two columns, `a` and `b` + captures: Option, } impl LambdaArgument { - pub fn new(params: Vec, body: Arc) -> Self { - Self { params, body } + pub fn new( + params: Vec, + body: Arc, + captures: Option, + ) -> Self { + Self { + params, + body, + captures, + } } /// Evaluate this lambda /// `args` should evaluate to the value of each parameter /// of the correspondent lambda returned in [HigherOrderUDF::lambda_parameters]. + /// + /// `adjust` should adjust the captured columns of this + /// lambda, if any, relative to it's parameters + /// + /// Tip: For adjusting multiple arrays by indices, use [`take_arrays`] + /// + /// [`take_arrays`]: arrow::compute::take_arrays pub fn evaluate( &self, args: &[&dyn Fn() -> Result], + adjust: impl FnOnce(&[ArrayRef]) -> Result>, ) -> Result { - let columns = args - .iter() - .take(self.params.len()) - .map(|arg| arg()) - .collect::>()?; + let adjusted_captures = self + .captures + .as_ref() + .map(|captures| { + let adjusted_columns = adjust(captures.columns())?; + + RecordBatch::try_new(captures.schema(), adjusted_columns) + }) + .transpose()?; + + let merged = merge_captures_with_variables( + adjusted_captures.as_ref(), + &self.params, + args, + )?; + + self.body.evaluate(&merged) + } +} + +fn merge_captures_with_variables( + captures: Option<&RecordBatch>, + params: &[FieldRef], + variables: &[&dyn Fn() -> Result], +) -> Result { + if variables.len() < params.len() { + return exec_err!( + "expected at least {} lambda arguments to merge with captures, got {}", + params.len(), + variables.len() + ); + } + + match captures { + Some(captures) => { + let old_fields = captures.schema_ref().fields(); + + let mut new_fields = old_fields + .iter() + .map(|field| { + if !fields_contains(params, field.name()) { + return Arc::clone(field); + } - let schema = Arc::new(Schema::new(self.params.clone())); + let mut i = 0; - let batch = RecordBatch::try_new(schema, columns)?; + loop { + let alias = format!("{}_shadowed_{i}", field.name()); - self.body.evaluate(&batch) + if !fields_contains(params, &alias) + && old_fields.find(&alias).is_none() + { + break Arc::new(Field::new( + alias, + field.data_type().clone(), + field.is_nullable(), + )); + } + + i += 1; + } + }) + .collect::>(); + + new_fields.extend_from_slice(params); + + let mut columns = captures.columns().to_vec(); + + for arg in &variables[..params.len()] { + columns.push(arg()?); + } + + let new_schema = Arc::new(Schema::new(new_fields)); + + Ok(RecordBatch::try_new(new_schema, columns)?) + } + None => { + let columns = variables + .iter() + .take(params.len()) + .map(|arg| arg()) + .collect::>()?; + + let schema = Arc::new(Schema::new(params)); + + Ok(RecordBatch::try_new(schema, columns)?) + } } } +fn fields_contains(fields: &[FieldRef], name: &str) -> bool { + fields.iter().any(|f| f.name().as_str() == name) +} + /// Information about arguments passed to the function /// /// This structure contains metadata about how the function was called diff --git a/datafusion/functions-nested/src/array_transform.rs b/datafusion/functions-nested/src/array_transform.rs index 86e54abe812ec..c0a48fad2b132 100644 --- a/datafusion/functions-nested/src/array_transform.rs +++ b/datafusion/functions-nested/src/array_transform.rs @@ -22,11 +22,14 @@ use arrow::{ Array, ArrayRef, AsArray, FixedSizeListArray, LargeListArray, ListArray, new_null_array, }, + compute::take_arrays, datatypes::{DataType, Field, FieldRef}, }; use datafusion_common::{ Result, exec_err, plan_err, - utils::{adjust_offsets_for_slice, list_values, take_function_args}, + utils::{ + adjust_offsets_for_slice, list_values, list_values_row_number, take_function_args, + }, }; use datafusion_expr::{ ColumnarValue, Documentation, HigherOrderFunctionArgs, HigherOrderReturnFieldArgs, @@ -200,7 +203,12 @@ impl HigherOrderUDF for ArrayTransform { // call the transforming lambda let transformed_values = lambda - .evaluate(&[&values_param])? + .evaluate(&[&values_param], |arrays| { + // if any column got captured, we need to adjust it to the values arrays, + // duplicating values of list with mulitple values and removing values of empty lists + let indices = list_values_row_number(&list_array)?; + Ok(take_arrays(arrays, &indices, None)?) + })? .into_array(list_values.len())?; let field = match args.return_field.data_type() { diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index 3568f1e856dd9..842474b6cc798 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -19,7 +19,7 @@ mod literal_lookup_table; use super::{Column, Literal}; use crate::PhysicalExpr; -use crate::expressions::{LambdaExpr, LambdaVariable, lit, try_cast}; +use crate::expressions::{LambdaVariable, lit, try_cast}; use arrow::array::*; use arrow::compute::kernels::zip::zip; use arrow::compute::{ @@ -137,9 +137,6 @@ impl CaseBody { expr.as_any().downcast_ref::() { used_column_indices.insert(lambda_variable.index()); - } else if expr.as_any().is::() { - //todo: remove this branch when lambda supports column capture - return Ok(TreeNodeRecursion::Jump); } Ok(TreeNodeRecursion::Continue) }) @@ -179,7 +176,7 @@ impl CaseBody { )))); } } else if let Some(lambda_variable) = - expr.as_any().downcast_ref::() + e.as_any().downcast_ref::() { let original = lambda_variable.index(); let projected = *column_index_map.get(&original).unwrap(); @@ -190,9 +187,6 @@ impl CaseBody { Arc::clone(lambda_variable.field()), )))); } - } else if expr.as_any().is::() { - //todo: remove this branch when lambda supports column capture - return Ok(Transformed::new(e, false, TreeNodeRecursion::Jump)); } Ok(Transformed::no(e)) }) @@ -1039,8 +1033,15 @@ impl CaseExpr { projected: &ProjectedCaseBody, ) -> Result { let return_type = self.data_type(&batch.schema())?; - if projected.projection.len() < batch.num_columns() { - let projected_batch = batch.project(&projected.projection)?; + // projected.projection may include indexes of lambda variables not available on this batch + let projection = projected + .projection + .iter() + .copied() + .filter(|index| *index < batch.num_columns()) + .collect::>(); + if projection.len() < batch.num_columns() { + let projected_batch = batch.project(&projection)?; projected .body .case_when_with_expr(&projected_batch, &return_type) @@ -1062,8 +1063,15 @@ impl CaseExpr { projected: &ProjectedCaseBody, ) -> Result { let return_type = self.data_type(&batch.schema())?; - if projected.projection.len() < batch.num_columns() { - let projected_batch = batch.project(&projected.projection)?; + // projected.projection may include indexes of lambda variables not available on this batch + let projection = projected + .projection + .iter() + .copied() + .filter(|index| *index < batch.num_columns()) + .collect::>(); + if projection.len() < batch.num_columns() { + let projected_batch = batch.project(&projection)?; projected .body .case_when_no_expr(&projected_batch, &return_type) @@ -1181,14 +1189,23 @@ impl CaseExpr { )?)) } } - } else if projected.projection.len() < batch.num_columns() { - // The case expressions do not use all the columns of the input batch. - // Project first to reduce time spent filtering. - let projected_batch = batch.project(&projected.projection)?; - projected.body.expr_or_expr(&projected_batch, when_value) } else { - // All columns are used in the case expressions, so there is no need to project. - self.body.expr_or_expr(batch, when_value) + // projected.projection may include indexes of lambda variables not available on this batch + let projection = projected + .projection + .iter() + .copied() + .filter(|index| *index < batch.num_columns()) + .collect::>(); + if projection.len() < batch.num_columns() { + // The case expressions do not use all the columns of the input batch. + // Project first to reduce time spent filtering. + let projected_batch = batch.project(&projection)?; + projected.body.expr_or_expr(&projected_batch, when_value) + } else { + // All columns are used in the case expressions, so there is no need to project. + self.body.expr_or_expr(batch, when_value) + } } } diff --git a/datafusion/physical-expr/src/expressions/lambda.rs b/datafusion/physical-expr/src/expressions/lambda.rs index c06b48591b70e..aa0247bab8699 100644 --- a/datafusion/physical-expr/src/expressions/lambda.rs +++ b/datafusion/physical-expr/src/expressions/lambda.rs @@ -21,12 +21,18 @@ use std::any::Any; use std::hash::Hash; use std::sync::Arc; -use crate::physical_expr::PhysicalExpr; +use crate::{ + expressions::{Column, LambdaVariable}, + physical_expr::PhysicalExpr, +}; use arrow::{ datatypes::{DataType, Schema}, record_batch::RecordBatch, }; -use datafusion_common::plan_err; +use datafusion_common::{ + HashMap, plan_err, + tree_node::{Transformed, TreeNode, TreeNodeRecursion}, +}; use datafusion_common::{HashSet, Result, internal_err}; use datafusion_expr::ColumnarValue; @@ -35,6 +41,8 @@ use datafusion_expr::ColumnarValue; pub struct LambdaExpr { params: Vec, body: Arc, + projected_body: Arc, + projection: Vec, } // Manually derive PartialEq and Hash to work around https://github.com/rust-lang/rust/issues/78808 [https://github.com/apache/datafusion/issues/13196] @@ -62,7 +70,64 @@ impl LambdaExpr { } fn new(params: Vec, body: Arc) -> Self { - Self { params, body } + let mut used_column_indices = HashSet::new(); + + body.apply(|node| { + if let Some(col) = node.as_any().downcast_ref::() { + used_column_indices.insert(col.index()); + } else if let Some(var) = node.as_any().downcast_ref::() { + used_column_indices.insert(var.index()); + } + + Ok(TreeNodeRecursion::Continue) + }) + .expect("closure should be infallible"); + + let mut projection = used_column_indices.into_iter().collect::>(); + + projection.sort(); + + let column_index_map = projection + .iter() + .enumerate() + .map(|(projected, original)| (*original, projected)) + .collect::>(); + + let projected_body = Arc::clone(&body) + .transform(|e| { + if let Some(column) = e.as_any().downcast_ref::() { + let original = column.index(); + let projected = *column_index_map.get(&original).unwrap(); + if projected != original { + return Ok(Transformed::yes(Arc::new(Column::new( + column.name(), + projected, + )))); + } + } else if let Some(lambda_variable) = + e.as_any().downcast_ref::() + { + let original = lambda_variable.index(); + let projected = *column_index_map.get(&original).unwrap(); + if projected != original { + return Ok(Transformed::yes(Arc::new(LambdaVariable::new( + lambda_variable.name().to_owned(), + projected, + Arc::clone(lambda_variable.field()), + )))); + } + } + Ok(Transformed::no(e)) + }) + .expect("closure should be infallible") + .data; + + Self { + params, + body, + projected_body, + projection, + } } /// Get the lambda's params names @@ -74,6 +139,14 @@ impl LambdaExpr { pub fn body(&self) -> &Arc { &self.body } + + pub(crate) fn projection(&self) -> &[usize] { + &self.projection + } + + pub(crate) fn projected_body(&self) -> &Arc { + &self.projected_body + } } impl std::fmt::Display for LambdaExpr { diff --git a/datafusion/physical-expr/src/higher_order_function.rs b/datafusion/physical-expr/src/higher_order_function.rs index 0afc2de9388df..45321ed7cd844 100644 --- a/datafusion/physical-expr/src/higher_order_function.rs +++ b/datafusion/physical-expr/src/higher_order_function.rs @@ -305,9 +305,22 @@ impl PhysicalExpr for HigherOrderFunctionExpr { .map(|(name, param)| Arc::new(param.with_name(name))) .collect(); + // lambda.projection may include indexes of nested lambda variables not present on this batch + let projection = lambda + .projection() + .iter() + .copied() + .filter(|i| *i < batch.num_columns()) + .collect::>(); + Ok(ValueOrLambda::Lambda(LambdaArgument::new( params, - Arc::clone(lambda.body()), + Arc::clone(lambda.projected_body()), + if projection.is_empty() { + None + } else { + Some(batch.project(&projection)?) + }, ))) } None => { diff --git a/datafusion/physical-expr/src/planner.rs b/datafusion/physical-expr/src/planner.rs index 30dee27c3014f..f2d69103ab0ef 100644 --- a/datafusion/physical-expr/src/planner.rs +++ b/datafusion/physical-expr/src/planner.rs @@ -15,7 +15,6 @@ // specific language governing permissions and limitations // under the License. -use std::collections::HashMap; use std::sync::Arc; use crate::{HigherOrderFunctionExpr, ScalarFunctionExpr}; @@ -24,7 +23,7 @@ use crate::{ expressions::{self, Column, Literal, binary, like, similar_to}, }; -use arrow::datatypes::Schema; +use arrow::datatypes::{Field, Schema}; use datafusion_common::config::ConfigOptions; use datafusion_common::metadata::{FieldMetadata, format_type_and_metadata}; use datafusion_common::{ @@ -438,6 +437,37 @@ pub fn create_physical_expr( .iter() .map(|arg| match arg { Expr::Lambda(lambda) => { + let mut new_fields = input_dfschema + .iter() + .map(|(qualifier, field)| { + if !lambda.params.contains(field.name()) { + return (qualifier.cloned(), Arc::clone(field)); + } + + let mut i = 0; + + loop { + let alias = format!("{}_shadowed_{i}", field.name()); + + if !lambda.params.contains(&alias) + && !input_dfschema + .has_column_with_unqualified_name(&alias) + { + break ( + qualifier.cloned(), + Arc::new(Field::new( + alias, + field.data_type().clone(), + field.is_nullable(), + )), + ); + } + + i += 1; + } + }) + .collect::>(); + let lambda_parameters = lambda_parameters .next() .ok_or_else(|| { @@ -447,12 +477,13 @@ pub fn create_physical_expr( })? .into_iter() .zip(&lambda.params) - .map(|(field, name)| field.with_name(name)) - .collect(); + .map(|(field, name)| (None, Arc::new(field.with_name(name)))); + + new_fields.extend(lambda_parameters); - let lambda_schema = DFSchema::from_unqualified_fields( - lambda_parameters, - HashMap::new(), + let lambda_schema = DFSchema::new_with_metadata( + new_fields, + input_dfschema.metadata().clone(), )?; create_physical_expr(arg, &lambda_schema, execution_props) @@ -473,16 +504,10 @@ pub fn create_physical_expr( config_options, )?)) } - Expr::Lambda(Lambda { params, body }) => { - if body.any_column_refs() { - return plan_err!("lambda doesn't support column capture"); - } - - expressions::lambda( - params, - create_physical_expr(body, input_dfschema, execution_props)?, - ) - } + Expr::Lambda(Lambda { params, body }) => expressions::lambda( + params, + create_physical_expr(body, input_dfschema, execution_props)?, + ), Expr::LambdaVariable(LambdaVariable { name, field, diff --git a/datafusion/sqllogictest/test_files/higher_order.slt b/datafusion/sqllogictest/test_files/higher_order.slt index 7c6a6b1045901..c31f0e871b35c 100644 --- a/datafusion/sqllogictest/test_files/higher_order.slt +++ b/datafusion/sqllogictest/test_files/higher_order.slt @@ -154,9 +154,9 @@ select array_transform(arrow_cast(t.list, 'ListView(Int32)'), a -> a+1) from t; # higher order function with inner case using lambda variables only query ? -select array_transform([3, 5, 0], v -> case when v > 1 then 2 when v > 4 then 6 else 8 end); +select array_transform([1, 5, 9], v -> case when v = 1 then 2 when v = 5 then 6 else 8 end); ---- -[2, 2, 8] +[2, 6, 8] #case with inner higher order function query I?? @@ -174,13 +174,21 @@ order by t.number; 40 [4, 50] [5, 51] 60 [7, 50] [8, 51] -# higher order function with inner case using lambda variables and captured column(capture not supported yet) -query error DataFusion error: Error during planning: lambda doesn't support column capture -select array_transform([3, 5, 9], v -> case when v > 1 then t.number when v > 4 then 6 else 8 end) from t; +# higher order function with inner case using lambda variables and captured column +query ? +select array_transform([1, 5, 9], v -> case when v = 1 then t.number when v = 5 then 6 else 8 end) from t; +---- +[10, 6, 8] +[40, 6, 8] +[60, 6, 8] -# higher order function with inner case using captured column only(capture not supported yet) -query error DataFusion error: Error during planning: lambda doesn't support column capture -select array_transform([3, 5, 9], v -> case when t.number > 1 then 2 when t.number > 4 then 6 else 8 end) from t; +# higher order function with inner case using captured column only +query ? +select array_transform([3, 5, 9], v -> case when t.number = 10 then 2 when t.number = 40 then 6 else 8 end) from t; +---- +[2, 2, 2] +[6, 6, 6] +[8, 8, 8] query error select array_transform(); From a4eaeb2f48b9bd9c64ed8f4fa72dd112ed0d22ec Mon Sep 17 00:00:00 2001 From: gstvg <28798827+gstvg@users.noreply.github.com> Date: Mon, 6 Apr 2026 15:54:00 -0300 Subject: [PATCH 44/47] add support for listview and map --- datafusion/common/src/utils/mod.rs | 69 ++++++++++++++++++++---------- 1 file changed, 46 insertions(+), 23 deletions(-) diff --git a/datafusion/common/src/utils/mod.rs b/datafusion/common/src/utils/mod.rs index b48f83d7c1b9b..cc3b95c1cd528 100644 --- a/datafusion/common/src/utils/mod.rs +++ b/datafusion/common/src/utils/mod.rs @@ -31,14 +31,17 @@ use arrow::array::{ cast::AsArray, }; use arrow::array::{ - ArrowPrimitiveType, BooleanArray, Datum, GenericListArray, Int32Array, Int64Array, MutableArrayData, PrimitiveArray, Scalar, make_array + ArrowPrimitiveType, BooleanArray, Datum, GenericListArray, Int32Array, Int64Array, + MutableArrayData, PrimitiveArray, Scalar, make_array, }; use arrow::buffer::OffsetBuffer; use arrow::compute::kernels::cmp::neq; use arrow::compute::kernels::length::length; use arrow::compute::kernels::zip::zip; use arrow::compute::{SortColumn, SortOptions, partition}; -use arrow::datatypes::{ArrowNativeType, DataType, Field, Int32Type, Int64Type, SchemaRef}; +use arrow::datatypes::{ + ArrowNativeType, DataType, Field, Int32Type, Int64Type, SchemaRef, +}; #[cfg(feature = "sql")] use sqlparser::{ast::Ident, dialect::GenericDialect, parser::Parser}; use std::borrow::{Borrow, Cow}; @@ -1141,20 +1144,28 @@ fn truncate_list_nulls( Ok(list.clone()) } -/// If `array` is a contiguos list, returns a new array of the same length as it's inner values +/// If `array` is a list or a map, returns a new array of the same length as it's inner values /// where each value is the 1-based index of the sublist it's contained. Example: /// /// `[[1], [2, 3], [4, 5, 6]] => [1, 2, 2, 3, 3, 3]` /// -/// If it's not a contiguos list, return an error +/// Otherwise returns an error pub fn list_values_row_number(array: &dyn Array) -> Result { match array.data_type() { - DataType::List(_) => Ok(Arc::new(list_array_values_row_number::( - array.as_list().offsets(), - ))), - DataType::LargeList(_) => Ok(Arc::new( - list_array_values_row_number::(array.as_list().offsets()), - )), + DataType::List(_) => Ok(Arc::new(variable_size_list_values_row_number::< + Int32Type, + >(array.as_list().offsets()))), + DataType::LargeList(_) => Ok(Arc::new(variable_size_list_values_row_number::< + Int64Type, + >(array.as_list().offsets()))), + DataType::ListView(_) => Ok(Arc::new(variable_size_list_values_row_number::< + Int32Type, + >(array.as_list_view().offsets()))), + DataType::LargeListView(_) => { + Ok(Arc::new(variable_size_list_values_row_number::( + array.as_list_view().offsets(), + ))) + } DataType::FixedSizeList(_, _) => { let fixed_size_list = array.as_fixed_size_list(); @@ -1163,19 +1174,23 @@ pub fn list_values_row_number(array: &dyn Array) -> Result { fixed_size_list.len(), )?)) } + DataType::Map(_, _) => Ok(Arc::new(variable_size_list_values_row_number::< + Int32Type, + >(array.as_map().offsets()))), other => _exec_err!("expected list, got {other}"), } } /// [0, 2, 2, 5, 6] -> [0, 0, 2, 2, 2, 3] -fn list_array_values_row_number( - offsets: &OffsetBuffer, +fn variable_size_list_values_row_number( + offsets: &[T::Native], ) -> PrimitiveArray { let mut rows_number = Vec::with_capacity( offsets[offsets.len() - 1].to_usize().unwrap() - offsets[0].to_usize().unwrap(), ); - for (i, len) in offsets.lengths().enumerate() { + for (i, w) in offsets.windows(2).enumerate() { + let len = w[1].as_usize() - w[0].as_usize(); rows_number.extend(repeat_n(T::Native::usize_as(i), len)); } @@ -1674,36 +1689,44 @@ mod tests { #[test] fn test_list_array_values_row_number() { assert_eq!( - list_array_values_row_number::(&OffsetBuffer::from_lengths([ - 1, 3, 0, 2, - ])), + variable_size_list_values_row_number::( + &OffsetBuffer::from_lengths([1, 3, 0, 2,]) + ), Int32Array::from(vec![0, 1, 1, 1, 3, 3]) ); assert_eq!( - list_array_values_row_number::(&OffsetBuffer::from_lengths([])), + variable_size_list_values_row_number::( + &OffsetBuffer::from_lengths([]) + ), Int32Array::new_null(0) ); assert_eq!( - list_array_values_row_number::(&OffsetBuffer::from_lengths([0])), + variable_size_list_values_row_number::( + &OffsetBuffer::from_lengths([0]) + ), Int32Array::new_null(0) ); assert_eq!( - list_array_values_row_number::(&OffsetBuffer::from_lengths([ - 0, 0 - ])), + variable_size_list_values_row_number::( + &OffsetBuffer::from_lengths([0, 0]) + ), Int32Array::new_null(0) ); assert_eq!( - list_array_values_row_number::(&OffsetBuffer::from_lengths([1])), + variable_size_list_values_row_number::( + &OffsetBuffer::from_lengths([1]) + ), Int32Array::from(vec![0]) ); assert_eq!( - list_array_values_row_number::(&OffsetBuffer::from_lengths([2])), + variable_size_list_values_row_number::( + &OffsetBuffer::from_lengths([2]) + ), Int32Array::from(vec![0, 0]) ); } From 071919d0e6c42cfcd639b61763437710b8e9bdd7 Mon Sep 17 00:00:00 2001 From: gstvg <28798827+gstvg@users.noreply.github.com> Date: Mon, 6 Apr 2026 15:58:13 -0300 Subject: [PATCH 45/47] include index in lambda variable formatting --- .../physical-expr/src/expressions/lambda_variable.rs | 4 ++-- datafusion/sqllogictest/test_files/higher_order.slt | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/datafusion/physical-expr/src/expressions/lambda_variable.rs b/datafusion/physical-expr/src/expressions/lambda_variable.rs index 25c7d869a5801..3a7006e02b67f 100644 --- a/datafusion/physical-expr/src/expressions/lambda_variable.rs +++ b/datafusion/physical-expr/src/expressions/lambda_variable.rs @@ -78,7 +78,7 @@ impl LambdaVariable { impl std::fmt::Display for LambdaVariable { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - write!(f, "{}@", self.name) + write!(f, "{}@{}", self.name, self.index) } } @@ -130,7 +130,7 @@ impl PhysicalExpr for LambdaVariable { } fn fmt_sql(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.name) + write!(f, "{}@{}", self.name, self.index) } } diff --git a/datafusion/sqllogictest/test_files/higher_order.slt b/datafusion/sqllogictest/test_files/higher_order.slt index c31f0e871b35c..65c7f5bfdc404 100644 --- a/datafusion/sqllogictest/test_files/higher_order.slt +++ b/datafusion/sqllogictest/test_files/higher_order.slt @@ -87,7 +87,7 @@ logical_plan 01)Projection: array_transform(make_array(t.number), (v) -> CAST(v AS Float64) + Float64(3)) 02)--TableScan: t projection=[number] physical_plan -01)ProjectionExec: expr=[array_transform(make_array(number@0), (v) -> CAST(v@ AS Float64) + 3) as array_transform(make_array(t.number),(v) -> v + Float64(3))] +01)ProjectionExec: expr=[array_transform(make_array(number@0), (v) -> CAST(v@1 AS Float64) + 3) as array_transform(make_array(t.number),(v) -> v + Float64(3))] 02)--DataSourceExec: partitions=1, partition_sizes=[1] #cse should not eliminate subtrees containing lambdas @@ -110,7 +110,7 @@ logical_plan 02)--Projection: make_array(t.number) AS __common_expr_1 03)----TableScan: t projection=[number] physical_plan -01)ProjectionExec: expr=[array_transform(__common_expr_1@0, (v) -> CAST(v@ AS Int64) * 2) as array_transform(make_array(t.number),(v) -> v * Int64(2)), array_transform(__common_expr_1@0, (v) -> CAST(v@ AS Int64) * 2 - 1) as array_transform(make_array(t.number),(v) -> v * Int64(2) - Int64(1))] +01)ProjectionExec: expr=[array_transform(__common_expr_1@0, (v) -> CAST(v@1 AS Int64) * 2) as array_transform(make_array(t.number),(v) -> v * Int64(2)), array_transform(__common_expr_1@0, (v) -> CAST(v@1 AS Int64) * 2 - 1) as array_transform(make_array(t.number),(v) -> v * Int64(2) - Int64(1))] 02)--ProjectionExec: expr=[make_array(number@0) as __common_expr_1] 03)----DataSourceExec: partitions=1, partition_sizes=[1] @@ -130,7 +130,7 @@ logical_plan 01)Projection: array_transform(make_array(t.number), (v) -> v IS NOT NULL OR Boolean(NULL)) AS array_transform(make_array(t.number),(v) -> v = v) 02)--TableScan: t projection=[number] physical_plan -01)ProjectionExec: expr=[array_transform(make_array(number@0), (v) -> v@ IS NOT NULL OR NULL) as array_transform(make_array(t.number),(v) -> v = v)] +01)ProjectionExec: expr=[array_transform(make_array(number@0), (v) -> v@1 IS NOT NULL OR NULL) as array_transform(make_array(t.number),(v) -> v = v)] 02)--DataSourceExec: partitions=1, partition_sizes=[1] @@ -142,7 +142,7 @@ logical_plan 01)Projection: array_transform(CAST(CAST(t.list AS ListView(Int32)) AS List(Int32)), (a) -> CAST(a AS Int64) + Int64(1)) AS array_transform(arrow_cast(t.list,Utf8("ListView(Int32)")),(a) -> a + Int64(1)) 02)--TableScan: t projection=[list] physical_plan -01)ProjectionExec: expr=[array_transform(CAST(CAST(list@0 AS ListView(Int32)) AS List(Int32)), (a) -> CAST(a@ AS Int64) + 1) as array_transform(arrow_cast(t.list,Utf8("ListView(Int32)")),(a) -> a + Int64(1))] +01)ProjectionExec: expr=[array_transform(CAST(CAST(list@0 AS ListView(Int32)) AS List(Int32)), (a) -> CAST(a@1 AS Int64) + 1) as array_transform(arrow_cast(t.list,Utf8("ListView(Int32)")),(a) -> a + Int64(1))] 02)--DataSourceExec: partitions=1, partition_sizes=[1] query ? From 4932cae74e4338bc6e148887cd862f7ab5f4d43c Mon Sep 17 00:00:00 2001 From: gstvg <28798827+gstvg@users.noreply.github.com> Date: Mon, 6 Apr 2026 16:06:19 -0300 Subject: [PATCH 46/47] add more sqllogictests --- .../sqllogictest/test_files/higher_order.slt | 83 ++++++++++++++++++- 1 file changed, 79 insertions(+), 4 deletions(-) diff --git a/datafusion/sqllogictest/test_files/higher_order.slt b/datafusion/sqllogictest/test_files/higher_order.slt index 65c7f5bfdc404..79c3eb361d08d 100644 --- a/datafusion/sqllogictest/test_files/higher_order.slt +++ b/datafusion/sqllogictest/test_files/higher_order.slt @@ -23,11 +23,11 @@ statement ok set datafusion.sql_parser.dialect = databricks; statement ok -CREATE TABLE t (list array, number int) +CREATE TABLE t (text varchar, list array, number int) AS VALUES -([1, 50], 10), -([4, 50], 40), -([7, 50], 60); +('a', [1, 50], 10), +('b', [4, 50], 40), +('c', [7, 50], 60); query ? SELECT array_transform([1,2,3,4,5], v -> v*2); @@ -174,6 +174,81 @@ order by t.number; 40 [4, 50] [5, 51] 60 [7, 50] [8, 51] +#case with inner nested higher order function +query T?I? +select + t.text, + t.list, + t.number, + case + when t.number > 30 then array_transform( + [[t.list]], + list -> array_transform( + list, + list -> array_transform( + list, + v -> number + v + list[1] + ) + ) + ) + else array_transform( + [[t.list]], + list -> array_transform( + list, + list -> array_transform( + list, + v -> number + list[1] + ) + ) + ) + end +from t +order by t.number; +---- +a [1, 50] 10 [[[11, 11]]] +b [4, 50] 40 [[[48, 94]]] +c [7, 50] 60 [[[74, 117]]] + +#explain case with inner nested higher order function +query TT +explain select + t.text, + t.list, + t.number, + case + when t.number > 30 then array_transform( + [[t.list]], + list -> array_transform( + list, + list -> array_transform( + list, + v -> number + v + list[1] + ) + ) + ) + else array_transform( + [[t.list]], + list -> array_transform( + list, + list -> array_transform( + list, + v -> number + list[1] + ) + ) + ) + end +from t +order by t.number; +---- +logical_plan +01)Sort: t.number ASC NULLS LAST +02)--Projection: t.text, t.list, t.number, CASE WHEN t.number > Int32(30) THEN array_transform(make_array(make_array(t.list)), (list) -> array_transform(list, (list) -> array_transform(list, (v) -> t.number + v + array_element(list, Int64(1))))) ELSE array_transform(make_array(make_array(t.list)), (list) -> array_transform(list, (list) -> array_transform(list, (v) -> t.number + array_element(list, Int64(1))))) END AS CASE WHEN t.number > Int64(30) THEN array_transform(make_array(make_array(t.list)),(list) -> array_transform(list,(list) -> array_transform(list,(v) -> t.number + v + list[Int64(1)]))) ELSE array_transform(make_array(make_array(t.list)),(list) -> array_transform(list,(list) -> array_transform(list,(v) -> t.number + list[Int64(1)]))) END +03)----TableScan: t projection=[text, list, number] +physical_plan +01)SortExec: expr=[number@2 ASC NULLS LAST], preserve_partitioning=[false] +02)--ProjectionExec: expr=[text@0 as text, list@1 as list, number@2 as number, CASE WHEN number@2 > 30 THEN array_transform(make_array(make_array(list@1)), (list) -> array_transform(list@3, (list) -> array_transform(list@4, (v) -> number@2 + v@5 + array_element(list@4, 1)))) ELSE array_transform(make_array(make_array(list@1)), (list) -> array_transform(list@3, (list) -> array_transform(list@4, (v) -> number@2 + array_element(list@4, 1)))) END as CASE WHEN t.number > Int64(30) THEN array_transform(make_array(make_array(t.list)),(list) -> array_transform(list,(list) -> array_transform(list,(v) -> t.number + v + list[Int64(1)]))) ELSE array_transform(make_array(make_array(t.list)),(list) -> array_transform(list,(list) -> array_transform(list,(v) -> t.number + list[Int64(1)]))) END] +03)----DataSourceExec: partitions=1, partition_sizes=[1] + # higher order function with inner case using lambda variables and captured column query ? select array_transform([1, 5, 9], v -> case when v = 1 then t.number when v = 5 then 6 else 8 end) from t; From 5ae3868c761d9e0adad2c90ad2306d468661c708 Mon Sep 17 00:00:00 2001 From: gstvg <28798827+gstvg@users.noreply.github.com> Date: Wed, 8 Apr 2026 12:20:50 -0300 Subject: [PATCH 47/47] qualify lambda fields during physical planning --- datafusion/expr/src/execution_props.rs | 18 +++- datafusion/expr/src/udhof.rs | 36 +------- .../physical-expr/src/expressions/lambda.rs | 2 +- .../src/expressions/lambda_variable.rs | 22 +---- .../physical-expr/src/expressions/mod.rs | 2 +- datafusion/physical-expr/src/planner.rs | 87 ++++++++++--------- .../sqllogictest/test_files/higher_order.slt | 34 ++++++++ 7 files changed, 104 insertions(+), 97 deletions(-) diff --git a/datafusion/expr/src/execution_props.rs b/datafusion/expr/src/execution_props.rs index 3bf6978eb60ee..1932328c7bde5 100644 --- a/datafusion/expr/src/execution_props.rs +++ b/datafusion/expr/src/execution_props.rs @@ -17,9 +17,9 @@ use crate::var_provider::{VarProvider, VarType}; use chrono::{DateTime, Utc}; -use datafusion_common::HashMap; use datafusion_common::alias::AliasGenerator; use datafusion_common::config::ConfigOptions; +use datafusion_common::{HashMap, TableReference}; use std::sync::Arc; /// Holds per-query execution properties and data (such as statement @@ -42,6 +42,7 @@ pub struct ExecutionProps { pub config_options: Option>, /// Providers for scalar variables pub var_providers: Option>>, + pub lambda_variable_qualifier: HashMap, } impl Default for ExecutionProps { @@ -58,6 +59,7 @@ impl ExecutionProps { alias_generator: Arc::new(AliasGenerator::new()), config_options: None, var_providers: None, + lambda_variable_qualifier: HashMap::new(), } } @@ -117,6 +119,20 @@ impl ExecutionProps { pub fn config_options(&self) -> Option<&Arc> { self.config_options.as_ref() } + + pub fn with_qualified_lambda_variables( + mut self, + qualifier: &TableReference, + variables: &[String], + ) -> Self { + for var in variables { + self.lambda_variable_qualifier + .entry_ref(var) + .insert(qualifier.clone()); + } + + self + } } #[cfg(test)] diff --git a/datafusion/expr/src/udhof.rs b/datafusion/expr/src/udhof.rs index a54dcf34560d8..da502cee461e8 100644 --- a/datafusion/expr/src/udhof.rs +++ b/datafusion/expr/src/udhof.rs @@ -265,37 +265,13 @@ fn merge_captures_with_variables( match captures { Some(captures) => { - let old_fields = captures.schema_ref().fields(); - - let mut new_fields = old_fields + let new_fields = captures.schema_ref() + .fields() .iter() - .map(|field| { - if !fields_contains(params, field.name()) { - return Arc::clone(field); - } - - let mut i = 0; - - loop { - let alias = format!("{}_shadowed_{i}", field.name()); - - if !fields_contains(params, &alias) - && old_fields.find(&alias).is_none() - { - break Arc::new(Field::new( - alias, - field.data_type().clone(), - field.is_nullable(), - )); - } - - i += 1; - } - }) + .chain(params) + .cloned() .collect::>(); - new_fields.extend_from_slice(params); - let mut columns = captures.columns().to_vec(); for arg in &variables[..params.len()] { @@ -320,10 +296,6 @@ fn merge_captures_with_variables( } } -fn fields_contains(fields: &[FieldRef], name: &str) -> bool { - fields.iter().any(|f| f.name().as_str() == name) -} - /// Information about arguments passed to the function /// /// This structure contains metadata about how the function was called diff --git a/datafusion/physical-expr/src/expressions/lambda.rs b/datafusion/physical-expr/src/expressions/lambda.rs index aa0247bab8699..4632079dd3907 100644 --- a/datafusion/physical-expr/src/expressions/lambda.rs +++ b/datafusion/physical-expr/src/expressions/lambda.rs @@ -94,7 +94,7 @@ impl LambdaExpr { .collect::>(); let projected_body = Arc::clone(&body) - .transform(|e| { + .transform_down(|e| { if let Some(column) = e.as_any().downcast_ref::() { let original = column.index(); let projected = *column_index_map.get(&original).unwrap(); diff --git a/datafusion/physical-expr/src/expressions/lambda_variable.rs b/datafusion/physical-expr/src/expressions/lambda_variable.rs index 3a7006e02b67f..68514f8fa8570 100644 --- a/datafusion/physical-expr/src/expressions/lambda_variable.rs +++ b/datafusion/physical-expr/src/expressions/lambda_variable.rs @@ -28,7 +28,7 @@ use arrow::{ record_batch::RecordBatch, }; -use datafusion_common::{Result, internal_err, plan_err}; +use datafusion_common::{Result, internal_err}; use datafusion_expr::ColumnarValue; /// Represents the lambda variable with a given name and field @@ -133,23 +133,3 @@ impl PhysicalExpr for LambdaVariable { write!(f, "{}@{}", self.name, self.index) } } - -/// Create a lambda variable expression -pub fn lambda_variable( - name: impl Into, - field: FieldRef, - schema: &Schema, -) -> Result> { - let name = name.into(); - let index = schema.index_of(&name)?; - - let schema_field = schema.field(index); - - if field.as_ref() != schema_field { - return plan_err!( - "LambdaVariable owned field differ from schema field {field} != {schema_field}" - ); - } - - Ok(Arc::new(LambdaVariable::new(name, index, field))) -} diff --git a/datafusion/physical-expr/src/expressions/mod.rs b/datafusion/physical-expr/src/expressions/mod.rs index 0d49910a3554f..b994e4ac98481 100644 --- a/datafusion/physical-expr/src/expressions/mod.rs +++ b/datafusion/physical-expr/src/expressions/mod.rs @@ -52,7 +52,7 @@ pub use in_list::{InListExpr, in_list}; pub use is_not_null::{IsNotNullExpr, is_not_null}; pub use is_null::{IsNullExpr, is_null}; pub use lambda::{LambdaExpr, lambda}; -pub use lambda_variable::{LambdaVariable, lambda_variable}; +pub use lambda_variable::LambdaVariable; pub use like::{LikeExpr, like}; pub use literal::{Literal, lit}; pub use negative::{NegativeExpr, negative}; diff --git a/datafusion/physical-expr/src/planner.rs b/datafusion/physical-expr/src/planner.rs index f2d69103ab0ef..f4b27ff75ff4a 100644 --- a/datafusion/physical-expr/src/planner.rs +++ b/datafusion/physical-expr/src/planner.rs @@ -23,12 +23,12 @@ use crate::{ expressions::{self, Column, Literal, binary, like, similar_to}, }; -use arrow::datatypes::{Field, Schema}; +use arrow::datatypes::Schema; use datafusion_common::config::ConfigOptions; use datafusion_common::metadata::{FieldMetadata, format_type_and_metadata}; use datafusion_common::{ - DFSchema, Result, ScalarValue, ToDFSchema, exec_err, internal_datafusion_err, - not_impl_err, plan_err, + DFSchema, Result, ScalarValue, TableReference, ToDFSchema, exec_err, + internal_datafusion_err, not_impl_err, plan_datafusion_err, plan_err, }; use datafusion_expr::execution_props::ExecutionProps; use datafusion_expr::expr::{ @@ -433,40 +433,24 @@ pub fn create_physical_expr( ); } + let mut lambda_qualifier = input_dfschema + .iter() + .filter_map(|(qualifier, _field)| { + qualifier.and_then(|tbl| { + tbl.table().strip_prefix("lambda_")?.parse::().ok() + }) + }) + .max() + .unwrap_or_default(); + let physical_args = args .iter() .map(|arg| match arg { Expr::Lambda(lambda) => { - let mut new_fields = input_dfschema - .iter() - .map(|(qualifier, field)| { - if !lambda.params.contains(field.name()) { - return (qualifier.cloned(), Arc::clone(field)); - } - - let mut i = 0; - - loop { - let alias = format!("{}_shadowed_{i}", field.name()); - - if !lambda.params.contains(&alias) - && !input_dfschema - .has_column_with_unqualified_name(&alias) - { - break ( - qualifier.cloned(), - Arc::new(Field::new( - alias, - field.data_type().clone(), - field.is_nullable(), - )), - ); - } - - i += 1; - } - }) - .collect::>(); + lambda_qualifier += 1; + + let qualifier = + TableReference::bare(format!("lambda_{lambda_qualifier}")); let lambda_parameters = lambda_parameters .next() @@ -477,16 +461,26 @@ pub fn create_physical_expr( })? .into_iter() .zip(&lambda.params) - .map(|(field, name)| (None, Arc::new(field.with_name(name)))); + .map(|(field, name)| { + (Some(qualifier.clone()), Arc::new(field.with_name(name))) + }); - new_fields.extend(lambda_parameters); + let new_fields = input_dfschema + .iter() + .map(|(tbl, field)| (tbl.cloned(), Arc::clone(field))) + .chain(lambda_parameters) + .collect(); let lambda_schema = DFSchema::new_with_metadata( new_fields, input_dfschema.metadata().clone(), )?; - create_physical_expr(arg, &lambda_schema, execution_props) + let execution_props = execution_props + .clone() + .with_qualified_lambda_variables(&qualifier, &lambda.params); + + create_physical_expr(arg, &lambda_schema, &execution_props) } _ => create_physical_expr(arg, input_dfschema, execution_props), }) @@ -512,11 +506,22 @@ pub fn create_physical_expr( name, field, spans: _, - }) => expressions::lambda_variable( - name, - Arc::clone(field), - input_dfschema.as_arrow(), - ), + }) => { + let qualifier = execution_props + .lambda_variable_qualifier + .get(name) + .ok_or_else(|| plan_datafusion_err!(""))?; + + let index = input_dfschema + .index_of_column_by_name(Some(qualifier), name) + .ok_or_else(|| plan_datafusion_err!(""))?; + + Ok(Arc::new(expressions::LambdaVariable::new( + name.clone(), + index, + Arc::clone(field), + ))) + } other => { not_impl_err!("Physical plan does not support logical expression {other:?}") } diff --git a/datafusion/sqllogictest/test_files/higher_order.slt b/datafusion/sqllogictest/test_files/higher_order.slt index 79c3eb361d08d..760dcace8b304 100644 --- a/datafusion/sqllogictest/test_files/higher_order.slt +++ b/datafusion/sqllogictest/test_files/higher_order.slt @@ -300,3 +300,37 @@ drop table t; statement ok set datafusion.sql_parser.dialect = generic; + + +statement ok +create table t as select 1 as c0, true as c1, 'a' as c2; + +query I +SELECT + case + when c0 = 1 then + case + when c1 = true then + case + when c2 = 'a' then 33 + end + end + end +FROM t; +---- +33 + +# create innermost case +# c2=2 +# projected c2=0 +# projection = [2] + +# create middle case +# c1=1, c2=2 +# projected c1=0, c2=1 +# projection = [1, 2] + +# create outermost case +# c0=0, c1=1, c2=2 +# projected c0=0, c1=1, c2=2 +# projection = [0, 1, 2]