[fix](nereids)logicalhaving is in wrong place after logicalagg and logicalwindow (#29463)

This commit is contained in:
starocean999
2024-01-10 09:57:29 +08:00
committed by yiguolei
parent 883d6dfc73
commit e17809a684
7 changed files with 214 additions and 130 deletions

View File

@ -33,6 +33,7 @@ import org.apache.doris.nereids.rules.analysis.CollectSubQueryAlias;
import org.apache.doris.nereids.rules.analysis.EliminateGroupByConstant;
import org.apache.doris.nereids.rules.analysis.EliminateLogicalSelectHint;
import org.apache.doris.nereids.rules.analysis.FillUpMissingSlots;
import org.apache.doris.nereids.rules.analysis.HavingToFilter;
import org.apache.doris.nereids.rules.analysis.LeadingJoin;
import org.apache.doris.nereids.rules.analysis.NormalizeAggregate;
import org.apache.doris.nereids.rules.analysis.NormalizeRepeat;
@ -161,6 +162,7 @@ public class Analyzer extends AbstractBatchJobExecutor {
bottomUp(new CheckAnalysis()),
topDown(new EliminateGroupByConstant()),
topDown(new NormalizeAggregate()),
topDown(new HavingToFilter()),
bottomUp(new SemiJoinCommute()),
bottomUp(
new CollectSubQueryAlias(),

View File

@ -63,6 +63,7 @@ public enum RuleType {
RESOLVE_PROJECT_ALIAS(RuleTypeClass.REWRITE),
RESOLVE_AGGREGATE_ALIAS(RuleTypeClass.REWRITE),
PROJECT_TO_GLOBAL_AGGREGATE(RuleTypeClass.REWRITE),
HAVING_TO_FILTER(RuleTypeClass.REWRITE),
ONE_ROW_RELATION_EXTRACT_AGGREGATE(RuleTypeClass.REWRITE),
PROJECT_WITH_DISTINCT_TO_AGGREGATE(RuleTypeClass.REWRITE),
AVG_DISTINCT_TO_SUM_DIV_COUNT(RuleTypeClass.REWRITE),

View File

@ -30,7 +30,6 @@ import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.algebra.Aggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalHaving;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.logical.LogicalSort;
@ -174,10 +173,6 @@ public class FillUpMissingSlots implements AnalysisRuleFactory {
return new LogicalProject<>(ImmutableList.copyOf(project.getOutput()),
having.withChildren(new LogicalProject<>(projects, project.child())));
})
),
// Convert having to filter
RuleType.FILL_UP_HAVING_PROJECT.build(
logicalHaving().then(having -> new LogicalFilter<>(having.getConjuncts(), having.child()))
)
);
}

View File

@ -0,0 +1,34 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.analysis;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
/**
* HavingToFilter
*/
public class HavingToFilter extends OneAnalysisRuleFactory {
@Override
public Rule build() {
return logicalHaving()
.then(having -> new LogicalFilter<>(having.getConjuncts(), having.child()))
.toRule(RuleType.HAVING_TO_FILTER);
}
}

View File

