[fix](Nereids) fix bug in case-when/if stats estimation (#30265)

This commit is contained in:
minghong
2024-01-23 19:07:28 +08:00
committed by yiguolei
parent 1a51d04cb8
commit 4af3fd2a2e
2 changed files with 114 additions and 17 deletions

View File

@ -39,6 +39,7 @@ import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.Subtract;
import org.apache.doris.nereids.trees.expressions.TimestampArithmetic;
import org.apache.doris.nereids.trees.expressions.VirtualSlotReference;
import org.apache.doris.nereids.trees.expressions.WhenClause;
import org.apache.doris.nereids.trees.expressions.functions.BoundFunction;
import org.apache.doris.nereids.trees.expressions.functions.agg.Avg;
import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
@ -137,21 +138,37 @@ public class ExpressionEstimation extends ExpressionVisitor<ColumnStatistic, Sta
//TODO: case-when need to re-implemented
@Override
public ColumnStatistic visitCaseWhen(CaseWhen caseWhen, Statistics context) {
double ndv = caseWhen.getWhenClauses().size();
if (caseWhen.getDefaultValue().isPresent()) {
ndv += 1;
}
for (WhenClause clause : caseWhen.getWhenClauses()) {
ColumnStatistic colStats = ExpressionEstimation.estimate(clause.getResult(), context);
ndv = Math.max(ndv, colStats.ndv);
}
if (caseWhen.getDefaultValue().isPresent()) {
ColumnStatistic colStats = ExpressionEstimation.estimate(caseWhen.getDefaultValue().get(), context);
ndv = Math.max(ndv, colStats.ndv);
}
return new ColumnStatisticBuilder()
.setNdv(caseWhen.getWhenClauses().size() + 1)
.setMinValue(0)
.setMaxValue(Double.MAX_VALUE)
.setNdv(ndv)
.setMinValue(Double.NEGATIVE_INFINITY)
.setMaxValue(Double.POSITIVE_INFINITY)
.setAvgSizeByte(8)
.setNumNulls(0)
.build();
}
@Override
public ColumnStatistic visitIf(If function, Statistics context) {
// TODO: copy from visitCaseWhen, polish them.
public ColumnStatistic visitIf(If ifClause, Statistics context) {
double ndv = 2;
ColumnStatistic colStatsThen = ExpressionEstimation.estimate(ifClause.child(1), context);
ndv = Math.max(ndv, colStatsThen.ndv);
ColumnStatistic colStatsElse = ExpressionEstimation.estimate(ifClause.child(2), context);
ndv = Math.max(ndv, colStatsElse.ndv);
return new ColumnStatisticBuilder()
.setNdv(2)
.setMinValue(0)
.setNdv(ndv)
.setMinValue(Double.NEGATIVE_INFINITY)
.setMaxValue(Double.POSITIVE_INFINITY)
.setAvgSizeByte(8)
.setNumNulls(0)
@ -577,13 +594,22 @@ public class ExpressionEstimation extends ExpressionVisitor<ColumnStatistic, Sta
if (childColumnStats.minOrMaxIsInf()) {
return columnStatisticBuilder.build();
}
double minValue = getDatetimeFromLong((long) childColumnStats.minValue).toLocalDate()
.atStartOfDay(ZoneId.systemDefault()).toEpochSecond();
double maxValue = getDatetimeFromLong((long) childColumnStats.maxValue).toLocalDate()
.atStartOfDay(ZoneId.systemDefault()).toEpochSecond();
double minValue;
double maxValue;
try {
// min/max value is infinite, but they may be too large to convert to date
minValue = getDatetimeFromLong((long) childColumnStats.minValue).toLocalDate()
.atStartOfDay(ZoneId.systemDefault()).toEpochSecond();
maxValue = getDatetimeFromLong((long) childColumnStats.maxValue).toLocalDate()
.atStartOfDay(ZoneId.systemDefault()).toEpochSecond();
} catch (Exception e) {
// ignore DateTimeException
minValue = Double.NEGATIVE_INFINITY;
maxValue = Double.POSITIVE_INFINITY;
}
return columnStatisticBuilder.setMaxValue(maxValue)
.setMinValue(minValue)
.build();
.setMinValue(minValue).build();
}
private LocalDateTime getDatetimeFromLong(long dateTime) {
@ -599,10 +625,18 @@ public class ExpressionEstimation extends ExpressionVisitor<ColumnStatistic, Sta
if (childColumnStats.minOrMaxIsInf()) {
return columnStatisticBuilder.build();
}
double minValue = getDatetimeFromLong((long) childColumnStats.minValue).toLocalDate().toEpochDay()
+ (double) DAYS_FROM_0_TO_1970;
double maxValue = getDatetimeFromLong((long) childColumnStats.maxValue).toLocalDate().toEpochDay()
+ (double) DAYS_FROM_0_TO_1970;
double minValue;
double maxValue;
try {
minValue = getDatetimeFromLong((long) childColumnStats.minValue).toLocalDate().toEpochDay()
+ (double) DAYS_FROM_0_TO_1970;
maxValue = getDatetimeFromLong((long) childColumnStats.maxValue).toLocalDate().toEpochDay()
+ (double) DAYS_FROM_0_TO_1970;
} catch (Exception e) {
// ignore DateTimeException
minValue = Double.NEGATIVE_INFINITY;
maxValue = Double.POSITIVE_INFINITY;
}
return columnStatisticBuilder.setMaxValue(maxValue)
.setMinValue(minValue)
.build();