From 59efebce3b5193eea5a7ff97f149ae43d0624f1a Mon Sep 17 00:00:00 2001 From: minghong Date: Fri, 10 Nov 2023 16:13:53 +0800 Subject: [PATCH] [opt](nereids) estimate join cost when col stats are not available (#26086) no stats left zigzag --- .../doris/nereids/cost/CostModelV1.java | 21 -------------- .../cascades/OptimizeGroupExpressionJob.java | 5 ++++ .../apache/doris/nereids/rules/RuleSet.java | 12 ++++++++ .../exploration/join/InnerJoinLAsscom.java | 29 +++++++++++++++---- .../join/InnerJoinLAsscomProject.java | 10 +++++-- .../rules/exploration/join/JoinCommute.java | 14 ++++++++- .../SemiJoinSemiJoinTransposeProject.java | 2 +- .../doris/nereids/stats/JoinEstimation.java | 2 +- .../org/apache/doris/qe/SessionVariable.java | 12 ++++++++ .../hypergraph/GraphSimplifierTest.java | 8 ++--- .../apache/doris/nereids/memo/RankTest.java | 2 +- 11 files changed, 80 insertions(+), 37 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/cost/CostModelV1.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/cost/CostModelV1.java index 72baa7deaf..44af612281 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/cost/CostModelV1.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/cost/CostModelV1.java @@ -293,27 +293,12 @@ class CostModelV1 extends PlanVisitor { // use totalInstanceNumber to the power of 2 as the default factor value buildSideFactor = Math.pow(totalInstanceNumber, 0.5); } - // TODO: since the outputs rows may expand a lot, penalty on it will cause bc never be chosen. - // will refine this in next generation cost model. - if (!context.isStatsReliable()) { - // forbid broadcast join when stats is unknown - return CostV1.of(context.getSessionVariable(), rightRowCount * buildSideFactor + 1 / leftRowCount, - rightRowCount, - 0 - ); - } return CostV1.of(context.getSessionVariable(), leftRowCount + rightRowCount * buildSideFactor + outputRowCount * probeSideFactor, rightRowCount, 0 ); } - if (!context.isStatsReliable()) { - return CostV1.of(context.getSessionVariable(), - rightRowCount + 1 / leftRowCount, - rightRowCount, - 0); - } return CostV1.of(context.getSessionVariable(), leftRowCount + rightRowCount + outputRowCount, rightRowCount, 0 @@ -328,12 +313,6 @@ class CostModelV1 extends PlanVisitor { Preconditions.checkState(context.arity() == 2); Statistics leftStatistics = context.getChildStatistics(0); Statistics rightStatistics = context.getChildStatistics(1); - if (!context.isStatsReliable()) { - return CostV1.of(context.getSessionVariable(), - rightStatistics.getRowCount() + 1 / leftStatistics.getRowCount(), - rightStatistics.getRowCount(), - 0); - } return CostV1.of(context.getSessionVariable(), leftStatistics.getRowCount() * rightStatistics.getRowCount(), rightStatistics.getRowCount(), diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/cascades/OptimizeGroupExpressionJob.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/cascades/OptimizeGroupExpressionJob.java index 178818e660..38c0c4484b 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/cascades/OptimizeGroupExpressionJob.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/cascades/OptimizeGroupExpressionJob.java @@ -71,6 +71,9 @@ public class OptimizeGroupExpressionJob extends Job { boolean isOtherJoinReorder = context.getCascadesContext().getStatementContext().isOtherJoinReorder(); boolean isEnableBushyTree = context.getCascadesContext().getConnectContext().getSessionVariable() .isEnableBushyTree(); + boolean isLeftZigZagTree = context.getCascadesContext().getConnectContext() + .getSessionVariable().isEnableLeftZigZag() + || (groupExpression.getOwnerGroup() != null && !groupExpression.getOwnerGroup().isStatsReliable()); int joinNumBushyTree = context.getCascadesContext().getConnectContext() .getSessionVariable().getMaxJoinNumBushyTree(); if (isDisableJoinReorder) { @@ -81,6 +84,8 @@ public class OptimizeGroupExpressionJob extends Job { } else { return Collections.emptyList(); } + } else if (isLeftZigZagTree) { + return getRuleSet().getLeftZigZagTreeJoinReorder(); } else if (isEnableBushyTree) { return getRuleSet().getBushyTreeJoinReorder(); } else if (context.getCascadesContext().getStatementContext().getMaxNAryInnerJoin() <= joinNumBushyTree) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleSet.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleSet.java index b8bb9e8c46..8016366318 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleSet.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleSet.java @@ -180,6 +180,14 @@ public class RuleSet { .add(new LogicalDeferMaterializeResultSinkToPhysicalDeferMaterializeResultSink()) .build(); + // left-zig-zag tree is used when column stats are not available. + public static final List LEFT_ZIG_ZAG_TREE_JOIN_REORDER = planRuleFactories() + .add(JoinCommute.LEFT_ZIG_ZAG) + .add(InnerJoinLAsscom.LEFT_ZIG_ZAG) + .add(InnerJoinLAsscomProject.LEFT_ZIG_ZAG) + .addAll(OTHER_REORDER_RULES) + .build(); + public static final List ZIG_ZAG_TREE_JOIN_REORDER = planRuleFactories() .add(JoinCommute.ZIG_ZAG) .add(InnerJoinLAsscom.INSTANCE) @@ -220,6 +228,10 @@ public class RuleSet { return ZIG_ZAG_TREE_JOIN_REORDER_RULES; } + public List getLeftZigZagTreeJoinReorder() { + return LEFT_ZIG_ZAG_TREE_JOIN_REORDER; + } + public List getBushyTreeJoinReorder() { return BUSHY_TREE_JOIN_REORDER_RULES; } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/InnerJoinLAsscom.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/InnerJoinLAsscom.java index d92aede5bb..f256bff004 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/InnerJoinLAsscom.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/InnerJoinLAsscom.java @@ -36,7 +36,13 @@ import java.util.stream.Collectors; * Rule for change inner join LAsscom (associative and commutive). */ public class InnerJoinLAsscom extends OneExplorationRuleFactory { - public static final InnerJoinLAsscom INSTANCE = new InnerJoinLAsscom(); + public static final InnerJoinLAsscom INSTANCE = new InnerJoinLAsscom(false); + public static final InnerJoinLAsscom LEFT_ZIG_ZAG = new InnerJoinLAsscom(true); + private boolean leftZigZag = false; + + public InnerJoinLAsscom(boolean leftZigZag) { + this.leftZigZag = leftZigZag; + } /* * topJoin newTopJoin @@ -48,7 +54,7 @@ public class InnerJoinLAsscom extends OneExplorationRuleFactory { @Override public Rule build() { return innerLogicalJoin(innerLogicalJoin(), group()) - .when(topJoin -> checkReorder(topJoin, topJoin.left())) + .when(topJoin -> checkReorder(topJoin, topJoin.left(), leftZigZag)) .whenNot(join -> join.hasJoinHint() || join.left().hasJoinHint()) .whenNot(join -> join.isMarkJoin() || join.left().isMarkJoin()) .then(topJoin -> { @@ -85,11 +91,22 @@ public class InnerJoinLAsscom extends OneExplorationRuleFactory { }).toRule(RuleType.LOGICAL_INNER_JOIN_LASSCOM); } + /** + * trigger rule condition + */ public static boolean checkReorder(LogicalJoin topJoin, - LogicalJoin bottomJoin) { - return !bottomJoin.getJoinReorderContext().hasCommuteZigZag() - && !topJoin.getJoinReorderContext().hasLAsscom() - && (!bottomJoin.isMarkJoin() && !topJoin.isMarkJoin()); + LogicalJoin bottomJoin, boolean leftZigZag) { + if (leftZigZag) { + double bRows = bottomJoin.right().getGroup().getStatistics().getRowCount(); + double cRows = topJoin.right().getGroup().getStatistics().getRowCount(); + return bRows < cRows && !bottomJoin.getJoinReorderContext().hasCommuteZigZag() + && !topJoin.getJoinReorderContext().hasLAsscom() + && (!bottomJoin.isMarkJoin() && !topJoin.isMarkJoin()); + } else { + return !bottomJoin.getJoinReorderContext().hasCommuteZigZag() + && !topJoin.getJoinReorderContext().hasLAsscom() + && (!bottomJoin.isMarkJoin() && !topJoin.isMarkJoin()); + } } /** diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/InnerJoinLAsscomProject.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/InnerJoinLAsscomProject.java index 2ca9c81591..297ec9d76e 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/InnerJoinLAsscomProject.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/InnerJoinLAsscomProject.java @@ -40,8 +40,13 @@ import java.util.stream.Collectors; * Rule for change inner join LAsscom (associative and commutive). */ public class InnerJoinLAsscomProject extends OneExplorationRuleFactory { - public static final InnerJoinLAsscomProject INSTANCE = new InnerJoinLAsscomProject(); + public static final InnerJoinLAsscomProject INSTANCE = new InnerJoinLAsscomProject(false); + public static final InnerJoinLAsscomProject LEFT_ZIG_ZAG = new InnerJoinLAsscomProject(true); + private final boolean enableLeftZigZag; + public InnerJoinLAsscomProject(boolean enableLeftZigZag) { + this.enableLeftZigZag = enableLeftZigZag; + } /* * topJoin newTopJoin * / \ / \ @@ -51,10 +56,11 @@ public class InnerJoinLAsscomProject extends OneExplorationRuleFactory { * / \ / \ * A B A C */ + @Override public Rule build() { return innerLogicalJoin(logicalProject(innerLogicalJoin()), group()) - .when(topJoin -> InnerJoinLAsscom.checkReorder(topJoin, topJoin.left().child())) + .when(topJoin -> InnerJoinLAsscom.checkReorder(topJoin, topJoin.left().child(), enableLeftZigZag)) .whenNot(join -> join.hasJoinHint() || join.left().child().hasJoinHint()) .whenNot(join -> join.isMarkJoin() || join.left().child().isMarkJoin()) .when(join -> join.left().isAllSlots()) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinCommute.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinCommute.java index f4c56fabd5..d6df03e1c0 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinCommute.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinCommute.java @@ -38,6 +38,7 @@ import java.util.List; public class JoinCommute extends OneExplorationRuleFactory { public static final JoinCommute LEFT_DEEP = new JoinCommute(SwapType.LEFT_DEEP, false); + public static final JoinCommute LEFT_ZIG_ZAG = new JoinCommute(SwapType.LEFT_ZIG_ZAG, false); public static final JoinCommute ZIG_ZAG = new JoinCommute(SwapType.ZIG_ZAG, false); public static final JoinCommute BUSHY = new JoinCommute(SwapType.BUSHY, false); public static final JoinCommute NON_INNER = new JoinCommute(SwapType.BUSHY, true); @@ -73,7 +74,8 @@ public class JoinCommute extends OneExplorationRuleFactory { } enum SwapType { - LEFT_DEEP, ZIG_ZAG, BUSHY + LEFT_DEEP, ZIG_ZAG, BUSHY, + LEFT_ZIG_ZAG } /** @@ -88,6 +90,12 @@ public class JoinCommute extends OneExplorationRuleFactory { return false; } + if (swapType == SwapType.LEFT_ZIG_ZAG) { + double leftRows = join.left().getGroup().getStatistics().getRowCount(); + double rightRows = join.right().getGroup().getStatistics().getRowCount(); + return leftRows < rightRows && isZigZagJoin(join); + } + return true; } @@ -101,6 +109,10 @@ public class JoinCommute extends OneExplorationRuleFactory { return containJoin(join.left()) || containJoin(join.right()); } + public static boolean isZigZagJoin(LogicalJoin join) { + return !containJoin(join.left()) || !containJoin(join.right()); + } + private static boolean containJoin(GroupPlan groupPlan) { // TODO: tmp way to judge containJoin List output = groupPlan.getOutput(); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinSemiJoinTransposeProject.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinSemiJoinTransposeProject.java index 0db1c80087..13e9c7cd46 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinSemiJoinTransposeProject.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinSemiJoinTransposeProject.java @@ -53,7 +53,7 @@ public class SemiJoinSemiJoinTransposeProject extends OneExplorationRuleFactory public Rule build() { return logicalJoin(logicalProject(logicalJoin()), group()) .when(this::typeChecker) - .when(topSemi -> InnerJoinLAsscom.checkReorder(topSemi, topSemi.left().child())) + .when(topSemi -> InnerJoinLAsscom.checkReorder(topSemi, topSemi.left().child(), false)) .whenNot(join -> join.hasJoinHint() || join.left().child().hasJoinHint()) .whenNot(join -> join.isMarkJoin() || join.left().child().isMarkJoin()) .when(join -> join.left().isAllSlots()) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/JoinEstimation.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/JoinEstimation.java index 261d4385ce..3b7797439f 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/JoinEstimation.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/JoinEstimation.java @@ -171,7 +171,7 @@ public class JoinEstimation { private static Statistics estimateInnerJoin(Statistics leftStats, Statistics rightStats, Join join) { if (hashJoinConditionContainsUnknownColumnStats(leftStats, rightStats, join)) { - double rowCount = leftStats.getRowCount() + rightStats.getRowCount(); + double rowCount = Math.max(leftStats.getRowCount(), rightStats.getRowCount()); rowCount = Math.max(1, rowCount); return new StatisticsBuilder() .setRowCount(rowCount) diff --git a/fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java b/fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java index 7a4cef6e40..b4c117874f 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java +++ b/fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java @@ -245,6 +245,7 @@ public class SessionVariable implements Serializable, Writable { public static final String ENABLE_DPHYP_OPTIMIZER = "enable_dphyp_optimizer"; + public static final String ENABLE_LEFT_ZIG_ZAG = "enable_left_zig_zag"; public static final String NTH_OPTIMIZED_PLAN = "nth_optimized_plan"; public static final String ENABLE_NEREIDS_PLANNER = "enable_nereids_planner"; @@ -909,6 +910,17 @@ public class SessionVariable implements Serializable, Writable { @VariableMgr.VarAttr(name = NTH_OPTIMIZED_PLAN) private int nthOptimizedPlan = 1; + public boolean isEnableLeftZigZag() { + return enableLeftZigZag; + } + + public void setEnableLeftZigZag(boolean enableLeftZigZag) { + this.enableLeftZigZag = enableLeftZigZag; + } + + @VariableMgr.VarAttr(name = ENABLE_LEFT_ZIG_ZAG) + private boolean enableLeftZigZag = false; + /** * as the new optimizer is not mature yet, use this var * to control whether to use new optimizer, remove it when diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/GraphSimplifierTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/GraphSimplifierTest.java index 402cc1b64d..a1d44658db 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/GraphSimplifierTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/GraphSimplifierTest.java @@ -89,10 +89,10 @@ class GraphSimplifierTest { .add(Pair.of(17L, 2L)) // 04 - 1 .add(Pair.of(17L, 4L)) // 04 - 2 .add(Pair.of(17L, 8L)) // 04 - 3 - .add(Pair.of(25L, 2L)) // 034 - 1 - .add(Pair.of(25L, 4L)) // 034 - 2 - .add(Pair.of(29L, 2L)) // 0234 - 1 - .build(); // 0-4-3-2-1 : big left deep tree + .add(Pair.of(19L, 8L)) // 041 - 2 + .add(Pair.of(21L, 2L)) // 042 - 1 + .add(Pair.of(23L, 8L)) // 0134 - 2 + .build(); // 0-4-3-1-2 : big left deep tree for (Pair step : steps) { if (!graphSimplifier.applySimplificationStep()) { break; diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/memo/RankTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/memo/RankTest.java index ba0f4bd26c..e3571395e4 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/memo/RankTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/memo/RankTest.java @@ -55,7 +55,7 @@ class RankTest extends TestWithFeService { shape.add(memo.unrank(memo.rank(i + 1).first).shape("")); } System.out.println(shape); - Assertions.assertEquals(4, shape.size()); + Assertions.assertEquals(1, shape.size()); Assertions.assertEquals(bestPlan.shape(""), memo.unrank(memo.rank(1).first).shape("")); } }