diff --git a/datafusion/core/tests/physical_optimizer/limit_pushdown.rs b/datafusion/core/tests/physical_optimizer/limit_pushdown.rs index e8d06d69df414..b30092159f43a 100644 --- a/datafusion/core/tests/physical_optimizer/limit_pushdown.rs +++ b/datafusion/core/tests/physical_optimizer/limit_pushdown.rs @@ -26,7 +26,7 @@ use arrow::compute::SortOptions; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion_common::config::ConfigOptions; use datafusion_common::error::Result; -use datafusion_expr::Operator; +use datafusion_expr::{JoinType, Operator}; use datafusion_physical_expr::Partitioning; use datafusion_physical_expr::expressions::{BinaryExpr, col, lit}; use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; @@ -34,6 +34,7 @@ use datafusion_physical_optimizer::PhysicalOptimizerRule; use datafusion_physical_optimizer::limit_pushdown::LimitPushdown; use datafusion_physical_plan::empty::EmptyExec; use datafusion_physical_plan::filter::FilterExec; +use datafusion_physical_plan::joins::NestedLoopJoinExec; use datafusion_physical_plan::projection::ProjectionExec; use datafusion_physical_plan::repartition::RepartitionExec; use datafusion_physical_plan::{ExecutionPlan, get_plan_string}; @@ -87,6 +88,16 @@ fn empty_exec(schema: SchemaRef) -> Arc { Arc::new(EmptyExec::new(schema)) } +fn nested_loop_join_exec( + left: Arc, + right: Arc, + join_type: JoinType, +) -> Result> { + Ok(Arc::new(NestedLoopJoinExec::try_new( + left, right, None, &join_type, None, + )?)) +} + #[test] fn transforms_streaming_table_exec_into_fetching_version_when_skip_is_zero() -> Result<()> { @@ -343,3 +354,104 @@ fn merges_local_limit_with_global_limit() -> Result<()> { Ok(()) } + +#[test] +fn preserves_nested_global_limit() -> Result<()> { + // If there are multiple limits in an execution plan, they all need to be + // preserved in the optimized plan. + // + // Plan structure: + // GlobalLimitExec: skip=1, fetch=1 + // NestedLoopJoinExec (Left) + // EmptyExec (left side) + // GlobalLimitExec: skip=2, fetch=1 + // NestedLoopJoinExec (Right) + // EmptyExec (left side) + // EmptyExec (right side) + let schema = create_schema(); + + // Build inner join: NestedLoopJoin(Empty, Empty) + let inner_left = empty_exec(Arc::clone(&schema)); + let inner_right = empty_exec(Arc::clone(&schema)); + let inner_join = nested_loop_join_exec(inner_left, inner_right, JoinType::Right)?; + + // Add inner limit: GlobalLimitExec: skip=2, fetch=1 + let inner_limit = global_limit_exec(inner_join, 2, Some(1)); + + // Build outer join: NestedLoopJoin(Empty, GlobalLimit) + let outer_left = empty_exec(Arc::clone(&schema)); + let outer_join = nested_loop_join_exec(outer_left, inner_limit, JoinType::Left)?; + + // Add outer limit: GlobalLimitExec: skip=1, fetch=1 + let outer_limit = global_limit_exec(outer_join, 1, Some(1)); + + let initial = get_plan_string(&outer_limit); + let expected_initial = [ + "GlobalLimitExec: skip=1, fetch=1", + " NestedLoopJoinExec: join_type=Left", + " EmptyExec", + " GlobalLimitExec: skip=2, fetch=1", + " NestedLoopJoinExec: join_type=Right", + " EmptyExec", + " EmptyExec", + ]; + assert_eq!(initial, expected_initial); + + let after_optimize = + LimitPushdown::new().optimize(outer_limit, &ConfigOptions::new())?; + let expected = [ + "GlobalLimitExec: skip=1, fetch=1", + " NestedLoopJoinExec: join_type=Left", + " EmptyExec", + " GlobalLimitExec: skip=2, fetch=1", + " NestedLoopJoinExec: join_type=Right", + " EmptyExec", + " EmptyExec", + ]; + assert_eq!(get_plan_string(&after_optimize), expected); + + Ok(()) +} + +#[test] +fn preserves_skip_before_sort() -> Result<()> { + // If there's a limit with skip before a node that (1) supports fetch but + // (2) does not support limit pushdown, that limit should not be removed. + // + // Plan structure: + // GlobalLimitExec: skip=1, fetch=None + // SortExec: TopK(fetch=4) + // EmptyExec + let schema = create_schema(); + + let empty = empty_exec(Arc::clone(&schema)); + + let ordering = [PhysicalSortExpr { + expr: col("c1", &schema)?, + options: SortOptions::default(), + }]; + let sort = sort_exec(ordering.into(), empty) + .with_fetch(Some(4)) + .unwrap(); + + let outer_limit = global_limit_exec(sort, 1, None); + + let initial = get_plan_string(&outer_limit); + let expected_initial = [ + "GlobalLimitExec: skip=1, fetch=None", + " SortExec: TopK(fetch=4), expr=[c1@0 ASC], preserve_partitioning=[false]", + " EmptyExec", + ]; + assert_eq!(initial, expected_initial); + + let after_optimize = + LimitPushdown::new().optimize(outer_limit, &ConfigOptions::new())?; + let expected = [ + "GlobalLimitExec: skip=1, fetch=3", + " SortExec: TopK(fetch=4), expr=[c1@0 ASC], preserve_partitioning=[false]", + " EmptyExec", + ]; + assert_eq!(get_plan_string(&after_optimize), expected); + + Ok(()) +} diff --git a/datafusion/physical-optimizer/src/limit_pushdown.rs b/datafusion/physical-optimizer/src/limit_pushdown.rs index a4dac81dbacf8..e7bede494da99 100644 --- a/datafusion/physical-optimizer/src/limit_pushdown.rs +++ b/datafusion/physical-optimizer/src/limit_pushdown.rs @@ -155,6 +155,7 @@ pub fn pushdown_limit_helper( global_state.skip = skip; global_state.fetch = fetch; global_state.preserve_order = limit_exec.preserve_order(); + global_state.satisfied = false; // Now the global state has the most recent information, we can remove // the `LimitExec` plan. We will decide later if we should add it again @@ -172,7 +173,7 @@ pub fn pushdown_limit_helper( // If we have a non-limit operator with fetch capability, update global // state as necessary: if pushdown_plan.fetch().is_some() { - if global_state.fetch.is_none() { + if global_state.skip == 0 { global_state.satisfied = true; } (global_state.skip, global_state.fetch) = combine_limit( diff --git a/datafusion/sqllogictest/test_files/limit.slt b/datafusion/sqllogictest/test_files/limit.slt index 96471411e0f95..429181a2d385b 100644 --- a/datafusion/sqllogictest/test_files/limit.slt +++ b/datafusion/sqllogictest/test_files/limit.slt @@ -706,8 +706,8 @@ ON t1.b = t2.b ORDER BY t1.b desc, c desc, c2 desc OFFSET 3 LIMIT 2; ---- -3 99 82 -3 99 79 +3 98 79 +3 97 96 statement ok drop table ordered_table; diff --git a/datafusion/sqllogictest/test_files/union.slt b/datafusion/sqllogictest/test_files/union.slt index b79b6d2fe5e9e..d858d0ae3ea4e 100644 --- a/datafusion/sqllogictest/test_files/union.slt +++ b/datafusion/sqllogictest/test_files/union.slt @@ -494,22 +494,25 @@ physical_plan 01)CoalescePartitionsExec: fetch=3 02)--UnionExec 03)----ProjectionExec: expr=[count(Int64(1))@0 as cnt] -04)------AggregateExec: mode=Final, gby=[], aggr=[count(Int64(1))] -05)--------CoalescePartitionsExec -06)----------AggregateExec: mode=Partial, gby=[], aggr=[count(Int64(1))] -07)------------ProjectionExec: expr=[] -08)--------------AggregateExec: mode=FinalPartitioned, gby=[c1@0 as c1], aggr=[] -09)----------------RepartitionExec: partitioning=Hash([c1@0], 4), input_partitions=4 -10)------------------AggregateExec: mode=Partial, gby=[c1@0 as c1], aggr=[] -11)--------------------FilterExec: c13@1 != C2GT5KVyOPZpgKVl110TyZO0NcJ434, projection=[c1@0] -12)----------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -13)------------------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c13], file_type=csv, has_header=true -14)----ProjectionExec: expr=[1 as cnt] -15)------PlaceholderRowExec -16)----ProjectionExec: expr=[lead(b.c1,Int64(1)) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING@1 as cnt] -17)------BoundedWindowAggExec: wdw=[lead(b.c1,Int64(1)) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING: Field { "lead(b.c1,Int64(1)) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING": nullable Int64 }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING], mode=[Sorted] -18)--------ProjectionExec: expr=[1 as c1] -19)----------PlaceholderRowExec +04)------GlobalLimitExec: skip=0, fetch=3 +05)--------AggregateExec: mode=Final, gby=[], aggr=[count(Int64(1))] +06)----------CoalescePartitionsExec +07)------------AggregateExec: mode=Partial, gby=[], aggr=[count(Int64(1))] +08)--------------ProjectionExec: expr=[] +09)----------------AggregateExec: mode=FinalPartitioned, gby=[c1@0 as c1], aggr=[] +10)------------------RepartitionExec: partitioning=Hash([c1@0], 4), input_partitions=4 +11)--------------------AggregateExec: mode=Partial, gby=[c1@0 as c1], aggr=[] +12)----------------------FilterExec: c13@1 != C2GT5KVyOPZpgKVl110TyZO0NcJ434, projection=[c1@0] +13)------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +14)--------------------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c13], file_type=csv, has_header=true +15)----ProjectionExec: expr=[1 as cnt] +16)------GlobalLimitExec: skip=0, fetch=3 +17)--------PlaceholderRowExec +18)----ProjectionExec: expr=[lead(b.c1,Int64(1)) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING@1 as cnt] +19)------GlobalLimitExec: skip=0, fetch=3 +20)--------BoundedWindowAggExec: wdw=[lead(b.c1,Int64(1)) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING: Field { "lead(b.c1,Int64(1)) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING": nullable Int64 }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING], mode=[Sorted] +21)----------ProjectionExec: expr=[1 as c1] +22)------------PlaceholderRowExec ########