[fix](Nereids) fix bind having aggregate failed (#32490)

fix bind having aggregate failed, keep the behavior like mysql
This commit is contained in:
924060929
2024-03-22 11:53:46 +08:00
committed by yiguolei
parent 1c521cd94e
commit 6812b575b2
10 changed files with 366 additions and 27 deletions

View File

@ -53,14 +53,14 @@ public class FunctionRegistry {
// to record the global alias function and other udf.
private static final String GLOBAL_FUNCTION = "__GLOBAL_FUNCTION__";
private final Map<String, List<FunctionBuilder>> name2InternalBuiltinBuilders;
private final Map<String, List<FunctionBuilder>> name2BuiltinBuilders;
private final Map<String, Map<String, List<FunctionBuilder>>> name2UdfBuilders;
public FunctionRegistry() {
name2InternalBuiltinBuilders = new ConcurrentHashMap<>();
name2BuiltinBuilders = new ConcurrentHashMap<>();
name2UdfBuilders = new ConcurrentHashMap<>();
registerBuiltinFunctions(name2InternalBuiltinBuilders);
afterRegisterBuiltinFunctions(name2InternalBuiltinBuilders);
registerBuiltinFunctions(name2BuiltinBuilders);
afterRegisterBuiltinFunctions(name2BuiltinBuilders);
}
// this function is used to test.
@ -78,12 +78,33 @@ public class FunctionRegistry {
}
public Optional<List<FunctionBuilder>> tryGetBuiltinBuilders(String name) {
List<FunctionBuilder> builders = name2InternalBuiltinBuilders.get(name);
return name2InternalBuiltinBuilders.get(name) == null
List<FunctionBuilder> builders = name2BuiltinBuilders.get(name);
return name2BuiltinBuilders.get(name) == null
? Optional.empty()
: Optional.of(ImmutableList.copyOf(builders));
}
public boolean isAggregateFunction(String dbName, String name) {
name = name.toLowerCase();
Class<?> aggClass = org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction.class;
if (StringUtils.isEmpty(dbName)) {
List<FunctionBuilder> functionBuilders = name2BuiltinBuilders.get(name);
for (FunctionBuilder functionBuilder : functionBuilders) {
if (aggClass.isAssignableFrom(functionBuilder.functionClass())) {
return true;
}
}
}
List<FunctionBuilder> udfBuilders = findUdfBuilder(dbName, name);
for (FunctionBuilder udfBuilder : udfBuilders) {
if (aggClass.isAssignableFrom(udfBuilder.functionClass())) {
return true;
}
}
return false;
}
// currently we only find function by name and arity and args' types.
public FunctionBuilder findFunctionBuilder(String dbName, String name, List<?> arguments) {
List<FunctionBuilder> functionBuilders = null;
@ -92,11 +113,11 @@ public class FunctionRegistry {
if (StringUtils.isEmpty(dbName)) {
// search internal function only if dbName is empty
functionBuilders = name2InternalBuiltinBuilders.get(name.toLowerCase());
functionBuilders = name2BuiltinBuilders.get(name.toLowerCase());
if (CollectionUtils.isEmpty(functionBuilders) && AggCombinerFunctionBuilder.isAggStateCombinator(name)) {
String nestedName = AggCombinerFunctionBuilder.getNestedName(name);
String combinatorSuffix = AggCombinerFunctionBuilder.getCombinatorSuffix(name);
functionBuilders = name2InternalBuiltinBuilders.get(nestedName.toLowerCase());
functionBuilders = name2BuiltinBuilders.get(nestedName.toLowerCase());
if (functionBuilders != null) {
List<FunctionBuilder> candidateBuilders = Lists.newArrayListWithCapacity(functionBuilders.size());
for (FunctionBuilder functionBuilder : functionBuilders) {
@ -199,8 +220,8 @@ public class FunctionRegistry {
}
synchronized (name2UdfBuilders) {
Map<String, List<FunctionBuilder>> builders = name2UdfBuilders.getOrDefault(dbName, ImmutableMap.of());
builders.getOrDefault(name, Lists.newArrayList()).removeIf(builder -> ((UdfBuilder) builder).getArgTypes()
.equals(argTypes));
builders.getOrDefault(name, Lists.newArrayList())
.removeIf(builder -> ((UdfBuilder) builder).getArgTypes().equals(argTypes));
}
}
}

View File

@ -19,11 +19,13 @@ package org.apache.doris.nereids.rules.analysis;
import org.apache.doris.catalog.Env;
import org.apache.doris.catalog.FunctionRegistry;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.NereidsPlanner;
import org.apache.doris.nereids.StatementContext;
import org.apache.doris.nereids.analyzer.MappingSlot;
import org.apache.doris.nereids.analyzer.Scope;
import org.apache.doris.nereids.analyzer.UnboundFunction;
import org.apache.doris.nereids.analyzer.UnboundOneRowRelation;
import org.apache.doris.nereids.analyzer.UnboundResultSink;
import org.apache.doris.nereids.analyzer.UnboundSlot;
@ -351,12 +353,12 @@ public class BindExpression implements AnalysisRuleFactory {
CascadesContext cascadesContext = ctx.cascadesContext;
// bind slot by child.output first
Scope defaultScope = toScope(cascadesContext, childPlan.getOutput());
Scope childOutput = toScope(cascadesContext, childPlan.getOutput());
// then bind slot by child.children.output
Supplier<Scope> backupScope = Suppliers.memoize(() ->
Supplier<Scope> childChildrenOutput = Suppliers.memoize(() ->
toScope(cascadesContext, PlanUtils.fastGetChildrenOutputs(childPlan.children()))
);
return bindHavingByScopes(having, cascadesContext, defaultScope, backupScope);
return bindHavingByScopes(having, cascadesContext, childOutput, childChildrenOutput);
}
private LogicalHaving<Plan> bindHavingAggregate(
@ -365,13 +367,115 @@ public class BindExpression implements AnalysisRuleFactory {
Aggregate<Plan> aggregate = having.child();
CascadesContext cascadesContext = ctx.cascadesContext;
// having(aggregate) should bind slot by aggregate.child.output first
Scope defaultScope = toScope(cascadesContext, PlanUtils.fastGetChildrenOutputs(aggregate.children()));
// then bind slot by aggregate.output
Supplier<Scope> backupScope = Suppliers.memoize(() ->
toScope(cascadesContext, aggregate.getOutput())
);
return bindHavingByScopes(ctx.root, ctx.cascadesContext, defaultScope, backupScope);
// keep same behavior as mysql
Supplier<CustomSlotBinderAnalyzer> bindByAggChild = Suppliers.memoize(() -> {
Scope aggChildOutputScope
= toScope(cascadesContext, PlanUtils.fastGetChildrenOutputs(aggregate.children()));
return (analyzer, unboundSlot) -> analyzer.bindSlotByScope(unboundSlot, aggChildOutputScope);
});
Scope aggOutputScope = toScope(cascadesContext, aggregate.getOutput());
Supplier<CustomSlotBinderAnalyzer> bindByGroupByThenAggOutputThenAggChild = Suppliers.memoize(() -> {
List<Expression> groupByExprs = aggregate.getGroupByExpressions();
ImmutableList.Builder<Slot> groupBySlots
= ImmutableList.builderWithExpectedSize(groupByExprs.size());
for (Expression groupBy : groupByExprs) {
if (groupBy instanceof Slot) {
groupBySlots.add((Slot) groupBy);
}
}
Scope groupBySlotsScope = toScope(cascadesContext, groupBySlots.build());
Supplier<Pair<Scope, Scope>> separateAggOutputScopes = Suppliers.memoize(() -> {
ImmutableList.Builder<Slot> groupByOutputs = ImmutableList.builderWithExpectedSize(
aggregate.getOutputExpressions().size());
ImmutableList.Builder<Slot> aggFunOutputs = ImmutableList.builderWithExpectedSize(
aggregate.getOutputExpressions().size());
for (NamedExpression outputExpression : aggregate.getOutputExpressions()) {
if (outputExpression.anyMatch(AggregateFunction.class::isInstance)) {
aggFunOutputs.add(outputExpression.toSlot());
} else {
groupByOutputs.add(outputExpression.toSlot());
}
}
Scope nonAggFunSlotsScope = toScope(cascadesContext, groupByOutputs.build());
Scope aggFuncSlotsScope = toScope(cascadesContext, aggFunOutputs.build());
return Pair.of(nonAggFunSlotsScope, aggFuncSlotsScope);
});
return (analyzer, unboundSlot) -> {
List<Slot> boundInGroupBy = analyzer.bindSlotByScope(unboundSlot, groupBySlotsScope);
if (boundInGroupBy.size() == 1) {
return boundInGroupBy;
}
Pair<Scope, Scope> separateAggOutputScope = separateAggOutputScopes.get();
List<Slot> boundInNonAggFuncs = analyzer.bindSlotByScope(unboundSlot, separateAggOutputScope.first);
if (boundInNonAggFuncs.size() == 1) {
return boundInNonAggFuncs;
}
List<Slot> boundInAggFuncs = analyzer.bindSlotByScope(unboundSlot, separateAggOutputScope.second);
if (boundInAggFuncs.size() == 1) {
return boundInAggFuncs;
}
return bindByAggChild.get().bindSlot(analyzer, unboundSlot);
};
});
FunctionRegistry functionRegistry = cascadesContext.getConnectContext().getEnv().getFunctionRegistry();
ExpressionAnalyzer havingAnalyzer = new ExpressionAnalyzer(having, aggOutputScope, cascadesContext,
false, true) {
private boolean currentIsInAggregateFunction;
@Override
public Expression visitAggregateFunction(AggregateFunction aggregateFunction,
ExpressionRewriteContext context) {
if (!currentIsInAggregateFunction) {
currentIsInAggregateFunction = true;
try {
return super.visitAggregateFunction(aggregateFunction, context);
} finally {
currentIsInAggregateFunction = false;
}
} else {
return super.visitAggregateFunction(aggregateFunction, context);
}
}
@Override
public Expression visitUnboundFunction(UnboundFunction unboundFunction, ExpressionRewriteContext context) {
if (!currentIsInAggregateFunction && isAggregateFunction(unboundFunction, functionRegistry)) {
currentIsInAggregateFunction = true;
try {
return super.visitUnboundFunction(unboundFunction, context);
} finally {
currentIsInAggregateFunction = false;
}
} else {
return super.visitUnboundFunction(unboundFunction, context);
}
}
@Override
protected List<? extends Expression> bindSlotByThisScope(UnboundSlot unboundSlot) {
if (currentIsInAggregateFunction) {
return bindByAggChild.get().bindSlot(this, unboundSlot);
} else {
return bindByGroupByThenAggOutputThenAggChild.get().bindSlot(this, unboundSlot);
}
}
};
Set<Expression> havingExprs = having.getConjuncts();
ImmutableSet.Builder<Expression> analyzedHaving = ImmutableSet.builderWithExpectedSize(havingExprs.size());
ExpressionRewriteContext rewriteContext = new ExpressionRewriteContext(cascadesContext);
for (Expression expression : havingExprs) {
analyzedHaving.add(havingAnalyzer.analyze(expression, rewriteContext));
}
return new LogicalHaving<>(analyzedHaving.build(), having.child());
}
private LogicalHaving<Plan> bindHavingByScopes(
@ -764,6 +868,11 @@ public class BindExpression implements AnalysisRuleFactory {
}
}
private boolean isAggregateFunction(UnboundFunction unboundFunction, FunctionRegistry functionRegistry) {
return functionRegistry.isAggregateFunction(
unboundFunction.getDbName(), unboundFunction.getName());
}
private <E extends Expression> E checkBoundExceptLambda(E expression, Plan plan) {
if (expression instanceof Lambda) {
return expression;
@ -797,6 +906,12 @@ public class BindExpression implements AnalysisRuleFactory {
boolean enableExactMatch, boolean bindSlotInOuterScope) {
List<Slot> childrenOutputs = PlanUtils.fastGetChildrenOutputs(children);
Scope scope = toScope(cascadesContext, childrenOutputs);
return buildSimpleExprAnalyzer(currentPlan, cascadesContext, scope, enableExactMatch, bindSlotInOuterScope);
}
private SimpleExprAnalyzer buildSimpleExprAnalyzer(
Plan currentPlan, CascadesContext cascadesContext, Scope scope,
boolean enableExactMatch, boolean bindSlotInOuterScope) {
ExpressionRewriteContext rewriteContext = new ExpressionRewriteContext(cascadesContext);
ExpressionAnalyzer expressionAnalyzer = new ExpressionAnalyzer(currentPlan,
scope, cascadesContext, enableExactMatch, bindSlotInOuterScope);

View File

@ -56,6 +56,11 @@ public class AggCombinerFunctionBuilder extends FunctionBuilder {
this.nestedBuilder = Objects.requireNonNull(nestedBuilder, "nestedBuilder can not be null");
}
@Override
public Class<? extends BoundFunction> functionClass() {
return nestedBuilder.functionClass();
}
@Override
public boolean canApply(List<? extends Object> arguments) {
if (combinatorSuffix.equals(STATE) || combinatorSuffix.equals(FOREACH)) {

View File

@ -42,13 +42,21 @@ public class BuiltinFunctionBuilder extends FunctionBuilder {
// Concrete BoundFunction's constructor
private final Constructor<BoundFunction> builderMethod;
private final Class<? extends BoundFunction> functionClass;
public BuiltinFunctionBuilder(Constructor<BoundFunction> builderMethod) {
public BuiltinFunctionBuilder(
Class<? extends BoundFunction> functionClass, Constructor<BoundFunction> builderMethod) {
this.functionClass = Objects.requireNonNull(functionClass, "functionClass can not be null");
this.builderMethod = Objects.requireNonNull(builderMethod, "builderMethod can not be null");
this.arity = builderMethod.getParameterCount();
this.isVariableLength = arity > 0 && builderMethod.getParameterTypes()[arity - 1].isArray();
}
@Override
public Class<? extends BoundFunction> functionClass() {
return functionClass;
}
@Override
public boolean canApply(List<? extends Object> arguments) {
if (isVariableLength && arity > arguments.size() + 1) {
@ -133,7 +141,7 @@ public class BuiltinFunctionBuilder extends FunctionBuilder {
+ functionClass.getSimpleName());
return Arrays.stream(functionClass.getConstructors())
.filter(constructor -> Modifier.isPublic(constructor.getModifiers()))
.map(constructor -> new BuiltinFunctionBuilder((Constructor<BoundFunction>) constructor))
.map(constructor -> new BuiltinFunctionBuilder(functionClass, (Constructor<BoundFunction>) constructor))
.collect(ImmutableList.toImmutableList());
}
}

View File

@ -27,6 +27,8 @@ import java.util.List;
* This class used to build BoundFunction(Builtin or Combinator) by a list of Expressions.
*/
public abstract class FunctionBuilder {
public abstract Class<? extends BoundFunction> functionClass();
/** check whether arguments can apply to the constructor */
public abstract boolean canApply(List<? extends Object> arguments);

View File

@ -50,6 +50,11 @@ public class AliasUdfBuilder extends UdfBuilder {
return aliasUdf.getArgTypes();
}
@Override
public Class<? extends BoundFunction> functionClass() {
return AliasUdf.class;
}
@Override
public boolean canApply(List<?> arguments) {
if (arguments.size() != aliasUdf.arity()) {

View File

@ -49,6 +49,11 @@ public class JavaUdafBuilder extends UdfBuilder {
.collect(Collectors.toList())).get();
}
@Override
public Class<? extends BoundFunction> functionClass() {
return JavaUdaf.class;
}
@Override
public boolean canApply(List<?> arguments) {
if ((isVarArgs && arity > arguments.size() + 1) || (!isVarArgs && arguments.size() != arity)) {

View File

@ -51,6 +51,11 @@ public class JavaUdfBuilder extends UdfBuilder {
.collect(Collectors.toList())).get();
}
@Override
public Class<? extends BoundFunction> functionClass() {
return JavaUdf.class;
}
@Override
public boolean canApply(List<?> arguments) {
if ((isVarArgs && arity > arguments.size() + 1) || (!isVarArgs && arguments.size() != arity)) {