From 26737dddff6b2311457a35169cfde77caaf379ff Mon Sep 17 00:00:00 2001 From: jakevin Date: Tue, 1 Aug 2023 13:23:55 +0800 Subject: [PATCH] [feature](Nereids): pushdown MIN/MAX/SUM through join (#22264) * [minor](Nereids): add more comment to explain code * [feature](Nereids): pushdown MIN/MAX/SUM through join --- .../jobs/joinorder/hypergraph/Node.java | 2 +- .../hypergraph/receiver/PlanReceiver.java | 4 + .../apache/doris/nereids/rules/RuleType.java | 9 +- .../rules/exploration/EagerGroupByCount.java | 2 +- .../nereids/rules/exploration/EagerSplit.java | 2 +- .../PushdownProjectThroughInnerOuterJoin.java | 4 +- .../join/PushdownProjectThroughSemiJoin.java | 4 +- .../rewrite/PushdownMinMaxThroughJoin.java | 181 ++++++++++++++++ .../rules/rewrite/PushdownSumThroughJoin.java | 194 ++++++++++++++++++ .../apache/doris/statistics/Statistics.java | 5 +- .../PushdownMinMaxThroughJoinTest.java | 114 ++++++++++ .../rewrite/PushdownSumThroughJoinTest.java | 104 ++++++++++ .../doris/nereids/util/PlanChecker.java | 5 + 13 files changed, 617 insertions(+), 13 deletions(-) create mode 100644 fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownMinMaxThroughJoin.java create mode 100644 fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownSumThroughJoin.java create mode 100644 fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushdownMinMaxThroughJoinTest.java create mode 100644 fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushdownSumThroughJoinTest.java diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/Node.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/Node.java index 40c5171769..fe8a12ed52 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/Node.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/Node.java @@ -29,7 +29,7 @@ import java.util.List; */ public class Node { private final int index; - // TODO + // Due to group in Node is base group, so mergeGroup() don't need to consider it. private final Group group; private final List edges = new ArrayList<>(); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/receiver/PlanReceiver.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/receiver/PlanReceiver.java index 51c16e24f7..5106d71a4b 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/receiver/PlanReceiver.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/receiver/PlanReceiver.java @@ -319,6 +319,10 @@ public class PlanReceiver implements AbstractReceiver { hasGenerated.add(groupExpression); // process child first, plan's child may be changed due to mergeGroup + // due to mergeGroup, the children Group of groupExpression may be replaced, so we need to use lambda to + // get the child to make we can get child at the time we use child. + // If we use for child: groupExpression.children(), it means that we take it in advance. It may cause NPE, + // work flow: get children() to get left, right -> copyIn left() -> mergeGroup -> right is merged -> NPE Plan physicalPlan = groupExpression.getPlan(); for (int i = 0; i < groupExpression.children().size(); i++) { int childIdx = i; diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java index 2d8f913e88..114c0529fa 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java @@ -149,12 +149,15 @@ public enum RuleType { PUSHDOWN_FILTER_THROUGH_CTE(RuleTypeClass.REWRITE), PUSHDOWN_FILTER_THROUGH_CTE_ANCHOR(RuleTypeClass.REWRITE), - PUSH_DOWN_DISTINCT_THROUGH_JOIN(RuleTypeClass.REWRITE), + PUSHDOWN_DISTINCT_THROUGH_JOIN(RuleTypeClass.REWRITE), COLUMN_PRUNING(RuleTypeClass.REWRITE), PUSHDOWN_TOP_N_THROUGH_PROJECTION_WINDOW(RuleTypeClass.REWRITE), PUSHDOWN_TOP_N_THROUGH_WINDOW(RuleTypeClass.REWRITE), + PUSHDOWN_MIN_MAX_THROUGH_JOIN(RuleTypeClass.REWRITE), + PUSHDOWN_SUM_THROUGH_JOIN(RuleTypeClass.REWRITE), + PUSHDOWN_COUNT_THROUGH_JOIN(RuleTypeClass.REWRITE), TRANSPOSE_LOGICAL_SEMI_JOIN_LOGICAL_JOIN(RuleTypeClass.REWRITE), TRANSPOSE_LOGICAL_SEMI_JOIN_LOGICAL_JOIN_PROJECT(RuleTypeClass.REWRITE), @@ -269,8 +272,8 @@ public enum RuleType { TRANSPOSE_LOGICAL_AGG_SEMI_JOIN(RuleTypeClass.EXPLORATION), TRANSPOSE_LOGICAL_AGG_SEMI_JOIN_PROJECT(RuleTypeClass.EXPLORATION), TRANSPOSE_LOGICAL_JOIN_UNION(RuleTypeClass.EXPLORATION), - PUSH_DOWN_PROJECT_THROUGH_SEMI_JOIN(RuleTypeClass.EXPLORATION), - PUSH_DOWN_PROJECT_THROUGH_INNER_OUTER_JOIN(RuleTypeClass.EXPLORATION), + PUSHDOWN_PROJECT_THROUGH_SEMI_JOIN(RuleTypeClass.EXPLORATION), + PUSHDOWN_PROJECT_THROUGH_INNER_OUTER_JOIN(RuleTypeClass.EXPLORATION), EAGER_COUNT(RuleTypeClass.EXPLORATION), EAGER_GROUP_BY(RuleTypeClass.EXPLORATION), EAGER_GROUP_BY_COUNT(RuleTypeClass.EXPLORATION), diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/EagerGroupByCount.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/EagerGroupByCount.java index 582e84f6b5..b993d7cb94 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/EagerGroupByCount.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/EagerGroupByCount.java @@ -78,7 +78,7 @@ public class EagerGroupByCount extends OneExplorationRuleFactory { rightSums.add(sum); } } - if (leftSums.size() == 0 || rightSums.size() == 0) { + if (leftSums.size() == 0 && rightSums.size() == 0) { return null; } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/EagerSplit.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/EagerSplit.java index 89023ca69f..fc0159f26c 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/EagerSplit.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/EagerSplit.java @@ -80,7 +80,7 @@ public class EagerSplit extends OneExplorationRuleFactory { rightSums.add(sum); } } - if (leftSums.size() == 0 || rightSums.size() == 0) { + if (leftSums.size() == 0 && rightSums.size() == 0) { return null; } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/PushdownProjectThroughInnerOuterJoin.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/PushdownProjectThroughInnerOuterJoin.java index 8a480c67dd..5c39097108 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/PushdownProjectThroughInnerOuterJoin.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/PushdownProjectThroughInnerOuterJoin.java @@ -69,7 +69,7 @@ public class PushdownProjectThroughInnerOuterJoin implements ExplorationRuleFact return null; } return topJoin.withChildren(newLeft, topJoin.right()); - }).toRule(RuleType.PUSH_DOWN_PROJECT_THROUGH_INNER_OUTER_JOIN), + }).toRule(RuleType.PUSHDOWN_PROJECT_THROUGH_INNER_OUTER_JOIN), logicalJoin(group(), logicalProject(logicalJoin())) .when(j -> j.right().child().getJoinType().isOuterJoin() || j.right().child().getJoinType().isInnerJoin()) @@ -83,7 +83,7 @@ public class PushdownProjectThroughInnerOuterJoin implements ExplorationRuleFact return null; } return topJoin.withChildren(topJoin.left(), newRight); - }).toRule(RuleType.PUSH_DOWN_PROJECT_THROUGH_INNER_OUTER_JOIN) + }).toRule(RuleType.PUSHDOWN_PROJECT_THROUGH_INNER_OUTER_JOIN) ); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/PushdownProjectThroughSemiJoin.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/PushdownProjectThroughSemiJoin.java index aa27774b8b..485631cb11 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/PushdownProjectThroughSemiJoin.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/PushdownProjectThroughSemiJoin.java @@ -63,7 +63,7 @@ public class PushdownProjectThroughSemiJoin implements ExplorationRuleFactory { LogicalProject> project = topJoin.left(); Plan newLeft = pushdownProject(project); return topJoin.withChildren(newLeft, topJoin.right()); - }).toRule(RuleType.PUSH_DOWN_PROJECT_THROUGH_SEMI_JOIN), + }).toRule(RuleType.PUSHDOWN_PROJECT_THROUGH_SEMI_JOIN), logicalJoin(group(), logicalProject(logicalJoin())) .when(j -> j.right().child().getJoinType().isLeftSemiOrAntiJoin()) @@ -74,7 +74,7 @@ public class PushdownProjectThroughSemiJoin implements ExplorationRuleFactory { LogicalProject> project = topJoin.right(); Plan newRight = pushdownProject(project); return topJoin.withChildren(topJoin.left(), newRight); - }).toRule(RuleType.PUSH_DOWN_PROJECT_THROUGH_SEMI_JOIN) + }).toRule(RuleType.PUSHDOWN_PROJECT_THROUGH_SEMI_JOIN) ); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownMinMaxThroughJoin.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownMinMaxThroughJoin.java new file mode 100644 index 0000000000..bd61f4f0ac --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownMinMaxThroughJoin.java @@ -0,0 +1,181 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.rewrite; + +import org.apache.doris.nereids.rules.Rule; +import org.apache.doris.nereids.rules.RuleType; +import org.apache.doris.nereids.trees.expressions.Alias; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.NamedExpression; +import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; +import org.apache.doris.nereids.trees.expressions.functions.agg.Max; +import org.apache.doris.nereids.trees.expressions.functions.agg.Min; +import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; +import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; + +import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableList.Builder; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +/** + * Related paper "Eager aggregation and lazy aggregation". + *
+ * aggregate: Min/Max(x)
+ * |
+ * join
+ * |   \
+ * |    *
+ * (x)
+ * ->
+ * aggregate: Min/Max(min1)
+ * |
+ * join
+ * |   \
+ * |    *
+ * aggregate: Min/Max(x) as min1
+ * 
+ */ +public class PushdownMinMaxThroughJoin implements RewriteRuleFactory { + @Override + public List buildRules() { + return ImmutableList.of( + logicalAggregate(innerLogicalJoin()) + .when(agg -> agg.child().getOtherJoinConjuncts().size() == 0) + .whenNot(agg -> agg.child().children().stream().anyMatch(p -> p instanceof LogicalAggregate)) + .when(agg -> { + Set funcs = agg.getAggregateFunctions(); + return !funcs.isEmpty() && funcs.stream() + .allMatch(f -> (f instanceof Min || f instanceof Max) && f.child(0) instanceof Slot); + }) + .then(agg -> pushMinMax(agg, agg.child(), ImmutableList.of())) + .toRule(RuleType.PUSHDOWN_MIN_MAX_THROUGH_JOIN), + logicalAggregate(logicalProject(innerLogicalJoin())) + .when(agg -> agg.child().isAllSlots()) + .when(agg -> agg.child().child().getOtherJoinConjuncts().size() == 0) + .whenNot(agg -> agg.child().children().stream().anyMatch(p -> p instanceof LogicalAggregate)) + .when(agg -> { + Set funcs = agg.getAggregateFunctions(); + return !funcs.isEmpty() && funcs.stream() + .allMatch(f -> (f instanceof Min || f instanceof Max) && f.child(0) instanceof Slot); + }) + .then(agg -> pushMinMax(agg, agg.child().child(), agg.child().getProjects())) + .toRule(RuleType.PUSHDOWN_MIN_MAX_THROUGH_JOIN) + ); + } + + private LogicalAggregate pushMinMax(LogicalAggregate agg, + LogicalJoin join, List projects) { + List leftOutput = join.left().getOutput(); + List rightOutput = join.right().getOutput(); + + List leftFuncs = new ArrayList<>(); + List rightFuncs = new ArrayList<>(); + for (AggregateFunction func : agg.getAggregateFunctions()) { + Slot slot = (Slot) func.child(0); + if (leftOutput.contains(slot)) { + leftFuncs.add(func); + } else if (rightOutput.contains(slot)) { + rightFuncs.add(func); + } else { + throw new IllegalStateException("Slot " + slot + " not found in join output"); + } + } + + Set leftGroupBy = new HashSet<>(); + Set rightGroupBy = new HashSet<>(); + for (Expression e : agg.getGroupByExpressions()) { + Slot slot = (Slot) e; + if (leftOutput.contains(slot)) { + leftGroupBy.add(slot); + } else if (rightOutput.contains(slot)) { + rightGroupBy.add(slot); + } else { + return null; + } + } + join.getHashJoinConjuncts().forEach(e -> e.getInputSlots().forEach(slot -> { + if (leftOutput.contains(slot)) { + leftGroupBy.add(slot); + } else if (rightOutput.contains(slot)) { + rightGroupBy.add(slot); + } else { + throw new IllegalStateException("Slot " + slot + " not found in join output"); + } + })); + + Plan left = join.left(); + Plan right = join.right(); + Map leftSlotToOutput = new HashMap<>(); + Map rightSlotToOutput = new HashMap<>(); + if (!leftFuncs.isEmpty()) { + Builder leftAggOutputBuilder = ImmutableList.builder() + .addAll(leftGroupBy); + leftFuncs.forEach(func -> { + Alias alias = func.alias(func.getName()); + leftSlotToOutput.put((Slot) func.child(0), alias); + leftAggOutputBuilder.add(alias); + }); + left = new LogicalAggregate<>(ImmutableList.copyOf(leftGroupBy), leftAggOutputBuilder.build(), join.left()); + } + if (!rightFuncs.isEmpty()) { + Builder rightAggOutputBuilder = ImmutableList.builder() + .addAll(rightGroupBy); + rightFuncs.forEach(func -> { + Alias alias = func.alias(func.getName()); + rightSlotToOutput.put((Slot) func.child(0), alias); + rightAggOutputBuilder.add(alias); + }); + right = new LogicalAggregate<>(ImmutableList.copyOf(rightGroupBy), rightAggOutputBuilder.build(), + join.right()); + } + + Preconditions.checkState(left != join.left() || right != join.right()); + Plan newJoin = join.withChildren(left, right); + + List newOutputExprs = new ArrayList<>(); + for (NamedExpression ne : agg.getOutputExpressions()) { + if (ne instanceof Alias && ((Alias) ne).child() instanceof AggregateFunction) { + AggregateFunction func = (AggregateFunction) ((Alias) ne).child(); + Slot slot = (Slot) func.child(0); + if (leftSlotToOutput.containsKey(slot)) { + Expression newFunc = func.withChildren(leftSlotToOutput.get(slot).toSlot()); + newOutputExprs.add((NamedExpression) ne.withChildren(newFunc)); + } else if (rightSlotToOutput.containsKey(slot)) { + Expression newFunc = func.withChildren(rightSlotToOutput.get(slot).toSlot()); + newOutputExprs.add((NamedExpression) ne.withChildren(newFunc)); + } else { + throw new IllegalStateException("Slot " + slot + " not found in join output"); + } + } else { + newOutputExprs.add(ne); + } + } + + // TODO: column prune project + return agg.withAggOutputChild(newOutputExprs, newJoin); + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownSumThroughJoin.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownSumThroughJoin.java new file mode 100644 index 0000000000..1319200220 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownSumThroughJoin.java @@ -0,0 +1,194 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.rewrite; + +import org.apache.doris.nereids.rules.Rule; +import org.apache.doris.nereids.rules.RuleType; +import org.apache.doris.nereids.trees.expressions.Alias; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.Multiply; +import org.apache.doris.nereids.trees.expressions.NamedExpression; +import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; +import org.apache.doris.nereids.trees.expressions.functions.agg.Count; +import org.apache.doris.nereids.trees.expressions.functions.agg.Sum; +import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; +import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableList.Builder; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +/** + * Related paper "Eager aggregation and lazy aggregation". + *
+ * aggregate: Sum(x)
+ * |
+ * join
+ * |   \
+ * |    *
+ * (x)
+ * ->
+ * aggregate: Sum(min1)
+ * |
+ * join
+ * |   \
+ * |    *
+ * aggregate: Sum(x) as min1
+ * 
+ */ +public class PushdownSumThroughJoin implements RewriteRuleFactory { + @Override + public List buildRules() { + return ImmutableList.of( + logicalAggregate(innerLogicalJoin()) + .when(agg -> agg.child().getOtherJoinConjuncts().size() == 0) + .whenNot(agg -> agg.child().children().stream().anyMatch(p -> p instanceof LogicalAggregate)) + .when(agg -> { + Set funcs = agg.getAggregateFunctions(); + return !funcs.isEmpty() && funcs.stream() + .allMatch(f -> f instanceof Sum && f.child(0) instanceof Slot); + }) + .then(agg -> pushSum(agg, agg.child(), ImmutableList.of())) + .toRule(RuleType.PUSHDOWN_SUM_THROUGH_JOIN), + logicalAggregate(logicalProject(innerLogicalJoin())) + .when(agg -> agg.child().isAllSlots()) + .when(agg -> agg.child().child().getOtherJoinConjuncts().size() == 0) + .whenNot(agg -> agg.child().children().stream().anyMatch(p -> p instanceof LogicalAggregate)) + .when(agg -> { + Set funcs = agg.getAggregateFunctions(); + return !funcs.isEmpty() && funcs.stream() + .allMatch(f -> f instanceof Sum && f.child(0) instanceof Slot); + }) + .then(agg -> pushSum(agg, agg.child().child(), agg.child().getProjects())) + .toRule(RuleType.PUSHDOWN_SUM_THROUGH_JOIN) + ); + } + + private LogicalAggregate pushSum(LogicalAggregate agg, + LogicalJoin join, List projects) { + List leftOutput = join.left().getOutput(); + List rightOutput = join.right().getOutput(); + + List leftSums = new ArrayList<>(); + List rightSums = new ArrayList<>(); + for (AggregateFunction f : agg.getAggregateFunctions()) { + Sum sum = (Sum) f; + Slot slot = (Slot) sum.child(); + if (leftOutput.contains(slot)) { + leftSums.add(sum); + } else if (rightOutput.contains(slot)) { + rightSums.add(sum); + } else { + throw new IllegalStateException("Slot " + slot + " not found in join output"); + } + } + if (leftSums.isEmpty() && rightSums.isEmpty() + || (!leftSums.isEmpty() && !rightSums.isEmpty())) { + return null; + } + + Set leftGroupBy = new HashSet<>(); + Set rightGroupBy = new HashSet<>(); + for (Expression e : agg.getGroupByExpressions()) { + Slot slot = (Slot) e; + if (leftOutput.contains(slot)) { + leftGroupBy.add(slot); + } else if (rightOutput.contains(slot)) { + rightGroupBy.add(slot); + } else { + return null; + } + } + join.getHashJoinConjuncts().forEach(e -> e.getInputSlots().forEach(slot -> { + if (leftOutput.contains(slot)) { + leftGroupBy.add(slot); + } else if (rightOutput.contains(slot)) { + rightGroupBy.add(slot); + } else { + throw new IllegalStateException("Slot " + slot + " not found in join output"); + } + })); + + List sums; + Set sumGroupBy; + Set cntGroupBy; + Plan sumChild; + Plan cntChild; + if (!leftSums.isEmpty()) { + sums = leftSums; + sumGroupBy = leftGroupBy; + cntGroupBy = rightGroupBy; + sumChild = join.left(); + cntChild = join.right(); + } else { + sums = rightSums; + sumGroupBy = rightGroupBy; + cntGroupBy = leftGroupBy; + sumChild = join.right(); + cntChild = join.left(); + } + + // Sum agg + Map sumSlotToOutput = new HashMap<>(); + Builder sumAggOutputBuilder = ImmutableList.builder().addAll(sumGroupBy); + sums.forEach(func -> { + Alias alias = func.alias(func.getName()); + sumSlotToOutput.put((Slot) func.child(0), alias); + sumAggOutputBuilder.add(alias); + }); + LogicalAggregate sumAgg = new LogicalAggregate<>( + ImmutableList.copyOf(sumGroupBy), sumAggOutputBuilder.build(), sumChild); + + // Count agg + Alias cnt = new Count().alias("cnt"); + List cntAggOutput = ImmutableList.builder() + .addAll(cntGroupBy).add(cnt) + .build(); + LogicalAggregate cntAgg = new LogicalAggregate<>( + ImmutableList.copyOf(cntGroupBy), cntAggOutput, cntChild); + + Plan newJoin = !leftSums.isEmpty() ? join.withChildren(sumAgg, cntAgg) : join.withChildren(cntAgg, sumAgg); + + // top Sum agg + // replace sum(x) -> sum(sum# * cnt) + List newOutputExprs = new ArrayList<>(); + for (NamedExpression ne : agg.getOutputExpressions()) { + if (ne instanceof Alias && ((Alias) ne).child() instanceof AggregateFunction) { + AggregateFunction func = (AggregateFunction) ((Alias) ne).child(); + Slot slot = (Slot) func.child(0); + if (sumSlotToOutput.containsKey(slot)) { + Expression expr = func.withChildren(new Multiply(sumSlotToOutput.get(slot).toSlot(), cnt.toSlot())); + newOutputExprs.add((NamedExpression) ne.withChildren(expr)); + } else { + throw new IllegalStateException("Slot " + slot + " not found in join output"); + } + } else { + newOutputExprs.add(ne); + } + } + return agg.withAggOutputChild(newOutputExprs, newJoin); + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/statistics/Statistics.java b/fe/fe-core/src/main/java/org/apache/doris/statistics/Statistics.java index 5c628aaba3..7c54ed8669 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/statistics/Statistics.java +++ b/fe/fe-core/src/main/java/org/apache/doris/statistics/Statistics.java @@ -111,10 +111,9 @@ public class Statistics { ColumnStatistic columnStatistic = entry.getValue(); ColumnStatisticBuilder columnStatisticBuilder = new ColumnStatisticBuilder(columnStatistic); columnStatisticBuilder.setNdv(Math.min(columnStatistic.ndv, rowCount)); - double nullFactor = (rowCount - columnStatistic.numNulls) / rowCount; - columnStatisticBuilder.setNumNulls(nullFactor * rowCount); + columnStatisticBuilder.setNumNulls(rowCount - columnStatistic.numNulls); columnStatisticBuilder.setCount(rowCount); - statistics.addColumnStats(entry.getKey(), columnStatisticBuilder.build()); + expressionToColumnStats.put(entry.getKey(), columnStatisticBuilder.build()); } return statistics; } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushdownMinMaxThroughJoinTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushdownMinMaxThroughJoinTest.java new file mode 100644 index 0000000000..83c297f7a0 --- /dev/null +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushdownMinMaxThroughJoinTest.java @@ -0,0 +1,114 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.rewrite; + +import org.apache.doris.common.Pair; +import org.apache.doris.nereids.trees.expressions.Alias; +import org.apache.doris.nereids.trees.expressions.functions.agg.Max; +import org.apache.doris.nereids.trees.expressions.functions.agg.Min; +import org.apache.doris.nereids.trees.plans.JoinType; +import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; +import org.apache.doris.nereids.trees.plans.logical.LogicalPlan; +import org.apache.doris.nereids.util.LogicalPlanBuilder; +import org.apache.doris.nereids.util.MemoPatternMatchSupported; +import org.apache.doris.nereids.util.MemoTestUtils; +import org.apache.doris.nereids.util.PlanChecker; +import org.apache.doris.nereids.util.PlanConstructor; + +import com.google.common.collect.ImmutableList; +import org.junit.jupiter.api.Test; + +class PushdownMinMaxThroughJoinTest implements MemoPatternMatchSupported { + private static final LogicalOlapScan scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0); + private static final LogicalOlapScan scan2 = PlanConstructor.newLogicalOlapScan(1, "t2", 0); + private static final LogicalOlapScan scan3 = PlanConstructor.newLogicalOlapScan(2, "t3", 0); + private static final LogicalOlapScan scan4 = PlanConstructor.newLogicalOlapScan(3, "t4", 0); + + @Test + void testSingleJoin() { + Alias min = new Min(scan1.getOutput().get(0)).alias("min"); + LogicalPlan plan = new LogicalPlanBuilder(scan1) + .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0)) + .aggGroupUsingIndex(ImmutableList.of(0), ImmutableList.of(scan1.getOutput().get(0), min)) + .build(); + + PlanChecker.from(MemoTestUtils.createConnectContext(), plan) + .applyTopDown(new PushdownMinMaxThroughJoin()) + .printlnTree(); + } + + @Test + void testMultiJoin() { + Alias min = new Min(scan1.getOutput().get(0)).alias("min"); + LogicalPlan plan = new LogicalPlanBuilder(scan1) + .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0)) + .join(scan3, JoinType.INNER_JOIN, Pair.of(0, 0)) + .join(scan4, JoinType.INNER_JOIN, Pair.of(0, 0)) + .aggGroupUsingIndex(ImmutableList.of(0), ImmutableList.of(scan1.getOutput().get(0), min)) + .build(); + + PlanChecker.from(MemoTestUtils.createConnectContext(), plan) + .applyTopDown(new PushdownMinMaxThroughJoin()) + .printlnTree(); + } + + @Test + void testAggNotOutputGroupBy() { + // agg don't output group by + Alias min = new Min(scan1.getOutput().get(0)).alias("min"); + LogicalPlan plan = new LogicalPlanBuilder(scan1) + .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0)) + .join(scan3, JoinType.INNER_JOIN, Pair.of(0, 0)) + .aggGroupUsingIndex(ImmutableList.of(0), ImmutableList.of(min)) + .build(); + + PlanChecker.from(MemoTestUtils.createConnectContext(), plan) + .applyTopDown(new PushdownMinMaxThroughJoin()) + .printlnTree(); + } + + @Test + void testBothSideSingleJoin() { + Alias min = new Min(scan1.getOutput().get(1)).alias("min"); + Alias max = new Max(scan2.getOutput().get(1)).alias("max"); + LogicalPlan plan = new LogicalPlanBuilder(scan1) + .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0)) + .aggGroupUsingIndex(ImmutableList.of(0), ImmutableList.of(scan1.getOutput().get(0), min, max)) + .build(); + + PlanChecker.from(MemoTestUtils.createConnectContext(), plan) + .printlnTree() + .applyTopDown(new PushdownMinMaxThroughJoin()) + .printlnTree(); + } + + @Test + void testBothSide() { + Alias min = new Min(scan1.getOutput().get(1)).alias("min"); + Alias max = new Max(scan3.getOutput().get(1)).alias("max"); + LogicalPlan plan = new LogicalPlanBuilder(scan1) + .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0)) + .join(scan3, JoinType.INNER_JOIN, Pair.of(0, 0)) + .aggGroupUsingIndex(ImmutableList.of(0), ImmutableList.of(min, max)) + .build(); + + PlanChecker.from(MemoTestUtils.createConnectContext(), plan) + .applyTopDown(new PushdownMinMaxThroughJoin()) + .printlnTree(); + } +} diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushdownSumThroughJoinTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushdownSumThroughJoinTest.java new file mode 100644 index 0000000000..c6d65e784c --- /dev/null +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushdownSumThroughJoinTest.java @@ -0,0 +1,104 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.rewrite; + +import org.apache.doris.common.Pair; +import org.apache.doris.nereids.trees.expressions.Alias; +import org.apache.doris.nereids.trees.expressions.functions.agg.Sum; +import org.apache.doris.nereids.trees.plans.JoinType; +import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; +import org.apache.doris.nereids.trees.plans.logical.LogicalPlan; +import org.apache.doris.nereids.util.LogicalPlanBuilder; +import org.apache.doris.nereids.util.MemoPatternMatchSupported; +import org.apache.doris.nereids.util.MemoTestUtils; +import org.apache.doris.nereids.util.PlanChecker; +import org.apache.doris.nereids.util.PlanConstructor; + +import com.google.common.collect.ImmutableList; +import org.junit.jupiter.api.Test; + +class PushdownSumThroughJoinTest implements MemoPatternMatchSupported { + private static final LogicalOlapScan scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0); + private static final LogicalOlapScan scan2 = PlanConstructor.newLogicalOlapScan(1, "t2", 0); + private static final LogicalOlapScan scan3 = PlanConstructor.newLogicalOlapScan(2, "t3", 0); + private static final LogicalOlapScan scan4 = PlanConstructor.newLogicalOlapScan(3, "t4", 0); + + @Test + void testSingleJoinLeftSum() { + Alias sum = new Sum(scan1.getOutput().get(1)).alias("sum"); + LogicalPlan plan = new LogicalPlanBuilder(scan1) + .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0)) + .aggGroupUsingIndex(ImmutableList.of(0), ImmutableList.of(scan1.getOutput().get(0), sum)) + .build(); + + PlanChecker.from(MemoTestUtils.createConnectContext(), plan) + .printlnTree() + .applyTopDown(new PushdownSumThroughJoin()) + .matches( + logicalAggregate( + logicalJoin( + logicalAggregate(), + logicalAggregate() + ) + ) + ); + } + + @Test + void testSingleJoinRightSum() { + Alias sum = new Sum(scan2.getOutput().get(1)).alias("sum"); + LogicalPlan plan = new LogicalPlanBuilder(scan1) + .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0)) + .aggGroupUsingIndex(ImmutableList.of(0), ImmutableList.of(scan1.getOutput().get(0), sum)) + .build(); + + PlanChecker.from(MemoTestUtils.createConnectContext(), plan) + .printlnTree() + .applyTopDown(new PushdownSumThroughJoin()) + .matches( + logicalAggregate( + logicalJoin( + logicalAggregate(), + logicalAggregate() + ) + ) + ); + } + + @Test + void testAggNotOutputGroupBy() { + // agg don't output group by + Alias sum = new Sum(scan1.getOutput().get(1)).alias("sum"); + LogicalPlan plan = new LogicalPlanBuilder(scan1) + .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0)) + .aggGroupUsingIndex(ImmutableList.of(0), ImmutableList.of(sum)) + .build(); + + PlanChecker.from(MemoTestUtils.createConnectContext(), plan) + .printlnTree() + .applyTopDown(new PushdownSumThroughJoin()) + .matches( + logicalAggregate( + logicalJoin( + logicalAggregate(), + logicalAggregate() + ) + ) + ); + } +} diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/PlanChecker.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/PlanChecker.java index 77d0db9195..3f09e00657 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/PlanChecker.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/PlanChecker.java @@ -346,7 +346,12 @@ public class PlanChecker { } private PlanChecker applyExploration(Group group, Rule rule) { + // copy children expression, because group may be changed after apply rule. List logicalExpressions = Lists.newArrayList(group.getLogicalExpressions()); + // due to mergeGroup, the children Group of groupExpression may be replaced, so we need to use lambda to + // get the child to make we can get child at the time we use child. + // If we use for child: groupExpression.children(), it means that we take it in advance. It may cause NPE, + // work flow: get children() to get left, right -> copyIn left() -> mergeGroup -> right is merged -> NPE for (int i = 0; i < logicalExpressions.size(); i++) { final int childIdx = i; applyExploration(() -> logicalExpressions.get(childIdx), rule);