[Performance](Nereids): pass ConnectContext to avoid ThreadLocal.get() (#26165)

This commit is contained in:
jakevin
2023-11-01 16:16:42 +08:00
committed by GitHub
parent 1770224322
commit 683832230c
24 changed files with 274 additions and 233 deletions

View File

@ -199,7 +199,7 @@ public class CascadesContext implements ScheduleContext {
}
public void toMemo() {
this.memo = new Memo(plan);
this.memo = new Memo(getConnectContext(), plan);
}
public Analyzer newAnalyzer() {
@ -358,7 +358,7 @@ public class CascadesContext implements ScheduleContext {
return table;
}
}
if (ConnectContext.get().getSessionVariable().isPlayNereidsDump()) {
if (getConnectContext().getSessionVariable().isPlayNereidsDump()) {
throw new AnalysisException("Minidump cache can not find table:" + tableName);
}
return null;

View File

@ -19,6 +19,8 @@ package org.apache.doris.nereids;
import org.apache.doris.nereids.memo.GroupExpression;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.qe.ConnectContext;
import org.apache.doris.qe.SessionVariable;
import org.apache.doris.statistics.Statistics;
import java.util.ArrayList;
@ -31,7 +33,7 @@ import java.util.List;
* Inspired by GPORCA-CExpressionHandle.
*/
public class PlanContext {
private final ConnectContext connectContext;
private final List<Statistics> childrenStats;
private final Statistics planStats;
private final int arity;
@ -41,7 +43,8 @@ public class PlanContext {
/**
* Constructor for PlanContext.
*/
public PlanContext(GroupExpression groupExpression) {
public PlanContext(ConnectContext connectContext, GroupExpression groupExpression) {
this.connectContext = connectContext;
this.arity = groupExpression.arity();
this.planStats = groupExpression.getOwnerGroup().getStatistics();
this.isStatsReliable = groupExpression.getOwnerGroup().isStatsReliable();
@ -51,12 +54,8 @@ public class PlanContext {
}
}
// This is used in GraphSimplifier
public PlanContext(Statistics planStats, List<Statistics> childrenStats) {
this.planStats = planStats;
this.childrenStats = childrenStats;
this.isStatsReliable = false;
this.arity = this.childrenStats.size();
public SessionVariable getSessionVariable() {
return connectContext.getSessionVariable();
}
public void setBroadcastJoin() {

View File

@ -17,7 +17,7 @@
package org.apache.doris.nereids.cost;
import org.apache.doris.qe.ConnectContext;
import org.apache.doris.qe.SessionVariable;
/**
* Cost encapsulate the real cost with double type.
@ -27,21 +27,22 @@ import org.apache.doris.qe.ConnectContext;
public interface Cost {
double getValue();
/**
* return zero cost
*/
static Cost zero() {
if (ConnectContext.get().getSessionVariable().getEnableNewCostModel()) {
static Cost zero(SessionVariable sessionVariable) {
if (sessionVariable.getEnableNewCostModel()) {
return CostV2.zero();
}
return CostV1.zero();
}
static Cost infinite() {
if (ConnectContext.get().getSessionVariable().getEnableNewCostModel()) {
static Cost infinite(SessionVariable sessionVariable) {
if (sessionVariable.getEnableNewCostModel()) {
return CostV2.infinite();
}
return CostV1.infinite();
}
static Cost zeroV1() {
return CostV1.zero();
}
}

View File

@ -18,53 +18,42 @@
package org.apache.doris.nereids.cost;
import org.apache.doris.nereids.PlanContext;
import org.apache.doris.nereids.annotation.Developing;
import org.apache.doris.nereids.memo.GroupExpression;
import org.apache.doris.nereids.properties.DistributionSpecReplicated;
import org.apache.doris.nereids.properties.PhysicalProperties;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.qe.ConnectContext;
import org.apache.doris.qe.SessionVariable;
import java.util.List;
/**
* Calculate the cost of a plan.
*/
@Developing
//TODO: memory cost and network cost should be estimated by byte size.
// TODO: memory cost and network cost should be estimated by byte size.
public class CostCalculator {
/**
* Calculate cost for groupExpression
*/
public static Cost calculateCost(GroupExpression groupExpression, List<PhysicalProperties> childrenProperties) {
PlanContext planContext = new PlanContext(groupExpression);
public static Cost calculateCost(ConnectContext connectContext, GroupExpression groupExpression,
List<PhysicalProperties> childrenProperties) {
PlanContext planContext = new PlanContext(connectContext, groupExpression);
if (childrenProperties.size() >= 2
&& childrenProperties.get(1).getDistributionSpec() instanceof DistributionSpecReplicated) {
planContext.setBroadcastJoin();
}
CostModelV1 costModelV1 = new CostModelV1();
CostModelV1 costModelV1 = new CostModelV1(connectContext);
return groupExpression.getPlan().accept(costModelV1, planContext);
}
/**
* Calculate cost without groupExpression
*/
public static Cost calculateCost(Plan plan, PlanContext planContext) {
if (ConnectContext.get().getSessionVariable().getEnableNewCostModel()) {
CostModelV2 costModel = new CostModelV2();
return plan.accept(costModel, planContext);
} else {
CostModelV1 costModel = new CostModelV1();
return plan.accept(costModel, planContext);
}
}
public static Cost addChildCost(Plan plan, Cost planCost, Cost childCost, int index) {
if (ConnectContext.get().getSessionVariable().getEnableNewCostModel()) {
public static Cost addChildCost(ConnectContext connectContext, Plan plan, Cost planCost, Cost childCost,
int index) {
SessionVariable sessionVariable = connectContext.getSessionVariable();
if (sessionVariable.getEnableNewCostModel()) {
return CostModelV2.addChildCost(plan, planCost, childCost, index);
}
return CostModelV1.addChildCost(plan, planCost, childCost, index);
return CostModelV1.addChildCost(sessionVariable, planCost, childCost);
}
}

View File

@ -43,23 +43,12 @@ import org.apache.doris.nereids.trees.plans.physical.PhysicalStorageLayerAggrega
import org.apache.doris.nereids.trees.plans.physical.PhysicalTopN;
import org.apache.doris.nereids.trees.plans.visitor.PlanVisitor;
import org.apache.doris.qe.ConnectContext;
import org.apache.doris.qe.SessionVariable;
import org.apache.doris.statistics.Statistics;
import com.google.common.base.Preconditions;
class CostModelV1 extends PlanVisitor<Cost, PlanContext> {
/**
* The intuition behind `HEAVY_OPERATOR_PUNISH_FACTOR` is we need to avoid this form of join patterns:
* Plan1: L join ( AGG1(A) join AGG2(B))
* But
* Plan2: L join AGG1(A) join AGG2(B) is welcomed.
* AGG is time-consuming operator. From the perspective of rowCount, nereids may choose Plan1,
* because `Agg1 join Agg2` generates few tuples. But in Plan1, Agg1 and Agg2 are done in serial, in Plan2, Agg1 and
* Agg2 are done in parallel. And hence, Plan1 should be punished.
* <p>
* An example is tpch q15.
*/
static final double HEAVY_OPERATOR_PUNISH_FACTOR = 0.0;
// for a join, skew = leftRowCount/rightRowCount
// the higher skew is, the more we prefer broadcast join than shuffle join
@ -69,22 +58,24 @@ class CostModelV1 extends PlanVisitor<Cost, PlanContext> {
static final double BROADCAST_JOIN_SKEW_PENALTY_LIMIT = 2.0;
private final int beNumber;
public CostModelV1() {
if (ConnectContext.get().getSessionVariable().isPlayNereidsDump()) {
public CostModelV1(ConnectContext connectContext) {
SessionVariable sessionVariable = connectContext.getSessionVariable();
if (sessionVariable.isPlayNereidsDump()) {
// TODO: @bingfeng refine minidump setting, and pass testMinidumpUt
beNumber = 1;
} else if (ConnectContext.get().getSessionVariable().getBeNumberForTest() != -1) {
beNumber = ConnectContext.get().getSessionVariable().getBeNumberForTest();
} else if (sessionVariable.getBeNumberForTest() != -1) {
beNumber = sessionVariable.getBeNumberForTest();
} else {
beNumber = Math.max(1, ConnectContext.get().getEnv().getClusterInfo().getBackendsNumber(true));
}
}
public static Cost addChildCost(Plan plan, Cost planCost, Cost childCost, int index) {
public static Cost addChildCost(SessionVariable sessionVariable, Cost planCost, Cost childCost) {
Preconditions.checkArgument(childCost instanceof CostV1 && planCost instanceof CostV1);
CostV1 childCostV1 = (CostV1) childCost;
CostV1 planCostV1 = (CostV1) planCost;
return new CostV1(childCostV1.getCpuCost() + planCostV1.getCpuCost(),
return new CostV1(sessionVariable,
childCostV1.getCpuCost() + planCostV1.getCpuCost(),
childCostV1.getMemoryCost() + planCostV1.getMemoryCost(),
childCostV1.getNetworkCost() + planCostV1.getNetworkCost());
}
@ -97,7 +88,7 @@ class CostModelV1 extends PlanVisitor<Cost, PlanContext> {
@Override
public Cost visitPhysicalOlapScan(PhysicalOlapScan physicalOlapScan, PlanContext context) {
Statistics statistics = context.getStatisticsWithCheck();
return CostV1.ofCpu(statistics.getRowCount());
return CostV1.ofCpu(context.getSessionVariable(), statistics.getRowCount());
}
@Override
@ -108,7 +99,7 @@ class CostModelV1 extends PlanVisitor<Cost, PlanContext> {
public Cost visitPhysicalSchemaScan(PhysicalSchemaScan physicalSchemaScan, PlanContext context) {
Statistics statistics = context.getStatisticsWithCheck();
return CostV1.ofCpu(statistics.getRowCount());
return CostV1.ofCpu(context.getSessionVariable(), statistics.getRowCount());
}
@Override
@ -116,31 +107,31 @@ class CostModelV1 extends PlanVisitor<Cost, PlanContext> {
PhysicalStorageLayerAggregate storageLayerAggregate, PlanContext context) {
CostV1 costValue = (CostV1) storageLayerAggregate.getRelation().accept(this, context);
// multiply a factor less than 1, so we can select PhysicalStorageLayerAggregate as far as possible
return new CostV1(costValue.getCpuCost() * 0.7, costValue.getMemoryCost(),
return new CostV1(context.getSessionVariable(), costValue.getCpuCost() * 0.7, costValue.getMemoryCost(),
costValue.getNetworkCost());
}
@Override
public Cost visitPhysicalFileScan(PhysicalFileScan physicalFileScan, PlanContext context) {
Statistics statistics = context.getStatisticsWithCheck();
return CostV1.ofCpu(statistics.getRowCount());
return CostV1.ofCpu(context.getSessionVariable(), statistics.getRowCount());
}
@Override
public Cost visitPhysicalProject(PhysicalProject<? extends Plan> physicalProject, PlanContext context) {
return CostV1.ofCpu(1);
return CostV1.ofCpu(context.getSessionVariable(), 1);
}
@Override
public Cost visitPhysicalJdbcScan(PhysicalJdbcScan physicalJdbcScan, PlanContext context) {
Statistics statistics = context.getStatisticsWithCheck();
return CostV1.ofCpu(statistics.getRowCount());
return CostV1.ofCpu(context.getSessionVariable(), statistics.getRowCount());
}
@Override
public Cost visitPhysicalEsScan(PhysicalEsScan physicalEsScan, PlanContext context) {
Statistics statistics = context.getStatisticsWithCheck();
return CostV1.ofCpu(statistics.getRowCount());
return CostV1.ofCpu(context.getSessionVariable(), statistics.getRowCount());
}
@Override
@ -156,7 +147,7 @@ class CostModelV1 extends PlanVisitor<Cost, PlanContext> {
// Now we do more like two-phase sort, so penalise one-phase sort
rowCount *= 100;
}
return CostV1.of(childRowCount, rowCount, childRowCount);
return CostV1.of(context.getSessionVariable(), childRowCount, rowCount, childRowCount);
}
@Override
@ -171,7 +162,7 @@ class CostModelV1 extends PlanVisitor<Cost, PlanContext> {
// Now we do more like two-phase sort, so penalise one-phase sort
rowCount *= 100;
}
return CostV1.of(childRowCount, rowCount, childRowCount);
return CostV1.of(context.getSessionVariable(), childRowCount, rowCount, childRowCount);
}
@Override
@ -184,10 +175,10 @@ class CostModelV1 extends PlanVisitor<Cost, PlanContext> {
public Cost visitPhysicalPartitionTopN(PhysicalPartitionTopN<? extends Plan> partitionTopN, PlanContext context) {
Statistics statistics = context.getStatisticsWithCheck();
Statistics childStatistics = context.getChildStatistics(0);
return CostV1.of(
childStatistics.getRowCount(),
statistics.getRowCount(),
childStatistics.getRowCount());
return CostV1.of(context.getSessionVariable(),
childStatistics.getRowCount(),
statistics.getRowCount(),
childStatistics.getRowCount());
}
@Override
@ -199,7 +190,7 @@ class CostModelV1 extends PlanVisitor<Cost, PlanContext> {
// shuffle
if (spec instanceof DistributionSpecHash) {
return CostV1.of(
return CostV1.of(context.getSessionVariable(),
0,
0,
intputRowCount * childStatistics.dataSizeFactor() / beNumber);
@ -210,7 +201,7 @@ class CostModelV1 extends PlanVisitor<Cost, PlanContext> {
// estimate broadcast cost by an experience formula: beNumber^0.5 * rowCount
// - sender number and receiver number is not available at RBO stage now, so we use beNumber
// - senders and receivers work in parallel, that why we use square of beNumber
return CostV1.of(
return CostV1.of(context.getSessionVariable(),
0,
0,
intputRowCount * childStatistics.dataSizeFactor());
@ -219,14 +210,14 @@ class CostModelV1 extends PlanVisitor<Cost, PlanContext> {
// gather
if (spec instanceof DistributionSpecGather) {
return CostV1.of(
return CostV1.of(context.getSessionVariable(),
0,
0,
intputRowCount * childStatistics.dataSizeFactor() / beNumber);
}
// any
return CostV1.of(
return CostV1.of(context.getSessionVariable(),
intputRowCount,
0,
0);
@ -237,11 +228,11 @@ class CostModelV1 extends PlanVisitor<Cost, PlanContext> {
PhysicalHashAggregate<? extends Plan> aggregate, PlanContext context) {
Statistics inputStatistics = context.getChildStatistics(0);
if (aggregate.getAggPhase().isLocal()) {
return CostV1.of(inputStatistics.getRowCount() / beNumber,
return CostV1.of(context.getSessionVariable(), inputStatistics.getRowCount() / beNumber,
inputStatistics.getRowCount() / beNumber, 0);
} else {
// global
return CostV1.of(inputStatistics.getRowCount(),
return CostV1.of(context.getSessionVariable(), inputStatistics.getRowCount(),
inputStatistics.getRowCount(), 0);
}
}
@ -278,7 +269,7 @@ class CostModelV1 extends PlanVisitor<Cost, PlanContext> {
in pattern2, join1 and join2 takes more time, but Agg1 and agg2 can be processed in parallel.
*/
if (physicalHashJoin.getJoinType().isCrossJoin()) {
return CostV1.of(leftRowCount + rightRowCount + outputRowCount,
return CostV1.of(context.getSessionVariable(), leftRowCount + rightRowCount + outputRowCount,
0,
leftRowCount + rightRowCount
);
@ -293,8 +284,8 @@ class CostModelV1 extends PlanVisitor<Cost, PlanContext> {
// bigger cost for ProbeWhenBuildSideOutput effort and ProbeWhenSearchHashTableTime
// on the output rows, taken on outputRowCount()
double probeSideFactor = 1.0;
double buildSideFactor = ConnectContext.get().getSessionVariable().getBroadcastRightTableScaleFactor();
int parallelInstance = Math.max(1, ConnectContext.get().getSessionVariable().getParallelExecInstanceNum());
double buildSideFactor = context.getSessionVariable().getBroadcastRightTableScaleFactor();
int parallelInstance = Math.max(1, context.getSessionVariable().getParallelExecInstanceNum());
int totalInstanceNumber = parallelInstance * beNumber;
if (buildSideFactor <= 1.0) {
// use totalInstanceNumber to the power of 2 as the default factor value
@ -304,22 +295,24 @@ class CostModelV1 extends PlanVisitor<Cost, PlanContext> {
// will refine this in next generation cost model.
if (!context.isStatsReliable()) {
// forbid broadcast join when stats is unknown
return CostV1.of(rightRowCount * buildSideFactor + 1 / leftRowCount,
return CostV1.of(context.getSessionVariable(), rightRowCount * buildSideFactor + 1 / leftRowCount,
rightRowCount,
0
);
}
return CostV1.of(leftRowCount + rightRowCount * buildSideFactor + outputRowCount * probeSideFactor,
return CostV1.of(context.getSessionVariable(),
leftRowCount + rightRowCount * buildSideFactor + outputRowCount * probeSideFactor,
rightRowCount,
0
);
}
if (!context.isStatsReliable()) {
return CostV1.of(rightRowCount + 1 / leftRowCount,
return CostV1.of(context.getSessionVariable(),
rightRowCount + 1 / leftRowCount,
rightRowCount,
0);
}
return CostV1.of(leftRowCount + rightRowCount + outputRowCount,
return CostV1.of(context.getSessionVariable(), leftRowCount + rightRowCount + outputRowCount,
rightRowCount,
0
);
@ -334,11 +327,12 @@ class CostModelV1 extends PlanVisitor<Cost, PlanContext> {
Statistics leftStatistics = context.getChildStatistics(0);
Statistics rightStatistics = context.getChildStatistics(1);
if (!context.isStatsReliable()) {
return CostV1.of(rightStatistics.getRowCount() + 1 / leftStatistics.getRowCount(),
return CostV1.of(context.getSessionVariable(),
rightStatistics.getRowCount() + 1 / leftStatistics.getRowCount(),
rightStatistics.getRowCount(),
0);
}
return CostV1.of(
return CostV1.of(context.getSessionVariable(),
leftStatistics.getRowCount() * rightStatistics.getRowCount(),
rightStatistics.getRowCount(),
0);
@ -347,7 +341,7 @@ class CostModelV1 extends PlanVisitor<Cost, PlanContext> {
@Override
public Cost visitPhysicalAssertNumRows(PhysicalAssertNumRows<? extends Plan> assertNumRows,
PlanContext context) {
return CostV1.of(
return CostV1.of(context.getSessionVariable(),
assertNumRows.getAssertNumRowsElement().getDesiredNumOfRows(),
assertNumRows.getAssertNumRowsElement().getDesiredNumOfRows(),
0
@ -357,7 +351,7 @@ class CostModelV1 extends PlanVisitor<Cost, PlanContext> {
@Override
public Cost visitPhysicalGenerate(PhysicalGenerate<? extends Plan> generate, PlanContext context) {
Statistics statistics = context.getStatisticsWithCheck();
return CostV1.of(
return CostV1.of(context.getSessionVariable(),
statistics.getRowCount(),
statistics.getRowCount(),
0

View File

@ -44,6 +44,7 @@ import org.apache.doris.nereids.trees.plans.physical.PhysicalStorageLayerAggrega
import org.apache.doris.nereids.trees.plans.physical.PhysicalUnion;
import org.apache.doris.nereids.trees.plans.physical.PhysicalWindow;
import org.apache.doris.nereids.trees.plans.visitor.PlanVisitor;
import org.apache.doris.qe.SessionVariable;
import org.apache.doris.statistics.Statistics;
import com.google.common.base.Preconditions;
@ -58,6 +59,12 @@ class CostModelV2 extends PlanVisitor<Cost, PlanContext> {
static double CMP_COST = 1.5;
static double PUSH_DOWN_AGG_COST = 0.1;
private final SessionVariable sessionVariable;
CostModelV2(SessionVariable sessionVariable) {
this.sessionVariable = sessionVariable;
}
public static Cost addChildCost(Plan plan, Cost planCost, Cost childCost, int index) {
Preconditions.checkArgument(childCost instanceof CostV2 && planCost instanceof CostV2);
CostV2 planCostV2 = (CostV2) planCost;
@ -103,7 +110,7 @@ class CostModelV2 extends PlanVisitor<Cost, PlanContext> {
double ioCost = stats.computeSize();
double runCost1 = CostWeight.get().weightSum(0, ioCost, 0) / stats.getBENumber();
double runCost1 = CostWeight.get(sessionVariable).weightSum(0, ioCost, 0) / stats.getBENumber();
// Note the stats of this operator is the stats of relation.
// We need add a plenty for this cost. Maybe changing rowCount of storageLayer is better
@ -125,7 +132,7 @@ class CostModelV2 extends PlanVisitor<Cost, PlanContext> {
double cpuCost = statistics.getRowCount() * ExprCostModel.calculateExprCost(physicalProject.getProjects());
double startCost = 0;
double runCost = CostWeight.get().weightSum(cpuCost, 0, 0) / statistics.getBENumber();
double runCost = CostWeight.get(sessionVariable).weightSum(cpuCost, 0, 0) / statistics.getBENumber();
return new CostV2(startCost, runCost, 0);
}
@ -185,7 +192,7 @@ class CostModelV2 extends PlanVisitor<Cost, PlanContext> {
}
double startCost = 0;
double runCost = CostWeight.get().weightSum(0, 0, netCost) / childStatistics.getBENumber();
double runCost = CostWeight.get(sessionVariable).weightSum(0, 0, netCost) / childStatistics.getBENumber();
return new CostV2(startCost, runCost, 0);
}
@ -212,8 +219,8 @@ class CostModelV2 extends PlanVisitor<Cost, PlanContext> {
}
double probeCost = leftStats.getRowCount() * PROBE_COST + stats.getRowCount() * otherExprCost;
double startCost = CostWeight.get().weightSum(buildTableCost, 0, 0);
double runCost = CostWeight.get().weightSum(probeCost, 0, 0) / stats.getBENumber();
double startCost = CostWeight.get(sessionVariable).weightSum(buildTableCost, 0, 0);
double runCost = CostWeight.get(sessionVariable).weightSum(probeCost, 0, 0) / stats.getBENumber();
return new CostV2(startCost, runCost, rightStats.computeSize());
}
@ -232,7 +239,7 @@ class CostModelV2 extends PlanVisitor<Cost, PlanContext> {
}
double startCost = 0;
double runCost = CostWeight.get().weightSum(probeCost, 0, 0) / stats.getBENumber();
double runCost = CostWeight.get(sessionVariable).weightSum(probeCost, 0, 0) / stats.getBENumber();
return new CostV2(startCost, runCost, rightStats.computeSize());
}
@ -257,7 +264,7 @@ class CostModelV2 extends PlanVisitor<Cost, PlanContext> {
double cpuCost = exprCost * statistics.getRowCount();
double startCost = 0;
double runCost = CostWeight.get().weightSum(cpuCost, 0, 0) / statistics.getBENumber();
double runCost = CostWeight.get(sessionVariable).weightSum(cpuCost, 0, 0) / statistics.getBENumber();
return new CostV2(startCost, runCost, 0);
}
@ -274,7 +281,7 @@ class CostModelV2 extends PlanVisitor<Cost, PlanContext> {
double cpuCost = stats.getRowCount() * exprCost;
double startCost = 0;
double runCost = CostWeight.get().weightSum(cpuCost, 0, 0) / stats.getBENumber();
double runCost = CostWeight.get(sessionVariable).weightSum(cpuCost, 0, 0) / stats.getBENumber();
return new CostV2(startCost, runCost, 0);
}
@ -293,7 +300,7 @@ class CostModelV2 extends PlanVisitor<Cost, PlanContext> {
size += childStats.computeSize();
}
double startCost = CostWeight.get().weightSum(rowCount * HASH_COST, 0, 0);
double startCost = CostWeight.get(sessionVariable).weightSum(rowCount * HASH_COST, 0, 0);
double runCost = 0;
return new CostV2(startCost, runCost, size);
@ -307,7 +314,7 @@ class CostModelV2 extends PlanVisitor<Cost, PlanContext> {
double cpuCost = exprCost * stats.getRowCount();
double startCost = 0;
double runCost = CostWeight.get().weightSum(cpuCost, 0, 0) / stats.getBENumber();
double runCost = CostWeight.get(sessionVariable).weightSum(cpuCost, 0, 0) / stats.getBENumber();
return new CostV2(startCost, runCost, 0);
}
@ -316,13 +323,13 @@ class CostModelV2 extends PlanVisitor<Cost, PlanContext> {
//TODO: consider runtimeFilter
double io = stats.computeSize();
double startCost = 0;
double runCost = CostWeight.get().weightSum(0, io, 0) / stats.getBENumber();
double runCost = CostWeight.get(sessionVariable).weightSum(0, io, 0) / stats.getBENumber();
return new CostV2(startCost, runCost, 0);
}
private CostV2 calculateAggregate(Statistics stats, Statistics childStats, double exprCost) {
// Build HashTable
double startCost = CostWeight.get()
double startCost = CostWeight.get(sessionVariable)
.weightSum(HASH_COST * childStats.getRowCount() + exprCost * childStats.getRowCount(), 0, 0);
double runCost = 0;
return new CostV2(startCost, runCost, stats.computeSize());

View File

@ -17,10 +17,13 @@
package org.apache.doris.nereids.cost;
import org.apache.doris.qe.SessionVariable;
class CostV1 implements Cost {
private static final CostV1 INFINITE = new CostV1(Double.POSITIVE_INFINITY, Double.POSITIVE_INFINITY,
Double.POSITIVE_INFINITY,
Double.POSITIVE_INFINITY);
private static final CostV1 ZERO = new CostV1(0, 0, 0);
private static final CostV1 ZERO = new CostV1(0, 0, 0, 0);
private final double cpuCost;
private final double memoryCost;
@ -29,9 +32,9 @@ class CostV1 implements Cost {
private final double cost;
/**
* Constructor of CostEstimate.
* Constructor of CostV1.
*/
public CostV1(double cpuCost, double memoryCost, double networkCost) {
public CostV1(SessionVariable sessionVariable, double cpuCost, double memoryCost, double networkCost) {
// TODO: fix stats
cpuCost = Double.max(0, cpuCost);
memoryCost = Double.max(0, memoryCost);
@ -40,11 +43,18 @@ class CostV1 implements Cost {
this.memoryCost = memoryCost;
this.networkCost = networkCost;
CostWeight costWeight = CostWeight.get();
CostWeight costWeight = CostWeight.get(sessionVariable);
this.cost = costWeight.cpuWeight * cpuCost + costWeight.memoryWeight * memoryCost
+ costWeight.networkWeight * networkCost;
}
private CostV1(double cost, double cpuCost, double memoryCost, double networkCost) {
this.cost = cost;
this.cpuCost = cpuCost;
this.memoryCost = memoryCost;
this.networkCost = networkCost;
}
public static CostV1 infinite() {
return INFINITE;
}
@ -69,16 +79,12 @@ class CostV1 implements Cost {
return cost;
}
public static CostV1 of(double cpuCost, double maxMemory, double networkCost) {
return new CostV1(cpuCost, maxMemory, networkCost);
public static CostV1 of(SessionVariable sessionVariable, double cpuCost, double maxMemory, double networkCost) {
return new CostV1(sessionVariable, cpuCost, maxMemory, networkCost);
}
public static CostV1 ofCpu(double cpuCost) {
return new CostV1(cpuCost, 0, 0);
}
public static CostV1 ofMemory(double memoryCost) {
return new CostV1(0, memoryCost, 0);
public static CostV1 ofCpu(SessionVariable sessionVariable, double cpuCost) {
return new CostV1(sessionVariable, cpuCost, 0, 0);
}
@Override

View File

@ -18,6 +18,7 @@
package org.apache.doris.nereids.cost;
import org.apache.doris.qe.ConnectContext;
import org.apache.doris.qe.SessionVariable;
import com.google.common.base.Preconditions;
@ -66,11 +67,18 @@ public class CostWeight {
}
public static CostWeight get() {
double cpuWeight = ConnectContext.get().getSessionVariable().getCboCpuWeight();
double memWeight = ConnectContext.get().getSessionVariable().getCboMemWeight();
double netWeight = ConnectContext.get().getSessionVariable().getCboNetWeight();
return new CostWeight(cpuWeight, memWeight, netWeight,
ConnectContext.get().getSessionVariable().getNereidsCboPenaltyFactor());
SessionVariable sessionVariable = ConnectContext.get().getSessionVariable();
double cpuWeight = sessionVariable.getCboCpuWeight();
double memWeight = sessionVariable.getCboMemWeight();
double netWeight = sessionVariable.getCboNetWeight();
return new CostWeight(cpuWeight, memWeight, netWeight, sessionVariable.getNereidsCboPenaltyFactor());
}
public static CostWeight get(SessionVariable sessionVariable) {
double cpuWeight = sessionVariable.getCboCpuWeight();
double memWeight = sessionVariable.getCboMemWeight();
double netWeight = sessionVariable.getCboNetWeight();
return new CostWeight(cpuWeight, memWeight, netWeight, sessionVariable.getNereidsCboPenaltyFactor());
}
//TODO: add it in session variable

View File

@ -244,11 +244,11 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla
Collections.reverse(context.getPlanFragments());
// TODO: maybe we need to trans nullable directly? and then we could remove call computeMemLayout
context.getDescTable().computeMemLayout();
if (ConnectContext.get() != null && ConnectContext.get().getSessionVariable().forbidUnknownColStats) {
if (context.getSessionVariable() != null && context.getSessionVariable().forbidUnknownColStats) {
Set<ScanNode> scans = context.getScanNodeWithUnknownColumnStats();
if (!scans.isEmpty()) {
StringBuilder builder = new StringBuilder();
scans.forEach(scanNode -> builder.append(scanNode));
scans.forEach(builder::append);
throw new AnalysisException("tables with unknown column stats: " + builder);
}
}
@ -607,7 +607,7 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla
// TODO: move all node set cardinality into one place
if (olapScan.getStats() != null) {
olapScanNode.setCardinality((long) olapScan.getStats().getRowCount());
if (ConnectContext.get().getSessionVariable().forbidUnknownColStats) {
if (context.getSessionVariable() != null && context.getSessionVariable().forbidUnknownColStats) {
for (int i = 0; i < slots.size(); i++) {
Slot slot = slots.get(i);
if (olapScan.getStats().findColumnStatistics(slot).isUnKnown()
@ -1027,7 +1027,7 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla
updateLegacyPlanIdToPhysicalPlan(inputFragment.getPlanRoot(), filter);
}
}
//in ut, filter.stats may be null
// in ut, filter.stats may be null
if (filter.getStats() != null) {
inputFragment.getPlanRoot().setCardinalityAfterFilter((long) filter.getStats().getRowCount());
}
@ -1226,7 +1226,7 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla
sd = context.getDescTable().copySlotDescriptor(intermediateDescriptor, leftSlotDescriptor);
} else {
sd = context.createSlotDesc(intermediateDescriptor, sf, leftSlotDescriptor.getParent().getTable());
//sd = context.createSlotDesc(intermediateDescriptor, sf);
// sd = context.createSlotDesc(intermediateDescriptor, sf);
if (hashOutputSlotReferenceMap.get(sf.getExprId()) != null) {
hashJoinNode.addSlotIdToHashOutputSlotIds(leftSlotDescriptor.getId());
hashJoinNode.getHashOutputExprSlotIdMap().put(sf.getExprId(), leftSlotDescriptor.getId());
@ -1247,7 +1247,7 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla
sd = context.getDescTable().copySlotDescriptor(intermediateDescriptor, rightSlotDescriptor);
} else {
sd = context.createSlotDesc(intermediateDescriptor, sf, rightSlotDescriptor.getParent().getTable());
//sd = context.createSlotDesc(intermediateDescriptor, sf);
// sd = context.createSlotDesc(intermediateDescriptor, sf);
if (hashOutputSlotReferenceMap.get(sf.getExprId()) != null) {
hashJoinNode.addSlotIdToHashOutputSlotIds(rightSlotDescriptor.getId());
hashJoinNode.getHashOutputExprSlotIdMap().put(sf.getExprId(), rightSlotDescriptor.getId());
@ -1267,7 +1267,7 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla
sd = context.getDescTable().copySlotDescriptor(intermediateDescriptor, leftSlotDescriptor);
} else {
sd = context.createSlotDesc(intermediateDescriptor, sf, leftSlotDescriptor.getParent().getTable());
//sd = context.createSlotDesc(intermediateDescriptor, sf);
// sd = context.createSlotDesc(intermediateDescriptor, sf);
if (hashOutputSlotReferenceMap.get(sf.getExprId()) != null) {
hashJoinNode.addSlotIdToHashOutputSlotIds(leftSlotDescriptor.getId());
hashJoinNode.getHashOutputExprSlotIdMap().put(sf.getExprId(), leftSlotDescriptor.getId());
@ -1286,7 +1286,7 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla
sd = context.getDescTable().copySlotDescriptor(intermediateDescriptor, rightSlotDescriptor);
} else {
sd = context.createSlotDesc(intermediateDescriptor, sf, rightSlotDescriptor.getParent().getTable());
//sd = context.createSlotDesc(intermediateDescriptor, sf);
// sd = context.createSlotDesc(intermediateDescriptor, sf);
if (hashOutputSlotReferenceMap.get(sf.getExprId()) != null) {
hashJoinNode.addSlotIdToHashOutputSlotIds(rightSlotDescriptor.getId());
hashJoinNode.getHashOutputExprSlotIdMap().put(sf.getExprId(), rightSlotDescriptor.getId());
@ -1454,7 +1454,7 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla
sd = context.getDescTable().copySlotDescriptor(intermediateDescriptor, leftSlotDescriptor);
} else {
sd = context.createSlotDesc(intermediateDescriptor, sf, leftSlotDescriptor.getParent().getTable());
//sd = context.createSlotDesc(intermediateDescriptor, sf);
// sd = context.createSlotDesc(intermediateDescriptor, sf);
}
leftIntermediateSlotDescriptor.add(sd);
}
@ -1469,7 +1469,7 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla
sd = context.getDescTable().copySlotDescriptor(intermediateDescriptor, rightSlotDescriptor);
} else {
sd = context.createSlotDesc(intermediateDescriptor, sf, rightSlotDescriptor.getParent().getTable());
//sd = context.createSlotDesc(intermediateDescriptor, sf);
// sd = context.createSlotDesc(intermediateDescriptor, sf);
}
rightIntermediateSlotDescriptor.add(sd);
}
@ -1494,7 +1494,7 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla
.map(e -> ExpressionTranslator.translate(e, context)).collect(Collectors.toList());
if (!nestedLoopJoin.isBitMapRuntimeFilterConditionsEmpty() && joinConjuncts.isEmpty()) {
//left semi join need at least one conjunct. otherwise left-semi-join fallback to cross-join
// left semi join need at least one conjunct. otherwise left-semi-join fallback to cross-join
joinConjuncts.add(new BoolLiteral(true));
}
@ -2099,7 +2099,8 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla
scanNode.getTupleDesc().getSlots().add(smallest);
}
try {
if (ConnectContext.get() != null && ConnectContext.get().getSessionVariable().forbidUnknownColStats
if (context.getSessionVariable() != null
&& context.getSessionVariable().forbidUnknownColStats
&& !StatisticConstants.isSystemTable(scanNode.getTupleDesc().getTable())) {
for (SlotId slotId : requiredByProjectSlotIdSet) {
if (context.isColumnStatsUnknown(scanNode, slotId)) {
@ -2316,7 +2317,7 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla
private boolean checkPushSort(SortNode sortNode, OlapTable olapTable) {
// Ensure limit is less than threshold
if (sortNode.getLimit() <= 0
|| sortNode.getLimit() > ConnectContext.get().getSessionVariable().topnOptLimitThreshold) {
|| sortNode.getLimit() > context.getSessionVariable().topnOptLimitThreshold) {
return false;
}
@ -2365,7 +2366,7 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla
List<Expr> outputExprs = Lists.newArrayList();
if (conjuncts != null) {
conjuncts.stream()
.map(e -> ExpressionTranslator.translate(e, context))
.map(e -> ExpressionTranslator.translate(e, context))
.forEach(outputExprs::add);
}
return outputExprs;

View File

@ -43,6 +43,8 @@ import org.apache.doris.planner.PlanFragmentId;
import org.apache.doris.planner.PlanNode;
import org.apache.doris.planner.PlanNodeId;
import org.apache.doris.planner.ScanNode;
import org.apache.doris.qe.ConnectContext;
import org.apache.doris.qe.SessionVariable;
import org.apache.doris.thrift.TPushAggOp;
import com.google.common.annotations.VisibleForTesting;
@ -62,6 +64,7 @@ import javax.annotation.Nullable;
* Context of physical plan.
*/
public class PlanTranslatorContext {
private final ConnectContext connectContext;
private final List<PlanFragment> planFragments = Lists.newArrayList();
private final DescriptorTable descTable = new DescriptorTable();
@ -110,12 +113,14 @@ public class PlanTranslatorContext {
private final Map<ScanNode, Set<SlotId>> statsUnknownColumnsMap = Maps.newHashMap();
public PlanTranslatorContext(CascadesContext ctx) {
this.connectContext = ctx.getConnectContext();
this.translator = new RuntimeFilterTranslator(ctx.getRuntimeFilterContext());
}
@VisibleForTesting
public PlanTranslatorContext() {
translator = null;
this.connectContext = null;
this.translator = null;
}
/**
@ -142,6 +147,10 @@ public class PlanTranslatorContext {
statsUnknownColumnsMap.remove(scan);
}
public SessionVariable getSessionVariable() {
return connectContext == null ? null : connectContext.getSessionVariable();
}
public Set<ScanNode> getScanNodeWithUnknownColumnStats() {
return statsUnknownColumnsMap.keySet();
}

View File

@ -31,6 +31,8 @@ import org.apache.doris.nereids.properties.ChildrenPropertiesRegulator;
import org.apache.doris.nereids.properties.EnforceMissingPropertiesHelper;
import org.apache.doris.nereids.properties.PhysicalProperties;
import org.apache.doris.nereids.properties.RequestPropertyDeriver;
import org.apache.doris.qe.ConnectContext;
import org.apache.doris.qe.SessionVariable;
import com.google.common.collect.Lists;
import org.apache.logging.log4j.LogManager;
@ -79,6 +81,14 @@ public class CostAndEnforcerJob extends Job implements Cloneable {
this.groupExpression = groupExpression;
}
private ConnectContext getConnectContext() {
return context.getCascadesContext().getConnectContext();
}
private SessionVariable getSessionVariable() {
return context.getCascadesContext().getConnectContext().getSessionVariable();
}
/*-
* Please read the ORCA paper
* - 4.1.4 Optimization.
@ -113,17 +123,19 @@ public class CostAndEnforcerJob extends Job implements Cloneable {
return;
}
SessionVariable sessionVariable = getSessionVariable();
countJobExecutionTimesOfGroupExpressions(groupExpression);
// Do init logic of root plan/groupExpr of `subplan`, only run once per task.
if (curChildIndex == -1) {
curNodeCost = Cost.zero();
curTotalCost = Cost.zero();
curNodeCost = Cost.zero(sessionVariable);
curTotalCost = Cost.zero(sessionVariable);
curChildIndex = 0;
// List<request property to children>
// [ child item: [leftProperties, rightProperties]]
// like :[ [Properties {"", ANY}, Properties {"", BROADCAST}],
// [Properties {"", SHUFFLE_JOIN}, Properties {"", SHUFFLE_JOIN}] ]
RequestPropertyDeriver requestPropertyDeriver = new RequestPropertyDeriver(context);
RequestPropertyDeriver requestPropertyDeriver = new RequestPropertyDeriver(getConnectContext(), context);
requestChildrenPropertiesList = requestPropertyDeriver.getRequestChildrenPropertyList(groupExpression);
for (List<PhysicalProperties> requestChildrenProperties : requestChildrenPropertiesList) {
outputChildrenPropertiesList.add(new ArrayList<>(requestChildrenProperties));
@ -139,7 +151,8 @@ public class CostAndEnforcerJob extends Job implements Cloneable {
= outputChildrenPropertiesList.get(requestPropertiesIndex);
// Calculate cost
if (curChildIndex == 0 && prevChildIndex == -1) {
curNodeCost = CostCalculator.calculateCost(groupExpression, requestChildrenProperties);
curNodeCost = CostCalculator.calculateCost(getConnectContext(), groupExpression,
requestChildrenProperties);
groupExpression.setCost(curNodeCost);
curTotalCost = curNodeCost;
}
@ -184,7 +197,9 @@ public class CostAndEnforcerJob extends Job implements Cloneable {
// plan's requestChildProperty).getOutputProperties(current plan's requestChildProperty) == child
// plan's outputProperties`, the outputProperties must satisfy the origin requestChildProperty
outputChildrenProperties.set(curChildIndex, outputProperties);
curTotalCost = CostCalculator.addChildCost(groupExpression.getPlan(),
curTotalCost = CostCalculator.addChildCost(
getConnectContext(),
groupExpression.getPlan(),
curNodeCost,
lowestCostExpr.getCostValueByProperties(requestChildProperty),
curChildIndex);
@ -194,7 +209,7 @@ public class CostAndEnforcerJob extends Job implements Cloneable {
// Group1 : betterExpr, currentExpr(child: Group2), otherExpr(child: Group)
// steps
// 1. CostAndEnforce(currentExpr) with upperBound betterExpr.cost
// 2. OptimzeGroup(Group2) with upperBound bestExpr.cost - currentExpr.nodeCost
// 2. OptimizeGroup(Group2) with upperBound bestExpr.cost - currentExpr.nodeCost
// 3. CostAndEnforce(Expr in Group2) trigger here and exit
// ...
// n. CostAndEnforce(otherExpr) can trigger optimize group2 again for the same requireProp
@ -240,7 +255,8 @@ public class CostAndEnforcerJob extends Job implements Cloneable {
ChildOutputPropertyDeriver childOutputPropertyDeriver
= new ChildOutputPropertyDeriver(outputChildrenProperties);
// the physical properties the group expression support for its parent.
PhysicalProperties outputProperty = childOutputPropertyDeriver.getOutputProperties(groupExpression);
PhysicalProperties outputProperty = childOutputPropertyDeriver.getOutputProperties(getConnectContext(),
groupExpression);
// update current group statistics and re-compute costs.
if (groupExpression.children().stream().anyMatch(group -> group.getStatistics() == null)
@ -251,12 +267,14 @@ public class CostAndEnforcerJob extends Job implements Cloneable {
}
// recompute cost after adjusting property
curNodeCost = CostCalculator.calculateCost(groupExpression, requestChildrenProperties);
curNodeCost = CostCalculator.calculateCost(getConnectContext(), groupExpression, requestChildrenProperties);
groupExpression.setCost(curNodeCost);
curTotalCost = curNodeCost;
for (int i = 0; i < outputChildrenProperties.size(); i++) {
PhysicalProperties childProperties = outputChildrenProperties.get(i);
curTotalCost = CostCalculator.addChildCost(groupExpression.getPlan(),
curTotalCost = CostCalculator.addChildCost(
getConnectContext(),
groupExpression.getPlan(),
curTotalCost,
groupExpression.child(i).getLowestCostPlan(childProperties).get().first,
i);
@ -301,7 +319,7 @@ public class CostAndEnforcerJob extends Job implements Cloneable {
}
EnforceMissingPropertiesHelper enforceMissingPropertiesHelper
= new EnforceMissingPropertiesHelper(context, groupExpression, curTotalCost);
= new EnforceMissingPropertiesHelper(getConnectContext(), groupExpression, curTotalCost);
PhysicalProperties addEnforcedProperty = enforceMissingPropertiesHelper
.enforceProperty(outputProperty, requiredProperties);
curTotalCost = enforceMissingPropertiesHelper.getCurTotalCost();
@ -338,8 +356,8 @@ public class CostAndEnforcerJob extends Job implements Cloneable {
lowestCostChildren.clear();
prevChildIndex = -1;
curChildIndex = 0;
curTotalCost = Cost.zero();
curNodeCost = Cost.zero();
curTotalCost = Cost.zero(getSessionVariable());
curNodeCost = Cost.zero(getSessionVariable());
}
/**

View File

@ -30,6 +30,7 @@ import org.apache.doris.nereids.stats.StatsCalculator;
import org.apache.doris.nereids.trees.expressions.CTEId;
import org.apache.doris.nereids.trees.plans.algebra.Project;
import org.apache.doris.qe.ConnectContext;
import org.apache.doris.qe.SessionVariable;
import org.apache.doris.statistics.Statistics;
import java.util.HashMap;
@ -102,20 +103,19 @@ public class DeriveStatsJob extends Job {
}
}
} else {
ConnectContext connectContext = context.getCascadesContext().getConnectContext();
SessionVariable sessionVariable = connectContext.getSessionVariable();
StatsCalculator statsCalculator = StatsCalculator.estimate(groupExpression,
context.getCascadesContext().getConnectContext().getSessionVariable().getForbidUnknownColStats(),
context.getCascadesContext().getConnectContext().getTotalColumnStatisticMap(),
context.getCascadesContext().getConnectContext().getSessionVariable().isPlayNereidsDump(),
sessionVariable.getForbidUnknownColStats(),
connectContext.getTotalColumnStatisticMap(),
sessionVariable.isPlayNereidsDump(),
cteIdToStats,
context.getCascadesContext());
STATS_STATE_TRACER.log(StatsStateEvent.of(groupExpression,
groupExpression.getOwnerGroup().getStatistics()));
if (ConnectContext.get().getSessionVariable().isEnableMinidump()
&& !ConnectContext.get().getSessionVariable().isPlayNereidsDump()) {
context.getCascadesContext().getConnectContext().getTotalColumnStatisticMap()
.putAll(statsCalculator.getTotalColumnStatisticMap());
context.getCascadesContext().getConnectContext().getTotalHistogramMap()
.putAll(statsCalculator.getTotalHistogramMap());
if (sessionVariable.isEnableMinidump() && !sessionVariable.isPlayNereidsDump()) {
connectContext.getTotalColumnStatisticMap().putAll(statsCalculator.getTotalColumnStatisticMap());
connectContext.getTotalHistogramMap().putAll(statsCalculator.getTotalHistogramMap());
}
if (groupExpression.getPlan() instanceof Project) {

View File

@ -68,6 +68,7 @@ public class Memo {
private static final EventProducer GROUP_MERGE_TRACER = new EventProducer(GroupMergeEvent.class,
EventChannel.getDefaultChannel().addConsumers(new LogConsumer(GroupMergeEvent.class, EventChannel.LOG)));
private static long stateId = 0;
private final ConnectContext connectContext;
private final IdGenerator<GroupId> groupIdGenerator = GroupId.createGenerator();
private final Map<GroupId, Group> groups = Maps.newLinkedHashMap();
// we could not use Set, because Set does not have get method.
@ -76,11 +77,13 @@ public class Memo {
// FOR TEST ONLY
public Memo() {
root = null;
this.root = null;
this.connectContext = null;
}
public Memo(Plan plan) {
root = init(plan);
public Memo(ConnectContext connectContext, Plan plan) {
this.root = init(plan);
this.connectContext = connectContext;
}
public static long getStateId() {
@ -214,8 +217,7 @@ public class Memo {
}
private void maybeAddStateId(CopyInResult result) {
if (ConnectContext.get() != null
&& ConnectContext.get().getSessionVariable().isEnableNereidsTrace()
if (connectContext != null && connectContext.getSessionVariable().isEnableNereidsTrace()
&& result.generateNewExpression) {
stateId++;
}
@ -850,11 +852,12 @@ public class Memo {
List<Pair<Long, List<Integer>>> childrenId = new ArrayList<>();
permute(children, 0, childrenId, new ArrayList<>());
Cost cost = CostCalculator.calculateCost(groupExpression, inputProperties);
Cost cost = CostCalculator.calculateCost(connectContext, groupExpression, inputProperties);
for (Pair<Long, List<Integer>> c : childrenId) {
Cost totalCost = cost;
for (int i = 0; i < children.size(); i++) {
totalCost = CostCalculator.addChildCost(groupExpression.getPlan(),
totalCost = CostCalculator.addChildCost(connectContext,
groupExpression.getPlan(),
totalCost,
children.get(i).get(c.second.get(i)).second,
i);
@ -942,7 +945,7 @@ public class Memo {
// return any if exits except RequirePropertiesSupplier and SetOperators
// Because PropRegulator could change their input properties
RequestPropertyDeriver requestPropertyDeriver = new RequestPropertyDeriver(prop);
RequestPropertyDeriver requestPropertyDeriver = new RequestPropertyDeriver(connectContext, prop);
List<List<PhysicalProperties>> requestList = requestPropertyDeriver
.getRequestChildrenPropertyList(groupExpression);
Optional<List<PhysicalProperties>> any = requestList.stream()

View File

@ -286,8 +286,6 @@ public class MinidumpUtils {
/**
* serialize output plan to dump file and persistent into disk
* @param resultPlan
*
*/
public static void serializeOutputToDumpFile(Plan resultPlan) {
if (ConnectContext.get().getSessionVariable().isPlayNereidsDump()
@ -401,22 +399,22 @@ public class MinidumpUtils {
}
switch (field.getType().getSimpleName()) {
case "boolean":
root.put(attr.name(), (Boolean) field.get(sessionVariable));
root.put(attr.name(), field.get(sessionVariable));
break;
case "int":
root.put(attr.name(), (Integer) field.get(sessionVariable));
root.put(attr.name(), field.get(sessionVariable));
break;
case "long":
root.put(attr.name(), (Long) field.get(sessionVariable));
root.put(attr.name(), field.get(sessionVariable));
break;
case "float":
root.put(attr.name(), (Float) field.get(sessionVariable));
root.put(attr.name(), field.get(sessionVariable));
break;
case "double":
root.put(attr.name(), (Double) field.get(sessionVariable));
root.put(attr.name(), field.get(sessionVariable));
break;
case "String":
root.put(attr.name(), (String) field.get(sessionVariable));
root.put(attr.name(), field.get(sessionVariable));
break;
default:
// Unsupported type variable.

View File

@ -60,6 +60,7 @@ import org.apache.doris.nereids.trees.plans.physical.PhysicalWindow;
import org.apache.doris.nereids.trees.plans.visitor.PlanVisitor;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.util.JoinUtils;
import org.apache.doris.qe.ConnectContext;
import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
@ -89,8 +90,8 @@ public class ChildOutputPropertyDeriver extends PlanVisitor<PhysicalProperties,
this.childrenOutputProperties = Objects.requireNonNull(childrenOutputProperties);
}
public PhysicalProperties getOutputProperties(GroupExpression groupExpression) {
return groupExpression.getPlan().accept(this, new PlanContext(groupExpression));
public PhysicalProperties getOutputProperties(ConnectContext connectContext, GroupExpression groupExpression) {
return groupExpression.getPlan().accept(this, new PlanContext(connectContext, groupExpression));
}
@Override

View File

@ -43,6 +43,7 @@ import org.apache.doris.nereids.trees.plans.physical.PhysicalProject;
import org.apache.doris.nereids.trees.plans.physical.PhysicalSetOperation;
import org.apache.doris.nereids.trees.plans.visitor.PlanVisitor;
import org.apache.doris.nereids.util.JoinUtils;
import org.apache.doris.qe.ConnectContext;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
@ -453,7 +454,7 @@ public class ChildrenPropertiesRegulator extends PlanVisitor<Boolean, Void> {
*
* @param shuffleType real output shuffle type
* @param notShuffleSideOutput not shuffle side real output used hash spec
* @param shuffleSideOutput shuffle side real output used hash spec
* @param shuffleSideOutput shuffle side real output used hash spec
* @param notShuffleSideRequired not shuffle side required used hash spec
* @param shuffleSideRequired shuffle side required hash spec
* @return shuffle side new required hash spec
@ -481,7 +482,7 @@ public class ChildrenPropertiesRegulator extends PlanVisitor<Boolean, Void> {
private void updateChildEnforceAndCost(GroupExpression child, PhysicalProperties childOutput,
DistributionSpec target, Cost currentCost) {
if (child.getPlan() instanceof PhysicalDistribute) {
//To avoid continuous distribute operator, we just enforce the child's child
// To avoid continuous distribute operator, we just enforce the child's child
childOutput = child.getInputPropertiesList(childOutput).get(0);
Pair<Cost, GroupExpression> newChildAndCost = child.getOwnerGroup().getLowestCostPlan(childOutput).get();
child = newChildAndCost.second;
@ -491,8 +492,9 @@ public class ChildrenPropertiesRegulator extends PlanVisitor<Boolean, Void> {
PhysicalProperties newOutputProperty = new PhysicalProperties(target);
GroupExpression enforcer = target.addEnforcer(child.getOwnerGroup());
child.getOwnerGroup().addEnforcer(enforcer);
Cost totalCost = CostCalculator.addChildCost(enforcer.getPlan(),
CostCalculator.calculateCost(enforcer, Lists.newArrayList(childOutput)),
ConnectContext connectContext = jobContext.getCascadesContext().getConnectContext();
Cost totalCost = CostCalculator.addChildCost(connectContext, enforcer.getPlan(),
CostCalculator.calculateCost(connectContext, enforcer, Lists.newArrayList(childOutput)),
currentCost,
0);

View File

@ -19,7 +19,6 @@ package org.apache.doris.nereids.properties;
import org.apache.doris.nereids.cost.Cost;
import org.apache.doris.nereids.cost.CostCalculator;
import org.apache.doris.nereids.jobs.JobContext;
import org.apache.doris.nereids.memo.GroupExpression;
import org.apache.doris.nereids.metrics.EventChannel;
import org.apache.doris.nereids.metrics.EventProducer;
@ -28,6 +27,7 @@ import org.apache.doris.nereids.metrics.event.EnforcerEvent;
import org.apache.doris.nereids.minidump.NereidsTracer;
import org.apache.doris.nereids.properties.DistributionSpecHash.ShuffleType;
import org.apache.doris.nereids.trees.plans.physical.PhysicalPlan;
import org.apache.doris.qe.ConnectContext;
import com.google.common.collect.Lists;
@ -38,13 +38,13 @@ import com.google.common.collect.Lists;
public class EnforceMissingPropertiesHelper {
private static final EventProducer ENFORCER_TRACER = new EventProducer(EnforcerEvent.class,
EventChannel.getDefaultChannel().addConsumers(new LogConsumer(EnforcerEvent.class, EventChannel.LOG)));
private final JobContext context;
private final ConnectContext connectContext;
private final GroupExpression groupExpression;
private Cost curTotalCost;
public EnforceMissingPropertiesHelper(JobContext context, GroupExpression groupExpression,
public EnforceMissingPropertiesHelper(ConnectContext connectContext, GroupExpression groupExpression,
Cost curTotalCost) {
this.context = context;
this.connectContext = connectContext;
this.groupExpression = groupExpression;
this.curTotalCost = curTotalCost;
}
@ -155,12 +155,15 @@ public class EnforceMissingPropertiesHelper {
ENFORCER_TRACER.log(EnforcerEvent.of(groupExpression, ((PhysicalPlan) enforcer.getPlan()),
oldOutputProperty, newOutputProperty));
enforcer.setEstOutputRowCount(enforcer.getOwnerGroup().getStatistics().getRowCount());
Cost enforcerCost = CostCalculator.calculateCost(enforcer, Lists.newArrayList(oldOutputProperty));
Cost enforcerCost = CostCalculator.calculateCost(connectContext, enforcer,
Lists.newArrayList(oldOutputProperty));
enforcer.setCost(enforcerCost);
curTotalCost = CostCalculator.addChildCost(enforcer.getPlan(),
enforcerCost,
curTotalCost,
0);
curTotalCost = CostCalculator.addChildCost(
connectContext,
enforcer.getPlan(),
enforcerCost,
curTotalCost,
0);
if (enforcer.updateLowestCostTable(newOutputProperty,
Lists.newArrayList(oldOutputProperty), curTotalCost)) {
enforcer.putOutputPropertiesMap(newOutputProperty, newOutputProperty);

View File

@ -65,14 +65,17 @@ public class RequestPropertyDeriver extends PlanVisitor<Void, PlanContext> {
* ▼
* requestPropertyToChildren
*/
private final ConnectContext connectContext;
private final PhysicalProperties requestPropertyFromParent;
private List<List<PhysicalProperties>> requestPropertyToChildren;
public RequestPropertyDeriver(JobContext context) {
public RequestPropertyDeriver(ConnectContext connectContext, JobContext context) {
this.connectContext = connectContext;
this.requestPropertyFromParent = context.getRequiredProperties();
}
public RequestPropertyDeriver(PhysicalProperties requestPropertyFromParent) {
public RequestPropertyDeriver(ConnectContext connectContext, PhysicalProperties requestPropertyFromParent) {
this.connectContext = connectContext;
this.requestPropertyFromParent = requestPropertyFromParent;
}
@ -81,7 +84,7 @@ public class RequestPropertyDeriver extends PlanVisitor<Void, PlanContext> {
*/
public List<List<PhysicalProperties>> getRequestChildrenPropertyList(GroupExpression groupExpression) {
requestPropertyToChildren = Lists.newArrayList();
groupExpression.getPlan().accept(this, new PlanContext(groupExpression));
groupExpression.getPlan().accept(this, new PlanContext(connectContext, groupExpression));
return requestPropertyToChildren;
}
@ -110,8 +113,7 @@ public class RequestPropertyDeriver extends PlanVisitor<Void, PlanContext> {
@Override
public Void visitPhysicalOlapTableSink(PhysicalOlapTableSink<? extends Plan> olapTableSink, PlanContext context) {
if (ConnectContext.get() != null && ConnectContext.get().getSessionVariable() != null
&& !ConnectContext.get().getSessionVariable().enableStrictConsistencyDml) {
if (connectContext != null && !connectContext.getSessionVariable().enableStrictConsistencyDml) {
addRequestPropertyToChildren(PhysicalProperties.ANY);
} else {
addRequestPropertyToChildren(olapTableSink.getRequirePhysicalProperties());