pick from master #40181
This commit is contained in:
@ -26,7 +26,6 @@ import org.apache.doris.nereids.rules.analysis.BindExpression;
|
||||
import org.apache.doris.nereids.rules.analysis.BindRelation;
|
||||
import org.apache.doris.nereids.rules.analysis.BindRelation.CustomTableResolver;
|
||||
import org.apache.doris.nereids.rules.analysis.BindSink;
|
||||
import org.apache.doris.nereids.rules.analysis.BuildAggForRandomDistributedTable;
|
||||
import org.apache.doris.nereids.rules.analysis.CheckAfterBind;
|
||||
import org.apache.doris.nereids.rules.analysis.CheckAnalysis;
|
||||
import org.apache.doris.nereids.rules.analysis.CheckPolicy;
|
||||
@ -163,8 +162,6 @@ public class Analyzer extends AbstractBatchJobExecutor {
|
||||
topDown(new EliminateGroupByConstant()),
|
||||
|
||||
topDown(new SimplifyAggGroupBy()),
|
||||
// run BuildAggForRandomDistributedTable before NormalizeAggregate in order to optimize the agg plan
|
||||
topDown(new BuildAggForRandomDistributedTable()),
|
||||
topDown(new NormalizeAggregate()),
|
||||
topDown(new HavingToFilter()),
|
||||
bottomUp(new SemiJoinCommute()),
|
||||
|
||||
@ -336,10 +336,6 @@ public enum RuleType {
|
||||
|
||||
// topn opts
|
||||
DEFER_MATERIALIZE_TOP_N_RESULT(RuleTypeClass.REWRITE),
|
||||
// pre agg for random distributed table
|
||||
BUILD_AGG_FOR_RANDOM_DISTRIBUTED_TABLE_PROJECT_SCAN(RuleTypeClass.REWRITE),
|
||||
BUILD_AGG_FOR_RANDOM_DISTRIBUTED_TABLE_FILTER_SCAN(RuleTypeClass.REWRITE),
|
||||
BUILD_AGG_FOR_RANDOM_DISTRIBUTED_TABLE_AGG_SCAN(RuleTypeClass.REWRITE),
|
||||
// short circuit rule
|
||||
SHOR_CIRCUIT_POINT_QUERY(RuleTypeClass.REWRITE),
|
||||
// exploration rules
|
||||
|
||||
@ -17,10 +17,17 @@
|
||||
|
||||
package org.apache.doris.nereids.rules.analysis;
|
||||
|
||||
import org.apache.doris.catalog.AggStateType;
|
||||
import org.apache.doris.catalog.AggregateType;
|
||||
import org.apache.doris.catalog.Column;
|
||||
import org.apache.doris.catalog.DistributionInfo;
|
||||
import org.apache.doris.catalog.Env;
|
||||
import org.apache.doris.catalog.FunctionRegistry;
|
||||
import org.apache.doris.catalog.KeysType;
|
||||
import org.apache.doris.catalog.OlapTable;
|
||||
import org.apache.doris.catalog.Partition;
|
||||
import org.apache.doris.catalog.TableIf;
|
||||
import org.apache.doris.catalog.Type;
|
||||
import org.apache.doris.catalog.View;
|
||||
import org.apache.doris.common.Config;
|
||||
import org.apache.doris.common.Pair;
|
||||
@ -44,13 +51,26 @@ import org.apache.doris.nereids.properties.LogicalProperties;
|
||||
import org.apache.doris.nereids.properties.PhysicalProperties;
|
||||
import org.apache.doris.nereids.rules.Rule;
|
||||
import org.apache.doris.nereids.rules.RuleType;
|
||||
import org.apache.doris.nereids.trees.expressions.Alias;
|
||||
import org.apache.doris.nereids.trees.expressions.EqualTo;
|
||||
import org.apache.doris.nereids.trees.expressions.ExprId;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.NamedExpression;
|
||||
import org.apache.doris.nereids.trees.expressions.Slot;
|
||||
import org.apache.doris.nereids.trees.expressions.SlotReference;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.AggCombinerFunctionBuilder;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.FunctionBuilder;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.agg.BitmapUnion;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.agg.HllUnion;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.agg.Max;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.agg.Min;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.agg.QuantileUnion;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.TinyIntLiteral;
|
||||
import org.apache.doris.nereids.trees.plans.Plan;
|
||||
import org.apache.doris.nereids.trees.plans.PreAggStatus;
|
||||
import org.apache.doris.nereids.trees.plans.algebra.Relation;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalCTEConsumer;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalEsScan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalFileScan;
|
||||
@ -74,6 +94,7 @@ import com.google.common.collect.Lists;
|
||||
import com.google.common.collect.Sets;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
import java.util.function.Function;
|
||||
@ -215,25 +236,127 @@ public class BindRelation extends OneAnalysisRuleFactory {
|
||||
unboundRelation.getTableSample());
|
||||
}
|
||||
}
|
||||
if (!Util.showHiddenColumns() && scan.getTable().hasDeleteSign()
|
||||
&& !ConnectContext.get().getSessionVariable().skipDeleteSign()) {
|
||||
// table qualifier is catalog.db.table, we make db.table.column
|
||||
Slot deleteSlot = null;
|
||||
for (Slot slot : scan.getOutput()) {
|
||||
if (slot.getName().equals(Column.DELETE_SIGN)) {
|
||||
deleteSlot = slot;
|
||||
if (needGenerateLogicalAggForRandomDistAggTable(scan)) {
|
||||
// it's a random distribution agg table
|
||||
// add agg on olap scan
|
||||
return preAggForRandomDistribution(scan);
|
||||
} else {
|
||||
// it's a duplicate, unique or hash distribution agg table
|
||||
// add delete sign filter on olap scan if needed
|
||||
if (!Util.showHiddenColumns() && scan.getTable().hasDeleteSign()
|
||||
&& !ConnectContext.get().getSessionVariable().skipDeleteSign()) {
|
||||
// table qualifier is catalog.db.table, we make db.table.column
|
||||
Slot deleteSlot = null;
|
||||
for (Slot slot : scan.getOutput()) {
|
||||
if (slot.getName().equals(Column.DELETE_SIGN)) {
|
||||
deleteSlot = slot;
|
||||
break;
|
||||
}
|
||||
}
|
||||
Preconditions.checkArgument(deleteSlot != null);
|
||||
Expression conjunct = new EqualTo(new TinyIntLiteral((byte) 0), deleteSlot);
|
||||
if (!((OlapTable) table).getEnableUniqueKeyMergeOnWrite()) {
|
||||
scan = scan.withPreAggStatus(
|
||||
PreAggStatus.off(Column.DELETE_SIGN + " is used as conjuncts."));
|
||||
}
|
||||
return new LogicalFilter<>(Sets.newHashSet(conjunct), scan);
|
||||
}
|
||||
return scan;
|
||||
}
|
||||
}
|
||||
|
||||
private boolean needGenerateLogicalAggForRandomDistAggTable(LogicalOlapScan olapScan) {
|
||||
if (ConnectContext.get() != null && ConnectContext.get().getState() != null
|
||||
&& ConnectContext.get().getState().isQuery()) {
|
||||
// we only need to add an agg node for query, and should not do it for deleting
|
||||
// from random distributed table. see https://github.com/apache/doris/pull/37985 for more info
|
||||
OlapTable olapTable = olapScan.getTable();
|
||||
KeysType keysType = olapTable.getKeysType();
|
||||
DistributionInfo distributionInfo = olapTable.getDefaultDistributionInfo();
|
||||
return keysType == KeysType.AGG_KEYS
|
||||
&& distributionInfo.getType() == DistributionInfo.DistributionInfoType.RANDOM;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* add LogicalAggregate above olapScan for preAgg
|
||||
* @param olapScan olap scan plan
|
||||
* @return rewritten plan
|
||||
*/
|
||||
private LogicalPlan preAggForRandomDistribution(LogicalOlapScan olapScan) {
|
||||
OlapTable olapTable = olapScan.getTable();
|
||||
List<Slot> childOutputSlots = olapScan.computeOutput();
|
||||
List<Expression> groupByExpressions = new ArrayList<>();
|
||||
List<NamedExpression> outputExpressions = new ArrayList<>();
|
||||
List<Column> columns = olapTable.getBaseSchema();
|
||||
|
||||
for (Column col : columns) {
|
||||
// use exist slot in the plan
|
||||
SlotReference slot = SlotReference.fromColumn(olapTable, col, col.getName(), olapScan.qualified());
|
||||
ExprId exprId = slot.getExprId();
|
||||
for (Slot childSlot : childOutputSlots) {
|
||||
if (childSlot instanceof SlotReference && ((SlotReference) childSlot).getName() == col.getName()) {
|
||||
exprId = childSlot.getExprId();
|
||||
slot = slot.withExprId(exprId);
|
||||
break;
|
||||
}
|
||||
}
|
||||
Preconditions.checkArgument(deleteSlot != null);
|
||||
Expression conjunct = new EqualTo(new TinyIntLiteral((byte) 0), deleteSlot);
|
||||
if (!((OlapTable) table).getEnableUniqueKeyMergeOnWrite()) {
|
||||
scan = scan.withPreAggStatus(PreAggStatus.off(
|
||||
Column.DELETE_SIGN + " is used as conjuncts."));
|
||||
if (col.isKey()) {
|
||||
groupByExpressions.add(slot);
|
||||
outputExpressions.add(slot);
|
||||
} else {
|
||||
Expression function = generateAggFunction(slot, col);
|
||||
// DO NOT rewrite
|
||||
if (function == null) {
|
||||
return olapScan;
|
||||
}
|
||||
Alias alias = new Alias(exprId, ImmutableList.of(function), col.getName(),
|
||||
olapScan.qualified(), true);
|
||||
outputExpressions.add(alias);
|
||||
}
|
||||
return new LogicalFilter<>(Sets.newHashSet(conjunct), scan);
|
||||
}
|
||||
return scan;
|
||||
LogicalAggregate<LogicalOlapScan> aggregate = new LogicalAggregate<>(groupByExpressions, outputExpressions,
|
||||
olapScan);
|
||||
return aggregate;
|
||||
}
|
||||
|
||||
/**
|
||||
* generate aggregation function according to the aggType of column
|
||||
*
|
||||
* @param slot slot of column
|
||||
* @return aggFunction generated
|
||||
*/
|
||||
private Expression generateAggFunction(SlotReference slot, Column column) {
|
||||
AggregateType aggregateType = column.getAggregationType();
|
||||
switch (aggregateType) {
|
||||
case SUM:
|
||||
return new Sum(slot);
|
||||
case MAX:
|
||||
return new Max(slot);
|
||||
case MIN:
|
||||
return new Min(slot);
|
||||
case HLL_UNION:
|
||||
return new HllUnion(slot);
|
||||
case BITMAP_UNION:
|
||||
return new BitmapUnion(slot);
|
||||
case QUANTILE_UNION:
|
||||
return new QuantileUnion(slot);
|
||||
case GENERIC:
|
||||
Type type = column.getType();
|
||||
if (!type.isAggStateType()) {
|
||||
return null;
|
||||
}
|
||||
AggStateType aggState = (AggStateType) type;
|
||||
// use AGGREGATE_FUNCTION_UNION to aggregate multiple agg_state into one
|
||||
String funcName = aggState.getFunctionName() + AggCombinerFunctionBuilder.UNION_SUFFIX;
|
||||
FunctionRegistry functionRegistry = Env.getCurrentEnv().getFunctionRegistry();
|
||||
FunctionBuilder builder = functionRegistry.findFunctionBuilder(funcName, slot);
|
||||
return builder.build(funcName, ImmutableList.of(slot)).first;
|
||||
default:
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
private LogicalPlan getLogicalPlan(TableIf table, UnboundRelation unboundRelation,
|
||||
|
||||
@ -1,271 +0,0 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package org.apache.doris.nereids.rules.analysis;
|
||||
|
||||
import org.apache.doris.catalog.AggStateType;
|
||||
import org.apache.doris.catalog.AggregateType;
|
||||
import org.apache.doris.catalog.Column;
|
||||
import org.apache.doris.catalog.DistributionInfo;
|
||||
import org.apache.doris.catalog.DistributionInfo.DistributionInfoType;
|
||||
import org.apache.doris.catalog.Env;
|
||||
import org.apache.doris.catalog.FunctionRegistry;
|
||||
import org.apache.doris.catalog.KeysType;
|
||||
import org.apache.doris.catalog.OlapTable;
|
||||
import org.apache.doris.catalog.Type;
|
||||
import org.apache.doris.nereids.rules.Rule;
|
||||
import org.apache.doris.nereids.rules.RuleType;
|
||||
import org.apache.doris.nereids.trees.expressions.Alias;
|
||||
import org.apache.doris.nereids.trees.expressions.ExprId;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.NamedExpression;
|
||||
import org.apache.doris.nereids.trees.expressions.Slot;
|
||||
import org.apache.doris.nereids.trees.expressions.SlotReference;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.AggCombinerFunctionBuilder;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.FunctionBuilder;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.agg.BitmapFunction;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.agg.BitmapUnion;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.agg.HllFunction;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.agg.HllUnion;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.agg.Max;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.agg.Min;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.agg.QuantileUnion;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
|
||||
import org.apache.doris.nereids.trees.plans.Plan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
|
||||
import org.apache.doris.qe.ConnectContext;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
|
||||
/**
|
||||
* build agg plan for querying random distributed table
|
||||
*/
|
||||
public class BuildAggForRandomDistributedTable implements AnalysisRuleFactory {
|
||||
|
||||
@Override
|
||||
public List<Rule> buildRules() {
|
||||
return ImmutableList.of(
|
||||
// Project(Scan) -> project(agg(scan))
|
||||
logicalProject(logicalOlapScan())
|
||||
.when(this::isQuery)
|
||||
.when(project -> isRandomDistributedTbl(project.child()))
|
||||
.then(project -> preAggForRandomDistribution(project, project.child()))
|
||||
.toRule(RuleType.BUILD_AGG_FOR_RANDOM_DISTRIBUTED_TABLE_PROJECT_SCAN),
|
||||
// agg(scan) -> agg(agg(scan)), agg(agg) may optimized by MergeAggregate
|
||||
logicalAggregate(logicalOlapScan())
|
||||
.when(this::isQuery)
|
||||
.when(agg -> isRandomDistributedTbl(agg.child()))
|
||||
.whenNot(agg -> {
|
||||
Set<AggregateFunction> functions = agg.getAggregateFunctions();
|
||||
List<Expression> groupByExprs = agg.getGroupByExpressions();
|
||||
// check if need generate an inner agg plan or not
|
||||
// should not rewrite twice if we had rewritten olapScan to aggregate(olapScan)
|
||||
return functions.stream().allMatch(this::aggTypeMatch) && groupByExprs.stream()
|
||||
.allMatch(this::isKeyOrConstantExpr);
|
||||
})
|
||||
.then(agg -> preAggForRandomDistribution(agg, agg.child()))
|
||||
.toRule(RuleType.BUILD_AGG_FOR_RANDOM_DISTRIBUTED_TABLE_AGG_SCAN),
|
||||
// filter(scan) -> filter(agg(scan))
|
||||
logicalFilter(logicalOlapScan())
|
||||
.when(this::isQuery)
|
||||
.when(filter -> isRandomDistributedTbl(filter.child()))
|
||||
.then(filter -> preAggForRandomDistribution(filter, filter.child()))
|
||||
.toRule(RuleType.BUILD_AGG_FOR_RANDOM_DISTRIBUTED_TABLE_FILTER_SCAN));
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* check the olapTable of olapScan is randomDistributed table
|
||||
*
|
||||
* @param olapScan olap scan plan
|
||||
* @return true if olapTable is randomDistributed table
|
||||
*/
|
||||
private boolean isRandomDistributedTbl(LogicalOlapScan olapScan) {
|
||||
OlapTable olapTable = olapScan.getTable();
|
||||
KeysType keysType = olapTable.getKeysType();
|
||||
DistributionInfo distributionInfo = olapTable.getDefaultDistributionInfo();
|
||||
return keysType == KeysType.AGG_KEYS && distributionInfo.getType() == DistributionInfoType.RANDOM;
|
||||
}
|
||||
|
||||
private boolean isQuery(LogicalPlan plan) {
|
||||
return ConnectContext.get() != null
|
||||
&& ConnectContext.get().getState() != null
|
||||
&& ConnectContext.get().getState().isQuery();
|
||||
}
|
||||
|
||||
/**
|
||||
* add LogicalAggregate above olapScan for preAgg
|
||||
*
|
||||
* @param logicalPlan parent plan of olapScan
|
||||
* @param olapScan olap scan plan, it may be LogicalProject, LogicalFilter, LogicalAggregate
|
||||
* @return rewritten plan
|
||||
*/
|
||||
private Plan preAggForRandomDistribution(LogicalPlan logicalPlan, LogicalOlapScan olapScan) {
|
||||
OlapTable olapTable = olapScan.getTable();
|
||||
List<Slot> childOutputSlots = olapScan.computeOutput();
|
||||
List<Expression> groupByExpressions = new ArrayList<>();
|
||||
List<NamedExpression> outputExpressions = new ArrayList<>();
|
||||
List<Column> columns = olapTable.getBaseSchema();
|
||||
|
||||
for (Column col : columns) {
|
||||
// use exist slot in the plan
|
||||
SlotReference slot = SlotReference.fromColumn(olapTable, col, col.getName(), olapScan.getQualifier());
|
||||
ExprId exprId = slot.getExprId();
|
||||
for (Slot childSlot : childOutputSlots) {
|
||||
if (childSlot instanceof SlotReference && ((SlotReference) childSlot).getName() == col.getName()) {
|
||||
exprId = childSlot.getExprId();
|
||||
slot = slot.withExprId(exprId);
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (col.isKey()) {
|
||||
groupByExpressions.add(slot);
|
||||
outputExpressions.add(slot);
|
||||
} else {
|
||||
Expression function = generateAggFunction(slot, col);
|
||||
// DO NOT rewrite
|
||||
if (function == null) {
|
||||
return logicalPlan;
|
||||
}
|
||||
Alias alias = new Alias(exprId, function, col.getName());
|
||||
outputExpressions.add(alias);
|
||||
}
|
||||
}
|
||||
LogicalAggregate<LogicalOlapScan> aggregate = new LogicalAggregate<>(groupByExpressions, outputExpressions,
|
||||
olapScan);
|
||||
return logicalPlan.withChildren(aggregate);
|
||||
}
|
||||
|
||||
/**
|
||||
* generate aggregation function according to the aggType of column
|
||||
*
|
||||
* @param slot slot of column
|
||||
* @return aggFunction generated
|
||||
*/
|
||||
private Expression generateAggFunction(SlotReference slot, Column column) {
|
||||
AggregateType aggregateType = column.getAggregationType();
|
||||
switch (aggregateType) {
|
||||
case SUM:
|
||||
return new Sum(slot);
|
||||
case MAX:
|
||||
return new Max(slot);
|
||||
case MIN:
|
||||
return new Min(slot);
|
||||
case HLL_UNION:
|
||||
return new HllUnion(slot);
|
||||
case BITMAP_UNION:
|
||||
return new BitmapUnion(slot);
|
||||
case QUANTILE_UNION:
|
||||
return new QuantileUnion(slot);
|
||||
case GENERIC:
|
||||
Type type = column.getType();
|
||||
if (!type.isAggStateType()) {
|
||||
return null;
|
||||
}
|
||||
AggStateType aggState = (AggStateType) type;
|
||||
// use AGGREGATE_FUNCTION_UNION to aggregate multiple agg_state into one
|
||||
String funcName = aggState.getFunctionName() + AggCombinerFunctionBuilder.UNION_SUFFIX;
|
||||
FunctionRegistry functionRegistry = Env.getCurrentEnv().getFunctionRegistry();
|
||||
FunctionBuilder builder = functionRegistry.findFunctionBuilder(funcName, slot);
|
||||
return builder.build(funcName, ImmutableList.of(slot)).first;
|
||||
default:
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* if the agg type of AggregateFunction is as same as the agg type of column, DO NOT need to rewrite
|
||||
*
|
||||
* @param function agg function to check
|
||||
* @return true if agg type match
|
||||
*/
|
||||
private boolean aggTypeMatch(AggregateFunction function) {
|
||||
List<Expression> children = function.children();
|
||||
if (function.getName().equalsIgnoreCase("count")) {
|
||||
Count count = (Count) function;
|
||||
// do not rewrite for count distinct for key column
|
||||
if (count.isDistinct()) {
|
||||
return children.stream().allMatch(this::isKeyOrConstantExpr);
|
||||
}
|
||||
if (count.isStar()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return children.stream().allMatch(child -> aggTypeMatch(function, child));
|
||||
}
|
||||
|
||||
/**
|
||||
* check if the agg type of functionCall match the agg type of column
|
||||
*
|
||||
* @param function the functionCall
|
||||
* @param expression expr to check
|
||||
* @return true if agg type match
|
||||
*/
|
||||
private boolean aggTypeMatch(AggregateFunction function, Expression expression) {
|
||||
if (expression.children().isEmpty()) {
|
||||
if (expression instanceof SlotReference && ((SlotReference) expression).getColumn().isPresent()) {
|
||||
Column col = ((SlotReference) expression).getColumn().get();
|
||||
String functionName = function.getName();
|
||||
if (col.isKey()) {
|
||||
return functionName.equalsIgnoreCase("max") || functionName.equalsIgnoreCase("min");
|
||||
}
|
||||
if (col.isAggregated()) {
|
||||
AggregateType aggType = col.getAggregationType();
|
||||
// agg type not mach
|
||||
if (aggType == AggregateType.GENERIC) {
|
||||
return col.getType().isAggStateType();
|
||||
}
|
||||
if (aggType == AggregateType.HLL_UNION) {
|
||||
return function instanceof HllFunction;
|
||||
}
|
||||
if (aggType == AggregateType.BITMAP_UNION) {
|
||||
return function instanceof BitmapFunction;
|
||||
}
|
||||
return functionName.equalsIgnoreCase(aggType.name());
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
List<Expression> children = expression.children();
|
||||
return children.stream().allMatch(child -> aggTypeMatch(function, child));
|
||||
}
|
||||
|
||||
/**
|
||||
* check if the columns in expr is key column or constant, if group by clause contains value column, need rewrite
|
||||
*
|
||||
* @param expr expr to check
|
||||
* @return true if all columns is key column or constant
|
||||
*/
|
||||
private boolean isKeyOrConstantExpr(Expression expr) {
|
||||
if (expr instanceof SlotReference && ((SlotReference) expr).getColumn().isPresent()) {
|
||||
Column col = ((SlotReference) expr).getColumn().get();
|
||||
return col.isKey();
|
||||
} else if (expr.isConstant()) {
|
||||
return true;
|
||||
}
|
||||
List<Expression> children = expr.children();
|
||||
return children.stream().allMatch(this::isKeyOrConstantExpr);
|
||||
}
|
||||
}
|
||||
@ -23,6 +23,7 @@ import org.apache.doris.nereids.rules.Rule;
|
||||
import org.apache.doris.nereids.rules.RuleType;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.plans.Plan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalCheckPolicy;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalCheckPolicy.RelatedPolicy;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
|
||||
@ -49,12 +50,23 @@ public class CheckPolicy implements AnalysisRuleFactory {
|
||||
logicalCheckPolicy(any().when(child -> !(child instanceof UnboundRelation))).thenApply(ctx -> {
|
||||
LogicalCheckPolicy<Plan> checkPolicy = ctx.root;
|
||||
LogicalFilter<Plan> upperFilter = null;
|
||||
Plan upAgg = null;
|
||||
|
||||
Plan child = checkPolicy.child();
|
||||
// Because the unique table will automatically include a filter condition
|
||||
if (child instanceof LogicalFilter && child.bound() && child
|
||||
.child(0) instanceof LogicalRelation) {
|
||||
if ((child instanceof LogicalFilter) && child.bound()) {
|
||||
upperFilter = (LogicalFilter) child;
|
||||
if (child.child(0) instanceof LogicalRelation) {
|
||||
child = child.child(0);
|
||||
} else if (child.child(0) instanceof LogicalAggregate
|
||||
&& child.child(0).child(0) instanceof LogicalRelation) {
|
||||
upAgg = child.child(0);
|
||||
child = child.child(0).child(0);
|
||||
}
|
||||
}
|
||||
if ((child instanceof LogicalAggregate)
|
||||
&& child.bound() && child.child(0) instanceof LogicalRelation) {
|
||||
upAgg = child;
|
||||
child = child.child(0);
|
||||
}
|
||||
if (!(child instanceof LogicalRelation)
|
||||
@ -76,16 +88,17 @@ public class CheckPolicy implements AnalysisRuleFactory {
|
||||
RelatedPolicy relatedPolicy = checkPolicy.findPolicy(relation, ctx.cascadesContext);
|
||||
relatedPolicy.rowPolicyFilter.ifPresent(expression -> combineFilter.addAll(
|
||||
ExpressionUtils.extractConjunctionToSet(expression)));
|
||||
Plan result = relation;
|
||||
Plan result = upAgg != null ? upAgg.withChildren(relation) : relation;
|
||||
if (upperFilter != null) {
|
||||
combineFilter.addAll(upperFilter.getConjuncts());
|
||||
}
|
||||
if (!combineFilter.isEmpty()) {
|
||||
result = new LogicalFilter<>(combineFilter, relation);
|
||||
result = new LogicalFilter<>(combineFilter, result);
|
||||
}
|
||||
if (relatedPolicy.dataMaskProjects.isPresent()) {
|
||||
result = new LogicalProject<>(relatedPolicy.dataMaskProjects.get(), result);
|
||||
}
|
||||
|
||||
return result;
|
||||
})
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user