[refactor](nereids) unify withSel/updateRowCountOnly/withRowCount (#24997)

1.refactor statistics functions withSel/updateRowCountOnly/withRowCount,
2. donot use Double.MAX in stats estimation
3. dateLikeType.rangeLength() do not throw DateTimeException.
This commit is contained in:
minghong
2023-10-07 16:22:30 +08:00
committed by GitHub
parent 335804bb25
commit 1405f1efd2
9 changed files with 53 additions and 118 deletions

View File

@ -125,7 +125,7 @@ public class DeriveStatsJob extends Job {
// child group's row count unchanged when the parent group expression is a project operation.
double parentRowCount = groupExpression.getOwnerGroup().getStatistics().getRowCount();
groupExpression.children().forEach(g -> g.setStatistics(
g.getStatistics().updateRowCountAndColStats(parentRowCount))
g.getStatistics().withRowCountAndEnforceValid(parentRowCount))
);
}
}

View File

@ -151,7 +151,7 @@ public class ExpressionEstimation extends ExpressionVisitor<ColumnStatistic, Sta
return new ColumnStatisticBuilder()
.setNdv(2)
.setMinValue(0)
.setMaxValue(Double.MAX_VALUE)
.setMaxValue(Double.POSITIVE_INFINITY)
.setAvgSizeByte(8)
.setNumNulls(0)
.build();

View File

