[Enhancement](Nereids) Refactor AggregateFunction and support explain plan (#14380)

# Proposed changes

- Refactor AggregateFunction
    1. AggregateFunction implement ComputeSignature
    3. Add a CustomSignature to dynamic compute signature, we can check input type and compute implicit cast type in the `customSignature` method
    2. Add PartialAggType to record some type information before disassemble aggregate
    4. Refine and create a custom catalog function when translate AggregateFunction, without `finalizeForNereids`
-  Support explain plan
    1. explain parsed plan select ...
    5. explain analyzed plan select ...
    6. explain rewritten/logical plan select ...
    7. explain optimized/physical plan select ...
    8. explain all plan select ...
This commit is contained in:
924060929
2022-11-18 23:40:33 +08:00
committed by GitHub
parent c4bade71c8
commit 63a2344e68
57 changed files with 1070 additions and 659 deletions

View File

@ -94,6 +94,7 @@ AFTER: 'AFTER';
ALL: 'ALL';
ALTER: 'ALTER';
ANALYZE: 'ANALYZE';
ANALYZED: 'ANALYZED';
AND: 'AND';
ANTI: 'ANTI';
ANY: 'ANY';
@ -247,6 +248,7 @@ NULLS: 'NULLS';
OF: 'OF';
ON: 'ON';
ONLY: 'ONLY';
OPTIMIZED: 'OPTIMIZED';
OPTION: 'OPTION';
OPTIONS: 'OPTIONS';
OR: 'OR';
@ -258,13 +260,16 @@ OVER: 'OVER';
OVERLAPS: 'OVERLAPS';
OVERLAY: 'OVERLAY';
OVERWRITE: 'OVERWRITE';
PARSED: 'PARSED';
PARTITION: 'PARTITION';
PARTITIONED: 'PARTITIONED';
PARTITIONS: 'PARTITIONS';
PERCENTILE_CONT: 'PERCENTILE_CONT';
PERCENTLIT: 'PERCENT';
PHYSICAL: 'PHYSICAL';
PIVOT: 'PIVOT';
PLACING: 'PLACING';
PLAN: 'PLAN';
POSITION: 'POSITION';
PRECEDING: 'PRECEDING';
PRIMARY: 'PRIMARY';
@ -288,6 +293,7 @@ RESET: 'RESET';
RESPECT: 'RESPECT';
RESTRICT: 'RESTRICT';
REVOKE: 'REVOKE';
REWRITTEN: 'REWRITTEN';
RIGHT: 'RIGHT';
// original optimizer only support REGEXP, the new optimizer should be consistent with it
RLIKE: 'RLIKE';

View File

@ -50,7 +50,16 @@ singleStatement
statement
: cte? query #statementDefault
| (EXPLAIN | DESC | DESCRIBE) level=(VERBOSE | GRAPH)? query #explain
| (EXPLAIN planType? | DESC | DESCRIBE)
level=(VERBOSE | GRAPH | PLAN)? query #explain
;
planType
: PARSED
| ANALYZED
| REWRITTEN | LOGICAL // same type
| OPTIMIZED | PHYSICAL // same type
| ALL // default type
;
// -----------------Query-----------------
@ -335,6 +344,7 @@ ansiNonReserved
| AFTER
| ALTER
| ANALYZE
| ANALYZED
| ANTI
| ARCHIVE
| ARRAY
@ -441,6 +451,7 @@ ansiNonReserved
| NO
| NULLS
| OF
| OPTIMIZED
| OPTION
| OPTIONS
| OUT
@ -448,12 +459,15 @@ ansiNonReserved
| OVER
| OVERLAY
| OVERWRITE
| PARSED
| PARTITION
| PARTITIONED
| PARTITIONS
| PERCENTLIT
| PHYSICAL
| PIVOT
| PLACING
| PLAN
| POSITION
| PRECEDING
| PRINCIPALS
@ -474,6 +488,7 @@ ansiNonReserved
| RESPECT
| RESTRICT
| REVOKE
| REWRITTEN
| RLIKE
| ROLE
| ROLES
@ -575,6 +590,7 @@ nonReserved
| ALL
| ALTER
| ANALYZE
| ANALYZED
| AND
| ANY
| ARCHIVE
@ -716,6 +732,7 @@ nonReserved
| NULLS
| OF
| ONLY
| OPTIMIZED
| OPTION
| OPTIONS
| OR
@ -727,13 +744,16 @@ nonReserved
| OVERLAPS
| OVERLAY
| OVERWRITE
| PARSED
| PARTITION
| PARTITIONED
| PARTITIONS
| PERCENTILE_CONT
| PERCENTLIT
| PHYSICAL
| PIVOT
| PLACING
| PLAN
| POSITION
| PRECEDING
| PRIMARY
@ -756,6 +776,7 @@ nonReserved
| RESPECT
| RESTRICT
| REVOKE
| REWRITTEN
| RLIKE
| ROLE
| ROLES

View File

@ -17,21 +17,38 @@
package org.apache.doris.analysis;
import org.apache.doris.nereids.trees.plans.commands.ExplainCommand;
import org.apache.doris.nereids.trees.plans.commands.ExplainCommand.ExplainLevel;
public class ExplainOptions {
private boolean isVerbose;
private boolean isGraph;
private ExplainCommand.ExplainLevel explainLevel;
public ExplainOptions(ExplainCommand.ExplainLevel explainLevel) {
this.explainLevel = explainLevel;
}
public ExplainOptions(boolean isVerbose, boolean isGraph) {
this.isVerbose = isVerbose;
this.isGraph = isGraph;
}
public boolean isVerbose() {
return isVerbose;
return explainLevel == ExplainLevel.VERBOSE || isVerbose;
}
public boolean isGraph() {
return isGraph;
return explainLevel == ExplainLevel.GRAPH || isGraph;
}
public boolean hasExplainLevel() {
return explainLevel != null;
}
public ExplainLevel getExplainLevel() {
return explainLevel;
}
}

View File

@ -220,16 +220,22 @@ public class FunctionCallExpr extends Expr {
this.argTypesForNereids = argTypes;
}
// nereids constructor without finalize/analyze
public FunctionCallExpr(FunctionName functionName, Function function, FunctionParams functionParams) {
this.fnName = functionName;
// nereids scalar function call expr constructor without finalize/analyze
public FunctionCallExpr(Function function, FunctionParams functionParams) {
this(function, functionParams, null, false, functionParams.exprs());
}
// nereids aggregate function call expr constructor without finalize/analyze
public FunctionCallExpr(Function function, FunctionParams functionParams, FunctionParams aggFnParams,
boolean isMergeAggFn, List<Expr> children) {
this.fnName = function.getFunctionName();
this.fn = function;
this.type = function.getReturnType();
this.fnParams = functionParams;
if (functionParams.exprs() != null) {
this.children.addAll(functionParams.exprs());
}
this.aggFnParams = aggFnParams;
this.children.addAll(children);
this.originChildSize = children.size();
this.isMergeAggFn = isMergeAggFn;
this.shouldFinalizeForNereids = false;
}

View File

@ -84,6 +84,19 @@ public class SlotRef extends Expr {
analysisDone();
}
// nerieds use this constructor to build aggFnParam
public SlotRef(Type type, boolean nullable) {
super();
// tuple id and slot id is meaningless here, nereids just use type and nullable
// to build the TAggregateExpr.param_types
TupleDescriptor tupleDescriptor = new TupleDescriptor(new TupleId(-1));
desc = new SlotDescriptor(new SlotId(-1), tupleDescriptor);
tupleDescriptor.addSlot(desc);
desc.setIsNullable(nullable);
desc.setType(type);
this.type = type;
}
protected SlotRef(SlotRef other) {
super(other);
tblName = other.tblName;

View File

@ -183,6 +183,30 @@ public class AggregateFunction extends Function {
returnsNonNullOnEmpty = false;
}
public AggregateFunction(FunctionName fnName, List<Type> argTypes,
Type retType, Type intermediateType, boolean hasVarArgs,
URI location, String updateFnSymbol, String initFnSymbol,
String serializeFnSymbol, String mergeFnSymbol, String getValueFnSymbol,
String removeFnSymbol, String finalizeFnSymbol, boolean ignoresDistinct,
boolean isAnalyticFn, boolean returnsNonNullOnEmpty, TFunctionBinaryType binaryType,
boolean userVisible, boolean vectorized, NullableMode nullableMode) {
// only `count` is always not nullable, other aggregate function is always nullable
super(0, fnName, argTypes, retType, hasVarArgs, binaryType, userVisible, vectorized, nullableMode);
setLocation(location);
this.intermediateType = (intermediateType.equals(retType)) ? null : intermediateType;
this.updateFnSymbol = updateFnSymbol;
this.initFnSymbol = initFnSymbol;
this.serializeFnSymbol = serializeFnSymbol;
this.mergeFnSymbol = mergeFnSymbol;
this.getValueFnSymbol = getValueFnSymbol;
this.removeFnSymbol = removeFnSymbol;
this.finalizeFnSymbol = finalizeFnSymbol;
this.ignoresDistinct = ignoresDistinct;
this.isAnalyticFn = isAnalyticFn;
this.isAggregateFn = true;
this.returnsNonNullOnEmpty = returnsNonNullOnEmpty;
}
public static AggregateFunction createBuiltin(String name,
List<Type> argTypes, Type retType, Type intermediateType,
String initFnSymbol, String updateFnSymbol, String mergeFnSymbol,

View File

@ -19,6 +19,7 @@ package org.apache.doris.catalog;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.coercion.AbstractDataType;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
@ -32,10 +33,10 @@ import java.util.function.BiFunction;
public class FunctionSignature {
public final DataType returnType;
public final boolean hasVarArgs;
public final List<DataType> argumentsTypes;
public final List<AbstractDataType> argumentsTypes;
public final int arity;
public FunctionSignature(DataType returnType, boolean hasVarArgs, List<DataType> argumentsTypes) {
public FunctionSignature(DataType returnType, boolean hasVarArgs, List<AbstractDataType> argumentsTypes) {
this.returnType = Objects.requireNonNull(returnType, "returnType is not null");
this.argumentsTypes = ImmutableList.copyOf(
Objects.requireNonNull(argumentsTypes, "argumentsTypes is not null"));
@ -43,11 +44,11 @@ public class FunctionSignature {
this.arity = argumentsTypes.size();
}
public Optional<DataType> getVarArgType() {
public Optional<AbstractDataType> getVarArgType() {
return hasVarArgs ? Optional.of(argumentsTypes.get(arity - 1)) : Optional.empty();
}
public DataType getArgType(int index) {
public AbstractDataType getArgType(int index) {
if (hasVarArgs && index >= arity) {
return argumentsTypes.get(arity - 1);
}
@ -58,7 +59,7 @@ public class FunctionSignature {
return new FunctionSignature(returnType, hasVarArgs, argumentsTypes);
}
public FunctionSignature withArgumentTypes(boolean hasVarArgs, List<DataType> argumentsTypes) {
public FunctionSignature withArgumentTypes(boolean hasVarArgs, List<AbstractDataType> argumentsTypes) {
return new FunctionSignature(returnType, hasVarArgs, argumentsTypes);
}
@ -69,27 +70,27 @@ public class FunctionSignature {
* @return
*/
public FunctionSignature withArgumentTypes(List<Expression> arguments,
BiFunction<DataType, Expression, DataType> transform) {
List<DataType> newTypes = Lists.newArrayList();
BiFunction<AbstractDataType, Expression, AbstractDataType> transform) {
List<AbstractDataType> newTypes = Lists.newArrayList();
for (int i = 0; i < arguments.size(); i++) {
newTypes.add(transform.apply(getArgType(i), arguments.get(i)));
}
return withArgumentTypes(hasVarArgs, newTypes);
}
public static FunctionSignature of(DataType returnType, List<DataType> argumentsTypes) {
public static FunctionSignature of(DataType returnType, List<AbstractDataType> argumentsTypes) {
return of(returnType, false, argumentsTypes);
}
public static FunctionSignature of(DataType returnType, boolean hasVarArgs, List<DataType> argumentsTypes) {
public static FunctionSignature of(DataType returnType, boolean hasVarArgs, List<AbstractDataType> argumentsTypes) {
return new FunctionSignature(returnType, hasVarArgs, argumentsTypes);
}
public static FunctionSignature of(DataType returnType, DataType... argumentsTypes) {
public static FunctionSignature of(DataType returnType, AbstractDataType... argumentsTypes) {
return of(returnType, false, argumentsTypes);
}
public static FunctionSignature of(DataType returnType, boolean hasVarArgs, DataType... argumentsTypes) {
public static FunctionSignature of(DataType returnType, boolean hasVarArgs, AbstractDataType... argumentsTypes) {
return new FunctionSignature(returnType, hasVarArgs, Arrays.asList(argumentsTypes));
}
@ -104,11 +105,11 @@ public class FunctionSignature {
this.returnType = returnType;
}
public FunctionSignature args(DataType...argTypes) {
public FunctionSignature args(AbstractDataType...argTypes) {
return FunctionSignature.of(returnType, false, argTypes);
}
public FunctionSignature varArgs(DataType...argTypes) {
public FunctionSignature varArgs(AbstractDataType...argTypes) {
return FunctionSignature.of(returnType, true, argTypes);
}
}

View File

@ -74,6 +74,13 @@ public class ScalarFunction extends Function {
NullableMode.DEPEND_ON_ARGUMENT);
}
/** nerieds custom scalar function */
public ScalarFunction(FunctionName fnName, List<Type> argTypes, Type retType, boolean hasVarArgs, String symbolName,
TFunctionBinaryType binaryType, boolean userVisible, boolean isVec, NullableMode nullableMode) {
super(0, fnName, argTypes, retType, hasVarArgs, binaryType, userVisible, isVec, nullableMode);
this.symbolName = symbolName;
}
public ScalarFunction(FunctionName fnName, List<Type> argTypes,
Type retType, URI location, String symbolName, String initFnSymbol,
String closeFnSymbol) {

View File

@ -18,9 +18,10 @@
package org.apache.doris.nereids;
import org.apache.doris.analysis.DescriptorTable;
import org.apache.doris.analysis.ExplainOptions;
import org.apache.doris.analysis.StatementBase;
import org.apache.doris.common.AnalysisException;
import org.apache.doris.common.UserException;
import org.apache.doris.common.NereidsException;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.glue.LogicalPlanAdapter;
import org.apache.doris.nereids.glue.translator.PhysicalPlanTranslator;
import org.apache.doris.nereids.glue.translator.PlanTranslatorContext;
@ -36,6 +37,7 @@ import org.apache.doris.nereids.properties.PhysicalProperties;
import org.apache.doris.nereids.rules.joinreorder.HyperGraphJoinReorder;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.commands.ExplainCommand.ExplainLevel;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.trees.plans.physical.PhysicalPlan;
import org.apache.doris.planner.PlanFragment;
@ -64,18 +66,29 @@ public class NereidsPlanner extends Planner {
private List<ScanNode> scanNodeList = null;
private DescriptorTable descTable;
private Plan parsedPlan;
private Plan analyzedPlan;
private Plan rewrittenPlan;
private Plan optimizedPlan;
public NereidsPlanner(StatementContext statementContext) {
this.statementContext = statementContext;
}
@Override
public void plan(StatementBase queryStmt, org.apache.doris.thrift.TQueryOptions queryOptions) throws UserException {
public void plan(StatementBase queryStmt, org.apache.doris.thrift.TQueryOptions queryOptions) {
if (!(queryStmt instanceof LogicalPlanAdapter)) {
throw new RuntimeException("Wrong type of queryStmt, expected: <? extends LogicalPlanAdapter>");
}
LogicalPlanAdapter logicalPlanAdapter = (LogicalPlanAdapter) queryStmt;
PhysicalPlan physicalPlan = plan(logicalPlanAdapter.getLogicalPlan(), PhysicalProperties.ANY);
ExplainLevel explainLevel = getExplainLevel(queryStmt.getExplainOptions());
Plan resultPlan = plan(logicalPlanAdapter.getLogicalPlan(), PhysicalProperties.ANY, explainLevel);
if (explainLevel.isPlanLevel) {
return;
}
PhysicalPlan physicalPlan = (PhysicalPlan) resultPlan;
PhysicalPlanTranslator physicalPlanTranslator = new PhysicalPlanTranslator();
PlanTranslatorContext planTranslatorContext = new PlanTranslatorContext(cascadesContext);
if (ConnectContext.get().getSessionVariable().isEnableNereidsTrace()) {
@ -103,20 +116,30 @@ public class NereidsPlanner extends Planner {
public void plan(StatementBase queryStmt) {
try {
plan(queryStmt, statementContext.getConnectContext().getSessionVariable().toThrift());
} catch (UserException e) {
throw new RuntimeException(e);
} catch (Exception e) {
throw new NereidsException(e);
}
}
public PhysicalPlan plan(LogicalPlan plan, PhysicalProperties outputProperties) {
return (PhysicalPlan) plan(plan, outputProperties, ExplainLevel.NONE);
}
/**
* Do analyze and optimize for query plan.
*
* @param plan wait for plan
* @param outputProperties physical properties constraints
* @return physical plan generated by this planner
* @return plan generated by this planner
* @throws AnalysisException throw exception if failed in ant stage
*/
public PhysicalPlan plan(LogicalPlan plan, PhysicalProperties outputProperties) throws AnalysisException {
public Plan plan(LogicalPlan plan, PhysicalProperties outputProperties, ExplainLevel explainLevel) {
if (explainLevel == ExplainLevel.PARSED_PLAN || explainLevel == ExplainLevel.ALL_PLAN) {
parsedPlan = plan;
if (explainLevel == ExplainLevel.PARSED_PLAN) {
return parsedPlan;
}
}
// pre-process logical plan out of memo, e.g. process SET_VAR hint
plan = preprocess(plan);
@ -125,9 +148,21 @@ public class NereidsPlanner extends Planner {
// resolve column, table and function
analyze();
if (explainLevel == ExplainLevel.ANALYZED_PLAN || explainLevel == ExplainLevel.ALL_PLAN) {
analyzedPlan = cascadesContext.getMemo().copyOut(false);
if (explainLevel == ExplainLevel.ANALYZED_PLAN) {
return analyzedPlan;
}
}
// rule-based optimize
rewrite();
if (explainLevel == ExplainLevel.REWRITTEN_PLAN || explainLevel == ExplainLevel.ALL_PLAN) {
rewrittenPlan = cascadesContext.getMemo().copyOut(false);
if (explainLevel == ExplainLevel.REWRITTEN_PLAN) {
return rewrittenPlan;
}
}
deriveStats();
@ -142,7 +177,12 @@ public class NereidsPlanner extends Planner {
PhysicalPlan physicalPlan = chooseBestPlan(getRoot(), PhysicalProperties.ANY);
// post-process physical plan out of memo, just for future use.
return postProcess(physicalPlan);
physicalPlan = postProcess(physicalPlan);
if (explainLevel == ExplainLevel.OPTIMIZED_PLAN || explainLevel == ExplainLevel.ALL_PLAN) {
optimizedPlan = physicalPlan;
}
return physicalPlan;
}
private LogicalPlan preprocess(LogicalPlan logicalPlan) {
@ -229,6 +269,33 @@ public class NereidsPlanner extends Planner {
}
}
@Override
public String getExplainString(ExplainOptions explainOptions) {
ExplainLevel explainLevel = getExplainLevel(explainOptions);
switch (explainLevel) {
case PARSED_PLAN:
return parsedPlan.treeString();
case ANALYZED_PLAN:
return analyzedPlan.treeString();
case REWRITTEN_PLAN:
return rewrittenPlan.treeString();
case OPTIMIZED_PLAN:
return optimizedPlan.treeString();
case ALL_PLAN:
String explainString = "========== PARSED PLAN ==========\n"
+ parsedPlan.treeString() + "\n\n"
+ "========== ANALYZED PLAN ==========\n"
+ analyzedPlan.treeString() + "\n\n"
+ "========== REWRITTEN PLAN ==========\n"
+ rewrittenPlan.treeString() + "\n\n"
+ "========== OPTIMIZED PLAN ==========\n"
+ optimizedPlan.treeString();
return explainString;
default:
return super.getExplainString(explainOptions);
}
}
@Override
public boolean isBlockQuery() {
return true;
@ -253,4 +320,12 @@ public class NereidsPlanner extends Planner {
public CascadesContext getCascadesContext() {
return cascadesContext;
}
private ExplainLevel getExplainLevel(ExplainOptions explainOptions) {
if (explainOptions == null) {
return ExplainLevel.NONE;
}
ExplainLevel explainLevel = explainOptions.getExplainLevel();
return explainLevel == null ? ExplainLevel.NONE : explainLevel;
}
}

View File

@ -43,6 +43,8 @@ public class StatementContext {
private CTEContext cteContext;
public StatementContext() {
this.connectContext = ConnectContext.get();
this.cteContext = new CTEContext();
}
public StatementContext(ConnectContext connectContext, OriginStatement originStatement) {

View File

@ -33,10 +33,8 @@ import org.apache.doris.analysis.FunctionParams;
import org.apache.doris.analysis.LikePredicate;
import org.apache.doris.analysis.SlotRef;
import org.apache.doris.analysis.TimestampArithmeticExpr;
import org.apache.doris.catalog.Function;
import org.apache.doris.catalog.Function.NullableMode;
import org.apache.doris.catalog.Type;
import org.apache.doris.nereids.annotation.Developing;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.And;
@ -67,8 +65,10 @@ import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ScalarFunction;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionVisitor;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.coercion.AbstractDataType;
import org.apache.doris.thrift.TFunctionBinaryType;
import com.google.common.collect.ImmutableList;
import java.util.ArrayList;
import java.util.List;
@ -256,68 +256,56 @@ public class ExpressionTranslator extends DefaultExpressionVisitor<Expr, PlanTra
// TODO: Supports for `distinct`
@Override
@Developing("Generate FunctionCallExpr and Function without analyze/finalize")
public Expr visitAggregateFunction(AggregateFunction function, PlanTranslatorContext context) {
// inputTypesBeforeDissemble is used to find the origin function's input type before disassemble aggregate.
//
// For example, 'double avg(int)' will be disassembled to 'varchar avg(int)' and 'double avg(varchar)'
// which the varchar contains sum(double) and count(int).
//
// We save the origin input type 'int' for the global aggregate 'varchar avg(int)', and get it in the
// 'inputTypesBeforeDissemble' variable, so we can find the catalog function 'avg(int)' in **frontend**.
//
// Vectorized engine in backend will find the 'avg(int)' function, and forwarding to the correct global
// aggregate function 'double avg(varchar)' by FunctionCallExpr.isMergeAggFn.
Optional<List<Type>> inputTypesBeforeDissemble = function.inputTypesBeforeDissemble()
.map(types -> types.stream()
.map(DataType::toCatalogDataType)
.collect(Collectors.toList())
);
List<Expr> catalogArguments = function.getArguments()
.stream()
.map(arg -> arg.accept(this, context))
.collect(ImmutableList.toImmutableList());
// We should change the global aggregate function's temporary input type(varchar) to the origin input type.
//
// For example: the global aggregate function expression 'avg(slotRef(type=varchar))' of the origin function
// 'avg(int)' should change to 'avg(slotRef(type=int))', because FunctionCallExpr will be converted to thrift
// format, and compute signature string by the children's type, we should pass the signature 'avg(int)' to
// **backend**. If we pass 'avg(varchar)' to backend, it will throw an exception: 'Agg Function avg is not
// implemented'.
List<Expr> catalogParams = new ArrayList<>();
for (int i = 0; i < function.arity(); i++) {
Expr catalogExpr = function.child(i).accept(this, context);
if (catalogExpr instanceof SlotRef && inputTypesBeforeDissemble.isPresent()
// count(*) in local aggregate contains empty children
// but contains one child in global aggregate: 'count(count(*))'.
// so the size of inputTypesBeforeDissemble maybe less than global aggregate param.
&& inputTypesBeforeDissemble.get().size() > i) {
SlotRef intermediateSlot = (SlotRef) catalogExpr.clone();
// change the slot type to origin input type
intermediateSlot.setType(inputTypesBeforeDissemble.get().get(i));
catalogExpr = intermediateSlot;
}
catalogParams.add(catalogExpr);
// aggFnArguments is used to build TAggregateExpr.param_types, so backend can find the aggregate function
List<Expr> aggFnArguments = function.getArgumentsBeforeDisassembled()
.stream()
.map(arg -> new SlotRef(arg.getDataType().toCatalogDataType(), arg.nullable()))
.collect(ImmutableList.toImmutableList());
FunctionParams aggFnParams;
if (function instanceof Count && ((Count) function).isStar()) {
aggFnParams = FunctionParams.createStarParam();
} else {
aggFnParams = new FunctionParams(function.isDistinct(), aggFnArguments);
}
boolean distinct = function.isDistinct();
FunctionParams aggFnParams = new FunctionParams(distinct, catalogParams);
ImmutableList<Type> argTypes = catalogArguments.stream()
.map(arg -> arg.getType())
.collect(ImmutableList.toImmutableList());
if (function instanceof Count) {
Count count = (Count) function;
if (count.isStar()) {
return new FunctionCallExpr(function.getName(), FunctionParams.createStarParam(),
aggFnParams, inputTypesBeforeDissemble);
} else if (count.isDistinct()) {
return new FunctionCallExpr(function.getName(), new FunctionParams(distinct, catalogParams),
aggFnParams, inputTypesBeforeDissemble);
}
}
return new FunctionCallExpr(function.getName(), new FunctionParams(distinct, catalogParams),
aggFnParams, inputTypesBeforeDissemble);
NullableMode nullableMode = function.nullable()
? NullableMode.ALWAYS_NULLABLE
: NullableMode.ALWAYS_NOT_NULLABLE;
boolean isAnalyticFunction = false;
org.apache.doris.catalog.AggregateFunction catalogFunction = new org.apache.doris.catalog.AggregateFunction(
new FunctionName(function.getName()), argTypes,
function.getDataType().toCatalogDataType(),
function.getIntermediateTypes().toCatalogDataType(),
function.hasVarArguments(),
null, "", "", null, "",
null, "", null, false,
isAnalyticFunction, false, TFunctionBinaryType.BUILTIN,
true, true, nullableMode
);
boolean isMergeFn = function.isGlobal() && function.isDisassembled();
// create catalog FunctionCallExpr without analyze again
return new FunctionCallExpr(catalogFunction, aggFnParams, aggFnParams, isMergeFn, catalogArguments);
}
@Override
public Expr visitScalarFunction(ScalarFunction function, PlanTranslatorContext context) {
List<Expr> arguments = function.getArguments()
.stream().map(arg -> arg.accept(this, context))
.stream()
.map(arg -> arg.accept(this, context))
.collect(Collectors.toList());
List<Type> argTypes = function.expectedInputTypes().stream()
.map(AbstractDataType::toCatalogDataType)
@ -327,12 +315,13 @@ public class ExpressionTranslator extends DefaultExpressionVisitor<Expr, PlanTra
? NullableMode.ALWAYS_NULLABLE
: NullableMode.ALWAYS_NOT_NULLABLE;
Function catalogFunction = new Function(new FunctionName(function.getName()), argTypes,
function.getDataType().toCatalogDataType(), function.hasVarArguments(), true, nullableMode);
org.apache.doris.catalog.ScalarFunction catalogFunction = new org.apache.doris.catalog.ScalarFunction(
new FunctionName(function.getName()), argTypes,
function.getDataType().toCatalogDataType(), function.hasVarArguments(),
"", TFunctionBinaryType.BUILTIN, true, true, nullableMode);
// create catalog FunctionCallExpr without analyze again
return new FunctionCallExpr(catalogFunction.getFunctionName(), catalogFunction,
new FunctionParams(false, arguments));
return new FunctionCallExpr(catalogFunction, new FunctionParams(false, arguments));
}
@Override

View File

@ -217,11 +217,7 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla
outputTupleDesc = localAggNode.getAggInfo().getOutputTupleDesc();
}
if (aggregate.getAggPhase() == AggPhase.GLOBAL) {
for (FunctionCallExpr execAggregateFunction : execAggregateFunctions) {
execAggregateFunction.setMergeForNereids(true);
}
}
// TODO: move setMergeForNereids to ExpressionTranslator.visitAggregateFunction
if (aggregate.getAggPhase() == AggPhase.DISTINCT_LOCAL) {
for (FunctionCallExpr execAggregateFunction : execAggregateFunctions) {
if (!execAggregateFunction.isDistinct()) {

View File

@ -53,6 +53,7 @@ import org.apache.doris.nereids.DorisParser.NamedExpressionContext;
import org.apache.doris.nereids.DorisParser.NamedExpressionSeqContext;
import org.apache.doris.nereids.DorisParser.NullLiteralContext;
import org.apache.doris.nereids.DorisParser.ParenthesizedExpressionContext;
import org.apache.doris.nereids.DorisParser.PlanTypeContext;
import org.apache.doris.nereids.DorisParser.PredicateContext;
import org.apache.doris.nereids.DorisParser.PredicatedContext;
import org.apache.doris.nereids.DorisParser.QualifiedNameContext;
@ -219,8 +220,10 @@ public class LogicalPlanBuilder extends DorisParserBaseVisitor<Object> {
List<Pair<LogicalPlan, StatementContext>> logicalPlans = Lists.newArrayList();
for (org.apache.doris.nereids.DorisParser.StatementContext statement : ctx.statement()) {
StatementContext statementContext = new StatementContext();
if (ConnectContext.get() != null) {
ConnectContext.get().setStatementContext(statementContext);
ConnectContext connectContext = ConnectContext.get();
if (connectContext != null) {
connectContext.setStatementContext(statementContext);
statementContext.setConnectContext(connectContext);
}
logicalPlans.add(Pair.of(
ParserUtils.withOrigin(ctx, () -> (LogicalPlan) visit(statement)), statementContext));
@ -258,12 +261,25 @@ public class LogicalPlanBuilder extends DorisParserBaseVisitor<Object> {
@Override
public Command visitExplain(ExplainContext ctx) {
LogicalPlan logicalPlan = plan(ctx.query());
ExplainLevel explainLevel = ExplainLevel.NORMAL;
if (ctx.level != null) {
explainLevel = ExplainLevel.valueOf(ctx.level.getText().toUpperCase(Locale.ROOT));
}
return new ExplainCommand(explainLevel, logicalPlan);
return ParserUtils.withOrigin(ctx, () -> {
LogicalPlan logicalPlan = plan(ctx.query());
ExplainLevel explainLevel = ExplainLevel.NORMAL;
if (ctx.planType() != null) {
if (ctx.level == null || !ctx.level.getText().equalsIgnoreCase("plan")) {
throw new ParseException("Only explain plan can use plan type: " + ctx.planType().getText(), ctx);
}
}
if (ctx.level != null) {
if (!ctx.level.getText().equalsIgnoreCase("plan")) {
explainLevel = ExplainLevel.valueOf(ctx.level.getText().toUpperCase(Locale.ROOT));
} else {
explainLevel = parseExplainPlanType(ctx.planType());
}
}
return new ExplainCommand(explainLevel, logicalPlan);
});
}
@Override
@ -1119,4 +1135,23 @@ public class LogicalPlanBuilder extends DorisParserBaseVisitor<Object> {
}
return item.getText();
}
private ExplainLevel parseExplainPlanType(PlanTypeContext planTypeContext) {
if (planTypeContext == null || planTypeContext.ALL() != null) {
return ExplainLevel.ALL_PLAN;
}
if (planTypeContext.PHYSICAL() != null || planTypeContext.OPTIMIZED() != null) {
return ExplainLevel.OPTIMIZED_PLAN;
}
if (planTypeContext.REWRITTEN() != null || planTypeContext.LOGICAL() != null) {
return ExplainLevel.REWRITTEN_PLAN;
}
if (planTypeContext.ANALYZED() != null) {
return ExplainLevel.ANALYZED_PLAN;
}
if (planTypeContext.PARSED() != null) {
return ExplainLevel.PARSED_PLAN;
}
return ExplainLevel.ALL_PLAN;
}
}

View File

@ -53,19 +53,20 @@ public class NereidsParser {
public List<StatementBase> parseSQL(String originStr) {
List<Pair<LogicalPlan, StatementContext>> logicalPlans = parseMultiple(originStr);
List<StatementBase> statementBases = Lists.newArrayList();
for (Pair<LogicalPlan, StatementContext> logicalPlan : logicalPlans) {
for (Pair<LogicalPlan, StatementContext> parsedPlanToContext : logicalPlans) {
// TODO: this is a trick to support explain. Since we do not support any other command in a short time.
// It is acceptable. In the future, we need to refactor this.
if (logicalPlan.first instanceof ExplainCommand) {
ExplainCommand explainCommand = (ExplainCommand) logicalPlan.first;
StatementContext statementContext = parsedPlanToContext.second;
if (parsedPlanToContext.first instanceof ExplainCommand) {
ExplainCommand explainCommand = (ExplainCommand) parsedPlanToContext.first;
LogicalPlan innerPlan = explainCommand.getLogicalPlan();
LogicalPlanAdapter logicalPlanAdapter = new LogicalPlanAdapter(innerPlan, logicalPlan.second);
logicalPlanAdapter.setIsExplain(new ExplainOptions(
explainCommand.getLevel() == ExplainLevel.VERBOSE,
explainCommand.getLevel() == ExplainLevel.GRAPH));
LogicalPlanAdapter logicalPlanAdapter = new LogicalPlanAdapter(innerPlan, statementContext);
ExplainLevel explainLevel = explainCommand.getLevel();
ExplainOptions explainOptions = new ExplainOptions(explainLevel);
logicalPlanAdapter.setIsExplain(explainOptions);
statementBases.add(logicalPlanAdapter);
} else {
statementBases.add(new LogicalPlanAdapter(logicalPlan.first, logicalPlan.second));
statementBases.add(new LogicalPlanAdapter(parsedPlanToContext.first, statementContext));
}
}
return statementBases;

View File

@ -170,8 +170,8 @@ public class BindFunction implements AnalysisRuleFactory {
return new Count();
}
if (arguments.size() == 1) {
boolean isGlobalAgg = true;
AggregateParam aggregateParam = new AggregateParam(unboundFunction.isDistinct(), isGlobalAgg);
AggregateParam aggregateParam = new AggregateParam(
unboundFunction.isDistinct(), true, false);
return new Count(aggregateParam, unboundFunction.getArguments().get(0));
}
}

View File

@ -28,7 +28,6 @@ import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunctio
import org.apache.doris.nereids.trees.plans.AggPhase;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.util.ExpressionUtils;
import com.google.common.base.Preconditions;
@ -191,20 +190,15 @@ public class AggregateDisassemble extends OneRewriteRuleFactory {
AggregateFunction localAggregateFunction = aggregateFunction.withAggregateParam(
aggregateFunction.getAggregateParam()
.withDistinct(false)
.withGlobal(false)
.withGlobalAndDisassembled(false, true)
);
NamedExpression localOutputExpr = new Alias(localAggregateFunction, aggregateFunction.toSql());
List<DataType> inputTypesBeforeDissemble = aggregateFunction.children()
.stream()
.map(Expression::getDataType)
.collect(Collectors.toList());
AggregateFunction substitutionValue = aggregateFunction
// save the origin input types to the global aggregate functions
.withAggregateParam(aggregateFunction.getAggregateParam()
.withDistinct(false)
.withGlobal(true)
.withInputTypesBeforeDissemble(Optional.of(inputTypesBeforeDissemble)))
.withGlobalAndDisassembled(true, true))
.withChildren(Lists.newArrayList(localOutputExpr.toSlot()));
inputSubstitutionMap.put(aggregateFunction, substitutionValue);

View File

@ -36,7 +36,7 @@ public class Cast extends Expression implements UnaryExpression {
public Cast(Expression child, DataType targetType) {
super(child);
this.targetType = targetType;
this.targetType = Objects.requireNonNull(targetType, "targetType can not be null");
}
@Override
@ -67,7 +67,7 @@ public class Cast extends Expression implements UnaryExpression {
@Override
public String toString() {
return toSql();
return "CAST(" + child() + " AS " + targetType + ")";
}
@Override

View File

@ -60,6 +60,6 @@ public class TVFProperties extends Expression implements LeafExpression {
@Override
public String toString() {
return "KeyValuesExpression(" + toSql() + ")";
return "TVFProperties(" + toSql() + ")";
}
}

View File

@ -17,18 +17,44 @@
package org.apache.doris.nereids.trees.expressions.functions;
import org.apache.doris.catalog.FunctionSignature;
import org.apache.doris.catalog.ScalarType;
import org.apache.doris.catalog.Type;
import org.apache.doris.common.Config;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.annotation.Developing;
import org.apache.doris.nereids.exceptions.UnboundException;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.ArrayType;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.DateTimeV2Type;
import org.apache.doris.nereids.types.DecimalV2Type;
import org.apache.doris.nereids.types.DecimalV3Type;
import org.apache.doris.nereids.types.coercion.AbstractDataType;
import org.apache.doris.nereids.util.ResponsibilityChain;
import com.google.common.base.Suppliers;
import java.util.List;
import java.util.Objects;
import java.util.function.BiFunction;
import java.util.function.Supplier;
import java.util.stream.Collectors;
/** BoundFunction. */
public abstract class BoundFunction extends Expression implements FunctionTrait {
public abstract class BoundFunction extends Expression implements FunctionTrait, ComputeSignature {
private final String name;
private final Supplier<FunctionSignature> signatureCache = Suppliers.memoize(() -> {
// first step: find the candidate signature in the signature list
List<Expression> originArguments = getOriginArguments();
FunctionSignature matchedSignature = searchSignature(
getOriginArgumentTypes(), originArguments, getSignatures());
// second step: change the signature, e.g. fill precision for decimal v2
return computeSignature(matchedSignature, originArguments);
});
public BoundFunction(String name, Expression... arguments) {
super(arguments);
this.name = Objects.requireNonNull(name, "name can not be null");
@ -39,10 +65,40 @@ public abstract class BoundFunction extends Expression implements FunctionTrait
this.name = Objects.requireNonNull(name, "name can not be null");
}
protected FunctionSignature computeSignature(FunctionSignature signature, List<Expression> arguments) {
// NOTE:
// this computed chain only process the common cases.
// If you want to add some common cases to here, please separate the process code
// to the other methods and add to this chain.
// If you want to add some special cases, please override this method in the special
// function class, like 'If' function and 'Substring' function.
return ComputeSignatureChain.from(signature, arguments)
.then(this::computePrecisionForDatetimeV2)
.then(this::upgradeDateOrDateTimeToV2)
.then(this::upgradeDecimalV2ToV3)
.then(this::computePrecisionForDecimal)
.then(this::dynamicComputePropertiesOfArray)
.get();
}
public String getName() {
return name;
}
public FunctionSignature getSignature() {
return signatureCache.get();
}
@Override
public List<AbstractDataType> expectedInputTypes() {
return ComputeSignature.super.expectedInputTypes();
}
@Override
public DataType getDataType() {
return ComputeSignature.super.getDataType();
}
@Override
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
return visitor.visitBoundFunction(this, context);
@ -82,4 +138,114 @@ public abstract class BoundFunction extends Expression implements FunctionTrait
.collect(Collectors.joining(", "));
return name + "(" + args + ")";
}
private FunctionSignature computePrecisionForDatetimeV2(
FunctionSignature signature, List<Expression> arguments) {
// fill for arguments type
signature = signature.withArgumentTypes(arguments, (sigArgType, realArgType) -> {
if (sigArgType instanceof DateTimeV2Type && realArgType.getDataType() instanceof DateTimeV2Type) {
return realArgType.getDataType();
}
return sigArgType;
});
// fill for return type
if (signature.returnType instanceof DateTimeV2Type) {
Integer maxScale = signature.argumentsTypes.stream()
.filter(DateTimeV2Type.class::isInstance)
.map(t -> ((DateTimeV2Type) t).getScale())
.reduce(Math::max)
.orElse(((DateTimeV2Type) signature.returnType).getScale());
signature = signature.withReturnType(DateTimeV2Type.of(maxScale));
}
return signature;
}
private FunctionSignature upgradeDateOrDateTimeToV2(
FunctionSignature signature, List<Expression> arguments) {
DataType returnType = signature.returnType;
Type type = returnType.toCatalogDataType();
if ((type.isDate() || type.isDatetime()) && Config.enable_date_conversion) {
Type legacyReturnType = ScalarType.getDefaultDateType(returnType.toCatalogDataType());
signature = signature.withReturnType(DataType.fromCatalogType(legacyReturnType));
}
return signature;
}
@Developing
private FunctionSignature computePrecisionForDecimal(
FunctionSignature signature, List<Expression> arguments) {
if (signature.returnType instanceof DecimalV3Type || signature.returnType instanceof DecimalV2Type) {
if (this instanceof DecimalSamePrecision) {
signature = signature.withReturnType(arguments.get(0).getDataType());
} else if (this instanceof DecimalWiderPrecision) {
ScalarType widerType = ScalarType.createDecimalV3Type(ScalarType.MAX_DECIMAL128_PRECISION,
((ScalarType) arguments.get(0).getDataType().toCatalogDataType()).getScalarScale());
signature = signature.withReturnType(DataType.fromCatalogType(widerType));
} else if (this instanceof DecimalStddevPrecision) {
// for all stddev function, use decimal(38,9) as computing result
ScalarType stddevDecimalType = ScalarType.createDecimalV3Type(ScalarType.MAX_DECIMAL128_PRECISION,
DecimalStddevPrecision.STDDEV_DECIMAL_SCALE);
signature = signature.withReturnType(DataType.fromCatalogType(stddevDecimalType));
}
}
return signature;
}
private FunctionSignature upgradeDecimalV2ToV3(
FunctionSignature signature, List<Expression> arguments) {
DataType returnType = signature.returnType;
Type type = returnType.toCatalogDataType();
if (type.isDecimalV2() && Config.enable_decimal_conversion && Config.enable_decimalv3) {
Type v3Type = ScalarType.createDecimalV3Type(type.getPrecision(), ((ScalarType) type).getScalarScale());
signature = signature.withReturnType(DataType.fromCatalogType(v3Type));
}
return signature;
}
private FunctionSignature dynamicComputePropertiesOfArray(
FunctionSignature signature, List<Expression> arguments) {
if (!(signature.returnType instanceof ArrayType)) {
return signature;
}
// TODO: compute array(...) function's itemType
// fill item type by the type of first item
ArrayType arrayType = (ArrayType) signature.returnType;
// fill containsNull if any array argument contains null
boolean containsNull = signature.argumentsTypes
.stream()
.filter(argType -> argType instanceof ArrayType)
.map(ArrayType.class::cast)
.anyMatch(ArrayType::containsNull);
return signature.withReturnType(
ArrayType.of(arrayType.getItemType(), arrayType.containsNull() || containsNull));
}
static class ComputeSignatureChain {
private ResponsibilityChain<Pair<FunctionSignature, List<Expression>>> computeChain;
public ComputeSignatureChain(ResponsibilityChain<Pair<FunctionSignature, List<Expression>>> computeChain) {
this.computeChain = computeChain;
}
public static ComputeSignatureChain from(FunctionSignature signature, List<Expression> arguments) {
return new ComputeSignatureChain(ResponsibilityChain.from(Pair.of(signature, arguments)));
}
public ComputeSignatureChain then(
BiFunction<FunctionSignature, List<Expression>, FunctionSignature> computeFunction) {
computeChain.then(pair -> Pair.of(computeFunction.apply(pair.first, pair.second), pair.second));
return this;
}
public FunctionSignature get() {
return computeChain.get().first;
}
}
}

View File

@ -19,6 +19,7 @@ package org.apache.doris.nereids.trees.expressions.functions;
import org.apache.doris.catalog.FunctionSignature;
import org.apache.doris.nereids.annotation.Developing;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.typecoercion.ImplicitCastInputTypes;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.coercion.AbstractDataType;
@ -47,7 +48,8 @@ public interface ComputeSignature extends FunctionTrait, ImplicitCastInputTypes
*
* @return the matched signature
*/
FunctionSignature searchSignature();
FunctionSignature searchSignature(List<DataType> argumentTypes, List<Expression> arguments,
List<FunctionSignature> signatures);
///// re-defined other interface's methods, so we can mixin this interfaces like a trait /////
@ -64,7 +66,7 @@ public interface ComputeSignature extends FunctionTrait, ImplicitCastInputTypes
*/
@Override
default List<AbstractDataType> expectedInputTypes() {
return (List) getSignature().argumentsTypes;
return getSignature().argumentsTypes;
}
/**

View File

@ -0,0 +1,47 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.trees.expressions.functions;
import org.apache.doris.catalog.FunctionSignature;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.types.DataType;
import com.google.common.collect.ImmutableList;
import java.util.List;
/** CustomSignature */
public interface CustomSignature extends ComputeSignature {
// custom generate a function signature.
FunctionSignature customSignature(List<DataType> argumentTypes, List<Expression> arguments);
@Override
default List<FunctionSignature> getSignatures() {
List<DataType> originArgumentTypes = getOriginArgumentTypes();
List<Expression> originArguments = getOriginArguments();
return ImmutableList.of(customSignature(originArgumentTypes, originArguments));
}
// use the first signature as the candidate signature.
@Override
default FunctionSignature searchSignature(List<DataType> argumentTypes, List<Expression> arguments,
List<FunctionSignature> signatures) {
return signatures.get(0);
}
}

View File

@ -41,17 +41,17 @@ public abstract class DateTimeWithPrecision extends ScalarFunction {
}
@Override
protected FunctionSignature computeSignature(FunctionSignature signature) {
protected FunctionSignature computeSignature(FunctionSignature signature, List<Expression> arguments) {
if (arity() == 1 && signature.returnType instanceof DateTimeV2Type) {
// For functions in TIME_FUNCTIONS_WITH_PRECISION, we can't figure out which function should be use when
// searching in FunctionSet. So we adjust the return type by hand here.
if (child(0) instanceof IntegerLikeLiteral) {
IntegerLikeLiteral integerLikeLiteral = (IntegerLikeLiteral) child(0);
if (arguments.get(0) instanceof IntegerLikeLiteral) {
IntegerLikeLiteral integerLikeLiteral = (IntegerLikeLiteral) arguments.get(0);
signature = signature.withReturnType(DateTimeV2Type.of(integerLikeLiteral.getIntValue()));
} else {
signature = signature.withReturnType(DateTimeV2Type.of(6));
}
}
return super.computeSignature(signature);
return super.computeSignature(signature, arguments);
}
}

View File

@ -19,8 +19,12 @@ package org.apache.doris.nereids.trees.expressions.functions;
import org.apache.doris.catalog.FunctionSignature;
import org.apache.doris.catalog.Type;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.coercion.AbstractDataType;
import java.util.List;
/**
* Explicitly castable signature. This class equals to 'CompareMode.IS_NONSTRICT_SUPERTYPE_OF'.
*
@ -34,8 +38,9 @@ public interface ExplicitlyCastableSignature extends ComputeSignature {
}
@Override
default FunctionSignature searchSignature() {
return SearchSignature.from(getSignatures(), getArguments())
default FunctionSignature searchSignature(List<DataType> argumentTypes, List<Expression> arguments,
List<FunctionSignature> signatures) {
return SearchSignature.from(signatures, arguments)
// first round, use identical strategy to find signature
.orElseSearch(IdenticalSignature::isIdentical)
// second round: if not found, use nullOrIdentical strategy

View File

@ -39,6 +39,10 @@ public interface ExpressionTrait extends TreeNode<Expression> {
return children();
}
default Expression getArgument(int index) {
return child(index);
}
default DataType getDataType() throws UnboundException {
throw new UnboundException("dataType");
}

View File

@ -17,6 +17,13 @@
package org.apache.doris.nereids.trees.expressions.functions;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.types.DataType;
import com.google.common.collect.ImmutableList;
import java.util.List;
/**
* FunctionTrait.
*/
@ -24,4 +31,15 @@ public interface FunctionTrait extends ExpressionTrait {
String getName();
boolean hasVarArguments();
default List<Expression> getOriginArguments() {
return getArguments();
}
default List<DataType> getOriginArgumentTypes() {
return getArguments()
.stream()
.map(Expression::getDataType)
.collect(ImmutableList.toImmutableList());
}
}

View File

@ -18,8 +18,12 @@
package org.apache.doris.nereids.trees.expressions.functions;
import org.apache.doris.catalog.FunctionSignature;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.coercion.AbstractDataType;
import java.util.List;
/**
* Identical function signature. This class equals to 'CompareMode.IS_IDENTICAL'.
*
@ -33,8 +37,9 @@ public interface IdenticalSignature extends ComputeSignature {
}
@Override
default FunctionSignature searchSignature() {
return SearchSignature.from(getSignatures(), getArguments())
default FunctionSignature searchSignature(List<DataType> argumentTypes, List<Expression> arguments,
List<FunctionSignature> signatures) {
return SearchSignature.from(signatures, arguments)
// first round, use identical strategy to find signature
.orElseSearch(IdenticalSignature::isIdentical)
.resultOrException(getName());

View File

@ -19,8 +19,12 @@ package org.apache.doris.nereids.trees.expressions.functions;
import org.apache.doris.catalog.FunctionSignature;
import org.apache.doris.catalog.Type;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.coercion.AbstractDataType;
import java.util.List;
/**
* Implicitly castable function signature. This class equals to 'CompareMode.IS_SUPERTYPE_OF'.
*
@ -34,8 +38,9 @@ public interface ImplicitlyCastableSignature extends ComputeSignature {
}
@Override
default FunctionSignature searchSignature() {
return SearchSignature.from(getSignatures(), getArguments())
default FunctionSignature searchSignature(List<DataType> argumentTypes, List<Expression> arguments,
List<FunctionSignature> signatures) {
return SearchSignature.from(signatures, arguments)
// first round, use identical strategy to find signature
.orElseSearch(IdenticalSignature::isIdentical)
// second round: if not found, use nullOrIdentical strategy

View File

@ -18,9 +18,13 @@
package org.apache.doris.nereids.trees.expressions.functions;
import org.apache.doris.catalog.FunctionSignature;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.NullType;
import org.apache.doris.nereids.types.coercion.AbstractDataType;
import java.util.List;
/**
* Null or identical function signature. This class equals to 'CompareMode.IS_INDISTINGUISHABLE'.
*
@ -35,8 +39,9 @@ public interface NullOrIdenticalSignature extends ComputeSignature {
}
@Override
default FunctionSignature searchSignature() {
return SearchSignature.from(getSignatures(), getArguments())
default FunctionSignature searchSignature(List<DataType> argumentTypes, List<Expression> arguments,
List<FunctionSignature> signatures) {
return SearchSignature.from(signatures, arguments)
// first round, use identical strategy to find signature
.orElseSearch(IdenticalSignature::isIdentical)
// second round: if not found, use nullOrIdentical strategy

View File

@ -19,17 +19,21 @@ package org.apache.doris.nereids.trees.expressions.functions.agg;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.BoundFunction;
import org.apache.doris.nereids.trees.expressions.typecoercion.ExpectsInputTypes;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.PartialAggType;
import org.apache.doris.nereids.types.coercion.AbstractDataType;
import com.google.common.collect.ImmutableList;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
/**
* The function which consume arguments in lots of rows and product one value.
*/
public abstract class AggregateFunction extends BoundFunction {
public abstract class AggregateFunction extends BoundFunction implements ExpectsInputTypes {
private final AggregateParam aggregateParam;
@ -42,15 +46,69 @@ public abstract class AggregateFunction extends BoundFunction {
this.aggregateParam = Objects.requireNonNull(aggregateParam, "aggregateParam can not be null");
}
@Override
public List<Expression> getOriginArguments() {
return getArgumentsBeforeDisassembled();
}
@Override
public List<DataType> getOriginArgumentTypes() {
return getArgumentTypesBeforeDisassembled();
}
@Override
public abstract AggregateFunction withChildren(List<Expression> children);
public abstract DataType getFinalType();
public abstract DataType getIntermediateType();
public abstract AggregateFunction withAggregateParam(AggregateParam aggregateParam);
protected abstract List<DataType> intermediateTypes(List<DataType> argumentTypes, List<Expression> arguments);
/** getIntermediateTypes */
public final PartialAggType getIntermediateTypes() {
if (isGlobal() && isDisassembled()) {
return (PartialAggType) child(0).getDataType();
}
List<Expression> arguments = getArgumentsBeforeDisassembled();
List<DataType> types = getArgumentTypesBeforeDisassembled();
return new PartialAggType(getArguments(), intermediateTypes(types, arguments));
}
public final DataType getFinalType() {
return getSignature().returnType;
}
@Override
public final DataType getDataType() {
if (aggregateParam.isGlobal) {
return getFinalType();
} else {
return getIntermediateTypes();
}
}
@Override
public final List<AbstractDataType> expectedInputTypes() {
if (isGlobal() && isDisassembled()) {
return ImmutableList.of(getIntermediateTypes());
} else {
return getSignature().argumentsTypes;
}
}
public List<Expression> getArgumentsBeforeDisassembled() {
if (arity() == 1 && getArgument(0).getDataType() instanceof PartialAggType) {
return ((PartialAggType) getArgument(0).getDataType()).getOriginArguments();
}
return getArguments();
}
public List<DataType> getArgumentTypesBeforeDisassembled() {
return getArgumentsBeforeDisassembled()
.stream()
.map(Expression::getDataType)
.collect(ImmutableList.toImmutableList());
}
public boolean isDistinct() {
return aggregateParam.isDistinct;
}
@ -59,8 +117,8 @@ public abstract class AggregateFunction extends BoundFunction {
return aggregateParam.isGlobal;
}
public Optional<List<DataType>> inputTypesBeforeDissemble() {
return aggregateParam.inputTypesBeforeDissemble;
public boolean isDisassembled() {
return aggregateParam.isDisassembled;
}
public AggregateParam getAggregateParam() {
@ -95,13 +153,4 @@ public abstract class AggregateFunction extends BoundFunction {
public boolean hasVarArguments() {
return false;
}
@Override
public final DataType getDataType() {
if (aggregateParam.isGlobal) {
return getFinalType();
} else {
return getIntermediateType();
}
}
}

View File

@ -17,11 +17,9 @@
package org.apache.doris.nereids.trees.expressions.functions.agg;
import org.apache.doris.nereids.types.DataType;
import com.google.common.base.Preconditions;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
/** AggregateParam. */
public class AggregateParam {
@ -29,50 +27,41 @@ public class AggregateParam {
public final boolean isDistinct;
// When AggregateDisassemble rule disassemble the aggregate function, say double avg(int), the local
// aggregate keep the origin signature, but the global aggregate change to double avg(double).
// This behavior is difference from the legacy optimizer, because legacy optimizer keep the same signature
// between local aggregate and global aggregate. If the signatures are different, the result would wrong.
// So we use this field to record the originInputTypes, and find the catalog function by the origin input types.
public final Optional<List<DataType>> inputTypesBeforeDissemble;
public final boolean isDisassembled;
public AggregateParam() {
this(false, true, Optional.empty());
}
public AggregateParam(boolean distinct) {
this(distinct, true, Optional.empty());
}
public AggregateParam(boolean isDistinct, boolean isGlobal) {
this(isDistinct, isGlobal, Optional.empty());
}
public AggregateParam(boolean isDistinct, boolean isGlobal, Optional<List<DataType>> inputTypesBeforeDissemble) {
/** AggregateParam */
public AggregateParam(boolean isDistinct, boolean isGlobal, boolean isDisassembled) {
this.isDistinct = isDistinct;
this.isGlobal = isGlobal;
this.inputTypesBeforeDissemble = Objects.requireNonNull(inputTypesBeforeDissemble,
"inputTypesBeforeDissemble can not be null");
}
public static AggregateParam distinctAndGlobal() {
return new AggregateParam(true, true, Optional.empty());
this.isDisassembled = isDisassembled;
if (!isGlobal) {
Preconditions.checkArgument(isDisassembled == true,
"local aggregate should be disassembed");
}
}
public static AggregateParam global() {
return new AggregateParam(false, true, Optional.empty());
return new AggregateParam(false, true, false);
}
public static AggregateParam distinctAndGlobal() {
return new AggregateParam(true, true, false);
}
public AggregateParam withDistinct(boolean isDistinct) {
return new AggregateParam(isDistinct, isGlobal, inputTypesBeforeDissemble);
return new AggregateParam(isDistinct, isGlobal, isDisassembled);
}
public AggregateParam withGlobal(boolean isGlobal) {
return new AggregateParam(isDistinct, isGlobal, inputTypesBeforeDissemble);
return new AggregateParam(isDistinct, isGlobal, isDisassembled);
}
public AggregateParam withInputTypesBeforeDissemble(Optional<List<DataType>> inputTypesBeforeDissemble) {
return new AggregateParam(isDistinct, isGlobal, inputTypesBeforeDissemble);
public AggregateParam withDisassembled(boolean isDisassembled) {
return new AggregateParam(isDistinct, isGlobal, isDisassembled);
}
public AggregateParam withGlobalAndDisassembled(boolean isGlobal, boolean isDisassembled) {
return new AggregateParam(isDistinct, isGlobal, isDisassembled);
}
@Override
@ -86,11 +75,11 @@ public class AggregateParam {
AggregateParam that = (AggregateParam) o;
return isDistinct == that.isDistinct
&& Objects.equals(isGlobal, that.isGlobal)
&& Objects.equals(inputTypesBeforeDissemble, that.inputTypesBeforeDissemble);
&& Objects.equals(isDisassembled, that.isDisassembled);
}
@Override
public int hashCode() {
return Objects.hash(isDistinct, isGlobal, inputTypesBeforeDissemble);
return Objects.hash(isDistinct, isGlobal, isDisassembled);
}
}

View File

@ -17,19 +17,20 @@
package org.apache.doris.nereids.trees.expressions.functions.agg;
import org.apache.doris.catalog.FunctionSignature;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.CustomSignature;
import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable;
import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression;
import org.apache.doris.nereids.trees.expressions.typecoercion.ImplicitCastInputTypes;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.BigIntType;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.DateTimeType;
import org.apache.doris.nereids.types.DateType;
import org.apache.doris.nereids.types.DecimalV2Type;
import org.apache.doris.nereids.types.DoubleType;
import org.apache.doris.nereids.types.VarcharType;
import org.apache.doris.nereids.types.coercion.AbstractDataType;
import org.apache.doris.nereids.types.coercion.NumericType;
import org.apache.doris.nereids.types.coercion.TypeCollection;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
@ -37,12 +38,7 @@ import com.google.common.collect.ImmutableList;
import java.util.List;
/** avg agg function. */
public class Avg extends AggregateFunction implements UnaryExpression, ImplicitCastInputTypes {
// used in interface expectedInputTypes to avoid new list in each time it be called
private static final List<AbstractDataType> EXPECTED_INPUT_TYPES = ImmutableList.of(
new TypeCollection(NumericType.INSTANCE, DateTimeType.INSTANCE, DateType.INSTANCE)
);
public class Avg extends AggregateFunction implements UnaryExpression, PropagateNullable, CustomSignature {
public Avg(Expression child) {
super("avg", child);
@ -53,31 +49,16 @@ public class Avg extends AggregateFunction implements UnaryExpression, ImplicitC
}
@Override
public DataType getFinalType() {
DataType argumentType = inputTypesBeforeDissemble()
.map(types -> types.get(0))
.orElse(child().getDataType());
if (argumentType instanceof DecimalV2Type) {
return DecimalV2Type.SYSTEM_DEFAULT;
} else if (argumentType.isDate()) {
return DateType.INSTANCE;
} else if (argumentType.isDateTime()) {
return DateTimeType.INSTANCE;
} else {
return DoubleType.INSTANCE;
}
}
// TODO: We should return a complex type: PartialAggType(bufferTypes=[Double, Int], inputTypes=[Int])
// to denote sum(double) and count(int)
@Override
public DataType getIntermediateType() {
return VarcharType.createVarcharType(-1);
public FunctionSignature customSignature(List<DataType> argumentTypes, List<Expression> arguments) {
DataType implicitCastType = implicitCast(argumentTypes.get(0));
return FunctionSignature.ret(implicitCastType).args(implicitCastType);
}
@Override
public boolean nullable() {
return child().nullable();
protected List<DataType> intermediateTypes(List<DataType> argumentTypes, List<Expression> arguments) {
DataType sumType = getFinalType();
BigIntType countType = BigIntType.INSTANCE;
return ImmutableList.of(sumType, countType);
}
@Override
@ -91,17 +72,22 @@ public class Avg extends AggregateFunction implements UnaryExpression, ImplicitC
return new Avg(aggregateParam, child());
}
@Override
public List<AbstractDataType> expectedInputTypes() {
if (isGlobal() && inputTypesBeforeDissemble().isPresent()) {
return ImmutableList.of();
} else {
return EXPECTED_INPUT_TYPES;
}
}
@Override
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
return visitor.visitAvg(this, context);
}
private DataType implicitCast(DataType dataType) {
if (dataType instanceof DecimalV2Type) {
return DecimalV2Type.SYSTEM_DEFAULT;
} else if (dataType.isDate()) {
return DateType.INSTANCE;
} else if (dataType.isDateTime()) {
return DateTimeType.INSTANCE;
} else if (dataType instanceof NumericType) {
return DoubleType.INSTANCE;
} else {
throw new AnalysisException("avg requires a numeric parameter: " + dataType);
}
}
}

View File

@ -17,13 +17,13 @@
package org.apache.doris.nereids.trees.expressions.functions.agg;
import org.apache.doris.catalog.FunctionSignature;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable;
import org.apache.doris.nereids.trees.expressions.functions.AlwaysNotNullable;
import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression;
import org.apache.doris.nereids.trees.expressions.typecoercion.ImplicitCastInputTypes;
import org.apache.doris.nereids.types.BitmapType;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.coercion.AbstractDataType;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
@ -32,7 +32,11 @@ import java.util.List;
/** BitmapIntersect */
public class BitmapIntersect extends AggregateFunction
implements UnaryExpression, PropagateNullable, ImplicitCastInputTypes {
implements UnaryExpression, AlwaysNotNullable, ExplicitlyCastableSignature {
public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(BitmapType.INSTANCE).args(BitmapType.INSTANCE)
);
public BitmapIntersect(Expression arg0) {
super("bitmap_intersect", arg0);
}
@ -41,29 +45,24 @@ public class BitmapIntersect extends AggregateFunction
super("bitmap_intersect", aggregateParam, arg0);
}
@Override
protected List<DataType> intermediateTypes(List<DataType> argumentTypes, List<Expression> arguments) {
return ImmutableList.of(BitmapType.INSTANCE);
}
@Override
public BitmapIntersect withChildren(List<Expression> children) {
Preconditions.checkArgument(children.size() == 1);
return new BitmapIntersect(getAggregateParam(), children.get(0));
}
@Override
public List<AbstractDataType> expectedInputTypes() {
return ImmutableList.of(BitmapType.INSTANCE);
}
@Override
public DataType getFinalType() {
return BitmapType.INSTANCE;
}
@Override
public DataType getIntermediateType() {
return BitmapType.INSTANCE;
}
@Override
public BitmapIntersect withAggregateParam(AggregateParam aggregateParam) {
return new BitmapIntersect(aggregateParam, child());
}
@Override
public List<FunctionSignature> getSignatures() {
return SIGNATURES;
}
}

View File

@ -17,13 +17,13 @@
package org.apache.doris.nereids.trees.expressions.functions.agg;
import org.apache.doris.catalog.FunctionSignature;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable;
import org.apache.doris.nereids.trees.expressions.functions.AlwaysNotNullable;
import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression;
import org.apache.doris.nereids.trees.expressions.typecoercion.ImplicitCastInputTypes;
import org.apache.doris.nereids.types.BitmapType;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.coercion.AbstractDataType;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
@ -32,7 +32,11 @@ import java.util.List;
/** BitmapUnion */
public class BitmapUnion extends AggregateFunction
implements UnaryExpression, PropagateNullable, ImplicitCastInputTypes {
implements UnaryExpression, AlwaysNotNullable, ExplicitlyCastableSignature {
public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(BitmapType.INSTANCE).args(BitmapType.INSTANCE)
);
public BitmapUnion(Expression arg0) {
super("bitmap_union", arg0);
}
@ -41,29 +45,24 @@ public class BitmapUnion extends AggregateFunction
super("bitmap_union", aggregateParam, arg0);
}
@Override
protected List<DataType> intermediateTypes(List<DataType> argumentTypes, List<Expression> arguments) {
return ImmutableList.of(BitmapType.INSTANCE);
}
@Override
public BitmapUnion withChildren(List<Expression> children) {
Preconditions.checkArgument(children.size() == 1);
return new BitmapUnion(getAggregateParam(), children.get(0));
}
@Override
public List<AbstractDataType> expectedInputTypes() {
return ImmutableList.of(BitmapType.INSTANCE);
}
@Override
public DataType getFinalType() {
return BitmapType.INSTANCE;
}
@Override
public DataType getIntermediateType() {
return BitmapType.INSTANCE;
}
@Override
public BitmapUnion withAggregateParam(AggregateParam aggregateParam) {
return new BitmapUnion(aggregateParam, child());
}
@Override
public List<FunctionSignature> getSignatures() {
return SIGNATURES;
}
}

View File

@ -17,14 +17,14 @@
package org.apache.doris.nereids.trees.expressions.functions.agg;
import org.apache.doris.catalog.FunctionSignature;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable;
import org.apache.doris.nereids.trees.expressions.functions.AlwaysNotNullable;
import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression;
import org.apache.doris.nereids.trees.expressions.typecoercion.ImplicitCastInputTypes;
import org.apache.doris.nereids.types.BigIntType;
import org.apache.doris.nereids.types.BitmapType;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.coercion.AbstractDataType;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
@ -33,7 +33,11 @@ import java.util.List;
/** BitmapUnionCount */
public class BitmapUnionCount extends AggregateFunction
implements UnaryExpression, PropagateNullable, ImplicitCastInputTypes {
implements UnaryExpression, AlwaysNotNullable, ExplicitlyCastableSignature {
public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(BigIntType.INSTANCE).args(BitmapType.INSTANCE)
);
public BitmapUnionCount(Expression arg0) {
super("bitmap_union_count", arg0);
}
@ -42,29 +46,24 @@ public class BitmapUnionCount extends AggregateFunction
super("bitmap_union_count", aggregateParam, arg0);
}
@Override
protected List<DataType> intermediateTypes(List<DataType> argumentTypes, List<Expression> arguments) {
return ImmutableList.of(BitmapType.INSTANCE);
}
@Override
public BitmapUnionCount withChildren(List<Expression> children) {
Preconditions.checkArgument(children.size() == 1);
return new BitmapUnionCount(getAggregateParam(), children.get(0));
}
@Override
public List<AbstractDataType> expectedInputTypes() {
return ImmutableList.of(BitmapType.INSTANCE);
}
@Override
public DataType getFinalType() {
return BigIntType.INSTANCE;
}
@Override
public DataType getIntermediateType() {
return BitmapType.INSTANCE;
}
@Override
public BitmapUnionCount withAggregateParam(AggregateParam aggregateParam) {
return new BitmapUnionCount(aggregateParam, child());
}
@Override
public List<FunctionSignature> getSignatures() {
return SIGNATURES;
}
}

View File

@ -17,18 +17,17 @@
package org.apache.doris.nereids.trees.expressions.functions.agg;
import org.apache.doris.catalog.FunctionSignature;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable;
import org.apache.doris.nereids.trees.expressions.functions.AlwaysNotNullable;
import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression;
import org.apache.doris.nereids.trees.expressions.typecoercion.ImplicitCastInputTypes;
import org.apache.doris.nereids.types.BigIntType;
import org.apache.doris.nereids.types.BitmapType;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.IntegerType;
import org.apache.doris.nereids.types.SmallIntType;
import org.apache.doris.nereids.types.TinyIntType;
import org.apache.doris.nereids.types.coercion.AbstractDataType;
import org.apache.doris.nereids.types.coercion.TypeCollection;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
@ -37,7 +36,14 @@ import java.util.List;
/** BitmapUnionInt */
public class BitmapUnionInt extends AggregateFunction
implements UnaryExpression, PropagateNullable, ImplicitCastInputTypes {
implements UnaryExpression, AlwaysNotNullable, ExplicitlyCastableSignature {
public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(BigIntType.INSTANCE).args(SmallIntType.INSTANCE),
FunctionSignature.ret(BigIntType.INSTANCE).args(TinyIntType.INSTANCE),
FunctionSignature.ret(BigIntType.INSTANCE).args(IntegerType.INSTANCE),
FunctionSignature.ret(BigIntType.INSTANCE).args(BigIntType.INSTANCE)
);
public BitmapUnionInt(Expression arg0) {
super("bitmap_union_int", arg0);
}
@ -46,34 +52,24 @@ public class BitmapUnionInt extends AggregateFunction
super("bitmap_union_int", aggregateParam, arg0);
}
@Override
protected List<DataType> intermediateTypes(List<DataType> argumentTypes, List<Expression> arguments) {
return ImmutableList.of(BitmapType.INSTANCE);
}
@Override
public BitmapUnionInt withChildren(List<Expression> children) {
Preconditions.checkArgument(children.size() == 1);
return new BitmapUnionInt(getAggregateParam(), children.get(0));
}
@Override
public List<AbstractDataType> expectedInputTypes() {
if (isGlobal() && inputTypesBeforeDissemble().isPresent()) {
return ImmutableList.of();
} else {
return ImmutableList.of(new TypeCollection(
TinyIntType.INSTANCE, SmallIntType.INSTANCE, IntegerType.INSTANCE, BigIntType.INSTANCE));
}
}
@Override
public DataType getFinalType() {
return BigIntType.INSTANCE;
}
@Override
public DataType getIntermediateType() {
return BitmapType.INSTANCE;
}
@Override
public BitmapUnionInt withAggregateParam(AggregateParam aggregateParam) {
return new BitmapUnionInt(aggregateParam, child());
}
@Override
public List<FunctionSignature> getSignatures() {
return SIGNATURES;
}
}

View File

@ -17,20 +17,23 @@
package org.apache.doris.nereids.trees.expressions.functions.agg;
import org.apache.doris.catalog.FunctionSignature;
import org.apache.doris.nereids.exceptions.UnboundException;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.AlwaysNotNullable;
import org.apache.doris.nereids.trees.expressions.functions.CustomSignature;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.BigIntType;
import org.apache.doris.nereids.types.DataType;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import java.util.List;
import java.util.stream.Collectors;
/** count agg function. */
public class Count extends AggregateFunction implements AlwaysNotNullable {
public class Count extends AggregateFunction implements AlwaysNotNullable, CustomSignature {
private final boolean isStar;
@ -59,13 +62,13 @@ public class Count extends AggregateFunction implements AlwaysNotNullable {
}
@Override
public DataType getFinalType() {
return BigIntType.INSTANCE;
public FunctionSignature customSignature(List<DataType> argumentTypes, List<Expression> arguments) {
return FunctionSignature.of(BigIntType.INSTANCE, (List) argumentTypes);
}
@Override
public DataType getIntermediateType() {
return getFinalType();
protected List<DataType> intermediateTypes(List<DataType> argumentTypes, List<Expression> arguments) {
return ImmutableList.of(BigIntType.INSTANCE);
}
@Override

View File

@ -17,13 +17,13 @@
package org.apache.doris.nereids.trees.expressions.functions.agg;
import org.apache.doris.catalog.FunctionSignature;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable;
import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression;
import org.apache.doris.nereids.trees.expressions.typecoercion.ImplicitCastInputTypes;
import org.apache.doris.nereids.types.BitmapType;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.coercion.AbstractDataType;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
@ -32,7 +32,11 @@ import java.util.List;
/** GroupBitmapXor */
public class GroupBitmapXor extends AggregateFunction
implements UnaryExpression, PropagateNullable, ImplicitCastInputTypes {
implements UnaryExpression, PropagateNullable, ExplicitlyCastableSignature {
public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(BitmapType.INSTANCE).args(BitmapType.INSTANCE)
);
public GroupBitmapXor(Expression arg0) {
super("group_bitmap_xor", arg0);
}
@ -41,33 +45,24 @@ public class GroupBitmapXor extends AggregateFunction
super("group_bitmap_xor", aggregateParam, arg0);
}
@Override
protected List<DataType> intermediateTypes(List<DataType> argumentTypes, List<Expression> arguments) {
return ImmutableList.of(BitmapType.INSTANCE);
}
@Override
public GroupBitmapXor withChildren(List<Expression> children) {
Preconditions.checkArgument(children.size() == 1);
return new GroupBitmapXor(getAggregateParam(), children.get(0));
}
@Override
public List<AbstractDataType> expectedInputTypes() {
if (isGlobal() && inputTypesBeforeDissemble().isPresent()) {
return ImmutableList.of();
} else {
return ImmutableList.of(BitmapType.INSTANCE);
}
}
@Override
public DataType getFinalType() {
return BitmapType.INSTANCE;
}
@Override
public DataType getIntermediateType() {
return BitmapType.INSTANCE;
}
@Override
public GroupBitmapXor withAggregateParam(AggregateParam aggregateParam) {
return new GroupBitmapXor(aggregateParam, child());
}
@Override
public List<FunctionSignature> getSignatures() {
return SIGNATURES;
}
}

View File

@ -17,13 +17,13 @@
package org.apache.doris.nereids.trees.expressions.functions.agg;
import org.apache.doris.catalog.FunctionSignature;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable;
import org.apache.doris.nereids.trees.expressions.functions.AlwaysNotNullable;
import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression;
import org.apache.doris.nereids.trees.expressions.typecoercion.ImplicitCastInputTypes;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.HllType;
import org.apache.doris.nereids.types.coercion.AbstractDataType;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
@ -32,15 +32,22 @@ import java.util.List;
/** HllUnion */
public class HllUnion extends AggregateFunction
implements UnaryExpression, PropagateNullable, ImplicitCastInputTypes {
implements UnaryExpression, AlwaysNotNullable, ExplicitlyCastableSignature {
public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(HllType.INSTANCE).args(HllType.INSTANCE)
);
public HllUnion(Expression arg0) {
// TODO: change to hll_union in the future
super("hll_raw_agg", arg0);
super("hll_union", arg0);
}
public HllUnion(AggregateParam aggregateParam, Expression arg0) {
// TODO: change to hll_union in the future
super("hll_raw_agg", aggregateParam, arg0);
super("hll_union", aggregateParam, arg0);
}
@Override
protected List<DataType> intermediateTypes(List<DataType> argumentTypes, List<Expression> arguments) {
return ImmutableList.of(HllType.INSTANCE);
}
@Override
@ -49,23 +56,13 @@ public class HllUnion extends AggregateFunction
return new HllUnion(getAggregateParam(), children.get(0));
}
@Override
public List<AbstractDataType> expectedInputTypes() {
return ImmutableList.of(HllType.INSTANCE);
}
@Override
public DataType getFinalType() {
return HllType.INSTANCE;
}
@Override
public DataType getIntermediateType() {
return HllType.INSTANCE;
}
@Override
public HllUnion withAggregateParam(AggregateParam aggregateParam) {
return new HllUnion(aggregateParam, child());
}
@Override
public List<FunctionSignature> getSignatures() {
return SIGNATURES;
}
}

View File

@ -17,14 +17,14 @@
package org.apache.doris.nereids.trees.expressions.functions.agg;
import org.apache.doris.catalog.FunctionSignature;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable;
import org.apache.doris.nereids.trees.expressions.functions.AlwaysNotNullable;
import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression;
import org.apache.doris.nereids.trees.expressions.typecoercion.ImplicitCastInputTypes;
import org.apache.doris.nereids.types.BigIntType;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.HllType;
import org.apache.doris.nereids.types.coercion.AbstractDataType;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
@ -33,7 +33,11 @@ import java.util.List;
/** HllUnionAgg */
public class HllUnionAgg extends AggregateFunction
implements UnaryExpression, PropagateNullable, ImplicitCastInputTypes {
implements UnaryExpression, AlwaysNotNullable, ExplicitlyCastableSignature {
public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(BigIntType.INSTANCE).args(HllType.INSTANCE)
);
public HllUnionAgg(Expression arg0) {
super("hll_union_agg", arg0);
}
@ -42,29 +46,24 @@ public class HllUnionAgg extends AggregateFunction
super("hll_union_agg", aggregateParam, arg0);
}
@Override
protected List<DataType> intermediateTypes(List<DataType> argumentTypes, List<Expression> arguments) {
return ImmutableList.of(HllType.INSTANCE);
}
@Override
public HllUnionAgg withChildren(List<Expression> children) {
Preconditions.checkArgument(children.size() == 1);
return new HllUnionAgg(getAggregateParam(), children.get(0));
}
@Override
public List<AbstractDataType> expectedInputTypes() {
return ImmutableList.of(HllType.INSTANCE);
}
@Override
public DataType getFinalType() {
return BigIntType.INSTANCE;
}
@Override
public DataType getIntermediateType() {
return HllType.INSTANCE;
}
@Override
public HllUnionAgg withAggregateParam(AggregateParam aggregateParam) {
return new HllUnionAgg(aggregateParam, child());
}
@Override
public List<FunctionSignature> getSignatures() {
return SIGNATURES;
}
}

View File

@ -17,7 +17,10 @@
package org.apache.doris.nereids.trees.expressions.functions.agg;
import org.apache.doris.catalog.FunctionSignature;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.CustomSignature;
import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable;
import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.DataType;
@ -27,8 +30,7 @@ import com.google.common.base.Preconditions;
import java.util.List;
/** max agg function. */
public class Max extends AggregateFunction implements UnaryExpression {
public class Max extends AggregateFunction implements UnaryExpression, PropagateNullable, CustomSignature {
public Max(Expression child) {
super("max", child);
}
@ -38,18 +40,13 @@ public class Max extends AggregateFunction implements UnaryExpression {
}
@Override
public DataType getFinalType() {
return child().getDataType();
public FunctionSignature customSignature(List<DataType> argumentTypes, List<Expression> arguments) {
return FunctionSignature.ret(argumentTypes.get(0)).args(argumentTypes.get(0));
}
@Override
public DataType getIntermediateType() {
return getFinalType();
}
@Override
public boolean nullable() {
return child().nullable();
protected List<DataType> intermediateTypes(List<DataType> argumentTypes, List<Expression> arguments) {
return argumentTypes;
}
@Override

View File

@ -17,7 +17,10 @@
package org.apache.doris.nereids.trees.expressions.functions.agg;
import org.apache.doris.catalog.FunctionSignature;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.CustomSignature;
import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable;
import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.DataType;
@ -27,7 +30,7 @@ import com.google.common.base.Preconditions;
import java.util.List;
/** min agg function. */
public class Min extends AggregateFunction implements UnaryExpression {
public class Min extends AggregateFunction implements UnaryExpression, PropagateNullable, CustomSignature {
public Min(Expression child) {
super("min", child);
@ -38,18 +41,13 @@ public class Min extends AggregateFunction implements UnaryExpression {
}
@Override
public DataType getFinalType() {
return child().getDataType();
public FunctionSignature customSignature(List<DataType> argumentTypes, List<Expression> arguments) {
return FunctionSignature.ret(argumentTypes.get(0)).args(argumentTypes.get(0));
}
@Override
public DataType getIntermediateType() {
return getFinalType();
}
@Override
public boolean nullable() {
return child().nullable();
protected List<DataType> intermediateTypes(List<DataType> argumentTypes, List<Expression> arguments) {
return argumentTypes;
}
@Override

View File

@ -17,17 +17,18 @@
package org.apache.doris.nereids.trees.expressions.functions.agg;
import org.apache.doris.catalog.FunctionSignature;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.CustomSignature;
import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable;
import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression;
import org.apache.doris.nereids.trees.expressions.typecoercion.ImplicitCastInputTypes;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.BigIntType;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.DecimalV2Type;
import org.apache.doris.nereids.types.DoubleType;
import org.apache.doris.nereids.types.LargeIntType;
import org.apache.doris.nereids.types.coercion.AbstractDataType;
import org.apache.doris.nereids.types.coercion.FractionalType;
import org.apache.doris.nereids.types.coercion.IntegralType;
import org.apache.doris.nereids.types.coercion.NumericType;
@ -37,11 +38,7 @@ import com.google.common.collect.ImmutableList;
import java.util.List;
/** sum agg function. */
public class Sum extends AggregateFunction implements UnaryExpression, ImplicitCastInputTypes {
// used in interface expectedInputTypes to avoid new list in each time it be called
private static final List<AbstractDataType> EXPECTED_INPUT_TYPES = ImmutableList.of(NumericType.INSTANCE);
public class Sum extends AggregateFunction implements UnaryExpression, PropagateNullable, CustomSignature {
public Sum(Expression child) {
super("sum", child);
}
@ -51,34 +48,14 @@ public class Sum extends AggregateFunction implements UnaryExpression, ImplicitC
}
@Override
public DataType getFinalType() {
DataType dataType = child().getDataType();
if (dataType instanceof LargeIntType) {
return dataType;
} else if (dataType instanceof DecimalV2Type) {
return DecimalV2Type.SYSTEM_DEFAULT;
} else if (dataType instanceof IntegralType) {
return BigIntType.INSTANCE;
} else if (dataType instanceof FractionalType) {
return DoubleType.INSTANCE;
} else {
throw new IllegalStateException("Unsupported sum type: " + dataType);
}
public FunctionSignature customSignature(List<DataType> argumentTypes, List<Expression> arguments) {
DataType implicitCastType = implicitCast(argumentTypes.get(0));
return FunctionSignature.ret(implicitCastType).args(NumericType.INSTANCE);
}
@Override
public DataType getIntermediateType() {
return getFinalType();
}
@Override
public boolean nullable() {
return child().nullable();
}
@Override
public List<AbstractDataType> expectedInputTypes() {
return EXPECTED_INPUT_TYPES;
protected List<DataType> intermediateTypes(List<DataType> argumentTypes, List<Expression> arguments) {
return ImmutableList.of(getFinalType());
}
@Override
@ -96,4 +73,18 @@ public class Sum extends AggregateFunction implements UnaryExpression, ImplicitC
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
return visitor.visitSum(this, context);
}
private DataType implicitCast(DataType dataType) {
if (dataType instanceof LargeIntType) {
return dataType;
} else if (dataType instanceof DecimalV2Type) {
return DecimalV2Type.SYSTEM_DEFAULT;
} else if (dataType instanceof IntegralType) {
return BigIntType.INSTANCE;
} else if (dataType instanceof NumericType) {
return DoubleType.INSTANCE;
} else {
throw new AnalysisException("sum requires a numeric parameter: " + dataType);
}
}
}

View File

@ -41,6 +41,7 @@ import org.apache.doris.nereids.types.SmallIntType;
import org.apache.doris.nereids.types.StringType;
import org.apache.doris.nereids.types.TinyIntType;
import org.apache.doris.nereids.types.VarcharType;
import org.apache.doris.nereids.types.coercion.AbstractDataType;
import com.google.common.base.Preconditions;
import com.google.common.base.Suppliers;
@ -90,7 +91,7 @@ public class If extends ScalarFunction
);
private final Supplier<DataType> widerType = Suppliers.memoize(() -> {
List<DataType> argumentsTypes = getSignature().argumentsTypes;
List<AbstractDataType> argumentsTypes = getSignature().argumentsTypes;
Type assignmentCompatibleType = ScalarType.getAssignmentCompatibleType(
argumentsTypes.get(1).toCatalogDataType(),
argumentsTypes.get(2).toCatalogDataType(),
@ -106,11 +107,11 @@ public class If extends ScalarFunction
}
@Override
protected FunctionSignature computeSignature(FunctionSignature signature) {
protected FunctionSignature computeSignature(FunctionSignature signature, List<Expression> arguments) {
DataType widerType = this.widerType.get();
signature = signature.withArgumentTypes(children(), (sigType, argType) -> widerType)
signature = signature.withArgumentTypes(arguments, (sigType, argType) -> widerType)
.withReturnType(widerType);
return super.computeSignature(signature);
return super.computeSignature(signature, arguments);
}
/**

View File

@ -17,45 +17,17 @@
package org.apache.doris.nereids.trees.expressions.functions.scalar;
import org.apache.doris.catalog.FunctionSignature;
import org.apache.doris.catalog.ScalarType;
import org.apache.doris.catalog.Type;
import org.apache.doris.common.Config;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.annotation.Developing;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.BoundFunction;
import org.apache.doris.nereids.trees.expressions.functions.ComputeSignature;
import org.apache.doris.nereids.trees.expressions.functions.DecimalSamePrecision;
import org.apache.doris.nereids.trees.expressions.functions.DecimalStddevPrecision;
import org.apache.doris.nereids.trees.expressions.functions.DecimalWiderPrecision;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.ArrayType;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.DateTimeV2Type;
import org.apache.doris.nereids.types.DecimalV2Type;
import org.apache.doris.nereids.types.DecimalV3Type;
import org.apache.doris.nereids.types.coercion.AbstractDataType;
import org.apache.doris.nereids.util.ResponsibilityChain;
import com.google.common.base.Suppliers;
import java.util.List;
import java.util.function.BiFunction;
import java.util.function.Supplier;
/**
* The function which consume zero or more arguments in a row and product one value.
*/
public abstract class ScalarFunction extends BoundFunction implements ComputeSignature {
@Developing("this field will move to BoundFunction, when we support compute signature for AggregateFunction")
private final Supplier<FunctionSignature> signatureCache = Suppliers.memoize(() -> {
// first step: find the candidate signature in the signature list
FunctionSignature matchedSignature = searchSignature();
// second step: change the signature, e.g. fill precision for decimal v2
return computeSignature(matchedSignature);
});
public ScalarFunction(String name, Expression... arguments) {
super(name, arguments);
}
@ -64,150 +36,8 @@ public abstract class ScalarFunction extends BoundFunction implements ComputeSig
super(name, arguments);
}
public FunctionSignature getSignature() {
return signatureCache.get();
}
protected FunctionSignature computeSignature(FunctionSignature signature) {
// NOTE:
// this computed chain only process the common cases.
// If you want to add some common cases to here, please separate the process code
// to the other methods and add to this chain.
// If you want to add some special cases, please override this method in the special
// function class, like 'If' function and 'Substring' function.
return ComputeSignatureChain.from(signature, getArguments())
.then(this::computePrecisionForDatetimeV2)
.then(this::upgradeDateOrDateTimeToV2)
.then(this::upgradeDecimalV2ToV3)
.then(this::computePrecisionForDecimal)
.then(this::dynamicComputePropertiesOfArray)
.get();
}
@Override
@Developing("this method will move to BoundFunction, when we support compute signature for AggregateFunction")
public final List<AbstractDataType> expectedInputTypes() {
return ComputeSignature.super.expectedInputTypes();
}
@Override
@Developing("this method will move to BoundFunction, when we support compute signature for AggregateFunction")
public final DataType getDataType() {
return ComputeSignature.super.getDataType();
}
@Override
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
return visitor.visitScalarFunction(this, context);
}
private FunctionSignature computePrecisionForDatetimeV2(
FunctionSignature signature, List<Expression> arguments) {
// fill for arguments type
signature = signature.withArgumentTypes(arguments, (sigArgType, realArgType) -> {
if (sigArgType instanceof DateTimeV2Type && realArgType.getDataType() instanceof DateTimeV2Type) {
return realArgType.getDataType();
}
return sigArgType;
});
// fill for return type
if (signature.returnType instanceof DateTimeV2Type) {
Integer maxScale = signature.argumentsTypes.stream()
.filter(DateTimeV2Type.class::isInstance)
.map(t -> ((DateTimeV2Type) t).getScale())
.reduce(Math::max)
.orElse(((DateTimeV2Type) signature.returnType).getScale());
signature = signature.withReturnType(DateTimeV2Type.of(maxScale));
}
return signature;
}
private FunctionSignature upgradeDateOrDateTimeToV2(
FunctionSignature signature, List<Expression> arguments) {
DataType returnType = signature.returnType;
Type type = returnType.toCatalogDataType();
if ((type.isDate() || type.isDatetime()) && Config.enable_date_conversion) {
Type legacyReturnType = ScalarType.getDefaultDateType(returnType.toCatalogDataType());
signature = signature.withReturnType(DataType.fromCatalogType(legacyReturnType));
}
return signature;
}
@Developing
private FunctionSignature computePrecisionForDecimal(
FunctionSignature signature, List<Expression> arguments) {
if (signature.returnType instanceof DecimalV3Type || signature.returnType instanceof DecimalV2Type) {
if (this instanceof DecimalSamePrecision) {
signature = signature.withReturnType(signature.argumentsTypes.get(0));
} else if (this instanceof DecimalWiderPrecision) {
ScalarType widerType = ScalarType.createDecimalV3Type(ScalarType.MAX_DECIMAL128_PRECISION,
((ScalarType) signature.argumentsTypes.get(0).toCatalogDataType()).getScalarScale());
signature = signature.withReturnType(DataType.fromCatalogType(widerType));
} else if (this instanceof DecimalStddevPrecision) {
// for all stddev function, use decimal(38,9) as computing result
ScalarType stddevDecimalType = ScalarType.createDecimalV3Type(ScalarType.MAX_DECIMAL128_PRECISION,
DecimalStddevPrecision.STDDEV_DECIMAL_SCALE);
signature = signature.withReturnType(DataType.fromCatalogType(stddevDecimalType));
}
}
return signature;
}
private FunctionSignature upgradeDecimalV2ToV3(
FunctionSignature signature, List<Expression> arguments) {
DataType returnType = signature.returnType;
Type type = returnType.toCatalogDataType();
if (type.isDecimalV2() && Config.enable_decimal_conversion && Config.enable_decimalv3) {
Type v3Type = ScalarType.createDecimalV3Type(type.getPrecision(), ((ScalarType) type).getScalarScale());
signature = signature.withReturnType(DataType.fromCatalogType(v3Type));
}
return signature;
}
private FunctionSignature dynamicComputePropertiesOfArray(
FunctionSignature signature, List<Expression> arguments) {
if (!(signature.returnType instanceof ArrayType)) {
return signature;
}
// TODO: compute array(...) function's itemType
// fill item type by the type of first item
ArrayType arrayType = (ArrayType) signature.returnType;
// fill containsNull if any array argument contains null
boolean containsNull = signature.argumentsTypes
.stream()
.filter(argType -> argType instanceof ArrayType)
.map(ArrayType.class::cast)
.anyMatch(ArrayType::containsNull);
return signature.withReturnType(
ArrayType.of(arrayType.getItemType(), arrayType.containsNull() || containsNull));
}
static class ComputeSignatureChain {
private ResponsibilityChain<Pair<FunctionSignature, List<Expression>>> computeChain;
public ComputeSignatureChain(ResponsibilityChain<Pair<FunctionSignature, List<Expression>>> computeChain) {
this.computeChain = computeChain;
}
public static ComputeSignatureChain from(FunctionSignature signature, List<Expression> arguments) {
return new ComputeSignatureChain(ResponsibilityChain.from(Pair.of(signature, arguments)));
}
public ComputeSignatureChain then(
BiFunction<FunctionSignature, List<Expression>, FunctionSignature> computeFunction) {
computeChain.then(pair -> Pair.of(computeFunction.apply(pair.first, pair.second), pair.second));
return this;
}
public FunctionSignature get() {
return computeChain.get().first;
}
}
}

View File

@ -56,7 +56,7 @@ public class StrToDate extends ScalarFunction
}
@Override
protected FunctionSignature computeSignature(FunctionSignature signature) {
protected FunctionSignature computeSignature(FunctionSignature signature, List<Expression> arguments) {
/*
* The return type of str_to_date depends on whether the time part is included in the format.
* If included, it is datetime, otherwise it is date.
@ -83,8 +83,8 @@ public class StrToDate extends ScalarFunction
* Return type is DATETIME
*/
DataType returnType;
if (child(1) instanceof StringLikeLiteral) {
if (DateLiteral.hasTimePart(((StringLikeLiteral) child(1)).getStringValue())) {
if (arguments.get(1) instanceof StringLikeLiteral) {
if (DateLiteral.hasTimePart(((StringLikeLiteral) arguments.get(1)).getStringValue())) {
returnType = DataType.fromCatalogType(ScalarType.getDefaultDateType(Type.DATETIME));
} else {
returnType = DataType.fromCatalogType(ScalarType.getDefaultDateType(Type.DATE));

View File

@ -66,8 +66,9 @@ public class Substring extends ScalarFunction
}
@Override
protected FunctionSignature computeSignature(FunctionSignature signature) {
Optional<Expression> length = getLength();
protected FunctionSignature computeSignature(FunctionSignature signature, List<Expression> arguments) {
Optional<Expression> length = arguments.size() == 3
? Optional.of(arguments.get(2)) : Optional.empty();
DataType returnType = VarcharType.SYSTEM_DEFAULT;
if (length.isPresent() && length.get() instanceof IntegerLiteral) {
returnType = VarcharType.createVarcharType(((IntegerLiteral) length.get()).getValue());

View File

@ -18,6 +18,7 @@
package org.apache.doris.nereids.trees.expressions.functions.table;
import org.apache.doris.analysis.IntLiteral;
import org.apache.doris.catalog.FunctionSignature;
import org.apache.doris.catalog.Type;
import org.apache.doris.common.Id;
import org.apache.doris.common.NereidsException;
@ -26,6 +27,8 @@ import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.TVFProperties;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.BigIntType;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.statistics.ColumnStatistic;
import org.apache.doris.statistics.StatsDeriveResult;
import org.apache.doris.tablefunction.NumbersTableValuedFunction;
@ -43,10 +46,15 @@ public class Numbers extends TableValuedFunction {
super("numbers", properties);
}
@Override
public FunctionSignature customSignature(List<DataType> argumentTypes, List<Expression> arguments) {
return FunctionSignature.of(BigIntType.INSTANCE, (List) argumentTypes);
}
@Override
protected TableValuedFunctionIf toCatalogFunction() {
try {
Map<String, String> arguments = getKeyValuesExpression().getMap();
Map<String, String> arguments = getTVFProperties().getMap();
return new NumbersTableValuedFunction(arguments);
} catch (Throwable t) {
throw new AnalysisException("Can not build NumbersTableValuedFunction by "
@ -82,9 +90,4 @@ public class Numbers extends TableValuedFunction {
&& children().get(0) instanceof TVFProperties);
return new Numbers((TVFProperties) children.get(0));
}
@Override
public boolean hasVarArguments() {
return false;
}
}

View File

@ -24,6 +24,7 @@ import org.apache.doris.nereids.exceptions.UnboundException;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.TVFProperties;
import org.apache.doris.nereids.trees.expressions.functions.BoundFunction;
import org.apache.doris.nereids.trees.expressions.functions.CustomSignature;
import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.DataType;
@ -38,7 +39,7 @@ import java.util.function.Supplier;
import java.util.stream.Collectors;
/** TableValuedFunction */
public abstract class TableValuedFunction extends BoundFunction implements UnaryExpression {
public abstract class TableValuedFunction extends BoundFunction implements UnaryExpression, CustomSignature {
protected final Supplier<TableValuedFunctionIf> catalogFunctionCache = Suppliers.memoize(() -> toCatalogFunction());
protected final Supplier<FunctionGenTable> tableCache = Suppliers.memoize(() -> {
try {
@ -58,7 +59,7 @@ public abstract class TableValuedFunction extends BoundFunction implements Unary
public abstract StatsDeriveResult computeStats(List<Slot> slots);
public TVFProperties getKeyValuesExpression() {
public TVFProperties getTVFProperties() {
return (TVFProperties) child(0);
}
@ -95,7 +96,7 @@ public abstract class TableValuedFunction extends BoundFunction implements Unary
@Override
public String toSql() {
String args = getKeyValuesExpression()
String args = getTVFProperties()
.getMap()
.entrySet()
.stream()

View File

@ -29,10 +29,22 @@ public class ExplainCommand implements Command {
* explain level.
*/
public enum ExplainLevel {
NORMAL,
VERBOSE,
GRAPH,
NONE(false),
NORMAL(false),
VERBOSE(false),
GRAPH(false),
PARSED_PLAN(true),
ANALYZED_PLAN(true),
REWRITTEN_PLAN(true),
OPTIMIZED_PLAN(true),
ALL_PLAN(true)
;
public final boolean isPlanLevel;
ExplainLevel(boolean isPlanLevel) {
this.isPlanLevel = isPlanLevel;
}
}
private final ExplainLevel level;

View File

@ -78,7 +78,7 @@ public class PhysicalTVFRelation extends PhysicalRelation implements TVFRelation
@Override
public String toString() {
return Utils.toSqlString("LogicalTVFRelation",
return Utils.toSqlString("PhysicalTVFRelation",
"qualified", Utils.qualifiedName(qualifier, getTable().getName()),
"output", getOutput(),
"function", function.toSql()

View File

@ -0,0 +1,99 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.types;
import org.apache.doris.catalog.ScalarType;
import org.apache.doris.catalog.Type;
import org.apache.doris.nereids.trees.expressions.Expression;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import java.util.List;
import java.util.Objects;
/** PartialAggType */
public class PartialAggType extends DataType {
public final List<Expression> originArguments;
public final List<DataType> intermediateTypes;
/** PartialAggType */
public PartialAggType(List<Expression> originArguments, List<DataType> intermediateTypes) {
this.originArguments = ImmutableList.copyOf(
Objects.requireNonNull(originArguments, "originArguments can not be null"));
this.intermediateTypes = ImmutableList.copyOf(
Objects.requireNonNull(intermediateTypes, "intermediateTypes can not be null"));
Preconditions.checkArgument(intermediateTypes.size() > 0, "intermediateTypes can not empty");
}
public List<Expression> getOriginArguments() {
return originArguments;
}
public List<DataType> getIntermediateTypes() {
return intermediateTypes;
}
public List<DataType> getOriginInputTypes() {
return originArguments.stream()
.map(Expression::getDataType)
.collect(ImmutableList.toImmutableList());
}
@Override
public String toSql() {
return "PartialAggType(types=" + intermediateTypes + ")";
}
@Override
public int width() {
return intermediateTypes.stream()
.map(DataType::width)
.reduce((w1, w2) -> w1 + w2)
.get();
}
@Override
public Type toCatalogDataType() {
if (intermediateTypes.size() == 1) {
return intermediateTypes.get(0).toCatalogDataType();
}
return ScalarType.createVarcharType(-1);
}
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
if (!super.equals(o)) {
return false;
}
PartialAggType that = (PartialAggType) o;
return Objects.equals(originArguments, that.originArguments)
&& Objects.equals(intermediateTypes, that.intermediateTypes);
}
@Override
public int hashCode() {
return Objects.hash(super.hashCode(), originArguments, intermediateTypes);
}
}

View File

@ -18,15 +18,19 @@
package org.apache.doris.nereids.rules.analysis;
import org.apache.doris.catalog.FunctionRegistry;
import org.apache.doris.catalog.FunctionSignature;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.BoundFunction;
import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
import org.apache.doris.nereids.trees.expressions.functions.FunctionBuilder;
import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ScalarFunction;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Substring;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Year;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression;
import org.apache.doris.nereids.types.IntegerType;
import org.apache.doris.nereids.util.MemoTestUtils;
import org.apache.doris.nereids.util.PatternMatchSupported;
import org.apache.doris.nereids.util.PlanChecker;
@ -136,18 +140,27 @@ public class FunctionRegistryTest implements PatternMatchSupported {
});
}
public static class ExtendFunction extends BoundFunction implements UnaryExpression, PropagateNullable {
public static class ExtendFunction extends BoundFunction implements UnaryExpression, PropagateNullable,
ExplicitlyCastableSignature {
public ExtendFunction(Expression a1) {
super("foo", a1);
}
@Override
public List<FunctionSignature> getSignatures() {
return ImmutableList.of(
FunctionSignature.ret(IntegerType.INSTANCE).args(IntegerType.INSTANCE)
);
}
@Override
public boolean hasVarArguments() {
return false;
}
}
public static class AmbiguousFunction extends BoundFunction implements UnaryExpression, PropagateNullable {
public static class AmbiguousFunction extends ScalarFunction implements UnaryExpression, PropagateNullable,
ExplicitlyCastableSignature {
public AmbiguousFunction(Expression a1) {
super("abc", a1);
}
@ -156,6 +169,13 @@ public class FunctionRegistryTest implements PatternMatchSupported {
super("abc", a1);
}
@Override
public List<FunctionSignature> getSignatures() {
return ImmutableList.of(
FunctionSignature.ret(IntegerType.INSTANCE).args(IntegerType.INSTANCE)
);
}
@Override
public boolean hasVarArguments() {
return false;

View File

@ -17,6 +17,7 @@
package org.apache.doris.nereids.rules.analysis;
import org.apache.doris.common.NereidsException;
import org.apache.doris.nereids.NereidsPlanner;
import org.apache.doris.nereids.StatementContext;
import org.apache.doris.nereids.analyzer.UnboundAlias;
@ -282,7 +283,7 @@ public class RegisterCTETest extends TestWithFeService implements PatternMatchSu
String sql = "WITH cte1 (a1, A1) AS (SELECT * FROM supplier)"
+ "SELECT * FROM cte1";
AnalysisException exception = Assertions.assertThrows(AnalysisException.class, () -> {
NereidsException exception = Assertions.assertThrows(NereidsException.class, () -> {
PlanChecker.from(connectContext).checkPlannerResult(sql);
}, "Not throw expected exception.");
Assertions.assertTrue(exception.getMessage().contains("Duplicated CTE column alias: [a1] in CTE [cte1]"));
@ -294,7 +295,7 @@ public class RegisterCTETest extends TestWithFeService implements PatternMatchSu
+ "(SELECT s_suppkey FROM supplier)"
+ "SELECT * FROM cte1";
AnalysisException exception = Assertions.assertThrows(AnalysisException.class, () -> {
NereidsException exception = Assertions.assertThrows(NereidsException.class, () -> {
PlanChecker.from(connectContext).checkPlannerResult(sql);
}, "Not throw expected exception.");
System.out.println(exception.getMessage());

View File

@ -17,6 +17,7 @@
package org.apache.doris.nereids.rules.expression.rewrite;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.rules.expression.rewrite.rules.TypeCoercion;
import org.apache.doris.nereids.trees.expressions.CaseWhen;
import org.apache.doris.nereids.trees.expressions.Cast;
@ -48,6 +49,7 @@ import org.apache.doris.nereids.types.TinyIntType;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
@ -100,16 +102,16 @@ public class TypeCoercionTest extends ExpressionRewriteTestHelper {
@Test
public void testSumImplicitCast() {
Expression expression = new Sum(new StringLiteral("1"));
Expression expected = new Sum(new Cast(new StringLiteral("1"), DoubleType.INSTANCE));
assertRewrite(expression, expected);
Assertions.assertThrows(AnalysisException.class, () -> {
new Sum(new StringLiteral("1")).getDataType();
});
}
@Test
public void testAvgImplicitCast() {
Expression expression = new Avg(new StringLiteral("1"));
Expression expected = new Avg(new Cast(new StringLiteral("1"), DoubleType.INSTANCE));
assertRewrite(expression, expected);
Assertions.assertThrows(AnalysisException.class, () -> {
new Avg(new StringLiteral("1")).getDataType();
});
}
@Test

View File

@ -45,7 +45,6 @@ import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInstance;
import java.util.List;
import java.util.Optional;
@TestInstance(TestInstance.Lifecycle.PER_CLASS)
public class AggregateDisassembleTest implements PatternMatchSupported {
@ -277,7 +276,7 @@ public class AggregateDisassembleTest implements PatternMatchSupported {
// id
Expression localOutput0 = rStudent.getOutput().get(0);
// sum
Sum localOutput1 = new Sum(new AggregateParam(false, false, Optional.empty()), rStudent.getOutput().get(0).toSlot());
Sum localOutput1 = new Sum(new AggregateParam(false, false, true), rStudent.getOutput().get(0).toSlot());
// age
Expression localOutput2 = rStudent.getOutput().get(2);
// id

View File

@ -17,12 +17,10 @@
package org.apache.doris.nereids.rules.rewrite.logical;
import org.apache.doris.common.AnalysisException;
import org.apache.doris.nereids.NereidsPlanner;
import org.apache.doris.nereids.parser.NereidsParser;
import org.apache.doris.nereids.properties.PhysicalProperties;
import org.apache.doris.nereids.trees.expressions.NamedExpressionUtil;
import org.apache.doris.nereids.trees.plans.physical.PhysicalPlan;
import org.apache.doris.nereids.util.PatternMatchSupported;
import org.apache.doris.nereids.util.PlanChecker;
import org.apache.doris.utframe.TestWithFeService;
@ -33,13 +31,6 @@ import org.junit.jupiter.api.Test;
import java.util.List;
public class PushdownExpressionsInHashConditionTest extends TestWithFeService implements PatternMatchSupported {
private final List<String> testSql = ImmutableList.of(
"SELECT * FROM T1 JOIN T2 ON T1.ID + 1 = T2.ID + 2 AND T1.ID + 1 > 2",
"SELECT * FROM (SELECT * FROM T1) X JOIN (SELECT * FROM T2) Y ON X.ID + 1 = Y.ID + 2 AND X.ID + 1 > 2",
"SELECT * FROM T1 JOIN (SELECT ID, SUM(SCORE) SCORE FROM T2 GROUP BY ID) T ON T1.ID + 1 = T.ID AND T.SCORE < 10",
"SELECT * FROM T1 JOIN (SELECT ID, SUM(SCORE) SCORE FROM T2 GROUP BY ID ORDER BY ID) T ON T1.ID + 1 = T.ID AND T.SCORE < 10"
);
@Override
protected void runBeforeAll() throws Exception {
createDatabase("test");
@ -81,15 +72,10 @@ public class PushdownExpressionsInHashConditionTest extends TestWithFeService im
"SELECT * FROM T1 JOIN (SELECT ID, SUM(SCORE) SCORE FROM T2 GROUP BY ID ORDER BY ID) T ON T1.ID + 1 = T.ID AND T.SCORE < 10"
);
testSql.forEach(sql -> {
try {
PhysicalPlan plan = new NereidsPlanner(createStatementCtx(sql)).plan(
new NereidsParser().parseSingle(sql),
PhysicalProperties.ANY
);
System.out.println(plan.treeString());
} catch (AnalysisException e) {
throw new RuntimeException(e);
}
new NereidsPlanner(createStatementCtx(sql)).plan(
new NereidsParser().parseSingle(sql),
PhysicalProperties.ANY
);
});
}

View File

@ -32,4 +32,19 @@ suite("explain") {
contains "project output tuple id: 1"
}
explain {
sql("physical plan select 100")
contains "PhysicalOneRowRelation"
}
explain {
sql("logical plan select 100")
contains "LogicalOneRowRelation"
}
explain {
sql("parsed plan select 100")
contains "UnboundOneRowRelation"
}
}