diff --git a/.github/workflows/large_files.yml b/.github/workflows/large_files.yml index 12a5599189216..63cf18027c80f 100644 --- a/.github/workflows/large_files.yml +++ b/.github/workflows/large_files.yml @@ -34,9 +34,9 @@ jobs: fetch-depth: 0 - name: Check size of new Git objects env: - # 1 MB ought to be enough for anybody. + # 2 MB ought to be enough for anybody. # TODO in case we may want to consciously commit a bigger file to the repo without using Git LFS we may disable the check e.g. with a label - MAX_FILE_SIZE_BYTES: 1048576 + MAX_FILE_SIZE_BYTES: 2097152 shell: bash run: | if [ "${{ github.event_name }}" = "merge_group" ]; then diff --git a/datafusion/core/Cargo.toml b/datafusion/core/Cargo.toml index c1230f7d5daa6..f193acebc0b29 100644 --- a/datafusion/core/Cargo.toml +++ b/datafusion/core/Cargo.toml @@ -297,3 +297,7 @@ required-features = ["parquet"] [[bench]] harness = false name = "reset_plan_states" + +[[bench]] +harness = false +name = "scalar_subquery_sql" diff --git a/datafusion/core/benches/scalar_subquery_sql.rs b/datafusion/core/benches/scalar_subquery_sql.rs new file mode 100644 index 0000000000000..3642bdd1a0150 --- /dev/null +++ b/datafusion/core/benches/scalar_subquery_sql.rs @@ -0,0 +1,118 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Benchmarks for uncorrelated scalar subquery evaluation. +//! +//! Measures the overhead of subquery execution machinery by using simple +//! arithmetic and comparison operators that don't have specialized scalar +//! fast paths, keeping the comparison between the old (join-based) and new +//! (ScalarSubqueryExec-based) approaches apples-to-apples. + +use arrow::array::Int64Array; +use arrow::datatypes::{DataType, Field, Schema}; +use arrow::record_batch::RecordBatch; +use criterion::{Criterion, criterion_group, criterion_main}; +use datafusion::datasource::MemTable; +use datafusion::error::Result; +use datafusion::prelude::SessionContext; +use std::hint::black_box; +use std::sync::Arc; +use tokio::runtime::Runtime; + +fn query(ctx: &SessionContext, rt: &Runtime, sql: &str) { + let df = rt.block_on(ctx.sql(sql)).unwrap(); + black_box(rt.block_on(df.collect()).unwrap()); +} + +fn create_context(num_rows: usize) -> Result { + let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Int64, false)])); + + let batch_size = 4096; + let batches = (0..num_rows / batch_size) + .map(|i| { + let values: Vec = + ((i * batch_size) as i64..((i + 1) * batch_size) as i64).collect(); + RecordBatch::try_new(schema.clone(), vec![Arc::new(Int64Array::from(values))]) + .unwrap() + }) + .collect::>(); + + // Small lookup table for the subquery to read from. + let sq_schema = Arc::new(Schema::new(vec![Field::new("v", DataType::Int64, false)])); + let sq_batch = RecordBatch::try_new( + sq_schema.clone(), + vec![Arc::new(Int64Array::from(vec![10, 20, 30]))], + )?; + + let ctx = SessionContext::new(); + ctx.register_table( + "main_t", + Arc::new(MemTable::try_new(schema, vec![batches])?), + )?; + ctx.register_table( + "lookup", + Arc::new(MemTable::try_new(sq_schema, vec![vec![sq_batch]])?), + )?; + + Ok(ctx) +} + +fn criterion_benchmark(c: &mut Criterion) { + let num_rows = 1_048_576; // 2^20 + let rt = Runtime::new().unwrap(); + + // Scalar subquery in a filter (WHERE clause). + c.bench_function("scalar_subquery_filter", |b| { + let ctx = create_context(num_rows).unwrap(); + b.iter(|| { + query( + &ctx, + &rt, + "SELECT x FROM main_t WHERE x > (SELECT max(v) FROM lookup)", + ) + }) + }); + + // Scalar subquery in a projection (SELECT expression). + c.bench_function("scalar_subquery_projection", |b| { + let ctx = create_context(num_rows).unwrap(); + b.iter(|| { + query( + &ctx, + &rt, + "SELECT x + (SELECT max(v) FROM lookup) AS y FROM main_t", + ) + }) + }); + + // Two scalar subqueries in one query. + c.bench_function("scalar_subquery_two_subqueries", |b| { + let ctx = create_context(num_rows).unwrap(); + b.iter(|| { + query( + &ctx, + &rt, + "SELECT x FROM main_t \ + WHERE x > (SELECT min(v) FROM lookup) \ + AND x < (SELECT max(v) FROM lookup) + 1000000", + ) + }) + }); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 82f3c4d80c9ec..f94078bdbece2 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -63,6 +63,7 @@ use arrow::datatypes::Schema; use arrow_schema::Field; use datafusion_catalog::ScanArgs; use datafusion_common::Column; +use datafusion_common::HashMap as DFHashMap; use datafusion_common::display::ToStringifiedPlan; use datafusion_common::format::ExplainAnalyzeCategories; use datafusion_common::tree_node::{ @@ -78,11 +79,15 @@ use datafusion_common::{ use datafusion_datasource::file_groups::FileGroup; use datafusion_datasource::memory::MemorySourceConfig; use datafusion_expr::dml::{CopyTo, InsertOp}; +use datafusion_expr::execution_props::{ + ScalarSubqueryResults, new_scalar_subquery_results, +}; use datafusion_expr::expr::{ AggregateFunction, AggregateFunctionParams, Alias, GroupingSet, NullTreatment, WindowFunction, WindowFunctionParams, physical_name, }; use datafusion_expr::expr_rewriter::unnormalize_cols; +use datafusion_expr::logical_plan::Subquery; use datafusion_expr::logical_plan::builder::wrap_projection_for_join_if_necessary; use datafusion_expr::utils::{expr_to_columns, split_conjunction}; use datafusion_expr::{ @@ -101,6 +106,7 @@ use datafusion_physical_plan::execution_plan::InvariantLevel; use datafusion_physical_plan::joins::PiecewiseMergeJoinExec; use datafusion_physical_plan::placeholder_row::PlaceholderRowExec; use datafusion_physical_plan::recursive_query::RecursiveQueryExec; +use datafusion_physical_plan::scalar_subquery::{ScalarSubqueryExec, ScalarSubqueryLink}; use datafusion_physical_plan::unnest::ListUnnest; use async_trait::async_trait; @@ -379,8 +385,95 @@ impl DefaultPhysicalPlanner { Ok(()) } - /// Create a physical plan from a logical plan - async fn create_initial_plan( + /// Collect uncorrelated scalar subqueries. We don't descend into nested + /// subqueries here: each call to `create_initial_plan` handles subqueries + /// at its level and then recurses in order to handle nested subqueries. + fn collect_scalar_subqueries(plan: &LogicalPlan) -> Vec { + let mut subqueries = Vec::new(); + let mut seen = HashSet::new(); + plan.apply(|node| { + for expr in node.expressions() { + expr.apply(|e| { + if let Expr::ScalarSubquery(sq) = e + && sq.outer_ref_columns.is_empty() + && seen.insert(sq.clone()) + { + subqueries.push(sq.clone()); + } + Ok(TreeNodeRecursion::Continue) + }) + .expect("infallible"); + } + Ok(TreeNodeRecursion::Continue) + }) + .expect("infallible"); + subqueries + } + + /// Create a physical plan from a logical plan. + /// + /// Uncorrelated scalar subqueries in the plan's own expressions are + /// collected, planned as separate physical plans, and each assigned a + /// shared `OnceLock` slot that will hold its result at execution time. + /// These slots are registered in [`ExecutionProps`] so that + /// [`create_physical_expr`] can convert `Expr::ScalarSubquery` into + /// [`ScalarSubqueryExpr`] nodes that read from the slots. + /// + /// The resulting physical plan is wrapped in a [`ScalarSubqueryExec`] node + /// that executes those subquery plans before any data flows through the + /// main plan. If a subquery itself contains nested uncorrelated subqueries, + /// the recursive call produces its own [`ScalarSubqueryExec`] inside the + /// subquery plan — each level manages only its own subqueries. + /// + /// Returns a [`BoxFuture`] rather than using `async fn` because of + /// this recursion. + /// + /// [`ScalarSubqueryExpr`]: datafusion_physical_expr::scalar_subquery::ScalarSubqueryExpr + /// [`BoxFuture`]: futures::future::BoxFuture + fn create_initial_plan<'a>( + &'a self, + logical_plan: &'a LogicalPlan, + session_state: &'a SessionState, + ) -> futures::future::BoxFuture<'a, Result>> { + Box::pin(async move { + let all_subqueries = Self::collect_scalar_subqueries(logical_plan); + let all_sq_refs: Vec<&_> = all_subqueries.iter().collect(); + let (links, index_map) = self + .plan_scalar_subqueries(&all_sq_refs, session_state) + .await?; + + // Create the shared results container and register it in + // ExecutionProps so that `create_physical_expr` can resolve + // `Expr::ScalarSubquery` into `ScalarSubqueryExpr` nodes. We clone + // the SessionState so these are available throughout physical + // planning without mutating the caller's state. + // + // Ideally, the subquery state would live in a dedicated planning + // context rather than on ExecutionProps. It's here because + // `create_physical_expr` only receives `&ExecutionProps`, and + // changing that signature would be a breaking public API change. + let results = new_scalar_subquery_results(links.len()); + let session_state = if links.is_empty() { + Cow::Borrowed(session_state) + } else { + let mut owned = session_state.clone(); + owned.execution_props_mut().subquery_indexes = index_map; + owned.execution_props_mut().subquery_results = Arc::clone(&results); + Cow::Owned(owned) + }; + + let plan = self + .create_initial_plan_inner(logical_plan, &session_state) + .await?; + Ok(Self::wrap_scalar_subquery_exec_if_needed( + plan, links, results, + )) + }) + } + + /// Inner physical planning that converts a logical plan tree into an + /// execution plan tree without collecting scalar subqueries. + async fn create_initial_plan_inner( &self, logical_plan: &LogicalPlan, session_state: &SessionState, @@ -545,6 +638,7 @@ impl DefaultPhysicalPlanner { session_state: &SessionState, children: ChildrenContainer, ) -> Result> { + let execution_props = session_state.execution_props(); let exec_node: Arc = match node { // Leaves (no children) LogicalPlan::TableScan(scan) => { @@ -601,7 +695,7 @@ impl DefaultPhysicalPlanner { .map(|row| { row.iter() .map(|expr| { - self.create_physical_expr(expr, schema, session_state) + create_physical_expr(expr, schema, execution_props) }) .collect::>>>() }) @@ -868,13 +962,7 @@ impl DefaultPhysicalPlanner { let logical_schema = node.schema(); let window_expr = window_expr .iter() - .map(|e| { - create_window_expr( - e, - logical_schema, - session_state.execution_props(), - ) - }) + .map(|e| create_window_expr(e, logical_schema, execution_props)) .collect::>>()?; let can_repartition = session_state.config().target_partitions() > 1 @@ -979,7 +1067,7 @@ impl DefaultPhysicalPlanner { group_expr, logical_input_schema, &physical_input_schema, - session_state, + execution_props, )?; let agg_filter = aggr_expr @@ -989,7 +1077,7 @@ impl DefaultPhysicalPlanner { e, logical_input_schema, &physical_input_schema, - session_state.execution_props(), + execution_props, ) }) .collect::>>()?; @@ -1083,8 +1171,8 @@ impl DefaultPhysicalPlanner { )?) } LogicalPlan::Projection(Projection { input, expr, .. }) => self - .create_project_physical_exec( - session_state, + .create_project_physical_exec_with_props( + execution_props, children.one()?, input, expr, @@ -1094,9 +1182,8 @@ impl DefaultPhysicalPlanner { }) => { let physical_input = children.one()?; let input_dfschema = input.schema(); - let runtime_expr = - self.create_physical_expr(predicate, input_dfschema, session_state)?; + create_physical_expr(predicate, input_dfschema, execution_props)?; let input_schema = input.schema(); let filter = match self.try_plan_async_exprs( @@ -1144,7 +1231,9 @@ impl DefaultPhysicalPlanner { .options() .optimizer .default_filter_selectivity; - Arc::new(filter.with_default_selectivity(selectivity)?) + let filter_exec: Arc = + Arc::new(filter.with_default_selectivity(selectivity)?); + filter_exec } LogicalPlan::Repartition(Repartition { input, @@ -1160,11 +1249,7 @@ impl DefaultPhysicalPlanner { let runtime_expr = expr .iter() .map(|e| { - self.create_physical_expr( - e, - input_dfschema, - session_state, - ) + create_physical_expr(e, input_dfschema, execution_props) }) .collect::>>()?; Partitioning::Hash(runtime_expr, *n) @@ -1185,11 +1270,8 @@ impl DefaultPhysicalPlanner { }) => { let physical_input = children.one()?; let input_dfschema = input.as_ref().schema(); - let sort_exprs = create_physical_sort_exprs( - expr, - input_dfschema, - session_state.execution_props(), - )?; + let sort_exprs = + create_physical_sort_exprs(expr, input_dfschema, execution_props)?; let Some(ordering) = LexOrdering::new(sort_exprs) else { return internal_err!( "SortExec requires at least one sort expression" @@ -1316,8 +1398,8 @@ impl DefaultPhysicalPlanner { ( true, LogicalPlan::Projection(Projection { input, expr, .. }), - ) => self.create_project_physical_exec( - session_state, + ) => self.create_project_physical_exec_with_props( + execution_props, physical_left, input, expr, @@ -1329,8 +1411,8 @@ impl DefaultPhysicalPlanner { ( true, LogicalPlan::Projection(Projection { input, expr, .. }), - ) => self.create_project_physical_exec( - session_state, + ) => self.create_project_physical_exec_with_props( + execution_props, physical_right, input, expr, @@ -1395,7 +1477,6 @@ impl DefaultPhysicalPlanner { // All equi-join keys are columns now, create physical join plan let left_df_schema = left.schema(); let right_df_schema = right.schema(); - let execution_props = session_state.execution_props(); let join_on = keys .iter() .map(|(l, r)| { @@ -1501,7 +1582,7 @@ impl DefaultPhysicalPlanner { let filter_expr = create_physical_expr( expr, &filter_df_schema, - session_state.execution_props(), + execution_props, )?; let column_indices = join_utils::JoinFilter::build_column_indices( left_field_indices, @@ -1621,12 +1702,12 @@ impl DefaultPhysicalPlanner { let on_left = create_physical_expr( lhs_logical, left_df_schema, - session_state.execution_props(), + execution_props, )?; let on_right = create_physical_expr( rhs_logical, right_df_schema, - session_state.execution_props(), + execution_props, )?; Arc::new(PiecewiseMergeJoinExec::try_new( @@ -1696,7 +1777,12 @@ impl DefaultPhysicalPlanner { // If plan was mutated previously then need to create the ExecutionPlan // for the new Projection that was applied on top. if let Some((input, expr)) = new_project { - self.create_project_physical_exec(session_state, join, input, expr)? + self.create_project_physical_exec_with_props( + execution_props, + join, + input, + expr, + )? } else { join } @@ -1790,7 +1876,7 @@ impl DefaultPhysicalPlanner { group_expr: &[Expr], input_dfschema: &DFSchema, input_schema: &Schema, - session_state: &SessionState, + execution_props: &ExecutionProps, ) -> Result { if group_expr.len() == 1 { match &group_expr[0] { @@ -1799,25 +1885,25 @@ impl DefaultPhysicalPlanner { grouping_sets, input_dfschema, input_schema, - session_state, + execution_props, ) } Expr::GroupingSet(GroupingSet::Cube(exprs)) => create_cube_physical_expr( exprs, input_dfschema, input_schema, - session_state, + execution_props, ), Expr::GroupingSet(GroupingSet::Rollup(exprs)) => { create_rollup_physical_expr( exprs, input_dfschema, input_schema, - session_state, + execution_props, ) } expr => Ok(PhysicalGroupBy::new_single(vec![tuple_err(( - self.create_physical_expr(expr, input_dfschema, session_state), + create_physical_expr(expr, input_dfschema, execution_props), physical_name(expr), ))?])), } @@ -1831,7 +1917,7 @@ impl DefaultPhysicalPlanner { .iter() .map(|e| { tuple_err(( - self.create_physical_expr(e, input_dfschema, session_state), + create_physical_expr(e, input_dfschema, execution_props), physical_name(e), )) }) @@ -1855,7 +1941,7 @@ fn merge_grouping_set_physical_expr( grouping_sets: &[Vec], input_dfschema: &DFSchema, input_schema: &Schema, - session_state: &SessionState, + execution_props: &ExecutionProps, ) -> Result { let num_groups = grouping_sets.len(); let mut all_exprs: Vec = vec![]; @@ -1869,14 +1955,14 @@ fn merge_grouping_set_physical_expr( grouping_set_expr.push(get_physical_expr_pair( expr, input_dfschema, - session_state, + execution_props, )?); null_exprs.push(get_null_physical_expr_pair( expr, input_dfschema, input_schema, - session_state, + execution_props, )?); } } @@ -1906,7 +1992,7 @@ fn create_cube_physical_expr( exprs: &[Expr], input_dfschema: &DFSchema, input_schema: &Schema, - session_state: &SessionState, + execution_props: &ExecutionProps, ) -> Result { let num_of_exprs = exprs.len(); let num_groups = num_of_exprs * num_of_exprs; @@ -1921,10 +2007,14 @@ fn create_cube_physical_expr( expr, input_dfschema, input_schema, - session_state, + execution_props, )?); - all_exprs.push(get_physical_expr_pair(expr, input_dfschema, session_state)?) + all_exprs.push(get_physical_expr_pair( + expr, + input_dfschema, + execution_props, + )?) } let mut groups: Vec> = Vec::with_capacity(num_groups); @@ -1948,7 +2038,7 @@ fn create_rollup_physical_expr( exprs: &[Expr], input_dfschema: &DFSchema, input_schema: &Schema, - session_state: &SessionState, + execution_props: &ExecutionProps, ) -> Result { let num_of_exprs = exprs.len(); @@ -1964,10 +2054,14 @@ fn create_rollup_physical_expr( expr, input_dfschema, input_schema, - session_state, + execution_props, )?); - all_exprs.push(get_physical_expr_pair(expr, input_dfschema, session_state)?) + all_exprs.push(get_physical_expr_pair( + expr, + input_dfschema, + execution_props, + )?) } for total in 0..=num_of_exprs { @@ -1992,10 +2086,9 @@ fn get_null_physical_expr_pair( expr: &Expr, input_dfschema: &DFSchema, input_schema: &Schema, - session_state: &SessionState, + execution_props: &ExecutionProps, ) -> Result<(Arc, String)> { - let physical_expr = - create_physical_expr(expr, input_dfschema, session_state.execution_props())?; + let physical_expr = create_physical_expr(expr, input_dfschema, execution_props)?; let physical_name = physical_name(&expr.clone())?; let data_type = physical_expr.data_type(input_schema)?; @@ -2064,10 +2157,9 @@ fn qualify_join_schema_sides( fn get_physical_expr_pair( expr: &Expr, input_dfschema: &DFSchema, - session_state: &SessionState, + execution_props: &ExecutionProps, ) -> Result<(Arc, String)> { - let physical_expr = - create_physical_expr(expr, input_dfschema, session_state.execution_props())?; + let physical_expr = create_physical_expr(expr, input_dfschema, execution_props)?; let physical_name = physical_name(expr)?; Ok((physical_expr, physical_name)) } @@ -2846,9 +2938,49 @@ impl DefaultPhysicalPlanner { Ok(mem_exec) } - fn create_project_physical_exec( + /// Build physical plans for scalar subqueries and assign each an ordinal + /// index. Returns the links (plan + index) and a map from logical + /// `Subquery` to its index. + async fn plan_scalar_subqueries( &self, + subqueries: &[&Subquery], session_state: &SessionState, + ) -> Result<(Vec, DFHashMap)> { + let mut links = Vec::with_capacity(subqueries.len()); + let mut index_map = DFHashMap::with_capacity(subqueries.len()); + for &sq in subqueries { + // Callers deduplicate, but guard against accidental double-planning. + if index_map.contains_key(sq) { + continue; + } + let physical_plan = self + .create_initial_plan(&sq.subquery, session_state) + .await?; + let index = links.len(); + links.push(ScalarSubqueryLink { + plan: physical_plan, + index, + }); + index_map.insert(sq.clone(), index); + } + Ok((links, index_map)) + } + + fn wrap_scalar_subquery_exec_if_needed( + input: Arc, + subqueries: Vec, + results: ScalarSubqueryResults, + ) -> Arc { + if subqueries.is_empty() { + input + } else { + Arc::new(ScalarSubqueryExec::new(input, subqueries, results)) + } + } + + fn create_project_physical_exec_with_props( + &self, + execution_props: &ExecutionProps, input_exec: Arc, input: &Arc, expr: &[Expr], @@ -2887,7 +3019,7 @@ impl DefaultPhysicalPlanner { }; let physical_expr = - self.create_physical_expr(e, input_logical_schema, session_state); + create_physical_expr(e, input_logical_schema, execution_props); tuple_err((physical_expr, physical_name)) }) @@ -3132,7 +3264,7 @@ mod tests { use arrow_schema::SchemaRef; use datafusion_common::config::ConfigOptions; use datafusion_common::{ - DFSchemaRef, TableReference, ToDFSchema as _, assert_contains, + DFSchemaRef, TableReference, ToDFSchema as _, assert_batches_eq, assert_contains, }; use datafusion_execution::TaskContext; use datafusion_execution::runtime_env::RuntimeEnv; @@ -3166,6 +3298,16 @@ mod tests { .await } + async fn plan_sql(query: &str) -> Result> { + let ctx = SessionContext::new(); + ctx.sql(query).await?.create_physical_plan().await + } + + async fn collect_sql(query: &str) -> Result> { + let ctx = SessionContext::new(); + ctx.sql(query).await?.collect().await + } + #[tokio::test] async fn test_all_operators() -> Result<()> { let logical_plan = test_csv_scan() @@ -3189,6 +3331,99 @@ mod tests { Ok(()) } + #[tokio::test] + async fn scalar_subquery_in_sort_expr_plans() -> Result<()> { + let plan = plan_sql( + "SELECT x \ + FROM (VALUES (2), (1)) AS t(x) \ + ORDER BY x + (SELECT max(y) FROM (VALUES (10), (20)) AS u(y))", + ) + .await?; + + assert_contains!(format!("{plan:?}"), "ScalarSubqueryExec"); + Ok(()) + } + + #[tokio::test] + async fn scalar_subquery_in_sort_expr_executes() -> Result<()> { + let batches = collect_sql( + "SELECT x \ + FROM (VALUES (2), (1), (3)) AS t(x) \ + ORDER BY x + (SELECT max(y) FROM (VALUES (10), (20)) AS u(y)) DESC", + ) + .await?; + + assert_batches_eq!( + &[ + "+---+", "| x |", "+---+", "| 3 |", "| 2 |", "| 1 |", "+---+", + ], + &batches + ); + Ok(()) + } + + #[tokio::test] + async fn scalar_subquery_in_aggregate_arg_plans() -> Result<()> { + let plan = plan_sql( + "SELECT sum(x + (SELECT max(y) FROM (VALUES (10), (20)) AS u(y))) \ + FROM (VALUES (2), (1)) AS t(x)", + ) + .await?; + + assert_contains!(format!("{plan:?}"), "ScalarSubqueryExec"); + Ok(()) + } + + #[tokio::test] + async fn scalar_subquery_in_aggregate_arg_executes() -> Result<()> { + let batches = collect_sql( + "SELECT sum(x + (SELECT max(y) FROM (VALUES (10), (20)) AS u(y))) AS s \ + FROM (VALUES (2), (1)) AS t(x)", + ) + .await?; + + assert_batches_eq!( + &["+----+", "| s |", "+----+", "| 43 |", "+----+",], + &batches + ); + Ok(()) + } + + #[tokio::test] + async fn scalar_subquery_in_join_on_plans() -> Result<()> { + let plan = plan_sql( + "SELECT l.x, r.y \ + FROM (VALUES (1), (2)) AS l(x) \ + JOIN (VALUES (11), (12)) AS r(y) \ + ON l.x + (SELECT 10) = r.y", + ) + .await?; + + let formatted = format!("{plan:?}"); + assert_contains!(&formatted, "ScalarSubqueryExec"); + assert!( + formatted.contains("HashJoinExec") + || formatted.contains("SortMergeJoinExec") + || formatted.contains("NestedLoopJoinExec") + ); + Ok(()) + } + + #[tokio::test] + async fn scalar_subquery_in_projection_and_filter_plans() -> Result<()> { + let plan = plan_sql( + "SELECT x + (SELECT max(y) FROM (VALUES (10), (20)) AS u(y)) \ + FROM (VALUES (2), (1)) AS t(x) \ + WHERE x > (SELECT min(y) FROM (VALUES (0), (1)) AS v(y))", + ) + .await?; + + let formatted = format!("{plan:?}"); + // All uncorrelated scalar subqueries are hoisted to a single root node. + assert_eq!(formatted.matches("ScalarSubqueryExec").count(), 1); + Ok(()) + } + #[tokio::test] async fn test_create_cube_expr() -> Result<()> { let logical_plan = test_csv_scan().await?.build()?; @@ -3206,7 +3441,7 @@ mod tests { &exprs, logical_input_schema, physical_input_schema, - &session_state, + session_state.execution_props(), ); insta::assert_debug_snapshot!(cube, @r#" @@ -3337,7 +3572,7 @@ mod tests { &exprs, logical_input_schema, physical_input_schema, - &session_state, + session_state.execution_props(), ); insta::assert_debug_snapshot!(rollup, @r#" diff --git a/datafusion/expr/src/execution_props.rs b/datafusion/expr/src/execution_props.rs index 3bf6978eb60ee..284472107ec9a 100644 --- a/datafusion/expr/src/execution_props.rs +++ b/datafusion/expr/src/execution_props.rs @@ -18,9 +18,22 @@ use crate::var_provider::{VarProvider, VarType}; use chrono::{DateTime, Utc}; use datafusion_common::HashMap; +use datafusion_common::ScalarValue; use datafusion_common::alias::AliasGenerator; use datafusion_common::config::ConfigOptions; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; + +/// Shared results container for uncorrelated scalar subqueries. +/// +/// Each entry corresponds to one scalar subquery, identified by its index. +/// The `OnceLock` is populated at execution time by `ScalarSubqueryExec` and +/// read by `ScalarSubqueryExpr` instances that share this container. +pub type ScalarSubqueryResults = Arc>>; + +/// Creates a [`ScalarSubqueryResults`] container with `n` empty slots. +pub fn new_scalar_subquery_results(n: usize) -> ScalarSubqueryResults { + Arc::new((0..n).map(|_| OnceLock::new()).collect()) +} /// Holds per-query execution properties and data (such as statement /// starting timestamps). @@ -42,6 +55,12 @@ pub struct ExecutionProps { pub config_options: Option>, /// Providers for scalar variables pub var_providers: Option>>, + /// Maps each logical `Subquery` to its index in `subquery_results`. + /// Populated by the physical planner before calling `create_physical_expr`. + pub subquery_indexes: HashMap, + /// Shared results container for uncorrelated scalar subquery values. + /// Populated at execution time by `ScalarSubqueryExec`. + pub subquery_results: ScalarSubqueryResults, } impl Default for ExecutionProps { @@ -58,6 +77,8 @@ impl ExecutionProps { alias_generator: Arc::new(AliasGenerator::new()), config_options: None, var_providers: None, + subquery_indexes: HashMap::new(), + subquery_results: Arc::new(vec![]), } } @@ -126,7 +147,7 @@ mod test { fn debug() { let props = ExecutionProps::new(); assert_eq!( - "ExecutionProps { query_execution_start_time: None, alias_generator: AliasGenerator { next_id: 1 }, config_options: None, var_providers: None }", + "ExecutionProps { query_execution_start_time: None, alias_generator: AliasGenerator { next_id: 1 }, config_options: None, var_providers: None, subquery_indexes: {}, subquery_results: [] }", format!("{props:?}") ); } diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 7dba4f0579b97..c3dffe553dc70 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -2059,6 +2059,12 @@ impl Expr { .expect("exists closure is infallible") } + /// Returns true if the expression contains a scalar subquery. + pub fn contains_scalar_subquery(&self) -> bool { + self.exists(|expr| Ok(matches!(expr, Expr::ScalarSubquery(_)))) + .expect("exists closure is infallible") + } + /// Returns true if the expression node is volatile, i.e. whether it can return /// different results when evaluated multiple times with the same input. /// Note: unlike [`Self::is_volatile`], this function does not consider inputs: diff --git a/datafusion/expr/src/logical_plan/tree_node.rs b/datafusion/expr/src/logical_plan/tree_node.rs index a1285510da569..2b1015d4802b5 100644 --- a/datafusion/expr/src/logical_plan/tree_node.rs +++ b/datafusion/expr/src/logical_plan/tree_node.rs @@ -808,7 +808,7 @@ impl LogicalPlan { transform_down_up_with_subqueries_impl(self, &mut f_down, &mut f_up) } - /// Similarly to [`Self::apply`], calls `f` on this node and its inputs + /// Similarly to [`Self::apply`], calls `f` on this node and its inputs, /// including subqueries that may appear in expressions such as `IN (SELECT /// ...)`. pub fn apply_subqueries Result>( @@ -821,9 +821,7 @@ impl LogicalPlan { | Expr::InSubquery(InSubquery { subquery, .. }) | Expr::SetComparison(SetComparison { subquery, .. }) | Expr::ScalarSubquery(subquery) => { - // use a synthetic plan so the collector sees a - // LogicalPlan::Subquery (even though it is - // actually a Subquery alias) + // Wrap in LogicalPlan::Subquery to match f's signature f(&LogicalPlan::Subquery(subquery.clone())) } _ => Ok(TreeNodeRecursion::Continue), @@ -888,4 +886,18 @@ impl LogicalPlan { }) }) } + + /// Similar to [`Self::map_subqueries`], but only applies `f` to + /// uncorrelated subqueries (those with no outer reference columns). + pub fn map_uncorrelated_subqueries Result>>( + self, + mut f: F, + ) -> Result> { + self.map_subqueries(|subquery_plan| match &subquery_plan { + LogicalPlan::Subquery(sq) if sq.outer_ref_columns.is_empty() => { + f(subquery_plan) + } + _ => Ok(Transformed::no(subquery_plan)), + }) + } } diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index 4213c23ccc897..c02ba602475f3 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -586,8 +586,12 @@ impl OptimizerRule for CommonSubexprEliminate { | LogicalPlan::Unnest(_) | LogicalPlan::RecursiveQuery(_) => { // This rule handles recursion itself in a `ApplyOrder::TopDown` like - // manner. - plan.map_children(|c| self.rewrite(c, config))? + // manner. Process uncorrelated subqueries in expressions + // (e.g., Expr::ScalarSubquery), then direct children. + plan.map_uncorrelated_subqueries(|c| self.rewrite(c, config))? + .transform_sibling(|plan| { + plan.map_children(|c| self.rewrite(c, config)) + })? } }; diff --git a/datafusion/optimizer/src/eliminate_cross_join.rs b/datafusion/optimizer/src/eliminate_cross_join.rs index 3cb0516a6d296..8306d4b54c256 100644 --- a/datafusion/optimizer/src/eliminate_cross_join.rs +++ b/datafusion/optimizer/src/eliminate_cross_join.rs @@ -212,7 +212,12 @@ fn rewrite_children( plan: LogicalPlan, config: &dyn OptimizerConfig, ) -> Result> { - let transformed_plan = plan.map_children(|input| optimizer.rewrite(input, config))?; + // Process uncorrelated subqueries in expressions, then direct children. + let transformed_plan = plan + .map_uncorrelated_subqueries(|input| optimizer.rewrite(input, config))? + .transform_sibling(|plan| { + plan.map_children(|input| optimizer.rewrite(input, config)) + })?; // recompute schema if the plan was transformed if transformed_plan.transformed { diff --git a/datafusion/optimizer/src/optimize_projections/mod.rs b/datafusion/optimizer/src/optimize_projections/mod.rs index 93df300bb50b4..6bac3abeb8376 100644 --- a/datafusion/optimizer/src/optimize_projections/mod.rs +++ b/datafusion/optimizer/src/optimize_projections/mod.rs @@ -136,9 +136,11 @@ fn optimize_projections( // their parents' required indices. match plan { LogicalPlan::Projection(proj) => { - return merge_consecutive_projections(proj)?.transform_data(|proj| { - rewrite_projection_given_requirements(proj, config, &indices) - }); + return merge_consecutive_projections(proj)? + .transform_data(|proj| { + rewrite_projection_given_requirements(proj, config, &indices) + })? + .transform_data(|plan| optimize_subqueries(plan, config)); } LogicalPlan::Aggregate(aggregate) => { // Split parent requirements to GROUP BY and aggregate sections: @@ -210,7 +212,8 @@ fn optimize_projections( new_aggr_expr, ) .map(LogicalPlan::Aggregate) - }); + })? + .transform_data(|plan| optimize_subqueries(plan, config)); } LogicalPlan::Window(window) => { let input_schema = Arc::clone(window.input.schema()); @@ -250,7 +253,8 @@ fn optimize_projections( .map(LogicalPlan::Window) .map(Transformed::yes) } - }); + })? + .transform_data(|plan| optimize_subqueries(plan, config)); } LogicalPlan::TableScan(table_scan) => { let TableScan { @@ -271,7 +275,8 @@ fn optimize_projections( let new_scan = TableScan::try_new(table_name, source, Some(projection), filters, fetch)?; - return Ok(Transformed::yes(LogicalPlan::TableScan(new_scan))); + return Transformed::yes(LogicalPlan::TableScan(new_scan)) + .transform_data(|plan| optimize_subqueries(plan, config)); } // Other node types are handled below _ => {} @@ -458,6 +463,9 @@ fn optimize_projections( ) })?; + let transformed_plan = + transformed_plan.transform_data(|plan| optimize_subqueries(plan, config))?; + // If any of the children are transformed, we need to potentially update the plan's schema if transformed_plan.transformed { transformed_plan.map_data(|plan| plan.recompute_schema()) @@ -468,6 +476,19 @@ fn optimize_projections( /// Merges consecutive projections. /// +/// Optimizes uncorrelated subquery plans embedded in expressions of the given +/// plan node (e.g., `Expr::ScalarSubquery`). `map_children` only visits direct +/// plan inputs, so subqueries must be handled separately. +fn optimize_subqueries( + plan: LogicalPlan, + config: &dyn OptimizerConfig, +) -> Result> { + plan.map_uncorrelated_subqueries(|subquery_plan| { + let indices = RequiredIndices::new_for_all_exprs(&subquery_plan); + optimize_projections(subquery_plan, config, indices) + }) +} + /// Given a projection `proj`, this function attempts to merge it with a previous /// projection if it exists and if merging is beneficial. Merging is considered /// beneficial when expressions in the current projection are non-trivial and diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index a1a636cfef9af..3be9c97cff8c4 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -1144,8 +1144,16 @@ impl OptimizerRule for PushDownFilter { LogicalPlan::TableScan(scan) => { let filter_predicates = split_conjunction(&filter.predicate); - let (volatile_filters, non_volatile_filters): (Vec<&Expr>, Vec<&Expr>) = + // Filters containing scalar subqueries cannot be pushed to + // providers because the subquery result is not available + // until execution time. + let (subquery_filters, pushdown_candidates): (Vec<&Expr>, Vec<&Expr>) = filter_predicates + .into_iter() + .partition(|pred| pred.contains_scalar_subquery()); + + let (volatile_filters, non_volatile_filters): (Vec<&Expr>, Vec<&Expr>) = + pushdown_candidates .into_iter() .partition(|pred| pred.is_volatile()); @@ -1178,11 +1186,13 @@ impl OptimizerRule for PushDownFilter { .cloned() .collect(); - // Compose predicates to be of `Unsupported` or `Inexact` pushdown type, and also include volatile filters + // Compose predicates to be of `Unsupported` or `Inexact` pushdown type, + // and also include volatile and subquery-containing filters let new_predicate: Vec = zip .filter(|(_, res)| res != &TableProviderFilterPushDown::Exact) .map(|(pred, _)| pred) .chain(volatile_filters) + .chain(subquery_filters) .cloned() .collect(); diff --git a/datafusion/optimizer/src/scalar_subquery_to_join.rs b/datafusion/optimizer/src/scalar_subquery_to_join.rs index 590b00098bd46..9ed6afe481baa 100644 --- a/datafusion/optimizer/src/scalar_subquery_to_join.rs +++ b/datafusion/optimizer/src/scalar_subquery_to_join.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -//! [`ScalarSubqueryToJoin`] rewriting scalar subquery filters to `JOIN`s +//! [`ScalarSubqueryToJoin`] rewriting correlated scalar subquery filters to `JOIN`s use std::collections::{BTreeSet, HashMap}; use std::sync::Arc; @@ -36,9 +36,9 @@ use datafusion_expr::logical_plan::{JoinType, Subquery}; use datafusion_expr::utils::conjunction; use datafusion_expr::{EmptyRelation, Expr, LogicalPlan, LogicalPlanBuilder, expr}; -/// Optimizer rule for rewriting subquery filters to joins -/// and places additional projection on top of the filter, to preserve -/// original schema. +/// Optimizer rule that rewrites correlated scalar subquery filters to joins and +/// places an additional projection on top of the filter, to preserve original +/// schema. #[derive(Default, Debug)] pub struct ScalarSubqueryToJoin {} @@ -48,10 +48,15 @@ impl ScalarSubqueryToJoin { Self::default() } - /// Finds expressions that have a scalar subquery in them (and recurses when found) + /// Finds expressions that contain correlated scalar subqueries (and + /// recurses when found) /// /// # Arguments - /// * `predicate` - A conjunction to split and search + /// * `predicate` - A conjunction to split and search. + /// * `alias_gen` - Generator used to produce unique aliases for each + /// extracted scalar subquery (e.g. `__scalar_sq_1`, `__scalar_sq_2`). + /// Each subquery is replaced by a column reference using the generated + /// alias, and the same alias is later used to construct the join. /// /// Returns a tuple (subqueries, alias) fn extract_subquery_exprs( @@ -85,7 +90,7 @@ impl OptimizerRule for ScalarSubqueryToJoin { LogicalPlan::Filter(filter) => { // Optimization: skip the rest of the rule and its copies if // there are no scalar subqueries - if !contains_scalar_subquery(&filter.predicate) { + if !contains_correlated_scalar_subquery(&filter.predicate) { return Ok(Transformed::no(LogicalPlan::Filter(filter))); } @@ -137,9 +142,13 @@ impl OptimizerRule for ScalarSubqueryToJoin { Ok(Transformed::yes(new_plan)) } LogicalPlan::Projection(projection) => { - // Optimization: skip the rest of the rule and its copies if - // there are no scalar subqueries - if !projection.expr.iter().any(contains_scalar_subquery) { + // Optimization: skip the rest of the rule and its copies if there + // are no correlated scalar subqueries + if !projection + .expr + .iter() + .any(contains_correlated_scalar_subquery) + { return Ok(Transformed::no(LogicalPlan::Projection(projection))); } @@ -222,11 +231,14 @@ impl OptimizerRule for ScalarSubqueryToJoin { } } -/// Returns true if the expression has a scalar subquery somewhere in it -/// false otherwise -fn contains_scalar_subquery(expr: &Expr) -> bool { - expr.exists(|expr| Ok(matches!(expr, Expr::ScalarSubquery(_)))) - .expect("Inner is always Ok") +/// Returns true if the expression contains a correlated scalar subquery, false +/// otherwise. Uncorrelated scalar subqueries are handled by the physical +/// planner via `ScalarSubqueryExec` and do not need to be converted to joins. +fn contains_correlated_scalar_subquery(expr: &Expr) -> bool { + expr.exists(|expr| { + Ok(matches!(expr, Expr::ScalarSubquery(sq) if !sq.outer_ref_columns.is_empty())) + }) + .expect("Inner is always Ok") } struct ExtractScalarSubQuery<'a> { @@ -239,7 +251,11 @@ impl TreeNodeRewriter for ExtractScalarSubQuery<'_> { fn f_down(&mut self, expr: Expr) -> Result> { match expr { - Expr::ScalarSubquery(subquery) => { + // Skip uncorrelated scalar subqueries + Expr::ScalarSubquery(ref subquery) + if !subquery.outer_ref_columns.is_empty() => + { + let subquery = subquery.clone(); let subqry_alias = self.alias_gen.next("__scalar_sq"); self.sub_query_info .push((subquery.clone(), subqry_alias.clone())); @@ -623,15 +639,13 @@ mod tests { plan, @r" Projection: customer.c_custkey [c_custkey:Int64] - Projection: customer.c_custkey, customer.c_name [c_custkey:Int64, c_name:Utf8] - Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N] - Left Join: Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N] - TableScan: customer [c_custkey:Int64, c_name:Utf8] - SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N] - Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N] - Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N] - Filter: orders.o_custkey = orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] - TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + Filter: customer.c_custkey = () [c_custkey:Int64, c_name:Utf8] + Subquery: [max(orders.o_custkey):Int64;N] + Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N] + Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N] + Filter: orders.o_custkey = orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + TableScan: customer [c_custkey:Int64, c_name:Utf8] " ) } @@ -1028,14 +1042,12 @@ mod tests { plan, @r" Projection: customer.c_custkey [c_custkey:Int64] - Projection: customer.c_custkey, customer.c_name [c_custkey:Int64, c_name:Utf8] - Filter: customer.c_custkey < __scalar_sq_1.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N] - Left Join: Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N] - TableScan: customer [c_custkey:Int64, c_name:Utf8] - SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N] - Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N] - Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N] - TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + Filter: customer.c_custkey < () [c_custkey:Int64, c_name:Utf8] + Subquery: [max(orders.o_custkey):Int64;N] + Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N] + Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + TableScan: customer [c_custkey:Int64, c_name:Utf8] " ) } @@ -1058,14 +1070,12 @@ mod tests { plan, @r" Projection: customer.c_custkey [c_custkey:Int64] - Projection: customer.c_custkey, customer.c_name [c_custkey:Int64, c_name:Utf8] - Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N] - Left Join: Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N] - TableScan: customer [c_custkey:Int64, c_name:Utf8] - SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N] - Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N] - Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N] - TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + Filter: customer.c_custkey = () [c_custkey:Int64, c_name:Utf8] + Subquery: [max(orders.o_custkey):Int64;N] + Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N] + Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + TableScan: customer [c_custkey:Int64, c_name:Utf8] " ) } @@ -1157,19 +1167,16 @@ mod tests { plan, @r" Projection: customer.c_custkey [c_custkey:Int64] - Projection: customer.c_custkey, customer.c_name [c_custkey:Int64, c_name:Utf8] - Filter: customer.c_custkey BETWEEN __scalar_sq_1.min(orders.o_custkey) AND __scalar_sq_2.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, min(orders.o_custkey):Int64;N, max(orders.o_custkey):Int64;N] - Left Join: Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, min(orders.o_custkey):Int64;N, max(orders.o_custkey):Int64;N] - Left Join: Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, min(orders.o_custkey):Int64;N] - TableScan: customer [c_custkey:Int64, c_name:Utf8] - SubqueryAlias: __scalar_sq_1 [min(orders.o_custkey):Int64;N] - Projection: min(orders.o_custkey) [min(orders.o_custkey):Int64;N] - Aggregate: groupBy=[[]], aggr=[[min(orders.o_custkey)]] [min(orders.o_custkey):Int64;N] - TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] - SubqueryAlias: __scalar_sq_2 [max(orders.o_custkey):Int64;N] - Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N] - Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N] - TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + Filter: customer.c_custkey BETWEEN () AND () [c_custkey:Int64, c_name:Utf8] + Subquery: [min(orders.o_custkey):Int64;N] + Projection: min(orders.o_custkey) [min(orders.o_custkey):Int64;N] + Aggregate: groupBy=[[]], aggr=[[min(orders.o_custkey)]] [min(orders.o_custkey):Int64;N] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + Subquery: [max(orders.o_custkey):Int64;N] + Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N] + Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + TableScan: customer [c_custkey:Int64, c_name:Utf8] " ) } diff --git a/datafusion/optimizer/tests/optimizer_integration.rs b/datafusion/optimizer/tests/optimizer_integration.rs index fd4991c24413f..391e78e8d3b0c 100644 --- a/datafusion/optimizer/tests/optimizer_integration.rs +++ b/datafusion/optimizer/tests/optimizer_integration.rs @@ -128,15 +128,13 @@ fn subquery_filter_with_cast() -> Result<()> { assert_snapshot!( format!("{plan}"), @r#" - Projection: test.col_int32 - Inner Join: Filter: CAST(test.col_int32 AS Float64) > __scalar_sq_1.avg(test.col_int32) - TableScan: test projection=[col_int32] - SubqueryAlias: __scalar_sq_1 - Aggregate: groupBy=[[]], aggr=[[avg(CAST(test.col_int32 AS Float64))]] - Projection: test.col_int32 - Filter: __common_expr_4 >= Date32("2002-05-08") AND __common_expr_4 <= Date32("2002-05-13") - Projection: CAST(test.col_utf8 AS Date32) AS __common_expr_4, test.col_int32 - TableScan: test projection=[col_int32, col_utf8] + Filter: CAST(test.col_int32 AS Float64) > () + Subquery: + Aggregate: groupBy=[[]], aggr=[[avg(CAST(test.col_int32 AS Float64))]] + Projection: test.col_int32 + Filter: CAST(test.col_utf8 AS Date32) >= Date32("2002-05-08") AND CAST(test.col_utf8 AS Date32) <= Date32("2002-05-13") + TableScan: test projection=[col_int32, col_utf8] + TableScan: test projection=[col_int32] "# ); Ok(()) diff --git a/datafusion/physical-expr/src/lib.rs b/datafusion/physical-expr/src/lib.rs index bedd348dab92f..9c567ce862149 100644 --- a/datafusion/physical-expr/src/lib.rs +++ b/datafusion/physical-expr/src/lib.rs @@ -40,6 +40,7 @@ mod physical_expr; pub mod planner; pub mod projection; mod scalar_function; +pub mod scalar_subquery; pub mod simplifier; pub mod statistics; pub mod utils; diff --git a/datafusion/physical-expr/src/planner.rs b/datafusion/physical-expr/src/planner.rs index fd2de812e4664..1f1a2000a0bd3 100644 --- a/datafusion/physical-expr/src/planner.rs +++ b/datafusion/physical-expr/src/planner.rs @@ -18,6 +18,7 @@ use std::sync::Arc; use crate::ScalarFunctionExpr; +use crate::scalar_subquery::ScalarSubqueryExpr; use crate::{ PhysicalExpr, expressions::{self, Column, Literal, binary, like, similar_to}, @@ -409,6 +410,29 @@ pub fn create_physical_expr( expressions::in_list(value_expr, list_exprs, negated, input_schema) } }, + Expr::ScalarSubquery(sq) => { + match execution_props.subquery_indexes.get(sq) { + Some(&index) => { + let schema = sq.subquery.schema(); + let dt = schema.field(0).data_type().clone(); + let nullable = schema.field(0).is_nullable(); + Ok(Arc::new(ScalarSubqueryExpr::new( + dt, + nullable, + index, + Arc::clone(&execution_props.subquery_results), + ))) + } + None => { + // Not found: either a correlated subquery that wasn't + // rewritten to a join, or an uncorrelated one that wasn't + // registered by the physical planner. + not_impl_err!( + "Physical plan does not support logical expression {e:?}" + ) + } + } + } Expr::Placeholder(Placeholder { id, .. }) => { exec_err!("Placeholder '{id}' was not provided a value for execution.") } diff --git a/datafusion/physical-expr/src/scalar_subquery.rs b/datafusion/physical-expr/src/scalar_subquery.rs new file mode 100644 index 0000000000000..0a39091c73f0e --- /dev/null +++ b/datafusion/physical-expr/src/scalar_subquery.rs @@ -0,0 +1,230 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Physical expression for uncorrelated scalar subqueries. +//! +//! [`ScalarSubqueryExpr`] reads a cached [`ScalarValue`] that is populated +//! at execution time by `ScalarSubqueryExec`. + +use std::any::Any; +use std::fmt; +use std::hash::Hash; +use std::sync::{Arc, OnceLock}; + +use arrow::datatypes::{DataType, Field, FieldRef, Schema}; +use arrow::record_batch::RecordBatch; +use datafusion_common::{Result, ScalarValue, internal_datafusion_err}; +use datafusion_expr_common::columnar_value::ColumnarValue; +use datafusion_expr_common::sort_properties::{ExprProperties, SortProperties}; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; + +/// A physical expression whose value is provided by a scalar subquery. +/// +/// Subquery execution is handled by `ScalarSubqueryExec`, which stores the +/// result in a shared results container. This expression simply reads from the +/// shared results container at the appropriate index. +/// +/// If the same subquery appears multiple times in a query, there will be +/// multiple `ScalarSubqueryExpr` with the same result index. +#[derive(Debug)] +pub struct ScalarSubqueryExpr { + data_type: DataType, + nullable: bool, + /// Index of this subquery in the shared results container. + index: usize, + /// Shared results container populated by `ScalarSubqueryExec`. + results: Arc>>, +} + +impl ScalarSubqueryExpr { + pub fn new( + data_type: DataType, + nullable: bool, + index: usize, + results: Arc>>, + ) -> Self { + Self { + data_type, + nullable, + index, + results, + } + } + + pub fn data_type(&self) -> &DataType { + &self.data_type + } + + pub fn nullable(&self) -> bool { + self.nullable + } + + /// Returns the index of this subquery in the shared results container. + pub fn index(&self) -> usize { + self.index + } + + pub fn results(&self) -> &Arc>> { + &self.results + } +} + +impl fmt::Display for ScalarSubqueryExpr { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self.results.get(self.index).and_then(|slot| slot.get()) { + Some(v) => write!(f, "scalar_subquery({v})"), + None => write!(f, "scalar_subquery()"), + } + } +} + +// Two ScalarSubqueryExprs are the "same" if they share the same results +// container and have the same index. +impl Hash for ScalarSubqueryExpr { + fn hash(&self, state: &mut H) { + Arc::as_ptr(&self.results).hash(state); + self.index.hash(state); + } +} + +impl PartialEq for ScalarSubqueryExpr { + fn eq(&self, other: &Self) -> bool { + Arc::ptr_eq(&self.results, &other.results) && self.index == other.index + } +} + +impl Eq for ScalarSubqueryExpr {} + +impl PhysicalExpr for ScalarSubqueryExpr { + fn as_any(&self) -> &dyn Any { + self + } + + fn return_field(&self, _input_schema: &Schema) -> Result { + Ok(Arc::new(Field::new( + format!("{self}"), + self.data_type.clone(), + self.nullable, + ))) + } + + fn evaluate(&self, _batch: &RecordBatch) -> Result { + let value = self + .results + .get(self.index) + .and_then(|slot| slot.get()) + .ok_or_else(|| { + internal_datafusion_err!( + "ScalarSubqueryExpr evaluated before the subquery was executed" + ) + })?; + Ok(ColumnarValue::Scalar(value.clone())) + } + + fn children(&self) -> Vec<&Arc> { + vec![] + } + + fn with_new_children( + self: Arc, + _children: Vec>, + ) -> Result> { + Ok(self) + } + + fn get_properties(&self, _children: &[ExprProperties]) -> Result { + Ok(ExprProperties::new_unknown().with_order(SortProperties::Singleton)) + } + + fn fmt_sql(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "(scalar subquery)") + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use arrow::array::Int32Array; + use arrow::datatypes::Field; + + fn make_results(values: Vec>) -> Arc>> { + Arc::new( + values + .into_iter() + .map(|v| { + let lock = OnceLock::new(); + if let Some(val) = v { + lock.set(val).unwrap(); + } + lock + }) + .collect(), + ) + } + + #[test] + fn test_evaluate_with_value() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]); + let a = Int32Array::from(vec![1, 2, 3]); + let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)])?; + + let results = make_results(vec![Some(ScalarValue::Int32(Some(42)))]); + let expr = ScalarSubqueryExpr::new(DataType::Int32, false, 0, results); + + let result = expr.evaluate(&batch)?; + match result { + ColumnarValue::Scalar(ScalarValue::Int32(Some(42))) => {} + other => panic!("Expected Scalar(Int32(42)), got {other:?}"), + } + Ok(()) + } + + #[test] + fn test_evaluate_before_populated() { + let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]); + let a = Int32Array::from(vec![1]); + let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)]).unwrap(); + + let results = Arc::new(vec![OnceLock::new()]); + let expr = ScalarSubqueryExpr::new(DataType::Int32, false, 0, results); + + let result = expr.evaluate(&batch); + assert!(result.is_err()); + } + + #[test] + fn test_identity_equality() { + let results = make_results(vec![None, None]); + + let e1a = + ScalarSubqueryExpr::new(DataType::Int32, false, 0, Arc::clone(&results)); + let e1b = + ScalarSubqueryExpr::new(DataType::Int32, false, 0, Arc::clone(&results)); + let e2 = ScalarSubqueryExpr::new(DataType::Int32, false, 1, Arc::clone(&results)); + + // Same container + same index → equal + assert_eq!(e1a, e1b); + // Same container, different index → not equal + assert_ne!(e1a, e2); + + // Different container, same index → not equal + let other_results = make_results(vec![None]); + let e3 = ScalarSubqueryExpr::new(DataType::Int32, false, 0, other_results); + assert_ne!(e1a, e3); + } +} diff --git a/datafusion/physical-plan/src/lib.rs b/datafusion/physical-plan/src/lib.rs index 6467d7a2e389d..94b7e2a1e0cae 100644 --- a/datafusion/physical-plan/src/lib.rs +++ b/datafusion/physical-plan/src/lib.rs @@ -84,6 +84,7 @@ pub mod placeholder_row; pub mod projection; pub mod recursive_query; pub mod repartition; +pub mod scalar_subquery; pub mod sort_pushdown; pub mod sorts; pub mod spill; diff --git a/datafusion/physical-plan/src/scalar_subquery.rs b/datafusion/physical-plan/src/scalar_subquery.rs new file mode 100644 index 0000000000000..8be4c5189cff9 --- /dev/null +++ b/datafusion/physical-plan/src/scalar_subquery.rs @@ -0,0 +1,514 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Execution plan for uncorrelated scalar subqueries. +//! +//! [`ScalarSubqueryExec`] wraps a main input plan and a set of subquery plans. +//! At execution time, it runs each subquery exactly once, extracts the scalar +//! result, and populates the shared results container that +//! [`ScalarSubqueryExpr`] instances read from by index. +//! +//! [`ScalarSubqueryExpr`]: datafusion_physical_expr::scalar_subquery::ScalarSubqueryExpr + +use std::fmt; +use std::sync::Arc; + +use datafusion_common::tree_node::TreeNodeRecursion; +use datafusion_common::{Result, ScalarValue, Statistics, exec_err, internal_err}; +use datafusion_execution::TaskContext; +use datafusion_expr::execution_props::ScalarSubqueryResults; +use datafusion_physical_expr::PhysicalExpr; + +use crate::execution_plan::{CardinalityEffect, ExecutionPlan, PlanProperties}; +use crate::joins::utils::{OnceAsync, OnceFut}; +use crate::stream::RecordBatchStreamAdapter; +use crate::{DisplayAs, DisplayFormatType, SendableRecordBatchStream}; + +use futures::StreamExt; +use futures::TryStreamExt; + +/// Links a scalar subquery's execution plan to its index in the shared results +/// container. The [`ScalarSubqueryExec`] that owns these links populates +/// `results[index]` at execution time, and [`ScalarSubqueryExpr`] instances +/// with the same index read from it. +/// +/// [`ScalarSubqueryExpr`]: datafusion_physical_expr::scalar_subquery::ScalarSubqueryExpr +#[derive(Debug, Clone)] +pub struct ScalarSubqueryLink { + /// The physical plan for the subquery. + pub plan: Arc, + /// Index into the shared results container. + pub index: usize, +} + +/// Manages execution of uncorrelated scalar subqueries for a single plan +/// level. +/// +/// The first child node is the **main input plan**, whose batches are passed +/// through unchanged. The remaining children are **subquery plans**, each of +/// which must produce exactly zero or one row. Before any batches from the main +/// input are yielded, all subquery plans are executed and their scalar results +/// are stored in a shared results container ([`ScalarSubqueryResults`]). +/// [`ScalarSubqueryExpr`] nodes embedded in the main input's expressions read +/// from this container by index. +/// +/// All subqueries are evaluated eagerly when the first output partition is +/// requested, before any rows from the main input are produced. +/// +/// TODO: Consider overlapping computation of the subqueries with evaluating the +/// main query. +/// +/// [`ScalarSubqueryExpr`]: datafusion_physical_expr::scalar_subquery::ScalarSubqueryExpr +#[derive(Debug)] +pub struct ScalarSubqueryExec { + /// The main input plan whose output is passed through. + input: Arc, + /// Subquery plans and their result indexes. + subqueries: Vec, + /// Shared one-time async computation of subquery results. + subquery_future: Arc>, + /// Shared results container; same instance held by ScalarSubqueryExpr nodes. + results: ScalarSubqueryResults, + /// Cached plan properties (copied from input). + cache: Arc, +} + +impl ScalarSubqueryExec { + pub fn new( + input: Arc, + subqueries: Vec, + results: ScalarSubqueryResults, + ) -> Self { + let cache = Arc::clone(input.properties()); + Self { + input, + subqueries, + subquery_future: Arc::default(), + results, + cache, + } + } + + pub fn input(&self) -> &Arc { + &self.input + } + + pub fn subqueries(&self) -> &[ScalarSubqueryLink] { + &self.subqueries + } + + pub fn results(&self) -> &ScalarSubqueryResults { + &self.results + } + + /// Returns a per-child bool vec that is `true` for the main input + /// (child 0) and `false` for every subquery child. + fn true_for_input_only(&self) -> Vec { + std::iter::once(true) + .chain(std::iter::repeat_n(false, self.subqueries.len())) + .collect() + } +} + +impl DisplayAs for ScalarSubqueryExec { + fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!( + f, + "ScalarSubqueryExec: subqueries={}", + self.subqueries.len() + ) + } + DisplayFormatType::TreeRender => { + write!(f, "") + } + } + } +} + +impl ExecutionPlan for ScalarSubqueryExec { + fn name(&self) -> &'static str { + "ScalarSubqueryExec" + } + + fn properties(&self) -> &Arc { + &self.cache + } + + fn children(&self) -> Vec<&Arc> { + let mut children = vec![&self.input]; + for sq in &self.subqueries { + children.push(&sq.plan); + } + children + } + + fn with_new_children( + self: Arc, + mut children: Vec>, + ) -> Result> { + // First child is the main input, the rest are subquery plans. + let input = children.remove(0); + let subqueries = self + .subqueries + .iter() + .zip(children) + .map(|(sq, new_plan)| ScalarSubqueryLink { + plan: new_plan, + index: sq.index, + }) + .collect(); + Ok(Arc::new(ScalarSubqueryExec::new( + input, + subqueries, + Arc::clone(&self.results), + ))) + } + + fn execute( + &self, + partition: usize, + context: Arc, + ) -> Result { + let subqueries = self.subqueries.clone(); + let results = Arc::clone(&self.results); + let subquery_ctx = Arc::clone(&context); + let mut subquery_future = self.subquery_future.try_once(move || { + Ok(async move { execute_subqueries(subqueries, results, subquery_ctx).await }) + })?; + let input = Arc::clone(&self.input); + let schema = self.schema(); + + Ok(Box::pin(RecordBatchStreamAdapter::new( + schema, + futures::stream::once(async move { + // Execute all subqueries exactly once, even when multiple + // partitions call execute() concurrently. + wait_for_subqueries(&mut subquery_future).await?; + + // Now that the subqueries have finished execution, we can + // safely execute the main input + input.execute(partition, context) + }) + .try_flatten(), + ))) + } + + fn apply_expressions( + &self, + _f: &mut dyn FnMut(&dyn PhysicalExpr) -> Result, + ) -> Result { + Ok(TreeNodeRecursion::Continue) + } + + fn maintains_input_order(&self) -> Vec { + // Only the main input (first child); subquery children don't + // contribute to ordering. + self.true_for_input_only() + } + + fn benefits_from_input_partitioning(&self) -> Vec { + // Only the main input; subquery children produce at most one + // row, so repartitioning them adds overhead with no benefit. + self.true_for_input_only() + } + + fn partition_statistics(&self, partition: Option) -> Result> { + self.input.partition_statistics(partition) + } + + fn cardinality_effect(&self) -> CardinalityEffect { + CardinalityEffect::Equal + } +} + +/// Execute all subquery plans, extract their scalar results, and populate +/// the shared results container. +async fn wait_for_subqueries(fut: &mut OnceFut<()>) -> Result<()> { + std::future::poll_fn(|cx| fut.get_shared(cx)).await?; + Ok(()) +} + +async fn execute_subqueries( + subqueries: Vec, + results: ScalarSubqueryResults, + context: Arc, +) -> Result<()> { + // Evaluate subqueries in parallel; wait for them all to finish evaluation + // before returning. + let futures = subqueries.iter().map(|sq| { + let plan = Arc::clone(&sq.plan); + let ctx = Arc::clone(&context); + let results = Arc::clone(&results); + let index = sq.index; + async move { + let value = execute_scalar_subquery(plan, ctx).await?; + if results[index].set(value).is_err() { + return internal_err!( + "ScalarSubqueryExec: result for index {index} was already populated" + ); + } + Ok(()) as Result<()> + } + }); + futures::future::try_join_all(futures).await?; + Ok(()) +} + +/// Execute a single subquery plan and extract the scalar value. +/// Returns NULL for 0 rows, the scalar value for exactly 1 row, +/// or an error for >1 rows. +async fn execute_scalar_subquery( + plan: Arc, + context: Arc, +) -> Result { + let schema = plan.schema(); + if schema.fields().len() != 1 { + // Should be enforced by the physical planner. + return internal_err!( + "Scalar subquery must return exactly one column, got {}", + schema.fields().len() + ); + } + + let mut stream = crate::execute_stream(plan, context)?; + let mut result: Option = None; + + while let Some(batch) = stream.next().await.transpose()? { + if batch.num_rows() == 0 { + continue; + } + if result.is_some() || batch.num_rows() > 1 { + return exec_err!("Scalar subquery returned more than one row"); + } + result = Some(ScalarValue::try_from_array(batch.column(0), 0)?); + } + + // 0 rows → typed NULL per SQL semantics + match result { + Some(v) => Ok(v), + None => ScalarValue::try_from(schema.field(0).data_type()), + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::test::{self, TestMemoryExec}; + + use std::sync::OnceLock; + use std::sync::atomic::{AtomicUsize, Ordering}; + + use crate::test::exec::ErrorExec; + use arrow::array::{Int32Array, Int64Array}; + use arrow::datatypes::{DataType, Field, Schema}; + use arrow::record_batch::RecordBatch; + + #[derive(Debug)] + struct CountingExec { + inner: Arc, + execute_calls: Arc, + } + + impl CountingExec { + fn new(inner: Arc, execute_calls: Arc) -> Self { + Self { + inner, + execute_calls, + } + } + } + + impl DisplayAs for CountingExec { + fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!(f, "CountingExec") + } + DisplayFormatType::TreeRender => write!(f, ""), + } + } + } + + impl ExecutionPlan for CountingExec { + fn name(&self) -> &'static str { + "CountingExec" + } + + fn properties(&self) -> &Arc { + self.inner.properties() + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.inner] + } + + fn with_new_children( + self: Arc, + mut children: Vec>, + ) -> Result> { + Ok(Arc::new(Self::new( + children.remove(0), + Arc::clone(&self.execute_calls), + ))) + } + + fn apply_expressions( + &self, + _f: &mut dyn FnMut(&dyn PhysicalExpr) -> Result, + ) -> Result { + Ok(TreeNodeRecursion::Continue) + } + + fn execute( + &self, + partition: usize, + context: Arc, + ) -> Result { + self.execute_calls.fetch_add(1, Ordering::SeqCst); + self.inner.execute(partition, context) + } + } + + fn make_subquery_plan(batches: Vec) -> Arc { + let schema = batches[0].schema(); + TestMemoryExec::try_new_exec(&[batches], schema, None).unwrap() + } + + fn make_results(n: usize) -> ScalarSubqueryResults { + Arc::new((0..n).map(|_| OnceLock::new()).collect()) + } + + #[tokio::test] + async fn test_single_row_subquery() -> Result<()> { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![Arc::new(Int32Array::from(vec![42]))], + )?; + + let results = make_results(1); + let subquery_plan = make_subquery_plan(vec![batch]); + let sq = ScalarSubqueryLink { + plan: subquery_plan, + index: 0, + }; + + let main_input = Arc::new(crate::placeholder_row::PlaceholderRowExec::new( + test::aggr_test_schema(), + )); + let exec = ScalarSubqueryExec::new(main_input, vec![sq], Arc::clone(&results)); + + let ctx = Arc::new(TaskContext::default()); + let stream = exec.execute(0, ctx)?; + let _batches = crate::common::collect(stream).await?; + + assert_eq!(results[0].get(), Some(&ScalarValue::Int32(Some(42)))); + Ok(()) + } + + #[tokio::test] + async fn test_zero_row_subquery_returns_null() -> Result<()> { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, true)])); + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![Arc::new(Int64Array::from(vec![] as Vec))], + )?; + + let results = make_results(1); + let subquery_plan = make_subquery_plan(vec![batch]); + let sq = ScalarSubqueryLink { + plan: subquery_plan, + index: 0, + }; + + let main_input = Arc::new(crate::placeholder_row::PlaceholderRowExec::new( + test::aggr_test_schema(), + )); + let exec = ScalarSubqueryExec::new(main_input, vec![sq], Arc::clone(&results)); + + let ctx = Arc::new(TaskContext::default()); + let stream = exec.execute(0, ctx)?; + let _batches = crate::common::collect(stream).await?; + + assert_eq!(results[0].get(), Some(&ScalarValue::Int64(None))); + Ok(()) + } + + #[tokio::test] + async fn test_multi_row_subquery_errors() -> Result<()> { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], + )?; + + let results = make_results(1); + let subquery_plan = make_subquery_plan(vec![batch]); + let sq = ScalarSubqueryLink { + plan: subquery_plan, + index: 0, + }; + + let main_input = Arc::new(crate::placeholder_row::PlaceholderRowExec::new( + test::aggr_test_schema(), + )); + let exec = ScalarSubqueryExec::new(main_input, vec![sq], Arc::clone(&results)); + + let ctx = Arc::new(TaskContext::default()); + let stream = exec.execute(0, ctx)?; + let result = crate::common::collect(stream).await; + + assert!(result.is_err()); + let err_msg = result.unwrap_err().to_string(); + assert!( + err_msg.contains("more than one row"), + "Expected 'more than one row' error, got: {err_msg}" + ); + Ok(()) + } + + #[tokio::test] + async fn test_failed_subquery_is_not_retried() -> Result<()> { + let execute_calls = Arc::new(AtomicUsize::new(0)); + let subquery_plan = Arc::new(CountingExec::new( + Arc::new(ErrorExec::new()), + Arc::clone(&execute_calls), + )); + let results = make_results(1); + let sq = ScalarSubqueryLink { + plan: subquery_plan, + index: 0, + }; + + let main_input = Arc::new(crate::placeholder_row::PlaceholderRowExec::new( + test::aggr_test_schema(), + )); + let exec = ScalarSubqueryExec::new(main_input, vec![sq], results); + + let ctx = Arc::new(TaskContext::default()); + let stream = exec.execute(0, Arc::clone(&ctx))?; + assert!(crate::common::collect(stream).await.is_err()); + + let stream = exec.execute(0, ctx)?; + assert!(crate::common::collect(stream).await.is_err()); + + assert_eq!(execute_calls.load(Ordering::SeqCst), 1); + Ok(()) + } +} diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 9464e85727d4f..bf4aa65937a01 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -426,6 +426,8 @@ message LogicalExprNode { Unnest unnest = 35; + // Subquery expressions + ScalarSubqueryExprNode scalar_subquery_expr = 36; } } @@ -433,6 +435,15 @@ message Wildcard { TableReference qualifier = 1; } +message SubqueryNode { + LogicalPlanNode subquery = 1; + repeated LogicalExprNode outer_ref_columns = 2; +} + +message ScalarSubqueryExprNode { + SubqueryNode subquery = 1; +} + message PlaceholderNode { string id = 1; // We serialize the data type, metadata, and nullability separately to maintain @@ -775,6 +786,7 @@ message PhysicalPlanNode { AsyncFuncExecNode async_func = 36; BufferExecNode buffer = 37; ArrowScanExecNode arrow_scan = 38; + ScalarSubqueryExecNode scalar_subquery = 39; } } @@ -920,6 +932,8 @@ message PhysicalExprNode { UnknownColumn unknown_column = 20; PhysicalHashExprNode hash_expr = 21; + + PhysicalScalarSubqueryExprNode scalar_subquery = 22; } } @@ -1474,4 +1488,15 @@ message AsyncFuncExecNode { message BufferExecNode { PhysicalPlanNode input = 1; uint64 capacity = 2; +} + +message ScalarSubqueryExecNode { + PhysicalPlanNode input = 1; + repeated PhysicalPlanNode subqueries = 2; +} + +message PhysicalScalarSubqueryExprNode { + datafusion_common.ArrowType data_type = 1; + bool nullable = 2; + uint32 index = 3; } \ No newline at end of file diff --git a/datafusion/proto/src/bytes/mod.rs b/datafusion/proto/src/bytes/mod.rs index 84b15ea9a8920..820bbc20af46c 100644 --- a/datafusion/proto/src/bytes/mod.rs +++ b/datafusion/proto/src/bytes/mod.rs @@ -27,23 +27,15 @@ use crate::physical_plan::{ use crate::protobuf; use datafusion_common::{Result, plan_datafusion_err}; use datafusion_execution::TaskContext; -use datafusion_expr::{ - AggregateUDF, Expr, LogicalPlan, Volatility, WindowUDF, create_udaf, create_udf, - create_udwf, -}; +use datafusion_expr::{Expr, LogicalPlan}; use prost::{ Message, bytes::{Bytes, BytesMut}, }; use std::sync::Arc; -// Reexport Bytes which appears in the API -use datafusion_execution::registry::FunctionRegistry; -use datafusion_expr::planner::ExprPlanner; use datafusion_physical_plan::ExecutionPlan; -mod registry; - /// Encodes something (such as [`Expr`]) to/from a stream of /// bytes. /// @@ -65,26 +57,21 @@ pub trait Serializeable: Sized { /// Convert `self` to an opaque byte stream fn to_bytes(&self) -> Result; - /// Convert `bytes` (the output of [`to_bytes`]) back into an - /// object. This will error if the serialized bytes contain any - /// user defined functions, in which case use - /// [`from_bytes_with_registry`] + /// Convert `bytes` (the output of [`to_bytes`]) back into an object. This + /// will error if the serialized bytes contain any user defined functions, + /// in which case use [`from_bytes_with_ctx`] /// /// [`to_bytes`]: Self::to_bytes - /// [`from_bytes_with_registry`]: Self::from_bytes_with_registry + /// [`from_bytes_with_ctx`]: Self::from_bytes_with_ctx fn from_bytes(bytes: &[u8]) -> Result { - Self::from_bytes_with_registry(bytes, ®istry::NoRegistry {}) + Self::from_bytes_with_ctx(bytes, &TaskContext::default()) } - /// Convert `bytes` (the output of [`to_bytes`]) back into an - /// object resolving user defined functions with the specified - /// `registry` + /// Convert `bytes` (the output of [`to_bytes`]) back into an object + /// resolving user defined functions with the specified `ctx` /// /// [`to_bytes`]: Self::to_bytes - fn from_bytes_with_registry( - bytes: &[u8], - registry: &dyn FunctionRegistry, - ) -> Result; + fn from_bytes_with_ctx(bytes: &[u8], ctx: &TaskContext) -> Result; } impl Serializeable for Expr { @@ -100,100 +87,22 @@ impl Serializeable for Expr { let bytes: Bytes = buffer.into(); - // the produced byte stream may lead to "recursion limit" errors, see + // The produced byte stream may lead to "recursion limit" errors, see // https://github.com/apache/datafusion/issues/3968 - // Until the underlying prost issue ( https://github.com/tokio-rs/prost/issues/736 ) is fixed, we try to - // deserialize the data here and check for errors. - // - // Need to provide some placeholder registry because the stream may contain UDFs - struct PlaceHolderRegistry; - - impl FunctionRegistry for PlaceHolderRegistry { - fn udfs(&self) -> std::collections::HashSet { - std::collections::HashSet::default() - } - - fn udf(&self, name: &str) -> Result> { - Ok(Arc::new(create_udf( - name, - vec![], - arrow::datatypes::DataType::Null, - Volatility::Immutable, - Arc::new(|_| unimplemented!()), - ))) - } - - fn udaf(&self, name: &str) -> Result> { - Ok(Arc::new(create_udaf( - name, - vec![arrow::datatypes::DataType::Null], - Arc::new(arrow::datatypes::DataType::Null), - Volatility::Immutable, - Arc::new(|_| unimplemented!()), - Arc::new(vec![]), - ))) - } - - fn udwf(&self, name: &str) -> Result> { - Ok(Arc::new(create_udwf( - name, - arrow::datatypes::DataType::Null, - Arc::new(arrow::datatypes::DataType::Null), - Volatility::Immutable, - Arc::new(|| unimplemented!()), - ))) - } - fn register_udaf( - &mut self, - _udaf: Arc, - ) -> Result>> { - datafusion_common::internal_err!( - "register_udaf called in Placeholder Registry!" - ) - } - fn register_udf( - &mut self, - _udf: Arc, - ) -> Result>> { - datafusion_common::internal_err!( - "register_udf called in Placeholder Registry!" - ) - } - fn register_udwf( - &mut self, - _udaf: Arc, - ) -> Result>> { - datafusion_common::internal_err!( - "register_udwf called in Placeholder Registry!" - ) - } - - fn expr_planners(&self) -> Vec> { - vec![] - } - - fn udafs(&self) -> std::collections::HashSet { - std::collections::HashSet::default() - } - - fn udwfs(&self) -> std::collections::HashSet { - std::collections::HashSet::default() - } - } - Expr::from_bytes_with_registry(&bytes, &PlaceHolderRegistry)?; + // Until the underlying prost issue ( https://github.com/tokio-rs/prost/issues/736 ) + // is fixed, verify the bytes can be decoded without hitting that limit. + protobuf::LogicalExprNode::decode(bytes.as_ref()) + .map_err(|e| plan_datafusion_err!("Error decoding expr as protobuf: {e}"))?; Ok(bytes) } - fn from_bytes_with_registry( - bytes: &[u8], - registry: &dyn FunctionRegistry, - ) -> Result { + fn from_bytes_with_ctx(bytes: &[u8], ctx: &TaskContext) -> Result { let protobuf = protobuf::LogicalExprNode::decode(bytes) .map_err(|e| plan_datafusion_err!("Error decoding expr as protobuf: {e}"))?; let extension_codec = DefaultLogicalExtensionCodec {}; - logical_plan::from_proto::parse_expr(&protobuf, registry, &extension_codec) + logical_plan::from_proto::parse_expr(&protobuf, ctx, &extension_codec) .map_err(|e| plan_datafusion_err!("Error parsing protobuf into Expr: {e}")) } } @@ -277,7 +186,7 @@ pub fn logical_plan_from_json_with_extension_codec( /// Serialize a PhysicalPlan as bytes pub fn physical_plan_to_bytes(plan: Arc) -> Result { let extension_codec = DefaultPhysicalExtensionCodec {}; - let proto_converter = DefaultPhysicalProtoConverter {}; + let proto_converter = DefaultPhysicalProtoConverter::default(); physical_plan_to_bytes_with_proto_converter(plan, &extension_codec, &proto_converter) } @@ -285,7 +194,7 @@ pub fn physical_plan_to_bytes(plan: Arc) -> Result { #[cfg(feature = "json")] pub fn physical_plan_to_json(plan: Arc) -> Result { let extension_codec = DefaultPhysicalExtensionCodec {}; - let proto_converter = DefaultPhysicalProtoConverter {}; + let proto_converter = DefaultPhysicalProtoConverter::default(); let protobuf = proto_converter .execution_plan_to_proto(&plan, &extension_codec) .map_err(|e| plan_datafusion_err!("Error serializing plan: {e}"))?; @@ -298,7 +207,7 @@ pub fn physical_plan_to_bytes_with_extension_codec( plan: Arc, extension_codec: &dyn PhysicalExtensionCodec, ) -> Result { - let proto_converter = DefaultPhysicalProtoConverter {}; + let proto_converter = DefaultPhysicalProtoConverter::default(); physical_plan_to_bytes_with_proto_converter(plan, extension_codec, &proto_converter) } @@ -326,7 +235,7 @@ pub fn physical_plan_from_json( let back: protobuf::PhysicalPlanNode = serde_json::from_str(json) .map_err(|e| plan_datafusion_err!("Error serializing plan: {e}"))?; let extension_codec = DefaultPhysicalExtensionCodec {}; - let proto_converter = DefaultPhysicalProtoConverter {}; + let proto_converter = DefaultPhysicalProtoConverter::default(); proto_converter.proto_to_execution_plan(ctx, &extension_codec, &back) } @@ -336,7 +245,7 @@ pub fn physical_plan_from_bytes( ctx: &TaskContext, ) -> Result> { let extension_codec = DefaultPhysicalExtensionCodec {}; - let proto_converter = DefaultPhysicalProtoConverter {}; + let proto_converter = DefaultPhysicalProtoConverter::default(); physical_plan_from_bytes_with_proto_converter( bytes, ctx, @@ -351,7 +260,7 @@ pub fn physical_plan_from_bytes_with_extension_codec( ctx: &TaskContext, extension_codec: &dyn PhysicalExtensionCodec, ) -> Result> { - let proto_converter = DefaultPhysicalProtoConverter {}; + let proto_converter = DefaultPhysicalProtoConverter::default(); physical_plan_from_bytes_with_proto_converter( bytes, ctx, diff --git a/datafusion/proto/src/bytes/registry.rs b/datafusion/proto/src/bytes/registry.rs deleted file mode 100644 index a3f74787e2b50..0000000000000 --- a/datafusion/proto/src/bytes/registry.rs +++ /dev/null @@ -1,85 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use std::{collections::HashSet, sync::Arc}; - -use datafusion_common::Result; -use datafusion_common::plan_err; -use datafusion_execution::registry::FunctionRegistry; -use datafusion_expr::planner::ExprPlanner; -use datafusion_expr::{AggregateUDF, ScalarUDF, WindowUDF}; - -/// A default [`FunctionRegistry`] registry that does not resolve any -/// user defined functions -pub(crate) struct NoRegistry {} - -impl FunctionRegistry for NoRegistry { - fn udfs(&self) -> HashSet { - HashSet::new() - } - - fn udf(&self, name: &str) -> Result> { - plan_err!( - "No function registry provided to deserialize, so can not deserialize User Defined Function '{name}'" - ) - } - - fn udaf(&self, name: &str) -> Result> { - plan_err!( - "No function registry provided to deserialize, so can not deserialize User Defined Aggregate Function '{name}'" - ) - } - - fn udwf(&self, name: &str) -> Result> { - plan_err!( - "No function registry provided to deserialize, so can not deserialize User Defined Window Function '{name}'" - ) - } - fn register_udaf( - &mut self, - udaf: Arc, - ) -> Result>> { - plan_err!( - "No function registry provided to deserialize, so can not register User Defined Aggregate Function '{}'", - udaf.inner().name() - ) - } - fn register_udf(&mut self, udf: Arc) -> Result>> { - plan_err!( - "No function registry provided to deserialize, so can not deserialize User Defined Function '{}'", - udf.inner().name() - ) - } - fn register_udwf(&mut self, udwf: Arc) -> Result>> { - plan_err!( - "No function registry provided to deserialize, so can not deserialize User Defined Window Function '{}'", - udwf.inner().name() - ) - } - - fn expr_planners(&self) -> Vec> { - vec![] - } - - fn udafs(&self) -> HashSet { - HashSet::new() - } - - fn udwfs(&self) -> HashSet { - HashSet::new() - } -} diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index c81e2fabe1185..a0715885dcfb4 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -12193,6 +12193,9 @@ impl serde::Serialize for LogicalExprNode { logical_expr_node::ExprType::Unnest(v) => { struct_ser.serialize_field("unnest", v)?; } + logical_expr_node::ExprType::ScalarSubqueryExpr(v) => { + struct_ser.serialize_field("scalarSubqueryExpr", v)?; + } } } struct_ser.end() @@ -12254,6 +12257,8 @@ impl<'de> serde::Deserialize<'de> for LogicalExprNode { "similarTo", "placeholder", "unnest", + "scalar_subquery_expr", + "scalarSubqueryExpr", ]; #[allow(clippy::enum_variant_names)] @@ -12289,6 +12294,7 @@ impl<'de> serde::Deserialize<'de> for LogicalExprNode { SimilarTo, Placeholder, Unnest, + ScalarSubqueryExpr, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -12341,6 +12347,7 @@ impl<'de> serde::Deserialize<'de> for LogicalExprNode { "similarTo" | "similar_to" => Ok(GeneratedField::SimilarTo), "placeholder" => Ok(GeneratedField::Placeholder), "unnest" => Ok(GeneratedField::Unnest), + "scalarSubqueryExpr" | "scalar_subquery_expr" => Ok(GeneratedField::ScalarSubqueryExpr), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -12578,6 +12585,13 @@ impl<'de> serde::Deserialize<'de> for LogicalExprNode { return Err(serde::de::Error::duplicate_field("unnest")); } expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::Unnest) +; + } + GeneratedField::ScalarSubqueryExpr => { + if expr_type__.is_some() { + return Err(serde::de::Error::duplicate_field("scalarSubqueryExpr")); + } + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::ScalarSubqueryExpr) ; } } @@ -16593,6 +16607,9 @@ impl serde::Serialize for PhysicalExprNode { physical_expr_node::ExprType::HashExpr(v) => { struct_ser.serialize_field("hashExpr", v)?; } + physical_expr_node::ExprType::ScalarSubquery(v) => { + struct_ser.serialize_field("scalarSubquery", v)?; + } } } struct_ser.end() @@ -16639,6 +16656,8 @@ impl<'de> serde::Deserialize<'de> for PhysicalExprNode { "unknownColumn", "hash_expr", "hashExpr", + "scalar_subquery", + "scalarSubquery", ]; #[allow(clippy::enum_variant_names)] @@ -16663,6 +16682,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalExprNode { Extension, UnknownColumn, HashExpr, + ScalarSubquery, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -16704,6 +16724,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalExprNode { "extension" => Ok(GeneratedField::Extension), "unknownColumn" | "unknown_column" => Ok(GeneratedField::UnknownColumn), "hashExpr" | "hash_expr" => Ok(GeneratedField::HashExpr), + "scalarSubquery" | "scalar_subquery" => Ok(GeneratedField::ScalarSubquery), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -16866,6 +16887,13 @@ impl<'de> serde::Deserialize<'de> for PhysicalExprNode { return Err(serde::de::Error::duplicate_field("hashExpr")); } expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::HashExpr) +; + } + GeneratedField::ScalarSubquery => { + if expr_type__.is_some() { + return Err(serde::de::Error::duplicate_field("scalarSubquery")); + } + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::ScalarSubquery) ; } } @@ -18104,6 +18132,9 @@ impl serde::Serialize for PhysicalPlanNode { physical_plan_node::PhysicalPlanType::ArrowScan(v) => { struct_ser.serialize_field("arrowScan", v)?; } + physical_plan_node::PhysicalPlanType::ScalarSubquery(v) => { + struct_ser.serialize_field("scalarSubquery", v)?; + } } } struct_ser.end() @@ -18174,6 +18205,8 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { "buffer", "arrow_scan", "arrowScan", + "scalar_subquery", + "scalarSubquery", ]; #[allow(clippy::enum_variant_names)] @@ -18215,6 +18248,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { AsyncFunc, Buffer, ArrowScan, + ScalarSubquery, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -18273,6 +18307,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { "asyncFunc" | "async_func" => Ok(GeneratedField::AsyncFunc), "buffer" => Ok(GeneratedField::Buffer), "arrowScan" | "arrow_scan" => Ok(GeneratedField::ArrowScan), + "scalarSubquery" | "scalar_subquery" => Ok(GeneratedField::ScalarSubquery), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -18552,6 +18587,13 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { return Err(serde::de::Error::duplicate_field("arrowScan")); } physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::ArrowScan) +; + } + GeneratedField::ScalarSubquery => { + if physical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("scalarSubquery")); + } + physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::ScalarSubquery) ; } } @@ -18564,6 +18606,134 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { deserializer.deserialize_struct("datafusion.PhysicalPlanNode", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for PhysicalScalarSubqueryExprNode { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.data_type.is_some() { + len += 1; + } + if self.nullable { + len += 1; + } + if self.index != 0 { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.PhysicalScalarSubqueryExprNode", len)?; + if let Some(v) = self.data_type.as_ref() { + struct_ser.serialize_field("dataType", v)?; + } + if self.nullable { + struct_ser.serialize_field("nullable", &self.nullable)?; + } + if self.index != 0 { + struct_ser.serialize_field("index", &self.index)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for PhysicalScalarSubqueryExprNode { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "data_type", + "dataType", + "nullable", + "index", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + DataType, + Nullable, + Index, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl serde::de::Visitor<'_> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "dataType" | "data_type" => Ok(GeneratedField::DataType), + "nullable" => Ok(GeneratedField::Nullable), + "index" => Ok(GeneratedField::Index), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = PhysicalScalarSubqueryExprNode; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.PhysicalScalarSubqueryExprNode") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut data_type__ = None; + let mut nullable__ = None; + let mut index__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::DataType => { + if data_type__.is_some() { + return Err(serde::de::Error::duplicate_field("dataType")); + } + data_type__ = map_.next_value()?; + } + GeneratedField::Nullable => { + if nullable__.is_some() { + return Err(serde::de::Error::duplicate_field("nullable")); + } + nullable__ = Some(map_.next_value()?); + } + GeneratedField::Index => { + if index__.is_some() { + return Err(serde::de::Error::duplicate_field("index")); + } + index__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + } + } + Ok(PhysicalScalarSubqueryExprNode { + data_type: data_type__, + nullable: nullable__.unwrap_or_default(), + index: index__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion.PhysicalScalarSubqueryExprNode", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for PhysicalScalarUdfNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result @@ -21233,6 +21403,205 @@ impl<'de> serde::Deserialize<'de> for RollupNode { deserializer.deserialize_struct("datafusion.RollupNode", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for ScalarSubqueryExecNode { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.input.is_some() { + len += 1; + } + if !self.subqueries.is_empty() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.ScalarSubqueryExecNode", len)?; + if let Some(v) = self.input.as_ref() { + struct_ser.serialize_field("input", v)?; + } + if !self.subqueries.is_empty() { + struct_ser.serialize_field("subqueries", &self.subqueries)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for ScalarSubqueryExecNode { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "input", + "subqueries", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Input, + Subqueries, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl serde::de::Visitor<'_> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "input" => Ok(GeneratedField::Input), + "subqueries" => Ok(GeneratedField::Subqueries), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = ScalarSubqueryExecNode; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.ScalarSubqueryExecNode") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut input__ = None; + let mut subqueries__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Input => { + if input__.is_some() { + return Err(serde::de::Error::duplicate_field("input")); + } + input__ = map_.next_value()?; + } + GeneratedField::Subqueries => { + if subqueries__.is_some() { + return Err(serde::de::Error::duplicate_field("subqueries")); + } + subqueries__ = Some(map_.next_value()?); + } + } + } + Ok(ScalarSubqueryExecNode { + input: input__, + subqueries: subqueries__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion.ScalarSubqueryExecNode", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for ScalarSubqueryExprNode { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.subquery.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.ScalarSubqueryExprNode", len)?; + if let Some(v) = self.subquery.as_ref() { + struct_ser.serialize_field("subquery", v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for ScalarSubqueryExprNode { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "subquery", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Subquery, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl serde::de::Visitor<'_> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "subquery" => Ok(GeneratedField::Subquery), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = ScalarSubqueryExprNode; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.ScalarSubqueryExprNode") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut subquery__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Subquery => { + if subquery__.is_some() { + return Err(serde::de::Error::duplicate_field("subquery")); + } + subquery__ = map_.next_value()?; + } + } + } + Ok(ScalarSubqueryExprNode { + subquery: subquery__, + }) + } + } + deserializer.deserialize_struct("datafusion.ScalarSubqueryExprNode", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for ScalarUdfExprNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result @@ -22910,6 +23279,115 @@ impl<'de> serde::Deserialize<'de> for SubqueryAliasNode { deserializer.deserialize_struct("datafusion.SubqueryAliasNode", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for SubqueryNode { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.subquery.is_some() { + len += 1; + } + if !self.outer_ref_columns.is_empty() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.SubqueryNode", len)?; + if let Some(v) = self.subquery.as_ref() { + struct_ser.serialize_field("subquery", v)?; + } + if !self.outer_ref_columns.is_empty() { + struct_ser.serialize_field("outerRefColumns", &self.outer_ref_columns)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for SubqueryNode { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "subquery", + "outer_ref_columns", + "outerRefColumns", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Subquery, + OuterRefColumns, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl serde::de::Visitor<'_> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "subquery" => Ok(GeneratedField::Subquery), + "outerRefColumns" | "outer_ref_columns" => Ok(GeneratedField::OuterRefColumns), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = SubqueryNode; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.SubqueryNode") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut subquery__ = None; + let mut outer_ref_columns__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Subquery => { + if subquery__.is_some() { + return Err(serde::de::Error::duplicate_field("subquery")); + } + subquery__ = map_.next_value()?; + } + GeneratedField::OuterRefColumns => { + if outer_ref_columns__.is_some() { + return Err(serde::de::Error::duplicate_field("outerRefColumns")); + } + outer_ref_columns__ = Some(map_.next_value()?); + } + } + } + Ok(SubqueryNode { + subquery: subquery__, + outer_ref_columns: outer_ref_columns__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion.SubqueryNode", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for SymmetricHashJoinExecNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index ff9133b1ced5b..e58ef5b0145be 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -195,8 +195,8 @@ pub mod projection_node { pub struct SelectionNode { #[prost(message, optional, boxed, tag = "1")] pub input: ::core::option::Option<::prost::alloc::boxed::Box>, - #[prost(message, optional, tag = "2")] - pub expr: ::core::option::Option, + #[prost(message, optional, boxed, tag = "2")] + pub expr: ::core::option::Option<::prost::alloc::boxed::Box>, } #[derive(Clone, PartialEq, ::prost::Message)] pub struct SortNode { @@ -382,8 +382,8 @@ pub struct JoinNode { pub right_join_key: ::prost::alloc::vec::Vec, #[prost(enumeration = "super::datafusion_common::NullEquality", tag = "7")] pub null_equality: i32, - #[prost(message, optional, tag = "8")] - pub filter: ::core::option::Option, + #[prost(message, optional, boxed, tag = "8")] + pub filter: ::core::option::Option<::prost::alloc::boxed::Box>, } #[derive(Clone, PartialEq, ::prost::Message)] pub struct DistinctNode { @@ -578,7 +578,7 @@ pub struct SubqueryAliasNode { pub struct LogicalExprNode { #[prost( oneof = "logical_expr_node::ExprType", - tags = "1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 13, 14, 15, 17, 18, 19, 20, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35" + tags = "1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 13, 14, 15, 17, 18, 19, 20, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36" )] pub expr_type: ::core::option::Option, } @@ -656,6 +656,9 @@ pub mod logical_expr_node { Placeholder(super::PlaceholderNode), #[prost(message, tag = "35")] Unnest(super::Unnest), + /// Subquery expressions + #[prost(message, tag = "36")] + ScalarSubqueryExpr(::prost::alloc::boxed::Box), } } #[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] @@ -664,6 +667,18 @@ pub struct Wildcard { pub qualifier: ::core::option::Option, } #[derive(Clone, PartialEq, ::prost::Message)] +pub struct SubqueryNode { + #[prost(message, optional, boxed, tag = "1")] + pub subquery: ::core::option::Option<::prost::alloc::boxed::Box>, + #[prost(message, repeated, tag = "2")] + pub outer_ref_columns: ::prost::alloc::vec::Vec, +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ScalarSubqueryExprNode { + #[prost(message, optional, boxed, tag = "1")] + pub subquery: ::core::option::Option<::prost::alloc::boxed::Box>, +} +#[derive(Clone, PartialEq, ::prost::Message)] pub struct PlaceholderNode { #[prost(string, tag = "1")] pub id: ::prost::alloc::string::String, @@ -1102,7 +1117,7 @@ pub mod table_reference { pub struct PhysicalPlanNode { #[prost( oneof = "physical_plan_node::PhysicalPlanType", - tags = "1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38" + tags = "1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39" )] pub physical_plan_type: ::core::option::Option, } @@ -1186,6 +1201,8 @@ pub mod physical_plan_node { Buffer(::prost::alloc::boxed::Box), #[prost(message, tag = "38")] ArrowScan(super::ArrowScanExecNode), + #[prost(message, tag = "39")] + ScalarSubquery(::prost::alloc::boxed::Box), } } #[derive(Clone, PartialEq, ::prost::Message)] @@ -1317,7 +1334,7 @@ pub struct PhysicalExprNode { pub expr_id: ::core::option::Option, #[prost( oneof = "physical_expr_node::ExprType", - tags = "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 15, 16, 18, 19, 20, 21" + tags = "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 15, 16, 18, 19, 20, 21, 22" )] pub expr_type: ::core::option::Option, } @@ -1370,6 +1387,8 @@ pub mod physical_expr_node { UnknownColumn(super::UnknownColumn), #[prost(message, tag = "21")] HashExpr(super::PhysicalHashExprNode), + #[prost(message, tag = "22")] + ScalarSubquery(super::PhysicalScalarSubqueryExprNode), } } #[derive(Clone, PartialEq, ::prost::Message)] @@ -2198,6 +2217,22 @@ pub struct BufferExecNode { #[prost(uint64, tag = "2")] pub capacity: u64, } +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ScalarSubqueryExecNode { + #[prost(message, optional, boxed, tag = "1")] + pub input: ::core::option::Option<::prost::alloc::boxed::Box>, + #[prost(message, repeated, tag = "2")] + pub subqueries: ::prost::alloc::vec::Vec, +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct PhysicalScalarSubqueryExprNode { + #[prost(message, optional, tag = "1")] + pub data_type: ::core::option::Option, + #[prost(bool, tag = "2")] + pub nullable: bool, + #[prost(uint32, tag = "3")] + pub index: u32, +} /// Identifies a built-in file format supported by DataFusion. /// Used by DefaultLogicalExtensionCodec to serialize/deserialize /// FileFormatFactory instances (e.g. in CopyTo plans). diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index ed33d9fab1820..78ffd362c8e48 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -23,10 +23,12 @@ use datafusion_common::{ NullEquality, RecursionUnnestOption, Result, ScalarValue, TableReference, UnnestOptions, exec_datafusion_err, internal_err, plan_datafusion_err, }; +use datafusion_execution::TaskContext; use datafusion_execution::registry::FunctionRegistry; use datafusion_expr::dml::InsertOp; use datafusion_expr::expr::{Alias, NullTreatment, Placeholder, Sort}; use datafusion_expr::expr::{Unnest, WildcardOptions}; +use datafusion_expr::logical_plan::Subquery; use datafusion_expr::{ Between, BinaryExpr, Case, Cast, Expr, GroupingSet, GroupingSet::GroupingSets, @@ -52,7 +54,7 @@ use crate::protobuf::{ }, }; -use super::LogicalExtensionCodec; +use super::{AsLogicalPlan, LogicalExtensionCodec}; impl From<&protobuf::UnnestOptions> for UnnestOptions { fn from(opts: &protobuf::UnnestOptions) -> Self { @@ -256,7 +258,7 @@ impl From for NullTreatment { pub fn parse_expr( proto: &protobuf::LogicalExprNode, - registry: &dyn FunctionRegistry, + ctx: &TaskContext, codec: &dyn LogicalExtensionCodec, ) -> Result { use protobuf::{logical_expr_node::ExprType, window_expr_node}; @@ -269,7 +271,7 @@ pub fn parse_expr( match expr_type { ExprType::BinaryExpr(binary_expr) => { let op = from_proto_binary_op(&binary_expr.op)?; - let operands = parse_exprs(&binary_expr.operands, registry, codec)?; + let operands = parse_exprs(&binary_expr.operands, ctx, codec)?; if operands.len() < 2 { return Err(proto_error( @@ -296,8 +298,8 @@ pub fn parse_expr( .window_function .as_ref() .ok_or_else(|| Error::required("window_function"))?; - let partition_by = parse_exprs(&expr.partition_by, registry, codec)?; - let mut order_by = parse_sorts(&expr.order_by, registry, codec)?; + let partition_by = parse_exprs(&expr.partition_by, ctx, codec)?; + let mut order_by = parse_sorts(&expr.order_by, ctx, codec)?; let window_frame = expr .window_frame .as_ref() @@ -329,7 +331,7 @@ pub fn parse_expr( window_expr_node::WindowFunction::Udaf(udaf_name) => { let udaf_function = match &expr.fun_definition { Some(buf) => codec.try_decode_udaf(udaf_name, buf)?, - None => registry + None => ctx .udaf(udaf_name) .or_else(|_| codec.try_decode_udaf(udaf_name, &[]))?, }; @@ -338,7 +340,7 @@ pub fn parse_expr( window_expr_node::WindowFunction::Udwf(udwf_name) => { let udwf_function = match &expr.fun_definition { Some(buf) => codec.try_decode_udwf(udwf_name, buf)?, - None => registry + None => ctx .udwf(udwf_name) .or_else(|_| codec.try_decode_udwf(udwf_name, &[]))?, }; @@ -346,7 +348,7 @@ pub fn parse_expr( } }; - let args = parse_exprs(&expr.exprs, registry, codec)?; + let args = parse_exprs(&expr.exprs, ctx, codec)?; let mut builder = Expr::from(WindowFunction::new(agg_fn, args)) .partition_by(partition_by) .order_by(order_by) @@ -357,8 +359,7 @@ pub fn parse_expr( builder = builder.distinct(); }; - if let Some(filter) = - parse_optional_expr(expr.filter.as_deref(), registry, codec)? + if let Some(filter) = parse_optional_expr(expr.filter.as_deref(), ctx, codec)? { builder = builder.filter(filter); } @@ -366,7 +367,7 @@ pub fn parse_expr( builder.build().map_err(Error::DataFusionError) } ExprType::Alias(alias) => Ok(Expr::Alias(Alias::new( - parse_required_expr(alias.expr.as_deref(), registry, "expr", codec)?, + parse_required_expr(alias.expr.as_deref(), ctx, "expr", codec)?, alias .relation .first() @@ -376,69 +377,69 @@ pub fn parse_expr( ))), ExprType::IsNullExpr(is_null) => Ok(Expr::IsNull(Box::new(parse_required_expr( is_null.expr.as_deref(), - registry, + ctx, "expr", codec, )?))), ExprType::IsNotNullExpr(is_not_null) => Ok(Expr::IsNotNull(Box::new( - parse_required_expr(is_not_null.expr.as_deref(), registry, "expr", codec)?, + parse_required_expr(is_not_null.expr.as_deref(), ctx, "expr", codec)?, ))), ExprType::NotExpr(not) => Ok(Expr::Not(Box::new(parse_required_expr( not.expr.as_deref(), - registry, + ctx, "expr", codec, )?))), ExprType::IsTrue(msg) => Ok(Expr::IsTrue(Box::new(parse_required_expr( msg.expr.as_deref(), - registry, + ctx, "expr", codec, )?))), ExprType::IsFalse(msg) => Ok(Expr::IsFalse(Box::new(parse_required_expr( msg.expr.as_deref(), - registry, + ctx, "expr", codec, )?))), ExprType::IsUnknown(msg) => Ok(Expr::IsUnknown(Box::new(parse_required_expr( msg.expr.as_deref(), - registry, + ctx, "expr", codec, )?))), ExprType::IsNotTrue(msg) => Ok(Expr::IsNotTrue(Box::new(parse_required_expr( msg.expr.as_deref(), - registry, + ctx, "expr", codec, )?))), ExprType::IsNotFalse(msg) => Ok(Expr::IsNotFalse(Box::new(parse_required_expr( msg.expr.as_deref(), - registry, + ctx, "expr", codec, )?))), ExprType::IsNotUnknown(msg) => Ok(Expr::IsNotUnknown(Box::new( - parse_required_expr(msg.expr.as_deref(), registry, "expr", codec)?, + parse_required_expr(msg.expr.as_deref(), ctx, "expr", codec)?, ))), ExprType::Between(between) => Ok(Expr::Between(Between::new( Box::new(parse_required_expr( between.expr.as_deref(), - registry, + ctx, "expr", codec, )?), between.negated, Box::new(parse_required_expr( between.low.as_deref(), - registry, + ctx, "expr", codec, )?), Box::new(parse_required_expr( between.high.as_deref(), - registry, + ctx, "expr", codec, )?), @@ -447,13 +448,13 @@ pub fn parse_expr( like.negated, Box::new(parse_required_expr( like.expr.as_deref(), - registry, + ctx, "expr", codec, )?), Box::new(parse_required_expr( like.pattern.as_deref(), - registry, + ctx, "pattern", codec, )?), @@ -464,13 +465,13 @@ pub fn parse_expr( like.negated, Box::new(parse_required_expr( like.expr.as_deref(), - registry, + ctx, "expr", codec, )?), Box::new(parse_required_expr( like.pattern.as_deref(), - registry, + ctx, "pattern", codec, )?), @@ -481,13 +482,13 @@ pub fn parse_expr( like.negated, Box::new(parse_required_expr( like.expr.as_deref(), - registry, + ctx, "expr", codec, )?), Box::new(parse_required_expr( like.pattern.as_deref(), - registry, + ctx, "pattern", codec, )?), @@ -501,13 +502,13 @@ pub fn parse_expr( .map(|e| { let when_expr = parse_required_expr( e.when_expr.as_ref(), - registry, + ctx, "when_expr", codec, )?; let then_expr = parse_required_expr( e.then_expr.as_ref(), - registry, + ctx, "then_expr", codec, )?; @@ -515,16 +516,15 @@ pub fn parse_expr( }) .collect::, Box)>, Error>>()?; Ok(Expr::Case(Case::new( - parse_optional_expr(case.expr.as_deref(), registry, codec)?.map(Box::new), + parse_optional_expr(case.expr.as_deref(), ctx, codec)?.map(Box::new), when_then_expr, - parse_optional_expr(case.else_expr.as_deref(), registry, codec)? - .map(Box::new), + parse_optional_expr(case.else_expr.as_deref(), ctx, codec)?.map(Box::new), ))) } ExprType::Cast(cast) => { let expr = Box::new(parse_required_expr( cast.expr.as_deref(), - registry, + ctx, "expr", codec, )?); @@ -537,7 +537,7 @@ pub fn parse_expr( ExprType::TryCast(cast) => { let expr = Box::new(parse_required_expr( cast.expr.as_deref(), - registry, + ctx, "expr", codec, )?); @@ -551,10 +551,10 @@ pub fn parse_expr( ))) } ExprType::Negative(negative) => Ok(Expr::Negative(Box::new( - parse_required_expr(negative.expr.as_deref(), registry, "expr", codec)?, + parse_required_expr(negative.expr.as_deref(), ctx, "expr", codec)?, ))), ExprType::Unnest(unnest) => { - let mut exprs = parse_exprs(&unnest.exprs, registry, codec)?; + let mut exprs = parse_exprs(&unnest.exprs, ctx, codec)?; if exprs.len() != 1 { return Err(proto_error("Unnest must have exactly one expression")); } @@ -563,11 +563,11 @@ pub fn parse_expr( ExprType::InList(in_list) => Ok(Expr::InList(InList::new( Box::new(parse_required_expr( in_list.expr.as_deref(), - registry, + ctx, "expr", codec, )?), - parse_exprs(&in_list.list, registry, codec)?, + parse_exprs(&in_list.list, ctx, codec)?, in_list.negated, ))), ExprType::Wildcard(protobuf::Wildcard { qualifier }) => { @@ -585,19 +585,19 @@ pub fn parse_expr( }) => { let scalar_fn = match fun_definition { Some(buf) => codec.try_decode_udf(fun_name, buf)?, - None => registry + None => ctx .udf(fun_name.as_str()) .or_else(|_| codec.try_decode_udf(fun_name, &[]))?, }; Ok(Expr::ScalarFunction(expr::ScalarFunction::new_udf( scalar_fn, - parse_exprs(args, registry, codec)?, + parse_exprs(args, ctx, codec)?, ))) } ExprType::AggregateUdfExpr(pb) => { let agg_fn = match &pb.fun_definition { Some(buf) => codec.try_decode_udaf(&pb.fun_name, buf)?, - None => registry + None => ctx .udaf(&pb.fun_name) .or_else(|_| codec.try_decode_udaf(&pb.fun_name, &[]))?, }; @@ -616,10 +616,10 @@ pub fn parse_expr( Ok(Expr::AggregateFunction(expr::AggregateFunction::new_udf( agg_fn, - parse_exprs(&pb.args, registry, codec)?, + parse_exprs(&pb.args, ctx, codec)?, pb.distinct, - parse_optional_expr(pb.filter.as_deref(), registry, codec)?.map(Box::new), - parse_sorts(&pb.order_by, registry, codec)?, + parse_optional_expr(pb.filter.as_deref(), ctx, codec)?.map(Box::new), + parse_sorts(&pb.order_by, ctx, codec)?, null_treatment, ))) } @@ -627,15 +627,15 @@ pub fn parse_expr( ExprType::GroupingSet(GroupingSetNode { expr }) => { Ok(Expr::GroupingSet(GroupingSets( expr.iter() - .map(|expr_list| parse_exprs(&expr_list.expr, registry, codec)) + .map(|expr_list| parse_exprs(&expr_list.expr, ctx, codec)) .collect::, Error>>()?, ))) } ExprType::Cube(CubeNode { expr }) => Ok(Expr::GroupingSet(GroupingSet::Cube( - parse_exprs(expr, registry, codec)?, + parse_exprs(expr, ctx, codec)?, ))), ExprType::Rollup(RollupNode { expr }) => Ok(Expr::GroupingSet( - GroupingSet::Rollup(parse_exprs(expr, registry, codec)?), + GroupingSet::Rollup(parse_exprs(expr, ctx, codec)?), )), ExprType::Placeholder(PlaceholderNode { id, @@ -657,13 +657,41 @@ pub fn parse_expr( ))) } }, + ExprType::ScalarSubqueryExpr(sq) => { + let subquery = parse_subquery( + sq.subquery + .as_deref() + .ok_or_else(|| Error::required("ScalarSubqueryExprNode.subquery"))?, + ctx, + codec, + )?; + Ok(Expr::ScalarSubquery(subquery)) + } } } +fn parse_subquery( + proto: &protobuf::SubqueryNode, + ctx: &TaskContext, + codec: &dyn LogicalExtensionCodec, +) -> Result { + let plan_node = proto + .subquery + .as_ref() + .ok_or_else(|| Error::required("SubqueryNode.subquery"))?; + let plan = plan_node.try_into_logical_plan(ctx, codec)?; + let outer_ref_columns = parse_exprs(&proto.outer_ref_columns, ctx, codec)?; + Ok(Subquery { + subquery: Arc::new(plan), + outer_ref_columns, + spans: Default::default(), + }) +} + /// Parse a vector of `protobuf::LogicalExprNode`s. pub fn parse_exprs<'a, I>( protos: I, - registry: &dyn FunctionRegistry, + ctx: &TaskContext, codec: &dyn LogicalExtensionCodec, ) -> Result, Error> where @@ -672,7 +700,7 @@ where let res = protos .into_iter() .map(|elem| { - parse_expr(elem, registry, codec).map_err(|e| plan_datafusion_err!("{}", e)) + parse_expr(elem, ctx, codec).map_err(|e| plan_datafusion_err!("{}", e)) }) .collect::>>()?; Ok(res) @@ -680,7 +708,7 @@ where pub fn parse_sorts<'a, I>( protos: I, - registry: &dyn FunctionRegistry, + ctx: &TaskContext, codec: &dyn LogicalExtensionCodec, ) -> Result, Error> where @@ -688,17 +716,17 @@ where { protos .into_iter() - .map(|sort| parse_sort(sort, registry, codec)) + .map(|sort| parse_sort(sort, ctx, codec)) .collect::, Error>>() } pub fn parse_sort( sort: &protobuf::SortExprNode, - registry: &dyn FunctionRegistry, + ctx: &TaskContext, codec: &dyn LogicalExtensionCodec, ) -> Result { Ok(Sort::new( - parse_required_expr(sort.expr.as_ref(), registry, "expr", codec)?, + parse_required_expr(sort.expr.as_ref(), ctx, "expr", codec)?, sort.asc, sort.nulls_first, )) @@ -754,23 +782,23 @@ pub fn from_proto_binary_op(op: &str) -> Result { fn parse_optional_expr( p: Option<&protobuf::LogicalExprNode>, - registry: &dyn FunctionRegistry, + ctx: &TaskContext, codec: &dyn LogicalExtensionCodec, ) -> Result, Error> { match p { - Some(expr) => parse_expr(expr, registry, codec).map(Some), + Some(expr) => parse_expr(expr, ctx, codec).map(Some), None => Ok(None), } } fn parse_required_expr( p: Option<&protobuf::LogicalExprNode>, - registry: &dyn FunctionRegistry, + ctx: &TaskContext, field: impl Into, codec: &dyn LogicalExtensionCodec, ) -> Result { match p { - Some(expr) => parse_expr(expr, registry, codec), + Some(expr) => parse_expr(expr, ctx, codec), None => Err(Error::required(field)), } } diff --git a/datafusion/proto/src/logical_plan/mod.rs b/datafusion/proto/src/logical_plan/mod.rs index 1e0264690d4fb..57826c31f695e 100644 --- a/datafusion/proto/src/logical_plan/mod.rs +++ b/datafusion/proto/src/logical_plan/mod.rs @@ -1325,10 +1325,10 @@ impl AsLogicalPlan for LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::Selection(Box::new( protobuf::SelectionNode { input: Some(Box::new(input)), - expr: Some(serialize_expr( + expr: Some(Box::new(serialize_expr( &filter.predicate, extension_codec, - )?), + )?)), }, ))), }) @@ -1444,7 +1444,7 @@ impl AsLogicalPlan for LogicalPlanNode { null_equality.to_owned().into(); let filter = filter .as_ref() - .map(|e| serialize_expr(e, extension_codec)) + .map(|e| serialize_expr(e, extension_codec).map(Box::new)) .map_or(Ok(None), |v| v.map(Some))?; Ok(LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::Join(Box::new( @@ -1461,8 +1461,14 @@ impl AsLogicalPlan for LogicalPlanNode { ))), }) } - LogicalPlan::Subquery(_) => { - not_impl_err!("LogicalPlan serde is not yet implemented for subqueries") + LogicalPlan::Subquery(subquery) => { + // Serialize the inner subquery plan directly — the + // LogicalPlan::Subquery wrapper is reconstructed during + // expression deserialization. + LogicalPlanNode::try_from_logical_plan( + &subquery.subquery, + extension_codec, + ) } LogicalPlan::SubqueryAlias(SubqueryAlias { input, alias, .. }) => { let input: LogicalPlanNode = LogicalPlanNode::try_from_logical_plan( diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 6fcb7389922ad..bd5c4b585c24f 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -28,6 +28,7 @@ use datafusion_expr::expr::{ self, AggregateFunctionParams, Alias, Between, BinaryExpr, Cast, GroupingSet, InList, Like, NullTreatment, Placeholder, ScalarFunction, Unnest, }; +use datafusion_expr::logical_plan::Subquery; use datafusion_expr::{ Expr, JoinConstraint, JoinType, SortExpr, TryCast, WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, logical_plan::PlanType, @@ -48,7 +49,8 @@ use crate::protobuf::{ }, }; -use super::LogicalExtensionCodec; +use super::{AsLogicalPlan, LogicalExtensionCodec}; +use crate::protobuf::LogicalPlanNode; impl From<&UnnestOptions> for protobuf::UnnestOptions { fn from(opts: &UnnestOptions) -> Self { @@ -579,14 +581,20 @@ pub fn serialize_expr( qualifier: qualifier.to_owned().map(|x| x.into()), })), }, - Expr::ScalarSubquery(_) - | Expr::InSubquery(_) - | Expr::Exists { .. } - | Expr::SetComparison(_) - | Expr::OuterReferenceColumn { .. } => { - // we would need to add logical plan operators to datafusion.proto to support this - // see discussion in https://github.com/apache/datafusion/issues/2565 - return Err(Error::General("Proto serialization error: Expr::ScalarSubquery(_) | Expr::InSubquery(_) | Expr::Exists { .. } | Exp:OuterReferenceColumn not supported".to_string())); + Expr::ScalarSubquery(subquery) => protobuf::LogicalExprNode { + expr_type: Some(ExprType::ScalarSubqueryExpr(Box::new( + protobuf::ScalarSubqueryExprNode { + subquery: Some(Box::new(serialize_subquery(subquery, codec)?)), + }, + ))), + }, + Expr::InSubquery(_) + | Expr::Exists(_) + | Expr::OuterReferenceColumn(_, _) + | Expr::SetComparison(_) => { + return Err(Error::General(format!( + "Proto serialization error: {expr} is not yet supported" + ))); } Expr::GroupingSet(GroupingSet::Cube(exprs)) => protobuf::LogicalExprNode { expr_type: Some(ExprType::Cube(CubeNode { @@ -631,6 +639,19 @@ pub fn serialize_expr( Ok(expr_node) } +fn serialize_subquery( + subquery: &Subquery, + codec: &dyn LogicalExtensionCodec, +) -> Result { + let plan = LogicalPlanNode::try_from_logical_plan(&subquery.subquery, codec) + .map_err(|e| Error::General(e.to_string()))?; + let outer_ref_columns = serialize_exprs(&subquery.outer_ref_columns, codec)?; + Ok(protobuf::SubqueryNode { + subquery: Some(Box::new(plan)), + outer_ref_columns, + }) +} + pub fn serialize_sorts<'a, I>( sorts: I, codec: &dyn LogicalExtensionCodec, diff --git a/datafusion/proto/src/physical_plan/from_proto.rs b/datafusion/proto/src/physical_plan/from_proto.rs index 73a347331cb6b..770e6a289800b 100644 --- a/datafusion/proto/src/physical_plan/from_proto.rs +++ b/datafusion/proto/src/physical_plan/from_proto.rs @@ -39,6 +39,7 @@ use datafusion_execution::{FunctionRegistry, TaskContext}; use datafusion_expr::WindowFunctionDefinition; use datafusion_expr::dml::InsertOp; use datafusion_physical_expr::projection::{ProjectionExpr, ProjectionExprs}; +use datafusion_physical_expr::scalar_subquery::ScalarSubqueryExpr; use datafusion_physical_expr::{LexOrdering, PhysicalSortExpr, ScalarFunctionExpr}; use datafusion_physical_plan::expressions::{ BinaryExpr, CaseExpr, CastExpr, Column, IsNotNullExpr, IsNullExpr, LikeExpr, Literal, @@ -240,7 +241,7 @@ pub fn parse_physical_expr( ctx, input_schema, codec, - &DefaultPhysicalProtoConverter {}, + &DefaultPhysicalProtoConverter::default(), ) } @@ -490,6 +491,27 @@ pub fn parse_physical_expr_with_converter( hash_expr.description.clone(), )) } + ExprType::ScalarSubquery(sq) => { + let data_type: arrow::datatypes::DataType = sq + .data_type + .as_ref() + .ok_or_else(|| { + proto_error("Missing data_type in PhysicalScalarSubqueryExprNode") + })? + .try_into()?; + let results = proto_converter.scalar_subquery_results().ok_or_else(|| { + proto_error( + "ScalarSubqueryExpr can only be deserialized as part \ + of a surrounding ScalarSubqueryExec", + ) + })?; + Arc::new(ScalarSubqueryExpr::new( + data_type, + sq.nullable, + sq.index as usize, + results, + )) + } ExprType::Extension(extension) => { let inputs: Vec> = extension .inputs diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index 4e37e9f9528e5..cd0f3723484d5 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -51,6 +51,9 @@ use datafusion_datasource_parquet::source::ParquetSource; #[cfg(feature = "parquet")] use datafusion_execution::object_store::ObjectStoreUrl; use datafusion_execution::{FunctionRegistry, TaskContext}; +use datafusion_expr::execution_props::{ + ScalarSubqueryResults, new_scalar_subquery_results, +}; use datafusion_expr::{AggregateUDF, ScalarUDF, WindowUDF}; use datafusion_functions_table::generate_series::{ Empty, GenSeriesArgs, GenerateSeriesTable, GenericSeriesState, TimestampValue, @@ -83,6 +86,7 @@ use datafusion_physical_plan::metrics::{MetricCategory, MetricType}; use datafusion_physical_plan::placeholder_row::PlaceholderRowExec; use datafusion_physical_plan::projection::{ProjectionExec, ProjectionExpr}; use datafusion_physical_plan::repartition::RepartitionExec; +use datafusion_physical_plan::scalar_subquery::{ScalarSubqueryExec, ScalarSubqueryLink}; use datafusion_physical_plan::sorts::sort::SortExec; use datafusion_physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; use datafusion_physical_plan::union::{InterleaveExec, UnionExec}; @@ -145,7 +149,7 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { self.try_into_physical_plan_with_converter( ctx, codec, - &DefaultPhysicalProtoConverter {}, + &DefaultPhysicalProtoConverter::default(), ) } @@ -159,7 +163,7 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { Self::try_from_physical_plan_with_converter( plan, codec, - &DefaultPhysicalProtoConverter {}, + &DefaultPhysicalProtoConverter::default(), ) } } @@ -321,6 +325,8 @@ impl protobuf::PhysicalPlanNode { PhysicalPlanType::Buffer(buffer) => { self.try_into_buffer_physical_plan(buffer, ctx, codec, proto_converter) } + PhysicalPlanType::ScalarSubquery(sq) => self + .try_into_scalar_subquery_physical_plan(sq, ctx, codec, proto_converter), } } @@ -569,6 +575,14 @@ impl protobuf::PhysicalPlanNode { ); } + if let Some(exec) = plan.downcast_ref::() { + return protobuf::PhysicalPlanNode::try_from_scalar_subquery_exec( + exec, + codec, + proto_converter, + ); + } + let mut buf: Vec = vec![]; match codec.try_encode(Arc::clone(&plan_clone), &mut buf) { Ok(_) => { @@ -2251,6 +2265,44 @@ impl protobuf::PhysicalPlanNode { Ok(Arc::new(BufferExec::new(input, buffer.capacity as usize))) } + fn try_into_scalar_subquery_physical_plan( + &self, + sq: &protobuf::ScalarSubqueryExecNode, + ctx: &TaskContext, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, + ) -> Result> { + // First, deserialize the main input plan. We set up the subquery results + // container first, so that ScalarSubqueryExpr nodes can reference it. + let subquery_results = new_scalar_subquery_results(sq.subqueries.len()); + let prev = proto_converter + .set_scalar_subquery_results(Some(Arc::clone(&subquery_results))); + let input = into_physical_plan(&sq.input, ctx, codec, proto_converter); + proto_converter.set_scalar_subquery_results(prev); + let input = input?; + + // Now deserialize the subquery children. + let subqueries: Vec = sq + .subqueries + .iter() + .enumerate() + .map(|(index, sq_plan)| { + let plan = sq_plan.try_into_physical_plan_with_converter( + ctx, + codec, + proto_converter, + )?; + Ok(ScalarSubqueryLink { plan, index }) + }) + .collect::>>()?; + + Ok(Arc::new(ScalarSubqueryExec::new( + input, + subqueries, + subquery_results, + ))) + } + fn try_from_explain_exec( exec: &ExplainExec, _codec: &dyn PhysicalExtensionCodec, @@ -3645,6 +3697,38 @@ impl protobuf::PhysicalPlanNode { ))), }) } + + fn try_from_scalar_subquery_exec( + exec: &ScalarSubqueryExec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, + ) -> Result { + let input = protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( + Arc::clone(exec.input()), + codec, + proto_converter, + )?; + let subqueries = exec + .subqueries() + .iter() + .map(|sq| { + protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( + Arc::clone(&sq.plan), + codec, + proto_converter, + ) + }) + .collect::>>()?; + + Ok(protobuf::PhysicalPlanNode { + physical_plan_type: Some(PhysicalPlanType::ScalarSubquery(Box::new( + protobuf::ScalarSubqueryExecNode { + input: Some(Box::new(input)), + subqueries, + }, + ))), + }) + } } pub trait AsExecutionPlan: Debug + Send + Sync + Clone { @@ -3777,6 +3861,37 @@ pub trait PhysicalProtoConverterExtension { expr: &Arc, codec: &dyn PhysicalExtensionCodec, ) -> Result; + + /// Sets the scalar subquery results container for expression deserialization. + /// Returns the previous value (for save/restore around nested subqueries). + /// + /// During `ScalarSubqueryExec` deserialization, this is called before + /// deserializing the input plan so that `ScalarSubqueryExpr` nodes can + /// pick up the shared results container at construction time. + /// + /// The default implementation discards the value and always returns `None`. + /// This means `ScalarSubqueryExpr` deserialization will fail with an error, + /// Implementations that need scalar subquery support should override both + /// this method and `scalar_subquery_results`. + /// + /// NOTE: These methods use interior mutability (`RefCell`) to thread state + /// through `&self`, because the deserialization call stack does not have an + /// explicit context parameter. A dedicated deserialization context struct + /// would be a cleaner long-term solution. + fn set_scalar_subquery_results( + &self, + _results: Option, + ) -> Option { + None + } + + /// Returns the current scalar subquery results container, if any. + /// + /// See [`set_scalar_subquery_results`](Self::set_scalar_subquery_results) + /// for details on the default behavior and design trade-offs. + fn scalar_subquery_results(&self) -> Option { + None + } } /// DataEncoderTuple captures the position of the encoder @@ -3792,7 +3907,11 @@ struct DataEncoderTuple { pub blob: Vec, } -pub struct DefaultPhysicalProtoConverter; +#[derive(Default)] +pub struct DefaultPhysicalProtoConverter { + scalar_subquery_results: RefCell>, +} + impl PhysicalProtoConverterExtension for DefaultPhysicalProtoConverter { fn proto_to_execution_plan( &self, @@ -3839,6 +3958,17 @@ impl PhysicalProtoConverterExtension for DefaultPhysicalProtoConverter { ) -> Result { serialize_physical_expr_with_converter(expr, codec, self) } + + fn set_scalar_subquery_results( + &self, + results: Option, + ) -> Option { + self.scalar_subquery_results.replace(results) + } + + fn scalar_subquery_results(&self) -> Option { + self.scalar_subquery_results.borrow().clone() + } } /// Internal serializer that adds expr_id to expressions. @@ -3921,6 +4051,10 @@ impl PhysicalProtoConverterExtension for DeduplicatingSerializer { struct DeduplicatingDeserializer { /// Cache mapping expr_id to deserialized expressions. cache: RefCell>>, + /// Scalar subquery results container for the current deserialization scope. + /// Set by `ScalarSubqueryExec` deserialization so that `ScalarSubqueryExpr` + /// nodes in the input plan can pick up the shared results container. + scalar_subquery_results: RefCell>, } impl PhysicalProtoConverterExtension for DeduplicatingDeserializer { @@ -3981,6 +4115,17 @@ impl PhysicalProtoConverterExtension for DeduplicatingDeserializer { ) -> Result { internal_err!("DeduplicatingDeserializer cannot serialize physical expressions") } + + fn set_scalar_subquery_results( + &self, + results: Option, + ) -> Option { + self.scalar_subquery_results.replace(results) + } + + fn scalar_subquery_results(&self) -> Option { + self.scalar_subquery_results.borrow().clone() + } } /// A proto converter that adds expression deduplication during serialization diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index 990a54cf94c7a..a1fa57ea5bdb6 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -32,6 +32,7 @@ use datafusion_datasource_json::file_format::JsonSink; use datafusion_datasource_parquet::file_format::ParquetSink; use datafusion_expr::WindowFrame; use datafusion_physical_expr::ScalarFunctionExpr; +use datafusion_physical_expr::scalar_subquery::ScalarSubqueryExpr; use datafusion_physical_expr::window::{SlidingAggregateWindowExpr, StandardWindowExpr}; use datafusion_physical_expr_common::physical_expr::snapshot_physical_expr; use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; @@ -242,7 +243,7 @@ pub fn serialize_physical_expr( serialize_physical_expr_with_converter( value, codec, - &DefaultPhysicalProtoConverter {}, + &DefaultPhysicalProtoConverter::default(), ) } @@ -504,6 +505,17 @@ pub fn serialize_physical_expr_with_converter( }, )), }) + } else if let Some(expr) = expr.downcast_ref::() { + Ok(protobuf::PhysicalExprNode { + expr_id: None, + expr_type: Some(protobuf::physical_expr_node::ExprType::ScalarSubquery( + protobuf::PhysicalScalarSubqueryExprNode { + data_type: Some(expr.data_type().try_into()?), + nullable: expr.nullable(), + index: expr.index() as u32, + }, + )), + }) } else { let mut buf: Vec = vec![]; match codec.try_encode_expr(&value, &mut buf) { diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index d29d6fc7cd08d..4464b511a7328 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -133,7 +133,8 @@ fn roundtrip_expr_test_with_codec( ) { let proto: protobuf::LogicalExprNode = serialize_expr(&initial_struct, codec) .unwrap_or_else(|e| panic!("Error serializing expression: {e:?}")); - let round_trip: Expr = from_proto::parse_expr(&proto, &ctx, codec).unwrap(); + let round_trip: Expr = + from_proto::parse_expr(&proto, ctx.task_ctx().as_ref(), codec).unwrap(); assert_eq!(format!("{:?}", &initial_struct), format!("{round_trip:?}")); @@ -2564,7 +2565,8 @@ fn roundtrip_scalar_udf_extension_codec() { let ctx = SessionContext::new(); let proto = serialize_expr(&test_expr, &UDFExtensionCodec).expect("serialize expr"); let round_trip = - from_proto::parse_expr(&proto, &ctx, &UDFExtensionCodec).expect("parse expr"); + from_proto::parse_expr(&proto, ctx.task_ctx().as_ref(), &UDFExtensionCodec) + .expect("parse expr"); assert_eq!(format!("{:?}", &test_expr), format!("{round_trip:?}")); roundtrip_json_test(&proto); @@ -2577,7 +2579,8 @@ fn roundtrip_aggregate_udf_extension_codec() { let ctx = SessionContext::new(); let proto = serialize_expr(&test_expr, &UDFExtensionCodec).expect("serialize expr"); let round_trip = - from_proto::parse_expr(&proto, &ctx, &UDFExtensionCodec).expect("parse expr"); + from_proto::parse_expr(&proto, ctx.task_ctx().as_ref(), &UDFExtensionCodec) + .expect("parse expr"); assert_eq!(format!("{:?}", &test_expr), format!("{round_trip:?}")); roundtrip_json_test(&proto); diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index 15639bcd25bdd..b24b19a873420 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -18,7 +18,7 @@ use std::any::Any; use std::collections::HashMap; use std::fmt::{Display, Formatter}; -use std::sync::{Arc, RwLock}; +use std::sync::{Arc, OnceLock, RwLock}; use std::vec; use arrow::array::RecordBatch; @@ -75,6 +75,9 @@ use datafusion::physical_plan::metrics::MetricType; use datafusion::physical_plan::placeholder_row::PlaceholderRowExec; use datafusion::physical_plan::projection::{ProjectionExec, ProjectionExpr}; use datafusion::physical_plan::repartition::RepartitionExec; +use datafusion::physical_plan::scalar_subquery::{ + ScalarSubqueryExec, ScalarSubqueryLink, +}; use datafusion::physical_plan::sorts::sort::SortExec; use datafusion::physical_plan::union::{InterleaveExec, UnionExec}; use datafusion::physical_plan::unnest::{ListUnnest, UnnestExec}; @@ -110,6 +113,7 @@ use datafusion_functions_aggregate::average::avg_udaf; use datafusion_functions_aggregate::min_max::max_udaf; use datafusion_functions_aggregate::nth_value::nth_value_udaf; use datafusion_functions_aggregate::string_agg::string_agg_udaf; +use datafusion_physical_expr::scalar_subquery::ScalarSubqueryExpr; use datafusion_proto::bytes::{ physical_plan_from_bytes_with_proto_converter, physical_plan_to_bytes_with_proto_converter, @@ -136,7 +140,7 @@ use crate::cases::{ fn roundtrip_test(exec_plan: Arc) -> Result<()> { let ctx = SessionContext::new(); let codec = DefaultPhysicalExtensionCodec {}; - let proto_converter = DefaultPhysicalProtoConverter {}; + let proto_converter = DefaultPhysicalProtoConverter::default(); roundtrip_test_and_return(exec_plan, &ctx, &codec, &proto_converter)?; Ok(()) } @@ -183,7 +187,7 @@ fn roundtrip_test_with_context( ctx: &SessionContext, ) -> Result<()> { let codec = DefaultPhysicalExtensionCodec {}; - let proto_converter = DefaultPhysicalProtoConverter {}; + let proto_converter = DefaultPhysicalProtoConverter::default(); roundtrip_test_and_return(exec_plan, ctx, &codec, &proto_converter)?; Ok(()) } @@ -192,7 +196,7 @@ fn roundtrip_test_with_context( /// query results are identical. async fn roundtrip_test_sql_with_context(sql: &str, ctx: &SessionContext) -> Result<()> { let codec = DefaultPhysicalExtensionCodec {}; - let proto_converter = DefaultPhysicalProtoConverter {}; + let proto_converter = DefaultPhysicalProtoConverter::default(); let initial_plan = ctx.sql(sql).await?.create_physical_plan().await?; roundtrip_test_and_return(initial_plan, ctx, &codec, &proto_converter)?; @@ -964,7 +968,7 @@ fn roundtrip_parquet_exec_attaches_cached_reader_factory_after_roundtrip() -> Re let ctx = SessionContext::new(); let codec = DefaultPhysicalExtensionCodec {}; - let proto_converter = DefaultPhysicalProtoConverter {}; + let proto_converter = DefaultPhysicalProtoConverter::default(); let roundtripped = roundtrip_test_and_return(exec_plan, &ctx, &codec, &proto_converter)?; @@ -1191,7 +1195,7 @@ fn roundtrip_parquet_exec_with_custom_predicate_expr() -> Result<()> { exec_plan, &ctx, &CustomPhysicalExtensionCodec {}, - &DefaultPhysicalProtoConverter {}, + &DefaultPhysicalProtoConverter::default(), )?; Ok(()) } @@ -1398,7 +1402,7 @@ fn roundtrip_scalar_udf_extension_codec() -> Result<()> { )?); let ctx = SessionContext::new(); - let proto_converter = DefaultPhysicalProtoConverter {}; + let proto_converter = DefaultPhysicalProtoConverter::default(); roundtrip_test_and_return(aggregate, &ctx, &UDFExtensionCodec, &proto_converter)?; Ok(()) } @@ -1446,7 +1450,7 @@ fn roundtrip_udwf_extension_codec() -> Result<()> { )?); let ctx = SessionContext::new(); - let proto_converter = DefaultPhysicalProtoConverter {}; + let proto_converter = DefaultPhysicalProtoConverter::default(); roundtrip_test_and_return(window, &ctx, &UDFExtensionCodec, &proto_converter)?; Ok(()) } @@ -1518,7 +1522,7 @@ fn roundtrip_aggregate_udf_extension_codec() -> Result<()> { )?); let ctx = SessionContext::new(); - let proto_converter = DefaultPhysicalProtoConverter {}; + let proto_converter = DefaultPhysicalProtoConverter::default(); roundtrip_test_and_return(aggregate, &ctx, &UDFExtensionCodec, &proto_converter)?; Ok(()) } @@ -1646,7 +1650,7 @@ fn roundtrip_csv_sink() -> Result<()> { let ctx = SessionContext::new(); let codec = DefaultPhysicalExtensionCodec {}; - let proto_converter = DefaultPhysicalProtoConverter {}; + let proto_converter = DefaultPhysicalProtoConverter::default(); let roundtrip_plan = roundtrip_test_and_return( Arc::new(DataSinkExec::new(input, data_sink, Some(sort_order))), @@ -2851,7 +2855,7 @@ fn test_backward_compatibility_no_expr_id() -> Result<()> { let ctx = SessionContext::new(); let codec = DefaultPhysicalExtensionCodec {}; - let proto_converter = DefaultPhysicalProtoConverter {}; + let proto_converter = DefaultPhysicalProtoConverter::default(); // Should deserialize without error let result = proto_converter.proto_to_physical_expr( @@ -3186,3 +3190,90 @@ fn roundtrip_lead_with_default_value() -> Result<()> { true, )?)) } + +/// Verify that ScalarSubqueryExpr nodes in the input plan are connected to the +/// same shared results container as ScalarSubqueryExec after a proto round-trip. +#[test] +fn roundtrip_scalar_subquery_exec() -> Result<()> { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, false)])); + + // Create a shared results container with one slot. + let results = Arc::new(vec![OnceLock::new()]); + + // Build the input plan: a filter whose predicate references the + // scalar subquery result via ScalarSubqueryExpr. + let sq_expr = Arc::new(ScalarSubqueryExpr::new( + DataType::Int64, + true, + 0, + Arc::clone(&results), + )); + let predicate = binary(col("a", &schema)?, Operator::Eq, sq_expr, &schema)?; + let filter = + FilterExec::try_new(predicate, Arc::new(EmptyExec::new(schema.clone())))?; + + // Build a trivial subquery plan. + let subquery_plan = + Arc::new(EmptyExec::new(Arc::new(Schema::new(vec![Field::new( + "x", + DataType::Int64, + true, + )])))); + + let exec: Arc = Arc::new(ScalarSubqueryExec::new( + Arc::new(filter), + vec![ScalarSubqueryLink { + plan: subquery_plan, + index: 0, + }], + results, + )); + + // Perform the round-trip using DeduplicatingProtoConverter, which + // creates a DeduplicatingDeserializer that threads scalar subquery + // results through expression deserialization. + let codec = DefaultPhysicalExtensionCodec {}; + let converter = DeduplicatingProtoConverter {}; + let bytes = physical_plan_to_bytes_with_proto_converter( + Arc::clone(&exec), + &codec, + &converter, + )?; + let ctx = SessionContext::new(); + let deserialized = physical_plan_from_bytes_with_proto_converter( + bytes.as_ref(), + ctx.task_ctx().as_ref(), + &codec, + &converter, + )?; + + // Verify the deserialized ScalarSubqueryExec's results container is + // shared with the ScalarSubqueryExpr in the input plan. + let sq_exec = deserialized + .downcast_ref::() + .expect("expected ScalarSubqueryExec"); + let exec_results = sq_exec.results(); + + // Walk the input plan to find the ScalarSubqueryExpr and verify it + // points to the same results container. + let filter_exec = sq_exec + .input() + .downcast_ref::() + .expect("expected FilterExec"); + let binary_expr = filter_exec + .predicate() + .as_any() + .downcast_ref::() + .expect("expected BinaryExpr"); + let deserialized_sq_expr = binary_expr + .right() + .as_any() + .downcast_ref::() + .expect("expected ScalarSubqueryExpr"); + + assert!( + Arc::ptr_eq(exec_results, deserialized_sq_expr.results()), + "ScalarSubqueryExpr should share the same results container as ScalarSubqueryExec" + ); + Ok(()) +} diff --git a/datafusion/proto/tests/cases/serialize.rs b/datafusion/proto/tests/cases/serialize.rs index bb955a426ca78..850fd42ce131b 100644 --- a/datafusion/proto/tests/cases/serialize.rs +++ b/datafusion/proto/tests/cases/serialize.rs @@ -77,7 +77,8 @@ fn udf_roundtrip_with_registry() { .call(vec![lit("")]); let bytes = expr.to_bytes().unwrap(); - let deserialized_expr = Expr::from_bytes_with_registry(&bytes, &ctx).unwrap(); + let deserialized_expr = + Expr::from_bytes_with_ctx(&bytes, ctx.task_ctx().as_ref()).unwrap(); assert_eq!(expr, deserialized_expr); } @@ -281,7 +282,8 @@ fn test_expression_serialization_roundtrip() { let extension_codec = DefaultLogicalExtensionCodec {}; let proto = serialize_expr(&expr, &extension_codec).unwrap(); - let deserialize = parse_expr(&proto, &ctx, &extension_codec).unwrap(); + let deserialize = + parse_expr(&proto, ctx.task_ctx().as_ref(), &extension_codec).unwrap(); let serialize_name = extract_function_name(&expr); let deserialize_name = extract_function_name(&deserialize); diff --git a/datafusion/sqllogictest/test_files/metadata.slt b/datafusion/sqllogictest/test_files/metadata.slt index f3836b23ec321..fc090102c4066 100644 --- a/datafusion/sqllogictest/test_files/metadata.slt +++ b/datafusion/sqllogictest/test_files/metadata.slt @@ -47,16 +47,12 @@ select id, name from table_with_metadata; NULL bar 3 baz -query I rowsort +query error DataFusion error: Execution error: Scalar subquery returned more than one row SELECT ( SELECT id FROM table_with_metadata ) UNION ( SELECT id FROM table_with_metadata ); ----- -1 -3 -NULL query I rowsort SELECT "data"."id" diff --git a/datafusion/sqllogictest/test_files/subquery.slt b/datafusion/sqllogictest/test_files/subquery.slt index 7f88199b3c0ef..033de42f871df 100644 --- a/datafusion/sqllogictest/test_files/subquery.slt +++ b/datafusion/sqllogictest/test_files/subquery.slt @@ -689,10 +689,10 @@ query TT explain SELECT t1_id, (SELECT t2_id FROM t2 limit 0) FROM t1 ---- logical_plan -01)Projection: t1.t1_id, __scalar_sq_1.t2_id AS t2_id -02)--Left Join: -03)----TableScan: t1 projection=[t1_id] -04)----EmptyRelation: rows=0 +01)Projection: t1.t1_id, () +02)--Subquery: +03)----EmptyRelation: rows=0 +04)--TableScan: t1 projection=[t1_id] query II rowsort SELECT t1_id, (SELECT t2_id FROM t2 limit 0) FROM t1 @@ -723,26 +723,79 @@ query TT explain select (select count(*) from t1) as b ---- logical_plan -01)Projection: __scalar_sq_1.count(*) AS b -02)--SubqueryAlias: __scalar_sq_1 +01)Projection: () AS b +02)--Subquery: 03)----Projection: count(Int64(1)) AS count(*) 04)------Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] 05)--------TableScan: t1 projection=[] +06)--EmptyRelation: rows=1 #simple_uncorrelated_scalar_subquery2 query TT explain select (select count(*) from t1) as b, (select count(1) from t2) ---- logical_plan -01)Projection: __scalar_sq_1.count(*) AS b, __scalar_sq_2.count(Int64(1)) AS count(Int64(1)) -02)--Left Join: -03)----SubqueryAlias: __scalar_sq_1 -04)------Projection: count(Int64(1)) AS count(*) -05)--------Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] -06)----------TableScan: t1 projection=[] -07)----SubqueryAlias: __scalar_sq_2 -08)------Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] -09)--------TableScan: t2 projection=[] +01)Projection: () AS b, () +02)--Subquery: +03)----Projection: count(Int64(1)) AS count(*) +04)------Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] +05)--------TableScan: t1 projection=[] +06)--Subquery: +07)----Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] +08)------TableScan: t2 projection=[] +09)--EmptyRelation: rows=1 + +# Verify projection pushdown works inside uncorrelated scalar subqueries. +# Each test targets a different early-return path in OptimizeProjections +# to ensure subquery plans are optimized regardless of where the subquery +# expression appears. + +# Subquery in a Filter predicate: the TableScan inside the subquery should +# only read t2_id, not all columns. +query TT +explain select t1_id from t1 where t1_id > (select max(t2_id) from t2) +---- +logical_plan +01)Filter: t1.t1_id > () +02)--Subquery: +03)----Aggregate: groupBy=[[]], aggr=[[max(t2.t2_id)]] +04)------TableScan: t2 projection=[t2_id] +05)--TableScan: t1 projection=[t1_id] + +# Subquery in a Projection expression +query TT +explain select t1_id, (select max(t2_id) from t2) as max_t2 from t1 +---- +logical_plan +01)Projection: t1.t1_id, () AS max_t2 +02)--Subquery: +03)----Aggregate: groupBy=[[]], aggr=[[max(t2.t2_id)]] +04)------TableScan: t2 projection=[t2_id] +05)--TableScan: t1 projection=[t1_id] + +# Subquery in an Aggregate expression +query TT +explain select sum(t1_int + (select min(t2_int) from t2)) as s from t1 +---- +logical_plan +01)Projection: sum(t1.t1_int + min(t2.t2_int)) AS s +02)--Aggregate: groupBy=[[]], aggr=[[sum(CAST(t1.t1_int + () AS Int64))]] +03)----Subquery: +04)------Aggregate: groupBy=[[]], aggr=[[min(t2.t2_int)]] +05)--------TableScan: t2 projection=[t2_int] +06)----TableScan: t1 projection=[t1_int] + +# Subquery in a Window expression +query TT +explain select t1_id, sum(t1_int + (select min(t2_int) from t2)) over () as win from t1 +---- +logical_plan +01)Projection: t1.t1_id, sum(t1.t1_int + min(t2.t2_int)) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING AS win +02)--WindowAggr: windowExpr=[[sum(CAST(t1.t1_int + () AS Int64)) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]] +03)----Subquery: +04)------Aggregate: groupBy=[[]], aggr=[[min(t2.t2_int)]] +05)--------TableScan: t2 projection=[t2_int] +06)----TableScan: t1 projection=[t1_id, t1_int] statement ok set datafusion.explain.logical_plan_only = false; @@ -751,22 +804,23 @@ query TT explain select (select count(*) from t1) as b, (select count(1) from t2) ---- logical_plan -01)Projection: __scalar_sq_1.count(*) AS b, __scalar_sq_2.count(Int64(1)) AS count(Int64(1)) -02)--Left Join: -03)----SubqueryAlias: __scalar_sq_1 -04)------Projection: count(Int64(1)) AS count(*) -05)--------Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] -06)----------TableScan: t1 projection=[] -07)----SubqueryAlias: __scalar_sq_2 -08)------Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] -09)--------TableScan: t2 projection=[] +01)Projection: () AS b, () +02)--Subquery: +03)----Projection: count(Int64(1)) AS count(*) +04)------Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] +05)--------TableScan: t1 projection=[] +06)--Subquery: +07)----Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] +08)------TableScan: t2 projection=[] +09)--EmptyRelation: rows=1 physical_plan -01)ProjectionExec: expr=[count(*)@0 as b, count(Int64(1))@1 as count(Int64(1))] -02)--NestedLoopJoinExec: join_type=Left -03)----ProjectionExec: expr=[4 as count(*)] -04)------PlaceholderRowExec -05)----ProjectionExec: expr=[4 as count(Int64(1))] -06)------PlaceholderRowExec +01)ScalarSubqueryExec: subqueries=2 +02)--ProjectionExec: expr=[scalar_subquery() as b, scalar_subquery() as count(Int64(1))] +03)----PlaceholderRowExec +04)--ProjectionExec: expr=[4 as count(*)] +05)----PlaceholderRowExec +06)--ProjectionExec: expr=[4 as count(Int64(1))] +07)----PlaceholderRowExec statement ok set datafusion.explain.logical_plan_only = true; @@ -1672,6 +1726,197 @@ drop table employees; statement count 0 drop table project_assignments; +############# +## Uncorrelated scalar subquery row-count semantics +## A scalar subquery must return at most one row; returning more is an error. +############# + +statement ok +CREATE TABLE sq_values(v INT) AS VALUES (1), (2), (3); + +statement ok +CREATE TABLE sq_main(x INT) AS VALUES (10), (20); + +statement ok +CREATE TABLE sq_empty(v INT) AS VALUES (1) LIMIT 0; + +# Scalar subquery returning multiple rows in SELECT position → error +query error DataFusion error: Execution error: Scalar subquery returned more than one row +SELECT (SELECT v FROM sq_values); + +# Scalar subquery returning multiple rows in WHERE position → error +query error DataFusion error: Execution error: Scalar subquery returned more than one row +SELECT x FROM sq_main WHERE x > (SELECT v FROM sq_values); + +# Scalar subquery returning multiple rows as a function argument → error +query error DataFusion error: Execution error: Scalar subquery returned more than one row +SELECT x + (SELECT v FROM sq_values) FROM sq_main; + +# Scalar subquery returning exactly one row → success +query I +SELECT (SELECT v FROM sq_values LIMIT 1); +---- +1 + +# Scalar subquery returning exactly one row in WHERE → success +query I rowsort +SELECT x FROM sq_main WHERE x > (SELECT v FROM sq_values LIMIT 1); +---- +10 +20 + +# Scalar subquery returning zero rows → NULL +query I +SELECT (SELECT v FROM sq_empty); +---- +NULL + +# Scalar subquery returning zero rows in arithmetic → NULL propagation +query I +SELECT x + (SELECT v FROM sq_empty) FROM sq_main; +---- +NULL +NULL + +# Scalar subquery returning zero rows in WHERE comparison → no matching rows +query I +SELECT x FROM sq_main WHERE x > (SELECT v FROM sq_empty); +---- + +# Aggregated subquery always returns one row, even on empty input → success +query I +SELECT (SELECT count(*) FROM sq_empty); +---- +0 + +# Aggregated subquery on multi-row table → success (aggregation reduces to 1 row) +query I +SELECT (SELECT max(v) FROM sq_values); +---- +3 + +# Multiple scalar subqueries, one returns multiple rows → error +query error DataFusion error: Execution error: Scalar subquery returned more than one row +SELECT (SELECT count(*) FROM sq_empty), (SELECT v FROM sq_values); + +############# +## Uncorrelated scalar subqueries in various expression contexts +############# + +# HAVING clause with uncorrelated scalar subquery +query II rowsort +SELECT x, count(*) AS cnt FROM sq_main GROUP BY x +HAVING count(*) > (SELECT min(v) FROM sq_values); +---- + +# CASE WHEN with uncorrelated scalar subquery as condition +query T rowsort +SELECT CASE WHEN x > (SELECT min(v) FROM sq_values) + THEN 'big' ELSE 'small' END AS label +FROM sq_main; +---- +big +big + +# ORDER BY with uncorrelated scalar subquery +query I +SELECT x FROM sq_main ORDER BY x + (SELECT max(v) FROM sq_values); +---- +10 +20 + +# Aggregate function argument containing uncorrelated scalar subquery +query I +SELECT sum(x + (SELECT max(v) FROM sq_values)) AS s FROM sq_main; +---- +36 + +# JOIN ON condition with uncorrelated scalar subquery +query II rowsort +SELECT l.x, r.x AS rx +FROM sq_main AS l JOIN sq_main AS r +ON l.x = r.x + (SELECT min(v) FROM sq_values); +---- + +# Nested uncorrelated-in-uncorrelated scalar subquery. +query I +SELECT (SELECT max(v) + (SELECT min(v) FROM sq_values) FROM sq_values); +---- +4 + +# Verify nested subqueries are not hoisted: the root ScalarSubqueryExec +# should manage only the outer subquery (subqueries=1), not both. +query TT +EXPLAIN SELECT (SELECT max(v) + (SELECT min(v) FROM sq_values) FROM sq_values); +---- +logical_plan +01)Projection: () +02)--Subquery: +03)----Projection: max(sq_values.v) + () +04)------Subquery: +05)--------Aggregate: groupBy=[[]], aggr=[[min(sq_values.v)]] +06)----------TableScan: sq_values projection=[v] +07)------Aggregate: groupBy=[[]], aggr=[[max(sq_values.v)]] +08)--------TableScan: sq_values projection=[v] +09)--EmptyRelation: rows=1 +physical_plan +01)ScalarSubqueryExec: subqueries=1 +02)--ProjectionExec: expr=[scalar_subquery() as max(sq_values.v) + min(sq_values.v)] +03)----PlaceholderRowExec +04)--ScalarSubqueryExec: subqueries=1 +05)----ProjectionExec: expr=[max(sq_values.v)@0 + scalar_subquery() as max(sq_values.v) + min(sq_values.v)] +06)------AggregateExec: mode=Single, gby=[], aggr=[max(sq_values.v)] +07)--------DataSourceExec: partitions=1, partition_sizes=[1] +08)----AggregateExec: mode=Single, gby=[], aggr=[min(sq_values.v)] +09)------DataSourceExec: partitions=1, partition_sizes=[1] + +# CTE as source inside uncorrelated scalar subquery +query I +SELECT (SELECT s FROM (WITH cte AS (SELECT max(v) AS s FROM sq_values) SELECT s FROM cte)); +---- +3 + +# Window function with uncorrelated scalar subquery +query II rowsort +SELECT x, sum(x + (SELECT max(v) FROM sq_values)) OVER () AS win_sum FROM sq_main; +---- +10 36 +20 36 + +# Duplicate uncorrelated scalar subqueries only appear in the query plan once +statement ok +set datafusion.explain.logical_plan_only = false; + +query TT +explain SELECT (SELECT max(v) FROM sq_values) + (SELECT max(v) FROM sq_values) AS doubled; +---- +logical_plan +01)Projection: __common_expr_1 + __common_expr_1 AS doubled +02)--Projection: () AS __common_expr_1 +03)----Subquery: +04)------Aggregate: groupBy=[[]], aggr=[[max(sq_values.v)]] +05)--------TableScan: sq_values projection=[v] +06)----EmptyRelation: rows=1 +physical_plan +01)ScalarSubqueryExec: subqueries=1 +02)--ProjectionExec: expr=[__common_expr_1@0 + __common_expr_1@0 as doubled] +03)----ProjectionExec: expr=[scalar_subquery() as __common_expr_1] +04)------PlaceholderRowExec +05)--AggregateExec: mode=Single, gby=[], aggr=[max(sq_values.v)] +06)----DataSourceExec: partitions=1, partition_sizes=[1] + +statement ok +RESET datafusion.explain.logical_plan_only; + +statement count 0 +DROP TABLE sq_values; + +statement count 0 +DROP TABLE sq_main; + +statement count 0 +DROP TABLE sq_empty; + # https://github.com/apache/datafusion/issues/21205 statement ok CREATE TABLE dup_filter_t1(id INTEGER) AS VALUES (1), (2), (3); @@ -1739,3 +1984,38 @@ DROP TABLE sq_name_t1; statement ok DROP TABLE sq_name_t2; + +# Test: scalar subquery in filter on a partition column of a partitioned table. +# This exercises the code path where filters are pushed down to the table +# provider for partition pruning. Scalar subqueries must not be pushed to the +# provider because the subquery result is not available at partition listing +# time. + +query I +COPY (VALUES(1, 'a'), (2, 'b'), (3, 'c')) +TO 'test_files/scratch/subquery/partition_pruning/part=1/file1.parquet'; +---- +3 + +query I +COPY (VALUES(4, 'd'), (5, 'e')) +TO 'test_files/scratch/subquery/partition_pruning/part=2/file1.parquet'; +---- +2 + +statement ok +CREATE EXTERNAL TABLE subquery_partitioned +STORED AS PARQUET +LOCATION 'test_files/scratch/subquery/partition_pruning/'; + +query IT +SELECT column1, column2 FROM subquery_partitioned +WHERE part = (SELECT 1) +ORDER BY column1; +---- +1 a +2 b +3 c + +statement ok +DROP TABLE subquery_partitioned; diff --git a/datafusion/sqllogictest/test_files/tpch/plans/q11.slt.part b/datafusion/sqllogictest/test_files/tpch/plans/q11.slt.part index a31579eb1e09d..0c5b6d76dc1e1 100644 --- a/datafusion/sqllogictest/test_files/tpch/plans/q11.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/plans/q11.slt.part @@ -49,61 +49,62 @@ limit 10; logical_plan 01)Sort: value DESC NULLS FIRST, fetch=10 02)--Projection: partsupp.ps_partkey, sum(partsupp.ps_supplycost * partsupp.ps_availqty) AS value -03)----Inner Join: Filter: CAST(sum(partsupp.ps_supplycost * partsupp.ps_availqty) AS Decimal128(38, 15)) > __scalar_sq_1.sum(partsupp.ps_supplycost * partsupp.ps_availqty) * Float64(0.0001) -04)------Aggregate: groupBy=[[partsupp.ps_partkey]], aggr=[[sum(partsupp.ps_supplycost * CAST(partsupp.ps_availqty AS Decimal128(10, 0)))]] -05)--------Projection: partsupp.ps_partkey, partsupp.ps_availqty, partsupp.ps_supplycost -06)----------Inner Join: supplier.s_nationkey = nation.n_nationkey -07)------------Projection: partsupp.ps_partkey, partsupp.ps_availqty, partsupp.ps_supplycost, supplier.s_nationkey -08)--------------Inner Join: partsupp.ps_suppkey = supplier.s_suppkey -09)----------------TableScan: partsupp projection=[ps_partkey, ps_suppkey, ps_availqty, ps_supplycost], partial_filters=[Boolean(true)] -10)----------------TableScan: supplier projection=[s_suppkey, s_nationkey] -11)------------Projection: nation.n_nationkey -12)--------------Filter: nation.n_name = Utf8View("GERMANY") -13)----------------TableScan: nation projection=[n_nationkey, n_name], partial_filters=[nation.n_name = Utf8View("GERMANY")] -14)------SubqueryAlias: __scalar_sq_1 -15)--------Projection: CAST(CAST(sum(partsupp.ps_supplycost * partsupp.ps_availqty) AS Float64) * Float64(0.0001) AS Decimal128(38, 15)) -16)----------Aggregate: groupBy=[[]], aggr=[[sum(partsupp.ps_supplycost * CAST(partsupp.ps_availqty AS Decimal128(10, 0)))]] -17)------------Projection: partsupp.ps_availqty, partsupp.ps_supplycost -18)--------------Inner Join: supplier.s_nationkey = nation.n_nationkey -19)----------------Projection: partsupp.ps_availqty, partsupp.ps_supplycost, supplier.s_nationkey -20)------------------Inner Join: partsupp.ps_suppkey = supplier.s_suppkey -21)--------------------TableScan: partsupp projection=[ps_suppkey, ps_availqty, ps_supplycost] -22)--------------------TableScan: supplier projection=[s_suppkey, s_nationkey] -23)----------------Projection: nation.n_nationkey -24)------------------Filter: nation.n_name = Utf8View("GERMANY") -25)--------------------TableScan: nation projection=[n_nationkey, n_name], partial_filters=[nation.n_name = Utf8View("GERMANY")] +03)----Filter: CAST(sum(partsupp.ps_supplycost * partsupp.ps_availqty) AS Decimal128(38, 15)) > () +04)------Subquery: +05)--------Projection: CAST(CAST(sum(partsupp.ps_supplycost * partsupp.ps_availqty) AS Float64) * Float64(0.0001) AS Decimal128(38, 15)) +06)----------Aggregate: groupBy=[[]], aggr=[[sum(partsupp.ps_supplycost * CAST(partsupp.ps_availqty AS Decimal128(10, 0)))]] +07)------------Projection: partsupp.ps_availqty, partsupp.ps_supplycost +08)--------------Inner Join: supplier.s_nationkey = nation.n_nationkey +09)----------------Projection: partsupp.ps_availqty, partsupp.ps_supplycost, supplier.s_nationkey +10)------------------Inner Join: partsupp.ps_suppkey = supplier.s_suppkey +11)--------------------TableScan: partsupp projection=[ps_suppkey, ps_availqty, ps_supplycost] +12)--------------------TableScan: supplier projection=[s_suppkey, s_nationkey] +13)----------------Projection: nation.n_nationkey +14)------------------Filter: nation.n_name = Utf8View("GERMANY") +15)--------------------TableScan: nation projection=[n_nationkey, n_name], partial_filters=[nation.n_name = Utf8View("GERMANY")] +16)------Aggregate: groupBy=[[partsupp.ps_partkey]], aggr=[[sum(partsupp.ps_supplycost * CAST(partsupp.ps_availqty AS Decimal128(10, 0)))]] +17)--------Projection: partsupp.ps_partkey, partsupp.ps_availqty, partsupp.ps_supplycost +18)----------Inner Join: supplier.s_nationkey = nation.n_nationkey +19)------------Projection: partsupp.ps_partkey, partsupp.ps_availqty, partsupp.ps_supplycost, supplier.s_nationkey +20)--------------Inner Join: partsupp.ps_suppkey = supplier.s_suppkey +21)----------------TableScan: partsupp projection=[ps_partkey, ps_suppkey, ps_availqty, ps_supplycost] +22)----------------TableScan: supplier projection=[s_suppkey, s_nationkey] +23)------------Projection: nation.n_nationkey +24)--------------Filter: nation.n_name = Utf8View("GERMANY") +25)----------------TableScan: nation projection=[n_nationkey, n_name], partial_filters=[nation.n_name = Utf8View("GERMANY")] physical_plan -01)SortExec: TopK(fetch=10), expr=[value@1 DESC], preserve_partitioning=[false] -02)--ProjectionExec: expr=[ps_partkey@0 as ps_partkey, sum(partsupp.ps_supplycost * partsupp.ps_availqty)@1 as value] -03)----NestedLoopJoinExec: join_type=Inner, filter=join_proj_push_down_1@1 > sum(partsupp.ps_supplycost * partsupp.ps_availqty) * Float64(0.0001)@0, projection=[ps_partkey@0, sum(partsupp.ps_supplycost * partsupp.ps_availqty)@1, sum(partsupp.ps_supplycost * partsupp.ps_availqty) * Float64(0.0001)@3] -04)------ProjectionExec: expr=[ps_partkey@0 as ps_partkey, sum(partsupp.ps_supplycost * partsupp.ps_availqty)@1 as sum(partsupp.ps_supplycost * partsupp.ps_availqty), CAST(sum(partsupp.ps_supplycost * partsupp.ps_availqty)@1 AS Decimal128(38, 15)) as join_proj_push_down_1] -05)--------CoalescePartitionsExec -06)----------AggregateExec: mode=FinalPartitioned, gby=[ps_partkey@0 as ps_partkey], aggr=[sum(partsupp.ps_supplycost * partsupp.ps_availqty)] -07)------------RepartitionExec: partitioning=Hash([ps_partkey@0], 4), input_partitions=4 -08)--------------AggregateExec: mode=Partial, gby=[ps_partkey@0 as ps_partkey], aggr=[sum(partsupp.ps_supplycost * partsupp.ps_availqty)] -09)----------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(s_nationkey@3, n_nationkey@0)], projection=[ps_partkey@0, ps_availqty@1, ps_supplycost@2] -10)------------------RepartitionExec: partitioning=Hash([s_nationkey@3], 4), input_partitions=4 -11)--------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(ps_suppkey@1, s_suppkey@0)], projection=[ps_partkey@0, ps_availqty@2, ps_supplycost@3, s_nationkey@5] -12)----------------------RepartitionExec: partitioning=Hash([ps_suppkey@1], 4), input_partitions=4 -13)------------------------DataSourceExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:0..2932049], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:2932049..5864098], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:5864098..8796147], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:8796147..11728193]]}, projection=[ps_partkey, ps_suppkey, ps_availqty, ps_supplycost], file_type=csv, has_header=false -14)----------------------RepartitionExec: partitioning=Hash([s_suppkey@0], 4), input_partitions=1 -15)------------------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/supplier.tbl]]}, projection=[s_suppkey, s_nationkey], file_type=csv, has_header=false -16)------------------RepartitionExec: partitioning=Hash([n_nationkey@0], 4), input_partitions=4 -17)--------------------FilterExec: n_name@1 = GERMANY, projection=[n_nationkey@0] -18)----------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -19)------------------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/nation.tbl]]}, projection=[n_nationkey, n_name], file_type=csv, has_header=false -20)------ProjectionExec: expr=[CAST(CAST(sum(partsupp.ps_supplycost * partsupp.ps_availqty)@0 AS Float64) * 0.0001 AS Decimal128(38, 15)) as sum(partsupp.ps_supplycost * partsupp.ps_availqty) * Float64(0.0001)] -21)--------AggregateExec: mode=Final, gby=[], aggr=[sum(partsupp.ps_supplycost * partsupp.ps_availqty)] -22)----------CoalescePartitionsExec -23)------------AggregateExec: mode=Partial, gby=[], aggr=[sum(partsupp.ps_supplycost * partsupp.ps_availqty)] -24)--------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(s_nationkey@2, n_nationkey@0)], projection=[ps_availqty@0, ps_supplycost@1] -25)----------------RepartitionExec: partitioning=Hash([s_nationkey@2], 4), input_partitions=4 -26)------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(ps_suppkey@0, s_suppkey@0)], projection=[ps_availqty@1, ps_supplycost@2, s_nationkey@4] -27)--------------------RepartitionExec: partitioning=Hash([ps_suppkey@0], 4), input_partitions=4 -28)----------------------DataSourceExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:0..2932049], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:2932049..5864098], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:5864098..8796147], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:8796147..11728193]]}, projection=[ps_suppkey, ps_availqty, ps_supplycost], file_type=csv, has_header=false -29)--------------------RepartitionExec: partitioning=Hash([s_suppkey@0], 4), input_partitions=1 -30)----------------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/supplier.tbl]]}, projection=[s_suppkey, s_nationkey], file_type=csv, has_header=false -31)----------------RepartitionExec: partitioning=Hash([n_nationkey@0], 4), input_partitions=4 -32)------------------FilterExec: n_name@1 = GERMANY, projection=[n_nationkey@0] -33)--------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -34)----------------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/nation.tbl]]}, projection=[n_nationkey, n_name], file_type=csv, has_header=false +01)ScalarSubqueryExec: subqueries=1 +02)--RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1, maintains_sort_order=true +03)----SortPreservingMergeExec: [value@1 DESC], fetch=10 +04)------SortExec: TopK(fetch=10), expr=[value@1 DESC], preserve_partitioning=[true] +05)--------ProjectionExec: expr=[ps_partkey@0 as ps_partkey, sum(partsupp.ps_supplycost * partsupp.ps_availqty)@1 as value] +06)----------FilterExec: CAST(sum(partsupp.ps_supplycost * partsupp.ps_availqty)@1 AS Decimal128(38, 15)) > scalar_subquery() +07)------------AggregateExec: mode=FinalPartitioned, gby=[ps_partkey@0 as ps_partkey], aggr=[sum(partsupp.ps_supplycost * partsupp.ps_availqty)] +08)--------------RepartitionExec: partitioning=Hash([ps_partkey@0], 4), input_partitions=4 +09)----------------AggregateExec: mode=Partial, gby=[ps_partkey@0 as ps_partkey], aggr=[sum(partsupp.ps_supplycost * partsupp.ps_availqty)] +10)------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(s_nationkey@3, n_nationkey@0)], projection=[ps_partkey@0, ps_availqty@1, ps_supplycost@2] +11)--------------------RepartitionExec: partitioning=Hash([s_nationkey@3], 4), input_partitions=4 +12)----------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(ps_suppkey@1, s_suppkey@0)], projection=[ps_partkey@0, ps_availqty@2, ps_supplycost@3, s_nationkey@5] +13)------------------------RepartitionExec: partitioning=Hash([ps_suppkey@1], 4), input_partitions=4 +14)--------------------------DataSourceExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:0..2932049], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:2932049..5864098], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:5864098..8796147], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:8796147..11728193]]}, projection=[ps_partkey, ps_suppkey, ps_availqty, ps_supplycost], file_type=csv, has_header=false +15)------------------------RepartitionExec: partitioning=Hash([s_suppkey@0], 4), input_partitions=1 +16)--------------------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/supplier.tbl]]}, projection=[s_suppkey, s_nationkey], file_type=csv, has_header=false +17)--------------------RepartitionExec: partitioning=Hash([n_nationkey@0], 4), input_partitions=4 +18)----------------------FilterExec: n_name@1 = GERMANY, projection=[n_nationkey@0] +19)------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +20)--------------------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/nation.tbl]]}, projection=[n_nationkey, n_name], file_type=csv, has_header=false +21)--ProjectionExec: expr=[CAST(CAST(sum(partsupp.ps_supplycost * partsupp.ps_availqty)@0 AS Float64) * 0.0001 AS Decimal128(38, 15)) as sum(partsupp.ps_supplycost * partsupp.ps_availqty) * Float64(0.0001)] +22)----AggregateExec: mode=Final, gby=[], aggr=[sum(partsupp.ps_supplycost * partsupp.ps_availqty)] +23)------CoalescePartitionsExec +24)--------AggregateExec: mode=Partial, gby=[], aggr=[sum(partsupp.ps_supplycost * partsupp.ps_availqty)] +25)----------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(s_nationkey@2, n_nationkey@0)], projection=[ps_availqty@0, ps_supplycost@1] +26)------------RepartitionExec: partitioning=Hash([s_nationkey@2], 4), input_partitions=4 +27)--------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(ps_suppkey@0, s_suppkey@0)], projection=[ps_availqty@1, ps_supplycost@2, s_nationkey@4] +28)----------------RepartitionExec: partitioning=Hash([ps_suppkey@0], 4), input_partitions=4 +29)------------------DataSourceExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:0..2932049], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:2932049..5864098], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:5864098..8796147], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:8796147..11728193]]}, projection=[ps_suppkey, ps_availqty, ps_supplycost], file_type=csv, has_header=false +30)----------------RepartitionExec: partitioning=Hash([s_suppkey@0], 4), input_partitions=1 +31)------------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/supplier.tbl]]}, projection=[s_suppkey, s_nationkey], file_type=csv, has_header=false +32)------------RepartitionExec: partitioning=Hash([n_nationkey@0], 4), input_partitions=4 +33)--------------FilterExec: n_name@1 = GERMANY, projection=[n_nationkey@0] +34)----------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +35)------------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/nation.tbl]]}, projection=[n_nationkey, n_name], file_type=csv, has_header=false diff --git a/datafusion/sqllogictest/test_files/tpch/plans/q15.slt.part b/datafusion/sqllogictest/test_files/tpch/plans/q15.slt.part index ae0c0a93a3552..3e1aca318b5c7 100644 --- a/datafusion/sqllogictest/test_files/tpch/plans/q15.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/plans/q15.slt.part @@ -52,43 +52,44 @@ order by logical_plan 01)Sort: supplier.s_suppkey ASC NULLS LAST 02)--Projection: supplier.s_suppkey, supplier.s_name, supplier.s_address, supplier.s_phone, revenue0.total_revenue -03)----Inner Join: revenue0.total_revenue = __scalar_sq_1.max(revenue0.total_revenue) -04)------Projection: supplier.s_suppkey, supplier.s_name, supplier.s_address, supplier.s_phone, revenue0.total_revenue -05)--------Inner Join: supplier.s_suppkey = revenue0.supplier_no -06)----------TableScan: supplier projection=[s_suppkey, s_name, s_address, s_phone], partial_filters=[Boolean(true)] -07)----------SubqueryAlias: revenue0 -08)------------Projection: lineitem.l_suppkey AS supplier_no, sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount) AS total_revenue -09)--------------Aggregate: groupBy=[[lineitem.l_suppkey]], aggr=[[sum(lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount)) AS sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)]] -10)----------------Projection: lineitem.l_suppkey, lineitem.l_extendedprice, lineitem.l_discount -11)------------------Filter: lineitem.l_shipdate >= Date32("1996-01-01") AND lineitem.l_shipdate < Date32("1996-04-01") -12)--------------------TableScan: lineitem projection=[l_suppkey, l_extendedprice, l_discount, l_shipdate], partial_filters=[lineitem.l_shipdate >= Date32("1996-01-01"), lineitem.l_shipdate < Date32("1996-04-01")] -13)------SubqueryAlias: __scalar_sq_1 -14)--------Aggregate: groupBy=[[]], aggr=[[max(revenue0.total_revenue)]] -15)----------SubqueryAlias: revenue0 -16)------------Projection: sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount) AS total_revenue -17)--------------Aggregate: groupBy=[[lineitem.l_suppkey]], aggr=[[sum(lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount)) AS sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)]] -18)----------------Projection: lineitem.l_suppkey, lineitem.l_extendedprice, lineitem.l_discount -19)------------------Filter: lineitem.l_shipdate >= Date32("1996-01-01") AND lineitem.l_shipdate < Date32("1996-04-01") -20)--------------------TableScan: lineitem projection=[l_suppkey, l_extendedprice, l_discount, l_shipdate], partial_filters=[lineitem.l_shipdate >= Date32("1996-01-01"), lineitem.l_shipdate < Date32("1996-04-01")] +03)----Inner Join: supplier.s_suppkey = revenue0.supplier_no +04)------TableScan: supplier projection=[s_suppkey, s_name, s_address, s_phone] +05)------SubqueryAlias: revenue0 +06)--------Projection: lineitem.l_suppkey AS supplier_no, sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount) AS total_revenue +07)----------Filter: sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount) = () +08)------------Subquery: +09)--------------Aggregate: groupBy=[[]], aggr=[[max(revenue0.total_revenue)]] +10)----------------SubqueryAlias: revenue0 +11)------------------Projection: sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount) AS total_revenue +12)--------------------Aggregate: groupBy=[[lineitem.l_suppkey]], aggr=[[sum(lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount)) AS sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)]] +13)----------------------Projection: lineitem.l_suppkey, lineitem.l_extendedprice, lineitem.l_discount +14)------------------------Filter: lineitem.l_shipdate >= Date32("1996-01-01") AND lineitem.l_shipdate < Date32("1996-04-01") +15)--------------------------TableScan: lineitem projection=[l_suppkey, l_extendedprice, l_discount, l_shipdate], partial_filters=[lineitem.l_shipdate >= Date32("1996-01-01"), lineitem.l_shipdate < Date32("1996-04-01")] +16)------------Aggregate: groupBy=[[lineitem.l_suppkey]], aggr=[[sum(lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount)) AS sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)]] +17)--------------Projection: lineitem.l_suppkey, lineitem.l_extendedprice, lineitem.l_discount +18)----------------Filter: lineitem.l_shipdate >= Date32("1996-01-01") AND lineitem.l_shipdate < Date32("1996-04-01") +19)------------------TableScan: lineitem projection=[l_suppkey, l_extendedprice, l_discount, l_shipdate], partial_filters=[lineitem.l_shipdate >= Date32("1996-01-01"), lineitem.l_shipdate < Date32("1996-04-01")] physical_plan -01)SortPreservingMergeExec: [s_suppkey@0 ASC NULLS LAST] -02)--SortExec: expr=[s_suppkey@0 ASC NULLS LAST], preserve_partitioning=[true] -03)----HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(max(revenue0.total_revenue)@0, total_revenue@4)], projection=[s_suppkey@1, s_name@2, s_address@3, s_phone@4, total_revenue@5] -04)------AggregateExec: mode=Final, gby=[], aggr=[max(revenue0.total_revenue)] -05)--------CoalescePartitionsExec -06)----------AggregateExec: mode=Partial, gby=[], aggr=[max(revenue0.total_revenue)] -07)------------ProjectionExec: expr=[sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)@1 as total_revenue] -08)--------------AggregateExec: mode=FinalPartitioned, gby=[l_suppkey@0 as l_suppkey], aggr=[sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)] -09)----------------RepartitionExec: partitioning=Hash([l_suppkey@0], 4), input_partitions=4 -10)------------------AggregateExec: mode=Partial, gby=[l_suppkey@0 as l_suppkey], aggr=[sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)] -11)--------------------FilterExec: l_shipdate@3 >= 1996-01-01 AND l_shipdate@3 < 1996-04-01, projection=[l_suppkey@0, l_extendedprice@1, l_discount@2] -12)----------------------DataSourceExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:0..18561749], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:18561749..37123498], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:37123498..55685247], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:55685247..74246996]]}, projection=[l_suppkey, l_extendedprice, l_discount, l_shipdate], file_type=csv, has_header=false -13)------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(s_suppkey@0, supplier_no@0)], projection=[s_suppkey@0, s_name@1, s_address@2, s_phone@3, total_revenue@5] -14)--------RepartitionExec: partitioning=Hash([s_suppkey@0], 4), input_partitions=1 -15)----------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/supplier.tbl]]}, projection=[s_suppkey, s_name, s_address, s_phone], file_type=csv, has_header=false -16)--------ProjectionExec: expr=[l_suppkey@0 as supplier_no, sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)@1 as total_revenue] -17)----------AggregateExec: mode=FinalPartitioned, gby=[l_suppkey@0 as l_suppkey], aggr=[sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)] -18)------------RepartitionExec: partitioning=Hash([l_suppkey@0], 4), input_partitions=4 -19)--------------AggregateExec: mode=Partial, gby=[l_suppkey@0 as l_suppkey], aggr=[sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)] -20)----------------FilterExec: l_shipdate@3 >= 1996-01-01 AND l_shipdate@3 < 1996-04-01, projection=[l_suppkey@0, l_extendedprice@1, l_discount@2] -21)------------------DataSourceExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:0..18561749], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:18561749..37123498], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:37123498..55685247], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:55685247..74246996]]}, projection=[l_suppkey, l_extendedprice, l_discount, l_shipdate], file_type=csv, has_header=false +01)ScalarSubqueryExec: subqueries=1 +02)--RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1, maintains_sort_order=true +03)----SortPreservingMergeExec: [s_suppkey@0 ASC NULLS LAST] +04)------SortExec: expr=[s_suppkey@0 ASC NULLS LAST], preserve_partitioning=[true] +05)--------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(s_suppkey@0, supplier_no@0)], projection=[s_suppkey@0, s_name@1, s_address@2, s_phone@3, total_revenue@5] +06)----------RepartitionExec: partitioning=Hash([s_suppkey@0], 4), input_partitions=1 +07)------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/supplier.tbl]]}, projection=[s_suppkey, s_name, s_address, s_phone], file_type=csv, has_header=false +08)----------ProjectionExec: expr=[l_suppkey@0 as supplier_no, sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)@1 as total_revenue] +09)------------FilterExec: sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)@1 = scalar_subquery() +10)--------------AggregateExec: mode=FinalPartitioned, gby=[l_suppkey@0 as l_suppkey], aggr=[sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)] +11)----------------RepartitionExec: partitioning=Hash([l_suppkey@0], 4), input_partitions=4 +12)------------------AggregateExec: mode=Partial, gby=[l_suppkey@0 as l_suppkey], aggr=[sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)] +13)--------------------FilterExec: l_shipdate@3 >= 1996-01-01 AND l_shipdate@3 < 1996-04-01, projection=[l_suppkey@0, l_extendedprice@1, l_discount@2] +14)----------------------DataSourceExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:0..18561749], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:18561749..37123498], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:37123498..55685247], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:55685247..74246996]]}, projection=[l_suppkey, l_extendedprice, l_discount, l_shipdate], file_type=csv, has_header=false +15)--AggregateExec: mode=Final, gby=[], aggr=[max(revenue0.total_revenue)] +16)----CoalescePartitionsExec +17)------AggregateExec: mode=Partial, gby=[], aggr=[max(revenue0.total_revenue)] +18)--------ProjectionExec: expr=[sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)@1 as total_revenue] +19)----------AggregateExec: mode=FinalPartitioned, gby=[l_suppkey@0 as l_suppkey], aggr=[sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)] +20)------------RepartitionExec: partitioning=Hash([l_suppkey@0], 4), input_partitions=4 +21)--------------AggregateExec: mode=Partial, gby=[l_suppkey@0 as l_suppkey], aggr=[sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)] +22)----------------FilterExec: l_shipdate@3 >= 1996-01-01 AND l_shipdate@3 < 1996-04-01, projection=[l_suppkey@0, l_extendedprice@1, l_discount@2] +23)------------------DataSourceExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:0..18561749], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:18561749..37123498], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:37123498..55685247], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:55685247..74246996]]}, projection=[l_suppkey, l_extendedprice, l_discount, l_shipdate], file_type=csv, has_header=false diff --git a/datafusion/sqllogictest/test_files/tpch/plans/q22.slt.part b/datafusion/sqllogictest/test_files/tpch/plans/q22.slt.part index add578c3b079d..3240cbfb697d5 100644 --- a/datafusion/sqllogictest/test_files/tpch/plans/q22.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/plans/q22.slt.part @@ -61,40 +61,36 @@ logical_plan 03)----Aggregate: groupBy=[[custsale.cntrycode]], aggr=[[count(Int64(1)), sum(custsale.c_acctbal)]] 04)------SubqueryAlias: custsale 05)--------Projection: substr(customer.c_phone, Int64(1), Int64(2)) AS cntrycode, customer.c_acctbal -06)----------Inner Join: Filter: CAST(customer.c_acctbal AS Decimal128(19, 6)) > __scalar_sq_2.avg(customer.c_acctbal) -07)------------Projection: customer.c_phone, customer.c_acctbal -08)--------------LeftAnti Join: customer.c_custkey = __correlated_sq_1.o_custkey -09)----------------Filter: substr(customer.c_phone, Int64(1), Int64(2)) IN ([Utf8View("13"), Utf8View("31"), Utf8View("23"), Utf8View("29"), Utf8View("30"), Utf8View("18"), Utf8View("17")]) -10)------------------TableScan: customer projection=[c_custkey, c_phone, c_acctbal], partial_filters=[substr(customer.c_phone, Int64(1), Int64(2)) IN ([Utf8View("13"), Utf8View("31"), Utf8View("23"), Utf8View("29"), Utf8View("30"), Utf8View("18"), Utf8View("17")]), Boolean(true)] -11)----------------SubqueryAlias: __correlated_sq_1 -12)------------------TableScan: orders projection=[o_custkey] -13)------------SubqueryAlias: __scalar_sq_2 -14)--------------Aggregate: groupBy=[[]], aggr=[[avg(customer.c_acctbal)]] -15)----------------Projection: customer.c_acctbal -16)------------------Filter: customer.c_acctbal > Decimal128(Some(0),15,2) AND substr(customer.c_phone, Int64(1), Int64(2)) IN ([Utf8View("13"), Utf8View("31"), Utf8View("23"), Utf8View("29"), Utf8View("30"), Utf8View("18"), Utf8View("17")]) -17)--------------------TableScan: customer projection=[c_phone, c_acctbal], partial_filters=[customer.c_acctbal > Decimal128(Some(0),15,2), substr(customer.c_phone, Int64(1), Int64(2)) IN ([Utf8View("13"), Utf8View("31"), Utf8View("23"), Utf8View("29"), Utf8View("30"), Utf8View("18"), Utf8View("17")])] +06)----------LeftAnti Join: customer.c_custkey = __correlated_sq_1.o_custkey +07)------------Filter: substr(customer.c_phone, Int64(1), Int64(2)) IN ([Utf8View("13"), Utf8View("31"), Utf8View("23"), Utf8View("29"), Utf8View("30"), Utf8View("18"), Utf8View("17")]) AND CAST(customer.c_acctbal AS Decimal128(19, 6)) > () +08)--------------Subquery: +09)----------------Aggregate: groupBy=[[]], aggr=[[avg(customer.c_acctbal)]] +10)------------------Projection: customer.c_acctbal +11)--------------------Filter: customer.c_acctbal > Decimal128(Some(0),15,2) AND substr(customer.c_phone, Int64(1), Int64(2)) IN ([Utf8View("13"), Utf8View("31"), Utf8View("23"), Utf8View("29"), Utf8View("30"), Utf8View("18"), Utf8View("17")]) +12)----------------------TableScan: customer projection=[c_phone, c_acctbal], partial_filters=[customer.c_acctbal > Decimal128(Some(0),15,2), substr(customer.c_phone, Int64(1), Int64(2)) IN ([Utf8View("13"), Utf8View("31"), Utf8View("23"), Utf8View("29"), Utf8View("30"), Utf8View("18"), Utf8View("17")])] +13)--------------TableScan: customer projection=[c_custkey, c_phone, c_acctbal], partial_filters=[substr(customer.c_phone, Int64(1), Int64(2)) IN ([Utf8View("13"), Utf8View("31"), Utf8View("23"), Utf8View("29"), Utf8View("30"), Utf8View("18"), Utf8View("17")])] +14)------------SubqueryAlias: __correlated_sq_1 +15)--------------TableScan: orders projection=[o_custkey] physical_plan -01)SortPreservingMergeExec: [cntrycode@0 ASC NULLS LAST] -02)--SortExec: expr=[cntrycode@0 ASC NULLS LAST], preserve_partitioning=[true] -03)----ProjectionExec: expr=[cntrycode@0 as cntrycode, count(Int64(1))@1 as numcust, sum(custsale.c_acctbal)@2 as totacctbal] -04)------AggregateExec: mode=FinalPartitioned, gby=[cntrycode@0 as cntrycode], aggr=[count(Int64(1)), sum(custsale.c_acctbal)] -05)--------RepartitionExec: partitioning=Hash([cntrycode@0], 4), input_partitions=4 -06)----------AggregateExec: mode=Partial, gby=[cntrycode@0 as cntrycode], aggr=[count(Int64(1)), sum(custsale.c_acctbal)] -07)------------ProjectionExec: expr=[substr(c_phone@0, 1, 2) as cntrycode, c_acctbal@1 as c_acctbal] -08)--------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -09)----------------NestedLoopJoinExec: join_type=Inner, filter=join_proj_push_down_1@1 > avg(customer.c_acctbal)@0, projection=[c_phone@0, c_acctbal@1, avg(customer.c_acctbal)@3] -10)------------------ProjectionExec: expr=[c_phone@0 as c_phone, c_acctbal@1 as c_acctbal, CAST(c_acctbal@1 AS Decimal128(19, 6)) as join_proj_push_down_1] -11)--------------------CoalescePartitionsExec -12)----------------------HashJoinExec: mode=Partitioned, join_type=LeftAnti, on=[(c_custkey@0, o_custkey@0)], projection=[c_phone@1, c_acctbal@2] -13)------------------------RepartitionExec: partitioning=Hash([c_custkey@0], 4), input_partitions=4 -14)--------------------------FilterExec: substr(c_phone@1, 1, 2) IN (SET) ([13, 31, 23, 29, 30, 18, 17]) -15)----------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -16)------------------------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/customer.tbl]]}, projection=[c_custkey, c_phone, c_acctbal], file_type=csv, has_header=false -17)------------------------RepartitionExec: partitioning=Hash([o_custkey@0], 4), input_partitions=4 -18)--------------------------DataSourceExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:0..4223281], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:4223281..8446562], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:8446562..12669843], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:12669843..16893122]]}, projection=[o_custkey], file_type=csv, has_header=false -19)------------------AggregateExec: mode=Final, gby=[], aggr=[avg(customer.c_acctbal)] -20)--------------------CoalescePartitionsExec -21)----------------------AggregateExec: mode=Partial, gby=[], aggr=[avg(customer.c_acctbal)] -22)------------------------FilterExec: c_acctbal@1 > Some(0),15,2 AND substr(c_phone@0, 1, 2) IN (SET) ([13, 31, 23, 29, 30, 18, 17]), projection=[c_acctbal@1] -23)--------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -24)----------------------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/customer.tbl]]}, projection=[c_phone, c_acctbal], file_type=csv, has_header=false +01)ScalarSubqueryExec: subqueries=1 +02)--RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1, maintains_sort_order=true +03)----SortPreservingMergeExec: [cntrycode@0 ASC NULLS LAST] +04)------SortExec: expr=[cntrycode@0 ASC NULLS LAST], preserve_partitioning=[true] +05)--------ProjectionExec: expr=[cntrycode@0 as cntrycode, count(Int64(1))@1 as numcust, sum(custsale.c_acctbal)@2 as totacctbal] +06)----------AggregateExec: mode=FinalPartitioned, gby=[cntrycode@0 as cntrycode], aggr=[count(Int64(1)), sum(custsale.c_acctbal)] +07)------------RepartitionExec: partitioning=Hash([cntrycode@0], 4), input_partitions=4 +08)--------------AggregateExec: mode=Partial, gby=[cntrycode@0 as cntrycode], aggr=[count(Int64(1)), sum(custsale.c_acctbal)] +09)----------------ProjectionExec: expr=[substr(c_phone@0, 1, 2) as cntrycode, c_acctbal@1 as c_acctbal] +10)------------------HashJoinExec: mode=Partitioned, join_type=LeftAnti, on=[(c_custkey@0, o_custkey@0)], projection=[c_phone@1, c_acctbal@2] +11)--------------------RepartitionExec: partitioning=Hash([c_custkey@0], 4), input_partitions=4 +12)----------------------FilterExec: substr(c_phone@1, 1, 2) IN (SET) ([13, 31, 23, 29, 30, 18, 17]) AND CAST(c_acctbal@2 AS Decimal128(19, 6)) > scalar_subquery() +13)------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +14)--------------------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/customer.tbl]]}, projection=[c_custkey, c_phone, c_acctbal], file_type=csv, has_header=false +15)--------------------RepartitionExec: partitioning=Hash([o_custkey@0], 4), input_partitions=4 +16)----------------------DataSourceExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:0..4223281], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:4223281..8446562], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:8446562..12669843], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:12669843..16893122]]}, projection=[o_custkey], file_type=csv, has_header=false +17)--AggregateExec: mode=Final, gby=[], aggr=[avg(customer.c_acctbal)] +18)----CoalescePartitionsExec +19)------AggregateExec: mode=Partial, gby=[], aggr=[avg(customer.c_acctbal)] +20)--------FilterExec: c_acctbal@1 > Some(0),15,2 AND substr(c_phone@0, 1, 2) IN (SET) ([13, 31, 23, 29, 30, 18, 17]), projection=[c_acctbal@1] +21)----------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +22)------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/customer.tbl]]}, projection=[c_phone, c_acctbal], file_type=csv, has_header=false