[enhancement](nereids)eliminate repeat node if there is only 1 grouping set and no grouping scalar function (#35872)
This commit is contained in:
@ -82,6 +82,12 @@ public class NormalizeRepeat extends OneAnalysisRuleFactory {
|
||||
public Rule build() {
|
||||
return RuleType.NORMALIZE_REPEAT.build(
|
||||
logicalRepeat(any()).when(LogicalRepeat::canBindVirtualSlot).then(repeat -> {
|
||||
if (repeat.getGroupingSets().size() == 1
|
||||
&& ExpressionUtils.collect(repeat.getOutputExpressions(),
|
||||
GroupingScalarFunction.class::isInstance).isEmpty()) {
|
||||
return new LogicalAggregate<>(repeat.getGroupByExpressions(),
|
||||
repeat.getOutputExpressions(), repeat.child());
|
||||
}
|
||||
checkRepeatLegality(repeat);
|
||||
repeat = removeDuplicateColumns(repeat);
|
||||
// add virtual slot, LogicalAggregate and LogicalProject for normalize
|
||||
|
||||
@ -20,6 +20,7 @@ package org.apache.doris.nereids.rules.analysis;
|
||||
import org.apache.doris.nereids.trees.expressions.Alias;
|
||||
import org.apache.doris.nereids.trees.expressions.Slot;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.scalar.GroupingId;
|
||||
import org.apache.doris.nereids.trees.plans.Plan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalRepeat;
|
||||
@ -41,7 +42,7 @@ public class NormalizeRepeatTest implements MemoPatternMatchSupported {
|
||||
Slot name = scan1.getOutput().get(1);
|
||||
Alias alias = new Alias(new Sum(name), "sum(name)");
|
||||
Plan plan = new LogicalRepeat<>(
|
||||
ImmutableList.of(ImmutableList.of(id)),
|
||||
ImmutableList.of(ImmutableList.of(id), ImmutableList.of(name)),
|
||||
ImmutableList.of(idNotNull, alias),
|
||||
scan1
|
||||
);
|
||||
@ -51,4 +52,40 @@ public class NormalizeRepeatTest implements MemoPatternMatchSupported {
|
||||
logicalRepeat().when(repeat -> repeat.getOutputExpressions().get(0).nullable())
|
||||
);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testEliminateRepeat() {
|
||||
Slot id = scan1.getOutput().get(0);
|
||||
Slot idNotNull = id.withNullable(true);
|
||||
Slot name = scan1.getOutput().get(1);
|
||||
Alias alias = new Alias(new Sum(name), "sum(name)");
|
||||
Plan plan = new LogicalRepeat<>(
|
||||
ImmutableList.of(ImmutableList.of(id)),
|
||||
ImmutableList.of(idNotNull, alias),
|
||||
scan1
|
||||
);
|
||||
PlanChecker.from(MemoTestUtils.createCascadesContext(plan))
|
||||
.applyTopDown(new NormalizeRepeat())
|
||||
.matchesFromRoot(
|
||||
logicalAggregate(logicalOlapScan())
|
||||
);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testNoEliminateRepeat() {
|
||||
Slot id = scan1.getOutput().get(0);
|
||||
Slot idNotNull = id.withNullable(true);
|
||||
Slot name = scan1.getOutput().get(1);
|
||||
Alias alias = new Alias(new GroupingId(name), "grouping_id(name)");
|
||||
Plan plan = new LogicalRepeat<>(
|
||||
ImmutableList.of(ImmutableList.of(id)),
|
||||
ImmutableList.of(idNotNull, alias),
|
||||
scan1
|
||||
);
|
||||
PlanChecker.from(MemoTestUtils.createCascadesContext(plan))
|
||||
.applyTopDown(new NormalizeRepeat())
|
||||
.matchesFromRoot(
|
||||
logicalAggregate(logicalRepeat(logicalOlapScan()))
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@ -39,4 +39,19 @@ suite("grouping_normalize_test"){
|
||||
SELECT ROUND( SUM(pk + 1) - 3) col_alias1, MAX( DISTINCT col_int_undef_signed - 5) AS col_alias2, pk + 1 AS col_alias3
|
||||
FROM grouping_normalize_test GROUP BY GROUPING SETS ((col_int_undef_signed,col_int_undef_signed2,pk),()) order by 1,2,3;
|
||||
"""
|
||||
|
||||
explain {
|
||||
sql("SELECT col_int_undef_signed, col_int_undef_signed2, SUM(pk) FROM grouping_normalize_test GROUP BY GROUPING SETS ((col_int_undef_signed, col_int_undef_signed2));")
|
||||
notContains("VREPEAT_NODE")
|
||||
}
|
||||
|
||||
explain {
|
||||
sql("SELECT col_int_undef_signed, col_int_undef_signed2, SUM(pk), grouping_id(col_int_undef_signed2) FROM grouping_normalize_test GROUP BY GROUPING SETS ((col_int_undef_signed, col_int_undef_signed2),());")
|
||||
contains("VREPEAT_NODE")
|
||||
}
|
||||
|
||||
explain {
|
||||
sql("SELECT col_int_undef_signed, col_int_undef_signed2, SUM(pk) FROM grouping_normalize_test GROUP BY GROUPING SETS ((col_int_undef_signed, col_int_undef_signed2));")
|
||||
notContains("VREPEAT_NODE")
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user