diff --git a/fe/fe-core/src/main/antlr4/org/apache/doris/nereids/DorisLexer.g4 b/fe/fe-core/src/main/antlr4/org/apache/doris/nereids/DorisLexer.g4 index f076731864..b79b3836b4 100644 --- a/fe/fe-core/src/main/antlr4/org/apache/doris/nereids/DorisLexer.g4 +++ b/fe/fe-core/src/main/antlr4/org/apache/doris/nereids/DorisLexer.g4 @@ -94,6 +94,7 @@ AFTER: 'AFTER'; ALL: 'ALL'; ALTER: 'ALTER'; ANALYZE: 'ANALYZE'; +ANALYZED: 'ANALYZED'; AND: 'AND'; ANTI: 'ANTI'; ANY: 'ANY'; @@ -247,6 +248,7 @@ NULLS: 'NULLS'; OF: 'OF'; ON: 'ON'; ONLY: 'ONLY'; +OPTIMIZED: 'OPTIMIZED'; OPTION: 'OPTION'; OPTIONS: 'OPTIONS'; OR: 'OR'; @@ -258,13 +260,16 @@ OVER: 'OVER'; OVERLAPS: 'OVERLAPS'; OVERLAY: 'OVERLAY'; OVERWRITE: 'OVERWRITE'; +PARSED: 'PARSED'; PARTITION: 'PARTITION'; PARTITIONED: 'PARTITIONED'; PARTITIONS: 'PARTITIONS'; PERCENTILE_CONT: 'PERCENTILE_CONT'; PERCENTLIT: 'PERCENT'; +PHYSICAL: 'PHYSICAL'; PIVOT: 'PIVOT'; PLACING: 'PLACING'; +PLAN: 'PLAN'; POSITION: 'POSITION'; PRECEDING: 'PRECEDING'; PRIMARY: 'PRIMARY'; @@ -288,6 +293,7 @@ RESET: 'RESET'; RESPECT: 'RESPECT'; RESTRICT: 'RESTRICT'; REVOKE: 'REVOKE'; +REWRITTEN: 'REWRITTEN'; RIGHT: 'RIGHT'; // original optimizer only support REGEXP, the new optimizer should be consistent with it RLIKE: 'RLIKE'; diff --git a/fe/fe-core/src/main/antlr4/org/apache/doris/nereids/DorisParser.g4 b/fe/fe-core/src/main/antlr4/org/apache/doris/nereids/DorisParser.g4 index fac0f3575a..4641926d4a 100644 --- a/fe/fe-core/src/main/antlr4/org/apache/doris/nereids/DorisParser.g4 +++ b/fe/fe-core/src/main/antlr4/org/apache/doris/nereids/DorisParser.g4 @@ -50,7 +50,16 @@ singleStatement statement : cte? query #statementDefault - | (EXPLAIN | DESC | DESCRIBE) level=(VERBOSE | GRAPH)? query #explain + | (EXPLAIN planType? | DESC | DESCRIBE) + level=(VERBOSE | GRAPH | PLAN)? query #explain + ; + +planType + : PARSED + | ANALYZED + | REWRITTEN | LOGICAL // same type + | OPTIMIZED | PHYSICAL // same type + | ALL // default type ; // -----------------Query----------------- @@ -335,6 +344,7 @@ ansiNonReserved | AFTER | ALTER | ANALYZE + | ANALYZED | ANTI | ARCHIVE | ARRAY @@ -441,6 +451,7 @@ ansiNonReserved | NO | NULLS | OF + | OPTIMIZED | OPTION | OPTIONS | OUT @@ -448,12 +459,15 @@ ansiNonReserved | OVER | OVERLAY | OVERWRITE + | PARSED | PARTITION | PARTITIONED | PARTITIONS | PERCENTLIT + | PHYSICAL | PIVOT | PLACING + | PLAN | POSITION | PRECEDING | PRINCIPALS @@ -474,6 +488,7 @@ ansiNonReserved | RESPECT | RESTRICT | REVOKE + | REWRITTEN | RLIKE | ROLE | ROLES @@ -575,6 +590,7 @@ nonReserved | ALL | ALTER | ANALYZE + | ANALYZED | AND | ANY | ARCHIVE @@ -716,6 +732,7 @@ nonReserved | NULLS | OF | ONLY + | OPTIMIZED | OPTION | OPTIONS | OR @@ -727,13 +744,16 @@ nonReserved | OVERLAPS | OVERLAY | OVERWRITE + | PARSED | PARTITION | PARTITIONED | PARTITIONS | PERCENTILE_CONT | PERCENTLIT + | PHYSICAL | PIVOT | PLACING + | PLAN | POSITION | PRECEDING | PRIMARY @@ -756,6 +776,7 @@ nonReserved | RESPECT | RESTRICT | REVOKE + | REWRITTEN | RLIKE | ROLE | ROLES diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/ExplainOptions.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/ExplainOptions.java index 543fd5bbae..14286f1de8 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/analysis/ExplainOptions.java +++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/ExplainOptions.java @@ -17,21 +17,38 @@ package org.apache.doris.analysis; +import org.apache.doris.nereids.trees.plans.commands.ExplainCommand; +import org.apache.doris.nereids.trees.plans.commands.ExplainCommand.ExplainLevel; + public class ExplainOptions { private boolean isVerbose; private boolean isGraph; + private ExplainCommand.ExplainLevel explainLevel; + + public ExplainOptions(ExplainCommand.ExplainLevel explainLevel) { + this.explainLevel = explainLevel; + } + public ExplainOptions(boolean isVerbose, boolean isGraph) { this.isVerbose = isVerbose; this.isGraph = isGraph; } public boolean isVerbose() { - return isVerbose; + return explainLevel == ExplainLevel.VERBOSE || isVerbose; } public boolean isGraph() { - return isGraph; + return explainLevel == ExplainLevel.GRAPH || isGraph; + } + + public boolean hasExplainLevel() { + return explainLevel != null; + } + + public ExplainLevel getExplainLevel() { + return explainLevel; } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java index 0e618b6404..bf1869b277 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java +++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java @@ -220,16 +220,22 @@ public class FunctionCallExpr extends Expr { this.argTypesForNereids = argTypes; } - // nereids constructor without finalize/analyze - public FunctionCallExpr(FunctionName functionName, Function function, FunctionParams functionParams) { - this.fnName = functionName; + // nereids scalar function call expr constructor without finalize/analyze + public FunctionCallExpr(Function function, FunctionParams functionParams) { + this(function, functionParams, null, false, functionParams.exprs()); + } + + // nereids aggregate function call expr constructor without finalize/analyze + public FunctionCallExpr(Function function, FunctionParams functionParams, FunctionParams aggFnParams, + boolean isMergeAggFn, List children) { + this.fnName = function.getFunctionName(); this.fn = function; this.type = function.getReturnType(); this.fnParams = functionParams; - if (functionParams.exprs() != null) { - this.children.addAll(functionParams.exprs()); - } + this.aggFnParams = aggFnParams; + this.children.addAll(children); this.originChildSize = children.size(); + this.isMergeAggFn = isMergeAggFn; this.shouldFinalizeForNereids = false; } diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/SlotRef.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/SlotRef.java index 3573d5620b..2c9db0fdad 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/analysis/SlotRef.java +++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/SlotRef.java @@ -84,6 +84,19 @@ public class SlotRef extends Expr { analysisDone(); } + // nerieds use this constructor to build aggFnParam + public SlotRef(Type type, boolean nullable) { + super(); + // tuple id and slot id is meaningless here, nereids just use type and nullable + // to build the TAggregateExpr.param_types + TupleDescriptor tupleDescriptor = new TupleDescriptor(new TupleId(-1)); + desc = new SlotDescriptor(new SlotId(-1), tupleDescriptor); + tupleDescriptor.addSlot(desc); + desc.setIsNullable(nullable); + desc.setType(type); + this.type = type; + } + protected SlotRef(SlotRef other) { super(other); tblName = other.tblName; diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/AggregateFunction.java b/fe/fe-core/src/main/java/org/apache/doris/catalog/AggregateFunction.java index 6ce6b37f7d..12363ed818 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/catalog/AggregateFunction.java +++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/AggregateFunction.java @@ -183,6 +183,30 @@ public class AggregateFunction extends Function { returnsNonNullOnEmpty = false; } + public AggregateFunction(FunctionName fnName, List argTypes, + Type retType, Type intermediateType, boolean hasVarArgs, + URI location, String updateFnSymbol, String initFnSymbol, + String serializeFnSymbol, String mergeFnSymbol, String getValueFnSymbol, + String removeFnSymbol, String finalizeFnSymbol, boolean ignoresDistinct, + boolean isAnalyticFn, boolean returnsNonNullOnEmpty, TFunctionBinaryType binaryType, + boolean userVisible, boolean vectorized, NullableMode nullableMode) { + // only `count` is always not nullable, other aggregate function is always nullable + super(0, fnName, argTypes, retType, hasVarArgs, binaryType, userVisible, vectorized, nullableMode); + setLocation(location); + this.intermediateType = (intermediateType.equals(retType)) ? null : intermediateType; + this.updateFnSymbol = updateFnSymbol; + this.initFnSymbol = initFnSymbol; + this.serializeFnSymbol = serializeFnSymbol; + this.mergeFnSymbol = mergeFnSymbol; + this.getValueFnSymbol = getValueFnSymbol; + this.removeFnSymbol = removeFnSymbol; + this.finalizeFnSymbol = finalizeFnSymbol; + this.ignoresDistinct = ignoresDistinct; + this.isAnalyticFn = isAnalyticFn; + this.isAggregateFn = true; + this.returnsNonNullOnEmpty = returnsNonNullOnEmpty; + } + public static AggregateFunction createBuiltin(String name, List argTypes, Type retType, Type intermediateType, String initFnSymbol, String updateFnSymbol, String mergeFnSymbol, diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSignature.java b/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSignature.java index a1104a71e2..66de21153b 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSignature.java +++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSignature.java @@ -19,6 +19,7 @@ package org.apache.doris.catalog; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.types.DataType; +import org.apache.doris.nereids.types.coercion.AbstractDataType; import com.google.common.collect.ImmutableList; import com.google.common.collect.Lists; @@ -32,10 +33,10 @@ import java.util.function.BiFunction; public class FunctionSignature { public final DataType returnType; public final boolean hasVarArgs; - public final List argumentsTypes; + public final List argumentsTypes; public final int arity; - public FunctionSignature(DataType returnType, boolean hasVarArgs, List argumentsTypes) { + public FunctionSignature(DataType returnType, boolean hasVarArgs, List argumentsTypes) { this.returnType = Objects.requireNonNull(returnType, "returnType is not null"); this.argumentsTypes = ImmutableList.copyOf( Objects.requireNonNull(argumentsTypes, "argumentsTypes is not null")); @@ -43,11 +44,11 @@ public class FunctionSignature { this.arity = argumentsTypes.size(); } - public Optional getVarArgType() { + public Optional getVarArgType() { return hasVarArgs ? Optional.of(argumentsTypes.get(arity - 1)) : Optional.empty(); } - public DataType getArgType(int index) { + public AbstractDataType getArgType(int index) { if (hasVarArgs && index >= arity) { return argumentsTypes.get(arity - 1); } @@ -58,7 +59,7 @@ public class FunctionSignature { return new FunctionSignature(returnType, hasVarArgs, argumentsTypes); } - public FunctionSignature withArgumentTypes(boolean hasVarArgs, List argumentsTypes) { + public FunctionSignature withArgumentTypes(boolean hasVarArgs, List argumentsTypes) { return new FunctionSignature(returnType, hasVarArgs, argumentsTypes); } @@ -69,27 +70,27 @@ public class FunctionSignature { * @return */ public FunctionSignature withArgumentTypes(List arguments, - BiFunction transform) { - List newTypes = Lists.newArrayList(); + BiFunction transform) { + List newTypes = Lists.newArrayList(); for (int i = 0; i < arguments.size(); i++) { newTypes.add(transform.apply(getArgType(i), arguments.get(i))); } return withArgumentTypes(hasVarArgs, newTypes); } - public static FunctionSignature of(DataType returnType, List argumentsTypes) { + public static FunctionSignature of(DataType returnType, List argumentsTypes) { return of(returnType, false, argumentsTypes); } - public static FunctionSignature of(DataType returnType, boolean hasVarArgs, List argumentsTypes) { + public static FunctionSignature of(DataType returnType, boolean hasVarArgs, List argumentsTypes) { return new FunctionSignature(returnType, hasVarArgs, argumentsTypes); } - public static FunctionSignature of(DataType returnType, DataType... argumentsTypes) { + public static FunctionSignature of(DataType returnType, AbstractDataType... argumentsTypes) { return of(returnType, false, argumentsTypes); } - public static FunctionSignature of(DataType returnType, boolean hasVarArgs, DataType... argumentsTypes) { + public static FunctionSignature of(DataType returnType, boolean hasVarArgs, AbstractDataType... argumentsTypes) { return new FunctionSignature(returnType, hasVarArgs, Arrays.asList(argumentsTypes)); } @@ -104,11 +105,11 @@ public class FunctionSignature { this.returnType = returnType; } - public FunctionSignature args(DataType...argTypes) { + public FunctionSignature args(AbstractDataType...argTypes) { return FunctionSignature.of(returnType, false, argTypes); } - public FunctionSignature varArgs(DataType...argTypes) { + public FunctionSignature varArgs(AbstractDataType...argTypes) { return FunctionSignature.of(returnType, true, argTypes); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/ScalarFunction.java b/fe/fe-core/src/main/java/org/apache/doris/catalog/ScalarFunction.java index a04cc5aaa3..148faa8ecd 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/catalog/ScalarFunction.java +++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/ScalarFunction.java @@ -74,6 +74,13 @@ public class ScalarFunction extends Function { NullableMode.DEPEND_ON_ARGUMENT); } + /** nerieds custom scalar function */ + public ScalarFunction(FunctionName fnName, List argTypes, Type retType, boolean hasVarArgs, String symbolName, + TFunctionBinaryType binaryType, boolean userVisible, boolean isVec, NullableMode nullableMode) { + super(0, fnName, argTypes, retType, hasVarArgs, binaryType, userVisible, isVec, nullableMode); + this.symbolName = symbolName; + } + public ScalarFunction(FunctionName fnName, List argTypes, Type retType, URI location, String symbolName, String initFnSymbol, String closeFnSymbol) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/NereidsPlanner.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/NereidsPlanner.java index 8638b0217d..08157e5c62 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/NereidsPlanner.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/NereidsPlanner.java @@ -18,9 +18,10 @@ package org.apache.doris.nereids; import org.apache.doris.analysis.DescriptorTable; +import org.apache.doris.analysis.ExplainOptions; import org.apache.doris.analysis.StatementBase; -import org.apache.doris.common.AnalysisException; -import org.apache.doris.common.UserException; +import org.apache.doris.common.NereidsException; +import org.apache.doris.nereids.exceptions.AnalysisException; import org.apache.doris.nereids.glue.LogicalPlanAdapter; import org.apache.doris.nereids.glue.translator.PhysicalPlanTranslator; import org.apache.doris.nereids.glue.translator.PlanTranslatorContext; @@ -36,6 +37,7 @@ import org.apache.doris.nereids.properties.PhysicalProperties; import org.apache.doris.nereids.rules.joinreorder.HyperGraphJoinReorder; import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.commands.ExplainCommand.ExplainLevel; import org.apache.doris.nereids.trees.plans.logical.LogicalPlan; import org.apache.doris.nereids.trees.plans.physical.PhysicalPlan; import org.apache.doris.planner.PlanFragment; @@ -64,18 +66,29 @@ public class NereidsPlanner extends Planner { private List scanNodeList = null; private DescriptorTable descTable; + private Plan parsedPlan; + private Plan analyzedPlan; + private Plan rewrittenPlan; + private Plan optimizedPlan; + public NereidsPlanner(StatementContext statementContext) { this.statementContext = statementContext; } @Override - public void plan(StatementBase queryStmt, org.apache.doris.thrift.TQueryOptions queryOptions) throws UserException { + public void plan(StatementBase queryStmt, org.apache.doris.thrift.TQueryOptions queryOptions) { if (!(queryStmt instanceof LogicalPlanAdapter)) { throw new RuntimeException("Wrong type of queryStmt, expected: "); } LogicalPlanAdapter logicalPlanAdapter = (LogicalPlanAdapter) queryStmt; - PhysicalPlan physicalPlan = plan(logicalPlanAdapter.getLogicalPlan(), PhysicalProperties.ANY); + ExplainLevel explainLevel = getExplainLevel(queryStmt.getExplainOptions()); + Plan resultPlan = plan(logicalPlanAdapter.getLogicalPlan(), PhysicalProperties.ANY, explainLevel); + if (explainLevel.isPlanLevel) { + return; + } + + PhysicalPlan physicalPlan = (PhysicalPlan) resultPlan; PhysicalPlanTranslator physicalPlanTranslator = new PhysicalPlanTranslator(); PlanTranslatorContext planTranslatorContext = new PlanTranslatorContext(cascadesContext); if (ConnectContext.get().getSessionVariable().isEnableNereidsTrace()) { @@ -103,20 +116,30 @@ public class NereidsPlanner extends Planner { public void plan(StatementBase queryStmt) { try { plan(queryStmt, statementContext.getConnectContext().getSessionVariable().toThrift()); - } catch (UserException e) { - throw new RuntimeException(e); + } catch (Exception e) { + throw new NereidsException(e); } } + public PhysicalPlan plan(LogicalPlan plan, PhysicalProperties outputProperties) { + return (PhysicalPlan) plan(plan, outputProperties, ExplainLevel.NONE); + } + /** * Do analyze and optimize for query plan. * * @param plan wait for plan * @param outputProperties physical properties constraints - * @return physical plan generated by this planner + * @return plan generated by this planner * @throws AnalysisException throw exception if failed in ant stage */ - public PhysicalPlan plan(LogicalPlan plan, PhysicalProperties outputProperties) throws AnalysisException { + public Plan plan(LogicalPlan plan, PhysicalProperties outputProperties, ExplainLevel explainLevel) { + if (explainLevel == ExplainLevel.PARSED_PLAN || explainLevel == ExplainLevel.ALL_PLAN) { + parsedPlan = plan; + if (explainLevel == ExplainLevel.PARSED_PLAN) { + return parsedPlan; + } + } // pre-process logical plan out of memo, e.g. process SET_VAR hint plan = preprocess(plan); @@ -125,9 +148,21 @@ public class NereidsPlanner extends Planner { // resolve column, table and function analyze(); + if (explainLevel == ExplainLevel.ANALYZED_PLAN || explainLevel == ExplainLevel.ALL_PLAN) { + analyzedPlan = cascadesContext.getMemo().copyOut(false); + if (explainLevel == ExplainLevel.ANALYZED_PLAN) { + return analyzedPlan; + } + } // rule-based optimize rewrite(); + if (explainLevel == ExplainLevel.REWRITTEN_PLAN || explainLevel == ExplainLevel.ALL_PLAN) { + rewrittenPlan = cascadesContext.getMemo().copyOut(false); + if (explainLevel == ExplainLevel.REWRITTEN_PLAN) { + return rewrittenPlan; + } + } deriveStats(); @@ -142,7 +177,12 @@ public class NereidsPlanner extends Planner { PhysicalPlan physicalPlan = chooseBestPlan(getRoot(), PhysicalProperties.ANY); // post-process physical plan out of memo, just for future use. - return postProcess(physicalPlan); + physicalPlan = postProcess(physicalPlan); + if (explainLevel == ExplainLevel.OPTIMIZED_PLAN || explainLevel == ExplainLevel.ALL_PLAN) { + optimizedPlan = physicalPlan; + } + + return physicalPlan; } private LogicalPlan preprocess(LogicalPlan logicalPlan) { @@ -229,6 +269,33 @@ public class NereidsPlanner extends Planner { } } + @Override + public String getExplainString(ExplainOptions explainOptions) { + ExplainLevel explainLevel = getExplainLevel(explainOptions); + switch (explainLevel) { + case PARSED_PLAN: + return parsedPlan.treeString(); + case ANALYZED_PLAN: + return analyzedPlan.treeString(); + case REWRITTEN_PLAN: + return rewrittenPlan.treeString(); + case OPTIMIZED_PLAN: + return optimizedPlan.treeString(); + case ALL_PLAN: + String explainString = "========== PARSED PLAN ==========\n" + + parsedPlan.treeString() + "\n\n" + + "========== ANALYZED PLAN ==========\n" + + analyzedPlan.treeString() + "\n\n" + + "========== REWRITTEN PLAN ==========\n" + + rewrittenPlan.treeString() + "\n\n" + + "========== OPTIMIZED PLAN ==========\n" + + optimizedPlan.treeString(); + return explainString; + default: + return super.getExplainString(explainOptions); + } + } + @Override public boolean isBlockQuery() { return true; @@ -253,4 +320,12 @@ public class NereidsPlanner extends Planner { public CascadesContext getCascadesContext() { return cascadesContext; } + + private ExplainLevel getExplainLevel(ExplainOptions explainOptions) { + if (explainOptions == null) { + return ExplainLevel.NONE; + } + ExplainLevel explainLevel = explainOptions.getExplainLevel(); + return explainLevel == null ? ExplainLevel.NONE : explainLevel; + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/StatementContext.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/StatementContext.java index db4bfa6b4e..8b04dc8f06 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/StatementContext.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/StatementContext.java @@ -43,6 +43,8 @@ public class StatementContext { private CTEContext cteContext; public StatementContext() { + this.connectContext = ConnectContext.get(); + this.cteContext = new CTEContext(); } public StatementContext(ConnectContext connectContext, OriginStatement originStatement) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/ExpressionTranslator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/ExpressionTranslator.java index c6f8589515..15e92af4d5 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/ExpressionTranslator.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/ExpressionTranslator.java @@ -33,10 +33,8 @@ import org.apache.doris.analysis.FunctionParams; import org.apache.doris.analysis.LikePredicate; import org.apache.doris.analysis.SlotRef; import org.apache.doris.analysis.TimestampArithmeticExpr; -import org.apache.doris.catalog.Function; import org.apache.doris.catalog.Function.NullableMode; import org.apache.doris.catalog.Type; -import org.apache.doris.nereids.annotation.Developing; import org.apache.doris.nereids.exceptions.AnalysisException; import org.apache.doris.nereids.trees.expressions.Alias; import org.apache.doris.nereids.trees.expressions.And; @@ -67,8 +65,10 @@ import org.apache.doris.nereids.trees.expressions.functions.agg.Count; import org.apache.doris.nereids.trees.expressions.functions.scalar.ScalarFunction; import org.apache.doris.nereids.trees.expressions.literal.Literal; import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionVisitor; -import org.apache.doris.nereids.types.DataType; import org.apache.doris.nereids.types.coercion.AbstractDataType; +import org.apache.doris.thrift.TFunctionBinaryType; + +import com.google.common.collect.ImmutableList; import java.util.ArrayList; import java.util.List; @@ -256,68 +256,56 @@ public class ExpressionTranslator extends DefaultExpressionVisitor> inputTypesBeforeDissemble = function.inputTypesBeforeDissemble() - .map(types -> types.stream() - .map(DataType::toCatalogDataType) - .collect(Collectors.toList()) - ); + List catalogArguments = function.getArguments() + .stream() + .map(arg -> arg.accept(this, context)) + .collect(ImmutableList.toImmutableList()); - // We should change the global aggregate function's temporary input type(varchar) to the origin input type. - // - // For example: the global aggregate function expression 'avg(slotRef(type=varchar))' of the origin function - // 'avg(int)' should change to 'avg(slotRef(type=int))', because FunctionCallExpr will be converted to thrift - // format, and compute signature string by the children's type, we should pass the signature 'avg(int)' to - // **backend**. If we pass 'avg(varchar)' to backend, it will throw an exception: 'Agg Function avg is not - // implemented'. - List catalogParams = new ArrayList<>(); - for (int i = 0; i < function.arity(); i++) { - Expr catalogExpr = function.child(i).accept(this, context); - if (catalogExpr instanceof SlotRef && inputTypesBeforeDissemble.isPresent() - // count(*) in local aggregate contains empty children - // but contains one child in global aggregate: 'count(count(*))'. - // so the size of inputTypesBeforeDissemble maybe less than global aggregate param. - && inputTypesBeforeDissemble.get().size() > i) { - SlotRef intermediateSlot = (SlotRef) catalogExpr.clone(); - // change the slot type to origin input type - intermediateSlot.setType(inputTypesBeforeDissemble.get().get(i)); - catalogExpr = intermediateSlot; - } - catalogParams.add(catalogExpr); + // aggFnArguments is used to build TAggregateExpr.param_types, so backend can find the aggregate function + List aggFnArguments = function.getArgumentsBeforeDisassembled() + .stream() + .map(arg -> new SlotRef(arg.getDataType().toCatalogDataType(), arg.nullable())) + .collect(ImmutableList.toImmutableList()); + + FunctionParams aggFnParams; + if (function instanceof Count && ((Count) function).isStar()) { + aggFnParams = FunctionParams.createStarParam(); + } else { + aggFnParams = new FunctionParams(function.isDistinct(), aggFnArguments); } - boolean distinct = function.isDistinct(); - FunctionParams aggFnParams = new FunctionParams(distinct, catalogParams); + ImmutableList argTypes = catalogArguments.stream() + .map(arg -> arg.getType()) + .collect(ImmutableList.toImmutableList()); - if (function instanceof Count) { - Count count = (Count) function; - if (count.isStar()) { - return new FunctionCallExpr(function.getName(), FunctionParams.createStarParam(), - aggFnParams, inputTypesBeforeDissemble); - } else if (count.isDistinct()) { - return new FunctionCallExpr(function.getName(), new FunctionParams(distinct, catalogParams), - aggFnParams, inputTypesBeforeDissemble); - } - } - return new FunctionCallExpr(function.getName(), new FunctionParams(distinct, catalogParams), - aggFnParams, inputTypesBeforeDissemble); + NullableMode nullableMode = function.nullable() + ? NullableMode.ALWAYS_NULLABLE + : NullableMode.ALWAYS_NOT_NULLABLE; + + boolean isAnalyticFunction = false; + org.apache.doris.catalog.AggregateFunction catalogFunction = new org.apache.doris.catalog.AggregateFunction( + new FunctionName(function.getName()), argTypes, + function.getDataType().toCatalogDataType(), + function.getIntermediateTypes().toCatalogDataType(), + function.hasVarArguments(), + null, "", "", null, "", + null, "", null, false, + isAnalyticFunction, false, TFunctionBinaryType.BUILTIN, + true, true, nullableMode + ); + + boolean isMergeFn = function.isGlobal() && function.isDisassembled(); + + // create catalog FunctionCallExpr without analyze again + return new FunctionCallExpr(catalogFunction, aggFnParams, aggFnParams, isMergeFn, catalogArguments); } @Override public Expr visitScalarFunction(ScalarFunction function, PlanTranslatorContext context) { List arguments = function.getArguments() - .stream().map(arg -> arg.accept(this, context)) + .stream() + .map(arg -> arg.accept(this, context)) .collect(Collectors.toList()); List argTypes = function.expectedInputTypes().stream() .map(AbstractDataType::toCatalogDataType) @@ -327,12 +315,13 @@ public class ExpressionTranslator extends DefaultExpressionVisitor { List> logicalPlans = Lists.newArrayList(); for (org.apache.doris.nereids.DorisParser.StatementContext statement : ctx.statement()) { StatementContext statementContext = new StatementContext(); - if (ConnectContext.get() != null) { - ConnectContext.get().setStatementContext(statementContext); + ConnectContext connectContext = ConnectContext.get(); + if (connectContext != null) { + connectContext.setStatementContext(statementContext); + statementContext.setConnectContext(connectContext); } logicalPlans.add(Pair.of( ParserUtils.withOrigin(ctx, () -> (LogicalPlan) visit(statement)), statementContext)); @@ -258,12 +261,25 @@ public class LogicalPlanBuilder extends DorisParserBaseVisitor { @Override public Command visitExplain(ExplainContext ctx) { - LogicalPlan logicalPlan = plan(ctx.query()); - ExplainLevel explainLevel = ExplainLevel.NORMAL; - if (ctx.level != null) { - explainLevel = ExplainLevel.valueOf(ctx.level.getText().toUpperCase(Locale.ROOT)); - } - return new ExplainCommand(explainLevel, logicalPlan); + return ParserUtils.withOrigin(ctx, () -> { + LogicalPlan logicalPlan = plan(ctx.query()); + ExplainLevel explainLevel = ExplainLevel.NORMAL; + + if (ctx.planType() != null) { + if (ctx.level == null || !ctx.level.getText().equalsIgnoreCase("plan")) { + throw new ParseException("Only explain plan can use plan type: " + ctx.planType().getText(), ctx); + } + } + + if (ctx.level != null) { + if (!ctx.level.getText().equalsIgnoreCase("plan")) { + explainLevel = ExplainLevel.valueOf(ctx.level.getText().toUpperCase(Locale.ROOT)); + } else { + explainLevel = parseExplainPlanType(ctx.planType()); + } + } + return new ExplainCommand(explainLevel, logicalPlan); + }); } @Override @@ -1119,4 +1135,23 @@ public class LogicalPlanBuilder extends DorisParserBaseVisitor { } return item.getText(); } + + private ExplainLevel parseExplainPlanType(PlanTypeContext planTypeContext) { + if (planTypeContext == null || planTypeContext.ALL() != null) { + return ExplainLevel.ALL_PLAN; + } + if (planTypeContext.PHYSICAL() != null || planTypeContext.OPTIMIZED() != null) { + return ExplainLevel.OPTIMIZED_PLAN; + } + if (planTypeContext.REWRITTEN() != null || planTypeContext.LOGICAL() != null) { + return ExplainLevel.REWRITTEN_PLAN; + } + if (planTypeContext.ANALYZED() != null) { + return ExplainLevel.ANALYZED_PLAN; + } + if (planTypeContext.PARSED() != null) { + return ExplainLevel.PARSED_PLAN; + } + return ExplainLevel.ALL_PLAN; + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/NereidsParser.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/NereidsParser.java index 22476d361d..72c8c9368e 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/NereidsParser.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/NereidsParser.java @@ -53,19 +53,20 @@ public class NereidsParser { public List parseSQL(String originStr) { List> logicalPlans = parseMultiple(originStr); List statementBases = Lists.newArrayList(); - for (Pair logicalPlan : logicalPlans) { + for (Pair parsedPlanToContext : logicalPlans) { // TODO: this is a trick to support explain. Since we do not support any other command in a short time. // It is acceptable. In the future, we need to refactor this. - if (logicalPlan.first instanceof ExplainCommand) { - ExplainCommand explainCommand = (ExplainCommand) logicalPlan.first; + StatementContext statementContext = parsedPlanToContext.second; + if (parsedPlanToContext.first instanceof ExplainCommand) { + ExplainCommand explainCommand = (ExplainCommand) parsedPlanToContext.first; LogicalPlan innerPlan = explainCommand.getLogicalPlan(); - LogicalPlanAdapter logicalPlanAdapter = new LogicalPlanAdapter(innerPlan, logicalPlan.second); - logicalPlanAdapter.setIsExplain(new ExplainOptions( - explainCommand.getLevel() == ExplainLevel.VERBOSE, - explainCommand.getLevel() == ExplainLevel.GRAPH)); + LogicalPlanAdapter logicalPlanAdapter = new LogicalPlanAdapter(innerPlan, statementContext); + ExplainLevel explainLevel = explainCommand.getLevel(); + ExplainOptions explainOptions = new ExplainOptions(explainLevel); + logicalPlanAdapter.setIsExplain(explainOptions); statementBases.add(logicalPlanAdapter); } else { - statementBases.add(new LogicalPlanAdapter(logicalPlan.first, logicalPlan.second)); + statementBases.add(new LogicalPlanAdapter(parsedPlanToContext.first, statementContext)); } } return statementBases; diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindFunction.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindFunction.java index 032bf7a744..474e77ee2f 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindFunction.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindFunction.java @@ -170,8 +170,8 @@ public class BindFunction implements AnalysisRuleFactory { return new Count(); } if (arguments.size() == 1) { - boolean isGlobalAgg = true; - AggregateParam aggregateParam = new AggregateParam(unboundFunction.isDistinct(), isGlobalAgg); + AggregateParam aggregateParam = new AggregateParam( + unboundFunction.isDistinct(), true, false); return new Count(aggregateParam, unboundFunction.getArguments().get(0)); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggregateDisassemble.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggregateDisassemble.java index 4a632e188b..8fc151b271 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggregateDisassemble.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggregateDisassemble.java @@ -28,7 +28,6 @@ import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunctio import org.apache.doris.nereids.trees.plans.AggPhase; import org.apache.doris.nereids.trees.plans.GroupPlan; import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; -import org.apache.doris.nereids.types.DataType; import org.apache.doris.nereids.util.ExpressionUtils; import com.google.common.base.Preconditions; @@ -191,20 +190,15 @@ public class AggregateDisassemble extends OneRewriteRuleFactory { AggregateFunction localAggregateFunction = aggregateFunction.withAggregateParam( aggregateFunction.getAggregateParam() .withDistinct(false) - .withGlobal(false) + .withGlobalAndDisassembled(false, true) ); NamedExpression localOutputExpr = new Alias(localAggregateFunction, aggregateFunction.toSql()); - List inputTypesBeforeDissemble = aggregateFunction.children() - .stream() - .map(Expression::getDataType) - .collect(Collectors.toList()); AggregateFunction substitutionValue = aggregateFunction // save the origin input types to the global aggregate functions .withAggregateParam(aggregateFunction.getAggregateParam() .withDistinct(false) - .withGlobal(true) - .withInputTypesBeforeDissemble(Optional.of(inputTypesBeforeDissemble))) + .withGlobalAndDisassembled(true, true)) .withChildren(Lists.newArrayList(localOutputExpr.toSlot())); inputSubstitutionMap.put(aggregateFunction, substitutionValue); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Cast.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Cast.java index 321a334e39..32f30a9944 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Cast.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Cast.java @@ -36,7 +36,7 @@ public class Cast extends Expression implements UnaryExpression { public Cast(Expression child, DataType targetType) { super(child); - this.targetType = targetType; + this.targetType = Objects.requireNonNull(targetType, "targetType can not be null"); } @Override @@ -67,7 +67,7 @@ public class Cast extends Expression implements UnaryExpression { @Override public String toString() { - return toSql(); + return "CAST(" + child() + " AS " + targetType + ")"; } @Override diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/TVFProperties.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/TVFProperties.java index a6db2f5780..78992add63 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/TVFProperties.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/TVFProperties.java @@ -60,6 +60,6 @@ public class TVFProperties extends Expression implements LeafExpression { @Override public String toString() { - return "KeyValuesExpression(" + toSql() + ")"; + return "TVFProperties(" + toSql() + ")"; } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/BoundFunction.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/BoundFunction.java index fedb42825c..23352c2655 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/BoundFunction.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/BoundFunction.java @@ -17,18 +17,44 @@ package org.apache.doris.nereids.trees.expressions.functions; +import org.apache.doris.catalog.FunctionSignature; +import org.apache.doris.catalog.ScalarType; +import org.apache.doris.catalog.Type; +import org.apache.doris.common.Config; +import org.apache.doris.common.Pair; +import org.apache.doris.nereids.annotation.Developing; import org.apache.doris.nereids.exceptions.UnboundException; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; +import org.apache.doris.nereids.types.ArrayType; +import org.apache.doris.nereids.types.DataType; +import org.apache.doris.nereids.types.DateTimeV2Type; +import org.apache.doris.nereids.types.DecimalV2Type; +import org.apache.doris.nereids.types.DecimalV3Type; +import org.apache.doris.nereids.types.coercion.AbstractDataType; +import org.apache.doris.nereids.util.ResponsibilityChain; + +import com.google.common.base.Suppliers; import java.util.List; import java.util.Objects; +import java.util.function.BiFunction; +import java.util.function.Supplier; import java.util.stream.Collectors; /** BoundFunction. */ -public abstract class BoundFunction extends Expression implements FunctionTrait { +public abstract class BoundFunction extends Expression implements FunctionTrait, ComputeSignature { private final String name; + private final Supplier signatureCache = Suppliers.memoize(() -> { + // first step: find the candidate signature in the signature list + List originArguments = getOriginArguments(); + FunctionSignature matchedSignature = searchSignature( + getOriginArgumentTypes(), originArguments, getSignatures()); + // second step: change the signature, e.g. fill precision for decimal v2 + return computeSignature(matchedSignature, originArguments); + }); + public BoundFunction(String name, Expression... arguments) { super(arguments); this.name = Objects.requireNonNull(name, "name can not be null"); @@ -39,10 +65,40 @@ public abstract class BoundFunction extends Expression implements FunctionTrait this.name = Objects.requireNonNull(name, "name can not be null"); } + protected FunctionSignature computeSignature(FunctionSignature signature, List arguments) { + // NOTE: + // this computed chain only process the common cases. + // If you want to add some common cases to here, please separate the process code + // to the other methods and add to this chain. + // If you want to add some special cases, please override this method in the special + // function class, like 'If' function and 'Substring' function. + return ComputeSignatureChain.from(signature, arguments) + .then(this::computePrecisionForDatetimeV2) + .then(this::upgradeDateOrDateTimeToV2) + .then(this::upgradeDecimalV2ToV3) + .then(this::computePrecisionForDecimal) + .then(this::dynamicComputePropertiesOfArray) + .get(); + } + public String getName() { return name; } + public FunctionSignature getSignature() { + return signatureCache.get(); + } + + @Override + public List expectedInputTypes() { + return ComputeSignature.super.expectedInputTypes(); + } + + @Override + public DataType getDataType() { + return ComputeSignature.super.getDataType(); + } + @Override public R accept(ExpressionVisitor visitor, C context) { return visitor.visitBoundFunction(this, context); @@ -82,4 +138,114 @@ public abstract class BoundFunction extends Expression implements FunctionTrait .collect(Collectors.joining(", ")); return name + "(" + args + ")"; } + + private FunctionSignature computePrecisionForDatetimeV2( + FunctionSignature signature, List arguments) { + + // fill for arguments type + signature = signature.withArgumentTypes(arguments, (sigArgType, realArgType) -> { + if (sigArgType instanceof DateTimeV2Type && realArgType.getDataType() instanceof DateTimeV2Type) { + return realArgType.getDataType(); + } + return sigArgType; + }); + + // fill for return type + if (signature.returnType instanceof DateTimeV2Type) { + Integer maxScale = signature.argumentsTypes.stream() + .filter(DateTimeV2Type.class::isInstance) + .map(t -> ((DateTimeV2Type) t).getScale()) + .reduce(Math::max) + .orElse(((DateTimeV2Type) signature.returnType).getScale()); + signature = signature.withReturnType(DateTimeV2Type.of(maxScale)); + } + + return signature; + } + + private FunctionSignature upgradeDateOrDateTimeToV2( + FunctionSignature signature, List arguments) { + DataType returnType = signature.returnType; + Type type = returnType.toCatalogDataType(); + if ((type.isDate() || type.isDatetime()) && Config.enable_date_conversion) { + Type legacyReturnType = ScalarType.getDefaultDateType(returnType.toCatalogDataType()); + signature = signature.withReturnType(DataType.fromCatalogType(legacyReturnType)); + } + return signature; + } + + @Developing + private FunctionSignature computePrecisionForDecimal( + FunctionSignature signature, List arguments) { + if (signature.returnType instanceof DecimalV3Type || signature.returnType instanceof DecimalV2Type) { + if (this instanceof DecimalSamePrecision) { + signature = signature.withReturnType(arguments.get(0).getDataType()); + } else if (this instanceof DecimalWiderPrecision) { + ScalarType widerType = ScalarType.createDecimalV3Type(ScalarType.MAX_DECIMAL128_PRECISION, + ((ScalarType) arguments.get(0).getDataType().toCatalogDataType()).getScalarScale()); + signature = signature.withReturnType(DataType.fromCatalogType(widerType)); + } else if (this instanceof DecimalStddevPrecision) { + // for all stddev function, use decimal(38,9) as computing result + ScalarType stddevDecimalType = ScalarType.createDecimalV3Type(ScalarType.MAX_DECIMAL128_PRECISION, + DecimalStddevPrecision.STDDEV_DECIMAL_SCALE); + signature = signature.withReturnType(DataType.fromCatalogType(stddevDecimalType)); + } + } + + return signature; + } + + private FunctionSignature upgradeDecimalV2ToV3( + FunctionSignature signature, List arguments) { + DataType returnType = signature.returnType; + Type type = returnType.toCatalogDataType(); + if (type.isDecimalV2() && Config.enable_decimal_conversion && Config.enable_decimalv3) { + Type v3Type = ScalarType.createDecimalV3Type(type.getPrecision(), ((ScalarType) type).getScalarScale()); + signature = signature.withReturnType(DataType.fromCatalogType(v3Type)); + } + return signature; + } + + private FunctionSignature dynamicComputePropertiesOfArray( + FunctionSignature signature, List arguments) { + if (!(signature.returnType instanceof ArrayType)) { + return signature; + } + + // TODO: compute array(...) function's itemType + + // fill item type by the type of first item + ArrayType arrayType = (ArrayType) signature.returnType; + + // fill containsNull if any array argument contains null + boolean containsNull = signature.argumentsTypes + .stream() + .filter(argType -> argType instanceof ArrayType) + .map(ArrayType.class::cast) + .anyMatch(ArrayType::containsNull); + return signature.withReturnType( + ArrayType.of(arrayType.getItemType(), arrayType.containsNull() || containsNull)); + } + + static class ComputeSignatureChain { + private ResponsibilityChain>> computeChain; + + public ComputeSignatureChain(ResponsibilityChain>> computeChain) { + this.computeChain = computeChain; + } + + public static ComputeSignatureChain from(FunctionSignature signature, List arguments) { + return new ComputeSignatureChain(ResponsibilityChain.from(Pair.of(signature, arguments))); + } + + public ComputeSignatureChain then( + BiFunction, FunctionSignature> computeFunction) { + computeChain.then(pair -> Pair.of(computeFunction.apply(pair.first, pair.second), pair.second)); + return this; + } + + public FunctionSignature get() { + return computeChain.get().first; + } + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/ComputeSignature.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/ComputeSignature.java index 7416b56910..7448d00f31 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/ComputeSignature.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/ComputeSignature.java @@ -19,6 +19,7 @@ package org.apache.doris.nereids.trees.expressions.functions; import org.apache.doris.catalog.FunctionSignature; import org.apache.doris.nereids.annotation.Developing; +import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.typecoercion.ImplicitCastInputTypes; import org.apache.doris.nereids.types.DataType; import org.apache.doris.nereids.types.coercion.AbstractDataType; @@ -47,7 +48,8 @@ public interface ComputeSignature extends FunctionTrait, ImplicitCastInputTypes * * @return the matched signature */ - FunctionSignature searchSignature(); + FunctionSignature searchSignature(List argumentTypes, List arguments, + List signatures); ///// re-defined other interface's methods, so we can mixin this interfaces like a trait ///// @@ -64,7 +66,7 @@ public interface ComputeSignature extends FunctionTrait, ImplicitCastInputTypes */ @Override default List expectedInputTypes() { - return (List) getSignature().argumentsTypes; + return getSignature().argumentsTypes; } /** diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/CustomSignature.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/CustomSignature.java new file mode 100644 index 0000000000..2c5020fd51 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/CustomSignature.java @@ -0,0 +1,47 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.trees.expressions.functions; + +import org.apache.doris.catalog.FunctionSignature; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.types.DataType; + +import com.google.common.collect.ImmutableList; + +import java.util.List; + +/** CustomSignature */ +public interface CustomSignature extends ComputeSignature { + + // custom generate a function signature. + FunctionSignature customSignature(List argumentTypes, List arguments); + + @Override + default List getSignatures() { + List originArgumentTypes = getOriginArgumentTypes(); + List originArguments = getOriginArguments(); + return ImmutableList.of(customSignature(originArgumentTypes, originArguments)); + } + + // use the first signature as the candidate signature. + @Override + default FunctionSignature searchSignature(List argumentTypes, List arguments, + List signatures) { + return signatures.get(0); + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/DateTimeWithPrecision.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/DateTimeWithPrecision.java index bc42728c0f..2726888da4 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/DateTimeWithPrecision.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/DateTimeWithPrecision.java @@ -41,17 +41,17 @@ public abstract class DateTimeWithPrecision extends ScalarFunction { } @Override - protected FunctionSignature computeSignature(FunctionSignature signature) { + protected FunctionSignature computeSignature(FunctionSignature signature, List arguments) { if (arity() == 1 && signature.returnType instanceof DateTimeV2Type) { // For functions in TIME_FUNCTIONS_WITH_PRECISION, we can't figure out which function should be use when // searching in FunctionSet. So we adjust the return type by hand here. - if (child(0) instanceof IntegerLikeLiteral) { - IntegerLikeLiteral integerLikeLiteral = (IntegerLikeLiteral) child(0); + if (arguments.get(0) instanceof IntegerLikeLiteral) { + IntegerLikeLiteral integerLikeLiteral = (IntegerLikeLiteral) arguments.get(0); signature = signature.withReturnType(DateTimeV2Type.of(integerLikeLiteral.getIntValue())); } else { signature = signature.withReturnType(DateTimeV2Type.of(6)); } } - return super.computeSignature(signature); + return super.computeSignature(signature, arguments); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/ExplicitlyCastableSignature.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/ExplicitlyCastableSignature.java index b728d0320c..2ddaec4d01 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/ExplicitlyCastableSignature.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/ExplicitlyCastableSignature.java @@ -19,8 +19,12 @@ package org.apache.doris.nereids.trees.expressions.functions; import org.apache.doris.catalog.FunctionSignature; import org.apache.doris.catalog.Type; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.types.DataType; import org.apache.doris.nereids.types.coercion.AbstractDataType; +import java.util.List; + /** * Explicitly castable signature. This class equals to 'CompareMode.IS_NONSTRICT_SUPERTYPE_OF'. * @@ -34,8 +38,9 @@ public interface ExplicitlyCastableSignature extends ComputeSignature { } @Override - default FunctionSignature searchSignature() { - return SearchSignature.from(getSignatures(), getArguments()) + default FunctionSignature searchSignature(List argumentTypes, List arguments, + List signatures) { + return SearchSignature.from(signatures, arguments) // first round, use identical strategy to find signature .orElseSearch(IdenticalSignature::isIdentical) // second round: if not found, use nullOrIdentical strategy diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/ExpressionTrait.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/ExpressionTrait.java index b936544650..ac72ac3db4 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/ExpressionTrait.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/ExpressionTrait.java @@ -39,6 +39,10 @@ public interface ExpressionTrait extends TreeNode { return children(); } + default Expression getArgument(int index) { + return child(index); + } + default DataType getDataType() throws UnboundException { throw new UnboundException("dataType"); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/FunctionTrait.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/FunctionTrait.java index 152f0f777a..68c1a557d3 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/FunctionTrait.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/FunctionTrait.java @@ -17,6 +17,13 @@ package org.apache.doris.nereids.trees.expressions.functions; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.types.DataType; + +import com.google.common.collect.ImmutableList; + +import java.util.List; + /** * FunctionTrait. */ @@ -24,4 +31,15 @@ public interface FunctionTrait extends ExpressionTrait { String getName(); boolean hasVarArguments(); + + default List getOriginArguments() { + return getArguments(); + } + + default List getOriginArgumentTypes() { + return getArguments() + .stream() + .map(Expression::getDataType) + .collect(ImmutableList.toImmutableList()); + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/IdenticalSignature.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/IdenticalSignature.java index 2beab26027..e5018be876 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/IdenticalSignature.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/IdenticalSignature.java @@ -18,8 +18,12 @@ package org.apache.doris.nereids.trees.expressions.functions; import org.apache.doris.catalog.FunctionSignature; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.types.DataType; import org.apache.doris.nereids.types.coercion.AbstractDataType; +import java.util.List; + /** * Identical function signature. This class equals to 'CompareMode.IS_IDENTICAL'. * @@ -33,8 +37,9 @@ public interface IdenticalSignature extends ComputeSignature { } @Override - default FunctionSignature searchSignature() { - return SearchSignature.from(getSignatures(), getArguments()) + default FunctionSignature searchSignature(List argumentTypes, List arguments, + List signatures) { + return SearchSignature.from(signatures, arguments) // first round, use identical strategy to find signature .orElseSearch(IdenticalSignature::isIdentical) .resultOrException(getName()); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/ImplicitlyCastableSignature.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/ImplicitlyCastableSignature.java index 3c830ef31b..77cfd54d08 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/ImplicitlyCastableSignature.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/ImplicitlyCastableSignature.java @@ -19,8 +19,12 @@ package org.apache.doris.nereids.trees.expressions.functions; import org.apache.doris.catalog.FunctionSignature; import org.apache.doris.catalog.Type; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.types.DataType; import org.apache.doris.nereids.types.coercion.AbstractDataType; +import java.util.List; + /** * Implicitly castable function signature. This class equals to 'CompareMode.IS_SUPERTYPE_OF'. * @@ -34,8 +38,9 @@ public interface ImplicitlyCastableSignature extends ComputeSignature { } @Override - default FunctionSignature searchSignature() { - return SearchSignature.from(getSignatures(), getArguments()) + default FunctionSignature searchSignature(List argumentTypes, List arguments, + List signatures) { + return SearchSignature.from(signatures, arguments) // first round, use identical strategy to find signature .orElseSearch(IdenticalSignature::isIdentical) // second round: if not found, use nullOrIdentical strategy diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/NullOrIdenticalSignature.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/NullOrIdenticalSignature.java index 164d175729..5f1dd0ee1a 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/NullOrIdenticalSignature.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/NullOrIdenticalSignature.java @@ -18,9 +18,13 @@ package org.apache.doris.nereids.trees.expressions.functions; import org.apache.doris.catalog.FunctionSignature; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.types.DataType; import org.apache.doris.nereids.types.NullType; import org.apache.doris.nereids.types.coercion.AbstractDataType; +import java.util.List; + /** * Null or identical function signature. This class equals to 'CompareMode.IS_INDISTINGUISHABLE'. * @@ -35,8 +39,9 @@ public interface NullOrIdenticalSignature extends ComputeSignature { } @Override - default FunctionSignature searchSignature() { - return SearchSignature.from(getSignatures(), getArguments()) + default FunctionSignature searchSignature(List argumentTypes, List arguments, + List signatures) { + return SearchSignature.from(signatures, arguments) // first round, use identical strategy to find signature .orElseSearch(IdenticalSignature::isIdentical) // second round: if not found, use nullOrIdentical strategy diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/AggregateFunction.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/AggregateFunction.java index f01e1839cd..80c799391b 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/AggregateFunction.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/AggregateFunction.java @@ -19,17 +19,21 @@ package org.apache.doris.nereids.trees.expressions.functions.agg; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.functions.BoundFunction; +import org.apache.doris.nereids.trees.expressions.typecoercion.ExpectsInputTypes; import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; import org.apache.doris.nereids.types.DataType; +import org.apache.doris.nereids.types.PartialAggType; +import org.apache.doris.nereids.types.coercion.AbstractDataType; + +import com.google.common.collect.ImmutableList; import java.util.List; import java.util.Objects; -import java.util.Optional; /** * The function which consume arguments in lots of rows and product one value. */ -public abstract class AggregateFunction extends BoundFunction { +public abstract class AggregateFunction extends BoundFunction implements ExpectsInputTypes { private final AggregateParam aggregateParam; @@ -42,15 +46,69 @@ public abstract class AggregateFunction extends BoundFunction { this.aggregateParam = Objects.requireNonNull(aggregateParam, "aggregateParam can not be null"); } + @Override + public List getOriginArguments() { + return getArgumentsBeforeDisassembled(); + } + + @Override + public List getOriginArgumentTypes() { + return getArgumentTypesBeforeDisassembled(); + } + @Override public abstract AggregateFunction withChildren(List children); - public abstract DataType getFinalType(); - - public abstract DataType getIntermediateType(); - public abstract AggregateFunction withAggregateParam(AggregateParam aggregateParam); + protected abstract List intermediateTypes(List argumentTypes, List arguments); + + /** getIntermediateTypes */ + public final PartialAggType getIntermediateTypes() { + if (isGlobal() && isDisassembled()) { + return (PartialAggType) child(0).getDataType(); + } + List arguments = getArgumentsBeforeDisassembled(); + List types = getArgumentTypesBeforeDisassembled(); + return new PartialAggType(getArguments(), intermediateTypes(types, arguments)); + } + + public final DataType getFinalType() { + return getSignature().returnType; + } + + @Override + public final DataType getDataType() { + if (aggregateParam.isGlobal) { + return getFinalType(); + } else { + return getIntermediateTypes(); + } + } + + @Override + public final List expectedInputTypes() { + if (isGlobal() && isDisassembled()) { + return ImmutableList.of(getIntermediateTypes()); + } else { + return getSignature().argumentsTypes; + } + } + + public List getArgumentsBeforeDisassembled() { + if (arity() == 1 && getArgument(0).getDataType() instanceof PartialAggType) { + return ((PartialAggType) getArgument(0).getDataType()).getOriginArguments(); + } + return getArguments(); + } + + public List getArgumentTypesBeforeDisassembled() { + return getArgumentsBeforeDisassembled() + .stream() + .map(Expression::getDataType) + .collect(ImmutableList.toImmutableList()); + } + public boolean isDistinct() { return aggregateParam.isDistinct; } @@ -59,8 +117,8 @@ public abstract class AggregateFunction extends BoundFunction { return aggregateParam.isGlobal; } - public Optional> inputTypesBeforeDissemble() { - return aggregateParam.inputTypesBeforeDissemble; + public boolean isDisassembled() { + return aggregateParam.isDisassembled; } public AggregateParam getAggregateParam() { @@ -95,13 +153,4 @@ public abstract class AggregateFunction extends BoundFunction { public boolean hasVarArguments() { return false; } - - @Override - public final DataType getDataType() { - if (aggregateParam.isGlobal) { - return getFinalType(); - } else { - return getIntermediateType(); - } - } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/AggregateParam.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/AggregateParam.java index bf31599c1a..9e4b8e1394 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/AggregateParam.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/AggregateParam.java @@ -17,11 +17,9 @@ package org.apache.doris.nereids.trees.expressions.functions.agg; -import org.apache.doris.nereids.types.DataType; +import com.google.common.base.Preconditions; -import java.util.List; import java.util.Objects; -import java.util.Optional; /** AggregateParam. */ public class AggregateParam { @@ -29,50 +27,41 @@ public class AggregateParam { public final boolean isDistinct; - // When AggregateDisassemble rule disassemble the aggregate function, say double avg(int), the local - // aggregate keep the origin signature, but the global aggregate change to double avg(double). - // This behavior is difference from the legacy optimizer, because legacy optimizer keep the same signature - // between local aggregate and global aggregate. If the signatures are different, the result would wrong. - // So we use this field to record the originInputTypes, and find the catalog function by the origin input types. - public final Optional> inputTypesBeforeDissemble; + public final boolean isDisassembled; - public AggregateParam() { - this(false, true, Optional.empty()); - } - - public AggregateParam(boolean distinct) { - this(distinct, true, Optional.empty()); - } - - public AggregateParam(boolean isDistinct, boolean isGlobal) { - this(isDistinct, isGlobal, Optional.empty()); - } - - public AggregateParam(boolean isDistinct, boolean isGlobal, Optional> inputTypesBeforeDissemble) { + /** AggregateParam */ + public AggregateParam(boolean isDistinct, boolean isGlobal, boolean isDisassembled) { this.isDistinct = isDistinct; this.isGlobal = isGlobal; - this.inputTypesBeforeDissemble = Objects.requireNonNull(inputTypesBeforeDissemble, - "inputTypesBeforeDissemble can not be null"); - } - - public static AggregateParam distinctAndGlobal() { - return new AggregateParam(true, true, Optional.empty()); + this.isDisassembled = isDisassembled; + if (!isGlobal) { + Preconditions.checkArgument(isDisassembled == true, + "local aggregate should be disassembed"); + } } public static AggregateParam global() { - return new AggregateParam(false, true, Optional.empty()); + return new AggregateParam(false, true, false); + } + + public static AggregateParam distinctAndGlobal() { + return new AggregateParam(true, true, false); } public AggregateParam withDistinct(boolean isDistinct) { - return new AggregateParam(isDistinct, isGlobal, inputTypesBeforeDissemble); + return new AggregateParam(isDistinct, isGlobal, isDisassembled); } public AggregateParam withGlobal(boolean isGlobal) { - return new AggregateParam(isDistinct, isGlobal, inputTypesBeforeDissemble); + return new AggregateParam(isDistinct, isGlobal, isDisassembled); } - public AggregateParam withInputTypesBeforeDissemble(Optional> inputTypesBeforeDissemble) { - return new AggregateParam(isDistinct, isGlobal, inputTypesBeforeDissemble); + public AggregateParam withDisassembled(boolean isDisassembled) { + return new AggregateParam(isDistinct, isGlobal, isDisassembled); + } + + public AggregateParam withGlobalAndDisassembled(boolean isGlobal, boolean isDisassembled) { + return new AggregateParam(isDistinct, isGlobal, isDisassembled); } @Override @@ -86,11 +75,11 @@ public class AggregateParam { AggregateParam that = (AggregateParam) o; return isDistinct == that.isDistinct && Objects.equals(isGlobal, that.isGlobal) - && Objects.equals(inputTypesBeforeDissemble, that.inputTypesBeforeDissemble); + && Objects.equals(isDisassembled, that.isDisassembled); } @Override public int hashCode() { - return Objects.hash(isDistinct, isGlobal, inputTypesBeforeDissemble); + return Objects.hash(isDistinct, isGlobal, isDisassembled); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Avg.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Avg.java index 642310ce4e..17f4a0ba39 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Avg.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Avg.java @@ -17,19 +17,20 @@ package org.apache.doris.nereids.trees.expressions.functions.agg; +import org.apache.doris.catalog.FunctionSignature; +import org.apache.doris.nereids.exceptions.AnalysisException; import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.functions.CustomSignature; +import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable; import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression; -import org.apache.doris.nereids.trees.expressions.typecoercion.ImplicitCastInputTypes; import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; +import org.apache.doris.nereids.types.BigIntType; import org.apache.doris.nereids.types.DataType; import org.apache.doris.nereids.types.DateTimeType; import org.apache.doris.nereids.types.DateType; import org.apache.doris.nereids.types.DecimalV2Type; import org.apache.doris.nereids.types.DoubleType; -import org.apache.doris.nereids.types.VarcharType; -import org.apache.doris.nereids.types.coercion.AbstractDataType; import org.apache.doris.nereids.types.coercion.NumericType; -import org.apache.doris.nereids.types.coercion.TypeCollection; import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; @@ -37,12 +38,7 @@ import com.google.common.collect.ImmutableList; import java.util.List; /** avg agg function. */ -public class Avg extends AggregateFunction implements UnaryExpression, ImplicitCastInputTypes { - - // used in interface expectedInputTypes to avoid new list in each time it be called - private static final List EXPECTED_INPUT_TYPES = ImmutableList.of( - new TypeCollection(NumericType.INSTANCE, DateTimeType.INSTANCE, DateType.INSTANCE) - ); +public class Avg extends AggregateFunction implements UnaryExpression, PropagateNullable, CustomSignature { public Avg(Expression child) { super("avg", child); @@ -53,31 +49,16 @@ public class Avg extends AggregateFunction implements UnaryExpression, ImplicitC } @Override - public DataType getFinalType() { - DataType argumentType = inputTypesBeforeDissemble() - .map(types -> types.get(0)) - .orElse(child().getDataType()); - if (argumentType instanceof DecimalV2Type) { - return DecimalV2Type.SYSTEM_DEFAULT; - } else if (argumentType.isDate()) { - return DateType.INSTANCE; - } else if (argumentType.isDateTime()) { - return DateTimeType.INSTANCE; - } else { - return DoubleType.INSTANCE; - } - } - - // TODO: We should return a complex type: PartialAggType(bufferTypes=[Double, Int], inputTypes=[Int]) - // to denote sum(double) and count(int) - @Override - public DataType getIntermediateType() { - return VarcharType.createVarcharType(-1); + public FunctionSignature customSignature(List argumentTypes, List arguments) { + DataType implicitCastType = implicitCast(argumentTypes.get(0)); + return FunctionSignature.ret(implicitCastType).args(implicitCastType); } @Override - public boolean nullable() { - return child().nullable(); + protected List intermediateTypes(List argumentTypes, List arguments) { + DataType sumType = getFinalType(); + BigIntType countType = BigIntType.INSTANCE; + return ImmutableList.of(sumType, countType); } @Override @@ -91,17 +72,22 @@ public class Avg extends AggregateFunction implements UnaryExpression, ImplicitC return new Avg(aggregateParam, child()); } - @Override - public List expectedInputTypes() { - if (isGlobal() && inputTypesBeforeDissemble().isPresent()) { - return ImmutableList.of(); - } else { - return EXPECTED_INPUT_TYPES; - } - } - @Override public R accept(ExpressionVisitor visitor, C context) { return visitor.visitAvg(this, context); } + + private DataType implicitCast(DataType dataType) { + if (dataType instanceof DecimalV2Type) { + return DecimalV2Type.SYSTEM_DEFAULT; + } else if (dataType.isDate()) { + return DateType.INSTANCE; + } else if (dataType.isDateTime()) { + return DateTimeType.INSTANCE; + } else if (dataType instanceof NumericType) { + return DoubleType.INSTANCE; + } else { + throw new AnalysisException("avg requires a numeric parameter: " + dataType); + } + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/BitmapIntersect.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/BitmapIntersect.java index c519e6df12..d62066a5b7 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/BitmapIntersect.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/BitmapIntersect.java @@ -17,13 +17,13 @@ package org.apache.doris.nereids.trees.expressions.functions.agg; +import org.apache.doris.catalog.FunctionSignature; import org.apache.doris.nereids.trees.expressions.Expression; -import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable; +import org.apache.doris.nereids.trees.expressions.functions.AlwaysNotNullable; +import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature; import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression; -import org.apache.doris.nereids.trees.expressions.typecoercion.ImplicitCastInputTypes; import org.apache.doris.nereids.types.BitmapType; import org.apache.doris.nereids.types.DataType; -import org.apache.doris.nereids.types.coercion.AbstractDataType; import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; @@ -32,7 +32,11 @@ import java.util.List; /** BitmapIntersect */ public class BitmapIntersect extends AggregateFunction - implements UnaryExpression, PropagateNullable, ImplicitCastInputTypes { + implements UnaryExpression, AlwaysNotNullable, ExplicitlyCastableSignature { + public static final List SIGNATURES = ImmutableList.of( + FunctionSignature.ret(BitmapType.INSTANCE).args(BitmapType.INSTANCE) + ); + public BitmapIntersect(Expression arg0) { super("bitmap_intersect", arg0); } @@ -41,29 +45,24 @@ public class BitmapIntersect extends AggregateFunction super("bitmap_intersect", aggregateParam, arg0); } + @Override + protected List intermediateTypes(List argumentTypes, List arguments) { + return ImmutableList.of(BitmapType.INSTANCE); + } + @Override public BitmapIntersect withChildren(List children) { Preconditions.checkArgument(children.size() == 1); return new BitmapIntersect(getAggregateParam(), children.get(0)); } - @Override - public List expectedInputTypes() { - return ImmutableList.of(BitmapType.INSTANCE); - } - - @Override - public DataType getFinalType() { - return BitmapType.INSTANCE; - } - - @Override - public DataType getIntermediateType() { - return BitmapType.INSTANCE; - } - @Override public BitmapIntersect withAggregateParam(AggregateParam aggregateParam) { return new BitmapIntersect(aggregateParam, child()); } + + @Override + public List getSignatures() { + return SIGNATURES; + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/BitmapUnion.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/BitmapUnion.java index 56807aad90..f3d2ef0d52 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/BitmapUnion.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/BitmapUnion.java @@ -17,13 +17,13 @@ package org.apache.doris.nereids.trees.expressions.functions.agg; +import org.apache.doris.catalog.FunctionSignature; import org.apache.doris.nereids.trees.expressions.Expression; -import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable; +import org.apache.doris.nereids.trees.expressions.functions.AlwaysNotNullable; +import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature; import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression; -import org.apache.doris.nereids.trees.expressions.typecoercion.ImplicitCastInputTypes; import org.apache.doris.nereids.types.BitmapType; import org.apache.doris.nereids.types.DataType; -import org.apache.doris.nereids.types.coercion.AbstractDataType; import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; @@ -32,7 +32,11 @@ import java.util.List; /** BitmapUnion */ public class BitmapUnion extends AggregateFunction - implements UnaryExpression, PropagateNullable, ImplicitCastInputTypes { + implements UnaryExpression, AlwaysNotNullable, ExplicitlyCastableSignature { + public static final List SIGNATURES = ImmutableList.of( + FunctionSignature.ret(BitmapType.INSTANCE).args(BitmapType.INSTANCE) + ); + public BitmapUnion(Expression arg0) { super("bitmap_union", arg0); } @@ -41,29 +45,24 @@ public class BitmapUnion extends AggregateFunction super("bitmap_union", aggregateParam, arg0); } + @Override + protected List intermediateTypes(List argumentTypes, List arguments) { + return ImmutableList.of(BitmapType.INSTANCE); + } + @Override public BitmapUnion withChildren(List children) { Preconditions.checkArgument(children.size() == 1); return new BitmapUnion(getAggregateParam(), children.get(0)); } - @Override - public List expectedInputTypes() { - return ImmutableList.of(BitmapType.INSTANCE); - } - - @Override - public DataType getFinalType() { - return BitmapType.INSTANCE; - } - - @Override - public DataType getIntermediateType() { - return BitmapType.INSTANCE; - } - @Override public BitmapUnion withAggregateParam(AggregateParam aggregateParam) { return new BitmapUnion(aggregateParam, child()); } + + @Override + public List getSignatures() { + return SIGNATURES; + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/BitmapUnionCount.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/BitmapUnionCount.java index 9b7a247002..a98653a6e4 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/BitmapUnionCount.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/BitmapUnionCount.java @@ -17,14 +17,14 @@ package org.apache.doris.nereids.trees.expressions.functions.agg; +import org.apache.doris.catalog.FunctionSignature; import org.apache.doris.nereids.trees.expressions.Expression; -import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable; +import org.apache.doris.nereids.trees.expressions.functions.AlwaysNotNullable; +import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature; import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression; -import org.apache.doris.nereids.trees.expressions.typecoercion.ImplicitCastInputTypes; import org.apache.doris.nereids.types.BigIntType; import org.apache.doris.nereids.types.BitmapType; import org.apache.doris.nereids.types.DataType; -import org.apache.doris.nereids.types.coercion.AbstractDataType; import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; @@ -33,7 +33,11 @@ import java.util.List; /** BitmapUnionCount */ public class BitmapUnionCount extends AggregateFunction - implements UnaryExpression, PropagateNullable, ImplicitCastInputTypes { + implements UnaryExpression, AlwaysNotNullable, ExplicitlyCastableSignature { + public static final List SIGNATURES = ImmutableList.of( + FunctionSignature.ret(BigIntType.INSTANCE).args(BitmapType.INSTANCE) + ); + public BitmapUnionCount(Expression arg0) { super("bitmap_union_count", arg0); } @@ -42,29 +46,24 @@ public class BitmapUnionCount extends AggregateFunction super("bitmap_union_count", aggregateParam, arg0); } + @Override + protected List intermediateTypes(List argumentTypes, List arguments) { + return ImmutableList.of(BitmapType.INSTANCE); + } + @Override public BitmapUnionCount withChildren(List children) { Preconditions.checkArgument(children.size() == 1); return new BitmapUnionCount(getAggregateParam(), children.get(0)); } - @Override - public List expectedInputTypes() { - return ImmutableList.of(BitmapType.INSTANCE); - } - - @Override - public DataType getFinalType() { - return BigIntType.INSTANCE; - } - - @Override - public DataType getIntermediateType() { - return BitmapType.INSTANCE; - } - @Override public BitmapUnionCount withAggregateParam(AggregateParam aggregateParam) { return new BitmapUnionCount(aggregateParam, child()); } + + @Override + public List getSignatures() { + return SIGNATURES; + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/BitmapUnionInt.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/BitmapUnionInt.java index 5d434505cb..2e22ce64c3 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/BitmapUnionInt.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/BitmapUnionInt.java @@ -17,18 +17,17 @@ package org.apache.doris.nereids.trees.expressions.functions.agg; +import org.apache.doris.catalog.FunctionSignature; import org.apache.doris.nereids.trees.expressions.Expression; -import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable; +import org.apache.doris.nereids.trees.expressions.functions.AlwaysNotNullable; +import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature; import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression; -import org.apache.doris.nereids.trees.expressions.typecoercion.ImplicitCastInputTypes; import org.apache.doris.nereids.types.BigIntType; import org.apache.doris.nereids.types.BitmapType; import org.apache.doris.nereids.types.DataType; import org.apache.doris.nereids.types.IntegerType; import org.apache.doris.nereids.types.SmallIntType; import org.apache.doris.nereids.types.TinyIntType; -import org.apache.doris.nereids.types.coercion.AbstractDataType; -import org.apache.doris.nereids.types.coercion.TypeCollection; import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; @@ -37,7 +36,14 @@ import java.util.List; /** BitmapUnionInt */ public class BitmapUnionInt extends AggregateFunction - implements UnaryExpression, PropagateNullable, ImplicitCastInputTypes { + implements UnaryExpression, AlwaysNotNullable, ExplicitlyCastableSignature { + public static final List SIGNATURES = ImmutableList.of( + FunctionSignature.ret(BigIntType.INSTANCE).args(SmallIntType.INSTANCE), + FunctionSignature.ret(BigIntType.INSTANCE).args(TinyIntType.INSTANCE), + FunctionSignature.ret(BigIntType.INSTANCE).args(IntegerType.INSTANCE), + FunctionSignature.ret(BigIntType.INSTANCE).args(BigIntType.INSTANCE) + ); + public BitmapUnionInt(Expression arg0) { super("bitmap_union_int", arg0); } @@ -46,34 +52,24 @@ public class BitmapUnionInt extends AggregateFunction super("bitmap_union_int", aggregateParam, arg0); } + @Override + protected List intermediateTypes(List argumentTypes, List arguments) { + return ImmutableList.of(BitmapType.INSTANCE); + } + @Override public BitmapUnionInt withChildren(List children) { Preconditions.checkArgument(children.size() == 1); return new BitmapUnionInt(getAggregateParam(), children.get(0)); } - @Override - public List expectedInputTypes() { - if (isGlobal() && inputTypesBeforeDissemble().isPresent()) { - return ImmutableList.of(); - } else { - return ImmutableList.of(new TypeCollection( - TinyIntType.INSTANCE, SmallIntType.INSTANCE, IntegerType.INSTANCE, BigIntType.INSTANCE)); - } - } - - @Override - public DataType getFinalType() { - return BigIntType.INSTANCE; - } - - @Override - public DataType getIntermediateType() { - return BitmapType.INSTANCE; - } - @Override public BitmapUnionInt withAggregateParam(AggregateParam aggregateParam) { return new BitmapUnionInt(aggregateParam, child()); } + + @Override + public List getSignatures() { + return SIGNATURES; + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Count.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Count.java index 637bc60e01..2a3ca1947e 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Count.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Count.java @@ -17,20 +17,23 @@ package org.apache.doris.nereids.trees.expressions.functions.agg; +import org.apache.doris.catalog.FunctionSignature; import org.apache.doris.nereids.exceptions.UnboundException; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.functions.AlwaysNotNullable; +import org.apache.doris.nereids.trees.expressions.functions.CustomSignature; import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; import org.apache.doris.nereids.types.BigIntType; import org.apache.doris.nereids.types.DataType; import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableList; import java.util.List; import java.util.stream.Collectors; /** count agg function. */ -public class Count extends AggregateFunction implements AlwaysNotNullable { +public class Count extends AggregateFunction implements AlwaysNotNullable, CustomSignature { private final boolean isStar; @@ -59,13 +62,13 @@ public class Count extends AggregateFunction implements AlwaysNotNullable { } @Override - public DataType getFinalType() { - return BigIntType.INSTANCE; + public FunctionSignature customSignature(List argumentTypes, List arguments) { + return FunctionSignature.of(BigIntType.INSTANCE, (List) argumentTypes); } @Override - public DataType getIntermediateType() { - return getFinalType(); + protected List intermediateTypes(List argumentTypes, List arguments) { + return ImmutableList.of(BigIntType.INSTANCE); } @Override diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/GroupBitmapXor.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/GroupBitmapXor.java index 2f641a4185..201e73d6e2 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/GroupBitmapXor.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/GroupBitmapXor.java @@ -17,13 +17,13 @@ package org.apache.doris.nereids.trees.expressions.functions.agg; +import org.apache.doris.catalog.FunctionSignature; import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature; import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable; import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression; -import org.apache.doris.nereids.trees.expressions.typecoercion.ImplicitCastInputTypes; import org.apache.doris.nereids.types.BitmapType; import org.apache.doris.nereids.types.DataType; -import org.apache.doris.nereids.types.coercion.AbstractDataType; import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; @@ -32,7 +32,11 @@ import java.util.List; /** GroupBitmapXor */ public class GroupBitmapXor extends AggregateFunction - implements UnaryExpression, PropagateNullable, ImplicitCastInputTypes { + implements UnaryExpression, PropagateNullable, ExplicitlyCastableSignature { + public static final List SIGNATURES = ImmutableList.of( + FunctionSignature.ret(BitmapType.INSTANCE).args(BitmapType.INSTANCE) + ); + public GroupBitmapXor(Expression arg0) { super("group_bitmap_xor", arg0); } @@ -41,33 +45,24 @@ public class GroupBitmapXor extends AggregateFunction super("group_bitmap_xor", aggregateParam, arg0); } + @Override + protected List intermediateTypes(List argumentTypes, List arguments) { + return ImmutableList.of(BitmapType.INSTANCE); + } + @Override public GroupBitmapXor withChildren(List children) { Preconditions.checkArgument(children.size() == 1); return new GroupBitmapXor(getAggregateParam(), children.get(0)); } - @Override - public List expectedInputTypes() { - if (isGlobal() && inputTypesBeforeDissemble().isPresent()) { - return ImmutableList.of(); - } else { - return ImmutableList.of(BitmapType.INSTANCE); - } - } - - @Override - public DataType getFinalType() { - return BitmapType.INSTANCE; - } - - @Override - public DataType getIntermediateType() { - return BitmapType.INSTANCE; - } - @Override public GroupBitmapXor withAggregateParam(AggregateParam aggregateParam) { return new GroupBitmapXor(aggregateParam, child()); } + + @Override + public List getSignatures() { + return SIGNATURES; + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/HllUnion.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/HllUnion.java index 5b421bf6f4..5802872b19 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/HllUnion.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/HllUnion.java @@ -17,13 +17,13 @@ package org.apache.doris.nereids.trees.expressions.functions.agg; +import org.apache.doris.catalog.FunctionSignature; import org.apache.doris.nereids.trees.expressions.Expression; -import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable; +import org.apache.doris.nereids.trees.expressions.functions.AlwaysNotNullable; +import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature; import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression; -import org.apache.doris.nereids.trees.expressions.typecoercion.ImplicitCastInputTypes; import org.apache.doris.nereids.types.DataType; import org.apache.doris.nereids.types.HllType; -import org.apache.doris.nereids.types.coercion.AbstractDataType; import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; @@ -32,15 +32,22 @@ import java.util.List; /** HllUnion */ public class HllUnion extends AggregateFunction - implements UnaryExpression, PropagateNullable, ImplicitCastInputTypes { + implements UnaryExpression, AlwaysNotNullable, ExplicitlyCastableSignature { + public static final List SIGNATURES = ImmutableList.of( + FunctionSignature.ret(HllType.INSTANCE).args(HllType.INSTANCE) + ); + public HllUnion(Expression arg0) { - // TODO: change to hll_union in the future - super("hll_raw_agg", arg0); + super("hll_union", arg0); } public HllUnion(AggregateParam aggregateParam, Expression arg0) { - // TODO: change to hll_union in the future - super("hll_raw_agg", aggregateParam, arg0); + super("hll_union", aggregateParam, arg0); + } + + @Override + protected List intermediateTypes(List argumentTypes, List arguments) { + return ImmutableList.of(HllType.INSTANCE); } @Override @@ -49,23 +56,13 @@ public class HllUnion extends AggregateFunction return new HllUnion(getAggregateParam(), children.get(0)); } - @Override - public List expectedInputTypes() { - return ImmutableList.of(HllType.INSTANCE); - } - - @Override - public DataType getFinalType() { - return HllType.INSTANCE; - } - - @Override - public DataType getIntermediateType() { - return HllType.INSTANCE; - } - @Override public HllUnion withAggregateParam(AggregateParam aggregateParam) { return new HllUnion(aggregateParam, child()); } + + @Override + public List getSignatures() { + return SIGNATURES; + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/HllUnionAgg.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/HllUnionAgg.java index df7bf95478..eade606a42 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/HllUnionAgg.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/HllUnionAgg.java @@ -17,14 +17,14 @@ package org.apache.doris.nereids.trees.expressions.functions.agg; +import org.apache.doris.catalog.FunctionSignature; import org.apache.doris.nereids.trees.expressions.Expression; -import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable; +import org.apache.doris.nereids.trees.expressions.functions.AlwaysNotNullable; +import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature; import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression; -import org.apache.doris.nereids.trees.expressions.typecoercion.ImplicitCastInputTypes; import org.apache.doris.nereids.types.BigIntType; import org.apache.doris.nereids.types.DataType; import org.apache.doris.nereids.types.HllType; -import org.apache.doris.nereids.types.coercion.AbstractDataType; import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; @@ -33,7 +33,11 @@ import java.util.List; /** HllUnionAgg */ public class HllUnionAgg extends AggregateFunction - implements UnaryExpression, PropagateNullable, ImplicitCastInputTypes { + implements UnaryExpression, AlwaysNotNullable, ExplicitlyCastableSignature { + public static final List SIGNATURES = ImmutableList.of( + FunctionSignature.ret(BigIntType.INSTANCE).args(HllType.INSTANCE) + ); + public HllUnionAgg(Expression arg0) { super("hll_union_agg", arg0); } @@ -42,29 +46,24 @@ public class HllUnionAgg extends AggregateFunction super("hll_union_agg", aggregateParam, arg0); } + @Override + protected List intermediateTypes(List argumentTypes, List arguments) { + return ImmutableList.of(HllType.INSTANCE); + } + @Override public HllUnionAgg withChildren(List children) { Preconditions.checkArgument(children.size() == 1); return new HllUnionAgg(getAggregateParam(), children.get(0)); } - @Override - public List expectedInputTypes() { - return ImmutableList.of(HllType.INSTANCE); - } - - @Override - public DataType getFinalType() { - return BigIntType.INSTANCE; - } - - @Override - public DataType getIntermediateType() { - return HllType.INSTANCE; - } - @Override public HllUnionAgg withAggregateParam(AggregateParam aggregateParam) { return new HllUnionAgg(aggregateParam, child()); } + + @Override + public List getSignatures() { + return SIGNATURES; + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Max.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Max.java index 36e6e79a41..5c81e3ca40 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Max.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Max.java @@ -17,7 +17,10 @@ package org.apache.doris.nereids.trees.expressions.functions.agg; +import org.apache.doris.catalog.FunctionSignature; import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.functions.CustomSignature; +import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable; import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression; import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; import org.apache.doris.nereids.types.DataType; @@ -27,8 +30,7 @@ import com.google.common.base.Preconditions; import java.util.List; /** max agg function. */ -public class Max extends AggregateFunction implements UnaryExpression { - +public class Max extends AggregateFunction implements UnaryExpression, PropagateNullable, CustomSignature { public Max(Expression child) { super("max", child); } @@ -38,18 +40,13 @@ public class Max extends AggregateFunction implements UnaryExpression { } @Override - public DataType getFinalType() { - return child().getDataType(); + public FunctionSignature customSignature(List argumentTypes, List arguments) { + return FunctionSignature.ret(argumentTypes.get(0)).args(argumentTypes.get(0)); } @Override - public DataType getIntermediateType() { - return getFinalType(); - } - - @Override - public boolean nullable() { - return child().nullable(); + protected List intermediateTypes(List argumentTypes, List arguments) { + return argumentTypes; } @Override diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Min.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Min.java index c350640d6d..8a5d62fe54 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Min.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Min.java @@ -17,7 +17,10 @@ package org.apache.doris.nereids.trees.expressions.functions.agg; +import org.apache.doris.catalog.FunctionSignature; import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.functions.CustomSignature; +import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable; import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression; import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; import org.apache.doris.nereids.types.DataType; @@ -27,7 +30,7 @@ import com.google.common.base.Preconditions; import java.util.List; /** min agg function. */ -public class Min extends AggregateFunction implements UnaryExpression { +public class Min extends AggregateFunction implements UnaryExpression, PropagateNullable, CustomSignature { public Min(Expression child) { super("min", child); @@ -38,18 +41,13 @@ public class Min extends AggregateFunction implements UnaryExpression { } @Override - public DataType getFinalType() { - return child().getDataType(); + public FunctionSignature customSignature(List argumentTypes, List arguments) { + return FunctionSignature.ret(argumentTypes.get(0)).args(argumentTypes.get(0)); } @Override - public DataType getIntermediateType() { - return getFinalType(); - } - - @Override - public boolean nullable() { - return child().nullable(); + protected List intermediateTypes(List argumentTypes, List arguments) { + return argumentTypes; } @Override diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Sum.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Sum.java index 20a2b275f4..25385ccb0d 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Sum.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Sum.java @@ -17,17 +17,18 @@ package org.apache.doris.nereids.trees.expressions.functions.agg; +import org.apache.doris.catalog.FunctionSignature; +import org.apache.doris.nereids.exceptions.AnalysisException; import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.functions.CustomSignature; +import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable; import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression; -import org.apache.doris.nereids.trees.expressions.typecoercion.ImplicitCastInputTypes; import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; import org.apache.doris.nereids.types.BigIntType; import org.apache.doris.nereids.types.DataType; import org.apache.doris.nereids.types.DecimalV2Type; import org.apache.doris.nereids.types.DoubleType; import org.apache.doris.nereids.types.LargeIntType; -import org.apache.doris.nereids.types.coercion.AbstractDataType; -import org.apache.doris.nereids.types.coercion.FractionalType; import org.apache.doris.nereids.types.coercion.IntegralType; import org.apache.doris.nereids.types.coercion.NumericType; @@ -37,11 +38,7 @@ import com.google.common.collect.ImmutableList; import java.util.List; /** sum agg function. */ -public class Sum extends AggregateFunction implements UnaryExpression, ImplicitCastInputTypes { - - // used in interface expectedInputTypes to avoid new list in each time it be called - private static final List EXPECTED_INPUT_TYPES = ImmutableList.of(NumericType.INSTANCE); - +public class Sum extends AggregateFunction implements UnaryExpression, PropagateNullable, CustomSignature { public Sum(Expression child) { super("sum", child); } @@ -51,34 +48,14 @@ public class Sum extends AggregateFunction implements UnaryExpression, ImplicitC } @Override - public DataType getFinalType() { - DataType dataType = child().getDataType(); - if (dataType instanceof LargeIntType) { - return dataType; - } else if (dataType instanceof DecimalV2Type) { - return DecimalV2Type.SYSTEM_DEFAULT; - } else if (dataType instanceof IntegralType) { - return BigIntType.INSTANCE; - } else if (dataType instanceof FractionalType) { - return DoubleType.INSTANCE; - } else { - throw new IllegalStateException("Unsupported sum type: " + dataType); - } + public FunctionSignature customSignature(List argumentTypes, List arguments) { + DataType implicitCastType = implicitCast(argumentTypes.get(0)); + return FunctionSignature.ret(implicitCastType).args(NumericType.INSTANCE); } @Override - public DataType getIntermediateType() { - return getFinalType(); - } - - @Override - public boolean nullable() { - return child().nullable(); - } - - @Override - public List expectedInputTypes() { - return EXPECTED_INPUT_TYPES; + protected List intermediateTypes(List argumentTypes, List arguments) { + return ImmutableList.of(getFinalType()); } @Override @@ -96,4 +73,18 @@ public class Sum extends AggregateFunction implements UnaryExpression, ImplicitC public R accept(ExpressionVisitor visitor, C context) { return visitor.visitSum(this, context); } + + private DataType implicitCast(DataType dataType) { + if (dataType instanceof LargeIntType) { + return dataType; + } else if (dataType instanceof DecimalV2Type) { + return DecimalV2Type.SYSTEM_DEFAULT; + } else if (dataType instanceof IntegralType) { + return BigIntType.INSTANCE; + } else if (dataType instanceof NumericType) { + return DoubleType.INSTANCE; + } else { + throw new AnalysisException("sum requires a numeric parameter: " + dataType); + } + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/If.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/If.java index abf6797062..44256d52a9 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/If.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/If.java @@ -41,6 +41,7 @@ import org.apache.doris.nereids.types.SmallIntType; import org.apache.doris.nereids.types.StringType; import org.apache.doris.nereids.types.TinyIntType; import org.apache.doris.nereids.types.VarcharType; +import org.apache.doris.nereids.types.coercion.AbstractDataType; import com.google.common.base.Preconditions; import com.google.common.base.Suppliers; @@ -90,7 +91,7 @@ public class If extends ScalarFunction ); private final Supplier widerType = Suppliers.memoize(() -> { - List argumentsTypes = getSignature().argumentsTypes; + List argumentsTypes = getSignature().argumentsTypes; Type assignmentCompatibleType = ScalarType.getAssignmentCompatibleType( argumentsTypes.get(1).toCatalogDataType(), argumentsTypes.get(2).toCatalogDataType(), @@ -106,11 +107,11 @@ public class If extends ScalarFunction } @Override - protected FunctionSignature computeSignature(FunctionSignature signature) { + protected FunctionSignature computeSignature(FunctionSignature signature, List arguments) { DataType widerType = this.widerType.get(); - signature = signature.withArgumentTypes(children(), (sigType, argType) -> widerType) + signature = signature.withArgumentTypes(arguments, (sigType, argType) -> widerType) .withReturnType(widerType); - return super.computeSignature(signature); + return super.computeSignature(signature, arguments); } /** diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/ScalarFunction.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/ScalarFunction.java index 83f295faa2..7267ecc899 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/ScalarFunction.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/ScalarFunction.java @@ -17,45 +17,17 @@ package org.apache.doris.nereids.trees.expressions.functions.scalar; -import org.apache.doris.catalog.FunctionSignature; -import org.apache.doris.catalog.ScalarType; -import org.apache.doris.catalog.Type; -import org.apache.doris.common.Config; -import org.apache.doris.common.Pair; -import org.apache.doris.nereids.annotation.Developing; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.functions.BoundFunction; import org.apache.doris.nereids.trees.expressions.functions.ComputeSignature; -import org.apache.doris.nereids.trees.expressions.functions.DecimalSamePrecision; -import org.apache.doris.nereids.trees.expressions.functions.DecimalStddevPrecision; -import org.apache.doris.nereids.trees.expressions.functions.DecimalWiderPrecision; import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; -import org.apache.doris.nereids.types.ArrayType; -import org.apache.doris.nereids.types.DataType; -import org.apache.doris.nereids.types.DateTimeV2Type; -import org.apache.doris.nereids.types.DecimalV2Type; -import org.apache.doris.nereids.types.DecimalV3Type; -import org.apache.doris.nereids.types.coercion.AbstractDataType; -import org.apache.doris.nereids.util.ResponsibilityChain; - -import com.google.common.base.Suppliers; import java.util.List; -import java.util.function.BiFunction; -import java.util.function.Supplier; /** * The function which consume zero or more arguments in a row and product one value. */ public abstract class ScalarFunction extends BoundFunction implements ComputeSignature { - @Developing("this field will move to BoundFunction, when we support compute signature for AggregateFunction") - private final Supplier signatureCache = Suppliers.memoize(() -> { - // first step: find the candidate signature in the signature list - FunctionSignature matchedSignature = searchSignature(); - // second step: change the signature, e.g. fill precision for decimal v2 - return computeSignature(matchedSignature); - }); - public ScalarFunction(String name, Expression... arguments) { super(name, arguments); } @@ -64,150 +36,8 @@ public abstract class ScalarFunction extends BoundFunction implements ComputeSig super(name, arguments); } - public FunctionSignature getSignature() { - return signatureCache.get(); - } - - protected FunctionSignature computeSignature(FunctionSignature signature) { - // NOTE: - // this computed chain only process the common cases. - // If you want to add some common cases to here, please separate the process code - // to the other methods and add to this chain. - // If you want to add some special cases, please override this method in the special - // function class, like 'If' function and 'Substring' function. - return ComputeSignatureChain.from(signature, getArguments()) - .then(this::computePrecisionForDatetimeV2) - .then(this::upgradeDateOrDateTimeToV2) - .then(this::upgradeDecimalV2ToV3) - .then(this::computePrecisionForDecimal) - .then(this::dynamicComputePropertiesOfArray) - .get(); - } - - @Override - @Developing("this method will move to BoundFunction, when we support compute signature for AggregateFunction") - public final List expectedInputTypes() { - return ComputeSignature.super.expectedInputTypes(); - } - - @Override - @Developing("this method will move to BoundFunction, when we support compute signature for AggregateFunction") - public final DataType getDataType() { - return ComputeSignature.super.getDataType(); - } - @Override public R accept(ExpressionVisitor visitor, C context) { return visitor.visitScalarFunction(this, context); } - - private FunctionSignature computePrecisionForDatetimeV2( - FunctionSignature signature, List arguments) { - - // fill for arguments type - signature = signature.withArgumentTypes(arguments, (sigArgType, realArgType) -> { - if (sigArgType instanceof DateTimeV2Type && realArgType.getDataType() instanceof DateTimeV2Type) { - return realArgType.getDataType(); - } - return sigArgType; - }); - - // fill for return type - if (signature.returnType instanceof DateTimeV2Type) { - Integer maxScale = signature.argumentsTypes.stream() - .filter(DateTimeV2Type.class::isInstance) - .map(t -> ((DateTimeV2Type) t).getScale()) - .reduce(Math::max) - .orElse(((DateTimeV2Type) signature.returnType).getScale()); - signature = signature.withReturnType(DateTimeV2Type.of(maxScale)); - } - - return signature; - } - - private FunctionSignature upgradeDateOrDateTimeToV2( - FunctionSignature signature, List arguments) { - DataType returnType = signature.returnType; - Type type = returnType.toCatalogDataType(); - if ((type.isDate() || type.isDatetime()) && Config.enable_date_conversion) { - Type legacyReturnType = ScalarType.getDefaultDateType(returnType.toCatalogDataType()); - signature = signature.withReturnType(DataType.fromCatalogType(legacyReturnType)); - } - return signature; - } - - @Developing - private FunctionSignature computePrecisionForDecimal( - FunctionSignature signature, List arguments) { - if (signature.returnType instanceof DecimalV3Type || signature.returnType instanceof DecimalV2Type) { - if (this instanceof DecimalSamePrecision) { - signature = signature.withReturnType(signature.argumentsTypes.get(0)); - } else if (this instanceof DecimalWiderPrecision) { - ScalarType widerType = ScalarType.createDecimalV3Type(ScalarType.MAX_DECIMAL128_PRECISION, - ((ScalarType) signature.argumentsTypes.get(0).toCatalogDataType()).getScalarScale()); - signature = signature.withReturnType(DataType.fromCatalogType(widerType)); - } else if (this instanceof DecimalStddevPrecision) { - // for all stddev function, use decimal(38,9) as computing result - ScalarType stddevDecimalType = ScalarType.createDecimalV3Type(ScalarType.MAX_DECIMAL128_PRECISION, - DecimalStddevPrecision.STDDEV_DECIMAL_SCALE); - signature = signature.withReturnType(DataType.fromCatalogType(stddevDecimalType)); - } - } - - return signature; - } - - private FunctionSignature upgradeDecimalV2ToV3( - FunctionSignature signature, List arguments) { - DataType returnType = signature.returnType; - Type type = returnType.toCatalogDataType(); - if (type.isDecimalV2() && Config.enable_decimal_conversion && Config.enable_decimalv3) { - Type v3Type = ScalarType.createDecimalV3Type(type.getPrecision(), ((ScalarType) type).getScalarScale()); - signature = signature.withReturnType(DataType.fromCatalogType(v3Type)); - } - return signature; - } - - private FunctionSignature dynamicComputePropertiesOfArray( - FunctionSignature signature, List arguments) { - if (!(signature.returnType instanceof ArrayType)) { - return signature; - } - - // TODO: compute array(...) function's itemType - - // fill item type by the type of first item - ArrayType arrayType = (ArrayType) signature.returnType; - - // fill containsNull if any array argument contains null - boolean containsNull = signature.argumentsTypes - .stream() - .filter(argType -> argType instanceof ArrayType) - .map(ArrayType.class::cast) - .anyMatch(ArrayType::containsNull); - return signature.withReturnType( - ArrayType.of(arrayType.getItemType(), arrayType.containsNull() || containsNull)); - } - - static class ComputeSignatureChain { - private ResponsibilityChain>> computeChain; - - public ComputeSignatureChain(ResponsibilityChain>> computeChain) { - this.computeChain = computeChain; - } - - public static ComputeSignatureChain from(FunctionSignature signature, List arguments) { - return new ComputeSignatureChain(ResponsibilityChain.from(Pair.of(signature, arguments))); - } - - public ComputeSignatureChain then( - BiFunction, FunctionSignature> computeFunction) { - computeChain.then(pair -> Pair.of(computeFunction.apply(pair.first, pair.second), pair.second)); - return this; - } - - public FunctionSignature get() { - return computeChain.get().first; - } - } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/StrToDate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/StrToDate.java index fbaf438496..b70a8d687a 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/StrToDate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/StrToDate.java @@ -56,7 +56,7 @@ public class StrToDate extends ScalarFunction } @Override - protected FunctionSignature computeSignature(FunctionSignature signature) { + protected FunctionSignature computeSignature(FunctionSignature signature, List arguments) { /* * The return type of str_to_date depends on whether the time part is included in the format. * If included, it is datetime, otherwise it is date. @@ -83,8 +83,8 @@ public class StrToDate extends ScalarFunction * Return type is DATETIME */ DataType returnType; - if (child(1) instanceof StringLikeLiteral) { - if (DateLiteral.hasTimePart(((StringLikeLiteral) child(1)).getStringValue())) { + if (arguments.get(1) instanceof StringLikeLiteral) { + if (DateLiteral.hasTimePart(((StringLikeLiteral) arguments.get(1)).getStringValue())) { returnType = DataType.fromCatalogType(ScalarType.getDefaultDateType(Type.DATETIME)); } else { returnType = DataType.fromCatalogType(ScalarType.getDefaultDateType(Type.DATE)); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/Substring.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/Substring.java index a653a792c7..9b419fd675 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/Substring.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/Substring.java @@ -66,8 +66,9 @@ public class Substring extends ScalarFunction } @Override - protected FunctionSignature computeSignature(FunctionSignature signature) { - Optional length = getLength(); + protected FunctionSignature computeSignature(FunctionSignature signature, List arguments) { + Optional length = arguments.size() == 3 + ? Optional.of(arguments.get(2)) : Optional.empty(); DataType returnType = VarcharType.SYSTEM_DEFAULT; if (length.isPresent() && length.get() instanceof IntegerLiteral) { returnType = VarcharType.createVarcharType(((IntegerLiteral) length.get()).getValue()); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/table/Numbers.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/table/Numbers.java index 7d0e501702..dafc19d8fb 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/table/Numbers.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/table/Numbers.java @@ -18,6 +18,7 @@ package org.apache.doris.nereids.trees.expressions.functions.table; import org.apache.doris.analysis.IntLiteral; +import org.apache.doris.catalog.FunctionSignature; import org.apache.doris.catalog.Type; import org.apache.doris.common.Id; import org.apache.doris.common.NereidsException; @@ -26,6 +27,8 @@ import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.expressions.TVFProperties; import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; +import org.apache.doris.nereids.types.BigIntType; +import org.apache.doris.nereids.types.DataType; import org.apache.doris.statistics.ColumnStatistic; import org.apache.doris.statistics.StatsDeriveResult; import org.apache.doris.tablefunction.NumbersTableValuedFunction; @@ -43,10 +46,15 @@ public class Numbers extends TableValuedFunction { super("numbers", properties); } + @Override + public FunctionSignature customSignature(List argumentTypes, List arguments) { + return FunctionSignature.of(BigIntType.INSTANCE, (List) argumentTypes); + } + @Override protected TableValuedFunctionIf toCatalogFunction() { try { - Map arguments = getKeyValuesExpression().getMap(); + Map arguments = getTVFProperties().getMap(); return new NumbersTableValuedFunction(arguments); } catch (Throwable t) { throw new AnalysisException("Can not build NumbersTableValuedFunction by " @@ -82,9 +90,4 @@ public class Numbers extends TableValuedFunction { && children().get(0) instanceof TVFProperties); return new Numbers((TVFProperties) children.get(0)); } - - @Override - public boolean hasVarArguments() { - return false; - } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/table/TableValuedFunction.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/table/TableValuedFunction.java index 5df73b4594..e570abe49d 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/table/TableValuedFunction.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/table/TableValuedFunction.java @@ -24,6 +24,7 @@ import org.apache.doris.nereids.exceptions.UnboundException; import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.expressions.TVFProperties; import org.apache.doris.nereids.trees.expressions.functions.BoundFunction; +import org.apache.doris.nereids.trees.expressions.functions.CustomSignature; import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression; import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; import org.apache.doris.nereids.types.DataType; @@ -38,7 +39,7 @@ import java.util.function.Supplier; import java.util.stream.Collectors; /** TableValuedFunction */ -public abstract class TableValuedFunction extends BoundFunction implements UnaryExpression { +public abstract class TableValuedFunction extends BoundFunction implements UnaryExpression, CustomSignature { protected final Supplier catalogFunctionCache = Suppliers.memoize(() -> toCatalogFunction()); protected final Supplier tableCache = Suppliers.memoize(() -> { try { @@ -58,7 +59,7 @@ public abstract class TableValuedFunction extends BoundFunction implements Unary public abstract StatsDeriveResult computeStats(List slots); - public TVFProperties getKeyValuesExpression() { + public TVFProperties getTVFProperties() { return (TVFProperties) child(0); } @@ -95,7 +96,7 @@ public abstract class TableValuedFunction extends BoundFunction implements Unary @Override public String toSql() { - String args = getKeyValuesExpression() + String args = getTVFProperties() .getMap() .entrySet() .stream() diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/ExplainCommand.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/ExplainCommand.java index 5d559627bd..adf9cd5ba4 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/ExplainCommand.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/ExplainCommand.java @@ -29,10 +29,22 @@ public class ExplainCommand implements Command { * explain level. */ public enum ExplainLevel { - NORMAL, - VERBOSE, - GRAPH, + NONE(false), + NORMAL(false), + VERBOSE(false), + GRAPH(false), + PARSED_PLAN(true), + ANALYZED_PLAN(true), + REWRITTEN_PLAN(true), + OPTIMIZED_PLAN(true), + ALL_PLAN(true) ; + + public final boolean isPlanLevel; + + ExplainLevel(boolean isPlanLevel) { + this.isPlanLevel = isPlanLevel; + } } private final ExplainLevel level; diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalTVFRelation.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalTVFRelation.java index ab05a24846..6074cba757 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalTVFRelation.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalTVFRelation.java @@ -78,7 +78,7 @@ public class PhysicalTVFRelation extends PhysicalRelation implements TVFRelation @Override public String toString() { - return Utils.toSqlString("LogicalTVFRelation", + return Utils.toSqlString("PhysicalTVFRelation", "qualified", Utils.qualifiedName(qualifier, getTable().getName()), "output", getOutput(), "function", function.toSql() diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/PartialAggType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/PartialAggType.java new file mode 100644 index 0000000000..14d4a7ac35 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/PartialAggType.java @@ -0,0 +1,99 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.types; + +import org.apache.doris.catalog.ScalarType; +import org.apache.doris.catalog.Type; +import org.apache.doris.nereids.trees.expressions.Expression; + +import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableList; + +import java.util.List; +import java.util.Objects; + +/** PartialAggType */ +public class PartialAggType extends DataType { + public final List originArguments; + public final List intermediateTypes; + + /** PartialAggType */ + public PartialAggType(List originArguments, List intermediateTypes) { + this.originArguments = ImmutableList.copyOf( + Objects.requireNonNull(originArguments, "originArguments can not be null")); + this.intermediateTypes = ImmutableList.copyOf( + Objects.requireNonNull(intermediateTypes, "intermediateTypes can not be null")); + Preconditions.checkArgument(intermediateTypes.size() > 0, "intermediateTypes can not empty"); + } + + public List getOriginArguments() { + return originArguments; + } + + public List getIntermediateTypes() { + return intermediateTypes; + } + + public List getOriginInputTypes() { + return originArguments.stream() + .map(Expression::getDataType) + .collect(ImmutableList.toImmutableList()); + } + + @Override + public String toSql() { + return "PartialAggType(types=" + intermediateTypes + ")"; + } + + @Override + public int width() { + return intermediateTypes.stream() + .map(DataType::width) + .reduce((w1, w2) -> w1 + w2) + .get(); + } + + @Override + public Type toCatalogDataType() { + if (intermediateTypes.size() == 1) { + return intermediateTypes.get(0).toCatalogDataType(); + } + return ScalarType.createVarcharType(-1); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + if (!super.equals(o)) { + return false; + } + PartialAggType that = (PartialAggType) o; + return Objects.equals(originArguments, that.originArguments) + && Objects.equals(intermediateTypes, that.intermediateTypes); + } + + @Override + public int hashCode() { + return Objects.hash(super.hashCode(), originArguments, intermediateTypes); + } +} diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/FunctionRegistryTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/FunctionRegistryTest.java index 36b56d0342..4108893cfa 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/FunctionRegistryTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/FunctionRegistryTest.java @@ -18,15 +18,19 @@ package org.apache.doris.nereids.rules.analysis; import org.apache.doris.catalog.FunctionRegistry; +import org.apache.doris.catalog.FunctionSignature; import org.apache.doris.nereids.exceptions.AnalysisException; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.functions.BoundFunction; +import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature; import org.apache.doris.nereids.trees.expressions.functions.FunctionBuilder; import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable; +import org.apache.doris.nereids.trees.expressions.functions.scalar.ScalarFunction; import org.apache.doris.nereids.trees.expressions.functions.scalar.Substring; import org.apache.doris.nereids.trees.expressions.functions.scalar.Year; import org.apache.doris.nereids.trees.expressions.literal.Literal; import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression; +import org.apache.doris.nereids.types.IntegerType; import org.apache.doris.nereids.util.MemoTestUtils; import org.apache.doris.nereids.util.PatternMatchSupported; import org.apache.doris.nereids.util.PlanChecker; @@ -136,18 +140,27 @@ public class FunctionRegistryTest implements PatternMatchSupported { }); } - public static class ExtendFunction extends BoundFunction implements UnaryExpression, PropagateNullable { + public static class ExtendFunction extends BoundFunction implements UnaryExpression, PropagateNullable, + ExplicitlyCastableSignature { public ExtendFunction(Expression a1) { super("foo", a1); } + @Override + public List getSignatures() { + return ImmutableList.of( + FunctionSignature.ret(IntegerType.INSTANCE).args(IntegerType.INSTANCE) + ); + } + @Override public boolean hasVarArguments() { return false; } } - public static class AmbiguousFunction extends BoundFunction implements UnaryExpression, PropagateNullable { + public static class AmbiguousFunction extends ScalarFunction implements UnaryExpression, PropagateNullable, + ExplicitlyCastableSignature { public AmbiguousFunction(Expression a1) { super("abc", a1); } @@ -156,6 +169,13 @@ public class FunctionRegistryTest implements PatternMatchSupported { super("abc", a1); } + @Override + public List getSignatures() { + return ImmutableList.of( + FunctionSignature.ret(IntegerType.INSTANCE).args(IntegerType.INSTANCE) + ); + } + @Override public boolean hasVarArguments() { return false; diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/RegisterCTETest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/RegisterCTETest.java index 5ac061ed58..d0d4692f93 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/RegisterCTETest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/RegisterCTETest.java @@ -17,6 +17,7 @@ package org.apache.doris.nereids.rules.analysis; +import org.apache.doris.common.NereidsException; import org.apache.doris.nereids.NereidsPlanner; import org.apache.doris.nereids.StatementContext; import org.apache.doris.nereids.analyzer.UnboundAlias; @@ -282,7 +283,7 @@ public class RegisterCTETest extends TestWithFeService implements PatternMatchSu String sql = "WITH cte1 (a1, A1) AS (SELECT * FROM supplier)" + "SELECT * FROM cte1"; - AnalysisException exception = Assertions.assertThrows(AnalysisException.class, () -> { + NereidsException exception = Assertions.assertThrows(NereidsException.class, () -> { PlanChecker.from(connectContext).checkPlannerResult(sql); }, "Not throw expected exception."); Assertions.assertTrue(exception.getMessage().contains("Duplicated CTE column alias: [a1] in CTE [cte1]")); @@ -294,7 +295,7 @@ public class RegisterCTETest extends TestWithFeService implements PatternMatchSu + "(SELECT s_suppkey FROM supplier)" + "SELECT * FROM cte1"; - AnalysisException exception = Assertions.assertThrows(AnalysisException.class, () -> { + NereidsException exception = Assertions.assertThrows(NereidsException.class, () -> { PlanChecker.from(connectContext).checkPlannerResult(sql); }, "Not throw expected exception."); System.out.println(exception.getMessage()); diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rewrite/TypeCoercionTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rewrite/TypeCoercionTest.java index 0fce3b8040..f9e3a463ba 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rewrite/TypeCoercionTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rewrite/TypeCoercionTest.java @@ -17,6 +17,7 @@ package org.apache.doris.nereids.rules.expression.rewrite; +import org.apache.doris.nereids.exceptions.AnalysisException; import org.apache.doris.nereids.rules.expression.rewrite.rules.TypeCoercion; import org.apache.doris.nereids.trees.expressions.CaseWhen; import org.apache.doris.nereids.trees.expressions.Cast; @@ -48,6 +49,7 @@ import org.apache.doris.nereids.types.TinyIntType; import com.google.common.collect.ImmutableList; import com.google.common.collect.Lists; +import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -100,16 +102,16 @@ public class TypeCoercionTest extends ExpressionRewriteTestHelper { @Test public void testSumImplicitCast() { - Expression expression = new Sum(new StringLiteral("1")); - Expression expected = new Sum(new Cast(new StringLiteral("1"), DoubleType.INSTANCE)); - assertRewrite(expression, expected); + Assertions.assertThrows(AnalysisException.class, () -> { + new Sum(new StringLiteral("1")).getDataType(); + }); } @Test public void testAvgImplicitCast() { - Expression expression = new Avg(new StringLiteral("1")); - Expression expected = new Avg(new Cast(new StringLiteral("1"), DoubleType.INSTANCE)); - assertRewrite(expression, expected); + Assertions.assertThrows(AnalysisException.class, () -> { + new Avg(new StringLiteral("1")).getDataType(); + }); } @Test diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/AggregateDisassembleTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/AggregateDisassembleTest.java index af737bdeb3..f10b6325a2 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/AggregateDisassembleTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/AggregateDisassembleTest.java @@ -45,7 +45,6 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.TestInstance; import java.util.List; -import java.util.Optional; @TestInstance(TestInstance.Lifecycle.PER_CLASS) public class AggregateDisassembleTest implements PatternMatchSupported { @@ -277,7 +276,7 @@ public class AggregateDisassembleTest implements PatternMatchSupported { // id Expression localOutput0 = rStudent.getOutput().get(0); // sum - Sum localOutput1 = new Sum(new AggregateParam(false, false, Optional.empty()), rStudent.getOutput().get(0).toSlot()); + Sum localOutput1 = new Sum(new AggregateParam(false, false, true), rStudent.getOutput().get(0).toSlot()); // age Expression localOutput2 = rStudent.getOutput().get(2); // id diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PushdownExpressionsInHashConditionTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PushdownExpressionsInHashConditionTest.java index b60d1acad1..f6735c4f4c 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PushdownExpressionsInHashConditionTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PushdownExpressionsInHashConditionTest.java @@ -17,12 +17,10 @@ package org.apache.doris.nereids.rules.rewrite.logical; -import org.apache.doris.common.AnalysisException; import org.apache.doris.nereids.NereidsPlanner; import org.apache.doris.nereids.parser.NereidsParser; import org.apache.doris.nereids.properties.PhysicalProperties; import org.apache.doris.nereids.trees.expressions.NamedExpressionUtil; -import org.apache.doris.nereids.trees.plans.physical.PhysicalPlan; import org.apache.doris.nereids.util.PatternMatchSupported; import org.apache.doris.nereids.util.PlanChecker; import org.apache.doris.utframe.TestWithFeService; @@ -33,13 +31,6 @@ import org.junit.jupiter.api.Test; import java.util.List; public class PushdownExpressionsInHashConditionTest extends TestWithFeService implements PatternMatchSupported { - private final List testSql = ImmutableList.of( - "SELECT * FROM T1 JOIN T2 ON T1.ID + 1 = T2.ID + 2 AND T1.ID + 1 > 2", - "SELECT * FROM (SELECT * FROM T1) X JOIN (SELECT * FROM T2) Y ON X.ID + 1 = Y.ID + 2 AND X.ID + 1 > 2", - "SELECT * FROM T1 JOIN (SELECT ID, SUM(SCORE) SCORE FROM T2 GROUP BY ID) T ON T1.ID + 1 = T.ID AND T.SCORE < 10", - "SELECT * FROM T1 JOIN (SELECT ID, SUM(SCORE) SCORE FROM T2 GROUP BY ID ORDER BY ID) T ON T1.ID + 1 = T.ID AND T.SCORE < 10" - ); - @Override protected void runBeforeAll() throws Exception { createDatabase("test"); @@ -81,15 +72,10 @@ public class PushdownExpressionsInHashConditionTest extends TestWithFeService im "SELECT * FROM T1 JOIN (SELECT ID, SUM(SCORE) SCORE FROM T2 GROUP BY ID ORDER BY ID) T ON T1.ID + 1 = T.ID AND T.SCORE < 10" ); testSql.forEach(sql -> { - try { - PhysicalPlan plan = new NereidsPlanner(createStatementCtx(sql)).plan( - new NereidsParser().parseSingle(sql), - PhysicalProperties.ANY - ); - System.out.println(plan.treeString()); - } catch (AnalysisException e) { - throw new RuntimeException(e); - } + new NereidsPlanner(createStatementCtx(sql)).plan( + new NereidsParser().parseSingle(sql), + PhysicalProperties.ANY + ); }); } diff --git a/regression-test/suites/nereids_syntax_p0/explain.groovy b/regression-test/suites/nereids_syntax_p0/explain.groovy index 52e13de757..e588b510f3 100644 --- a/regression-test/suites/nereids_syntax_p0/explain.groovy +++ b/regression-test/suites/nereids_syntax_p0/explain.groovy @@ -32,4 +32,19 @@ suite("explain") { contains "project output tuple id: 1" } + + explain { + sql("physical plan select 100") + contains "PhysicalOneRowRelation" + } + + explain { + sql("logical plan select 100") + contains "LogicalOneRowRelation" + } + + explain { + sql("parsed plan select 100") + contains "UnboundOneRowRelation" + } }