From 598dc6960a362b3527dd4409ad41b7e9dc125b6b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B0=A2=E5=81=A5?= Date: Tue, 29 Aug 2023 15:01:26 +0800 Subject: [PATCH] [fix](Nereids) make agg output unchanged after normalized (#23499) The normalizedAgg rule can change the output of agg. For example: ``` select c1 as c, c1 from t having c1 > 0 ``` The normalizedAgg rule will make a plan with output c, which can cause the having filter error Therefore, the output exprId should be unchanged after normalized --- .../rules/rewrite/NormalizeAggregate.java | 49 ++++++++++++------- .../nereids/trees/expressions/CaseWhen.java | 11 ++++- .../nereids/trees/expressions/WhenClause.java | 2 +- .../nereids_p0/aggregate/aggregate.groovy | 3 ++ 4 files changed, 45 insertions(+), 20 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/NormalizeAggregate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/NormalizeAggregate.java index 90e997941a..eb683e8b58 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/NormalizeAggregate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/NormalizeAggregate.java @@ -33,6 +33,7 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalProject; import org.apache.doris.nereids.util.ExpressionUtils; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableList.Builder; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Lists; import com.google.common.collect.Maps; @@ -185,24 +186,6 @@ public class NormalizeAggregate extends OneRewriteRuleFactory implements Normali bottomProjects.addAll(aggInputSlots); // build group by exprs List normalizedGroupExprs = groupByToSlotContext.normalizeToUseSlotRef(groupingByExprs); - // build upper project, use two context to do pop up, because agg output maybe contain two part: - // group by keys and agg expressions - List upperProjects = groupByToSlotContext - .normalizeToUseSlotRefWithoutWindowFunction(aggregateOutput); - upperProjects = normalizedAggFuncsToSlotContext.normalizeToUseSlotRefWithoutWindowFunction(upperProjects); - // process Expression like Alias(SlotReference#0)#0 - upperProjects = upperProjects.stream().map(e -> { - if (e instanceof Alias) { - Alias alias = (Alias) e; - if (alias.child() instanceof SlotReference) { - SlotReference slotReference = (SlotReference) alias.child(); - if (slotReference.getExprId().equals(alias.getExprId())) { - return slotReference; - } - } - } - return e; - }).collect(Collectors.toList()); Plan bottomPlan; if (!bottomProjects.isEmpty()) { @@ -211,11 +194,41 @@ public class NormalizeAggregate extends OneRewriteRuleFactory implements Normali bottomPlan = aggregate.child(); } + List upperProjects = normalizeOutput(aggregateOutput, + groupByToSlotContext, normalizedAggFuncsToSlotContext); + return new LogicalProject<>(upperProjects, aggregate.withNormalized(normalizedGroupExprs, normalizedAggOutput, bottomPlan)); }).toRule(RuleType.NORMALIZE_AGGREGATE); } + private List normalizeOutput(List aggregateOutput, + NormalizeToSlotContext groupByToSlotContext, NormalizeToSlotContext normalizedAggFuncsToSlotContext) { + // build upper project, use two context to do pop up, because agg output maybe contain two part: + // group by keys and agg expressions + List upperProjects = groupByToSlotContext + .normalizeToUseSlotRefWithoutWindowFunction(aggregateOutput); + upperProjects = normalizedAggFuncsToSlotContext.normalizeToUseSlotRefWithoutWindowFunction(upperProjects); + + Builder builder = new ImmutableList.Builder<>(); + for (int i = 0; i < aggregateOutput.size(); i++) { + NamedExpression e = upperProjects.get(i); + // process Expression like Alias(SlotReference#0)#0 + if (e instanceof Alias && e.child(0) instanceof SlotReference) { + SlotReference slotReference = (SlotReference) e.child(0); + if (slotReference.getExprId().equals(e.getExprId())) { + e = slotReference; + } + } + // Make the output ExprId unchanged + if (!e.getExprId().equals(aggregateOutput.get(i).getExprId())) { + e = new Alias(aggregateOutput.get(i).getExprId(), e, aggregateOutput.get(i).getName()); + } + builder.add(e); + } + return builder.build(); + } + private static class CollectNonWindowedAggFuncs extends DefaultExpressionVisitor> { private static final CollectNonWindowedAggFuncs INSTANCE = new CollectNonWindowedAggFuncs(); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/CaseWhen.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/CaseWhen.java index e61d599e7b..c9233d5c14 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/CaseWhen.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/CaseWhen.java @@ -91,7 +91,16 @@ public class CaseWhen extends Expression { @Override public String toString() { - return toSql(); + StringBuilder output = new StringBuilder("CASE"); + for (Expression child : children()) { + if (child instanceof WhenClause) { + output.append(child); + } else { + output.append(" ELSE ").append(child.toString()); + } + } + output.append(" END"); + return output.toString(); } @Override diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/WhenClause.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/WhenClause.java index dea93d216d..3cc3586990 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/WhenClause.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/WhenClause.java @@ -111,6 +111,6 @@ public class WhenClause extends Expression implements BinaryExpression, ExpectsI @Override public String toString() { - return toSql(); + return " WHEN " + left().toString() + " THEN " + right().toString(); } } diff --git a/regression-test/suites/nereids_p0/aggregate/aggregate.groovy b/regression-test/suites/nereids_p0/aggregate/aggregate.groovy index 7ac3fbe9c5..e1ae3131b2 100644 --- a/regression-test/suites/nereids_p0/aggregate/aggregate.groovy +++ b/regression-test/suites/nereids_p0/aggregate/aggregate.groovy @@ -314,4 +314,7 @@ suite("aggregate") { qt_aggregate """ select avg(distinct c_bigint), avg(distinct c_double) from regression_test_nereids_p0_aggregate.${tableName} """ qt_aggregate """ select count(distinct c_bigint),count(distinct c_double),count(distinct c_string),count(distinct c_date_1),count(distinct c_timestamp_1),count(distinct c_timestamp_2),count(distinct c_timestamp_3),count(distinct c_boolean) from regression_test_nereids_p0_aggregate.${tableName} """ qt_select_quantile_percent """ select QUANTILE_PERCENT(QUANTILE_UNION(TO_QUANTILE_STATE(c_bigint,2048)),0.5) from regression_test_nereids_p0_aggregate.${tableName}; """ + + sql "select k1 as k, k1 from tempbaseall group by k1 having k1 > 0" + sql "select k1 as k, k1 from tempbaseall group by k1 having k > 0" }