[fix](nereids) fix some arrgregate bugs in Nereids (#15326)

1. the agg function without distinct keyword should be a "merge" funcion in threePhaseAggregateWithDistinct
2. use aggregateParam.aggMode.consumeAggregateBuffer instead of aggregateParam.aggPhase.isGlobal() to indicate if a agg function is a "merge" function
3. add an AvgDistinctToSumDivCount rule to support avg(distinct xxx) in some case
4. AggregateExpression's nullable method should call inner function's nullable method.
5. add a bind slot rule to bind pattern "logicalSort(logicalHaving(logicalProject()))"
6. don't remove project node in PhysicalPlanTranslator
7. add a cast to bigint expr when count( distinct datelike type )
8. fallback to old optimizer if bitmap runtime filter is enabled.
9. fix exchange node mem leak
This commit is contained in:
starocean999
2022-12-30 23:07:37 +08:00
committed by GitHub
parent cc7a9d92ad
commit 100834df8b
20 changed files with 202 additions and 57 deletions

View File

@ -444,8 +444,7 @@ public class ExpressionTranslator extends DefaultExpressionVisitor<Expr, PlanTra
true, true, nullableMode
);
boolean isMergeFn = aggregateParam.aggPhase.isGlobal();
boolean isMergeFn = aggregateParam.aggMode.consumeAggregateBuffer;
// create catalog FunctionCallExpr without analyze again
return new FunctionCallExpr(catalogFunction, fnParams, aggFnParams, isMergeFn, catalogArguments);
}

View File

