Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 112 additions & 30 deletions datafusion/physical-plan/src/sorts/sort_preserving_merge.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
//! [`SortPreservingMergeExec`] merges multiple sorted streams into one sorted stream.

use std::any::Any;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};

use crate::common::spawn_buffered;
use crate::limit::LimitStream;
Expand All @@ -27,16 +29,18 @@ use crate::projection::{ProjectionExec, make_with_child, update_ordering};
use crate::sorts::streaming_merge::StreamingMergeBuilder;
use crate::{
DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, ExecutionPlanProperties,
Partitioning, PlanProperties, SendableRecordBatchStream, Statistics,
check_if_same_properties,
Partitioning, PlanProperties, RecordBatchStream, SendableRecordBatchStream,
Statistics, check_if_same_properties,
};

use arrow::datatypes::SchemaRef;
use datafusion_common::tree_node::TreeNodeRecursion;
use datafusion_common::{Result, assert_eq_or_internal_err, internal_err};
use datafusion_execution::TaskContext;
use datafusion_execution::memory_pool::MemoryConsumer;
use datafusion_physical_expr::PhysicalExpr;
use datafusion_physical_expr_common::sort_expr::{LexOrdering, OrderingRequirements};
use futures::{Stream, StreamExt};

use crate::execution_plan::{EvaluationType, SchedulingType};
use log::{debug, trace};
Expand Down Expand Up @@ -362,34 +366,19 @@ impl ExecutionPlan for SortPreservingMergeExec {
}
},
_ => {
let receivers = (0..input_partitions)
.map(|partition| {
let stream =
self.input.execute(partition, Arc::clone(&context))?;
Ok(spawn_buffered(stream, 1))
})
.collect::<Result<_>>()?;

debug!(
"Done setting up sender-receiver for SortPreservingMergeExec::execute"
);

let result = StreamingMergeBuilder::new()
.with_streams(receivers)
.with_schema(schema)
.with_expressions(&self.expr)
.with_metrics(BaselineMetrics::new(&self.metrics, partition))
.with_batch_size(context.session_config().batch_size())
.with_fetch(self.fetch)
.with_reservation(reservation)
.with_round_robin_tie_breaker(self.enable_round_robin_repartition)
.build()?;

debug!(
"Got stream result from SortPreservingMergeStream::new_from_receivers"
);

Ok(result)
let batch_size = context.session_config().batch_size();
Ok(Box::pin(SortPreservingMergeExecStream {
schema,
input: Arc::clone(&self.input),
context,
expr: self.expr.clone(),
metrics: BaselineMetrics::new(&self.metrics, partition),
batch_size,
fetch: self.fetch,
reservation,
enable_round_robin_repartition: self.enable_round_robin_repartition,
state: SPMStreamState::Pending,
}))
}
}
}
Expand Down Expand Up @@ -433,6 +422,99 @@ impl ExecutionPlan for SortPreservingMergeExec {
}
}

/// A stream that lazily spawns input partition tasks and builds the streaming
/// merge on first poll, rather than eagerly in `execute()`.
struct SortPreservingMergeExecStream {
schema: SchemaRef,
input: Arc<dyn ExecutionPlan>,
context: Arc<TaskContext>,
expr: LexOrdering,
metrics: BaselineMetrics,
batch_size: usize,
fetch: Option<usize>,
reservation: datafusion_execution::memory_pool::MemoryReservation,
enable_round_robin_repartition: bool,
state: SPMStreamState,
}

enum SPMStreamState {
/// Tasks have not been spawned yet.
Pending,
/// The streaming merge has been built and is running.
Running(SendableRecordBatchStream),
/// Initialization failed.
Failed,
}

impl SortPreservingMergeExecStream {
fn start(&mut self) -> Result<&mut SendableRecordBatchStream> {
let input_partitions = self.input.output_partitioning().partition_count();

let receivers = (0..input_partitions)
.map(|partition| {
let stream = self.input.execute(partition, Arc::clone(&self.context))?;
Ok(spawn_buffered(stream, 1))
})
.collect::<Result<_>>()?;

debug!("Done setting up sender-receiver for SortPreservingMergeExec::execute");

// Take reservation out of self via mem::replace to pass ownership
let reservation = std::mem::replace(
&mut self.reservation,
MemoryConsumer::new("empty")
.register(&self.context.runtime_env().memory_pool),
);

let result = StreamingMergeBuilder::new()
.with_streams(receivers)
.with_schema(Arc::clone(&self.schema))
.with_expressions(&self.expr)
.with_metrics(self.metrics.clone())
.with_batch_size(self.batch_size)
.with_fetch(self.fetch)
.with_reservation(reservation)
.with_round_robin_tie_breaker(self.enable_round_robin_repartition)
.build()?;

debug!("Got stream result from SortPreservingMergeStream::new_from_receivers");

self.state = SPMStreamState::Running(result);
match &mut self.state {
SPMStreamState::Running(s) => Ok(s),
_ => unreachable!(),
}
}
}

impl RecordBatchStream for SortPreservingMergeExecStream {
fn schema(&self) -> SchemaRef {
Arc::clone(&self.schema)
}
}

impl Stream for SortPreservingMergeExecStream {
type Item = Result<arrow::array::RecordBatch>;

fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Self::Item>> {
let stream = match &mut self.state {
SPMStreamState::Running(s) => s,
SPMStreamState::Failed => return Poll::Ready(None),
SPMStreamState::Pending => match self.start() {
Ok(s) => s,
Err(e) => {
self.state = SPMStreamState::Failed;
return Poll::Ready(Some(Err(e)));
}
},
};
stream.poll_next_unpin(cx)
}
}

#[cfg(test)]
mod tests {
use std::collections::HashSet;
Expand Down
Loading