[enhancement](Nereids) add new distributed cost model (#17556)

Add a new distributed cost model in Nereids. The new cost model models the cost of the pipeline execute engine by dividing cost into run and start costs. They are:
* START COST: the cost from starting to emitting the fist tuple
* RUN COST: the cost from emitting the first tuple to emitting all tuples

For the parent operator and child operator, we assume the timeline of them is:
  ```
  child start ---> child run --------------------> finish
             |---> parent start ---> parent run -> finish
  ```

Therefore, in the parallel model, we can get:
  ```
  start_cost(parent) = start_cost(child) + start_cost(parent)
  run_cost(parent) = max(run_cost(child), start_cost(parent) + run_cost(parent))
  ```
This commit is contained in:
谢健
2023-03-15 11:22:31 +08:00
committed by GitHub
parent 66f3ef568e
commit 97bf07fe26
22 changed files with 651 additions and 47 deletions

View File

@ -35,6 +35,7 @@ public class PlanContext {
private List<Statistics> childrenStats = new ArrayList<>();
private Statistics planStats;
private int arity = 0;
private boolean isBroadcastJoin = false;
/**
* Constructor for PlanContext.
@ -57,6 +58,14 @@ public class PlanContext {
this.arity = this.childrenStats.size();
}
public void setBroadcastJoin() {
isBroadcastJoin = true;
}
public boolean isBroadcastJoin() {
return isBroadcastJoin;
}
public int arity() {
return arity;
}
@ -71,4 +80,8 @@ public class PlanContext {
public Statistics getChildStatistics(int index) {
return childrenStats.get(index);
}
public List<Statistics> getChildrenStatistics() {
return childrenStats;
}
}

View File

@ -17,6 +17,8 @@
package org.apache.doris.nereids.cost;
import org.apache.doris.qe.ConnectContext;
/**
* Cost encapsulate the real cost with double type.
* We do this because we want to customize the operation of adding child cost
@ -25,19 +27,33 @@ package org.apache.doris.nereids.cost;
public interface Cost {
public double getValue();
public int compare(Cost other);
public Cost minus(Cost other);
/**
* This is for calculating the cost in simplifier
*/
public static Cost withRowCount(double rowCount) {
return new CostV1(rowCount, 0, 0, 0);
if (ConnectContext.get().getSessionVariable().getEnableNewCostModel()) {
return new CostV2(0, rowCount, 0);
}
return new CostV1(rowCount);
}
/**
* return zero cost
*/
public static Cost zero() {
if (ConnectContext.get().getSessionVariable().getEnableNewCostModel()) {
return CostV2.zero();
}
return CostV1.zero();
}
/**
* return infinite cost
*/
public static Cost infinite() {
if (ConnectContext.get().getSessionVariable().getEnableNewCostModel()) {
return CostV2.infinite();
}
return CostV1.infinite();
}
}

View File

@ -20,30 +20,54 @@ package org.apache.doris.nereids.cost;
import org.apache.doris.nereids.PlanContext;
import org.apache.doris.nereids.annotation.Developing;
import org.apache.doris.nereids.memo.GroupExpression;
import org.apache.doris.nereids.properties.DistributionSpecReplicated;
import org.apache.doris.nereids.properties.PhysicalProperties;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.qe.ConnectContext;
import java.util.List;
/**
* Calculate the cost of a plan.
* Inspired by Presto.
*/
@Developing
//TODO: memory cost and network cost should be estimated by byte size.
public class CostCalculator {
/**
* Constructor.
* Calculate cost for groupExpression
*/
public static Cost calculateCost(GroupExpression groupExpression) {
public static Cost calculateCost(GroupExpression groupExpression, List<PhysicalProperties> childrenProperties) {
PlanContext planContext = new PlanContext(groupExpression);
CostModelV1 costModel = new CostModelV1();
return groupExpression.getPlan().accept(costModel, planContext);
if (childrenProperties.size() >= 2
&& childrenProperties.get(1).getDistributionSpec() instanceof DistributionSpecReplicated) {
planContext.setBroadcastJoin();
}
if (ConnectContext.get().getSessionVariable().getEnableNewCostModel()) {
CostModelV2 costModelV2 = new CostModelV2();
return groupExpression.getPlan().accept(costModelV2, planContext);
} else {
CostModelV1 costModelV1 = new CostModelV1();
return groupExpression.getPlan().accept(costModelV1, planContext);
}
}
/**
* Calculate cost without groupExpression
*/
public static Cost calculateCost(Plan plan, PlanContext planContext) {
CostModelV1 costModel = new CostModelV1();
return plan.accept(costModel, planContext);
if (ConnectContext.get().getSessionVariable().getEnableNewCostModel()) {
CostModelV2 costModel = new CostModelV2();
return plan.accept(costModel, planContext);
} else {
CostModelV1 costModel = new CostModelV1();
return plan.accept(costModel, planContext);
}
}
public static Cost addChildCost(Plan plan, Cost planCost, Cost childCost, int index) {
return CostModelV1.addChildCost(plan, planCost, childCost, index);
if (!ConnectContext.get().getSessionVariable().getEnableNewCostModel()) {
return CostModelV1.addChildCost(plan, planCost, childCost, index);
}
return CostModelV2.addChildCost(plan, planCost, childCost, index);
}
}

View File

@ -0,0 +1,322 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.cost;
import org.apache.doris.nereids.PlanContext;
import org.apache.doris.nereids.properties.DistributionSpec;
import org.apache.doris.nereids.properties.DistributionSpecReplicated;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.physical.AbstractPhysicalJoin;
import org.apache.doris.nereids.trees.plans.physical.AbstractPhysicalSort;
import org.apache.doris.nereids.trees.plans.physical.PhysicalAssertNumRows;
import org.apache.doris.nereids.trees.plans.physical.PhysicalDistribute;
import org.apache.doris.nereids.trees.plans.physical.PhysicalEsScan;
import org.apache.doris.nereids.trees.plans.physical.PhysicalFileScan;
import org.apache.doris.nereids.trees.plans.physical.PhysicalFilter;
import org.apache.doris.nereids.trees.plans.physical.PhysicalGenerate;
import org.apache.doris.nereids.trees.plans.physical.PhysicalHashAggregate;
import org.apache.doris.nereids.trees.plans.physical.PhysicalHashJoin;
import org.apache.doris.nereids.trees.plans.physical.PhysicalJdbcScan;
import org.apache.doris.nereids.trees.plans.physical.PhysicalLimit;
import org.apache.doris.nereids.trees.plans.physical.PhysicalNestedLoopJoin;
import org.apache.doris.nereids.trees.plans.physical.PhysicalOlapScan;
import org.apache.doris.nereids.trees.plans.physical.PhysicalProject;
import org.apache.doris.nereids.trees.plans.physical.PhysicalRepeat;
import org.apache.doris.nereids.trees.plans.physical.PhysicalSchemaScan;
import org.apache.doris.nereids.trees.plans.physical.PhysicalSetOperation;
import org.apache.doris.nereids.trees.plans.physical.PhysicalStorageLayerAggregate;
import org.apache.doris.nereids.trees.plans.physical.PhysicalUnion;
import org.apache.doris.nereids.trees.plans.physical.PhysicalWindow;
import org.apache.doris.nereids.trees.plans.visitor.PlanVisitor;
import org.apache.doris.statistics.Statistics;
import com.google.common.base.Preconditions;
/**
* This is a cost model to calculate the runCost and startCost of each operator
*/
class CostModelV2 extends PlanVisitor<Cost, PlanContext> {
static double HASH_COST = 1.0;
static double PROBE_COST = 1.2;
static double CMP_COST = 1.5;
static double PUSH_DOWN_AGG_COST = 0.1;
public static Cost addChildCost(Plan plan, Cost planCost, Cost childCost, int index) {
Preconditions.checkArgument(childCost instanceof CostV2 && planCost instanceof CostV2);
CostV2 planCostV2 = (CostV2) planCost;
CostV2 childCostV2 = (CostV2) childCost;
if (plan instanceof PhysicalLimit) {
planCostV2 = new CostV2(childCostV2.getStartCost(), childCostV2.getRunCost() * planCostV2.getLimitRation(),
childCostV2.getMemory());
} else if (plan instanceof AbstractPhysicalJoin) {
if (index == 0) {
planCostV2.updateChildCost(childCostV2.getStartCost(), childCostV2.getRunCost(),
childCostV2.getMemory());
} else {
planCostV2.updateChildCost(childCostV2.getRunCost(), 0, childCostV2.getMemory());
}
} else {
planCostV2.updateChildCost(childCostV2.getStartCost(), childCostV2.getRunCost(), childCostV2.getMemory());
}
if (index == plan.arity() - 1) {
planCostV2.finish();
}
return planCostV2;
}
@Override
public Cost visit(Plan plan, PlanContext context) {
return CostV2.zero();
}
@Override
public Cost visitPhysicalOlapScan(PhysicalOlapScan physicalOlapScan, PlanContext context) {
return calculateScanWithoutRF(context.getStatisticsWithCheck());
}
public Cost visitPhysicalSchemaScan(PhysicalSchemaScan physicalSchemaScan, PlanContext context) {
return calculateScanWithoutRF(context.getStatisticsWithCheck());
}
@Override
public Cost visitPhysicalStorageLayerAggregate(PhysicalStorageLayerAggregate storageLayerAggregate,
PlanContext context) {
Statistics stats = context.getStatisticsWithCheck();
double ioCost = stats.computeSize();
double runCost1 = CostWeight.get().weightSum(0, ioCost, 0) / stats.getBENumber();
// Note the stats of this operator is the stats of relation.
// We need add a plenty for this cost. Maybe changing rowCount of storageLayer is better
double startCost = runCost1 / 2;
double totalCost = startCost;
double runCost = totalCost - startCost;
return new CostV2(startCost, runCost, 0);
}
@Override
public Cost visitPhysicalFileScan(PhysicalFileScan physicalFileScan, PlanContext context) {
return calculateScanWithoutRF(context.getStatisticsWithCheck());
}
@Override
public Cost visitPhysicalProject(PhysicalProject<? extends Plan> physicalProject, PlanContext context) {
Statistics statistics = context.getStatisticsWithCheck();
double cpuCost = statistics.getRowCount() * ExprCostModel.calculateExprCost(physicalProject.getProjects());
double startCost = 0;
double runCost = CostWeight.get().weightSum(cpuCost, 0, 0) / statistics.getBENumber();
return new CostV2(startCost, runCost, 0);
}
@Override
public Cost visitPhysicalJdbcScan(PhysicalJdbcScan physicalJdbcScan, PlanContext context) {
return calculateScanWithoutRF(context.getStatisticsWithCheck());
}
@Override
public Cost visitPhysicalEsScan(PhysicalEsScan physicalEsScan, PlanContext context) {
return calculateScanWithoutRF(context.getStatisticsWithCheck());
}
@Override
public Cost visitAbstractPhysicalSort(AbstractPhysicalSort<? extends Plan> sort, PlanContext context) {
Statistics statistics = context.getStatisticsWithCheck();
Statistics childStatistics = context.getChildStatistics(0);
double runCost;
if (sort.getSortPhase().isMerge()) {
runCost = statistics.getRowCount() * CMP_COST * Math.log(childStatistics.getBENumber());
} else {
runCost = childStatistics.getRowCount() * CMP_COST * Math.log(statistics.getRowCount())
/ statistics.getBENumber();
}
double startCost = runCost;
return new CostV2(startCost, runCost, statistics.computeSize());
}
@Override
public Cost visitPhysicalDistribute(PhysicalDistribute<? extends Plan> distribute, PlanContext context) {
Statistics childStatistics = context.getChildStatistics(0);
double size = childStatistics.computeSize();
DistributionSpec spec = distribute.getDistributionSpec();
double netCost;
if (spec instanceof DistributionSpecReplicated) {
netCost = getNetCost(size * childStatistics.getBENumber());
} else {
netCost = getNetCost(size);
}
double startCost = 0;
double runCost = CostWeight.get().weightSum(0, 0, netCost) / childStatistics.getBENumber();
return new CostV2(startCost, runCost, 0);
}
@Override
public Cost visitPhysicalHashAggregate(PhysicalHashAggregate<? extends Plan> aggregate, PlanContext context) {
Statistics stats = context.getStatisticsWithCheck();
Statistics childStats = context.getChildStatistics(0);
double exprCost = ExprCostModel.calculateExprCost(aggregate.getExpressions());
return calculateAggregate(stats, childStats, exprCost);
}
@Override
public Cost visitPhysicalHashJoin(PhysicalHashJoin<? extends Plan, ? extends Plan> physicalHashJoin,
PlanContext context) {
Statistics stats = context.getStatisticsWithCheck();
Statistics leftStats = context.getChildStatistics(0);
Statistics rightStats = context.getChildStatistics(1);
double otherExprCost = ExprCostModel.calculateExprCost(physicalHashJoin.getOtherJoinConjuncts());
double buildTableCost = rightStats.getRowCount() * HASH_COST;
if (context.isBroadcastJoin()) {
buildTableCost *= stats.getBENumber();
}
double probeCost = leftStats.getRowCount() * PROBE_COST + stats.getRowCount() * otherExprCost;
double startCost = CostWeight.get().weightSum(buildTableCost, 0, 0);
double runCost = CostWeight.get().weightSum(probeCost, 0, 0) / stats.getBENumber();
return new CostV2(startCost, runCost, rightStats.computeSize());
}
@Override
public Cost visitPhysicalNestedLoopJoin(PhysicalNestedLoopJoin<? extends Plan, ? extends Plan> nestedLoopJoin,
PlanContext context) {
Statistics stats = context.getStatisticsWithCheck();
Statistics leftStats = context.getChildStatistics(0);
Statistics rightStats = context.getChildStatistics(1);
double otherExprCost = ExprCostModel.calculateExprCost(nestedLoopJoin.getOtherJoinConjuncts());
//NSL materialized right child
double probeCost = leftStats.getRowCount() * rightStats.getRowCount() * otherExprCost;
if (!context.isBroadcastJoin()) {
probeCost /= stats.getBENumber();
}
double startCost = 0;
double runCost = CostWeight.get().weightSum(probeCost, 0, 0) / stats.getBENumber();
return new CostV2(startCost, runCost, rightStats.computeSize());
}
@Override
public Cost visitPhysicalAssertNumRows(PhysicalAssertNumRows<? extends Plan> assertNumRows, PlanContext context) {
return new CostV2(0, 0, 0);
}
@Override
public Cost visitPhysicalLimit(PhysicalLimit<? extends Plan> limit, PlanContext context) {
CostV2 cost = new CostV2(0, 0, 0);
long rows = limit.getLimit() + limit.getOffset();
cost.setLimitRation(rows / context.getChildStatistics(0).getRowCount());
return cost;
}
@Override
public Cost visitPhysicalGenerate(PhysicalGenerate<? extends Plan> generate, PlanContext context) {
Statistics statistics = context.getStatisticsWithCheck();
double exprCost = ExprCostModel.calculateExprCost(generate.getGenerators());
double cpuCost = exprCost * statistics.getRowCount();
double startCost = 0;
double runCost = CostWeight.get().weightSum(cpuCost, 0, 0) / statistics.getBENumber();
return new CostV2(startCost, runCost, 0);
}
@Override
public Cost visitPhysicalRepeat(PhysicalRepeat<? extends Plan> repeat, PlanContext context) {
//Repeat expand the tuple according the groupSet
return new CostV2(0, 0, 0);
}
@Override
public Cost visitPhysicalWindow(PhysicalWindow<? extends Plan> window, PlanContext context) {
Statistics stats = context.getStatisticsWithCheck();
double exprCost = ExprCostModel.calculateExprCost(window.getWindowExpressions());
double cpuCost = stats.getRowCount() * exprCost;
double startCost = 0;
double runCost = CostWeight.get().weightSum(cpuCost, 0, 0) / stats.getBENumber();
return new CostV2(startCost, runCost, 0);
}
@Override
public Cost visitPhysicalUnion(PhysicalUnion union, PlanContext context) {
//Union all operation just concat all tuples
return new CostV2(0, 0, 0);
}
@Override
public Cost visitPhysicalSetOperation(PhysicalSetOperation intersect, PlanContext context) {
int rowCount = 0;
double size = 0;
for (Statistics childStats : context.getChildrenStatistics()) {
rowCount += childStats.getRowCount();
size += childStats.computeSize();
}
double startCost = CostWeight.get().weightSum(rowCount * HASH_COST, 0, 0);
double runCost = 0;
return new CostV2(startCost, runCost, size);
}
@Override
public Cost visitPhysicalFilter(PhysicalFilter physicalFilter, PlanContext context) {
Statistics stats = context.getStatisticsWithCheck();
double exprCost = ExprCostModel.calculateExprCost(physicalFilter.getExpressions());
double cpuCost = exprCost * stats.getRowCount();
double startCost = 0;
double runCost = CostWeight.get().weightSum(cpuCost, 0, 0) / stats.getBENumber();
return new CostV2(startCost, runCost, 0);
}
private CostV2 calculateScanWithoutRF(Statistics stats) {
//TODO: consider runtimeFilter
double io = stats.computeSize();
double startCost = 0;
double runCost = CostWeight.get().weightSum(0, io, 0) / stats.getBENumber();
return new CostV2(startCost, runCost, 0);
}
private CostV2 calculateAggregate(Statistics stats, Statistics childStats, double exprCost) {
// Build HashTable
double startCost = CostWeight.get()
.weightSum(HASH_COST * childStats.getRowCount() + exprCost * childStats.getRowCount(), 0, 0);
double runCost = 0;
return new CostV2(startCost, runCost, stats.computeSize());
}
private double getNetCost(double size) {
// we assume the transferRate is 4MB/s.
// TODO: setting in session variable
int transferRate = 4096 * 1024;
return Math.ceil(size / transferRate);
}
}

View File

@ -17,8 +17,6 @@
package org.apache.doris.nereids.cost;
import com.google.common.base.Preconditions;
class CostV1 implements Cost {
private static final CostV1 INFINITE = new CostV1(Double.POSITIVE_INFINITY, Double.POSITIVE_INFINITY,
Double.POSITIVE_INFINITY, Double.POSITIVE_INFINITY);
@ -60,11 +58,6 @@ class CostV1 implements Cost {
this.penalty = 0;
}
@Override
public Cost minus(Cost other) {
return new CostV1(cost - other.getValue());
}
public static CostV1 infinite() {
return INFINITE;
}
@ -116,11 +109,5 @@ class CostV1 implements Cost {
.append("/").append((long) penalty);
return sb.toString();
}
@Override
public int compare(Cost other) {
Preconditions.checkArgument(other instanceof CostV1, "costValueV1 can only compare with costValueV1");
return Double.compare(cost, ((CostV1) other).cost);
}
}

View File

@ -0,0 +1,135 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.cost;
import org.apache.doris.qe.ConnectContext;
/**
* This cost model calculates the distributed cost by dividing it into two components: startCost and runCost.
* The startCost represents the cost of an operator starting to emit the first tuple, while the runCost represents
* the cost of the operator emitting all tuples.
* <p>
* If all operators run in parallel, the child and parent operators can be represented as follows:
* childStart ---> childRun
* |---> operatorStart ---> operatorRun
* <p>
* If all operators run serially, the order would be:
* childStart ---> childRun ---> operatorStart ---> operatorRun
* <p>
* The degree of parallelism is controlled by the decay parameter, with a value of 1 indicating perfect serial execution
* and a value of 0 indicating perfect parallel execution.
*/
class CostV2 implements Cost {
double memory;
double runCost;
double startCost;
double childStartCost;
double childRunCost;
double cost;
double leftStartCost = 0;
double limitRatio = 1;
/**
* Constructor of CostV2.
*/
CostV2(double startCost, double runCost, double memory) {
this.memory = memory;
this.runCost = makeValidDouble(runCost);
this.startCost = makeValidDouble(startCost);
this.childRunCost = 0;
this.childStartCost = 0;
this.cost = this.startCost + this.runCost;
}
public void setLimitRation(double ratio) {
this.limitRatio = Double.min(1, ratio);
}
public double getLimitRation() {
return limitRatio;
}
public void updateChildCost(double childStartCost, double childRunCost, double memory) {
childStartCost = makeValidDouble(childStartCost);
childRunCost = makeValidDouble(childRunCost);
this.childStartCost = Double.max(childStartCost, this.childStartCost);
this.childRunCost = Double.max(childRunCost, this.childRunCost);
this.cost = startCost + this.childStartCost + Double.max(
this.childRunCost + this.runCost * CostWeight.getDelay(), this.runCost);
this.memory += memory;
}
@Override
public String toString() {
StringBuilder sb = new StringBuilder();
sb.append(runCost)
.append("/").append(startCost)
.append("/").append(cost);
return sb.toString();
}
private double makeValidDouble(Double value) {
if (Double.isNaN(value)) {
return 0;
}
if (Double.isInfinite(value)) {
return Double.MAX_VALUE / 1000;
}
return value;
}
@Override
public double getValue() {
double maxExecMemByte = ConnectContext.get().getSessionVariable().getMaxExecMemByte();
if (memory > maxExecMemByte) {
cost *= Math.ceil(memory / maxExecMemByte);
}
return cost;
}
public void finish() {
startCost = startCost + childStartCost;
runCost = cost - startCost;
}
public double getRunCost() {
return runCost;
}
public double getStartCost() {
return startCost;
}
public double getCost() {
return cost;
}
public double getMemory() {
return memory;
}
public static Cost zero() {
return new CostV2(0, 0, 0);
}
public static Cost infinite() {
return new CostV2(0, Double.MAX_VALUE, Double.MAX_VALUE);
}
}

View File

@ -37,9 +37,11 @@ public class CostWeight {
static final double CPU_WEIGHT = 1;
static final double MEMORY_WEIGHT = 1;
static final double NETWORK_WEIGHT = 1.5;
static final double DELAY = 0.5;
final double cpuWeight;
final double memoryWeight;
final double networkWeight;
final double ioWeight;
/*
* About PENALTY:
* Except stats information, there are some special criteria in doris.
@ -62,10 +64,20 @@ public class CostWeight {
this.memoryWeight = memoryWeight;
this.networkWeight = networkWeight;
this.penaltyWeight = penaltyWeight;
this.ioWeight = 1;
}
public static CostWeight get() {
return new CostWeight(CPU_WEIGHT, MEMORY_WEIGHT, NETWORK_WEIGHT,
ConnectContext.get().getSessionVariable().getNereidsCboPenaltyFactor());
}
//TODO: add it in session variable
public static double getDelay() {
return DELAY;
}
public double weightSum(double cpuCost, double ioCost, double netCost) {
return cpuCost * cpuWeight + ioCost * ioWeight + netCost * networkWeight;
}
}

View File

@ -0,0 +1,73 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.cost;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import java.util.List;
/**
* Calculate the weight of each expression
* Now we just set all expression 1 except simple slot
*/
public class ExprCostModel extends ExpressionVisitor<Double, Void> {
public static double calculateExprCost(List<? extends Expression> expressionList) {
ExprCostModel exprCostModel = new ExprCostModel();
return expressionList.stream()
.map(e -> e.accept(exprCostModel, null))
.reduce(0.0, (a, b) -> a + b);
}
public static double calculateExprCost(Expression expression) {
ExprCostModel exprCostModel = new ExprCostModel();
return expression.accept(exprCostModel, null);
}
@Override
public Double visit(Expression expr, Void context) {
return 1.0;
}
@Override
public Double visitAlias(Alias alias, Void context) {
return alias.children().stream()
.map(e -> e.accept(this, context))
.reduce(0.0, (a, b) -> a + b);
}
@Override
public Double visitSlot(Slot slot, Void context) {
return 0.0;
}
@Override
public Double visitSlotReference(SlotReference slotReference, Void context) {
return 0.0;
}
@Override
public Double visitLiteral(Literal literal, Void context) {
return 0.0;
}
}

View File

@ -204,7 +204,7 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla
if (physicalPlan instanceof PhysicalDistribute) {
PhysicalDistribute distribute = (PhysicalDistribute) physicalPlan;
DataPartition dataPartition;
if (distribute.getPhysicalProperties().equals(PhysicalProperties.GATHER)) {
if (distribute.getDistributionSpec().equals(PhysicalProperties.GATHER.getDistributionSpec())) {
dataPartition = DataPartition.UNPARTITIONED;
} else {
throw new AnalysisException("Unsupported PhysicalDistribute in the root plan: " + distribute);

View File

@ -138,7 +138,7 @@ public class CostAndEnforcerJob extends Job implements Cloneable {
= outputChildrenPropertiesList.get(requestPropertiesIndex);
// Calculate cost
if (curChildIndex == 0 && prevChildIndex == -1) {
curNodeCost = CostCalculator.calculateCost(groupExpression);
curNodeCost = CostCalculator.calculateCost(groupExpression, requestChildrenProperties);
groupExpression.setCost(curNodeCost.getValue());
curTotalCost = curNodeCost;
}
@ -241,7 +241,7 @@ public class CostAndEnforcerJob extends Job implements Cloneable {
StatsCalculator.estimate(groupExpression);
// recompute cost after adjusting property
curNodeCost = CostCalculator.calculateCost(groupExpression); // recompute current node's cost in current context
curNodeCost = CostCalculator.calculateCost(groupExpression, requestChildrenProperties);
groupExpression.setCost(curNodeCost.getValue());
curTotalCost = curNodeCost;
for (int i = 0; i < outputChildrenProperties.size(); i++) {

View File

@ -784,7 +784,7 @@ public class Memo {
}
List<Pair<Long, List<Integer>>> childrenId = new ArrayList<>();
permute(children, 0, childrenId, new ArrayList<>());
Cost cost = CostCalculator.calculateCost(groupExpression);
Cost cost = CostCalculator.calculateCost(groupExpression, inputProperties);
for (Pair<Long, List<Integer>> c : childrenId) {
Cost totalCost = cost;
for (int i = 0; i < children.size(); i++) {

View File

@ -194,7 +194,7 @@ public class ChildrenPropertiesRegulator extends PlanVisitor<Double, Void> {
jobContext.getCascadesContext().getMemo().addEnforcerPlan(enforcer, child.getOwnerGroup());
Cost totalCost = CostCalculator.addChildCost(enforcer.getPlan(),
currentCost,
CostCalculator.calculateCost(enforcer),
CostCalculator.calculateCost(enforcer, Lists.newArrayList(childOutput)),
0);
if (enforcer.updateLowestCostTable(newOutputProperty,

View File

@ -152,7 +152,7 @@ public class EnforceMissingPropertiesHelper {
ENFORCER_TRACER.log(EnforcerEvent.of(groupExpression, ((PhysicalPlan) enforcer.getPlan()),
oldOutputProperty, newOutputProperty));
curTotalCost = CostCalculator.addChildCost(enforcer.getPlan(),
CostCalculator.calculateCost(enforcer),
CostCalculator.calculateCost(enforcer, Lists.newArrayList(oldOutputProperty)),
curTotalCost,
0);
if (enforcer.updateLowestCostTable(newOutputProperty,

View File

@ -153,7 +153,7 @@ public abstract class ExpressionVisitor<R, C>
return visit(binaryOperator, context);
}
public R visitUinaryOperator(UnaryOperator unaryOperator, C context) {
public R visitUnaryOperator(UnaryOperator unaryOperator, C context) {
return visit(unaryOperator, context);
}
@ -306,7 +306,7 @@ public abstract class ExpressionVisitor<R, C>
}
public R visitUnaryArithmetic(UnaryArithmetic unaryArithmetic, C context) {
return visitUinaryOperator(unaryArithmetic, context);
return visitUnaryOperator(unaryArithmetic, context);
}
public R visitBinaryArithmetic(BinaryArithmetic binaryArithmetic, C context) {

View File

@ -85,7 +85,8 @@ public class PhysicalDistribute<CHILD_TYPE extends Plan> extends PhysicalUnary<C
public PhysicalDistribute<Plan> withChildren(List<Plan> children) {
Preconditions.checkArgument(children.size() == 1);
return new PhysicalDistribute<>(distributionSpec, Optional.empty(),
getLogicalProperties(), children.get(0));
getLogicalProperties(), physicalProperties, statistics, children.get(0));
}
@Override

View File

@ -123,7 +123,8 @@ public class PhysicalHashJoin<
public PhysicalHashJoin<Plan, Plan> withChildren(List<Plan> children) {
Preconditions.checkArgument(children.size() == 2);
return new PhysicalHashJoin<>(joinType, hashJoinConjuncts, otherJoinConjuncts, hint, markJoinSlotReference,
getLogicalProperties(), children.get(0), children.get(1));
Optional.empty(), getLogicalProperties(), physicalProperties, statistics,
children.get(0), children.get(1));
}
@Override

View File

@ -127,8 +127,8 @@ public class PhysicalNestedLoopJoin<
public PhysicalNestedLoopJoin<Plan, Plan> withChildren(List<Plan> children) {
Preconditions.checkArgument(children.size() == 2);
return new PhysicalNestedLoopJoin<>(joinType,
hashJoinConjuncts, otherJoinConjuncts, markJoinSlotReference,
getLogicalProperties(), children.get(0), children.get(1));
hashJoinConjuncts, otherJoinConjuncts, markJoinSlotReference, Optional.empty(),
getLogicalProperties(), physicalProperties, statistics, children.get(0), children.get(1));
}
@Override

View File

@ -201,7 +201,7 @@ public class SessionVariable implements Serializable, Writable {
public static final String ENABLE_NEREIDS_PLANNER = "enable_nereids_planner";
public static final String DISABLE_NEREIDS_RULES = "disable_nereids_rules";
public static final String ENABLE_NEW_COST_MODEL = "enable_new_cost_model";
public static final String ENABLE_FALLBACK_TO_ORIGINAL_PLANNER = "enable_fallback_to_original_planner";
public static final String ENABLE_NEREIDS_RUNTIME_FILTER = "enable_nereids_runtime_filter";
@ -305,6 +305,8 @@ public class SessionVariable implements Serializable, Writable {
// if it is setStmt, we needn't collect session origin value
public boolean isSingleSetVar = false;
@VariableMgr.VarAttr(name = INSERT_VISIBLE_TIMEOUT_MS, needForward = true)
public long insertVisibleTimeoutMs = DEFAULT_INSERT_VISIBLE_TIMEOUT_MS;
@ -616,6 +618,9 @@ public class SessionVariable implements Serializable, Writable {
@VariableMgr.VarAttr(name = DISABLE_NEREIDS_RULES)
private String disableNereidsRules = "";
@VariableMgr.VarAttr(name = ENABLE_NEW_COST_MODEL)
private boolean enableNewCostModel = true;
@VariableMgr.VarAttr(name = NEREIDS_STAR_SCHEMA_SUPPORT)
private boolean nereidsStarSchemaSupport = true;
@ -778,7 +783,6 @@ public class SessionVariable implements Serializable, Writable {
@VariableMgr.VarAttr(name = DRY_RUN_QUERY, needForward = true)
public boolean dryRunQuery = false;
// If this fe is in fuzzy mode, then will use initFuzzyModeVariables to generate some variables,
// not the default value set in the code.
public void initFuzzyModeVariables() {
@ -1496,6 +1500,14 @@ public class SessionVariable implements Serializable, Writable {
.collect(ImmutableSet.toImmutableSet());
}
public void setEnableNewCostModel(boolean enable) {
this.enableNewCostModel = enable;
}
public boolean getEnableNewCostModel() {
return this.enableNewCostModel;
}
public void setDisableNereidsRules(String disableNereidsRules) {
this.disableNereidsRules = disableNereidsRules;
}

View File

@ -103,7 +103,7 @@ public class Statistics {
}
public double computeSize() {
if (computeSize < 0) {
if (computeSize <= 0) {
computeSize = Math.max(1, expressionToColumnStats.values().stream()
.map(s -> s.dataSize).reduce(0D, Double::sum)
) * rowCount;
@ -131,4 +131,8 @@ public class Statistics {
public double getPenalty() {
return penalty;
}
public int getBENumber() {
return 1;
}
}

View File

@ -204,4 +204,8 @@ public class StatsDeriveResult {
public void setPenalty(double penalty) {
this.penalty = penalty;
}
public int getBENumber() {
return 1;
}
}

View File

@ -84,8 +84,8 @@ public class JoinTest extends SqlTestBase {
@Test
void testBucketJoinWithAgg() {
String sql = "select * from "
+ "(select count(id) as cnt from T2 group by id) T1 inner join"
+ "(select count(id) as cnt from T2 group by id) T2 "
+ "(select distinct id as cnt from T2) T1 inner join"
+ "(select distinct id as cnt from T2) T2 "
+ "on T1.cnt = T2.cnt";
PhysicalPlan plan = PlanChecker.from(connectContext)
.analyze(sql)

View File

@ -223,12 +223,12 @@ suite("join") {
logger.info(explainStr)
assertTrue(
//if analyze finished
explainStr.contains("4:VAGGREGATE (update serialize)") && explainStr.contains("6:VAGGREGATE (merge finalize)")
&& explainStr.contains("wtid[#8] = CAST(wtid[#3] AS CHARACTER)") && explainStr.contains("projections: wtid[#5], wfid[#6]")
explainStr.contains("VAGGREGATE (update serialize)") && explainStr.contains("VAGGREGATE (merge finalize)")
&& explainStr.contains("wtid[#8] = wtid[#3]") && explainStr.contains("projections: wtid[#5], wfid[#6]")
||
//analyze not finished
explainStr.contains("4:VAGGREGATE (update serialize)") && explainStr.contains("8:VAGGREGATE (update finalize)")
&& explainStr.contains("7:VEXCHANGE") && explainStr.contains("3:VHASH JOIN")
explainStr.contains("VAGGREGATE (update finalize)") && explainStr.contains("VAGGREGATE (update finalize)")
&& explainStr.contains("VEXCHANGE") && explainStr.contains("VHASH JOIN")
)
test {