From 4a3e55362bdf87ab8da537a4cc296f7b4033e990 Mon Sep 17 00:00:00 2001 From: Neil Conway Date: Fri, 20 Feb 2026 14:21:30 -0500 Subject: [PATCH 01/28] Initial implementation --- datafusion/core/Cargo.toml | 4 + .../core/benches/scalar_subquery_sql.rs | 119 +++++ datafusion/core/src/physical_planner.rs | 361 ++++++++++++--- datafusion/expr/src/execution_props.rs | 20 +- .../optimizer/src/scalar_subquery_to_join.rs | 105 ++--- .../optimizer/tests/optimizer_integration.rs | 14 +- datafusion/physical-expr/src/lib.rs | 1 + datafusion/physical-expr/src/planner.rs | 24 + .../physical-expr/src/scalar_subquery.rs | 225 ++++++++++ datafusion/physical-plan/src/lib.rs | 1 + .../physical-plan/src/scalar_subquery.rs | 416 ++++++++++++++++++ datafusion/proto/proto/datafusion.proto | 14 + datafusion/proto/src/bytes/mod.rs | 12 +- datafusion/proto/src/generated/pbjson.rs | 264 +++++++++++ datafusion/proto/src/generated/prost.rs | 24 +- .../proto/src/physical_plan/from_proto.rs | 25 +- datafusion/proto/src/physical_plan/mod.rs | 138 +++++- .../proto/src/physical_plan/to_proto.rs | 14 +- .../tests/cases/roundtrip_physical_plan.rs | 108 ++++- .../sqllogictest/test_files/metadata.slt | 6 +- .../sqllogictest/test_files/subquery.slt | 255 +++++++++-- 21 files changed, 1968 insertions(+), 182 deletions(-) create mode 100644 datafusion/core/benches/scalar_subquery_sql.rs create mode 100644 datafusion/physical-expr/src/scalar_subquery.rs create mode 100644 datafusion/physical-plan/src/scalar_subquery.rs diff --git a/datafusion/core/Cargo.toml b/datafusion/core/Cargo.toml index c1230f7d5daa6..f193acebc0b29 100644 --- a/datafusion/core/Cargo.toml +++ b/datafusion/core/Cargo.toml @@ -297,3 +297,7 @@ required-features = ["parquet"] [[bench]] harness = false name = "reset_plan_states" + +[[bench]] +harness = false +name = "scalar_subquery_sql" diff --git a/datafusion/core/benches/scalar_subquery_sql.rs b/datafusion/core/benches/scalar_subquery_sql.rs new file mode 100644 index 0000000000000..bd5d25a96123f --- /dev/null +++ b/datafusion/core/benches/scalar_subquery_sql.rs @@ -0,0 +1,119 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Benchmarks for uncorrelated scalar subquery evaluation. +//! +//! Measures the overhead of subquery execution machinery by using simple +//! arithmetic and comparison operators that don't have specialized scalar +//! fast paths, keeping the comparison between the old (join-based) and new +//! (ScalarSubqueryExec-based) approaches apples-to-apples. + +use arrow::array::Int64Array; +use arrow::datatypes::{DataType, Field, Schema}; +use arrow::record_batch::RecordBatch; +use criterion::{Criterion, criterion_group, criterion_main}; +use datafusion::datasource::MemTable; +use datafusion::error::Result; +use datafusion::prelude::SessionContext; +use std::hint::black_box; +use std::sync::Arc; +use tokio::runtime::Runtime; + +fn query(ctx: &SessionContext, rt: &Runtime, sql: &str) { + let df = rt.block_on(ctx.sql(sql)).unwrap(); + black_box(rt.block_on(df.collect()).unwrap()); +} + +fn create_context(num_rows: usize) -> Result { + let schema = Arc::new(Schema::new(vec![ + Field::new("x", DataType::Int64, false), + ])); + + let batch_size = 4096; + let batches = (0..num_rows / batch_size) + .map(|i| { + let values: Vec = + ((i * batch_size) as i64..((i + 1) * batch_size) as i64).collect(); + RecordBatch::try_new(schema.clone(), vec![Arc::new(Int64Array::from(values))]) + .unwrap() + }) + .collect::>(); + + // Small lookup table for the subquery to read from. + let sq_schema = Arc::new(Schema::new(vec![ + Field::new("v", DataType::Int64, false), + ])); + let sq_batch = RecordBatch::try_new( + sq_schema.clone(), + vec![Arc::new(Int64Array::from(vec![10, 20, 30]))], + )?; + + let ctx = SessionContext::new(); + ctx.register_table("main_t", Arc::new(MemTable::try_new(schema, vec![batches])?))?; + ctx.register_table( + "lookup", + Arc::new(MemTable::try_new(sq_schema, vec![vec![sq_batch]])?), + )?; + + Ok(ctx) +} + +fn criterion_benchmark(c: &mut Criterion) { + let num_rows = 1_048_576; // 2^20 + let rt = Runtime::new().unwrap(); + + // Scalar subquery in a filter (WHERE clause). + c.bench_function("scalar_subquery_filter", |b| { + let ctx = create_context(num_rows).unwrap(); + b.iter(|| { + query( + &ctx, + &rt, + "SELECT x FROM main_t WHERE x > (SELECT max(v) FROM lookup)", + ) + }) + }); + + // Scalar subquery in a projection (SELECT expression). + c.bench_function("scalar_subquery_projection", |b| { + let ctx = create_context(num_rows).unwrap(); + b.iter(|| { + query( + &ctx, + &rt, + "SELECT x + (SELECT max(v) FROM lookup) AS y FROM main_t", + ) + }) + }); + + // Two scalar subqueries in one query. + c.bench_function("scalar_subquery_two_subqueries", |b| { + let ctx = create_context(num_rows).unwrap(); + b.iter(|| { + query( + &ctx, + &rt, + "SELECT x FROM main_t \ + WHERE x > (SELECT min(v) FROM lookup) \ + AND x < (SELECT max(v) FROM lookup) + 1000000", + ) + }) + }); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index e25969903521c..14734a774a165 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -19,7 +19,7 @@ use std::borrow::Cow; use std::collections::{HashMap, HashSet}; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; use crate::datasource::file_format::file_type_to_format; use crate::datasource::listing::ListingTableUrl; @@ -63,6 +63,7 @@ use arrow::datatypes::Schema; use arrow_schema::Field; use datafusion_catalog::ScanArgs; use datafusion_common::Column; +use datafusion_common::HashMap as DFHashMap; use datafusion_common::display::ToStringifiedPlan; use datafusion_common::format::ExplainAnalyzeCategories; use datafusion_common::tree_node::{ @@ -83,6 +84,7 @@ use datafusion_expr::expr::{ WindowFunction, WindowFunctionParams, physical_name, }; use datafusion_expr::expr_rewriter::unnormalize_cols; +use datafusion_expr::logical_plan::Subquery; use datafusion_expr::logical_plan::builder::wrap_projection_for_join_if_necessary; use datafusion_expr::utils::{expr_to_columns, split_conjunction}; use datafusion_expr::{ @@ -101,6 +103,7 @@ use datafusion_physical_plan::execution_plan::InvariantLevel; use datafusion_physical_plan::joins::PiecewiseMergeJoinExec; use datafusion_physical_plan::placeholder_row::PlaceholderRowExec; use datafusion_physical_plan::recursive_query::RecursiveQueryExec; +use datafusion_physical_plan::scalar_subquery::{ScalarSubqueryExec, ScalarSubqueryLink}; use datafusion_physical_plan::unnest::ListUnnest; use async_trait::async_trait; @@ -379,8 +382,98 @@ impl DefaultPhysicalPlanner { Ok(()) } - /// Create a physical plan from a logical plan - async fn create_initial_plan( + /// Collect uncorrelated scalar subqueries. We don't descend into nested + /// subqueries here: each call to `create_initial_plan` handles subqueries + /// at its level and then recurses in order to handle nested subqueries. + fn collect_scalar_subqueries(plan: &LogicalPlan) -> Vec { + let mut subqueries = Vec::new(); + let mut seen = HashSet::new(); + plan.apply(|node| { + for expr in node.expressions() { + expr.apply(|e| { + if let Expr::ScalarSubquery(sq) = e + && sq.outer_ref_columns.is_empty() + && seen.insert(sq.clone()) + { + subqueries.push(sq.clone()); + } + Ok(TreeNodeRecursion::Continue) + }) + .expect("infallible"); + } + Ok(TreeNodeRecursion::Continue) + }) + .expect("infallible"); + subqueries + } + + /// Create a physical plan from a logical plan. + /// + /// Uncorrelated scalar subqueries in the plan's own expressions are + /// collected, planned as separate physical plans, and each assigned a + /// shared [`OnceLock`] slot that will hold its result at execution time. + /// These slots are registered in [`ExecutionProps`] so that + /// [`create_physical_expr`] can convert `Expr::ScalarSubquery` into + /// [`ScalarSubqueryExpr`] nodes that read from the slots. + /// + /// The resulting physical plan is wrapped in a [`ScalarSubqueryExec`] + /// that owns and executes those subquery plans before any data flows + /// through the main plan. If a subquery itself contains nested + /// uncorrelated subqueries, the recursive call produces its own + /// [`ScalarSubqueryExec`] inside the subquery plan — each level + /// manages only its own subqueries. + /// + /// Returns a [`BoxFuture`] rather than using `async fn` because of + /// this recursion. + /// + /// [`ScalarSubqueryExpr`]: datafusion_physical_expr::scalar_subquery::ScalarSubqueryExpr + /// [`BoxFuture`]: futures::future::BoxFuture + fn create_initial_plan<'a>( + &'a self, + logical_plan: &'a LogicalPlan, + session_state: &'a SessionState, + ) -> futures::future::BoxFuture<'a, Result>> { + Box::pin(async move { + let all_subqueries = Self::collect_scalar_subqueries(logical_plan); + let all_sq_refs: Vec<&_> = all_subqueries.iter().collect(); + let (links, index_map) = self + .plan_scalar_subqueries(&all_sq_refs, session_state) + .await?; + + // Create the shared results container and register it (along with + // the index map) in ExecutionProps so that `create_physical_expr` + // can resolve `Expr::ScalarSubquery` into `ScalarSubqueryExpr` + // nodes. We clone the SessionState so these are available + // throughout physical planning without mutating the caller's state. + // + // Ideally, the subquery state would live in a dedicated planning + // context rather than on ExecutionProps (which is meant for + // session-level configuration). It's here because + // `create_physical_expr` only receives `&ExecutionProps`, and + // changing that signature would be a breaking public API change. + let results: Arc>> = + Arc::new((0..links.len()).map(|_| OnceLock::new()).collect()); + let session_state = if links.is_empty() { + Cow::Borrowed(session_state) + } else { + let mut owned = session_state.clone(); + owned.execution_props_mut().subquery_indexes = index_map; + owned.execution_props_mut().subquery_results = Arc::clone(&results); + Cow::Owned(owned) + }; + + let plan = self + .create_initial_plan_inner(logical_plan, &session_state) + .await?; + Ok(Self::wrap_scalar_subquery_exec_if_needed( + plan, links, results, + )) + }) + } + + /// Inner physical planning that converts a logical plan tree into an + /// execution plan tree without collecting scalar subqueries. + async fn create_initial_plan_inner( &self, logical_plan: &LogicalPlan, session_state: &SessionState, @@ -545,6 +638,7 @@ impl DefaultPhysicalPlanner { session_state: &SessionState, children: ChildrenContainer, ) -> Result> { + let execution_props = session_state.execution_props(); let exec_node: Arc = match node { // Leaves (no children) LogicalPlan::TableScan(scan) => { @@ -601,7 +695,7 @@ impl DefaultPhysicalPlanner { .map(|row| { row.iter() .map(|expr| { - self.create_physical_expr(expr, schema, session_state) + create_physical_expr(expr, schema, execution_props) }) .collect::>>>() }) @@ -868,13 +962,7 @@ impl DefaultPhysicalPlanner { let logical_schema = node.schema(); let window_expr = window_expr .iter() - .map(|e| { - create_window_expr( - e, - logical_schema, - session_state.execution_props(), - ) - }) + .map(|e| create_window_expr(e, logical_schema, execution_props)) .collect::>>()?; let can_repartition = session_state.config().target_partitions() > 1 @@ -979,7 +1067,7 @@ impl DefaultPhysicalPlanner { group_expr, logical_input_schema, &physical_input_schema, - session_state, + execution_props, )?; let agg_filter = aggr_expr @@ -989,7 +1077,7 @@ impl DefaultPhysicalPlanner { e, logical_input_schema, &physical_input_schema, - session_state.execution_props(), + execution_props, ) }) .collect::>>()?; @@ -1083,8 +1171,8 @@ impl DefaultPhysicalPlanner { )?) } LogicalPlan::Projection(Projection { input, expr, .. }) => self - .create_project_physical_exec( - session_state, + .create_project_physical_exec_with_props( + execution_props, children.one()?, input, expr, @@ -1094,9 +1182,8 @@ impl DefaultPhysicalPlanner { }) => { let physical_input = children.one()?; let input_dfschema = input.schema(); - let runtime_expr = - self.create_physical_expr(predicate, input_dfschema, session_state)?; + create_physical_expr(predicate, input_dfschema, execution_props)?; let input_schema = input.schema(); let filter = match self.try_plan_async_exprs( @@ -1144,7 +1231,9 @@ impl DefaultPhysicalPlanner { .options() .optimizer .default_filter_selectivity; - Arc::new(filter.with_default_selectivity(selectivity)?) + let filter_exec: Arc = + Arc::new(filter.with_default_selectivity(selectivity)?); + filter_exec } LogicalPlan::Repartition(Repartition { input, @@ -1160,11 +1249,7 @@ impl DefaultPhysicalPlanner { let runtime_expr = expr .iter() .map(|e| { - self.create_physical_expr( - e, - input_dfschema, - session_state, - ) + create_physical_expr(e, input_dfschema, execution_props) }) .collect::>>()?; Partitioning::Hash(runtime_expr, *n) @@ -1185,11 +1270,8 @@ impl DefaultPhysicalPlanner { }) => { let physical_input = children.one()?; let input_dfschema = input.as_ref().schema(); - let sort_exprs = create_physical_sort_exprs( - expr, - input_dfschema, - session_state.execution_props(), - )?; + let sort_exprs = + create_physical_sort_exprs(expr, input_dfschema, execution_props)?; let Some(ordering) = LexOrdering::new(sort_exprs) else { return internal_err!( "SortExec requires at least one sort expression" @@ -1309,8 +1391,8 @@ impl DefaultPhysicalPlanner { ( true, LogicalPlan::Projection(Projection { input, expr, .. }), - ) => self.create_project_physical_exec( - session_state, + ) => self.create_project_physical_exec_with_props( + execution_props, physical_left, input, expr, @@ -1322,8 +1404,8 @@ impl DefaultPhysicalPlanner { ( true, LogicalPlan::Projection(Projection { input, expr, .. }), - ) => self.create_project_physical_exec( - session_state, + ) => self.create_project_physical_exec_with_props( + execution_props, physical_right, input, expr, @@ -1388,7 +1470,6 @@ impl DefaultPhysicalPlanner { // All equi-join keys are columns now, create physical join plan let left_df_schema = left.schema(); let right_df_schema = right.schema(); - let execution_props = session_state.execution_props(); let join_on = keys .iter() .map(|(l, r)| { @@ -1494,7 +1575,7 @@ impl DefaultPhysicalPlanner { let filter_expr = create_physical_expr( expr, &filter_df_schema, - session_state.execution_props(), + execution_props, )?; let column_indices = join_utils::JoinFilter::build_column_indices( left_field_indices, @@ -1614,12 +1695,12 @@ impl DefaultPhysicalPlanner { let on_left = create_physical_expr( lhs_logical, left_df_schema, - session_state.execution_props(), + execution_props, )?; let on_right = create_physical_expr( rhs_logical, right_df_schema, - session_state.execution_props(), + execution_props, )?; Arc::new(PiecewiseMergeJoinExec::try_new( @@ -1689,7 +1770,12 @@ impl DefaultPhysicalPlanner { // If plan was mutated previously then need to create the ExecutionPlan // for the new Projection that was applied on top. if let Some((input, expr)) = new_project { - self.create_project_physical_exec(session_state, join, input, expr)? + self.create_project_physical_exec_with_props( + execution_props, + join, + input, + expr, + )? } else { join } @@ -1783,7 +1869,7 @@ impl DefaultPhysicalPlanner { group_expr: &[Expr], input_dfschema: &DFSchema, input_schema: &Schema, - session_state: &SessionState, + execution_props: &ExecutionProps, ) -> Result { if group_expr.len() == 1 { match &group_expr[0] { @@ -1792,25 +1878,25 @@ impl DefaultPhysicalPlanner { grouping_sets, input_dfschema, input_schema, - session_state, + execution_props, ) } Expr::GroupingSet(GroupingSet::Cube(exprs)) => create_cube_physical_expr( exprs, input_dfschema, input_schema, - session_state, + execution_props, ), Expr::GroupingSet(GroupingSet::Rollup(exprs)) => { create_rollup_physical_expr( exprs, input_dfschema, input_schema, - session_state, + execution_props, ) } expr => Ok(PhysicalGroupBy::new_single(vec![tuple_err(( - self.create_physical_expr(expr, input_dfschema, session_state), + create_physical_expr(expr, input_dfschema, execution_props), physical_name(expr), ))?])), } @@ -1824,7 +1910,7 @@ impl DefaultPhysicalPlanner { .iter() .map(|e| { tuple_err(( - self.create_physical_expr(e, input_dfschema, session_state), + create_physical_expr(e, input_dfschema, execution_props), physical_name(e), )) }) @@ -1848,7 +1934,7 @@ fn merge_grouping_set_physical_expr( grouping_sets: &[Vec], input_dfschema: &DFSchema, input_schema: &Schema, - session_state: &SessionState, + execution_props: &ExecutionProps, ) -> Result { let num_groups = grouping_sets.len(); let mut all_exprs: Vec = vec![]; @@ -1862,14 +1948,14 @@ fn merge_grouping_set_physical_expr( grouping_set_expr.push(get_physical_expr_pair( expr, input_dfschema, - session_state, + execution_props, )?); null_exprs.push(get_null_physical_expr_pair( expr, input_dfschema, input_schema, - session_state, + execution_props, )?); } } @@ -1899,7 +1985,7 @@ fn create_cube_physical_expr( exprs: &[Expr], input_dfschema: &DFSchema, input_schema: &Schema, - session_state: &SessionState, + execution_props: &ExecutionProps, ) -> Result { let num_of_exprs = exprs.len(); let num_groups = num_of_exprs * num_of_exprs; @@ -1914,10 +2000,14 @@ fn create_cube_physical_expr( expr, input_dfschema, input_schema, - session_state, + execution_props, )?); - all_exprs.push(get_physical_expr_pair(expr, input_dfschema, session_state)?) + all_exprs.push(get_physical_expr_pair( + expr, + input_dfschema, + execution_props, + )?) } let mut groups: Vec> = Vec::with_capacity(num_groups); @@ -1941,7 +2031,7 @@ fn create_rollup_physical_expr( exprs: &[Expr], input_dfschema: &DFSchema, input_schema: &Schema, - session_state: &SessionState, + execution_props: &ExecutionProps, ) -> Result { let num_of_exprs = exprs.len(); @@ -1957,10 +2047,14 @@ fn create_rollup_physical_expr( expr, input_dfschema, input_schema, - session_state, + execution_props, )?); - all_exprs.push(get_physical_expr_pair(expr, input_dfschema, session_state)?) + all_exprs.push(get_physical_expr_pair( + expr, + input_dfschema, + execution_props, + )?) } for total in 0..=num_of_exprs { @@ -1985,10 +2079,9 @@ fn get_null_physical_expr_pair( expr: &Expr, input_dfschema: &DFSchema, input_schema: &Schema, - session_state: &SessionState, + execution_props: &ExecutionProps, ) -> Result<(Arc, String)> { - let physical_expr = - create_physical_expr(expr, input_dfschema, session_state.execution_props())?; + let physical_expr = create_physical_expr(expr, input_dfschema, execution_props)?; let physical_name = physical_name(&expr.clone())?; let data_type = physical_expr.data_type(input_schema)?; @@ -2057,10 +2150,9 @@ fn qualify_join_schema_sides( fn get_physical_expr_pair( expr: &Expr, input_dfschema: &DFSchema, - session_state: &SessionState, + execution_props: &ExecutionProps, ) -> Result<(Arc, String)> { - let physical_expr = - create_physical_expr(expr, input_dfschema, session_state.execution_props())?; + let physical_expr = create_physical_expr(expr, input_dfschema, execution_props)?; let physical_name = physical_name(expr)?; Ok((physical_expr, physical_name)) } @@ -2839,9 +2931,49 @@ impl DefaultPhysicalPlanner { Ok(mem_exec) } - fn create_project_physical_exec( + /// Build physical plans for scalar subqueries and assign each an ordinal + /// index. Returns the links (plan + index) and a map from logical + /// `Subquery` to its index. + async fn plan_scalar_subqueries( &self, + subqueries: &[&Subquery], session_state: &SessionState, + ) -> Result<(Vec, DFHashMap)> { + let mut links = Vec::with_capacity(subqueries.len()); + let mut index_map = DFHashMap::with_capacity(subqueries.len()); + for &sq in subqueries { + // Callers deduplicate, but guard against accidental double-planning. + if index_map.contains_key(sq) { + continue; + } + let physical_plan = self + .create_initial_plan(&sq.subquery, session_state) + .await?; + let index = links.len(); + links.push(ScalarSubqueryLink { + plan: physical_plan, + index, + }); + index_map.insert(sq.clone(), index); + } + Ok((links, index_map)) + } + + fn wrap_scalar_subquery_exec_if_needed( + input: Arc, + subqueries: Vec, + results: Arc>>, + ) -> Arc { + if subqueries.is_empty() { + input + } else { + Arc::new(ScalarSubqueryExec::new(input, subqueries, results)) + } + } + + fn create_project_physical_exec_with_props( + &self, + execution_props: &ExecutionProps, input_exec: Arc, input: &Arc, expr: &[Expr], @@ -2880,7 +3012,7 @@ impl DefaultPhysicalPlanner { }; let physical_expr = - self.create_physical_expr(e, input_logical_schema, session_state); + create_physical_expr(e, input_logical_schema, execution_props); tuple_err((physical_expr, physical_name)) }) @@ -3125,7 +3257,7 @@ mod tests { use arrow_schema::SchemaRef; use datafusion_common::config::ConfigOptions; use datafusion_common::{ - DFSchemaRef, TableReference, ToDFSchema as _, assert_contains, + DFSchemaRef, TableReference, ToDFSchema as _, assert_batches_eq, assert_contains, }; use datafusion_execution::TaskContext; use datafusion_execution::runtime_env::RuntimeEnv; @@ -3159,6 +3291,16 @@ mod tests { .await } + async fn plan_sql(query: &str) -> Result> { + let ctx = SessionContext::new(); + ctx.sql(query).await?.create_physical_plan().await + } + + async fn collect_sql(query: &str) -> Result> { + let ctx = SessionContext::new(); + ctx.sql(query).await?.collect().await + } + #[tokio::test] async fn test_all_operators() -> Result<()> { let logical_plan = test_csv_scan() @@ -3182,6 +3324,99 @@ mod tests { Ok(()) } + #[tokio::test] + async fn scalar_subquery_in_sort_expr_plans() -> Result<()> { + let plan = plan_sql( + "SELECT x \ + FROM (VALUES (2), (1)) AS t(x) \ + ORDER BY x + (SELECT max(y) FROM (VALUES (10), (20)) AS u(y))", + ) + .await?; + + assert_contains!(format!("{plan:?}"), "ScalarSubqueryExec"); + Ok(()) + } + + #[tokio::test] + async fn scalar_subquery_in_sort_expr_executes() -> Result<()> { + let batches = collect_sql( + "SELECT x \ + FROM (VALUES (2), (1), (3)) AS t(x) \ + ORDER BY x + (SELECT max(y) FROM (VALUES (10), (20)) AS u(y)) DESC", + ) + .await?; + + assert_batches_eq!( + &[ + "+---+", "| x |", "+---+", "| 3 |", "| 2 |", "| 1 |", "+---+", + ], + &batches + ); + Ok(()) + } + + #[tokio::test] + async fn scalar_subquery_in_aggregate_arg_plans() -> Result<()> { + let plan = plan_sql( + "SELECT sum(x + (SELECT max(y) FROM (VALUES (10), (20)) AS u(y))) \ + FROM (VALUES (2), (1)) AS t(x)", + ) + .await?; + + assert_contains!(format!("{plan:?}"), "ScalarSubqueryExec"); + Ok(()) + } + + #[tokio::test] + async fn scalar_subquery_in_aggregate_arg_executes() -> Result<()> { + let batches = collect_sql( + "SELECT sum(x + (SELECT max(y) FROM (VALUES (10), (20)) AS u(y))) AS s \ + FROM (VALUES (2), (1)) AS t(x)", + ) + .await?; + + assert_batches_eq!( + &["+----+", "| s |", "+----+", "| 43 |", "+----+",], + &batches + ); + Ok(()) + } + + #[tokio::test] + async fn scalar_subquery_in_join_on_plans() -> Result<()> { + let plan = plan_sql( + "SELECT l.x, r.y \ + FROM (VALUES (1), (2)) AS l(x) \ + JOIN (VALUES (11), (12)) AS r(y) \ + ON l.x + (SELECT 10) = r.y", + ) + .await?; + + let formatted = format!("{plan:?}"); + assert_contains!(&formatted, "ScalarSubqueryExec"); + assert!( + formatted.contains("HashJoinExec") + || formatted.contains("SortMergeJoinExec") + || formatted.contains("NestedLoopJoinExec") + ); + Ok(()) + } + + #[tokio::test] + async fn scalar_subquery_in_projection_and_filter_plans() -> Result<()> { + let plan = plan_sql( + "SELECT x + (SELECT max(y) FROM (VALUES (10), (20)) AS u(y)) \ + FROM (VALUES (2), (1)) AS t(x) \ + WHERE x > (SELECT min(y) FROM (VALUES (0), (1)) AS v(y))", + ) + .await?; + + let formatted = format!("{plan:?}"); + // All uncorrelated scalar subqueries are hoisted to a single root node. + assert_eq!(formatted.matches("ScalarSubqueryExec").count(), 1); + Ok(()) + } + #[tokio::test] async fn test_create_cube_expr() -> Result<()> { let logical_plan = test_csv_scan().await?.build()?; @@ -3199,7 +3434,7 @@ mod tests { &exprs, logical_input_schema, physical_input_schema, - &session_state, + session_state.execution_props(), ); insta::assert_debug_snapshot!(cube, @r#" @@ -3330,7 +3565,7 @@ mod tests { &exprs, logical_input_schema, physical_input_schema, - &session_state, + session_state.execution_props(), ); insta::assert_debug_snapshot!(rollup, @r#" diff --git a/datafusion/expr/src/execution_props.rs b/datafusion/expr/src/execution_props.rs index 3bf6978eb60ee..0c9cfe2e06d1a 100644 --- a/datafusion/expr/src/execution_props.rs +++ b/datafusion/expr/src/execution_props.rs @@ -18,9 +18,17 @@ use crate::var_provider::{VarProvider, VarType}; use chrono::{DateTime, Utc}; use datafusion_common::HashMap; +use datafusion_common::ScalarValue; use datafusion_common::alias::AliasGenerator; use datafusion_common::config::ConfigOptions; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; + +/// Shared results container for uncorrelated scalar subqueries. +/// +/// Each entry corresponds to one scalar subquery, identified by its index. +/// The `OnceLock` is populated at execution time by `ScalarSubqueryExec` and +/// read by `ScalarSubqueryExpr` instances that share this container. +pub type ScalarSubqueryResults = Arc>>; /// Holds per-query execution properties and data (such as statement /// starting timestamps). @@ -42,6 +50,12 @@ pub struct ExecutionProps { pub config_options: Option>, /// Providers for scalar variables pub var_providers: Option>>, + /// Maps each logical `Subquery` to its index in `subquery_results`. + /// Populated by the physical planner before calling `create_physical_expr`. + pub subquery_indexes: HashMap, + /// Shared results container for uncorrelated scalar subquery values. + /// Populated at execution time by `ScalarSubqueryExec`. + pub subquery_results: ScalarSubqueryResults, } impl Default for ExecutionProps { @@ -58,6 +72,8 @@ impl ExecutionProps { alias_generator: Arc::new(AliasGenerator::new()), config_options: None, var_providers: None, + subquery_indexes: HashMap::new(), + subquery_results: Arc::new(vec![]), } } @@ -126,7 +142,7 @@ mod test { fn debug() { let props = ExecutionProps::new(); assert_eq!( - "ExecutionProps { query_execution_start_time: None, alias_generator: AliasGenerator { next_id: 1 }, config_options: None, var_providers: None }", + "ExecutionProps { query_execution_start_time: None, alias_generator: AliasGenerator { next_id: 1 }, config_options: None, var_providers: None, subquery_indexes: {}, subquery_results: [] }", format!("{props:?}") ); } diff --git a/datafusion/optimizer/src/scalar_subquery_to_join.rs b/datafusion/optimizer/src/scalar_subquery_to_join.rs index 975c234b38836..5315e4dcc2f2b 100644 --- a/datafusion/optimizer/src/scalar_subquery_to_join.rs +++ b/datafusion/optimizer/src/scalar_subquery_to_join.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -//! [`ScalarSubqueryToJoin`] rewriting scalar subquery filters to `JOIN`s +//! [`ScalarSubqueryToJoin`] rewriting correlated scalar subquery filters to `JOIN`s use std::collections::{BTreeSet, HashMap}; use std::sync::Arc; @@ -36,9 +36,9 @@ use datafusion_expr::logical_plan::{JoinType, Subquery}; use datafusion_expr::utils::conjunction; use datafusion_expr::{EmptyRelation, Expr, LogicalPlan, LogicalPlanBuilder, expr}; -/// Optimizer rule for rewriting subquery filters to joins -/// and places additional projection on top of the filter, to preserve -/// original schema. +/// Optimizer rule that rewrites correlated scalar subquery filters to joins and +/// places an additional projection on top of the filter, to preserve original +/// schema. #[derive(Default, Debug)] pub struct ScalarSubqueryToJoin {} @@ -48,10 +48,15 @@ impl ScalarSubqueryToJoin { Self::default() } - /// Finds expressions that have a scalar subquery in them (and recurses when found) + /// Finds expressions that contain correlated scalar subqueries (and + /// recurses when found) /// /// # Arguments - /// * `predicate` - A conjunction to split and search + /// * `predicate` - A conjunction to split and search. + /// * `alias_gen` - Generator used to produce unique aliases for each + /// extracted scalar subquery (e.g. `__scalar_sq_1`, `__scalar_sq_2`). + /// Each subquery is replaced by a column reference using the generated + /// alias, and the same alias is later used to construct the join. /// /// Returns a tuple (subqueries, alias) fn extract_subquery_exprs( @@ -137,8 +142,8 @@ impl OptimizerRule for ScalarSubqueryToJoin { Ok(Transformed::yes(new_plan)) } LogicalPlan::Projection(projection) => { - // Optimization: skip the rest of the rule and its copies if - // there are no scalar subqueries + // Optimization: skip the rest of the rule and its copies if there + // are no correlated scalar subqueries if !projection.expr.iter().any(contains_scalar_subquery) { return Ok(Transformed::no(LogicalPlan::Projection(projection))); } @@ -222,11 +227,14 @@ impl OptimizerRule for ScalarSubqueryToJoin { } } -/// Returns true if the expression has a scalar subquery somewhere in it -/// false otherwise +/// Returns true if the expression contains a correlated scalar subquery, false +/// otherwise. Uncorrelated scalar subqueries are handled by the physical +/// planner via `ScalarSubqueryExec` and do not need to be converted to joins. fn contains_scalar_subquery(expr: &Expr) -> bool { - expr.exists(|expr| Ok(matches!(expr, Expr::ScalarSubquery(_)))) - .expect("Inner is always Ok") + expr.exists(|expr| { + Ok(matches!(expr, Expr::ScalarSubquery(sq) if !sq.outer_ref_columns.is_empty())) + }) + .expect("Inner is always Ok") } struct ExtractScalarSubQuery<'a> { @@ -239,7 +247,11 @@ impl TreeNodeRewriter for ExtractScalarSubQuery<'_> { fn f_down(&mut self, expr: Expr) -> Result> { match expr { - Expr::ScalarSubquery(subquery) => { + // Skip uncorrelated scalar subqueries + Expr::ScalarSubquery(ref subquery) + if !subquery.outer_ref_columns.is_empty() => + { + let subquery = subquery.clone(); let subqry_alias = self.alias_gen.next("__scalar_sq"); self.sub_query_info .push((subquery.clone(), subqry_alias.clone())); @@ -624,15 +636,13 @@ mod tests { plan, @r" Projection: customer.c_custkey [c_custkey:Int64] - Projection: customer.c_custkey, customer.c_name [c_custkey:Int64, c_name:Utf8] - Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N] - Left Join: Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N] - TableScan: customer [c_custkey:Int64, c_name:Utf8] - SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N] - Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N] - Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N] - Filter: orders.o_custkey = orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] - TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + Filter: customer.c_custkey = () [c_custkey:Int64, c_name:Utf8] + Subquery: [max(orders.o_custkey):Int64;N] + Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N] + Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N] + Filter: orders.o_custkey = orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + TableScan: customer [c_custkey:Int64, c_name:Utf8] " ) } @@ -1029,14 +1039,12 @@ mod tests { plan, @r" Projection: customer.c_custkey [c_custkey:Int64] - Projection: customer.c_custkey, customer.c_name [c_custkey:Int64, c_name:Utf8] - Filter: customer.c_custkey < __scalar_sq_1.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N] - Left Join: Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N] - TableScan: customer [c_custkey:Int64, c_name:Utf8] - SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N] - Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N] - Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N] - TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + Filter: customer.c_custkey < () [c_custkey:Int64, c_name:Utf8] + Subquery: [max(orders.o_custkey):Int64;N] + Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N] + Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + TableScan: customer [c_custkey:Int64, c_name:Utf8] " ) } @@ -1059,14 +1067,12 @@ mod tests { plan, @r" Projection: customer.c_custkey [c_custkey:Int64] - Projection: customer.c_custkey, customer.c_name [c_custkey:Int64, c_name:Utf8] - Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N] - Left Join: Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N] - TableScan: customer [c_custkey:Int64, c_name:Utf8] - SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N] - Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N] - Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N] - TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + Filter: customer.c_custkey = () [c_custkey:Int64, c_name:Utf8] + Subquery: [max(orders.o_custkey):Int64;N] + Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N] + Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + TableScan: customer [c_custkey:Int64, c_name:Utf8] " ) } @@ -1158,19 +1164,16 @@ mod tests { plan, @r" Projection: customer.c_custkey [c_custkey:Int64] - Projection: customer.c_custkey, customer.c_name [c_custkey:Int64, c_name:Utf8] - Filter: customer.c_custkey BETWEEN __scalar_sq_1.min(orders.o_custkey) AND __scalar_sq_2.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, min(orders.o_custkey):Int64;N, max(orders.o_custkey):Int64;N] - Left Join: Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, min(orders.o_custkey):Int64;N, max(orders.o_custkey):Int64;N] - Left Join: Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, min(orders.o_custkey):Int64;N] - TableScan: customer [c_custkey:Int64, c_name:Utf8] - SubqueryAlias: __scalar_sq_1 [min(orders.o_custkey):Int64;N] - Projection: min(orders.o_custkey) [min(orders.o_custkey):Int64;N] - Aggregate: groupBy=[[]], aggr=[[min(orders.o_custkey)]] [min(orders.o_custkey):Int64;N] - TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] - SubqueryAlias: __scalar_sq_2 [max(orders.o_custkey):Int64;N] - Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N] - Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N] - TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + Filter: customer.c_custkey BETWEEN () AND () [c_custkey:Int64, c_name:Utf8] + Subquery: [min(orders.o_custkey):Int64;N] + Projection: min(orders.o_custkey) [min(orders.o_custkey):Int64;N] + Aggregate: groupBy=[[]], aggr=[[min(orders.o_custkey)]] [min(orders.o_custkey):Int64;N] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + Subquery: [max(orders.o_custkey):Int64;N] + Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N] + Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + TableScan: customer [c_custkey:Int64, c_name:Utf8] " ) } diff --git a/datafusion/optimizer/tests/optimizer_integration.rs b/datafusion/optimizer/tests/optimizer_integration.rs index fd4991c24413f..ef97d603acb6f 100644 --- a/datafusion/optimizer/tests/optimizer_integration.rs +++ b/datafusion/optimizer/tests/optimizer_integration.rs @@ -128,15 +128,13 @@ fn subquery_filter_with_cast() -> Result<()> { assert_snapshot!( format!("{plan}"), @r#" - Projection: test.col_int32 - Inner Join: Filter: CAST(test.col_int32 AS Float64) > __scalar_sq_1.avg(test.col_int32) - TableScan: test projection=[col_int32] - SubqueryAlias: __scalar_sq_1 + Filter: CAST(test.col_int32 AS Float64) > () + Subquery: + Projection: avg(test.col_int32) Aggregate: groupBy=[[]], aggr=[[avg(CAST(test.col_int32 AS Float64))]] - Projection: test.col_int32 - Filter: __common_expr_4 >= Date32("2002-05-08") AND __common_expr_4 <= Date32("2002-05-13") - Projection: CAST(test.col_utf8 AS Date32) AS __common_expr_4, test.col_int32 - TableScan: test projection=[col_int32, col_utf8] + Filter: CAST(test.col_utf8 AS Date32) >= Date32("2002-05-08") AND CAST(test.col_utf8 AS Date32) <= Date32("2002-05-13") + TableScan: test + TableScan: test projection=[col_int32] "# ); Ok(()) diff --git a/datafusion/physical-expr/src/lib.rs b/datafusion/physical-expr/src/lib.rs index bedd348dab92f..9c567ce862149 100644 --- a/datafusion/physical-expr/src/lib.rs +++ b/datafusion/physical-expr/src/lib.rs @@ -40,6 +40,7 @@ mod physical_expr; pub mod planner; pub mod projection; mod scalar_function; +pub mod scalar_subquery; pub mod simplifier; pub mod statistics; pub mod utils; diff --git a/datafusion/physical-expr/src/planner.rs b/datafusion/physical-expr/src/planner.rs index fd2de812e4664..1f1a2000a0bd3 100644 --- a/datafusion/physical-expr/src/planner.rs +++ b/datafusion/physical-expr/src/planner.rs @@ -18,6 +18,7 @@ use std::sync::Arc; use crate::ScalarFunctionExpr; +use crate::scalar_subquery::ScalarSubqueryExpr; use crate::{ PhysicalExpr, expressions::{self, Column, Literal, binary, like, similar_to}, @@ -409,6 +410,29 @@ pub fn create_physical_expr( expressions::in_list(value_expr, list_exprs, negated, input_schema) } }, + Expr::ScalarSubquery(sq) => { + match execution_props.subquery_indexes.get(sq) { + Some(&index) => { + let schema = sq.subquery.schema(); + let dt = schema.field(0).data_type().clone(); + let nullable = schema.field(0).is_nullable(); + Ok(Arc::new(ScalarSubqueryExpr::new( + dt, + nullable, + index, + Arc::clone(&execution_props.subquery_results), + ))) + } + None => { + // Not found: either a correlated subquery that wasn't + // rewritten to a join, or an uncorrelated one that wasn't + // registered by the physical planner. + not_impl_err!( + "Physical plan does not support logical expression {e:?}" + ) + } + } + } Expr::Placeholder(Placeholder { id, .. }) => { exec_err!("Placeholder '{id}' was not provided a value for execution.") } diff --git a/datafusion/physical-expr/src/scalar_subquery.rs b/datafusion/physical-expr/src/scalar_subquery.rs new file mode 100644 index 0000000000000..40d5ca66d9ec3 --- /dev/null +++ b/datafusion/physical-expr/src/scalar_subquery.rs @@ -0,0 +1,225 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Physical expression for uncorrelated scalar subqueries. +//! +//! [`ScalarSubqueryExpr`] reads a cached [`ScalarValue`] that is populated +//! at execution time by [`ScalarSubqueryExec`]. +//! +//! [`ScalarSubqueryExec`]: datafusion_physical_plan::scalar_subquery::ScalarSubqueryExec + +use std::any::Any; +use std::fmt; +use std::hash::Hash; +use std::sync::{Arc, OnceLock}; + +use arrow::datatypes::{DataType, Field, FieldRef, Schema}; +use arrow::record_batch::RecordBatch; +use datafusion_common::{Result, ScalarValue, internal_datafusion_err}; +use datafusion_expr_common::columnar_value::ColumnarValue; +use datafusion_expr_common::sort_properties::{ExprProperties, SortProperties}; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; + +/// A physical expression whose value is provided by a scalar subquery. +/// +/// Subquery execution is handled by `ScalarSubqueryExec`, which stores the +/// result in a shared results container. This expression simply reads from the +/// shared results container at the appropriate index. +/// +/// If the same subquery appears multiple times in a query, there will be +/// multiple `ScalarSubqueryExec` with the same result index. +#[derive(Debug)] +pub struct ScalarSubqueryExpr { + data_type: DataType, + nullable: bool, + /// Index of this subquery in the shared results container. + index: usize, + /// Shared results container populated by `ScalarSubqueryExec`. + results: Arc>>, +} + +impl ScalarSubqueryExpr { + pub fn new( + data_type: DataType, + nullable: bool, + index: usize, + results: Arc>>, + ) -> Self { + Self { + data_type, + nullable, + index, + results, + } + } + + /// Returns the index of this subquery in the shared results container. + pub fn index(&self) -> usize { + self.index + } + + pub fn results(&self) -> &Arc>> { + &self.results + } +} + +impl fmt::Display for ScalarSubqueryExpr { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self.results.get(self.index).and_then(|slot| slot.get()) { + Some(v) => write!(f, "scalar_subquery({v})"), + None => write!(f, "scalar_subquery()"), + } + } +} + +// Two ScalarSubqueryExprs are the "same" if they share the same results +// container and have the same index. This follows the DynamicFilterPhysicalExpr +// precedent of identity-based equality. +impl Hash for ScalarSubqueryExpr { + fn hash(&self, state: &mut H) { + Arc::as_ptr(&self.results).hash(state); + self.index.hash(state); + } +} + +impl PartialEq for ScalarSubqueryExpr { + fn eq(&self, other: &Self) -> bool { + Arc::ptr_eq(&self.results, &other.results) && self.index == other.index + } +} + +impl Eq for ScalarSubqueryExpr {} + +impl PhysicalExpr for ScalarSubqueryExpr { + fn as_any(&self) -> &dyn Any { + self + } + + fn return_field(&self, _input_schema: &Schema) -> Result { + Ok(Arc::new(Field::new( + format!("{self}"), + self.data_type.clone(), + self.nullable, + ))) + } + + fn evaluate(&self, _batch: &RecordBatch) -> Result { + let value = self + .results + .get(self.index) + .and_then(|slot| slot.get()) + .ok_or_else(|| { + internal_datafusion_err!( + "ScalarSubqueryExpr evaluated before the subquery was executed" + ) + })?; + Ok(ColumnarValue::Scalar(value.clone())) + } + + fn children(&self) -> Vec<&Arc> { + vec![] + } + + fn with_new_children( + self: Arc, + _children: Vec>, + ) -> Result> { + Ok(self) + } + + fn get_properties(&self, _children: &[ExprProperties]) -> Result { + Ok(ExprProperties::new_unknown().with_order(SortProperties::Singleton)) + } + + fn fmt_sql(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "(scalar subquery)") + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use arrow::array::Int32Array; + use arrow::datatypes::Field; + + fn make_results(values: Vec>) -> Arc>> { + Arc::new( + values + .into_iter() + .map(|v| { + let lock = OnceLock::new(); + if let Some(val) = v { + lock.set(val).unwrap(); + } + lock + }) + .collect(), + ) + } + + #[test] + fn test_evaluate_with_value() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]); + let a = Int32Array::from(vec![1, 2, 3]); + let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)])?; + + let results = make_results(vec![Some(ScalarValue::Int32(Some(42)))]); + let expr = ScalarSubqueryExpr::new(DataType::Int32, false, 0, results); + + let result = expr.evaluate(&batch)?; + match result { + ColumnarValue::Scalar(ScalarValue::Int32(Some(42))) => {} + other => panic!("Expected Scalar(Int32(42)), got {other:?}"), + } + Ok(()) + } + + #[test] + fn test_evaluate_before_populated() { + let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]); + let a = Int32Array::from(vec![1]); + let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)]).unwrap(); + + let results = Arc::new(vec![OnceLock::new()]); + let expr = ScalarSubqueryExpr::new(DataType::Int32, false, 0, results); + + let result = expr.evaluate(&batch); + assert!(result.is_err()); + } + + #[test] + fn test_identity_equality() { + let results = make_results(vec![None, None]); + + let e1a = + ScalarSubqueryExpr::new(DataType::Int32, false, 0, Arc::clone(&results)); + let e1b = + ScalarSubqueryExpr::new(DataType::Int32, false, 0, Arc::clone(&results)); + let e2 = ScalarSubqueryExpr::new(DataType::Int32, false, 1, Arc::clone(&results)); + + // Same container + same index → equal + assert_eq!(e1a, e1b); + // Same container, different index → not equal + assert_ne!(e1a, e2); + + // Different container, same index → not equal + let other_results = make_results(vec![None]); + let e3 = ScalarSubqueryExpr::new(DataType::Int32, false, 0, other_results); + assert_ne!(e1a, e3); + } +} diff --git a/datafusion/physical-plan/src/lib.rs b/datafusion/physical-plan/src/lib.rs index 6467d7a2e389d..94b7e2a1e0cae 100644 --- a/datafusion/physical-plan/src/lib.rs +++ b/datafusion/physical-plan/src/lib.rs @@ -84,6 +84,7 @@ pub mod placeholder_row; pub mod projection; pub mod recursive_query; pub mod repartition; +pub mod scalar_subquery; pub mod sort_pushdown; pub mod sorts; pub mod spill; diff --git a/datafusion/physical-plan/src/scalar_subquery.rs b/datafusion/physical-plan/src/scalar_subquery.rs new file mode 100644 index 0000000000000..62ad72c27b2f6 --- /dev/null +++ b/datafusion/physical-plan/src/scalar_subquery.rs @@ -0,0 +1,416 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Execution plan for uncorrelated scalar subqueries. +//! +//! [`ScalarSubqueryExec`] wraps a main input plan and a set of subquery plans. +//! At execution time, it runs each subquery exactly once, extracts the scalar +//! result, and populates the shared results container that +//! [`ScalarSubqueryExpr`] instances read from by index. +//! +//! [`ScalarSubqueryExpr`]: datafusion_physical_expr::scalar_subquery::ScalarSubqueryExpr + +use std::any::Any; +use std::fmt; +use std::sync::Arc; + +use datafusion_common::tree_node::TreeNodeRecursion; +use datafusion_common::{Result, ScalarValue, exec_err, internal_err}; +use datafusion_execution::TaskContext; +use datafusion_expr::execution_props::ScalarSubqueryResults; +use datafusion_physical_expr::PhysicalExpr; + +use crate::execution_plan::{CardinalityEffect, ExecutionPlan, PlanProperties}; +use crate::joins::utils::OnceAsync; +use crate::stream::RecordBatchStreamAdapter; +use crate::{DisplayAs, DisplayFormatType, SendableRecordBatchStream}; + +use futures::stream::StreamExt; + +/// Links a scalar subquery's execution plan to its index in the shared results +/// container. The [`ScalarSubqueryExec`] that owns these links populates +/// `results[index]` at execution time, and [`ScalarSubqueryExpr`] instances +/// with the same index read from it. +/// +/// [`ScalarSubqueryExpr`]: datafusion_physical_expr::scalar_subquery::ScalarSubqueryExpr +#[derive(Debug, Clone)] +pub struct ScalarSubqueryLink { + /// The physical plan for the subquery. + pub plan: Arc, + /// Index into the shared results container. + pub index: usize, +} + +/// Manages execution of uncorrelated scalar subqueries for a single plan +/// level. +/// +/// This node has an asymmetric set of children: the first child is the +/// **main input plan**, whose batches are passed through unchanged. The +/// remaining children are **subquery plans**, each of which must produce +/// exactly zero or one row. Before any batches from the main input are +/// yielded, all subquery plans are executed and their scalar results are +/// stored in a shared results container ([`ScalarSubqueryResults`]). +/// [`ScalarSubqueryExpr`] nodes embedded in the main input's expressions +/// read from this container by index. +/// +/// All subqueries are evaluated eagerly when the first output partition is +/// requested, before any rows from the main input are produced. +/// +/// TODO: Consider overlapping computation of the subqueries with evaluating the +/// main query. +/// +/// TODO: Subqueries are evaluated sequentially. Consider parallel evaluation in +/// the future. +/// +/// [`ScalarSubqueryExpr`]: datafusion_physical_expr::scalar_subquery::ScalarSubqueryExpr +#[derive(Debug)] +pub struct ScalarSubqueryExec { + /// The main input plan whose output is passed through. + input: Arc, + /// Subquery plans and their result indexes. + subqueries: Vec, + /// Shared one-time async computation of subquery results. + subquery_future: Arc>>, + /// Shared results container; same instance held by ScalarSubqueryExpr nodes. + results: ScalarSubqueryResults, + /// Cached plan properties (copied from input). + cache: Arc, +} + +impl ScalarSubqueryExec { + pub fn new( + input: Arc, + subqueries: Vec, + results: ScalarSubqueryResults, + ) -> Self { + let cache = Arc::clone(input.properties()); + Self { + input, + subqueries, + subquery_future: Arc::default(), + results, + cache, + } + } + + pub fn input(&self) -> &Arc { + &self.input + } + + pub fn subqueries(&self) -> &[ScalarSubqueryLink] { + &self.subqueries + } + + pub fn results(&self) -> &ScalarSubqueryResults { + &self.results + } + + /// Returns a per-child bool vec that is `true` for the main input + /// (child 0) and `false` for every subquery child. + fn true_for_input_only(&self) -> Vec { + std::iter::once(true) + .chain(std::iter::repeat_n(false, self.subqueries.len())) + .collect() + } +} + +impl DisplayAs for ScalarSubqueryExec { + fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!( + f, + "ScalarSubqueryExec: subqueries={}", + self.subqueries.len() + ) + } + DisplayFormatType::TreeRender => { + write!(f, "") + } + } + } +} + +impl ExecutionPlan for ScalarSubqueryExec { + fn name(&self) -> &'static str { + "ScalarSubqueryExec" + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn properties(&self) -> &Arc { + &self.cache + } + + fn children(&self) -> Vec<&Arc> { + let mut children = vec![&self.input]; + for sq in &self.subqueries { + children.push(&sq.plan); + } + children + } + + fn with_new_children( + self: Arc, + mut children: Vec>, + ) -> Result> { + // First child is the main input, the rest are subquery plans. + let input = children.remove(0); + let subqueries = self + .subqueries + .iter() + .zip(children) + .map(|(sq, new_plan)| ScalarSubqueryLink { + plan: new_plan, + index: sq.index, + }) + .collect(); + Ok(Arc::new(ScalarSubqueryExec::new( + input, + subqueries, + Arc::clone(&self.results), + ))) + } + + fn execute( + &self, + partition: usize, + context: Arc, + ) -> Result { + let subqueries = self.subqueries.clone(); + let results = Arc::clone(&self.results); + let ctx = Arc::clone(&context); + + // Use OnceAsync to ensure all subqueries are executed exactly once, + // even when multiple partitions call execute() concurrently. + let mut once_fut = self.subquery_future.try_once(move || { + Ok(async move { execute_subqueries(subqueries, results, ctx).await }) + })?; + + let input = self.input.execute(partition, context)?; + let schema = input.schema(); + + // Create a stream that first waits for subquery results, then + // delegates to the inner input stream. + Ok(Box::pin(RecordBatchStreamAdapter::new( + Arc::clone(&schema), + futures::stream::once(async move { + // Wait for subquery execution to complete. + let result: Result<()> = + std::future::poll_fn(|cx| once_fut.get(cx).map(|r| r.map(|_| ()))) + .await; + result + }) + .filter_map(|result| async move { + match result { + Ok(()) => None, // Subqueries done, proceed to input + Err(e) => Some(Err(e)), // Propagate error + } + }) + .chain(input), + ))) + } + + fn apply_expressions( + &self, + _f: &mut dyn FnMut(&dyn PhysicalExpr) -> Result, + ) -> Result { + Ok(TreeNodeRecursion::Continue) + } + + fn maintains_input_order(&self) -> Vec { + // Only the main input (first child); subquery children don't + // contribute to ordering. + self.true_for_input_only() + } + + fn benefits_from_input_partitioning(&self) -> Vec { + // Only the main input; subquery children produce at most one + // row, so repartitioning them adds overhead with no benefit. + self.true_for_input_only() + } + + fn cardinality_effect(&self) -> CardinalityEffect { + CardinalityEffect::Equal + } +} + +/// Execute all subquery plans, extract their scalar results, and populate +/// the shared results container. +async fn execute_subqueries( + subqueries: Vec, + results: ScalarSubqueryResults, + context: Arc, +) -> Result> { + let mut values = Vec::with_capacity(subqueries.len()); + for sq in &subqueries { + let value = + execute_scalar_subquery(Arc::clone(&sq.plan), Arc::clone(&context)).await?; + let _ = results[sq.index].set(value.clone()); + values.push(value); + } + Ok(values) +} + +/// Execute a single subquery plan and extract the scalar value. +/// Returns NULL for 0 rows, the scalar value for exactly 1 row, +/// or an error for >1 rows. +async fn execute_scalar_subquery( + plan: Arc, + context: Arc, +) -> Result { + let schema = plan.schema(); + if schema.fields().len() != 1 { + return internal_err!( + "Scalar subquery must return exactly one column, got {}", + schema.fields().len() + ); + } + + let mut stream = crate::execute_stream(plan, context)?; + + let mut total_rows = 0usize; + let mut result_value: Option = None; + while let Some(batch) = stream.next().await.transpose()? { + total_rows += batch.num_rows(); + if total_rows > 1 { + return exec_err!( + "Scalar subquery returned more than one row (got at least {total_rows})" + ); + } + if batch.num_rows() == 1 { + result_value = Some(ScalarValue::try_from_array(batch.column(0), 0)?); + } + } + + // 0 rows → NULL of the appropriate type + Ok(result_value.unwrap_or_else(|| { + ScalarValue::try_from(schema.field(0).data_type()).unwrap_or(ScalarValue::Null) + })) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::test::{self, TestMemoryExec}; + + use std::sync::OnceLock; + + use arrow::array::{Int32Array, Int64Array}; + use arrow::datatypes::{DataType, Field, Schema}; + use arrow::record_batch::RecordBatch; + + fn make_subquery_plan(batches: Vec) -> Arc { + let schema = batches[0].schema(); + TestMemoryExec::try_new_exec(&[batches], schema, None).unwrap() + } + + fn make_results(n: usize) -> ScalarSubqueryResults { + Arc::new((0..n).map(|_| OnceLock::new()).collect()) + } + + #[tokio::test] + async fn test_single_row_subquery() -> Result<()> { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![Arc::new(Int32Array::from(vec![42]))], + )?; + + let results = make_results(1); + let subquery_plan = make_subquery_plan(vec![batch]); + let sq = ScalarSubqueryLink { + plan: subquery_plan, + index: 0, + }; + + let main_input = Arc::new(crate::placeholder_row::PlaceholderRowExec::new( + test::aggr_test_schema(), + )); + let exec = ScalarSubqueryExec::new(main_input, vec![sq], Arc::clone(&results)); + + let ctx = Arc::new(TaskContext::default()); + let stream = exec.execute(0, ctx)?; + let _batches = crate::common::collect(stream).await?; + + assert_eq!(results[0].get(), Some(&ScalarValue::Int32(Some(42)))); + Ok(()) + } + + #[tokio::test] + async fn test_zero_row_subquery_returns_null() -> Result<()> { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, true)])); + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![Arc::new(Int64Array::from(vec![] as Vec))], + )?; + + let results = make_results(1); + let subquery_plan = make_subquery_plan(vec![batch]); + let sq = ScalarSubqueryLink { + plan: subquery_plan, + index: 0, + }; + + let main_input = Arc::new(crate::placeholder_row::PlaceholderRowExec::new( + test::aggr_test_schema(), + )); + let exec = ScalarSubqueryExec::new(main_input, vec![sq], Arc::clone(&results)); + + let ctx = Arc::new(TaskContext::default()); + let stream = exec.execute(0, ctx)?; + let _batches = crate::common::collect(stream).await?; + + assert_eq!(results[0].get(), Some(&ScalarValue::Int64(None))); + Ok(()) + } + + #[tokio::test] + async fn test_multi_row_subquery_errors() -> Result<()> { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], + )?; + + let results = make_results(1); + let subquery_plan = make_subquery_plan(vec![batch]); + let sq = ScalarSubqueryLink { + plan: subquery_plan, + index: 0, + }; + + let main_input = Arc::new(crate::placeholder_row::PlaceholderRowExec::new( + test::aggr_test_schema(), + )); + let exec = ScalarSubqueryExec::new(main_input, vec![sq], Arc::clone(&results)); + + let ctx = Arc::new(TaskContext::default()); + let stream = exec.execute(0, ctx)?; + let result = crate::common::collect(stream).await; + + assert!(result.is_err()); + let err_msg = result.unwrap_err().to_string(); + assert!( + err_msg.contains("more than one row"), + "Expected 'more than one row' error, got: {err_msg}" + ); + Ok(()) + } +} diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 9464e85727d4f..82aff7abce12a 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -775,6 +775,7 @@ message PhysicalPlanNode { AsyncFuncExecNode async_func = 36; BufferExecNode buffer = 37; ArrowScanExecNode arrow_scan = 38; + ScalarSubqueryExecNode scalar_subquery = 39; } } @@ -920,6 +921,8 @@ message PhysicalExprNode { UnknownColumn unknown_column = 20; PhysicalHashExprNode hash_expr = 21; + + PhysicalScalarSubqueryExprNode scalar_subquery = 22; } } @@ -1474,4 +1477,15 @@ message AsyncFuncExecNode { message BufferExecNode { PhysicalPlanNode input = 1; uint64 capacity = 2; +} + +message ScalarSubqueryExecNode { + PhysicalPlanNode input = 1; + repeated PhysicalPlanNode subqueries = 2; +} + +message PhysicalScalarSubqueryExprNode { + datafusion_common.ArrowType data_type = 1; + bool nullable = 2; + uint32 index = 3; } \ No newline at end of file diff --git a/datafusion/proto/src/bytes/mod.rs b/datafusion/proto/src/bytes/mod.rs index 84b15ea9a8920..d6f511476befe 100644 --- a/datafusion/proto/src/bytes/mod.rs +++ b/datafusion/proto/src/bytes/mod.rs @@ -277,7 +277,7 @@ pub fn logical_plan_from_json_with_extension_codec( /// Serialize a PhysicalPlan as bytes pub fn physical_plan_to_bytes(plan: Arc) -> Result { let extension_codec = DefaultPhysicalExtensionCodec {}; - let proto_converter = DefaultPhysicalProtoConverter {}; + let proto_converter = DefaultPhysicalProtoConverter::default(); physical_plan_to_bytes_with_proto_converter(plan, &extension_codec, &proto_converter) } @@ -285,7 +285,7 @@ pub fn physical_plan_to_bytes(plan: Arc) -> Result { #[cfg(feature = "json")] pub fn physical_plan_to_json(plan: Arc) -> Result { let extension_codec = DefaultPhysicalExtensionCodec {}; - let proto_converter = DefaultPhysicalProtoConverter {}; + let proto_converter = DefaultPhysicalProtoConverter::default(); let protobuf = proto_converter .execution_plan_to_proto(&plan, &extension_codec) .map_err(|e| plan_datafusion_err!("Error serializing plan: {e}"))?; @@ -298,7 +298,7 @@ pub fn physical_plan_to_bytes_with_extension_codec( plan: Arc, extension_codec: &dyn PhysicalExtensionCodec, ) -> Result { - let proto_converter = DefaultPhysicalProtoConverter {}; + let proto_converter = DefaultPhysicalProtoConverter::default(); physical_plan_to_bytes_with_proto_converter(plan, extension_codec, &proto_converter) } @@ -326,7 +326,7 @@ pub fn physical_plan_from_json( let back: protobuf::PhysicalPlanNode = serde_json::from_str(json) .map_err(|e| plan_datafusion_err!("Error serializing plan: {e}"))?; let extension_codec = DefaultPhysicalExtensionCodec {}; - let proto_converter = DefaultPhysicalProtoConverter {}; + let proto_converter = DefaultPhysicalProtoConverter::default(); proto_converter.proto_to_execution_plan(ctx, &extension_codec, &back) } @@ -336,7 +336,7 @@ pub fn physical_plan_from_bytes( ctx: &TaskContext, ) -> Result> { let extension_codec = DefaultPhysicalExtensionCodec {}; - let proto_converter = DefaultPhysicalProtoConverter {}; + let proto_converter = DefaultPhysicalProtoConverter::default(); physical_plan_from_bytes_with_proto_converter( bytes, ctx, @@ -351,7 +351,7 @@ pub fn physical_plan_from_bytes_with_extension_codec( ctx: &TaskContext, extension_codec: &dyn PhysicalExtensionCodec, ) -> Result> { - let proto_converter = DefaultPhysicalProtoConverter {}; + let proto_converter = DefaultPhysicalProtoConverter::default(); physical_plan_from_bytes_with_proto_converter( bytes, ctx, diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index c81e2fabe1185..49617d74d8fc6 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -16593,6 +16593,9 @@ impl serde::Serialize for PhysicalExprNode { physical_expr_node::ExprType::HashExpr(v) => { struct_ser.serialize_field("hashExpr", v)?; } + physical_expr_node::ExprType::ScalarSubquery(v) => { + struct_ser.serialize_field("scalarSubquery", v)?; + } } } struct_ser.end() @@ -16639,6 +16642,8 @@ impl<'de> serde::Deserialize<'de> for PhysicalExprNode { "unknownColumn", "hash_expr", "hashExpr", + "scalar_subquery", + "scalarSubquery", ]; #[allow(clippy::enum_variant_names)] @@ -16663,6 +16668,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalExprNode { Extension, UnknownColumn, HashExpr, + ScalarSubquery, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -16704,6 +16710,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalExprNode { "extension" => Ok(GeneratedField::Extension), "unknownColumn" | "unknown_column" => Ok(GeneratedField::UnknownColumn), "hashExpr" | "hash_expr" => Ok(GeneratedField::HashExpr), + "scalarSubquery" | "scalar_subquery" => Ok(GeneratedField::ScalarSubquery), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -16866,6 +16873,13 @@ impl<'de> serde::Deserialize<'de> for PhysicalExprNode { return Err(serde::de::Error::duplicate_field("hashExpr")); } expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::HashExpr) +; + } + GeneratedField::ScalarSubquery => { + if expr_type__.is_some() { + return Err(serde::de::Error::duplicate_field("scalarSubquery")); + } + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::ScalarSubquery) ; } } @@ -18104,6 +18118,9 @@ impl serde::Serialize for PhysicalPlanNode { physical_plan_node::PhysicalPlanType::ArrowScan(v) => { struct_ser.serialize_field("arrowScan", v)?; } + physical_plan_node::PhysicalPlanType::ScalarSubquery(v) => { + struct_ser.serialize_field("scalarSubquery", v)?; + } } } struct_ser.end() @@ -18174,6 +18191,8 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { "buffer", "arrow_scan", "arrowScan", + "scalar_subquery", + "scalarSubquery", ]; #[allow(clippy::enum_variant_names)] @@ -18215,6 +18234,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { AsyncFunc, Buffer, ArrowScan, + ScalarSubquery, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -18273,6 +18293,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { "asyncFunc" | "async_func" => Ok(GeneratedField::AsyncFunc), "buffer" => Ok(GeneratedField::Buffer), "arrowScan" | "arrow_scan" => Ok(GeneratedField::ArrowScan), + "scalarSubquery" | "scalar_subquery" => Ok(GeneratedField::ScalarSubquery), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -18552,6 +18573,13 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { return Err(serde::de::Error::duplicate_field("arrowScan")); } physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::ArrowScan) +; + } + GeneratedField::ScalarSubquery => { + if physical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("scalarSubquery")); + } + physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::ScalarSubquery) ; } } @@ -18564,6 +18592,134 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { deserializer.deserialize_struct("datafusion.PhysicalPlanNode", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for PhysicalScalarSubqueryExprNode { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.data_type.is_some() { + len += 1; + } + if self.nullable { + len += 1; + } + if self.index != 0 { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.PhysicalScalarSubqueryExprNode", len)?; + if let Some(v) = self.data_type.as_ref() { + struct_ser.serialize_field("dataType", v)?; + } + if self.nullable { + struct_ser.serialize_field("nullable", &self.nullable)?; + } + if self.index != 0 { + struct_ser.serialize_field("index", &self.index)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for PhysicalScalarSubqueryExprNode { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "data_type", + "dataType", + "nullable", + "index", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + DataType, + Nullable, + Index, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl serde::de::Visitor<'_> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "dataType" | "data_type" => Ok(GeneratedField::DataType), + "nullable" => Ok(GeneratedField::Nullable), + "index" => Ok(GeneratedField::Index), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = PhysicalScalarSubqueryExprNode; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.PhysicalScalarSubqueryExprNode") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut data_type__ = None; + let mut nullable__ = None; + let mut index__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::DataType => { + if data_type__.is_some() { + return Err(serde::de::Error::duplicate_field("dataType")); + } + data_type__ = map_.next_value()?; + } + GeneratedField::Nullable => { + if nullable__.is_some() { + return Err(serde::de::Error::duplicate_field("nullable")); + } + nullable__ = Some(map_.next_value()?); + } + GeneratedField::Index => { + if index__.is_some() { + return Err(serde::de::Error::duplicate_field("index")); + } + index__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + } + } + Ok(PhysicalScalarSubqueryExprNode { + data_type: data_type__, + nullable: nullable__.unwrap_or_default(), + index: index__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion.PhysicalScalarSubqueryExprNode", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for PhysicalScalarUdfNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result @@ -21233,6 +21389,114 @@ impl<'de> serde::Deserialize<'de> for RollupNode { deserializer.deserialize_struct("datafusion.RollupNode", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for ScalarSubqueryExecNode { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.input.is_some() { + len += 1; + } + if !self.subqueries.is_empty() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.ScalarSubqueryExecNode", len)?; + if let Some(v) = self.input.as_ref() { + struct_ser.serialize_field("input", v)?; + } + if !self.subqueries.is_empty() { + struct_ser.serialize_field("subqueries", &self.subqueries)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for ScalarSubqueryExecNode { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "input", + "subqueries", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Input, + Subqueries, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl serde::de::Visitor<'_> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "input" => Ok(GeneratedField::Input), + "subqueries" => Ok(GeneratedField::Subqueries), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = ScalarSubqueryExecNode; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.ScalarSubqueryExecNode") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut input__ = None; + let mut subqueries__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Input => { + if input__.is_some() { + return Err(serde::de::Error::duplicate_field("input")); + } + input__ = map_.next_value()?; + } + GeneratedField::Subqueries => { + if subqueries__.is_some() { + return Err(serde::de::Error::duplicate_field("subqueries")); + } + subqueries__ = Some(map_.next_value()?); + } + } + } + Ok(ScalarSubqueryExecNode { + input: input__, + subqueries: subqueries__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion.ScalarSubqueryExecNode", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for ScalarUdfExprNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index ff9133b1ced5b..32d8b60034074 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -1102,7 +1102,7 @@ pub mod table_reference { pub struct PhysicalPlanNode { #[prost( oneof = "physical_plan_node::PhysicalPlanType", - tags = "1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38" + tags = "1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39" )] pub physical_plan_type: ::core::option::Option, } @@ -1186,6 +1186,8 @@ pub mod physical_plan_node { Buffer(::prost::alloc::boxed::Box), #[prost(message, tag = "38")] ArrowScan(super::ArrowScanExecNode), + #[prost(message, tag = "39")] + ScalarSubquery(::prost::alloc::boxed::Box), } } #[derive(Clone, PartialEq, ::prost::Message)] @@ -1317,7 +1319,7 @@ pub struct PhysicalExprNode { pub expr_id: ::core::option::Option, #[prost( oneof = "physical_expr_node::ExprType", - tags = "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 15, 16, 18, 19, 20, 21" + tags = "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 15, 16, 18, 19, 20, 21, 22" )] pub expr_type: ::core::option::Option, } @@ -1370,6 +1372,8 @@ pub mod physical_expr_node { UnknownColumn(super::UnknownColumn), #[prost(message, tag = "21")] HashExpr(super::PhysicalHashExprNode), + #[prost(message, tag = "22")] + ScalarSubquery(super::PhysicalScalarSubqueryExprNode), } } #[derive(Clone, PartialEq, ::prost::Message)] @@ -2198,6 +2202,22 @@ pub struct BufferExecNode { #[prost(uint64, tag = "2")] pub capacity: u64, } +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ScalarSubqueryExecNode { + #[prost(message, optional, boxed, tag = "1")] + pub input: ::core::option::Option<::prost::alloc::boxed::Box>, + #[prost(message, repeated, tag = "2")] + pub subqueries: ::prost::alloc::vec::Vec, +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct PhysicalScalarSubqueryExprNode { + #[prost(message, optional, tag = "1")] + pub data_type: ::core::option::Option, + #[prost(bool, tag = "2")] + pub nullable: bool, + #[prost(uint32, tag = "3")] + pub index: u32, +} /// Identifies a built-in file format supported by DataFusion. /// Used by DefaultLogicalExtensionCodec to serialize/deserialize /// FileFormatFactory instances (e.g. in CopyTo plans). diff --git a/datafusion/proto/src/physical_plan/from_proto.rs b/datafusion/proto/src/physical_plan/from_proto.rs index 73a347331cb6b..c01714e765e29 100644 --- a/datafusion/proto/src/physical_plan/from_proto.rs +++ b/datafusion/proto/src/physical_plan/from_proto.rs @@ -39,6 +39,7 @@ use datafusion_execution::{FunctionRegistry, TaskContext}; use datafusion_expr::WindowFunctionDefinition; use datafusion_expr::dml::InsertOp; use datafusion_physical_expr::projection::{ProjectionExpr, ProjectionExprs}; +use datafusion_physical_expr::scalar_subquery::ScalarSubqueryExpr; use datafusion_physical_expr::{LexOrdering, PhysicalSortExpr, ScalarFunctionExpr}; use datafusion_physical_plan::expressions::{ BinaryExpr, CaseExpr, CastExpr, Column, IsNotNullExpr, IsNullExpr, LikeExpr, Literal, @@ -240,7 +241,7 @@ pub fn parse_physical_expr( ctx, input_schema, codec, - &DefaultPhysicalProtoConverter {}, + &DefaultPhysicalProtoConverter::default(), ) } @@ -490,6 +491,28 @@ pub fn parse_physical_expr_with_converter( hash_expr.description.clone(), )) } + ExprType::ScalarSubquery(sq) => { + let data_type: arrow::datatypes::DataType = sq + .data_type + .as_ref() + .ok_or_else(|| { + proto_error("Missing data_type in PhysicalScalarSubqueryExprNode") + })? + .try_into()?; + // Use the results container from the converter if available + // (set by ScalarSubqueryExec deserialization), otherwise fall + // back to an empty placeholder for standalone expression + // deserialization. + let results = proto_converter + .scalar_subquery_results() + .unwrap_or_else(|| Arc::new(vec![])); + Arc::new(ScalarSubqueryExpr::new( + data_type, + sq.nullable, + sq.index as usize, + results, + )) + } ExprType::Extension(extension) => { let inputs: Vec> = extension .inputs diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index 0f37e9ad3f942..f746df2ecc2fb 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -51,6 +51,7 @@ use datafusion_datasource_parquet::source::ParquetSource; #[cfg(feature = "parquet")] use datafusion_execution::object_store::ObjectStoreUrl; use datafusion_execution::{FunctionRegistry, TaskContext}; +use datafusion_expr::execution_props::ScalarSubqueryResults; use datafusion_expr::{AggregateUDF, ScalarUDF, WindowUDF}; use datafusion_functions_table::generate_series::{ Empty, GenSeriesArgs, GenerateSeriesTable, GenericSeriesState, TimestampValue, @@ -83,6 +84,7 @@ use datafusion_physical_plan::metrics::{MetricCategory, MetricType}; use datafusion_physical_plan::placeholder_row::PlaceholderRowExec; use datafusion_physical_plan::projection::{ProjectionExec, ProjectionExpr}; use datafusion_physical_plan::repartition::RepartitionExec; +use datafusion_physical_plan::scalar_subquery::{ScalarSubqueryExec, ScalarSubqueryLink}; use datafusion_physical_plan::sorts::sort::SortExec; use datafusion_physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; use datafusion_physical_plan::union::{InterleaveExec, UnionExec}; @@ -145,7 +147,7 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { self.try_into_physical_plan_with_converter( ctx, codec, - &DefaultPhysicalProtoConverter {}, + &DefaultPhysicalProtoConverter::default(), ) } @@ -159,7 +161,7 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { Self::try_from_physical_plan_with_converter( plan, codec, - &DefaultPhysicalProtoConverter {}, + &DefaultPhysicalProtoConverter::default(), ) } } @@ -321,6 +323,8 @@ impl protobuf::PhysicalPlanNode { PhysicalPlanType::Buffer(buffer) => { self.try_into_buffer_physical_plan(buffer, ctx, codec, proto_converter) } + PhysicalPlanType::ScalarSubquery(sq) => self + .try_into_scalar_subquery_physical_plan(sq, ctx, codec, proto_converter), } } @@ -569,6 +573,14 @@ impl protobuf::PhysicalPlanNode { ); } + if let Some(exec) = plan.downcast_ref::() { + return protobuf::PhysicalPlanNode::try_from_scalar_subquery_exec( + exec, + codec, + proto_converter, + ); + } + let mut buf: Vec = vec![]; match codec.try_encode(Arc::clone(&plan_clone), &mut buf) { Ok(_) => { @@ -2251,6 +2263,46 @@ impl protobuf::PhysicalPlanNode { Ok(Arc::new(BufferExec::new(input, buffer.capacity as usize))) } + fn try_into_scalar_subquery_physical_plan( + &self, + sq: &protobuf::ScalarSubqueryExecNode, + ctx: &TaskContext, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, + ) -> Result> { + // Create the results container upfront — we know the count from the + // proto. Making it available on the converter before deserializing + // the input plan allows ScalarSubqueryExpr nodes to pick up the + // shared container at construction time. + let results = Arc::new( + (0..sq.subqueries.len()) + .map(|_| std::sync::OnceLock::new()) + .collect::>(), + ); + let prev = + proto_converter.set_scalar_subquery_results(Some(Arc::clone(&results))); + let input = into_physical_plan(&sq.input, ctx, codec, proto_converter); + // Restore previous state before propagating errors, so nested + // ScalarSubqueryExec deserialization doesn't see stale state. + proto_converter.set_scalar_subquery_results(prev); + let input: Arc = input?; + + let subqueries: Vec = sq + .subqueries + .iter() + .enumerate() + .map(|(index, sq_plan)| { + let plan = sq_plan.try_into_physical_plan_with_converter( + ctx, + codec, + proto_converter, + )?; + Ok(ScalarSubqueryLink { plan, index }) + }) + .collect::>>()?; + Ok(Arc::new(ScalarSubqueryExec::new(input, subqueries, results))) + } + fn try_from_explain_exec( exec: &ExplainExec, _codec: &dyn PhysicalExtensionCodec, @@ -3645,6 +3697,38 @@ impl protobuf::PhysicalPlanNode { ))), }) } + + fn try_from_scalar_subquery_exec( + exec: &ScalarSubqueryExec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, + ) -> Result { + let input = protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( + Arc::clone(exec.input()), + codec, + proto_converter, + )?; + let subqueries = exec + .subqueries() + .iter() + .map(|sq| { + protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( + Arc::clone(&sq.plan), + codec, + proto_converter, + ) + }) + .collect::>>()?; + + Ok(protobuf::PhysicalPlanNode { + physical_plan_type: Some(PhysicalPlanType::ScalarSubquery(Box::new( + protobuf::ScalarSubqueryExecNode { + input: Some(Box::new(input)), + subqueries, + }, + ))), + }) + } } pub trait AsExecutionPlan: Debug + Send + Sync + Clone { @@ -3777,6 +3861,24 @@ pub trait PhysicalProtoConverterExtension { expr: &Arc, codec: &dyn PhysicalExtensionCodec, ) -> Result; + + /// Sets the scalar subquery results container for expression deserialization. + /// Returns the previous value (for save/restore around nested subqueries). + /// + /// During `ScalarSubqueryExec` deserialization, this is called before + /// deserializing the input plan so that `ScalarSubqueryExpr` nodes can + /// pick up the shared results container at construction time. + fn set_scalar_subquery_results( + &self, + _results: Option, + ) -> Option { + None + } + + /// Returns the current scalar subquery results container, if any. + fn scalar_subquery_results(&self) -> Option { + None + } } /// DataEncoderTuple captures the position of the encoder @@ -3792,7 +3894,11 @@ struct DataEncoderTuple { pub blob: Vec, } -pub struct DefaultPhysicalProtoConverter; +#[derive(Default)] +pub struct DefaultPhysicalProtoConverter { + scalar_subquery_results: RefCell>, +} + impl PhysicalProtoConverterExtension for DefaultPhysicalProtoConverter { fn proto_to_execution_plan( &self, @@ -3839,6 +3945,17 @@ impl PhysicalProtoConverterExtension for DefaultPhysicalProtoConverter { ) -> Result { serialize_physical_expr_with_converter(expr, codec, self) } + + fn set_scalar_subquery_results( + &self, + results: Option, + ) -> Option { + self.scalar_subquery_results.replace(results) + } + + fn scalar_subquery_results(&self) -> Option { + self.scalar_subquery_results.borrow().clone() + } } /// Internal serializer that adds expr_id to expressions. @@ -3921,6 +4038,10 @@ impl PhysicalProtoConverterExtension for DeduplicatingSerializer { struct DeduplicatingDeserializer { /// Cache mapping expr_id to deserialized expressions. cache: RefCell>>, + /// Scalar subquery results container for the current deserialization scope. + /// Set by `ScalarSubqueryExec` deserialization so that `ScalarSubqueryExpr` + /// nodes in the input plan can pick up the shared results container. + scalar_subquery_results: RefCell>, } impl PhysicalProtoConverterExtension for DeduplicatingDeserializer { @@ -3981,6 +4102,17 @@ impl PhysicalProtoConverterExtension for DeduplicatingDeserializer { ) -> Result { internal_err!("DeduplicatingDeserializer cannot serialize physical expressions") } + + fn set_scalar_subquery_results( + &self, + results: Option, + ) -> Option { + self.scalar_subquery_results.replace(results) + } + + fn scalar_subquery_results(&self) -> Option { + self.scalar_subquery_results.borrow().clone() + } } /// A proto converter that adds expression deduplication during serialization diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index 990a54cf94c7a..e41ea85b93dcd 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -32,6 +32,7 @@ use datafusion_datasource_json::file_format::JsonSink; use datafusion_datasource_parquet::file_format::ParquetSink; use datafusion_expr::WindowFrame; use datafusion_physical_expr::ScalarFunctionExpr; +use datafusion_physical_expr::scalar_subquery::ScalarSubqueryExpr; use datafusion_physical_expr::window::{SlidingAggregateWindowExpr, StandardWindowExpr}; use datafusion_physical_expr_common::physical_expr::snapshot_physical_expr; use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; @@ -242,7 +243,7 @@ pub fn serialize_physical_expr( serialize_physical_expr_with_converter( value, codec, - &DefaultPhysicalProtoConverter {}, + &DefaultPhysicalProtoConverter::default(), ) } @@ -504,6 +505,17 @@ pub fn serialize_physical_expr_with_converter( }, )), }) + } else if let Some(expr) = expr.downcast_ref::() { + Ok(protobuf::PhysicalExprNode { + expr_id: None, + expr_type: Some(protobuf::physical_expr_node::ExprType::ScalarSubquery( + protobuf::PhysicalScalarSubqueryExprNode { + data_type: Some((&expr.data_type(&Schema::empty())?).try_into()?), + nullable: expr.nullable(&Schema::empty())?, + index: expr.index() as u32, + }, + )), + }) } else { let mut buf: Vec = vec![]; match codec.try_encode_expr(&value, &mut buf) { diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index 2b744a8fbdd19..2d02640a34442 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -18,7 +18,7 @@ use std::any::Any; use std::collections::HashMap; use std::fmt::{Display, Formatter}; -use std::sync::{Arc, RwLock}; +use std::sync::{Arc, OnceLock, RwLock}; use std::vec; use arrow::array::RecordBatch; @@ -54,6 +54,7 @@ use datafusion::physical_expr::window::{SlidingAggregateWindowExpr, StandardWind use datafusion::physical_expr::{ LexOrdering, PhysicalSortRequirement, ScalarFunctionExpr, }; +use datafusion_physical_expr::scalar_subquery::ScalarSubqueryExpr; use datafusion::physical_plan::aggregates::{ AggregateExec, AggregateMode, LimitOptions, PhysicalGroupBy, }; @@ -66,6 +67,7 @@ use datafusion::physical_plan::expressions::{ BinaryExpr, Column, NotExpr, PhysicalSortExpr, binary, cast, col, in_list, like, lit, }; use datafusion::physical_plan::filter::{FilterExec, FilterExecBuilder}; +use datafusion::physical_plan::scalar_subquery::{ScalarSubqueryExec, ScalarSubqueryLink}; use datafusion::physical_plan::joins::{ HashJoinExec, NestedLoopJoinExec, PartitionMode, SortMergeJoinExec, StreamJoinPartitionMode, SymmetricHashJoinExec, @@ -136,7 +138,7 @@ use crate::cases::{ fn roundtrip_test(exec_plan: Arc) -> Result<()> { let ctx = SessionContext::new(); let codec = DefaultPhysicalExtensionCodec {}; - let proto_converter = DefaultPhysicalProtoConverter {}; + let proto_converter = DefaultPhysicalProtoConverter::default(); roundtrip_test_and_return(exec_plan, &ctx, &codec, &proto_converter)?; Ok(()) } @@ -183,7 +185,7 @@ fn roundtrip_test_with_context( ctx: &SessionContext, ) -> Result<()> { let codec = DefaultPhysicalExtensionCodec {}; - let proto_converter = DefaultPhysicalProtoConverter {}; + let proto_converter = DefaultPhysicalProtoConverter::default(); roundtrip_test_and_return(exec_plan, ctx, &codec, &proto_converter)?; Ok(()) } @@ -192,7 +194,7 @@ fn roundtrip_test_with_context( /// query results are identical. async fn roundtrip_test_sql_with_context(sql: &str, ctx: &SessionContext) -> Result<()> { let codec = DefaultPhysicalExtensionCodec {}; - let proto_converter = DefaultPhysicalProtoConverter {}; + let proto_converter = DefaultPhysicalProtoConverter::default(); let initial_plan = ctx.sql(sql).await?.create_physical_plan().await?; roundtrip_test_and_return(initial_plan, ctx, &codec, &proto_converter)?; @@ -964,7 +966,7 @@ fn roundtrip_parquet_exec_attaches_cached_reader_factory_after_roundtrip() -> Re let ctx = SessionContext::new(); let codec = DefaultPhysicalExtensionCodec {}; - let proto_converter = DefaultPhysicalProtoConverter {}; + let proto_converter = DefaultPhysicalProtoConverter::default(); let roundtripped = roundtrip_test_and_return(exec_plan, &ctx, &codec, &proto_converter)?; @@ -1192,7 +1194,7 @@ fn roundtrip_parquet_exec_with_custom_predicate_expr() -> Result<()> { exec_plan, &ctx, &CustomPhysicalExtensionCodec {}, - &DefaultPhysicalProtoConverter {}, + &DefaultPhysicalProtoConverter::default(), )?; Ok(()) } @@ -1399,7 +1401,7 @@ fn roundtrip_scalar_udf_extension_codec() -> Result<()> { )?); let ctx = SessionContext::new(); - let proto_converter = DefaultPhysicalProtoConverter {}; + let proto_converter = DefaultPhysicalProtoConverter::default(); roundtrip_test_and_return(aggregate, &ctx, &UDFExtensionCodec, &proto_converter)?; Ok(()) } @@ -1447,7 +1449,7 @@ fn roundtrip_udwf_extension_codec() -> Result<()> { )?); let ctx = SessionContext::new(); - let proto_converter = DefaultPhysicalProtoConverter {}; + let proto_converter = DefaultPhysicalProtoConverter::default(); roundtrip_test_and_return(window, &ctx, &UDFExtensionCodec, &proto_converter)?; Ok(()) } @@ -1519,7 +1521,7 @@ fn roundtrip_aggregate_udf_extension_codec() -> Result<()> { )?); let ctx = SessionContext::new(); - let proto_converter = DefaultPhysicalProtoConverter {}; + let proto_converter = DefaultPhysicalProtoConverter::default(); roundtrip_test_and_return(aggregate, &ctx, &UDFExtensionCodec, &proto_converter)?; Ok(()) } @@ -1647,7 +1649,7 @@ fn roundtrip_csv_sink() -> Result<()> { let ctx = SessionContext::new(); let codec = DefaultPhysicalExtensionCodec {}; - let proto_converter = DefaultPhysicalProtoConverter {}; + let proto_converter = DefaultPhysicalProtoConverter::default(); let roundtrip_plan = roundtrip_test_and_return( Arc::new(DataSinkExec::new(input, data_sink, Some(sort_order))), @@ -2856,7 +2858,7 @@ fn test_backward_compatibility_no_expr_id() -> Result<()> { let ctx = SessionContext::new(); let codec = DefaultPhysicalExtensionCodec {}; - let proto_converter = DefaultPhysicalProtoConverter {}; + let proto_converter = DefaultPhysicalProtoConverter::default(); // Should deserialize without error let result = proto_converter.proto_to_physical_expr( @@ -3193,3 +3195,87 @@ fn roundtrip_lead_with_default_value() -> Result<()> { true, )?)) } + +/// Verify that ScalarSubqueryExpr nodes in the input plan are connected to the +/// same shared results container as ScalarSubqueryExec after a proto round-trip. +#[test] +fn roundtrip_scalar_subquery_exec() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int64, false), + ])); + + // Create a shared results container with one slot. + let results = Arc::new(vec![OnceLock::new()]); + + // Build the input plan: a filter whose predicate references the + // scalar subquery result via ScalarSubqueryExpr. + let sq_expr = Arc::new(ScalarSubqueryExpr::new( + DataType::Int64, + true, + 0, + Arc::clone(&results), + )); + let predicate = binary(col("a", &schema)?, Operator::Eq, sq_expr, &schema)?; + let filter = FilterExec::try_new(predicate, Arc::new(EmptyExec::new(schema.clone())))?; + + // Build a trivial subquery plan. + let subquery_plan = Arc::new(EmptyExec::new(Arc::new(Schema::new(vec![ + Field::new("x", DataType::Int64, true), + ])))); + + let exec: Arc = Arc::new(ScalarSubqueryExec::new( + Arc::new(filter), + vec![ScalarSubqueryLink { + plan: subquery_plan, + index: 0, + }], + results, + )); + + // Perform the round-trip using DeduplicatingProtoConverter, which + // creates a DeduplicatingDeserializer that threads scalar subquery + // results through expression deserialization. + let codec = DefaultPhysicalExtensionCodec {}; + let converter = DeduplicatingProtoConverter {}; + let bytes = + physical_plan_to_bytes_with_proto_converter(Arc::clone(&exec), &codec, &converter)?; + let ctx = SessionContext::new(); + let deserialized = physical_plan_from_bytes_with_proto_converter( + bytes.as_ref(), + ctx.task_ctx().as_ref(), + &codec, + &converter, + )?; + + // Verify the deserialized ScalarSubqueryExec's results container is + // shared with the ScalarSubqueryExpr in the input plan. + let sq_exec = deserialized + .as_any() + .downcast_ref::() + .expect("expected ScalarSubqueryExec"); + let exec_results = sq_exec.results(); + + // Walk the input plan to find the ScalarSubqueryExpr and verify it + // points to the same results container. + let filter_exec = sq_exec + .input() + .as_any() + .downcast_ref::() + .expect("expected FilterExec"); + let binary_expr = filter_exec + .predicate() + .as_any() + .downcast_ref::() + .expect("expected BinaryExpr"); + let deserialized_sq_expr = binary_expr + .right() + .as_any() + .downcast_ref::() + .expect("expected ScalarSubqueryExpr"); + + assert!( + Arc::ptr_eq(exec_results, deserialized_sq_expr.results()), + "ScalarSubqueryExpr should share the same results container as ScalarSubqueryExec" + ); + Ok(()) +} diff --git a/datafusion/sqllogictest/test_files/metadata.slt b/datafusion/sqllogictest/test_files/metadata.slt index f3836b23ec321..fc090102c4066 100644 --- a/datafusion/sqllogictest/test_files/metadata.slt +++ b/datafusion/sqllogictest/test_files/metadata.slt @@ -47,16 +47,12 @@ select id, name from table_with_metadata; NULL bar 3 baz -query I rowsort +query error DataFusion error: Execution error: Scalar subquery returned more than one row SELECT ( SELECT id FROM table_with_metadata ) UNION ( SELECT id FROM table_with_metadata ); ----- -1 -3 -NULL query I rowsort SELECT "data"."id" diff --git a/datafusion/sqllogictest/test_files/subquery.slt b/datafusion/sqllogictest/test_files/subquery.slt index e5ca9d674e19f..5054ce85d1a76 100644 --- a/datafusion/sqllogictest/test_files/subquery.slt +++ b/datafusion/sqllogictest/test_files/subquery.slt @@ -689,10 +689,10 @@ query TT explain SELECT t1_id, (SELECT t2_id FROM t2 limit 0) FROM t1 ---- logical_plan -01)Projection: t1.t1_id, __scalar_sq_1.t2_id AS t2_id -02)--Left Join: -03)----TableScan: t1 projection=[t1_id] -04)----EmptyRelation: rows=0 +01)Projection: t1.t1_id, () +02)--Subquery: +03)----EmptyRelation: rows=0 +04)--TableScan: t1 projection=[t1_id] query II rowsort SELECT t1_id, (SELECT t2_id FROM t2 limit 0) FROM t1 @@ -723,26 +723,28 @@ query TT explain select (select count(*) from t1) as b ---- logical_plan -01)Projection: __scalar_sq_1.count(*) AS b -02)--SubqueryAlias: __scalar_sq_1 +01)Projection: () AS b +02)--Subquery: 03)----Projection: count(Int64(1)) AS count(*) 04)------Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] -05)--------TableScan: t1 projection=[] +05)--------TableScan: t1 +06)--EmptyRelation: rows=1 #simple_uncorrelated_scalar_subquery2 query TT explain select (select count(*) from t1) as b, (select count(1) from t2) ---- logical_plan -01)Projection: __scalar_sq_1.count(*) AS b, __scalar_sq_2.count(Int64(1)) AS count(Int64(1)) -02)--Left Join: -03)----SubqueryAlias: __scalar_sq_1 -04)------Projection: count(Int64(1)) AS count(*) -05)--------Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] -06)----------TableScan: t1 projection=[] -07)----SubqueryAlias: __scalar_sq_2 +01)Projection: () AS b, () +02)--Subquery: +03)----Projection: count(Int64(1)) AS count(*) +04)------Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] +05)--------TableScan: t1 +06)--Subquery: +07)----Projection: count(Int64(1)) 08)------Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] -09)--------TableScan: t2 projection=[] +09)--------TableScan: t2 +10)--EmptyRelation: rows=1 statement ok set datafusion.explain.logical_plan_only = false; @@ -751,22 +753,24 @@ query TT explain select (select count(*) from t1) as b, (select count(1) from t2) ---- logical_plan -01)Projection: __scalar_sq_1.count(*) AS b, __scalar_sq_2.count(Int64(1)) AS count(Int64(1)) -02)--Left Join: -03)----SubqueryAlias: __scalar_sq_1 -04)------Projection: count(Int64(1)) AS count(*) -05)--------Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] -06)----------TableScan: t1 projection=[] -07)----SubqueryAlias: __scalar_sq_2 +01)Projection: () AS b, () +02)--Subquery: +03)----Projection: count(Int64(1)) AS count(*) +04)------Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] +05)--------TableScan: t1 +06)--Subquery: +07)----Projection: count(Int64(1)) 08)------Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] -09)--------TableScan: t2 projection=[] +09)--------TableScan: t2 +10)--EmptyRelation: rows=1 physical_plan -01)ProjectionExec: expr=[count(*)@0 as b, count(Int64(1))@1 as count(Int64(1))] -02)--NestedLoopJoinExec: join_type=Left -03)----ProjectionExec: expr=[4 as count(*)] -04)------PlaceholderRowExec -05)----ProjectionExec: expr=[4 as count(Int64(1))] -06)------PlaceholderRowExec +01)ScalarSubqueryExec: subqueries=2 +02)--ProjectionExec: expr=[scalar_subquery() as b, scalar_subquery() as count(Int64(1))] +03)----PlaceholderRowExec +04)--ProjectionExec: expr=[4 as count(*)] +05)----PlaceholderRowExec +06)--ProjectionExec: expr=[4 as count(Int64(1))] +07)----PlaceholderRowExec statement ok set datafusion.explain.logical_plan_only = true; @@ -1671,3 +1675,196 @@ drop table employees; statement count 0 drop table project_assignments; + +############# +## Uncorrelated scalar subquery row-count semantics +## A scalar subquery must return at most one row; returning more is an error. +############# + +statement ok +CREATE TABLE sq_values(v INT) AS VALUES (1), (2), (3); + +statement ok +CREATE TABLE sq_main(x INT) AS VALUES (10), (20); + +statement ok +CREATE TABLE sq_empty(v INT) AS VALUES (1) LIMIT 0; + +# Scalar subquery returning multiple rows in SELECT position → error +query error DataFusion error: Execution error: Scalar subquery returned more than one row +SELECT (SELECT v FROM sq_values); + +# Scalar subquery returning multiple rows in WHERE position → error +query error DataFusion error: Execution error: Scalar subquery returned more than one row +SELECT x FROM sq_main WHERE x > (SELECT v FROM sq_values); + +# Scalar subquery returning multiple rows as a function argument → error +query error DataFusion error: Execution error: Scalar subquery returned more than one row +SELECT x + (SELECT v FROM sq_values) FROM sq_main; + +# Scalar subquery returning exactly one row → success +query I +SELECT (SELECT v FROM sq_values LIMIT 1); +---- +1 + +# Scalar subquery returning exactly one row in WHERE → success +query I rowsort +SELECT x FROM sq_main WHERE x > (SELECT v FROM sq_values LIMIT 1); +---- +10 +20 + +# Scalar subquery returning zero rows → NULL +query I +SELECT (SELECT v FROM sq_empty); +---- +NULL + +# Scalar subquery returning zero rows in arithmetic → NULL propagation +query I +SELECT x + (SELECT v FROM sq_empty) FROM sq_main; +---- +NULL +NULL + +# Scalar subquery returning zero rows in WHERE comparison → no matching rows +query I +SELECT x FROM sq_main WHERE x > (SELECT v FROM sq_empty); +---- + +# Aggregated subquery always returns one row, even on empty input → success +query I +SELECT (SELECT count(*) FROM sq_empty); +---- +0 + +# Aggregated subquery on multi-row table → success (aggregation reduces to 1 row) +query I +SELECT (SELECT max(v) FROM sq_values); +---- +3 + +# Multiple scalar subqueries, one returns multiple rows → error +query error DataFusion error: Execution error: Scalar subquery returned more than one row +SELECT (SELECT count(*) FROM sq_empty), (SELECT v FROM sq_values); + +############# +## Uncorrelated scalar subqueries in various expression contexts +############# + +# HAVING clause with uncorrelated scalar subquery +query II rowsort +SELECT x, count(*) AS cnt FROM sq_main GROUP BY x +HAVING count(*) > (SELECT min(v) FROM sq_values); +---- + +# CASE WHEN with uncorrelated scalar subquery as condition +query T rowsort +SELECT CASE WHEN x > (SELECT min(v) FROM sq_values) + THEN 'big' ELSE 'small' END AS label +FROM sq_main; +---- +big +big + +# ORDER BY with uncorrelated scalar subquery +query I +SELECT x FROM sq_main ORDER BY x + (SELECT max(v) FROM sq_values); +---- +10 +20 + +# Aggregate function argument containing uncorrelated scalar subquery +query I +SELECT sum(x + (SELECT max(v) FROM sq_values)) AS s FROM sq_main; +---- +36 + +# JOIN ON condition with uncorrelated scalar subquery +query II rowsort +SELECT l.x, r.x AS rx +FROM sq_main AS l JOIN sq_main AS r +ON l.x = r.x + (SELECT min(v) FROM sq_values); +---- + +# Nested uncorrelated-in-uncorrelated scalar subquery. +query I +SELECT (SELECT max(v) + (SELECT min(v) FROM sq_values) FROM sq_values); +---- +4 + +# Verify nested subqueries are not hoisted: the root ScalarSubqueryExec +# should manage only the outer subquery (subqueries=1), not both. +query TT +EXPLAIN SELECT (SELECT max(v) + (SELECT min(v) FROM sq_values) FROM sq_values); +---- +logical_plan +01)Projection: () +02)--Subquery: +03)----Projection: max(sq_values.v) + () +04)------Subquery: +05)--------Projection: min(sq_values.v) +06)----------Aggregate: groupBy=[[]], aggr=[[min(sq_values.v)]] +07)------------TableScan: sq_values +08)------Aggregate: groupBy=[[]], aggr=[[max(sq_values.v)]] +09)--------TableScan: sq_values +10)--EmptyRelation: rows=1 +physical_plan +01)ScalarSubqueryExec: subqueries=1 +02)--ProjectionExec: expr=[scalar_subquery() as max(sq_values.v) + min(sq_values.v)] +03)----PlaceholderRowExec +04)--ScalarSubqueryExec: subqueries=1 +05)----ProjectionExec: expr=[max(sq_values.v)@0 + scalar_subquery() as max(sq_values.v) + min(sq_values.v)] +06)------AggregateExec: mode=Single, gby=[], aggr=[max(sq_values.v)] +07)--------DataSourceExec: partitions=1, partition_sizes=[1] +08)----AggregateExec: mode=Single, gby=[], aggr=[min(sq_values.v)] +09)------DataSourceExec: partitions=1, partition_sizes=[1] + +# CTE as source inside uncorrelated scalar subquery +query I +SELECT (SELECT s FROM (WITH cte AS (SELECT max(v) AS s FROM sq_values) SELECT s FROM cte)); +---- +3 + +# Window function with uncorrelated scalar subquery +query II rowsort +SELECT x, sum(x + (SELECT max(v) FROM sq_values)) OVER () AS win_sum FROM sq_main; +---- +10 36 +20 36 + +# Duplicate uncorrelated scalar subqueries only appear in the query plan once +statement ok +set datafusion.explain.logical_plan_only = false; + +query TT +explain SELECT (SELECT max(v) FROM sq_values) + (SELECT max(v) FROM sq_values) AS doubled; +---- +logical_plan +01)Projection: __common_expr_1 + __common_expr_1 AS doubled +02)--Projection: () AS __common_expr_1 +03)----Subquery: +04)------Projection: max(sq_values.v) +05)--------Aggregate: groupBy=[[]], aggr=[[max(sq_values.v)]] +06)----------TableScan: sq_values +07)----EmptyRelation: rows=1 +physical_plan +01)ScalarSubqueryExec: subqueries=1 +02)--ProjectionExec: expr=[__common_expr_1@0 + __common_expr_1@0 as doubled] +03)----ProjectionExec: expr=[scalar_subquery() as __common_expr_1] +04)------PlaceholderRowExec +05)--AggregateExec: mode=Single, gby=[], aggr=[max(sq_values.v)] +06)----DataSourceExec: partitions=1, partition_sizes=[1] + +statement ok +RESET datafusion.explain.logical_plan_only; + +statement count 0 +DROP TABLE sq_values; + +statement count 0 +DROP TABLE sq_main; + +statement count 0 +DROP TABLE sq_empty; From 6b4f5c02a60af9815ca783cb9e9357be2f4ff1d8 Mon Sep 17 00:00:00 2001 From: Neil Conway Date: Sun, 29 Mar 2026 15:02:26 -0400 Subject: [PATCH 02/28] cargo fmt --- .../core/benches/scalar_subquery_sql.rs | 13 ++++----- datafusion/proto/src/physical_plan/mod.rs | 4 ++- .../tests/cases/roundtrip_physical_plan.rs | 29 ++++++++++++------- 3 files changed, 27 insertions(+), 19 deletions(-) diff --git a/datafusion/core/benches/scalar_subquery_sql.rs b/datafusion/core/benches/scalar_subquery_sql.rs index bd5d25a96123f..3642bdd1a0150 100644 --- a/datafusion/core/benches/scalar_subquery_sql.rs +++ b/datafusion/core/benches/scalar_subquery_sql.rs @@ -39,9 +39,7 @@ fn query(ctx: &SessionContext, rt: &Runtime, sql: &str) { } fn create_context(num_rows: usize) -> Result { - let schema = Arc::new(Schema::new(vec![ - Field::new("x", DataType::Int64, false), - ])); + let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Int64, false)])); let batch_size = 4096; let batches = (0..num_rows / batch_size) @@ -54,16 +52,17 @@ fn create_context(num_rows: usize) -> Result { .collect::>(); // Small lookup table for the subquery to read from. - let sq_schema = Arc::new(Schema::new(vec![ - Field::new("v", DataType::Int64, false), - ])); + let sq_schema = Arc::new(Schema::new(vec![Field::new("v", DataType::Int64, false)])); let sq_batch = RecordBatch::try_new( sq_schema.clone(), vec![Arc::new(Int64Array::from(vec![10, 20, 30]))], )?; let ctx = SessionContext::new(); - ctx.register_table("main_t", Arc::new(MemTable::try_new(schema, vec![batches])?))?; + ctx.register_table( + "main_t", + Arc::new(MemTable::try_new(schema, vec![batches])?), + )?; ctx.register_table( "lookup", Arc::new(MemTable::try_new(sq_schema, vec![vec![sq_batch]])?), diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index f746df2ecc2fb..a28fbb463ddc7 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -2300,7 +2300,9 @@ impl protobuf::PhysicalPlanNode { Ok(ScalarSubqueryLink { plan, index }) }) .collect::>>()?; - Ok(Arc::new(ScalarSubqueryExec::new(input, subqueries, results))) + Ok(Arc::new(ScalarSubqueryExec::new( + input, subqueries, results, + ))) } fn try_from_explain_exec( diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index 2d02640a34442..fc817ebda1e48 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -54,7 +54,6 @@ use datafusion::physical_expr::window::{SlidingAggregateWindowExpr, StandardWind use datafusion::physical_expr::{ LexOrdering, PhysicalSortRequirement, ScalarFunctionExpr, }; -use datafusion_physical_expr::scalar_subquery::ScalarSubqueryExpr; use datafusion::physical_plan::aggregates::{ AggregateExec, AggregateMode, LimitOptions, PhysicalGroupBy, }; @@ -67,7 +66,6 @@ use datafusion::physical_plan::expressions::{ BinaryExpr, Column, NotExpr, PhysicalSortExpr, binary, cast, col, in_list, like, lit, }; use datafusion::physical_plan::filter::{FilterExec, FilterExecBuilder}; -use datafusion::physical_plan::scalar_subquery::{ScalarSubqueryExec, ScalarSubqueryLink}; use datafusion::physical_plan::joins::{ HashJoinExec, NestedLoopJoinExec, PartitionMode, SortMergeJoinExec, StreamJoinPartitionMode, SymmetricHashJoinExec, @@ -77,6 +75,9 @@ use datafusion::physical_plan::metrics::MetricType; use datafusion::physical_plan::placeholder_row::PlaceholderRowExec; use datafusion::physical_plan::projection::{ProjectionExec, ProjectionExpr}; use datafusion::physical_plan::repartition::RepartitionExec; +use datafusion::physical_plan::scalar_subquery::{ + ScalarSubqueryExec, ScalarSubqueryLink, +}; use datafusion::physical_plan::sorts::sort::SortExec; use datafusion::physical_plan::union::{InterleaveExec, UnionExec}; use datafusion::physical_plan::unnest::{ListUnnest, UnnestExec}; @@ -112,6 +113,7 @@ use datafusion_functions_aggregate::average::avg_udaf; use datafusion_functions_aggregate::min_max::max_udaf; use datafusion_functions_aggregate::nth_value::nth_value_udaf; use datafusion_functions_aggregate::string_agg::string_agg_udaf; +use datafusion_physical_expr::scalar_subquery::ScalarSubqueryExpr; use datafusion_proto::bytes::{ physical_plan_from_bytes_with_proto_converter, physical_plan_to_bytes_with_proto_converter, @@ -3200,9 +3202,7 @@ fn roundtrip_lead_with_default_value() -> Result<()> { /// same shared results container as ScalarSubqueryExec after a proto round-trip. #[test] fn roundtrip_scalar_subquery_exec() -> Result<()> { - let schema = Arc::new(Schema::new(vec![ - Field::new("a", DataType::Int64, false), - ])); + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, false)])); // Create a shared results container with one slot. let results = Arc::new(vec![OnceLock::new()]); @@ -3216,12 +3216,16 @@ fn roundtrip_scalar_subquery_exec() -> Result<()> { Arc::clone(&results), )); let predicate = binary(col("a", &schema)?, Operator::Eq, sq_expr, &schema)?; - let filter = FilterExec::try_new(predicate, Arc::new(EmptyExec::new(schema.clone())))?; + let filter = + FilterExec::try_new(predicate, Arc::new(EmptyExec::new(schema.clone())))?; // Build a trivial subquery plan. - let subquery_plan = Arc::new(EmptyExec::new(Arc::new(Schema::new(vec![ - Field::new("x", DataType::Int64, true), - ])))); + let subquery_plan = + Arc::new(EmptyExec::new(Arc::new(Schema::new(vec![Field::new( + "x", + DataType::Int64, + true, + )])))); let exec: Arc = Arc::new(ScalarSubqueryExec::new( Arc::new(filter), @@ -3237,8 +3241,11 @@ fn roundtrip_scalar_subquery_exec() -> Result<()> { // results through expression deserialization. let codec = DefaultPhysicalExtensionCodec {}; let converter = DeduplicatingProtoConverter {}; - let bytes = - physical_plan_to_bytes_with_proto_converter(Arc::clone(&exec), &codec, &converter)?; + let bytes = physical_plan_to_bytes_with_proto_converter( + Arc::clone(&exec), + &codec, + &converter, + )?; let ctx = SessionContext::new(); let deserialized = physical_plan_from_bytes_with_proto_converter( bytes.as_ref(), From 1412ab17511100ed9ef2d45199ea054b23187105 Mon Sep 17 00:00:00 2001 From: Neil Conway Date: Sun, 29 Mar 2026 16:05:25 -0400 Subject: [PATCH 03/28] Properly wait for subquery exec to complete before exec'ing main input --- .../physical-plan/src/scalar_subquery.rs | 53 ++++++++----------- 1 file changed, 23 insertions(+), 30 deletions(-) diff --git a/datafusion/physical-plan/src/scalar_subquery.rs b/datafusion/physical-plan/src/scalar_subquery.rs index 62ad72c27b2f6..bea3cf4a5e892 100644 --- a/datafusion/physical-plan/src/scalar_subquery.rs +++ b/datafusion/physical-plan/src/scalar_subquery.rs @@ -35,11 +35,11 @@ use datafusion_expr::execution_props::ScalarSubqueryResults; use datafusion_physical_expr::PhysicalExpr; use crate::execution_plan::{CardinalityEffect, ExecutionPlan, PlanProperties}; -use crate::joins::utils::OnceAsync; use crate::stream::RecordBatchStreamAdapter; use crate::{DisplayAs, DisplayFormatType, SendableRecordBatchStream}; -use futures::stream::StreamExt; +use futures::StreamExt; +use futures::TryStreamExt; /// Links a scalar subquery's execution plan to its index in the shared results /// container. The [`ScalarSubqueryExec`] that owns these links populates @@ -83,8 +83,9 @@ pub struct ScalarSubqueryExec { input: Arc, /// Subquery plans and their result indexes. subqueries: Vec, - /// Shared one-time async computation of subquery results. - subquery_future: Arc>>, + /// Ensures subqueries are executed exactly once, even when multiple + /// partitions call `execute()` concurrently. + subquery_barrier: Arc>, /// Shared results container; same instance held by ScalarSubqueryExpr nodes. results: ScalarSubqueryResults, /// Cached plan properties (copied from input). @@ -101,7 +102,7 @@ impl ScalarSubqueryExec { Self { input, subqueries, - subquery_future: Arc::default(), + subquery_barrier: Arc::new(tokio::sync::OnceCell::new()), results, cache, } @@ -195,35 +196,27 @@ impl ExecutionPlan for ScalarSubqueryExec { ) -> Result { let subqueries = self.subqueries.clone(); let results = Arc::clone(&self.results); - let ctx = Arc::clone(&context); + let barrier = Arc::clone(&self.subquery_barrier); + let input = Arc::clone(&self.input); + let schema = self.schema(); - // Use OnceAsync to ensure all subqueries are executed exactly once, - // even when multiple partitions call execute() concurrently. - let mut once_fut = self.subquery_future.try_once(move || { - Ok(async move { execute_subqueries(subqueries, results, ctx).await }) - })?; - - let input = self.input.execute(partition, context)?; - let schema = input.schema(); - - // Create a stream that first waits for subquery results, then - // delegates to the inner input stream. Ok(Box::pin(RecordBatchStreamAdapter::new( - Arc::clone(&schema), + schema, futures::stream::once(async move { - // Wait for subquery execution to complete. - let result: Result<()> = - std::future::poll_fn(|cx| once_fut.get(cx).map(|r| r.map(|_| ()))) - .await; - result - }) - .filter_map(|result| async move { - match result { - Ok(()) => None, // Subqueries done, proceed to input - Err(e) => Some(Err(e)), // Propagate error - } + // Execute all subqueries exactly once + barrier + .get_or_try_init(|| async { + execute_subqueries(subqueries, results, context.clone()) + .await + .map(|_| ()) + }) + .await?; + + // Now that the subqueries have been computed, we can safely + // start the main input + input.execute(partition, context) }) - .chain(input), + .try_flatten(), ))) } From cedfa5c0e80d8146603768ba62c8932cafa7fab7 Mon Sep 17 00:00:00 2001 From: Neil Conway Date: Sun, 29 Mar 2026 17:14:34 -0400 Subject: [PATCH 04/28] Better fix for async exec issue --- .../physical-plan/src/scalar_subquery.rs | 142 +++++++++++++++--- 1 file changed, 122 insertions(+), 20 deletions(-) diff --git a/datafusion/physical-plan/src/scalar_subquery.rs b/datafusion/physical-plan/src/scalar_subquery.rs index bea3cf4a5e892..3db81669ffd27 100644 --- a/datafusion/physical-plan/src/scalar_subquery.rs +++ b/datafusion/physical-plan/src/scalar_subquery.rs @@ -35,6 +35,7 @@ use datafusion_expr::execution_props::ScalarSubqueryResults; use datafusion_physical_expr::PhysicalExpr; use crate::execution_plan::{CardinalityEffect, ExecutionPlan, PlanProperties}; +use crate::joins::utils::{OnceAsync, OnceFut}; use crate::stream::RecordBatchStreamAdapter; use crate::{DisplayAs, DisplayFormatType, SendableRecordBatchStream}; @@ -83,9 +84,8 @@ pub struct ScalarSubqueryExec { input: Arc, /// Subquery plans and their result indexes. subqueries: Vec, - /// Ensures subqueries are executed exactly once, even when multiple - /// partitions call `execute()` concurrently. - subquery_barrier: Arc>, + /// Shared one-time async computation of subquery results. + subquery_future: Arc>, /// Shared results container; same instance held by ScalarSubqueryExpr nodes. results: ScalarSubqueryResults, /// Cached plan properties (copied from input). @@ -102,7 +102,7 @@ impl ScalarSubqueryExec { Self { input, subqueries, - subquery_barrier: Arc::new(tokio::sync::OnceCell::new()), + subquery_future: Arc::default(), results, cache, } @@ -196,24 +196,22 @@ impl ExecutionPlan for ScalarSubqueryExec { ) -> Result { let subqueries = self.subqueries.clone(); let results = Arc::clone(&self.results); - let barrier = Arc::clone(&self.subquery_barrier); + let subquery_ctx = Arc::clone(&context); + let mut subquery_future = self.subquery_future.try_once(move || { + Ok(async move { execute_subqueries(subqueries, results, subquery_ctx).await }) + })?; let input = Arc::clone(&self.input); let schema = self.schema(); Ok(Box::pin(RecordBatchStreamAdapter::new( schema, futures::stream::once(async move { - // Execute all subqueries exactly once - barrier - .get_or_try_init(|| async { - execute_subqueries(subqueries, results, context.clone()) - .await - .map(|_| ()) - }) - .await?; - - // Now that the subqueries have been computed, we can safely - // start the main input + // Execute all subqueries exactly once, even when multiple + // partitions call execute() concurrently. + wait_for_subqueries(&mut subquery_future).await?; + + // Now that the subqueries have finished execution, we can + // safely execute the main input input.execute(partition, context) }) .try_flatten(), @@ -246,19 +244,22 @@ impl ExecutionPlan for ScalarSubqueryExec { /// Execute all subquery plans, extract their scalar results, and populate /// the shared results container. +async fn wait_for_subqueries(fut: &mut OnceFut<()>) -> Result<()> { + std::future::poll_fn(|cx| fut.get_shared(cx)).await?; + Ok(()) +} + async fn execute_subqueries( subqueries: Vec, results: ScalarSubqueryResults, context: Arc, -) -> Result> { - let mut values = Vec::with_capacity(subqueries.len()); +) -> Result<()> { for sq in &subqueries { let value = execute_scalar_subquery(Arc::clone(&sq.plan), Arc::clone(&context)).await?; let _ = results[sq.index].set(value.clone()); - values.push(value); } - Ok(values) + Ok(()) } /// Execute a single subquery plan and extract the scalar value. @@ -304,11 +305,83 @@ mod tests { use crate::test::{self, TestMemoryExec}; use std::sync::OnceLock; + use std::sync::atomic::{AtomicUsize, Ordering}; + use crate::test::exec::ErrorExec; use arrow::array::{Int32Array, Int64Array}; use arrow::datatypes::{DataType, Field, Schema}; use arrow::record_batch::RecordBatch; + #[derive(Debug)] + struct CountingExec { + inner: Arc, + execute_calls: Arc, + } + + impl CountingExec { + fn new(inner: Arc, execute_calls: Arc) -> Self { + Self { + inner, + execute_calls, + } + } + } + + impl DisplayAs for CountingExec { + fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!(f, "CountingExec") + } + DisplayFormatType::TreeRender => write!(f, ""), + } + } + } + + impl ExecutionPlan for CountingExec { + fn name(&self) -> &'static str { + "CountingExec" + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn properties(&self) -> &Arc { + self.inner.properties() + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.inner] + } + + fn with_new_children( + self: Arc, + mut children: Vec>, + ) -> Result> { + Ok(Arc::new(Self::new( + children.remove(0), + Arc::clone(&self.execute_calls), + ))) + } + + fn apply_expressions( + &self, + _f: &mut dyn FnMut(&dyn PhysicalExpr) -> Result, + ) -> Result { + Ok(TreeNodeRecursion::Continue) + } + + fn execute( + &self, + partition: usize, + context: Arc, + ) -> Result { + self.execute_calls.fetch_add(1, Ordering::SeqCst); + self.inner.execute(partition, context) + } + } + fn make_subquery_plan(batches: Vec) -> Arc { let schema = batches[0].schema(); TestMemoryExec::try_new_exec(&[batches], schema, None).unwrap() @@ -406,4 +479,33 @@ mod tests { ); Ok(()) } + + #[tokio::test] + async fn test_failed_subquery_is_not_retried() -> Result<()> { + let execute_calls = Arc::new(AtomicUsize::new(0)); + let subquery_plan = Arc::new(CountingExec::new( + Arc::new(ErrorExec::new()), + Arc::clone(&execute_calls), + )); + let results = make_results(1); + let sq = ScalarSubqueryLink { + plan: subquery_plan, + index: 0, + }; + + let main_input = Arc::new(crate::placeholder_row::PlaceholderRowExec::new( + test::aggr_test_schema(), + )); + let exec = ScalarSubqueryExec::new(main_input, vec![sq], results); + + let ctx = Arc::new(TaskContext::default()); + let stream = exec.execute(0, Arc::clone(&ctx))?; + assert!(crate::common::collect(stream).await.is_err()); + + let stream = exec.execute(0, ctx)?; + assert!(crate::common::collect(stream).await.is_err()); + + assert_eq!(execute_calls.load(Ordering::SeqCst), 1); + Ok(()) + } } From d80569fef1a68995159b34b3d1b084228561b25f Mon Sep 17 00:00:00 2001 From: Neil Conway Date: Sun, 29 Mar 2026 19:27:38 -0400 Subject: [PATCH 05/28] Fix doc lint error --- datafusion/physical-expr/src/scalar_subquery.rs | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/datafusion/physical-expr/src/scalar_subquery.rs b/datafusion/physical-expr/src/scalar_subquery.rs index 40d5ca66d9ec3..bc6a8ebfbd5cf 100644 --- a/datafusion/physical-expr/src/scalar_subquery.rs +++ b/datafusion/physical-expr/src/scalar_subquery.rs @@ -18,9 +18,7 @@ //! Physical expression for uncorrelated scalar subqueries. //! //! [`ScalarSubqueryExpr`] reads a cached [`ScalarValue`] that is populated -//! at execution time by [`ScalarSubqueryExec`]. -//! -//! [`ScalarSubqueryExec`]: datafusion_physical_plan::scalar_subquery::ScalarSubqueryExec +//! at execution time by `ScalarSubqueryExec`. use std::any::Any; use std::fmt; From 9f606fb428f1f0803eddea18fba9f4a68c46066e Mon Sep 17 00:00:00 2001 From: Neil Conway Date: Sun, 29 Mar 2026 21:36:11 -0400 Subject: [PATCH 06/28] Implement logical plan serialization/deserialization for subqueries --- datafusion/proto/proto/datafusion.proto | 24 + datafusion/proto/src/generated/pbjson.rs | 475 ++++++++++++++++++ datafusion/proto/src/generated/prost.rs | 45 +- .../proto/src/logical_plan/from_proto.rs | 68 ++- datafusion/proto/src/logical_plan/mod.rs | 16 +- datafusion/proto/src/logical_plan/to_proto.rs | 62 ++- 6 files changed, 667 insertions(+), 23 deletions(-) diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 82aff7abce12a..36f5f882a4ba3 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -426,6 +426,10 @@ message LogicalExprNode { Unnest unnest = 35; + // Subquery expressions + ScalarSubqueryExprNode scalar_subquery_expr = 36; + InSubqueryExprNode in_subquery_expr = 37; + ExistsExprNode exists_expr = 38; } } @@ -433,6 +437,26 @@ message Wildcard { TableReference qualifier = 1; } +message SubqueryNode { + LogicalPlanNode subquery = 1; + repeated LogicalExprNode outer_ref_columns = 2; +} + +message ScalarSubqueryExprNode { + SubqueryNode subquery = 1; +} + +message InSubqueryExprNode { + LogicalExprNode expr = 1; + SubqueryNode subquery = 2; + bool negated = 3; +} + +message ExistsExprNode { + SubqueryNode subquery = 1; + bool negated = 2; +} + message PlaceholderNode { string id = 1; // We serialize the data type, metadata, and nullability separately to maintain diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 49617d74d8fc6..b32b7f9a63f94 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -5914,6 +5914,114 @@ impl<'de> serde::Deserialize<'de> for EmptyRelationNode { deserializer.deserialize_struct("datafusion.EmptyRelationNode", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for ExistsExprNode { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.subquery.is_some() { + len += 1; + } + if self.negated { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.ExistsExprNode", len)?; + if let Some(v) = self.subquery.as_ref() { + struct_ser.serialize_field("subquery", v)?; + } + if self.negated { + struct_ser.serialize_field("negated", &self.negated)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for ExistsExprNode { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "subquery", + "negated", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Subquery, + Negated, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl serde::de::Visitor<'_> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "subquery" => Ok(GeneratedField::Subquery), + "negated" => Ok(GeneratedField::Negated), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = ExistsExprNode; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.ExistsExprNode") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut subquery__ = None; + let mut negated__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Subquery => { + if subquery__.is_some() { + return Err(serde::de::Error::duplicate_field("subquery")); + } + subquery__ = map_.next_value()?; + } + GeneratedField::Negated => { + if negated__.is_some() { + return Err(serde::de::Error::duplicate_field("negated")); + } + negated__ = Some(map_.next_value()?); + } + } + } + Ok(ExistsExprNode { + subquery: subquery__, + negated: negated__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion.ExistsExprNode", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for ExplainExecNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result @@ -9261,6 +9369,131 @@ impl<'de> serde::Deserialize<'de> for InListNode { deserializer.deserialize_struct("datafusion.InListNode", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for InSubqueryExprNode { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.expr.is_some() { + len += 1; + } + if self.subquery.is_some() { + len += 1; + } + if self.negated { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.InSubqueryExprNode", len)?; + if let Some(v) = self.expr.as_ref() { + struct_ser.serialize_field("expr", v)?; + } + if let Some(v) = self.subquery.as_ref() { + struct_ser.serialize_field("subquery", v)?; + } + if self.negated { + struct_ser.serialize_field("negated", &self.negated)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for InSubqueryExprNode { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "expr", + "subquery", + "negated", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Expr, + Subquery, + Negated, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl serde::de::Visitor<'_> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "expr" => Ok(GeneratedField::Expr), + "subquery" => Ok(GeneratedField::Subquery), + "negated" => Ok(GeneratedField::Negated), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = InSubqueryExprNode; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.InSubqueryExprNode") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut expr__ = None; + let mut subquery__ = None; + let mut negated__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Expr => { + if expr__.is_some() { + return Err(serde::de::Error::duplicate_field("expr")); + } + expr__ = map_.next_value()?; + } + GeneratedField::Subquery => { + if subquery__.is_some() { + return Err(serde::de::Error::duplicate_field("subquery")); + } + subquery__ = map_.next_value()?; + } + GeneratedField::Negated => { + if negated__.is_some() { + return Err(serde::de::Error::duplicate_field("negated")); + } + negated__ = Some(map_.next_value()?); + } + } + } + Ok(InSubqueryExprNode { + expr: expr__, + subquery: subquery__, + negated: negated__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion.InSubqueryExprNode", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for InsertOp { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result @@ -12193,6 +12426,15 @@ impl serde::Serialize for LogicalExprNode { logical_expr_node::ExprType::Unnest(v) => { struct_ser.serialize_field("unnest", v)?; } + logical_expr_node::ExprType::ScalarSubqueryExpr(v) => { + struct_ser.serialize_field("scalarSubqueryExpr", v)?; + } + logical_expr_node::ExprType::InSubqueryExpr(v) => { + struct_ser.serialize_field("inSubqueryExpr", v)?; + } + logical_expr_node::ExprType::ExistsExpr(v) => { + struct_ser.serialize_field("existsExpr", v)?; + } } } struct_ser.end() @@ -12254,6 +12496,12 @@ impl<'de> serde::Deserialize<'de> for LogicalExprNode { "similarTo", "placeholder", "unnest", + "scalar_subquery_expr", + "scalarSubqueryExpr", + "in_subquery_expr", + "inSubqueryExpr", + "exists_expr", + "existsExpr", ]; #[allow(clippy::enum_variant_names)] @@ -12289,6 +12537,9 @@ impl<'de> serde::Deserialize<'de> for LogicalExprNode { SimilarTo, Placeholder, Unnest, + ScalarSubqueryExpr, + InSubqueryExpr, + ExistsExpr, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -12341,6 +12592,9 @@ impl<'de> serde::Deserialize<'de> for LogicalExprNode { "similarTo" | "similar_to" => Ok(GeneratedField::SimilarTo), "placeholder" => Ok(GeneratedField::Placeholder), "unnest" => Ok(GeneratedField::Unnest), + "scalarSubqueryExpr" | "scalar_subquery_expr" => Ok(GeneratedField::ScalarSubqueryExpr), + "inSubqueryExpr" | "in_subquery_expr" => Ok(GeneratedField::InSubqueryExpr), + "existsExpr" | "exists_expr" => Ok(GeneratedField::ExistsExpr), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -12578,6 +12832,27 @@ impl<'de> serde::Deserialize<'de> for LogicalExprNode { return Err(serde::de::Error::duplicate_field("unnest")); } expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::Unnest) +; + } + GeneratedField::ScalarSubqueryExpr => { + if expr_type__.is_some() { + return Err(serde::de::Error::duplicate_field("scalarSubqueryExpr")); + } + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::ScalarSubqueryExpr) +; + } + GeneratedField::InSubqueryExpr => { + if expr_type__.is_some() { + return Err(serde::de::Error::duplicate_field("inSubqueryExpr")); + } + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::InSubqueryExpr) +; + } + GeneratedField::ExistsExpr => { + if expr_type__.is_some() { + return Err(serde::de::Error::duplicate_field("existsExpr")); + } + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::ExistsExpr) ; } } @@ -21497,6 +21772,97 @@ impl<'de> serde::Deserialize<'de> for ScalarSubqueryExecNode { deserializer.deserialize_struct("datafusion.ScalarSubqueryExecNode", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for ScalarSubqueryExprNode { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.subquery.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.ScalarSubqueryExprNode", len)?; + if let Some(v) = self.subquery.as_ref() { + struct_ser.serialize_field("subquery", v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for ScalarSubqueryExprNode { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "subquery", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Subquery, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl serde::de::Visitor<'_> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "subquery" => Ok(GeneratedField::Subquery), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = ScalarSubqueryExprNode; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.ScalarSubqueryExprNode") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut subquery__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Subquery => { + if subquery__.is_some() { + return Err(serde::de::Error::duplicate_field("subquery")); + } + subquery__ = map_.next_value()?; + } + } + } + Ok(ScalarSubqueryExprNode { + subquery: subquery__, + }) + } + } + deserializer.deserialize_struct("datafusion.ScalarSubqueryExprNode", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for ScalarUdfExprNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result @@ -23174,6 +23540,115 @@ impl<'de> serde::Deserialize<'de> for SubqueryAliasNode { deserializer.deserialize_struct("datafusion.SubqueryAliasNode", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for SubqueryNode { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.subquery.is_some() { + len += 1; + } + if !self.outer_ref_columns.is_empty() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.SubqueryNode", len)?; + if let Some(v) = self.subquery.as_ref() { + struct_ser.serialize_field("subquery", v)?; + } + if !self.outer_ref_columns.is_empty() { + struct_ser.serialize_field("outerRefColumns", &self.outer_ref_columns)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for SubqueryNode { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "subquery", + "outer_ref_columns", + "outerRefColumns", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Subquery, + OuterRefColumns, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl serde::de::Visitor<'_> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "subquery" => Ok(GeneratedField::Subquery), + "outerRefColumns" | "outer_ref_columns" => Ok(GeneratedField::OuterRefColumns), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = SubqueryNode; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.SubqueryNode") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut subquery__ = None; + let mut outer_ref_columns__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Subquery => { + if subquery__.is_some() { + return Err(serde::de::Error::duplicate_field("subquery")); + } + subquery__ = map_.next_value()?; + } + GeneratedField::OuterRefColumns => { + if outer_ref_columns__.is_some() { + return Err(serde::de::Error::duplicate_field("outerRefColumns")); + } + outer_ref_columns__ = Some(map_.next_value()?); + } + } + } + Ok(SubqueryNode { + subquery: subquery__, + outer_ref_columns: outer_ref_columns__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion.SubqueryNode", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for SymmetricHashJoinExecNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 32d8b60034074..6c4f898909fb1 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -195,8 +195,8 @@ pub mod projection_node { pub struct SelectionNode { #[prost(message, optional, boxed, tag = "1")] pub input: ::core::option::Option<::prost::alloc::boxed::Box>, - #[prost(message, optional, tag = "2")] - pub expr: ::core::option::Option, + #[prost(message, optional, boxed, tag = "2")] + pub expr: ::core::option::Option<::prost::alloc::boxed::Box>, } #[derive(Clone, PartialEq, ::prost::Message)] pub struct SortNode { @@ -382,8 +382,8 @@ pub struct JoinNode { pub right_join_key: ::prost::alloc::vec::Vec, #[prost(enumeration = "super::datafusion_common::NullEquality", tag = "7")] pub null_equality: i32, - #[prost(message, optional, tag = "8")] - pub filter: ::core::option::Option, + #[prost(message, optional, boxed, tag = "8")] + pub filter: ::core::option::Option<::prost::alloc::boxed::Box>, } #[derive(Clone, PartialEq, ::prost::Message)] pub struct DistinctNode { @@ -578,7 +578,7 @@ pub struct SubqueryAliasNode { pub struct LogicalExprNode { #[prost( oneof = "logical_expr_node::ExprType", - tags = "1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 13, 14, 15, 17, 18, 19, 20, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35" + tags = "1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 13, 14, 15, 17, 18, 19, 20, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38" )] pub expr_type: ::core::option::Option, } @@ -656,6 +656,13 @@ pub mod logical_expr_node { Placeholder(super::PlaceholderNode), #[prost(message, tag = "35")] Unnest(super::Unnest), + /// Subquery expressions + #[prost(message, tag = "36")] + ScalarSubqueryExpr(::prost::alloc::boxed::Box), + #[prost(message, tag = "37")] + InSubqueryExpr(::prost::alloc::boxed::Box), + #[prost(message, tag = "38")] + ExistsExpr(::prost::alloc::boxed::Box), } } #[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] @@ -664,6 +671,34 @@ pub struct Wildcard { pub qualifier: ::core::option::Option, } #[derive(Clone, PartialEq, ::prost::Message)] +pub struct SubqueryNode { + #[prost(message, optional, boxed, tag = "1")] + pub subquery: ::core::option::Option<::prost::alloc::boxed::Box>, + #[prost(message, repeated, tag = "2")] + pub outer_ref_columns: ::prost::alloc::vec::Vec, +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ScalarSubqueryExprNode { + #[prost(message, optional, boxed, tag = "1")] + pub subquery: ::core::option::Option<::prost::alloc::boxed::Box>, +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct InSubqueryExprNode { + #[prost(message, optional, boxed, tag = "1")] + pub expr: ::core::option::Option<::prost::alloc::boxed::Box>, + #[prost(message, optional, boxed, tag = "2")] + pub subquery: ::core::option::Option<::prost::alloc::boxed::Box>, + #[prost(bool, tag = "3")] + pub negated: bool, +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ExistsExprNode { + #[prost(message, optional, boxed, tag = "1")] + pub subquery: ::core::option::Option<::prost::alloc::boxed::Box>, + #[prost(bool, tag = "2")] + pub negated: bool, +} +#[derive(Clone, PartialEq, ::prost::Message)] pub struct PlaceholderNode { #[prost(string, tag = "1")] pub id: ::prost::alloc::string::String, diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index ed33d9fab1820..04e21b04844ea 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -25,8 +25,9 @@ use datafusion_common::{ }; use datafusion_execution::registry::FunctionRegistry; use datafusion_expr::dml::InsertOp; -use datafusion_expr::expr::{Alias, NullTreatment, Placeholder, Sort}; +use datafusion_expr::expr::{Alias, Exists, InSubquery, NullTreatment, Placeholder, Sort}; use datafusion_expr::expr::{Unnest, WildcardOptions}; +use datafusion_expr::logical_plan::Subquery; use datafusion_expr::{ Between, BinaryExpr, Case, Cast, Expr, GroupingSet, GroupingSet::GroupingSets, @@ -52,7 +53,7 @@ use crate::protobuf::{ }, }; -use super::LogicalExtensionCodec; +use super::{AsLogicalPlan, LogicalExtensionCodec}; impl From<&protobuf::UnnestOptions> for UnnestOptions { fn from(opts: &protobuf::UnnestOptions) -> Self { @@ -657,9 +658,72 @@ pub fn parse_expr( ))) } }, + ExprType::ScalarSubqueryExpr(sq) => { + let subquery = parse_subquery( + sq.subquery.as_deref().ok_or_else(|| { + Error::required("ScalarSubqueryExprNode.subquery") + })?, + registry, + codec, + )?; + Ok(Expr::ScalarSubquery(subquery)) + } + ExprType::InSubqueryExpr(sq) => { + let expr = parse_required_expr( + sq.expr.as_deref(), + registry, + "InSubqueryExprNode.expr", + codec, + )?; + let subquery = parse_subquery( + sq.subquery.as_deref().ok_or_else(|| { + Error::required("InSubqueryExprNode.subquery") + })?, + registry, + codec, + )?; + Ok(Expr::InSubquery(InSubquery::new( + Box::new(expr), + subquery, + sq.negated, + ))) + } + ExprType::ExistsExpr(sq) => { + let subquery = parse_subquery( + sq.subquery.as_deref().ok_or_else(|| { + Error::required("ExistsExprNode.subquery") + })?, + registry, + codec, + )?; + Ok(Expr::Exists(Exists::new(subquery, sq.negated))) + } } } +fn parse_subquery( + proto: &protobuf::SubqueryNode, + registry: &dyn FunctionRegistry, + codec: &dyn LogicalExtensionCodec, +) -> Result { + let plan_node = proto + .subquery + .as_ref() + .ok_or_else(|| Error::required("SubqueryNode.subquery"))?; + // TaskContext is needed for try_into_logical_plan but only used for + // table provider deserialization (cache manager), not for the plan + // structure itself. + let ctx = datafusion_execution::TaskContext::default(); + let plan = plan_node.try_into_logical_plan(&ctx, codec)?; + let outer_ref_columns = + parse_exprs(&proto.outer_ref_columns, registry, codec)?; + Ok(Subquery { + subquery: Arc::new(plan), + outer_ref_columns, + spans: Default::default(), + }) +} + /// Parse a vector of `protobuf::LogicalExprNode`s. pub fn parse_exprs<'a, I>( protos: I, diff --git a/datafusion/proto/src/logical_plan/mod.rs b/datafusion/proto/src/logical_plan/mod.rs index 1e0264690d4fb..57826c31f695e 100644 --- a/datafusion/proto/src/logical_plan/mod.rs +++ b/datafusion/proto/src/logical_plan/mod.rs @@ -1325,10 +1325,10 @@ impl AsLogicalPlan for LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::Selection(Box::new( protobuf::SelectionNode { input: Some(Box::new(input)), - expr: Some(serialize_expr( + expr: Some(Box::new(serialize_expr( &filter.predicate, extension_codec, - )?), + )?)), }, ))), }) @@ -1444,7 +1444,7 @@ impl AsLogicalPlan for LogicalPlanNode { null_equality.to_owned().into(); let filter = filter .as_ref() - .map(|e| serialize_expr(e, extension_codec)) + .map(|e| serialize_expr(e, extension_codec).map(Box::new)) .map_or(Ok(None), |v| v.map(Some))?; Ok(LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::Join(Box::new( @@ -1461,8 +1461,14 @@ impl AsLogicalPlan for LogicalPlanNode { ))), }) } - LogicalPlan::Subquery(_) => { - not_impl_err!("LogicalPlan serde is not yet implemented for subqueries") + LogicalPlan::Subquery(subquery) => { + // Serialize the inner subquery plan directly — the + // LogicalPlan::Subquery wrapper is reconstructed during + // expression deserialization. + LogicalPlanNode::try_from_logical_plan( + &subquery.subquery, + extension_codec, + ) } LogicalPlan::SubqueryAlias(SubqueryAlias { input, alias, .. }) => { let input: LogicalPlanNode = LogicalPlanNode::try_from_logical_plan( diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 6fcb7389922ad..d6ccf8bd1b78b 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -25,9 +25,11 @@ use datafusion_common::{NullEquality, TableReference, UnnestOptions}; use datafusion_expr::WriteOp; use datafusion_expr::dml::InsertOp; use datafusion_expr::expr::{ - self, AggregateFunctionParams, Alias, Between, BinaryExpr, Cast, GroupingSet, InList, - Like, NullTreatment, Placeholder, ScalarFunction, Unnest, + self, AggregateFunctionParams, Alias, Between, BinaryExpr, Cast, Exists, + GroupingSet, InList, InSubquery, Like, NullTreatment, Placeholder, ScalarFunction, + Unnest, }; +use datafusion_expr::logical_plan::Subquery; use datafusion_expr::{ Expr, JoinConstraint, JoinType, SortExpr, TryCast, WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, logical_plan::PlanType, @@ -48,7 +50,8 @@ use crate::protobuf::{ }, }; -use super::LogicalExtensionCodec; +use super::{AsLogicalPlan, LogicalExtensionCodec}; +use crate::protobuf::LogicalPlanNode; impl From<&UnnestOptions> for protobuf::UnnestOptions { fn from(opts: &UnnestOptions) -> Self { @@ -579,14 +582,38 @@ pub fn serialize_expr( qualifier: qualifier.to_owned().map(|x| x.into()), })), }, - Expr::ScalarSubquery(_) - | Expr::InSubquery(_) - | Expr::Exists { .. } - | Expr::SetComparison(_) - | Expr::OuterReferenceColumn { .. } => { - // we would need to add logical plan operators to datafusion.proto to support this - // see discussion in https://github.com/apache/datafusion/issues/2565 - return Err(Error::General("Proto serialization error: Expr::ScalarSubquery(_) | Expr::InSubquery(_) | Expr::Exists { .. } | Exp:OuterReferenceColumn not supported".to_string())); + Expr::ScalarSubquery(subquery) => protobuf::LogicalExprNode { + expr_type: Some(ExprType::ScalarSubqueryExpr(Box::new( + protobuf::ScalarSubqueryExprNode { + subquery: Some(Box::new(serialize_subquery(subquery, codec)?)), + }, + ))), + }, + Expr::InSubquery(InSubquery { + expr, + subquery, + negated, + }) => protobuf::LogicalExprNode { + expr_type: Some(ExprType::InSubqueryExpr(Box::new( + protobuf::InSubqueryExprNode { + expr: Some(Box::new(serialize_expr(expr.as_ref(), codec)?)), + subquery: Some(Box::new(serialize_subquery(subquery, codec)?)), + negated: *negated, + }, + ))), + }, + Expr::Exists(Exists { subquery, negated }) => protobuf::LogicalExprNode { + expr_type: Some(ExprType::ExistsExpr(Box::new( + protobuf::ExistsExprNode { + subquery: Some(Box::new(serialize_subquery(subquery, codec)?)), + negated: *negated, + }, + ))), + }, + Expr::OuterReferenceColumn(_, _) | Expr::SetComparison(_) => { + return Err(Error::General(format!( + "Proto serialization error: {expr} is not yet supported" + ))); } Expr::GroupingSet(GroupingSet::Cube(exprs)) => protobuf::LogicalExprNode { expr_type: Some(ExprType::Cube(CubeNode { @@ -631,6 +658,19 @@ pub fn serialize_expr( Ok(expr_node) } +fn serialize_subquery( + subquery: &Subquery, + codec: &dyn LogicalExtensionCodec, +) -> Result { + let plan = LogicalPlanNode::try_from_logical_plan(&subquery.subquery, codec) + .map_err(|e| Error::General(e.to_string()))?; + let outer_ref_columns = serialize_exprs(&subquery.outer_ref_columns, codec)?; + Ok(protobuf::SubqueryNode { + subquery: Some(Box::new(plan)), + outer_ref_columns, + }) +} + pub fn serialize_sorts<'a, I>( sorts: I, codec: &dyn LogicalExtensionCodec, From b07491b081a3f3cadde9dc5e4dda94bfe02ed9f0 Mon Sep 17 00:00:00 2001 From: Neil Conway Date: Sun, 29 Mar 2026 21:38:01 -0400 Subject: [PATCH 07/28] cargo fmt --- .../proto/src/logical_plan/from_proto.rs | 25 ++++++++++--------- datafusion/proto/src/logical_plan/to_proto.rs | 15 +++++------ 2 files changed, 19 insertions(+), 21 deletions(-) diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 04e21b04844ea..aabb700e51b70 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -25,7 +25,9 @@ use datafusion_common::{ }; use datafusion_execution::registry::FunctionRegistry; use datafusion_expr::dml::InsertOp; -use datafusion_expr::expr::{Alias, Exists, InSubquery, NullTreatment, Placeholder, Sort}; +use datafusion_expr::expr::{ + Alias, Exists, InSubquery, NullTreatment, Placeholder, Sort, +}; use datafusion_expr::expr::{Unnest, WildcardOptions}; use datafusion_expr::logical_plan::Subquery; use datafusion_expr::{ @@ -660,9 +662,9 @@ pub fn parse_expr( }, ExprType::ScalarSubqueryExpr(sq) => { let subquery = parse_subquery( - sq.subquery.as_deref().ok_or_else(|| { - Error::required("ScalarSubqueryExprNode.subquery") - })?, + sq.subquery + .as_deref() + .ok_or_else(|| Error::required("ScalarSubqueryExprNode.subquery"))?, registry, codec, )?; @@ -676,9 +678,9 @@ pub fn parse_expr( codec, )?; let subquery = parse_subquery( - sq.subquery.as_deref().ok_or_else(|| { - Error::required("InSubqueryExprNode.subquery") - })?, + sq.subquery + .as_deref() + .ok_or_else(|| Error::required("InSubqueryExprNode.subquery"))?, registry, codec, )?; @@ -690,9 +692,9 @@ pub fn parse_expr( } ExprType::ExistsExpr(sq) => { let subquery = parse_subquery( - sq.subquery.as_deref().ok_or_else(|| { - Error::required("ExistsExprNode.subquery") - })?, + sq.subquery + .as_deref() + .ok_or_else(|| Error::required("ExistsExprNode.subquery"))?, registry, codec, )?; @@ -715,8 +717,7 @@ fn parse_subquery( // structure itself. let ctx = datafusion_execution::TaskContext::default(); let plan = plan_node.try_into_logical_plan(&ctx, codec)?; - let outer_ref_columns = - parse_exprs(&proto.outer_ref_columns, registry, codec)?; + let outer_ref_columns = parse_exprs(&proto.outer_ref_columns, registry, codec)?; Ok(Subquery { subquery: Arc::new(plan), outer_ref_columns, diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index d6ccf8bd1b78b..1a5bd7b5ae41e 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -25,9 +25,8 @@ use datafusion_common::{NullEquality, TableReference, UnnestOptions}; use datafusion_expr::WriteOp; use datafusion_expr::dml::InsertOp; use datafusion_expr::expr::{ - self, AggregateFunctionParams, Alias, Between, BinaryExpr, Cast, Exists, - GroupingSet, InList, InSubquery, Like, NullTreatment, Placeholder, ScalarFunction, - Unnest, + self, AggregateFunctionParams, Alias, Between, BinaryExpr, Cast, Exists, GroupingSet, + InList, InSubquery, Like, NullTreatment, Placeholder, ScalarFunction, Unnest, }; use datafusion_expr::logical_plan::Subquery; use datafusion_expr::{ @@ -603,12 +602,10 @@ pub fn serialize_expr( ))), }, Expr::Exists(Exists { subquery, negated }) => protobuf::LogicalExprNode { - expr_type: Some(ExprType::ExistsExpr(Box::new( - protobuf::ExistsExprNode { - subquery: Some(Box::new(serialize_subquery(subquery, codec)?)), - negated: *negated, - }, - ))), + expr_type: Some(ExprType::ExistsExpr(Box::new(protobuf::ExistsExprNode { + subquery: Some(Box::new(serialize_subquery(subquery, codec)?)), + negated: *negated, + }))), }, Expr::OuterReferenceColumn(_, _) | Expr::SetComparison(_) => { return Err(Error::General(format!( From 27a1ac29f453d34d3eaa4cf9173029d7036ac8a0 Mon Sep 17 00:00:00 2001 From: Neil Conway Date: Mon, 30 Mar 2026 12:39:56 -0400 Subject: [PATCH 08/28] Refactor logical plan deserialization --- datafusion/proto/src/bytes/mod.rs | 121 +++------------- .../proto/src/logical_plan/from_proto.rs | 133 +++++++++--------- .../tests/cases/roundtrip_logical_plan.rs | 9 +- datafusion/proto/tests/cases/serialize.rs | 6 +- 4 files changed, 90 insertions(+), 179 deletions(-) diff --git a/datafusion/proto/src/bytes/mod.rs b/datafusion/proto/src/bytes/mod.rs index d6f511476befe..67a591d92bfbb 100644 --- a/datafusion/proto/src/bytes/mod.rs +++ b/datafusion/proto/src/bytes/mod.rs @@ -27,19 +27,13 @@ use crate::physical_plan::{ use crate::protobuf; use datafusion_common::{Result, plan_datafusion_err}; use datafusion_execution::TaskContext; -use datafusion_expr::{ - AggregateUDF, Expr, LogicalPlan, Volatility, WindowUDF, create_udaf, create_udf, - create_udwf, -}; +use datafusion_expr::{Expr, LogicalPlan}; use prost::{ Message, bytes::{Bytes, BytesMut}, }; use std::sync::Arc; -// Reexport Bytes which appears in the API -use datafusion_execution::registry::FunctionRegistry; -use datafusion_expr::planner::ExprPlanner; use datafusion_physical_plan::ExecutionPlan; mod registry; @@ -65,26 +59,21 @@ pub trait Serializeable: Sized { /// Convert `self` to an opaque byte stream fn to_bytes(&self) -> Result; - /// Convert `bytes` (the output of [`to_bytes`]) back into an - /// object. This will error if the serialized bytes contain any - /// user defined functions, in which case use - /// [`from_bytes_with_registry`] + /// Convert `bytes` (the output of [`to_bytes`]) back into an object. This + /// will error if the serialized bytes contain any user defined functions, + /// in which case use [`from_bytes_with_ctx`] /// /// [`to_bytes`]: Self::to_bytes - /// [`from_bytes_with_registry`]: Self::from_bytes_with_registry + /// [`from_bytes_with_ctx`]: Self::from_bytes_with_ctx fn from_bytes(bytes: &[u8]) -> Result { - Self::from_bytes_with_registry(bytes, ®istry::NoRegistry {}) + Self::from_bytes_with_ctx(bytes, &TaskContext::default()) } - /// Convert `bytes` (the output of [`to_bytes`]) back into an - /// object resolving user defined functions with the specified - /// `registry` + /// Convert `bytes` (the output of [`to_bytes`]) back into an object + /// resolving user defined functions with the specified `ctx` /// /// [`to_bytes`]: Self::to_bytes - fn from_bytes_with_registry( - bytes: &[u8], - registry: &dyn FunctionRegistry, - ) -> Result; + fn from_bytes_with_ctx(bytes: &[u8], ctx: &TaskContext) -> Result; } impl Serializeable for Expr { @@ -100,100 +89,22 @@ impl Serializeable for Expr { let bytes: Bytes = buffer.into(); - // the produced byte stream may lead to "recursion limit" errors, see + // The produced byte stream may lead to "recursion limit" errors, see // https://github.com/apache/datafusion/issues/3968 - // Until the underlying prost issue ( https://github.com/tokio-rs/prost/issues/736 ) is fixed, we try to - // deserialize the data here and check for errors. - // - // Need to provide some placeholder registry because the stream may contain UDFs - struct PlaceHolderRegistry; - - impl FunctionRegistry for PlaceHolderRegistry { - fn udfs(&self) -> std::collections::HashSet { - std::collections::HashSet::default() - } - - fn udf(&self, name: &str) -> Result> { - Ok(Arc::new(create_udf( - name, - vec![], - arrow::datatypes::DataType::Null, - Volatility::Immutable, - Arc::new(|_| unimplemented!()), - ))) - } - - fn udaf(&self, name: &str) -> Result> { - Ok(Arc::new(create_udaf( - name, - vec![arrow::datatypes::DataType::Null], - Arc::new(arrow::datatypes::DataType::Null), - Volatility::Immutable, - Arc::new(|_| unimplemented!()), - Arc::new(vec![]), - ))) - } - - fn udwf(&self, name: &str) -> Result> { - Ok(Arc::new(create_udwf( - name, - arrow::datatypes::DataType::Null, - Arc::new(arrow::datatypes::DataType::Null), - Volatility::Immutable, - Arc::new(|| unimplemented!()), - ))) - } - fn register_udaf( - &mut self, - _udaf: Arc, - ) -> Result>> { - datafusion_common::internal_err!( - "register_udaf called in Placeholder Registry!" - ) - } - fn register_udf( - &mut self, - _udf: Arc, - ) -> Result>> { - datafusion_common::internal_err!( - "register_udf called in Placeholder Registry!" - ) - } - fn register_udwf( - &mut self, - _udaf: Arc, - ) -> Result>> { - datafusion_common::internal_err!( - "register_udwf called in Placeholder Registry!" - ) - } - - fn expr_planners(&self) -> Vec> { - vec![] - } - - fn udafs(&self) -> std::collections::HashSet { - std::collections::HashSet::default() - } - - fn udwfs(&self) -> std::collections::HashSet { - std::collections::HashSet::default() - } - } - Expr::from_bytes_with_registry(&bytes, &PlaceHolderRegistry)?; + // Until the underlying prost issue ( https://github.com/tokio-rs/prost/issues/736 ) + // is fixed, verify the bytes can be decoded without hitting that limit. + protobuf::LogicalExprNode::decode(bytes.as_ref()) + .map_err(|e| plan_datafusion_err!("Error decoding expr as protobuf: {e}"))?; Ok(bytes) } - fn from_bytes_with_registry( - bytes: &[u8], - registry: &dyn FunctionRegistry, - ) -> Result { + fn from_bytes_with_ctx(bytes: &[u8], ctx: &TaskContext) -> Result { let protobuf = protobuf::LogicalExprNode::decode(bytes) .map_err(|e| plan_datafusion_err!("Error decoding expr as protobuf: {e}"))?; let extension_codec = DefaultLogicalExtensionCodec {}; - logical_plan::from_proto::parse_expr(&protobuf, registry, &extension_codec) + logical_plan::from_proto::parse_expr(&protobuf, ctx, &extension_codec) .map_err(|e| plan_datafusion_err!("Error parsing protobuf into Expr: {e}")) } } diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index aabb700e51b70..6254507509724 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -23,6 +23,7 @@ use datafusion_common::{ NullEquality, RecursionUnnestOption, Result, ScalarValue, TableReference, UnnestOptions, exec_datafusion_err, internal_err, plan_datafusion_err, }; +use datafusion_execution::TaskContext; use datafusion_execution::registry::FunctionRegistry; use datafusion_expr::dml::InsertOp; use datafusion_expr::expr::{ @@ -259,7 +260,7 @@ impl From for NullTreatment { pub fn parse_expr( proto: &protobuf::LogicalExprNode, - registry: &dyn FunctionRegistry, + ctx: &TaskContext, codec: &dyn LogicalExtensionCodec, ) -> Result { use protobuf::{logical_expr_node::ExprType, window_expr_node}; @@ -272,7 +273,7 @@ pub fn parse_expr( match expr_type { ExprType::BinaryExpr(binary_expr) => { let op = from_proto_binary_op(&binary_expr.op)?; - let operands = parse_exprs(&binary_expr.operands, registry, codec)?; + let operands = parse_exprs(&binary_expr.operands, ctx, codec)?; if operands.len() < 2 { return Err(proto_error( @@ -299,8 +300,8 @@ pub fn parse_expr( .window_function .as_ref() .ok_or_else(|| Error::required("window_function"))?; - let partition_by = parse_exprs(&expr.partition_by, registry, codec)?; - let mut order_by = parse_sorts(&expr.order_by, registry, codec)?; + let partition_by = parse_exprs(&expr.partition_by, ctx, codec)?; + let mut order_by = parse_sorts(&expr.order_by, ctx, codec)?; let window_frame = expr .window_frame .as_ref() @@ -332,7 +333,7 @@ pub fn parse_expr( window_expr_node::WindowFunction::Udaf(udaf_name) => { let udaf_function = match &expr.fun_definition { Some(buf) => codec.try_decode_udaf(udaf_name, buf)?, - None => registry + None => ctx .udaf(udaf_name) .or_else(|_| codec.try_decode_udaf(udaf_name, &[]))?, }; @@ -341,7 +342,7 @@ pub fn parse_expr( window_expr_node::WindowFunction::Udwf(udwf_name) => { let udwf_function = match &expr.fun_definition { Some(buf) => codec.try_decode_udwf(udwf_name, buf)?, - None => registry + None => ctx .udwf(udwf_name) .or_else(|_| codec.try_decode_udwf(udwf_name, &[]))?, }; @@ -349,7 +350,7 @@ pub fn parse_expr( } }; - let args = parse_exprs(&expr.exprs, registry, codec)?; + let args = parse_exprs(&expr.exprs, ctx, codec)?; let mut builder = Expr::from(WindowFunction::new(agg_fn, args)) .partition_by(partition_by) .order_by(order_by) @@ -360,8 +361,7 @@ pub fn parse_expr( builder = builder.distinct(); }; - if let Some(filter) = - parse_optional_expr(expr.filter.as_deref(), registry, codec)? + if let Some(filter) = parse_optional_expr(expr.filter.as_deref(), ctx, codec)? { builder = builder.filter(filter); } @@ -369,7 +369,7 @@ pub fn parse_expr( builder.build().map_err(Error::DataFusionError) } ExprType::Alias(alias) => Ok(Expr::Alias(Alias::new( - parse_required_expr(alias.expr.as_deref(), registry, "expr", codec)?, + parse_required_expr(alias.expr.as_deref(), ctx, "expr", codec)?, alias .relation .first() @@ -379,69 +379,69 @@ pub fn parse_expr( ))), ExprType::IsNullExpr(is_null) => Ok(Expr::IsNull(Box::new(parse_required_expr( is_null.expr.as_deref(), - registry, + ctx, "expr", codec, )?))), ExprType::IsNotNullExpr(is_not_null) => Ok(Expr::IsNotNull(Box::new( - parse_required_expr(is_not_null.expr.as_deref(), registry, "expr", codec)?, + parse_required_expr(is_not_null.expr.as_deref(), ctx, "expr", codec)?, ))), ExprType::NotExpr(not) => Ok(Expr::Not(Box::new(parse_required_expr( not.expr.as_deref(), - registry, + ctx, "expr", codec, )?))), ExprType::IsTrue(msg) => Ok(Expr::IsTrue(Box::new(parse_required_expr( msg.expr.as_deref(), - registry, + ctx, "expr", codec, )?))), ExprType::IsFalse(msg) => Ok(Expr::IsFalse(Box::new(parse_required_expr( msg.expr.as_deref(), - registry, + ctx, "expr", codec, )?))), ExprType::IsUnknown(msg) => Ok(Expr::IsUnknown(Box::new(parse_required_expr( msg.expr.as_deref(), - registry, + ctx, "expr", codec, )?))), ExprType::IsNotTrue(msg) => Ok(Expr::IsNotTrue(Box::new(parse_required_expr( msg.expr.as_deref(), - registry, + ctx, "expr", codec, )?))), ExprType::IsNotFalse(msg) => Ok(Expr::IsNotFalse(Box::new(parse_required_expr( msg.expr.as_deref(), - registry, + ctx, "expr", codec, )?))), ExprType::IsNotUnknown(msg) => Ok(Expr::IsNotUnknown(Box::new( - parse_required_expr(msg.expr.as_deref(), registry, "expr", codec)?, + parse_required_expr(msg.expr.as_deref(), ctx, "expr", codec)?, ))), ExprType::Between(between) => Ok(Expr::Between(Between::new( Box::new(parse_required_expr( between.expr.as_deref(), - registry, + ctx, "expr", codec, )?), between.negated, Box::new(parse_required_expr( between.low.as_deref(), - registry, + ctx, "expr", codec, )?), Box::new(parse_required_expr( between.high.as_deref(), - registry, + ctx, "expr", codec, )?), @@ -450,13 +450,13 @@ pub fn parse_expr( like.negated, Box::new(parse_required_expr( like.expr.as_deref(), - registry, + ctx, "expr", codec, )?), Box::new(parse_required_expr( like.pattern.as_deref(), - registry, + ctx, "pattern", codec, )?), @@ -467,13 +467,13 @@ pub fn parse_expr( like.negated, Box::new(parse_required_expr( like.expr.as_deref(), - registry, + ctx, "expr", codec, )?), Box::new(parse_required_expr( like.pattern.as_deref(), - registry, + ctx, "pattern", codec, )?), @@ -484,13 +484,13 @@ pub fn parse_expr( like.negated, Box::new(parse_required_expr( like.expr.as_deref(), - registry, + ctx, "expr", codec, )?), Box::new(parse_required_expr( like.pattern.as_deref(), - registry, + ctx, "pattern", codec, )?), @@ -504,13 +504,13 @@ pub fn parse_expr( .map(|e| { let when_expr = parse_required_expr( e.when_expr.as_ref(), - registry, + ctx, "when_expr", codec, )?; let then_expr = parse_required_expr( e.then_expr.as_ref(), - registry, + ctx, "then_expr", codec, )?; @@ -518,16 +518,15 @@ pub fn parse_expr( }) .collect::, Box)>, Error>>()?; Ok(Expr::Case(Case::new( - parse_optional_expr(case.expr.as_deref(), registry, codec)?.map(Box::new), + parse_optional_expr(case.expr.as_deref(), ctx, codec)?.map(Box::new), when_then_expr, - parse_optional_expr(case.else_expr.as_deref(), registry, codec)? - .map(Box::new), + parse_optional_expr(case.else_expr.as_deref(), ctx, codec)?.map(Box::new), ))) } ExprType::Cast(cast) => { let expr = Box::new(parse_required_expr( cast.expr.as_deref(), - registry, + ctx, "expr", codec, )?); @@ -540,7 +539,7 @@ pub fn parse_expr( ExprType::TryCast(cast) => { let expr = Box::new(parse_required_expr( cast.expr.as_deref(), - registry, + ctx, "expr", codec, )?); @@ -554,10 +553,10 @@ pub fn parse_expr( ))) } ExprType::Negative(negative) => Ok(Expr::Negative(Box::new( - parse_required_expr(negative.expr.as_deref(), registry, "expr", codec)?, + parse_required_expr(negative.expr.as_deref(), ctx, "expr", codec)?, ))), ExprType::Unnest(unnest) => { - let mut exprs = parse_exprs(&unnest.exprs, registry, codec)?; + let mut exprs = parse_exprs(&unnest.exprs, ctx, codec)?; if exprs.len() != 1 { return Err(proto_error("Unnest must have exactly one expression")); } @@ -566,11 +565,11 @@ pub fn parse_expr( ExprType::InList(in_list) => Ok(Expr::InList(InList::new( Box::new(parse_required_expr( in_list.expr.as_deref(), - registry, + ctx, "expr", codec, )?), - parse_exprs(&in_list.list, registry, codec)?, + parse_exprs(&in_list.list, ctx, codec)?, in_list.negated, ))), ExprType::Wildcard(protobuf::Wildcard { qualifier }) => { @@ -588,19 +587,19 @@ pub fn parse_expr( }) => { let scalar_fn = match fun_definition { Some(buf) => codec.try_decode_udf(fun_name, buf)?, - None => registry + None => ctx .udf(fun_name.as_str()) .or_else(|_| codec.try_decode_udf(fun_name, &[]))?, }; Ok(Expr::ScalarFunction(expr::ScalarFunction::new_udf( scalar_fn, - parse_exprs(args, registry, codec)?, + parse_exprs(args, ctx, codec)?, ))) } ExprType::AggregateUdfExpr(pb) => { let agg_fn = match &pb.fun_definition { Some(buf) => codec.try_decode_udaf(&pb.fun_name, buf)?, - None => registry + None => ctx .udaf(&pb.fun_name) .or_else(|_| codec.try_decode_udaf(&pb.fun_name, &[]))?, }; @@ -619,10 +618,10 @@ pub fn parse_expr( Ok(Expr::AggregateFunction(expr::AggregateFunction::new_udf( agg_fn, - parse_exprs(&pb.args, registry, codec)?, + parse_exprs(&pb.args, ctx, codec)?, pb.distinct, - parse_optional_expr(pb.filter.as_deref(), registry, codec)?.map(Box::new), - parse_sorts(&pb.order_by, registry, codec)?, + parse_optional_expr(pb.filter.as_deref(), ctx, codec)?.map(Box::new), + parse_sorts(&pb.order_by, ctx, codec)?, null_treatment, ))) } @@ -630,15 +629,15 @@ pub fn parse_expr( ExprType::GroupingSet(GroupingSetNode { expr }) => { Ok(Expr::GroupingSet(GroupingSets( expr.iter() - .map(|expr_list| parse_exprs(&expr_list.expr, registry, codec)) + .map(|expr_list| parse_exprs(&expr_list.expr, ctx, codec)) .collect::, Error>>()?, ))) } ExprType::Cube(CubeNode { expr }) => Ok(Expr::GroupingSet(GroupingSet::Cube( - parse_exprs(expr, registry, codec)?, + parse_exprs(expr, ctx, codec)?, ))), ExprType::Rollup(RollupNode { expr }) => Ok(Expr::GroupingSet( - GroupingSet::Rollup(parse_exprs(expr, registry, codec)?), + GroupingSet::Rollup(parse_exprs(expr, ctx, codec)?), )), ExprType::Placeholder(PlaceholderNode { id, @@ -665,7 +664,7 @@ pub fn parse_expr( sq.subquery .as_deref() .ok_or_else(|| Error::required("ScalarSubqueryExprNode.subquery"))?, - registry, + ctx, codec, )?; Ok(Expr::ScalarSubquery(subquery)) @@ -673,7 +672,7 @@ pub fn parse_expr( ExprType::InSubqueryExpr(sq) => { let expr = parse_required_expr( sq.expr.as_deref(), - registry, + ctx, "InSubqueryExprNode.expr", codec, )?; @@ -681,7 +680,7 @@ pub fn parse_expr( sq.subquery .as_deref() .ok_or_else(|| Error::required("InSubqueryExprNode.subquery"))?, - registry, + ctx, codec, )?; Ok(Expr::InSubquery(InSubquery::new( @@ -695,7 +694,7 @@ pub fn parse_expr( sq.subquery .as_deref() .ok_or_else(|| Error::required("ExistsExprNode.subquery"))?, - registry, + ctx, codec, )?; Ok(Expr::Exists(Exists::new(subquery, sq.negated))) @@ -705,19 +704,15 @@ pub fn parse_expr( fn parse_subquery( proto: &protobuf::SubqueryNode, - registry: &dyn FunctionRegistry, + ctx: &TaskContext, codec: &dyn LogicalExtensionCodec, ) -> Result { let plan_node = proto .subquery .as_ref() .ok_or_else(|| Error::required("SubqueryNode.subquery"))?; - // TaskContext is needed for try_into_logical_plan but only used for - // table provider deserialization (cache manager), not for the plan - // structure itself. - let ctx = datafusion_execution::TaskContext::default(); - let plan = plan_node.try_into_logical_plan(&ctx, codec)?; - let outer_ref_columns = parse_exprs(&proto.outer_ref_columns, registry, codec)?; + let plan = plan_node.try_into_logical_plan(ctx, codec)?; + let outer_ref_columns = parse_exprs(&proto.outer_ref_columns, ctx, codec)?; Ok(Subquery { subquery: Arc::new(plan), outer_ref_columns, @@ -728,7 +723,7 @@ fn parse_subquery( /// Parse a vector of `protobuf::LogicalExprNode`s. pub fn parse_exprs<'a, I>( protos: I, - registry: &dyn FunctionRegistry, + ctx: &TaskContext, codec: &dyn LogicalExtensionCodec, ) -> Result, Error> where @@ -737,7 +732,7 @@ where let res = protos .into_iter() .map(|elem| { - parse_expr(elem, registry, codec).map_err(|e| plan_datafusion_err!("{}", e)) + parse_expr(elem, ctx, codec).map_err(|e| plan_datafusion_err!("{}", e)) }) .collect::>>()?; Ok(res) @@ -745,7 +740,7 @@ where pub fn parse_sorts<'a, I>( protos: I, - registry: &dyn FunctionRegistry, + ctx: &TaskContext, codec: &dyn LogicalExtensionCodec, ) -> Result, Error> where @@ -753,17 +748,17 @@ where { protos .into_iter() - .map(|sort| parse_sort(sort, registry, codec)) + .map(|sort| parse_sort(sort, ctx, codec)) .collect::, Error>>() } pub fn parse_sort( sort: &protobuf::SortExprNode, - registry: &dyn FunctionRegistry, + ctx: &TaskContext, codec: &dyn LogicalExtensionCodec, ) -> Result { Ok(Sort::new( - parse_required_expr(sort.expr.as_ref(), registry, "expr", codec)?, + parse_required_expr(sort.expr.as_ref(), ctx, "expr", codec)?, sort.asc, sort.nulls_first, )) @@ -819,23 +814,23 @@ pub fn from_proto_binary_op(op: &str) -> Result { fn parse_optional_expr( p: Option<&protobuf::LogicalExprNode>, - registry: &dyn FunctionRegistry, + ctx: &TaskContext, codec: &dyn LogicalExtensionCodec, ) -> Result, Error> { match p { - Some(expr) => parse_expr(expr, registry, codec).map(Some), + Some(expr) => parse_expr(expr, ctx, codec).map(Some), None => Ok(None), } } fn parse_required_expr( p: Option<&protobuf::LogicalExprNode>, - registry: &dyn FunctionRegistry, + ctx: &TaskContext, field: impl Into, codec: &dyn LogicalExtensionCodec, ) -> Result { match p { - Some(expr) => parse_expr(expr, registry, codec), + Some(expr) => parse_expr(expr, ctx, codec), None => Err(Error::required(field)), } } diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 049841980587a..e300677d19f30 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -133,7 +133,8 @@ fn roundtrip_expr_test_with_codec( ) { let proto: protobuf::LogicalExprNode = serialize_expr(&initial_struct, codec) .unwrap_or_else(|e| panic!("Error serializing expression: {e:?}")); - let round_trip: Expr = from_proto::parse_expr(&proto, &ctx, codec).unwrap(); + let round_trip: Expr = + from_proto::parse_expr(&proto, ctx.task_ctx().as_ref(), codec).unwrap(); assert_eq!(format!("{:?}", &initial_struct), format!("{round_trip:?}")); @@ -2562,7 +2563,8 @@ fn roundtrip_scalar_udf_extension_codec() { let ctx = SessionContext::new(); let proto = serialize_expr(&test_expr, &UDFExtensionCodec).expect("serialize expr"); let round_trip = - from_proto::parse_expr(&proto, &ctx, &UDFExtensionCodec).expect("parse expr"); + from_proto::parse_expr(&proto, ctx.task_ctx().as_ref(), &UDFExtensionCodec) + .expect("parse expr"); assert_eq!(format!("{:?}", &test_expr), format!("{round_trip:?}")); roundtrip_json_test(&proto); @@ -2575,7 +2577,8 @@ fn roundtrip_aggregate_udf_extension_codec() { let ctx = SessionContext::new(); let proto = serialize_expr(&test_expr, &UDFExtensionCodec).expect("serialize expr"); let round_trip = - from_proto::parse_expr(&proto, &ctx, &UDFExtensionCodec).expect("parse expr"); + from_proto::parse_expr(&proto, ctx.task_ctx().as_ref(), &UDFExtensionCodec) + .expect("parse expr"); assert_eq!(format!("{:?}", &test_expr), format!("{round_trip:?}")); roundtrip_json_test(&proto); diff --git a/datafusion/proto/tests/cases/serialize.rs b/datafusion/proto/tests/cases/serialize.rs index bb955a426ca78..850fd42ce131b 100644 --- a/datafusion/proto/tests/cases/serialize.rs +++ b/datafusion/proto/tests/cases/serialize.rs @@ -77,7 +77,8 @@ fn udf_roundtrip_with_registry() { .call(vec![lit("")]); let bytes = expr.to_bytes().unwrap(); - let deserialized_expr = Expr::from_bytes_with_registry(&bytes, &ctx).unwrap(); + let deserialized_expr = + Expr::from_bytes_with_ctx(&bytes, ctx.task_ctx().as_ref()).unwrap(); assert_eq!(expr, deserialized_expr); } @@ -281,7 +282,8 @@ fn test_expression_serialization_roundtrip() { let extension_codec = DefaultLogicalExtensionCodec {}; let proto = serialize_expr(&expr, &extension_codec).unwrap(); - let deserialize = parse_expr(&proto, &ctx, &extension_codec).unwrap(); + let deserialize = + parse_expr(&proto, ctx.task_ctx().as_ref(), &extension_codec).unwrap(); let serialize_name = extract_function_name(&expr); let deserialize_name = extract_function_name(&deserialize); From 7071001d1db45f3f6027614add1505524d685247 Mon Sep 17 00:00:00 2001 From: Neil Conway Date: Mon, 30 Mar 2026 12:42:35 -0400 Subject: [PATCH 09/28] Increase large files size check --- .github/workflows/large_files.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/large_files.yml b/.github/workflows/large_files.yml index 12a5599189216..63cf18027c80f 100644 --- a/.github/workflows/large_files.yml +++ b/.github/workflows/large_files.yml @@ -34,9 +34,9 @@ jobs: fetch-depth: 0 - name: Check size of new Git objects env: - # 1 MB ought to be enough for anybody. + # 2 MB ought to be enough for anybody. # TODO in case we may want to consciously commit a bigger file to the repo without using Git LFS we may disable the check e.g. with a label - MAX_FILE_SIZE_BYTES: 1048576 + MAX_FILE_SIZE_BYTES: 2097152 shell: bash run: | if [ "${{ github.event_name }}" = "merge_group" ]; then From b9bce91f77ffff5444c045bef09ae8bac7d7bc1c Mon Sep 17 00:00:00 2001 From: Neil Conway Date: Mon, 30 Mar 2026 12:44:01 -0400 Subject: [PATCH 10/28] fix clippy --- datafusion/proto/src/bytes/mod.rs | 2 - datafusion/proto/src/bytes/registry.rs | 85 -------------------------- 2 files changed, 87 deletions(-) delete mode 100644 datafusion/proto/src/bytes/registry.rs diff --git a/datafusion/proto/src/bytes/mod.rs b/datafusion/proto/src/bytes/mod.rs index 67a591d92bfbb..820bbc20af46c 100644 --- a/datafusion/proto/src/bytes/mod.rs +++ b/datafusion/proto/src/bytes/mod.rs @@ -36,8 +36,6 @@ use std::sync::Arc; use datafusion_physical_plan::ExecutionPlan; -mod registry; - /// Encodes something (such as [`Expr`]) to/from a stream of /// bytes. /// diff --git a/datafusion/proto/src/bytes/registry.rs b/datafusion/proto/src/bytes/registry.rs deleted file mode 100644 index a3f74787e2b50..0000000000000 --- a/datafusion/proto/src/bytes/registry.rs +++ /dev/null @@ -1,85 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use std::{collections::HashSet, sync::Arc}; - -use datafusion_common::Result; -use datafusion_common::plan_err; -use datafusion_execution::registry::FunctionRegistry; -use datafusion_expr::planner::ExprPlanner; -use datafusion_expr::{AggregateUDF, ScalarUDF, WindowUDF}; - -/// A default [`FunctionRegistry`] registry that does not resolve any -/// user defined functions -pub(crate) struct NoRegistry {} - -impl FunctionRegistry for NoRegistry { - fn udfs(&self) -> HashSet { - HashSet::new() - } - - fn udf(&self, name: &str) -> Result> { - plan_err!( - "No function registry provided to deserialize, so can not deserialize User Defined Function '{name}'" - ) - } - - fn udaf(&self, name: &str) -> Result> { - plan_err!( - "No function registry provided to deserialize, so can not deserialize User Defined Aggregate Function '{name}'" - ) - } - - fn udwf(&self, name: &str) -> Result> { - plan_err!( - "No function registry provided to deserialize, so can not deserialize User Defined Window Function '{name}'" - ) - } - fn register_udaf( - &mut self, - udaf: Arc, - ) -> Result>> { - plan_err!( - "No function registry provided to deserialize, so can not register User Defined Aggregate Function '{}'", - udaf.inner().name() - ) - } - fn register_udf(&mut self, udf: Arc) -> Result>> { - plan_err!( - "No function registry provided to deserialize, so can not deserialize User Defined Function '{}'", - udf.inner().name() - ) - } - fn register_udwf(&mut self, udwf: Arc) -> Result>> { - plan_err!( - "No function registry provided to deserialize, so can not deserialize User Defined Window Function '{}'", - udwf.inner().name() - ) - } - - fn expr_planners(&self) -> Vec> { - vec![] - } - - fn udafs(&self) -> HashSet { - HashSet::new() - } - - fn udwfs(&self) -> HashSet { - HashSet::new() - } -} From 7c965aae54f4b8537dd52e9d3edc371e10cbded5 Mon Sep 17 00:00:00 2001 From: Neil Conway Date: Mon, 30 Mar 2026 13:55:58 -0400 Subject: [PATCH 11/28] Update expected TPC-H plans --- .../test_files/tpch/plans/q11.slt.part | 110 +++++++++--------- .../test_files/tpch/plans/q15.slt.part | 80 ++++++------- .../test_files/tpch/plans/q22.slt.part | 73 ++++++------ 3 files changed, 132 insertions(+), 131 deletions(-) diff --git a/datafusion/sqllogictest/test_files/tpch/plans/q11.slt.part b/datafusion/sqllogictest/test_files/tpch/plans/q11.slt.part index a31579eb1e09d..8f1fe62d64201 100644 --- a/datafusion/sqllogictest/test_files/tpch/plans/q11.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/plans/q11.slt.part @@ -49,61 +49,59 @@ limit 10; logical_plan 01)Sort: value DESC NULLS FIRST, fetch=10 02)--Projection: partsupp.ps_partkey, sum(partsupp.ps_supplycost * partsupp.ps_availqty) AS value -03)----Inner Join: Filter: CAST(sum(partsupp.ps_supplycost * partsupp.ps_availqty) AS Decimal128(38, 15)) > __scalar_sq_1.sum(partsupp.ps_supplycost * partsupp.ps_availqty) * Float64(0.0001) -04)------Aggregate: groupBy=[[partsupp.ps_partkey]], aggr=[[sum(partsupp.ps_supplycost * CAST(partsupp.ps_availqty AS Decimal128(10, 0)))]] -05)--------Projection: partsupp.ps_partkey, partsupp.ps_availqty, partsupp.ps_supplycost -06)----------Inner Join: supplier.s_nationkey = nation.n_nationkey -07)------------Projection: partsupp.ps_partkey, partsupp.ps_availqty, partsupp.ps_supplycost, supplier.s_nationkey +03)----Filter: CAST(sum(partsupp.ps_supplycost * partsupp.ps_availqty) AS Decimal128(38, 15)) > () +04)------Subquery: +05)--------Projection: CAST(CAST(sum(partsupp.ps_supplycost * partsupp.ps_availqty) AS Float64) * Float64(0.0001) AS Decimal128(38, 15)) +06)----------Aggregate: groupBy=[[]], aggr=[[sum(partsupp.ps_supplycost * CAST(partsupp.ps_availqty AS Decimal128(10, 0)))]] +07)------------Inner Join: supplier.s_nationkey = nation.n_nationkey 08)--------------Inner Join: partsupp.ps_suppkey = supplier.s_suppkey -09)----------------TableScan: partsupp projection=[ps_partkey, ps_suppkey, ps_availqty, ps_supplycost], partial_filters=[Boolean(true)] -10)----------------TableScan: supplier projection=[s_suppkey, s_nationkey] -11)------------Projection: nation.n_nationkey -12)--------------Filter: nation.n_name = Utf8View("GERMANY") -13)----------------TableScan: nation projection=[n_nationkey, n_name], partial_filters=[nation.n_name = Utf8View("GERMANY")] -14)------SubqueryAlias: __scalar_sq_1 -15)--------Projection: CAST(CAST(sum(partsupp.ps_supplycost * partsupp.ps_availqty) AS Float64) * Float64(0.0001) AS Decimal128(38, 15)) -16)----------Aggregate: groupBy=[[]], aggr=[[sum(partsupp.ps_supplycost * CAST(partsupp.ps_availqty AS Decimal128(10, 0)))]] -17)------------Projection: partsupp.ps_availqty, partsupp.ps_supplycost -18)--------------Inner Join: supplier.s_nationkey = nation.n_nationkey -19)----------------Projection: partsupp.ps_availqty, partsupp.ps_supplycost, supplier.s_nationkey -20)------------------Inner Join: partsupp.ps_suppkey = supplier.s_suppkey -21)--------------------TableScan: partsupp projection=[ps_suppkey, ps_availqty, ps_supplycost] -22)--------------------TableScan: supplier projection=[s_suppkey, s_nationkey] -23)----------------Projection: nation.n_nationkey -24)------------------Filter: nation.n_name = Utf8View("GERMANY") -25)--------------------TableScan: nation projection=[n_nationkey, n_name], partial_filters=[nation.n_name = Utf8View("GERMANY")] +09)----------------TableScan: partsupp +10)----------------TableScan: supplier +11)--------------Filter: nation.n_name = Utf8View("GERMANY") +12)----------------TableScan: nation, partial_filters=[nation.n_name = Utf8View("GERMANY")] +13)------Aggregate: groupBy=[[partsupp.ps_partkey]], aggr=[[sum(partsupp.ps_supplycost * CAST(partsupp.ps_availqty AS Decimal128(10, 0)))]] +14)--------Projection: partsupp.ps_partkey, partsupp.ps_availqty, partsupp.ps_supplycost +15)----------Inner Join: supplier.s_nationkey = nation.n_nationkey +16)------------Projection: partsupp.ps_partkey, partsupp.ps_availqty, partsupp.ps_supplycost, supplier.s_nationkey +17)--------------Inner Join: partsupp.ps_suppkey = supplier.s_suppkey +18)----------------TableScan: partsupp projection=[ps_partkey, ps_suppkey, ps_availqty, ps_supplycost] +19)----------------TableScan: supplier projection=[s_suppkey, s_nationkey] +20)------------Projection: nation.n_nationkey +21)--------------Filter: nation.n_name = Utf8View("GERMANY") +22)----------------TableScan: nation projection=[n_nationkey, n_name], partial_filters=[nation.n_name = Utf8View("GERMANY")] physical_plan -01)SortExec: TopK(fetch=10), expr=[value@1 DESC], preserve_partitioning=[false] -02)--ProjectionExec: expr=[ps_partkey@0 as ps_partkey, sum(partsupp.ps_supplycost * partsupp.ps_availqty)@1 as value] -03)----NestedLoopJoinExec: join_type=Inner, filter=join_proj_push_down_1@1 > sum(partsupp.ps_supplycost * partsupp.ps_availqty) * Float64(0.0001)@0, projection=[ps_partkey@0, sum(partsupp.ps_supplycost * partsupp.ps_availqty)@1, sum(partsupp.ps_supplycost * partsupp.ps_availqty) * Float64(0.0001)@3] -04)------ProjectionExec: expr=[ps_partkey@0 as ps_partkey, sum(partsupp.ps_supplycost * partsupp.ps_availqty)@1 as sum(partsupp.ps_supplycost * partsupp.ps_availqty), CAST(sum(partsupp.ps_supplycost * partsupp.ps_availqty)@1 AS Decimal128(38, 15)) as join_proj_push_down_1] -05)--------CoalescePartitionsExec -06)----------AggregateExec: mode=FinalPartitioned, gby=[ps_partkey@0 as ps_partkey], aggr=[sum(partsupp.ps_supplycost * partsupp.ps_availqty)] -07)------------RepartitionExec: partitioning=Hash([ps_partkey@0], 4), input_partitions=4 -08)--------------AggregateExec: mode=Partial, gby=[ps_partkey@0 as ps_partkey], aggr=[sum(partsupp.ps_supplycost * partsupp.ps_availqty)] -09)----------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(s_nationkey@3, n_nationkey@0)], projection=[ps_partkey@0, ps_availqty@1, ps_supplycost@2] -10)------------------RepartitionExec: partitioning=Hash([s_nationkey@3], 4), input_partitions=4 -11)--------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(ps_suppkey@1, s_suppkey@0)], projection=[ps_partkey@0, ps_availqty@2, ps_supplycost@3, s_nationkey@5] -12)----------------------RepartitionExec: partitioning=Hash([ps_suppkey@1], 4), input_partitions=4 -13)------------------------DataSourceExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:0..2932049], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:2932049..5864098], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:5864098..8796147], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:8796147..11728193]]}, projection=[ps_partkey, ps_suppkey, ps_availqty, ps_supplycost], file_type=csv, has_header=false -14)----------------------RepartitionExec: partitioning=Hash([s_suppkey@0], 4), input_partitions=1 -15)------------------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/supplier.tbl]]}, projection=[s_suppkey, s_nationkey], file_type=csv, has_header=false -16)------------------RepartitionExec: partitioning=Hash([n_nationkey@0], 4), input_partitions=4 -17)--------------------FilterExec: n_name@1 = GERMANY, projection=[n_nationkey@0] -18)----------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -19)------------------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/nation.tbl]]}, projection=[n_nationkey, n_name], file_type=csv, has_header=false -20)------ProjectionExec: expr=[CAST(CAST(sum(partsupp.ps_supplycost * partsupp.ps_availqty)@0 AS Float64) * 0.0001 AS Decimal128(38, 15)) as sum(partsupp.ps_supplycost * partsupp.ps_availqty) * Float64(0.0001)] -21)--------AggregateExec: mode=Final, gby=[], aggr=[sum(partsupp.ps_supplycost * partsupp.ps_availqty)] -22)----------CoalescePartitionsExec -23)------------AggregateExec: mode=Partial, gby=[], aggr=[sum(partsupp.ps_supplycost * partsupp.ps_availqty)] -24)--------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(s_nationkey@2, n_nationkey@0)], projection=[ps_availqty@0, ps_supplycost@1] -25)----------------RepartitionExec: partitioning=Hash([s_nationkey@2], 4), input_partitions=4 -26)------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(ps_suppkey@0, s_suppkey@0)], projection=[ps_availqty@1, ps_supplycost@2, s_nationkey@4] -27)--------------------RepartitionExec: partitioning=Hash([ps_suppkey@0], 4), input_partitions=4 -28)----------------------DataSourceExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:0..2932049], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:2932049..5864098], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:5864098..8796147], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:8796147..11728193]]}, projection=[ps_suppkey, ps_availqty, ps_supplycost], file_type=csv, has_header=false -29)--------------------RepartitionExec: partitioning=Hash([s_suppkey@0], 4), input_partitions=1 -30)----------------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/supplier.tbl]]}, projection=[s_suppkey, s_nationkey], file_type=csv, has_header=false -31)----------------RepartitionExec: partitioning=Hash([n_nationkey@0], 4), input_partitions=4 -32)------------------FilterExec: n_name@1 = GERMANY, projection=[n_nationkey@0] -33)--------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -34)----------------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/nation.tbl]]}, projection=[n_nationkey, n_name], file_type=csv, has_header=false +01)ScalarSubqueryExec: subqueries=1 +02)--RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1, maintains_sort_order=true +03)----SortPreservingMergeExec: [value@1 DESC], fetch=10 +04)------SortExec: TopK(fetch=10), expr=[value@1 DESC], preserve_partitioning=[true] +05)--------ProjectionExec: expr=[ps_partkey@0 as ps_partkey, sum(partsupp.ps_supplycost * partsupp.ps_availqty)@1 as value] +06)----------FilterExec: CAST(sum(partsupp.ps_supplycost * partsupp.ps_availqty)@1 AS Decimal128(38, 15)) > scalar_subquery() +07)------------AggregateExec: mode=FinalPartitioned, gby=[ps_partkey@0 as ps_partkey], aggr=[sum(partsupp.ps_supplycost * partsupp.ps_availqty)] +08)--------------RepartitionExec: partitioning=Hash([ps_partkey@0], 4), input_partitions=4 +09)----------------AggregateExec: mode=Partial, gby=[ps_partkey@0 as ps_partkey], aggr=[sum(partsupp.ps_supplycost * partsupp.ps_availqty)] +10)------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(s_nationkey@3, n_nationkey@0)], projection=[ps_partkey@0, ps_availqty@1, ps_supplycost@2] +11)--------------------RepartitionExec: partitioning=Hash([s_nationkey@3], 4), input_partitions=4 +12)----------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(ps_suppkey@1, s_suppkey@0)], projection=[ps_partkey@0, ps_availqty@2, ps_supplycost@3, s_nationkey@5] +13)------------------------RepartitionExec: partitioning=Hash([ps_suppkey@1], 4), input_partitions=4 +14)--------------------------DataSourceExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:0..2932049], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:2932049..5864098], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:5864098..8796147], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:8796147..11728193]]}, projection=[ps_partkey, ps_suppkey, ps_availqty, ps_supplycost], file_type=csv, has_header=false +15)------------------------RepartitionExec: partitioning=Hash([s_suppkey@0], 4), input_partitions=1 +16)--------------------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/supplier.tbl]]}, projection=[s_suppkey, s_nationkey], file_type=csv, has_header=false +17)--------------------RepartitionExec: partitioning=Hash([n_nationkey@0], 4), input_partitions=4 +18)----------------------FilterExec: n_name@1 = GERMANY, projection=[n_nationkey@0] +19)------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +20)--------------------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/nation.tbl]]}, projection=[n_nationkey, n_name], file_type=csv, has_header=false +21)--ProjectionExec: expr=[CAST(CAST(sum(partsupp.ps_supplycost * partsupp.ps_availqty)@0 AS Float64) * 0.0001 AS Decimal128(38, 15)) as sum(partsupp.ps_supplycost * partsupp.ps_availqty) * Float64(0.0001)] +22)----AggregateExec: mode=Final, gby=[], aggr=[sum(partsupp.ps_supplycost * partsupp.ps_availqty)] +23)------CoalescePartitionsExec +24)--------AggregateExec: mode=Partial, gby=[], aggr=[sum(partsupp.ps_supplycost * partsupp.ps_availqty)] +25)----------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(s_nationkey@9, n_nationkey@0)] +26)------------RepartitionExec: partitioning=Hash([s_nationkey@9], 4), input_partitions=4 +27)--------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(ps_suppkey@1, s_suppkey@0)] +28)----------------RepartitionExec: partitioning=Hash([ps_suppkey@1], 4), input_partitions=4 +29)------------------DataSourceExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:0..2932049], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:2932049..5864098], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:5864098..8796147], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:8796147..11728193]]}, projection=[ps_partkey, ps_suppkey, ps_availqty, ps_supplycost, ps_comment, ps_rev], file_type=csv, has_header=false +30)----------------RepartitionExec: partitioning=Hash([s_suppkey@0], 4), input_partitions=1 +31)------------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/supplier.tbl]]}, projection=[s_suppkey, s_name, s_address, s_nationkey, s_phone, s_acctbal, s_comment, s_rev], file_type=csv, has_header=false +32)------------RepartitionExec: partitioning=Hash([n_nationkey@0], 4), input_partitions=4 +33)--------------FilterExec: n_name@1 = GERMANY +34)----------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +35)------------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/nation.tbl]]}, projection=[n_nationkey, n_name, n_regionkey, n_comment, n_rev], file_type=csv, has_header=false diff --git a/datafusion/sqllogictest/test_files/tpch/plans/q15.slt.part b/datafusion/sqllogictest/test_files/tpch/plans/q15.slt.part index ae0c0a93a3552..5bdca882218b0 100644 --- a/datafusion/sqllogictest/test_files/tpch/plans/q15.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/plans/q15.slt.part @@ -52,43 +52,45 @@ order by logical_plan 01)Sort: supplier.s_suppkey ASC NULLS LAST 02)--Projection: supplier.s_suppkey, supplier.s_name, supplier.s_address, supplier.s_phone, revenue0.total_revenue -03)----Inner Join: revenue0.total_revenue = __scalar_sq_1.max(revenue0.total_revenue) -04)------Projection: supplier.s_suppkey, supplier.s_name, supplier.s_address, supplier.s_phone, revenue0.total_revenue -05)--------Inner Join: supplier.s_suppkey = revenue0.supplier_no -06)----------TableScan: supplier projection=[s_suppkey, s_name, s_address, s_phone], partial_filters=[Boolean(true)] -07)----------SubqueryAlias: revenue0 -08)------------Projection: lineitem.l_suppkey AS supplier_no, sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount) AS total_revenue -09)--------------Aggregate: groupBy=[[lineitem.l_suppkey]], aggr=[[sum(lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount)) AS sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)]] -10)----------------Projection: lineitem.l_suppkey, lineitem.l_extendedprice, lineitem.l_discount -11)------------------Filter: lineitem.l_shipdate >= Date32("1996-01-01") AND lineitem.l_shipdate < Date32("1996-04-01") -12)--------------------TableScan: lineitem projection=[l_suppkey, l_extendedprice, l_discount, l_shipdate], partial_filters=[lineitem.l_shipdate >= Date32("1996-01-01"), lineitem.l_shipdate < Date32("1996-04-01")] -13)------SubqueryAlias: __scalar_sq_1 -14)--------Aggregate: groupBy=[[]], aggr=[[max(revenue0.total_revenue)]] -15)----------SubqueryAlias: revenue0 -16)------------Projection: sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount) AS total_revenue -17)--------------Aggregate: groupBy=[[lineitem.l_suppkey]], aggr=[[sum(lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount)) AS sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)]] -18)----------------Projection: lineitem.l_suppkey, lineitem.l_extendedprice, lineitem.l_discount -19)------------------Filter: lineitem.l_shipdate >= Date32("1996-01-01") AND lineitem.l_shipdate < Date32("1996-04-01") -20)--------------------TableScan: lineitem projection=[l_suppkey, l_extendedprice, l_discount, l_shipdate], partial_filters=[lineitem.l_shipdate >= Date32("1996-01-01"), lineitem.l_shipdate < Date32("1996-04-01")] +03)----Inner Join: supplier.s_suppkey = revenue0.supplier_no +04)------TableScan: supplier projection=[s_suppkey, s_name, s_address, s_phone] +05)------SubqueryAlias: revenue0 +06)--------Projection: lineitem.l_suppkey AS supplier_no, sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount) AS total_revenue +07)----------Filter: sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount) = () +08)------------Subquery: +09)--------------Projection: max(revenue0.total_revenue) +10)----------------Aggregate: groupBy=[[]], aggr=[[max(revenue0.total_revenue)]] +11)------------------SubqueryAlias: revenue0 +12)--------------------Projection: lineitem.l_suppkey AS supplier_no, sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount) AS total_revenue +13)----------------------Projection: lineitem.l_suppkey, sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount) +14)------------------------Aggregate: groupBy=[[lineitem.l_suppkey]], aggr=[[sum(lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount)) AS sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)]] +15)--------------------------Filter: lineitem.l_shipdate >= Date32("1996-01-01") AND lineitem.l_shipdate < Date32("1996-04-01") +16)----------------------------TableScan: lineitem, partial_filters=[lineitem.l_shipdate >= Date32("1996-01-01"), lineitem.l_shipdate < Date32("1996-04-01")] +17)------------Aggregate: groupBy=[[lineitem.l_suppkey]], aggr=[[sum(lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount)) AS sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)]] +18)--------------Projection: lineitem.l_suppkey, lineitem.l_extendedprice, lineitem.l_discount +19)----------------Filter: lineitem.l_shipdate >= Date32("1996-01-01") AND lineitem.l_shipdate < Date32("1996-04-01") +20)------------------TableScan: lineitem projection=[l_suppkey, l_extendedprice, l_discount, l_shipdate], partial_filters=[lineitem.l_shipdate >= Date32("1996-01-01"), lineitem.l_shipdate < Date32("1996-04-01")] physical_plan -01)SortPreservingMergeExec: [s_suppkey@0 ASC NULLS LAST] -02)--SortExec: expr=[s_suppkey@0 ASC NULLS LAST], preserve_partitioning=[true] -03)----HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(max(revenue0.total_revenue)@0, total_revenue@4)], projection=[s_suppkey@1, s_name@2, s_address@3, s_phone@4, total_revenue@5] -04)------AggregateExec: mode=Final, gby=[], aggr=[max(revenue0.total_revenue)] -05)--------CoalescePartitionsExec -06)----------AggregateExec: mode=Partial, gby=[], aggr=[max(revenue0.total_revenue)] -07)------------ProjectionExec: expr=[sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)@1 as total_revenue] -08)--------------AggregateExec: mode=FinalPartitioned, gby=[l_suppkey@0 as l_suppkey], aggr=[sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)] -09)----------------RepartitionExec: partitioning=Hash([l_suppkey@0], 4), input_partitions=4 -10)------------------AggregateExec: mode=Partial, gby=[l_suppkey@0 as l_suppkey], aggr=[sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)] -11)--------------------FilterExec: l_shipdate@3 >= 1996-01-01 AND l_shipdate@3 < 1996-04-01, projection=[l_suppkey@0, l_extendedprice@1, l_discount@2] -12)----------------------DataSourceExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:0..18561749], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:18561749..37123498], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:37123498..55685247], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:55685247..74246996]]}, projection=[l_suppkey, l_extendedprice, l_discount, l_shipdate], file_type=csv, has_header=false -13)------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(s_suppkey@0, supplier_no@0)], projection=[s_suppkey@0, s_name@1, s_address@2, s_phone@3, total_revenue@5] -14)--------RepartitionExec: partitioning=Hash([s_suppkey@0], 4), input_partitions=1 -15)----------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/supplier.tbl]]}, projection=[s_suppkey, s_name, s_address, s_phone], file_type=csv, has_header=false -16)--------ProjectionExec: expr=[l_suppkey@0 as supplier_no, sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)@1 as total_revenue] -17)----------AggregateExec: mode=FinalPartitioned, gby=[l_suppkey@0 as l_suppkey], aggr=[sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)] -18)------------RepartitionExec: partitioning=Hash([l_suppkey@0], 4), input_partitions=4 -19)--------------AggregateExec: mode=Partial, gby=[l_suppkey@0 as l_suppkey], aggr=[sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)] -20)----------------FilterExec: l_shipdate@3 >= 1996-01-01 AND l_shipdate@3 < 1996-04-01, projection=[l_suppkey@0, l_extendedprice@1, l_discount@2] -21)------------------DataSourceExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:0..18561749], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:18561749..37123498], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:37123498..55685247], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:55685247..74246996]]}, projection=[l_suppkey, l_extendedprice, l_discount, l_shipdate], file_type=csv, has_header=false +01)ScalarSubqueryExec: subqueries=1 +02)--RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1, maintains_sort_order=true +03)----SortPreservingMergeExec: [s_suppkey@0 ASC NULLS LAST] +04)------SortExec: expr=[s_suppkey@0 ASC NULLS LAST], preserve_partitioning=[true] +05)--------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(s_suppkey@0, supplier_no@0)], projection=[s_suppkey@0, s_name@1, s_address@2, s_phone@3, total_revenue@5] +06)----------RepartitionExec: partitioning=Hash([s_suppkey@0], 4), input_partitions=1 +07)------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/supplier.tbl]]}, projection=[s_suppkey, s_name, s_address, s_phone], file_type=csv, has_header=false +08)----------ProjectionExec: expr=[l_suppkey@0 as supplier_no, sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)@1 as total_revenue] +09)------------FilterExec: sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)@1 = scalar_subquery() +10)--------------AggregateExec: mode=FinalPartitioned, gby=[l_suppkey@0 as l_suppkey], aggr=[sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)] +11)----------------RepartitionExec: partitioning=Hash([l_suppkey@0], 4), input_partitions=4 +12)------------------AggregateExec: mode=Partial, gby=[l_suppkey@0 as l_suppkey], aggr=[sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)] +13)--------------------FilterExec: l_shipdate@3 >= 1996-01-01 AND l_shipdate@3 < 1996-04-01, projection=[l_suppkey@0, l_extendedprice@1, l_discount@2] +14)----------------------DataSourceExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:0..18561749], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:18561749..37123498], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:37123498..55685247], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:55685247..74246996]]}, projection=[l_suppkey, l_extendedprice, l_discount, l_shipdate], file_type=csv, has_header=false +15)--AggregateExec: mode=Final, gby=[], aggr=[max(revenue0.total_revenue)] +16)----CoalescePartitionsExec +17)------AggregateExec: mode=Partial, gby=[], aggr=[max(revenue0.total_revenue)] +18)--------ProjectionExec: expr=[l_suppkey@0 as supplier_no, sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)@1 as total_revenue] +19)----------AggregateExec: mode=FinalPartitioned, gby=[l_suppkey@0 as l_suppkey], aggr=[sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)] +20)------------RepartitionExec: partitioning=Hash([l_suppkey@0], 4), input_partitions=4 +21)--------------AggregateExec: mode=Partial, gby=[l_suppkey@2 as l_suppkey], aggr=[sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)] +22)----------------FilterExec: l_shipdate@10 >= 1996-01-01 AND l_shipdate@10 < 1996-04-01 +23)------------------DataSourceExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:0..18561749], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:18561749..37123498], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:37123498..55685247], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:55685247..74246996]]}, projection=[l_orderkey, l_partkey, l_suppkey, l_linenumber, l_quantity, l_extendedprice, l_discount, l_tax, l_returnflag, l_linestatus, l_shipdate, l_commitdate, l_receiptdate, l_shipinstruct, l_shipmode, l_comment, l_rev], file_type=csv, has_header=false diff --git a/datafusion/sqllogictest/test_files/tpch/plans/q22.slt.part b/datafusion/sqllogictest/test_files/tpch/plans/q22.slt.part index add578c3b079d..306deec725fe3 100644 --- a/datafusion/sqllogictest/test_files/tpch/plans/q22.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/plans/q22.slt.part @@ -61,40 +61,41 @@ logical_plan 03)----Aggregate: groupBy=[[custsale.cntrycode]], aggr=[[count(Int64(1)), sum(custsale.c_acctbal)]] 04)------SubqueryAlias: custsale 05)--------Projection: substr(customer.c_phone, Int64(1), Int64(2)) AS cntrycode, customer.c_acctbal -06)----------Inner Join: Filter: CAST(customer.c_acctbal AS Decimal128(19, 6)) > __scalar_sq_2.avg(customer.c_acctbal) -07)------------Projection: customer.c_phone, customer.c_acctbal -08)--------------LeftAnti Join: customer.c_custkey = __correlated_sq_1.o_custkey -09)----------------Filter: substr(customer.c_phone, Int64(1), Int64(2)) IN ([Utf8View("13"), Utf8View("31"), Utf8View("23"), Utf8View("29"), Utf8View("30"), Utf8View("18"), Utf8View("17")]) -10)------------------TableScan: customer projection=[c_custkey, c_phone, c_acctbal], partial_filters=[substr(customer.c_phone, Int64(1), Int64(2)) IN ([Utf8View("13"), Utf8View("31"), Utf8View("23"), Utf8View("29"), Utf8View("30"), Utf8View("18"), Utf8View("17")]), Boolean(true)] -11)----------------SubqueryAlias: __correlated_sq_1 -12)------------------TableScan: orders projection=[o_custkey] -13)------------SubqueryAlias: __scalar_sq_2 -14)--------------Aggregate: groupBy=[[]], aggr=[[avg(customer.c_acctbal)]] -15)----------------Projection: customer.c_acctbal -16)------------------Filter: customer.c_acctbal > Decimal128(Some(0),15,2) AND substr(customer.c_phone, Int64(1), Int64(2)) IN ([Utf8View("13"), Utf8View("31"), Utf8View("23"), Utf8View("29"), Utf8View("30"), Utf8View("18"), Utf8View("17")]) -17)--------------------TableScan: customer projection=[c_phone, c_acctbal], partial_filters=[customer.c_acctbal > Decimal128(Some(0),15,2), substr(customer.c_phone, Int64(1), Int64(2)) IN ([Utf8View("13"), Utf8View("31"), Utf8View("23"), Utf8View("29"), Utf8View("30"), Utf8View("18"), Utf8View("17")])] +06)----------LeftAnti Join: customer.c_custkey = __correlated_sq_1.o_custkey +07)------------Filter: substr(customer.c_phone, Int64(1), Int64(2)) IN ([Utf8View("13"), Utf8View("31"), Utf8View("23"), Utf8View("29"), Utf8View("30"), Utf8View("18"), Utf8View("17")]) AND CAST(customer.c_acctbal AS Decimal128(19, 6)) > () +08)--------------Subquery: +09)----------------Projection: avg(customer.c_acctbal) +10)------------------Aggregate: groupBy=[[]], aggr=[[avg(customer.c_acctbal)]] +11)--------------------Filter: customer.c_acctbal > Decimal128(Some(0),15,2) AND substr(customer.c_phone, Int64(1), Int64(2)) IN ([Utf8View("13"), Utf8View("31"), Utf8View("23"), Utf8View("29"), Utf8View("30"), Utf8View("18"), Utf8View("17")]) +12)----------------------TableScan: customer, partial_filters=[customer.c_acctbal > Decimal128(Some(0),15,2), substr(customer.c_phone, Int64(1), Int64(2)) IN ([Utf8View("13"), Utf8View("31"), Utf8View("23"), Utf8View("29"), Utf8View("30"), Utf8View("18"), Utf8View("17")])] +13)--------------TableScan: customer projection=[c_custkey, c_phone, c_acctbal], partial_filters=[substr(customer.c_phone, Int64(1), Int64(2)) IN ([Utf8View("13"), Utf8View("31"), Utf8View("23"), Utf8View("29"), Utf8View("30"), Utf8View("18"), Utf8View("17")]), CAST(customer.c_acctbal AS Decimal128(19, 6)) > ()] +14)----------------Subquery: +15)------------------Projection: avg(customer.c_acctbal) +16)--------------------Aggregate: groupBy=[[]], aggr=[[avg(customer.c_acctbal)]] +17)----------------------Filter: customer.c_acctbal > Decimal128(Some(0),15,2) AND substr(customer.c_phone, Int64(1), Int64(2)) IN ([Utf8View("13"), Utf8View("31"), Utf8View("23"), Utf8View("29"), Utf8View("30"), Utf8View("18"), Utf8View("17")]) +18)------------------------TableScan: customer, partial_filters=[customer.c_acctbal > Decimal128(Some(0),15,2), substr(customer.c_phone, Int64(1), Int64(2)) IN ([Utf8View("13"), Utf8View("31"), Utf8View("23"), Utf8View("29"), Utf8View("30"), Utf8View("18"), Utf8View("17")])] +19)------------SubqueryAlias: __correlated_sq_1 +20)--------------TableScan: orders projection=[o_custkey] physical_plan -01)SortPreservingMergeExec: [cntrycode@0 ASC NULLS LAST] -02)--SortExec: expr=[cntrycode@0 ASC NULLS LAST], preserve_partitioning=[true] -03)----ProjectionExec: expr=[cntrycode@0 as cntrycode, count(Int64(1))@1 as numcust, sum(custsale.c_acctbal)@2 as totacctbal] -04)------AggregateExec: mode=FinalPartitioned, gby=[cntrycode@0 as cntrycode], aggr=[count(Int64(1)), sum(custsale.c_acctbal)] -05)--------RepartitionExec: partitioning=Hash([cntrycode@0], 4), input_partitions=4 -06)----------AggregateExec: mode=Partial, gby=[cntrycode@0 as cntrycode], aggr=[count(Int64(1)), sum(custsale.c_acctbal)] -07)------------ProjectionExec: expr=[substr(c_phone@0, 1, 2) as cntrycode, c_acctbal@1 as c_acctbal] -08)--------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -09)----------------NestedLoopJoinExec: join_type=Inner, filter=join_proj_push_down_1@1 > avg(customer.c_acctbal)@0, projection=[c_phone@0, c_acctbal@1, avg(customer.c_acctbal)@3] -10)------------------ProjectionExec: expr=[c_phone@0 as c_phone, c_acctbal@1 as c_acctbal, CAST(c_acctbal@1 AS Decimal128(19, 6)) as join_proj_push_down_1] -11)--------------------CoalescePartitionsExec -12)----------------------HashJoinExec: mode=Partitioned, join_type=LeftAnti, on=[(c_custkey@0, o_custkey@0)], projection=[c_phone@1, c_acctbal@2] -13)------------------------RepartitionExec: partitioning=Hash([c_custkey@0], 4), input_partitions=4 -14)--------------------------FilterExec: substr(c_phone@1, 1, 2) IN (SET) ([13, 31, 23, 29, 30, 18, 17]) -15)----------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -16)------------------------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/customer.tbl]]}, projection=[c_custkey, c_phone, c_acctbal], file_type=csv, has_header=false -17)------------------------RepartitionExec: partitioning=Hash([o_custkey@0], 4), input_partitions=4 -18)--------------------------DataSourceExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:0..4223281], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:4223281..8446562], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:8446562..12669843], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:12669843..16893122]]}, projection=[o_custkey], file_type=csv, has_header=false -19)------------------AggregateExec: mode=Final, gby=[], aggr=[avg(customer.c_acctbal)] -20)--------------------CoalescePartitionsExec -21)----------------------AggregateExec: mode=Partial, gby=[], aggr=[avg(customer.c_acctbal)] -22)------------------------FilterExec: c_acctbal@1 > Some(0),15,2 AND substr(c_phone@0, 1, 2) IN (SET) ([13, 31, 23, 29, 30, 18, 17]), projection=[c_acctbal@1] -23)--------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -24)----------------------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/customer.tbl]]}, projection=[c_phone, c_acctbal], file_type=csv, has_header=false +01)ScalarSubqueryExec: subqueries=1 +02)--RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1, maintains_sort_order=true +03)----SortPreservingMergeExec: [cntrycode@0 ASC NULLS LAST] +04)------SortExec: expr=[cntrycode@0 ASC NULLS LAST], preserve_partitioning=[true] +05)--------ProjectionExec: expr=[cntrycode@0 as cntrycode, count(Int64(1))@1 as numcust, sum(custsale.c_acctbal)@2 as totacctbal] +06)----------AggregateExec: mode=FinalPartitioned, gby=[cntrycode@0 as cntrycode], aggr=[count(Int64(1)), sum(custsale.c_acctbal)] +07)------------RepartitionExec: partitioning=Hash([cntrycode@0], 4), input_partitions=4 +08)--------------AggregateExec: mode=Partial, gby=[cntrycode@0 as cntrycode], aggr=[count(Int64(1)), sum(custsale.c_acctbal)] +09)----------------ProjectionExec: expr=[substr(c_phone@0, 1, 2) as cntrycode, c_acctbal@1 as c_acctbal] +10)------------------HashJoinExec: mode=Partitioned, join_type=LeftAnti, on=[(c_custkey@0, o_custkey@0)], projection=[c_phone@1, c_acctbal@2] +11)--------------------RepartitionExec: partitioning=Hash([c_custkey@0], 4), input_partitions=4 +12)----------------------FilterExec: substr(c_phone@1, 1, 2) IN (SET) ([13, 31, 23, 29, 30, 18, 17]) AND CAST(c_acctbal@2 AS Decimal128(19, 6)) > scalar_subquery() +13)------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +14)--------------------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/customer.tbl]]}, projection=[c_custkey, c_phone, c_acctbal], file_type=csv, has_header=false +15)--------------------RepartitionExec: partitioning=Hash([o_custkey@0], 4), input_partitions=4 +16)----------------------DataSourceExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:0..4223281], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:4223281..8446562], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:8446562..12669843], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:12669843..16893122]]}, projection=[o_custkey], file_type=csv, has_header=false +17)--AggregateExec: mode=Final, gby=[], aggr=[avg(customer.c_acctbal)] +18)----CoalescePartitionsExec +19)------AggregateExec: mode=Partial, gby=[], aggr=[avg(customer.c_acctbal)] +20)--------FilterExec: c_acctbal@5 > Some(0),15,2 AND substr(c_phone@4, 1, 2) IN (SET) ([13, 31, 23, 29, 30, 18, 17]) +21)----------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +22)------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/customer.tbl]]}, projection=[c_custkey, c_name, c_address, c_nationkey, c_phone, c_acctbal, c_mktsegment, c_comment, c_rev], file_type=csv, has_header=false From 09f167ac32239ff789464beefa73c0600a35fa4d Mon Sep 17 00:00:00 2001 From: Neil Conway Date: Mon, 30 Mar 2026 14:43:08 -0400 Subject: [PATCH 12/28] Implement statistics --- datafusion/physical-plan/src/scalar_subquery.rs | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/datafusion/physical-plan/src/scalar_subquery.rs b/datafusion/physical-plan/src/scalar_subquery.rs index 3db81669ffd27..ad0cbc4dff1f0 100644 --- a/datafusion/physical-plan/src/scalar_subquery.rs +++ b/datafusion/physical-plan/src/scalar_subquery.rs @@ -29,7 +29,7 @@ use std::fmt; use std::sync::Arc; use datafusion_common::tree_node::TreeNodeRecursion; -use datafusion_common::{Result, ScalarValue, exec_err, internal_err}; +use datafusion_common::{Result, ScalarValue, Statistics, exec_err, internal_err}; use datafusion_execution::TaskContext; use datafusion_expr::execution_props::ScalarSubqueryResults; use datafusion_physical_expr::PhysicalExpr; @@ -237,6 +237,10 @@ impl ExecutionPlan for ScalarSubqueryExec { self.true_for_input_only() } + fn partition_statistics(&self, partition: Option) -> Result> { + self.input.partition_statistics(partition) + } + fn cardinality_effect(&self) -> CardinalityEffect { CardinalityEffect::Equal } From 54a9f7995748831bb0d68eafc4bd07bff86163bc Mon Sep 17 00:00:00 2001 From: Neil Conway Date: Mon, 30 Mar 2026 15:16:32 -0400 Subject: [PATCH 13/28] Tweak comments --- datafusion/core/src/physical_planner.rs | 24 +++++++++++------------- 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 14734a774a165..2c09fbaae0b22 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -416,12 +416,11 @@ impl DefaultPhysicalPlanner { /// [`create_physical_expr`] can convert `Expr::ScalarSubquery` into /// [`ScalarSubqueryExpr`] nodes that read from the slots. /// - /// The resulting physical plan is wrapped in a [`ScalarSubqueryExec`] - /// that owns and executes those subquery plans before any data flows - /// through the main plan. If a subquery itself contains nested - /// uncorrelated subqueries, the recursive call produces its own - /// [`ScalarSubqueryExec`] inside the subquery plan — each level - /// manages only its own subqueries. + /// The resulting physical plan is wrapped in a [`ScalarSubqueryExec`] node + /// that executes those subquery plans before any data flows through the + /// main plan. If a subquery itself contains nested uncorrelated subqueries, + /// the recursive call produces its own [`ScalarSubqueryExec`] inside the + /// subquery plan — each level manages only its own subqueries. /// /// Returns a [`BoxFuture`] rather than using `async fn` because of /// this recursion. @@ -440,15 +439,14 @@ impl DefaultPhysicalPlanner { .plan_scalar_subqueries(&all_sq_refs, session_state) .await?; - // Create the shared results container and register it (along with - // the index map) in ExecutionProps so that `create_physical_expr` - // can resolve `Expr::ScalarSubquery` into `ScalarSubqueryExpr` - // nodes. We clone the SessionState so these are available - // throughout physical planning without mutating the caller's state. + // Create the shared results container and register it in + // ExecutionProps so that `create_physical_expr` can resolve + // `Expr::ScalarSubquery` into `ScalarSubqueryExpr` nodes. We clone + // the SessionState so these are available throughout physical + // planning without mutating the caller's state. // // Ideally, the subquery state would live in a dedicated planning - // context rather than on ExecutionProps (which is meant for - // session-level configuration). It's here because + // context rather than on ExecutionProps. It's here because // `create_physical_expr` only receives `&ExecutionProps`, and // changing that signature would be a breaking public API change. let results: Arc>> = From 2c256e789fec2a5d5c5f6fcaf7bdd40c0e86c551 Mon Sep 17 00:00:00 2001 From: Neil Conway Date: Mon, 30 Mar 2026 17:21:47 -0400 Subject: [PATCH 14/28] Ensure projection pushdown works inside uncorrelated subqueries --- .../optimizer/src/common_subexpr_eliminate.rs | 15 +++++++++++++-- datafusion/optimizer/src/eliminate_cross_join.rs | 14 +++++++++++++- .../optimizer/src/optimize_projections/mod.rs | 15 +++++++++++++++ .../optimizer/tests/optimizer_integration.rs | 6 +++--- datafusion/physical-expr/src/scalar_subquery.rs | 5 ++--- datafusion/physical-plan/src/scalar_subquery.rs | 15 +++++++-------- datafusion/sqllogictest/test_files/subquery.slt | 12 ++++++++++++ 7 files changed, 65 insertions(+), 17 deletions(-) diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index 88dba57d75b1d..1e574ad871454 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -590,8 +590,19 @@ impl OptimizerRule for CommonSubexprEliminate { | LogicalPlan::Unnest(_) | LogicalPlan::RecursiveQuery(_) => { // This rule handles recursion itself in a `ApplyOrder::TopDown` like - // manner. - plan.map_children(|c| self.rewrite(c, config))? + // manner. Process uncorrelated subqueries in expressions + // (e.g., Expr::ScalarSubquery), then direct children. + // Correlated subqueries are skipped to avoid interfering + // with decorrelation in subsequent optimizer passes. + plan.map_subqueries(|c| match &c { + LogicalPlan::Subquery(sq) if sq.outer_ref_columns.is_empty() => { + self.rewrite(c, config) + } + _ => Ok(Transformed::no(c)), + })? + .transform_sibling(|plan| { + plan.map_children(|c| self.rewrite(c, config)) + })? } }; diff --git a/datafusion/optimizer/src/eliminate_cross_join.rs b/datafusion/optimizer/src/eliminate_cross_join.rs index 3cb0516a6d296..f5ca2e4779bad 100644 --- a/datafusion/optimizer/src/eliminate_cross_join.rs +++ b/datafusion/optimizer/src/eliminate_cross_join.rs @@ -212,7 +212,19 @@ fn rewrite_children( plan: LogicalPlan, config: &dyn OptimizerConfig, ) -> Result> { - let transformed_plan = plan.map_children(|input| optimizer.rewrite(input, config))?; + // Process uncorrelated subqueries in expressions, then direct + // children. Correlated subqueries are skipped to avoid interfering + // with decorrelation in subsequent optimizer passes. + let transformed_plan = plan + .map_subqueries(|input| match &input { + LogicalPlan::Subquery(sq) if sq.outer_ref_columns.is_empty() => { + optimizer.rewrite(input, config) + } + _ => Ok(Transformed::no(input)), + })? + .transform_sibling(|plan| { + plan.map_children(|input| optimizer.rewrite(input, config)) + })?; // recompute schema if the plan was transformed if transformed_plan.transformed { diff --git a/datafusion/optimizer/src/optimize_projections/mod.rs b/datafusion/optimizer/src/optimize_projections/mod.rs index 93df300bb50b4..56c7df1aef366 100644 --- a/datafusion/optimizer/src/optimize_projections/mod.rs +++ b/datafusion/optimizer/src/optimize_projections/mod.rs @@ -458,6 +458,21 @@ fn optimize_projections( ) })?; + // Also optimize uncorrelated subquery plans embedded in expressions + // (e.g., Expr::ScalarSubquery). map_children only visits direct plan + // inputs, so subqueries must be handled separately. We skip correlated + // subqueries because modifying their plan structure can interfere with + // decorrelation in subsequent optimizer passes. + let transformed_plan = transformed_plan.transform_data(|plan| { + plan.map_subqueries(|subquery_plan| match &subquery_plan { + LogicalPlan::Subquery(sq) if sq.outer_ref_columns.is_empty() => { + let indices = RequiredIndices::new_for_all_exprs(&subquery_plan); + optimize_projections(subquery_plan, config, indices) + } + _ => Ok(Transformed::no(subquery_plan)), + }) + })?; + // If any of the children are transformed, we need to potentially update the plan's schema if transformed_plan.transformed { transformed_plan.map_data(|plan| plan.recompute_schema()) diff --git a/datafusion/optimizer/tests/optimizer_integration.rs b/datafusion/optimizer/tests/optimizer_integration.rs index ef97d603acb6f..391e78e8d3b0c 100644 --- a/datafusion/optimizer/tests/optimizer_integration.rs +++ b/datafusion/optimizer/tests/optimizer_integration.rs @@ -130,10 +130,10 @@ fn subquery_filter_with_cast() -> Result<()> { @r#" Filter: CAST(test.col_int32 AS Float64) > () Subquery: - Projection: avg(test.col_int32) - Aggregate: groupBy=[[]], aggr=[[avg(CAST(test.col_int32 AS Float64))]] + Aggregate: groupBy=[[]], aggr=[[avg(CAST(test.col_int32 AS Float64))]] + Projection: test.col_int32 Filter: CAST(test.col_utf8 AS Date32) >= Date32("2002-05-08") AND CAST(test.col_utf8 AS Date32) <= Date32("2002-05-13") - TableScan: test + TableScan: test projection=[col_int32, col_utf8] TableScan: test projection=[col_int32] "# ); diff --git a/datafusion/physical-expr/src/scalar_subquery.rs b/datafusion/physical-expr/src/scalar_subquery.rs index bc6a8ebfbd5cf..86548fac882c7 100644 --- a/datafusion/physical-expr/src/scalar_subquery.rs +++ b/datafusion/physical-expr/src/scalar_subquery.rs @@ -39,7 +39,7 @@ use datafusion_physical_expr_common::physical_expr::PhysicalExpr; /// shared results container at the appropriate index. /// /// If the same subquery appears multiple times in a query, there will be -/// multiple `ScalarSubqueryExec` with the same result index. +/// multiple `ScalarSubqueryExpr` with the same result index. #[derive(Debug)] pub struct ScalarSubqueryExpr { data_type: DataType, @@ -85,8 +85,7 @@ impl fmt::Display for ScalarSubqueryExpr { } // Two ScalarSubqueryExprs are the "same" if they share the same results -// container and have the same index. This follows the DynamicFilterPhysicalExpr -// precedent of identity-based equality. +// container and have the same index. impl Hash for ScalarSubqueryExpr { fn hash(&self, state: &mut H) { Arc::as_ptr(&self.results).hash(state); diff --git a/datafusion/physical-plan/src/scalar_subquery.rs b/datafusion/physical-plan/src/scalar_subquery.rs index ad0cbc4dff1f0..980348a91a24a 100644 --- a/datafusion/physical-plan/src/scalar_subquery.rs +++ b/datafusion/physical-plan/src/scalar_subquery.rs @@ -59,14 +59,13 @@ pub struct ScalarSubqueryLink { /// Manages execution of uncorrelated scalar subqueries for a single plan /// level. /// -/// This node has an asymmetric set of children: the first child is the -/// **main input plan**, whose batches are passed through unchanged. The -/// remaining children are **subquery plans**, each of which must produce -/// exactly zero or one row. Before any batches from the main input are -/// yielded, all subquery plans are executed and their scalar results are -/// stored in a shared results container ([`ScalarSubqueryResults`]). -/// [`ScalarSubqueryExpr`] nodes embedded in the main input's expressions -/// read from this container by index. +/// The first child node is the **main input plan**, whose batches are passed +/// through unchanged. The remaining children are **subquery plans**, each of +/// which must produce exactly zero or one row. Before any batches from the main +/// input are yielded, all subquery plans are executed and their scalar results +/// are stored in a shared results container ([`ScalarSubqueryResults`]). +/// [`ScalarSubqueryExpr`] nodes embedded in the main input's expressions read +/// from this container by index. /// /// All subqueries are evaluated eagerly when the first output partition is /// requested, before any rows from the main input are produced. diff --git a/datafusion/sqllogictest/test_files/subquery.slt b/datafusion/sqllogictest/test_files/subquery.slt index 5054ce85d1a76..2916dfef07202 100644 --- a/datafusion/sqllogictest/test_files/subquery.slt +++ b/datafusion/sqllogictest/test_files/subquery.slt @@ -746,6 +746,18 @@ logical_plan 09)--------TableScan: t2 10)--EmptyRelation: rows=1 +# Verify projection pushdown works inside uncorrelated scalar subqueries: +# the TableScan inside the subquery should only read t2_id, not all columns. +query TT +explain select t1_id from t1 where t1_id > (select max(t2_id) from t2) +---- +logical_plan +01)Filter: t1.t1_id > () +02)--Subquery: +03)----Aggregate: groupBy=[[]], aggr=[[max(t2.t2_id)]] +04)------TableScan: t2 projection=[t2_id] +05)--TableScan: t1 projection=[t1_id] + statement ok set datafusion.explain.logical_plan_only = false; From 99d9bcffebee6c58f97ada080d80dd48d1687d1f Mon Sep 17 00:00:00 2001 From: Neil Conway Date: Mon, 30 Mar 2026 19:53:44 -0400 Subject: [PATCH 15/28] Update expected plans --- .../test_files/tpch/plans/q11.slt.part | 51 ++++++++++--------- .../test_files/tpch/plans/q15.slt.part | 25 +++++---- .../test_files/tpch/plans/q22.slt.part | 29 +++++++---- 3 files changed, 59 insertions(+), 46 deletions(-) diff --git a/datafusion/sqllogictest/test_files/tpch/plans/q11.slt.part b/datafusion/sqllogictest/test_files/tpch/plans/q11.slt.part index 8f1fe62d64201..0c5b6d76dc1e1 100644 --- a/datafusion/sqllogictest/test_files/tpch/plans/q11.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/plans/q11.slt.part @@ -53,22 +53,25 @@ logical_plan 04)------Subquery: 05)--------Projection: CAST(CAST(sum(partsupp.ps_supplycost * partsupp.ps_availqty) AS Float64) * Float64(0.0001) AS Decimal128(38, 15)) 06)----------Aggregate: groupBy=[[]], aggr=[[sum(partsupp.ps_supplycost * CAST(partsupp.ps_availqty AS Decimal128(10, 0)))]] -07)------------Inner Join: supplier.s_nationkey = nation.n_nationkey -08)--------------Inner Join: partsupp.ps_suppkey = supplier.s_suppkey -09)----------------TableScan: partsupp -10)----------------TableScan: supplier -11)--------------Filter: nation.n_name = Utf8View("GERMANY") -12)----------------TableScan: nation, partial_filters=[nation.n_name = Utf8View("GERMANY")] -13)------Aggregate: groupBy=[[partsupp.ps_partkey]], aggr=[[sum(partsupp.ps_supplycost * CAST(partsupp.ps_availqty AS Decimal128(10, 0)))]] -14)--------Projection: partsupp.ps_partkey, partsupp.ps_availqty, partsupp.ps_supplycost -15)----------Inner Join: supplier.s_nationkey = nation.n_nationkey -16)------------Projection: partsupp.ps_partkey, partsupp.ps_availqty, partsupp.ps_supplycost, supplier.s_nationkey -17)--------------Inner Join: partsupp.ps_suppkey = supplier.s_suppkey -18)----------------TableScan: partsupp projection=[ps_partkey, ps_suppkey, ps_availqty, ps_supplycost] -19)----------------TableScan: supplier projection=[s_suppkey, s_nationkey] -20)------------Projection: nation.n_nationkey -21)--------------Filter: nation.n_name = Utf8View("GERMANY") -22)----------------TableScan: nation projection=[n_nationkey, n_name], partial_filters=[nation.n_name = Utf8View("GERMANY")] +07)------------Projection: partsupp.ps_availqty, partsupp.ps_supplycost +08)--------------Inner Join: supplier.s_nationkey = nation.n_nationkey +09)----------------Projection: partsupp.ps_availqty, partsupp.ps_supplycost, supplier.s_nationkey +10)------------------Inner Join: partsupp.ps_suppkey = supplier.s_suppkey +11)--------------------TableScan: partsupp projection=[ps_suppkey, ps_availqty, ps_supplycost] +12)--------------------TableScan: supplier projection=[s_suppkey, s_nationkey] +13)----------------Projection: nation.n_nationkey +14)------------------Filter: nation.n_name = Utf8View("GERMANY") +15)--------------------TableScan: nation projection=[n_nationkey, n_name], partial_filters=[nation.n_name = Utf8View("GERMANY")] +16)------Aggregate: groupBy=[[partsupp.ps_partkey]], aggr=[[sum(partsupp.ps_supplycost * CAST(partsupp.ps_availqty AS Decimal128(10, 0)))]] +17)--------Projection: partsupp.ps_partkey, partsupp.ps_availqty, partsupp.ps_supplycost +18)----------Inner Join: supplier.s_nationkey = nation.n_nationkey +19)------------Projection: partsupp.ps_partkey, partsupp.ps_availqty, partsupp.ps_supplycost, supplier.s_nationkey +20)--------------Inner Join: partsupp.ps_suppkey = supplier.s_suppkey +21)----------------TableScan: partsupp projection=[ps_partkey, ps_suppkey, ps_availqty, ps_supplycost] +22)----------------TableScan: supplier projection=[s_suppkey, s_nationkey] +23)------------Projection: nation.n_nationkey +24)--------------Filter: nation.n_name = Utf8View("GERMANY") +25)----------------TableScan: nation projection=[n_nationkey, n_name], partial_filters=[nation.n_name = Utf8View("GERMANY")] physical_plan 01)ScalarSubqueryExec: subqueries=1 02)--RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1, maintains_sort_order=true @@ -94,14 +97,14 @@ physical_plan 22)----AggregateExec: mode=Final, gby=[], aggr=[sum(partsupp.ps_supplycost * partsupp.ps_availqty)] 23)------CoalescePartitionsExec 24)--------AggregateExec: mode=Partial, gby=[], aggr=[sum(partsupp.ps_supplycost * partsupp.ps_availqty)] -25)----------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(s_nationkey@9, n_nationkey@0)] -26)------------RepartitionExec: partitioning=Hash([s_nationkey@9], 4), input_partitions=4 -27)--------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(ps_suppkey@1, s_suppkey@0)] -28)----------------RepartitionExec: partitioning=Hash([ps_suppkey@1], 4), input_partitions=4 -29)------------------DataSourceExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:0..2932049], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:2932049..5864098], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:5864098..8796147], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:8796147..11728193]]}, projection=[ps_partkey, ps_suppkey, ps_availqty, ps_supplycost, ps_comment, ps_rev], file_type=csv, has_header=false +25)----------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(s_nationkey@2, n_nationkey@0)], projection=[ps_availqty@0, ps_supplycost@1] +26)------------RepartitionExec: partitioning=Hash([s_nationkey@2], 4), input_partitions=4 +27)--------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(ps_suppkey@0, s_suppkey@0)], projection=[ps_availqty@1, ps_supplycost@2, s_nationkey@4] +28)----------------RepartitionExec: partitioning=Hash([ps_suppkey@0], 4), input_partitions=4 +29)------------------DataSourceExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:0..2932049], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:2932049..5864098], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:5864098..8796147], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:8796147..11728193]]}, projection=[ps_suppkey, ps_availqty, ps_supplycost], file_type=csv, has_header=false 30)----------------RepartitionExec: partitioning=Hash([s_suppkey@0], 4), input_partitions=1 -31)------------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/supplier.tbl]]}, projection=[s_suppkey, s_name, s_address, s_nationkey, s_phone, s_acctbal, s_comment, s_rev], file_type=csv, has_header=false +31)------------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/supplier.tbl]]}, projection=[s_suppkey, s_nationkey], file_type=csv, has_header=false 32)------------RepartitionExec: partitioning=Hash([n_nationkey@0], 4), input_partitions=4 -33)--------------FilterExec: n_name@1 = GERMANY +33)--------------FilterExec: n_name@1 = GERMANY, projection=[n_nationkey@0] 34)----------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -35)------------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/nation.tbl]]}, projection=[n_nationkey, n_name, n_regionkey, n_comment, n_rev], file_type=csv, has_header=false +35)------------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/nation.tbl]]}, projection=[n_nationkey, n_name], file_type=csv, has_header=false diff --git a/datafusion/sqllogictest/test_files/tpch/plans/q15.slt.part b/datafusion/sqllogictest/test_files/tpch/plans/q15.slt.part index 5bdca882218b0..ecf75c4712a99 100644 --- a/datafusion/sqllogictest/test_files/tpch/plans/q15.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/plans/q15.slt.part @@ -58,15 +58,14 @@ logical_plan 06)--------Projection: lineitem.l_suppkey AS supplier_no, sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount) AS total_revenue 07)----------Filter: sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount) = () 08)------------Subquery: -09)--------------Projection: max(revenue0.total_revenue) -10)----------------Aggregate: groupBy=[[]], aggr=[[max(revenue0.total_revenue)]] -11)------------------SubqueryAlias: revenue0 -12)--------------------Projection: lineitem.l_suppkey AS supplier_no, sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount) AS total_revenue -13)----------------------Projection: lineitem.l_suppkey, sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount) -14)------------------------Aggregate: groupBy=[[lineitem.l_suppkey]], aggr=[[sum(lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount)) AS sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)]] -15)--------------------------Filter: lineitem.l_shipdate >= Date32("1996-01-01") AND lineitem.l_shipdate < Date32("1996-04-01") -16)----------------------------TableScan: lineitem, partial_filters=[lineitem.l_shipdate >= Date32("1996-01-01"), lineitem.l_shipdate < Date32("1996-04-01")] -17)------------Aggregate: groupBy=[[lineitem.l_suppkey]], aggr=[[sum(lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount)) AS sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)]] +09)--------------Aggregate: groupBy=[[]], aggr=[[max(revenue0.total_revenue)]] +10)----------------SubqueryAlias: revenue0 +11)------------------Projection: sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount) AS total_revenue +12)--------------------Aggregate: groupBy=[[lineitem.l_suppkey]], aggr=[[sum(lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount)) AS sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)]] +13)----------------------Projection: lineitem.l_suppkey, lineitem.l_extendedprice, lineitem.l_discount +14)------------------------Filter: lineitem.l_shipdate >= Date32("1996-01-01") AND lineitem.l_shipdate < Date32("1996-04-01") +15)--------------------------TableScan: lineitem projection=[l_suppkey, l_extendedprice, l_discount, l_shipdate], partial_filters=[lineitem.l_shipdate >= Date32("1996-01-01"), lineitem.l_shipdate < Date32("1996-04-01")] +16)------------Aggregate: groupBy=[[lineitem.l_suppkey]], aggr=[[sum(lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount)) AS sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)]] 18)--------------Projection: lineitem.l_suppkey, lineitem.l_extendedprice, lineitem.l_discount 19)----------------Filter: lineitem.l_shipdate >= Date32("1996-01-01") AND lineitem.l_shipdate < Date32("1996-04-01") 20)------------------TableScan: lineitem projection=[l_suppkey, l_extendedprice, l_discount, l_shipdate], partial_filters=[lineitem.l_shipdate >= Date32("1996-01-01"), lineitem.l_shipdate < Date32("1996-04-01")] @@ -88,9 +87,9 @@ physical_plan 15)--AggregateExec: mode=Final, gby=[], aggr=[max(revenue0.total_revenue)] 16)----CoalescePartitionsExec 17)------AggregateExec: mode=Partial, gby=[], aggr=[max(revenue0.total_revenue)] -18)--------ProjectionExec: expr=[l_suppkey@0 as supplier_no, sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)@1 as total_revenue] +18)--------ProjectionExec: expr=[sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)@1 as total_revenue] 19)----------AggregateExec: mode=FinalPartitioned, gby=[l_suppkey@0 as l_suppkey], aggr=[sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)] 20)------------RepartitionExec: partitioning=Hash([l_suppkey@0], 4), input_partitions=4 -21)--------------AggregateExec: mode=Partial, gby=[l_suppkey@2 as l_suppkey], aggr=[sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)] -22)----------------FilterExec: l_shipdate@10 >= 1996-01-01 AND l_shipdate@10 < 1996-04-01 -23)------------------DataSourceExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:0..18561749], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:18561749..37123498], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:37123498..55685247], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:55685247..74246996]]}, projection=[l_orderkey, l_partkey, l_suppkey, l_linenumber, l_quantity, l_extendedprice, l_discount, l_tax, l_returnflag, l_linestatus, l_shipdate, l_commitdate, l_receiptdate, l_shipinstruct, l_shipmode, l_comment, l_rev], file_type=csv, has_header=false +21)--------------AggregateExec: mode=Partial, gby=[l_suppkey@0 as l_suppkey], aggr=[sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)] +22)----------------FilterExec: l_shipdate@3 >= 1996-01-01 AND l_shipdate@3 < 1996-04-01, projection=[l_suppkey@0, l_extendedprice@1, l_discount@2] +23)------------------DataSourceExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:0..18561749], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:18561749..37123498], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:37123498..55685247], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:55685247..74246996]]}, projection=[l_suppkey, l_extendedprice, l_discount, l_shipdate], file_type=csv, has_header=false diff --git a/datafusion/sqllogictest/test_files/tpch/plans/q22.slt.part b/datafusion/sqllogictest/test_files/tpch/plans/q22.slt.part index 306deec725fe3..9784423f72a8d 100644 --- a/datafusion/sqllogictest/test_files/tpch/plans/q22.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/plans/q22.slt.part @@ -64,20 +64,25 @@ logical_plan 06)----------LeftAnti Join: customer.c_custkey = __correlated_sq_1.o_custkey 07)------------Filter: substr(customer.c_phone, Int64(1), Int64(2)) IN ([Utf8View("13"), Utf8View("31"), Utf8View("23"), Utf8View("29"), Utf8View("30"), Utf8View("18"), Utf8View("17")]) AND CAST(customer.c_acctbal AS Decimal128(19, 6)) > () 08)--------------Subquery: -09)----------------Projection: avg(customer.c_acctbal) -10)------------------Aggregate: groupBy=[[]], aggr=[[avg(customer.c_acctbal)]] +09)----------------Aggregate: groupBy=[[]], aggr=[[avg(customer.c_acctbal)]] +10)------------------Projection: customer.c_acctbal 11)--------------------Filter: customer.c_acctbal > Decimal128(Some(0),15,2) AND substr(customer.c_phone, Int64(1), Int64(2)) IN ([Utf8View("13"), Utf8View("31"), Utf8View("23"), Utf8View("29"), Utf8View("30"), Utf8View("18"), Utf8View("17")]) -12)----------------------TableScan: customer, partial_filters=[customer.c_acctbal > Decimal128(Some(0),15,2), substr(customer.c_phone, Int64(1), Int64(2)) IN ([Utf8View("13"), Utf8View("31"), Utf8View("23"), Utf8View("29"), Utf8View("30"), Utf8View("18"), Utf8View("17")])] -13)--------------TableScan: customer projection=[c_custkey, c_phone, c_acctbal], partial_filters=[substr(customer.c_phone, Int64(1), Int64(2)) IN ([Utf8View("13"), Utf8View("31"), Utf8View("23"), Utf8View("29"), Utf8View("30"), Utf8View("18"), Utf8View("17")]), CAST(customer.c_acctbal AS Decimal128(19, 6)) > ()] +12)----------------------TableScan: customer projection=[c_phone, c_acctbal], partial_filters=[customer.c_acctbal > Decimal128(Some(0),15,2), substr(customer.c_phone, Int64(1), Int64(2)) IN ([Utf8View("13"), Utf8View("31"), Utf8View("23"), Utf8View("29"), Utf8View("30"), Utf8View("18"), Utf8View("17")])] +13)--------------TableScan: customer projection=[c_custkey, c_phone, c_acctbal], partial_filters=[substr(customer.c_phone, Int64(1), Int64(2)) IN ([Utf8View("13"), Utf8View("31"), Utf8View("23"), Utf8View("29"), Utf8View("30"), Utf8View("18"), Utf8View("17")]), CAST(customer.c_acctbal AS Decimal128(19, 6)) > (), CAST(customer.c_acctbal AS Decimal128(19, 6)) > ()] 14)----------------Subquery: 15)------------------Projection: avg(customer.c_acctbal) 16)--------------------Aggregate: groupBy=[[]], aggr=[[avg(customer.c_acctbal)]] 17)----------------------Filter: customer.c_acctbal > Decimal128(Some(0),15,2) AND substr(customer.c_phone, Int64(1), Int64(2)) IN ([Utf8View("13"), Utf8View("31"), Utf8View("23"), Utf8View("29"), Utf8View("30"), Utf8View("18"), Utf8View("17")]) 18)------------------------TableScan: customer, partial_filters=[customer.c_acctbal > Decimal128(Some(0),15,2), substr(customer.c_phone, Int64(1), Int64(2)) IN ([Utf8View("13"), Utf8View("31"), Utf8View("23"), Utf8View("29"), Utf8View("30"), Utf8View("18"), Utf8View("17")])] -19)------------SubqueryAlias: __correlated_sq_1 -20)--------------TableScan: orders projection=[o_custkey] +19)----------------Subquery: +20)------------------Aggregate: groupBy=[[]], aggr=[[avg(customer.c_acctbal)]] +21)--------------------Projection: customer.c_acctbal +22)----------------------Filter: customer.c_acctbal > Decimal128(Some(0),15,2) AND substr(customer.c_phone, Int64(1), Int64(2)) IN ([Utf8View("13"), Utf8View("31"), Utf8View("23"), Utf8View("29"), Utf8View("30"), Utf8View("18"), Utf8View("17")]) +23)------------------------TableScan: customer projection=[c_phone, c_acctbal], partial_filters=[customer.c_acctbal > Decimal128(Some(0),15,2), substr(customer.c_phone, Int64(1), Int64(2)) IN ([Utf8View("13"), Utf8View("31"), Utf8View("23"), Utf8View("29"), Utf8View("30"), Utf8View("18"), Utf8View("17")])] +24)------------SubqueryAlias: __correlated_sq_1 +25)--------------TableScan: orders projection=[o_custkey] physical_plan -01)ScalarSubqueryExec: subqueries=1 +01)ScalarSubqueryExec: subqueries=2 02)--RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1, maintains_sort_order=true 03)----SortPreservingMergeExec: [cntrycode@0 ASC NULLS LAST] 04)------SortExec: expr=[cntrycode@0 ASC NULLS LAST], preserve_partitioning=[true] @@ -96,6 +101,12 @@ physical_plan 17)--AggregateExec: mode=Final, gby=[], aggr=[avg(customer.c_acctbal)] 18)----CoalescePartitionsExec 19)------AggregateExec: mode=Partial, gby=[], aggr=[avg(customer.c_acctbal)] -20)--------FilterExec: c_acctbal@5 > Some(0),15,2 AND substr(c_phone@4, 1, 2) IN (SET) ([13, 31, 23, 29, 30, 18, 17]) +20)--------FilterExec: c_acctbal@1 > Some(0),15,2 AND substr(c_phone@0, 1, 2) IN (SET) ([13, 31, 23, 29, 30, 18, 17]), projection=[c_acctbal@1] 21)----------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -22)------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/customer.tbl]]}, projection=[c_custkey, c_name, c_address, c_nationkey, c_phone, c_acctbal, c_mktsegment, c_comment, c_rev], file_type=csv, has_header=false +22)------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/customer.tbl]]}, projection=[c_phone, c_acctbal], file_type=csv, has_header=false +23)--AggregateExec: mode=Final, gby=[], aggr=[avg(customer.c_acctbal)] +24)----CoalescePartitionsExec +25)------AggregateExec: mode=Partial, gby=[], aggr=[avg(customer.c_acctbal)] +26)--------FilterExec: c_acctbal@5 > Some(0),15,2 AND substr(c_phone@4, 1, 2) IN (SET) ([13, 31, 23, 29, 30, 18, 17]) +27)----------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +28)------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/customer.tbl]]}, projection=[c_custkey, c_name, c_address, c_nationkey, c_phone, c_acctbal, c_mktsegment, c_comment, c_rev], file_type=csv, has_header=false From 9a11d62f2468d5ededd7c070f9d4c107382ccc27 Mon Sep 17 00:00:00 2001 From: Neil Conway Date: Mon, 30 Mar 2026 20:36:57 -0400 Subject: [PATCH 16/28] Fix overlooked cases for projection pushdown --- .../optimizer/src/optimize_projections/mod.rs | 51 +++++++----- .../sqllogictest/test_files/subquery.slt | 82 +++++++++++++------ .../test_files/tpch/plans/q22.slt.part | 25 ++---- 3 files changed, 97 insertions(+), 61 deletions(-) diff --git a/datafusion/optimizer/src/optimize_projections/mod.rs b/datafusion/optimizer/src/optimize_projections/mod.rs index 56c7df1aef366..d6adee95e10e7 100644 --- a/datafusion/optimizer/src/optimize_projections/mod.rs +++ b/datafusion/optimizer/src/optimize_projections/mod.rs @@ -136,9 +136,11 @@ fn optimize_projections( // their parents' required indices. match plan { LogicalPlan::Projection(proj) => { - return merge_consecutive_projections(proj)?.transform_data(|proj| { - rewrite_projection_given_requirements(proj, config, &indices) - }); + return merge_consecutive_projections(proj)? + .transform_data(|proj| { + rewrite_projection_given_requirements(proj, config, &indices) + })? + .transform_data(|plan| optimize_subqueries(plan, config)); } LogicalPlan::Aggregate(aggregate) => { // Split parent requirements to GROUP BY and aggregate sections: @@ -210,7 +212,8 @@ fn optimize_projections( new_aggr_expr, ) .map(LogicalPlan::Aggregate) - }); + })? + .transform_data(|plan| optimize_subqueries(plan, config)); } LogicalPlan::Window(window) => { let input_schema = Arc::clone(window.input.schema()); @@ -250,7 +253,8 @@ fn optimize_projections( .map(LogicalPlan::Window) .map(Transformed::yes) } - }); + })? + .transform_data(|plan| optimize_subqueries(plan, config)); } LogicalPlan::TableScan(table_scan) => { let TableScan { @@ -271,7 +275,8 @@ fn optimize_projections( let new_scan = TableScan::try_new(table_name, source, Some(projection), filters, fetch)?; - return Ok(Transformed::yes(LogicalPlan::TableScan(new_scan))); + return Transformed::yes(LogicalPlan::TableScan(new_scan)) + .transform_data(|plan| optimize_subqueries(plan, config)); } // Other node types are handled below _ => {} @@ -458,20 +463,8 @@ fn optimize_projections( ) })?; - // Also optimize uncorrelated subquery plans embedded in expressions - // (e.g., Expr::ScalarSubquery). map_children only visits direct plan - // inputs, so subqueries must be handled separately. We skip correlated - // subqueries because modifying their plan structure can interfere with - // decorrelation in subsequent optimizer passes. - let transformed_plan = transformed_plan.transform_data(|plan| { - plan.map_subqueries(|subquery_plan| match &subquery_plan { - LogicalPlan::Subquery(sq) if sq.outer_ref_columns.is_empty() => { - let indices = RequiredIndices::new_for_all_exprs(&subquery_plan); - optimize_projections(subquery_plan, config, indices) - } - _ => Ok(Transformed::no(subquery_plan)), - }) - })?; + let transformed_plan = + transformed_plan.transform_data(|plan| optimize_subqueries(plan, config))?; // If any of the children are transformed, we need to potentially update the plan's schema if transformed_plan.transformed { @@ -483,6 +476,24 @@ fn optimize_projections( /// Merges consecutive projections. /// +/// Optimizes uncorrelated subquery plans embedded in expressions of the given +/// plan node (e.g., `Expr::ScalarSubquery`). `map_children` only visits direct +/// plan inputs, so subqueries must be handled separately. Correlated subqueries +/// are skipped to avoid interfering with decorrelation in subsequent optimizer +/// passes. +fn optimize_subqueries( + plan: LogicalPlan, + config: &dyn OptimizerConfig, +) -> Result> { + plan.map_subqueries(|subquery_plan| match &subquery_plan { + LogicalPlan::Subquery(sq) if sq.outer_ref_columns.is_empty() => { + let indices = RequiredIndices::new_for_all_exprs(&subquery_plan); + optimize_projections(subquery_plan, config, indices) + } + _ => Ok(Transformed::no(subquery_plan)), + }) +} + /// Given a projection `proj`, this function attempts to merge it with a previous /// projection if it exists and if merging is beneficial. Merging is considered /// beneficial when expressions in the current projection are non-trivial and diff --git a/datafusion/sqllogictest/test_files/subquery.slt b/datafusion/sqllogictest/test_files/subquery.slt index 2916dfef07202..ba57ed6361184 100644 --- a/datafusion/sqllogictest/test_files/subquery.slt +++ b/datafusion/sqllogictest/test_files/subquery.slt @@ -727,7 +727,7 @@ logical_plan 02)--Subquery: 03)----Projection: count(Int64(1)) AS count(*) 04)------Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] -05)--------TableScan: t1 +05)--------TableScan: t1 projection=[] 06)--EmptyRelation: rows=1 #simple_uncorrelated_scalar_subquery2 @@ -739,15 +739,19 @@ logical_plan 02)--Subquery: 03)----Projection: count(Int64(1)) AS count(*) 04)------Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] -05)--------TableScan: t1 +05)--------TableScan: t1 projection=[] 06)--Subquery: -07)----Projection: count(Int64(1)) -08)------Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] -09)--------TableScan: t2 -10)--EmptyRelation: rows=1 +07)----Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] +08)------TableScan: t2 projection=[] +09)--EmptyRelation: rows=1 -# Verify projection pushdown works inside uncorrelated scalar subqueries: -# the TableScan inside the subquery should only read t2_id, not all columns. +# Verify projection pushdown works inside uncorrelated scalar subqueries. +# Each test targets a different early-return path in OptimizeProjections +# to ensure subquery plans are optimized regardless of where the subquery +# expression appears. + +# Subquery in a Filter predicate: the TableScan inside the subquery should +# only read t2_id, not all columns. query TT explain select t1_id from t1 where t1_id > (select max(t2_id) from t2) ---- @@ -758,6 +762,41 @@ logical_plan 04)------TableScan: t2 projection=[t2_id] 05)--TableScan: t1 projection=[t1_id] +# Subquery in a Projection expression +query TT +explain select t1_id, (select max(t2_id) from t2) as max_t2 from t1 +---- +logical_plan +01)Projection: t1.t1_id, () AS max_t2 +02)--Subquery: +03)----Aggregate: groupBy=[[]], aggr=[[max(t2.t2_id)]] +04)------TableScan: t2 projection=[t2_id] +05)--TableScan: t1 projection=[t1_id] + +# Subquery in an Aggregate expression +query TT +explain select sum(t1_int + (select min(t2_int) from t2)) as s from t1 +---- +logical_plan +01)Projection: sum(t1.t1_int + min(t2.t2_int)) AS s +02)--Aggregate: groupBy=[[]], aggr=[[sum(CAST(t1.t1_int + () AS Int64))]] +03)----Subquery: +04)------Aggregate: groupBy=[[]], aggr=[[min(t2.t2_int)]] +05)--------TableScan: t2 projection=[t2_int] +06)----TableScan: t1 projection=[t1_int] + +# Subquery in a Window expression +query TT +explain select t1_id, sum(t1_int + (select min(t2_int) from t2)) over () as win from t1 +---- +logical_plan +01)Projection: t1.t1_id, sum(t1.t1_int + min(t2.t2_int)) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING AS win +02)--WindowAggr: windowExpr=[[sum(CAST(t1.t1_int + () AS Int64)) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]] +03)----Subquery: +04)------Aggregate: groupBy=[[]], aggr=[[min(t2.t2_int)]] +05)--------TableScan: t2 projection=[t2_int] +06)----TableScan: t1 projection=[t1_id, t1_int] + statement ok set datafusion.explain.logical_plan_only = false; @@ -769,12 +808,11 @@ logical_plan 02)--Subquery: 03)----Projection: count(Int64(1)) AS count(*) 04)------Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] -05)--------TableScan: t1 +05)--------TableScan: t1 projection=[] 06)--Subquery: -07)----Projection: count(Int64(1)) -08)------Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] -09)--------TableScan: t2 -10)--EmptyRelation: rows=1 +07)----Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] +08)------TableScan: t2 projection=[] +09)--EmptyRelation: rows=1 physical_plan 01)ScalarSubqueryExec: subqueries=2 02)--ProjectionExec: expr=[scalar_subquery() as b, scalar_subquery() as count(Int64(1))] @@ -1816,12 +1854,11 @@ logical_plan 02)--Subquery: 03)----Projection: max(sq_values.v) + () 04)------Subquery: -05)--------Projection: min(sq_values.v) -06)----------Aggregate: groupBy=[[]], aggr=[[min(sq_values.v)]] -07)------------TableScan: sq_values -08)------Aggregate: groupBy=[[]], aggr=[[max(sq_values.v)]] -09)--------TableScan: sq_values -10)--EmptyRelation: rows=1 +05)--------Aggregate: groupBy=[[]], aggr=[[min(sq_values.v)]] +06)----------TableScan: sq_values projection=[v] +07)------Aggregate: groupBy=[[]], aggr=[[max(sq_values.v)]] +08)--------TableScan: sq_values projection=[v] +09)--EmptyRelation: rows=1 physical_plan 01)ScalarSubqueryExec: subqueries=1 02)--ProjectionExec: expr=[scalar_subquery() as max(sq_values.v) + min(sq_values.v)] @@ -1857,10 +1894,9 @@ logical_plan 01)Projection: __common_expr_1 + __common_expr_1 AS doubled 02)--Projection: () AS __common_expr_1 03)----Subquery: -04)------Projection: max(sq_values.v) -05)--------Aggregate: groupBy=[[]], aggr=[[max(sq_values.v)]] -06)----------TableScan: sq_values -07)----EmptyRelation: rows=1 +04)------Aggregate: groupBy=[[]], aggr=[[max(sq_values.v)]] +05)--------TableScan: sq_values projection=[v] +06)----EmptyRelation: rows=1 physical_plan 01)ScalarSubqueryExec: subqueries=1 02)--ProjectionExec: expr=[__common_expr_1@0 + __common_expr_1@0 as doubled] diff --git a/datafusion/sqllogictest/test_files/tpch/plans/q22.slt.part b/datafusion/sqllogictest/test_files/tpch/plans/q22.slt.part index 9784423f72a8d..fb63f9ede2b81 100644 --- a/datafusion/sqllogictest/test_files/tpch/plans/q22.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/plans/q22.slt.part @@ -68,21 +68,16 @@ logical_plan 10)------------------Projection: customer.c_acctbal 11)--------------------Filter: customer.c_acctbal > Decimal128(Some(0),15,2) AND substr(customer.c_phone, Int64(1), Int64(2)) IN ([Utf8View("13"), Utf8View("31"), Utf8View("23"), Utf8View("29"), Utf8View("30"), Utf8View("18"), Utf8View("17")]) 12)----------------------TableScan: customer projection=[c_phone, c_acctbal], partial_filters=[customer.c_acctbal > Decimal128(Some(0),15,2), substr(customer.c_phone, Int64(1), Int64(2)) IN ([Utf8View("13"), Utf8View("31"), Utf8View("23"), Utf8View("29"), Utf8View("30"), Utf8View("18"), Utf8View("17")])] -13)--------------TableScan: customer projection=[c_custkey, c_phone, c_acctbal], partial_filters=[substr(customer.c_phone, Int64(1), Int64(2)) IN ([Utf8View("13"), Utf8View("31"), Utf8View("23"), Utf8View("29"), Utf8View("30"), Utf8View("18"), Utf8View("17")]), CAST(customer.c_acctbal AS Decimal128(19, 6)) > (), CAST(customer.c_acctbal AS Decimal128(19, 6)) > ()] +13)--------------TableScan: customer projection=[c_custkey, c_phone, c_acctbal], partial_filters=[substr(customer.c_phone, Int64(1), Int64(2)) IN ([Utf8View("13"), Utf8View("31"), Utf8View("23"), Utf8View("29"), Utf8View("30"), Utf8View("18"), Utf8View("17")]), CAST(customer.c_acctbal AS Decimal128(19, 6)) > ()] 14)----------------Subquery: -15)------------------Projection: avg(customer.c_acctbal) -16)--------------------Aggregate: groupBy=[[]], aggr=[[avg(customer.c_acctbal)]] +15)------------------Aggregate: groupBy=[[]], aggr=[[avg(customer.c_acctbal)]] +16)--------------------Projection: customer.c_acctbal 17)----------------------Filter: customer.c_acctbal > Decimal128(Some(0),15,2) AND substr(customer.c_phone, Int64(1), Int64(2)) IN ([Utf8View("13"), Utf8View("31"), Utf8View("23"), Utf8View("29"), Utf8View("30"), Utf8View("18"), Utf8View("17")]) -18)------------------------TableScan: customer, partial_filters=[customer.c_acctbal > Decimal128(Some(0),15,2), substr(customer.c_phone, Int64(1), Int64(2)) IN ([Utf8View("13"), Utf8View("31"), Utf8View("23"), Utf8View("29"), Utf8View("30"), Utf8View("18"), Utf8View("17")])] -19)----------------Subquery: -20)------------------Aggregate: groupBy=[[]], aggr=[[avg(customer.c_acctbal)]] -21)--------------------Projection: customer.c_acctbal -22)----------------------Filter: customer.c_acctbal > Decimal128(Some(0),15,2) AND substr(customer.c_phone, Int64(1), Int64(2)) IN ([Utf8View("13"), Utf8View("31"), Utf8View("23"), Utf8View("29"), Utf8View("30"), Utf8View("18"), Utf8View("17")]) -23)------------------------TableScan: customer projection=[c_phone, c_acctbal], partial_filters=[customer.c_acctbal > Decimal128(Some(0),15,2), substr(customer.c_phone, Int64(1), Int64(2)) IN ([Utf8View("13"), Utf8View("31"), Utf8View("23"), Utf8View("29"), Utf8View("30"), Utf8View("18"), Utf8View("17")])] -24)------------SubqueryAlias: __correlated_sq_1 -25)--------------TableScan: orders projection=[o_custkey] +18)------------------------TableScan: customer projection=[c_phone, c_acctbal], partial_filters=[customer.c_acctbal > Decimal128(Some(0),15,2), substr(customer.c_phone, Int64(1), Int64(2)) IN ([Utf8View("13"), Utf8View("31"), Utf8View("23"), Utf8View("29"), Utf8View("30"), Utf8View("18"), Utf8View("17")])] +19)------------SubqueryAlias: __correlated_sq_1 +20)--------------TableScan: orders projection=[o_custkey] physical_plan -01)ScalarSubqueryExec: subqueries=2 +01)ScalarSubqueryExec: subqueries=1 02)--RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1, maintains_sort_order=true 03)----SortPreservingMergeExec: [cntrycode@0 ASC NULLS LAST] 04)------SortExec: expr=[cntrycode@0 ASC NULLS LAST], preserve_partitioning=[true] @@ -104,9 +99,3 @@ physical_plan 20)--------FilterExec: c_acctbal@1 > Some(0),15,2 AND substr(c_phone@0, 1, 2) IN (SET) ([13, 31, 23, 29, 30, 18, 17]), projection=[c_acctbal@1] 21)----------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 22)------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/customer.tbl]]}, projection=[c_phone, c_acctbal], file_type=csv, has_header=false -23)--AggregateExec: mode=Final, gby=[], aggr=[avg(customer.c_acctbal)] -24)----CoalescePartitionsExec -25)------AggregateExec: mode=Partial, gby=[], aggr=[avg(customer.c_acctbal)] -26)--------FilterExec: c_acctbal@5 > Some(0),15,2 AND substr(c_phone@4, 1, 2) IN (SET) ([13, 31, 23, 29, 30, 18, 17]) -27)----------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -28)------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/customer.tbl]]}, projection=[c_custkey, c_name, c_address, c_nationkey, c_phone, c_acctbal, c_mktsegment, c_comment, c_rev], file_type=csv, has_header=false From 5aef67edfb1db0461a9641cc912f403b0cb92640 Mon Sep 17 00:00:00 2001 From: Neil Conway Date: Tue, 31 Mar 2026 10:29:34 -0400 Subject: [PATCH 17/28] Fix line numbers in expected EXPLAIN --- datafusion/sqllogictest/test_files/tpch/plans/q15.slt.part | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/datafusion/sqllogictest/test_files/tpch/plans/q15.slt.part b/datafusion/sqllogictest/test_files/tpch/plans/q15.slt.part index ecf75c4712a99..3e1aca318b5c7 100644 --- a/datafusion/sqllogictest/test_files/tpch/plans/q15.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/plans/q15.slt.part @@ -66,9 +66,9 @@ logical_plan 14)------------------------Filter: lineitem.l_shipdate >= Date32("1996-01-01") AND lineitem.l_shipdate < Date32("1996-04-01") 15)--------------------------TableScan: lineitem projection=[l_suppkey, l_extendedprice, l_discount, l_shipdate], partial_filters=[lineitem.l_shipdate >= Date32("1996-01-01"), lineitem.l_shipdate < Date32("1996-04-01")] 16)------------Aggregate: groupBy=[[lineitem.l_suppkey]], aggr=[[sum(lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount)) AS sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)]] -18)--------------Projection: lineitem.l_suppkey, lineitem.l_extendedprice, lineitem.l_discount -19)----------------Filter: lineitem.l_shipdate >= Date32("1996-01-01") AND lineitem.l_shipdate < Date32("1996-04-01") -20)------------------TableScan: lineitem projection=[l_suppkey, l_extendedprice, l_discount, l_shipdate], partial_filters=[lineitem.l_shipdate >= Date32("1996-01-01"), lineitem.l_shipdate < Date32("1996-04-01")] +17)--------------Projection: lineitem.l_suppkey, lineitem.l_extendedprice, lineitem.l_discount +18)----------------Filter: lineitem.l_shipdate >= Date32("1996-01-01") AND lineitem.l_shipdate < Date32("1996-04-01") +19)------------------TableScan: lineitem projection=[l_suppkey, l_extendedprice, l_discount, l_shipdate], partial_filters=[lineitem.l_shipdate >= Date32("1996-01-01"), lineitem.l_shipdate < Date32("1996-04-01")] physical_plan 01)ScalarSubqueryExec: subqueries=1 02)--RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1, maintains_sort_order=true From 3d0b99fec88829618060aab4f9e35609827674e5 Mon Sep 17 00:00:00 2001 From: Neil Conway Date: Tue, 31 Mar 2026 11:19:14 -0400 Subject: [PATCH 18/28] Evaluate subqueries in parallel --- .../physical-plan/src/scalar_subquery.rs | 22 ++++++++++++------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/datafusion/physical-plan/src/scalar_subquery.rs b/datafusion/physical-plan/src/scalar_subquery.rs index 980348a91a24a..d311a4535b7f1 100644 --- a/datafusion/physical-plan/src/scalar_subquery.rs +++ b/datafusion/physical-plan/src/scalar_subquery.rs @@ -73,9 +73,6 @@ pub struct ScalarSubqueryLink { /// TODO: Consider overlapping computation of the subqueries with evaluating the /// main query. /// -/// TODO: Subqueries are evaluated sequentially. Consider parallel evaluation in -/// the future. -/// /// [`ScalarSubqueryExpr`]: datafusion_physical_expr::scalar_subquery::ScalarSubqueryExpr #[derive(Debug)] pub struct ScalarSubqueryExec { @@ -257,11 +254,20 @@ async fn execute_subqueries( results: ScalarSubqueryResults, context: Arc, ) -> Result<()> { - for sq in &subqueries { - let value = - execute_scalar_subquery(Arc::clone(&sq.plan), Arc::clone(&context)).await?; - let _ = results[sq.index].set(value.clone()); - } + // Evaluate subqueries in parallel; wait for them all to finish evaluation + // before returning. + let futures = subqueries.iter().map(|sq| { + let plan = Arc::clone(&sq.plan); + let ctx = Arc::clone(&context); + let results = Arc::clone(&results); + let index = sq.index; + async move { + let value = execute_scalar_subquery(plan, ctx).await?; + let _ = results[index].set(value); + Ok(()) as Result<()> + } + }); + futures::future::try_join_all(futures).await?; Ok(()) } From b02abf80e54614b8c04612285016b846cb773277 Mon Sep 17 00:00:00 2001 From: Neil Conway Date: Thu, 2 Apr 2026 10:30:33 -0400 Subject: [PATCH 19/28] Don't try to use subquery filters for partition pruning --- datafusion/expr/src/expr.rs | 6 ++++ datafusion/optimizer/src/push_down_filter.rs | 14 ++++++-- .../proto/src/physical_plan/from_proto.rs | 11 +++--- .../sqllogictest/test_files/subquery.slt | 35 +++++++++++++++++++ 4 files changed, 59 insertions(+), 7 deletions(-) diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 7dba4f0579b97..c3dffe553dc70 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -2059,6 +2059,12 @@ impl Expr { .expect("exists closure is infallible") } + /// Returns true if the expression contains a scalar subquery. + pub fn contains_scalar_subquery(&self) -> bool { + self.exists(|expr| Ok(matches!(expr, Expr::ScalarSubquery(_)))) + .expect("exists closure is infallible") + } + /// Returns true if the expression node is volatile, i.e. whether it can return /// different results when evaluated multiple times with the same input. /// Note: unlike [`Self::is_volatile`], this function does not consider inputs: diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index a1a636cfef9af..5f18bec82ddd1 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -1144,8 +1144,16 @@ impl OptimizerRule for PushDownFilter { LogicalPlan::TableScan(scan) => { let filter_predicates = split_conjunction(&filter.predicate); - let (volatile_filters, non_volatile_filters): (Vec<&Expr>, Vec<&Expr>) = + // Filters containing scalar subqueries cannot be pushed to + // providers because the subquery result is not available + // until execution time (via ScalarSubqueryExec). + let (subquery_filters, pushdown_candidates): (Vec<&Expr>, Vec<&Expr>) = filter_predicates + .into_iter() + .partition(|pred| pred.contains_scalar_subquery()); + + let (volatile_filters, non_volatile_filters): (Vec<&Expr>, Vec<&Expr>) = + pushdown_candidates .into_iter() .partition(|pred| pred.is_volatile()); @@ -1178,11 +1186,13 @@ impl OptimizerRule for PushDownFilter { .cloned() .collect(); - // Compose predicates to be of `Unsupported` or `Inexact` pushdown type, and also include volatile filters + // Compose predicates to be of `Unsupported` or `Inexact` pushdown type, + // and also include volatile and subquery-containing filters let new_predicate: Vec = zip .filter(|(_, res)| res != &TableProviderFilterPushDown::Exact) .map(|(pred, _)| pred) .chain(volatile_filters) + .chain(subquery_filters) .cloned() .collect(); diff --git a/datafusion/proto/src/physical_plan/from_proto.rs b/datafusion/proto/src/physical_plan/from_proto.rs index c01714e765e29..cb969e861ca0c 100644 --- a/datafusion/proto/src/physical_plan/from_proto.rs +++ b/datafusion/proto/src/physical_plan/from_proto.rs @@ -499,13 +499,14 @@ pub fn parse_physical_expr_with_converter( proto_error("Missing data_type in PhysicalScalarSubqueryExprNode") })? .try_into()?; - // Use the results container from the converter if available - // (set by ScalarSubqueryExec deserialization), otherwise fall - // back to an empty placeholder for standalone expression - // deserialization. let results = proto_converter .scalar_subquery_results() - .unwrap_or_else(|| Arc::new(vec![])); + .ok_or_else(|| { + proto_error( + "ScalarSubqueryExpr can only be deserialized as part \ + of a surrounding ScalarSubqueryExec", + ) + })?; Arc::new(ScalarSubqueryExpr::new( data_type, sq.nullable, diff --git a/datafusion/sqllogictest/test_files/subquery.slt b/datafusion/sqllogictest/test_files/subquery.slt index a827e71ffc582..033de42f871df 100644 --- a/datafusion/sqllogictest/test_files/subquery.slt +++ b/datafusion/sqllogictest/test_files/subquery.slt @@ -1984,3 +1984,38 @@ DROP TABLE sq_name_t1; statement ok DROP TABLE sq_name_t2; + +# Test: scalar subquery in filter on a partition column of a partitioned table. +# This exercises the code path where filters are pushed down to the table +# provider for partition pruning. Scalar subqueries must not be pushed to the +# provider because the subquery result is not available at partition listing +# time. + +query I +COPY (VALUES(1, 'a'), (2, 'b'), (3, 'c')) +TO 'test_files/scratch/subquery/partition_pruning/part=1/file1.parquet'; +---- +3 + +query I +COPY (VALUES(4, 'd'), (5, 'e')) +TO 'test_files/scratch/subquery/partition_pruning/part=2/file1.parquet'; +---- +2 + +statement ok +CREATE EXTERNAL TABLE subquery_partitioned +STORED AS PARQUET +LOCATION 'test_files/scratch/subquery/partition_pruning/'; + +query IT +SELECT column1, column2 FROM subquery_partitioned +WHERE part = (SELECT 1) +ORDER BY column1; +---- +1 a +2 b +3 c + +statement ok +DROP TABLE subquery_partitioned; From 3971312d7a4a15dd723ea1c4eccf97612580bafa Mon Sep 17 00:00:00 2001 From: Neil Conway Date: Thu, 2 Apr 2026 10:32:31 -0400 Subject: [PATCH 20/28] Raise an error if duplicate subquery eval is detected --- datafusion/physical-plan/src/scalar_subquery.rs | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/datafusion/physical-plan/src/scalar_subquery.rs b/datafusion/physical-plan/src/scalar_subquery.rs index d311a4535b7f1..d960ac326727b 100644 --- a/datafusion/physical-plan/src/scalar_subquery.rs +++ b/datafusion/physical-plan/src/scalar_subquery.rs @@ -263,7 +263,11 @@ async fn execute_subqueries( let index = sq.index; async move { let value = execute_scalar_subquery(plan, ctx).await?; - let _ = results[index].set(value); + if results[index].set(value).is_err() { + return internal_err!( + "ScalarSubqueryExec: result for index {index} was already populated" + ); + } Ok(()) as Result<()> } }); From 64e9f34ccf9ea065fa30974fb586508643cef20a Mon Sep 17 00:00:00 2001 From: Neil Conway Date: Thu, 2 Apr 2026 10:35:52 -0400 Subject: [PATCH 21/28] cargo fmt --- datafusion/proto/src/physical_plan/from_proto.rs | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/datafusion/proto/src/physical_plan/from_proto.rs b/datafusion/proto/src/physical_plan/from_proto.rs index cb969e861ca0c..770e6a289800b 100644 --- a/datafusion/proto/src/physical_plan/from_proto.rs +++ b/datafusion/proto/src/physical_plan/from_proto.rs @@ -499,14 +499,12 @@ pub fn parse_physical_expr_with_converter( proto_error("Missing data_type in PhysicalScalarSubqueryExprNode") })? .try_into()?; - let results = proto_converter - .scalar_subquery_results() - .ok_or_else(|| { - proto_error( - "ScalarSubqueryExpr can only be deserialized as part \ + let results = proto_converter.scalar_subquery_results().ok_or_else(|| { + proto_error( + "ScalarSubqueryExpr can only be deserialized as part \ of a surrounding ScalarSubqueryExec", - ) - })?; + ) + })?; Arc::new(ScalarSubqueryExpr::new( data_type, sq.nullable, From 26d8acb37dab46d05e9104ecf225aefa433bbb06 Mon Sep 17 00:00:00 2001 From: Neil Conway Date: Thu, 2 Apr 2026 11:18:59 -0400 Subject: [PATCH 22/28] Update expected plan --- .../sqllogictest/test_files/tpch/plans/q22.slt.part | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/datafusion/sqllogictest/test_files/tpch/plans/q22.slt.part b/datafusion/sqllogictest/test_files/tpch/plans/q22.slt.part index fb63f9ede2b81..3240cbfb697d5 100644 --- a/datafusion/sqllogictest/test_files/tpch/plans/q22.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/plans/q22.slt.part @@ -68,14 +68,9 @@ logical_plan 10)------------------Projection: customer.c_acctbal 11)--------------------Filter: customer.c_acctbal > Decimal128(Some(0),15,2) AND substr(customer.c_phone, Int64(1), Int64(2)) IN ([Utf8View("13"), Utf8View("31"), Utf8View("23"), Utf8View("29"), Utf8View("30"), Utf8View("18"), Utf8View("17")]) 12)----------------------TableScan: customer projection=[c_phone, c_acctbal], partial_filters=[customer.c_acctbal > Decimal128(Some(0),15,2), substr(customer.c_phone, Int64(1), Int64(2)) IN ([Utf8View("13"), Utf8View("31"), Utf8View("23"), Utf8View("29"), Utf8View("30"), Utf8View("18"), Utf8View("17")])] -13)--------------TableScan: customer projection=[c_custkey, c_phone, c_acctbal], partial_filters=[substr(customer.c_phone, Int64(1), Int64(2)) IN ([Utf8View("13"), Utf8View("31"), Utf8View("23"), Utf8View("29"), Utf8View("30"), Utf8View("18"), Utf8View("17")]), CAST(customer.c_acctbal AS Decimal128(19, 6)) > ()] -14)----------------Subquery: -15)------------------Aggregate: groupBy=[[]], aggr=[[avg(customer.c_acctbal)]] -16)--------------------Projection: customer.c_acctbal -17)----------------------Filter: customer.c_acctbal > Decimal128(Some(0),15,2) AND substr(customer.c_phone, Int64(1), Int64(2)) IN ([Utf8View("13"), Utf8View("31"), Utf8View("23"), Utf8View("29"), Utf8View("30"), Utf8View("18"), Utf8View("17")]) -18)------------------------TableScan: customer projection=[c_phone, c_acctbal], partial_filters=[customer.c_acctbal > Decimal128(Some(0),15,2), substr(customer.c_phone, Int64(1), Int64(2)) IN ([Utf8View("13"), Utf8View("31"), Utf8View("23"), Utf8View("29"), Utf8View("30"), Utf8View("18"), Utf8View("17")])] -19)------------SubqueryAlias: __correlated_sq_1 -20)--------------TableScan: orders projection=[o_custkey] +13)--------------TableScan: customer projection=[c_custkey, c_phone, c_acctbal], partial_filters=[substr(customer.c_phone, Int64(1), Int64(2)) IN ([Utf8View("13"), Utf8View("31"), Utf8View("23"), Utf8View("29"), Utf8View("30"), Utf8View("18"), Utf8View("17")])] +14)------------SubqueryAlias: __correlated_sq_1 +15)--------------TableScan: orders projection=[o_custkey] physical_plan 01)ScalarSubqueryExec: subqueries=1 02)--RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1, maintains_sort_order=true From f9c9d5d5e81e692bd143e9ff55a5ea8b799b90f7 Mon Sep 17 00:00:00 2001 From: Neil Conway Date: Fri, 3 Apr 2026 10:21:21 -0400 Subject: [PATCH 23/28] Remove unnecessary IN/EXISTS serialization code --- datafusion/optimizer/src/push_down_filter.rs | 2 +- datafusion/proto/proto/datafusion.proto | 13 - datafusion/proto/src/generated/pbjson.rs | 261 ------------------ datafusion/proto/src/generated/prost.rs | 22 +- .../proto/src/logical_plan/from_proto.rs | 34 +-- datafusion/proto/src/logical_plan/to_proto.rs | 28 +- 6 files changed, 9 insertions(+), 351 deletions(-) diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index 5f18bec82ddd1..3be9c97cff8c4 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -1146,7 +1146,7 @@ impl OptimizerRule for PushDownFilter { // Filters containing scalar subqueries cannot be pushed to // providers because the subquery result is not available - // until execution time (via ScalarSubqueryExec). + // until execution time. let (subquery_filters, pushdown_candidates): (Vec<&Expr>, Vec<&Expr>) = filter_predicates .into_iter() diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 36f5f882a4ba3..bf4aa65937a01 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -428,8 +428,6 @@ message LogicalExprNode { // Subquery expressions ScalarSubqueryExprNode scalar_subquery_expr = 36; - InSubqueryExprNode in_subquery_expr = 37; - ExistsExprNode exists_expr = 38; } } @@ -446,17 +444,6 @@ message ScalarSubqueryExprNode { SubqueryNode subquery = 1; } -message InSubqueryExprNode { - LogicalExprNode expr = 1; - SubqueryNode subquery = 2; - bool negated = 3; -} - -message ExistsExprNode { - SubqueryNode subquery = 1; - bool negated = 2; -} - message PlaceholderNode { string id = 1; // We serialize the data type, metadata, and nullability separately to maintain diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index b32b7f9a63f94..a0715885dcfb4 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -5914,114 +5914,6 @@ impl<'de> serde::Deserialize<'de> for EmptyRelationNode { deserializer.deserialize_struct("datafusion.EmptyRelationNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for ExistsExprNode { - #[allow(deprecated)] - fn serialize(&self, serializer: S) -> std::result::Result - where - S: serde::Serializer, - { - use serde::ser::SerializeStruct; - let mut len = 0; - if self.subquery.is_some() { - len += 1; - } - if self.negated { - len += 1; - } - let mut struct_ser = serializer.serialize_struct("datafusion.ExistsExprNode", len)?; - if let Some(v) = self.subquery.as_ref() { - struct_ser.serialize_field("subquery", v)?; - } - if self.negated { - struct_ser.serialize_field("negated", &self.negated)?; - } - struct_ser.end() - } -} -impl<'de> serde::Deserialize<'de> for ExistsExprNode { - #[allow(deprecated)] - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - const FIELDS: &[&str] = &[ - "subquery", - "negated", - ]; - - #[allow(clippy::enum_variant_names)] - enum GeneratedField { - Subquery, - Negated, - } - impl<'de> serde::Deserialize<'de> for GeneratedField { - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - struct GeneratedVisitor; - - impl serde::de::Visitor<'_> for GeneratedVisitor { - type Value = GeneratedField; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(formatter, "expected one of: {:?}", &FIELDS) - } - - #[allow(unused_variables)] - fn visit_str(self, value: &str) -> std::result::Result - where - E: serde::de::Error, - { - match value { - "subquery" => Ok(GeneratedField::Subquery), - "negated" => Ok(GeneratedField::Negated), - _ => Err(serde::de::Error::unknown_field(value, FIELDS)), - } - } - } - deserializer.deserialize_identifier(GeneratedVisitor) - } - } - struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = ExistsExprNode; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.ExistsExprNode") - } - - fn visit_map(self, mut map_: V) -> std::result::Result - where - V: serde::de::MapAccess<'de>, - { - let mut subquery__ = None; - let mut negated__ = None; - while let Some(k) = map_.next_key()? { - match k { - GeneratedField::Subquery => { - if subquery__.is_some() { - return Err(serde::de::Error::duplicate_field("subquery")); - } - subquery__ = map_.next_value()?; - } - GeneratedField::Negated => { - if negated__.is_some() { - return Err(serde::de::Error::duplicate_field("negated")); - } - negated__ = Some(map_.next_value()?); - } - } - } - Ok(ExistsExprNode { - subquery: subquery__, - negated: negated__.unwrap_or_default(), - }) - } - } - deserializer.deserialize_struct("datafusion.ExistsExprNode", FIELDS, GeneratedVisitor) - } -} impl serde::Serialize for ExplainExecNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result @@ -9369,131 +9261,6 @@ impl<'de> serde::Deserialize<'de> for InListNode { deserializer.deserialize_struct("datafusion.InListNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for InSubqueryExprNode { - #[allow(deprecated)] - fn serialize(&self, serializer: S) -> std::result::Result - where - S: serde::Serializer, - { - use serde::ser::SerializeStruct; - let mut len = 0; - if self.expr.is_some() { - len += 1; - } - if self.subquery.is_some() { - len += 1; - } - if self.negated { - len += 1; - } - let mut struct_ser = serializer.serialize_struct("datafusion.InSubqueryExprNode", len)?; - if let Some(v) = self.expr.as_ref() { - struct_ser.serialize_field("expr", v)?; - } - if let Some(v) = self.subquery.as_ref() { - struct_ser.serialize_field("subquery", v)?; - } - if self.negated { - struct_ser.serialize_field("negated", &self.negated)?; - } - struct_ser.end() - } -} -impl<'de> serde::Deserialize<'de> for InSubqueryExprNode { - #[allow(deprecated)] - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - const FIELDS: &[&str] = &[ - "expr", - "subquery", - "negated", - ]; - - #[allow(clippy::enum_variant_names)] - enum GeneratedField { - Expr, - Subquery, - Negated, - } - impl<'de> serde::Deserialize<'de> for GeneratedField { - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - struct GeneratedVisitor; - - impl serde::de::Visitor<'_> for GeneratedVisitor { - type Value = GeneratedField; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(formatter, "expected one of: {:?}", &FIELDS) - } - - #[allow(unused_variables)] - fn visit_str(self, value: &str) -> std::result::Result - where - E: serde::de::Error, - { - match value { - "expr" => Ok(GeneratedField::Expr), - "subquery" => Ok(GeneratedField::Subquery), - "negated" => Ok(GeneratedField::Negated), - _ => Err(serde::de::Error::unknown_field(value, FIELDS)), - } - } - } - deserializer.deserialize_identifier(GeneratedVisitor) - } - } - struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = InSubqueryExprNode; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.InSubqueryExprNode") - } - - fn visit_map(self, mut map_: V) -> std::result::Result - where - V: serde::de::MapAccess<'de>, - { - let mut expr__ = None; - let mut subquery__ = None; - let mut negated__ = None; - while let Some(k) = map_.next_key()? { - match k { - GeneratedField::Expr => { - if expr__.is_some() { - return Err(serde::de::Error::duplicate_field("expr")); - } - expr__ = map_.next_value()?; - } - GeneratedField::Subquery => { - if subquery__.is_some() { - return Err(serde::de::Error::duplicate_field("subquery")); - } - subquery__ = map_.next_value()?; - } - GeneratedField::Negated => { - if negated__.is_some() { - return Err(serde::de::Error::duplicate_field("negated")); - } - negated__ = Some(map_.next_value()?); - } - } - } - Ok(InSubqueryExprNode { - expr: expr__, - subquery: subquery__, - negated: negated__.unwrap_or_default(), - }) - } - } - deserializer.deserialize_struct("datafusion.InSubqueryExprNode", FIELDS, GeneratedVisitor) - } -} impl serde::Serialize for InsertOp { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result @@ -12429,12 +12196,6 @@ impl serde::Serialize for LogicalExprNode { logical_expr_node::ExprType::ScalarSubqueryExpr(v) => { struct_ser.serialize_field("scalarSubqueryExpr", v)?; } - logical_expr_node::ExprType::InSubqueryExpr(v) => { - struct_ser.serialize_field("inSubqueryExpr", v)?; - } - logical_expr_node::ExprType::ExistsExpr(v) => { - struct_ser.serialize_field("existsExpr", v)?; - } } } struct_ser.end() @@ -12498,10 +12259,6 @@ impl<'de> serde::Deserialize<'de> for LogicalExprNode { "unnest", "scalar_subquery_expr", "scalarSubqueryExpr", - "in_subquery_expr", - "inSubqueryExpr", - "exists_expr", - "existsExpr", ]; #[allow(clippy::enum_variant_names)] @@ -12538,8 +12295,6 @@ impl<'de> serde::Deserialize<'de> for LogicalExprNode { Placeholder, Unnest, ScalarSubqueryExpr, - InSubqueryExpr, - ExistsExpr, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -12593,8 +12348,6 @@ impl<'de> serde::Deserialize<'de> for LogicalExprNode { "placeholder" => Ok(GeneratedField::Placeholder), "unnest" => Ok(GeneratedField::Unnest), "scalarSubqueryExpr" | "scalar_subquery_expr" => Ok(GeneratedField::ScalarSubqueryExpr), - "inSubqueryExpr" | "in_subquery_expr" => Ok(GeneratedField::InSubqueryExpr), - "existsExpr" | "exists_expr" => Ok(GeneratedField::ExistsExpr), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -12839,20 +12592,6 @@ impl<'de> serde::Deserialize<'de> for LogicalExprNode { return Err(serde::de::Error::duplicate_field("scalarSubqueryExpr")); } expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::ScalarSubqueryExpr) -; - } - GeneratedField::InSubqueryExpr => { - if expr_type__.is_some() { - return Err(serde::de::Error::duplicate_field("inSubqueryExpr")); - } - expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::InSubqueryExpr) -; - } - GeneratedField::ExistsExpr => { - if expr_type__.is_some() { - return Err(serde::de::Error::duplicate_field("existsExpr")); - } - expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::ExistsExpr) ; } } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 6c4f898909fb1..e58ef5b0145be 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -578,7 +578,7 @@ pub struct SubqueryAliasNode { pub struct LogicalExprNode { #[prost( oneof = "logical_expr_node::ExprType", - tags = "1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 13, 14, 15, 17, 18, 19, 20, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38" + tags = "1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 13, 14, 15, 17, 18, 19, 20, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36" )] pub expr_type: ::core::option::Option, } @@ -659,10 +659,6 @@ pub mod logical_expr_node { /// Subquery expressions #[prost(message, tag = "36")] ScalarSubqueryExpr(::prost::alloc::boxed::Box), - #[prost(message, tag = "37")] - InSubqueryExpr(::prost::alloc::boxed::Box), - #[prost(message, tag = "38")] - ExistsExpr(::prost::alloc::boxed::Box), } } #[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] @@ -683,22 +679,6 @@ pub struct ScalarSubqueryExprNode { pub subquery: ::core::option::Option<::prost::alloc::boxed::Box>, } #[derive(Clone, PartialEq, ::prost::Message)] -pub struct InSubqueryExprNode { - #[prost(message, optional, boxed, tag = "1")] - pub expr: ::core::option::Option<::prost::alloc::boxed::Box>, - #[prost(message, optional, boxed, tag = "2")] - pub subquery: ::core::option::Option<::prost::alloc::boxed::Box>, - #[prost(bool, tag = "3")] - pub negated: bool, -} -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct ExistsExprNode { - #[prost(message, optional, boxed, tag = "1")] - pub subquery: ::core::option::Option<::prost::alloc::boxed::Box>, - #[prost(bool, tag = "2")] - pub negated: bool, -} -#[derive(Clone, PartialEq, ::prost::Message)] pub struct PlaceholderNode { #[prost(string, tag = "1")] pub id: ::prost::alloc::string::String, diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 6254507509724..78ffd362c8e48 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -26,9 +26,7 @@ use datafusion_common::{ use datafusion_execution::TaskContext; use datafusion_execution::registry::FunctionRegistry; use datafusion_expr::dml::InsertOp; -use datafusion_expr::expr::{ - Alias, Exists, InSubquery, NullTreatment, Placeholder, Sort, -}; +use datafusion_expr::expr::{Alias, NullTreatment, Placeholder, Sort}; use datafusion_expr::expr::{Unnest, WildcardOptions}; use datafusion_expr::logical_plan::Subquery; use datafusion_expr::{ @@ -669,36 +667,6 @@ pub fn parse_expr( )?; Ok(Expr::ScalarSubquery(subquery)) } - ExprType::InSubqueryExpr(sq) => { - let expr = parse_required_expr( - sq.expr.as_deref(), - ctx, - "InSubqueryExprNode.expr", - codec, - )?; - let subquery = parse_subquery( - sq.subquery - .as_deref() - .ok_or_else(|| Error::required("InSubqueryExprNode.subquery"))?, - ctx, - codec, - )?; - Ok(Expr::InSubquery(InSubquery::new( - Box::new(expr), - subquery, - sq.negated, - ))) - } - ExprType::ExistsExpr(sq) => { - let subquery = parse_subquery( - sq.subquery - .as_deref() - .ok_or_else(|| Error::required("ExistsExprNode.subquery"))?, - ctx, - codec, - )?; - Ok(Expr::Exists(Exists::new(subquery, sq.negated))) - } } } diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 1a5bd7b5ae41e..9ee0d73293a39 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -25,8 +25,8 @@ use datafusion_common::{NullEquality, TableReference, UnnestOptions}; use datafusion_expr::WriteOp; use datafusion_expr::dml::InsertOp; use datafusion_expr::expr::{ - self, AggregateFunctionParams, Alias, Between, BinaryExpr, Cast, Exists, GroupingSet, - InList, InSubquery, Like, NullTreatment, Placeholder, ScalarFunction, Unnest, + self, AggregateFunctionParams, Alias, Between, BinaryExpr, Cast, GroupingSet, + InList, Like, NullTreatment, Placeholder, ScalarFunction, Unnest, }; use datafusion_expr::logical_plan::Subquery; use datafusion_expr::{ @@ -588,26 +588,10 @@ pub fn serialize_expr( }, ))), }, - Expr::InSubquery(InSubquery { - expr, - subquery, - negated, - }) => protobuf::LogicalExprNode { - expr_type: Some(ExprType::InSubqueryExpr(Box::new( - protobuf::InSubqueryExprNode { - expr: Some(Box::new(serialize_expr(expr.as_ref(), codec)?)), - subquery: Some(Box::new(serialize_subquery(subquery, codec)?)), - negated: *negated, - }, - ))), - }, - Expr::Exists(Exists { subquery, negated }) => protobuf::LogicalExprNode { - expr_type: Some(ExprType::ExistsExpr(Box::new(protobuf::ExistsExprNode { - subquery: Some(Box::new(serialize_subquery(subquery, codec)?)), - negated: *negated, - }))), - }, - Expr::OuterReferenceColumn(_, _) | Expr::SetComparison(_) => { + Expr::InSubquery(_) + | Expr::Exists(_) + | Expr::OuterReferenceColumn(_, _) + | Expr::SetComparison(_) => { return Err(Error::General(format!( "Proto serialization error: {expr} is not yet supported" ))); From 92e60543770498fb606c24ca894e0587039df0a5 Mon Sep 17 00:00:00 2001 From: Neil Conway Date: Fri, 3 Apr 2026 10:34:08 -0400 Subject: [PATCH 24/28] Code cleanup --- datafusion/core/src/physical_planner.rs | 10 ++++---- datafusion/expr/src/execution_props.rs | 5 ++++ .../physical-expr/src/scalar_subquery.rs | 8 +++++++ datafusion/proto/src/physical_plan/mod.rs | 24 ++++++++++++++----- .../proto/src/physical_plan/to_proto.rs | 4 ++-- 5 files changed, 39 insertions(+), 12 deletions(-) diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 2c09fbaae0b22..94e74ecee058a 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -19,7 +19,7 @@ use std::borrow::Cow; use std::collections::{HashMap, HashSet}; -use std::sync::{Arc, OnceLock}; +use std::sync::Arc; use crate::datasource::file_format::file_type_to_format; use crate::datasource::listing::ListingTableUrl; @@ -79,6 +79,9 @@ use datafusion_common::{ use datafusion_datasource::file_groups::FileGroup; use datafusion_datasource::memory::MemorySourceConfig; use datafusion_expr::dml::{CopyTo, InsertOp}; +use datafusion_expr::execution_props::{ + ScalarSubqueryResults, new_scalar_subquery_results, +}; use datafusion_expr::expr::{ AggregateFunction, AggregateFunctionParams, Alias, GroupingSet, NullTreatment, WindowFunction, WindowFunctionParams, physical_name, @@ -449,8 +452,7 @@ impl DefaultPhysicalPlanner { // context rather than on ExecutionProps. It's here because // `create_physical_expr` only receives `&ExecutionProps`, and // changing that signature would be a breaking public API change. - let results: Arc>> = - Arc::new((0..links.len()).map(|_| OnceLock::new()).collect()); + let results = new_scalar_subquery_results(links.len()); let session_state = if links.is_empty() { Cow::Borrowed(session_state) } else { @@ -2960,7 +2962,7 @@ impl DefaultPhysicalPlanner { fn wrap_scalar_subquery_exec_if_needed( input: Arc, subqueries: Vec, - results: Arc>>, + results: ScalarSubqueryResults, ) -> Arc { if subqueries.is_empty() { input diff --git a/datafusion/expr/src/execution_props.rs b/datafusion/expr/src/execution_props.rs index 0c9cfe2e06d1a..284472107ec9a 100644 --- a/datafusion/expr/src/execution_props.rs +++ b/datafusion/expr/src/execution_props.rs @@ -30,6 +30,11 @@ use std::sync::{Arc, OnceLock}; /// read by `ScalarSubqueryExpr` instances that share this container. pub type ScalarSubqueryResults = Arc>>; +/// Creates a [`ScalarSubqueryResults`] container with `n` empty slots. +pub fn new_scalar_subquery_results(n: usize) -> ScalarSubqueryResults { + Arc::new((0..n).map(|_| OnceLock::new()).collect()) +} + /// Holds per-query execution properties and data (such as statement /// starting timestamps). /// diff --git a/datafusion/physical-expr/src/scalar_subquery.rs b/datafusion/physical-expr/src/scalar_subquery.rs index 86548fac882c7..0a39091c73f0e 100644 --- a/datafusion/physical-expr/src/scalar_subquery.rs +++ b/datafusion/physical-expr/src/scalar_subquery.rs @@ -65,6 +65,14 @@ impl ScalarSubqueryExpr { } } + pub fn data_type(&self) -> &DataType { + &self.data_type + } + + pub fn nullable(&self) -> bool { + self.nullable + } + /// Returns the index of this subquery in the shared results container. pub fn index(&self) -> usize { self.index diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index a28fbb463ddc7..421e5002f1a24 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -51,7 +51,9 @@ use datafusion_datasource_parquet::source::ParquetSource; #[cfg(feature = "parquet")] use datafusion_execution::object_store::ObjectStoreUrl; use datafusion_execution::{FunctionRegistry, TaskContext}; -use datafusion_expr::execution_props::ScalarSubqueryResults; +use datafusion_expr::execution_props::{ + ScalarSubqueryResults, new_scalar_subquery_results, +}; use datafusion_expr::{AggregateUDF, ScalarUDF, WindowUDF}; use datafusion_functions_table::generate_series::{ Empty, GenSeriesArgs, GenerateSeriesTable, GenericSeriesState, TimestampValue, @@ -2274,11 +2276,7 @@ impl protobuf::PhysicalPlanNode { // proto. Making it available on the converter before deserializing // the input plan allows ScalarSubqueryExpr nodes to pick up the // shared container at construction time. - let results = Arc::new( - (0..sq.subqueries.len()) - .map(|_| std::sync::OnceLock::new()) - .collect::>(), - ); + let results = new_scalar_subquery_results(sq.subqueries.len()); let prev = proto_converter.set_scalar_subquery_results(Some(Arc::clone(&results))); let input = into_physical_plan(&sq.input, ctx, codec, proto_converter); @@ -2300,6 +2298,7 @@ impl protobuf::PhysicalPlanNode { Ok(ScalarSubqueryLink { plan, index }) }) .collect::>>()?; + Ok(Arc::new(ScalarSubqueryExec::new( input, subqueries, results, ))) @@ -3870,6 +3869,16 @@ pub trait PhysicalProtoConverterExtension { /// During `ScalarSubqueryExec` deserialization, this is called before /// deserializing the input plan so that `ScalarSubqueryExpr` nodes can /// pick up the shared results container at construction time. + /// + /// The default implementation discards the value and always returns `None`. + /// This means `ScalarSubqueryExpr` deserialization will fail with an error, + /// Implementations that need scalar subquery support should override both + /// this method and `scalar_subquery_results`. + /// + /// NOTE: These methods use interior mutability (`RefCell`) to thread state + /// through `&self`, because the deserialization call stack does not have an + /// explicit context parameter. A dedicated deserialization context struct + /// would be a cleaner long-term solution. fn set_scalar_subquery_results( &self, _results: Option, @@ -3878,6 +3887,9 @@ pub trait PhysicalProtoConverterExtension { } /// Returns the current scalar subquery results container, if any. + /// + /// See [`set_scalar_subquery_results`](Self::set_scalar_subquery_results) + /// for details on the default behavior and design trade-offs. fn scalar_subquery_results(&self) -> Option { None } diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index e41ea85b93dcd..a1fa57ea5bdb6 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -510,8 +510,8 @@ pub fn serialize_physical_expr_with_converter( expr_id: None, expr_type: Some(protobuf::physical_expr_node::ExprType::ScalarSubquery( protobuf::PhysicalScalarSubqueryExprNode { - data_type: Some((&expr.data_type(&Schema::empty())?).try_into()?), - nullable: expr.nullable(&Schema::empty())?, + data_type: Some(expr.data_type().try_into()?), + nullable: expr.nullable(), index: expr.index() as u32, }, )), From 6857966c8977c701710ac000559fc965cc42abc9 Mon Sep 17 00:00:00 2001 From: Neil Conway Date: Fri, 3 Apr 2026 10:39:08 -0400 Subject: [PATCH 25/28] Code cleanup --- datafusion/proto/src/physical_plan/mod.rs | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index 421e5002f1a24..4cc7b1234380b 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -2272,19 +2272,16 @@ impl protobuf::PhysicalPlanNode { codec: &dyn PhysicalExtensionCodec, proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { - // Create the results container upfront — we know the count from the - // proto. Making it available on the converter before deserializing - // the input plan allows ScalarSubqueryExpr nodes to pick up the - // shared container at construction time. - let results = new_scalar_subquery_results(sq.subqueries.len()); - let prev = - proto_converter.set_scalar_subquery_results(Some(Arc::clone(&results))); + // First, deserialize the main input plan. We set up the subquery results + // container first, so that ScalarSubqueryExpr nodes can reference it. + let subquery_results = new_scalar_subquery_results(sq.subqueries.len()); + let prev = proto_converter + .set_scalar_subquery_results(Some(Arc::clone(&subquery_results))); let input = into_physical_plan(&sq.input, ctx, codec, proto_converter); - // Restore previous state before propagating errors, so nested - // ScalarSubqueryExec deserialization doesn't see stale state. proto_converter.set_scalar_subquery_results(prev); - let input: Arc = input?; + let input = input?; + // Now deserialize the subquery children. let subqueries: Vec = sq .subqueries .iter() @@ -2300,7 +2297,9 @@ impl protobuf::PhysicalPlanNode { .collect::>>()?; Ok(Arc::new(ScalarSubqueryExec::new( - input, subqueries, results, + input, + subqueries, + subquery_results, ))) } From 6a4f524a5535b6e0d700b044cf4de5cb20037fd7 Mon Sep 17 00:00:00 2001 From: Neil Conway Date: Fri, 3 Apr 2026 10:51:34 -0400 Subject: [PATCH 26/28] Code cleanup and refactoring --- datafusion/expr/src/logical_plan/tree_node.rs | 20 ++++++++++++--- .../optimizer/src/common_subexpr_eliminate.rs | 9 +------ .../optimizer/src/eliminate_cross_join.rs | 11 ++------ .../optimizer/src/optimize_projections/mod.rs | 13 +++------- .../optimizer/src/scalar_subquery_to_join.rs | 6 ++--- .../physical-plan/src/scalar_subquery.rs | 25 +++++++++---------- 6 files changed, 38 insertions(+), 46 deletions(-) diff --git a/datafusion/expr/src/logical_plan/tree_node.rs b/datafusion/expr/src/logical_plan/tree_node.rs index a1285510da569..2b1015d4802b5 100644 --- a/datafusion/expr/src/logical_plan/tree_node.rs +++ b/datafusion/expr/src/logical_plan/tree_node.rs @@ -808,7 +808,7 @@ impl LogicalPlan { transform_down_up_with_subqueries_impl(self, &mut f_down, &mut f_up) } - /// Similarly to [`Self::apply`], calls `f` on this node and its inputs + /// Similarly to [`Self::apply`], calls `f` on this node and its inputs, /// including subqueries that may appear in expressions such as `IN (SELECT /// ...)`. pub fn apply_subqueries Result>( @@ -821,9 +821,7 @@ impl LogicalPlan { | Expr::InSubquery(InSubquery { subquery, .. }) | Expr::SetComparison(SetComparison { subquery, .. }) | Expr::ScalarSubquery(subquery) => { - // use a synthetic plan so the collector sees a - // LogicalPlan::Subquery (even though it is - // actually a Subquery alias) + // Wrap in LogicalPlan::Subquery to match f's signature f(&LogicalPlan::Subquery(subquery.clone())) } _ => Ok(TreeNodeRecursion::Continue), @@ -888,4 +886,18 @@ impl LogicalPlan { }) }) } + + /// Similar to [`Self::map_subqueries`], but only applies `f` to + /// uncorrelated subqueries (those with no outer reference columns). + pub fn map_uncorrelated_subqueries Result>>( + self, + mut f: F, + ) -> Result> { + self.map_subqueries(|subquery_plan| match &subquery_plan { + LogicalPlan::Subquery(sq) if sq.outer_ref_columns.is_empty() => { + f(subquery_plan) + } + _ => Ok(Transformed::no(subquery_plan)), + }) + } } diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index 116aafdd30075..b495191f36c75 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -588,14 +588,7 @@ impl OptimizerRule for CommonSubexprEliminate { // This rule handles recursion itself in a `ApplyOrder::TopDown` like // manner. Process uncorrelated subqueries in expressions // (e.g., Expr::ScalarSubquery), then direct children. - // Correlated subqueries are skipped to avoid interfering - // with decorrelation in subsequent optimizer passes. - plan.map_subqueries(|c| match &c { - LogicalPlan::Subquery(sq) if sq.outer_ref_columns.is_empty() => { - self.rewrite(c, config) - } - _ => Ok(Transformed::no(c)), - })? + plan.map_uncorrelated_subqueries(|c| self.rewrite(c, config))? .transform_sibling(|plan| { plan.map_children(|c| self.rewrite(c, config)) })? diff --git a/datafusion/optimizer/src/eliminate_cross_join.rs b/datafusion/optimizer/src/eliminate_cross_join.rs index f5ca2e4779bad..8306d4b54c256 100644 --- a/datafusion/optimizer/src/eliminate_cross_join.rs +++ b/datafusion/optimizer/src/eliminate_cross_join.rs @@ -212,16 +212,9 @@ fn rewrite_children( plan: LogicalPlan, config: &dyn OptimizerConfig, ) -> Result> { - // Process uncorrelated subqueries in expressions, then direct - // children. Correlated subqueries are skipped to avoid interfering - // with decorrelation in subsequent optimizer passes. + // Process uncorrelated subqueries in expressions, then direct children. let transformed_plan = plan - .map_subqueries(|input| match &input { - LogicalPlan::Subquery(sq) if sq.outer_ref_columns.is_empty() => { - optimizer.rewrite(input, config) - } - _ => Ok(Transformed::no(input)), - })? + .map_uncorrelated_subqueries(|input| optimizer.rewrite(input, config))? .transform_sibling(|plan| { plan.map_children(|input| optimizer.rewrite(input, config)) })?; diff --git a/datafusion/optimizer/src/optimize_projections/mod.rs b/datafusion/optimizer/src/optimize_projections/mod.rs index d6adee95e10e7..6bac3abeb8376 100644 --- a/datafusion/optimizer/src/optimize_projections/mod.rs +++ b/datafusion/optimizer/src/optimize_projections/mod.rs @@ -478,19 +478,14 @@ fn optimize_projections( /// /// Optimizes uncorrelated subquery plans embedded in expressions of the given /// plan node (e.g., `Expr::ScalarSubquery`). `map_children` only visits direct -/// plan inputs, so subqueries must be handled separately. Correlated subqueries -/// are skipped to avoid interfering with decorrelation in subsequent optimizer -/// passes. +/// plan inputs, so subqueries must be handled separately. fn optimize_subqueries( plan: LogicalPlan, config: &dyn OptimizerConfig, ) -> Result> { - plan.map_subqueries(|subquery_plan| match &subquery_plan { - LogicalPlan::Subquery(sq) if sq.outer_ref_columns.is_empty() => { - let indices = RequiredIndices::new_for_all_exprs(&subquery_plan); - optimize_projections(subquery_plan, config, indices) - } - _ => Ok(Transformed::no(subquery_plan)), + plan.map_uncorrelated_subqueries(|subquery_plan| { + let indices = RequiredIndices::new_for_all_exprs(&subquery_plan); + optimize_projections(subquery_plan, config, indices) }) } diff --git a/datafusion/optimizer/src/scalar_subquery_to_join.rs b/datafusion/optimizer/src/scalar_subquery_to_join.rs index ad7f6311429b9..66b34f24dedfb 100644 --- a/datafusion/optimizer/src/scalar_subquery_to_join.rs +++ b/datafusion/optimizer/src/scalar_subquery_to_join.rs @@ -90,7 +90,7 @@ impl OptimizerRule for ScalarSubqueryToJoin { LogicalPlan::Filter(filter) => { // Optimization: skip the rest of the rule and its copies if // there are no scalar subqueries - if !contains_scalar_subquery(&filter.predicate) { + if !contains_correlated_scalar_subquery(&filter.predicate) { return Ok(Transformed::no(LogicalPlan::Filter(filter))); } @@ -144,7 +144,7 @@ impl OptimizerRule for ScalarSubqueryToJoin { LogicalPlan::Projection(projection) => { // Optimization: skip the rest of the rule and its copies if there // are no correlated scalar subqueries - if !projection.expr.iter().any(contains_scalar_subquery) { + if !projection.expr.iter().any(contains_correlated_scalar_subquery) { return Ok(Transformed::no(LogicalPlan::Projection(projection))); } @@ -230,7 +230,7 @@ impl OptimizerRule for ScalarSubqueryToJoin { /// Returns true if the expression contains a correlated scalar subquery, false /// otherwise. Uncorrelated scalar subqueries are handled by the physical /// planner via `ScalarSubqueryExec` and do not need to be converted to joins. -fn contains_scalar_subquery(expr: &Expr) -> bool { +fn contains_correlated_scalar_subquery(expr: &Expr) -> bool { expr.exists(|expr| { Ok(matches!(expr, Expr::ScalarSubquery(sq) if !sq.outer_ref_columns.is_empty())) }) diff --git a/datafusion/physical-plan/src/scalar_subquery.rs b/datafusion/physical-plan/src/scalar_subquery.rs index d960ac326727b..f0cb8b0cce7ef 100644 --- a/datafusion/physical-plan/src/scalar_subquery.rs +++ b/datafusion/physical-plan/src/scalar_subquery.rs @@ -284,6 +284,7 @@ async fn execute_scalar_subquery( ) -> Result { let schema = plan.schema(); if schema.fields().len() != 1 { + // Should be enforced by the physical planner. return internal_err!( "Scalar subquery must return exactly one column, got {}", schema.fields().len() @@ -291,25 +292,23 @@ async fn execute_scalar_subquery( } let mut stream = crate::execute_stream(plan, context)?; + let mut result: Option = None; - let mut total_rows = 0usize; - let mut result_value: Option = None; while let Some(batch) = stream.next().await.transpose()? { - total_rows += batch.num_rows(); - if total_rows > 1 { - return exec_err!( - "Scalar subquery returned more than one row (got at least {total_rows})" - ); + if batch.num_rows() == 0 { + continue; } - if batch.num_rows() == 1 { - result_value = Some(ScalarValue::try_from_array(batch.column(0), 0)?); + if result.is_some() || batch.num_rows() > 1 { + return exec_err!("Scalar subquery returned more than one row"); } + result = Some(ScalarValue::try_from_array(batch.column(0), 0)?); } - // 0 rows → NULL of the appropriate type - Ok(result_value.unwrap_or_else(|| { - ScalarValue::try_from(schema.field(0).data_type()).unwrap_or(ScalarValue::Null) - })) + // 0 rows → typed NULL per SQL semantics + match result { + Some(v) => Ok(v), + None => ScalarValue::try_from(schema.field(0).data_type()), + } } #[cfg(test)] From 670139c43ce01432daf48f555f4fdc313eb8dcd5 Mon Sep 17 00:00:00 2001 From: Neil Conway Date: Fri, 3 Apr 2026 11:32:28 -0400 Subject: [PATCH 27/28] Updates for plan API changes --- datafusion/optimizer/src/common_subexpr_eliminate.rs | 6 +++--- datafusion/optimizer/src/scalar_subquery_to_join.rs | 6 +++++- datafusion/physical-plan/src/scalar_subquery.rs | 9 --------- datafusion/proto/src/logical_plan/to_proto.rs | 4 ++-- datafusion/proto/tests/cases/roundtrip_physical_plan.rs | 2 -- 5 files changed, 10 insertions(+), 17 deletions(-) diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index b495191f36c75..c02ba602475f3 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -589,9 +589,9 @@ impl OptimizerRule for CommonSubexprEliminate { // manner. Process uncorrelated subqueries in expressions // (e.g., Expr::ScalarSubquery), then direct children. plan.map_uncorrelated_subqueries(|c| self.rewrite(c, config))? - .transform_sibling(|plan| { - plan.map_children(|c| self.rewrite(c, config)) - })? + .transform_sibling(|plan| { + plan.map_children(|c| self.rewrite(c, config)) + })? } }; diff --git a/datafusion/optimizer/src/scalar_subquery_to_join.rs b/datafusion/optimizer/src/scalar_subquery_to_join.rs index 66b34f24dedfb..9ed6afe481baa 100644 --- a/datafusion/optimizer/src/scalar_subquery_to_join.rs +++ b/datafusion/optimizer/src/scalar_subquery_to_join.rs @@ -144,7 +144,11 @@ impl OptimizerRule for ScalarSubqueryToJoin { LogicalPlan::Projection(projection) => { // Optimization: skip the rest of the rule and its copies if there // are no correlated scalar subqueries - if !projection.expr.iter().any(contains_correlated_scalar_subquery) { + if !projection + .expr + .iter() + .any(contains_correlated_scalar_subquery) + { return Ok(Transformed::no(LogicalPlan::Projection(projection))); } diff --git a/datafusion/physical-plan/src/scalar_subquery.rs b/datafusion/physical-plan/src/scalar_subquery.rs index f0cb8b0cce7ef..8be4c5189cff9 100644 --- a/datafusion/physical-plan/src/scalar_subquery.rs +++ b/datafusion/physical-plan/src/scalar_subquery.rs @@ -24,7 +24,6 @@ //! //! [`ScalarSubqueryExpr`]: datafusion_physical_expr::scalar_subquery::ScalarSubqueryExpr -use std::any::Any; use std::fmt; use std::sync::Arc; @@ -147,10 +146,6 @@ impl ExecutionPlan for ScalarSubqueryExec { "ScalarSubqueryExec" } - fn as_any(&self) -> &dyn Any { - self - } - fn properties(&self) -> &Arc { &self.cache } @@ -355,10 +350,6 @@ mod tests { "CountingExec" } - fn as_any(&self) -> &dyn Any { - self - } - fn properties(&self) -> &Arc { self.inner.properties() } diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 9ee0d73293a39..bd5c4b585c24f 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -25,8 +25,8 @@ use datafusion_common::{NullEquality, TableReference, UnnestOptions}; use datafusion_expr::WriteOp; use datafusion_expr::dml::InsertOp; use datafusion_expr::expr::{ - self, AggregateFunctionParams, Alias, Between, BinaryExpr, Cast, GroupingSet, - InList, Like, NullTreatment, Placeholder, ScalarFunction, Unnest, + self, AggregateFunctionParams, Alias, Between, BinaryExpr, Cast, GroupingSet, InList, + Like, NullTreatment, Placeholder, ScalarFunction, Unnest, }; use datafusion_expr::logical_plan::Subquery; use datafusion_expr::{ diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index 675fed9ed6fbd..b24b19a873420 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -3250,7 +3250,6 @@ fn roundtrip_scalar_subquery_exec() -> Result<()> { // Verify the deserialized ScalarSubqueryExec's results container is // shared with the ScalarSubqueryExpr in the input plan. let sq_exec = deserialized - .as_any() .downcast_ref::() .expect("expected ScalarSubqueryExec"); let exec_results = sq_exec.results(); @@ -3259,7 +3258,6 @@ fn roundtrip_scalar_subquery_exec() -> Result<()> { // points to the same results container. let filter_exec = sq_exec .input() - .as_any() .downcast_ref::() .expect("expected FilterExec"); let binary_expr = filter_exec From 1239e3a28ce309ca50d68f055fd5474b1958077d Mon Sep 17 00:00:00 2001 From: Neil Conway Date: Fri, 3 Apr 2026 11:45:10 -0400 Subject: [PATCH 28/28] Fix doc build --- datafusion/core/src/physical_planner.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index fb7e08835a1ea..f94078bdbece2 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -414,7 +414,7 @@ impl DefaultPhysicalPlanner { /// /// Uncorrelated scalar subqueries in the plan's own expressions are /// collected, planned as separate physical plans, and each assigned a - /// shared [`OnceLock`] slot that will hold its result at execution time. + /// shared `OnceLock` slot that will hold its result at execution time. /// These slots are registered in [`ExecutionProps`] so that /// [`create_physical_expr`] can convert `Expr::ScalarSubquery` into /// [`ScalarSubqueryExpr`] nodes that read from the slots.