filter estimation refactor (#18170)

This commit is contained in:
minghong
2023-03-31 08:49:38 +08:00
committed by GitHub
parent a88e80f8ee
commit 1abb19d0fd
5 changed files with 116 additions and 3 deletions

View File

@ -0,0 +1,64 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.stats;
import org.apache.doris.nereids.trees.expressions.Cast;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.statistics.ColumnStatistic;
import org.apache.doris.statistics.Statistics;
/**
* table: T(A, B)
* T.stats = (rows=10,
* {
* A->ndv=10, rows=10
* B->...
* }
* )
* after node: filter(cast(A as double)=1.0)
* filter.stats = (rows = 1
* {
* A->ndv=m, rows=1
* B->ndv=m, rows=1
* cast(A as double) -> ndv=1, rows=1
* }
* )
*
* m is computed by function computeNdv()
*
* filter.stats should be adjusted.
* A.columnStats should be equal to "cast(A as double)".columnStats
* for other expressions(except cast), we also need to adjust their input column stats.
*
*/
public class ColumnStatsAdjustVisitor extends ExpressionVisitor<ColumnStatistic, Statistics> {
@Override
public ColumnStatistic visit(Expression expr, Statistics context) {
expr.children().forEach(child -> child.accept(this, context));
return null;
}
public ColumnStatistic visitCast(Cast cast, Statistics context) {
ColumnStatistic colStats = context.findColumnStatistics(cast);
if (colStats != null) {
context.addColumnStats(cast.child(), colStats);
}
return null;
}
}

View File

@ -141,6 +141,10 @@ public class ExpressionEstimation extends ExpressionVisitor<ColumnStatistic, Sta
}
public ColumnStatistic visitCast(Cast cast, Statistics context) {
ColumnStatistic stats = context.findColumnStatistics(cast);
if (stats != null) {
return stats;
}
return cast.child().accept(this, context);
}

View File

@ -175,9 +175,11 @@ public class FilterEstimation extends ExpressionVisitor<Statistics, EstimationCo
if (statsForLeft.histogram != null) {
return estimateLessThanLiteralWithHistogram(leftExpr, statsForLeft, val, context, contains);
}
//rightRange.distinctValues should not be used
StatisticRange rightRange = new StatisticRange(statsForLeft.minValue, val, statsForLeft.ndv);
return estimateBinaryComparisonFilter(leftExpr,
statsForLeft,
new StatisticRange(Double.NEGATIVE_INFINITY, val, statsForLeft.ndv), context);
rightRange, context);
}
private Statistics updateGreaterThanLiteral(Expression leftExpr, ColumnStatistic statsForLeft,
@ -185,7 +187,8 @@ public class FilterEstimation extends ExpressionVisitor<Statistics, EstimationCo
if (statsForLeft.histogram != null) {
return estimateGreaterThanLiteralWithHistogram(leftExpr, statsForLeft, val, context, contains);
}
StatisticRange rightRange = new StatisticRange(val, Double.POSITIVE_INFINITY,
//rightRange.distinctValues should not be used
StatisticRange rightRange = new StatisticRange(val, statsForLeft.maxValue,
statsForLeft.ndv);
return estimateBinaryComparisonFilter(leftExpr, statsForLeft, rightRange, context);
}
@ -358,7 +361,7 @@ public class FilterEstimation extends ExpressionVisitor<Statistics, EstimationCo
StatisticRange rightRange, EstimationContext context) {
StatisticRange leftRange =
new StatisticRange(leftStats.minValue, leftStats.maxValue, leftStats.ndv);
StatisticRange intersectRange = leftRange.intersect(rightRange);
StatisticRange intersectRange = leftRange.cover(rightRange);
ColumnStatisticBuilder leftColumnStatisticBuilder = new ColumnStatisticBuilder(leftStats)
.setMinValue(intersectRange.getLow())
.setMaxValue(intersectRange.getHigh())
@ -366,6 +369,7 @@ public class FilterEstimation extends ExpressionVisitor<Statistics, EstimationCo
double sel = leftRange.overlapPercentWith(rightRange);
Statistics updatedStatistics = context.statistics.withSel(sel);
updatedStatistics.addColumnStats(leftExpr, leftColumnStatisticBuilder.build());
leftExpr.accept(new ColumnStatsAdjustVisitor(), updatedStatistics);
return updatedStatistics;
}

View File

@ -110,6 +110,18 @@ public class StatisticRange {
return empty();
}
public StatisticRange cover(StatisticRange other) {
double newLow = Math.max(low, other.low);
double newHigh = Math.min(high, other.high);
if (newLow <= newHigh) {
double overlapPercentOfLeft = overlapPercentWith(other);
double overlapDistinctValuesLeft = overlapPercentOfLeft * distinctValues;
double coveredDistinctValues = minExcludeNaN(distinctValues, overlapDistinctValuesLeft);
return new StatisticRange(newLow, newHigh, coveredDistinctValues);
}
return empty();
}
public StatisticRange union(StatisticRange other) {
double overlapPercentThis = this.overlapPercentWith(other);
double overlapPercentOther = other.overlapPercentWith(this);

View File

@ -18,6 +18,7 @@
package org.apache.doris.nereids.stats;
import org.apache.doris.nereids.trees.expressions.And;
import org.apache.doris.nereids.trees.expressions.Cast;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.GreaterThan;
@ -28,7 +29,9 @@ import org.apache.doris.nereids.trees.expressions.LessThanEqual;
import org.apache.doris.nereids.trees.expressions.Not;
import org.apache.doris.nereids.trees.expressions.Or;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.literal.DoubleLiteral;
import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral;
import org.apache.doris.nereids.types.DoubleType;
import org.apache.doris.nereids.types.IntegerType;
import org.apache.doris.statistics.ColumnStatistic;
import org.apache.doris.statistics.ColumnStatisticBuilder;
@ -838,4 +841,30 @@ class FilterEstimationTest {
Assertions.assertEquals(10, statsC.minValue);
Assertions.assertEquals(40, statsC.maxValue);
}
@Test
public void testBetweenCastFilter() {
SlotReference a = new SlotReference("a", IntegerType.INSTANCE);
ColumnStatisticBuilder builder = new ColumnStatisticBuilder()
.setNdv(100)
.setAvgSizeByte(4)
.setNumNulls(0)
.setMaxValue(100)
.setMinValue(0)
.setSelectivity(1.0)
.setCount(100);
DoubleLiteral begin = new DoubleLiteral(40.0);
DoubleLiteral end = new DoubleLiteral(50.0);
LessThan less = new LessThan(new Cast(a, DoubleType.INSTANCE), end);
GreaterThan greater = new GreaterThan(new Cast(a, DoubleType.INSTANCE), begin);
And and = new And(less, greater);
Statistics stats = new Statistics(100, new HashMap<>());
stats.addColumnStats(a, builder.build());
FilterEstimation filterEstimation = new FilterEstimation();
Statistics result = filterEstimation.estimate(and, stats);
Assertions.assertEquals(result.getRowCount(), 10, 0.01);
ColumnStatistic colStats = result.findColumnStatistics(a);
Assertions.assertTrue(colStats != null);
Assertions.assertEquals(10, colStats.ndv, 0.1);
}
}