[enhancement](Nereids) refactor costModel framework (#17339)

refactor cost-model frameWork:
1. Use Cost class to encapsulate double cost
2. Use the `addChildCost` function to calculate the cost with children rather than add directly

Note we use the `Cost` class because we hope to customize the operator of adding a child host. Therefore, only when the cost would add the child Cost or be added by the parent  we use `Cost`. Otherwise, we use double such as `upperbound`
This commit is contained in:
谢健
2023-03-09 13:58:44 +08:00
committed by GitHub
parent e1ea2e1f2c
commit aaedcf34cf
15 changed files with 597 additions and 502 deletions

View File

@ -267,7 +267,7 @@ public class NereidsPlanner extends Planner {
if (nthPlan <= 1) {
cost = rootGroup.getLowestCostPlan(physicalProperties).orElseThrow(
() -> new AnalysisException("lowestCostPlans with physicalProperties("
+ physicalProperties + ") doesn't exist in root group")).first;
+ physicalProperties + ") doesn't exist in root group")).first.getValue();
return chooseBestPlan(rootGroup, physicalProperties);
}
Memo memo = cascadesContext.getMemo();

View File

@ -0,0 +1,44 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.cost;
/**
* Cost encapsulate the real cost with double type.
* We do this because we want to customize the operation of adding child cost
* according different operator
*/
public interface Cost {
public double getValue();
public int compare(Cost other);
public Cost minus(Cost other);
public static Cost withRowCount(double rowCount) {
return new CostV1(rowCount, 0, 0, 0);
}
public static Cost zero() {
return CostV1.zero();
}
public static Cost infinite() {
return CostV1.infinite();
}
}

View File

