[fix](Nereids): fix AssertNumRows StatsCalculator (#30053)

This commit is contained in:
jakevin
2024-01-18 17:18:41 +08:00
committed by yiguolei
parent a5ca8833d7
commit 097641b543
9 changed files with 269 additions and 257 deletions

View File

@ -26,6 +26,7 @@ import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.memo.Group;
import org.apache.doris.nereids.memo.GroupExpression;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.AssertNumRowsElement;
import org.apache.doris.nereids.trees.expressions.CTEId;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
@ -167,7 +168,7 @@ public class StatsCalculator extends DefaultPlanVisitor<Statistics, Void> {
private CascadesContext cascadesContext;
private StatsCalculator(GroupExpression groupExpression, boolean forbidUnknownColStats,
Map<String, ColumnStatistic> columnStatisticMap, boolean isPlayNereidsDump,
Map<String, ColumnStatistic> columnStatisticMap, boolean isPlayNereidsDump,
Map<CTEId, Statistics> cteIdToStats, CascadesContext context) {
this.groupExpression = groupExpression;
this.forbidUnknownColStats = forbidUnknownColStats;
@ -193,7 +194,7 @@ public class StatsCalculator extends DefaultPlanVisitor<Statistics, Void> {
* estimate stats
*/
public static StatsCalculator estimate(GroupExpression groupExpression, boolean forbidUnknownColStats,
Map<String, ColumnStatistic> columnStatisticMap, boolean isPlayNereidsDump,
Map<String, ColumnStatistic> columnStatisticMap, boolean isPlayNereidsDump,
Map<CTEId, Statistics> cteIdToStats, CascadesContext context) {
StatsCalculator statsCalculator = new StatsCalculator(
groupExpression, forbidUnknownColStats, columnStatisticMap, isPlayNereidsDump, cteIdToStats, context);
@ -369,7 +370,7 @@ public class StatsCalculator extends DefaultPlanVisitor<Statistics, Void> {
@Override
public Statistics visitLogicalAssertNumRows(
LogicalAssertNumRows<? extends Plan> assertNumRows, Void context) {
return computeAssertNumRows(assertNumRows.getAssertNumRowsElement().getDesiredNumOfRows());
return computeAssertNumRows(assertNumRows.getAssertNumRowsElement());
}
@Override
@ -533,7 +534,7 @@ public class StatsCalculator extends DefaultPlanVisitor<Statistics, Void> {
@Override
public Statistics visitPhysicalAssertNumRows(PhysicalAssertNumRows<? extends Plan> assertNumRows,
Void context) {
return computeAssertNumRows(assertNumRows.getAssertNumRowsElement().getDesiredNumOfRows());
return computeAssertNumRows(assertNumRows.getAssertNumRowsElement());
}
@Override
@ -556,11 +557,34 @@ public class StatsCalculator extends DefaultPlanVisitor<Statistics, Void> {
return computeGenerate(generate);
}
private Statistics computeAssertNumRows(long desiredNumOfRows) {
private Statistics computeAssertNumRows(AssertNumRowsElement assertNumRowsElement) {
Statistics statistics = groupExpression.childStatistics(0);
statistics.withRowCountAndEnforceValid(Math.min(1, statistics.getRowCount()));
statistics = new StatisticsBuilder(statistics).setWidthInJoinCluster(1).build();
return statistics;
long newRowCount;
long rowCount = (long) statistics.getRowCount();
long desiredNumOfRows = assertNumRowsElement.getDesiredNumOfRows();
switch (assertNumRowsElement.getAssertion()) {
case EQ:
newRowCount = desiredNumOfRows;
break;
case GE:
newRowCount = statistics.getRowCount() >= desiredNumOfRows ? rowCount : desiredNumOfRows;
break;
case GT:
newRowCount = statistics.getRowCount() > desiredNumOfRows ? rowCount : desiredNumOfRows;
break;
case LE:
newRowCount = statistics.getRowCount() <= desiredNumOfRows ? rowCount : desiredNumOfRows;
break;
case LT:
newRowCount = statistics.getRowCount() < desiredNumOfRows ? rowCount : desiredNumOfRows;
break;
case NE:
return statistics;
default:
throw new IllegalArgumentException("Unknown assertion: " + assertNumRowsElement.getAssertion());
}
Statistics newStatistics = statistics.withRowCountAndEnforceValid(newRowCount);
return new StatisticsBuilder(newStatistics).setWidthInJoinCluster(1).build();
}
private Statistics computeFilter(Filter filter) {
@ -610,7 +634,7 @@ public class StatsCalculator extends DefaultPlanVisitor<Statistics, Void> {
}
} else {
return Env.getCurrentEnv().getStatisticsCache().getColumnStatistics(
catalogId, dbId, table.getId(), colName);
catalogId, dbId, table.getId(), colName);
}
}
@ -701,7 +725,7 @@ public class StatsCalculator extends DefaultPlanVisitor<Statistics, Void> {
rowCount = rowCount * DEFAULT_COLUMN_NDV_RATIO;
} else {
rowCount = Math.min(rowCount, partitionByKeyStats.stream().map(s -> s.ndv)
.max(Double::compare).get() * partitionTopN.getPartitionLimit());
.max(Double::compare).get() * partitionTopN.getPartitionLimit());
}
} else {
rowCount = Math.min(rowCount, partitionTopN.getPartitionLimit());