@ -91,7 +91,7 @@ public class FilterEstimation extends ExpressionVisitor<Statistics, EstimationCo
@Override
public Statistics visit(Expression expr, EstimationContext context) {
return context.statistics.withSel(DEFAULT_INEQUALITY_COEFFICIENT, false);
return context.statistics.withSel(DEFAULT_INEQUALITY_COEFFICIENT);
}
@Override
@ -106,7 +106,7 @@ public class FilterEstimation extends ExpressionVisitor<Statistics, EstimationCo
} else if (predicate instanceof Or) {
Statistics rightStats = rightExpr.accept(this, context);
double rowCount = leftStats.getRowCount() + rightStats.getRowCount() - andStats.getRowCount();
Statistics orStats = context.statistics.setRowCount(rowCount);
Statistics orStats = context.statistics.withRowCount(rowCount);
Set<Slot> leftInputSlots = leftExpr.getInputSlots();
Set<Slot> rightInputSlots = rightExpr.getInputSlots();
for (Slot slot : context.keyColumns) {
@ -164,7 +164,7 @@ public class FilterEstimation extends ExpressionVisitor<Statistics, EstimationCo
double rowCount = context.statistics.getRowCount();
double newRowCount = Math.max(rowCount * DEFAULT_HAVING_COEFFICIENT,
Math.max(statsForLeft.ndv, statsForRight.ndv));
return context.statistics.setRowCount(newRowCount);
return context.statistics.withRowCount(newRowCount);
}
}
if (!left.isConstant() && !right.isConstant()) {
@ -207,7 +207,7 @@ public class FilterEstimation extends ExpressionVisitor<Statistics, EstimationCo
private Statistics calculateWhenLiteralRight(ComparisonPredicate cp,
ColumnStatistic statsForLeft, ColumnStatistic statsForRight, EstimationContext context) {
if (statsForLeft.isUnKnown) {
return context.statistics.withSel(DEFAULT_INEQUALITY_COEFFICIENT, false);
return context.statistics.withSel(DEFAULT_INEQUALITY_COEFFICIENT);
}
if (cp instanceof EqualTo || cp instanceof NullSafeEqual) {
@ -241,7 +241,7 @@ public class FilterEstimation extends ExpressionVisitor<Statistics, EstimationCo
return estimateEqualToWithHistogram(cp.left(), statsForLeft, val, context);
}
Statistics equalStats = context.statistics.withSel(selectivity, false);
Statistics equalStats = context.statistics.withSel(selectivity);
Expression left = cp.left();
equalStats.addColumnStats(left, statsForRight);
context.addKeyIfSlot(left);
@ -272,7 +272,7 @@ public class FilterEstimation extends ExpressionVisitor<Statistics, EstimationCo
Expression compareExpr = inPredicate.getCompareExpr();
ColumnStatistic compareExprStats = ExpressionEstimation.estimate(compareExpr, context.statistics);
if (compareExprStats.isUnKnown || compareExpr instanceof Function) {
return context.statistics.withSel(DEFAULT_IN_COEFFICIENT, false);
return context.statistics.withSel(DEFAULT_IN_COEFFICIENT);
}
List<Expression> options = inPredicate.getOptions();
// init minOption and maxOption by compareExpr.max and compareExpr.min respectively,
@ -348,7 +348,7 @@ public class FilterEstimation extends ExpressionVisitor<Statistics, EstimationCo
}
}
Statistics estimated = new Statistics(context.statistics);
estimated = estimated.withSel(selectivity, false);
estimated = estimated.withSel(selectivity);
estimated.addColumnStats(compareExpr,
compareExprStatsBuilder.build());
context.addKeyIfSlot(compareExpr);
@ -468,7 +468,7 @@ public class FilterEstimation extends ExpressionVisitor<Statistics, EstimationCo
ColumnStatisticBuilder leftColumnStatisticBuilder;
Statistics updatedStatistics;
if (intersectRange.isEmpty()) {
updatedStatistics = context.statistics.setRowCount(0);
updatedStatistics = context.statistics.withRowCount(0);
leftColumnStatisticBuilder = new ColumnStatisticBuilder(leftStats)
.setMinValue(Double.NEGATIVE_INFINITY)
.setMinExpr(null)
@ -484,7 +484,7 @@ public class FilterEstimation extends ExpressionVisitor<Statistics, EstimationCo
.setMaxExpr(intersectRange.getHighExpr())
.setNdv(intersectRange.getDistinctValues());
double sel = leftRange.overlapPercentWith(rightRange);
updatedStatistics = context.statistics.withSel(sel, false);
updatedStatistics = context.statistics.withSel(sel);
leftColumnStatisticBuilder.setCount(updatedStatistics.getRowCount());
}
updatedStatistics.addColumnStats(leftExpr, leftColumnStatisticBuilder.build());
@ -504,7 +504,7 @@ public class FilterEstimation extends ExpressionVisitor<Statistics, EstimationCo
intersectBuilder.setMinValue(intersect.getLow());
intersectBuilder.setMaxValue(intersect.getHigh());
double sel = 1 / StatsMathUtil.nonZeroDivisor(Math.max(leftStats.ndv, rightStats.ndv));
Statistics updatedStatistics = context.statistics.withSel(sel, false);
Statistics updatedStatistics = context.statistics.withSel(sel);
updatedStatistics.addColumnStats(leftExpr, intersectBuilder.build());
updatedStatistics.addColumnStats(rightExpr, intersectBuilder.build());
context.addKeyIfSlot(leftExpr);
@ -520,7 +520,7 @@ public class FilterEstimation extends ExpressionVisitor<Statistics, EstimationCo
// Left always less than Right
if (leftRange.getHigh() < rightRange.getLow()) {
statistics =
context.statistics.setRowCount(Math.min(context.statistics.getRowCount() - leftStats.numNulls,
context.statistics.withRowCount(Math.min(context.statistics.getRowCount() - leftStats.numNulls,
context.statistics.getRowCount() - rightStats.numNulls));
statistics.addColumnStats(leftExpr, new ColumnStatisticBuilder(leftStats).setNumNulls(0.0).build());
statistics.addColumnStats(rightExpr, new ColumnStatisticBuilder(rightStats).setNumNulls(0.0).build());
@ -531,7 +531,7 @@ public class FilterEstimation extends ExpressionVisitor<Statistics, EstimationCo
double leftOverlapPercent = leftRange.overlapPercentWith(rightRange);
// Left always greater than right
if (leftOverlapPercent == 0) {
return context.statistics.setRowCount(0.0);
return context.statistics.withRowCount(0.0);
}
StatisticRange leftAlwaysLessThanRightRange = new StatisticRange(leftStats.minValue, leftStats.minExpr,
rightStats.minValue, rightStats.minExpr, Double.NaN, leftExpr.getDataType());
@ -564,7 +564,7 @@ public class FilterEstimation extends ExpressionVisitor<Statistics, EstimationCo
+ leftOverlapPercent * rightAlwaysGreaterRangeFraction;
context.addKeyIfSlot(leftExpr);
context.addKeyIfSlot(rightExpr);
return context.statistics.withSel(sel, false)
return context.statistics.withSel(sel)
.addColumnStats(leftExpr, leftColumnStatistic)
.addColumnStats(rightExpr, rightColumnStatistic);
}
@ -598,10 +598,10 @@ public class FilterEstimation extends ExpressionVisitor<Statistics, EstimationCo
.setHistogram(new HistogramBuilder(leftHist).setBuckets(updatedBucketList).build())
.build();
context.addKeyIfSlot(leftExpr);
return context.statistics.withSel(sel, false).addColumnStats(leftExpr, columnStatistic);
return context.statistics.withSel(sel).addColumnStats(leftExpr, columnStatistic);
}
}
return context.statistics.withSel(0, false);
return context.statistics.withSel(0);
}
private Statistics estimateGreaterThanLiteralWithHistogram(Expression leftExpr, ColumnStatistic leftStats,
@ -635,10 +635,10 @@ public class FilterEstimation extends ExpressionVisitor<Statistics, EstimationCo
.setHistogram(new HistogramBuilder(leftHist).setBuckets(updatedBucketList).build())
.build();
context.addKeyIfSlot(leftExpr);
return context.statistics.withSel(sel, false).addColumnStats(leftExpr, columnStatistic);
return context.statistics.withSel(sel).addColumnStats(leftExpr, columnStatistic);
}
}
return context.statistics.withSel(0, false);
return context.statistics.withSel(0);
}
private Statistics estimateEqualToWithHistogram(Expression leftExpr, ColumnStatistic leftStats,
@ -663,7 +663,7 @@ public class FilterEstimation extends ExpressionVisitor<Statistics, EstimationCo
.setMinValue(numVal)
.build();
context.addKeyIfSlot(leftExpr);
return context.statistics.withSel(sel, false).addColumnStats(leftExpr, columnStatistic);
return context.statistics.withSel(sel).addColumnStats(leftExpr, columnStatistic);
}
@Override

