[fix](Nereids): fix AssertNumRows StatsCalculator (#30053)
This commit is contained in:
@ -26,6 +26,7 @@ import org.apache.doris.nereids.CascadesContext;
|
||||
import org.apache.doris.nereids.memo.Group;
|
||||
import org.apache.doris.nereids.memo.GroupExpression;
|
||||
import org.apache.doris.nereids.trees.expressions.Alias;
|
||||
import org.apache.doris.nereids.trees.expressions.AssertNumRowsElement;
|
||||
import org.apache.doris.nereids.trees.expressions.CTEId;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.NamedExpression;
|
||||
@ -167,7 +168,7 @@ public class StatsCalculator extends DefaultPlanVisitor<Statistics, Void> {
|
||||
private CascadesContext cascadesContext;
|
||||
|
||||
private StatsCalculator(GroupExpression groupExpression, boolean forbidUnknownColStats,
|
||||
Map<String, ColumnStatistic> columnStatisticMap, boolean isPlayNereidsDump,
|
||||
Map<String, ColumnStatistic> columnStatisticMap, boolean isPlayNereidsDump,
|
||||
Map<CTEId, Statistics> cteIdToStats, CascadesContext context) {
|
||||
this.groupExpression = groupExpression;
|
||||
this.forbidUnknownColStats = forbidUnknownColStats;
|
||||
@ -193,7 +194,7 @@ public class StatsCalculator extends DefaultPlanVisitor<Statistics, Void> {
|
||||
* estimate stats
|
||||
*/
|
||||
public static StatsCalculator estimate(GroupExpression groupExpression, boolean forbidUnknownColStats,
|
||||
Map<String, ColumnStatistic> columnStatisticMap, boolean isPlayNereidsDump,
|
||||
Map<String, ColumnStatistic> columnStatisticMap, boolean isPlayNereidsDump,
|
||||
Map<CTEId, Statistics> cteIdToStats, CascadesContext context) {
|
||||
StatsCalculator statsCalculator = new StatsCalculator(
|
||||
groupExpression, forbidUnknownColStats, columnStatisticMap, isPlayNereidsDump, cteIdToStats, context);
|
||||
@ -369,7 +370,7 @@ public class StatsCalculator extends DefaultPlanVisitor<Statistics, Void> {
|
||||
@Override
|
||||
public Statistics visitLogicalAssertNumRows(
|
||||
LogicalAssertNumRows<? extends Plan> assertNumRows, Void context) {
|
||||
return computeAssertNumRows(assertNumRows.getAssertNumRowsElement().getDesiredNumOfRows());
|
||||
return computeAssertNumRows(assertNumRows.getAssertNumRowsElement());
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -533,7 +534,7 @@ public class StatsCalculator extends DefaultPlanVisitor<Statistics, Void> {
|
||||
@Override
|
||||
public Statistics visitPhysicalAssertNumRows(PhysicalAssertNumRows<? extends Plan> assertNumRows,
|
||||
Void context) {
|
||||
return computeAssertNumRows(assertNumRows.getAssertNumRowsElement().getDesiredNumOfRows());
|
||||
return computeAssertNumRows(assertNumRows.getAssertNumRowsElement());
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -556,11 +557,34 @@ public class StatsCalculator extends DefaultPlanVisitor<Statistics, Void> {
|
||||
return computeGenerate(generate);
|
||||
}
|
||||
|
||||
private Statistics computeAssertNumRows(long desiredNumOfRows) {
|
||||
private Statistics computeAssertNumRows(AssertNumRowsElement assertNumRowsElement) {
|
||||
Statistics statistics = groupExpression.childStatistics(0);
|
||||
statistics.withRowCountAndEnforceValid(Math.min(1, statistics.getRowCount()));
|
||||
statistics = new StatisticsBuilder(statistics).setWidthInJoinCluster(1).build();
|
||||
return statistics;
|
||||
long newRowCount;
|
||||
long rowCount = (long) statistics.getRowCount();
|
||||
long desiredNumOfRows = assertNumRowsElement.getDesiredNumOfRows();
|
||||
switch (assertNumRowsElement.getAssertion()) {
|
||||
case EQ:
|
||||
newRowCount = desiredNumOfRows;
|
||||
break;
|
||||
case GE:
|
||||
newRowCount = statistics.getRowCount() >= desiredNumOfRows ? rowCount : desiredNumOfRows;
|
||||
break;
|
||||
case GT:
|
||||
newRowCount = statistics.getRowCount() > desiredNumOfRows ? rowCount : desiredNumOfRows;
|
||||
break;
|
||||
case LE:
|
||||
newRowCount = statistics.getRowCount() <= desiredNumOfRows ? rowCount : desiredNumOfRows;
|
||||
break;
|
||||
case LT:
|
||||
newRowCount = statistics.getRowCount() < desiredNumOfRows ? rowCount : desiredNumOfRows;
|
||||
break;
|
||||
case NE:
|
||||
return statistics;
|
||||
default:
|
||||
throw new IllegalArgumentException("Unknown assertion: " + assertNumRowsElement.getAssertion());
|
||||
}
|
||||
Statistics newStatistics = statistics.withRowCountAndEnforceValid(newRowCount);
|
||||
return new StatisticsBuilder(newStatistics).setWidthInJoinCluster(1).build();
|
||||
}
|
||||
|
||||
private Statistics computeFilter(Filter filter) {
|
||||
@ -610,7 +634,7 @@ public class StatsCalculator extends DefaultPlanVisitor<Statistics, Void> {
|
||||
}
|
||||
} else {
|
||||
return Env.getCurrentEnv().getStatisticsCache().getColumnStatistics(
|
||||
catalogId, dbId, table.getId(), colName);
|
||||
catalogId, dbId, table.getId(), colName);
|
||||
}
|
||||
}
|
||||
|
||||
@ -701,7 +725,7 @@ public class StatsCalculator extends DefaultPlanVisitor<Statistics, Void> {
|
||||
rowCount = rowCount * DEFAULT_COLUMN_NDV_RATIO;
|
||||
} else {
|
||||
rowCount = Math.min(rowCount, partitionByKeyStats.stream().map(s -> s.ndv)
|
||||
.max(Double::compare).get() * partitionTopN.getPartitionLimit());
|
||||
.max(Double::compare).get() * partitionTopN.getPartitionLimit());
|
||||
}
|
||||
} else {
|
||||
rowCount = Math.min(rowCount, partitionTopN.getPartitionLimit());
|
||||
|
||||
Reference in New Issue
Block a user