@ -127,7 +127,6 @@ import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.collect.Sets;
import org.apache.commons.collections.CollectionUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@ -184,13 +183,6 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla
rootFragment = exchangeToMergeFragment(rootFragment, context);
}
List<Expr> outputExprs = Lists.newArrayList();
if (physicalPlan instanceof PhysicalProject) {
PhysicalProject project = (PhysicalProject) physicalPlan;
if (isUnnecessaryProject(project) && !projectOnAgg(project)) {
List<Slot> slotReferences = removeAlias(project);
physicalPlan = (PhysicalPlan) physicalPlan.child(0).withOutput(slotReferences);
}
}
physicalPlan.getOutput().stream().map(Slot::getExprId)
.forEach(exprId -> outputExprs.add(context.findSlotRef(exprId)));
rootFragment.setOutputExprs(outputExprs);
@ -1079,7 +1071,7 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla
for (Expr expr : predicateList) {
extractExecSlot(expr, requiredSlotIdList);
}
boolean nonPredicate = CollectionUtils.isEmpty(requiredSlotIdList);
for (Expr expr : execExprList) {
extractExecSlot(expr, requiredSlotIdList);
}
@ -1087,21 +1079,11 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla
TableFunctionNode tableFunctionNode = (TableFunctionNode) inputPlanNode;
tableFunctionNode.setOutputSlotIds(Lists.newArrayList(requiredSlotIdList));
}
if (!hasExprCalc(project) && (!hasPrune(project) || nonPredicate) && !projectOnAgg(project)) {
List<NamedExpression> namedExpressions = project.getProjects();
for (int i = 0; i < namedExpressions.size(); i++) {
NamedExpression n = namedExpressions.get(i);
for (Expression e : n.children()) {
SlotReference slotReference = (SlotReference) e;
SlotRef slotRef = context.findSlotRef(slotReference.getExprId());
context.addExprIdSlotRefPair(slotList.get(i).getExprId(), slotRef);
}
}
} else {
TupleDescriptor tupleDescriptor = generateTupleDesc(slotList, null, context);
inputPlanNode.setProjectList(execExprList);
inputPlanNode.setOutputTupleDesc(tupleDescriptor);
}
TupleDescriptor tupleDescriptor = generateTupleDesc(slotList, null, context);
inputPlanNode.setProjectList(execExprList);
inputPlanNode.setOutputTupleDesc(tupleDescriptor);
if (inputPlanNode instanceof OlapScanNode) {
updateChildSlotsMaterialization(inputPlanNode, requiredSlotIdList, context);
return inputFragment;

View File

@ -18,6 +18,7 @@
package org.apache.doris.nereids.jobs.batch;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.rules.analysis.AvgDistinctToSumDivCount;
import org.apache.doris.nereids.rules.analysis.BindFunction;
import org.apache.doris.nereids.rules.analysis.BindRelation;
import org.apache.doris.nereids.rules.analysis.BindSlotReference;
@ -69,6 +70,7 @@ public class AnalyzeRulesJob extends BatchRulesJob {
// should make sure isDisinct property is correctly passed around.
// please see rule BindSlotReference or BindFunction for example
new ProjectWithDistinctToAggregate(),
new AvgDistinctToSumDivCount(),
new ResolveOrdinalInOrderByAndGroupBy(),
new ReplaceExpressionByChildOutput(),
new HideOneRowRelationUnderUnion(),

View File

@ -68,6 +68,7 @@ public enum RuleType {
RESOLVE_AGGREGATE_ALIAS(RuleTypeClass.REWRITE),
PROJECT_TO_GLOBAL_AGGREGATE(RuleTypeClass.REWRITE),
PROJECT_WITH_DISTINCT_TO_AGGREGATE(RuleTypeClass.REWRITE),
AVG_DISTINCT_TO_SUM_DIV_COUNT(RuleTypeClass.REWRITE),
REGISTER_CTE(RuleTypeClass.REWRITE),
RELATION_AUTHENTICATION(RuleTypeClass.VALIDATION),

View File

@ -0,0 +1,70 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.analysis;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.Divide;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.agg.Avg;
import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.util.ExpressionUtils;
import com.google.common.collect.ImmutableMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
/**
* AvgDistinctToSumDivCount.
*
* change avg( distinct a ) into sum( distinct a ) / count( distinct a ) if there are more than 1 distinct arguments
*/
public class AvgDistinctToSumDivCount extends OneAnalysisRuleFactory {
@Override
public Rule build() {
return RuleType.AVG_DISTINCT_TO_SUM_DIV_COUNT.build(
logicalAggregate().when(agg -> agg.getDistinctArguments().size() > 1).then(agg -> {
Map<AggregateFunction, Expression> avgToSumDivCount = agg.getAggregateFunctions()
.stream()
.filter(function -> function instanceof Avg && function.isDistinct())
.collect(ImmutableMap.toImmutableMap(function -> function, function -> {
Sum sum = new Sum(true, ((Avg) function).child());
Count count = new Count(true, ((Avg) function).child());
Divide divide = new Divide(sum, count);
return divide;
}));
if (!avgToSumDivCount.isEmpty()) {
List<NamedExpression> newOutput = agg.getOutputExpressions().stream()
.map(expr -> (NamedExpression) ExpressionUtils.replace(expr, avgToSumDivCount))
.collect(Collectors.toList());
return new LogicalAggregate<>(agg.getGroupByExpressions(), newOutput,
agg.child());
} else {
return agg;
}
})
);
}
}

View File

@ -278,6 +278,22 @@ public class BindSlotReference implements AnalysisRuleFactory {
return bindSortWithAggregateFunction(sort, aggregate, ctx.cascadesContext);
})
),
RuleType.BINDING_SORT_SLOT.build(
logicalSort(logicalHaving(logicalProject())).when(Plan::canBind).thenApply(ctx -> {
LogicalSort<LogicalHaving<LogicalProject<GroupPlan>>> sort = ctx.root;
List<OrderKey> sortItemList = sort.getOrderKeys()
.stream()
.map(orderKey -> {
Expression item = bind(orderKey.getExpr(), sort.children(), sort, ctx.cascadesContext);
if (item.containsType(UnboundSlot.class)) {
item = bind(item, sort.child().children(), sort, ctx.cascadesContext);
}
return new OrderKey(item, orderKey.isAsc(), orderKey.isNullFirst());
}).collect(Collectors.toList());
return new LogicalSort<>(sortItemList, sort.child());
})
),
RuleType.BINDING_SORT_SLOT.build(
logicalSort(logicalProject()).when(Plan::canBind).thenApply(ctx -> {
LogicalSort<LogicalProject<GroupPlan>> sort = ctx.root;

View File

@ -787,7 +787,7 @@ public class AggregateStrategies implements ImplementationRuleFactory {
false, Optional.empty(), logicalAgg.getLogicalProperties(),
requireGather, logicalAgg.child());
AggregateParam inputToResultParam = new AggregateParam(AggPhase.GLOBAL, AggMode.INPUT_TO_RESULT);
AggregateParam bufferToResultParam = new AggregateParam(AggPhase.GLOBAL, AggMode.BUFFER_TO_RESULT);
List<NamedExpression> globalOutput = ExpressionUtils.rewriteDownShortCircuit(
logicalAgg.getOutputExpressions(), outputChild -> {
if (outputChild instanceof AggregateFunction) {
@ -800,7 +800,7 @@ public class AggregateStrategies implements ImplementationRuleFactory {
} else {
Alias alias = nonDistinctAggFunctionToAliasPhase1.get(outputChild);
return new AggregateExpression(
aggregateFunction, inputToResultParam, alias.toSlot());
aggregateFunction, bufferToResultParam, alias.toSlot());
}
} else {
return outputChild;
@ -809,7 +809,7 @@ public class AggregateStrategies implements ImplementationRuleFactory {
PhysicalHashAggregate<Plan> gatherLocalGatherGlobalAgg
= new PhysicalHashAggregate<>(logicalAgg.getGroupByExpressions(), globalOutput,
Optional.empty(), inputToResultParam, false,
Optional.empty(), bufferToResultParam, false,
logicalAgg.getLogicalProperties(), requireGather, gatherLocalAgg);
if (logicalAgg.getGroupByExpressions().isEmpty()) {
@ -949,7 +949,9 @@ public class AggregateStrategies implements ImplementationRuleFactory {
bufferToResultParam, aggregateFunction.child(0));
} else {
Alias alias = nonDistinctAggFunctionToAliasPhase2.get(expr);
return new AggregateExpression(aggregateFunction, bufferToResultParam, alias.toSlot());
return new AggregateExpression(aggregateFunction,
new AggregateParam(AggPhase.DISTINCT_LOCAL, AggMode.BUFFER_TO_RESULT),
alias.toSlot());
}
}
return expr;

View File

@ -17,6 +17,7 @@
package org.apache.doris.nereids.rules.rewrite.logical;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory;
@ -27,10 +28,13 @@ import org.apache.doris.nereids.trees.plans.JoinHint;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.logical.LogicalApply;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.types.BitmapType;
import org.apache.doris.nereids.util.ExpressionUtils;
import com.google.common.collect.Lists;
import java.util.List;
/**
* Convert InApply to LogicalJoin.
* <p>
@ -52,14 +56,20 @@ public class InApplyToJoin extends OneRewriteRuleFactory {
apply.right().getOutput().get(0));
}
//TODO nereids should support bitmap runtime filter in future
List<Expression> conjuncts = ExpressionUtils.extractConjunction(predicate);
if (conjuncts.stream().anyMatch(expression -> expression.children().stream()
.anyMatch(expr -> expr.getDataType() == BitmapType.INSTANCE))) {
throw new AnalysisException("nereids don't support bitmap runtime filter");
}
if (((InSubquery) apply.getSubqueryExpr()).isNot()) {
return new LogicalJoin<>(JoinType.NULL_AWARE_LEFT_ANTI_JOIN, Lists.newArrayList(),
ExpressionUtils.extractConjunction(predicate),
conjuncts,
JoinHint.NONE,
apply.left(), apply.right());
} else {
return new LogicalJoin<>(JoinType.LEFT_SEMI_JOIN, Lists.newArrayList(),
ExpressionUtils.extractConjunction(predicate),
conjuncts,
JoinHint.NONE,
apply.left(), apply.right());
}

View File

@ -17,7 +17,6 @@
package org.apache.doris.nereids.trees.expressions;
import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateParam;
import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression;
@ -38,7 +37,7 @@ import java.util.Objects;
* so the aggregate function don't need to care about the phase of
* aggregate.
*/
public class AggregateExpression extends Expression implements UnaryExpression, PropagateNullable {
public class AggregateExpression extends Expression implements UnaryExpression {
private final AggregateFunction function;
private final AggregateParam aggregateParam;
@ -143,4 +142,9 @@ public class AggregateExpression extends Expression implements UnaryExpression,
public int hashCode() {
return Objects.hash(super.hashCode(), function, aggregateParam, child());
}
@Override
public boolean nullable() {
return function.nullable();
}
}

View File

@ -18,28 +18,38 @@
package org.apache.doris.nereids.trees.expressions.functions.agg;
import org.apache.doris.catalog.FunctionSignature;
import org.apache.doris.nereids.trees.expressions.Cast;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.AlwaysNotNullable;
import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.BigIntType;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.coercion.DateLikeType;
import org.apache.doris.nereids.util.ExpressionUtils;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import java.util.List;
import java.util.stream.Collectors;
/** MultiDistinctCount */
public class MultiDistinctCount extends AggregateFunction
implements AlwaysNotNullable, ExplicitlyCastableSignature {
// MultiDistinctCount is created in AggregateStrategies phase
// can't change getSignatures to use type coercion rule to add a cast expr
// because AggregateStrategies phase is after type coercion
public MultiDistinctCount(Expression arg0, Expression... varArgs) {
super("multi_distinct_count", true, ExpressionUtils.mergeArguments(arg0, varArgs));
super("multi_distinct_count", true, ExpressionUtils.mergeArguments(arg0, varArgs).stream()
.map(arg -> arg.getDataType() instanceof DateLikeType ? new Cast(arg, BigIntType.INSTANCE) : arg)
.collect(Collectors.toList()));
}
public MultiDistinctCount(boolean isDistinct, Expression arg0, Expression... varArgs) {
super("multi_distinct_count", true, ExpressionUtils.mergeArguments(arg0, varArgs));
super("multi_distinct_count", true, ExpressionUtils.mergeArguments(arg0, varArgs).stream()
.map(arg -> arg.getDataType() instanceof DateLikeType ? new Cast(arg, BigIntType.INSTANCE) : arg)
.collect(Collectors.toList()));
}
@Override

View File

@ -265,7 +265,7 @@ public class AggregateStrategiesTest implements PatternMatchSupported {
// id
AggregateParam phaseTwoCountAggParam = new AggregateParam(AggPhase.LOCAL, AggMode.INPUT_TO_RESULT);
AggregateParam phaseOneSumAggParam = new AggregateParam(AggPhase.LOCAL, AggMode.INPUT_TO_BUFFER);
AggregateParam phaseTwoSumAggParam = new AggregateParam(AggPhase.GLOBAL, AggMode.INPUT_TO_RESULT);
AggregateParam phaseTwoSumAggParam = new AggregateParam(AggPhase.GLOBAL, AggMode.BUFFER_TO_RESULT);
// sum
Sum sumId = new Sum(false, id.toSlot());

View File

@ -26,20 +26,15 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.util.LogicalPlanBuilder;
import org.apache.doris.nereids.util.MemoTestUtils;
import org.apache.doris.nereids.util.PlanChecker;
import org.apache.doris.nereids.util.PlanConstructor;
import org.apache.doris.planner.OlapScanNode;
import org.apache.doris.planner.PlanFragment;
import org.apache.doris.utframe.TestWithFeService;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import java.util.List;
import java.util.Set;
/**
* test ELIMINATE_UNNECESSARY_PROJECT rule.
@ -103,18 +98,19 @@ public class EliminateUnnecessaryProjectTest extends TestWithFeService {
Assertions.assertTrue(actual instanceof LogicalProject);
}
@Test
public void testEliminationForThoseNeitherDoPruneNorDoExprCalc() {
PlanChecker.from(connectContext).checkPlannerResult("SELECT col1 FROM t1",
p -> {
List<PlanFragment> fragments = p.getFragments();
Assertions.assertTrue(fragments.stream()
.flatMap(fragment -> {
Set<OlapScanNode> scans = Sets.newHashSet();
fragment.getPlanRoot().collect(OlapScanNode.class, scans);
return scans.stream();
})
.noneMatch(s -> s.getProjectList() != null));
});
}
// TODO: uncomment this after the Elimination project rule is correctly implemented
// @Test
// public void testEliminationForThoseNeitherDoPruneNorDoExprCalc() {
// PlanChecker.from(connectContext).checkPlannerResult("SELECT col1 FROM t1",
// p -> {
// List<PlanFragment> fragments = p.getFragments();
// Assertions.assertTrue(fragments.stream()
// .flatMap(fragment -> {
// Set<OlapScanNode> scans = Sets.newHashSet();
// fragment.getPlanRoot().collect(OlapScanNode.class, scans);
// return scans.stream();
// })
// .noneMatch(s -> s.getProjectList() != null));
// });
// }
}