diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/NereidsPlanner.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/NereidsPlanner.java
index 33025b2559..6f35193797 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/NereidsPlanner.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/NereidsPlanner.java
@@ -267,7 +267,7 @@ public class NereidsPlanner extends Planner {
if (nthPlan <= 1) {
cost = rootGroup.getLowestCostPlan(physicalProperties).orElseThrow(
() -> new AnalysisException("lowestCostPlans with physicalProperties("
- + physicalProperties + ") doesn't exist in root group")).first;
+ + physicalProperties + ") doesn't exist in root group")).first.getValue();
return chooseBestPlan(rootGroup, physicalProperties);
}
Memo memo = cascadesContext.getMemo();
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/cost/Cost.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/cost/Cost.java
new file mode 100644
index 0000000000..c616b5a858
--- /dev/null
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/cost/Cost.java
@@ -0,0 +1,44 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.cost;
+
+/**
+ * Cost encapsulate the real cost with double type.
+ * We do this because we want to customize the operation of adding child cost
+ * according different operator
+ */
+public interface Cost {
+ public double getValue();
+
+ public int compare(Cost other);
+
+ public Cost minus(Cost other);
+
+ public static Cost withRowCount(double rowCount) {
+ return new CostV1(rowCount, 0, 0, 0);
+ }
+
+ public static Cost zero() {
+ return CostV1.zero();
+ }
+
+ public static Cost infinite() {
+ return CostV1.infinite();
+ }
+}
+
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/cost/CostCalculator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/cost/CostCalculator.java
index 6a805cd665..36c4dcfec7 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/cost/CostCalculator.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/cost/CostCalculator.java
@@ -20,31 +20,7 @@ package org.apache.doris.nereids.cost;
import org.apache.doris.nereids.PlanContext;
import org.apache.doris.nereids.annotation.Developing;
import org.apache.doris.nereids.memo.GroupExpression;
-import org.apache.doris.nereids.properties.DistributionSpec;
-import org.apache.doris.nereids.properties.DistributionSpecGather;
-import org.apache.doris.nereids.properties.DistributionSpecHash;
-import org.apache.doris.nereids.properties.DistributionSpecReplicated;
import org.apache.doris.nereids.trees.plans.Plan;
-import org.apache.doris.nereids.trees.plans.physical.PhysicalAssertNumRows;
-import org.apache.doris.nereids.trees.plans.physical.PhysicalDistribute;
-import org.apache.doris.nereids.trees.plans.physical.PhysicalEsScan;
-import org.apache.doris.nereids.trees.plans.physical.PhysicalFileScan;
-import org.apache.doris.nereids.trees.plans.physical.PhysicalGenerate;
-import org.apache.doris.nereids.trees.plans.physical.PhysicalHashAggregate;
-import org.apache.doris.nereids.trees.plans.physical.PhysicalHashJoin;
-import org.apache.doris.nereids.trees.plans.physical.PhysicalJdbcScan;
-import org.apache.doris.nereids.trees.plans.physical.PhysicalNestedLoopJoin;
-import org.apache.doris.nereids.trees.plans.physical.PhysicalOlapScan;
-import org.apache.doris.nereids.trees.plans.physical.PhysicalProject;
-import org.apache.doris.nereids.trees.plans.physical.PhysicalQuickSort;
-import org.apache.doris.nereids.trees.plans.physical.PhysicalSchemaScan;
-import org.apache.doris.nereids.trees.plans.physical.PhysicalStorageLayerAggregate;
-import org.apache.doris.nereids.trees.plans.physical.PhysicalTopN;
-import org.apache.doris.nereids.trees.plans.visitor.PlanVisitor;
-import org.apache.doris.qe.ConnectContext;
-import org.apache.doris.statistics.StatsDeriveResult;
-
-import com.google.common.base.Preconditions;
/**
* Calculate the cost of a plan.
@@ -53,264 +29,21 @@ import com.google.common.base.Preconditions;
@Developing
//TODO: memory cost and network cost should be estimated by byte size.
public class CostCalculator {
- static final double CPU_WEIGHT = 1;
- static final double MEMORY_WEIGHT = 1;
- static final double NETWORK_WEIGHT = 1.5;
-
- /**
- * The intuition behind `HEAVY_OPERATOR_PUNISH_FACTOR` is we need to avoid this form of join patterns:
- * Plan1: L join ( AGG1(A) join AGG2(B))
- * But
- * Plan2: L join AGG1(A) join AGG2(B) is welcomed.
- * AGG is time-consuming operator. From the perspective of rowCount, nereids may choose Plan1,
- * because `Agg1 join Agg2` generates few tuples. But in Plan1, Agg1 and Agg2 are done in serial, in Plan2, Agg1 and
- * Agg2 are done in parallel. And hence, Plan1 should be punished.
- *
- * An example is tpch q15.
- */
- static final double HEAVY_OPERATOR_PUNISH_FACTOR = 6.0;
-
/**
* Constructor.
*/
- public static double calculateCost(GroupExpression groupExpression) {
+ public static Cost calculateCost(GroupExpression groupExpression) {
PlanContext planContext = new PlanContext(groupExpression);
- CostEstimator costCalculator = new CostEstimator();
- CostEstimate costEstimate = groupExpression.getPlan().accept(costCalculator, planContext);
- groupExpression.setCostEstimate(costEstimate);
- /*
- * About PENALTY:
- * Except stats information, there are some special criteria in doris.
- * For example, in hash join cluster, BE could build hash tables
- * in parallel for left deep tree. And hence, we need to punish right deep tree.
- * penalyWeight is the factor of punishment.
- * The punishment is denoted by stats.penalty.
- */
- CostWeight costWeight = new CostWeight(CPU_WEIGHT, MEMORY_WEIGHT, NETWORK_WEIGHT,
- ConnectContext.get().getSessionVariable().getNereidsCboPenaltyFactor());
- return costWeight.calculate(costEstimate);
+ CostModelV1 costModel = new CostModelV1();
+ return groupExpression.getPlan().accept(costModel, planContext);
}
- public static double calculateCost(Plan plan, PlanContext planContext) {
- CostEstimator costCalculator = new CostEstimator();
- CostEstimate costEstimate = plan.accept(costCalculator, planContext);
- CostWeight costWeight = new CostWeight(CPU_WEIGHT, MEMORY_WEIGHT, NETWORK_WEIGHT,
- ConnectContext.get().getSessionVariable().getNereidsCboPenaltyFactor());
- return costWeight.calculate(costEstimate);
+ public static Cost calculateCost(Plan plan, PlanContext planContext) {
+ CostModelV1 costModel = new CostModelV1();
+ return plan.accept(costModel, planContext);
}
- private static class CostEstimator extends PlanVisitor {
- @Override
- public CostEstimate visit(Plan plan, PlanContext context) {
- return CostEstimate.zero();
- }
-
- @Override
- public CostEstimate visitPhysicalOlapScan(PhysicalOlapScan physicalOlapScan, PlanContext context) {
- StatsDeriveResult statistics = context.getStatisticsWithCheck();
- return CostEstimate.ofCpu(statistics.getRowCount());
- }
-
- public CostEstimate visitPhysicalSchemaScan(PhysicalSchemaScan physicalSchemaScan, PlanContext context) {
- StatsDeriveResult statistics = context.getStatisticsWithCheck();
- return CostEstimate.ofCpu(statistics.getRowCount());
- }
-
- @Override
- public CostEstimate visitPhysicalStorageLayerAggregate(
- PhysicalStorageLayerAggregate storageLayerAggregate, PlanContext context) {
- CostEstimate costEstimate = storageLayerAggregate.getRelation().accept(this, context);
- // multiply a factor less than 1, so we can select PhysicalStorageLayerAggregate as far as possible
- return new CostEstimate(costEstimate.getCpuCost() * 0.7, costEstimate.getMemoryCost(),
- costEstimate.getNetworkCost(), costEstimate.getPenalty());
- }
-
- @Override
- public CostEstimate visitPhysicalFileScan(PhysicalFileScan physicalFileScan, PlanContext context) {
- StatsDeriveResult statistics = context.getStatisticsWithCheck();
- return CostEstimate.ofCpu(statistics.getRowCount());
- }
-
- @Override
- public CostEstimate visitPhysicalProject(PhysicalProject extends Plan> physicalProject, PlanContext context) {
- return CostEstimate.ofCpu(1);
- }
-
- @Override
- public CostEstimate visitPhysicalJdbcScan(PhysicalJdbcScan physicalJdbcScan, PlanContext context) {
- StatsDeriveResult statistics = context.getStatisticsWithCheck();
- return CostEstimate.ofCpu(statistics.getRowCount());
- }
-
- @Override
- public CostEstimate visitPhysicalEsScan(PhysicalEsScan physicalEsScan, PlanContext context) {
- StatsDeriveResult statistics = context.getStatisticsWithCheck();
- return CostEstimate.ofCpu(statistics.getRowCount());
- }
-
- @Override
- public CostEstimate visitPhysicalQuickSort(
- PhysicalQuickSort extends Plan> physicalQuickSort, PlanContext context) {
- // TODO: consider two-phase sort and enforcer.
- StatsDeriveResult statistics = context.getStatisticsWithCheck();
- StatsDeriveResult childStatistics = context.getChildStatistics(0);
- if (physicalQuickSort.getSortPhase().isGather()) {
- // Now we do more like two-phase sort, so penalise one-phase sort
- statistics.updateRowCount(statistics.getRowCount() * 100);
- }
- return CostEstimate.of(
- childStatistics.getRowCount(),
- statistics.getRowCount(),
- childStatistics.getRowCount());
- }
-
- @Override
- public CostEstimate visitPhysicalTopN(PhysicalTopN extends Plan> topN, PlanContext context) {
- // TODO: consider two-phase sort and enforcer.
- StatsDeriveResult statistics = context.getStatisticsWithCheck();
- StatsDeriveResult childStatistics = context.getChildStatistics(0);
- if (topN.getSortPhase().isGather()) {
- // Now we do more like two-phase sort, so penalise one-phase sort
- statistics.updateRowCount(statistics.getRowCount() * 100);
- }
- return CostEstimate.of(
- childStatistics.getRowCount(),
- statistics.getRowCount(),
- childStatistics.getRowCount());
- }
-
- @Override
- public CostEstimate visitPhysicalDistribute(
- PhysicalDistribute extends Plan> distribute, PlanContext context) {
- StatsDeriveResult childStatistics = context.getChildStatistics(0);
- DistributionSpec spec = distribute.getDistributionSpec();
- // shuffle
- if (spec instanceof DistributionSpecHash) {
- return CostEstimate.of(
- childStatistics.getRowCount(),
- 0,
- childStatistics.getRowCount());
- }
-
- // replicate
- if (spec instanceof DistributionSpecReplicated) {
- int beNumber = ConnectContext.get().getEnv().getClusterInfo().getBackendIds(true).size();
- int instanceNumber = ConnectContext.get().getSessionVariable().getParallelExecInstanceNum();
- beNumber = Math.max(1, beNumber);
- double memLimit = ConnectContext.get().getSessionVariable().getMaxExecMemByte();
- //if build side is big, avoid use broadcast join
- double rowsLimit = ConnectContext.get().getSessionVariable().getBroadcastRowCountLimit();
- double brMemlimit = ConnectContext.get().getSessionVariable().getBroadcastHashtableMemLimitPercentage();
- double buildSize = childStatistics.computeSize();
- if (buildSize * instanceNumber > memLimit * brMemlimit
- || childStatistics.getRowCount() > rowsLimit) {
- return CostEstimate.of(Double.MAX_VALUE, Double.MAX_VALUE, Double.MAX_VALUE);
- }
- return CostEstimate.of(
- childStatistics.getRowCount() * beNumber,
- childStatistics.getRowCount() * beNumber * instanceNumber,
- childStatistics.getRowCount() * beNumber * instanceNumber);
- }
-
- // gather
- if (spec instanceof DistributionSpecGather) {
- return CostEstimate.of(
- childStatistics.getRowCount(),
- 0,
- childStatistics.getRowCount());
- }
-
- // any
- return CostEstimate.of(
- childStatistics.getRowCount(),
- 0,
- 0);
- }
-
- @Override
- public CostEstimate visitPhysicalHashAggregate(
- PhysicalHashAggregate extends Plan> aggregate, PlanContext context) {
- // TODO: stage.....
-
- StatsDeriveResult statistics = context.getStatisticsWithCheck();
- StatsDeriveResult inputStatistics = context.getChildStatistics(0);
- return CostEstimate.of(inputStatistics.getRowCount(), statistics.getRowCount(), 0);
- }
-
- @Override
- public CostEstimate visitPhysicalHashJoin(
- PhysicalHashJoin extends Plan, ? extends Plan> physicalHashJoin, PlanContext context) {
- Preconditions.checkState(context.arity() == 2);
- StatsDeriveResult outputStats = context.getStatisticsWithCheck();
- double outputRowCount = outputStats.getRowCount();
-
- StatsDeriveResult probeStats = context.getChildStatistics(0);
- StatsDeriveResult buildStats = context.getChildStatistics(1);
-
- double leftRowCount = probeStats.getRowCount();
- double rightRowCount = buildStats.getRowCount();
- /*
- pattern1: L join1 (Agg1() join2 Agg2())
- result number of join2 may much less than Agg1.
- but Agg1 and Agg2 are slow. so we need to punish this pattern1.
-
- pattern2: (L join1 Agg1) join2 agg2
- in pattern2, join1 and join2 takes more time, but Agg1 and agg2 can be processed in parallel.
- */
- double penalty = CostCalculator.HEAVY_OPERATOR_PUNISH_FACTOR
- * Math.min(probeStats.getPenalty(), buildStats.getPenalty());
- if (buildStats.getWidth() >= 2) {
- //penalty for right deep tree
- penalty += rightRowCount;
- }
-
- if (physicalHashJoin.getJoinType().isCrossJoin()) {
- return CostEstimate.of(leftRowCount + rightRowCount + outputRowCount,
- 0,
- leftRowCount + rightRowCount,
- penalty);
- }
- return CostEstimate.of(leftRowCount + rightRowCount + outputRowCount,
- rightRowCount,
- 0,
- penalty
- );
- }
-
- @Override
- public CostEstimate visitPhysicalNestedLoopJoin(
- PhysicalNestedLoopJoin extends Plan, ? extends Plan> nestedLoopJoin,
- PlanContext context) {
- // TODO: copy from physicalHashJoin, should update according to physical nested loop join properties.
- Preconditions.checkState(context.arity() == 2);
-
- StatsDeriveResult leftStatistics = context.getChildStatistics(0);
- StatsDeriveResult rightStatistics = context.getChildStatistics(1);
-
- return CostEstimate.of(
- leftStatistics.getRowCount() * rightStatistics.getRowCount(),
- rightStatistics.getRowCount(),
- 0);
- }
-
- @Override
- public CostEstimate visitPhysicalAssertNumRows(PhysicalAssertNumRows extends Plan> assertNumRows,
- PlanContext context) {
- return CostEstimate.of(
- assertNumRows.getAssertNumRowsElement().getDesiredNumOfRows(),
- assertNumRows.getAssertNumRowsElement().getDesiredNumOfRows(),
- 0
- );
- }
-
- @Override
- public CostEstimate visitPhysicalGenerate(PhysicalGenerate extends Plan> generate, PlanContext context) {
- StatsDeriveResult statistics = context.getStatisticsWithCheck();
- return CostEstimate.of(
- statistics.getRowCount(),
- statistics.getRowCount(),
- 0
- );
- }
+ public static Cost addChildCost(Plan plan, Cost planCost, Cost childCost, int index) {
+ return CostModelV1.addChildCost(plan, planCost, childCost, index);
}
}
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/cost/CostEstimate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/cost/CostEstimate.java
deleted file mode 100644
index 63600f4466..0000000000
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/cost/CostEstimate.java
+++ /dev/null
@@ -1,128 +0,0 @@
-// Licensed to the Apache Software Foundation (ASF) under one
-// or more contributor license agreements. See the NOTICE file
-// distributed with this work for additional information
-// regarding copyright ownership. The ASF licenses this file
-// to you under the Apache License, Version 2.0 (the
-// "License"); you may not use this file except in compliance
-// with the License. You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing,
-// software distributed under the License is distributed on an
-// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-// KIND, either express or implied. See the License for the
-// specific language governing permissions and limitations
-// under the License.
-
-package org.apache.doris.nereids.cost;
-
-import com.google.common.base.Preconditions;
-
-/**
- * Use for estimating the cost of plan.
- */
-public final class CostEstimate {
- private static final CostEstimate INFINITE =
- new CostEstimate(Double.POSITIVE_INFINITY,
- Double.POSITIVE_INFINITY,
- Double.POSITIVE_INFINITY,
- Double.POSITIVE_INFINITY);
- private static final CostEstimate ZERO = new CostEstimate(0, 0, 0, 0);
-
- private final double cpuCost;
- private final double memoryCost;
- private final double networkCost;
- //penalty for
- // 1. right deep tree
- // 2. right XXX join
- private final double penalty;
-
- /**
- * Constructor of CostEstimate.
- */
- public CostEstimate(double cpuCost, double memoryCost, double networkCost, double penaltiy) {
- // TODO: fix stats
- if (cpuCost < 0) {
- cpuCost = 0;
- }
- if (memoryCost < 0) {
- memoryCost = 0;
- }
- if (networkCost < 0) {
- networkCost = 0;
- }
- Preconditions.checkArgument(!(cpuCost < 0), "cpuCost cannot be negative: %s", cpuCost);
- Preconditions.checkArgument(!(memoryCost < 0), "memoryCost cannot be negative: %s", memoryCost);
- Preconditions.checkArgument(!(networkCost < 0), "networkCost cannot be negative: %s", networkCost);
- this.cpuCost = cpuCost;
- this.memoryCost = memoryCost;
- this.networkCost = networkCost;
- this.penalty = penaltiy;
- }
-
- public static CostEstimate infinite() {
- return INFINITE;
- }
-
- public static CostEstimate zero() {
- return ZERO;
- }
-
- public double getCpuCost() {
- return cpuCost;
- }
-
- public double getMemoryCost() {
- return memoryCost;
- }
-
- public double getNetworkCost() {
- return networkCost;
- }
-
- public double getPenalty() {
- return penalty;
- }
-
- public static CostEstimate of(double cpuCost, double maxMemory, double networkCost, double rightDeepPenaltiy) {
- return new CostEstimate(cpuCost, maxMemory, networkCost, rightDeepPenaltiy);
- }
-
- public static CostEstimate of(double cpuCost, double maxMemory, double networkCost) {
- return new CostEstimate(cpuCost, maxMemory, networkCost, 0);
- }
-
- public static CostEstimate ofCpu(double cpuCost) {
- return new CostEstimate(cpuCost, 0, 0, 0);
- }
-
- public static CostEstimate ofMemory(double memoryCost) {
- return new CostEstimate(0, memoryCost, 0, 0);
- }
-
- /**
- * sum of cost estimate
- */
- public static CostEstimate sum(CostEstimate one, CostEstimate two, CostEstimate... more) {
- double cpuCostSum = one.cpuCost + two.cpuCost;
- double memoryCostSum = one.memoryCost + two.memoryCost;
- double networkCostSum = one.networkCost + one.networkCost;
- for (CostEstimate costEstimate : more) {
- cpuCostSum += costEstimate.cpuCost;
- memoryCostSum += costEstimate.memoryCost;
- networkCostSum += costEstimate.networkCost;
- }
- return CostEstimate.of(cpuCostSum, memoryCostSum, networkCostSum);
- }
-
- @Override
- public String toString() {
- StringBuilder sb = new StringBuilder();
- sb.append((long) cpuCost)
- .append("/").append((long) memoryCost)
- .append("/").append((long) networkCost)
- .append("/").append((long) penalty);
- return sb.toString();
- }
-}
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/cost/CostModelV1.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/cost/CostModelV1.java
new file mode 100644
index 0000000000..d290ee38b3
--- /dev/null
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/cost/CostModelV1.java
@@ -0,0 +1,279 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.cost;
+
+import org.apache.doris.nereids.PlanContext;
+import org.apache.doris.nereids.properties.DistributionSpec;
+import org.apache.doris.nereids.properties.DistributionSpecGather;
+import org.apache.doris.nereids.properties.DistributionSpecHash;
+import org.apache.doris.nereids.properties.DistributionSpecReplicated;
+import org.apache.doris.nereids.trees.plans.Plan;
+import org.apache.doris.nereids.trees.plans.physical.PhysicalAssertNumRows;
+import org.apache.doris.nereids.trees.plans.physical.PhysicalDistribute;
+import org.apache.doris.nereids.trees.plans.physical.PhysicalEsScan;
+import org.apache.doris.nereids.trees.plans.physical.PhysicalFileScan;
+import org.apache.doris.nereids.trees.plans.physical.PhysicalGenerate;
+import org.apache.doris.nereids.trees.plans.physical.PhysicalHashAggregate;
+import org.apache.doris.nereids.trees.plans.physical.PhysicalHashJoin;
+import org.apache.doris.nereids.trees.plans.physical.PhysicalJdbcScan;
+import org.apache.doris.nereids.trees.plans.physical.PhysicalNestedLoopJoin;
+import org.apache.doris.nereids.trees.plans.physical.PhysicalOlapScan;
+import org.apache.doris.nereids.trees.plans.physical.PhysicalProject;
+import org.apache.doris.nereids.trees.plans.physical.PhysicalQuickSort;
+import org.apache.doris.nereids.trees.plans.physical.PhysicalSchemaScan;
+import org.apache.doris.nereids.trees.plans.physical.PhysicalStorageLayerAggregate;
+import org.apache.doris.nereids.trees.plans.physical.PhysicalTopN;
+import org.apache.doris.nereids.trees.plans.visitor.PlanVisitor;
+import org.apache.doris.qe.ConnectContext;
+import org.apache.doris.statistics.StatsDeriveResult;
+
+import com.google.common.base.Preconditions;
+
+class CostModelV1 extends PlanVisitor {
+ /**
+ * The intuition behind `HEAVY_OPERATOR_PUNISH_FACTOR` is we need to avoid this form of join patterns:
+ * Plan1: L join ( AGG1(A) join AGG2(B))
+ * But
+ * Plan2: L join AGG1(A) join AGG2(B) is welcomed.
+ * AGG is time-consuming operator. From the perspective of rowCount, nereids may choose Plan1,
+ * because `Agg1 join Agg2` generates few tuples. But in Plan1, Agg1 and Agg2 are done in serial, in Plan2, Agg1 and
+ * Agg2 are done in parallel. And hence, Plan1 should be punished.
+ *
+ * An example is tpch q15.
+ */
+ static final double HEAVY_OPERATOR_PUNISH_FACTOR = 6.0;
+
+ public static Cost addChildCost(Plan plan, Cost planCost, Cost childCost, int index) {
+ Preconditions.checkArgument(childCost instanceof CostV1 && planCost instanceof CostV1);
+ double cost = planCost.getValue() + childCost.getValue();
+ return new CostV1(cost);
+ }
+
+ @Override
+ public Cost visit(Plan plan, PlanContext context) {
+ return CostV1.zero();
+ }
+
+ @Override
+ public Cost visitPhysicalOlapScan(PhysicalOlapScan physicalOlapScan, PlanContext context) {
+ StatsDeriveResult statistics = context.getStatisticsWithCheck();
+ return CostV1.ofCpu(statistics.getRowCount());
+ }
+
+ public Cost visitPhysicalSchemaScan(PhysicalSchemaScan physicalSchemaScan, PlanContext context) {
+ StatsDeriveResult statistics = context.getStatisticsWithCheck();
+ return CostV1.ofCpu(statistics.getRowCount());
+ }
+
+ @Override
+ public Cost visitPhysicalStorageLayerAggregate(
+ PhysicalStorageLayerAggregate storageLayerAggregate, PlanContext context) {
+ CostV1 costValue = (CostV1) storageLayerAggregate.getRelation().accept(this, context);
+ // multiply a factor less than 1, so we can select PhysicalStorageLayerAggregate as far as possible
+ return new CostV1(costValue.getCpuCost() * 0.7, costValue.getMemoryCost(),
+ costValue.getNetworkCost(), costValue.getPenalty());
+ }
+
+ @Override
+ public Cost visitPhysicalFileScan(PhysicalFileScan physicalFileScan, PlanContext context) {
+ StatsDeriveResult statistics = context.getStatisticsWithCheck();
+ return CostV1.ofCpu(statistics.getRowCount());
+ }
+
+ @Override
+ public Cost visitPhysicalProject(PhysicalProject extends Plan> physicalProject, PlanContext context) {
+ return CostV1.ofCpu(1);
+ }
+
+ @Override
+ public Cost visitPhysicalJdbcScan(PhysicalJdbcScan physicalJdbcScan, PlanContext context) {
+ StatsDeriveResult statistics = context.getStatisticsWithCheck();
+ return CostV1.ofCpu(statistics.getRowCount());
+ }
+
+ @Override
+ public Cost visitPhysicalEsScan(PhysicalEsScan physicalEsScan, PlanContext context) {
+ StatsDeriveResult statistics = context.getStatisticsWithCheck();
+ return CostV1.ofCpu(statistics.getRowCount());
+ }
+
+ @Override
+ public Cost visitPhysicalQuickSort(
+ PhysicalQuickSort extends Plan> physicalQuickSort, PlanContext context) {
+ // TODO: consider two-phase sort and enforcer.
+ StatsDeriveResult statistics = context.getStatisticsWithCheck();
+ StatsDeriveResult childStatistics = context.getChildStatistics(0);
+ if (physicalQuickSort.getSortPhase().isGather()) {
+ // Now we do more like two-phase sort, so penalise one-phase sort
+ statistics.updateRowCount(statistics.getRowCount() * 100);
+ }
+ return CostV1.of(
+ childStatistics.getRowCount(),
+ statistics.getRowCount(),
+ childStatistics.getRowCount());
+ }
+
+ @Override
+ public Cost visitPhysicalTopN(PhysicalTopN extends Plan> topN, PlanContext context) {
+ // TODO: consider two-phase sort and enforcer.
+ StatsDeriveResult statistics = context.getStatisticsWithCheck();
+ StatsDeriveResult childStatistics = context.getChildStatistics(0);
+ if (topN.getSortPhase().isGather()) {
+ // Now we do more like two-phase sort, so penalise one-phase sort
+ statistics.updateRowCount(statistics.getRowCount() * 100);
+ }
+ return CostV1.of(
+ childStatistics.getRowCount(),
+ statistics.getRowCount(),
+ childStatistics.getRowCount());
+ }
+
+ @Override
+ public Cost visitPhysicalDistribute(
+ PhysicalDistribute extends Plan> distribute, PlanContext context) {
+ StatsDeriveResult childStatistics = context.getChildStatistics(0);
+ DistributionSpec spec = distribute.getDistributionSpec();
+ // shuffle
+ if (spec instanceof DistributionSpecHash) {
+ return CostV1.of(
+ childStatistics.getRowCount(),
+ 0,
+ childStatistics.getRowCount());
+ }
+
+ // replicate
+ if (spec instanceof DistributionSpecReplicated) {
+ int beNumber = ConnectContext.get().getEnv().getClusterInfo().getBackendIds(true).size();
+ int instanceNumber = ConnectContext.get().getSessionVariable().getParallelExecInstanceNum();
+ beNumber = Math.max(1, beNumber);
+ double memLimit = ConnectContext.get().getSessionVariable().getMaxExecMemByte();
+ //if build side is big, avoid use broadcast join
+ double rowsLimit = ConnectContext.get().getSessionVariable().getBroadcastRowCountLimit();
+ double brMemlimit = ConnectContext.get().getSessionVariable().getBroadcastHashtableMemLimitPercentage();
+ double buildSize = childStatistics.computeSize();
+ if (buildSize * instanceNumber > memLimit * brMemlimit
+ || childStatistics.getRowCount() > rowsLimit) {
+ return CostV1.of(Double.MAX_VALUE, Double.MAX_VALUE, Double.MAX_VALUE);
+ }
+ return CostV1.of(
+ childStatistics.getRowCount() * beNumber,
+ childStatistics.getRowCount() * beNumber * instanceNumber,
+ childStatistics.getRowCount() * beNumber * instanceNumber);
+ }
+
+ // gather
+ if (spec instanceof DistributionSpecGather) {
+ return CostV1.of(
+ childStatistics.getRowCount(),
+ 0,
+ childStatistics.getRowCount());
+ }
+
+ // any
+ return CostV1.of(
+ childStatistics.getRowCount(),
+ 0,
+ 0);
+ }
+
+ @Override
+ public Cost visitPhysicalHashAggregate(
+ PhysicalHashAggregate extends Plan> aggregate, PlanContext context) {
+ // TODO: stage.....
+
+ StatsDeriveResult statistics = context.getStatisticsWithCheck();
+ StatsDeriveResult inputStatistics = context.getChildStatistics(0);
+ return CostV1.of(inputStatistics.getRowCount(), statistics.getRowCount(), 0);
+ }
+
+ @Override
+ public Cost visitPhysicalHashJoin(
+ PhysicalHashJoin extends Plan, ? extends Plan> physicalHashJoin, PlanContext context) {
+ Preconditions.checkState(context.arity() == 2);
+ StatsDeriveResult outputStats = context.getStatisticsWithCheck();
+ double outputRowCount = outputStats.getRowCount();
+
+ StatsDeriveResult probeStats = context.getChildStatistics(0);
+ StatsDeriveResult buildStats = context.getChildStatistics(1);
+
+ double leftRowCount = probeStats.getRowCount();
+ double rightRowCount = buildStats.getRowCount();
+ /*
+ pattern1: L join1 (Agg1() join2 Agg2())
+ result number of join2 may much less than Agg1.
+ but Agg1 and Agg2 are slow. so we need to punish this pattern1.
+
+ pattern2: (L join1 Agg1) join2 agg2
+ in pattern2, join1 and join2 takes more time, but Agg1 and agg2 can be processed in parallel.
+ */
+ double penalty = HEAVY_OPERATOR_PUNISH_FACTOR
+ * Math.min(probeStats.getPenalty(), buildStats.getPenalty());
+ if (buildStats.getWidth() >= 2) {
+ //penalty for right deep tree
+ penalty += rightRowCount;
+ }
+
+ if (physicalHashJoin.getJoinType().isCrossJoin()) {
+ return CostV1.of(leftRowCount + rightRowCount + outputRowCount,
+ 0,
+ leftRowCount + rightRowCount,
+ penalty);
+ }
+ return CostV1.of(leftRowCount + rightRowCount + outputRowCount,
+ rightRowCount,
+ 0,
+ penalty
+ );
+ }
+
+ @Override
+ public Cost visitPhysicalNestedLoopJoin(
+ PhysicalNestedLoopJoin extends Plan, ? extends Plan> nestedLoopJoin,
+ PlanContext context) {
+ // TODO: copy from physicalHashJoin, should update according to physical nested loop join properties.
+ Preconditions.checkState(context.arity() == 2);
+
+ StatsDeriveResult leftStatistics = context.getChildStatistics(0);
+ StatsDeriveResult rightStatistics = context.getChildStatistics(1);
+
+ return CostV1.of(
+ leftStatistics.getRowCount() * rightStatistics.getRowCount(),
+ rightStatistics.getRowCount(),
+ 0);
+ }
+
+ @Override
+ public Cost visitPhysicalAssertNumRows(PhysicalAssertNumRows extends Plan> assertNumRows,
+ PlanContext context) {
+ return CostV1.of(
+ assertNumRows.getAssertNumRowsElement().getDesiredNumOfRows(),
+ assertNumRows.getAssertNumRowsElement().getDesiredNumOfRows(),
+ 0
+ );
+ }
+
+ @Override
+ public Cost visitPhysicalGenerate(PhysicalGenerate extends Plan> generate, PlanContext context) {
+ StatsDeriveResult statistics = context.getStatisticsWithCheck();
+ return CostV1.of(
+ statistics.getRowCount(),
+ statistics.getRowCount(),
+ 0
+ );
+ }
+}
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/cost/CostV1.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/cost/CostV1.java
new file mode 100644
index 0000000000..6941bc03e6
--- /dev/null
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/cost/CostV1.java
@@ -0,0 +1,126 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.cost;
+
+import com.google.common.base.Preconditions;
+
+class CostV1 implements Cost {
+ private static final CostV1 INFINITE = new CostV1(Double.POSITIVE_INFINITY, Double.POSITIVE_INFINITY,
+ Double.POSITIVE_INFINITY, Double.POSITIVE_INFINITY);
+ private static final CostV1 ZERO = new CostV1(0, 0, 0, 0);
+
+ private final double cpuCost;
+ private final double memoryCost;
+ private final double networkCost;
+ //penalty for
+ // 1. right deep tree
+ // 2. right XXX join
+ private final double penalty;
+
+ private final double cost;
+
+ /**
+ * Constructor of CostEstimate.
+ */
+ public CostV1(double cpuCost, double memoryCost, double networkCost, double penaltiy) {
+ // TODO: fix stats
+ cpuCost = Double.max(0, cpuCost);
+ memoryCost = Double.max(0, memoryCost);
+ networkCost = Double.max(0, networkCost);
+ this.cpuCost = cpuCost;
+ this.memoryCost = memoryCost;
+ this.networkCost = networkCost;
+ this.penalty = penaltiy;
+
+ CostWeight costWeight = CostWeight.get();
+ this.cost = costWeight.cpuWeight * cpuCost + costWeight.memoryWeight * memoryCost
+ + costWeight.networkWeight * networkCost + costWeight.penaltyWeight * penalty;
+ }
+
+ public CostV1(double cost) {
+ this.cost = cost;
+ this.cpuCost = 0;
+ this.networkCost = 0;
+ this.memoryCost = 0;
+ this.penalty = 0;
+ }
+
+ @Override
+ public Cost minus(Cost other) {
+ return new CostV1(cost - other.getValue());
+ }
+
+ public static CostV1 infinite() {
+ return INFINITE;
+ }
+
+ public static CostV1 zero() {
+ return ZERO;
+ }
+
+ public double getCpuCost() {
+ return cpuCost;
+ }
+
+ public double getMemoryCost() {
+ return memoryCost;
+ }
+
+ public double getNetworkCost() {
+ return networkCost;
+ }
+
+ public double getPenalty() {
+ return penalty;
+ }
+
+ public double getValue() {
+ return cost;
+ }
+
+ public static CostV1 of(double cpuCost, double maxMemory, double networkCost, double rightDeepPenaltiy) {
+ return new CostV1(cpuCost, maxMemory, networkCost, rightDeepPenaltiy);
+ }
+
+ public static CostV1 of(double cpuCost, double maxMemory, double networkCost) {
+ return new CostV1(cpuCost, maxMemory, networkCost, 0);
+ }
+
+ public static CostV1 ofCpu(double cpuCost) {
+ return new CostV1(cpuCost, 0, 0, 0);
+ }
+
+ public static CostV1 ofMemory(double memoryCost) {
+ return new CostV1(0, memoryCost, 0, 0);
+ }
+
+ @Override
+ public String toString() {
+ StringBuilder sb = new StringBuilder();
+ sb.append((long) cpuCost).append("/").append((long) memoryCost).append("/").append((long) networkCost)
+ .append("/").append((long) penalty);
+ return sb.toString();
+ }
+
+ @Override
+ public int compare(Cost other) {
+ Preconditions.checkArgument(other instanceof CostV1, "costValueV1 can only compare with costValueV1");
+ return Double.compare(cost, ((CostV1) other).cost);
+ }
+}
+
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/cost/CostWeight.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/cost/CostWeight.java
index c154a9bef0..ae23a22234 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/cost/CostWeight.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/cost/CostWeight.java
@@ -17,16 +17,38 @@
package org.apache.doris.nereids.cost;
+import org.apache.doris.qe.ConnectContext;
+
import com.google.common.base.Preconditions;
/**
* cost weight.
+ * The intuition behind `HEAVY_OPERATOR_PUNISH_FACTOR` is we need to avoid this form of join patterns:
+ * Plan1: L join ( AGG1(A) join AGG2(B))
+ * But
+ * Plan2: L join AGG1(A) join AGG2(B) is welcomed.
+ * AGG is time-consuming operator. From the perspective of rowCount, nereids may choose Plan1,
+ * because `Agg1 join Agg2` generates few tuples. But in Plan1, Agg1 and Agg2 are done in serial, in Plan2, Agg1 and
+ * Agg2 are done in parallel. And hence, Plan1 should be punished.
+ *
+ * An example is tpch q15.
*/
public class CostWeight {
- private final double cpuWeight;
- private final double memoryWeight;
- private final double networkWeight;
- private final double penaltyWeight;
+ static final double CPU_WEIGHT = 1;
+ static final double MEMORY_WEIGHT = 1;
+ static final double NETWORK_WEIGHT = 1.5;
+ final double cpuWeight;
+ final double memoryWeight;
+ final double networkWeight;
+ /*
+ * About PENALTY:
+ * Except stats information, there are some special criteria in doris.
+ * For example, in hash join cluster, BE could build hash tables
+ * in parallel for left deep tree. And hence, we need to punish right deep tree.
+ * penalyWeight is the factor of punishment.
+ * The punishment is denoted by stats.penalty.
+ */
+ final double penaltyWeight;
/**
* Constructor
@@ -42,9 +64,8 @@ public class CostWeight {
this.penaltyWeight = penaltyWeight;
}
- public double calculate(CostEstimate costEstimate) {
- return costEstimate.getCpuCost() * cpuWeight + costEstimate.getMemoryCost() * memoryWeight
- + costEstimate.getNetworkCost() * networkWeight
- + costEstimate.getPenalty() * penaltyWeight;
+ public static CostWeight get() {
+ return new CostWeight(CPU_WEIGHT, MEMORY_WEIGHT, NETWORK_WEIGHT,
+ ConnectContext.get().getSessionVariable().getNereidsCboPenaltyFactor());
}
}
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/cascades/CostAndEnforcerJob.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/cascades/CostAndEnforcerJob.java
index 9a5d410d20..769dab4296 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/cascades/CostAndEnforcerJob.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/cascades/CostAndEnforcerJob.java
@@ -18,6 +18,7 @@
package org.apache.doris.nereids.jobs.cascades;
import org.apache.doris.common.Pair;
+import org.apache.doris.nereids.cost.Cost;
import org.apache.doris.nereids.cost.CostCalculator;
import org.apache.doris.nereids.jobs.Job;
import org.apache.doris.nereids.jobs.JobContext;
@@ -50,9 +51,9 @@ public class CostAndEnforcerJob extends Job implements Cloneable {
private final GroupExpression groupExpression;
// cost of current plan tree
- private double curTotalCost;
+ private Cost curTotalCost;
// cost of current plan node
- private double curNodeCost;
+ private Cost curNodeCost;
// List of request property to children
// Example: Physical Hash Join
@@ -114,8 +115,8 @@ public class CostAndEnforcerJob extends Job implements Cloneable {
countJobExecutionTimesOfGroupExpressions(groupExpression);
// Do init logic of root plan/groupExpr of `subplan`, only run once per task.
if (curChildIndex == -1) {
- curNodeCost = 0;
- curTotalCost = 0;
+ curNodeCost = Cost.zero();
+ curTotalCost = Cost.zero();
curChildIndex = 0;
// List
// [ child item: [leftProperties, rightProperties]]
@@ -138,8 +139,8 @@ public class CostAndEnforcerJob extends Job implements Cloneable {
// Calculate cost
if (curChildIndex == 0 && prevChildIndex == -1) {
curNodeCost = CostCalculator.calculateCost(groupExpression);
- groupExpression.setCost(curNodeCost);
- curTotalCost += curNodeCost;
+ groupExpression.setCost(curNodeCost.getValue());
+ curTotalCost = curNodeCost;
}
// Handle all child plan node.
@@ -149,7 +150,7 @@ public class CostAndEnforcerJob extends Job implements Cloneable {
// Whether the child group was optimized for this requestChildProperty according to
// the result of returning.
- Optional> lowestCostPlanOpt
+ Optional> lowestCostPlanOpt
= childGroup.getLowestCostPlan(requestChildProperty);
if (!lowestCostPlanOpt.isPresent()) {
@@ -166,7 +167,7 @@ public class CostAndEnforcerJob extends Job implements Cloneable {
// Meaning that optimize recursively by derive job.
prevChildIndex = curChildIndex;
pushJob(clone());
- double newCostUpperBound = context.getCostUpperBound() - curTotalCost;
+ double newCostUpperBound = context.getCostUpperBound() - curTotalCost.getValue();
JobContext jobContext = new JobContext(context.getCascadesContext(),
requestChildProperty, newCostUpperBound);
pushJob(new OptimizeGroupJob(childGroup, jobContext));
@@ -182,10 +183,12 @@ public class CostAndEnforcerJob extends Job implements Cloneable {
// plan's requestChildProperty).getOutputProperties(current plan's requestChildProperty) == child
// plan's outputProperties`, the outputProperties must satisfy the origin requestChildProperty
outputChildrenProperties.set(curChildIndex, outputProperties);
-
- curTotalCost += lowestCostExpr.getLowestCostTable().get(requestChildProperty).first;
- if (curTotalCost > context.getCostUpperBound()) {
- curTotalCost = Double.POSITIVE_INFINITY;
+ curTotalCost = CostCalculator.addChildCost(groupExpression.getPlan(),
+ curNodeCost,
+ lowestCostExpr.getCostValueByProperties(requestChildProperty),
+ curChildIndex);
+ if (curTotalCost.getValue() > context.getCostUpperBound()) {
+ curTotalCost = Cost.infinite();
}
// the request child properties will be covered by the output properties
// that corresponding to the request properties. so if we run a costAndEnforceJob of the same
@@ -198,8 +201,8 @@ public class CostAndEnforcerJob extends Job implements Cloneable {
if (!calculateEnforce(requestChildrenProperties, outputChildrenProperties)) {
return; // if error exists, return
}
- if (curTotalCost < context.getCostUpperBound()) {
- context.setCostUpperBound(curTotalCost);
+ if (curTotalCost.getValue() < context.getCostUpperBound()) {
+ context.setCostUpperBound(curTotalCost.getValue());
}
}
clear();
@@ -222,7 +225,6 @@ public class CostAndEnforcerJob extends Job implements Cloneable {
// invalid enforce, return.
return false;
}
- curTotalCost += enforceCost;
// Not need to do pruning here because it has been done when we get the
// best expr from the child group
@@ -237,14 +239,18 @@ public class CostAndEnforcerJob extends Job implements Cloneable {
return false;
}
StatsCalculator.estimate(groupExpression);
- // previous curTotalCost exclude the exists best cost of current node
- curTotalCost -= curNodeCost;
+
+ // recompute cost after adjusting property
curNodeCost = CostCalculator.calculateCost(groupExpression); // recompute current node's cost in current context
- groupExpression.setCost(curNodeCost);
- // (previous curTotalCost) - (previous curNodeCost) + (current curNodeCost) = (current curTotalCost).
- // if current curTotalCost maybe less than previous curTotalCost, we will update the lowest cost and plan
- // to the grouping expression and the owner group
- curTotalCost += curNodeCost;
+ groupExpression.setCost(curNodeCost.getValue());
+ curTotalCost = curNodeCost;
+ for (int i = 0; i < outputChildrenProperties.size(); i++) {
+ PhysicalProperties childProperties = outputChildrenProperties.get(i);
+ curTotalCost = CostCalculator.addChildCost(groupExpression.getPlan(),
+ curTotalCost,
+ groupExpression.child(i).getLowestCostPlan(childProperties).get().first,
+ i);
+ }
// record map { outputProperty -> outputProperty }, { ANY -> outputProperty },
recordPropertyAndCost(groupExpression, outputProperty, PhysicalProperties.ANY, outputChildrenProperties);
@@ -303,8 +309,8 @@ public class CostAndEnforcerJob extends Job implements Cloneable {
lowestCostChildren.clear();
prevChildIndex = -1;
curChildIndex = 0;
- curTotalCost = 0;
- curNodeCost = 0;
+ curTotalCost = Cost.zero();
+ curNodeCost = Cost.zero();
}
/**
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/GraphSimplifier.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/GraphSimplifier.java
index db64127bf3..e80e1472c4 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/GraphSimplifier.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/GraphSimplifier.java
@@ -19,6 +19,7 @@ package org.apache.doris.nereids.jobs.joinorder.hypergraph;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.PlanContext;
+import org.apache.doris.nereids.cost.Cost;
import org.apache.doris.nereids.cost.CostCalculator;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.bitmap.LongBitmap;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.receiver.Counter;
@@ -62,7 +63,7 @@ public class GraphSimplifier {
// because it's just used for simulating join. In fact, the graph simplifier
// just generate the partial order of join operator.
private final HashMap cacheStats = new HashMap<>();
- private final HashMap cacheCost = new HashMap<>();
+ private final HashMap cacheCost = new HashMap<>();
private final Stack appliedSteps = new Stack<>();
private final Stack unAppliedSteps = new Stack<>();
@@ -81,7 +82,7 @@ public class GraphSimplifier {
}
for (Node node : graph.getNodes()) {
cacheStats.put(node.getNodeMap(), node.getGroup().getStatistics());
- cacheCost.put(node.getNodeMap(), node.getGroup().getStatistics().getRowCount());
+ cacheCost.put(node.getNodeMap(), Cost.withRowCount(node.getRowCount()));
}
circleDetector = new CircleDetector(edgeSize);
@@ -382,26 +383,26 @@ public class GraphSimplifier {
private SimplificationStep orderJoin(Pair edge1Before2,
Pair edge2Before1, int edgeIndex1, int edgeIndex2) {
- double cost1Before2 = calCost(edge1Before2.second, edge1Before2.first,
+ Cost cost1Before2 = calCost(edge1Before2.second, edge1Before2.first,
cacheStats.get(edge1Before2.second.getLeft()),
cacheStats.get(edge1Before2.second.getRight()));
- double cost2Before1 = calCost(edge2Before1.second, edge1Before2.first,
+ Cost cost2Before1 = calCost(edge2Before1.second, edge1Before2.first,
cacheStats.get(edge1Before2.second.getLeft()),
cacheStats.get(edge1Before2.second.getRight()));
double benefit = Double.MAX_VALUE;
SimplificationStep step;
// Choose the plan with smaller cost and make the simplification step to replace the old edge by it.
- if (cost1Before2 < cost2Before1) {
- if (cost1Before2 != 0) {
- benefit = cost2Before1 / cost1Before2;
+ if (cost1Before2.getValue() < cost2Before1.getValue()) {
+ if (cost1Before2.getValue() != 0) {
+ benefit = cost2Before1.getValue() / cost1Before2.getValue();
}
// choose edge1Before2
step = new SimplificationStep(benefit, edgeIndex1, edgeIndex2, edge1Before2.second.getLeft(),
edge1Before2.second.getRight(), graph.getEdge(edgeIndex2).getLeft(),
graph.getEdge(edgeIndex2).getRight());
} else {
- if (cost2Before1 != 0) {
- benefit = cost1Before2 / cost2Before1;
+ if (cost2Before1.getValue() != 0) {
+ benefit = cost1Before2.getValue() / cost2Before1.getValue();
}
// choose edge2Before1
step = new SimplificationStep(benefit, edgeIndex2, edgeIndex1, edge2Before1.second.getLeft(),
@@ -422,11 +423,11 @@ public class GraphSimplifier {
return false;
}
- private double calCost(Edge edge, StatsDeriveResult stats,
+ private Cost calCost(Edge edge, StatsDeriveResult stats,
StatsDeriveResult leftStats, StatsDeriveResult rightStats) {
LogicalJoin join = edge.getJoin();
PlanContext planContext = new PlanContext(stats, leftStats, rightStats);
- double cost = 0;
+ Cost cost = Cost.zero();
if (JoinUtils.shouldNestedLoopJoin(join)) {
PhysicalNestedLoopJoin nestedLoopJoin = new PhysicalNestedLoopJoin<>(
join.getJoinType(),
@@ -437,6 +438,8 @@ public class GraphSimplifier {
join.left(),
join.right());
cost = CostCalculator.calculateCost(nestedLoopJoin, planContext);
+ cost = CostCalculator.addChildCost(nestedLoopJoin, cost, cacheCost.get(edge.getLeft()), 0);
+ cost = CostCalculator.addChildCost(nestedLoopJoin, cost, cacheCost.get(edge.getRight()), 1);
} else {
PhysicalHashJoin hashJoin = new PhysicalHashJoin<>(
join.getJoinType(),
@@ -448,8 +451,11 @@ public class GraphSimplifier {
join.left(),
join.right());
cost = CostCalculator.calculateCost(hashJoin, planContext);
+ cost = CostCalculator.addChildCost(hashJoin, cost, cacheCost.get(edge.getLeft()), 0);
+ cost = CostCalculator.addChildCost(hashJoin, cost, cacheCost.get(edge.getRight()), 1);
}
- return cost + cacheCost.get(edge.getLeft()) + cacheCost.get(edge.getRight());
+
+ return cost;
}
/**
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/Group.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/Group.java
index 88e20f34e3..36fe488e6c 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/Group.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/Group.java
@@ -18,6 +18,7 @@
package org.apache.doris.nereids.memo;
import org.apache.doris.common.Pair;
+import org.apache.doris.nereids.cost.Cost;
import org.apache.doris.nereids.properties.LogicalProperties;
import org.apache.doris.nereids.properties.PhysicalProperties;
import org.apache.doris.nereids.trees.plans.JoinType;
@@ -60,7 +61,7 @@ public class Group {
// Map of cost lower bounds
// Map required plan props to cost lower bound of corresponding plan
- private final Map> lowestCostPlans = Maps.newHashMap();
+ private final Map> lowestCostPlans = Maps.newHashMap();
private boolean isExplored = false;
@@ -189,7 +190,7 @@ public class Group {
* @param physicalProperties the physical property constraints
* @return {@link Optional} of cost and {@link GroupExpression} of physical plan pair.
*/
- public Optional> getLowestCostPlan(PhysicalProperties physicalProperties) {
+ public Optional> getLowestCostPlan(PhysicalProperties physicalProperties) {
if (physicalProperties == null || lowestCostPlans.isEmpty()) {
return Optional.empty();
}
@@ -206,9 +207,9 @@ public class Group {
/**
* Set or update lowestCostPlans: properties --> Pair.of(cost, expression)
*/
- public void setBestPlan(GroupExpression expression, double cost, PhysicalProperties properties) {
+ public void setBestPlan(GroupExpression expression, Cost cost, PhysicalProperties properties) {
if (lowestCostPlans.containsKey(properties)) {
- if (lowestCostPlans.get(properties).first > cost) {
+ if (lowestCostPlans.get(properties).first.getValue() > cost.getValue()) {
lowestCostPlans.put(properties, Pair.of(cost, expression));
}
} else {
@@ -219,8 +220,9 @@ public class Group {
/**
* replace best plan with new properties
*/
- public void replaceBestPlanProperty(PhysicalProperties oldProperty, PhysicalProperties newProperty, double cost) {
- Pair pair = lowestCostPlans.get(oldProperty);
+ public void replaceBestPlanProperty(PhysicalProperties oldProperty,
+ PhysicalProperties newProperty, Cost cost) {
+ Pair pair = lowestCostPlans.get(oldProperty);
GroupExpression lowestGroupExpr = pair.second;
lowestGroupExpr.updateLowestCostTable(newProperty,
lowestGroupExpr.getInputPropertiesList(oldProperty), cost);
@@ -232,11 +234,11 @@ public class Group {
* replace oldGroupExpression with newGroupExpression in lowestCostPlans.
*/
public void replaceBestPlanGroupExpr(GroupExpression oldGroupExpression, GroupExpression newGroupExpression) {
- Map> needReplaceBestExpressions = Maps.newHashMap();
- for (Iterator>> iterator =
+ Map> needReplaceBestExpressions = Maps.newHashMap();
+ for (Iterator>> iterator =
lowestCostPlans.entrySet().iterator(); iterator.hasNext(); ) {
- Map.Entry> entry = iterator.next();
- Pair pair = entry.getValue();
+ Map.Entry> entry = iterator.next();
+ Pair pair = entry.getValue();
if (pair.second.equals(oldGroupExpression)) {
needReplaceBestExpressions.put(entry.getKey(), Pair.of(pair.first, newGroupExpression));
iterator.remove();
@@ -345,7 +347,8 @@ public class Group {
if (!target.lowestCostPlans.containsKey(physicalProperties)) {
target.lowestCostPlans.put(physicalProperties, costAndGroupExpr);
} else {
- if (costAndGroupExpr.first < target.lowestCostPlans.get(physicalProperties).first) {
+ if (costAndGroupExpr.first.getValue()
+ < target.lowestCostPlans.get(physicalProperties).first.getValue()) {
target.lowestCostPlans.put(physicalProperties, costAndGroupExpr);
}
}
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/GroupExpression.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/GroupExpression.java
index f1374e2d80..4555935f25 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/GroupExpression.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/GroupExpression.java
@@ -18,7 +18,7 @@
package org.apache.doris.nereids.memo;
import org.apache.doris.common.Pair;
-import org.apache.doris.nereids.cost.CostEstimate;
+import org.apache.doris.nereids.cost.Cost;
import org.apache.doris.nereids.metrics.EventChannel;
import org.apache.doris.nereids.metrics.EventProducer;
import org.apache.doris.nereids.metrics.consumer.LogConsumer;
@@ -48,7 +48,6 @@ public class GroupExpression {
EventChannel.getDefaultChannel().addConsumers(new LogConsumer(CostStateUpdateEvent.class,
EventChannel.LOG)));
private double cost = 0.0;
- private CostEstimate costEstimate = null;
private Group ownerGroup;
private final List children;
private final Plan plan;
@@ -60,7 +59,7 @@ public class GroupExpression {
// Mapping from output properties to the corresponding best cost, statistics, and child properties.
// key is the physical properties the group expression support for its parent
// and value is cost and request physical properties to its children.
- private final Map>> lowestCostTable;
+ private final Map>> lowestCostTable;
// Each physical group expression maintains mapping incoming requests to the corresponding child requests.
// key is the output physical properties satisfying the incoming request properties
// value is the request physical properties
@@ -182,7 +181,7 @@ public class GroupExpression {
this.isUnused = isUnused;
}
- public Map>> getLowestCostTable() {
+ public Map>> getLowestCostTable() {
return lowestCostTable;
}
@@ -198,10 +197,10 @@ public class GroupExpression {
* @return true if lowest cost table change.
*/
public boolean updateLowestCostTable(PhysicalProperties outputProperties,
- List childrenInputProperties, double cost) {
- COST_STATE_TRACER.log(CostStateUpdateEvent.of(this, cost, outputProperties));
+ List childrenInputProperties, Cost cost) {
+ COST_STATE_TRACER.log(CostStateUpdateEvent.of(this, cost.getValue(), outputProperties));
if (lowestCostTable.containsKey(outputProperties)) {
- if (lowestCostTable.get(outputProperties).first > cost) {
+ if (lowestCostTable.get(outputProperties).first.getValue() > cost.getValue()) {
lowestCostTable.put(outputProperties, Pair.of(cost, childrenInputProperties));
return true;
} else {
@@ -220,6 +219,11 @@ public class GroupExpression {
* @return Lowest cost to satisfy that property
*/
public double getCostByProperties(PhysicalProperties property) {
+ Preconditions.checkState(lowestCostTable.containsKey(property));
+ return lowestCostTable.get(property).first.getValue();
+ }
+
+ public Cost getCostValueByProperties(PhysicalProperties property) {
Preconditions.checkState(lowestCostTable.containsKey(property));
return lowestCostTable.get(property).first;
}
@@ -263,14 +267,6 @@ public class GroupExpression {
this.cost = cost;
}
- public CostEstimate getCostEstimate() {
- return costEstimate;
- }
-
- public void setCostEstimate(CostEstimate costEstimate) {
- this.costEstimate = costEstimate;
- }
-
@Override
public boolean equals(Object o) {
if (this == o) {
@@ -310,9 +306,8 @@ public class GroupExpression {
builder.append("#").append(ownerGroup.getGroupId().asInt());
}
- if (costEstimate != null) {
- builder.append(" cost=").append((long) cost).append(" (").append(costEstimate).append(")");
- }
+ builder.append(" cost=").append((long) cost);
+
builder.append(" estRows=").append(estOutputRowCount);
builder.append(" (plan=").append(plan.toString()).append(") children=[");
for (Group group : children) {
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/Memo.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/Memo.java
index b27d9a1f59..acb2cf1296 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/Memo.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/Memo.java
@@ -22,6 +22,8 @@ import org.apache.doris.common.Pair;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.StatementContext;
import org.apache.doris.nereids.analyzer.CTEContext;
+import org.apache.doris.nereids.cost.Cost;
+import org.apache.doris.nereids.cost.CostCalculator;
import org.apache.doris.nereids.metrics.EventChannel;
import org.apache.doris.nereids.metrics.EventProducer;
import org.apache.doris.nereids.metrics.consumer.LogConsumer;
@@ -765,26 +767,27 @@ public class Memo {
return res;
}
- double bestChildrenCost = 0;
List>> children = new ArrayList<>();
for (int i = 0; i < inputProperties.size(); i++) {
// To avoid reach a circle, we don't allow ranking the same group with the same physical properties.
Preconditions.checkArgument(!groupExpression.child(i).equals(groupExpression.getOwnerGroup())
|| !prop.equals(inputProperties.get(i)));
- bestChildrenCost += groupExpression.children().get(i).getLowestCostPlan(inputProperties.get(i)).get().first;
List> idCostPair
= rankGroup(groupExpression.child(i), inputProperties.get(i));
children.add(idCostPair);
}
List>> childrenId = new ArrayList<>();
permute(children, 0, childrenId, new ArrayList<>());
+ Cost cost = CostCalculator.calculateCost(groupExpression);
for (Pair> c : childrenId) {
- double childCost = 0;
+ Cost totalCost = cost;
for (int i = 0; i < children.size(); i++) {
- childCost += children.get(i).get(c.second.get(i)).second;
+ totalCost = CostCalculator.addChildCost(groupExpression.getPlan(),
+ totalCost,
+ groupExpression.child(i).getLowestCostPlan(inputProperties.get(i)).get().first,
+ i);
}
- res.add(Pair.of(c.first,
- childCost + groupExpression.getCostByProperties(prop) - bestChildrenCost));
+ res.add(Pair.of(c.first, totalCost.getValue()));
}
return res;
}
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/ChildrenPropertiesRegulator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/ChildrenPropertiesRegulator.java
index 47753e389e..8aed9504e1 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/ChildrenPropertiesRegulator.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/ChildrenPropertiesRegulator.java
@@ -18,6 +18,7 @@
package org.apache.doris.nereids.properties;
import org.apache.doris.common.Pair;
+import org.apache.doris.nereids.cost.Cost;
import org.apache.doris.nereids.cost.CostCalculator;
import org.apache.doris.nereids.jobs.JobContext;
import org.apache.doris.nereids.memo.GroupExpression;
@@ -103,12 +104,12 @@ public class ChildrenPropertiesRegulator extends PlanVisitor {
DistributionSpecHash rightHashSpec = (DistributionSpecHash) rightDistributionSpec;
GroupExpression leftChild = children.get(0);
- final Pair> leftLowest
+ final Pair> leftLowest
= leftChild.getLowestCostTable().get(childrenProperties.get(0));
PhysicalProperties leftOutput = leftChild.getOutputProperties(childrenProperties.get(0));
GroupExpression rightChild = children.get(1);
- Pair> rightLowest
+ Pair> rightLowest
= rightChild.getLowestCostTable().get(childrenProperties.get(1));
PhysicalProperties rightOutput = rightChild.getOutputProperties(childrenProperties.get(1));
@@ -128,20 +129,20 @@ public class ChildrenPropertiesRegulator extends PlanVisitor {
PhysicalProperties rightRequireProperties = calRightRequiredOfBucketShuffleJoin(leftHashSpec,
rightHashSpec);
if (!rightOutput.equals(rightRequireProperties)) {
- enforceCost += updateChildEnforceAndCost(rightChild, rightOutput,
+ updateChildEnforceAndCost(rightChild, rightOutput,
(DistributionSpecHash) rightRequireProperties.getDistributionSpec(), rightLowest.first);
}
childrenProperties.set(1, rightRequireProperties);
return enforceCost;
}
- enforceCost += updateChildEnforceAndCost(leftChild, leftOutput,
+ updateChildEnforceAndCost(leftChild, leftOutput,
(DistributionSpecHash) requiredProperties.get(0).getDistributionSpec(), leftLowest.first);
childrenProperties.set(0, requiredProperties.get(0));
}
// check right hand must distribute.
if (rightHashSpec.getShuffleType() != ShuffleType.ENFORCED) {
- enforceCost += updateChildEnforceAndCost(rightChild, rightOutput,
+ updateChildEnforceAndCost(rightChild, rightOutput,
(DistributionSpecHash) requiredProperties.get(1).getDistributionSpec(), rightLowest.first);
childrenProperties.set(1, requiredProperties.get(1));
}
@@ -175,15 +176,13 @@ public class ChildrenPropertiesRegulator extends PlanVisitor {
}
private double updateChildEnforceAndCost(GroupExpression child, PhysicalProperties childOutput,
- DistributionSpecHash required, double currentCost) {
- double enforceCost = 0;
+ DistributionSpecHash required, Cost currentCost) {
if (child.getPlan() instanceof PhysicalDistribute) {
//To avoid continuous distribute operator, we just enforce the child's child
childOutput = child.getInputPropertiesList(childOutput).get(0);
- Pair newChildAndCost
+ Pair newChildAndCost
= child.getOwnerGroup().getLowestCostPlan(childOutput).get();
child = newChildAndCost.second;
- enforceCost = newChildAndCost.first - currentCost;
currentCost = newChildAndCost.first;
}
@@ -193,13 +192,16 @@ public class ChildrenPropertiesRegulator extends PlanVisitor {
PhysicalProperties newOutputProperty = new PhysicalProperties(outputDistributionSpec);
GroupExpression enforcer = outputDistributionSpec.addEnforcer(child.getOwnerGroup());
jobContext.getCascadesContext().getMemo().addEnforcerPlan(enforcer, child.getOwnerGroup());
- enforceCost = Double.sum(enforceCost, CostCalculator.calculateCost(enforcer));
+ Cost totalCost = CostCalculator.addChildCost(enforcer.getPlan(),
+ currentCost,
+ CostCalculator.calculateCost(enforcer),
+ 0);
if (enforcer.updateLowestCostTable(newOutputProperty,
- Lists.newArrayList(childOutput), enforceCost + currentCost)) {
+ Lists.newArrayList(childOutput), totalCost)) {
enforcer.putOutputPropertiesMap(newOutputProperty, newOutputProperty);
}
- child.getOwnerGroup().setBestPlan(enforcer, enforceCost + currentCost, newOutputProperty);
+ child.getOwnerGroup().setBestPlan(enforcer, totalCost, newOutputProperty);
return enforceCost;
}
}
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/EnforceMissingPropertiesHelper.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/EnforceMissingPropertiesHelper.java
index 7b9ba19ccb..d38e9b8c2d 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/EnforceMissingPropertiesHelper.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/EnforceMissingPropertiesHelper.java
@@ -17,6 +17,7 @@
package org.apache.doris.nereids.properties;
+import org.apache.doris.nereids.cost.Cost;
import org.apache.doris.nereids.cost.CostCalculator;
import org.apache.doris.nereids.jobs.JobContext;
import org.apache.doris.nereids.memo.GroupExpression;
@@ -38,16 +39,16 @@ public class EnforceMissingPropertiesHelper {
EventChannel.getDefaultChannel().addConsumers(new LogConsumer(EnforcerEvent.class, EventChannel.LOG)));
private final JobContext context;
private final GroupExpression groupExpression;
- private double curTotalCost;
+ private Cost curTotalCost;
public EnforceMissingPropertiesHelper(JobContext context, GroupExpression groupExpression,
- double curTotalCost) {
+ Cost curTotalCost) {
this.context = context;
this.groupExpression = groupExpression;
this.curTotalCost = curTotalCost;
}
- public double getCurTotalCost() {
+ public Cost getCurTotalCost() {
return curTotalCost;
}
@@ -83,7 +84,8 @@ public class EnforceMissingPropertiesHelper {
*/
private PhysicalProperties enforceDistributionButMeetSort(PhysicalProperties output, PhysicalProperties request) {
groupExpression.getOwnerGroup()
- .replaceBestPlanProperty(output, PhysicalProperties.ANY, groupExpression.getCostByProperties(output));
+ .replaceBestPlanProperty(
+ output, PhysicalProperties.ANY, groupExpression.getCostValueByProperties(output));
return enforceSortAndDistribution(output, request);
}
@@ -149,8 +151,10 @@ public class EnforceMissingPropertiesHelper {
context.getCascadesContext().getMemo().addEnforcerPlan(enforcer, groupExpression.getOwnerGroup());
ENFORCER_TRACER.log(EnforcerEvent.of(groupExpression, ((PhysicalPlan) enforcer.getPlan()),
oldOutputProperty, newOutputProperty));
- curTotalCost += CostCalculator.calculateCost(enforcer);
-
+ curTotalCost = CostCalculator.addChildCost(enforcer.getPlan(),
+ CostCalculator.calculateCost(enforcer),
+ curTotalCost,
+ 0);
if (enforcer.updateLowestCostTable(newOutputProperty,
Lists.newArrayList(oldOutputProperty), curTotalCost)) {
enforcer.putOutputPropertiesMap(newOutputProperty, newOutputProperty);
diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/memo/MemoTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/memo/MemoTest.java
index 40d9d6289a..ee270eccac 100644
--- a/fe/fe-core/src/test/java/org/apache/doris/nereids/memo/MemoTest.java
+++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/memo/MemoTest.java
@@ -21,6 +21,7 @@ import org.apache.doris.catalog.OlapTable;
import org.apache.doris.common.jmockit.Deencapsulation;
import org.apache.doris.nereids.analyzer.UnboundRelation;
import org.apache.doris.nereids.analyzer.UnboundSlot;
+import org.apache.doris.nereids.cost.Cost;
import org.apache.doris.nereids.properties.LogicalProperties;
import org.apache.doris.nereids.properties.PhysicalProperties;
import org.apache.doris.nereids.properties.UnboundLogicalProperties;
@@ -91,7 +92,7 @@ class MemoTest implements MemoPatternMatchSupported {
FakePlan fakePlan = new FakePlan();
GroupExpression srcParentExpression = new GroupExpression(fakePlan, Lists.newArrayList(srcGroup));
Group srcParentGroup = new Group(new GroupId(0), srcParentExpression, new LogicalProperties(ArrayList::new));
- srcParentGroup.setBestPlan(srcParentExpression, Double.MIN_VALUE, PhysicalProperties.ANY);
+ srcParentGroup.setBestPlan(srcParentExpression, Cost.zero(), PhysicalProperties.ANY);
GroupExpression dstParentExpression = new GroupExpression(fakePlan, Lists.newArrayList(dstGroup));
Group dstParentGroup = new Group(new GroupId(1), dstParentExpression, new LogicalProperties(ArrayList::new));