@ -21,7 +21,7 @@ import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.rewrite.NormalizeToSlot;
import org.apache.doris.nereids.rules.rewrite.NormalizeToSlot.NormalizeToSlotContext;
import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory;
import org.apache.doris.nereids.rules.rewrite.RewriteRuleFactory;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
@ -34,6 +34,8 @@ import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionVisitor;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalHaving;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.util.ExpressionUtils;
@ -46,6 +48,7 @@ import com.google.common.collect.Sets;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
@ -97,141 +100,164 @@ import java.util.stream.Collectors;
* </pre>
* More example could get from UT {NormalizeAggregateTest}
*/
public class NormalizeAggregate extends OneRewriteRuleFactory implements NormalizeToSlot {
public class NormalizeAggregate implements RewriteRuleFactory, NormalizeToSlot {
@Override
public Rule build() {
return logicalAggregate().whenNot(LogicalAggregate::isNormalized).then(aggregate -> {
// The LogicalAggregate node may contain window agg functions and usual agg functions
// we call window agg functions as window-agg and usual agg functions as trival-agg for short
// This rule simplify LogicalAggregate node by:
// 1. Push down some exprs from old LogicalAggregate node to a new child LogicalProject Node,
// 2. create a new LogicalAggregate with normalized group by exprs and trival-aggs
// 3. Pull up normalized old LogicalAggregate's output exprs to a new parent LogicalProject Node
// Push down exprs:
// 1. all group by exprs
// 2. child contains subquery expr in trival-agg
// 3. child contains window expr in trival-agg
// 4. all input slots of trival-agg
// 5. expr(including subquery) in distinct trival-agg
// Normalize LogicalAggregate's output.
// 1. normalize group by exprs by outputs of bottom LogicalProject
// 2. normalize trival-aggs by outputs of bottom LogicalProject
// 3. build normalized agg outputs
// Pull up exprs:
// normalize all output exprs in old LogicalAggregate to build a parent project node, typically includes:
// 1. simple slots
// 2. aliases
// a. alias with no aggs child
// b. alias with trival-agg child
// c. alias with window-agg
public List<Rule> buildRules() {
return ImmutableList.of(
logicalHaving(logicalAggregate().whenNot(LogicalAggregate::isNormalized))
.then(having -> normalizeAgg(having.child(), Optional.of(having)))
.toRule(RuleType.NORMALIZE_AGGREGATE),
logicalAggregate().whenNot(LogicalAggregate::isNormalized)
.then(aggregate -> normalizeAgg(aggregate, Optional.empty()))
.toRule(RuleType.NORMALIZE_AGGREGATE));
}
// Push down exprs:
// collect group by exprs
Set<Expression> groupingByExprs =
ImmutableSet.copyOf(aggregate.getGroupByExpressions());
private LogicalPlan normalizeAgg(LogicalAggregate<Plan> aggregate, Optional<LogicalHaving<?>> having) {
// The LogicalAggregate node may contain window agg functions and usual agg functions
// we call window agg functions as window-agg and usual agg functions as trival-agg for short
// This rule simplify LogicalAggregate node by:
// 1. Push down some exprs from old LogicalAggregate node to a new child LogicalProject Node,
// 2. create a new LogicalAggregate with normalized group by exprs and trival-aggs
// 3. Pull up normalized old LogicalAggregate's output exprs to a new parent LogicalProject Node
// Push down exprs:
// 1. all group by exprs
// 2. child contains subquery expr in trival-agg
// 3. child contains window expr in trival-agg
// 4. all input slots of trival-agg
// 5. expr(including subquery) in distinct trival-agg
// Normalize LogicalAggregate's output.
// 1. normalize group by exprs by outputs of bottom LogicalProject
// 2. normalize trival-aggs by outputs of bottom LogicalProject
// 3. build normalized agg outputs
// Pull up exprs:
// normalize all output exprs in old LogicalAggregate to build a parent project node, typically includes:
// 1. simple slots
// 2. aliases
// a. alias with no aggs child
// b. alias with trival-agg child
// c. alias with window-agg
// collect all trival-agg
List<NamedExpression> aggregateOutput = aggregate.getOutputExpressions();
List<AggregateFunction> aggFuncs = Lists.newArrayList();
aggregateOutput.forEach(o -> o.accept(CollectNonWindowedAggFuncs.INSTANCE, aggFuncs));
// Push down exprs:
// collect group by exprs
Set<Expression> groupingByExprs =
ImmutableSet.copyOf(aggregate.getGroupByExpressions());
// split non-distinct agg child as two part
// TRUE part 1: need push down itself, if it contains subqury or window expression
// FALSE part 2: need push down its input slots, if it DOES NOT contain subqury or window expression
Map<Boolean, Set<Expression>> categorizedNoDistinctAggsChildren = aggFuncs.stream()
.filter(aggFunc -> !aggFunc.isDistinct())
.flatMap(agg -> agg.children().stream())
.collect(Collectors.groupingBy(
child -> child.containsType(SubqueryExpr.class, WindowExpression.class),
Collectors.toSet()));
// collect all trival-agg
List<NamedExpression> aggregateOutput = aggregate.getOutputExpressions();
List<AggregateFunction> aggFuncs = Lists.newArrayList();
aggregateOutput.forEach(o -> o.accept(CollectNonWindowedAggFuncs.INSTANCE, aggFuncs));
// split distinct agg child as two parts
// TRUE part 1: need push down itself, if it is NOT SlotReference or Literal
// FALSE part 2: need push down its input slots, if it is SlotReference or Literal
Map<Boolean, Set<Expression>> categorizedDistinctAggsChildren = aggFuncs.stream()
.filter(aggFunc -> aggFunc.isDistinct()).flatMap(agg -> agg.children().stream())
.collect(Collectors.groupingBy(
child -> !(child instanceof SlotReference || child instanceof Literal),
Collectors.toSet()));
// split non-distinct agg child as two part
// TRUE part 1: need push down itself, if it contains subqury or window expression
// FALSE part 2: need push down its input slots, if it DOES NOT contain subqury or window expression
Map<Boolean, Set<Expression>> categorizedNoDistinctAggsChildren = aggFuncs.stream()
.filter(aggFunc -> !aggFunc.isDistinct())
.flatMap(agg -> agg.children().stream())
.collect(Collectors.groupingBy(
child -> child.containsType(SubqueryExpr.class, WindowExpression.class),
Collectors.toSet()));
Set<Expression> needPushSelf = Sets.union(
categorizedNoDistinctAggsChildren.getOrDefault(true, new HashSet<>()),
categorizedDistinctAggsChildren.getOrDefault(true, new HashSet<>()));
Set<Slot> needPushInputSlots = ExpressionUtils.getInputSlotSet(Sets.union(
categorizedNoDistinctAggsChildren.getOrDefault(false, new HashSet<>()),
categorizedDistinctAggsChildren.getOrDefault(false, new HashSet<>())));
// split distinct agg child as two parts
// TRUE part 1: need push down itself, if it is NOT SlotReference or Literal
// FALSE part 2: need push down its input slots, if it is SlotReference or Literal
Map<Boolean, Set<Expression>> categorizedDistinctAggsChildren = aggFuncs.stream()
.filter(aggFunc -> aggFunc.isDistinct()).flatMap(agg -> agg.children().stream())
.collect(Collectors.groupingBy(
child -> !(child instanceof SlotReference || child instanceof Literal),
Collectors.toSet()));
Set<Alias> existsAlias =
ExpressionUtils.mutableCollect(aggregateOutput, Alias.class::isInstance);
Set<Expression> needPushSelf = Sets.union(
categorizedNoDistinctAggsChildren.getOrDefault(true, new HashSet<>()),
categorizedDistinctAggsChildren.getOrDefault(true, new HashSet<>()));
Set<Slot> needPushInputSlots = ExpressionUtils.getInputSlotSet(Sets.union(
categorizedNoDistinctAggsChildren.getOrDefault(false, new HashSet<>()),
categorizedDistinctAggsChildren.getOrDefault(false, new HashSet<>())));
// push down 3 kinds of exprs, these pushed exprs will be used to normalize agg output later
// 1. group by exprs
// 2. trivalAgg children
// 3. trivalAgg input slots
Set<Expression> allPushDownExprs =
Sets.union(groupingByExprs, Sets.union(needPushSelf, needPushInputSlots));
NormalizeToSlotContext bottomSlotContext =
NormalizeToSlotContext.buildContext(existsAlias, allPushDownExprs);
Set<NamedExpression> pushedGroupByExprs =
bottomSlotContext.pushDownToNamedExpression(groupingByExprs);
Set<NamedExpression> pushedTrivalAggChildren =
bottomSlotContext.pushDownToNamedExpression(needPushSelf);
Set<NamedExpression> pushedTrivalAggInputSlots =
bottomSlotContext.pushDownToNamedExpression(needPushInputSlots);
Set<NamedExpression> bottomProjects = Sets.union(pushedGroupByExprs,
Sets.union(pushedTrivalAggChildren, pushedTrivalAggInputSlots));
Set<Alias> existsAlias =
ExpressionUtils.mutableCollect(aggregateOutput, Alias.class::isInstance);
// create bottom project
Plan bottomPlan;
if (!bottomProjects.isEmpty()) {
bottomPlan = new LogicalProject<>(ImmutableList.copyOf(bottomProjects),
aggregate.child());
// push down 3 kinds of exprs, these pushed exprs will be used to normalize agg output later
// 1. group by exprs
// 2. trivalAgg children
// 3. trivalAgg input slots
Set<Expression> allPushDownExprs =
Sets.union(groupingByExprs, Sets.union(needPushSelf, needPushInputSlots));
NormalizeToSlotContext bottomSlotContext =
NormalizeToSlotContext.buildContext(existsAlias, allPushDownExprs);
Set<NamedExpression> pushedGroupByExprs =
bottomSlotContext.pushDownToNamedExpression(groupingByExprs);
Set<NamedExpression> pushedTrivalAggChildren =
bottomSlotContext.pushDownToNamedExpression(needPushSelf);
Set<NamedExpression> pushedTrivalAggInputSlots =
bottomSlotContext.pushDownToNamedExpression(needPushInputSlots);
Set<NamedExpression> bottomProjects = Sets.union(pushedGroupByExprs,
Sets.union(pushedTrivalAggChildren, pushedTrivalAggInputSlots));
// create bottom project
Plan bottomPlan;
if (!bottomProjects.isEmpty()) {
bottomPlan = new LogicalProject<>(ImmutableList.copyOf(bottomProjects),
aggregate.child());
} else {
bottomPlan = aggregate.child();
}
// use group by context to normalize agg functions to process
// sql like: select sum(a + 1) from t group by a + 1
//
// before normalize:
// agg(output: sum(a[#0] + 1)[#2], group_by: alias(a + 1)[#1])
// +-- project(a[#0], (a[#0] + 1)[#1])
//
// after normalize:
// agg(output: sum(alias(a + 1)[#1])[#2], group_by: alias(a + 1)[#1])
// +-- project((a[#0] + 1)[#1])
// normalize group by exprs by bottomProjects
List<Expression> normalizedGroupExprs =
bottomSlotContext.normalizeToUseSlotRef(groupingByExprs);
// normalize trival-aggs by bottomProjects
List<AggregateFunction> normalizedAggFuncs =
bottomSlotContext.normalizeToUseSlotRef(aggFuncs);
// build normalized agg output
NormalizeToSlotContext normalizedAggFuncsToSlotContext =
NormalizeToSlotContext.buildContext(existsAlias, normalizedAggFuncs);
// agg output include 2 parts
// pushedGroupByExprs and normalized agg functions
List<NamedExpression> normalizedAggOutput = ImmutableList.<NamedExpression>builder()
.addAll(pushedGroupByExprs.stream().map(NamedExpression::toSlot).iterator())
.addAll(normalizedAggFuncsToSlotContext
.pushDownToNamedExpression(normalizedAggFuncs))
.build();
// create new agg node
LogicalAggregate newAggregate =
aggregate.withNormalized(normalizedGroupExprs, normalizedAggOutput, bottomPlan);
// create upper projects by normalize all output exprs in old LogicalAggregate
List<NamedExpression> upperProjects = normalizeOutput(aggregateOutput,
bottomSlotContext, normalizedAggFuncsToSlotContext);
// create a parent project node
LogicalProject<Plan> project = new LogicalProject<>(upperProjects, newAggregate);
if (having.isPresent()) {
if (upperProjects.stream().anyMatch(expr -> expr.anyMatch(WindowExpression.class::isInstance))) {
// when project contains window functions, in order to get the correct result
// push having through project to make it the parent node of logicalAgg
return project.withChildren(ImmutableList.of(
new LogicalHaving<>(
ExpressionUtils.replace(having.get().getConjuncts(), project.getAliasToProducer()),
project.child()
)));
} else {
bottomPlan = aggregate.child();
return (LogicalPlan) having.get().withChildren(project);
}
// use group by context to normalize agg functions to process
// sql like: select sum(a + 1) from t group by a + 1
//
// before normalize:
// agg(output: sum(a[#0] + 1)[#2], group_by: alias(a + 1)[#1])
// +-- project(a[#0], (a[#0] + 1)[#1])
//
// after normalize:
// agg(output: sum(alias(a + 1)[#1])[#2], group_by: alias(a + 1)[#1])
// +-- project((a[#0] + 1)[#1])
// normalize group by exprs by bottomProjects
List<Expression> normalizedGroupExprs =
bottomSlotContext.normalizeToUseSlotRef(groupingByExprs);
// normalize trival-aggs by bottomProjects
List<AggregateFunction> normalizedAggFuncs =
bottomSlotContext.normalizeToUseSlotRef(aggFuncs);
// build normalized agg output
NormalizeToSlotContext normalizedAggFuncsToSlotContext =
NormalizeToSlotContext.buildContext(existsAlias, normalizedAggFuncs);
// agg output include 2 parts
// pushedGroupByExprs and normalized agg functions
List<NamedExpression> normalizedAggOutput = ImmutableList.<NamedExpression>builder()
.addAll(pushedGroupByExprs.stream().map(NamedExpression::toSlot).iterator())
.addAll(normalizedAggFuncsToSlotContext
.pushDownToNamedExpression(normalizedAggFuncs))
.build();
// create new agg node
LogicalAggregate newAggregate =
aggregate.withNormalized(normalizedGroupExprs, normalizedAggOutput, bottomPlan);
// create upper projects by normalize all output exprs in old LogicalAggregate
List<NamedExpression> upperProjects = normalizeOutput(aggregateOutput,
bottomSlotContext, normalizedAggFuncsToSlotContext);
// create a parent project node
return new LogicalProject<>(upperProjects, newAggregate);
}).toRule(RuleType.NORMALIZE_AGGREGATE);
} else {
return project;
}
}
private List<NamedExpression> normalizeOutput(List<NamedExpression> aggregateOutput,