[fix](Nereids): fix sum func in eager agg (#18675)

This commit is contained in:
jakevin
2023-04-17 15:06:28 +08:00
committed by GitHub
parent 1e06763366
commit d61f52d277
26 changed files with 120 additions and 114 deletions

View File

@ -35,10 +35,6 @@ import java.util.stream.Collectors;
* Common
*/
public class CBOUtils {
public static boolean isAllSlotProject(LogicalProject<? extends Plan> project) {
return project.getProjects().stream().allMatch(expr -> expr instanceof Slot);
}
/**
* Split project according to whether namedExpr contains by splitChildExprIds.
* Notice: projects must all be Slot.
@ -56,14 +52,6 @@ public class CBOUtils {
* If projects is empty or project output equal plan output, return the original plan.
*/
public static Plan projectOrSelf(List<NamedExpression> projects, Plan plan) {
Set<Slot> outputSet = plan.getOutputSet();
if (projects.isEmpty() || (outputSet.size() == projects.size() && outputSet.containsAll(projects))) {
return plan;
}
return new LogicalProject<>(projects, plan);
}
public static Plan projectOrSelfInOrder(List<NamedExpression> projects, Plan plan) {
if (projects.isEmpty() || projects.equals(plan.getOutput())) {
return plan;
}

View File

@ -26,7 +26,6 @@ import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
@ -50,7 +49,7 @@ import java.util.Set;
* | *
* (x)
* ->
* aggregate: SUM(x) * cnt
* aggregate: SUM(x * cnt)
* |
* join
* | \
@ -73,7 +72,7 @@ public class EagerCount implements ExplorationRuleFactory {
.then(agg -> eagerCount(agg, agg.child(), ImmutableList.of()))
.toRule(RuleType.EAGER_COUNT),
logicalAggregate(logicalProject(innerLogicalJoin()))
.when(agg -> CBOUtils.isAllSlotProject(agg.child()))
.when(agg -> agg.child().isAllSlots())
.when(agg -> agg.child().child().getOtherJoinConjuncts().size() == 0)
.when(agg -> agg.getGroupByExpressions().stream().allMatch(e -> e instanceof Slot))
.when(agg -> agg.getAggregateFunctions().stream()
@ -98,7 +97,7 @@ public class EagerCount implements ExplorationRuleFactory {
cntAggGroupBy.add(slot);
}
}));
Alias cnt = new Alias(new Count(Literal.of(1)), "cnt");
Alias cnt = new Alias(new Count(), "cnt");
List<NamedExpression> cntAggOutput = ImmutableList.<NamedExpression>builder()
.addAll(cntAggGroupBy).add(cnt).build();
LogicalAggregate<GroupPlan> cntAgg = new LogicalAggregate<>(
@ -116,7 +115,8 @@ public class EagerCount implements ExplorationRuleFactory {
}
for (Alias oldSum : sumOutputExprs) {
Sum oldSumFunc = (Sum) oldSum.child();
newOutputExprs.add(new Alias(oldSum.getExprId(), new Multiply(oldSumFunc, cnt.toSlot()),
Slot slot = (Slot) oldSumFunc.child();
newOutputExprs.add(new Alias(oldSum.getExprId(), new Sum(new Multiply(slot, cnt.toSlot())),
oldSum.getName()));
}
Plan child = PlanUtils.projectOrSelf(projects, newJoin);

View File

@ -72,7 +72,7 @@ public class EagerGroupBy implements ExplorationRuleFactory {
.then(agg -> eagerGroupBy(agg, agg.child(), ImmutableList.of()))
.toRule(RuleType.EAGER_GROUP_BY),
logicalAggregate(logicalProject(innerLogicalJoin()))
.when(agg -> CBOUtils.isAllSlotProject(agg.child()))
.when(agg -> agg.child().isAllSlots())
.when(agg -> agg.child().child().getOtherJoinConjuncts().size() == 0)
.when(agg -> agg.getGroupByExpressions().stream().allMatch(e -> e instanceof Slot))
.when(agg -> agg.getAggregateFunctions().stream()

View File

@ -26,7 +26,6 @@ import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
@ -49,7 +48,7 @@ import java.util.Set;
* | (y)
* (x)
* ->
* aggregate: SUM(sum1), SUM(y) * cnt
* aggregate: SUM(sum1), SUM(y * cnt)
* |
* join
* | \
@ -96,7 +95,7 @@ public class EagerGroupByCount extends OneExplorationRuleFactory {
for (int i = 0; i < leftSums.size(); i++) {
bottomSums.add(new Alias(new Sum(leftSums.get(i).child()), "sum" + i));
}
Alias cnt = new Alias(new Count(Literal.of(1)), "cnt");
Alias cnt = new Alias(new Count(), "cnt");
List<NamedExpression> bottomAggOutput = ImmutableList.<NamedExpression>builder()
.addAll(bottomAggGroupBy).addAll(bottomSums).add(cnt).build();
LogicalAggregate<GroupPlan> bottomAgg = new LogicalAggregate<>(
@ -129,7 +128,8 @@ public class EagerGroupByCount extends OneExplorationRuleFactory {
}
for (Alias oldSum : rightSumOutputExprs) {
Sum oldSumFunc = (Sum) oldSum.child();
newOutputExprs.add(new Alias(oldSum.getExprId(), new Multiply(oldSumFunc, cnt.toSlot()),
Slot slot = (Slot) oldSumFunc.child();
newOutputExprs.add(new Alias(oldSum.getExprId(), new Sum(new Multiply(slot, cnt.toSlot())),
oldSum.getName()));
}
return agg.withAggOutput(newOutputExprs).withChildren(newJoin);

View File

@ -27,7 +27,6 @@ import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
@ -51,7 +50,7 @@ import java.util.Set;
* | (y)
* (x)
* ->
* aggregate: SUM(sum1) * cnt2, SUM(sum2) * cnt1
* aggregate: SUM(sum1 * cnt2), SUM(sum2 * cnt1)
* |
* join
* | \
@ -98,7 +97,7 @@ public class EagerSplit extends OneExplorationRuleFactory {
for (int i = 0; i < leftSums.size(); i++) {
leftBottomSums.add(new Alias(new Sum(leftSums.get(i).child()), "left_sum" + i));
}
Alias leftCnt = new Alias(new Count(Literal.of(1)), "left_cnt");
Alias leftCnt = new Alias(new Count(), "left_cnt");
List<NamedExpression> leftBottomAggOutput = ImmutableList.<NamedExpression>builder()
.addAll(leftBottomAggGroupBy).addAll(leftBottomSums).add(leftCnt).build();
LogicalAggregate<GroupPlan> leftBottomAgg = new LogicalAggregate<>(
@ -117,7 +116,7 @@ public class EagerSplit extends OneExplorationRuleFactory {
for (int i = 0; i < rightSums.size(); i++) {
rightBottomSums.add(new Alias(new Sum(rightSums.get(i).child()), "right_sum" + i));
}
Alias rightCnt = new Alias(new Count(Literal.of(1)), "right_cnt");
Alias rightCnt = new Alias(new Count(), "right_cnt");
List<NamedExpression> rightBottomAggOutput = ImmutableList.<NamedExpression>builder()
.addAll(rightBottomAggGroupBy).addAll(rightBottomSums).add(rightCnt).build();
LogicalAggregate<GroupPlan> rightBottomAgg = new LogicalAggregate<>(
@ -146,16 +145,15 @@ public class EagerSplit extends OneExplorationRuleFactory {
Preconditions.checkState(rightSumOutputExprs.size() == rightBottomSums.size());
for (int i = 0; i < leftSumOutputExprs.size(); i++) {
Alias oldSum = leftSumOutputExprs.get(i);
Slot bottomSum = leftBottomSums.get(i).toSlot();
Alias newSum = new Alias(oldSum.getExprId(),
new Multiply(new Sum(bottomSum), rightCnt.toSlot()), oldSum.getName());
newOutputExprs.add(newSum);
Slot slot = leftBottomSums.get(i).toSlot();
newOutputExprs.add(new Alias(oldSum.getExprId(), new Sum(new Multiply(slot, rightCnt.toSlot())),
oldSum.getName()));
}
for (int i = 0; i < rightSumOutputExprs.size(); i++) {
Alias oldSum = rightSumOutputExprs.get(i);
Slot bottomSum = rightBottomSums.get(i).toSlot();
Alias newSum = new Alias(oldSum.getExprId(),
new Multiply(new Sum(bottomSum), leftCnt.toSlot()), oldSum.getName());
Alias newSum = new Alias(oldSum.getExprId(), new Sum(new Multiply(bottomSum, leftCnt.toSlot())),
oldSum.getName());
newOutputExprs.add(newSum);
}
return agg.withAggOutput(newOutputExprs).withChildren(newJoin);

View File

@ -56,7 +56,7 @@ public class InnerJoinLAsscomProject extends OneExplorationRuleFactory {
.when(topJoin -> InnerJoinLAsscom.checkReorder(topJoin, topJoin.left().child()))
.whenNot(join -> join.hasJoinHint() || join.left().child().hasJoinHint())
.whenNot(join -> join.isMarkJoin() || join.left().child().isMarkJoin())
.when(join -> CBOUtils.isAllSlotProject(join.left()))
.when(join -> join.left().isAllSlots())
.then(topJoin -> {
/* ********** init ********** */
LogicalJoin<GroupPlan, GroupPlan> bottomJoin = topJoin.left().child();

View File

@ -52,7 +52,7 @@ public class InnerJoinLeftAssociateProject extends OneExplorationRuleFactory {
.when(InnerJoinLeftAssociate::checkReorder)
.whenNot(join -> join.hasJoinHint() || join.right().child().hasJoinHint())
.whenNot(join -> join.isMarkJoin() || join.right().child().isMarkJoin())
.when(join -> CBOUtils.isAllSlotProject(join.right()))
.when(join -> join.right().isAllSlots())
.then(topJoin -> {
LogicalJoin<GroupPlan, GroupPlan> bottomJoin = topJoin.right().child();
GroupPlan a = topJoin.left();

View File

@ -50,7 +50,7 @@ public class InnerJoinRightAssociateProject extends OneExplorationRuleFactory {
.when(InnerJoinRightAssociate::checkReorder)
.whenNot(join -> join.hasJoinHint() || join.left().child().hasJoinHint())
.whenNot(join -> join.isMarkJoin() || join.left().child().isMarkJoin())
.when(join -> CBOUtils.isAllSlotProject(join.left()))
.when(join -> join.left().isAllSlots())
.then(topJoin -> {
LogicalJoin<GroupPlan, GroupPlan> bottomJoin = topJoin.left().child();
GroupPlan a = bottomJoin.left();

View File

@ -54,8 +54,7 @@ public class JoinExchangeBothProject extends OneExplorationRuleFactory {
public Rule build() {
return innerLogicalJoin(logicalProject(innerLogicalJoin()), logicalProject(innerLogicalJoin()))
.when(JoinExchange::checkReorder)
.when(join -> CBOUtils.isAllSlotProject(join.left())
&& CBOUtils.isAllSlotProject(join.right()))
.when(join -> join.left().isAllSlots() && join.right().isAllSlots())
.whenNot(join -> join.hasJoinHint()
|| join.left().child().hasJoinHint() || join.right().child().hasJoinHint())
.whenNot(join -> join.isMarkJoin() || join.left().child().isMarkJoin() || join.right().child().isMarkJoin())

View File

@ -54,7 +54,7 @@ public class JoinExchangeLeftProject extends OneExplorationRuleFactory {
public Rule build() {
return innerLogicalJoin(logicalProject(innerLogicalJoin()), innerLogicalJoin())
.when(JoinExchange::checkReorder)
.when(join -> CBOUtils.isAllSlotProject(join.left()))
.when(join -> join.left().isAllSlots())
.whenNot(join -> join.hasJoinHint()
|| join.left().child().hasJoinHint() || join.right().hasJoinHint())
.whenNot(join -> join.isMarkJoin() || join.left().child().isMarkJoin() || join.right().isMarkJoin())

View File

@ -54,7 +54,7 @@ public class JoinExchangeRightProject extends OneExplorationRuleFactory {
public Rule build() {
return innerLogicalJoin(innerLogicalJoin(), logicalProject(innerLogicalJoin()))
.when(JoinExchange::checkReorder)
.when(join -> CBOUtils.isAllSlotProject(join.right()))
.when(join -> join.right().isAllSlots())
.whenNot(join -> join.hasJoinHint()
|| join.left().hasJoinHint() || join.right().child().hasJoinHint())
.whenNot(join -> join.isMarkJoin() || join.left().isMarkJoin() || join.right().child().isMarkJoin())

View File

@ -46,7 +46,7 @@ public class LogicalJoinSemiJoinTransposeProject implements ExplorationRuleFacto
|| topJoin.getJoinType().isLeftOuterJoin())))
.whenNot(topJoin -> topJoin.hasJoinHint() || topJoin.left().child().hasJoinHint())
.whenNot(LogicalJoin::isMarkJoin)
.when(join -> CBOUtils.isAllSlotProject(join.left()))
.when(join -> join.left().isAllSlots())
.then(topJoin -> {
LogicalJoin<GroupPlan, GroupPlan> bottomJoin = topJoin.left().child();
GroupPlan a = bottomJoin.left();
@ -64,7 +64,7 @@ public class LogicalJoinSemiJoinTransposeProject implements ExplorationRuleFacto
.when(topJoin -> (topJoin.right().child().getJoinType().isLeftSemiOrAntiJoin()
&& (topJoin.getJoinType().isInnerJoin()
|| topJoin.getJoinType().isRightOuterJoin())))
.when(join -> CBOUtils.isAllSlotProject(join.right()))
.when(join -> join.right().isAllSlots())
.then(topJoin -> {
LogicalJoin<GroupPlan, GroupPlan> bottomJoin = topJoin.right().child();
GroupPlan a = topJoin.left();

View File

@ -59,7 +59,7 @@ public class OuterJoinAssocProject extends OneExplorationRuleFactory {
.whenNot(join -> join.hasJoinHint() || join.left().child().hasJoinHint())
.whenNot(join -> join.isMarkJoin() || join.left().child().isMarkJoin())
.when(join -> OuterJoinAssoc.checkCondition(join, join.left().child().left().getOutputSet()))
.when(join -> CBOUtils.isAllSlotProject(join.left()))
.when(join -> join.left().isAllSlots())
.then(topJoin -> {
/* ********** init ********** */
List<NamedExpression> projects = topJoin.left().getProjects();

View File

@ -61,7 +61,7 @@ public class OuterJoinLAsscomProject extends OneExplorationRuleFactory {
.when(topJoin -> OuterJoinLAsscom.checkReorder(topJoin, topJoin.left().child()))
.whenNot(join -> join.hasJoinHint() || join.left().child().hasJoinHint())
.whenNot(join -> join.isMarkJoin() || join.left().child().isMarkJoin())
.when(join -> CBOUtils.isAllSlotProject(join.left()))
.when(join -> join.left().isAllSlots())
.then(topJoin -> {
/* ********** init ********** */
List<NamedExpression> projects = topJoin.left().getProjects();

View File

@ -27,6 +27,7 @@ import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableList.Builder;
@ -52,7 +53,7 @@ public class PushdownProjectThroughInnerJoin extends OneExplorationRuleFactory {
@Override
public Rule build() {
return logicalProject(logicalJoin())
.whenNot(CBOUtils::isAllSlotProject)
.whenNot(LogicalProject::isAllSlots)
.when(project -> project.child().getJoinType().isInnerJoin())
.whenNot(project -> project.child().hasJoinHint())
.then(project -> {
@ -105,7 +106,7 @@ public class PushdownProjectThroughInnerJoin extends OneExplorationRuleFactory {
Plan newRight = CBOUtils.projectOrSelf(newBProject.build(), join.right());
Plan newJoin = join.withChildrenNoContext(newLeft, newRight);
return CBOUtils.projectOrSelfInOrder(new ArrayList<>(project.getOutput()), newJoin);
return CBOUtils.projectOrSelf(new ArrayList<>(project.getOutput()), newJoin);
}).toRule(RuleType.PUSH_DOWN_PROJECT_THROUGH_INNER_JOIN);
}
}

View File

@ -27,6 +27,7 @@ import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import java.util.ArrayList;
import java.util.List;
@ -52,7 +53,7 @@ public class PushdownProjectThroughSemiJoin extends OneExplorationRuleFactory {
return logicalProject(logicalJoin())
.when(project -> project.child().getJoinType().isLeftSemiOrAntiJoin())
// Just pushdown project with non-column expr like (t.id + 1)
.whenNot(CBOUtils::isAllSlotProject)
.whenNot(LogicalProject::isAllSlots)
.whenNot(project -> project.child().hasJoinHint())
.then(project -> {
LogicalJoin<GroupPlan, GroupPlan> join = project.child();
@ -65,7 +66,7 @@ public class PushdownProjectThroughSemiJoin extends OneExplorationRuleFactory {
Plan newLeft = CBOUtils.projectOrSelf(newProject, join.left());
Plan newJoin = join.withChildrenNoContext(newLeft, join.right());
return CBOUtils.projectOrSelfInOrder(new ArrayList<>(project.getOutput()), newJoin);
return CBOUtils.projectOrSelf(new ArrayList<>(project.getOutput()), newJoin);
}).toRule(RuleType.PUSH_DOWN_PROJECT_THROUGH_SEMI_JOIN);
}

View File

@ -56,7 +56,7 @@ public class SemiJoinSemiJoinTransposeProject extends OneExplorationRuleFactory
.when(topSemi -> InnerJoinLAsscom.checkReorder(topSemi, topSemi.left().child()))
.whenNot(join -> join.hasJoinHint() || join.left().child().hasJoinHint())
.whenNot(join -> join.isMarkJoin() || join.left().child().isMarkJoin())
.when(join -> CBOUtils.isAllSlotProject(join.left()))
.when(join -> join.left().isAllSlots())
.then(topSemi -> {
LogicalJoin<GroupPlan, GroupPlan> bottomSemi = topSemi.left().child();
LogicalProject abProject = topSemi.left();

View File

@ -102,6 +102,10 @@ public class LogicalProject<CHILD_TYPE extends Plan> extends LogicalUnary<CHILD_
return excepts;
}
public boolean isAllSlots() {
return projects.stream().allMatch(NamedExpression::isSlot);
}
@Override
public List<Slot> computeOutput() {
return projects.stream()