From aaedcf34cf09c7a7cf6df0754e39108de3053a7d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B0=A2=E5=81=A5?= Date: Thu, 9 Mar 2023 13:58:44 +0800 Subject: [PATCH] [enhancement](Nereids) refactor costModel framework (#17339) refactor cost-model frameWork: 1. Use Cost class to encapsulate double cost 2. Use the `addChildCost` function to calculate the cost with children rather than add directly Note we use the `Cost` class because we hope to customize the operator of adding a child host. Therefore, only when the cost would add the child Cost or be added by the parent we use `Cost`. Otherwise, we use double such as `upperbound` --- .../apache/doris/nereids/NereidsPlanner.java | 2 +- .../org/apache/doris/nereids/cost/Cost.java | 44 +++ .../doris/nereids/cost/CostCalculator.java | 283 +----------------- .../doris/nereids/cost/CostEstimate.java | 128 -------- .../doris/nereids/cost/CostModelV1.java | 279 +++++++++++++++++ .../org/apache/doris/nereids/cost/CostV1.java | 126 ++++++++ .../apache/doris/nereids/cost/CostWeight.java | 37 ++- .../jobs/cascades/CostAndEnforcerJob.java | 54 ++-- .../joinorder/hypergraph/GraphSimplifier.java | 30 +- .../org/apache/doris/nereids/memo/Group.java | 25 +- .../doris/nereids/memo/GroupExpression.java | 31 +- .../org/apache/doris/nereids/memo/Memo.java | 15 +- .../ChildrenPropertiesRegulator.java | 26 +- .../EnforceMissingPropertiesHelper.java | 16 +- .../apache/doris/nereids/memo/MemoTest.java | 3 +- 15 files changed, 597 insertions(+), 502 deletions(-) create mode 100644 fe/fe-core/src/main/java/org/apache/doris/nereids/cost/Cost.java delete mode 100644 fe/fe-core/src/main/java/org/apache/doris/nereids/cost/CostEstimate.java create mode 100644 fe/fe-core/src/main/java/org/apache/doris/nereids/cost/CostModelV1.java create mode 100644 fe/fe-core/src/main/java/org/apache/doris/nereids/cost/CostV1.java diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/NereidsPlanner.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/NereidsPlanner.java index 33025b2559..6f35193797 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/NereidsPlanner.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/NereidsPlanner.java @@ -267,7 +267,7 @@ public class NereidsPlanner extends Planner { if (nthPlan <= 1) { cost = rootGroup.getLowestCostPlan(physicalProperties).orElseThrow( () -> new AnalysisException("lowestCostPlans with physicalProperties(" - + physicalProperties + ") doesn't exist in root group")).first; + + physicalProperties + ") doesn't exist in root group")).first.getValue(); return chooseBestPlan(rootGroup, physicalProperties); } Memo memo = cascadesContext.getMemo(); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/cost/Cost.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/cost/Cost.java new file mode 100644 index 0000000000..c616b5a858 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/cost/Cost.java @@ -0,0 +1,44 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.cost; + +/** + * Cost encapsulate the real cost with double type. + * We do this because we want to customize the operation of adding child cost + * according different operator + */ +public interface Cost { + public double getValue(); + + public int compare(Cost other); + + public Cost minus(Cost other); + + public static Cost withRowCount(double rowCount) { + return new CostV1(rowCount, 0, 0, 0); + } + + public static Cost zero() { + return CostV1.zero(); + } + + public static Cost infinite() { + return CostV1.infinite(); + } +} + diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/cost/CostCalculator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/cost/CostCalculator.java index 6a805cd665..36c4dcfec7 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/cost/CostCalculator.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/cost/CostCalculator.java @@ -20,31 +20,7 @@ package org.apache.doris.nereids.cost; import org.apache.doris.nereids.PlanContext; import org.apache.doris.nereids.annotation.Developing; import org.apache.doris.nereids.memo.GroupExpression; -import org.apache.doris.nereids.properties.DistributionSpec; -import org.apache.doris.nereids.properties.DistributionSpecGather; -import org.apache.doris.nereids.properties.DistributionSpecHash; -import org.apache.doris.nereids.properties.DistributionSpecReplicated; import org.apache.doris.nereids.trees.plans.Plan; -import org.apache.doris.nereids.trees.plans.physical.PhysicalAssertNumRows; -import org.apache.doris.nereids.trees.plans.physical.PhysicalDistribute; -import org.apache.doris.nereids.trees.plans.physical.PhysicalEsScan; -import org.apache.doris.nereids.trees.plans.physical.PhysicalFileScan; -import org.apache.doris.nereids.trees.plans.physical.PhysicalGenerate; -import org.apache.doris.nereids.trees.plans.physical.PhysicalHashAggregate; -import org.apache.doris.nereids.trees.plans.physical.PhysicalHashJoin; -import org.apache.doris.nereids.trees.plans.physical.PhysicalJdbcScan; -import org.apache.doris.nereids.trees.plans.physical.PhysicalNestedLoopJoin; -import org.apache.doris.nereids.trees.plans.physical.PhysicalOlapScan; -import org.apache.doris.nereids.trees.plans.physical.PhysicalProject; -import org.apache.doris.nereids.trees.plans.physical.PhysicalQuickSort; -import org.apache.doris.nereids.trees.plans.physical.PhysicalSchemaScan; -import org.apache.doris.nereids.trees.plans.physical.PhysicalStorageLayerAggregate; -import org.apache.doris.nereids.trees.plans.physical.PhysicalTopN; -import org.apache.doris.nereids.trees.plans.visitor.PlanVisitor; -import org.apache.doris.qe.ConnectContext; -import org.apache.doris.statistics.StatsDeriveResult; - -import com.google.common.base.Preconditions; /** * Calculate the cost of a plan. @@ -53,264 +29,21 @@ import com.google.common.base.Preconditions; @Developing //TODO: memory cost and network cost should be estimated by byte size. public class CostCalculator { - static final double CPU_WEIGHT = 1; - static final double MEMORY_WEIGHT = 1; - static final double NETWORK_WEIGHT = 1.5; - - /** - * The intuition behind `HEAVY_OPERATOR_PUNISH_FACTOR` is we need to avoid this form of join patterns: - * Plan1: L join ( AGG1(A) join AGG2(B)) - * But - * Plan2: L join AGG1(A) join AGG2(B) is welcomed. - * AGG is time-consuming operator. From the perspective of rowCount, nereids may choose Plan1, - * because `Agg1 join Agg2` generates few tuples. But in Plan1, Agg1 and Agg2 are done in serial, in Plan2, Agg1 and - * Agg2 are done in parallel. And hence, Plan1 should be punished. - *

- * An example is tpch q15. - */ - static final double HEAVY_OPERATOR_PUNISH_FACTOR = 6.0; - /** * Constructor. */ - public static double calculateCost(GroupExpression groupExpression) { + public static Cost calculateCost(GroupExpression groupExpression) { PlanContext planContext = new PlanContext(groupExpression); - CostEstimator costCalculator = new CostEstimator(); - CostEstimate costEstimate = groupExpression.getPlan().accept(costCalculator, planContext); - groupExpression.setCostEstimate(costEstimate); - /* - * About PENALTY: - * Except stats information, there are some special criteria in doris. - * For example, in hash join cluster, BE could build hash tables - * in parallel for left deep tree. And hence, we need to punish right deep tree. - * penalyWeight is the factor of punishment. - * The punishment is denoted by stats.penalty. - */ - CostWeight costWeight = new CostWeight(CPU_WEIGHT, MEMORY_WEIGHT, NETWORK_WEIGHT, - ConnectContext.get().getSessionVariable().getNereidsCboPenaltyFactor()); - return costWeight.calculate(costEstimate); + CostModelV1 costModel = new CostModelV1(); + return groupExpression.getPlan().accept(costModel, planContext); } - public static double calculateCost(Plan plan, PlanContext planContext) { - CostEstimator costCalculator = new CostEstimator(); - CostEstimate costEstimate = plan.accept(costCalculator, planContext); - CostWeight costWeight = new CostWeight(CPU_WEIGHT, MEMORY_WEIGHT, NETWORK_WEIGHT, - ConnectContext.get().getSessionVariable().getNereidsCboPenaltyFactor()); - return costWeight.calculate(costEstimate); + public static Cost calculateCost(Plan plan, PlanContext planContext) { + CostModelV1 costModel = new CostModelV1(); + return plan.accept(costModel, planContext); } - private static class CostEstimator extends PlanVisitor { - @Override - public CostEstimate visit(Plan plan, PlanContext context) { - return CostEstimate.zero(); - } - - @Override - public CostEstimate visitPhysicalOlapScan(PhysicalOlapScan physicalOlapScan, PlanContext context) { - StatsDeriveResult statistics = context.getStatisticsWithCheck(); - return CostEstimate.ofCpu(statistics.getRowCount()); - } - - public CostEstimate visitPhysicalSchemaScan(PhysicalSchemaScan physicalSchemaScan, PlanContext context) { - StatsDeriveResult statistics = context.getStatisticsWithCheck(); - return CostEstimate.ofCpu(statistics.getRowCount()); - } - - @Override - public CostEstimate visitPhysicalStorageLayerAggregate( - PhysicalStorageLayerAggregate storageLayerAggregate, PlanContext context) { - CostEstimate costEstimate = storageLayerAggregate.getRelation().accept(this, context); - // multiply a factor less than 1, so we can select PhysicalStorageLayerAggregate as far as possible - return new CostEstimate(costEstimate.getCpuCost() * 0.7, costEstimate.getMemoryCost(), - costEstimate.getNetworkCost(), costEstimate.getPenalty()); - } - - @Override - public CostEstimate visitPhysicalFileScan(PhysicalFileScan physicalFileScan, PlanContext context) { - StatsDeriveResult statistics = context.getStatisticsWithCheck(); - return CostEstimate.ofCpu(statistics.getRowCount()); - } - - @Override - public CostEstimate visitPhysicalProject(PhysicalProject physicalProject, PlanContext context) { - return CostEstimate.ofCpu(1); - } - - @Override - public CostEstimate visitPhysicalJdbcScan(PhysicalJdbcScan physicalJdbcScan, PlanContext context) { - StatsDeriveResult statistics = context.getStatisticsWithCheck(); - return CostEstimate.ofCpu(statistics.getRowCount()); - } - - @Override - public CostEstimate visitPhysicalEsScan(PhysicalEsScan physicalEsScan, PlanContext context) { - StatsDeriveResult statistics = context.getStatisticsWithCheck(); - return CostEstimate.ofCpu(statistics.getRowCount()); - } - - @Override - public CostEstimate visitPhysicalQuickSort( - PhysicalQuickSort physicalQuickSort, PlanContext context) { - // TODO: consider two-phase sort and enforcer. - StatsDeriveResult statistics = context.getStatisticsWithCheck(); - StatsDeriveResult childStatistics = context.getChildStatistics(0); - if (physicalQuickSort.getSortPhase().isGather()) { - // Now we do more like two-phase sort, so penalise one-phase sort - statistics.updateRowCount(statistics.getRowCount() * 100); - } - return CostEstimate.of( - childStatistics.getRowCount(), - statistics.getRowCount(), - childStatistics.getRowCount()); - } - - @Override - public CostEstimate visitPhysicalTopN(PhysicalTopN topN, PlanContext context) { - // TODO: consider two-phase sort and enforcer. - StatsDeriveResult statistics = context.getStatisticsWithCheck(); - StatsDeriveResult childStatistics = context.getChildStatistics(0); - if (topN.getSortPhase().isGather()) { - // Now we do more like two-phase sort, so penalise one-phase sort - statistics.updateRowCount(statistics.getRowCount() * 100); - } - return CostEstimate.of( - childStatistics.getRowCount(), - statistics.getRowCount(), - childStatistics.getRowCount()); - } - - @Override - public CostEstimate visitPhysicalDistribute( - PhysicalDistribute distribute, PlanContext context) { - StatsDeriveResult childStatistics = context.getChildStatistics(0); - DistributionSpec spec = distribute.getDistributionSpec(); - // shuffle - if (spec instanceof DistributionSpecHash) { - return CostEstimate.of( - childStatistics.getRowCount(), - 0, - childStatistics.getRowCount()); - } - - // replicate - if (spec instanceof DistributionSpecReplicated) { - int beNumber = ConnectContext.get().getEnv().getClusterInfo().getBackendIds(true).size(); - int instanceNumber = ConnectContext.get().getSessionVariable().getParallelExecInstanceNum(); - beNumber = Math.max(1, beNumber); - double memLimit = ConnectContext.get().getSessionVariable().getMaxExecMemByte(); - //if build side is big, avoid use broadcast join - double rowsLimit = ConnectContext.get().getSessionVariable().getBroadcastRowCountLimit(); - double brMemlimit = ConnectContext.get().getSessionVariable().getBroadcastHashtableMemLimitPercentage(); - double buildSize = childStatistics.computeSize(); - if (buildSize * instanceNumber > memLimit * brMemlimit - || childStatistics.getRowCount() > rowsLimit) { - return CostEstimate.of(Double.MAX_VALUE, Double.MAX_VALUE, Double.MAX_VALUE); - } - return CostEstimate.of( - childStatistics.getRowCount() * beNumber, - childStatistics.getRowCount() * beNumber * instanceNumber, - childStatistics.getRowCount() * beNumber * instanceNumber); - } - - // gather - if (spec instanceof DistributionSpecGather) { - return CostEstimate.of( - childStatistics.getRowCount(), - 0, - childStatistics.getRowCount()); - } - - // any - return CostEstimate.of( - childStatistics.getRowCount(), - 0, - 0); - } - - @Override - public CostEstimate visitPhysicalHashAggregate( - PhysicalHashAggregate aggregate, PlanContext context) { - // TODO: stage..... - - StatsDeriveResult statistics = context.getStatisticsWithCheck(); - StatsDeriveResult inputStatistics = context.getChildStatistics(0); - return CostEstimate.of(inputStatistics.getRowCount(), statistics.getRowCount(), 0); - } - - @Override - public CostEstimate visitPhysicalHashJoin( - PhysicalHashJoin physicalHashJoin, PlanContext context) { - Preconditions.checkState(context.arity() == 2); - StatsDeriveResult outputStats = context.getStatisticsWithCheck(); - double outputRowCount = outputStats.getRowCount(); - - StatsDeriveResult probeStats = context.getChildStatistics(0); - StatsDeriveResult buildStats = context.getChildStatistics(1); - - double leftRowCount = probeStats.getRowCount(); - double rightRowCount = buildStats.getRowCount(); - /* - pattern1: L join1 (Agg1() join2 Agg2()) - result number of join2 may much less than Agg1. - but Agg1 and Agg2 are slow. so we need to punish this pattern1. - - pattern2: (L join1 Agg1) join2 agg2 - in pattern2, join1 and join2 takes more time, but Agg1 and agg2 can be processed in parallel. - */ - double penalty = CostCalculator.HEAVY_OPERATOR_PUNISH_FACTOR - * Math.min(probeStats.getPenalty(), buildStats.getPenalty()); - if (buildStats.getWidth() >= 2) { - //penalty for right deep tree - penalty += rightRowCount; - } - - if (physicalHashJoin.getJoinType().isCrossJoin()) { - return CostEstimate.of(leftRowCount + rightRowCount + outputRowCount, - 0, - leftRowCount + rightRowCount, - penalty); - } - return CostEstimate.of(leftRowCount + rightRowCount + outputRowCount, - rightRowCount, - 0, - penalty - ); - } - - @Override - public CostEstimate visitPhysicalNestedLoopJoin( - PhysicalNestedLoopJoin nestedLoopJoin, - PlanContext context) { - // TODO: copy from physicalHashJoin, should update according to physical nested loop join properties. - Preconditions.checkState(context.arity() == 2); - - StatsDeriveResult leftStatistics = context.getChildStatistics(0); - StatsDeriveResult rightStatistics = context.getChildStatistics(1); - - return CostEstimate.of( - leftStatistics.getRowCount() * rightStatistics.getRowCount(), - rightStatistics.getRowCount(), - 0); - } - - @Override - public CostEstimate visitPhysicalAssertNumRows(PhysicalAssertNumRows assertNumRows, - PlanContext context) { - return CostEstimate.of( - assertNumRows.getAssertNumRowsElement().getDesiredNumOfRows(), - assertNumRows.getAssertNumRowsElement().getDesiredNumOfRows(), - 0 - ); - } - - @Override - public CostEstimate visitPhysicalGenerate(PhysicalGenerate generate, PlanContext context) { - StatsDeriveResult statistics = context.getStatisticsWithCheck(); - return CostEstimate.of( - statistics.getRowCount(), - statistics.getRowCount(), - 0 - ); - } + public static Cost addChildCost(Plan plan, Cost planCost, Cost childCost, int index) { + return CostModelV1.addChildCost(plan, planCost, childCost, index); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/cost/CostEstimate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/cost/CostEstimate.java deleted file mode 100644 index 63600f4466..0000000000 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/cost/CostEstimate.java +++ /dev/null @@ -1,128 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -package org.apache.doris.nereids.cost; - -import com.google.common.base.Preconditions; - -/** - * Use for estimating the cost of plan. - */ -public final class CostEstimate { - private static final CostEstimate INFINITE = - new CostEstimate(Double.POSITIVE_INFINITY, - Double.POSITIVE_INFINITY, - Double.POSITIVE_INFINITY, - Double.POSITIVE_INFINITY); - private static final CostEstimate ZERO = new CostEstimate(0, 0, 0, 0); - - private final double cpuCost; - private final double memoryCost; - private final double networkCost; - //penalty for - // 1. right deep tree - // 2. right XXX join - private final double penalty; - - /** - * Constructor of CostEstimate. - */ - public CostEstimate(double cpuCost, double memoryCost, double networkCost, double penaltiy) { - // TODO: fix stats - if (cpuCost < 0) { - cpuCost = 0; - } - if (memoryCost < 0) { - memoryCost = 0; - } - if (networkCost < 0) { - networkCost = 0; - } - Preconditions.checkArgument(!(cpuCost < 0), "cpuCost cannot be negative: %s", cpuCost); - Preconditions.checkArgument(!(memoryCost < 0), "memoryCost cannot be negative: %s", memoryCost); - Preconditions.checkArgument(!(networkCost < 0), "networkCost cannot be negative: %s", networkCost); - this.cpuCost = cpuCost; - this.memoryCost = memoryCost; - this.networkCost = networkCost; - this.penalty = penaltiy; - } - - public static CostEstimate infinite() { - return INFINITE; - } - - public static CostEstimate zero() { - return ZERO; - } - - public double getCpuCost() { - return cpuCost; - } - - public double getMemoryCost() { - return memoryCost; - } - - public double getNetworkCost() { - return networkCost; - } - - public double getPenalty() { - return penalty; - } - - public static CostEstimate of(double cpuCost, double maxMemory, double networkCost, double rightDeepPenaltiy) { - return new CostEstimate(cpuCost, maxMemory, networkCost, rightDeepPenaltiy); - } - - public static CostEstimate of(double cpuCost, double maxMemory, double networkCost) { - return new CostEstimate(cpuCost, maxMemory, networkCost, 0); - } - - public static CostEstimate ofCpu(double cpuCost) { - return new CostEstimate(cpuCost, 0, 0, 0); - } - - public static CostEstimate ofMemory(double memoryCost) { - return new CostEstimate(0, memoryCost, 0, 0); - } - - /** - * sum of cost estimate - */ - public static CostEstimate sum(CostEstimate one, CostEstimate two, CostEstimate... more) { - double cpuCostSum = one.cpuCost + two.cpuCost; - double memoryCostSum = one.memoryCost + two.memoryCost; - double networkCostSum = one.networkCost + one.networkCost; - for (CostEstimate costEstimate : more) { - cpuCostSum += costEstimate.cpuCost; - memoryCostSum += costEstimate.memoryCost; - networkCostSum += costEstimate.networkCost; - } - return CostEstimate.of(cpuCostSum, memoryCostSum, networkCostSum); - } - - @Override - public String toString() { - StringBuilder sb = new StringBuilder(); - sb.append((long) cpuCost) - .append("/").append((long) memoryCost) - .append("/").append((long) networkCost) - .append("/").append((long) penalty); - return sb.toString(); - } -} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/cost/CostModelV1.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/cost/CostModelV1.java new file mode 100644 index 0000000000..d290ee38b3 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/cost/CostModelV1.java @@ -0,0 +1,279 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.cost; + +import org.apache.doris.nereids.PlanContext; +import org.apache.doris.nereids.properties.DistributionSpec; +import org.apache.doris.nereids.properties.DistributionSpecGather; +import org.apache.doris.nereids.properties.DistributionSpecHash; +import org.apache.doris.nereids.properties.DistributionSpecReplicated; +import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.physical.PhysicalAssertNumRows; +import org.apache.doris.nereids.trees.plans.physical.PhysicalDistribute; +import org.apache.doris.nereids.trees.plans.physical.PhysicalEsScan; +import org.apache.doris.nereids.trees.plans.physical.PhysicalFileScan; +import org.apache.doris.nereids.trees.plans.physical.PhysicalGenerate; +import org.apache.doris.nereids.trees.plans.physical.PhysicalHashAggregate; +import org.apache.doris.nereids.trees.plans.physical.PhysicalHashJoin; +import org.apache.doris.nereids.trees.plans.physical.PhysicalJdbcScan; +import org.apache.doris.nereids.trees.plans.physical.PhysicalNestedLoopJoin; +import org.apache.doris.nereids.trees.plans.physical.PhysicalOlapScan; +import org.apache.doris.nereids.trees.plans.physical.PhysicalProject; +import org.apache.doris.nereids.trees.plans.physical.PhysicalQuickSort; +import org.apache.doris.nereids.trees.plans.physical.PhysicalSchemaScan; +import org.apache.doris.nereids.trees.plans.physical.PhysicalStorageLayerAggregate; +import org.apache.doris.nereids.trees.plans.physical.PhysicalTopN; +import org.apache.doris.nereids.trees.plans.visitor.PlanVisitor; +import org.apache.doris.qe.ConnectContext; +import org.apache.doris.statistics.StatsDeriveResult; + +import com.google.common.base.Preconditions; + +class CostModelV1 extends PlanVisitor { + /** + * The intuition behind `HEAVY_OPERATOR_PUNISH_FACTOR` is we need to avoid this form of join patterns: + * Plan1: L join ( AGG1(A) join AGG2(B)) + * But + * Plan2: L join AGG1(A) join AGG2(B) is welcomed. + * AGG is time-consuming operator. From the perspective of rowCount, nereids may choose Plan1, + * because `Agg1 join Agg2` generates few tuples. But in Plan1, Agg1 and Agg2 are done in serial, in Plan2, Agg1 and + * Agg2 are done in parallel. And hence, Plan1 should be punished. + *

+ * An example is tpch q15. + */ + static final double HEAVY_OPERATOR_PUNISH_FACTOR = 6.0; + + public static Cost addChildCost(Plan plan, Cost planCost, Cost childCost, int index) { + Preconditions.checkArgument(childCost instanceof CostV1 && planCost instanceof CostV1); + double cost = planCost.getValue() + childCost.getValue(); + return new CostV1(cost); + } + + @Override + public Cost visit(Plan plan, PlanContext context) { + return CostV1.zero(); + } + + @Override + public Cost visitPhysicalOlapScan(PhysicalOlapScan physicalOlapScan, PlanContext context) { + StatsDeriveResult statistics = context.getStatisticsWithCheck(); + return CostV1.ofCpu(statistics.getRowCount()); + } + + public Cost visitPhysicalSchemaScan(PhysicalSchemaScan physicalSchemaScan, PlanContext context) { + StatsDeriveResult statistics = context.getStatisticsWithCheck(); + return CostV1.ofCpu(statistics.getRowCount()); + } + + @Override + public Cost visitPhysicalStorageLayerAggregate( + PhysicalStorageLayerAggregate storageLayerAggregate, PlanContext context) { + CostV1 costValue = (CostV1) storageLayerAggregate.getRelation().accept(this, context); + // multiply a factor less than 1, so we can select PhysicalStorageLayerAggregate as far as possible + return new CostV1(costValue.getCpuCost() * 0.7, costValue.getMemoryCost(), + costValue.getNetworkCost(), costValue.getPenalty()); + } + + @Override + public Cost visitPhysicalFileScan(PhysicalFileScan physicalFileScan, PlanContext context) { + StatsDeriveResult statistics = context.getStatisticsWithCheck(); + return CostV1.ofCpu(statistics.getRowCount()); + } + + @Override + public Cost visitPhysicalProject(PhysicalProject physicalProject, PlanContext context) { + return CostV1.ofCpu(1); + } + + @Override + public Cost visitPhysicalJdbcScan(PhysicalJdbcScan physicalJdbcScan, PlanContext context) { + StatsDeriveResult statistics = context.getStatisticsWithCheck(); + return CostV1.ofCpu(statistics.getRowCount()); + } + + @Override + public Cost visitPhysicalEsScan(PhysicalEsScan physicalEsScan, PlanContext context) { + StatsDeriveResult statistics = context.getStatisticsWithCheck(); + return CostV1.ofCpu(statistics.getRowCount()); + } + + @Override + public Cost visitPhysicalQuickSort( + PhysicalQuickSort physicalQuickSort, PlanContext context) { + // TODO: consider two-phase sort and enforcer. + StatsDeriveResult statistics = context.getStatisticsWithCheck(); + StatsDeriveResult childStatistics = context.getChildStatistics(0); + if (physicalQuickSort.getSortPhase().isGather()) { + // Now we do more like two-phase sort, so penalise one-phase sort + statistics.updateRowCount(statistics.getRowCount() * 100); + } + return CostV1.of( + childStatistics.getRowCount(), + statistics.getRowCount(), + childStatistics.getRowCount()); + } + + @Override + public Cost visitPhysicalTopN(PhysicalTopN topN, PlanContext context) { + // TODO: consider two-phase sort and enforcer. + StatsDeriveResult statistics = context.getStatisticsWithCheck(); + StatsDeriveResult childStatistics = context.getChildStatistics(0); + if (topN.getSortPhase().isGather()) { + // Now we do more like two-phase sort, so penalise one-phase sort + statistics.updateRowCount(statistics.getRowCount() * 100); + } + return CostV1.of( + childStatistics.getRowCount(), + statistics.getRowCount(), + childStatistics.getRowCount()); + } + + @Override + public Cost visitPhysicalDistribute( + PhysicalDistribute distribute, PlanContext context) { + StatsDeriveResult childStatistics = context.getChildStatistics(0); + DistributionSpec spec = distribute.getDistributionSpec(); + // shuffle + if (spec instanceof DistributionSpecHash) { + return CostV1.of( + childStatistics.getRowCount(), + 0, + childStatistics.getRowCount()); + } + + // replicate + if (spec instanceof DistributionSpecReplicated) { + int beNumber = ConnectContext.get().getEnv().getClusterInfo().getBackendIds(true).size(); + int instanceNumber = ConnectContext.get().getSessionVariable().getParallelExecInstanceNum(); + beNumber = Math.max(1, beNumber); + double memLimit = ConnectContext.get().getSessionVariable().getMaxExecMemByte(); + //if build side is big, avoid use broadcast join + double rowsLimit = ConnectContext.get().getSessionVariable().getBroadcastRowCountLimit(); + double brMemlimit = ConnectContext.get().getSessionVariable().getBroadcastHashtableMemLimitPercentage(); + double buildSize = childStatistics.computeSize(); + if (buildSize * instanceNumber > memLimit * brMemlimit + || childStatistics.getRowCount() > rowsLimit) { + return CostV1.of(Double.MAX_VALUE, Double.MAX_VALUE, Double.MAX_VALUE); + } + return CostV1.of( + childStatistics.getRowCount() * beNumber, + childStatistics.getRowCount() * beNumber * instanceNumber, + childStatistics.getRowCount() * beNumber * instanceNumber); + } + + // gather + if (spec instanceof DistributionSpecGather) { + return CostV1.of( + childStatistics.getRowCount(), + 0, + childStatistics.getRowCount()); + } + + // any + return CostV1.of( + childStatistics.getRowCount(), + 0, + 0); + } + + @Override + public Cost visitPhysicalHashAggregate( + PhysicalHashAggregate aggregate, PlanContext context) { + // TODO: stage..... + + StatsDeriveResult statistics = context.getStatisticsWithCheck(); + StatsDeriveResult inputStatistics = context.getChildStatistics(0); + return CostV1.of(inputStatistics.getRowCount(), statistics.getRowCount(), 0); + } + + @Override + public Cost visitPhysicalHashJoin( + PhysicalHashJoin physicalHashJoin, PlanContext context) { + Preconditions.checkState(context.arity() == 2); + StatsDeriveResult outputStats = context.getStatisticsWithCheck(); + double outputRowCount = outputStats.getRowCount(); + + StatsDeriveResult probeStats = context.getChildStatistics(0); + StatsDeriveResult buildStats = context.getChildStatistics(1); + + double leftRowCount = probeStats.getRowCount(); + double rightRowCount = buildStats.getRowCount(); + /* + pattern1: L join1 (Agg1() join2 Agg2()) + result number of join2 may much less than Agg1. + but Agg1 and Agg2 are slow. so we need to punish this pattern1. + + pattern2: (L join1 Agg1) join2 agg2 + in pattern2, join1 and join2 takes more time, but Agg1 and agg2 can be processed in parallel. + */ + double penalty = HEAVY_OPERATOR_PUNISH_FACTOR + * Math.min(probeStats.getPenalty(), buildStats.getPenalty()); + if (buildStats.getWidth() >= 2) { + //penalty for right deep tree + penalty += rightRowCount; + } + + if (physicalHashJoin.getJoinType().isCrossJoin()) { + return CostV1.of(leftRowCount + rightRowCount + outputRowCount, + 0, + leftRowCount + rightRowCount, + penalty); + } + return CostV1.of(leftRowCount + rightRowCount + outputRowCount, + rightRowCount, + 0, + penalty + ); + } + + @Override + public Cost visitPhysicalNestedLoopJoin( + PhysicalNestedLoopJoin nestedLoopJoin, + PlanContext context) { + // TODO: copy from physicalHashJoin, should update according to physical nested loop join properties. + Preconditions.checkState(context.arity() == 2); + + StatsDeriveResult leftStatistics = context.getChildStatistics(0); + StatsDeriveResult rightStatistics = context.getChildStatistics(1); + + return CostV1.of( + leftStatistics.getRowCount() * rightStatistics.getRowCount(), + rightStatistics.getRowCount(), + 0); + } + + @Override + public Cost visitPhysicalAssertNumRows(PhysicalAssertNumRows assertNumRows, + PlanContext context) { + return CostV1.of( + assertNumRows.getAssertNumRowsElement().getDesiredNumOfRows(), + assertNumRows.getAssertNumRowsElement().getDesiredNumOfRows(), + 0 + ); + } + + @Override + public Cost visitPhysicalGenerate(PhysicalGenerate generate, PlanContext context) { + StatsDeriveResult statistics = context.getStatisticsWithCheck(); + return CostV1.of( + statistics.getRowCount(), + statistics.getRowCount(), + 0 + ); + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/cost/CostV1.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/cost/CostV1.java new file mode 100644 index 0000000000..6941bc03e6 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/cost/CostV1.java @@ -0,0 +1,126 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.cost; + +import com.google.common.base.Preconditions; + +class CostV1 implements Cost { + private static final CostV1 INFINITE = new CostV1(Double.POSITIVE_INFINITY, Double.POSITIVE_INFINITY, + Double.POSITIVE_INFINITY, Double.POSITIVE_INFINITY); + private static final CostV1 ZERO = new CostV1(0, 0, 0, 0); + + private final double cpuCost; + private final double memoryCost; + private final double networkCost; + //penalty for + // 1. right deep tree + // 2. right XXX join + private final double penalty; + + private final double cost; + + /** + * Constructor of CostEstimate. + */ + public CostV1(double cpuCost, double memoryCost, double networkCost, double penaltiy) { + // TODO: fix stats + cpuCost = Double.max(0, cpuCost); + memoryCost = Double.max(0, memoryCost); + networkCost = Double.max(0, networkCost); + this.cpuCost = cpuCost; + this.memoryCost = memoryCost; + this.networkCost = networkCost; + this.penalty = penaltiy; + + CostWeight costWeight = CostWeight.get(); + this.cost = costWeight.cpuWeight * cpuCost + costWeight.memoryWeight * memoryCost + + costWeight.networkWeight * networkCost + costWeight.penaltyWeight * penalty; + } + + public CostV1(double cost) { + this.cost = cost; + this.cpuCost = 0; + this.networkCost = 0; + this.memoryCost = 0; + this.penalty = 0; + } + + @Override + public Cost minus(Cost other) { + return new CostV1(cost - other.getValue()); + } + + public static CostV1 infinite() { + return INFINITE; + } + + public static CostV1 zero() { + return ZERO; + } + + public double getCpuCost() { + return cpuCost; + } + + public double getMemoryCost() { + return memoryCost; + } + + public double getNetworkCost() { + return networkCost; + } + + public double getPenalty() { + return penalty; + } + + public double getValue() { + return cost; + } + + public static CostV1 of(double cpuCost, double maxMemory, double networkCost, double rightDeepPenaltiy) { + return new CostV1(cpuCost, maxMemory, networkCost, rightDeepPenaltiy); + } + + public static CostV1 of(double cpuCost, double maxMemory, double networkCost) { + return new CostV1(cpuCost, maxMemory, networkCost, 0); + } + + public static CostV1 ofCpu(double cpuCost) { + return new CostV1(cpuCost, 0, 0, 0); + } + + public static CostV1 ofMemory(double memoryCost) { + return new CostV1(0, memoryCost, 0, 0); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append((long) cpuCost).append("/").append((long) memoryCost).append("/").append((long) networkCost) + .append("/").append((long) penalty); + return sb.toString(); + } + + @Override + public int compare(Cost other) { + Preconditions.checkArgument(other instanceof CostV1, "costValueV1 can only compare with costValueV1"); + return Double.compare(cost, ((CostV1) other).cost); + } +} + diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/cost/CostWeight.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/cost/CostWeight.java index c154a9bef0..ae23a22234 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/cost/CostWeight.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/cost/CostWeight.java @@ -17,16 +17,38 @@ package org.apache.doris.nereids.cost; +import org.apache.doris.qe.ConnectContext; + import com.google.common.base.Preconditions; /** * cost weight. + * The intuition behind `HEAVY_OPERATOR_PUNISH_FACTOR` is we need to avoid this form of join patterns: + * Plan1: L join ( AGG1(A) join AGG2(B)) + * But + * Plan2: L join AGG1(A) join AGG2(B) is welcomed. + * AGG is time-consuming operator. From the perspective of rowCount, nereids may choose Plan1, + * because `Agg1 join Agg2` generates few tuples. But in Plan1, Agg1 and Agg2 are done in serial, in Plan2, Agg1 and + * Agg2 are done in parallel. And hence, Plan1 should be punished. + *

+ * An example is tpch q15. */ public class CostWeight { - private final double cpuWeight; - private final double memoryWeight; - private final double networkWeight; - private final double penaltyWeight; + static final double CPU_WEIGHT = 1; + static final double MEMORY_WEIGHT = 1; + static final double NETWORK_WEIGHT = 1.5; + final double cpuWeight; + final double memoryWeight; + final double networkWeight; + /* + * About PENALTY: + * Except stats information, there are some special criteria in doris. + * For example, in hash join cluster, BE could build hash tables + * in parallel for left deep tree. And hence, we need to punish right deep tree. + * penalyWeight is the factor of punishment. + * The punishment is denoted by stats.penalty. + */ + final double penaltyWeight; /** * Constructor @@ -42,9 +64,8 @@ public class CostWeight { this.penaltyWeight = penaltyWeight; } - public double calculate(CostEstimate costEstimate) { - return costEstimate.getCpuCost() * cpuWeight + costEstimate.getMemoryCost() * memoryWeight - + costEstimate.getNetworkCost() * networkWeight - + costEstimate.getPenalty() * penaltyWeight; + public static CostWeight get() { + return new CostWeight(CPU_WEIGHT, MEMORY_WEIGHT, NETWORK_WEIGHT, + ConnectContext.get().getSessionVariable().getNereidsCboPenaltyFactor()); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/cascades/CostAndEnforcerJob.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/cascades/CostAndEnforcerJob.java index 9a5d410d20..769dab4296 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/cascades/CostAndEnforcerJob.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/cascades/CostAndEnforcerJob.java @@ -18,6 +18,7 @@ package org.apache.doris.nereids.jobs.cascades; import org.apache.doris.common.Pair; +import org.apache.doris.nereids.cost.Cost; import org.apache.doris.nereids.cost.CostCalculator; import org.apache.doris.nereids.jobs.Job; import org.apache.doris.nereids.jobs.JobContext; @@ -50,9 +51,9 @@ public class CostAndEnforcerJob extends Job implements Cloneable { private final GroupExpression groupExpression; // cost of current plan tree - private double curTotalCost; + private Cost curTotalCost; // cost of current plan node - private double curNodeCost; + private Cost curNodeCost; // List of request property to children // Example: Physical Hash Join @@ -114,8 +115,8 @@ public class CostAndEnforcerJob extends Job implements Cloneable { countJobExecutionTimesOfGroupExpressions(groupExpression); // Do init logic of root plan/groupExpr of `subplan`, only run once per task. if (curChildIndex == -1) { - curNodeCost = 0; - curTotalCost = 0; + curNodeCost = Cost.zero(); + curTotalCost = Cost.zero(); curChildIndex = 0; // List // [ child item: [leftProperties, rightProperties]] @@ -138,8 +139,8 @@ public class CostAndEnforcerJob extends Job implements Cloneable { // Calculate cost if (curChildIndex == 0 && prevChildIndex == -1) { curNodeCost = CostCalculator.calculateCost(groupExpression); - groupExpression.setCost(curNodeCost); - curTotalCost += curNodeCost; + groupExpression.setCost(curNodeCost.getValue()); + curTotalCost = curNodeCost; } // Handle all child plan node. @@ -149,7 +150,7 @@ public class CostAndEnforcerJob extends Job implements Cloneable { // Whether the child group was optimized for this requestChildProperty according to // the result of returning. - Optional> lowestCostPlanOpt + Optional> lowestCostPlanOpt = childGroup.getLowestCostPlan(requestChildProperty); if (!lowestCostPlanOpt.isPresent()) { @@ -166,7 +167,7 @@ public class CostAndEnforcerJob extends Job implements Cloneable { // Meaning that optimize recursively by derive job. prevChildIndex = curChildIndex; pushJob(clone()); - double newCostUpperBound = context.getCostUpperBound() - curTotalCost; + double newCostUpperBound = context.getCostUpperBound() - curTotalCost.getValue(); JobContext jobContext = new JobContext(context.getCascadesContext(), requestChildProperty, newCostUpperBound); pushJob(new OptimizeGroupJob(childGroup, jobContext)); @@ -182,10 +183,12 @@ public class CostAndEnforcerJob extends Job implements Cloneable { // plan's requestChildProperty).getOutputProperties(current plan's requestChildProperty) == child // plan's outputProperties`, the outputProperties must satisfy the origin requestChildProperty outputChildrenProperties.set(curChildIndex, outputProperties); - - curTotalCost += lowestCostExpr.getLowestCostTable().get(requestChildProperty).first; - if (curTotalCost > context.getCostUpperBound()) { - curTotalCost = Double.POSITIVE_INFINITY; + curTotalCost = CostCalculator.addChildCost(groupExpression.getPlan(), + curNodeCost, + lowestCostExpr.getCostValueByProperties(requestChildProperty), + curChildIndex); + if (curTotalCost.getValue() > context.getCostUpperBound()) { + curTotalCost = Cost.infinite(); } // the request child properties will be covered by the output properties // that corresponding to the request properties. so if we run a costAndEnforceJob of the same @@ -198,8 +201,8 @@ public class CostAndEnforcerJob extends Job implements Cloneable { if (!calculateEnforce(requestChildrenProperties, outputChildrenProperties)) { return; // if error exists, return } - if (curTotalCost < context.getCostUpperBound()) { - context.setCostUpperBound(curTotalCost); + if (curTotalCost.getValue() < context.getCostUpperBound()) { + context.setCostUpperBound(curTotalCost.getValue()); } } clear(); @@ -222,7 +225,6 @@ public class CostAndEnforcerJob extends Job implements Cloneable { // invalid enforce, return. return false; } - curTotalCost += enforceCost; // Not need to do pruning here because it has been done when we get the // best expr from the child group @@ -237,14 +239,18 @@ public class CostAndEnforcerJob extends Job implements Cloneable { return false; } StatsCalculator.estimate(groupExpression); - // previous curTotalCost exclude the exists best cost of current node - curTotalCost -= curNodeCost; + + // recompute cost after adjusting property curNodeCost = CostCalculator.calculateCost(groupExpression); // recompute current node's cost in current context - groupExpression.setCost(curNodeCost); - // (previous curTotalCost) - (previous curNodeCost) + (current curNodeCost) = (current curTotalCost). - // if current curTotalCost maybe less than previous curTotalCost, we will update the lowest cost and plan - // to the grouping expression and the owner group - curTotalCost += curNodeCost; + groupExpression.setCost(curNodeCost.getValue()); + curTotalCost = curNodeCost; + for (int i = 0; i < outputChildrenProperties.size(); i++) { + PhysicalProperties childProperties = outputChildrenProperties.get(i); + curTotalCost = CostCalculator.addChildCost(groupExpression.getPlan(), + curTotalCost, + groupExpression.child(i).getLowestCostPlan(childProperties).get().first, + i); + } // record map { outputProperty -> outputProperty }, { ANY -> outputProperty }, recordPropertyAndCost(groupExpression, outputProperty, PhysicalProperties.ANY, outputChildrenProperties); @@ -303,8 +309,8 @@ public class CostAndEnforcerJob extends Job implements Cloneable { lowestCostChildren.clear(); prevChildIndex = -1; curChildIndex = 0; - curTotalCost = 0; - curNodeCost = 0; + curTotalCost = Cost.zero(); + curNodeCost = Cost.zero(); } /** diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/GraphSimplifier.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/GraphSimplifier.java index db64127bf3..e80e1472c4 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/GraphSimplifier.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/GraphSimplifier.java @@ -19,6 +19,7 @@ package org.apache.doris.nereids.jobs.joinorder.hypergraph; import org.apache.doris.common.Pair; import org.apache.doris.nereids.PlanContext; +import org.apache.doris.nereids.cost.Cost; import org.apache.doris.nereids.cost.CostCalculator; import org.apache.doris.nereids.jobs.joinorder.hypergraph.bitmap.LongBitmap; import org.apache.doris.nereids.jobs.joinorder.hypergraph.receiver.Counter; @@ -62,7 +63,7 @@ public class GraphSimplifier { // because it's just used for simulating join. In fact, the graph simplifier // just generate the partial order of join operator. private final HashMap cacheStats = new HashMap<>(); - private final HashMap cacheCost = new HashMap<>(); + private final HashMap cacheCost = new HashMap<>(); private final Stack appliedSteps = new Stack<>(); private final Stack unAppliedSteps = new Stack<>(); @@ -81,7 +82,7 @@ public class GraphSimplifier { } for (Node node : graph.getNodes()) { cacheStats.put(node.getNodeMap(), node.getGroup().getStatistics()); - cacheCost.put(node.getNodeMap(), node.getGroup().getStatistics().getRowCount()); + cacheCost.put(node.getNodeMap(), Cost.withRowCount(node.getRowCount())); } circleDetector = new CircleDetector(edgeSize); @@ -382,26 +383,26 @@ public class GraphSimplifier { private SimplificationStep orderJoin(Pair edge1Before2, Pair edge2Before1, int edgeIndex1, int edgeIndex2) { - double cost1Before2 = calCost(edge1Before2.second, edge1Before2.first, + Cost cost1Before2 = calCost(edge1Before2.second, edge1Before2.first, cacheStats.get(edge1Before2.second.getLeft()), cacheStats.get(edge1Before2.second.getRight())); - double cost2Before1 = calCost(edge2Before1.second, edge1Before2.first, + Cost cost2Before1 = calCost(edge2Before1.second, edge1Before2.first, cacheStats.get(edge1Before2.second.getLeft()), cacheStats.get(edge1Before2.second.getRight())); double benefit = Double.MAX_VALUE; SimplificationStep step; // Choose the plan with smaller cost and make the simplification step to replace the old edge by it. - if (cost1Before2 < cost2Before1) { - if (cost1Before2 != 0) { - benefit = cost2Before1 / cost1Before2; + if (cost1Before2.getValue() < cost2Before1.getValue()) { + if (cost1Before2.getValue() != 0) { + benefit = cost2Before1.getValue() / cost1Before2.getValue(); } // choose edge1Before2 step = new SimplificationStep(benefit, edgeIndex1, edgeIndex2, edge1Before2.second.getLeft(), edge1Before2.second.getRight(), graph.getEdge(edgeIndex2).getLeft(), graph.getEdge(edgeIndex2).getRight()); } else { - if (cost2Before1 != 0) { - benefit = cost1Before2 / cost2Before1; + if (cost2Before1.getValue() != 0) { + benefit = cost1Before2.getValue() / cost2Before1.getValue(); } // choose edge2Before1 step = new SimplificationStep(benefit, edgeIndex2, edgeIndex1, edge2Before1.second.getLeft(), @@ -422,11 +423,11 @@ public class GraphSimplifier { return false; } - private double calCost(Edge edge, StatsDeriveResult stats, + private Cost calCost(Edge edge, StatsDeriveResult stats, StatsDeriveResult leftStats, StatsDeriveResult rightStats) { LogicalJoin join = edge.getJoin(); PlanContext planContext = new PlanContext(stats, leftStats, rightStats); - double cost = 0; + Cost cost = Cost.zero(); if (JoinUtils.shouldNestedLoopJoin(join)) { PhysicalNestedLoopJoin nestedLoopJoin = new PhysicalNestedLoopJoin<>( join.getJoinType(), @@ -437,6 +438,8 @@ public class GraphSimplifier { join.left(), join.right()); cost = CostCalculator.calculateCost(nestedLoopJoin, planContext); + cost = CostCalculator.addChildCost(nestedLoopJoin, cost, cacheCost.get(edge.getLeft()), 0); + cost = CostCalculator.addChildCost(nestedLoopJoin, cost, cacheCost.get(edge.getRight()), 1); } else { PhysicalHashJoin hashJoin = new PhysicalHashJoin<>( join.getJoinType(), @@ -448,8 +451,11 @@ public class GraphSimplifier { join.left(), join.right()); cost = CostCalculator.calculateCost(hashJoin, planContext); + cost = CostCalculator.addChildCost(hashJoin, cost, cacheCost.get(edge.getLeft()), 0); + cost = CostCalculator.addChildCost(hashJoin, cost, cacheCost.get(edge.getRight()), 1); } - return cost + cacheCost.get(edge.getLeft()) + cacheCost.get(edge.getRight()); + + return cost; } /** diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/Group.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/Group.java index 88e20f34e3..36fe488e6c 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/Group.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/Group.java @@ -18,6 +18,7 @@ package org.apache.doris.nereids.memo; import org.apache.doris.common.Pair; +import org.apache.doris.nereids.cost.Cost; import org.apache.doris.nereids.properties.LogicalProperties; import org.apache.doris.nereids.properties.PhysicalProperties; import org.apache.doris.nereids.trees.plans.JoinType; @@ -60,7 +61,7 @@ public class Group { // Map of cost lower bounds // Map required plan props to cost lower bound of corresponding plan - private final Map> lowestCostPlans = Maps.newHashMap(); + private final Map> lowestCostPlans = Maps.newHashMap(); private boolean isExplored = false; @@ -189,7 +190,7 @@ public class Group { * @param physicalProperties the physical property constraints * @return {@link Optional} of cost and {@link GroupExpression} of physical plan pair. */ - public Optional> getLowestCostPlan(PhysicalProperties physicalProperties) { + public Optional> getLowestCostPlan(PhysicalProperties physicalProperties) { if (physicalProperties == null || lowestCostPlans.isEmpty()) { return Optional.empty(); } @@ -206,9 +207,9 @@ public class Group { /** * Set or update lowestCostPlans: properties --> Pair.of(cost, expression) */ - public void setBestPlan(GroupExpression expression, double cost, PhysicalProperties properties) { + public void setBestPlan(GroupExpression expression, Cost cost, PhysicalProperties properties) { if (lowestCostPlans.containsKey(properties)) { - if (lowestCostPlans.get(properties).first > cost) { + if (lowestCostPlans.get(properties).first.getValue() > cost.getValue()) { lowestCostPlans.put(properties, Pair.of(cost, expression)); } } else { @@ -219,8 +220,9 @@ public class Group { /** * replace best plan with new properties */ - public void replaceBestPlanProperty(PhysicalProperties oldProperty, PhysicalProperties newProperty, double cost) { - Pair pair = lowestCostPlans.get(oldProperty); + public void replaceBestPlanProperty(PhysicalProperties oldProperty, + PhysicalProperties newProperty, Cost cost) { + Pair pair = lowestCostPlans.get(oldProperty); GroupExpression lowestGroupExpr = pair.second; lowestGroupExpr.updateLowestCostTable(newProperty, lowestGroupExpr.getInputPropertiesList(oldProperty), cost); @@ -232,11 +234,11 @@ public class Group { * replace oldGroupExpression with newGroupExpression in lowestCostPlans. */ public void replaceBestPlanGroupExpr(GroupExpression oldGroupExpression, GroupExpression newGroupExpression) { - Map> needReplaceBestExpressions = Maps.newHashMap(); - for (Iterator>> iterator = + Map> needReplaceBestExpressions = Maps.newHashMap(); + for (Iterator>> iterator = lowestCostPlans.entrySet().iterator(); iterator.hasNext(); ) { - Map.Entry> entry = iterator.next(); - Pair pair = entry.getValue(); + Map.Entry> entry = iterator.next(); + Pair pair = entry.getValue(); if (pair.second.equals(oldGroupExpression)) { needReplaceBestExpressions.put(entry.getKey(), Pair.of(pair.first, newGroupExpression)); iterator.remove(); @@ -345,7 +347,8 @@ public class Group { if (!target.lowestCostPlans.containsKey(physicalProperties)) { target.lowestCostPlans.put(physicalProperties, costAndGroupExpr); } else { - if (costAndGroupExpr.first < target.lowestCostPlans.get(physicalProperties).first) { + if (costAndGroupExpr.first.getValue() + < target.lowestCostPlans.get(physicalProperties).first.getValue()) { target.lowestCostPlans.put(physicalProperties, costAndGroupExpr); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/GroupExpression.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/GroupExpression.java index f1374e2d80..4555935f25 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/GroupExpression.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/GroupExpression.java @@ -18,7 +18,7 @@ package org.apache.doris.nereids.memo; import org.apache.doris.common.Pair; -import org.apache.doris.nereids.cost.CostEstimate; +import org.apache.doris.nereids.cost.Cost; import org.apache.doris.nereids.metrics.EventChannel; import org.apache.doris.nereids.metrics.EventProducer; import org.apache.doris.nereids.metrics.consumer.LogConsumer; @@ -48,7 +48,6 @@ public class GroupExpression { EventChannel.getDefaultChannel().addConsumers(new LogConsumer(CostStateUpdateEvent.class, EventChannel.LOG))); private double cost = 0.0; - private CostEstimate costEstimate = null; private Group ownerGroup; private final List children; private final Plan plan; @@ -60,7 +59,7 @@ public class GroupExpression { // Mapping from output properties to the corresponding best cost, statistics, and child properties. // key is the physical properties the group expression support for its parent // and value is cost and request physical properties to its children. - private final Map>> lowestCostTable; + private final Map>> lowestCostTable; // Each physical group expression maintains mapping incoming requests to the corresponding child requests. // key is the output physical properties satisfying the incoming request properties // value is the request physical properties @@ -182,7 +181,7 @@ public class GroupExpression { this.isUnused = isUnused; } - public Map>> getLowestCostTable() { + public Map>> getLowestCostTable() { return lowestCostTable; } @@ -198,10 +197,10 @@ public class GroupExpression { * @return true if lowest cost table change. */ public boolean updateLowestCostTable(PhysicalProperties outputProperties, - List childrenInputProperties, double cost) { - COST_STATE_TRACER.log(CostStateUpdateEvent.of(this, cost, outputProperties)); + List childrenInputProperties, Cost cost) { + COST_STATE_TRACER.log(CostStateUpdateEvent.of(this, cost.getValue(), outputProperties)); if (lowestCostTable.containsKey(outputProperties)) { - if (lowestCostTable.get(outputProperties).first > cost) { + if (lowestCostTable.get(outputProperties).first.getValue() > cost.getValue()) { lowestCostTable.put(outputProperties, Pair.of(cost, childrenInputProperties)); return true; } else { @@ -220,6 +219,11 @@ public class GroupExpression { * @return Lowest cost to satisfy that property */ public double getCostByProperties(PhysicalProperties property) { + Preconditions.checkState(lowestCostTable.containsKey(property)); + return lowestCostTable.get(property).first.getValue(); + } + + public Cost getCostValueByProperties(PhysicalProperties property) { Preconditions.checkState(lowestCostTable.containsKey(property)); return lowestCostTable.get(property).first; } @@ -263,14 +267,6 @@ public class GroupExpression { this.cost = cost; } - public CostEstimate getCostEstimate() { - return costEstimate; - } - - public void setCostEstimate(CostEstimate costEstimate) { - this.costEstimate = costEstimate; - } - @Override public boolean equals(Object o) { if (this == o) { @@ -310,9 +306,8 @@ public class GroupExpression { builder.append("#").append(ownerGroup.getGroupId().asInt()); } - if (costEstimate != null) { - builder.append(" cost=").append((long) cost).append(" (").append(costEstimate).append(")"); - } + builder.append(" cost=").append((long) cost); + builder.append(" estRows=").append(estOutputRowCount); builder.append(" (plan=").append(plan.toString()).append(") children=["); for (Group group : children) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/Memo.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/Memo.java index b27d9a1f59..acb2cf1296 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/Memo.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/Memo.java @@ -22,6 +22,8 @@ import org.apache.doris.common.Pair; import org.apache.doris.nereids.CascadesContext; import org.apache.doris.nereids.StatementContext; import org.apache.doris.nereids.analyzer.CTEContext; +import org.apache.doris.nereids.cost.Cost; +import org.apache.doris.nereids.cost.CostCalculator; import org.apache.doris.nereids.metrics.EventChannel; import org.apache.doris.nereids.metrics.EventProducer; import org.apache.doris.nereids.metrics.consumer.LogConsumer; @@ -765,26 +767,27 @@ public class Memo { return res; } - double bestChildrenCost = 0; List>> children = new ArrayList<>(); for (int i = 0; i < inputProperties.size(); i++) { // To avoid reach a circle, we don't allow ranking the same group with the same physical properties. Preconditions.checkArgument(!groupExpression.child(i).equals(groupExpression.getOwnerGroup()) || !prop.equals(inputProperties.get(i))); - bestChildrenCost += groupExpression.children().get(i).getLowestCostPlan(inputProperties.get(i)).get().first; List> idCostPair = rankGroup(groupExpression.child(i), inputProperties.get(i)); children.add(idCostPair); } List>> childrenId = new ArrayList<>(); permute(children, 0, childrenId, new ArrayList<>()); + Cost cost = CostCalculator.calculateCost(groupExpression); for (Pair> c : childrenId) { - double childCost = 0; + Cost totalCost = cost; for (int i = 0; i < children.size(); i++) { - childCost += children.get(i).get(c.second.get(i)).second; + totalCost = CostCalculator.addChildCost(groupExpression.getPlan(), + totalCost, + groupExpression.child(i).getLowestCostPlan(inputProperties.get(i)).get().first, + i); } - res.add(Pair.of(c.first, - childCost + groupExpression.getCostByProperties(prop) - bestChildrenCost)); + res.add(Pair.of(c.first, totalCost.getValue())); } return res; } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/ChildrenPropertiesRegulator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/ChildrenPropertiesRegulator.java index 47753e389e..8aed9504e1 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/ChildrenPropertiesRegulator.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/ChildrenPropertiesRegulator.java @@ -18,6 +18,7 @@ package org.apache.doris.nereids.properties; import org.apache.doris.common.Pair; +import org.apache.doris.nereids.cost.Cost; import org.apache.doris.nereids.cost.CostCalculator; import org.apache.doris.nereids.jobs.JobContext; import org.apache.doris.nereids.memo.GroupExpression; @@ -103,12 +104,12 @@ public class ChildrenPropertiesRegulator extends PlanVisitor { DistributionSpecHash rightHashSpec = (DistributionSpecHash) rightDistributionSpec; GroupExpression leftChild = children.get(0); - final Pair> leftLowest + final Pair> leftLowest = leftChild.getLowestCostTable().get(childrenProperties.get(0)); PhysicalProperties leftOutput = leftChild.getOutputProperties(childrenProperties.get(0)); GroupExpression rightChild = children.get(1); - Pair> rightLowest + Pair> rightLowest = rightChild.getLowestCostTable().get(childrenProperties.get(1)); PhysicalProperties rightOutput = rightChild.getOutputProperties(childrenProperties.get(1)); @@ -128,20 +129,20 @@ public class ChildrenPropertiesRegulator extends PlanVisitor { PhysicalProperties rightRequireProperties = calRightRequiredOfBucketShuffleJoin(leftHashSpec, rightHashSpec); if (!rightOutput.equals(rightRequireProperties)) { - enforceCost += updateChildEnforceAndCost(rightChild, rightOutput, + updateChildEnforceAndCost(rightChild, rightOutput, (DistributionSpecHash) rightRequireProperties.getDistributionSpec(), rightLowest.first); } childrenProperties.set(1, rightRequireProperties); return enforceCost; } - enforceCost += updateChildEnforceAndCost(leftChild, leftOutput, + updateChildEnforceAndCost(leftChild, leftOutput, (DistributionSpecHash) requiredProperties.get(0).getDistributionSpec(), leftLowest.first); childrenProperties.set(0, requiredProperties.get(0)); } // check right hand must distribute. if (rightHashSpec.getShuffleType() != ShuffleType.ENFORCED) { - enforceCost += updateChildEnforceAndCost(rightChild, rightOutput, + updateChildEnforceAndCost(rightChild, rightOutput, (DistributionSpecHash) requiredProperties.get(1).getDistributionSpec(), rightLowest.first); childrenProperties.set(1, requiredProperties.get(1)); } @@ -175,15 +176,13 @@ public class ChildrenPropertiesRegulator extends PlanVisitor { } private double updateChildEnforceAndCost(GroupExpression child, PhysicalProperties childOutput, - DistributionSpecHash required, double currentCost) { - double enforceCost = 0; + DistributionSpecHash required, Cost currentCost) { if (child.getPlan() instanceof PhysicalDistribute) { //To avoid continuous distribute operator, we just enforce the child's child childOutput = child.getInputPropertiesList(childOutput).get(0); - Pair newChildAndCost + Pair newChildAndCost = child.getOwnerGroup().getLowestCostPlan(childOutput).get(); child = newChildAndCost.second; - enforceCost = newChildAndCost.first - currentCost; currentCost = newChildAndCost.first; } @@ -193,13 +192,16 @@ public class ChildrenPropertiesRegulator extends PlanVisitor { PhysicalProperties newOutputProperty = new PhysicalProperties(outputDistributionSpec); GroupExpression enforcer = outputDistributionSpec.addEnforcer(child.getOwnerGroup()); jobContext.getCascadesContext().getMemo().addEnforcerPlan(enforcer, child.getOwnerGroup()); - enforceCost = Double.sum(enforceCost, CostCalculator.calculateCost(enforcer)); + Cost totalCost = CostCalculator.addChildCost(enforcer.getPlan(), + currentCost, + CostCalculator.calculateCost(enforcer), + 0); if (enforcer.updateLowestCostTable(newOutputProperty, - Lists.newArrayList(childOutput), enforceCost + currentCost)) { + Lists.newArrayList(childOutput), totalCost)) { enforcer.putOutputPropertiesMap(newOutputProperty, newOutputProperty); } - child.getOwnerGroup().setBestPlan(enforcer, enforceCost + currentCost, newOutputProperty); + child.getOwnerGroup().setBestPlan(enforcer, totalCost, newOutputProperty); return enforceCost; } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/EnforceMissingPropertiesHelper.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/EnforceMissingPropertiesHelper.java index 7b9ba19ccb..d38e9b8c2d 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/EnforceMissingPropertiesHelper.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/EnforceMissingPropertiesHelper.java @@ -17,6 +17,7 @@ package org.apache.doris.nereids.properties; +import org.apache.doris.nereids.cost.Cost; import org.apache.doris.nereids.cost.CostCalculator; import org.apache.doris.nereids.jobs.JobContext; import org.apache.doris.nereids.memo.GroupExpression; @@ -38,16 +39,16 @@ public class EnforceMissingPropertiesHelper { EventChannel.getDefaultChannel().addConsumers(new LogConsumer(EnforcerEvent.class, EventChannel.LOG))); private final JobContext context; private final GroupExpression groupExpression; - private double curTotalCost; + private Cost curTotalCost; public EnforceMissingPropertiesHelper(JobContext context, GroupExpression groupExpression, - double curTotalCost) { + Cost curTotalCost) { this.context = context; this.groupExpression = groupExpression; this.curTotalCost = curTotalCost; } - public double getCurTotalCost() { + public Cost getCurTotalCost() { return curTotalCost; } @@ -83,7 +84,8 @@ public class EnforceMissingPropertiesHelper { */ private PhysicalProperties enforceDistributionButMeetSort(PhysicalProperties output, PhysicalProperties request) { groupExpression.getOwnerGroup() - .replaceBestPlanProperty(output, PhysicalProperties.ANY, groupExpression.getCostByProperties(output)); + .replaceBestPlanProperty( + output, PhysicalProperties.ANY, groupExpression.getCostValueByProperties(output)); return enforceSortAndDistribution(output, request); } @@ -149,8 +151,10 @@ public class EnforceMissingPropertiesHelper { context.getCascadesContext().getMemo().addEnforcerPlan(enforcer, groupExpression.getOwnerGroup()); ENFORCER_TRACER.log(EnforcerEvent.of(groupExpression, ((PhysicalPlan) enforcer.getPlan()), oldOutputProperty, newOutputProperty)); - curTotalCost += CostCalculator.calculateCost(enforcer); - + curTotalCost = CostCalculator.addChildCost(enforcer.getPlan(), + CostCalculator.calculateCost(enforcer), + curTotalCost, + 0); if (enforcer.updateLowestCostTable(newOutputProperty, Lists.newArrayList(oldOutputProperty), curTotalCost)) { enforcer.putOutputPropertiesMap(newOutputProperty, newOutputProperty); diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/memo/MemoTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/memo/MemoTest.java index 40d9d6289a..ee270eccac 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/memo/MemoTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/memo/MemoTest.java @@ -21,6 +21,7 @@ import org.apache.doris.catalog.OlapTable; import org.apache.doris.common.jmockit.Deencapsulation; import org.apache.doris.nereids.analyzer.UnboundRelation; import org.apache.doris.nereids.analyzer.UnboundSlot; +import org.apache.doris.nereids.cost.Cost; import org.apache.doris.nereids.properties.LogicalProperties; import org.apache.doris.nereids.properties.PhysicalProperties; import org.apache.doris.nereids.properties.UnboundLogicalProperties; @@ -91,7 +92,7 @@ class MemoTest implements MemoPatternMatchSupported { FakePlan fakePlan = new FakePlan(); GroupExpression srcParentExpression = new GroupExpression(fakePlan, Lists.newArrayList(srcGroup)); Group srcParentGroup = new Group(new GroupId(0), srcParentExpression, new LogicalProperties(ArrayList::new)); - srcParentGroup.setBestPlan(srcParentExpression, Double.MIN_VALUE, PhysicalProperties.ANY); + srcParentGroup.setBestPlan(srcParentExpression, Cost.zero(), PhysicalProperties.ANY); GroupExpression dstParentExpression = new GroupExpression(fakePlan, Lists.newArrayList(dstGroup)); Group dstParentGroup = new Group(new GroupId(1), dstParentExpression, new LogicalProperties(ArrayList::new));