@ -20,31 +20,7 @@ package org.apache.doris.nereids.cost;
import org.apache.doris.nereids.PlanContext;
import org.apache.doris.nereids.annotation.Developing;
import org.apache.doris.nereids.memo.GroupExpression;
import org.apache.doris.nereids.properties.DistributionSpec;
import org.apache.doris.nereids.properties.DistributionSpecGather;
import org.apache.doris.nereids.properties.DistributionSpecHash;
import org.apache.doris.nereids.properties.DistributionSpecReplicated;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.physical.PhysicalAssertNumRows;
import org.apache.doris.nereids.trees.plans.physical.PhysicalDistribute;
import org.apache.doris.nereids.trees.plans.physical.PhysicalEsScan;
import org.apache.doris.nereids.trees.plans.physical.PhysicalFileScan;
import org.apache.doris.nereids.trees.plans.physical.PhysicalGenerate;
import org.apache.doris.nereids.trees.plans.physical.PhysicalHashAggregate;
import org.apache.doris.nereids.trees.plans.physical.PhysicalHashJoin;
import org.apache.doris.nereids.trees.plans.physical.PhysicalJdbcScan;
import org.apache.doris.nereids.trees.plans.physical.PhysicalNestedLoopJoin;
import org.apache.doris.nereids.trees.plans.physical.PhysicalOlapScan;
import org.apache.doris.nereids.trees.plans.physical.PhysicalProject;
import org.apache.doris.nereids.trees.plans.physical.PhysicalQuickSort;
import org.apache.doris.nereids.trees.plans.physical.PhysicalSchemaScan;
import org.apache.doris.nereids.trees.plans.physical.PhysicalStorageLayerAggregate;
import org.apache.doris.nereids.trees.plans.physical.PhysicalTopN;
import org.apache.doris.nereids.trees.plans.visitor.PlanVisitor;
import org.apache.doris.qe.ConnectContext;
import org.apache.doris.statistics.StatsDeriveResult;
import com.google.common.base.Preconditions;
/**
* Calculate the cost of a plan.
@ -53,264 +29,21 @@ import com.google.common.base.Preconditions;
@Developing
//TODO: memory cost and network cost should be estimated by byte size.
public class CostCalculator {
static final double CPU_WEIGHT = 1;
static final double MEMORY_WEIGHT = 1;
static final double NETWORK_WEIGHT = 1.5;
/**
* The intuition behind `HEAVY_OPERATOR_PUNISH_FACTOR` is we need to avoid this form of join patterns:
* Plan1: L join ( AGG1(A) join AGG2(B))
* But
* Plan2: L join AGG1(A) join AGG2(B) is welcomed.
* AGG is time-consuming operator. From the perspective of rowCount, nereids may choose Plan1,
* because `Agg1 join Agg2` generates few tuples. But in Plan1, Agg1 and Agg2 are done in serial, in Plan2, Agg1 and
* Agg2 are done in parallel. And hence, Plan1 should be punished.
* <p>
* An example is tpch q15.
*/
static final double HEAVY_OPERATOR_PUNISH_FACTOR = 6.0;
/**
* Constructor.
*/
public static double calculateCost(GroupExpression groupExpression) {
public static Cost calculateCost(GroupExpression groupExpression) {
PlanContext planContext = new PlanContext(groupExpression);
CostEstimator costCalculator = new CostEstimator();
CostEstimate costEstimate = groupExpression.getPlan().accept(costCalculator, planContext);
groupExpression.setCostEstimate(costEstimate);
/*
* About PENALTY:
* Except stats information, there are some special criteria in doris.
* For example, in hash join cluster, BE could build hash tables
* in parallel for left deep tree. And hence, we need to punish right deep tree.
* penalyWeight is the factor of punishment.
* The punishment is denoted by stats.penalty.
*/
CostWeight costWeight = new CostWeight(CPU_WEIGHT, MEMORY_WEIGHT, NETWORK_WEIGHT,
ConnectContext.get().getSessionVariable().getNereidsCboPenaltyFactor());
return costWeight.calculate(costEstimate);
CostModelV1 costModel = new CostModelV1();
return groupExpression.getPlan().accept(costModel, planContext);
}
public static double calculateCost(Plan plan, PlanContext planContext) {
CostEstimator costCalculator = new CostEstimator();
CostEstimate costEstimate = plan.accept(costCalculator, planContext);
CostWeight costWeight = new CostWeight(CPU_WEIGHT, MEMORY_WEIGHT, NETWORK_WEIGHT,
ConnectContext.get().getSessionVariable().getNereidsCboPenaltyFactor());
return costWeight.calculate(costEstimate);
public static Cost calculateCost(Plan plan, PlanContext planContext) {
CostModelV1 costModel = new CostModelV1();
return plan.accept(costModel, planContext);
}
private static class CostEstimator extends PlanVisitor<CostEstimate, PlanContext> {
@Override
public CostEstimate visit(Plan plan, PlanContext context) {
return CostEstimate.zero();
}
@Override
public CostEstimate visitPhysicalOlapScan(PhysicalOlapScan physicalOlapScan, PlanContext context) {
StatsDeriveResult statistics = context.getStatisticsWithCheck();
return CostEstimate.ofCpu(statistics.getRowCount());
}
public CostEstimate visitPhysicalSchemaScan(PhysicalSchemaScan physicalSchemaScan, PlanContext context) {
StatsDeriveResult statistics = context.getStatisticsWithCheck();
return CostEstimate.ofCpu(statistics.getRowCount());
}
@Override
public CostEstimate visitPhysicalStorageLayerAggregate(
PhysicalStorageLayerAggregate storageLayerAggregate, PlanContext context) {
CostEstimate costEstimate = storageLayerAggregate.getRelation().accept(this, context);
// multiply a factor less than 1, so we can select PhysicalStorageLayerAggregate as far as possible
return new CostEstimate(costEstimate.getCpuCost() * 0.7, costEstimate.getMemoryCost(),
costEstimate.getNetworkCost(), costEstimate.getPenalty());
}
@Override
public CostEstimate visitPhysicalFileScan(PhysicalFileScan physicalFileScan, PlanContext context) {
StatsDeriveResult statistics = context.getStatisticsWithCheck();
return CostEstimate.ofCpu(statistics.getRowCount());
}
@Override
public CostEstimate visitPhysicalProject(PhysicalProject<? extends Plan> physicalProject, PlanContext context) {
return CostEstimate.ofCpu(1);
}
@Override
public CostEstimate visitPhysicalJdbcScan(PhysicalJdbcScan physicalJdbcScan, PlanContext context) {
StatsDeriveResult statistics = context.getStatisticsWithCheck();
return CostEstimate.ofCpu(statistics.getRowCount());
}
@Override
public CostEstimate visitPhysicalEsScan(PhysicalEsScan physicalEsScan, PlanContext context) {
StatsDeriveResult statistics = context.getStatisticsWithCheck();
return CostEstimate.ofCpu(statistics.getRowCount());
}
@Override
public CostEstimate visitPhysicalQuickSort(
PhysicalQuickSort<? extends Plan> physicalQuickSort, PlanContext context) {
// TODO: consider two-phase sort and enforcer.
StatsDeriveResult statistics = context.getStatisticsWithCheck();
StatsDeriveResult childStatistics = context.getChildStatistics(0);
if (physicalQuickSort.getSortPhase().isGather()) {
// Now we do more like two-phase sort, so penalise one-phase sort
statistics.updateRowCount(statistics.getRowCount() * 100);
}
return CostEstimate.of(
childStatistics.getRowCount(),
statistics.getRowCount(),
childStatistics.getRowCount());
}
@Override
public CostEstimate visitPhysicalTopN(PhysicalTopN<? extends Plan> topN, PlanContext context) {
// TODO: consider two-phase sort and enforcer.
StatsDeriveResult statistics = context.getStatisticsWithCheck();
StatsDeriveResult childStatistics = context.getChildStatistics(0);
if (topN.getSortPhase().isGather()) {
// Now we do more like two-phase sort, so penalise one-phase sort
statistics.updateRowCount(statistics.getRowCount() * 100);
}
return CostEstimate.of(
childStatistics.getRowCount(),
statistics.getRowCount(),
childStatistics.getRowCount());
}
@Override
public CostEstimate visitPhysicalDistribute(
PhysicalDistribute<? extends Plan> distribute, PlanContext context) {
StatsDeriveResult childStatistics = context.getChildStatistics(0);
DistributionSpec spec = distribute.getDistributionSpec();
// shuffle
if (spec instanceof DistributionSpecHash) {
return CostEstimate.of(
childStatistics.getRowCount(),
0,
childStatistics.getRowCount());
}
// replicate
if (spec instanceof DistributionSpecReplicated) {
int beNumber = ConnectContext.get().getEnv().getClusterInfo().getBackendIds(true).size();
int instanceNumber = ConnectContext.get().getSessionVariable().getParallelExecInstanceNum();
beNumber = Math.max(1, beNumber);
double memLimit = ConnectContext.get().getSessionVariable().getMaxExecMemByte();
//if build side is big, avoid use broadcast join
double rowsLimit = ConnectContext.get().getSessionVariable().getBroadcastRowCountLimit();
double brMemlimit = ConnectContext.get().getSessionVariable().getBroadcastHashtableMemLimitPercentage();
double buildSize = childStatistics.computeSize();
if (buildSize * instanceNumber > memLimit * brMemlimit
|| childStatistics.getRowCount() > rowsLimit) {
return CostEstimate.of(Double.MAX_VALUE, Double.MAX_VALUE, Double.MAX_VALUE);
}
return CostEstimate.of(
childStatistics.getRowCount() * beNumber,
childStatistics.getRowCount() * beNumber * instanceNumber,
childStatistics.getRowCount() * beNumber * instanceNumber);
}
// gather
if (spec instanceof DistributionSpecGather) {
return CostEstimate.of(
childStatistics.getRowCount(),
0,
childStatistics.getRowCount());
}
// any
return CostEstimate.of(
childStatistics.getRowCount(),
0,
0);
}
@Override
public CostEstimate visitPhysicalHashAggregate(
PhysicalHashAggregate<? extends Plan> aggregate, PlanContext context) {
// TODO: stage.....
StatsDeriveResult statistics = context.getStatisticsWithCheck();
StatsDeriveResult inputStatistics = context.getChildStatistics(0);
return CostEstimate.of(inputStatistics.getRowCount(), statistics.getRowCount(), 0);
}
@Override
public CostEstimate visitPhysicalHashJoin(
PhysicalHashJoin<? extends Plan, ? extends Plan> physicalHashJoin, PlanContext context) {
Preconditions.checkState(context.arity() == 2);
StatsDeriveResult outputStats = context.getStatisticsWithCheck();
double outputRowCount = outputStats.getRowCount();
StatsDeriveResult probeStats = context.getChildStatistics(0);
StatsDeriveResult buildStats = context.getChildStatistics(1);
double leftRowCount = probeStats.getRowCount();
double rightRowCount = buildStats.getRowCount();
/*
pattern1: L join1 (Agg1() join2 Agg2())
result number of join2 may much less than Agg1.
but Agg1 and Agg2 are slow. so we need to punish this pattern1.
pattern2: (L join1 Agg1) join2 agg2
in pattern2, join1 and join2 takes more time, but Agg1 and agg2 can be processed in parallel.
*/
double penalty = CostCalculator.HEAVY_OPERATOR_PUNISH_FACTOR
* Math.min(probeStats.getPenalty(), buildStats.getPenalty());
if (buildStats.getWidth() >= 2) {
//penalty for right deep tree
penalty += rightRowCount;
}
if (physicalHashJoin.getJoinType().isCrossJoin()) {
return CostEstimate.of(leftRowCount + rightRowCount + outputRowCount,
0,
leftRowCount + rightRowCount,
penalty);
}
return CostEstimate.of(leftRowCount + rightRowCount + outputRowCount,
rightRowCount,
0,
penalty
);
}
@Override
public CostEstimate visitPhysicalNestedLoopJoin(
PhysicalNestedLoopJoin<? extends Plan, ? extends Plan> nestedLoopJoin,
PlanContext context) {
// TODO: copy from physicalHashJoin, should update according to physical nested loop join properties.
Preconditions.checkState(context.arity() == 2);
StatsDeriveResult leftStatistics = context.getChildStatistics(0);
StatsDeriveResult rightStatistics = context.getChildStatistics(1);
return CostEstimate.of(
leftStatistics.getRowCount() * rightStatistics.getRowCount(),
rightStatistics.getRowCount(),
0);
}
@Override
public CostEstimate visitPhysicalAssertNumRows(PhysicalAssertNumRows<? extends Plan> assertNumRows,
PlanContext context) {
return CostEstimate.of(
assertNumRows.getAssertNumRowsElement().getDesiredNumOfRows(),
assertNumRows.getAssertNumRowsElement().getDesiredNumOfRows(),
0
);
}
@Override
public CostEstimate visitPhysicalGenerate(PhysicalGenerate<? extends Plan> generate, PlanContext context) {
StatsDeriveResult statistics = context.getStatisticsWithCheck();
return CostEstimate.of(
statistics.getRowCount(),
statistics.getRowCount(),
0
);
}
public static Cost addChildCost(Plan plan, Cost planCost, Cost childCost, int index) {
return CostModelV1.addChildCost(plan, planCost, childCost, index);
}
}

View File

@ -1,128 +0,0 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.cost;
import com.google.common.base.Preconditions;
/**
* Use for estimating the cost of plan.
*/
public final class CostEstimate {
private static final CostEstimate INFINITE =
new CostEstimate(Double.POSITIVE_INFINITY,
Double.POSITIVE_INFINITY,
Double.POSITIVE_INFINITY,
Double.POSITIVE_INFINITY);
private static final CostEstimate ZERO = new CostEstimate(0, 0, 0, 0);
private final double cpuCost;
private final double memoryCost;
private final double networkCost;
//penalty for
// 1. right deep tree
// 2. right XXX join
private final double penalty;
/**
* Constructor of CostEstimate.
*/
public CostEstimate(double cpuCost, double memoryCost, double networkCost, double penaltiy) {
// TODO: fix stats
if (cpuCost < 0) {
cpuCost = 0;
}
if (memoryCost < 0) {
memoryCost = 0;
}
if (networkCost < 0) {
networkCost = 0;
}
Preconditions.checkArgument(!(cpuCost < 0), "cpuCost cannot be negative: %s", cpuCost);
Preconditions.checkArgument(!(memoryCost < 0), "memoryCost cannot be negative: %s", memoryCost);
Preconditions.checkArgument(!(networkCost < 0), "networkCost cannot be negative: %s", networkCost);
this.cpuCost = cpuCost;
this.memoryCost = memoryCost;
this.networkCost = networkCost;
this.penalty = penaltiy;
}
public static CostEstimate infinite() {
return INFINITE;
}
public static CostEstimate zero() {
return ZERO;
}
public double getCpuCost() {
return cpuCost;
}
public double getMemoryCost() {
return memoryCost;
}
public double getNetworkCost() {
return networkCost;
}
public double getPenalty() {
return penalty;
}
public static CostEstimate of(double cpuCost, double maxMemory, double networkCost, double rightDeepPenaltiy) {
return new CostEstimate(cpuCost, maxMemory, networkCost, rightDeepPenaltiy);
}
public static CostEstimate of(double cpuCost, double maxMemory, double networkCost) {
return new CostEstimate(cpuCost, maxMemory, networkCost, 0);
}
public static CostEstimate ofCpu(double cpuCost) {
return new CostEstimate(cpuCost, 0, 0, 0);
}
public static CostEstimate ofMemory(double memoryCost) {
return new CostEstimate(0, memoryCost, 0, 0);
}
/**
* sum of cost estimate
*/
public static CostEstimate sum(CostEstimate one, CostEstimate two, CostEstimate... more) {
double cpuCostSum = one.cpuCost + two.cpuCost;
double memoryCostSum = one.memoryCost + two.memoryCost;
double networkCostSum = one.networkCost + one.networkCost;
for (CostEstimate costEstimate : more) {
cpuCostSum += costEstimate.cpuCost;
memoryCostSum += costEstimate.memoryCost;
networkCostSum += costEstimate.networkCost;
}
return CostEstimate.of(cpuCostSum, memoryCostSum, networkCostSum);
}
@Override
public String toString() {
StringBuilder sb = new StringBuilder();
sb.append((long) cpuCost)
.append("/").append((long) memoryCost)
.append("/").append((long) networkCost)
.append("/").append((long) penalty);
return sb.toString();
}
}

View File

@ -0,0 +1,279 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.cost;
import org.apache.doris.nereids.PlanContext;
import org.apache.doris.nereids.properties.DistributionSpec;
import org.apache.doris.nereids.properties.DistributionSpecGather;
import org.apache.doris.nereids.properties.DistributionSpecHash;
import org.apache.doris.nereids.properties.DistributionSpecReplicated;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.physical.PhysicalAssertNumRows;
import org.apache.doris.nereids.trees.plans.physical.PhysicalDistribute;
import org.apache.doris.nereids.trees.plans.physical.PhysicalEsScan;
import org.apache.doris.nereids.trees.plans.physical.PhysicalFileScan;
import org.apache.doris.nereids.trees.plans.physical.PhysicalGenerate;
import org.apache.doris.nereids.trees.plans.physical.PhysicalHashAggregate;
import org.apache.doris.nereids.trees.plans.physical.PhysicalHashJoin;
import org.apache.doris.nereids.trees.plans.physical.PhysicalJdbcScan;
import org.apache.doris.nereids.trees.plans.physical.PhysicalNestedLoopJoin;
import org.apache.doris.nereids.trees.plans.physical.PhysicalOlapScan;
import org.apache.doris.nereids.trees.plans.physical.PhysicalProject;
import org.apache.doris.nereids.trees.plans.physical.PhysicalQuickSort;
import org.apache.doris.nereids.trees.plans.physical.PhysicalSchemaScan;
import org.apache.doris.nereids.trees.plans.physical.PhysicalStorageLayerAggregate;
import org.apache.doris.nereids.trees.plans.physical.PhysicalTopN;
import org.apache.doris.nereids.trees.plans.visitor.PlanVisitor;
import org.apache.doris.qe.ConnectContext;
import org.apache.doris.statistics.StatsDeriveResult;
import com.google.common.base.Preconditions;
class CostModelV1 extends PlanVisitor<Cost, PlanContext> {
/**
* The intuition behind `HEAVY_OPERATOR_PUNISH_FACTOR` is we need to avoid this form of join patterns:
* Plan1: L join ( AGG1(A) join AGG2(B))
* But
* Plan2: L join AGG1(A) join AGG2(B) is welcomed.
* AGG is time-consuming operator. From the perspective of rowCount, nereids may choose Plan1,
* because `Agg1 join Agg2` generates few tuples. But in Plan1, Agg1 and Agg2 are done in serial, in Plan2, Agg1 and
* Agg2 are done in parallel. And hence, Plan1 should be punished.
* <p>
* An example is tpch q15.
*/
static final double HEAVY_OPERATOR_PUNISH_FACTOR = 6.0;
public static Cost addChildCost(Plan plan, Cost planCost, Cost childCost, int index) {
Preconditions.checkArgument(childCost instanceof CostV1 && planCost instanceof CostV1);
double cost = planCost.getValue() + childCost.getValue();
return new CostV1(cost);
}
@Override
public Cost visit(Plan plan, PlanContext context) {
return CostV1.zero();
}
@Override
public Cost visitPhysicalOlapScan(PhysicalOlapScan physicalOlapScan, PlanContext context) {
StatsDeriveResult statistics = context.getStatisticsWithCheck();
return CostV1.ofCpu(statistics.getRowCount());
}
public Cost visitPhysicalSchemaScan(PhysicalSchemaScan physicalSchemaScan, PlanContext context) {
StatsDeriveResult statistics = context.getStatisticsWithCheck();
return CostV1.ofCpu(statistics.getRowCount());
}
@Override
public Cost visitPhysicalStorageLayerAggregate(
PhysicalStorageLayerAggregate storageLayerAggregate, PlanContext context) {
CostV1 costValue = (CostV1) storageLayerAggregate.getRelation().accept(this, context);
// multiply a factor less than 1, so we can select PhysicalStorageLayerAggregate as far as possible
return new CostV1(costValue.getCpuCost() * 0.7, costValue.getMemoryCost(),
costValue.getNetworkCost(), costValue.getPenalty());
}
@Override
public Cost visitPhysicalFileScan(PhysicalFileScan physicalFileScan, PlanContext context) {
StatsDeriveResult statistics = context.getStatisticsWithCheck();
return CostV1.ofCpu(statistics.getRowCount());
}
@Override
public Cost visitPhysicalProject(PhysicalProject<? extends Plan> physicalProject, PlanContext context) {
return CostV1.ofCpu(1);
}
@Override
public Cost visitPhysicalJdbcScan(PhysicalJdbcScan physicalJdbcScan, PlanContext context) {
StatsDeriveResult statistics = context.getStatisticsWithCheck();
return CostV1.ofCpu(statistics.getRowCount());
}
@Override
public Cost visitPhysicalEsScan(PhysicalEsScan physicalEsScan, PlanContext context) {
StatsDeriveResult statistics = context.getStatisticsWithCheck();
return CostV1.ofCpu(statistics.getRowCount());
}
@Override
public Cost visitPhysicalQuickSort(
PhysicalQuickSort<? extends Plan> physicalQuickSort, PlanContext context) {
// TODO: consider two-phase sort and enforcer.
StatsDeriveResult statistics = context.getStatisticsWithCheck();
StatsDeriveResult childStatistics = context.getChildStatistics(0);
if (physicalQuickSort.getSortPhase().isGather()) {
// Now we do more like two-phase sort, so penalise one-phase sort
statistics.updateRowCount(statistics.getRowCount() * 100);
}
return CostV1.of(
childStatistics.getRowCount(),
statistics.getRowCount(),
childStatistics.getRowCount());
}
@Override
public Cost visitPhysicalTopN(PhysicalTopN<? extends Plan> topN, PlanContext context) {
// TODO: consider two-phase sort and enforcer.
StatsDeriveResult statistics = context.getStatisticsWithCheck();
StatsDeriveResult childStatistics = context.getChildStatistics(0);
if (topN.getSortPhase().isGather()) {
// Now we do more like two-phase sort, so penalise one-phase sort
statistics.updateRowCount(statistics.getRowCount() * 100);
}
return CostV1.of(
childStatistics.getRowCount(),
statistics.getRowCount(),
childStatistics.getRowCount());
}
@Override
public Cost visitPhysicalDistribute(
PhysicalDistribute<? extends Plan> distribute, PlanContext context) {
StatsDeriveResult childStatistics = context.getChildStatistics(0);
DistributionSpec spec = distribute.getDistributionSpec();
// shuffle
if (spec instanceof DistributionSpecHash) {
return CostV1.of(
childStatistics.getRowCount(),
0,
childStatistics.getRowCount());
}
// replicate
if (spec instanceof DistributionSpecReplicated) {
int beNumber = ConnectContext.get().getEnv().getClusterInfo().getBackendIds(true).size();
int instanceNumber = ConnectContext.get().getSessionVariable().getParallelExecInstanceNum();
beNumber = Math.max(1, beNumber);
double memLimit = ConnectContext.get().getSessionVariable().getMaxExecMemByte();
//if build side is big, avoid use broadcast join
double rowsLimit = ConnectContext.get().getSessionVariable().getBroadcastRowCountLimit();
double brMemlimit = ConnectContext.get().getSessionVariable().getBroadcastHashtableMemLimitPercentage();
double buildSize = childStatistics.computeSize();
if (buildSize * instanceNumber > memLimit * brMemlimit
|| childStatistics.getRowCount() > rowsLimit) {
return CostV1.of(Double.MAX_VALUE, Double.MAX_VALUE, Double.MAX_VALUE);
}
return CostV1.of(
childStatistics.getRowCount() * beNumber,
childStatistics.getRowCount() * beNumber * instanceNumber,
childStatistics.getRowCount() * beNumber * instanceNumber);
}
// gather
if (spec instanceof DistributionSpecGather) {
return CostV1.of(
childStatistics.getRowCount(),
0,
childStatistics.getRowCount());
}
// any
return CostV1.of(
childStatistics.getRowCount(),
0,
0);
}
@Override
public Cost visitPhysicalHashAggregate(
PhysicalHashAggregate<? extends Plan> aggregate, PlanContext context) {
// TODO: stage.....
StatsDeriveResult statistics = context.getStatisticsWithCheck();
StatsDeriveResult inputStatistics = context.getChildStatistics(0);
return CostV1.of(inputStatistics.getRowCount(), statistics.getRowCount(), 0);
}
@Override
public Cost visitPhysicalHashJoin(
PhysicalHashJoin<? extends Plan, ? extends Plan> physicalHashJoin, PlanContext context) {
Preconditions.checkState(context.arity() == 2);
StatsDeriveResult outputStats = context.getStatisticsWithCheck();
double outputRowCount = outputStats.getRowCount();
StatsDeriveResult probeStats = context.getChildStatistics(0);
StatsDeriveResult buildStats = context.getChildStatistics(1);
double leftRowCount = probeStats.getRowCount();
double rightRowCount = buildStats.getRowCount();
/*
pattern1: L join1 (Agg1() join2 Agg2())
result number of join2 may much less than Agg1.
but Agg1 and Agg2 are slow. so we need to punish this pattern1.
pattern2: (L join1 Agg1) join2 agg2
in pattern2, join1 and join2 takes more time, but Agg1 and agg2 can be processed in parallel.
*/
double penalty = HEAVY_OPERATOR_PUNISH_FACTOR
* Math.min(probeStats.getPenalty(), buildStats.getPenalty());
if (buildStats.getWidth() >= 2) {
//penalty for right deep tree
penalty += rightRowCount;
}
if (physicalHashJoin.getJoinType().isCrossJoin()) {
return CostV1.of(leftRowCount + rightRowCount + outputRowCount,
0,
leftRowCount + rightRowCount,
penalty);
}
return CostV1.of(leftRowCount + rightRowCount + outputRowCount,
rightRowCount,
0,
penalty
);
}
@Override
public Cost visitPhysicalNestedLoopJoin(
PhysicalNestedLoopJoin<? extends Plan, ? extends Plan> nestedLoopJoin,
PlanContext context) {
// TODO: copy from physicalHashJoin, should update according to physical nested loop join properties.
Preconditions.checkState(context.arity() == 2);
StatsDeriveResult leftStatistics = context.getChildStatistics(0);
StatsDeriveResult rightStatistics = context.getChildStatistics(1);
return CostV1.of(
leftStatistics.getRowCount() * rightStatistics.getRowCount(),
rightStatistics.getRowCount(),
0);
}
@Override
public Cost visitPhysicalAssertNumRows(PhysicalAssertNumRows<? extends Plan> assertNumRows,
PlanContext context) {
return CostV1.of(
assertNumRows.getAssertNumRowsElement().getDesiredNumOfRows(),
assertNumRows.getAssertNumRowsElement().getDesiredNumOfRows(),
0
);
}
@Override
public Cost visitPhysicalGenerate(PhysicalGenerate<? extends Plan> generate, PlanContext context) {
StatsDeriveResult statistics = context.getStatisticsWithCheck();
return CostV1.of(
statistics.getRowCount(),
statistics.getRowCount(),
0
);
}
}

View File

@ -0,0 +1,126 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.cost;
import com.google.common.base.Preconditions;
class CostV1 implements Cost {
private static final CostV1 INFINITE = new CostV1(Double.POSITIVE_INFINITY, Double.POSITIVE_INFINITY,
Double.POSITIVE_INFINITY, Double.POSITIVE_INFINITY);
private static final CostV1 ZERO = new CostV1(0, 0, 0, 0);
private final double cpuCost;
private final double memoryCost;
private final double networkCost;
//penalty for
// 1. right deep tree
// 2. right XXX join
private final double penalty;
private final double cost;
/**
* Constructor of CostEstimate.
*/
public CostV1(double cpuCost, double memoryCost, double networkCost, double penaltiy) {
// TODO: fix stats
cpuCost = Double.max(0, cpuCost);
memoryCost = Double.max(0, memoryCost);
networkCost = Double.max(0, networkCost);
this.cpuCost = cpuCost;
this.memoryCost = memoryCost;
this.networkCost = networkCost;
this.penalty = penaltiy;
CostWeight costWeight = CostWeight.get();
this.cost = costWeight.cpuWeight * cpuCost + costWeight.memoryWeight * memoryCost
+ costWeight.networkWeight * networkCost + costWeight.penaltyWeight * penalty;
}
public CostV1(double cost) {
this.cost = cost;
this.cpuCost = 0;
this.networkCost = 0;
this.memoryCost = 0;
this.penalty = 0;
}
@Override
public Cost minus(Cost other) {
return new CostV1(cost - other.getValue());
}
public static CostV1 infinite() {
return INFINITE;
}
public static CostV1 zero() {
return ZERO;
}
public double getCpuCost() {
return cpuCost;
}
public double getMemoryCost() {
return memoryCost;
}
public double getNetworkCost() {
return networkCost;
}
public double getPenalty() {
return penalty;
}
public double getValue() {
return cost;
}
public static CostV1 of(double cpuCost, double maxMemory, double networkCost, double rightDeepPenaltiy) {
return new CostV1(cpuCost, maxMemory, networkCost, rightDeepPenaltiy);
}
public static CostV1 of(double cpuCost, double maxMemory, double networkCost) {
return new CostV1(cpuCost, maxMemory, networkCost, 0);
}
public static CostV1 ofCpu(double cpuCost) {
return new CostV1(cpuCost, 0, 0, 0);
}
public static CostV1 ofMemory(double memoryCost) {
return new CostV1(0, memoryCost, 0, 0);
}
@Override
public String toString() {
StringBuilder sb = new StringBuilder();
sb.append((long) cpuCost).append("/").append((long) memoryCost).append("/").append((long) networkCost)
.append("/").append((long) penalty);
return sb.toString();
}
@Override
public int compare(Cost other) {
Preconditions.checkArgument(other instanceof CostV1, "costValueV1 can only compare with costValueV1");
return Double.compare(cost, ((CostV1) other).cost);
}
}

View File

@ -17,16 +17,38 @@
package org.apache.doris.nereids.cost;
import org.apache.doris.qe.ConnectContext;
import com.google.common.base.Preconditions;
/**
* cost weight.
* The intuition behind `HEAVY_OPERATOR_PUNISH_FACTOR` is we need to avoid this form of join patterns:
* Plan1: L join ( AGG1(A) join AGG2(B))
* But
* Plan2: L join AGG1(A) join AGG2(B) is welcomed.
* AGG is time-consuming operator. From the perspective of rowCount, nereids may choose Plan1,
* because `Agg1 join Agg2` generates few tuples. But in Plan1, Agg1 and Agg2 are done in serial, in Plan2, Agg1 and
* Agg2 are done in parallel. And hence, Plan1 should be punished.
* <p>
* An example is tpch q15.
*/
public class CostWeight {
private final double cpuWeight;
private final double memoryWeight;
private final double networkWeight;
private final double penaltyWeight;
static final double CPU_WEIGHT = 1;
static final double MEMORY_WEIGHT = 1;
static final double NETWORK_WEIGHT = 1.5;
final double cpuWeight;
final double memoryWeight;
final double networkWeight;
/*
* About PENALTY:
* Except stats information, there are some special criteria in doris.
* For example, in hash join cluster, BE could build hash tables
* in parallel for left deep tree. And hence, we need to punish right deep tree.
* penalyWeight is the factor of punishment.
* The punishment is denoted by stats.penalty.
*/
final double penaltyWeight;
/**
* Constructor
@ -42,9 +64,8 @@ public class CostWeight {
this.penaltyWeight = penaltyWeight;
}
public double calculate(CostEstimate costEstimate) {
return costEstimate.getCpuCost() * cpuWeight + costEstimate.getMemoryCost() * memoryWeight
+ costEstimate.getNetworkCost() * networkWeight
+ costEstimate.getPenalty() * penaltyWeight;
public static CostWeight get() {
return new CostWeight(CPU_WEIGHT, MEMORY_WEIGHT, NETWORK_WEIGHT,
ConnectContext.get().getSessionVariable().getNereidsCboPenaltyFactor());
}
}

View File

@ -18,6 +18,7 @@
package org.apache.doris.nereids.jobs.cascades;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.cost.Cost;
import org.apache.doris.nereids.cost.CostCalculator;
import org.apache.doris.nereids.jobs.Job;
import org.apache.doris.nereids.jobs.JobContext;
@ -50,9 +51,9 @@ public class CostAndEnforcerJob extends Job implements Cloneable {
private final GroupExpression groupExpression;
// cost of current plan tree
private double curTotalCost;
private Cost curTotalCost;
// cost of current plan node
private double curNodeCost;
private Cost curNodeCost;
// List of request property to children
// Example: Physical Hash Join
@ -114,8 +115,8 @@ public class CostAndEnforcerJob extends Job implements Cloneable {
countJobExecutionTimesOfGroupExpressions(groupExpression);
// Do init logic of root plan/groupExpr of `subplan`, only run once per task.
if (curChildIndex == -1) {
curNodeCost = 0;
curTotalCost = 0;
curNodeCost = Cost.zero();
curTotalCost = Cost.zero();
curChildIndex = 0;
// List<request property to children>
// [ child item: [leftProperties, rightProperties]]
@ -138,8 +139,8 @@ public class CostAndEnforcerJob extends Job implements Cloneable {
// Calculate cost
if (curChildIndex == 0 && prevChildIndex == -1) {
curNodeCost = CostCalculator.calculateCost(groupExpression);
groupExpression.setCost(curNodeCost);
curTotalCost += curNodeCost;
groupExpression.setCost(curNodeCost.getValue());
curTotalCost = curNodeCost;
}
// Handle all child plan node.
@ -149,7 +150,7 @@ public class CostAndEnforcerJob extends Job implements Cloneable {
// Whether the child group was optimized for this requestChildProperty according to
// the result of returning.
Optional<Pair<Double, GroupExpression>> lowestCostPlanOpt
Optional<Pair<Cost, GroupExpression>> lowestCostPlanOpt
= childGroup.getLowestCostPlan(requestChildProperty);
if (!lowestCostPlanOpt.isPresent()) {
@ -166,7 +167,7 @@ public class CostAndEnforcerJob extends Job implements Cloneable {
// Meaning that optimize recursively by derive job.
prevChildIndex = curChildIndex;
pushJob(clone());
double newCostUpperBound = context.getCostUpperBound() - curTotalCost;
double newCostUpperBound = context.getCostUpperBound() - curTotalCost.getValue();
JobContext jobContext = new JobContext(context.getCascadesContext(),
requestChildProperty, newCostUpperBound);
pushJob(new OptimizeGroupJob(childGroup, jobContext));
@ -182,10 +183,12 @@ public class CostAndEnforcerJob extends Job implements Cloneable {
// plan's requestChildProperty).getOutputProperties(current plan's requestChildProperty) == child
// plan's outputProperties`, the outputProperties must satisfy the origin requestChildProperty
outputChildrenProperties.set(curChildIndex, outputProperties);
curTotalCost += lowestCostExpr.getLowestCostTable().get(requestChildProperty).first;
if (curTotalCost > context.getCostUpperBound()) {
curTotalCost = Double.POSITIVE_INFINITY;
curTotalCost = CostCalculator.addChildCost(groupExpression.getPlan(),
curNodeCost,
lowestCostExpr.getCostValueByProperties(requestChildProperty),
curChildIndex);
if (curTotalCost.getValue() > context.getCostUpperBound()) {
curTotalCost = Cost.infinite();
}
// the request child properties will be covered by the output properties
// that corresponding to the request properties. so if we run a costAndEnforceJob of the same
@ -198,8 +201,8 @@ public class CostAndEnforcerJob extends Job implements Cloneable {
if (!calculateEnforce(requestChildrenProperties, outputChildrenProperties)) {
return; // if error exists, return
}
if (curTotalCost < context.getCostUpperBound()) {
context.setCostUpperBound(curTotalCost);
if (curTotalCost.getValue() < context.getCostUpperBound()) {
context.setCostUpperBound(curTotalCost.getValue());
}
}
clear();
@ -222,7 +225,6 @@ public class CostAndEnforcerJob extends Job implements Cloneable {
// invalid enforce, return.
return false;
}
curTotalCost += enforceCost;
// Not need to do pruning here because it has been done when we get the
// best expr from the child group
@ -237,14 +239,18 @@ public class CostAndEnforcerJob extends Job implements Cloneable {
return false;
}
StatsCalculator.estimate(groupExpression);
// previous curTotalCost exclude the exists best cost of current node
curTotalCost -= curNodeCost;
// recompute cost after adjusting property
curNodeCost = CostCalculator.calculateCost(groupExpression); // recompute current node's cost in current context
groupExpression.setCost(curNodeCost);
// (previous curTotalCost) - (previous curNodeCost) + (current curNodeCost) = (current curTotalCost).
// if current curTotalCost maybe less than previous curTotalCost, we will update the lowest cost and plan
// to the grouping expression and the owner group
curTotalCost += curNodeCost;
groupExpression.setCost(curNodeCost.getValue());
curTotalCost = curNodeCost;
for (int i = 0; i < outputChildrenProperties.size(); i++) {
PhysicalProperties childProperties = outputChildrenProperties.get(i);
curTotalCost = CostCalculator.addChildCost(groupExpression.getPlan(),
curTotalCost,
groupExpression.child(i).getLowestCostPlan(childProperties).get().first,
i);
}
// record map { outputProperty -> outputProperty }, { ANY -> outputProperty },
recordPropertyAndCost(groupExpression, outputProperty, PhysicalProperties.ANY, outputChildrenProperties);
@ -303,8 +309,8 @@ public class CostAndEnforcerJob extends Job implements Cloneable {
lowestCostChildren.clear();
prevChildIndex = -1;
curChildIndex = 0;
curTotalCost = 0;
curNodeCost = 0;
curTotalCost = Cost.zero();
curNodeCost = Cost.zero();
}
/**

View File

@ -19,6 +19,7 @@ package org.apache.doris.nereids.jobs.joinorder.hypergraph;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.PlanContext;
import org.apache.doris.nereids.cost.Cost;
import org.apache.doris.nereids.cost.CostCalculator;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.bitmap.LongBitmap;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.receiver.Counter;
@ -62,7 +63,7 @@ public class GraphSimplifier {
// because it's just used for simulating join. In fact, the graph simplifier
// just generate the partial order of join operator.
private final HashMap<Long, StatsDeriveResult> cacheStats = new HashMap<>();
private final HashMap<Long, Double> cacheCost = new HashMap<>();
private final HashMap<Long, Cost> cacheCost = new HashMap<>();
private final Stack<SimplificationStep> appliedSteps = new Stack<>();
private final Stack<SimplificationStep> unAppliedSteps = new Stack<>();
@ -81,7 +82,7 @@ public class GraphSimplifier {
}
for (Node node : graph.getNodes()) {
cacheStats.put(node.getNodeMap(), node.getGroup().getStatistics());
cacheCost.put(node.getNodeMap(), node.getGroup().getStatistics().getRowCount());
cacheCost.put(node.getNodeMap(), Cost.withRowCount(node.getRowCount()));
}
circleDetector = new CircleDetector(edgeSize);
@ -382,26 +383,26 @@ public class GraphSimplifier {
private SimplificationStep orderJoin(Pair<StatsDeriveResult, Edge> edge1Before2,
Pair<StatsDeriveResult, Edge> edge2Before1, int edgeIndex1, int edgeIndex2) {
double cost1Before2 = calCost(edge1Before2.second, edge1Before2.first,
Cost cost1Before2 = calCost(edge1Before2.second, edge1Before2.first,
cacheStats.get(edge1Before2.second.getLeft()),
cacheStats.get(edge1Before2.second.getRight()));
double cost2Before1 = calCost(edge2Before1.second, edge1Before2.first,
Cost cost2Before1 = calCost(edge2Before1.second, edge1Before2.first,
cacheStats.get(edge1Before2.second.getLeft()),
cacheStats.get(edge1Before2.second.getRight()));
double benefit = Double.MAX_VALUE;
SimplificationStep step;
// Choose the plan with smaller cost and make the simplification step to replace the old edge by it.
if (cost1Before2 < cost2Before1) {
if (cost1Before2 != 0) {
benefit = cost2Before1 / cost1Before2;
if (cost1Before2.getValue() < cost2Before1.getValue()) {
if (cost1Before2.getValue() != 0) {
benefit = cost2Before1.getValue() / cost1Before2.getValue();
}
// choose edge1Before2
step = new SimplificationStep(benefit, edgeIndex1, edgeIndex2, edge1Before2.second.getLeft(),
edge1Before2.second.getRight(), graph.getEdge(edgeIndex2).getLeft(),
graph.getEdge(edgeIndex2).getRight());
} else {
if (cost2Before1 != 0) {
benefit = cost1Before2 / cost2Before1;
if (cost2Before1.getValue() != 0) {
benefit = cost1Before2.getValue() / cost2Before1.getValue();
}
// choose edge2Before1
step = new SimplificationStep(benefit, edgeIndex2, edgeIndex1, edge2Before1.second.getLeft(),
@ -422,11 +423,11 @@ public class GraphSimplifier {
return false;
}
private double calCost(Edge edge, StatsDeriveResult stats,
private Cost calCost(Edge edge, StatsDeriveResult stats,
StatsDeriveResult leftStats, StatsDeriveResult rightStats) {
LogicalJoin join = edge.getJoin();
PlanContext planContext = new PlanContext(stats, leftStats, rightStats);
double cost = 0;
Cost cost = Cost.zero();
if (JoinUtils.shouldNestedLoopJoin(join)) {
PhysicalNestedLoopJoin nestedLoopJoin = new PhysicalNestedLoopJoin<>(
join.getJoinType(),
@ -437,6 +438,8 @@ public class GraphSimplifier {
join.left(),
join.right());
cost = CostCalculator.calculateCost(nestedLoopJoin, planContext);
cost = CostCalculator.addChildCost(nestedLoopJoin, cost, cacheCost.get(edge.getLeft()), 0);
cost = CostCalculator.addChildCost(nestedLoopJoin, cost, cacheCost.get(edge.getRight()), 1);
} else {
PhysicalHashJoin hashJoin = new PhysicalHashJoin<>(
join.getJoinType(),
@ -448,8 +451,11 @@ public class GraphSimplifier {
join.left(),
join.right());
cost = CostCalculator.calculateCost(hashJoin, planContext);
cost = CostCalculator.addChildCost(hashJoin, cost, cacheCost.get(edge.getLeft()), 0);
cost = CostCalculator.addChildCost(hashJoin, cost, cacheCost.get(edge.getRight()), 1);
}
return cost + cacheCost.get(edge.getLeft()) + cacheCost.get(edge.getRight());
return cost;
}
/**

View File

@ -18,6 +18,7 @@
package org.apache.doris.nereids.memo;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.cost.Cost;
import org.apache.doris.nereids.properties.LogicalProperties;
import org.apache.doris.nereids.properties.PhysicalProperties;
import org.apache.doris.nereids.trees.plans.JoinType;
@ -60,7 +61,7 @@ public class Group {
// Map of cost lower bounds
// Map required plan props to cost lower bound of corresponding plan
private final Map<PhysicalProperties, Pair<Double, GroupExpression>> lowestCostPlans = Maps.newHashMap();
private final Map<PhysicalProperties, Pair<Cost, GroupExpression>> lowestCostPlans = Maps.newHashMap();
private boolean isExplored = false;
@ -189,7 +190,7 @@ public class Group {
* @param physicalProperties the physical property constraints
* @return {@link Optional} of cost and {@link GroupExpression} of physical plan pair.
*/
public Optional<Pair<Double, GroupExpression>> getLowestCostPlan(PhysicalProperties physicalProperties) {
public Optional<Pair<Cost, GroupExpression>> getLowestCostPlan(PhysicalProperties physicalProperties) {
if (physicalProperties == null || lowestCostPlans.isEmpty()) {
return Optional.empty();
}
@ -206,9 +207,9 @@ public class Group {
/**
* Set or update lowestCostPlans: properties --> Pair.of(cost, expression)
*/
public void setBestPlan(GroupExpression expression, double cost, PhysicalProperties properties) {
public void setBestPlan(GroupExpression expression, Cost cost, PhysicalProperties properties) {
if (lowestCostPlans.containsKey(properties)) {
if (lowestCostPlans.get(properties).first > cost) {
if (lowestCostPlans.get(properties).first.getValue() > cost.getValue()) {
lowestCostPlans.put(properties, Pair.of(cost, expression));
}
} else {
@ -219,8 +220,9 @@ public class Group {
/**
* replace best plan with new properties
*/
public void replaceBestPlanProperty(PhysicalProperties oldProperty, PhysicalProperties newProperty, double cost) {
Pair<Double, GroupExpression> pair = lowestCostPlans.get(oldProperty);
public void replaceBestPlanProperty(PhysicalProperties oldProperty,
PhysicalProperties newProperty, Cost cost) {
Pair<Cost, GroupExpression> pair = lowestCostPlans.get(oldProperty);
GroupExpression lowestGroupExpr = pair.second;
lowestGroupExpr.updateLowestCostTable(newProperty,
lowestGroupExpr.getInputPropertiesList(oldProperty), cost);
@ -232,11 +234,11 @@ public class Group {
* replace oldGroupExpression with newGroupExpression in lowestCostPlans.
*/
public void replaceBestPlanGroupExpr(GroupExpression oldGroupExpression, GroupExpression newGroupExpression) {
Map<PhysicalProperties, Pair<Double, GroupExpression>> needReplaceBestExpressions = Maps.newHashMap();
for (Iterator<Entry<PhysicalProperties, Pair<Double, GroupExpression>>> iterator =
Map<PhysicalProperties, Pair<Cost, GroupExpression>> needReplaceBestExpressions = Maps.newHashMap();
for (Iterator<Entry<PhysicalProperties, Pair<Cost, GroupExpression>>> iterator =
lowestCostPlans.entrySet().iterator(); iterator.hasNext(); ) {
Map.Entry<PhysicalProperties, Pair<Double, GroupExpression>> entry = iterator.next();
Pair<Double, GroupExpression> pair = entry.getValue();
Map.Entry<PhysicalProperties, Pair<Cost, GroupExpression>> entry = iterator.next();
Pair<Cost, GroupExpression> pair = entry.getValue();
if (pair.second.equals(oldGroupExpression)) {
needReplaceBestExpressions.put(entry.getKey(), Pair.of(pair.first, newGroupExpression));
iterator.remove();
@ -345,7 +347,8 @@ public class Group {
if (!target.lowestCostPlans.containsKey(physicalProperties)) {
target.lowestCostPlans.put(physicalProperties, costAndGroupExpr);
} else {
if (costAndGroupExpr.first < target.lowestCostPlans.get(physicalProperties).first) {
if (costAndGroupExpr.first.getValue()
< target.lowestCostPlans.get(physicalProperties).first.getValue()) {
target.lowestCostPlans.put(physicalProperties, costAndGroupExpr);
}
}

View File

@ -18,7 +18,7 @@
package org.apache.doris.nereids.memo;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.cost.CostEstimate;
import org.apache.doris.nereids.cost.Cost;
import org.apache.doris.nereids.metrics.EventChannel;
import org.apache.doris.nereids.metrics.EventProducer;
import org.apache.doris.nereids.metrics.consumer.LogConsumer;
@ -48,7 +48,6 @@ public class GroupExpression {
EventChannel.getDefaultChannel().addConsumers(new LogConsumer(CostStateUpdateEvent.class,
EventChannel.LOG)));
private double cost = 0.0;
private CostEstimate costEstimate = null;
private Group ownerGroup;
private final List<Group> children;
private final Plan plan;
@ -60,7 +59,7 @@ public class GroupExpression {
// Mapping from output properties to the corresponding best cost, statistics, and child properties.
// key is the physical properties the group expression support for its parent
// and value is cost and request physical properties to its children.
private final Map<PhysicalProperties, Pair<Double, List<PhysicalProperties>>> lowestCostTable;
private final Map<PhysicalProperties, Pair<Cost, List<PhysicalProperties>>> lowestCostTable;
// Each physical group expression maintains mapping incoming requests to the corresponding child requests.
// key is the output physical properties satisfying the incoming request properties
// value is the request physical properties
@ -182,7 +181,7 @@ public class GroupExpression {
this.isUnused = isUnused;
}
public Map<PhysicalProperties, Pair<Double, List<PhysicalProperties>>> getLowestCostTable() {
public Map<PhysicalProperties, Pair<Cost, List<PhysicalProperties>>> getLowestCostTable() {
return lowestCostTable;
}
@ -198,10 +197,10 @@ public class GroupExpression {
* @return true if lowest cost table change.
*/
public boolean updateLowestCostTable(PhysicalProperties outputProperties,
List<PhysicalProperties> childrenInputProperties, double cost) {
COST_STATE_TRACER.log(CostStateUpdateEvent.of(this, cost, outputProperties));
List<PhysicalProperties> childrenInputProperties, Cost cost) {
COST_STATE_TRACER.log(CostStateUpdateEvent.of(this, cost.getValue(), outputProperties));
if (lowestCostTable.containsKey(outputProperties)) {
if (lowestCostTable.get(outputProperties).first > cost) {
if (lowestCostTable.get(outputProperties).first.getValue() > cost.getValue()) {
lowestCostTable.put(outputProperties, Pair.of(cost, childrenInputProperties));
return true;
} else {
@ -220,6 +219,11 @@ public class GroupExpression {
* @return Lowest cost to satisfy that property
*/
public double getCostByProperties(PhysicalProperties property) {
Preconditions.checkState(lowestCostTable.containsKey(property));
return lowestCostTable.get(property).first.getValue();
}
public Cost getCostValueByProperties(PhysicalProperties property) {
Preconditions.checkState(lowestCostTable.containsKey(property));
return lowestCostTable.get(property).first;
}
@ -263,14 +267,6 @@ public class GroupExpression {
this.cost = cost;
}
public CostEstimate getCostEstimate() {
return costEstimate;
}
public void setCostEstimate(CostEstimate costEstimate) {
this.costEstimate = costEstimate;
}
@Override
public boolean equals(Object o) {
if (this == o) {
@ -310,9 +306,8 @@ public class GroupExpression {
builder.append("#").append(ownerGroup.getGroupId().asInt());
}
if (costEstimate != null) {
builder.append(" cost=").append((long) cost).append(" (").append(costEstimate).append(")");
}
builder.append(" cost=").append((long) cost);
builder.append(" estRows=").append(estOutputRowCount);
builder.append(" (plan=").append(plan.toString()).append(") children=[");
for (Group group : children) {

View File

@ -22,6 +22,8 @@ import org.apache.doris.common.Pair;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.StatementContext;
import org.apache.doris.nereids.analyzer.CTEContext;
import org.apache.doris.nereids.cost.Cost;
import org.apache.doris.nereids.cost.CostCalculator;
import org.apache.doris.nereids.metrics.EventChannel;
import org.apache.doris.nereids.metrics.EventProducer;
import org.apache.doris.nereids.metrics.consumer.LogConsumer;
@ -765,26 +767,27 @@ public class Memo {
return res;
}
double bestChildrenCost = 0;
List<List<Pair<Long, Double>>> children = new ArrayList<>();
for (int i = 0; i < inputProperties.size(); i++) {
// To avoid reach a circle, we don't allow ranking the same group with the same physical properties.
Preconditions.checkArgument(!groupExpression.child(i).equals(groupExpression.getOwnerGroup())
|| !prop.equals(inputProperties.get(i)));
bestChildrenCost += groupExpression.children().get(i).getLowestCostPlan(inputProperties.get(i)).get().first;
List<Pair<Long, Double>> idCostPair
= rankGroup(groupExpression.child(i), inputProperties.get(i));
children.add(idCostPair);
}
List<Pair<Long, List<Integer>>> childrenId = new ArrayList<>();
permute(children, 0, childrenId, new ArrayList<>());
Cost cost = CostCalculator.calculateCost(groupExpression);
for (Pair<Long, List<Integer>> c : childrenId) {
double childCost = 0;
Cost totalCost = cost;
for (int i = 0; i < children.size(); i++) {
childCost += children.get(i).get(c.second.get(i)).second;
totalCost = CostCalculator.addChildCost(groupExpression.getPlan(),
totalCost,
groupExpression.child(i).getLowestCostPlan(inputProperties.get(i)).get().first,
i);
}
res.add(Pair.of(c.first,
childCost + groupExpression.getCostByProperties(prop) - bestChildrenCost));
res.add(Pair.of(c.first, totalCost.getValue()));
}
return res;
}

View File

@ -18,6 +18,7 @@
package org.apache.doris.nereids.properties;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.cost.Cost;
import org.apache.doris.nereids.cost.CostCalculator;
import org.apache.doris.nereids.jobs.JobContext;
import org.apache.doris.nereids.memo.GroupExpression;
@ -103,12 +104,12 @@ public class ChildrenPropertiesRegulator extends PlanVisitor<Double, Void> {
DistributionSpecHash rightHashSpec = (DistributionSpecHash) rightDistributionSpec;
GroupExpression leftChild = children.get(0);
final Pair<Double, List<PhysicalProperties>> leftLowest
final Pair<Cost, List<PhysicalProperties>> leftLowest
= leftChild.getLowestCostTable().get(childrenProperties.get(0));
PhysicalProperties leftOutput = leftChild.getOutputProperties(childrenProperties.get(0));
GroupExpression rightChild = children.get(1);
Pair<Double, List<PhysicalProperties>> rightLowest
Pair<Cost, List<PhysicalProperties>> rightLowest
= rightChild.getLowestCostTable().get(childrenProperties.get(1));
PhysicalProperties rightOutput = rightChild.getOutputProperties(childrenProperties.get(1));
@ -128,20 +129,20 @@ public class ChildrenPropertiesRegulator extends PlanVisitor<Double, Void> {
PhysicalProperties rightRequireProperties = calRightRequiredOfBucketShuffleJoin(leftHashSpec,
rightHashSpec);
if (!rightOutput.equals(rightRequireProperties)) {
enforceCost += updateChildEnforceAndCost(rightChild, rightOutput,
updateChildEnforceAndCost(rightChild, rightOutput,
(DistributionSpecHash) rightRequireProperties.getDistributionSpec(), rightLowest.first);
}
childrenProperties.set(1, rightRequireProperties);
return enforceCost;
}
enforceCost += updateChildEnforceAndCost(leftChild, leftOutput,
updateChildEnforceAndCost(leftChild, leftOutput,
(DistributionSpecHash) requiredProperties.get(0).getDistributionSpec(), leftLowest.first);
childrenProperties.set(0, requiredProperties.get(0));
}
// check right hand must distribute.
if (rightHashSpec.getShuffleType() != ShuffleType.ENFORCED) {
enforceCost += updateChildEnforceAndCost(rightChild, rightOutput,
updateChildEnforceAndCost(rightChild, rightOutput,
(DistributionSpecHash) requiredProperties.get(1).getDistributionSpec(), rightLowest.first);
childrenProperties.set(1, requiredProperties.get(1));
}
@ -175,15 +176,13 @@ public class ChildrenPropertiesRegulator extends PlanVisitor<Double, Void> {
}
private double updateChildEnforceAndCost(GroupExpression child, PhysicalProperties childOutput,
DistributionSpecHash required, double currentCost) {
double enforceCost = 0;
DistributionSpecHash required, Cost currentCost) {
if (child.getPlan() instanceof PhysicalDistribute) {
//To avoid continuous distribute operator, we just enforce the child's child
childOutput = child.getInputPropertiesList(childOutput).get(0);
Pair<Double, GroupExpression> newChildAndCost
Pair<Cost, GroupExpression> newChildAndCost
= child.getOwnerGroup().getLowestCostPlan(childOutput).get();
child = newChildAndCost.second;
enforceCost = newChildAndCost.first - currentCost;
currentCost = newChildAndCost.first;
}
@ -193,13 +192,16 @@ public class ChildrenPropertiesRegulator extends PlanVisitor<Double, Void> {
PhysicalProperties newOutputProperty = new PhysicalProperties(outputDistributionSpec);
GroupExpression enforcer = outputDistributionSpec.addEnforcer(child.getOwnerGroup());
jobContext.getCascadesContext().getMemo().addEnforcerPlan(enforcer, child.getOwnerGroup());
enforceCost = Double.sum(enforceCost, CostCalculator.calculateCost(enforcer));
Cost totalCost = CostCalculator.addChildCost(enforcer.getPlan(),
currentCost,
CostCalculator.calculateCost(enforcer),
0);
if (enforcer.updateLowestCostTable(newOutputProperty,
Lists.newArrayList(childOutput), enforceCost + currentCost)) {
Lists.newArrayList(childOutput), totalCost)) {
enforcer.putOutputPropertiesMap(newOutputProperty, newOutputProperty);
}
child.getOwnerGroup().setBestPlan(enforcer, enforceCost + currentCost, newOutputProperty);
child.getOwnerGroup().setBestPlan(enforcer, totalCost, newOutputProperty);
return enforceCost;
}
}

View File

@ -17,6 +17,7 @@
package org.apache.doris.nereids.properties;
import org.apache.doris.nereids.cost.Cost;
import org.apache.doris.nereids.cost.CostCalculator;
import org.apache.doris.nereids.jobs.JobContext;
import org.apache.doris.nereids.memo.GroupExpression;
@ -38,16 +39,16 @@ public class EnforceMissingPropertiesHelper {
EventChannel.getDefaultChannel().addConsumers(new LogConsumer(EnforcerEvent.class, EventChannel.LOG)));
private final JobContext context;
private final GroupExpression groupExpression;
private double curTotalCost;
private Cost curTotalCost;
public EnforceMissingPropertiesHelper(JobContext context, GroupExpression groupExpression,
double curTotalCost) {
Cost curTotalCost) {
this.context = context;
this.groupExpression = groupExpression;
this.curTotalCost = curTotalCost;
}
public double getCurTotalCost() {
public Cost getCurTotalCost() {
return curTotalCost;
}
@ -83,7 +84,8 @@ public class EnforceMissingPropertiesHelper {
*/
private PhysicalProperties enforceDistributionButMeetSort(PhysicalProperties output, PhysicalProperties request) {
groupExpression.getOwnerGroup()
.replaceBestPlanProperty(output, PhysicalProperties.ANY, groupExpression.getCostByProperties(output));
.replaceBestPlanProperty(
output, PhysicalProperties.ANY, groupExpression.getCostValueByProperties(output));
return enforceSortAndDistribution(output, request);
}
@ -149,8 +151,10 @@ public class EnforceMissingPropertiesHelper {
context.getCascadesContext().getMemo().addEnforcerPlan(enforcer, groupExpression.getOwnerGroup());
ENFORCER_TRACER.log(EnforcerEvent.of(groupExpression, ((PhysicalPlan) enforcer.getPlan()),
oldOutputProperty, newOutputProperty));
curTotalCost += CostCalculator.calculateCost(enforcer);
curTotalCost = CostCalculator.addChildCost(enforcer.getPlan(),
CostCalculator.calculateCost(enforcer),
curTotalCost,
0);
if (enforcer.updateLowestCostTable(newOutputProperty,
Lists.newArrayList(oldOutputProperty), curTotalCost)) {
enforcer.putOutputPropertiesMap(newOutputProperty, newOutputProperty);

View File

@ -21,6 +21,7 @@ import org.apache.doris.catalog.OlapTable;
import org.apache.doris.common.jmockit.Deencapsulation;
import org.apache.doris.nereids.analyzer.UnboundRelation;
import org.apache.doris.nereids.analyzer.UnboundSlot;
import org.apache.doris.nereids.cost.Cost;
import org.apache.doris.nereids.properties.LogicalProperties;
import org.apache.doris.nereids.properties.PhysicalProperties;
import org.apache.doris.nereids.properties.UnboundLogicalProperties;
@ -91,7 +92,7 @@ class MemoTest implements MemoPatternMatchSupported {
FakePlan fakePlan = new FakePlan();
GroupExpression srcParentExpression = new GroupExpression(fakePlan, Lists.newArrayList(srcGroup));
Group srcParentGroup = new Group(new GroupId(0), srcParentExpression, new LogicalProperties(ArrayList::new));
srcParentGroup.setBestPlan(srcParentExpression, Double.MIN_VALUE, PhysicalProperties.ANY);
srcParentGroup.setBestPlan(srcParentExpression, Cost.zero(), PhysicalProperties.ANY);
GroupExpression dstParentExpression = new GroupExpression(fakePlan, Lists.newArrayList(dstGroup));
Group dstParentGroup = new Group(new GroupId(1), dstParentExpression, new LogicalProperties(ArrayList::new));