[feature](nereids) estimate plan cost by column ndv and table row count (#13375)

In this version, we use column ndv information to estimate plan cost.

This is the first version, covers TPCH queries.
This commit is contained in:
minghong
2022-10-26 20:35:10 +08:00
committed by GitHub
parent bed759b3f5
commit f4c8d4ce85
25 changed files with 686 additions and 156 deletions

View File

@ -26,6 +26,7 @@ import org.apache.doris.nereids.properties.DistributionSpecHash;
import org.apache.doris.nereids.properties.DistributionSpecReplicated;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.physical.PhysicalAggregate;
import org.apache.doris.nereids.trees.plans.physical.PhysicalAssertNumRows;
import org.apache.doris.nereids.trees.plans.physical.PhysicalDistribute;
import org.apache.doris.nereids.trees.plans.physical.PhysicalHashJoin;
import org.apache.doris.nereids.trees.plans.physical.PhysicalLocalQuickSort;
@ -47,6 +48,29 @@ import java.util.List;
* Inspired by Presto.
*/
public class CostCalculator {
static final double CPU_WEIGHT = 1;
static final double MEMORY_WEIGHT = 1;
static final double NETWORK_WEIGHT = 1.5;
/**
* Except stats information, there are some special criteria in doris.
* For example, in hash join cluster, BE could build hash tables
* in parallel for left deep tree. And hence, we need to punish right deep tree.
* penalyWeight is the factor of punishment.
* The punishment is denoted by stats.penalty.
*/
static final double PENALTY_WEIGHT = 0.5;
/**
* The intuition behind `HEAVY_OPERATOR_PUNISH_FACTOR` is we need to avoid this form of join patterns:
* Plan1: L join ( AGG1(A) join AGG2(B))
* But
* Plan2: L join AGG1(A) join AGG2(B) is welcomed.
* AGG is time-consuming operator. From the perspective of rowCount, nereids may choose Plan1,
* because `Agg1 join Agg2` generates few tuples. But in Plan1, Agg1 and Agg2 are done in serial, in Plan2, Agg1 and
* Agg2 are done in parallel. And hence, Plan1 should be punished.
*
* An example is tpch q15.
*/
static final double HEAVY_OPERATOR_PUNISH_FACTOR = 6.0;
/**
* Constructor.
@ -55,8 +79,8 @@ public class CostCalculator {
PlanContext planContext = new PlanContext(groupExpression);
CostEstimator costCalculator = new CostEstimator();
CostEstimate costEstimate = groupExpression.getPlan().accept(costCalculator, planContext);
CostWeight costWeight = new CostWeight(0.5, 2, 1.5);
groupExpression.setCostEstimate(costEstimate);
CostWeight costWeight = new CostWeight(CPU_WEIGHT, MEMORY_WEIGHT, NETWORK_WEIGHT, PENALTY_WEIGHT);
return costWeight.calculate(costEstimate);
}
@ -84,7 +108,7 @@ public class CostCalculator {
StatsDeriveResult statistics = context.getStatisticsWithCheck();
StatsDeriveResult childStatistics = context.getChildStatistics(0);
return new CostEstimate(
return CostEstimate.of(
childStatistics.computeSize(),
statistics.computeSize(),
childStatistics.computeSize());
@ -96,7 +120,7 @@ public class CostCalculator {
StatsDeriveResult statistics = context.getStatisticsWithCheck();
StatsDeriveResult childStatistics = context.getChildStatistics(0);
return new CostEstimate(
return CostEstimate.of(
childStatistics.computeSize(),
statistics.computeSize(),
childStatistics.computeSize());
@ -109,7 +133,7 @@ public class CostCalculator {
StatsDeriveResult statistics = context.getStatisticsWithCheck();
StatsDeriveResult childStatistics = context.getChildStatistics(0);
return new CostEstimate(
return CostEstimate.of(
childStatistics.computeSize(),
statistics.computeSize(),
0);
@ -122,7 +146,7 @@ public class CostCalculator {
DistributionSpec spec = distribute.getDistributionSpec();
// shuffle
if (spec instanceof DistributionSpecHash) {
return new CostEstimate(
return CostEstimate.of(
childStatistics.computeSize(),
0,
childStatistics.computeSize());
@ -134,7 +158,7 @@ public class CostCalculator {
int instanceNumber = ConnectContext.get().getSessionVariable().getParallelExecInstanceNum();
beNumber = Math.max(1, beNumber);
return new CostEstimate(
return CostEstimate.of(
childStatistics.computeSize() * beNumber,
childStatistics.computeSize() * beNumber * instanceNumber,
childStatistics.computeSize() * beNumber * instanceNumber);
@ -142,14 +166,14 @@ public class CostCalculator {
// gather
if (spec instanceof DistributionSpecGather) {
return new CostEstimate(
return CostEstimate.of(
childStatistics.computeSize(),
0,
childStatistics.computeSize());
}
// any
return new CostEstimate(
return CostEstimate.of(
childStatistics.computeSize(),
0,
0);
@ -161,7 +185,7 @@ public class CostCalculator {
StatsDeriveResult statistics = context.getStatisticsWithCheck();
StatsDeriveResult inputStatistics = context.getChildStatistics(0);
return new CostEstimate(inputStatistics.computeSize(), statistics.computeSize(), 0);
return CostEstimate.of(inputStatistics.computeSize(), statistics.computeSize(), 0);
}
@Override
@ -170,16 +194,42 @@ public class CostCalculator {
Preconditions.checkState(context.getGroupExpression().arity() == 2);
Preconditions.checkState(context.getChildrenStats().size() == 2);
CostEstimate inputCost = calculateJoinInputCost(context);
CostEstimate outputCost = calculateJoinOutputCost(physicalHashJoin);
StatsDeriveResult outputStats = physicalHashJoin.getGroupExpression().get().getOwnerGroup().getStatistics();
double outputRowCount = outputStats.computeSize();
// TODO: handle some case
// handle cross join, onClause is empty .....
if (physicalHashJoin.getJoinType().isCrossJoin()) {
return CostEstimate.sum(inputCost, outputCost, outputCost);
StatsDeriveResult probeStats = context.getChildStatistics(0);
StatsDeriveResult buildStats = context.getChildStatistics(1);
List<Id> leftIds = context.getChildOutputIds(0);
List<Id> rightIds = context.getChildOutputIds(1);
double leftRowCount = probeStats.computeColumnSize(leftIds);
double rightRowCount = buildStats.computeColumnSize(rightIds);
/*
pattern1: L join1 (Agg1() join2 Agg2())
result number of join2 may much less than Agg1.
but Agg1 and Agg2 are slow. so we need to punish this pattern1.
pattern2: (L join1 Agg1) join2 agg2
in pattern2, join1 and join2 takes more time, but Agg1 and agg2 can be processed in parallel.
*/
double penalty = CostCalculator.HEAVY_OPERATOR_PUNISH_FACTOR
* Math.min(probeStats.getPenalty(), buildStats.getPenalty());
if (buildStats.getWidth() >= 2) {
//penalty for right deep tree
penalty += Math.abs(leftRowCount - rightRowCount);
}
return CostEstimate.sum(inputCost, outputCost);
if (physicalHashJoin.getJoinType().isCrossJoin()) {
return CostEstimate.of(leftRowCount + rightRowCount + outputRowCount,
0,
leftRowCount + rightRowCount,
penalty);
}
return CostEstimate.of(leftRowCount + rightRowCount + outputRowCount,
rightRowCount,
0,
penalty
);
}
@Override
@ -193,30 +243,20 @@ public class CostCalculator {
StatsDeriveResult leftStatistics = context.getChildStatistics(0);
StatsDeriveResult rightStatistics = context.getChildStatistics(1);
return new CostEstimate(
return CostEstimate.of(
leftStatistics.computeSize() * rightStatistics.computeSize(),
rightStatistics.computeSize(),
0);
}
}
private static CostEstimate calculateJoinInputCost(PlanContext context) {
StatsDeriveResult probeStats = context.getChildStatistics(0);
StatsDeriveResult buildStats = context.getChildStatistics(1);
List<Id> leftIds = context.getChildOutputIds(0);
List<Id> rightIds = context.getChildOutputIds(1);
double cpuCost = probeStats.computeColumnSize(leftIds) + buildStats.computeColumnSize(rightIds);
double memoryCost = buildStats.computeColumnSize(rightIds);
return CostEstimate.of(cpuCost, memoryCost, 0);
}
private static CostEstimate calculateJoinOutputCost(
PhysicalHashJoin<? extends Plan, ? extends Plan> physicalHashJoin) {
StatsDeriveResult outputStats = physicalHashJoin.getGroupExpression().get().getOwnerGroup().getStatistics();
double size = outputStats.computeSize();
return CostEstimate.ofCpu(size);
@Override
public CostEstimate visitPhysicalAssertNumRows(PhysicalAssertNumRows<? extends Plan> assertNumRows,
PlanContext context) {
return CostEstimate.of(
assertNumRows.getAssertNumRowsElement().getDesiredNumOfRows(),
assertNumRows.getAssertNumRowsElement().getDesiredNumOfRows(),
0
);
}
}
}

