[fix](Nereids): fix and enable stats derive job (#11755)

fix and enable statistics derive job
Add mock for statistics in computeScan().
This commit is contained in:
jakevin
2022-08-18 21:26:35 +08:00
committed by GitHub
parent 124b4f7694
commit b15e2ddeaa
11 changed files with 62 additions and 176 deletions

View File

@ -43,7 +43,6 @@ import org.apache.doris.nereids.trees.plans.physical.PhysicalPlan;
import org.apache.doris.planner.PlanFragment;
import org.apache.doris.planner.Planner;
import org.apache.doris.planner.ScanNode;
import org.apache.doris.qe.ConnectContext;
import com.google.common.collect.Lists;
@ -112,10 +111,7 @@ public class NereidsPlanner extends Planner {
// rule-based optimize
rewrite();
// TODO: remove this condition, when stats collector is fully developed.
if (ConnectContext.get().getSessionVariable().isEnableNereidsCBO()) {
deriveStats();
}
deriveStats();
// TODO: What is the appropriate time to set physical properties? Maybe before enter.
// cascades style optimize phase.
@ -154,7 +150,9 @@ public class NereidsPlanner extends Planner {
}
private void deriveStats() {
new DeriveStatsJob(getRoot().getLogicalExpression(), cascadesContext.getCurrentJobContext()).execute();
cascadesContext
.pushJob(new DeriveStatsJob(getRoot().getLogicalExpression(), cascadesContext.getCurrentJobContext()));
cascadesContext.getJobScheduler().executeJobPool(cascadesContext);
}
/**

View File

@ -40,17 +40,12 @@ import java.util.List;
public class PlanContext {
// array of children's derived stats
private final List<StatsDeriveResult> childrenStats = Lists.newArrayList();
// statistics of attached plan/gexpr
private StatsDeriveResult statistics;
// attached plan
private Plan plan;
// attached group expression
private GroupExpression groupExpression;
public PlanContext(Plan plan) {
this.plan = plan;
}
private final GroupExpression groupExpression;
/**
* Constructor for PlanContext.
*/
public PlanContext(GroupExpression groupExpression) {
this.groupExpression = groupExpression;
@ -59,10 +54,6 @@ public class PlanContext {
}
}
public Plan getPlan() {
return plan;
}
public GroupExpression getGroupExpression() {
return groupExpression;
}
@ -71,21 +62,14 @@ public class PlanContext {
return childrenStats;
}
public StatsDeriveResult getStatistics() {
return statistics;
}
public void setStatistics(StatsDeriveResult stats) {
this.statistics = stats;
}
public StatsDeriveResult getStatisticsWithCheck() {
StatsDeriveResult statistics = groupExpression.getOwnerGroup().getStatistics();
Preconditions.checkNotNull(statistics);
return statistics;
}
public LogicalProperties childLogicalPropertyAt(int index) {
return plan.child(index).getLogicalProperties();
return groupExpression.child(index).getLogicalProperties();
}
public List<Slot> getChildOutputSlots(int index) {

View File

@ -45,12 +45,10 @@ public class CostCalculator {
* Constructor.
*/
public static double calculateCost(GroupExpression groupExpression) {
// TODO: Enable following code after enable stats derive.
// PlanContext planContext = new PlanContext(groupExpression);
// CostEstimator costCalculator = new CostEstimator();
// CostEstimate costEstimate = groupExpression.getPlan().accept(costCalculator, planContext);
// return costFormula(costEstimate);
return 0;
PlanContext planContext = new PlanContext(groupExpression);
CostEstimator costCalculator = new CostEstimator();
CostEstimate costEstimate = groupExpression.getPlan().accept(costCalculator, planContext);
return costFormula(costEstimate);
}
private static double costFormula(CostEstimate costEstimate) {

View File

@ -35,6 +35,16 @@ public final class CostEstimate {
* Constructor of CostEstimate.
*/
public CostEstimate(double cpuCost, double memoryCost, double networkCost) {
// TODO: remove them after finish statistics.
if (cpuCost < 0) {
cpuCost = 0;
}
if (memoryCost < 0) {
memoryCost = 0;
}
if (networkCost < 0) {
networkCost = 0;
}
Preconditions.checkArgument(!(cpuCost < 0), "cpuCost cannot be negative: %s", cpuCost);
Preconditions.checkArgument(!(memoryCost < 0), "memoryCost cannot be negative: %s", memoryCost);
Preconditions.checkArgument(!(networkCost < 0), "networkCost cannot be negative: %s", networkCost);

View File

@ -71,12 +71,13 @@ public class ApplyRuleJob extends Job {
GroupExpression newGroupExpression = pair.second;
if (newPlan instanceof LogicalPlan) {
pushTask(new DeriveStatsJob(newGroupExpression, context));
if (exploredOnly) {
pushTask(new ExploreGroupExpressionJob(newGroupExpression, context));
pushTask(new DeriveStatsJob(newGroupExpression, context));
continue;
}
pushTask(new OptimizeGroupExpressionJob(newGroupExpression, context));
pushTask(new DeriveStatsJob(newGroupExpression, context));
} else {
pushTask(new CostAndEnforcerJob(newGroupExpression, context));
}

View File

@ -172,12 +172,12 @@ public class CostAndEnforcerJob extends Job implements Cloneable {
}
/* update current group statistics and re-compute costs. */
if (groupExpression.children().stream().anyMatch(group -> group.getStatistics() != null)) {
if (groupExpression.children().stream().anyMatch(group -> group.getStatistics() == null)) {
return;
}
PlanContext planContext = new PlanContext(groupExpression);
// TODO: calculate stats. ??????
groupExpression.getOwnerGroup().setStatistics(planContext.getStatistics());
groupExpression.getOwnerGroup().setStatistics(planContext.getStatisticsWithCheck());
enforce(outputProperty, requestChildrenProperty);

View File

@ -21,14 +21,11 @@ import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
import org.apache.doris.nereids.trees.expressions.CompoundPredicate;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.GreaterThan;
import org.apache.doris.nereids.trees.expressions.GreaterThanEqual;
import org.apache.doris.nereids.trees.expressions.LessThan;
import org.apache.doris.nereids.trees.expressions.Literal;
import org.apache.doris.nereids.trees.expressions.Or;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionVisitor;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.statistics.ColumnStats;
import com.google.common.base.Preconditions;
@ -38,7 +35,7 @@ import java.util.Map;
/**
* Calculate selectivity of the filter.
*/
public class FilterSelectivityCalculator extends DefaultExpressionVisitor<Double, Void> {
public class FilterSelectivityCalculator extends ExpressionVisitor<Double, Void> {
private static double DEFAULT_SELECTIVITY = 0.1;
@ -63,13 +60,19 @@ public class FilterSelectivityCalculator extends DefaultExpressionVisitor<Double
return expression.accept(this, null);
}
@Override
public Double visit(Expression expr, Void context) {
return DEFAULT_SELECTIVITY;
}
@Override
public Double visitCompoundPredicate(CompoundPredicate compoundPredicate, Void context) {
Expression leftExpr = compoundPredicate.child(0);
Expression rightExpr = compoundPredicate.child(1);
double leftSel = 1;
double rightSel = 1;
leftSel = estimate(leftExpr);
leftSel = estimate(leftExpr);
rightSel = estimate(rightExpr);
return compoundPredicate instanceof Or ? leftSel + rightSel - leftSel * rightSel : leftSel * rightSel;
}
@ -92,19 +95,4 @@ public class FilterSelectivityCalculator extends DefaultExpressionVisitor<Double
}
// TODO: Should consider the distribution of data.
@Override
public Double visitGreaterThan(GreaterThan greaterThan, Void context) {
return DEFAULT_SELECTIVITY;
}
@Override
public Double visitGreaterThanEqual(GreaterThanEqual greaterThanEqual, Void context) {
return DEFAULT_SELECTIVITY;
}
@Override
public Double visitLessThan(LessThan lessThan, Void context) {
return DEFAULT_SELECTIVITY;
}
}

View File

@ -129,6 +129,7 @@ public class JoinEstimation {
if (lhsCard == -1 || rhsCard == -1) {
return lhsCard;
}
long result = -1;
for (Expression eqJoinConjunct : eqConjunctList) {
Expression left = eqJoinConjunct.child(0);
@ -160,6 +161,7 @@ public class JoinEstimation {
result = Math.min(result, joinCard);
}
}
return result;
}
}

View File

@ -17,7 +17,9 @@
package org.apache.doris.nereids.stats;
import org.apache.doris.catalog.Table;
import org.apache.doris.catalog.MaterializedIndex;
import org.apache.doris.catalog.Partition;
import org.apache.doris.common.AnalysisException;
import org.apache.doris.nereids.memo.GroupExpression;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
@ -52,6 +54,7 @@ import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanVisitor;
import org.apache.doris.nereids.util.Utils;
import org.apache.doris.qe.ConnectContext;
import org.apache.doris.statistics.ColumnStats;
import org.apache.doris.statistics.Statistics;
import org.apache.doris.statistics.StatsDeriveResult;
import org.apache.doris.statistics.TableStats;
@ -194,9 +197,9 @@ public class StatsCalculator extends DefaultPlanVisitor<StatsDeriveResult, Void>
// 2. Consider the influence of runtime filter
// 3. Get NDV and column data size from StatisticManger, StatisticManager doesn't support it now.
private StatsDeriveResult computeScan(Scan scan) {
Table table = scan.getTable();
TableStats tableStats = Utils.execWithReturnVal(() ->
ConnectContext.get().getEnv().getStatisticsManager().getStatistics().getTableStats(table.getId())
// TODO: tmp mock the table stats, after we support the table stats, we should remove this mock.
mockRowCountInStatistic(scan)
);
Map<Slot, ColumnStats> slotToColumnStats = new HashMap<>();
Set<SlotReference> slotSet = scan.getOutput().stream().filter(SlotReference.class::isInstance)
@ -216,6 +219,23 @@ public class StatsCalculator extends DefaultPlanVisitor<StatsDeriveResult, Void>
return stats;
}
// TODO: tmp mock the table stats, after we support the table stats, we should remove this mock.
private TableStats mockRowCountInStatistic(Scan scan) throws AnalysisException {
long cardinality = 0;
if (scan instanceof PhysicalOlapScan) {
PhysicalOlapScan olapScan = (PhysicalOlapScan) scan;
for (long selectedPartitionId : olapScan.getSelectedPartitionId()) {
final Partition partition = olapScan.getTable().getPartition(selectedPartitionId);
final MaterializedIndex baseIndex = partition.getBaseIndex();
cardinality += baseIndex.getRowCount();
}
}
Statistics statistics = ConnectContext.get().getEnv().getStatisticsManager().getStatistics();
statistics.mockTableStatsWithRowCount(scan.getTable().getId(), cardinality);
return statistics.getTableStats(scan.getTable().getId());
}
private StatsDeriveResult computeTopN(TopN topN) {
StatsDeriveResult stats = groupExpression.getCopyOfChildStats(0);
return stats.updateRowCountByLimit(topN.getLimit());

View File

@ -93,8 +93,6 @@ public class SessionVariable implements Serializable, Writable {
public static final String ENABLE_COST_BASED_JOIN_REORDER = "enable_cost_based_join_reorder";
public static final String ENABLE_NEREIDS_CBO = "enable_nereids_cbo";
public static final int MIN_EXEC_INSTANCE_NUM = 1;
public static final int MAX_EXEC_INSTANCE_NUM = 32;
// if set to true, some of stmt will be forwarded to master FE to get result
@ -448,9 +446,6 @@ public class SessionVariable implements Serializable, Writable {
@VariableMgr.VarAttr(name = ENABLE_COST_BASED_JOIN_REORDER)
private boolean enableJoinReorderBasedCost = false;
@VariableMgr.VarAttr(name = ENABLE_NEREIDS_CBO)
private boolean enableNereidsCBO = false;
@VariableMgr.VarAttr(name = ENABLE_FOLD_CONSTANT_BY_BE)
private boolean enableFoldConstantByBe = false;
@ -1019,14 +1014,6 @@ public class SessionVariable implements Serializable, Writable {
this.enableJoinReorderBasedCost = enableJoinReorderBasedCost;
}
public boolean isEnableNereidsCBO() {
return enableNereidsCBO;
}
public void setEnableNereidsCBO(boolean enableNereidsCBO) {
this.enableNereidsCBO = enableNereidsCBO;
}
public void setDisableJoinReorder(boolean disableJoinReorder) {
this.disableJoinReorder = disableJoinReorder;
}

View File

@ -1,102 +0,0 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.jobs;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.cost.CostCalculator;
import org.apache.doris.nereids.jobs.cascades.OptimizeGroupJob;
import org.apache.doris.nereids.memo.GroupExpression;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.util.MemoTestUtils;
import org.apache.doris.nereids.util.PlanConstructor;
import com.google.common.collect.Lists;
import mockit.Mock;
import mockit.MockUp;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;
public class CostAndEnforcerJobTest {
/*
* topJoin
* / \
* C bottomJoin
* / \
* A B
*/
private static List<LogicalOlapScan> scans = Lists.newArrayList();
private static List<List<SlotReference>> outputs = Lists.newArrayList();
@BeforeAll
public static void init() {
LogicalOlapScan scan1 = PlanConstructor.newLogicalOlapScan(0, "a", 0);
LogicalOlapScan scan2 = PlanConstructor.newLogicalOlapScan(1, "b", 1);
LogicalOlapScan scan3 = PlanConstructor.newLogicalOlapScan(2, "c", 0);
scans.add(scan1);
scans.add(scan2);
scans.add(scan3);
List<SlotReference> t1Output = scan1.getOutput().stream().map(slot -> (SlotReference) slot)
.collect(Collectors.toList());
List<SlotReference> t2Output = scan2.getOutput().stream().map(slot -> (SlotReference) slot)
.collect(Collectors.toList());
List<SlotReference> t3Output = scan3.getOutput().stream().map(slot -> (SlotReference) slot)
.collect(Collectors.toList());
outputs.add(t1Output);
outputs.add(t2Output);
outputs.add(t3Output);
}
@Test
public void testExecute() {
new MockUp<CostCalculator>() {
@Mock
public double calculateCost(GroupExpression groupExpression) {
return 0;
}
};
/*
* bottomJoin
* / \
* A B
*/
Expression bottomJoinOnCondition = new EqualTo(outputs.get(0).get(0), outputs.get(1).get(0));
LogicalJoin<LogicalOlapScan, LogicalOlapScan> bottomJoin = new LogicalJoin<>(JoinType.INNER_JOIN,
Optional.of(bottomJoinOnCondition), scans.get(0), scans.get(1));
CascadesContext cascadesContext = MemoTestUtils.createCascadesContext(bottomJoin);
cascadesContext.pushJob(
new OptimizeGroupJob(
cascadesContext.getMemo().getRoot(),
cascadesContext.getCurrentJobContext()));
cascadesContext.getJobScheduler().executeJobPool(cascadesContext);
}
}