[feature](nereids) Add stats derive framework for new optimizer (#11179)

Implement a visitor to derive stats for each operator which would be used for CBO.
This commit is contained in:
Kikyou1997
2022-07-28 14:38:19 +08:00
committed by GitHub
parent ef25459fa1
commit 75451ab0ed
37 changed files with 1229 additions and 112 deletions

View File

@ -29,6 +29,7 @@ import org.apache.doris.nereids.jobs.batch.DisassembleRulesJob;
import org.apache.doris.nereids.jobs.batch.JoinReorderRulesJob;
import org.apache.doris.nereids.jobs.batch.OptimizeRulesJob;
import org.apache.doris.nereids.jobs.batch.PredicatePushDownRulesJob;
import org.apache.doris.nereids.jobs.cascades.DeriveStatsJob;
import org.apache.doris.nereids.memo.Group;
import org.apache.doris.nereids.memo.GroupExpression;
import org.apache.doris.nereids.properties.PhysicalProperties;
@ -105,6 +106,10 @@ public class NereidsPlanner extends Planner {
.setJobContext(outputProperties);
rewrite();
// TODO: remove this condition, when stats collector is fully developed.
if (ConnectContext.get().getSessionVariable().isEnableNereidsCBO()) {
deriveStats();
}
optimize();
// Get plan directly. Just for SSB.
@ -120,6 +125,10 @@ public class NereidsPlanner extends Planner {
new DisassembleRulesJob(plannerContext).execute();
}
private void deriveStats() {
new DeriveStatsJob(getRoot().getLogicalExpression(), plannerContext.getCurrentJobContext()).execute();
}
/**
* Cascades style optimize: perform equivalent logical plan exploration and physical implementation enumeration,
* try to find best plan under the guidance of statistic information and cost model.

View File

@ -51,6 +51,7 @@ import org.apache.doris.nereids.trees.plans.physical.PhysicalPlan;
import org.apache.doris.nereids.trees.plans.physical.PhysicalProject;
import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanVisitor;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.Utils;
import org.apache.doris.planner.AggregationNode;
import org.apache.doris.planner.DataPartition;
import org.apache.doris.planner.ExchangeNode;
@ -244,7 +245,7 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla
} catch (Exception e) {
throw new AnalysisException(e.getMessage());
}
exec(olapScanNode::init);
Utils.execWithUncheckedException(olapScanNode::init);
olapScanNode.addConjuncts(execConjunctsList);
context.addScanNode(olapScanNode);
// Create PlanFragment
@ -491,22 +492,4 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla
return fragment;
}
/**
* Helper function to eliminate unnecessary checked exception caught requirement from the main logic of translator.
*
* @param f function which would invoke the logic of
* stale code from old optimizer that could throw
* a checked exception
*/
public void exec(FuncWrapper f) {
try {
f.exec();
} catch (Exception e) {
throw new RuntimeException(e.getMessage(), e);
}
}
private interface FuncWrapper {
void exec() throws Exception;
}
}

View File

@ -22,6 +22,7 @@ import org.apache.doris.nereids.jobs.JobContext;
import org.apache.doris.nereids.jobs.JobType;
import org.apache.doris.nereids.memo.Group;
import org.apache.doris.nereids.memo.GroupExpression;
import org.apache.doris.nereids.stats.StatsCalculator;
/**
* Job to derive stats for {@link GroupExpression} in {@link org.apache.doris.nereids.memo.Memo}.
@ -64,8 +65,8 @@ public class DeriveStatsJob extends Job {
}
}
} else {
// TODO: derive stat here
groupExpression.setStatDerived(true);
StatsCalculator statsCalculator = new StatsCalculator(groupExpression);
statsCalculator.estimate();
}
}
}

View File

@ -69,6 +69,13 @@ public class Group {
groupExpression.setOwnerGroup(this);
}
/**
* For unit test only.
*/
public Group() {
groupId = null;
}
public GroupId getGroupId() {
return groupId;
}

View File

@ -22,6 +22,7 @@ import org.apache.doris.nereids.properties.PhysicalProperties;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.statistics.StatsDeriveResult;
import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
@ -174,4 +175,8 @@ public class GroupExpression {
public int hashCode() {
return Objects.hash(children, plan);
}
public StatsDeriveResult getCopyOfChildStats(int idx) {
return child(idx).getStatistics().copy();
}
}

View File

@ -0,0 +1,110 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.stats;
import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
import org.apache.doris.nereids.trees.expressions.CompoundPredicate;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.GreaterThan;
import org.apache.doris.nereids.trees.expressions.GreaterThanEqual;
import org.apache.doris.nereids.trees.expressions.LessThan;
import org.apache.doris.nereids.trees.expressions.Literal;
import org.apache.doris.nereids.trees.expressions.Or;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionVisitor;
import org.apache.doris.statistics.ColumnStats;
import com.google.common.base.Preconditions;
import java.util.Map;
/**
* Calculate selectivity of the filter.
*/
public class FilterSelectivityCalculator extends DefaultExpressionVisitor<Double, Void> {
private static double DEFAULT_SELECTIVITY = 0.1;
private final Map<Slot, ColumnStats> slotRefToStats;
public FilterSelectivityCalculator(Map<Slot, ColumnStats> slotRefToStats) {
Preconditions.checkState(slotRefToStats != null);
this.slotRefToStats = slotRefToStats;
}
/**
* Do estimate.
*/
public double estimate(Expression expression) {
// For a comparison predicate, only when it's left side is a slot and right side is a literal, we would
// consider is a valid predicate.
if (expression instanceof ComparisonPredicate
&& !(expression.child(0) instanceof SlotReference
&& expression.child(1) instanceof Literal)) {
return 1.0;
}
return expression.accept(this, null);
}
@Override
public Double visitCompoundPredicate(CompoundPredicate compoundPredicate, Void context) {
Expression leftExpr = compoundPredicate.child(0);
Expression rightExpr = compoundPredicate.child(1);
double leftSel = 1;
double rightSel = 1;
leftSel = estimate(leftExpr);
rightSel = estimate(rightExpr);
return compoundPredicate instanceof Or ? leftSel + rightSel - leftSel * rightSel : leftSel * rightSel;
}
@Override
public Double visitComparisonPredicate(ComparisonPredicate cp, Void context) {
return super.visitComparisonPredicate(cp, context);
}
// TODO: If right value greater than the max value or less than min value in column stats, return 0.0 .
@Override
public Double visitEqualTo(EqualTo equalTo, Void context) {
SlotReference left = (SlotReference) equalTo.left();
ColumnStats columnStats = slotRefToStats.get(left);
if (columnStats == null) {
return DEFAULT_SELECTIVITY;
}
long ndv = columnStats.getNdv();
return ndv < 0 ? DEFAULT_SELECTIVITY : ndv == 0 ? 0 : 1.0 / columnStats.getNdv();
}
// TODO: Should consider the distribution of data.
@Override
public Double visitGreaterThan(GreaterThan greaterThan, Void context) {
return DEFAULT_SELECTIVITY;
}
@Override
public Double visitGreaterThanEqual(GreaterThanEqual greaterThanEqual, Void context) {
return DEFAULT_SELECTIVITY;
}
@Override
public Double visitLessThan(LessThan lessThan, Void context) {
return DEFAULT_SELECTIVITY;
}
}

View File

@ -0,0 +1,160 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.stats;
import org.apache.doris.common.CheckedMath;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.statistics.ColumnStats;
import org.apache.doris.statistics.StatsDeriveResult;
import com.google.common.base.Preconditions;
import java.util.List;
import java.util.Map;
/**
* Estimate hash join stats.
* TODO: Update other props in the ColumnStats properly.
*/
public class JoinEstimation {
/**
* Do estimate.
*/
public static StatsDeriveResult estimate(StatsDeriveResult leftStats, StatsDeriveResult rightStats,
Expression eqCondition, JoinType joinType) {
StatsDeriveResult statsDeriveResult = new StatsDeriveResult(leftStats);
statsDeriveResult.merge(rightStats);
List<Expression> eqConjunctList = ExpressionUtils.extractConjunctive(eqCondition);
long rowCount = -1;
if (joinType.isSemiOrAntiJoin()) {
rowCount = getSemiJoinRowCount(leftStats, rightStats, eqConjunctList, joinType);
} else if (joinType.isInnerJoin() || joinType.isOuterJoin()) {
rowCount = getJoinRowCount(leftStats, rightStats, eqConjunctList, joinType);
} else if (joinType.isCrossJoin()) {
rowCount = CheckedMath.checkedMultiply(leftStats.getRowCount(),
rightStats.getRowCount());
} else {
throw new RuntimeException("joinType is not supported");
}
statsDeriveResult.setRowCount(rowCount);
return statsDeriveResult;
}
// TODO: If the condition of Join Plan could any expression in addition to EqualTo type,
// we should handle that properly.
private static long getSemiJoinRowCount(StatsDeriveResult leftStats, StatsDeriveResult rightStats,
List<Expression> eqConjunctList, JoinType joinType) {
long rowCount;
if (JoinType.RIGHT_SEMI_JOIN.equals(joinType) || JoinType.RIGHT_ANTI_JOIN.equals(joinType)) {
if (rightStats.getRowCount() == -1) {
return -1;
}
rowCount = rightStats.getRowCount();
} else {
if (leftStats.getRowCount() == -1) {
return -1;
}
rowCount = leftStats.getRowCount();
}
Map<Slot, ColumnStats> leftSlotToColStats = leftStats.getSlotToColumnStats();
Map<Slot, ColumnStats> rightSlotToColStats = rightStats.getSlotToColumnStats();
double minSelectivity = 1.0;
for (Expression eqJoinPredicate : eqConjunctList) {
long lhsNdv = leftSlotToColStats.get(eqJoinPredicate.child(0)).getNdv();
lhsNdv = Math.min(lhsNdv, leftStats.getRowCount());
long rhsNdv = rightSlotToColStats.get(eqJoinPredicate.child(1)).getNdv();
rhsNdv = Math.min(rhsNdv, rightStats.getRowCount());
// Skip conjuncts with unknown NDV on either side.
if (lhsNdv == -1 || rhsNdv == -1) {
continue;
}
// TODO: Do we need NULL_AWARE_LEFT_ANTI_JOIN type as stale optimizer?
double selectivity = 1.0;
switch (joinType) {
case LEFT_SEMI_JOIN: {
selectivity = (double) Math.min(lhsNdv, rhsNdv) / (double) (lhsNdv);
break;
}
case RIGHT_SEMI_JOIN: {
selectivity = (double) Math.min(lhsNdv, rhsNdv) / (double) (rhsNdv);
break;
}
case LEFT_ANTI_JOIN:
selectivity = (double) (lhsNdv > rhsNdv ? (lhsNdv - rhsNdv) : lhsNdv) / (double) lhsNdv;
break;
case RIGHT_ANTI_JOIN: {
selectivity = (double) (rhsNdv > lhsNdv ? (rhsNdv - lhsNdv) : rhsNdv) / (double) rhsNdv;
break;
}
default:
throw new RuntimeException("joinType is not supported");
}
minSelectivity = Math.min(minSelectivity, selectivity);
}
Preconditions.checkState(rowCount != -1);
return Math.round(rowCount * minSelectivity);
}
private static long getJoinRowCount(StatsDeriveResult leftStats, StatsDeriveResult rightStats,
List<Expression> eqConjunctList, JoinType joinType) {
long lhsCard = leftStats.getRowCount();
long rhsCard = rightStats.getRowCount();
Map<Slot, ColumnStats> leftSlotToColumnStats = leftStats.getSlotToColumnStats();
Map<Slot, ColumnStats> rightSlotToColumnStats = rightStats.getSlotToColumnStats();
if (lhsCard == -1 || rhsCard == -1) {
return lhsCard;
}
long result = -1;
for (Expression eqJoinConjunct : eqConjunctList) {
Expression left = eqJoinConjunct.child(0);
if (!(left instanceof SlotReference)) {
continue;
}
Expression right = eqJoinConjunct.child(1);
if (!(right instanceof SlotReference)) {
continue;
}
SlotReference leftSlot = (SlotReference) left;
ColumnStats leftColStats = leftSlotToColumnStats.get(leftSlot);
if (leftColStats == null) {
continue;
}
SlotReference rightSlot = (SlotReference) right;
ColumnStats rightColStats = rightSlotToColumnStats.get(rightSlot);
if (rightColStats == null) {
continue;
}
double leftSideNdv = leftColStats.getNdv();
double rightSideNdv = rightColStats.getNdv();
long tmpNdv = (long) Math.max(1, Math.max(leftSideNdv, rightSideNdv));
long joinCard = tmpNdv == rhsCard ? lhsCard : CheckedMath.checkedMultiply(
Math.round((lhsCard / Math.max(1, Math.max(leftSideNdv, rightSideNdv)))), rhsCard);
if (result == -1) {
result = joinCard;
} else {
result = Math.min(result, joinCard);
}
}
return result;
}
}

View File

@ -0,0 +1,235 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.stats;
import org.apache.doris.catalog.Table;
import org.apache.doris.nereids.memo.GroupExpression;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.plans.Aggregate;
import org.apache.doris.nereids.trees.plans.Filter;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.Project;
import org.apache.doris.nereids.trees.plans.Scan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.logical.LogicalSort;
import org.apache.doris.nereids.trees.plans.physical.PhysicalAggregate;
import org.apache.doris.nereids.trees.plans.physical.PhysicalDistribution;
import org.apache.doris.nereids.trees.plans.physical.PhysicalFilter;
import org.apache.doris.nereids.trees.plans.physical.PhysicalHashJoin;
import org.apache.doris.nereids.trees.plans.physical.PhysicalHeapSort;
import org.apache.doris.nereids.trees.plans.physical.PhysicalOlapScan;
import org.apache.doris.nereids.trees.plans.physical.PhysicalProject;
import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanVisitor;
import org.apache.doris.nereids.util.Utils;
import org.apache.doris.qe.ConnectContext;
import org.apache.doris.statistics.ColumnStats;
import org.apache.doris.statistics.StatsDeriveResult;
import org.apache.doris.statistics.TableStats;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
/**
* Used to calculate the stats for each operator
*/
public class StatsCalculator extends DefaultPlanVisitor<StatsDeriveResult, Void> {
private final GroupExpression groupExpression;
public StatsCalculator(GroupExpression groupExpression) {
this.groupExpression = groupExpression;
}
/**
* Do estimate.
*/
public void estimate() {
StatsDeriveResult stats = groupExpression.getPlan().accept(this, null);
groupExpression.getOwnerGroup().setStatistics(stats);
Plan plan = groupExpression.getPlan();
long limit = plan.getLimit();
if (limit != -1) {
stats.setRowCount(Math.min(limit, stats.getRowCount()));
}
groupExpression.setStatDerived(true);
}
@Override
public StatsDeriveResult visitLogicalAggregate(LogicalAggregate<Plan> aggregate, Void context) {
return computeAggregate(aggregate);
}
@Override
public StatsDeriveResult visitLogicalFilter(LogicalFilter<Plan> filter, Void context) {
return computeFilter(filter);
}
@Override
public StatsDeriveResult visitLogicalOlapScan(LogicalOlapScan olapScan, Void context) {
olapScan.getExpressions();
return computeScan(olapScan);
}
@Override
public StatsDeriveResult visitLogicalProject(LogicalProject<Plan> project, Void context) {
return computeProject(project);
}
@Override
public StatsDeriveResult visitLogicalSort(LogicalSort<Plan> sort, Void context) {
return groupExpression.getCopyOfChildStats(0);
}
@Override
public StatsDeriveResult visitLogicalJoin(LogicalJoin<Plan, Plan> join, Void context) {
return JoinEstimation.estimate(groupExpression.getCopyOfChildStats(0),
groupExpression.getCopyOfChildStats(1),
join.getCondition().get(), join.getJoinType());
}
@Override
public StatsDeriveResult visitPhysicalAggregate(PhysicalAggregate<Plan> agg, Void context) {
return computeAggregate(agg);
}
@Override
public StatsDeriveResult visitPhysicalOlapScan(PhysicalOlapScan olapScan, Void context) {
return computeScan(olapScan);
}
@Override
public StatsDeriveResult visitPhysicalHeapSort(PhysicalHeapSort<Plan> sort, Void context) {
return groupExpression.getCopyOfChildStats(0);
}
@Override
public StatsDeriveResult visitPhysicalHashJoin(PhysicalHashJoin<Plan, Plan> hashJoin, Void context) {
return JoinEstimation.estimate(groupExpression.getCopyOfChildStats(0),
groupExpression.getCopyOfChildStats(1),
hashJoin.getCondition().get(), hashJoin.getJoinType());
}
// TODO: We should subtract those pruned column, and consider the expression transformations in the node.
@Override
public StatsDeriveResult visitPhysicalProject(PhysicalProject<Plan> project, Void context) {
return computeProject(project);
}
@Override
public StatsDeriveResult visitPhysicalFilter(PhysicalFilter<Plan> filter, Void context) {
return computeFilter(filter);
}
@Override
public StatsDeriveResult visitPhysicalDistribution(PhysicalDistribution<Plan> distribution,
Void context) {
return groupExpression.getCopyOfChildStats(0);
}
private StatsDeriveResult computeFilter(Filter filter) {
StatsDeriveResult stats = groupExpression.getCopyOfChildStats(0);
FilterSelectivityCalculator selectivityCalculator =
new FilterSelectivityCalculator(stats.getSlotToColumnStats());
double selectivity = selectivityCalculator.estimate(filter.getPredicates());
stats.multiplyDouble(selectivity);
return stats;
}
// TODO: 1. Subtract the pruned partition
// 2. Consider the influence of runtime filter
// 3. Get NDV and column data size from StatisticManger, StatisticManager doesn't support it now.
private StatsDeriveResult computeScan(Scan scan) {
Table table = scan.getTable();
TableStats tableStats = Utils.execWithReturnVal(() ->
ConnectContext.get().getEnv().getStatisticsManager().getStatistics().getTableStats(table.getId())
);
Map<Slot, ColumnStats> slotToColumnStats = new HashMap<>();
Set<SlotReference> slotSet = scan.getOutput().stream().filter(SlotReference.class::isInstance)
.map(s -> (SlotReference) s).collect(Collectors.toSet());
for (SlotReference slotReference : slotSet) {
String colName = slotReference.getName();
if (colName == null) {
throw new RuntimeException("Column name of SlotReference shouldn't be null here");
}
ColumnStats columnStats = tableStats.getColumnStats(colName);
slotToColumnStats.put(slotReference, columnStats);
}
long rowCount = tableStats.getRowCount();
StatsDeriveResult stats = new StatsDeriveResult(rowCount,
new HashMap<>(), new HashMap<>());
stats.setSlotToColumnStats(slotToColumnStats);
return stats;
}
private StatsDeriveResult computeAggregate(Aggregate aggregate) {
List<Expression> groupByExprList = aggregate.getGroupByExpressions();
StatsDeriveResult childStats = groupExpression.getCopyOfChildStats(0);
Map<Slot, ColumnStats> childSlotColumnStatsMap = childStats.getSlotToColumnStats();
long resultSetCount = 1;
for (Expression expression : groupByExprList) {
List<SlotReference> slotRefList = expression.collect(SlotReference.class::isInstance);
// TODO: Support more complex group expr.
// For example:
// select max(col1+col3) from t1 group by col1+col3;
if (slotRefList.size() != 1) {
continue;
}
SlotReference slotRef = slotRefList.get(0);
ColumnStats columnStats = childSlotColumnStatsMap.get(slotRef);
resultSetCount *= columnStats.getNdv();
}
Map<Slot, ColumnStats> slotColumnStatsMap = new HashMap<>();
List<NamedExpression> namedExpressionList = aggregate.getOutputExpressions();
// TODO: 1. Estimate the output unit size by the type of corresponding AggregateFunction
// 2. Handle alias, literal in the output expression list
for (NamedExpression namedExpression : namedExpressionList) {
if (namedExpression instanceof SlotReference) {
slotColumnStatsMap.put((SlotReference) namedExpression, new ColumnStats());
}
}
StatsDeriveResult statsDeriveResult = new StatsDeriveResult(resultSetCount, slotColumnStatsMap);
// TODO: Update ColumnStats properly, add new mapping from output slot to ColumnStats
return statsDeriveResult;
}
// TODO: Update data size and min/max value.
private StatsDeriveResult computeProject(Project project) {
List<NamedExpression> namedExpressionList = project.getProjects();
Set<Slot> slotSet = new HashSet<>();
for (NamedExpression namedExpression : namedExpressionList) {
List<SlotReference> slotReferenceList = namedExpression.collect(SlotReference.class::isInstance);
slotSet.addAll(slotReferenceList);
}
StatsDeriveResult stat = groupExpression.getCopyOfChildStats(0);
Map<Slot, ColumnStats> slotColumnStatsMap = stat.getSlotToColumnStats();
slotColumnStatsMap.entrySet().removeIf(entry -> !slotSet.contains(entry.getKey()));
return stat;
}
}

View File

@ -110,7 +110,6 @@ public class SlotReference extends Slot {
return nullable == that.nullable
&& dataType.equals(that.dataType)
&& exprId.equals(that.exprId)
&& dataType.equals(that.dataType)
&& name.equals(that.name)
&& qualifier.equals(that.qualifier);
}

View File

@ -20,14 +20,11 @@ package org.apache.doris.nereids.trees.plans;
import org.apache.doris.nereids.memo.GroupExpression;
import org.apache.doris.nereids.properties.LogicalProperties;
import org.apache.doris.nereids.trees.AbstractTreeNode;
import org.apache.doris.statistics.ExprStats;
import org.apache.doris.statistics.StatisticalType;
import org.apache.doris.statistics.StatsDeriveResult;
import org.apache.commons.lang3.StringUtils;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
@ -38,7 +35,7 @@ import java.util.Optional;
public abstract class AbstractPlan extends AbstractTreeNode<Plan> implements Plan {
protected StatsDeriveResult statsDeriveResult;
protected long limit;
protected long limit = -1;
protected final PlanType type;
protected final LogicalProperties logicalProperties;
@ -101,36 +98,6 @@ public abstract class AbstractPlan extends AbstractTreeNode<Plan> implements Pla
}
}
@Override
public List<StatsDeriveResult> getChildrenStats() {
return Collections.emptyList();
}
@Override
public StatsDeriveResult getStatsDeriveResult() {
return statsDeriveResult;
}
@Override
public StatisticalType getStatisticalType() {
return null;
}
@Override
public void setStatsDeriveResult(StatsDeriveResult result) {
this.statsDeriveResult = result;
}
@Override
public long getLimit() {
return limit;
}
@Override
public List<? extends ExprStats> getConjuncts() {
return Collections.emptyList();
}
@Override
public boolean equals(Object o) {
if (this == o) {
@ -149,4 +116,8 @@ public abstract class AbstractPlan extends AbstractTreeNode<Plan> implements Pla
public int hashCode() {
return Objects.hash(statsDeriveResult, limit, logicalProperties);
}
public long getLimit() {
return limit;
}
}

View File

@ -0,0 +1,33 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.trees.plans;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import java.util.List;
/**
* Common interface for logical/physical Aggregate.
*/
public interface Aggregate {
List<Expression> getGroupByExpressions();
List<NamedExpression> getOutputExpressions();
}

View File

@ -0,0 +1,27 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.trees.plans;
import org.apache.doris.nereids.trees.expressions.Expression;
/**
* Common interface for logical/physical filter.
*/
public interface Filter {
Expression getPredicates();
}

View File

@ -24,9 +24,6 @@ import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.logical.LogicalLeaf;
import org.apache.doris.nereids.trees.plans.visitor.PlanVisitor;
import org.apache.doris.statistics.ExprStats;
import org.apache.doris.statistics.StatisticalType;
import org.apache.doris.statistics.StatsDeriveResult;
import com.google.common.collect.ImmutableList;
@ -70,36 +67,11 @@ public class GroupPlan extends LogicalLeaf {
throw new IllegalStateException("GroupPlan can not invoke withChildren()");
}
@Override
public List<StatsDeriveResult> getChildrenStats() {
throw new IllegalStateException("GroupPlan can not invoke getChildrenStats()");
}
@Override
public StatsDeriveResult getStatsDeriveResult() {
throw new IllegalStateException("GroupPlan can not invoke getStatsDeriveResult()");
}
@Override
public StatisticalType getStatisticalType() {
throw new IllegalStateException("GroupPlan can not invoke getStatisticalType()");
}
@Override
public void setStatsDeriveResult(StatsDeriveResult result) {
throw new IllegalStateException("GroupPlan can not invoke setStatsDeriveResult()");
}
@Override
public long getLimit() {
throw new IllegalStateException("GroupPlan can not invoke getLimit()");
}
@Override
public List<? extends ExprStats> getConjuncts() {
throw new IllegalStateException("GroupPlan can not invoke getConjuncts()");
}
@Override
public Plan withGroupExpression(Optional<GroupExpression> groupExpression) {
throw new IllegalStateException("GroupPlan can not invoke withGroupExpression()");

View File

@ -99,4 +99,16 @@ public enum JoinType {
public JoinType swap() {
return joinSwapMap.get(this);
}
public boolean isSemiOrAntiJoin() {
return this == LEFT_SEMI_JOIN || this == RIGHT_SEMI_JOIN || this == LEFT_ANTI_JOIN || this == RIGHT_ANTI_JOIN;
}
public boolean isInnerJoin() {
return this == INNER_JOIN;
}
public boolean isOuterJoin() {
return this == LEFT_OUTER_JOIN || this == RIGHT_OUTER_JOIN || this == FULL_OUTER_JOIN;
}
}

View File

@ -23,7 +23,6 @@ import org.apache.doris.nereids.trees.TreeNode;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.visitor.PlanVisitor;
import org.apache.doris.statistics.PlanStats;
import java.util.List;
import java.util.Optional;
@ -31,7 +30,7 @@ import java.util.Optional;
/**
* Abstract class for all plan node.
*/
public interface Plan extends TreeNode<Plan>, PlanStats {
public interface Plan extends TreeNode<Plan> {
PlanType getType();
@ -62,4 +61,6 @@ public interface Plan extends TreeNode<Plan>, PlanStats {
Plan withGroupExpression(Optional<GroupExpression> groupExpression);
Plan withLogicalProperties(Optional<LogicalProperties> logicalProperties);
long getLimit();
}

View File

@ -0,0 +1,29 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.trees.plans;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import java.util.List;
/**
* Common interface for logical/physical project.
*/
public interface Project {
List<NamedExpression> getProjects();
}

View File

@ -0,0 +1,38 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.trees.plans;
import org.apache.doris.catalog.Table;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Slot;
import java.util.Collections;
import java.util.List;
/**
* Common interface for logical/physical scan.
*/
public interface Scan {
List<Expression> getExpressions();
Table getTable();
default List<Slot> getOutput() {
return Collections.emptyList();
}
}

View File

@ -23,6 +23,7 @@ import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.AggPhase;
import org.apache.doris.nereids.trees.plans.Aggregate;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.PlanType;
import org.apache.doris.nereids.trees.plans.visitor.PlanVisitor;
@ -48,7 +49,7 @@ import java.util.Optional;
* Note: In general, the output of agg is a subset of the group by column plus aggregate column.
* In special cases. this relationship does not hold. for example, select k1+1, sum(v1) from table group by k1.
*/
public class LogicalAggregate<CHILD_TYPE extends Plan> extends LogicalUnary<CHILD_TYPE> {
public class LogicalAggregate<CHILD_TYPE extends Plan> extends LogicalUnary<CHILD_TYPE> implements Aggregate {
private final boolean disassembled;
private final List<Expression> groupByExpressions;

View File

@ -21,6 +21,7 @@ import org.apache.doris.nereids.memo.GroupExpression;
import org.apache.doris.nereids.properties.LogicalProperties;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.Filter;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.PlanType;
import org.apache.doris.nereids.trees.plans.visitor.PlanVisitor;
@ -35,7 +36,7 @@ import java.util.Optional;
/**
* Logical filter plan.
*/
public class LogicalFilter<CHILD_TYPE extends Plan> extends LogicalUnary<CHILD_TYPE> {
public class LogicalFilter<CHILD_TYPE extends Plan> extends LogicalUnary<CHILD_TYPE> implements Filter {
private final Expression predicates;

View File

@ -75,7 +75,7 @@ public class LogicalOlapScan extends LogicalRelation {
}
@Override
public Plan withLogicalProperties(Optional<LogicalProperties> logicalProperties) {
public LogicalOlapScan withLogicalProperties(Optional<LogicalProperties> logicalProperties) {
return new LogicalOlapScan(table, qualifier, Optional.empty(), logicalProperties);
}

View File

@ -24,6 +24,7 @@ import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.PlanType;
import org.apache.doris.nereids.trees.plans.Project;
import org.apache.doris.nereids.trees.plans.visitor.PlanVisitor;
import com.google.common.base.Preconditions;
@ -37,7 +38,7 @@ import java.util.Optional;
/**
* Logical project plan.
*/
public class LogicalProject<CHILD_TYPE extends Plan> extends LogicalUnary<CHILD_TYPE> {
public class LogicalProject<CHILD_TYPE extends Plan> extends LogicalUnary<CHILD_TYPE> implements Project {
private final List<NamedExpression> projects;

View File

@ -24,6 +24,7 @@ import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.plans.PlanType;
import org.apache.doris.nereids.trees.plans.Scan;
import org.apache.doris.nereids.trees.plans.visitor.PlanVisitor;
import org.apache.doris.nereids.util.Utils;
@ -36,7 +37,7 @@ import java.util.Optional;
/**
* Logical relation plan.
*/
public abstract class LogicalRelation extends LogicalLeaf {
public abstract class LogicalRelation extends LogicalLeaf implements Scan {
protected final Table table;
protected final List<String> qualifier;

View File

@ -22,6 +22,7 @@ import org.apache.doris.nereids.properties.LogicalProperties;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.plans.AggPhase;
import org.apache.doris.nereids.trees.plans.Aggregate;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.PlanType;
import org.apache.doris.nereids.trees.plans.visitor.PlanVisitor;
@ -37,7 +38,7 @@ import java.util.Optional;
/**
* Physical aggregation plan.
*/
public class PhysicalAggregate<CHILD_TYPE extends Plan> extends PhysicalUnary<CHILD_TYPE> {
public class PhysicalAggregate<CHILD_TYPE extends Plan> extends PhysicalUnary<CHILD_TYPE> implements Aggregate {
private final List<Expression> groupByExpressions;

View File

@ -20,6 +20,7 @@ package org.apache.doris.nereids.trees.plans.physical;
import org.apache.doris.nereids.memo.GroupExpression;
import org.apache.doris.nereids.properties.LogicalProperties;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.plans.Filter;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.PlanType;
import org.apache.doris.nereids.trees.plans.visitor.PlanVisitor;
@ -34,7 +35,7 @@ import java.util.Optional;
/**
* Physical filter plan.
*/
public class PhysicalFilter<CHILD_TYPE extends Plan> extends PhysicalUnary<CHILD_TYPE> {
public class PhysicalFilter<CHILD_TYPE extends Plan> extends PhysicalUnary<CHILD_TYPE> implements Filter {
private final Expression predicates;

View File

@ -23,6 +23,7 @@ import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.PlanType;
import org.apache.doris.nereids.trees.plans.Project;
import org.apache.doris.nereids.trees.plans.visitor.PlanVisitor;
import com.google.common.base.Preconditions;
@ -35,7 +36,7 @@ import java.util.Optional;
/**
* Physical project plan.
*/
public class PhysicalProject<CHILD_TYPE extends Plan> extends PhysicalUnary<CHILD_TYPE> {
public class PhysicalProject<CHILD_TYPE extends Plan> extends PhysicalUnary<CHILD_TYPE> implements Project {
private final List<NamedExpression> projects;

View File

@ -21,6 +21,7 @@ import org.apache.doris.nereids.memo.GroupExpression;
import org.apache.doris.nereids.properties.LogicalProperties;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.plans.PlanType;
import org.apache.doris.nereids.trees.plans.Scan;
import org.apache.doris.nereids.trees.plans.visitor.PlanVisitor;
import com.google.common.collect.ImmutableList;
@ -32,7 +33,7 @@ import java.util.Optional;
/**
* Abstract class for all physical scan plan.
*/
public abstract class PhysicalRelation extends PhysicalLeaf {
public abstract class PhysicalRelation extends PhysicalLeaf implements Scan {
protected final List<String> qualifier;

View File

@ -59,8 +59,8 @@ public abstract class PlanVisitor<R, C> {
}
public R visitLogicalAggregate(LogicalAggregate<Plan> relation, C context) {
return visit(relation, context);
public R visitLogicalAggregate(LogicalAggregate<Plan> aggregate, C context) {
return visit(aggregate, context);
}
public R visitLogicalFilter(LogicalFilter<Plan> filter, C context) {

View File

@ -38,6 +38,50 @@ public class Utils {
? part : part.replace("`", "``");
}
/**
* Helper function to eliminate unnecessary checked exception caught requirement from the main logic of translator.
*
* @param f function which would invoke the logic of
* stale code from old optimizer that could throw
* a checked exception
*/
public static void execWithUncheckedException(FuncWrapper f) {
try {
f.exec();
} catch (Exception e) {
throw new RuntimeException(e.getMessage(), e);
}
}
/**
* Helper function to eliminate unnecessary checked exception caught requirement from the main logic of translator.
*
*/
@SuppressWarnings("unchecked")
public static <R> R execWithReturnVal(Supplier<R> f) {
final Object[] ans = new Object[]{null};
try {
ans[0] = f.get();
} catch (Exception e) {
throw new RuntimeException(e.getMessage(), e);
}
return (R) ans[0];
}
/**
* Wrapper to a function without return value.
*/
public interface FuncWrapper {
void exec() throws Exception;
}
/**
* Wrapper to a funciton with return value.
*/
public interface Supplier<R> {
R get() throws Exception;
}
/**
* Fully qualified identifier name parts, i.e., concat qualifier and name into a list.
*/

View File

@ -92,6 +92,8 @@ public class SessionVariable implements Serializable, Writable {
public static final String ENABLE_COST_BASED_JOIN_REORDER = "enable_cost_based_join_reorder";
public static final String ENABLE_NEREIDS_CBO = "enable_nereids_cbo";
public static final int MIN_EXEC_INSTANCE_NUM = 1;
public static final int MAX_EXEC_INSTANCE_NUM = 32;
// if set to true, some of stmt will be forwarded to master FE to get result
@ -431,6 +433,9 @@ public class SessionVariable implements Serializable, Writable {
@VariableMgr.VarAttr(name = ENABLE_COST_BASED_JOIN_REORDER)
private boolean enableJoinReorderBasedCost = false;
@VariableMgr.VarAttr(name = ENABLE_NEREIDS_CBO)
private boolean enableNereidsCBO = false;
@VariableMgr.VarAttr(name = ENABLE_FOLD_CONSTANT_BY_BE)
private boolean enableFoldConstantByBe = false;
@ -979,6 +984,14 @@ public class SessionVariable implements Serializable, Writable {
this.enableJoinReorderBasedCost = enableJoinReorderBasedCost;
}
public boolean isEnableNereidsCBO() {
return enableNereidsCBO;
}
public void setEnableNereidsCBO(boolean enableNereidsCBO) {
this.enableNereidsCBO = enableNereidsCBO;
}
public void setDisableJoinReorder(boolean disableJoinReorder) {
this.disableJoinReorder = disableJoinReorder;
}

View File

@ -155,7 +155,7 @@ public class BaseStatsDerive {
protected HashMap<Id, Float> deriveColumnToDataSize() {
HashMap<Id, Float> columnToDataSize = new HashMap<>();
for (StatsDeriveResult child : childrenStatsResult) {
columnToDataSize.putAll(child.getColumnToDataSize());
columnToDataSize.putAll(child.getColumnIdToDataSize());
}
return columnToDataSize;
}
@ -163,7 +163,7 @@ public class BaseStatsDerive {
protected HashMap<Id, Long> deriveColumnToNdv() {
HashMap<Id, Long> columnToNdv = new HashMap<>();
for (StatsDeriveResult child : childrenStatsResult) {
columnToNdv.putAll(child.getColumnToNdv());
columnToNdv.putAll(child.getColumnIdToNdv());
}
return columnToNdv;
}

View File

@ -76,6 +76,22 @@ public class ColumnStats {
private LiteralExpr minValue;
private LiteralExpr maxValue;
public ColumnStats(ColumnStats other) {
this.ndv = other.ndv;
this.avgSize = other.avgSize;
this.maxSize = other.maxSize;
this.numNulls = other.numNulls;
if (other.minValue != null) {
this.minValue = (LiteralExpr) other.minValue.clone();
}
if (other.maxValue != null) {
this.maxValue = (LiteralExpr) other.maxValue.clone();
}
}
public ColumnStats() {
}
public long getNdv() {
return ndv;
}
@ -224,4 +240,14 @@ public class ColumnStats {
throw new AnalysisException("Unsupported setting this type: " + type + " of min max value");
}
}
public ColumnStats copy() {
return new ColumnStats(this);
}
public ColumnStats multiplyDouble(double selectivity) {
ndv *= selectivity;
numNulls *= selectivity;
return this;
}
}

View File

@ -174,4 +174,9 @@ public class Statistics {
tableStats.setRowCount(rowCount);
}
}
// Used for unit test
public void putTableStats(long id, TableStats tableStats) {
this.idToTableStats.put(id, tableStats);
}
}

View File

@ -18,9 +18,11 @@
package org.apache.doris.statistics;
import org.apache.doris.common.Id;
import org.apache.doris.nereids.trees.expressions.Slot;
import com.google.common.collect.Maps;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
@ -32,19 +34,36 @@ public class StatsDeriveResult {
private long rowCount = -1;
// The data size of the corresponding column in the operator
// The actual key is slotId
private final Map<Id, Float> columnToDataSize = Maps.newHashMap();
private final Map<Id, Float> columnIdToDataSize = Maps.newHashMap();
// The ndv of the corresponding column in the operator
// The actual key is slotId
private final Map<Id, Long> columnToNdv = Maps.newHashMap();
private final Map<Id, Long> columnIdToNdv = Maps.newHashMap();
public StatsDeriveResult(long rowCount, Map<Id, Float> columnToDataSize, Map<Id, Long> columnToNdv) {
private Map<Slot, ColumnStats> slotToColumnStats;
public StatsDeriveResult(long rowCount, Map<Slot, ColumnStats> slotToColumnStats) {
this.rowCount = rowCount;
this.columnToDataSize.putAll(columnToDataSize);
this.columnToNdv.putAll(columnToNdv);
this.slotToColumnStats = slotToColumnStats;
}
public StatsDeriveResult(long rowCount, Map<Id, Float> columnIdToDataSize, Map<Id, Long> columnIdToNdv) {
this.rowCount = rowCount;
this.columnIdToDataSize.putAll(columnIdToDataSize);
this.columnIdToNdv.putAll(columnIdToNdv);
}
public StatsDeriveResult(StatsDeriveResult another) {
this.rowCount = another.rowCount;
this.columnIdToDataSize.putAll(another.columnIdToDataSize);
this.columnIdToNdv.putAll(another.columnIdToNdv);
slotToColumnStats = new HashMap<>();
for (Entry<Slot, ColumnStats> entry : another.slotToColumnStats.entrySet()) {
slotToColumnStats.put(entry.getKey(), entry.getValue().copy());
}
}
public float computeSize() {
return Math.max(1, columnToDataSize.values().stream().reduce((float) 0, Float::sum)) * rowCount;
return Math.max(1, columnIdToDataSize.values().stream().reduce((float) 0, Float::sum)) * rowCount;
}
/**
@ -57,7 +76,7 @@ public class StatsDeriveResult {
float count = 0;
boolean exist = false;
for (Entry<Id, Float> entry : columnToDataSize.entrySet()) {
for (Entry<Id, Float> entry : columnIdToDataSize.entrySet()) {
if (slotIds.contains(entry.getKey())) {
count += entry.getValue();
exist = true;
@ -77,11 +96,38 @@ public class StatsDeriveResult {
return rowCount;
}
public Map<Id, Long> getColumnToNdv() {
return columnToNdv;
public Map<Id, Long> getColumnIdToNdv() {
return columnIdToNdv;
}
public Map<Id, Float> getColumnToDataSize() {
return columnToDataSize;
public Map<Id, Float> getColumnIdToDataSize() {
return columnIdToDataSize;
}
public Map<Slot, ColumnStats> getSlotToColumnStats() {
return slotToColumnStats;
}
public void setSlotToColumnStats(Map<Slot, ColumnStats> slotToColumnStats) {
this.slotToColumnStats = slotToColumnStats;
}
public StatsDeriveResult multiplyDouble(double selectivity) {
rowCount *= selectivity;
for (Entry<Slot, ColumnStats> entry : slotToColumnStats.entrySet()) {
entry.getValue().multiplyDouble(selectivity);
}
return this;
}
public StatsDeriveResult merge(StatsDeriveResult other) {
for (Entry<Slot, ColumnStats> entry : other.getSlotToColumnStats().entrySet()) {
this.slotToColumnStats.put(entry.getKey(), entry.getValue().copy());
}
return this;
}
public StatsDeriveResult copy() {
return new StatsDeriveResult(this);
}
}

View File

@ -116,7 +116,7 @@ public class TableStats {
public void updateColumnStats(String columnName, Type columnType, Map<StatsType, String> statsTypeToValue)
throws AnalysisException {
ColumnStats columnStats = getNotNullColumnStats(columnName);
ColumnStats columnStats = getColumnStats(columnName);
columnStats.updateStats(columnType, statsTypeToValue);
}
@ -141,7 +141,7 @@ public class TableStats {
* @param columnName column name
* @return @ColumnStats
*/
public ColumnStats getNotNullColumnStats(String columnName) {
public ColumnStats getColumnStats(String columnName) {
ColumnStats columnStats = nameToColumnStats.get(columnName);
if (columnStats == null) {
columnStats = new ColumnStats();
@ -247,4 +247,11 @@ public class TableStats {
}
}
}
/**
* This method is for unit test.
*/
public void putColumnStats(String name, ColumnStats columnStats) {
nameToColumnStats.put(name, columnStats);
}
}

View File

@ -0,0 +1,127 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.jobs.cascades;
import org.apache.doris.catalog.Env;
import org.apache.doris.catalog.Table;
import org.apache.doris.catalog.TableIf.TableType;
import org.apache.doris.nereids.PlannerContext;
import org.apache.doris.nereids.jobs.JobContext;
import org.apache.doris.nereids.memo.Memo;
import org.apache.doris.nereids.properties.LogicalProperties;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.functions.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.Sum;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.types.IntegerType;
import org.apache.doris.qe.ConnectContext;
import org.apache.doris.statistics.ColumnStats;
import org.apache.doris.statistics.Statistics;
import org.apache.doris.statistics.StatisticsManager;
import org.apache.doris.statistics.StatsDeriveResult;
import org.apache.doris.statistics.TableStats;
import com.google.common.base.Supplier;
import mockit.Expectations;
import mockit.Mocked;
import org.junit.Assert;
import org.junit.Test;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
public class DeriveStatsJobTest {
@Mocked
ConnectContext context;
@Mocked
Env env;
@Mocked
StatisticsManager statisticsManager;
SlotReference slot1;
@Test
public void testExecute() throws Exception {
LogicalOlapScan olapScan = constructOlapSCan();
LogicalAggregate agg = constructAgg(olapScan);
Memo memo = new Memo(agg);
PlannerContext plannerContext = new PlannerContext(memo, context);
new DeriveStatsJob(memo.getRoot().getLogicalExpression(),
new JobContext(plannerContext, null, Double.MAX_VALUE)).execute();
while (!plannerContext.getJobPool().isEmpty()) {
plannerContext.getJobPool().pop().execute();
}
StatsDeriveResult statistics = memo.getRoot().getStatistics();
Assert.assertNotNull(statistics);
Assert.assertEquals(10, statistics.getRowCount());
}
private LogicalOlapScan constructOlapSCan() {
ColumnStats columnStats1 = new ColumnStats();
columnStats1.setNdv(10);
columnStats1.setNumNulls(5);
long tableId1 = 0;
String tableName1 = "t1";
TableStats tableStats1 = new TableStats();
tableStats1.putColumnStats("c1", columnStats1);
Statistics statistics = new Statistics();
statistics.putTableStats(tableId1, tableStats1);
List<String> qualifier = new ArrayList<>();
qualifier.add("test");
qualifier.add("t");
slot1 = new SlotReference("c1", IntegerType.INSTANCE, true, qualifier);
new Expectations() {{
ConnectContext.get();
result = context;
context.getEnv();
result = env;
env.getStatisticsManager();
result = statisticsManager;
statisticsManager.getStatistics();
result = statistics;
}};
Table table1 = new Table(tableId1, tableName1, TableType.OLAP, Collections.emptyList());
LogicalOlapScan logicalOlapScan1 = new LogicalOlapScan(table1, Collections.emptyList()).withLogicalProperties(
Optional.of(new LogicalProperties(new Supplier<List<Slot>>() {
@Override
public List<Slot> get() {
return Arrays.asList(slot1);
}
})));
return logicalOlapScan1;
}
private LogicalAggregate constructAgg(Plan child) {
List<Expression> groupByExprList = new ArrayList<>();
groupByExprList.add(slot1);
AggregateFunction sum = new Sum(slot1);
Alias alias = new Alias(sum, "a");
LogicalAggregate logicalAggregate = new LogicalAggregate(groupByExprList, Arrays.asList(alias), child);
return logicalAggregate;
}
}

View File

@ -104,6 +104,11 @@ public class TestPlanOutput {
public Plan withLogicalProperties(Optional<LogicalProperties> logicalProperties) {
return null;
}
@Override
public Table getTable() {
return null;
}
};
}
}

View File

@ -0,0 +1,244 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.stats;
import org.apache.doris.catalog.Env;
import org.apache.doris.catalog.Table;
import org.apache.doris.catalog.TableIf.TableType;
import org.apache.doris.nereids.memo.Group;
import org.apache.doris.nereids.memo.GroupExpression;
import org.apache.doris.nereids.properties.LogicalProperties;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.And;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.IntegerLiteral;
import org.apache.doris.nereids.trees.expressions.Or;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.functions.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.Sum;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.types.IntegerType;
import org.apache.doris.qe.ConnectContext;
import org.apache.doris.statistics.ColumnStats;
import org.apache.doris.statistics.Statistics;
import org.apache.doris.statistics.StatisticsManager;
import org.apache.doris.statistics.StatsDeriveResult;
import org.apache.doris.statistics.TableStats;
import com.google.common.base.Supplier;
import mockit.Expectations;
import mockit.Mocked;
import org.junit.Assert;
import org.junit.Test;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
public class StatsCalculatorTest {
@Mocked
ConnectContext context;
@Mocked
Env env;
@Mocked
StatisticsManager statisticsManager;
@Test
public void testAgg() {
List<String> qualifier = new ArrayList<>();
qualifier.add("test");
qualifier.add("t");
SlotReference slot1 = new SlotReference("c1", IntegerType.INSTANCE, true, qualifier);
SlotReference slot2 = new SlotReference("c2", IntegerType.INSTANCE, true, qualifier);
ColumnStats columnStats1 = new ColumnStats();
columnStats1.setNdv(10);
columnStats1.setNumNulls(5);
ColumnStats columnStats2 = new ColumnStats();
columnStats2.setNdv(20);
columnStats1.setNumNulls(10);
Map<Slot, ColumnStats> slotColumnStatsMap = new HashMap<>();
slotColumnStatsMap.put(slot1, columnStats1);
slotColumnStatsMap.put(slot2, columnStats2);
List<Expression> groupByExprList = new ArrayList<>();
groupByExprList.add(slot1);
AggregateFunction sum = new Sum(slot2);
StatsDeriveResult childStats = new StatsDeriveResult(20, slotColumnStatsMap);
Alias alias = new Alias(sum, "a");
Group childGroup = new Group();
childGroup.setLogicalProperties(new LogicalProperties(new Supplier<List<Slot>>() {
@Override
public List<Slot> get() {
return Collections.emptyList();
}
}));
GroupPlan groupPlan = new GroupPlan(childGroup);
childGroup.setStatistics(childStats);
LogicalAggregate logicalAggregate = new LogicalAggregate(groupByExprList, Arrays.asList(alias), groupPlan);
GroupExpression groupExpression = new GroupExpression(logicalAggregate, Arrays.asList(childGroup));
Group ownerGroup = new Group();
groupExpression.setOwnerGroup(ownerGroup);
StatsCalculator statsCalculator = new StatsCalculator(groupExpression);
statsCalculator.estimate();
Assert.assertEquals(groupExpression.getOwnerGroup().getStatistics().getRowCount(), 10);
}
@Test
public void testFilter() {
List<String> qualifier = new ArrayList<>();
qualifier.add("test");
qualifier.add("t");
SlotReference slot1 = new SlotReference("c1", IntegerType.INSTANCE, true, qualifier);
SlotReference slot2 = new SlotReference("c2", IntegerType.INSTANCE, true, qualifier);
ColumnStats columnStats1 = new ColumnStats();
columnStats1.setNdv(10);
columnStats1.setNumNulls(5);
ColumnStats columnStats2 = new ColumnStats();
columnStats2.setNdv(20);
columnStats1.setNumNulls(10);
Map<Slot, ColumnStats> slotColumnStatsMap = new HashMap<>();
slotColumnStatsMap.put(slot1, columnStats1);
slotColumnStatsMap.put(slot2, columnStats2);
StatsDeriveResult childStats = new StatsDeriveResult(10000, slotColumnStatsMap);
EqualTo eq1 = new EqualTo(slot1, new IntegerLiteral(1));
EqualTo eq2 = new EqualTo(slot2, new IntegerLiteral(2));
And and = new And(eq1, eq2);
Or or = new Or(eq1, eq2);
Group childGroup = new Group();
childGroup.setLogicalProperties(new LogicalProperties(new Supplier<List<Slot>>() {
@Override
public List<Slot> get() {
return Collections.emptyList();
}
}));
GroupPlan groupPlan = new GroupPlan(childGroup);
childGroup.setStatistics(childStats);
LogicalFilter logicalFilter = new LogicalFilter(and, groupPlan);
GroupExpression groupExpression = new GroupExpression(logicalFilter);
groupExpression.addChild(childGroup);
Group ownerGroup = new Group();
groupExpression.setOwnerGroup(ownerGroup);
StatsCalculator statsCalculator = new StatsCalculator(groupExpression);
statsCalculator.estimate();
Assert.assertEquals((long) (10000 * 0.1 * 0.05), ownerGroup.getStatistics().getRowCount(), 0.001);
LogicalFilter logicalFilterOr = new LogicalFilter(or, groupPlan);
GroupExpression groupExpressionOr = new GroupExpression(logicalFilterOr);
groupExpressionOr.addChild(childGroup);
Group ownerGroupOr = new Group();
groupExpressionOr.setOwnerGroup(ownerGroupOr);
StatsCalculator statsCalculator2 = new StatsCalculator(groupExpressionOr);
statsCalculator2.estimate();
Assert.assertEquals((long) (10000 * (0.1 + 0.05 - 0.1 * 0.05)),
ownerGroupOr.getStatistics().getRowCount(), 0.001);
}
@Test
public void testHashJoin() {
List<String> qualifier = new ArrayList<>();
qualifier.add("test");
qualifier.add("t");
SlotReference slot1 = new SlotReference("c1", IntegerType.INSTANCE, true, qualifier);
SlotReference slot2 = new SlotReference("c2", IntegerType.INSTANCE, true, qualifier);
ColumnStats columnStats1 = new ColumnStats();
columnStats1.setNdv(10);
columnStats1.setNumNulls(5);
ColumnStats columnStats2 = new ColumnStats();
columnStats2.setNdv(20);
columnStats1.setNumNulls(10);
Map<Slot, ColumnStats> slotColumnStatsMap1 = new HashMap<>();
slotColumnStatsMap1.put(slot1, columnStats1);
Map<Slot, ColumnStats> slotColumnStatsMap2 = new HashMap<>();
slotColumnStatsMap2.put(slot2, columnStats2);
final long leftRowCount = 5000;
StatsDeriveResult leftStats = new StatsDeriveResult(leftRowCount, slotColumnStatsMap1);
final long rightRowCount = 10000;
StatsDeriveResult rightStats = new StatsDeriveResult(rightRowCount, slotColumnStatsMap2);
EqualTo equalTo = new EqualTo(slot1, slot2);
StatsDeriveResult semiJoinStats = JoinEstimation.estimate(leftStats,
rightStats, equalTo, JoinType.LEFT_SEMI_JOIN);
Assert.assertEquals(leftRowCount, semiJoinStats.getRowCount());
StatsDeriveResult innerJoinStats = JoinEstimation.estimate(leftStats,
rightStats, equalTo, JoinType.INNER_JOIN);
Assert.assertEquals(2500000, innerJoinStats.getRowCount());
}
@Test
public void testOlapScan() {
ColumnStats columnStats1 = new ColumnStats();
columnStats1.setNdv(10);
columnStats1.setNumNulls(5);
long tableId1 = 0;
String tableName1 = "t1";
TableStats tableStats1 = new TableStats();
tableStats1.putColumnStats("c1", columnStats1);
Statistics statistics = new Statistics();
statistics.putTableStats(tableId1, tableStats1);
List<String> qualifier = new ArrayList<>();
qualifier.add("test");
qualifier.add("t");
SlotReference slot1 = new SlotReference("c1", IntegerType.INSTANCE, true, qualifier);
new Expectations() {{
ConnectContext.get();
result = context;
context.getEnv();
result = env;
env.getStatisticsManager();
result = statisticsManager;
statisticsManager.getStatistics();
result = statistics;
}};
Table table1 = new Table(tableId1, tableName1, TableType.OLAP, Collections.emptyList());
LogicalOlapScan logicalOlapScan1 = new LogicalOlapScan(table1, Collections.emptyList()).withLogicalProperties(
Optional.of(new LogicalProperties(new Supplier<List<Slot>>() {
@Override
public List<Slot> get() {
return Arrays.asList(slot1);
}
})));
Group childGroup = new Group();
GroupExpression groupExpression = new GroupExpression(logicalOlapScan1, Arrays.asList(childGroup));
Group ownerGroup = new Group();
groupExpression.setOwnerGroup(ownerGroup);
StatsCalculator statsCalculator = new StatsCalculator(groupExpression);
statsCalculator.estimate();
StatsDeriveResult stats = ownerGroup.getStatistics();
Assert.assertEquals(1, stats.getSlotToColumnStats().size());
Assert.assertNotNull(stats.getSlotToColumnStats().get(slot1));
}
}