diff --git a/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs b/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs index 1e60c391f50d1..b24ff57e7bbcd 100644 --- a/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs +++ b/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs @@ -18,7 +18,9 @@ //! [`SortPreservingMergeExec`] merges multiple sorted streams into one sorted stream. use std::any::Any; +use std::pin::Pin; use std::sync::Arc; +use std::task::{Context, Poll}; use crate::common::spawn_buffered; use crate::limit::LimitStream; @@ -27,16 +29,18 @@ use crate::projection::{ProjectionExec, make_with_child, update_ordering}; use crate::sorts::streaming_merge::StreamingMergeBuilder; use crate::{ DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, ExecutionPlanProperties, - Partitioning, PlanProperties, SendableRecordBatchStream, Statistics, - check_if_same_properties, + Partitioning, PlanProperties, RecordBatchStream, SendableRecordBatchStream, + Statistics, check_if_same_properties, }; +use arrow::datatypes::SchemaRef; use datafusion_common::tree_node::TreeNodeRecursion; use datafusion_common::{Result, assert_eq_or_internal_err, internal_err}; use datafusion_execution::TaskContext; use datafusion_execution::memory_pool::MemoryConsumer; use datafusion_physical_expr::PhysicalExpr; use datafusion_physical_expr_common::sort_expr::{LexOrdering, OrderingRequirements}; +use futures::{Stream, StreamExt}; use crate::execution_plan::{EvaluationType, SchedulingType}; use log::{debug, trace}; @@ -362,34 +366,19 @@ impl ExecutionPlan for SortPreservingMergeExec { } }, _ => { - let receivers = (0..input_partitions) - .map(|partition| { - let stream = - self.input.execute(partition, Arc::clone(&context))?; - Ok(spawn_buffered(stream, 1)) - }) - .collect::>()?; - - debug!( - "Done setting up sender-receiver for SortPreservingMergeExec::execute" - ); - - let result = StreamingMergeBuilder::new() - .with_streams(receivers) - .with_schema(schema) - .with_expressions(&self.expr) - .with_metrics(BaselineMetrics::new(&self.metrics, partition)) - .with_batch_size(context.session_config().batch_size()) - .with_fetch(self.fetch) - .with_reservation(reservation) - .with_round_robin_tie_breaker(self.enable_round_robin_repartition) - .build()?; - - debug!( - "Got stream result from SortPreservingMergeStream::new_from_receivers" - ); - - Ok(result) + let batch_size = context.session_config().batch_size(); + Ok(Box::pin(SortPreservingMergeExecStream { + schema, + input: Arc::clone(&self.input), + context, + expr: self.expr.clone(), + metrics: BaselineMetrics::new(&self.metrics, partition), + batch_size, + fetch: self.fetch, + reservation, + enable_round_robin_repartition: self.enable_round_robin_repartition, + state: SPMStreamState::Pending, + })) } } } @@ -433,6 +422,99 @@ impl ExecutionPlan for SortPreservingMergeExec { } } +/// A stream that lazily spawns input partition tasks and builds the streaming +/// merge on first poll, rather than eagerly in `execute()`. +struct SortPreservingMergeExecStream { + schema: SchemaRef, + input: Arc, + context: Arc, + expr: LexOrdering, + metrics: BaselineMetrics, + batch_size: usize, + fetch: Option, + reservation: datafusion_execution::memory_pool::MemoryReservation, + enable_round_robin_repartition: bool, + state: SPMStreamState, +} + +enum SPMStreamState { + /// Tasks have not been spawned yet. + Pending, + /// The streaming merge has been built and is running. + Running(SendableRecordBatchStream), + /// Initialization failed. + Failed, +} + +impl SortPreservingMergeExecStream { + fn start(&mut self) -> Result<&mut SendableRecordBatchStream> { + let input_partitions = self.input.output_partitioning().partition_count(); + + let receivers = (0..input_partitions) + .map(|partition| { + let stream = self.input.execute(partition, Arc::clone(&self.context))?; + Ok(spawn_buffered(stream, 1)) + }) + .collect::>()?; + + debug!("Done setting up sender-receiver for SortPreservingMergeExec::execute"); + + // Take reservation out of self via mem::replace to pass ownership + let reservation = std::mem::replace( + &mut self.reservation, + MemoryConsumer::new("empty") + .register(&self.context.runtime_env().memory_pool), + ); + + let result = StreamingMergeBuilder::new() + .with_streams(receivers) + .with_schema(Arc::clone(&self.schema)) + .with_expressions(&self.expr) + .with_metrics(self.metrics.clone()) + .with_batch_size(self.batch_size) + .with_fetch(self.fetch) + .with_reservation(reservation) + .with_round_robin_tie_breaker(self.enable_round_robin_repartition) + .build()?; + + debug!("Got stream result from SortPreservingMergeStream::new_from_receivers"); + + self.state = SPMStreamState::Running(result); + match &mut self.state { + SPMStreamState::Running(s) => Ok(s), + _ => unreachable!(), + } + } +} + +impl RecordBatchStream for SortPreservingMergeExecStream { + fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } +} + +impl Stream for SortPreservingMergeExecStream { + type Item = Result; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + let stream = match &mut self.state { + SPMStreamState::Running(s) => s, + SPMStreamState::Failed => return Poll::Ready(None), + SPMStreamState::Pending => match self.start() { + Ok(s) => s, + Err(e) => { + self.state = SPMStreamState::Failed; + return Poll::Ready(Some(Err(e))); + } + }, + }; + stream.poll_next_unpin(cx) + } +} + #[cfg(test)] mod tests { use std::collections::HashSet;