[fix](Nereids) fix bind having aggregate failed (#32490)
fix bind having aggregate failed, keep the behavior like mysql
This commit is contained in:
@ -53,14 +53,14 @@ public class FunctionRegistry {
|
||||
// to record the global alias function and other udf.
|
||||
private static final String GLOBAL_FUNCTION = "__GLOBAL_FUNCTION__";
|
||||
|
||||
private final Map<String, List<FunctionBuilder>> name2InternalBuiltinBuilders;
|
||||
private final Map<String, List<FunctionBuilder>> name2BuiltinBuilders;
|
||||
private final Map<String, Map<String, List<FunctionBuilder>>> name2UdfBuilders;
|
||||
|
||||
public FunctionRegistry() {
|
||||
name2InternalBuiltinBuilders = new ConcurrentHashMap<>();
|
||||
name2BuiltinBuilders = new ConcurrentHashMap<>();
|
||||
name2UdfBuilders = new ConcurrentHashMap<>();
|
||||
registerBuiltinFunctions(name2InternalBuiltinBuilders);
|
||||
afterRegisterBuiltinFunctions(name2InternalBuiltinBuilders);
|
||||
registerBuiltinFunctions(name2BuiltinBuilders);
|
||||
afterRegisterBuiltinFunctions(name2BuiltinBuilders);
|
||||
}
|
||||
|
||||
// this function is used to test.
|
||||
@ -78,12 +78,33 @@ public class FunctionRegistry {
|
||||
}
|
||||
|
||||
public Optional<List<FunctionBuilder>> tryGetBuiltinBuilders(String name) {
|
||||
List<FunctionBuilder> builders = name2InternalBuiltinBuilders.get(name);
|
||||
return name2InternalBuiltinBuilders.get(name) == null
|
||||
List<FunctionBuilder> builders = name2BuiltinBuilders.get(name);
|
||||
return name2BuiltinBuilders.get(name) == null
|
||||
? Optional.empty()
|
||||
: Optional.of(ImmutableList.copyOf(builders));
|
||||
}
|
||||
|
||||
public boolean isAggregateFunction(String dbName, String name) {
|
||||
name = name.toLowerCase();
|
||||
Class<?> aggClass = org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction.class;
|
||||
if (StringUtils.isEmpty(dbName)) {
|
||||
List<FunctionBuilder> functionBuilders = name2BuiltinBuilders.get(name);
|
||||
for (FunctionBuilder functionBuilder : functionBuilders) {
|
||||
if (aggClass.isAssignableFrom(functionBuilder.functionClass())) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
List<FunctionBuilder> udfBuilders = findUdfBuilder(dbName, name);
|
||||
for (FunctionBuilder udfBuilder : udfBuilders) {
|
||||
if (aggClass.isAssignableFrom(udfBuilder.functionClass())) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// currently we only find function by name and arity and args' types.
|
||||
public FunctionBuilder findFunctionBuilder(String dbName, String name, List<?> arguments) {
|
||||
List<FunctionBuilder> functionBuilders = null;
|
||||
@ -92,11 +113,11 @@ public class FunctionRegistry {
|
||||
|
||||
if (StringUtils.isEmpty(dbName)) {
|
||||
// search internal function only if dbName is empty
|
||||
functionBuilders = name2InternalBuiltinBuilders.get(name.toLowerCase());
|
||||
functionBuilders = name2BuiltinBuilders.get(name.toLowerCase());
|
||||
if (CollectionUtils.isEmpty(functionBuilders) && AggCombinerFunctionBuilder.isAggStateCombinator(name)) {
|
||||
String nestedName = AggCombinerFunctionBuilder.getNestedName(name);
|
||||
String combinatorSuffix = AggCombinerFunctionBuilder.getCombinatorSuffix(name);
|
||||
functionBuilders = name2InternalBuiltinBuilders.get(nestedName.toLowerCase());
|
||||
functionBuilders = name2BuiltinBuilders.get(nestedName.toLowerCase());
|
||||
if (functionBuilders != null) {
|
||||
List<FunctionBuilder> candidateBuilders = Lists.newArrayListWithCapacity(functionBuilders.size());
|
||||
for (FunctionBuilder functionBuilder : functionBuilders) {
|
||||
@ -199,8 +220,8 @@ public class FunctionRegistry {
|
||||
}
|
||||
synchronized (name2UdfBuilders) {
|
||||
Map<String, List<FunctionBuilder>> builders = name2UdfBuilders.getOrDefault(dbName, ImmutableMap.of());
|
||||
builders.getOrDefault(name, Lists.newArrayList()).removeIf(builder -> ((UdfBuilder) builder).getArgTypes()
|
||||
.equals(argTypes));
|
||||
builders.getOrDefault(name, Lists.newArrayList())
|
||||
.removeIf(builder -> ((UdfBuilder) builder).getArgTypes().equals(argTypes));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -19,11 +19,13 @@ package org.apache.doris.nereids.rules.analysis;
|
||||
|
||||
import org.apache.doris.catalog.Env;
|
||||
import org.apache.doris.catalog.FunctionRegistry;
|
||||
import org.apache.doris.common.Pair;
|
||||
import org.apache.doris.nereids.CascadesContext;
|
||||
import org.apache.doris.nereids.NereidsPlanner;
|
||||
import org.apache.doris.nereids.StatementContext;
|
||||
import org.apache.doris.nereids.analyzer.MappingSlot;
|
||||
import org.apache.doris.nereids.analyzer.Scope;
|
||||
import org.apache.doris.nereids.analyzer.UnboundFunction;
|
||||
import org.apache.doris.nereids.analyzer.UnboundOneRowRelation;
|
||||
import org.apache.doris.nereids.analyzer.UnboundResultSink;
|
||||
import org.apache.doris.nereids.analyzer.UnboundSlot;
|
||||
@ -351,12 +353,12 @@ public class BindExpression implements AnalysisRuleFactory {
|
||||
CascadesContext cascadesContext = ctx.cascadesContext;
|
||||
|
||||
// bind slot by child.output first
|
||||
Scope defaultScope = toScope(cascadesContext, childPlan.getOutput());
|
||||
Scope childOutput = toScope(cascadesContext, childPlan.getOutput());
|
||||
// then bind slot by child.children.output
|
||||
Supplier<Scope> backupScope = Suppliers.memoize(() ->
|
||||
Supplier<Scope> childChildrenOutput = Suppliers.memoize(() ->
|
||||
toScope(cascadesContext, PlanUtils.fastGetChildrenOutputs(childPlan.children()))
|
||||
);
|
||||
return bindHavingByScopes(having, cascadesContext, defaultScope, backupScope);
|
||||
return bindHavingByScopes(having, cascadesContext, childOutput, childChildrenOutput);
|
||||
}
|
||||
|
||||
private LogicalHaving<Plan> bindHavingAggregate(
|
||||
@ -365,13 +367,115 @@ public class BindExpression implements AnalysisRuleFactory {
|
||||
Aggregate<Plan> aggregate = having.child();
|
||||
CascadesContext cascadesContext = ctx.cascadesContext;
|
||||
|
||||
// having(aggregate) should bind slot by aggregate.child.output first
|
||||
Scope defaultScope = toScope(cascadesContext, PlanUtils.fastGetChildrenOutputs(aggregate.children()));
|
||||
// then bind slot by aggregate.output
|
||||
Supplier<Scope> backupScope = Suppliers.memoize(() ->
|
||||
toScope(cascadesContext, aggregate.getOutput())
|
||||
);
|
||||
return bindHavingByScopes(ctx.root, ctx.cascadesContext, defaultScope, backupScope);
|
||||
// keep same behavior as mysql
|
||||
Supplier<CustomSlotBinderAnalyzer> bindByAggChild = Suppliers.memoize(() -> {
|
||||
Scope aggChildOutputScope
|
||||
= toScope(cascadesContext, PlanUtils.fastGetChildrenOutputs(aggregate.children()));
|
||||
return (analyzer, unboundSlot) -> analyzer.bindSlotByScope(unboundSlot, aggChildOutputScope);
|
||||
});
|
||||
|
||||
Scope aggOutputScope = toScope(cascadesContext, aggregate.getOutput());
|
||||
Supplier<CustomSlotBinderAnalyzer> bindByGroupByThenAggOutputThenAggChild = Suppliers.memoize(() -> {
|
||||
List<Expression> groupByExprs = aggregate.getGroupByExpressions();
|
||||
ImmutableList.Builder<Slot> groupBySlots
|
||||
= ImmutableList.builderWithExpectedSize(groupByExprs.size());
|
||||
for (Expression groupBy : groupByExprs) {
|
||||
if (groupBy instanceof Slot) {
|
||||
groupBySlots.add((Slot) groupBy);
|
||||
}
|
||||
}
|
||||
Scope groupBySlotsScope = toScope(cascadesContext, groupBySlots.build());
|
||||
|
||||
Supplier<Pair<Scope, Scope>> separateAggOutputScopes = Suppliers.memoize(() -> {
|
||||
ImmutableList.Builder<Slot> groupByOutputs = ImmutableList.builderWithExpectedSize(
|
||||
aggregate.getOutputExpressions().size());
|
||||
ImmutableList.Builder<Slot> aggFunOutputs = ImmutableList.builderWithExpectedSize(
|
||||
aggregate.getOutputExpressions().size());
|
||||
for (NamedExpression outputExpression : aggregate.getOutputExpressions()) {
|
||||
if (outputExpression.anyMatch(AggregateFunction.class::isInstance)) {
|
||||
aggFunOutputs.add(outputExpression.toSlot());
|
||||
} else {
|
||||
groupByOutputs.add(outputExpression.toSlot());
|
||||
}
|
||||
}
|
||||
Scope nonAggFunSlotsScope = toScope(cascadesContext, groupByOutputs.build());
|
||||
Scope aggFuncSlotsScope = toScope(cascadesContext, aggFunOutputs.build());
|
||||
return Pair.of(nonAggFunSlotsScope, aggFuncSlotsScope);
|
||||
});
|
||||
|
||||
return (analyzer, unboundSlot) -> {
|
||||
List<Slot> boundInGroupBy = analyzer.bindSlotByScope(unboundSlot, groupBySlotsScope);
|
||||
if (boundInGroupBy.size() == 1) {
|
||||
return boundInGroupBy;
|
||||
}
|
||||
|
||||
Pair<Scope, Scope> separateAggOutputScope = separateAggOutputScopes.get();
|
||||
List<Slot> boundInNonAggFuncs = analyzer.bindSlotByScope(unboundSlot, separateAggOutputScope.first);
|
||||
if (boundInNonAggFuncs.size() == 1) {
|
||||
return boundInNonAggFuncs;
|
||||
}
|
||||
|
||||
List<Slot> boundInAggFuncs = analyzer.bindSlotByScope(unboundSlot, separateAggOutputScope.second);
|
||||
if (boundInAggFuncs.size() == 1) {
|
||||
return boundInAggFuncs;
|
||||
}
|
||||
|
||||
return bindByAggChild.get().bindSlot(analyzer, unboundSlot);
|
||||
};
|
||||
});
|
||||
|
||||
FunctionRegistry functionRegistry = cascadesContext.getConnectContext().getEnv().getFunctionRegistry();
|
||||
ExpressionAnalyzer havingAnalyzer = new ExpressionAnalyzer(having, aggOutputScope, cascadesContext,
|
||||
false, true) {
|
||||
private boolean currentIsInAggregateFunction;
|
||||
|
||||
@Override
|
||||
public Expression visitAggregateFunction(AggregateFunction aggregateFunction,
|
||||
ExpressionRewriteContext context) {
|
||||
if (!currentIsInAggregateFunction) {
|
||||
currentIsInAggregateFunction = true;
|
||||
try {
|
||||
return super.visitAggregateFunction(aggregateFunction, context);
|
||||
} finally {
|
||||
currentIsInAggregateFunction = false;
|
||||
}
|
||||
} else {
|
||||
return super.visitAggregateFunction(aggregateFunction, context);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public Expression visitUnboundFunction(UnboundFunction unboundFunction, ExpressionRewriteContext context) {
|
||||
if (!currentIsInAggregateFunction && isAggregateFunction(unboundFunction, functionRegistry)) {
|
||||
currentIsInAggregateFunction = true;
|
||||
try {
|
||||
return super.visitUnboundFunction(unboundFunction, context);
|
||||
} finally {
|
||||
currentIsInAggregateFunction = false;
|
||||
}
|
||||
} else {
|
||||
return super.visitUnboundFunction(unboundFunction, context);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
protected List<? extends Expression> bindSlotByThisScope(UnboundSlot unboundSlot) {
|
||||
if (currentIsInAggregateFunction) {
|
||||
return bindByAggChild.get().bindSlot(this, unboundSlot);
|
||||
} else {
|
||||
return bindByGroupByThenAggOutputThenAggChild.get().bindSlot(this, unboundSlot);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
Set<Expression> havingExprs = having.getConjuncts();
|
||||
ImmutableSet.Builder<Expression> analyzedHaving = ImmutableSet.builderWithExpectedSize(havingExprs.size());
|
||||
ExpressionRewriteContext rewriteContext = new ExpressionRewriteContext(cascadesContext);
|
||||
for (Expression expression : havingExprs) {
|
||||
analyzedHaving.add(havingAnalyzer.analyze(expression, rewriteContext));
|
||||
}
|
||||
|
||||
return new LogicalHaving<>(analyzedHaving.build(), having.child());
|
||||
}
|
||||
|
||||
private LogicalHaving<Plan> bindHavingByScopes(
|
||||
@ -764,6 +868,11 @@ public class BindExpression implements AnalysisRuleFactory {
|
||||
}
|
||||
}
|
||||
|
||||
private boolean isAggregateFunction(UnboundFunction unboundFunction, FunctionRegistry functionRegistry) {
|
||||
return functionRegistry.isAggregateFunction(
|
||||
unboundFunction.getDbName(), unboundFunction.getName());
|
||||
}
|
||||
|
||||
private <E extends Expression> E checkBoundExceptLambda(E expression, Plan plan) {
|
||||
if (expression instanceof Lambda) {
|
||||
return expression;
|
||||
@ -797,6 +906,12 @@ public class BindExpression implements AnalysisRuleFactory {
|
||||
boolean enableExactMatch, boolean bindSlotInOuterScope) {
|
||||
List<Slot> childrenOutputs = PlanUtils.fastGetChildrenOutputs(children);
|
||||
Scope scope = toScope(cascadesContext, childrenOutputs);
|
||||
return buildSimpleExprAnalyzer(currentPlan, cascadesContext, scope, enableExactMatch, bindSlotInOuterScope);
|
||||
}
|
||||
|
||||
private SimpleExprAnalyzer buildSimpleExprAnalyzer(
|
||||
Plan currentPlan, CascadesContext cascadesContext, Scope scope,
|
||||
boolean enableExactMatch, boolean bindSlotInOuterScope) {
|
||||
ExpressionRewriteContext rewriteContext = new ExpressionRewriteContext(cascadesContext);
|
||||
ExpressionAnalyzer expressionAnalyzer = new ExpressionAnalyzer(currentPlan,
|
||||
scope, cascadesContext, enableExactMatch, bindSlotInOuterScope);
|
||||
|
||||
@ -56,6 +56,11 @@ public class AggCombinerFunctionBuilder extends FunctionBuilder {
|
||||
this.nestedBuilder = Objects.requireNonNull(nestedBuilder, "nestedBuilder can not be null");
|
||||
}
|
||||
|
||||
@Override
|
||||
public Class<? extends BoundFunction> functionClass() {
|
||||
return nestedBuilder.functionClass();
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean canApply(List<? extends Object> arguments) {
|
||||
if (combinatorSuffix.equals(STATE) || combinatorSuffix.equals(FOREACH)) {
|
||||
|
||||
@ -42,13 +42,21 @@ public class BuiltinFunctionBuilder extends FunctionBuilder {
|
||||
|
||||
// Concrete BoundFunction's constructor
|
||||
private final Constructor<BoundFunction> builderMethod;
|
||||
private final Class<? extends BoundFunction> functionClass;
|
||||
|
||||
public BuiltinFunctionBuilder(Constructor<BoundFunction> builderMethod) {
|
||||
public BuiltinFunctionBuilder(
|
||||
Class<? extends BoundFunction> functionClass, Constructor<BoundFunction> builderMethod) {
|
||||
this.functionClass = Objects.requireNonNull(functionClass, "functionClass can not be null");
|
||||
this.builderMethod = Objects.requireNonNull(builderMethod, "builderMethod can not be null");
|
||||
this.arity = builderMethod.getParameterCount();
|
||||
this.isVariableLength = arity > 0 && builderMethod.getParameterTypes()[arity - 1].isArray();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Class<? extends BoundFunction> functionClass() {
|
||||
return functionClass;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean canApply(List<? extends Object> arguments) {
|
||||
if (isVariableLength && arity > arguments.size() + 1) {
|
||||
@ -133,7 +141,7 @@ public class BuiltinFunctionBuilder extends FunctionBuilder {
|
||||
+ functionClass.getSimpleName());
|
||||
return Arrays.stream(functionClass.getConstructors())
|
||||
.filter(constructor -> Modifier.isPublic(constructor.getModifiers()))
|
||||
.map(constructor -> new BuiltinFunctionBuilder((Constructor<BoundFunction>) constructor))
|
||||
.map(constructor -> new BuiltinFunctionBuilder(functionClass, (Constructor<BoundFunction>) constructor))
|
||||
.collect(ImmutableList.toImmutableList());
|
||||
}
|
||||
}
|
||||
|
||||
@ -27,6 +27,8 @@ import java.util.List;
|
||||
* This class used to build BoundFunction(Builtin or Combinator) by a list of Expressions.
|
||||
*/
|
||||
public abstract class FunctionBuilder {
|
||||
public abstract Class<? extends BoundFunction> functionClass();
|
||||
|
||||
/** check whether arguments can apply to the constructor */
|
||||
public abstract boolean canApply(List<? extends Object> arguments);
|
||||
|
||||
|
||||
@ -50,6 +50,11 @@ public class AliasUdfBuilder extends UdfBuilder {
|
||||
return aliasUdf.getArgTypes();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Class<? extends BoundFunction> functionClass() {
|
||||
return AliasUdf.class;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean canApply(List<?> arguments) {
|
||||
if (arguments.size() != aliasUdf.arity()) {
|
||||
|
||||
@ -49,6 +49,11 @@ public class JavaUdafBuilder extends UdfBuilder {
|
||||
.collect(Collectors.toList())).get();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Class<? extends BoundFunction> functionClass() {
|
||||
return JavaUdaf.class;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean canApply(List<?> arguments) {
|
||||
if ((isVarArgs && arity > arguments.size() + 1) || (!isVarArgs && arguments.size() != arity)) {
|
||||
|
||||
@ -51,6 +51,11 @@ public class JavaUdfBuilder extends UdfBuilder {
|
||||
.collect(Collectors.toList())).get();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Class<? extends BoundFunction> functionClass() {
|
||||
return JavaUdf.class;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean canApply(List<?> arguments) {
|
||||
if ((isVarArgs && arity > arguments.size() + 1) || (!isVarArgs && arguments.size() != arity)) {
|
||||
|
||||
Reference in New Issue
Block a user