|
|
|
|
@ -21,7 +21,7 @@ import org.apache.doris.nereids.rules.Rule;
|
|
|
|
|
import org.apache.doris.nereids.rules.RuleType;
|
|
|
|
|
import org.apache.doris.nereids.rules.rewrite.NormalizeToSlot;
|
|
|
|
|
import org.apache.doris.nereids.rules.rewrite.NormalizeToSlot.NormalizeToSlotContext;
|
|
|
|
|
import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory;
|
|
|
|
|
import org.apache.doris.nereids.rules.rewrite.RewriteRuleFactory;
|
|
|
|
|
import org.apache.doris.nereids.trees.expressions.Alias;
|
|
|
|
|
import org.apache.doris.nereids.trees.expressions.Expression;
|
|
|
|
|
import org.apache.doris.nereids.trees.expressions.NamedExpression;
|
|
|
|
|
@ -34,6 +34,8 @@ import org.apache.doris.nereids.trees.expressions.literal.Literal;
|
|
|
|
|
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionVisitor;
|
|
|
|
|
import org.apache.doris.nereids.trees.plans.Plan;
|
|
|
|
|
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
|
|
|
|
|
import org.apache.doris.nereids.trees.plans.logical.LogicalHaving;
|
|
|
|
|
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
|
|
|
|
|
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
|
|
|
|
|
import org.apache.doris.nereids.util.ExpressionUtils;
|
|
|
|
|
|
|
|
|
|
@ -46,6 +48,7 @@ import com.google.common.collect.Sets;
|
|
|
|
|
import java.util.HashSet;
|
|
|
|
|
import java.util.List;
|
|
|
|
|
import java.util.Map;
|
|
|
|
|
import java.util.Optional;
|
|
|
|
|
import java.util.Set;
|
|
|
|
|
import java.util.stream.Collectors;
|
|
|
|
|
|
|
|
|
|
@ -97,141 +100,164 @@ import java.util.stream.Collectors;
|
|
|
|
|
* </pre>
|
|
|
|
|
* More example could get from UT {NormalizeAggregateTest}
|
|
|
|
|
*/
|
|
|
|
|
public class NormalizeAggregate extends OneRewriteRuleFactory implements NormalizeToSlot {
|
|
|
|
|
public class NormalizeAggregate implements RewriteRuleFactory, NormalizeToSlot {
|
|
|
|
|
@Override
|
|
|
|
|
public Rule build() {
|
|
|
|
|
return logicalAggregate().whenNot(LogicalAggregate::isNormalized).then(aggregate -> {
|
|
|
|
|
// The LogicalAggregate node may contain window agg functions and usual agg functions
|
|
|
|
|
// we call window agg functions as window-agg and usual agg functions as trival-agg for short
|
|
|
|
|
// This rule simplify LogicalAggregate node by:
|
|
|
|
|
// 1. Push down some exprs from old LogicalAggregate node to a new child LogicalProject Node,
|
|
|
|
|
// 2. create a new LogicalAggregate with normalized group by exprs and trival-aggs
|
|
|
|
|
// 3. Pull up normalized old LogicalAggregate's output exprs to a new parent LogicalProject Node
|
|
|
|
|
// Push down exprs:
|
|
|
|
|
// 1. all group by exprs
|
|
|
|
|
// 2. child contains subquery expr in trival-agg
|
|
|
|
|
// 3. child contains window expr in trival-agg
|
|
|
|
|
// 4. all input slots of trival-agg
|
|
|
|
|
// 5. expr(including subquery) in distinct trival-agg
|
|
|
|
|
// Normalize LogicalAggregate's output.
|
|
|
|
|
// 1. normalize group by exprs by outputs of bottom LogicalProject
|
|
|
|
|
// 2. normalize trival-aggs by outputs of bottom LogicalProject
|
|
|
|
|
// 3. build normalized agg outputs
|
|
|
|
|
// Pull up exprs:
|
|
|
|
|
// normalize all output exprs in old LogicalAggregate to build a parent project node, typically includes:
|
|
|
|
|
// 1. simple slots
|
|
|
|
|
// 2. aliases
|
|
|
|
|
// a. alias with no aggs child
|
|
|
|
|
// b. alias with trival-agg child
|
|
|
|
|
// c. alias with window-agg
|
|
|
|
|
public List<Rule> buildRules() {
|
|
|
|
|
return ImmutableList.of(
|
|
|
|
|
logicalHaving(logicalAggregate().whenNot(LogicalAggregate::isNormalized))
|
|
|
|
|
.then(having -> normalizeAgg(having.child(), Optional.of(having)))
|
|
|
|
|
.toRule(RuleType.NORMALIZE_AGGREGATE),
|
|
|
|
|
logicalAggregate().whenNot(LogicalAggregate::isNormalized)
|
|
|
|
|
.then(aggregate -> normalizeAgg(aggregate, Optional.empty()))
|
|
|
|
|
.toRule(RuleType.NORMALIZE_AGGREGATE));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Push down exprs:
|
|
|
|
|
// collect group by exprs
|
|
|
|
|
Set<Expression> groupingByExprs =
|
|
|
|
|
ImmutableSet.copyOf(aggregate.getGroupByExpressions());
|
|
|
|
|
private LogicalPlan normalizeAgg(LogicalAggregate<Plan> aggregate, Optional<LogicalHaving<?>> having) {
|
|
|
|
|
// The LogicalAggregate node may contain window agg functions and usual agg functions
|
|
|
|
|
// we call window agg functions as window-agg and usual agg functions as trival-agg for short
|
|
|
|
|
// This rule simplify LogicalAggregate node by:
|
|
|
|
|
// 1. Push down some exprs from old LogicalAggregate node to a new child LogicalProject Node,
|
|
|
|
|
// 2. create a new LogicalAggregate with normalized group by exprs and trival-aggs
|
|
|
|
|
// 3. Pull up normalized old LogicalAggregate's output exprs to a new parent LogicalProject Node
|
|
|
|
|
// Push down exprs:
|
|
|
|
|
// 1. all group by exprs
|
|
|
|
|
// 2. child contains subquery expr in trival-agg
|
|
|
|
|
// 3. child contains window expr in trival-agg
|
|
|
|
|
// 4. all input slots of trival-agg
|
|
|
|
|
// 5. expr(including subquery) in distinct trival-agg
|
|
|
|
|
// Normalize LogicalAggregate's output.
|
|
|
|
|
// 1. normalize group by exprs by outputs of bottom LogicalProject
|
|
|
|
|
// 2. normalize trival-aggs by outputs of bottom LogicalProject
|
|
|
|
|
// 3. build normalized agg outputs
|
|
|
|
|
// Pull up exprs:
|
|
|
|
|
// normalize all output exprs in old LogicalAggregate to build a parent project node, typically includes:
|
|
|
|
|
// 1. simple slots
|
|
|
|
|
// 2. aliases
|
|
|
|
|
// a. alias with no aggs child
|
|
|
|
|
// b. alias with trival-agg child
|
|
|
|
|
// c. alias with window-agg
|
|
|
|
|
|
|
|
|
|
// collect all trival-agg
|
|
|
|
|
List<NamedExpression> aggregateOutput = aggregate.getOutputExpressions();
|
|
|
|
|
List<AggregateFunction> aggFuncs = Lists.newArrayList();
|
|
|
|
|
aggregateOutput.forEach(o -> o.accept(CollectNonWindowedAggFuncs.INSTANCE, aggFuncs));
|
|
|
|
|
// Push down exprs:
|
|
|
|
|
// collect group by exprs
|
|
|
|
|
Set<Expression> groupingByExprs =
|
|
|
|
|
ImmutableSet.copyOf(aggregate.getGroupByExpressions());
|
|
|
|
|
|
|
|
|
|
// split non-distinct agg child as two part
|
|
|
|
|
// TRUE part 1: need push down itself, if it contains subqury or window expression
|
|
|
|
|
// FALSE part 2: need push down its input slots, if it DOES NOT contain subqury or window expression
|
|
|
|
|
Map<Boolean, Set<Expression>> categorizedNoDistinctAggsChildren = aggFuncs.stream()
|
|
|
|
|
.filter(aggFunc -> !aggFunc.isDistinct())
|
|
|
|
|
.flatMap(agg -> agg.children().stream())
|
|
|
|
|
.collect(Collectors.groupingBy(
|
|
|
|
|
child -> child.containsType(SubqueryExpr.class, WindowExpression.class),
|
|
|
|
|
Collectors.toSet()));
|
|
|
|
|
// collect all trival-agg
|
|
|
|
|
List<NamedExpression> aggregateOutput = aggregate.getOutputExpressions();
|
|
|
|
|
List<AggregateFunction> aggFuncs = Lists.newArrayList();
|
|
|
|
|
aggregateOutput.forEach(o -> o.accept(CollectNonWindowedAggFuncs.INSTANCE, aggFuncs));
|
|
|
|
|
|
|
|
|
|
// split distinct agg child as two parts
|
|
|
|
|
// TRUE part 1: need push down itself, if it is NOT SlotReference or Literal
|
|
|
|
|
// FALSE part 2: need push down its input slots, if it is SlotReference or Literal
|
|
|
|
|
Map<Boolean, Set<Expression>> categorizedDistinctAggsChildren = aggFuncs.stream()
|
|
|
|
|
.filter(aggFunc -> aggFunc.isDistinct()).flatMap(agg -> agg.children().stream())
|
|
|
|
|
.collect(Collectors.groupingBy(
|
|
|
|
|
child -> !(child instanceof SlotReference || child instanceof Literal),
|
|
|
|
|
Collectors.toSet()));
|
|
|
|
|
// split non-distinct agg child as two part
|
|
|
|
|
// TRUE part 1: need push down itself, if it contains subqury or window expression
|
|
|
|
|
// FALSE part 2: need push down its input slots, if it DOES NOT contain subqury or window expression
|
|
|
|
|
Map<Boolean, Set<Expression>> categorizedNoDistinctAggsChildren = aggFuncs.stream()
|
|
|
|
|
.filter(aggFunc -> !aggFunc.isDistinct())
|
|
|
|
|
.flatMap(agg -> agg.children().stream())
|
|
|
|
|
.collect(Collectors.groupingBy(
|
|
|
|
|
child -> child.containsType(SubqueryExpr.class, WindowExpression.class),
|
|
|
|
|
Collectors.toSet()));
|
|
|
|
|
|
|
|
|
|
Set<Expression> needPushSelf = Sets.union(
|
|
|
|
|
categorizedNoDistinctAggsChildren.getOrDefault(true, new HashSet<>()),
|
|
|
|
|
categorizedDistinctAggsChildren.getOrDefault(true, new HashSet<>()));
|
|
|
|
|
Set<Slot> needPushInputSlots = ExpressionUtils.getInputSlotSet(Sets.union(
|
|
|
|
|
categorizedNoDistinctAggsChildren.getOrDefault(false, new HashSet<>()),
|
|
|
|
|
categorizedDistinctAggsChildren.getOrDefault(false, new HashSet<>())));
|
|
|
|
|
// split distinct agg child as two parts
|
|
|
|
|
// TRUE part 1: need push down itself, if it is NOT SlotReference or Literal
|
|
|
|
|
// FALSE part 2: need push down its input slots, if it is SlotReference or Literal
|
|
|
|
|
Map<Boolean, Set<Expression>> categorizedDistinctAggsChildren = aggFuncs.stream()
|
|
|
|
|
.filter(aggFunc -> aggFunc.isDistinct()).flatMap(agg -> agg.children().stream())
|
|
|
|
|
.collect(Collectors.groupingBy(
|
|
|
|
|
child -> !(child instanceof SlotReference || child instanceof Literal),
|
|
|
|
|
Collectors.toSet()));
|
|
|
|
|
|
|
|
|
|
Set<Alias> existsAlias =
|
|
|
|
|
ExpressionUtils.mutableCollect(aggregateOutput, Alias.class::isInstance);
|
|
|
|
|
Set<Expression> needPushSelf = Sets.union(
|
|
|
|
|
categorizedNoDistinctAggsChildren.getOrDefault(true, new HashSet<>()),
|
|
|
|
|
categorizedDistinctAggsChildren.getOrDefault(true, new HashSet<>()));
|
|
|
|
|
Set<Slot> needPushInputSlots = ExpressionUtils.getInputSlotSet(Sets.union(
|
|
|
|
|
categorizedNoDistinctAggsChildren.getOrDefault(false, new HashSet<>()),
|
|
|
|
|
categorizedDistinctAggsChildren.getOrDefault(false, new HashSet<>())));
|
|
|
|
|
|
|
|
|
|
// push down 3 kinds of exprs, these pushed exprs will be used to normalize agg output later
|
|
|
|
|
// 1. group by exprs
|
|
|
|
|
// 2. trivalAgg children
|
|
|
|
|
// 3. trivalAgg input slots
|
|
|
|
|
Set<Expression> allPushDownExprs =
|
|
|
|
|
Sets.union(groupingByExprs, Sets.union(needPushSelf, needPushInputSlots));
|
|
|
|
|
NormalizeToSlotContext bottomSlotContext =
|
|
|
|
|
NormalizeToSlotContext.buildContext(existsAlias, allPushDownExprs);
|
|
|
|
|
Set<NamedExpression> pushedGroupByExprs =
|
|
|
|
|
bottomSlotContext.pushDownToNamedExpression(groupingByExprs);
|
|
|
|
|
Set<NamedExpression> pushedTrivalAggChildren =
|
|
|
|
|
bottomSlotContext.pushDownToNamedExpression(needPushSelf);
|
|
|
|
|
Set<NamedExpression> pushedTrivalAggInputSlots =
|
|
|
|
|
bottomSlotContext.pushDownToNamedExpression(needPushInputSlots);
|
|
|
|
|
Set<NamedExpression> bottomProjects = Sets.union(pushedGroupByExprs,
|
|
|
|
|
Sets.union(pushedTrivalAggChildren, pushedTrivalAggInputSlots));
|
|
|
|
|
Set<Alias> existsAlias =
|
|
|
|
|
ExpressionUtils.mutableCollect(aggregateOutput, Alias.class::isInstance);
|
|
|
|
|
|
|
|
|
|
// create bottom project
|
|
|
|
|
Plan bottomPlan;
|
|
|
|
|
if (!bottomProjects.isEmpty()) {
|
|
|
|
|
bottomPlan = new LogicalProject<>(ImmutableList.copyOf(bottomProjects),
|
|
|
|
|
aggregate.child());
|
|
|
|
|
// push down 3 kinds of exprs, these pushed exprs will be used to normalize agg output later
|
|
|
|
|
// 1. group by exprs
|
|
|
|
|
// 2. trivalAgg children
|
|
|
|
|
// 3. trivalAgg input slots
|
|
|
|
|
Set<Expression> allPushDownExprs =
|
|
|
|
|
Sets.union(groupingByExprs, Sets.union(needPushSelf, needPushInputSlots));
|
|
|
|
|
NormalizeToSlotContext bottomSlotContext =
|
|
|
|
|
NormalizeToSlotContext.buildContext(existsAlias, allPushDownExprs);
|
|
|
|
|
Set<NamedExpression> pushedGroupByExprs =
|
|
|
|
|
bottomSlotContext.pushDownToNamedExpression(groupingByExprs);
|
|
|
|
|
Set<NamedExpression> pushedTrivalAggChildren =
|
|
|
|
|
bottomSlotContext.pushDownToNamedExpression(needPushSelf);
|
|
|
|
|
Set<NamedExpression> pushedTrivalAggInputSlots =
|
|
|
|
|
bottomSlotContext.pushDownToNamedExpression(needPushInputSlots);
|
|
|
|
|
Set<NamedExpression> bottomProjects = Sets.union(pushedGroupByExprs,
|
|
|
|
|
Sets.union(pushedTrivalAggChildren, pushedTrivalAggInputSlots));
|
|
|
|
|
|
|
|
|
|
// create bottom project
|
|
|
|
|
Plan bottomPlan;
|
|
|
|
|
if (!bottomProjects.isEmpty()) {
|
|
|
|
|
bottomPlan = new LogicalProject<>(ImmutableList.copyOf(bottomProjects),
|
|
|
|
|
aggregate.child());
|
|
|
|
|
} else {
|
|
|
|
|
bottomPlan = aggregate.child();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// use group by context to normalize agg functions to process
|
|
|
|
|
// sql like: select sum(a + 1) from t group by a + 1
|
|
|
|
|
//
|
|
|
|
|
// before normalize:
|
|
|
|
|
// agg(output: sum(a[#0] + 1)[#2], group_by: alias(a + 1)[#1])
|
|
|
|
|
// +-- project(a[#0], (a[#0] + 1)[#1])
|
|
|
|
|
//
|
|
|
|
|
// after normalize:
|
|
|
|
|
// agg(output: sum(alias(a + 1)[#1])[#2], group_by: alias(a + 1)[#1])
|
|
|
|
|
// +-- project((a[#0] + 1)[#1])
|
|
|
|
|
|
|
|
|
|
// normalize group by exprs by bottomProjects
|
|
|
|
|
List<Expression> normalizedGroupExprs =
|
|
|
|
|
bottomSlotContext.normalizeToUseSlotRef(groupingByExprs);
|
|
|
|
|
|
|
|
|
|
// normalize trival-aggs by bottomProjects
|
|
|
|
|
List<AggregateFunction> normalizedAggFuncs =
|
|
|
|
|
bottomSlotContext.normalizeToUseSlotRef(aggFuncs);
|
|
|
|
|
|
|
|
|
|
// build normalized agg output
|
|
|
|
|
NormalizeToSlotContext normalizedAggFuncsToSlotContext =
|
|
|
|
|
NormalizeToSlotContext.buildContext(existsAlias, normalizedAggFuncs);
|
|
|
|
|
|
|
|
|
|
// agg output include 2 parts
|
|
|
|
|
// pushedGroupByExprs and normalized agg functions
|
|
|
|
|
List<NamedExpression> normalizedAggOutput = ImmutableList.<NamedExpression>builder()
|
|
|
|
|
.addAll(pushedGroupByExprs.stream().map(NamedExpression::toSlot).iterator())
|
|
|
|
|
.addAll(normalizedAggFuncsToSlotContext
|
|
|
|
|
.pushDownToNamedExpression(normalizedAggFuncs))
|
|
|
|
|
.build();
|
|
|
|
|
|
|
|
|
|
// create new agg node
|
|
|
|
|
LogicalAggregate newAggregate =
|
|
|
|
|
aggregate.withNormalized(normalizedGroupExprs, normalizedAggOutput, bottomPlan);
|
|
|
|
|
|
|
|
|
|
// create upper projects by normalize all output exprs in old LogicalAggregate
|
|
|
|
|
List<NamedExpression> upperProjects = normalizeOutput(aggregateOutput,
|
|
|
|
|
bottomSlotContext, normalizedAggFuncsToSlotContext);
|
|
|
|
|
|
|
|
|
|
// create a parent project node
|
|
|
|
|
LogicalProject<Plan> project = new LogicalProject<>(upperProjects, newAggregate);
|
|
|
|
|
if (having.isPresent()) {
|
|
|
|
|
if (upperProjects.stream().anyMatch(expr -> expr.anyMatch(WindowExpression.class::isInstance))) {
|
|
|
|
|
// when project contains window functions, in order to get the correct result
|
|
|
|
|
// push having through project to make it the parent node of logicalAgg
|
|
|
|
|
return project.withChildren(ImmutableList.of(
|
|
|
|
|
new LogicalHaving<>(
|
|
|
|
|
ExpressionUtils.replace(having.get().getConjuncts(), project.getAliasToProducer()),
|
|
|
|
|
project.child()
|
|
|
|
|
)));
|
|
|
|
|
} else {
|
|
|
|
|
bottomPlan = aggregate.child();
|
|
|
|
|
return (LogicalPlan) having.get().withChildren(project);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// use group by context to normalize agg functions to process
|
|
|
|
|
// sql like: select sum(a + 1) from t group by a + 1
|
|
|
|
|
//
|
|
|
|
|
// before normalize:
|
|
|
|
|
// agg(output: sum(a[#0] + 1)[#2], group_by: alias(a + 1)[#1])
|
|
|
|
|
// +-- project(a[#0], (a[#0] + 1)[#1])
|
|
|
|
|
//
|
|
|
|
|
// after normalize:
|
|
|
|
|
// agg(output: sum(alias(a + 1)[#1])[#2], group_by: alias(a + 1)[#1])
|
|
|
|
|
// +-- project((a[#0] + 1)[#1])
|
|
|
|
|
|
|
|
|
|
// normalize group by exprs by bottomProjects
|
|
|
|
|
List<Expression> normalizedGroupExprs =
|
|
|
|
|
bottomSlotContext.normalizeToUseSlotRef(groupingByExprs);
|
|
|
|
|
|
|
|
|
|
// normalize trival-aggs by bottomProjects
|
|
|
|
|
List<AggregateFunction> normalizedAggFuncs =
|
|
|
|
|
bottomSlotContext.normalizeToUseSlotRef(aggFuncs);
|
|
|
|
|
|
|
|
|
|
// build normalized agg output
|
|
|
|
|
NormalizeToSlotContext normalizedAggFuncsToSlotContext =
|
|
|
|
|
NormalizeToSlotContext.buildContext(existsAlias, normalizedAggFuncs);
|
|
|
|
|
|
|
|
|
|
// agg output include 2 parts
|
|
|
|
|
// pushedGroupByExprs and normalized agg functions
|
|
|
|
|
List<NamedExpression> normalizedAggOutput = ImmutableList.<NamedExpression>builder()
|
|
|
|
|
.addAll(pushedGroupByExprs.stream().map(NamedExpression::toSlot).iterator())
|
|
|
|
|
.addAll(normalizedAggFuncsToSlotContext
|
|
|
|
|
.pushDownToNamedExpression(normalizedAggFuncs))
|
|
|
|
|
.build();
|
|
|
|
|
|
|
|
|
|
// create new agg node
|
|
|
|
|
LogicalAggregate newAggregate =
|
|
|
|
|
aggregate.withNormalized(normalizedGroupExprs, normalizedAggOutput, bottomPlan);
|
|
|
|
|
|
|
|
|
|
// create upper projects by normalize all output exprs in old LogicalAggregate
|
|
|
|
|
List<NamedExpression> upperProjects = normalizeOutput(aggregateOutput,
|
|
|
|
|
bottomSlotContext, normalizedAggFuncsToSlotContext);
|
|
|
|
|
|
|
|
|
|
// create a parent project node
|
|
|
|
|
return new LogicalProject<>(upperProjects, newAggregate);
|
|
|
|
|
}).toRule(RuleType.NORMALIZE_AGGREGATE);
|
|
|
|
|
} else {
|
|
|
|
|
return project;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private List<NamedExpression> normalizeOutput(List<NamedExpression> aggregateOutput,
|
|
|
|
|
|