[fix](nereids) build agg for random distributed agg table in bindRelation phase (#40181) (#40702)

pick from master #40181
This commit is contained in:
starocean999
2024-09-12 14:08:50 +08:00
committed by GitHub
parent e2dc7544dd
commit 0f8176dee0
9 changed files with 302 additions and 309 deletions

View File

@ -26,7 +26,6 @@ import org.apache.doris.nereids.rules.analysis.BindExpression;
import org.apache.doris.nereids.rules.analysis.BindRelation;
import org.apache.doris.nereids.rules.analysis.BindRelation.CustomTableResolver;
import org.apache.doris.nereids.rules.analysis.BindSink;
import org.apache.doris.nereids.rules.analysis.BuildAggForRandomDistributedTable;
import org.apache.doris.nereids.rules.analysis.CheckAfterBind;
import org.apache.doris.nereids.rules.analysis.CheckAnalysis;
import org.apache.doris.nereids.rules.analysis.CheckPolicy;
@ -163,8 +162,6 @@ public class Analyzer extends AbstractBatchJobExecutor {
topDown(new EliminateGroupByConstant()),
topDown(new SimplifyAggGroupBy()),
// run BuildAggForRandomDistributedTable before NormalizeAggregate in order to optimize the agg plan
topDown(new BuildAggForRandomDistributedTable()),
topDown(new NormalizeAggregate()),
topDown(new HavingToFilter()),
bottomUp(new SemiJoinCommute()),

View File

@ -336,10 +336,6 @@ public enum RuleType {
// topn opts
DEFER_MATERIALIZE_TOP_N_RESULT(RuleTypeClass.REWRITE),
// pre agg for random distributed table
BUILD_AGG_FOR_RANDOM_DISTRIBUTED_TABLE_PROJECT_SCAN(RuleTypeClass.REWRITE),
BUILD_AGG_FOR_RANDOM_DISTRIBUTED_TABLE_FILTER_SCAN(RuleTypeClass.REWRITE),
BUILD_AGG_FOR_RANDOM_DISTRIBUTED_TABLE_AGG_SCAN(RuleTypeClass.REWRITE),
// short circuit rule
SHOR_CIRCUIT_POINT_QUERY(RuleTypeClass.REWRITE),
// exploration rules

View File

@ -17,10 +17,17 @@
package org.apache.doris.nereids.rules.analysis;
import org.apache.doris.catalog.AggStateType;
import org.apache.doris.catalog.AggregateType;
import org.apache.doris.catalog.Column;
import org.apache.doris.catalog.DistributionInfo;
import org.apache.doris.catalog.Env;
import org.apache.doris.catalog.FunctionRegistry;
import org.apache.doris.catalog.KeysType;
import org.apache.doris.catalog.OlapTable;
import org.apache.doris.catalog.Partition;
import org.apache.doris.catalog.TableIf;
import org.apache.doris.catalog.Type;
import org.apache.doris.catalog.View;
import org.apache.doris.common.Config;
import org.apache.doris.common.Pair;
@ -44,13 +51,26 @@ import org.apache.doris.nereids.properties.LogicalProperties;
import org.apache.doris.nereids.properties.PhysicalProperties;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.functions.AggCombinerFunctionBuilder;
import org.apache.doris.nereids.trees.expressions.functions.FunctionBuilder;
import org.apache.doris.nereids.trees.expressions.functions.agg.BitmapUnion;
import org.apache.doris.nereids.trees.expressions.functions.agg.HllUnion;
import org.apache.doris.nereids.trees.expressions.functions.agg.Max;
import org.apache.doris.nereids.trees.expressions.functions.agg.Min;
import org.apache.doris.nereids.trees.expressions.functions.agg.QuantileUnion;
import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
import org.apache.doris.nereids.trees.expressions.literal.TinyIntLiteral;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.PreAggStatus;
import org.apache.doris.nereids.trees.plans.algebra.Relation;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalCTEConsumer;
import org.apache.doris.nereids.trees.plans.logical.LogicalEsScan;
import org.apache.doris.nereids.trees.plans.logical.LogicalFileScan;
@ -74,6 +94,7 @@ import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import org.apache.commons.collections.CollectionUtils;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.function.Function;
@ -215,25 +236,127 @@ public class BindRelation extends OneAnalysisRuleFactory {
unboundRelation.getTableSample());
}
}
if (!Util.showHiddenColumns() && scan.getTable().hasDeleteSign()
&& !ConnectContext.get().getSessionVariable().skipDeleteSign()) {
// table qualifier is catalog.db.table, we make db.table.column
Slot deleteSlot = null;
for (Slot slot : scan.getOutput()) {
if (slot.getName().equals(Column.DELETE_SIGN)) {
deleteSlot = slot;
if (needGenerateLogicalAggForRandomDistAggTable(scan)) {
// it's a random distribution agg table
// add agg on olap scan
return preAggForRandomDistribution(scan);
} else {
// it's a duplicate, unique or hash distribution agg table
// add delete sign filter on olap scan if needed
if (!Util.showHiddenColumns() && scan.getTable().hasDeleteSign()
&& !ConnectContext.get().getSessionVariable().skipDeleteSign()) {
// table qualifier is catalog.db.table, we make db.table.column
Slot deleteSlot = null;
for (Slot slot : scan.getOutput()) {
if (slot.getName().equals(Column.DELETE_SIGN)) {
deleteSlot = slot;
break;
}
}
Preconditions.checkArgument(deleteSlot != null);
Expression conjunct = new EqualTo(new TinyIntLiteral((byte) 0), deleteSlot);
if (!((OlapTable) table).getEnableUniqueKeyMergeOnWrite()) {
scan = scan.withPreAggStatus(
PreAggStatus.off(Column.DELETE_SIGN + " is used as conjuncts."));
}
return new LogicalFilter<>(Sets.newHashSet(conjunct), scan);
}
return scan;
}
}
private boolean needGenerateLogicalAggForRandomDistAggTable(LogicalOlapScan olapScan) {
if (ConnectContext.get() != null && ConnectContext.get().getState() != null
&& ConnectContext.get().getState().isQuery()) {
// we only need to add an agg node for query, and should not do it for deleting
// from random distributed table. see https://github.com/apache/doris/pull/37985 for more info
OlapTable olapTable = olapScan.getTable();
KeysType keysType = olapTable.getKeysType();
DistributionInfo distributionInfo = olapTable.getDefaultDistributionInfo();
return keysType == KeysType.AGG_KEYS
&& distributionInfo.getType() == DistributionInfo.DistributionInfoType.RANDOM;
} else {
return false;
}
}
/**
* add LogicalAggregate above olapScan for preAgg
* @param olapScan olap scan plan
* @return rewritten plan
*/
private LogicalPlan preAggForRandomDistribution(LogicalOlapScan olapScan) {
OlapTable olapTable = olapScan.getTable();
List<Slot> childOutputSlots = olapScan.computeOutput();
List<Expression> groupByExpressions = new ArrayList<>();
List<NamedExpression> outputExpressions = new ArrayList<>();
List<Column> columns = olapTable.getBaseSchema();
for (Column col : columns) {
// use exist slot in the plan
SlotReference slot = SlotReference.fromColumn(olapTable, col, col.getName(), olapScan.qualified());
ExprId exprId = slot.getExprId();
for (Slot childSlot : childOutputSlots) {
if (childSlot instanceof SlotReference && ((SlotReference) childSlot).getName() == col.getName()) {
exprId = childSlot.getExprId();
slot = slot.withExprId(exprId);
break;
}
}
Preconditions.checkArgument(deleteSlot != null);
Expression conjunct = new EqualTo(new TinyIntLiteral((byte) 0), deleteSlot);
if (!((OlapTable) table).getEnableUniqueKeyMergeOnWrite()) {
scan = scan.withPreAggStatus(PreAggStatus.off(
Column.DELETE_SIGN + " is used as conjuncts."));
if (col.isKey()) {
groupByExpressions.add(slot);
outputExpressions.add(slot);
} else {
Expression function = generateAggFunction(slot, col);
// DO NOT rewrite
if (function == null) {
return olapScan;
}
Alias alias = new Alias(exprId, ImmutableList.of(function), col.getName(),
olapScan.qualified(), true);
outputExpressions.add(alias);
}
return new LogicalFilter<>(Sets.newHashSet(conjunct), scan);
}
return scan;
LogicalAggregate<LogicalOlapScan> aggregate = new LogicalAggregate<>(groupByExpressions, outputExpressions,
olapScan);
return aggregate;
}
/**
* generate aggregation function according to the aggType of column
*
* @param slot slot of column
* @return aggFunction generated
*/
private Expression generateAggFunction(SlotReference slot, Column column) {
AggregateType aggregateType = column.getAggregationType();
switch (aggregateType) {
case SUM:
return new Sum(slot);
case MAX:
return new Max(slot);
case MIN:
return new Min(slot);
case HLL_UNION:
return new HllUnion(slot);
case BITMAP_UNION:
return new BitmapUnion(slot);
case QUANTILE_UNION:
return new QuantileUnion(slot);
case GENERIC:
Type type = column.getType();
if (!type.isAggStateType()) {
return null;
}
AggStateType aggState = (AggStateType) type;
// use AGGREGATE_FUNCTION_UNION to aggregate multiple agg_state into one
String funcName = aggState.getFunctionName() + AggCombinerFunctionBuilder.UNION_SUFFIX;
FunctionRegistry functionRegistry = Env.getCurrentEnv().getFunctionRegistry();
FunctionBuilder builder = functionRegistry.findFunctionBuilder(funcName, slot);
return builder.build(funcName, ImmutableList.of(slot)).first;
default:
return null;
}
}
private LogicalPlan getLogicalPlan(TableIf table, UnboundRelation unboundRelation,

View File

@ -1,271 +0,0 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.analysis;
import org.apache.doris.catalog.AggStateType;
import org.apache.doris.catalog.AggregateType;
import org.apache.doris.catalog.Column;
import org.apache.doris.catalog.DistributionInfo;
import org.apache.doris.catalog.DistributionInfo.DistributionInfoType;
import org.apache.doris.catalog.Env;
import org.apache.doris.catalog.FunctionRegistry;
import org.apache.doris.catalog.KeysType;
import org.apache.doris.catalog.OlapTable;
import org.apache.doris.catalog.Type;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.functions.AggCombinerFunctionBuilder;
import org.apache.doris.nereids.trees.expressions.functions.FunctionBuilder;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.agg.BitmapFunction;
import org.apache.doris.nereids.trees.expressions.functions.agg.BitmapUnion;
import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
import org.apache.doris.nereids.trees.expressions.functions.agg.HllFunction;
import org.apache.doris.nereids.trees.expressions.functions.agg.HllUnion;
import org.apache.doris.nereids.trees.expressions.functions.agg.Max;
import org.apache.doris.nereids.trees.expressions.functions.agg.Min;
import org.apache.doris.nereids.trees.expressions.functions.agg.QuantileUnion;
import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.qe.ConnectContext;
import com.google.common.collect.ImmutableList;
import java.util.ArrayList;
import java.util.List;
import java.util.Set;
/**
* build agg plan for querying random distributed table
*/
public class BuildAggForRandomDistributedTable implements AnalysisRuleFactory {
@Override
public List<Rule> buildRules() {
return ImmutableList.of(
// Project(Scan) -> project(agg(scan))
logicalProject(logicalOlapScan())
.when(this::isQuery)
.when(project -> isRandomDistributedTbl(project.child()))
.then(project -> preAggForRandomDistribution(project, project.child()))
.toRule(RuleType.BUILD_AGG_FOR_RANDOM_DISTRIBUTED_TABLE_PROJECT_SCAN),
// agg(scan) -> agg(agg(scan)), agg(agg) may optimized by MergeAggregate
logicalAggregate(logicalOlapScan())
.when(this::isQuery)
.when(agg -> isRandomDistributedTbl(agg.child()))
.whenNot(agg -> {
Set<AggregateFunction> functions = agg.getAggregateFunctions();
List<Expression> groupByExprs = agg.getGroupByExpressions();
// check if need generate an inner agg plan or not
// should not rewrite twice if we had rewritten olapScan to aggregate(olapScan)
return functions.stream().allMatch(this::aggTypeMatch) && groupByExprs.stream()
.allMatch(this::isKeyOrConstantExpr);
})
.then(agg -> preAggForRandomDistribution(agg, agg.child()))
.toRule(RuleType.BUILD_AGG_FOR_RANDOM_DISTRIBUTED_TABLE_AGG_SCAN),
// filter(scan) -> filter(agg(scan))
logicalFilter(logicalOlapScan())
.when(this::isQuery)
.when(filter -> isRandomDistributedTbl(filter.child()))
.then(filter -> preAggForRandomDistribution(filter, filter.child()))
.toRule(RuleType.BUILD_AGG_FOR_RANDOM_DISTRIBUTED_TABLE_FILTER_SCAN));
}
/**
* check the olapTable of olapScan is randomDistributed table
*
* @param olapScan olap scan plan
* @return true if olapTable is randomDistributed table
*/
private boolean isRandomDistributedTbl(LogicalOlapScan olapScan) {
OlapTable olapTable = olapScan.getTable();
KeysType keysType = olapTable.getKeysType();
DistributionInfo distributionInfo = olapTable.getDefaultDistributionInfo();
return keysType == KeysType.AGG_KEYS && distributionInfo.getType() == DistributionInfoType.RANDOM;
}
private boolean isQuery(LogicalPlan plan) {
return ConnectContext.get() != null
&& ConnectContext.get().getState() != null
&& ConnectContext.get().getState().isQuery();
}
/**
* add LogicalAggregate above olapScan for preAgg
*
* @param logicalPlan parent plan of olapScan
* @param olapScan olap scan plan, it may be LogicalProject, LogicalFilter, LogicalAggregate
* @return rewritten plan
*/
private Plan preAggForRandomDistribution(LogicalPlan logicalPlan, LogicalOlapScan olapScan) {
OlapTable olapTable = olapScan.getTable();
List<Slot> childOutputSlots = olapScan.computeOutput();
List<Expression> groupByExpressions = new ArrayList<>();
List<NamedExpression> outputExpressions = new ArrayList<>();
List<Column> columns = olapTable.getBaseSchema();
for (Column col : columns) {
// use exist slot in the plan
SlotReference slot = SlotReference.fromColumn(olapTable, col, col.getName(), olapScan.getQualifier());
ExprId exprId = slot.getExprId();
for (Slot childSlot : childOutputSlots) {
if (childSlot instanceof SlotReference && ((SlotReference) childSlot).getName() == col.getName()) {
exprId = childSlot.getExprId();
slot = slot.withExprId(exprId);
break;
}
}
if (col.isKey()) {
groupByExpressions.add(slot);
outputExpressions.add(slot);
} else {
Expression function = generateAggFunction(slot, col);
// DO NOT rewrite
if (function == null) {
return logicalPlan;
}
Alias alias = new Alias(exprId, function, col.getName());
outputExpressions.add(alias);
}
}
LogicalAggregate<LogicalOlapScan> aggregate = new LogicalAggregate<>(groupByExpressions, outputExpressions,
olapScan);
return logicalPlan.withChildren(aggregate);
}
/**
* generate aggregation function according to the aggType of column
*
* @param slot slot of column
* @return aggFunction generated
*/
private Expression generateAggFunction(SlotReference slot, Column column) {
AggregateType aggregateType = column.getAggregationType();
switch (aggregateType) {
case SUM:
return new Sum(slot);
case MAX:
return new Max(slot);
case MIN:
return new Min(slot);
case HLL_UNION:
return new HllUnion(slot);
case BITMAP_UNION:
return new BitmapUnion(slot);
case QUANTILE_UNION:
return new QuantileUnion(slot);
case GENERIC:
Type type = column.getType();
if (!type.isAggStateType()) {
return null;
}
AggStateType aggState = (AggStateType) type;
// use AGGREGATE_FUNCTION_UNION to aggregate multiple agg_state into one
String funcName = aggState.getFunctionName() + AggCombinerFunctionBuilder.UNION_SUFFIX;
FunctionRegistry functionRegistry = Env.getCurrentEnv().getFunctionRegistry();
FunctionBuilder builder = functionRegistry.findFunctionBuilder(funcName, slot);
return builder.build(funcName, ImmutableList.of(slot)).first;
default:
return null;
}
}
/**
* if the agg type of AggregateFunction is as same as the agg type of column, DO NOT need to rewrite
*
* @param function agg function to check
* @return true if agg type match
*/
private boolean aggTypeMatch(AggregateFunction function) {
List<Expression> children = function.children();
if (function.getName().equalsIgnoreCase("count")) {
Count count = (Count) function;
// do not rewrite for count distinct for key column
if (count.isDistinct()) {
return children.stream().allMatch(this::isKeyOrConstantExpr);
}
if (count.isStar()) {
return false;
}
}
return children.stream().allMatch(child -> aggTypeMatch(function, child));
}
/**
* check if the agg type of functionCall match the agg type of column
*
* @param function the functionCall
* @param expression expr to check
* @return true if agg type match
*/
private boolean aggTypeMatch(AggregateFunction function, Expression expression) {
if (expression.children().isEmpty()) {
if (expression instanceof SlotReference && ((SlotReference) expression).getColumn().isPresent()) {
Column col = ((SlotReference) expression).getColumn().get();
String functionName = function.getName();
if (col.isKey()) {
return functionName.equalsIgnoreCase("max") || functionName.equalsIgnoreCase("min");
}
if (col.isAggregated()) {
AggregateType aggType = col.getAggregationType();
// agg type not mach
if (aggType == AggregateType.GENERIC) {
return col.getType().isAggStateType();
}
if (aggType == AggregateType.HLL_UNION) {
return function instanceof HllFunction;
}
if (aggType == AggregateType.BITMAP_UNION) {
return function instanceof BitmapFunction;
}
return functionName.equalsIgnoreCase(aggType.name());
}
}
return false;
}
List<Expression> children = expression.children();
return children.stream().allMatch(child -> aggTypeMatch(function, child));
}
/**
* check if the columns in expr is key column or constant, if group by clause contains value column, need rewrite
*
* @param expr expr to check
* @return true if all columns is key column or constant
*/
private boolean isKeyOrConstantExpr(Expression expr) {
if (expr instanceof SlotReference && ((SlotReference) expr).getColumn().isPresent()) {
Column col = ((SlotReference) expr).getColumn().get();
return col.isKey();
} else if (expr.isConstant()) {
return true;
}
List<Expression> children = expr.children();
return children.stream().allMatch(this::isKeyOrConstantExpr);
}
}

View File

@ -23,6 +23,7 @@ import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalCheckPolicy;
import org.apache.doris.nereids.trees.plans.logical.LogicalCheckPolicy.RelatedPolicy;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
@ -49,12 +50,23 @@ public class CheckPolicy implements AnalysisRuleFactory {
logicalCheckPolicy(any().when(child -> !(child instanceof UnboundRelation))).thenApply(ctx -> {
LogicalCheckPolicy<Plan> checkPolicy = ctx.root;
LogicalFilter<Plan> upperFilter = null;
Plan upAgg = null;
Plan child = checkPolicy.child();
// Because the unique table will automatically include a filter condition
if (child instanceof LogicalFilter && child.bound() && child
.child(0) instanceof LogicalRelation) {
if ((child instanceof LogicalFilter) && child.bound()) {
upperFilter = (LogicalFilter) child;
if (child.child(0) instanceof LogicalRelation) {
child = child.child(0);
} else if (child.child(0) instanceof LogicalAggregate
&& child.child(0).child(0) instanceof LogicalRelation) {
upAgg = child.child(0);
child = child.child(0).child(0);
}
}
if ((child instanceof LogicalAggregate)
&& child.bound() && child.child(0) instanceof LogicalRelation) {
upAgg = child;
child = child.child(0);
}
if (!(child instanceof LogicalRelation)
@ -76,16 +88,17 @@ public class CheckPolicy implements AnalysisRuleFactory {
RelatedPolicy relatedPolicy = checkPolicy.findPolicy(relation, ctx.cascadesContext);
relatedPolicy.rowPolicyFilter.ifPresent(expression -> combineFilter.addAll(
ExpressionUtils.extractConjunctionToSet(expression)));
Plan result = relation;
Plan result = upAgg != null ? upAgg.withChildren(relation) : relation;
if (upperFilter != null) {
combineFilter.addAll(upperFilter.getConjuncts());
}
if (!combineFilter.isEmpty()) {
result = new LogicalFilter<>(combineFilter, relation);
result = new LogicalFilter<>(combineFilter, result);
}
if (relatedPolicy.dataMaskProjects.isPresent()) {
result = new LogicalProject<>(relatedPolicy.dataMaskProjects.get(), result);
}
return result;
})
)