View File

@ -24,17 +24,24 @@ import com.google.common.base.Preconditions;
*/
public final class CostEstimate {
private static final CostEstimate INFINITE =
new CostEstimate(Double.POSITIVE_INFINITY, Double.POSITIVE_INFINITY, Double.POSITIVE_INFINITY);
private static final CostEstimate ZERO = new CostEstimate(0, 0, 0);
new CostEstimate(Double.POSITIVE_INFINITY,
Double.POSITIVE_INFINITY,
Double.POSITIVE_INFINITY,
Double.POSITIVE_INFINITY);
private static final CostEstimate ZERO = new CostEstimate(0, 0, 0, 0);
private final double cpuCost;
private final double memoryCost;
private final double networkCost;
//penalty for
// 1. right deep tree
// 2. right XXX join
private final double penalty;
/**
* Constructor of CostEstimate.
*/
public CostEstimate(double cpuCost, double memoryCost, double networkCost) {
public CostEstimate(double cpuCost, double memoryCost, double networkCost, double penaltiy) {
// TODO: fix stats
if (cpuCost < 0) {
cpuCost = 0;
@ -51,6 +58,7 @@ public final class CostEstimate {
this.cpuCost = cpuCost;
this.memoryCost = memoryCost;
this.networkCost = networkCost;
this.penalty = penaltiy;
}
public static CostEstimate infinite() {
@ -73,16 +81,24 @@ public final class CostEstimate {
return networkCost;
}
public double getPenalty() {
return penalty;
}
public static CostEstimate of(double cpuCost, double maxMemory, double networkCost, double rightDeepPenaltiy) {
return new CostEstimate(cpuCost, maxMemory, networkCost, rightDeepPenaltiy);
}
public static CostEstimate of(double cpuCost, double maxMemory, double networkCost) {
return new CostEstimate(cpuCost, maxMemory, networkCost);
return new CostEstimate(cpuCost, maxMemory, networkCost, 0);
}
public static CostEstimate ofCpu(double cpuCost) {
return new CostEstimate(cpuCost, 0, 0);
return new CostEstimate(cpuCost, 0, 0, 0);
}
public static CostEstimate ofMemory(double memoryCost) {
return new CostEstimate(0, memoryCost, 0);
return new CostEstimate(0, memoryCost, 0, 0);
}
/**
@ -97,6 +113,16 @@ public final class CostEstimate {
memoryCostSum += costEstimate.memoryCost;
networkCostSum += costEstimate.networkCost;
}
return new CostEstimate(cpuCostSum, memoryCostSum, networkCostSum);
return CostEstimate.of(cpuCostSum, memoryCostSum, networkCostSum);
}
@Override
public String toString() {
StringBuilder sb = new StringBuilder();
sb.append((long) cpuCost)
.append("/").append((long) memoryCost)
.append("/").append((long) networkCost)
.append("/").append((long) penalty);
return sb.toString();
}
}

View File

@ -26,21 +26,25 @@ public class CostWeight {
private final double cpuWeight;
private final double memoryWeight;
private final double networkWeight;
private final double penaltyWeight;
/**
* Constructor
*/
public CostWeight(double cpuWeight, double memoryWeight, double networkWeight) {
public CostWeight(double cpuWeight, double memoryWeight, double networkWeight, double penaltyWeight) {
Preconditions.checkArgument(cpuWeight >= 0, "cpuWeight cannot be negative");
Preconditions.checkArgument(memoryWeight >= 0, "memoryWeight cannot be negative");
Preconditions.checkArgument(networkWeight >= 0, "networkWeight cannot be negative");
this.cpuWeight = cpuWeight;
this.memoryWeight = memoryWeight;
this.networkWeight = networkWeight;
this.penaltyWeight = penaltyWeight;
}
public double calculate(CostEstimate costEstimate) {
return costEstimate.getCpuCost() * cpuWeight + costEstimate.getMemoryCost() * memoryWeight
+ costEstimate.getNetworkCost() * networkWeight;
+ costEstimate.getNetworkCost() * networkWeight
+ costEstimate.getPenalty() * penaltyWeight;
}
}

View File

@ -122,6 +122,7 @@ public class CostAndEnforcerJob extends Job implements Cloneable {
// Calculate cost
if (curChildIndex == 0 && prevChildIndex == -1) {
curNodeCost = CostCalculator.calculateCost(groupExpression);
groupExpression.setCost(curNodeCost);
curTotalCost += curNodeCost;
}
@ -192,6 +193,7 @@ public class CostAndEnforcerJob extends Job implements Cloneable {
StatsCalculator.estimate(groupExpression);
curTotalCost -= curNodeCost;
curNodeCost = CostCalculator.calculateCost(groupExpression);
groupExpression.setCost(curNodeCost);
curTotalCost += curNodeCost;
// record map { outputProperty -> outputProperty }, { ANY -> outputProperty },

View File

@ -19,6 +19,7 @@ package org.apache.doris.nereids.memo;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.analyzer.UnboundRelation;
import org.apache.doris.nereids.cost.CostEstimate;
import org.apache.doris.nereids.properties.PhysicalProperties;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
@ -39,6 +40,8 @@ import java.util.Optional;
* Representation for group expression in cascades optimizer.
*/
public class GroupExpression {
private double cost = 0.0;
private CostEstimate costEstimate = null;
private Group ownerGroup;
private List<Group> children;
private final Plan plan;
@ -189,7 +192,7 @@ public class GroupExpression {
* @param property property that needs to be satisfied
* @return Lowest cost to satisfy that property
*/
public double getCost(PhysicalProperties property) {
public double getCostByProperties(PhysicalProperties property) {
Preconditions.checkState(lowestCostTable.containsKey(property));
return lowestCostTable.get(property).first;
}
@ -199,6 +202,22 @@ public class GroupExpression {
this.requestPropertiesMap.put(requiredPropertySet, outputPropertySet);
}
public double getCost() {
return cost;
}
public void setCost(double cost) {
this.cost = cost;
}
public CostEstimate getCostEstimate() {
return costEstimate;
}
public void setCostEstimate(CostEstimate costEstimate) {
this.costEstimate = costEstimate;
}
@Override
public boolean equals(Object o) {
if (this == o) {
@ -228,20 +247,23 @@ public class GroupExpression {
@Override
public String toString() {
StringBuilder builder = new StringBuilder();
builder.append(ownerGroup.getGroupId()).append("(plan=").append(plan).append(") children=[");
if (ownerGroup == null) {
builder.append("OWNER GROUP IS NULL[]");
} else {
builder.append(ownerGroup.getGroupId()).append("(plan=").append(plan.toString()).append(") children=[");
builder.append(ownerGroup.getGroupId()).append(" cost=").append((long) cost);
}
if (costEstimate != null) {
builder.append(" est=").append(costEstimate);
}
builder.append(" (plan=" + plan.toString() + ") children=[");
for (Group group : children) {
builder.append(group.getGroupId()).append(" ");
}
builder.append("] stats=");
builder.append("]");
if (ownerGroup != null) {
builder.append(ownerGroup.getStatistics());
} else {
builder.append("NULL");
builder.append(" stats=").append(ownerGroup.getStatistics());
}
return builder.toString();
}

View File

@ -214,7 +214,10 @@ public class DistributionSpecHash extends DistributionSpec {
return false;
}
DistributionSpecHash that = (DistributionSpecHash) o;
return shuffleType == that.shuffleType && orderedShuffledColumns.equals(that.orderedShuffledColumns);
//TODO: that.orderedShuffledColumns may have equivalent slots. This will be done later
return shuffleType == that.shuffleType
&& orderedShuffledColumns.size() == that.orderedShuffledColumns.size()
&& equalsSatisfy(that.orderedShuffledColumns);
}
@Override
@ -248,4 +251,5 @@ public class DistributionSpecHash extends DistributionSpec {
// output, all distribute enforce
ENFORCED,
}
}

View File

@ -77,7 +77,7 @@ public class EnforceMissingPropertiesHelper {
*/
private PhysicalProperties enforceDistributionButMeetSort(PhysicalProperties output, PhysicalProperties request) {
groupExpression.getOwnerGroup()
.replaceBestPlan(output, PhysicalProperties.ANY, groupExpression.getCost(output));
.replaceBestPlan(output, PhysicalProperties.ANY, groupExpression.getCostByProperties(output));
return enforceSortAndDistribution(output, request);
}

View File

@ -105,4 +105,5 @@ public class OrderSpec {
public int hashCode() {
return Objects.hash(orderKeys);
}
}

View File

@ -96,4 +96,10 @@ public class PhysicalProperties {
}
return hashCode;
}
@Override
public String toString() {
return distributionSpec.toString() + " " + orderSpec.toString();
}
}

View File

@ -18,7 +18,10 @@
package org.apache.doris.nereids.stats;
import org.apache.doris.nereids.trees.expressions.Add;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.BinaryArithmetic;
import org.apache.doris.nereids.trees.expressions.CaseWhen;
import org.apache.doris.nereids.trees.expressions.Cast;
import org.apache.doris.nereids.trees.expressions.Divide;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Multiply;
@ -47,6 +50,9 @@ public class ExpressionEstimation extends ExpressionVisitor<ColumnStat, StatsDer
private static ExpressionEstimation INSTANCE = new ExpressionEstimation();
/**
* returned columnStat is newly created or a copy of stats
*/
public static ColumnStat estimate(Expression expression, StatsDeriveResult stats) {
return INSTANCE.visit(expression, stats);
}
@ -56,12 +62,30 @@ public class ExpressionEstimation extends ExpressionVisitor<ColumnStat, StatsDer
return expr.accept(this, context);
}
//TODO: case-when need to re-implemented
@Override
public ColumnStat visitCaseWhen(CaseWhen caseWhen, StatsDeriveResult context) {
ColumnStat columnStat = new ColumnStat();
columnStat.setNdv(caseWhen.getWhenClauses().size() + 1);
columnStat.setSelectivity(1.0);
columnStat.setMinValue(0);
columnStat.setMaxValue(Double.MAX_VALUE);
columnStat.setAvgSizeByte(8);
columnStat.setNumNulls(0);
columnStat.setMaxSizeByte(8);
return columnStat;
}
public ColumnStat visitCast(Cast cast, StatsDeriveResult context) {
return cast.child().accept(this, context);
}
@Override
public ColumnStat visitLiteral(Literal literal, StatsDeriveResult context) {
if (literal.isStringLiteral()) {
if (ColumnStat.MAX_MIN_UNSUPPORTED_TYPE.contains(literal.getDataType().toCatalogDataType())) {
return ColumnStat.UNKNOWN;
}
double literalVal = Double.parseDouble(literal.getValue().toString());
double literalVal = literal.getDouble();
ColumnStat columnStat = new ColumnStat();
columnStat.setMaxValue(literalVal);
columnStat.setMinValue(literalVal);
@ -75,7 +99,7 @@ public class ExpressionEstimation extends ExpressionVisitor<ColumnStat, StatsDer
public ColumnStat visitSlotReference(SlotReference slotReference, StatsDeriveResult context) {
ColumnStat columnStat = context.getColumnStatsBySlot(slotReference);
Preconditions.checkState(columnStat != null);
return columnStat;
return columnStat.copy();
}
@Override
@ -163,14 +187,9 @@ public class ExpressionEstimation extends ExpressionVisitor<ColumnStat, StatsDer
@Override
public ColumnStat visitCount(Count count, StatsDeriveResult context) {
Expression child = count.child(0);
ColumnStat columnStat = child.accept(this, context);
if (columnStat == ColumnStat.UNKNOWN) {
return ColumnStat.UNKNOWN;
}
double expectedValue = context.getRowCount() - columnStat.getNumNulls();
return new ColumnStat(1,
count.getDataType().width(), count.getDataType().width(), 0, expectedValue, expectedValue);
//count() returns long type
return new ColumnStat(1.0, 8.0, 8.0, 0.0,
Double.MIN_VALUE, Double.MAX_VALUE);
}
// TODO: return a proper estimated stat after supports histogram
@ -207,4 +226,9 @@ public class ExpressionEstimation extends ExpressionVisitor<ColumnStat, StatsDer
public ColumnStat visitSubstring(Substring substring, StatsDeriveResult context) {
return substring.child(0).accept(this, context);
}
@Override
public ColumnStat visitAlias(Alias alias, StatsDeriveResult context) {
return alias.child().accept(this, context);
}
}

View File

@ -47,18 +47,15 @@ import java.util.Map;
* TODO: Should consider the distribution of data.
*/
public class FilterEstimation extends ExpressionVisitor<StatsDeriveResult, EstimationContext> {
public static final double DEFAULT_INEQUALITY_COMPARISON_SELECTIVITY = 0.8;
private static final double DEFAULT_SELECTIVITY = 0.1;
public static final double DEFAULT_EQUALITY_COMPARISON_SELECTIVITY = 0.1;
private static final double DEFAULT_INEQUALITY_COMPARISON_SELECTIVITY = 1.0 / 3.0;
private final StatsDeriveResult inputStats;
private static final double DEFAULT_EQUALITY_COMPARISON_SELECTIVITY = 0.1;
private final StatsDeriveResult stats;
public FilterEstimation(StatsDeriveResult stats) {
Preconditions.checkNotNull(stats);
this.stats = stats;
public FilterEstimation(StatsDeriveResult inputStats) {
Preconditions.checkNotNull(inputStats);
this.inputStats = inputStats;
}
/**
@ -67,14 +64,7 @@ public class FilterEstimation extends ExpressionVisitor<StatsDeriveResult, Estim
public StatsDeriveResult estimate(Expression expression) {
// For a comparison predicate, only when it's left side is a slot and right side is a literal, we would
// consider is a valid predicate.
StatsDeriveResult stats = calculate(expression);
double expectedRowCount = stats.getRowCount();
for (ColumnStat columnStat : stats.getSlotToColumnStats().values()) {
if (columnStat.getNdv() > expectedRowCount) {
columnStat.setNdv(expectedRowCount);
}
}
return stats;
return calculate(expression);
}
private StatsDeriveResult calculate(Expression expression) {
@ -83,7 +73,7 @@ public class FilterEstimation extends ExpressionVisitor<StatsDeriveResult, Estim
@Override
public StatsDeriveResult visit(Expression expr, EstimationContext context) {
return new StatsDeriveResult(stats).updateRowCountBySelectivity(DEFAULT_SELECTIVITY);
return new StatsDeriveResult(inputStats).updateBySelectivity(DEFAULT_INEQUALITY_COMPARISON_SELECTIVITY);
}
@Override
@ -97,7 +87,7 @@ public class FilterEstimation extends ExpressionVisitor<StatsDeriveResult, Estim
StatsDeriveResult rightStats = rightExpr.accept(this, null);
StatsDeriveResult andStats = rightExpr.accept(new FilterEstimation(leftStats), null);
double rowCount = leftStats.getRowCount() + rightStats.getRowCount() - andStats.getRowCount();
StatsDeriveResult orStats = new StatsDeriveResult(stats).setRowCount(rowCount);
StatsDeriveResult orStats = new StatsDeriveResult(inputStats).setRowCount(rowCount);
for (Map.Entry<Slot, ColumnStat> entry : leftStats.getSlotToColumnStats().entrySet()) {
Slot keySlot = entry.getKey();
ColumnStat leftColStats = entry.getValue();
@ -115,19 +105,166 @@ public class FilterEstimation extends ExpressionVisitor<StatsDeriveResult, Estim
@Override
public StatsDeriveResult visitComparisonPredicate(ComparisonPredicate cp, EstimationContext context) {
boolean isNot = (context != null) && context.isNot;
Expression left = cp.left();
Expression right = cp.right();
ColumnStat statsForLeft = ExpressionEstimation.estimate(left, stats);
ColumnStat statsForRight = ExpressionEstimation.estimate(right, stats);
ColumnStat statsForLeft = ExpressionEstimation.estimate(left, inputStats);
ColumnStat statsForRight = ExpressionEstimation.estimate(right, inputStats);
double selectivity;
if (!(left instanceof Literal) && !(right instanceof Literal)) {
selectivity = calculateWhenBothChildIsColumn(cp, statsForLeft, statsForRight);
} else {
// For literal, it's max min is same value.
selectivity = calculateWhenRightChildIsLiteral(cp, statsForLeft, statsForRight.getMaxValue());
selectivity = updateLeftStatsWhenRightChildIsLiteral(cp,
statsForLeft,
statsForRight.getMaxValue(),
isNot);
}
return new StatsDeriveResult(stats).updateRowCountOnCopy(selectivity);
StatsDeriveResult outputStats = new StatsDeriveResult(inputStats);
//TODO: we take the assumption that func(A) and A have the same stats.
outputStats.updateBySelectivity(selectivity, cp.getInputSlots());
if (left.getInputSlots().size() == 1) {
Slot leftSlot = left.getInputSlots().iterator().next();
outputStats.updateColumnStatsForSlot(leftSlot, statsForLeft);
}
return outputStats;
}
private double updateLessThan(ColumnStat statsForLeft, double val,
double min, double max, double ndv) {
double selectivity = 1.0;
if (val <= min) {
statsForLeft.setMaxValue(val);
statsForLeft.setMinValue(0);
statsForLeft.setNdv(0);
selectivity = 0.0;
} else if (val > max) {
selectivity = 1.0;
} else if (val == max) {
selectivity = 1.0 - 1.0 / ndv;
} else {
statsForLeft.setMaxValue(val);
selectivity = (val - min) / (max - min);
statsForLeft.setNdv(selectivity * statsForLeft.getNdv());
}
return selectivity;
}
private double updateLessThanEqual(ColumnStat statsForLeft, double val,
double min, double max, double ndv) {
double selectivity = 1.0;
if (val < min) {
statsForLeft.setMaxValue(val);
statsForLeft.setMinValue(val);
selectivity = 0.0;
} else if (val == min) {
statsForLeft.setMaxValue(val);
selectivity = 1.0 / ndv;
} else if (val >= max) {
selectivity = 1.0;
} else {
statsForLeft.setMaxValue(val);
selectivity = (val - min) / (max - min);
statsForLeft.setNdv(selectivity * statsForLeft.getNdv());
}
return selectivity;
}
private double updateGreaterThan(ColumnStat statsForLeft, double val,
double min, double max, double ndv) {
double selectivity = 1.0;
if (val >= max) {
statsForLeft.setMaxValue(val);
statsForLeft.setMinValue(val);
statsForLeft.setNdv(0);
selectivity = 0.0;
} else if (val == min) {
selectivity = 1.0 - 1.0 / ndv;
} else if (val < min) {
selectivity = 1.0;
} else {
statsForLeft.setMinValue(val);
selectivity = (max - val) / (max - min);
statsForLeft.setNdv(selectivity * statsForLeft.getNdv());
}
return selectivity;
}
private double updateGreaterThanEqual(ColumnStat statsForLeft, double val,
double min, double max, double ndv) {
double selectivity = 1.0;
if (val > max) {
statsForLeft.setMinValue(val);
statsForLeft.setMaxValue(val);
selectivity = 0.0;
} else if (val == max) {
statsForLeft.setMinValue(val);
statsForLeft.setMaxValue(val);
selectivity = 1.0 / ndv;
} else if (val <= min) {
selectivity = 1.0;
} else {
statsForLeft.setMinValue(val);
selectivity = (max - val) / (max - min);
statsForLeft.setNdv(selectivity * statsForLeft.getNdv());
}
return selectivity;
}
private double updateLeftStatsWhenRightChildIsLiteral(ComparisonPredicate cp,
ColumnStat statsForLeft, double val, boolean isNot) {
if (statsForLeft == ColumnStat.UNKNOWN) {
return 1.0;
}
double selectivity = 1.0;
double ndv = statsForLeft.getNdv();
double max = statsForLeft.getMaxValue();
double min = statsForLeft.getMinValue();
if (cp instanceof EqualTo) {
if (!isNot) {
statsForLeft.setNdv(1);
statsForLeft.setMaxValue(val);
statsForLeft.setMinValue(val);
if (val > max || val < min) {
selectivity = 0.0;
} else {
selectivity = 1.0 / ndv;
}
} else {
if (val <= max && val >= min) {
selectivity = 1 - DEFAULT_EQUALITY_COMPARISON_SELECTIVITY;
}
}
} else if (cp instanceof LessThan) {
if (isNot) {
selectivity = updateGreaterThanEqual(statsForLeft, val, min, max, ndv);
} else {
selectivity = updateLessThan(statsForLeft, val, min, max, ndv);
}
} else if (cp instanceof LessThanEqual) {
if (isNot) {
selectivity = updateGreaterThan(statsForLeft, val, min, max, ndv);
} else {
selectivity = updateLessThanEqual(statsForLeft, val, min, max, ndv);
}
} else if (cp instanceof GreaterThan) {
if (isNot) {
selectivity = updateLessThanEqual(statsForLeft, val, min, max, ndv);
} else {
selectivity = updateGreaterThan(statsForLeft, val, min, max, ndv);
}
} else if (cp instanceof GreaterThanEqual) {
if (isNot) {
selectivity = updateLessThan(statsForLeft, val, min, max, ndv);
} else {
selectivity = updateGreaterThanEqual(statsForLeft, val, min, max, ndv);
}
} else {
throw new RuntimeException(String.format("Unexpected expression : %s", cp.toSql()));
}
return selectivity;
}
private double calculateWhenRightChildIsLiteral(ComparisonPredicate cp,
@ -247,41 +384,86 @@ public class FilterEstimation extends ExpressionVisitor<StatsDeriveResult, Estim
public StatsDeriveResult visitInPredicate(InPredicate inPredicate, EstimationContext context) {
boolean isNotIn = context != null && context.isNot;
Expression compareExpr = inPredicate.getCompareExpr();
ColumnStat compareExprStats = ExpressionEstimation.estimate(compareExpr, stats);
if (ColumnStat.isInvalid(compareExprStats)) {
return stats;
ColumnStat compareExprStats = ExpressionEstimation.estimate(compareExpr, inputStats);
if (ColumnStat.isUnKnown(compareExprStats)) {
return inputStats;
}
List<Expression> options = inPredicate.getOptions();
double maxOption = 0;
double minOption = Double.MAX_VALUE;
double optionDistinctCount = 0;
for (Expression option : options) {
ColumnStat optionStats = ExpressionEstimation.estimate(option, stats);
if (ColumnStat.isInvalid(optionStats)) {
return stats;
/* suppose A.(min, max) = (0, 10), A.ndv=10
A in ( 1, 2, 5, 100):
validInOptCount = 3, that is (1, 2, 5)
table selectivity = 3/10
A.min = 1, A.max=5
A.selectivity = 3/5
A.ndv = 3
A not in (1, 2, 3, 100):
validInOptCount = 10 - 3
we assume that 1, 2, 3 exist in A
A.ndv = 10 - 3 = 7
table selectivity = 7/10
A.(min, max) not changed
A.selectivity = 7/10
*/
double validInOptCount = 0;
double columnSelectivity = 1.0;
double selectivity = 1.0;
if (isNotIn) {
for (Expression option : options) {
ColumnStat optionStats = ExpressionEstimation.estimate(option, inputStats);
if (ColumnStat.isUnKnown(optionStats)) {
continue;
}
double validOptionNdv = compareExprStats.ndvIntersection(optionStats);
if (validOptionNdv > 0.0) {
validInOptCount += validOptionNdv;
}
}
optionDistinctCount += optionStats.getNdv();
maxOption = Math.max(optionStats.getMaxValue(), maxOption);
minOption = Math.min(optionStats.getMinValue(), minOption);
validInOptCount = Math.max(1, compareExprStats.getNdv() - validInOptCount);
columnSelectivity = Math.max(1, validInOptCount)
/ compareExprStats.getNdv();
} else {
for (Expression option : options) {
ColumnStat optionStats = ExpressionEstimation.estimate(option, inputStats);
if (ColumnStat.isUnKnown(optionStats)) {
validInOptCount++;
continue;
}
double validOptionNdv = compareExprStats.ndvIntersection(optionStats);
if (validOptionNdv > 0.0) {
validInOptCount += validOptionNdv;
maxOption = Math.max(optionStats.getMaxValue(), maxOption);
minOption = Math.min(optionStats.getMinValue(), minOption);
}
}
maxOption = Math.min(maxOption, compareExprStats.getMaxValue());
minOption = Math.max(minOption, compareExprStats.getMinValue());
if (maxOption == minOption) {
columnSelectivity = 1.0;
} else {
double outputRange = maxOption - minOption;
double originRange = Math.max(1, compareExprStats.getMaxValue() - compareExprStats.getMinValue());
double orginDensity = compareExprStats.getNdv() / originRange;
double outputDensity = validInOptCount / outputRange;
columnSelectivity = Math.min(1, outputDensity / orginDensity);
}
compareExprStats.setMaxValue(maxOption);
compareExprStats.setMinValue(minOption);
}
double selectivity = DEFAULT_SELECTIVITY;
double cmpExprMax = compareExprStats.getMaxValue();
double cmpExprMin = compareExprStats.getMinValue();
boolean hasOverlap = Math.max(cmpExprMin, minOption) <= Math.min(cmpExprMax, maxOption);
if (!hasOverlap) {
selectivity = 0.0;
}
double cmpDistinctCount = compareExprStats.getNdv();
selectivity = Math.min(1.0, optionDistinctCount / cmpDistinctCount);
double expectedMax = Math.min(cmpExprMax, maxOption);
double expectedMin = Math.max(cmpExprMin, minOption);
compareExprStats.setMaxValue(expectedMax);
compareExprStats.setMinValue(expectedMin);
StatsDeriveResult estimated = new StatsDeriveResult(stats);
if (compareExpr instanceof SlotReference && !isNotIn) {
selectivity = Math.min(1.0, validInOptCount / compareExprStats.getNdv());
compareExprStats.setSelectivity(compareExprStats.getSelectivity() * columnSelectivity);
compareExprStats.setNdv(validInOptCount);
StatsDeriveResult estimated = new StatsDeriveResult(inputStats);
estimated = estimated.updateRowCountOnCopy(selectivity);
if (compareExpr instanceof SlotReference) {
estimated.addColumnStats((SlotReference) compareExpr, compareExprStats);
}
return estimated.updateRowCountOnCopy(isNotIn ? 1.0 - selectivity : selectivity);
return estimated;
}
@Override

View File

@ -60,7 +60,8 @@ public class FilterSelectivityCalculator extends ExpressionVisitor<Double, Void>
//only calculate comparison in form of `slot comp literal`,
//otherwise, use DEFAULT_RANGE_SELECTIVITY
if (expression instanceof ComparisonPredicate
&& !(expression.child(0) instanceof SlotReference
&& !(!expression.child(0).getInputSlots().isEmpty()
&& expression.child(0).getInputSlots().toArray()[0] instanceof SlotReference
&& expression.child(1) instanceof Literal)) {
return DEFAULT_RANGE_SELECTIVITY;
}
@ -91,11 +92,17 @@ public class FilterSelectivityCalculator extends ExpressionVisitor<Double, Void>
// TODO: If right value greater than the max value or less than min value in column stats, return 0.0 .
@Override
public Double visitEqualTo(EqualTo equalTo, Void context) {
SlotReference left = (SlotReference) equalTo.left();
//TODO: remove the assumption that fun(A) and A have the same stats
SlotReference left = (SlotReference) equalTo.left().getInputSlots().toArray()[0];
Literal literal = (Literal) equalTo.right();
ColumnStat columnStats = slotRefToStats.get(left);
if (columnStats == null) {
return DEFAULT_EQUAL_SELECTIVITY;
throw new RuntimeException("FilterSelectivityCalculator - col stats not found: " + left);
}
ColumnStat newStats = new ColumnStat(1, columnStats.getAvgSizeByte(), columnStats.getMaxSizeByte(), 0,
Double.parseDouble(literal.getValue().toString()), Double.parseDouble(literal.getValue().toString()));
newStats.setSelectivity(1.0 / columnStats.getNdv());
slotRefToStats.put(left, newStats);
double ndv = columnStats.getNdv();
return ndv < 0 ? DEFAULT_EQUAL_SELECTIVITY : ndv == 0 ? 0 : 1.0 / columnStats.getNdv();
}

View File

@ -69,6 +69,29 @@ public class JoinEstimation {
public double rowCount = 0;
}
private static double estimateInnerJoinV2(Join join, EqualTo equalto,
StatsDeriveResult leftStats, StatsDeriveResult rightStats) {
SlotReference eqRight = (SlotReference) equalto.child(1).getInputSlots().toArray()[0];
ColumnStat rColumnStats = rightStats.getSlotToColumnStats().get(eqRight);
SlotReference eqLeft = (SlotReference) equalto.child(0).getInputSlots().toArray()[0];
if (rColumnStats == null) {
rColumnStats = rightStats.getSlotToColumnStats().get(eqLeft);
}
if (rColumnStats == null) {
throw new RuntimeException("estimateInnerJoinV2 cannot find columnStats: " + eqRight);
}
double rowCount =
(leftStats.getRowCount()
* rightStats.getRowCount()
* rColumnStats.getSelectivity()
/ rColumnStats.getNdv());
rowCount = Math.ceil(rowCount);
return rowCount;
}
/**
* the basic idea of star-schema is:
* 1. fact_table JOIN dimension_table, if dimension table are filtered, the result can be regarded as
@ -83,12 +106,12 @@ public class JoinEstimation {
JoinEstimationResult result = new JoinEstimationResult();
SlotReference eqLeft = (SlotReference) equalto.child(0);
SlotReference eqRight = (SlotReference) equalto.child(1);
if ((rightStats.width == LIMIT_RIGHT_WIDTH && !rightStats.isReduced)
|| rightStats.width > LIMIT_RIGHT_WIDTH + 1) {
if ((rightStats.getWidth() == LIMIT_RIGHT_WIDTH && !rightStats.isReduced)
|| rightStats.getWidth() > LIMIT_RIGHT_WIDTH + 1) {
//if the right side is too wide, ignore the filter effect.
result.forbiddenReducePropagation = true;
//penalty too right deep tree by multiply level
result.rowCount = rightStats.width * (leftStats.getRowCount()
result.rowCount = rightStats.getWidth() * (leftStats.getRowCount()
+ AVG_DIM_FACT_RATIO * rightStats.getRowCount());
} else if (eqLeft.getColumn().isPresent() || eqRight.getColumn().isPresent()) {
Set<Slot> rightSlots = ((PhysicalHashJoin<?, ?>) join).child(1).getOutputSet();
@ -124,11 +147,56 @@ public class JoinEstimation {
return result;
}
/**
* estimate join
*/
public static StatsDeriveResult estimate(StatsDeriveResult leftStats, StatsDeriveResult rightStats, Join join) {
JoinType joinType = join.getJoinType();
double rowCount = Double.MAX_VALUE;
if (joinType == JoinType.LEFT_SEMI_JOIN || joinType == JoinType.LEFT_ANTI_JOIN) {
rowCount = leftStats.getRowCount();
} else if (joinType == JoinType.RIGHT_SEMI_JOIN || joinType == JoinType.RIGHT_ANTI_JOIN) {
rowCount = rightStats.getRowCount();
} else if (joinType == JoinType.INNER_JOIN) {
if (join.getHashJoinConjuncts().isEmpty()) {
//TODO: consider other join conjuncts
rowCount = leftStats.getRowCount() * rightStats.getRowCount();
} else {
for (Expression joinConjunct : join.getHashJoinConjuncts()) {
double tmpRowCount = estimateInnerJoinV2(join,
(EqualTo) joinConjunct, leftStats, rightStats);
rowCount = Math.min(rowCount, tmpRowCount);
}
}
} else if (joinType == JoinType.LEFT_OUTER_JOIN) {
rowCount = leftStats.getRowCount();
} else if (joinType == JoinType.RIGHT_OUTER_JOIN) {
rowCount = rightStats.getRowCount();
} else if (joinType == JoinType.CROSS_JOIN) {
rowCount = CheckedMath.checkedMultiply(leftStats.getRowCount(),
rightStats.getRowCount());
} else {
throw new RuntimeException("joinType is not supported");
}
StatsDeriveResult statsDeriveResult = new StatsDeriveResult(rowCount, Maps.newHashMap());
if (joinType.isRemainLeftJoin()) {
statsDeriveResult.merge(leftStats);
}
if (joinType.isRemainRightJoin()) {
statsDeriveResult.merge(rightStats);
}
statsDeriveResult.setRowCount(rowCount);
statsDeriveResult.setWidth(rightStats.getWidth() + leftStats.getWidth());
statsDeriveResult.setPenalty(0.0);
return statsDeriveResult;
}
/**
* Do estimate.
* // TODO: since we have no column stats here. just use a fix ratio to compute the row count.
*/
public static StatsDeriveResult estimate(StatsDeriveResult leftStats, StatsDeriveResult rightStats, Join join) {
public static StatsDeriveResult estimate2(StatsDeriveResult leftStats, StatsDeriveResult rightStats, Join join) {
JoinType joinType = join.getJoinType();
// TODO: normalize join hashConjuncts.
// List<Expression> hashJoinConjuncts = join.getHashJoinConjuncts();
@ -139,7 +207,7 @@ public class JoinEstimation {
boolean forbiddenReducePropagation = false;
double rowCount;
if (joinType == JoinType.LEFT_SEMI_JOIN || joinType == JoinType.LEFT_ANTI_JOIN) {
if (rightStats.isReduced && rightStats.width <= LIMIT_RIGHT_WIDTH) {
if (rightStats.isReduced && rightStats.getWidth() <= LIMIT_RIGHT_WIDTH) {
rowCount = leftStats.getRowCount() / REDUCE_TIMES;
} else {
rowCount = leftStats.getRowCount() + 1;
@ -202,7 +270,7 @@ public class JoinEstimation {
}
statsDeriveResult.setRowCount(rowCount);
statsDeriveResult.isReduced = !forbiddenReducePropagation && (isReducedByHashJoin || leftStats.isReduced);
statsDeriveResult.width = rightStats.width + leftStats.width;
statsDeriveResult.setWidth(rightStats.getWidth() + leftStats.getWidth());
return statsDeriveResult;
}

View File

@ -104,7 +104,14 @@ public class StatsCalculator extends DefaultPlanVisitor<StatsDeriveResult, Void>
private void estimate() {
StatsDeriveResult stats = groupExpression.getPlan().accept(this, null);
groupExpression.getOwnerGroup().setStatistics(stats);
/*
in an ideal cost model, every group expression in a group are equivalent, but in fact the cost are different.
we record the lowest expression cost as group cost to avoid missing this group.
*/
if (groupExpression.getOwnerGroup().getStatistics() == null
|| (stats.getRowCount() < groupExpression.getOwnerGroup().getStatistics().getRowCount())) {
groupExpression.getOwnerGroup().setStatistics(stats);
}
groupExpression.setStatDerived(true);
}
@ -254,7 +261,7 @@ public class StatsCalculator extends DefaultPlanVisitor<StatsDeriveResult, Void>
FilterSelectivityCalculator selectivityCalculator =
new FilterSelectivityCalculator(stats.getSlotToColumnStats());
double selectivity = selectivityCalculator.estimate(filter.getPredicates());
stats.updateRowCountBySelectivity(selectivity);
stats.updateBySelectivity(selectivity, filter.getPredicates().getInputSlots());
stats.isReduced = selectivity < 1.0;
return stats;
}

View File

@ -99,7 +99,14 @@ public class StatsCalculatorV2 extends DefaultPlanVisitor<StatsDeriveResult, Voi
private void estimate() {
StatsDeriveResult stats = groupExpression.getPlan().accept(this, null);
groupExpression.getOwnerGroup().setStatistics(stats);
StatsDeriveResult originStats = groupExpression.getOwnerGroup().getStatistics();
/*
in an ideal cost model, every group expression in a group are equivalent, but in fact the cost are different.
we record the lowest expression cost as group cost to avoid missing this group.
*/
if (originStats == null || originStats.getRowCount() > stats.getRowCount()) {
groupExpression.getOwnerGroup().setStatistics(stats);
}
groupExpression.setStatDerived(true);
}
@ -321,9 +328,13 @@ public class StatsCalculatorV2 extends DefaultPlanVisitor<StatsDeriveResult, Voi
// TODO: 1. Estimate the output unit size by the type of corresponding AggregateFunction
// 2. Handle alias, literal in the output expression list
for (NamedExpression outputExpression : outputExpressions) {
slotToColumnStats.put(outputExpression.toSlot(), new ColumnStat());
ColumnStat columnStat = ExpressionEstimation.estimate(outputExpression, childStats);
columnStat.setNdv(Math.min(columnStat.getNdv(), resultSetCount));
slotToColumnStats.put(outputExpression.toSlot(), columnStat);
}
StatsDeriveResult statsDeriveResult = new StatsDeriveResult(resultSetCount, slotToColumnStats);
statsDeriveResult.setWidth(childStats.getWidth());
statsDeriveResult.setPenalty(childStats.getPenalty() + childStats.getRowCount());
// TODO: Update ColumnStats properly, add new mapping from output slot to ColumnStats
return statsDeriveResult;
}

View File

@ -53,4 +53,16 @@ public class CharLiteral extends Literal {
public LiteralExpr toLegacyLiteral() {
return new StringLiteral(value);
}
@Override
public double getDouble() {
long v = 0;
int pos = 0;
int len = Math.min(value.length(), 8);
while (pos < len) {
v += ((long) value.charAt(pos)) << ((7 - pos) * 8);
pos++;
}
return (double) v;
}
}

View File

@ -82,6 +82,22 @@ public abstract class Literal extends Expression implements LeafExpression {
public abstract Object getValue();
/**
* Map literal to double, and keep "<=" order.
* for numeric literal (int/long/double/float), directly convert to double
* for char/varchar/string, we take first 8 chars as a int64, and convert it to double
* for other literals, getDouble() is not used.
*
* And hence, we could express the range of a datatype, and used in stats derive.
* for example:
*'abcxxxxxxxxxxx' is between ('abb', 'zzz')
*
* @return double representation of literal.
*/
public double getDouble() {
return Double.parseDouble(getValue().toString());
}
public String getStringValue() {
return String.valueOf(getValue());
}

View File

@ -43,6 +43,18 @@ public class StringLiteral extends Literal {
return value;
}
@Override
public double getDouble() {
long v = 0;
int pos = 0;
int len = Math.min(value.length(), 8);
while (pos < len) {
v += ((long) value.charAt(pos)) << ((7 - pos) * 8);
pos++;
}
return (double) v;
}
@Override
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
return visitor.visitStringLiteral(this, context);

View File

@ -76,4 +76,16 @@ public class VarcharLiteral extends Literal {
}
return dateLiteral;
}
@Override
public double getDouble() {
long v = 0;
int pos = 0;
int len = Math.min(value.length(), 8);
while (pos < len) {
v += ((long) value.charAt(pos)) << ((7 - pos) * 8);
pos++;
}
return (double) v;
}
}

View File

@ -153,7 +153,8 @@ public class PhysicalAggregate<CHILD_TYPE extends Plan> extends PhysicalUnary<CH
"phase", aggPhase,
"outputExpr", outputExpressions,
"groupByExpr", groupByExpressions,
"partitionExpr", partitionExpressions
"partitionExpr", partitionExpressions,
"stats", statsDeriveResult
);
}

View File

@ -105,7 +105,8 @@ public class PhysicalOlapScan extends PhysicalRelation {
public String toString() {
return Utils.toSqlString("PhysicalOlapScan",
"qualified", Utils.qualifiedName(qualifier, olapTable.getName()),
"output", getOutput()
"output", getOutput(),
"stats", statsDeriveResult
);
}

View File

@ -67,7 +67,8 @@ public class PhysicalProject<CHILD_TYPE extends Plan> extends PhysicalUnary<CHIL
@Override
public String toString() {
return Utils.toSqlString("PhysicalProject",
"projects", projects
"projects", projects,
"stats", statsDeriveResult
);
}

View File

@ -30,14 +30,11 @@ import org.apache.doris.catalog.ScalarType;
import org.apache.doris.catalog.Type;
import org.apache.doris.common.AnalysisException;
import org.apache.doris.common.util.Util;
import org.apache.doris.nereids.trees.expressions.literal.DateTimeLiteral;
import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
import java.time.LocalDate;
import java.time.LocalDateTime;
import java.time.ZoneId;
import java.time.format.DateTimeFormatter;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
@ -76,11 +73,9 @@ public class ColumnStat {
private static final Predicate<Double> DESIRED_MAX_SIZE_PRED = (v) -> v >= -1L;
private static final Predicate<Double> DESIRED_NUM_NULLS_PRED = (v) -> v >= -1L;
private static final Set<Type> MAX_MIN_UNSUPPORTED_TYPE = new HashSet<>();
public static final Set<Type> MAX_MIN_UNSUPPORTED_TYPE = new HashSet<>();
static {
MAX_MIN_UNSUPPORTED_TYPE.add(Type.VARCHAR);
MAX_MIN_UNSUPPORTED_TYPE.add(Type.CHAR);
MAX_MIN_UNSUPPORTED_TYPE.add(Type.HLL);
MAX_MIN_UNSUPPORTED_TYPE.add(Type.BITMAP);
MAX_MIN_UNSUPPORTED_TYPE.add(Type.ARRAY);
@ -98,6 +93,8 @@ public class ColumnStat {
private LiteralExpr minExpr;
private LiteralExpr maxExpr;
private double selectivity = 1.0;
public static ColumnStat createDefaultColumnStats() {
ColumnStat columnStat = new ColumnStat();
columnStat.setAvgSizeByte(1);
@ -107,7 +104,7 @@ public class ColumnStat {
return columnStat;
}
public static boolean isInvalid(ColumnStat stats) {
public static boolean isUnKnown(ColumnStat stats) {
return stats == UNKNOWN;
}
@ -121,6 +118,7 @@ public class ColumnStat {
this.numNulls = other.numNulls;
this.minValue = other.minValue;
this.maxValue = other.maxValue;
this.selectivity = other.selectivity;
}
public ColumnStat(double ndv, double avgSizeByte,
@ -260,14 +258,14 @@ public class ColumnStat {
return Double.parseDouble(columnValue);
case DATE:
case DATEV2:
return LocalDate.parse(columnValue).atStartOfDay()
.atZone(ZoneId.systemDefault()).toInstant().getEpochSecond();
org.apache.doris.nereids.trees.expressions.literal.DateLiteral literal =
new org.apache.doris.nereids.trees.expressions.literal.DateLiteral(columnValue);
return literal.getDouble();
case DATETIMEV2:
case DATETIME:
DateTimeFormatter timeFormatter = DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss");
return LocalDateTime
.parse(columnValue, timeFormatter)
.atZone(ZoneId.systemDefault()).toInstant().getEpochSecond();
DateTimeLiteral dateTimeLiteral = new DateTimeLiteral(columnValue);
return dateTimeLiteral.getDouble();
case CHAR:
case VARCHAR:
return convertStringToDouble(columnValue);
@ -301,7 +299,10 @@ public class ColumnStat {
}
public ColumnStat updateBySelectivity(double selectivity, double rowCount) {
ndv = ndv * selectivity;
if (ColumnStat.isAlmostUnique(ndv, rowCount)) {
this.selectivity = selectivity;
ndv = ndv * selectivity;
}
numNulls = (long) Math.ceil(numNulls * selectivity);
if (ndv > rowCount) {
ndv = rowCount;
@ -434,4 +435,34 @@ public class ColumnStat {
}
}
public static boolean isAlmostUnique(double ndv, double rowCount) {
return rowCount * 0.9 < ndv && ndv < rowCount * 1.1;
}
public double getSelectivity() {
return selectivity;
}
public void setSelectivity(double selectivity) {
this.selectivity = selectivity;
}
public double ndvIntersection(ColumnStat other) {
if (maxValue == minValue) {
if (minValue <= other.maxValue && minValue >= other.minValue) {
return 1;
} else {
return 0;
}
}
double min = Math.max(minValue, other.minValue);
double max = Math.min(maxValue, other.maxValue);
if (min < max) {
return Math.ceil(ndv * (max - min) / (maxValue - minValue));
} else if (min > max) {
return 0;
} else {
return 1;
}
}
}

View File

@ -26,6 +26,7 @@ import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
/**
* This structure is maintained in each operator to store the statistical information results obtained by the operator.
@ -40,9 +41,10 @@ public class StatsDeriveResult {
private final Map<Id, Long> columnIdToNdv = Maps.newHashMap();
private Map<Slot, ColumnStat> slotToColumnStats;
private int width = 1;
private double penalty = 0.0;
//TODO: isReduced to be removed after remove StatsCalculatorV1
public boolean isReduced = false;
public int width = 1;
public StatsDeriveResult(double rowCount, Map<Slot, ColumnStat> slotToColumnStats) {
this.rowCount = rowCount;
@ -65,6 +67,7 @@ public class StatsDeriveResult {
}
this.isReduced = another.isReduced;
this.width = another.width;
this.penalty = another.penalty;
}
public double computeSize() {
@ -118,20 +121,37 @@ public class StatsDeriveResult {
this.slotToColumnStats = slotToColumnStats;
}
public StatsDeriveResult updateRowCountBySelectivity(double selectivity) {
rowCount *= selectivity;
public void updateColumnStatsForSlot(Slot slot, ColumnStat columnStat) {
slotToColumnStats.put(slot, columnStat);
}
public StatsDeriveResult updateBySelectivity(double selectivity, Set<Slot> exclude) {
double originRowCount = rowCount;
for (Entry<Slot, ColumnStat> entry : slotToColumnStats.entrySet()) {
entry.getValue().updateBySelectivity(selectivity, rowCount);
if (!exclude.contains(entry.getKey())) {
entry.getValue().updateBySelectivity(selectivity, originRowCount);
}
}
rowCount *= selectivity;
return this;
}
public StatsDeriveResult updateBySelectivity(double selectivity) {
double originRowCount = rowCount;
for (Entry<Slot, ColumnStat> entry : slotToColumnStats.entrySet()) {
entry.getValue().updateBySelectivity(selectivity, originRowCount);
}
rowCount *= selectivity;
return this;
}
public StatsDeriveResult updateRowCountByLimit(long limit) {
double originRowCount = rowCount;
if (limit > 0 && rowCount > 0 && rowCount > limit) {
double selectivity = ((double) limit) / rowCount;
rowCount = limit;
for (Entry<Slot, ColumnStat> entry : slotToColumnStats.entrySet()) {
entry.getValue().updateBySelectivity(selectivity, rowCount);
entry.getValue().updateBySelectivity(selectivity, originRowCount);
}
}
return this;
@ -151,9 +171,10 @@ public class StatsDeriveResult {
@Override
public String toString() {
StringBuilder builder = new StringBuilder();
builder.append("(rows=").append(rowCount)
builder.append("(rows=").append((long) rowCount)
.append(", isReduced=").append(isReduced)
.append(", width=").append(width).append(")");
.append(", width=").append(width)
.append(", penalty=").append(penalty).append(")");
return builder.toString();
}
@ -182,4 +203,20 @@ public class StatsDeriveResult {
public ColumnStat getColumnStatsBySlot(Slot slot) {
return slotToColumnStats.get(slot);
}
public int getWidth() {
return width;
}
public void setWidth(int width) {
this.width = width;
}
public double getPenalty() {
return penalty;
}
public void setPenalty(double penalty) {
this.penalty = penalty;
}
}

View File

@ -88,7 +88,10 @@ class FilterEstimationTest {
FilterEstimation filterEstimation = new FilterEstimation(stat);
StatsDeriveResult expected = filterEstimation.estimate(or);
Assertions.assertTrue(
Precision.equals((0.5 * 0.1 + 0.1 / 0.3 - 0.5 * 0.1 * 0.1 / 0.3) * 1000, expected.getRowCount(), 0.01));
Precision.equals((0.5 * 0.1
+ FilterEstimation.DEFAULT_INEQUALITY_COMPARISON_SELECTIVITY
- 0.5 * 0.1 * FilterEstimation.DEFAULT_INEQUALITY_COMPARISON_SELECTIVITY) * 1000,
expected.getRowCount(), 0.01));
}
// a >= 500