View File

@ -142,7 +142,7 @@ public class JoinEstimation {
outputRowCount = Math.max(1, outputRowCount * ratio.get());
}
}
innerJoinStats = crossJoinStats.updateRowCountAndColStats(outputRowCount);
innerJoinStats = crossJoinStats.withRowCountAndEnforceValid(outputRowCount);
return innerJoinStats;
}
@ -257,10 +257,9 @@ public class JoinEstimation {
double baseRowCount =
join.getJoinType().isLeftSemiOrAntiJoin() ? leftStats.getRowCount() : rightStats.getRowCount();
rowCount = Math.min(innerJoinStats.getRowCount(), baseRowCount);
return innerJoinStats.withRowCount(rowCount);
return innerJoinStats.withRowCountAndEnforceValid(rowCount);
} else {
StatisticsBuilder builder;
double originalRowCount = leftStats.getRowCount();
if (join.getJoinType().isLeftSemiOrAntiJoin()) {
builder = new StatisticsBuilder(leftStats);
builder.setRowCount(rowCount);
@ -268,10 +267,9 @@ public class JoinEstimation {
//right semi or anti
builder = new StatisticsBuilder(rightStats);
builder.setRowCount(rowCount);
originalRowCount = rightStats.getRowCount();
}
Statistics outputStats = builder.build();
outputStats.fix(rowCount, originalRowCount);
outputStats.enforceValid();
return outputStats;
}
}
@ -291,15 +289,15 @@ public class JoinEstimation {
Statistics innerJoinStats = estimateInnerJoin(leftStats, rightStats, join);
double rowCount = Math.max(leftStats.getRowCount(), innerJoinStats.getRowCount());
rowCount = Math.max(leftStats.getRowCount(), rowCount);
return innerJoinStats.withRowCount(rowCount);
return innerJoinStats.withRowCountAndEnforceValid(rowCount);
} else if (joinType == JoinType.RIGHT_OUTER_JOIN) {
Statistics innerJoinStats = estimateInnerJoin(leftStats, rightStats, join);
double rowCount = Math.max(rightStats.getRowCount(), innerJoinStats.getRowCount());
rowCount = Math.max(rowCount, rightStats.getRowCount());
return innerJoinStats.withRowCount(rowCount);
return innerJoinStats.withRowCountAndEnforceValid(rowCount);
} else if (joinType == JoinType.FULL_OUTER_JOIN) {
Statistics innerJoinStats = estimateInnerJoin(leftStats, rightStats, join);
return innerJoinStats.withRowCount(leftStats.getRowCount()
return innerJoinStats.withRowCountAndEnforceValid(leftStats.getRowCount()
+ rightStats.getRowCount() + innerJoinStats.getRowCount());
} else if (joinType == JoinType.CROSS_JOIN) {
return new StatisticsBuilder()

View File

@ -539,7 +539,7 @@ public class StatsCalculator extends DefaultPlanVisitor<Statistics, Void> {
private Statistics computeAssertNumRows(long desiredNumOfRows) {
Statistics statistics = groupExpression.childStatistics(0);
statistics.withRowCount(Math.min(1, statistics.getRowCount()));
statistics.withRowCountAndEnforceValid(Math.min(1, statistics.getRowCount()));
return statistics;
}
@ -657,7 +657,7 @@ public class StatsCalculator extends DefaultPlanVisitor<Statistics, Void> {
private Statistics computeTopN(TopN topN) {
Statistics stats = groupExpression.childStatistics(0);
return stats.withRowCount(Math.min(stats.getRowCount(), topN.getLimit()));
return stats.withRowCountAndEnforceValid(Math.min(stats.getRowCount(), topN.getLimit()));
}
private Statistics computePartitionTopN(PartitionTopN partitionTopN) {
@ -690,12 +690,12 @@ public class StatsCalculator extends DefaultPlanVisitor<Statistics, Void> {
// TODO: for the filter push down window situation, we will prune the row count twice
// because we keep the pushed down filter. And it will be calculated twice, one of them in 'PartitionTopN'
// and the other is in 'Filter'. It's hard to dismiss.
return childStats.updateRowCountAndColStats(rowCount);
return childStats.withRowCountAndEnforceValid(rowCount);
}
private Statistics computeLimit(Limit limit) {
Statistics stats = groupExpression.childStatistics(0);
return stats.withRowCount(Math.min(stats.getRowCount(), limit.getLimit()));
return stats.withRowCountAndEnforceValid(Math.min(stats.getRowCount(), limit.getLimit()));
}
private double estimateGroupByRowCount(List<Expression> groupByExpressions, Statistics childStats) {
@ -878,7 +878,7 @@ public class StatsCalculator extends DefaultPlanVisitor<Statistics, Void> {
for (int i = 1; i < setOperation.getArity(); ++i) {
rowCount = Math.min(rowCount, groupExpression.childStatistics(i).getRowCount());
}
double minProd = Double.MAX_VALUE;
double minProd = Double.POSITIVE_INFINITY;
for (Group group : groupExpression.children()) {
Statistics statistics = group.getStatistics();
double prod = 1.0;
@ -896,7 +896,7 @@ public class StatsCalculator extends DefaultPlanVisitor<Statistics, Void> {
leftChildStats.addColumnStats(outputs.get(i),
leftChildStats.findColumnStatistics(leftChildOutputs.get(i)));
}
return leftChildStats.withRowCount(rowCount);
return leftChildStats.withRowCountAndEnforceValid(rowCount);
}
private Statistics computeGenerate(Generate generate) {
@ -910,8 +910,8 @@ public class StatsCalculator extends DefaultPlanVisitor<Statistics, Void> {
for (Slot output : generate.getGeneratorOutput()) {
ColumnStatistic columnStatistic = new ColumnStatisticBuilder()
.setCount(count)
.setMinValue(Double.MAX_VALUE)
.setMaxValue(Double.MIN_VALUE)
.setMinValue(Double.NEGATIVE_INFINITY)
.setMaxValue(Double.POSITIVE_INFINITY)
.setNdv(count)
.setNumNulls(0)
.setAvgSizeByte(output.getDataType().width())

View File

@ -27,6 +27,7 @@ import org.apache.doris.nereids.types.DateTimeV2Type;
import org.apache.doris.nereids.types.DateType;
import org.apache.doris.nereids.types.DateV2Type;
import java.time.DateTimeException;
import java.time.LocalDate;
import java.time.temporal.ChronoUnit;
@ -49,11 +50,15 @@ public abstract class DateLikeType extends PrimitiveType {
return 0;
}
if (Double.isInfinite(high) || Double.isInfinite(low)) {
return high - low;
return Double.POSITIVE_INFINITY;
}
try {
LocalDate to = toLocalDate(high);
LocalDate from = toLocalDate(low);
return ChronoUnit.DAYS.between(from, to);
} catch (DateTimeException e) {
return Double.POSITIVE_INFINITY;
}
LocalDate to = toLocalDate(high);
LocalDate from = toLocalDate(low);
return ChronoUnit.DAYS.between(from, to);
}
/**

View File

@ -177,10 +177,10 @@ public class ColumnStatistic {
columnStatisticBuilder.setMinExpr(StatisticsUtil.readableValue(col.getType(), min));
} catch (AnalysisException e) {
LOG.warn("Failed to deserialize column {} min value {}.", col, min, e);
columnStatisticBuilder.setMinValue(Double.MIN_VALUE);
columnStatisticBuilder.setMinValue(Double.NEGATIVE_INFINITY);
}
} else {
columnStatisticBuilder.setMinValue(Double.MIN_VALUE);
columnStatisticBuilder.setMinValue(Double.NEGATIVE_INFINITY);
}
if (max != null && !max.equalsIgnoreCase("NULL")) {
try {
@ -188,10 +188,10 @@ public class ColumnStatistic {
columnStatisticBuilder.setMaxExpr(StatisticsUtil.readableValue(col.getType(), max));
} catch (AnalysisException e) {
LOG.warn("Failed to deserialize column {} max value {}.", col, max, e);
columnStatisticBuilder.setMaxValue(Double.MAX_VALUE);
columnStatisticBuilder.setMaxValue(Double.POSITIVE_INFINITY);
}
} else {
columnStatisticBuilder.setMaxValue(Double.MAX_VALUE);
columnStatisticBuilder.setMaxValue(Double.POSITIVE_INFINITY);
}
columnStatisticBuilder.setUpdatedTime(row.get(13));
return columnStatisticBuilder.build();

View File

@ -37,26 +37,6 @@ public class Statistics {
// the byte size of one tuple
private double tupleSize;
/**
* after filter, compute the new ndv of a column
* @param ndv original ndv of column
* @param newRowCount the row count of table after filter
* @param oldRowCount the row count of table before filter
* @return the new ndv after filter
*/
public static double computeNdv(double ndv, double newRowCount, double oldRowCount) {
if (newRowCount > oldRowCount) {
return ndv;
}
double selectOneTuple = newRowCount / StatsMathUtil.nonZeroDivisor(oldRowCount);
double allTuplesOfSameDistinctValueNotSelected = Math.pow((1 - selectOneTuple), oldRowCount / ndv);
if (allTuplesOfSameDistinctValueNotSelected == 1.0) {
// avoid NaN
return ndv;
}
return Math.min(ndv * (1 - allTuplesOfSameDistinctValueNotSelected), newRowCount);
}
public Statistics(Statistics another) {
this.rowCount = another.rowCount;
this.expressionToColumnStats = new HashMap<>(another.expressionToColumnStats);
@ -80,53 +60,19 @@ public class Statistics {
return rowCount;
}
/*
* Return a stats with new rowCount and fix each column stats.
*/
public Statistics withRowCount(double rowCount) {
if (Double.isNaN(rowCount)) {
return this;
}
Statistics statistics = new Statistics(rowCount, new HashMap<>(expressionToColumnStats));
statistics.fix(rowCount, StatsMathUtil.nonZeroDivisor(this.rowCount));
return statistics;
}
public Statistics setRowCount(double rowCount) {
return new Statistics(rowCount, new HashMap<>(expressionToColumnStats));
}
/**
* Update by count.
*/
public Statistics updateRowCountAndColStats(double rowCount) {
public Statistics withRowCountAndEnforceValid(double rowCount) {
Statistics statistics = new Statistics(rowCount, expressionToColumnStats);
for (Entry<Expression, ColumnStatistic> entry : expressionToColumnStats.entrySet()) {
ColumnStatistic columnStatistic = entry.getValue();
ColumnStatisticBuilder columnStatisticBuilder = new ColumnStatisticBuilder(columnStatistic);
columnStatisticBuilder.setNdv(Math.min(columnStatistic.ndv, rowCount));
columnStatisticBuilder.setNumNulls(rowCount - columnStatistic.numNulls);
columnStatisticBuilder.setCount(rowCount);
expressionToColumnStats.put(entry.getKey(), columnStatisticBuilder.build());
}
statistics.enforceValid();
return statistics;
}
/**
* Fix by sel.
*/
public void fix(double newRowCount, double originRowCount) {
double sel = newRowCount / originRowCount;
for (Entry<Expression, ColumnStatistic> entry : expressionToColumnStats.entrySet()) {
ColumnStatistic columnStatistic = entry.getValue();
ColumnStatisticBuilder columnStatisticBuilder = new ColumnStatisticBuilder(columnStatistic);
columnStatisticBuilder.setNdv(computeNdv(columnStatistic.ndv, newRowCount, originRowCount));
columnStatisticBuilder.setNumNulls(Math.min(columnStatistic.numNulls * sel, newRowCount));
columnStatisticBuilder.setCount(newRowCount);
expressionToColumnStats.put(entry.getKey(), columnStatisticBuilder.build());
}
}
public void enforceValid() {
for (Entry<Expression, ColumnStatistic> entry : expressionToColumnStats.entrySet()) {
ColumnStatistic columnStatistic = entry.getValue();
@ -137,8 +83,8 @@ public class Statistics {
columnStatisticBuilder.setNumNulls(Math.min(columnStatistic.numNulls, rowCount - ndv));
columnStatisticBuilder.setCount(rowCount);
columnStatistic = columnStatisticBuilder.build();
expressionToColumnStats.put(entry.getKey(), columnStatistic);
}
expressionToColumnStats.put(entry.getKey(), columnStatistic);
}
}
@ -148,21 +94,12 @@ public class Statistics {
}
public Statistics withSel(double sel) {
return withSel(sel, true);
}
public Statistics withSel(double sel, boolean updateColStats) {
sel = StatsMathUtil.minNonNaN(sel, 1);
if (Double.isNaN(rowCount)) {
return this;
}
double newCount = rowCount * sel;
double originCount = rowCount;
Statistics statistics = new Statistics(newCount, new HashMap<>(expressionToColumnStats));
if (updateColStats) {
statistics.fix(newCount, StatsMathUtil.nonZeroDivisor(originCount));
}
return statistics;
return new Statistics(newCount, new HashMap<>(expressionToColumnStats));
}
public Statistics addColumnStats(Expression expression, ColumnStatistic columnStatistic) {
@ -176,11 +113,6 @@ public class Statistics {
&& expressionToColumnStats.get(s).isUnKnown);
}
public Statistics merge(Statistics statistics) {
expressionToColumnStats.putAll(statistics.expressionToColumnStats);
return this;
}
private double computeTupleSize() {
if (tupleSize <= 0) {
double tempSize = 0.0;

View File

@ -650,8 +650,8 @@ public class StatisticsUtil {
TableScan tableScan = table.newScan().includeColumnStats();
ColumnStatisticBuilder columnStatisticBuilder = new ColumnStatisticBuilder();
columnStatisticBuilder.setCount(0);
columnStatisticBuilder.setMaxValue(Double.MAX_VALUE);
columnStatisticBuilder.setMinValue(Double.MIN_VALUE);
columnStatisticBuilder.setMaxValue(Double.POSITIVE_INFINITY);
columnStatisticBuilder.setMinValue(Double.NEGATIVE_INFINITY);
columnStatisticBuilder.setDataSize(0);
columnStatisticBuilder.setAvgSizeByte(0);
columnStatisticBuilder.setNumNulls(0);