branch-2.1: [fix](nereids) do eliminate constant group by key in normalizeagg #49589 (#50212)

Cherry-picked from https://github.com/apache/doris/pull/49589
This commit is contained in:
feiniaofeiafei
2025-05-08 18:52:40 +08:00
committed by GitHub
parent 995f1e5dc0
commit ebe302cb7e
7 changed files with 588 additions and 179 deletions

View File

@ -30,7 +30,6 @@ import org.apache.doris.nereids.rules.analysis.CheckPolicy;
import org.apache.doris.nereids.rules.analysis.CollectJoinConstraint;
import org.apache.doris.nereids.rules.analysis.CollectSubQueryAlias;
import org.apache.doris.nereids.rules.analysis.EliminateDistinctConstant;
import org.apache.doris.nereids.rules.analysis.EliminateGroupByConstant;
import org.apache.doris.nereids.rules.analysis.EliminateLogicalSelectHint;
import org.apache.doris.nereids.rules.analysis.FillUpMissingSlots;
import org.apache.doris.nereids.rules.analysis.HavingToFilter;
@ -136,8 +135,6 @@ public class Analyzer extends AbstractBatchJobExecutor {
// select SUM(lo_tax) FROM lineorder group by 1;
// errCode = 2, detailMessage = GROUP BY expression must not contain aggregate functions: sum(lo_tax)
bottomUp(new CheckAnalysis()),
topDown(new EliminateGroupByConstant()),
topDown(new SimplifyAggGroupBy()),
topDown(new NormalizeAggregate()),
topDown(new HavingToFilter()),

View File

@ -25,7 +25,6 @@ import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.analysis.AdjustAggregateNullableForEmptySet;
import org.apache.doris.nereids.rules.analysis.AvgDistinctToSumDivCount;
import org.apache.doris.nereids.rules.analysis.CheckAfterRewrite;
import org.apache.doris.nereids.rules.analysis.EliminateGroupByConstant;
import org.apache.doris.nereids.rules.analysis.LogicalSubQueryAliasToLogicalProject;
import org.apache.doris.nereids.rules.analysis.NormalizeAggregate;
import org.apache.doris.nereids.rules.expression.CheckLegalityAfterRewrite;
@ -158,7 +157,6 @@ public class Rewriter extends AbstractBatchJobExecutor {
topDown(
new EliminateOrderByConstant(),
new EliminateSortUnderSubqueryOrView(),
new EliminateGroupByConstant(),
// MergeProjects depends on this rule
new LogicalSubQueryAliasToLogicalProject(),
// TODO: we should do expression normalization after plan normalization

View File

@ -17,9 +17,12 @@
package org.apache.doris.nereids.rules.analysis;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext;
import org.apache.doris.nereids.rules.expression.rules.FoldConstantRuleOnFE;
import org.apache.doris.nereids.rules.rewrite.NormalizeToSlot;
import org.apache.doris.nereids.rules.rewrite.RewriteRuleFactory;
import org.apache.doris.nereids.trees.expressions.Alias;
@ -35,6 +38,7 @@ import org.apache.doris.nereids.trees.expressions.WindowExpression;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.agg.MultiDistinction;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.expressions.literal.TinyIntLiteral;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalHaving;
@ -50,6 +54,7 @@ import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Sets;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
@ -111,14 +116,16 @@ public class NormalizeAggregate implements RewriteRuleFactory, NormalizeToSlot {
public List<Rule> buildRules() {
return ImmutableList.of(
logicalHaving(logicalAggregate().whenNot(LogicalAggregate::isNormalized))
.then(having -> normalizeAgg(having.child(), Optional.of(having)))
.thenApply(ctx -> normalizeAgg(ctx.root.child(), Optional.of(ctx.root), ctx.cascadesContext))
.toRule(RuleType.NORMALIZE_AGGREGATE),
logicalAggregate().whenNot(LogicalAggregate::isNormalized)
.then(aggregate -> normalizeAgg(aggregate, Optional.empty()))
.thenApply(ctx -> normalizeAgg(ctx.root, Optional.empty(), ctx.cascadesContext))
.toRule(RuleType.NORMALIZE_AGGREGATE));
}
private LogicalPlan normalizeAgg(LogicalAggregate<Plan> aggregate, Optional<LogicalHaving<?>> having) {
@SuppressWarnings("checkstyle:UnusedLocalVariable")
private LogicalPlan normalizeAgg(LogicalAggregate<Plan> aggregate, Optional<LogicalHaving<?>> having,
CascadesContext ctx) {
// The LogicalAggregate node may contain window agg functions and usual agg functions
// we call window agg functions as window-agg and usual agg functions as trivial-agg for short
// This rule simplify LogicalAggregate node by:
@ -279,8 +286,10 @@ public class NormalizeAggregate implements RewriteRuleFactory, NormalizeToSlot {
List<NamedExpression> upperProjects = normalizeOutput(aggregateOutput,
groupByExprContext, argsOfAggFuncNeedPushDownContext, normalizedAggFuncsToSlotContext);
// create a parent project node
LogicalProject<Plan> project = new LogicalProject<>(upperProjects, newAggregate);
ExpressionRewriteContext rewriteContext = new ExpressionRewriteContext(ctx);
LogicalProject<Plan> project = eliminateGroupByConstant(groupByExprContext, rewriteContext,
normalizedGroupExprs, normalizedAggOutput, bottomProjects, aggregate, upperProjects, newAggregate);
// verify project used slots are all coming from agg's output
List<Slot> slots = collectAllUsedSlots(upperProjects);
if (!slots.isEmpty()) {
@ -389,4 +398,93 @@ public class NormalizeAggregate implements RewriteRuleFactory, NormalizeToSlot {
return expr;
}
}
private LogicalProject<Plan> eliminateGroupByConstant(NormalizeToSlotContext groupByExprContext,
ExpressionRewriteContext rewriteContext, List<Expression> normalizedGroupExprs,
List<NamedExpression> normalizedAggOutput, Set<NamedExpression> bottomProjects,
LogicalAggregate<Plan> aggregate, List<NamedExpression> upperProjects, LogicalAggregate<?> newAggregate) {
// 1. Find the expressions in group by that can be folded into constants and build a map(slot, literal)
Map<Expression, NormalizeToSlotTriplet> replaceMap = groupByExprContext.getNormalizeToSlotMap();
if (replaceMap.isEmpty()) {
return new LogicalProject<>(upperProjects, newAggregate);
}
Map<Slot, Expression> slotToLiteral = new HashMap<>();
for (Map.Entry<Expression, NormalizeToSlotTriplet> entry : replaceMap.entrySet()) {
Expression foldExpression = FoldConstantRuleOnFE.evaluate(entry.getKey(), rewriteContext);
if (foldExpression.isConstant()) {
slotToLiteral.put(entry.getValue().remainExpr, foldExpression);
}
}
if (slotToLiteral.isEmpty()) {
return new LogicalProject<>(upperProjects, newAggregate);
}
// 2. Regenerate a group by list without constant key
List<Expression> newNormalizedGroupExprs = new ArrayList<>();
for (Expression normalizedGroupExpr : normalizedGroupExprs) {
if (!slotToLiteral.containsKey((Slot) normalizedGroupExpr)) {
newNormalizedGroupExprs.add(normalizedGroupExpr);
}
}
if (newNormalizedGroupExprs.size() == normalizedGroupExprs.size()) {
return new LogicalProject<>(upperProjects, newAggregate);
}
if (newNormalizedGroupExprs.isEmpty()) {
Alias tinyInt = new Alias(new TinyIntLiteral((byte) 1));
bottomProjects = new HashSet<>(bottomProjects);
bottomProjects.add(tinyInt);
normalizedAggOutput = new ArrayList<>(normalizedAggOutput);
Slot tinyIntSlot = tinyInt.toSlot();
normalizedAggOutput.add(tinyIntSlot);
newNormalizedGroupExprs.add(tinyIntSlot);
}
// 3. Replace the agg output expression and delete the constant group by key in the output
ImmutableList.Builder<NamedExpression> nonConstAggOutput = ImmutableList.builder();
for (NamedExpression ne : normalizedAggOutput) {
if (ne instanceof Alias) {
nonConstAggOutput.add(ExpressionUtils.replaceNameExpression(ne, slotToLiteral));
continue;
} else if (ne instanceof Slot) {
if (!slotToLiteral.containsKey(ne)) {
nonConstAggOutput.add(ne);
}
continue;
}
nonConstAggOutput.add(ne);
}
// 4. The constant expression calculation in bottom projects needs to be deleted
// and put into upperProjects for calculation
Plan bottomPlan;
if (!bottomProjects.isEmpty()) {
ImmutableList.Builder<NamedExpression> builder = ImmutableList.builder();
for (NamedExpression bottomProject : bottomProjects) {
if (!slotToLiteral.containsKey(bottomProject.toSlot())) {
builder.add(bottomProject);
}
}
bottomPlan = new LogicalProject<>(builder.build(), aggregate.child());
} else {
bottomPlan = aggregate.child();
}
LogicalAggregate<Plan> newAggAfterEliminate = aggregate.withNormalized(newNormalizedGroupExprs,
nonConstAggOutput.build(), bottomPlan);
// 5. This upperProjects needs to add the constant key that was deleted in the group by key
// and change the reference to the constant key to a constant expression
ImmutableList.Builder<NamedExpression> newUpperProjects = ImmutableList.builder();
for (NamedExpression upperProject : upperProjects) {
if (upperProject instanceof Alias) {
newUpperProjects.add(ExpressionUtils.replaceNameExpression(upperProject, slotToLiteral));
continue;
} else if (upperProject instanceof Slot) {
if (slotToLiteral.containsKey(upperProject)) {
Alias newLiteral = new Alias(upperProject.getExprId(), slotToLiteral.get(upperProject),
upperProject.getName());
newUpperProjects.add(newLiteral);
continue;
}
}
newUpperProjects.add(upperProject);
}
return new LogicalProject<>(newUpperProjects.build(), newAggAfterEliminate);
}
}

View File

@ -1,165 +0,0 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.analysis;
import org.apache.doris.catalog.AggregateType;
import org.apache.doris.catalog.Column;
import org.apache.doris.catalog.KeysType;
import org.apache.doris.catalog.OlapTable;
import org.apache.doris.catalog.PartitionInfo;
import org.apache.doris.catalog.Type;
import org.apache.doris.nereids.trees.expressions.Add;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.functions.agg.Max;
import org.apache.doris.nereids.trees.expressions.functions.agg.Min;
import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral;
import org.apache.doris.nereids.trees.expressions.literal.StringLiteral;
import org.apache.doris.nereids.trees.plans.RelationId;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.util.LogicalPlanBuilder;
import org.apache.doris.nereids.util.MemoPatternMatchSupported;
import org.apache.doris.nereids.util.MemoTestUtils;
import org.apache.doris.nereids.util.PlanChecker;
import org.apache.doris.thrift.TStorageType;
import com.google.common.collect.ImmutableList;
import org.junit.jupiter.api.Test;
/** Tests for {@link EliminateGroupByConstant}. */
class EliminateGroupByConstantTest implements MemoPatternMatchSupported {
private static final OlapTable table = new OlapTable(0L, "student",
ImmutableList.of(new Column("k1", Type.INT, true, AggregateType.NONE, "0", ""),
new Column("k2", Type.INT, false, AggregateType.NONE, "0", ""),
new Column("k3", Type.INT, true, AggregateType.NONE, "", "")),
KeysType.PRIMARY_KEYS, new PartitionInfo(), null);
static {
table.setIndexMeta(-1,
"t1",
table.getFullSchema(),
0, 0, (short) 0,
TStorageType.COLUMN,
KeysType.PRIMARY_KEYS);
}
private static final LogicalOlapScan scan = new LogicalOlapScan(RelationId.createGenerator().getNextId(), table);
private static final Slot k1 = scan.getOutput().get(0);
private static final Slot k2 = scan.getOutput().get(1);
@Test
void testIntegerLiteral() {
LogicalPlan aggregate = new LogicalPlanBuilder(scan)
.agg(ImmutableList.of(new IntegerLiteral(1), k2),
ImmutableList.of(k1, k2))
.build();
PlanChecker.from(MemoTestUtils.createConnectContext(), aggregate)
.applyTopDown(new EliminateGroupByConstant())
.applyBottomUp(new CheckAfterRewrite())
.matches(
aggregate().when(agg -> agg.getGroupByExpressions().equals(ImmutableList.of(k2)))
);
}
@Test
void testOtherLiteral() {
LogicalPlan aggregate = new LogicalPlanBuilder(scan)
.agg(ImmutableList.of(
new StringLiteral("str"), k2),
ImmutableList.of(
new Alias(new StringLiteral("str"), "str"), k1, k2))
.build();
PlanChecker.from(MemoTestUtils.createConnectContext(), aggregate)
.applyTopDown(new EliminateGroupByConstant())
.applyBottomUp(new CheckAfterRewrite())
.matches(
aggregate().when(agg -> agg.getGroupByExpressions().equals(ImmutableList.of(k2)))
);
}
@Test
void testMixedLiteral() {
LogicalPlan aggregate = new LogicalPlanBuilder(scan)
.agg(ImmutableList.of(
new StringLiteral("str"), k2,
new IntegerLiteral(1),
new IntegerLiteral(2),
new IntegerLiteral(3),
new Add(k1, k2)),
ImmutableList.of(
new Alias(new StringLiteral("str"), "str"),
k2, k1, new Alias(new IntegerLiteral(1), "integer")))
.build();
PlanChecker.from(MemoTestUtils.createConnectContext(), aggregate)
.applyTopDown(new EliminateGroupByConstant())
.applyBottomUp(new CheckAfterRewrite())
.matches(
aggregate()
.when(agg -> agg.getGroupByExpressions().equals(ImmutableList.of(k2, new Add(k1, k2))))
);
}
@Test
void testComplexGroupBy() {
LogicalPlan aggregate = new LogicalPlanBuilder(scan)
.agg(ImmutableList.of(
new IntegerLiteral(1),
new IntegerLiteral(2),
new Add(k1, k2)),
ImmutableList.of(
new Alias(new Max(k1), "max"),
new Alias(new Min(k2), "min"),
new Alias(new Add(k1, k2), "add")))
.build();
PlanChecker.from(MemoTestUtils.createConnectContext(), aggregate)
.applyTopDown(new EliminateGroupByConstant())
.applyBottomUp(new CheckAfterRewrite())
.matches(
aggregate()
.when(agg -> agg.getGroupByExpressions().equals(ImmutableList.of(new Add(k1, k2))))
);
}
@Test
void testOutOfRange() {
LogicalPlan aggregate = new LogicalPlanBuilder(scan)
.agg(ImmutableList.of(
new StringLiteral("str"), k2,
new IntegerLiteral(1),
new IntegerLiteral(2),
new IntegerLiteral(3),
new IntegerLiteral(5),
new Add(k1, k2)),
ImmutableList.of(
new Alias(new StringLiteral("str"), "str"),
k2, k1, new Alias(new IntegerLiteral(1), "integer")))
.build();
PlanChecker.from(MemoTestUtils.createConnectContext(), aggregate)
.applyTopDown(new EliminateGroupByConstant())
.applyBottomUp(new CheckAfterRewrite())
.matches(
aggregate()
.when(agg -> agg.getGroupByExpressions().equals(ImmutableList.of(k2, new Add(k1, k2))))
);
}
}

View File

@ -37,23 +37,35 @@ import org.apache.doris.nereids.util.MemoPatternMatchSupported;
import org.apache.doris.nereids.util.MemoTestUtils;
import org.apache.doris.nereids.util.PlanChecker;
import org.apache.doris.nereids.util.PlanConstructor;
import org.apache.doris.utframe.TestWithFeService;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInstance;
import java.util.List;
@TestInstance(TestInstance.Lifecycle.PER_CLASS)
public class NormalizeAggregateTest implements MemoPatternMatchSupported {
public class NormalizeAggregateTest extends TestWithFeService implements MemoPatternMatchSupported {
private LogicalPlan rStudent;
@BeforeAll
public final void beforeAll() {
@Override
protected void runBeforeAll() throws Exception {
rStudent = new LogicalOlapScan(StatementScopeIdGenerator.newRelationId(), PlanConstructor.student,
ImmutableList.of());
createDatabase("test");
connectContext.setDatabase("default_cluster:test");
createTables(
"CREATE TABLE IF NOT EXISTS t1 (\n"
+ " id int not null,\n"
+ " name char\n"
+ ")\n"
+ "DUPLICATE KEY(id)\n"
+ "DISTRIBUTED BY HASH(id) BUCKETS 10\n"
+ "PROPERTIES (\"replication_num\" = \"1\")\n"
);
connectContext.getSessionVariable().setDisableNereidsRules("PRUNE_EMPTY_PARTITION");
}
/*-
@ -190,4 +202,105 @@ public class NormalizeAggregateTest implements MemoPatternMatchSupported {
);
}
// add test for agg eliminate const
@Test
void testEliminateGroupByConst() {
String sql = "select id ,1, 'abc' from t1 group by 1,2,3";
PlanChecker.from(connectContext)
.analyze(sql)
.rewrite()
.matches(logicalAggregate().when(aggregate -> aggregate.getGroupByExpressions().size() == 1));
}
@Test
void useTinyIntEliminateGroupByConst() {
String sql = "select 1, 'abc' from t1 group by 1,2";
PlanChecker.from(connectContext)
.analyze(sql)
.rewrite()
.matches(logicalAggregate().when(agg ->
agg.getGroupByExpressions().size() == 1));
}
@Test
void testMixedConstTypes() {
String sql = "select id, 1, 'abc', true from t1 group by 1, 2, 3, 4";
PlanChecker.from(connectContext)
.analyze(sql)
.rewrite()
.matches(logicalAggregate().when(agg ->
agg.getGroupByExpressions().size() == 1));
}
@Test
void testNullConst() {
String sql = "select id, NULL from t1 group by 1, 2";
PlanChecker.from(connectContext)
.analyze(sql)
.rewrite()
.matches(logicalAggregate().when(agg ->
agg.getGroupByExpressions().size() == 1));
}
@Test
void testTwoNullConst() {
String sql = "select Null, NULL from t1 group by 1, 2";
PlanChecker.from(connectContext)
.analyze(sql)
.rewrite()
.matches(logicalAggregate().when(agg ->
agg.getGroupByExpressions().size() == 1));
}
@Test
void testExpressionConst() {
String sql = "select id, 1+1, CONCAT('a','b') from t1 group by 1, 2, 3";
PlanChecker.from(connectContext)
.analyze(sql)
.rewrite()
.matches(logicalAggregate().when(agg ->
agg.getGroupByExpressions().size() == 1));
}
@Test
void testFunctionCallConst() {
String sql = "select id, NOW(), PI() from t1 group by 1, 2, 3";
PlanChecker.from(connectContext)
.analyze(sql)
.rewrite()
.matches(logicalAggregate().when(agg ->
agg.getGroupByExpressions().size() == 1));
}
@Test
void testDifferentOrder() {
String sql = "select 1, id, 'abc' from t1 group by 2, 1, 3";
PlanChecker.from(connectContext)
.analyze(sql)
.rewrite()
.matches(logicalAggregate().when(agg ->
agg.getGroupByExpressions().size() == 1));
}
@Test
void testDuplicateConst() {
String sql = "select id, 1, 1 from t1 group by 1, 2, 3";
PlanChecker.from(connectContext)
.analyze(sql)
.rewrite()
.matches(logicalAggregate().when(agg ->
agg.getGroupByExpressions().size() == 1));
}
@Test
void testWithAggFunction() {
String sql = "select 'abc', 1, COUNT(*) from t1 group by 1, 2";
PlanChecker.from(connectContext)
.analyze(sql)
.rewrite()
.matches(logicalAggregate().when(agg ->
agg.getGroupByExpressions().size() == 1
&& agg.getOutputExpressions().stream().anyMatch(e -> e.toString().contains("COUNT"))));
}
}