[feat](Nereids): Simplify Agg GroupBy (#18887)
This commit is contained in:
@ -72,6 +72,7 @@ import org.apache.doris.nereids.rules.rewrite.logical.SemiJoinAggTransposeProjec
|
||||
import org.apache.doris.nereids.rules.rewrite.logical.SemiJoinCommute;
|
||||
import org.apache.doris.nereids.rules.rewrite.logical.SemiJoinLogicalJoinTranspose;
|
||||
import org.apache.doris.nereids.rules.rewrite.logical.SemiJoinLogicalJoinTransposeProject;
|
||||
import org.apache.doris.nereids.rules.rewrite.logical.SimplifyAggGroupBy;
|
||||
import org.apache.doris.nereids.rules.rewrite.logical.SplitLimit;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
@ -155,6 +156,9 @@ public class NereidsRewriter extends BatchRewriteJob {
|
||||
// e.g. sum(sum(c1)) over(partition by avg(c1))
|
||||
new NormalizeAggregate(),
|
||||
new CheckAndStandardizeWindowFunctionAndFrame()
|
||||
),
|
||||
topDown(
|
||||
new SimplifyAggGroupBy()
|
||||
)
|
||||
),
|
||||
|
||||
|
||||
@ -95,6 +95,7 @@ public enum RuleType {
|
||||
EXTRACT_AND_NORMALIZE_WINDOW_EXPRESSIONS(RuleTypeClass.REWRITE),
|
||||
CHECK_AND_STANDARDIZE_WINDOW_FUNCTION_AND_FRAME(RuleTypeClass.REWRITE),
|
||||
AGGREGATE_DISASSEMBLE(RuleTypeClass.REWRITE),
|
||||
SIMPLIFY_AGG_GROUP_BY(RuleTypeClass.REWRITE),
|
||||
DISTINCT_AGGREGATE_DISASSEMBLE(RuleTypeClass.REWRITE),
|
||||
LOGICAL_SUB_QUERY_ALIAS_TO_LOGICAL_PROJECT(RuleTypeClass.REWRITE),
|
||||
ELIMINATE_GROUP_BY_CONSTANT(RuleTypeClass.REWRITE),
|
||||
|
||||
@ -0,0 +1,68 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package org.apache.doris.nereids.rules.rewrite.logical;
|
||||
|
||||
import org.apache.doris.nereids.rules.Rule;
|
||||
import org.apache.doris.nereids.rules.RuleType;
|
||||
import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory;
|
||||
import org.apache.doris.nereids.trees.expressions.BinaryArithmetic;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.Slot;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.Literal;
|
||||
import org.apache.doris.nereids.util.ExpressionUtils;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* Simplify Aggregate group by Multiple to One. For example
|
||||
* <p>
|
||||
* GROUP BY ClientIP, ClientIP - 1, ClientIP - 2, ClientIP - 3
|
||||
* -->
|
||||
* GROUP BY ClientIP
|
||||
*/
|
||||
public class SimplifyAggGroupBy extends OneRewriteRuleFactory {
|
||||
@Override
|
||||
public Rule build() {
|
||||
return logicalAggregate()
|
||||
.when(agg -> agg.getGroupByExpressions().size() > 1)
|
||||
.when(agg -> agg.getGroupByExpressions().stream().allMatch(this::isBinaryArithmeticSlot))
|
||||
.then(agg -> {
|
||||
Set<Expression> slots = agg.getGroupByExpressions().stream()
|
||||
.flatMap(e -> e.getInputSlots().stream()).collect(Collectors.toSet());
|
||||
if (slots.size() != 1) {
|
||||
return null;
|
||||
}
|
||||
return agg.withGroupByAndOutput(ImmutableList.copyOf(slots), agg.getOutputExpressions());
|
||||
})
|
||||
.toRule(RuleType.SIMPLIFY_AGG_GROUP_BY);
|
||||
}
|
||||
|
||||
private boolean isBinaryArithmeticSlot(Expression expr) {
|
||||
if (expr instanceof Slot) {
|
||||
return true;
|
||||
}
|
||||
if (!(expr instanceof BinaryArithmetic)) {
|
||||
return false;
|
||||
}
|
||||
return ExpressionUtils.isSlotOrCastOnSlot(expr.child(0)).isPresent() && expr.child(1) instanceof Literal
|
||||
|| ExpressionUtils.isSlotOrCastOnSlot(expr.child(1)).isPresent() && expr.child(0) instanceof Literal;
|
||||
}
|
||||
}
|
||||
@ -60,6 +60,10 @@ public abstract class Expression extends AbstractTreeNode<Expression> implements
|
||||
super(Optional.empty(), children);
|
||||
}
|
||||
|
||||
public Alias alias(String alias) {
|
||||
return new Alias(this, alias);
|
||||
}
|
||||
|
||||
/**
|
||||
* check input data types
|
||||
*/
|
||||
|
||||
@ -111,12 +111,6 @@ public class ExpressionUtils {
|
||||
}
|
||||
}
|
||||
|
||||
public static Set<Expression> extractToSet(Expression predicate) {
|
||||
Set<Expression> result = Sets.newHashSet();
|
||||
extract(predicate.getClass(), predicate, result);
|
||||
return result;
|
||||
}
|
||||
|
||||
public static Optional<Expression> optionalAnd(List<Expression> expressions) {
|
||||
if (expressions.isEmpty()) {
|
||||
return Optional.empty();
|
||||
|
||||
Reference in New Issue
Block a user