diff --git a/datafusion/physical-plan/src/coalesce_partitions.rs b/datafusion/physical-plan/src/coalesce_partitions.rs index 5ea3589f22b3e..c887c29083617 100644 --- a/datafusion/physical-plan/src/coalesce_partitions.rs +++ b/datafusion/physical-plan/src/coalesce_partitions.rs @@ -19,13 +19,15 @@ //! into a single partition use std::any::Any; +use std::pin::Pin; use std::sync::Arc; +use std::task::{Context, Poll}; use super::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}; use super::stream::{ObservedStream, RecordBatchReceiverStream}; use super::{ - DisplayAs, ExecutionPlanProperties, PlanProperties, SendableRecordBatchStream, - Statistics, + DisplayAs, ExecutionPlanProperties, PlanProperties, RecordBatchStream, + SendableRecordBatchStream, Statistics, }; use crate::execution_plan::{CardinalityEffect, EvaluationType, SchedulingType}; use crate::filter_pushdown::{FilterDescription, FilterPushdownPhase}; @@ -34,11 +36,14 @@ use crate::sort_pushdown::SortOrderPushdownResult; use crate::{DisplayFormatType, ExecutionPlan, Partitioning, check_if_same_properties}; use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; +use arrow::datatypes::SchemaRef; use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::TreeNodeRecursion; use datafusion_common::{Result, assert_eq_or_internal_err, internal_err}; use datafusion_execution::TaskContext; use datafusion_physical_expr::PhysicalExpr; +use futures::Stream; +use futures::StreamExt; /// Merge execution plan executes partitions in parallel and combines them into a single /// partition. No guarantees are made about the order of the resulting partition. @@ -209,33 +214,14 @@ impl ExecutionPlan for CoalescePartitionsExec { } _ => { let baseline_metrics = BaselineMetrics::new(&self.metrics, partition); - // record the (very) minimal work done so that - // elapsed_compute is not reported as 0 - let elapsed_compute = baseline_metrics.elapsed_compute().clone(); - let _timer = elapsed_compute.timer(); - - // use a stream that allows each sender to put in at - // least one result in an attempt to maximize - // parallelism. - let mut builder = - RecordBatchReceiverStream::builder(self.schema(), input_partitions); - - // spawn independent tasks whose resulting streams (of batches) - // are sent to the channel for consumption. - for part_i in 0..input_partitions { - builder.run_input( - Arc::clone(&self.input), - part_i, - Arc::clone(&context), - ); - } - - let stream = builder.build(); - Ok(Box::pin(ObservedStream::new( - stream, + Ok(Box::pin(CoalescePartitionsStream { + schema: self.schema(), + input: Arc::clone(&self.input), + context, baseline_metrics, - self.fetch, - ))) + fetch: self.fetch, + state: CoalescePartitionsStreamState::Pending, + })) } } } @@ -352,6 +338,81 @@ impl ExecutionPlan for CoalescePartitionsExec { } } +/// A stream that lazily spawns input partition tasks on first poll, +/// rather than eagerly in `execute()`. +struct CoalescePartitionsStream { + schema: SchemaRef, + input: Arc, + context: Arc, + baseline_metrics: BaselineMetrics, + fetch: Option, + state: CoalescePartitionsStreamState, +} + +enum CoalescePartitionsStreamState { + /// Tasks have not been spawned yet. + Pending, + /// Tasks have been spawned, polling the merged stream. + Running(Pin>), +} + +impl CoalescePartitionsStream { + fn start(&mut self) -> &mut Pin> { + let input_partitions = self.input.output_partitioning().partition_count(); + + // record the (very) minimal work done so that + // elapsed_compute is not reported as 0 + let elapsed_compute = self.baseline_metrics.elapsed_compute().clone(); + let _timer = elapsed_compute.timer(); + + // use a stream that allows each sender to put in at + // least one result in an attempt to maximize + // parallelism. + let mut builder = + RecordBatchReceiverStream::builder(self.schema.clone(), input_partitions); + + // spawn independent tasks whose resulting streams (of batches) + // are sent to the channel for consumption. + for part_i in 0..input_partitions { + builder.run_input( + Arc::clone(&self.input), + part_i, + Arc::clone(&self.context), + ); + } + + let stream = builder.build(); + self.state = CoalescePartitionsStreamState::Running(Box::pin( + ObservedStream::new(stream, self.baseline_metrics.clone(), self.fetch), + )); + match &mut self.state { + CoalescePartitionsStreamState::Running(s) => s, + _ => unreachable!(), + } + } +} + +impl RecordBatchStream for CoalescePartitionsStream { + fn schema(&self) -> SchemaRef { + self.schema.clone() + } +} + +impl Stream for CoalescePartitionsStream { + type Item = Result; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + let stream = match &mut self.state { + CoalescePartitionsStreamState::Running(s) => s, + CoalescePartitionsStreamState::Pending => self.start(), + }; + stream.poll_next_unpin(cx) + } +} + #[cfg(test)] mod tests { use super::*;