[Enhancement](Nereids) Refactor AggregateFunction and support explain plan (#14380)
# Proposed changes
- Refactor AggregateFunction
1. AggregateFunction implement ComputeSignature
3. Add a CustomSignature to dynamic compute signature, we can check input type and compute implicit cast type in the `customSignature` method
2. Add PartialAggType to record some type information before disassemble aggregate
4. Refine and create a custom catalog function when translate AggregateFunction, without `finalizeForNereids`
- Support explain plan
1. explain parsed plan select ...
5. explain analyzed plan select ...
6. explain rewritten/logical plan select ...
7. explain optimized/physical plan select ...
8. explain all plan select ...
This commit is contained in:
@ -94,6 +94,7 @@ AFTER: 'AFTER';
|
||||
ALL: 'ALL';
|
||||
ALTER: 'ALTER';
|
||||
ANALYZE: 'ANALYZE';
|
||||
ANALYZED: 'ANALYZED';
|
||||
AND: 'AND';
|
||||
ANTI: 'ANTI';
|
||||
ANY: 'ANY';
|
||||
@ -247,6 +248,7 @@ NULLS: 'NULLS';
|
||||
OF: 'OF';
|
||||
ON: 'ON';
|
||||
ONLY: 'ONLY';
|
||||
OPTIMIZED: 'OPTIMIZED';
|
||||
OPTION: 'OPTION';
|
||||
OPTIONS: 'OPTIONS';
|
||||
OR: 'OR';
|
||||
@ -258,13 +260,16 @@ OVER: 'OVER';
|
||||
OVERLAPS: 'OVERLAPS';
|
||||
OVERLAY: 'OVERLAY';
|
||||
OVERWRITE: 'OVERWRITE';
|
||||
PARSED: 'PARSED';
|
||||
PARTITION: 'PARTITION';
|
||||
PARTITIONED: 'PARTITIONED';
|
||||
PARTITIONS: 'PARTITIONS';
|
||||
PERCENTILE_CONT: 'PERCENTILE_CONT';
|
||||
PERCENTLIT: 'PERCENT';
|
||||
PHYSICAL: 'PHYSICAL';
|
||||
PIVOT: 'PIVOT';
|
||||
PLACING: 'PLACING';
|
||||
PLAN: 'PLAN';
|
||||
POSITION: 'POSITION';
|
||||
PRECEDING: 'PRECEDING';
|
||||
PRIMARY: 'PRIMARY';
|
||||
@ -288,6 +293,7 @@ RESET: 'RESET';
|
||||
RESPECT: 'RESPECT';
|
||||
RESTRICT: 'RESTRICT';
|
||||
REVOKE: 'REVOKE';
|
||||
REWRITTEN: 'REWRITTEN';
|
||||
RIGHT: 'RIGHT';
|
||||
// original optimizer only support REGEXP, the new optimizer should be consistent with it
|
||||
RLIKE: 'RLIKE';
|
||||
|
||||
@ -50,7 +50,16 @@ singleStatement
|
||||
|
||||
statement
|
||||
: cte? query #statementDefault
|
||||
| (EXPLAIN | DESC | DESCRIBE) level=(VERBOSE | GRAPH)? query #explain
|
||||
| (EXPLAIN planType? | DESC | DESCRIBE)
|
||||
level=(VERBOSE | GRAPH | PLAN)? query #explain
|
||||
;
|
||||
|
||||
planType
|
||||
: PARSED
|
||||
| ANALYZED
|
||||
| REWRITTEN | LOGICAL // same type
|
||||
| OPTIMIZED | PHYSICAL // same type
|
||||
| ALL // default type
|
||||
;
|
||||
|
||||
// -----------------Query-----------------
|
||||
@ -335,6 +344,7 @@ ansiNonReserved
|
||||
| AFTER
|
||||
| ALTER
|
||||
| ANALYZE
|
||||
| ANALYZED
|
||||
| ANTI
|
||||
| ARCHIVE
|
||||
| ARRAY
|
||||
@ -441,6 +451,7 @@ ansiNonReserved
|
||||
| NO
|
||||
| NULLS
|
||||
| OF
|
||||
| OPTIMIZED
|
||||
| OPTION
|
||||
| OPTIONS
|
||||
| OUT
|
||||
@ -448,12 +459,15 @@ ansiNonReserved
|
||||
| OVER
|
||||
| OVERLAY
|
||||
| OVERWRITE
|
||||
| PARSED
|
||||
| PARTITION
|
||||
| PARTITIONED
|
||||
| PARTITIONS
|
||||
| PERCENTLIT
|
||||
| PHYSICAL
|
||||
| PIVOT
|
||||
| PLACING
|
||||
| PLAN
|
||||
| POSITION
|
||||
| PRECEDING
|
||||
| PRINCIPALS
|
||||
@ -474,6 +488,7 @@ ansiNonReserved
|
||||
| RESPECT
|
||||
| RESTRICT
|
||||
| REVOKE
|
||||
| REWRITTEN
|
||||
| RLIKE
|
||||
| ROLE
|
||||
| ROLES
|
||||
@ -575,6 +590,7 @@ nonReserved
|
||||
| ALL
|
||||
| ALTER
|
||||
| ANALYZE
|
||||
| ANALYZED
|
||||
| AND
|
||||
| ANY
|
||||
| ARCHIVE
|
||||
@ -716,6 +732,7 @@ nonReserved
|
||||
| NULLS
|
||||
| OF
|
||||
| ONLY
|
||||
| OPTIMIZED
|
||||
| OPTION
|
||||
| OPTIONS
|
||||
| OR
|
||||
@ -727,13 +744,16 @@ nonReserved
|
||||
| OVERLAPS
|
||||
| OVERLAY
|
||||
| OVERWRITE
|
||||
| PARSED
|
||||
| PARTITION
|
||||
| PARTITIONED
|
||||
| PARTITIONS
|
||||
| PERCENTILE_CONT
|
||||
| PERCENTLIT
|
||||
| PHYSICAL
|
||||
| PIVOT
|
||||
| PLACING
|
||||
| PLAN
|
||||
| POSITION
|
||||
| PRECEDING
|
||||
| PRIMARY
|
||||
@ -756,6 +776,7 @@ nonReserved
|
||||
| RESPECT
|
||||
| RESTRICT
|
||||
| REVOKE
|
||||
| REWRITTEN
|
||||
| RLIKE
|
||||
| ROLE
|
||||
| ROLES
|
||||
|
||||
@ -17,21 +17,38 @@
|
||||
|
||||
package org.apache.doris.analysis;
|
||||
|
||||
import org.apache.doris.nereids.trees.plans.commands.ExplainCommand;
|
||||
import org.apache.doris.nereids.trees.plans.commands.ExplainCommand.ExplainLevel;
|
||||
|
||||
public class ExplainOptions {
|
||||
|
||||
private boolean isVerbose;
|
||||
private boolean isGraph;
|
||||
|
||||
private ExplainCommand.ExplainLevel explainLevel;
|
||||
|
||||
public ExplainOptions(ExplainCommand.ExplainLevel explainLevel) {
|
||||
this.explainLevel = explainLevel;
|
||||
}
|
||||
|
||||
public ExplainOptions(boolean isVerbose, boolean isGraph) {
|
||||
this.isVerbose = isVerbose;
|
||||
this.isGraph = isGraph;
|
||||
}
|
||||
|
||||
public boolean isVerbose() {
|
||||
return isVerbose;
|
||||
return explainLevel == ExplainLevel.VERBOSE || isVerbose;
|
||||
}
|
||||
|
||||
public boolean isGraph() {
|
||||
return isGraph;
|
||||
return explainLevel == ExplainLevel.GRAPH || isGraph;
|
||||
}
|
||||
|
||||
public boolean hasExplainLevel() {
|
||||
return explainLevel != null;
|
||||
}
|
||||
|
||||
public ExplainLevel getExplainLevel() {
|
||||
return explainLevel;
|
||||
}
|
||||
}
|
||||
|
||||
@ -220,16 +220,22 @@ public class FunctionCallExpr extends Expr {
|
||||
this.argTypesForNereids = argTypes;
|
||||
}
|
||||
|
||||
// nereids constructor without finalize/analyze
|
||||
public FunctionCallExpr(FunctionName functionName, Function function, FunctionParams functionParams) {
|
||||
this.fnName = functionName;
|
||||
// nereids scalar function call expr constructor without finalize/analyze
|
||||
public FunctionCallExpr(Function function, FunctionParams functionParams) {
|
||||
this(function, functionParams, null, false, functionParams.exprs());
|
||||
}
|
||||
|
||||
// nereids aggregate function call expr constructor without finalize/analyze
|
||||
public FunctionCallExpr(Function function, FunctionParams functionParams, FunctionParams aggFnParams,
|
||||
boolean isMergeAggFn, List<Expr> children) {
|
||||
this.fnName = function.getFunctionName();
|
||||
this.fn = function;
|
||||
this.type = function.getReturnType();
|
||||
this.fnParams = functionParams;
|
||||
if (functionParams.exprs() != null) {
|
||||
this.children.addAll(functionParams.exprs());
|
||||
}
|
||||
this.aggFnParams = aggFnParams;
|
||||
this.children.addAll(children);
|
||||
this.originChildSize = children.size();
|
||||
this.isMergeAggFn = isMergeAggFn;
|
||||
this.shouldFinalizeForNereids = false;
|
||||
}
|
||||
|
||||
|
||||
@ -84,6 +84,19 @@ public class SlotRef extends Expr {
|
||||
analysisDone();
|
||||
}
|
||||
|
||||
// nerieds use this constructor to build aggFnParam
|
||||
public SlotRef(Type type, boolean nullable) {
|
||||
super();
|
||||
// tuple id and slot id is meaningless here, nereids just use type and nullable
|
||||
// to build the TAggregateExpr.param_types
|
||||
TupleDescriptor tupleDescriptor = new TupleDescriptor(new TupleId(-1));
|
||||
desc = new SlotDescriptor(new SlotId(-1), tupleDescriptor);
|
||||
tupleDescriptor.addSlot(desc);
|
||||
desc.setIsNullable(nullable);
|
||||
desc.setType(type);
|
||||
this.type = type;
|
||||
}
|
||||
|
||||
protected SlotRef(SlotRef other) {
|
||||
super(other);
|
||||
tblName = other.tblName;
|
||||
|
||||
@ -183,6 +183,30 @@ public class AggregateFunction extends Function {
|
||||
returnsNonNullOnEmpty = false;
|
||||
}
|
||||
|
||||
public AggregateFunction(FunctionName fnName, List<Type> argTypes,
|
||||
Type retType, Type intermediateType, boolean hasVarArgs,
|
||||
URI location, String updateFnSymbol, String initFnSymbol,
|
||||
String serializeFnSymbol, String mergeFnSymbol, String getValueFnSymbol,
|
||||
String removeFnSymbol, String finalizeFnSymbol, boolean ignoresDistinct,
|
||||
boolean isAnalyticFn, boolean returnsNonNullOnEmpty, TFunctionBinaryType binaryType,
|
||||
boolean userVisible, boolean vectorized, NullableMode nullableMode) {
|
||||
// only `count` is always not nullable, other aggregate function is always nullable
|
||||
super(0, fnName, argTypes, retType, hasVarArgs, binaryType, userVisible, vectorized, nullableMode);
|
||||
setLocation(location);
|
||||
this.intermediateType = (intermediateType.equals(retType)) ? null : intermediateType;
|
||||
this.updateFnSymbol = updateFnSymbol;
|
||||
this.initFnSymbol = initFnSymbol;
|
||||
this.serializeFnSymbol = serializeFnSymbol;
|
||||
this.mergeFnSymbol = mergeFnSymbol;
|
||||
this.getValueFnSymbol = getValueFnSymbol;
|
||||
this.removeFnSymbol = removeFnSymbol;
|
||||
this.finalizeFnSymbol = finalizeFnSymbol;
|
||||
this.ignoresDistinct = ignoresDistinct;
|
||||
this.isAnalyticFn = isAnalyticFn;
|
||||
this.isAggregateFn = true;
|
||||
this.returnsNonNullOnEmpty = returnsNonNullOnEmpty;
|
||||
}
|
||||
|
||||
public static AggregateFunction createBuiltin(String name,
|
||||
List<Type> argTypes, Type retType, Type intermediateType,
|
||||
String initFnSymbol, String updateFnSymbol, String mergeFnSymbol,
|
||||
|
||||
@ -19,6 +19,7 @@ package org.apache.doris.catalog;
|
||||
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.types.DataType;
|
||||
import org.apache.doris.nereids.types.coercion.AbstractDataType;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import com.google.common.collect.Lists;
|
||||
@ -32,10 +33,10 @@ import java.util.function.BiFunction;
|
||||
public class FunctionSignature {
|
||||
public final DataType returnType;
|
||||
public final boolean hasVarArgs;
|
||||
public final List<DataType> argumentsTypes;
|
||||
public final List<AbstractDataType> argumentsTypes;
|
||||
public final int arity;
|
||||
|
||||
public FunctionSignature(DataType returnType, boolean hasVarArgs, List<DataType> argumentsTypes) {
|
||||
public FunctionSignature(DataType returnType, boolean hasVarArgs, List<AbstractDataType> argumentsTypes) {
|
||||
this.returnType = Objects.requireNonNull(returnType, "returnType is not null");
|
||||
this.argumentsTypes = ImmutableList.copyOf(
|
||||
Objects.requireNonNull(argumentsTypes, "argumentsTypes is not null"));
|
||||
@ -43,11 +44,11 @@ public class FunctionSignature {
|
||||
this.arity = argumentsTypes.size();
|
||||
}
|
||||
|
||||
public Optional<DataType> getVarArgType() {
|
||||
public Optional<AbstractDataType> getVarArgType() {
|
||||
return hasVarArgs ? Optional.of(argumentsTypes.get(arity - 1)) : Optional.empty();
|
||||
}
|
||||
|
||||
public DataType getArgType(int index) {
|
||||
public AbstractDataType getArgType(int index) {
|
||||
if (hasVarArgs && index >= arity) {
|
||||
return argumentsTypes.get(arity - 1);
|
||||
}
|
||||
@ -58,7 +59,7 @@ public class FunctionSignature {
|
||||
return new FunctionSignature(returnType, hasVarArgs, argumentsTypes);
|
||||
}
|
||||
|
||||
public FunctionSignature withArgumentTypes(boolean hasVarArgs, List<DataType> argumentsTypes) {
|
||||
public FunctionSignature withArgumentTypes(boolean hasVarArgs, List<AbstractDataType> argumentsTypes) {
|
||||
return new FunctionSignature(returnType, hasVarArgs, argumentsTypes);
|
||||
}
|
||||
|
||||
@ -69,27 +70,27 @@ public class FunctionSignature {
|
||||
* @return
|
||||
*/
|
||||
public FunctionSignature withArgumentTypes(List<Expression> arguments,
|
||||
BiFunction<DataType, Expression, DataType> transform) {
|
||||
List<DataType> newTypes = Lists.newArrayList();
|
||||
BiFunction<AbstractDataType, Expression, AbstractDataType> transform) {
|
||||
List<AbstractDataType> newTypes = Lists.newArrayList();
|
||||
for (int i = 0; i < arguments.size(); i++) {
|
||||
newTypes.add(transform.apply(getArgType(i), arguments.get(i)));
|
||||
}
|
||||
return withArgumentTypes(hasVarArgs, newTypes);
|
||||
}
|
||||
|
||||
public static FunctionSignature of(DataType returnType, List<DataType> argumentsTypes) {
|
||||
public static FunctionSignature of(DataType returnType, List<AbstractDataType> argumentsTypes) {
|
||||
return of(returnType, false, argumentsTypes);
|
||||
}
|
||||
|
||||
public static FunctionSignature of(DataType returnType, boolean hasVarArgs, List<DataType> argumentsTypes) {
|
||||
public static FunctionSignature of(DataType returnType, boolean hasVarArgs, List<AbstractDataType> argumentsTypes) {
|
||||
return new FunctionSignature(returnType, hasVarArgs, argumentsTypes);
|
||||
}
|
||||
|
||||
public static FunctionSignature of(DataType returnType, DataType... argumentsTypes) {
|
||||
public static FunctionSignature of(DataType returnType, AbstractDataType... argumentsTypes) {
|
||||
return of(returnType, false, argumentsTypes);
|
||||
}
|
||||
|
||||
public static FunctionSignature of(DataType returnType, boolean hasVarArgs, DataType... argumentsTypes) {
|
||||
public static FunctionSignature of(DataType returnType, boolean hasVarArgs, AbstractDataType... argumentsTypes) {
|
||||
return new FunctionSignature(returnType, hasVarArgs, Arrays.asList(argumentsTypes));
|
||||
}
|
||||
|
||||
@ -104,11 +105,11 @@ public class FunctionSignature {
|
||||
this.returnType = returnType;
|
||||
}
|
||||
|
||||
public FunctionSignature args(DataType...argTypes) {
|
||||
public FunctionSignature args(AbstractDataType...argTypes) {
|
||||
return FunctionSignature.of(returnType, false, argTypes);
|
||||
}
|
||||
|
||||
public FunctionSignature varArgs(DataType...argTypes) {
|
||||
public FunctionSignature varArgs(AbstractDataType...argTypes) {
|
||||
return FunctionSignature.of(returnType, true, argTypes);
|
||||
}
|
||||
}
|
||||
|
||||
@ -74,6 +74,13 @@ public class ScalarFunction extends Function {
|
||||
NullableMode.DEPEND_ON_ARGUMENT);
|
||||
}
|
||||
|
||||
/** nerieds custom scalar function */
|
||||
public ScalarFunction(FunctionName fnName, List<Type> argTypes, Type retType, boolean hasVarArgs, String symbolName,
|
||||
TFunctionBinaryType binaryType, boolean userVisible, boolean isVec, NullableMode nullableMode) {
|
||||
super(0, fnName, argTypes, retType, hasVarArgs, binaryType, userVisible, isVec, nullableMode);
|
||||
this.symbolName = symbolName;
|
||||
}
|
||||
|
||||
public ScalarFunction(FunctionName fnName, List<Type> argTypes,
|
||||
Type retType, URI location, String symbolName, String initFnSymbol,
|
||||
String closeFnSymbol) {
|
||||
|
||||
@ -18,9 +18,10 @@
|
||||
package org.apache.doris.nereids;
|
||||
|
||||
import org.apache.doris.analysis.DescriptorTable;
|
||||
import org.apache.doris.analysis.ExplainOptions;
|
||||
import org.apache.doris.analysis.StatementBase;
|
||||
import org.apache.doris.common.AnalysisException;
|
||||
import org.apache.doris.common.UserException;
|
||||
import org.apache.doris.common.NereidsException;
|
||||
import org.apache.doris.nereids.exceptions.AnalysisException;
|
||||
import org.apache.doris.nereids.glue.LogicalPlanAdapter;
|
||||
import org.apache.doris.nereids.glue.translator.PhysicalPlanTranslator;
|
||||
import org.apache.doris.nereids.glue.translator.PlanTranslatorContext;
|
||||
@ -36,6 +37,7 @@ import org.apache.doris.nereids.properties.PhysicalProperties;
|
||||
import org.apache.doris.nereids.rules.joinreorder.HyperGraphJoinReorder;
|
||||
import org.apache.doris.nereids.trees.expressions.NamedExpression;
|
||||
import org.apache.doris.nereids.trees.plans.Plan;
|
||||
import org.apache.doris.nereids.trees.plans.commands.ExplainCommand.ExplainLevel;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
|
||||
import org.apache.doris.nereids.trees.plans.physical.PhysicalPlan;
|
||||
import org.apache.doris.planner.PlanFragment;
|
||||
@ -64,18 +66,29 @@ public class NereidsPlanner extends Planner {
|
||||
private List<ScanNode> scanNodeList = null;
|
||||
private DescriptorTable descTable;
|
||||
|
||||
private Plan parsedPlan;
|
||||
private Plan analyzedPlan;
|
||||
private Plan rewrittenPlan;
|
||||
private Plan optimizedPlan;
|
||||
|
||||
public NereidsPlanner(StatementContext statementContext) {
|
||||
this.statementContext = statementContext;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void plan(StatementBase queryStmt, org.apache.doris.thrift.TQueryOptions queryOptions) throws UserException {
|
||||
public void plan(StatementBase queryStmt, org.apache.doris.thrift.TQueryOptions queryOptions) {
|
||||
if (!(queryStmt instanceof LogicalPlanAdapter)) {
|
||||
throw new RuntimeException("Wrong type of queryStmt, expected: <? extends LogicalPlanAdapter>");
|
||||
}
|
||||
|
||||
LogicalPlanAdapter logicalPlanAdapter = (LogicalPlanAdapter) queryStmt;
|
||||
PhysicalPlan physicalPlan = plan(logicalPlanAdapter.getLogicalPlan(), PhysicalProperties.ANY);
|
||||
ExplainLevel explainLevel = getExplainLevel(queryStmt.getExplainOptions());
|
||||
Plan resultPlan = plan(logicalPlanAdapter.getLogicalPlan(), PhysicalProperties.ANY, explainLevel);
|
||||
if (explainLevel.isPlanLevel) {
|
||||
return;
|
||||
}
|
||||
|
||||
PhysicalPlan physicalPlan = (PhysicalPlan) resultPlan;
|
||||
PhysicalPlanTranslator physicalPlanTranslator = new PhysicalPlanTranslator();
|
||||
PlanTranslatorContext planTranslatorContext = new PlanTranslatorContext(cascadesContext);
|
||||
if (ConnectContext.get().getSessionVariable().isEnableNereidsTrace()) {
|
||||
@ -103,20 +116,30 @@ public class NereidsPlanner extends Planner {
|
||||
public void plan(StatementBase queryStmt) {
|
||||
try {
|
||||
plan(queryStmt, statementContext.getConnectContext().getSessionVariable().toThrift());
|
||||
} catch (UserException e) {
|
||||
throw new RuntimeException(e);
|
||||
} catch (Exception e) {
|
||||
throw new NereidsException(e);
|
||||
}
|
||||
}
|
||||
|
||||
public PhysicalPlan plan(LogicalPlan plan, PhysicalProperties outputProperties) {
|
||||
return (PhysicalPlan) plan(plan, outputProperties, ExplainLevel.NONE);
|
||||
}
|
||||
|
||||
/**
|
||||
* Do analyze and optimize for query plan.
|
||||
*
|
||||
* @param plan wait for plan
|
||||
* @param outputProperties physical properties constraints
|
||||
* @return physical plan generated by this planner
|
||||
* @return plan generated by this planner
|
||||
* @throws AnalysisException throw exception if failed in ant stage
|
||||
*/
|
||||
public PhysicalPlan plan(LogicalPlan plan, PhysicalProperties outputProperties) throws AnalysisException {
|
||||
public Plan plan(LogicalPlan plan, PhysicalProperties outputProperties, ExplainLevel explainLevel) {
|
||||
if (explainLevel == ExplainLevel.PARSED_PLAN || explainLevel == ExplainLevel.ALL_PLAN) {
|
||||
parsedPlan = plan;
|
||||
if (explainLevel == ExplainLevel.PARSED_PLAN) {
|
||||
return parsedPlan;
|
||||
}
|
||||
}
|
||||
|
||||
// pre-process logical plan out of memo, e.g. process SET_VAR hint
|
||||
plan = preprocess(plan);
|
||||
@ -125,9 +148,21 @@ public class NereidsPlanner extends Planner {
|
||||
|
||||
// resolve column, table and function
|
||||
analyze();
|
||||
if (explainLevel == ExplainLevel.ANALYZED_PLAN || explainLevel == ExplainLevel.ALL_PLAN) {
|
||||
analyzedPlan = cascadesContext.getMemo().copyOut(false);
|
||||
if (explainLevel == ExplainLevel.ANALYZED_PLAN) {
|
||||
return analyzedPlan;
|
||||
}
|
||||
}
|
||||
|
||||
// rule-based optimize
|
||||
rewrite();
|
||||
if (explainLevel == ExplainLevel.REWRITTEN_PLAN || explainLevel == ExplainLevel.ALL_PLAN) {
|
||||
rewrittenPlan = cascadesContext.getMemo().copyOut(false);
|
||||
if (explainLevel == ExplainLevel.REWRITTEN_PLAN) {
|
||||
return rewrittenPlan;
|
||||
}
|
||||
}
|
||||
|
||||
deriveStats();
|
||||
|
||||
@ -142,7 +177,12 @@ public class NereidsPlanner extends Planner {
|
||||
PhysicalPlan physicalPlan = chooseBestPlan(getRoot(), PhysicalProperties.ANY);
|
||||
|
||||
// post-process physical plan out of memo, just for future use.
|
||||
return postProcess(physicalPlan);
|
||||
physicalPlan = postProcess(physicalPlan);
|
||||
if (explainLevel == ExplainLevel.OPTIMIZED_PLAN || explainLevel == ExplainLevel.ALL_PLAN) {
|
||||
optimizedPlan = physicalPlan;
|
||||
}
|
||||
|
||||
return physicalPlan;
|
||||
}
|
||||
|
||||
private LogicalPlan preprocess(LogicalPlan logicalPlan) {
|
||||
@ -229,6 +269,33 @@ public class NereidsPlanner extends Planner {
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getExplainString(ExplainOptions explainOptions) {
|
||||
ExplainLevel explainLevel = getExplainLevel(explainOptions);
|
||||
switch (explainLevel) {
|
||||
case PARSED_PLAN:
|
||||
return parsedPlan.treeString();
|
||||
case ANALYZED_PLAN:
|
||||
return analyzedPlan.treeString();
|
||||
case REWRITTEN_PLAN:
|
||||
return rewrittenPlan.treeString();
|
||||
case OPTIMIZED_PLAN:
|
||||
return optimizedPlan.treeString();
|
||||
case ALL_PLAN:
|
||||
String explainString = "========== PARSED PLAN ==========\n"
|
||||
+ parsedPlan.treeString() + "\n\n"
|
||||
+ "========== ANALYZED PLAN ==========\n"
|
||||
+ analyzedPlan.treeString() + "\n\n"
|
||||
+ "========== REWRITTEN PLAN ==========\n"
|
||||
+ rewrittenPlan.treeString() + "\n\n"
|
||||
+ "========== OPTIMIZED PLAN ==========\n"
|
||||
+ optimizedPlan.treeString();
|
||||
return explainString;
|
||||
default:
|
||||
return super.getExplainString(explainOptions);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isBlockQuery() {
|
||||
return true;
|
||||
@ -253,4 +320,12 @@ public class NereidsPlanner extends Planner {
|
||||
public CascadesContext getCascadesContext() {
|
||||
return cascadesContext;
|
||||
}
|
||||
|
||||
private ExplainLevel getExplainLevel(ExplainOptions explainOptions) {
|
||||
if (explainOptions == null) {
|
||||
return ExplainLevel.NONE;
|
||||
}
|
||||
ExplainLevel explainLevel = explainOptions.getExplainLevel();
|
||||
return explainLevel == null ? ExplainLevel.NONE : explainLevel;
|
||||
}
|
||||
}
|
||||
|
||||
@ -43,6 +43,8 @@ public class StatementContext {
|
||||
private CTEContext cteContext;
|
||||
|
||||
public StatementContext() {
|
||||
this.connectContext = ConnectContext.get();
|
||||
this.cteContext = new CTEContext();
|
||||
}
|
||||
|
||||
public StatementContext(ConnectContext connectContext, OriginStatement originStatement) {
|
||||
|
||||
@ -33,10 +33,8 @@ import org.apache.doris.analysis.FunctionParams;
|
||||
import org.apache.doris.analysis.LikePredicate;
|
||||
import org.apache.doris.analysis.SlotRef;
|
||||
import org.apache.doris.analysis.TimestampArithmeticExpr;
|
||||
import org.apache.doris.catalog.Function;
|
||||
import org.apache.doris.catalog.Function.NullableMode;
|
||||
import org.apache.doris.catalog.Type;
|
||||
import org.apache.doris.nereids.annotation.Developing;
|
||||
import org.apache.doris.nereids.exceptions.AnalysisException;
|
||||
import org.apache.doris.nereids.trees.expressions.Alias;
|
||||
import org.apache.doris.nereids.trees.expressions.And;
|
||||
@ -67,8 +65,10 @@ import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.scalar.ScalarFunction;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.Literal;
|
||||
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionVisitor;
|
||||
import org.apache.doris.nereids.types.DataType;
|
||||
import org.apache.doris.nereids.types.coercion.AbstractDataType;
|
||||
import org.apache.doris.thrift.TFunctionBinaryType;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
@ -256,68 +256,56 @@ public class ExpressionTranslator extends DefaultExpressionVisitor<Expr, PlanTra
|
||||
|
||||
// TODO: Supports for `distinct`
|
||||
@Override
|
||||
@Developing("Generate FunctionCallExpr and Function without analyze/finalize")
|
||||
public Expr visitAggregateFunction(AggregateFunction function, PlanTranslatorContext context) {
|
||||
// inputTypesBeforeDissemble is used to find the origin function's input type before disassemble aggregate.
|
||||
//
|
||||
// For example, 'double avg(int)' will be disassembled to 'varchar avg(int)' and 'double avg(varchar)'
|
||||
// which the varchar contains sum(double) and count(int).
|
||||
//
|
||||
// We save the origin input type 'int' for the global aggregate 'varchar avg(int)', and get it in the
|
||||
// 'inputTypesBeforeDissemble' variable, so we can find the catalog function 'avg(int)' in **frontend**.
|
||||
//
|
||||
// Vectorized engine in backend will find the 'avg(int)' function, and forwarding to the correct global
|
||||
// aggregate function 'double avg(varchar)' by FunctionCallExpr.isMergeAggFn.
|
||||
Optional<List<Type>> inputTypesBeforeDissemble = function.inputTypesBeforeDissemble()
|
||||
.map(types -> types.stream()
|
||||
.map(DataType::toCatalogDataType)
|
||||
.collect(Collectors.toList())
|
||||
);
|
||||
List<Expr> catalogArguments = function.getArguments()
|
||||
.stream()
|
||||
.map(arg -> arg.accept(this, context))
|
||||
.collect(ImmutableList.toImmutableList());
|
||||
|
||||
// We should change the global aggregate function's temporary input type(varchar) to the origin input type.
|
||||
//
|
||||
// For example: the global aggregate function expression 'avg(slotRef(type=varchar))' of the origin function
|
||||
// 'avg(int)' should change to 'avg(slotRef(type=int))', because FunctionCallExpr will be converted to thrift
|
||||
// format, and compute signature string by the children's type, we should pass the signature 'avg(int)' to
|
||||
// **backend**. If we pass 'avg(varchar)' to backend, it will throw an exception: 'Agg Function avg is not
|
||||
// implemented'.
|
||||
List<Expr> catalogParams = new ArrayList<>();
|
||||
for (int i = 0; i < function.arity(); i++) {
|
||||
Expr catalogExpr = function.child(i).accept(this, context);
|
||||
if (catalogExpr instanceof SlotRef && inputTypesBeforeDissemble.isPresent()
|
||||
// count(*) in local aggregate contains empty children
|
||||
// but contains one child in global aggregate: 'count(count(*))'.
|
||||
// so the size of inputTypesBeforeDissemble maybe less than global aggregate param.
|
||||
&& inputTypesBeforeDissemble.get().size() > i) {
|
||||
SlotRef intermediateSlot = (SlotRef) catalogExpr.clone();
|
||||
// change the slot type to origin input type
|
||||
intermediateSlot.setType(inputTypesBeforeDissemble.get().get(i));
|
||||
catalogExpr = intermediateSlot;
|
||||
}
|
||||
catalogParams.add(catalogExpr);
|
||||
// aggFnArguments is used to build TAggregateExpr.param_types, so backend can find the aggregate function
|
||||
List<Expr> aggFnArguments = function.getArgumentsBeforeDisassembled()
|
||||
.stream()
|
||||
.map(arg -> new SlotRef(arg.getDataType().toCatalogDataType(), arg.nullable()))
|
||||
.collect(ImmutableList.toImmutableList());
|
||||
|
||||
FunctionParams aggFnParams;
|
||||
if (function instanceof Count && ((Count) function).isStar()) {
|
||||
aggFnParams = FunctionParams.createStarParam();
|
||||
} else {
|
||||
aggFnParams = new FunctionParams(function.isDistinct(), aggFnArguments);
|
||||
}
|
||||
|
||||
boolean distinct = function.isDistinct();
|
||||
FunctionParams aggFnParams = new FunctionParams(distinct, catalogParams);
|
||||
ImmutableList<Type> argTypes = catalogArguments.stream()
|
||||
.map(arg -> arg.getType())
|
||||
.collect(ImmutableList.toImmutableList());
|
||||
|
||||
if (function instanceof Count) {
|
||||
Count count = (Count) function;
|
||||
if (count.isStar()) {
|
||||
return new FunctionCallExpr(function.getName(), FunctionParams.createStarParam(),
|
||||
aggFnParams, inputTypesBeforeDissemble);
|
||||
} else if (count.isDistinct()) {
|
||||
return new FunctionCallExpr(function.getName(), new FunctionParams(distinct, catalogParams),
|
||||
aggFnParams, inputTypesBeforeDissemble);
|
||||
}
|
||||
}
|
||||
return new FunctionCallExpr(function.getName(), new FunctionParams(distinct, catalogParams),
|
||||
aggFnParams, inputTypesBeforeDissemble);
|
||||
NullableMode nullableMode = function.nullable()
|
||||
? NullableMode.ALWAYS_NULLABLE
|
||||
: NullableMode.ALWAYS_NOT_NULLABLE;
|
||||
|
||||
boolean isAnalyticFunction = false;
|
||||
org.apache.doris.catalog.AggregateFunction catalogFunction = new org.apache.doris.catalog.AggregateFunction(
|
||||
new FunctionName(function.getName()), argTypes,
|
||||
function.getDataType().toCatalogDataType(),
|
||||
function.getIntermediateTypes().toCatalogDataType(),
|
||||
function.hasVarArguments(),
|
||||
null, "", "", null, "",
|
||||
null, "", null, false,
|
||||
isAnalyticFunction, false, TFunctionBinaryType.BUILTIN,
|
||||
true, true, nullableMode
|
||||
);
|
||||
|
||||
boolean isMergeFn = function.isGlobal() && function.isDisassembled();
|
||||
|
||||
// create catalog FunctionCallExpr without analyze again
|
||||
return new FunctionCallExpr(catalogFunction, aggFnParams, aggFnParams, isMergeFn, catalogArguments);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Expr visitScalarFunction(ScalarFunction function, PlanTranslatorContext context) {
|
||||
List<Expr> arguments = function.getArguments()
|
||||
.stream().map(arg -> arg.accept(this, context))
|
||||
.stream()
|
||||
.map(arg -> arg.accept(this, context))
|
||||
.collect(Collectors.toList());
|
||||
List<Type> argTypes = function.expectedInputTypes().stream()
|
||||
.map(AbstractDataType::toCatalogDataType)
|
||||
@ -327,12 +315,13 @@ public class ExpressionTranslator extends DefaultExpressionVisitor<Expr, PlanTra
|
||||
? NullableMode.ALWAYS_NULLABLE
|
||||
: NullableMode.ALWAYS_NOT_NULLABLE;
|
||||
|
||||
Function catalogFunction = new Function(new FunctionName(function.getName()), argTypes,
|
||||
function.getDataType().toCatalogDataType(), function.hasVarArguments(), true, nullableMode);
|
||||
org.apache.doris.catalog.ScalarFunction catalogFunction = new org.apache.doris.catalog.ScalarFunction(
|
||||
new FunctionName(function.getName()), argTypes,
|
||||
function.getDataType().toCatalogDataType(), function.hasVarArguments(),
|
||||
"", TFunctionBinaryType.BUILTIN, true, true, nullableMode);
|
||||
|
||||
// create catalog FunctionCallExpr without analyze again
|
||||
return new FunctionCallExpr(catalogFunction.getFunctionName(), catalogFunction,
|
||||
new FunctionParams(false, arguments));
|
||||
return new FunctionCallExpr(catalogFunction, new FunctionParams(false, arguments));
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@ -217,11 +217,7 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla
|
||||
outputTupleDesc = localAggNode.getAggInfo().getOutputTupleDesc();
|
||||
}
|
||||
|
||||
if (aggregate.getAggPhase() == AggPhase.GLOBAL) {
|
||||
for (FunctionCallExpr execAggregateFunction : execAggregateFunctions) {
|
||||
execAggregateFunction.setMergeForNereids(true);
|
||||
}
|
||||
}
|
||||
// TODO: move setMergeForNereids to ExpressionTranslator.visitAggregateFunction
|
||||
if (aggregate.getAggPhase() == AggPhase.DISTINCT_LOCAL) {
|
||||
for (FunctionCallExpr execAggregateFunction : execAggregateFunctions) {
|
||||
if (!execAggregateFunction.isDistinct()) {
|
||||
|
||||
@ -53,6 +53,7 @@ import org.apache.doris.nereids.DorisParser.NamedExpressionContext;
|
||||
import org.apache.doris.nereids.DorisParser.NamedExpressionSeqContext;
|
||||
import org.apache.doris.nereids.DorisParser.NullLiteralContext;
|
||||
import org.apache.doris.nereids.DorisParser.ParenthesizedExpressionContext;
|
||||
import org.apache.doris.nereids.DorisParser.PlanTypeContext;
|
||||
import org.apache.doris.nereids.DorisParser.PredicateContext;
|
||||
import org.apache.doris.nereids.DorisParser.PredicatedContext;
|
||||
import org.apache.doris.nereids.DorisParser.QualifiedNameContext;
|
||||
@ -219,8 +220,10 @@ public class LogicalPlanBuilder extends DorisParserBaseVisitor<Object> {
|
||||
List<Pair<LogicalPlan, StatementContext>> logicalPlans = Lists.newArrayList();
|
||||
for (org.apache.doris.nereids.DorisParser.StatementContext statement : ctx.statement()) {
|
||||
StatementContext statementContext = new StatementContext();
|
||||
if (ConnectContext.get() != null) {
|
||||
ConnectContext.get().setStatementContext(statementContext);
|
||||
ConnectContext connectContext = ConnectContext.get();
|
||||
if (connectContext != null) {
|
||||
connectContext.setStatementContext(statementContext);
|
||||
statementContext.setConnectContext(connectContext);
|
||||
}
|
||||
logicalPlans.add(Pair.of(
|
||||
ParserUtils.withOrigin(ctx, () -> (LogicalPlan) visit(statement)), statementContext));
|
||||
@ -258,12 +261,25 @@ public class LogicalPlanBuilder extends DorisParserBaseVisitor<Object> {
|
||||
|
||||
@Override
|
||||
public Command visitExplain(ExplainContext ctx) {
|
||||
LogicalPlan logicalPlan = plan(ctx.query());
|
||||
ExplainLevel explainLevel = ExplainLevel.NORMAL;
|
||||
if (ctx.level != null) {
|
||||
explainLevel = ExplainLevel.valueOf(ctx.level.getText().toUpperCase(Locale.ROOT));
|
||||
}
|
||||
return new ExplainCommand(explainLevel, logicalPlan);
|
||||
return ParserUtils.withOrigin(ctx, () -> {
|
||||
LogicalPlan logicalPlan = plan(ctx.query());
|
||||
ExplainLevel explainLevel = ExplainLevel.NORMAL;
|
||||
|
||||
if (ctx.planType() != null) {
|
||||
if (ctx.level == null || !ctx.level.getText().equalsIgnoreCase("plan")) {
|
||||
throw new ParseException("Only explain plan can use plan type: " + ctx.planType().getText(), ctx);
|
||||
}
|
||||
}
|
||||
|
||||
if (ctx.level != null) {
|
||||
if (!ctx.level.getText().equalsIgnoreCase("plan")) {
|
||||
explainLevel = ExplainLevel.valueOf(ctx.level.getText().toUpperCase(Locale.ROOT));
|
||||
} else {
|
||||
explainLevel = parseExplainPlanType(ctx.planType());
|
||||
}
|
||||
}
|
||||
return new ExplainCommand(explainLevel, logicalPlan);
|
||||
});
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -1119,4 +1135,23 @@ public class LogicalPlanBuilder extends DorisParserBaseVisitor<Object> {
|
||||
}
|
||||
return item.getText();
|
||||
}
|
||||
|
||||
private ExplainLevel parseExplainPlanType(PlanTypeContext planTypeContext) {
|
||||
if (planTypeContext == null || planTypeContext.ALL() != null) {
|
||||
return ExplainLevel.ALL_PLAN;
|
||||
}
|
||||
if (planTypeContext.PHYSICAL() != null || planTypeContext.OPTIMIZED() != null) {
|
||||
return ExplainLevel.OPTIMIZED_PLAN;
|
||||
}
|
||||
if (planTypeContext.REWRITTEN() != null || planTypeContext.LOGICAL() != null) {
|
||||
return ExplainLevel.REWRITTEN_PLAN;
|
||||
}
|
||||
if (planTypeContext.ANALYZED() != null) {
|
||||
return ExplainLevel.ANALYZED_PLAN;
|
||||
}
|
||||
if (planTypeContext.PARSED() != null) {
|
||||
return ExplainLevel.PARSED_PLAN;
|
||||
}
|
||||
return ExplainLevel.ALL_PLAN;
|
||||
}
|
||||
}
|
||||
|
||||
@ -53,19 +53,20 @@ public class NereidsParser {
|
||||
public List<StatementBase> parseSQL(String originStr) {
|
||||
List<Pair<LogicalPlan, StatementContext>> logicalPlans = parseMultiple(originStr);
|
||||
List<StatementBase> statementBases = Lists.newArrayList();
|
||||
for (Pair<LogicalPlan, StatementContext> logicalPlan : logicalPlans) {
|
||||
for (Pair<LogicalPlan, StatementContext> parsedPlanToContext : logicalPlans) {
|
||||
// TODO: this is a trick to support explain. Since we do not support any other command in a short time.
|
||||
// It is acceptable. In the future, we need to refactor this.
|
||||
if (logicalPlan.first instanceof ExplainCommand) {
|
||||
ExplainCommand explainCommand = (ExplainCommand) logicalPlan.first;
|
||||
StatementContext statementContext = parsedPlanToContext.second;
|
||||
if (parsedPlanToContext.first instanceof ExplainCommand) {
|
||||
ExplainCommand explainCommand = (ExplainCommand) parsedPlanToContext.first;
|
||||
LogicalPlan innerPlan = explainCommand.getLogicalPlan();
|
||||
LogicalPlanAdapter logicalPlanAdapter = new LogicalPlanAdapter(innerPlan, logicalPlan.second);
|
||||
logicalPlanAdapter.setIsExplain(new ExplainOptions(
|
||||
explainCommand.getLevel() == ExplainLevel.VERBOSE,
|
||||
explainCommand.getLevel() == ExplainLevel.GRAPH));
|
||||
LogicalPlanAdapter logicalPlanAdapter = new LogicalPlanAdapter(innerPlan, statementContext);
|
||||
ExplainLevel explainLevel = explainCommand.getLevel();
|
||||
ExplainOptions explainOptions = new ExplainOptions(explainLevel);
|
||||
logicalPlanAdapter.setIsExplain(explainOptions);
|
||||
statementBases.add(logicalPlanAdapter);
|
||||
} else {
|
||||
statementBases.add(new LogicalPlanAdapter(logicalPlan.first, logicalPlan.second));
|
||||
statementBases.add(new LogicalPlanAdapter(parsedPlanToContext.first, statementContext));
|
||||
}
|
||||
}
|
||||
return statementBases;
|
||||
|
||||
@ -170,8 +170,8 @@ public class BindFunction implements AnalysisRuleFactory {
|
||||
return new Count();
|
||||
}
|
||||
if (arguments.size() == 1) {
|
||||
boolean isGlobalAgg = true;
|
||||
AggregateParam aggregateParam = new AggregateParam(unboundFunction.isDistinct(), isGlobalAgg);
|
||||
AggregateParam aggregateParam = new AggregateParam(
|
||||
unboundFunction.isDistinct(), true, false);
|
||||
return new Count(aggregateParam, unboundFunction.getArguments().get(0));
|
||||
}
|
||||
}
|
||||
|
||||
@ -28,7 +28,6 @@ import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunctio
|
||||
import org.apache.doris.nereids.trees.plans.AggPhase;
|
||||
import org.apache.doris.nereids.trees.plans.GroupPlan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
|
||||
import org.apache.doris.nereids.types.DataType;
|
||||
import org.apache.doris.nereids.util.ExpressionUtils;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
@ -191,20 +190,15 @@ public class AggregateDisassemble extends OneRewriteRuleFactory {
|
||||
AggregateFunction localAggregateFunction = aggregateFunction.withAggregateParam(
|
||||
aggregateFunction.getAggregateParam()
|
||||
.withDistinct(false)
|
||||
.withGlobal(false)
|
||||
.withGlobalAndDisassembled(false, true)
|
||||
);
|
||||
NamedExpression localOutputExpr = new Alias(localAggregateFunction, aggregateFunction.toSql());
|
||||
|
||||
List<DataType> inputTypesBeforeDissemble = aggregateFunction.children()
|
||||
.stream()
|
||||
.map(Expression::getDataType)
|
||||
.collect(Collectors.toList());
|
||||
AggregateFunction substitutionValue = aggregateFunction
|
||||
// save the origin input types to the global aggregate functions
|
||||
.withAggregateParam(aggregateFunction.getAggregateParam()
|
||||
.withDistinct(false)
|
||||
.withGlobal(true)
|
||||
.withInputTypesBeforeDissemble(Optional.of(inputTypesBeforeDissemble)))
|
||||
.withGlobalAndDisassembled(true, true))
|
||||
.withChildren(Lists.newArrayList(localOutputExpr.toSlot()));
|
||||
|
||||
inputSubstitutionMap.put(aggregateFunction, substitutionValue);
|
||||
|
||||
@ -36,7 +36,7 @@ public class Cast extends Expression implements UnaryExpression {
|
||||
|
||||
public Cast(Expression child, DataType targetType) {
|
||||
super(child);
|
||||
this.targetType = targetType;
|
||||
this.targetType = Objects.requireNonNull(targetType, "targetType can not be null");
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -67,7 +67,7 @@ public class Cast extends Expression implements UnaryExpression {
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return toSql();
|
||||
return "CAST(" + child() + " AS " + targetType + ")";
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@ -60,6 +60,6 @@ public class TVFProperties extends Expression implements LeafExpression {
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "KeyValuesExpression(" + toSql() + ")";
|
||||
return "TVFProperties(" + toSql() + ")";
|
||||
}
|
||||
}
|
||||
|
||||
@ -17,18 +17,44 @@
|
||||
|
||||
package org.apache.doris.nereids.trees.expressions.functions;
|
||||
|
||||
import org.apache.doris.catalog.FunctionSignature;
|
||||
import org.apache.doris.catalog.ScalarType;
|
||||
import org.apache.doris.catalog.Type;
|
||||
import org.apache.doris.common.Config;
|
||||
import org.apache.doris.common.Pair;
|
||||
import org.apache.doris.nereids.annotation.Developing;
|
||||
import org.apache.doris.nereids.exceptions.UnboundException;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
|
||||
import org.apache.doris.nereids.types.ArrayType;
|
||||
import org.apache.doris.nereids.types.DataType;
|
||||
import org.apache.doris.nereids.types.DateTimeV2Type;
|
||||
import org.apache.doris.nereids.types.DecimalV2Type;
|
||||
import org.apache.doris.nereids.types.DecimalV3Type;
|
||||
import org.apache.doris.nereids.types.coercion.AbstractDataType;
|
||||
import org.apache.doris.nereids.util.ResponsibilityChain;
|
||||
|
||||
import com.google.common.base.Suppliers;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.function.BiFunction;
|
||||
import java.util.function.Supplier;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/** BoundFunction. */
|
||||
public abstract class BoundFunction extends Expression implements FunctionTrait {
|
||||
public abstract class BoundFunction extends Expression implements FunctionTrait, ComputeSignature {
|
||||
private final String name;
|
||||
|
||||
private final Supplier<FunctionSignature> signatureCache = Suppliers.memoize(() -> {
|
||||
// first step: find the candidate signature in the signature list
|
||||
List<Expression> originArguments = getOriginArguments();
|
||||
FunctionSignature matchedSignature = searchSignature(
|
||||
getOriginArgumentTypes(), originArguments, getSignatures());
|
||||
// second step: change the signature, e.g. fill precision for decimal v2
|
||||
return computeSignature(matchedSignature, originArguments);
|
||||
});
|
||||
|
||||
public BoundFunction(String name, Expression... arguments) {
|
||||
super(arguments);
|
||||
this.name = Objects.requireNonNull(name, "name can not be null");
|
||||
@ -39,10 +65,40 @@ public abstract class BoundFunction extends Expression implements FunctionTrait
|
||||
this.name = Objects.requireNonNull(name, "name can not be null");
|
||||
}
|
||||
|
||||
protected FunctionSignature computeSignature(FunctionSignature signature, List<Expression> arguments) {
|
||||
// NOTE:
|
||||
// this computed chain only process the common cases.
|
||||
// If you want to add some common cases to here, please separate the process code
|
||||
// to the other methods and add to this chain.
|
||||
// If you want to add some special cases, please override this method in the special
|
||||
// function class, like 'If' function and 'Substring' function.
|
||||
return ComputeSignatureChain.from(signature, arguments)
|
||||
.then(this::computePrecisionForDatetimeV2)
|
||||
.then(this::upgradeDateOrDateTimeToV2)
|
||||
.then(this::upgradeDecimalV2ToV3)
|
||||
.then(this::computePrecisionForDecimal)
|
||||
.then(this::dynamicComputePropertiesOfArray)
|
||||
.get();
|
||||
}
|
||||
|
||||
public String getName() {
|
||||
return name;
|
||||
}
|
||||
|
||||
public FunctionSignature getSignature() {
|
||||
return signatureCache.get();
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<AbstractDataType> expectedInputTypes() {
|
||||
return ComputeSignature.super.expectedInputTypes();
|
||||
}
|
||||
|
||||
@Override
|
||||
public DataType getDataType() {
|
||||
return ComputeSignature.super.getDataType();
|
||||
}
|
||||
|
||||
@Override
|
||||
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
|
||||
return visitor.visitBoundFunction(this, context);
|
||||
@ -82,4 +138,114 @@ public abstract class BoundFunction extends Expression implements FunctionTrait
|
||||
.collect(Collectors.joining(", "));
|
||||
return name + "(" + args + ")";
|
||||
}
|
||||
|
||||
private FunctionSignature computePrecisionForDatetimeV2(
|
||||
FunctionSignature signature, List<Expression> arguments) {
|
||||
|
||||
// fill for arguments type
|
||||
signature = signature.withArgumentTypes(arguments, (sigArgType, realArgType) -> {
|
||||
if (sigArgType instanceof DateTimeV2Type && realArgType.getDataType() instanceof DateTimeV2Type) {
|
||||
return realArgType.getDataType();
|
||||
}
|
||||
return sigArgType;
|
||||
});
|
||||
|
||||
// fill for return type
|
||||
if (signature.returnType instanceof DateTimeV2Type) {
|
||||
Integer maxScale = signature.argumentsTypes.stream()
|
||||
.filter(DateTimeV2Type.class::isInstance)
|
||||
.map(t -> ((DateTimeV2Type) t).getScale())
|
||||
.reduce(Math::max)
|
||||
.orElse(((DateTimeV2Type) signature.returnType).getScale());
|
||||
signature = signature.withReturnType(DateTimeV2Type.of(maxScale));
|
||||
}
|
||||
|
||||
return signature;
|
||||
}
|
||||
|
||||
private FunctionSignature upgradeDateOrDateTimeToV2(
|
||||
FunctionSignature signature, List<Expression> arguments) {
|
||||
DataType returnType = signature.returnType;
|
||||
Type type = returnType.toCatalogDataType();
|
||||
if ((type.isDate() || type.isDatetime()) && Config.enable_date_conversion) {
|
||||
Type legacyReturnType = ScalarType.getDefaultDateType(returnType.toCatalogDataType());
|
||||
signature = signature.withReturnType(DataType.fromCatalogType(legacyReturnType));
|
||||
}
|
||||
return signature;
|
||||
}
|
||||
|
||||
@Developing
|
||||
private FunctionSignature computePrecisionForDecimal(
|
||||
FunctionSignature signature, List<Expression> arguments) {
|
||||
if (signature.returnType instanceof DecimalV3Type || signature.returnType instanceof DecimalV2Type) {
|
||||
if (this instanceof DecimalSamePrecision) {
|
||||
signature = signature.withReturnType(arguments.get(0).getDataType());
|
||||
} else if (this instanceof DecimalWiderPrecision) {
|
||||
ScalarType widerType = ScalarType.createDecimalV3Type(ScalarType.MAX_DECIMAL128_PRECISION,
|
||||
((ScalarType) arguments.get(0).getDataType().toCatalogDataType()).getScalarScale());
|
||||
signature = signature.withReturnType(DataType.fromCatalogType(widerType));
|
||||
} else if (this instanceof DecimalStddevPrecision) {
|
||||
// for all stddev function, use decimal(38,9) as computing result
|
||||
ScalarType stddevDecimalType = ScalarType.createDecimalV3Type(ScalarType.MAX_DECIMAL128_PRECISION,
|
||||
DecimalStddevPrecision.STDDEV_DECIMAL_SCALE);
|
||||
signature = signature.withReturnType(DataType.fromCatalogType(stddevDecimalType));
|
||||
}
|
||||
}
|
||||
|
||||
return signature;
|
||||
}
|
||||
|
||||
private FunctionSignature upgradeDecimalV2ToV3(
|
||||
FunctionSignature signature, List<Expression> arguments) {
|
||||
DataType returnType = signature.returnType;
|
||||
Type type = returnType.toCatalogDataType();
|
||||
if (type.isDecimalV2() && Config.enable_decimal_conversion && Config.enable_decimalv3) {
|
||||
Type v3Type = ScalarType.createDecimalV3Type(type.getPrecision(), ((ScalarType) type).getScalarScale());
|
||||
signature = signature.withReturnType(DataType.fromCatalogType(v3Type));
|
||||
}
|
||||
return signature;
|
||||
}
|
||||
|
||||
private FunctionSignature dynamicComputePropertiesOfArray(
|
||||
FunctionSignature signature, List<Expression> arguments) {
|
||||
if (!(signature.returnType instanceof ArrayType)) {
|
||||
return signature;
|
||||
}
|
||||
|
||||
// TODO: compute array(...) function's itemType
|
||||
|
||||
// fill item type by the type of first item
|
||||
ArrayType arrayType = (ArrayType) signature.returnType;
|
||||
|
||||
// fill containsNull if any array argument contains null
|
||||
boolean containsNull = signature.argumentsTypes
|
||||
.stream()
|
||||
.filter(argType -> argType instanceof ArrayType)
|
||||
.map(ArrayType.class::cast)
|
||||
.anyMatch(ArrayType::containsNull);
|
||||
return signature.withReturnType(
|
||||
ArrayType.of(arrayType.getItemType(), arrayType.containsNull() || containsNull));
|
||||
}
|
||||
|
||||
static class ComputeSignatureChain {
|
||||
private ResponsibilityChain<Pair<FunctionSignature, List<Expression>>> computeChain;
|
||||
|
||||
public ComputeSignatureChain(ResponsibilityChain<Pair<FunctionSignature, List<Expression>>> computeChain) {
|
||||
this.computeChain = computeChain;
|
||||
}
|
||||
|
||||
public static ComputeSignatureChain from(FunctionSignature signature, List<Expression> arguments) {
|
||||
return new ComputeSignatureChain(ResponsibilityChain.from(Pair.of(signature, arguments)));
|
||||
}
|
||||
|
||||
public ComputeSignatureChain then(
|
||||
BiFunction<FunctionSignature, List<Expression>, FunctionSignature> computeFunction) {
|
||||
computeChain.then(pair -> Pair.of(computeFunction.apply(pair.first, pair.second), pair.second));
|
||||
return this;
|
||||
}
|
||||
|
||||
public FunctionSignature get() {
|
||||
return computeChain.get().first;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -19,6 +19,7 @@ package org.apache.doris.nereids.trees.expressions.functions;
|
||||
|
||||
import org.apache.doris.catalog.FunctionSignature;
|
||||
import org.apache.doris.nereids.annotation.Developing;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.typecoercion.ImplicitCastInputTypes;
|
||||
import org.apache.doris.nereids.types.DataType;
|
||||
import org.apache.doris.nereids.types.coercion.AbstractDataType;
|
||||
@ -47,7 +48,8 @@ public interface ComputeSignature extends FunctionTrait, ImplicitCastInputTypes
|
||||
*
|
||||
* @return the matched signature
|
||||
*/
|
||||
FunctionSignature searchSignature();
|
||||
FunctionSignature searchSignature(List<DataType> argumentTypes, List<Expression> arguments,
|
||||
List<FunctionSignature> signatures);
|
||||
|
||||
///// re-defined other interface's methods, so we can mixin this interfaces like a trait /////
|
||||
|
||||
@ -64,7 +66,7 @@ public interface ComputeSignature extends FunctionTrait, ImplicitCastInputTypes
|
||||
*/
|
||||
@Override
|
||||
default List<AbstractDataType> expectedInputTypes() {
|
||||
return (List) getSignature().argumentsTypes;
|
||||
return getSignature().argumentsTypes;
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@ -0,0 +1,47 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package org.apache.doris.nereids.trees.expressions.functions;
|
||||
|
||||
import org.apache.doris.catalog.FunctionSignature;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.types.DataType;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/** CustomSignature */
|
||||
public interface CustomSignature extends ComputeSignature {
|
||||
|
||||
// custom generate a function signature.
|
||||
FunctionSignature customSignature(List<DataType> argumentTypes, List<Expression> arguments);
|
||||
|
||||
@Override
|
||||
default List<FunctionSignature> getSignatures() {
|
||||
List<DataType> originArgumentTypes = getOriginArgumentTypes();
|
||||
List<Expression> originArguments = getOriginArguments();
|
||||
return ImmutableList.of(customSignature(originArgumentTypes, originArguments));
|
||||
}
|
||||
|
||||
// use the first signature as the candidate signature.
|
||||
@Override
|
||||
default FunctionSignature searchSignature(List<DataType> argumentTypes, List<Expression> arguments,
|
||||
List<FunctionSignature> signatures) {
|
||||
return signatures.get(0);
|
||||
}
|
||||
}
|
||||
@ -41,17 +41,17 @@ public abstract class DateTimeWithPrecision extends ScalarFunction {
|
||||
}
|
||||
|
||||
@Override
|
||||
protected FunctionSignature computeSignature(FunctionSignature signature) {
|
||||
protected FunctionSignature computeSignature(FunctionSignature signature, List<Expression> arguments) {
|
||||
if (arity() == 1 && signature.returnType instanceof DateTimeV2Type) {
|
||||
// For functions in TIME_FUNCTIONS_WITH_PRECISION, we can't figure out which function should be use when
|
||||
// searching in FunctionSet. So we adjust the return type by hand here.
|
||||
if (child(0) instanceof IntegerLikeLiteral) {
|
||||
IntegerLikeLiteral integerLikeLiteral = (IntegerLikeLiteral) child(0);
|
||||
if (arguments.get(0) instanceof IntegerLikeLiteral) {
|
||||
IntegerLikeLiteral integerLikeLiteral = (IntegerLikeLiteral) arguments.get(0);
|
||||
signature = signature.withReturnType(DateTimeV2Type.of(integerLikeLiteral.getIntValue()));
|
||||
} else {
|
||||
signature = signature.withReturnType(DateTimeV2Type.of(6));
|
||||
}
|
||||
}
|
||||
return super.computeSignature(signature);
|
||||
return super.computeSignature(signature, arguments);
|
||||
}
|
||||
}
|
||||
|
||||
@ -19,8 +19,12 @@ package org.apache.doris.nereids.trees.expressions.functions;
|
||||
|
||||
import org.apache.doris.catalog.FunctionSignature;
|
||||
import org.apache.doris.catalog.Type;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.types.DataType;
|
||||
import org.apache.doris.nereids.types.coercion.AbstractDataType;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* Explicitly castable signature. This class equals to 'CompareMode.IS_NONSTRICT_SUPERTYPE_OF'.
|
||||
*
|
||||
@ -34,8 +38,9 @@ public interface ExplicitlyCastableSignature extends ComputeSignature {
|
||||
}
|
||||
|
||||
@Override
|
||||
default FunctionSignature searchSignature() {
|
||||
return SearchSignature.from(getSignatures(), getArguments())
|
||||
default FunctionSignature searchSignature(List<DataType> argumentTypes, List<Expression> arguments,
|
||||
List<FunctionSignature> signatures) {
|
||||
return SearchSignature.from(signatures, arguments)
|
||||
// first round, use identical strategy to find signature
|
||||
.orElseSearch(IdenticalSignature::isIdentical)
|
||||
// second round: if not found, use nullOrIdentical strategy
|
||||
|
||||
@ -39,6 +39,10 @@ public interface ExpressionTrait extends TreeNode<Expression> {
|
||||
return children();
|
||||
}
|
||||
|
||||
default Expression getArgument(int index) {
|
||||
return child(index);
|
||||
}
|
||||
|
||||
default DataType getDataType() throws UnboundException {
|
||||
throw new UnboundException("dataType");
|
||||
}
|
||||
|
||||
@ -17,6 +17,13 @@
|
||||
|
||||
package org.apache.doris.nereids.trees.expressions.functions;
|
||||
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.types.DataType;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* FunctionTrait.
|
||||
*/
|
||||
@ -24,4 +31,15 @@ public interface FunctionTrait extends ExpressionTrait {
|
||||
String getName();
|
||||
|
||||
boolean hasVarArguments();
|
||||
|
||||
default List<Expression> getOriginArguments() {
|
||||
return getArguments();
|
||||
}
|
||||
|
||||
default List<DataType> getOriginArgumentTypes() {
|
||||
return getArguments()
|
||||
.stream()
|
||||
.map(Expression::getDataType)
|
||||
.collect(ImmutableList.toImmutableList());
|
||||
}
|
||||
}
|
||||
|
||||
@ -18,8 +18,12 @@
|
||||
package org.apache.doris.nereids.trees.expressions.functions;
|
||||
|
||||
import org.apache.doris.catalog.FunctionSignature;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.types.DataType;
|
||||
import org.apache.doris.nereids.types.coercion.AbstractDataType;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* Identical function signature. This class equals to 'CompareMode.IS_IDENTICAL'.
|
||||
*
|
||||
@ -33,8 +37,9 @@ public interface IdenticalSignature extends ComputeSignature {
|
||||
}
|
||||
|
||||
@Override
|
||||
default FunctionSignature searchSignature() {
|
||||
return SearchSignature.from(getSignatures(), getArguments())
|
||||
default FunctionSignature searchSignature(List<DataType> argumentTypes, List<Expression> arguments,
|
||||
List<FunctionSignature> signatures) {
|
||||
return SearchSignature.from(signatures, arguments)
|
||||
// first round, use identical strategy to find signature
|
||||
.orElseSearch(IdenticalSignature::isIdentical)
|
||||
.resultOrException(getName());
|
||||
|
||||
@ -19,8 +19,12 @@ package org.apache.doris.nereids.trees.expressions.functions;
|
||||
|
||||
import org.apache.doris.catalog.FunctionSignature;
|
||||
import org.apache.doris.catalog.Type;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.types.DataType;
|
||||
import org.apache.doris.nereids.types.coercion.AbstractDataType;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* Implicitly castable function signature. This class equals to 'CompareMode.IS_SUPERTYPE_OF'.
|
||||
*
|
||||
@ -34,8 +38,9 @@ public interface ImplicitlyCastableSignature extends ComputeSignature {
|
||||
}
|
||||
|
||||
@Override
|
||||
default FunctionSignature searchSignature() {
|
||||
return SearchSignature.from(getSignatures(), getArguments())
|
||||
default FunctionSignature searchSignature(List<DataType> argumentTypes, List<Expression> arguments,
|
||||
List<FunctionSignature> signatures) {
|
||||
return SearchSignature.from(signatures, arguments)
|
||||
// first round, use identical strategy to find signature
|
||||
.orElseSearch(IdenticalSignature::isIdentical)
|
||||
// second round: if not found, use nullOrIdentical strategy
|
||||
|
||||
@ -18,9 +18,13 @@
|
||||
package org.apache.doris.nereids.trees.expressions.functions;
|
||||
|
||||
import org.apache.doris.catalog.FunctionSignature;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.types.DataType;
|
||||
import org.apache.doris.nereids.types.NullType;
|
||||
import org.apache.doris.nereids.types.coercion.AbstractDataType;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* Null or identical function signature. This class equals to 'CompareMode.IS_INDISTINGUISHABLE'.
|
||||
*
|
||||
@ -35,8 +39,9 @@ public interface NullOrIdenticalSignature extends ComputeSignature {
|
||||
}
|
||||
|
||||
@Override
|
||||
default FunctionSignature searchSignature() {
|
||||
return SearchSignature.from(getSignatures(), getArguments())
|
||||
default FunctionSignature searchSignature(List<DataType> argumentTypes, List<Expression> arguments,
|
||||
List<FunctionSignature> signatures) {
|
||||
return SearchSignature.from(signatures, arguments)
|
||||
// first round, use identical strategy to find signature
|
||||
.orElseSearch(IdenticalSignature::isIdentical)
|
||||
// second round: if not found, use nullOrIdentical strategy
|
||||
|
||||
@ -19,17 +19,21 @@ package org.apache.doris.nereids.trees.expressions.functions.agg;
|
||||
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.BoundFunction;
|
||||
import org.apache.doris.nereids.trees.expressions.typecoercion.ExpectsInputTypes;
|
||||
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
|
||||
import org.apache.doris.nereids.types.DataType;
|
||||
import org.apache.doris.nereids.types.PartialAggType;
|
||||
import org.apache.doris.nereids.types.coercion.AbstractDataType;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.Optional;
|
||||
|
||||
/**
|
||||
* The function which consume arguments in lots of rows and product one value.
|
||||
*/
|
||||
public abstract class AggregateFunction extends BoundFunction {
|
||||
public abstract class AggregateFunction extends BoundFunction implements ExpectsInputTypes {
|
||||
|
||||
private final AggregateParam aggregateParam;
|
||||
|
||||
@ -42,15 +46,69 @@ public abstract class AggregateFunction extends BoundFunction {
|
||||
this.aggregateParam = Objects.requireNonNull(aggregateParam, "aggregateParam can not be null");
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<Expression> getOriginArguments() {
|
||||
return getArgumentsBeforeDisassembled();
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<DataType> getOriginArgumentTypes() {
|
||||
return getArgumentTypesBeforeDisassembled();
|
||||
}
|
||||
|
||||
@Override
|
||||
public abstract AggregateFunction withChildren(List<Expression> children);
|
||||
|
||||
public abstract DataType getFinalType();
|
||||
|
||||
public abstract DataType getIntermediateType();
|
||||
|
||||
public abstract AggregateFunction withAggregateParam(AggregateParam aggregateParam);
|
||||
|
||||
protected abstract List<DataType> intermediateTypes(List<DataType> argumentTypes, List<Expression> arguments);
|
||||
|
||||
/** getIntermediateTypes */
|
||||
public final PartialAggType getIntermediateTypes() {
|
||||
if (isGlobal() && isDisassembled()) {
|
||||
return (PartialAggType) child(0).getDataType();
|
||||
}
|
||||
List<Expression> arguments = getArgumentsBeforeDisassembled();
|
||||
List<DataType> types = getArgumentTypesBeforeDisassembled();
|
||||
return new PartialAggType(getArguments(), intermediateTypes(types, arguments));
|
||||
}
|
||||
|
||||
public final DataType getFinalType() {
|
||||
return getSignature().returnType;
|
||||
}
|
||||
|
||||
@Override
|
||||
public final DataType getDataType() {
|
||||
if (aggregateParam.isGlobal) {
|
||||
return getFinalType();
|
||||
} else {
|
||||
return getIntermediateTypes();
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public final List<AbstractDataType> expectedInputTypes() {
|
||||
if (isGlobal() && isDisassembled()) {
|
||||
return ImmutableList.of(getIntermediateTypes());
|
||||
} else {
|
||||
return getSignature().argumentsTypes;
|
||||
}
|
||||
}
|
||||
|
||||
public List<Expression> getArgumentsBeforeDisassembled() {
|
||||
if (arity() == 1 && getArgument(0).getDataType() instanceof PartialAggType) {
|
||||
return ((PartialAggType) getArgument(0).getDataType()).getOriginArguments();
|
||||
}
|
||||
return getArguments();
|
||||
}
|
||||
|
||||
public List<DataType> getArgumentTypesBeforeDisassembled() {
|
||||
return getArgumentsBeforeDisassembled()
|
||||
.stream()
|
||||
.map(Expression::getDataType)
|
||||
.collect(ImmutableList.toImmutableList());
|
||||
}
|
||||
|
||||
public boolean isDistinct() {
|
||||
return aggregateParam.isDistinct;
|
||||
}
|
||||
@ -59,8 +117,8 @@ public abstract class AggregateFunction extends BoundFunction {
|
||||
return aggregateParam.isGlobal;
|
||||
}
|
||||
|
||||
public Optional<List<DataType>> inputTypesBeforeDissemble() {
|
||||
return aggregateParam.inputTypesBeforeDissemble;
|
||||
public boolean isDisassembled() {
|
||||
return aggregateParam.isDisassembled;
|
||||
}
|
||||
|
||||
public AggregateParam getAggregateParam() {
|
||||
@ -95,13 +153,4 @@ public abstract class AggregateFunction extends BoundFunction {
|
||||
public boolean hasVarArguments() {
|
||||
return false;
|
||||
}
|
||||
|
||||
@Override
|
||||
public final DataType getDataType() {
|
||||
if (aggregateParam.isGlobal) {
|
||||
return getFinalType();
|
||||
} else {
|
||||
return getIntermediateType();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -17,11 +17,9 @@
|
||||
|
||||
package org.apache.doris.nereids.trees.expressions.functions.agg;
|
||||
|
||||
import org.apache.doris.nereids.types.DataType;
|
||||
import com.google.common.base.Preconditions;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.Optional;
|
||||
|
||||
/** AggregateParam. */
|
||||
public class AggregateParam {
|
||||
@ -29,50 +27,41 @@ public class AggregateParam {
|
||||
|
||||
public final boolean isDistinct;
|
||||
|
||||
// When AggregateDisassemble rule disassemble the aggregate function, say double avg(int), the local
|
||||
// aggregate keep the origin signature, but the global aggregate change to double avg(double).
|
||||
// This behavior is difference from the legacy optimizer, because legacy optimizer keep the same signature
|
||||
// between local aggregate and global aggregate. If the signatures are different, the result would wrong.
|
||||
// So we use this field to record the originInputTypes, and find the catalog function by the origin input types.
|
||||
public final Optional<List<DataType>> inputTypesBeforeDissemble;
|
||||
public final boolean isDisassembled;
|
||||
|
||||
public AggregateParam() {
|
||||
this(false, true, Optional.empty());
|
||||
}
|
||||
|
||||
public AggregateParam(boolean distinct) {
|
||||
this(distinct, true, Optional.empty());
|
||||
}
|
||||
|
||||
public AggregateParam(boolean isDistinct, boolean isGlobal) {
|
||||
this(isDistinct, isGlobal, Optional.empty());
|
||||
}
|
||||
|
||||
public AggregateParam(boolean isDistinct, boolean isGlobal, Optional<List<DataType>> inputTypesBeforeDissemble) {
|
||||
/** AggregateParam */
|
||||
public AggregateParam(boolean isDistinct, boolean isGlobal, boolean isDisassembled) {
|
||||
this.isDistinct = isDistinct;
|
||||
this.isGlobal = isGlobal;
|
||||
this.inputTypesBeforeDissemble = Objects.requireNonNull(inputTypesBeforeDissemble,
|
||||
"inputTypesBeforeDissemble can not be null");
|
||||
}
|
||||
|
||||
public static AggregateParam distinctAndGlobal() {
|
||||
return new AggregateParam(true, true, Optional.empty());
|
||||
this.isDisassembled = isDisassembled;
|
||||
if (!isGlobal) {
|
||||
Preconditions.checkArgument(isDisassembled == true,
|
||||
"local aggregate should be disassembed");
|
||||
}
|
||||
}
|
||||
|
||||
public static AggregateParam global() {
|
||||
return new AggregateParam(false, true, Optional.empty());
|
||||
return new AggregateParam(false, true, false);
|
||||
}
|
||||
|
||||
public static AggregateParam distinctAndGlobal() {
|
||||
return new AggregateParam(true, true, false);
|
||||
}
|
||||
|
||||
public AggregateParam withDistinct(boolean isDistinct) {
|
||||
return new AggregateParam(isDistinct, isGlobal, inputTypesBeforeDissemble);
|
||||
return new AggregateParam(isDistinct, isGlobal, isDisassembled);
|
||||
}
|
||||
|
||||
public AggregateParam withGlobal(boolean isGlobal) {
|
||||
return new AggregateParam(isDistinct, isGlobal, inputTypesBeforeDissemble);
|
||||
return new AggregateParam(isDistinct, isGlobal, isDisassembled);
|
||||
}
|
||||
|
||||
public AggregateParam withInputTypesBeforeDissemble(Optional<List<DataType>> inputTypesBeforeDissemble) {
|
||||
return new AggregateParam(isDistinct, isGlobal, inputTypesBeforeDissemble);
|
||||
public AggregateParam withDisassembled(boolean isDisassembled) {
|
||||
return new AggregateParam(isDistinct, isGlobal, isDisassembled);
|
||||
}
|
||||
|
||||
public AggregateParam withGlobalAndDisassembled(boolean isGlobal, boolean isDisassembled) {
|
||||
return new AggregateParam(isDistinct, isGlobal, isDisassembled);
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -86,11 +75,11 @@ public class AggregateParam {
|
||||
AggregateParam that = (AggregateParam) o;
|
||||
return isDistinct == that.isDistinct
|
||||
&& Objects.equals(isGlobal, that.isGlobal)
|
||||
&& Objects.equals(inputTypesBeforeDissemble, that.inputTypesBeforeDissemble);
|
||||
&& Objects.equals(isDisassembled, that.isDisassembled);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(isDistinct, isGlobal, inputTypesBeforeDissemble);
|
||||
return Objects.hash(isDistinct, isGlobal, isDisassembled);
|
||||
}
|
||||
}
|
||||
|
||||
@ -17,19 +17,20 @@
|
||||
|
||||
package org.apache.doris.nereids.trees.expressions.functions.agg;
|
||||
|
||||
import org.apache.doris.catalog.FunctionSignature;
|
||||
import org.apache.doris.nereids.exceptions.AnalysisException;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.CustomSignature;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable;
|
||||
import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression;
|
||||
import org.apache.doris.nereids.trees.expressions.typecoercion.ImplicitCastInputTypes;
|
||||
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
|
||||
import org.apache.doris.nereids.types.BigIntType;
|
||||
import org.apache.doris.nereids.types.DataType;
|
||||
import org.apache.doris.nereids.types.DateTimeType;
|
||||
import org.apache.doris.nereids.types.DateType;
|
||||
import org.apache.doris.nereids.types.DecimalV2Type;
|
||||
import org.apache.doris.nereids.types.DoubleType;
|
||||
import org.apache.doris.nereids.types.VarcharType;
|
||||
import org.apache.doris.nereids.types.coercion.AbstractDataType;
|
||||
import org.apache.doris.nereids.types.coercion.NumericType;
|
||||
import org.apache.doris.nereids.types.coercion.TypeCollection;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import com.google.common.collect.ImmutableList;
|
||||
@ -37,12 +38,7 @@ import com.google.common.collect.ImmutableList;
|
||||
import java.util.List;
|
||||
|
||||
/** avg agg function. */
|
||||
public class Avg extends AggregateFunction implements UnaryExpression, ImplicitCastInputTypes {
|
||||
|
||||
// used in interface expectedInputTypes to avoid new list in each time it be called
|
||||
private static final List<AbstractDataType> EXPECTED_INPUT_TYPES = ImmutableList.of(
|
||||
new TypeCollection(NumericType.INSTANCE, DateTimeType.INSTANCE, DateType.INSTANCE)
|
||||
);
|
||||
public class Avg extends AggregateFunction implements UnaryExpression, PropagateNullable, CustomSignature {
|
||||
|
||||
public Avg(Expression child) {
|
||||
super("avg", child);
|
||||
@ -53,31 +49,16 @@ public class Avg extends AggregateFunction implements UnaryExpression, ImplicitC
|
||||
}
|
||||
|
||||
@Override
|
||||
public DataType getFinalType() {
|
||||
DataType argumentType = inputTypesBeforeDissemble()
|
||||
.map(types -> types.get(0))
|
||||
.orElse(child().getDataType());
|
||||
if (argumentType instanceof DecimalV2Type) {
|
||||
return DecimalV2Type.SYSTEM_DEFAULT;
|
||||
} else if (argumentType.isDate()) {
|
||||
return DateType.INSTANCE;
|
||||
} else if (argumentType.isDateTime()) {
|
||||
return DateTimeType.INSTANCE;
|
||||
} else {
|
||||
return DoubleType.INSTANCE;
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: We should return a complex type: PartialAggType(bufferTypes=[Double, Int], inputTypes=[Int])
|
||||
// to denote sum(double) and count(int)
|
||||
@Override
|
||||
public DataType getIntermediateType() {
|
||||
return VarcharType.createVarcharType(-1);
|
||||
public FunctionSignature customSignature(List<DataType> argumentTypes, List<Expression> arguments) {
|
||||
DataType implicitCastType = implicitCast(argumentTypes.get(0));
|
||||
return FunctionSignature.ret(implicitCastType).args(implicitCastType);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean nullable() {
|
||||
return child().nullable();
|
||||
protected List<DataType> intermediateTypes(List<DataType> argumentTypes, List<Expression> arguments) {
|
||||
DataType sumType = getFinalType();
|
||||
BigIntType countType = BigIntType.INSTANCE;
|
||||
return ImmutableList.of(sumType, countType);
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -91,17 +72,22 @@ public class Avg extends AggregateFunction implements UnaryExpression, ImplicitC
|
||||
return new Avg(aggregateParam, child());
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<AbstractDataType> expectedInputTypes() {
|
||||
if (isGlobal() && inputTypesBeforeDissemble().isPresent()) {
|
||||
return ImmutableList.of();
|
||||
} else {
|
||||
return EXPECTED_INPUT_TYPES;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
|
||||
return visitor.visitAvg(this, context);
|
||||
}
|
||||
|
||||
private DataType implicitCast(DataType dataType) {
|
||||
if (dataType instanceof DecimalV2Type) {
|
||||
return DecimalV2Type.SYSTEM_DEFAULT;
|
||||
} else if (dataType.isDate()) {
|
||||
return DateType.INSTANCE;
|
||||
} else if (dataType.isDateTime()) {
|
||||
return DateTimeType.INSTANCE;
|
||||
} else if (dataType instanceof NumericType) {
|
||||
return DoubleType.INSTANCE;
|
||||
} else {
|
||||
throw new AnalysisException("avg requires a numeric parameter: " + dataType);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -17,13 +17,13 @@
|
||||
|
||||
package org.apache.doris.nereids.trees.expressions.functions.agg;
|
||||
|
||||
import org.apache.doris.catalog.FunctionSignature;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.AlwaysNotNullable;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
|
||||
import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression;
|
||||
import org.apache.doris.nereids.trees.expressions.typecoercion.ImplicitCastInputTypes;
|
||||
import org.apache.doris.nereids.types.BitmapType;
|
||||
import org.apache.doris.nereids.types.DataType;
|
||||
import org.apache.doris.nereids.types.coercion.AbstractDataType;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import com.google.common.collect.ImmutableList;
|
||||
@ -32,7 +32,11 @@ import java.util.List;
|
||||
|
||||
/** BitmapIntersect */
|
||||
public class BitmapIntersect extends AggregateFunction
|
||||
implements UnaryExpression, PropagateNullable, ImplicitCastInputTypes {
|
||||
implements UnaryExpression, AlwaysNotNullable, ExplicitlyCastableSignature {
|
||||
public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
|
||||
FunctionSignature.ret(BitmapType.INSTANCE).args(BitmapType.INSTANCE)
|
||||
);
|
||||
|
||||
public BitmapIntersect(Expression arg0) {
|
||||
super("bitmap_intersect", arg0);
|
||||
}
|
||||
@ -41,29 +45,24 @@ public class BitmapIntersect extends AggregateFunction
|
||||
super("bitmap_intersect", aggregateParam, arg0);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected List<DataType> intermediateTypes(List<DataType> argumentTypes, List<Expression> arguments) {
|
||||
return ImmutableList.of(BitmapType.INSTANCE);
|
||||
}
|
||||
|
||||
@Override
|
||||
public BitmapIntersect withChildren(List<Expression> children) {
|
||||
Preconditions.checkArgument(children.size() == 1);
|
||||
return new BitmapIntersect(getAggregateParam(), children.get(0));
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<AbstractDataType> expectedInputTypes() {
|
||||
return ImmutableList.of(BitmapType.INSTANCE);
|
||||
}
|
||||
|
||||
@Override
|
||||
public DataType getFinalType() {
|
||||
return BitmapType.INSTANCE;
|
||||
}
|
||||
|
||||
@Override
|
||||
public DataType getIntermediateType() {
|
||||
return BitmapType.INSTANCE;
|
||||
}
|
||||
|
||||
@Override
|
||||
public BitmapIntersect withAggregateParam(AggregateParam aggregateParam) {
|
||||
return new BitmapIntersect(aggregateParam, child());
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<FunctionSignature> getSignatures() {
|
||||
return SIGNATURES;
|
||||
}
|
||||
}
|
||||
|
||||
@ -17,13 +17,13 @@
|
||||
|
||||
package org.apache.doris.nereids.trees.expressions.functions.agg;
|
||||
|
||||
import org.apache.doris.catalog.FunctionSignature;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.AlwaysNotNullable;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
|
||||
import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression;
|
||||
import org.apache.doris.nereids.trees.expressions.typecoercion.ImplicitCastInputTypes;
|
||||
import org.apache.doris.nereids.types.BitmapType;
|
||||
import org.apache.doris.nereids.types.DataType;
|
||||
import org.apache.doris.nereids.types.coercion.AbstractDataType;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import com.google.common.collect.ImmutableList;
|
||||
@ -32,7 +32,11 @@ import java.util.List;
|
||||
|
||||
/** BitmapUnion */
|
||||
public class BitmapUnion extends AggregateFunction
|
||||
implements UnaryExpression, PropagateNullable, ImplicitCastInputTypes {
|
||||
implements UnaryExpression, AlwaysNotNullable, ExplicitlyCastableSignature {
|
||||
public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
|
||||
FunctionSignature.ret(BitmapType.INSTANCE).args(BitmapType.INSTANCE)
|
||||
);
|
||||
|
||||
public BitmapUnion(Expression arg0) {
|
||||
super("bitmap_union", arg0);
|
||||
}
|
||||
@ -41,29 +45,24 @@ public class BitmapUnion extends AggregateFunction
|
||||
super("bitmap_union", aggregateParam, arg0);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected List<DataType> intermediateTypes(List<DataType> argumentTypes, List<Expression> arguments) {
|
||||
return ImmutableList.of(BitmapType.INSTANCE);
|
||||
}
|
||||
|
||||
@Override
|
||||
public BitmapUnion withChildren(List<Expression> children) {
|
||||
Preconditions.checkArgument(children.size() == 1);
|
||||
return new BitmapUnion(getAggregateParam(), children.get(0));
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<AbstractDataType> expectedInputTypes() {
|
||||
return ImmutableList.of(BitmapType.INSTANCE);
|
||||
}
|
||||
|
||||
@Override
|
||||
public DataType getFinalType() {
|
||||
return BitmapType.INSTANCE;
|
||||
}
|
||||
|
||||
@Override
|
||||
public DataType getIntermediateType() {
|
||||
return BitmapType.INSTANCE;
|
||||
}
|
||||
|
||||
@Override
|
||||
public BitmapUnion withAggregateParam(AggregateParam aggregateParam) {
|
||||
return new BitmapUnion(aggregateParam, child());
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<FunctionSignature> getSignatures() {
|
||||
return SIGNATURES;
|
||||
}
|
||||
}
|
||||
|
||||
@ -17,14 +17,14 @@
|
||||
|
||||
package org.apache.doris.nereids.trees.expressions.functions.agg;
|
||||
|
||||
import org.apache.doris.catalog.FunctionSignature;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.AlwaysNotNullable;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
|
||||
import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression;
|
||||
import org.apache.doris.nereids.trees.expressions.typecoercion.ImplicitCastInputTypes;
|
||||
import org.apache.doris.nereids.types.BigIntType;
|
||||
import org.apache.doris.nereids.types.BitmapType;
|
||||
import org.apache.doris.nereids.types.DataType;
|
||||
import org.apache.doris.nereids.types.coercion.AbstractDataType;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import com.google.common.collect.ImmutableList;
|
||||
@ -33,7 +33,11 @@ import java.util.List;
|
||||
|
||||
/** BitmapUnionCount */
|
||||
public class BitmapUnionCount extends AggregateFunction
|
||||
implements UnaryExpression, PropagateNullable, ImplicitCastInputTypes {
|
||||
implements UnaryExpression, AlwaysNotNullable, ExplicitlyCastableSignature {
|
||||
public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
|
||||
FunctionSignature.ret(BigIntType.INSTANCE).args(BitmapType.INSTANCE)
|
||||
);
|
||||
|
||||
public BitmapUnionCount(Expression arg0) {
|
||||
super("bitmap_union_count", arg0);
|
||||
}
|
||||
@ -42,29 +46,24 @@ public class BitmapUnionCount extends AggregateFunction
|
||||
super("bitmap_union_count", aggregateParam, arg0);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected List<DataType> intermediateTypes(List<DataType> argumentTypes, List<Expression> arguments) {
|
||||
return ImmutableList.of(BitmapType.INSTANCE);
|
||||
}
|
||||
|
||||
@Override
|
||||
public BitmapUnionCount withChildren(List<Expression> children) {
|
||||
Preconditions.checkArgument(children.size() == 1);
|
||||
return new BitmapUnionCount(getAggregateParam(), children.get(0));
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<AbstractDataType> expectedInputTypes() {
|
||||
return ImmutableList.of(BitmapType.INSTANCE);
|
||||
}
|
||||
|
||||
@Override
|
||||
public DataType getFinalType() {
|
||||
return BigIntType.INSTANCE;
|
||||
}
|
||||
|
||||
@Override
|
||||
public DataType getIntermediateType() {
|
||||
return BitmapType.INSTANCE;
|
||||
}
|
||||
|
||||
@Override
|
||||
public BitmapUnionCount withAggregateParam(AggregateParam aggregateParam) {
|
||||
return new BitmapUnionCount(aggregateParam, child());
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<FunctionSignature> getSignatures() {
|
||||
return SIGNATURES;
|
||||
}
|
||||
}
|
||||
|
||||
@ -17,18 +17,17 @@
|
||||
|
||||
package org.apache.doris.nereids.trees.expressions.functions.agg;
|
||||
|
||||
import org.apache.doris.catalog.FunctionSignature;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.AlwaysNotNullable;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
|
||||
import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression;
|
||||
import org.apache.doris.nereids.trees.expressions.typecoercion.ImplicitCastInputTypes;
|
||||
import org.apache.doris.nereids.types.BigIntType;
|
||||
import org.apache.doris.nereids.types.BitmapType;
|
||||
import org.apache.doris.nereids.types.DataType;
|
||||
import org.apache.doris.nereids.types.IntegerType;
|
||||
import org.apache.doris.nereids.types.SmallIntType;
|
||||
import org.apache.doris.nereids.types.TinyIntType;
|
||||
import org.apache.doris.nereids.types.coercion.AbstractDataType;
|
||||
import org.apache.doris.nereids.types.coercion.TypeCollection;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import com.google.common.collect.ImmutableList;
|
||||
@ -37,7 +36,14 @@ import java.util.List;
|
||||
|
||||
/** BitmapUnionInt */
|
||||
public class BitmapUnionInt extends AggregateFunction
|
||||
implements UnaryExpression, PropagateNullable, ImplicitCastInputTypes {
|
||||
implements UnaryExpression, AlwaysNotNullable, ExplicitlyCastableSignature {
|
||||
public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
|
||||
FunctionSignature.ret(BigIntType.INSTANCE).args(SmallIntType.INSTANCE),
|
||||
FunctionSignature.ret(BigIntType.INSTANCE).args(TinyIntType.INSTANCE),
|
||||
FunctionSignature.ret(BigIntType.INSTANCE).args(IntegerType.INSTANCE),
|
||||
FunctionSignature.ret(BigIntType.INSTANCE).args(BigIntType.INSTANCE)
|
||||
);
|
||||
|
||||
public BitmapUnionInt(Expression arg0) {
|
||||
super("bitmap_union_int", arg0);
|
||||
}
|
||||
@ -46,34 +52,24 @@ public class BitmapUnionInt extends AggregateFunction
|
||||
super("bitmap_union_int", aggregateParam, arg0);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected List<DataType> intermediateTypes(List<DataType> argumentTypes, List<Expression> arguments) {
|
||||
return ImmutableList.of(BitmapType.INSTANCE);
|
||||
}
|
||||
|
||||
@Override
|
||||
public BitmapUnionInt withChildren(List<Expression> children) {
|
||||
Preconditions.checkArgument(children.size() == 1);
|
||||
return new BitmapUnionInt(getAggregateParam(), children.get(0));
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<AbstractDataType> expectedInputTypes() {
|
||||
if (isGlobal() && inputTypesBeforeDissemble().isPresent()) {
|
||||
return ImmutableList.of();
|
||||
} else {
|
||||
return ImmutableList.of(new TypeCollection(
|
||||
TinyIntType.INSTANCE, SmallIntType.INSTANCE, IntegerType.INSTANCE, BigIntType.INSTANCE));
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public DataType getFinalType() {
|
||||
return BigIntType.INSTANCE;
|
||||
}
|
||||
|
||||
@Override
|
||||
public DataType getIntermediateType() {
|
||||
return BitmapType.INSTANCE;
|
||||
}
|
||||
|
||||
@Override
|
||||
public BitmapUnionInt withAggregateParam(AggregateParam aggregateParam) {
|
||||
return new BitmapUnionInt(aggregateParam, child());
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<FunctionSignature> getSignatures() {
|
||||
return SIGNATURES;
|
||||
}
|
||||
}
|
||||
|
||||
@ -17,20 +17,23 @@
|
||||
|
||||
package org.apache.doris.nereids.trees.expressions.functions.agg;
|
||||
|
||||
import org.apache.doris.catalog.FunctionSignature;
|
||||
import org.apache.doris.nereids.exceptions.UnboundException;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.AlwaysNotNullable;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.CustomSignature;
|
||||
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
|
||||
import org.apache.doris.nereids.types.BigIntType;
|
||||
import org.apache.doris.nereids.types.DataType;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import com.google.common.collect.ImmutableList;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/** count agg function. */
|
||||
public class Count extends AggregateFunction implements AlwaysNotNullable {
|
||||
public class Count extends AggregateFunction implements AlwaysNotNullable, CustomSignature {
|
||||
|
||||
private final boolean isStar;
|
||||
|
||||
@ -59,13 +62,13 @@ public class Count extends AggregateFunction implements AlwaysNotNullable {
|
||||
}
|
||||
|
||||
@Override
|
||||
public DataType getFinalType() {
|
||||
return BigIntType.INSTANCE;
|
||||
public FunctionSignature customSignature(List<DataType> argumentTypes, List<Expression> arguments) {
|
||||
return FunctionSignature.of(BigIntType.INSTANCE, (List) argumentTypes);
|
||||
}
|
||||
|
||||
@Override
|
||||
public DataType getIntermediateType() {
|
||||
return getFinalType();
|
||||
protected List<DataType> intermediateTypes(List<DataType> argumentTypes, List<Expression> arguments) {
|
||||
return ImmutableList.of(BigIntType.INSTANCE);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@ -17,13 +17,13 @@
|
||||
|
||||
package org.apache.doris.nereids.trees.expressions.functions.agg;
|
||||
|
||||
import org.apache.doris.catalog.FunctionSignature;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable;
|
||||
import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression;
|
||||
import org.apache.doris.nereids.trees.expressions.typecoercion.ImplicitCastInputTypes;
|
||||
import org.apache.doris.nereids.types.BitmapType;
|
||||
import org.apache.doris.nereids.types.DataType;
|
||||
import org.apache.doris.nereids.types.coercion.AbstractDataType;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import com.google.common.collect.ImmutableList;
|
||||
@ -32,7 +32,11 @@ import java.util.List;
|
||||
|
||||
/** GroupBitmapXor */
|
||||
public class GroupBitmapXor extends AggregateFunction
|
||||
implements UnaryExpression, PropagateNullable, ImplicitCastInputTypes {
|
||||
implements UnaryExpression, PropagateNullable, ExplicitlyCastableSignature {
|
||||
public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
|
||||
FunctionSignature.ret(BitmapType.INSTANCE).args(BitmapType.INSTANCE)
|
||||
);
|
||||
|
||||
public GroupBitmapXor(Expression arg0) {
|
||||
super("group_bitmap_xor", arg0);
|
||||
}
|
||||
@ -41,33 +45,24 @@ public class GroupBitmapXor extends AggregateFunction
|
||||
super("group_bitmap_xor", aggregateParam, arg0);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected List<DataType> intermediateTypes(List<DataType> argumentTypes, List<Expression> arguments) {
|
||||
return ImmutableList.of(BitmapType.INSTANCE);
|
||||
}
|
||||
|
||||
@Override
|
||||
public GroupBitmapXor withChildren(List<Expression> children) {
|
||||
Preconditions.checkArgument(children.size() == 1);
|
||||
return new GroupBitmapXor(getAggregateParam(), children.get(0));
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<AbstractDataType> expectedInputTypes() {
|
||||
if (isGlobal() && inputTypesBeforeDissemble().isPresent()) {
|
||||
return ImmutableList.of();
|
||||
} else {
|
||||
return ImmutableList.of(BitmapType.INSTANCE);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public DataType getFinalType() {
|
||||
return BitmapType.INSTANCE;
|
||||
}
|
||||
|
||||
@Override
|
||||
public DataType getIntermediateType() {
|
||||
return BitmapType.INSTANCE;
|
||||
}
|
||||
|
||||
@Override
|
||||
public GroupBitmapXor withAggregateParam(AggregateParam aggregateParam) {
|
||||
return new GroupBitmapXor(aggregateParam, child());
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<FunctionSignature> getSignatures() {
|
||||
return SIGNATURES;
|
||||
}
|
||||
}
|
||||
|
||||
@ -17,13 +17,13 @@
|
||||
|
||||
package org.apache.doris.nereids.trees.expressions.functions.agg;
|
||||
|
||||
import org.apache.doris.catalog.FunctionSignature;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.AlwaysNotNullable;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
|
||||
import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression;
|
||||
import org.apache.doris.nereids.trees.expressions.typecoercion.ImplicitCastInputTypes;
|
||||
import org.apache.doris.nereids.types.DataType;
|
||||
import org.apache.doris.nereids.types.HllType;
|
||||
import org.apache.doris.nereids.types.coercion.AbstractDataType;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import com.google.common.collect.ImmutableList;
|
||||
@ -32,15 +32,22 @@ import java.util.List;
|
||||
|
||||
/** HllUnion */
|
||||
public class HllUnion extends AggregateFunction
|
||||
implements UnaryExpression, PropagateNullable, ImplicitCastInputTypes {
|
||||
implements UnaryExpression, AlwaysNotNullable, ExplicitlyCastableSignature {
|
||||
public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
|
||||
FunctionSignature.ret(HllType.INSTANCE).args(HllType.INSTANCE)
|
||||
);
|
||||
|
||||
public HllUnion(Expression arg0) {
|
||||
// TODO: change to hll_union in the future
|
||||
super("hll_raw_agg", arg0);
|
||||
super("hll_union", arg0);
|
||||
}
|
||||
|
||||
public HllUnion(AggregateParam aggregateParam, Expression arg0) {
|
||||
// TODO: change to hll_union in the future
|
||||
super("hll_raw_agg", aggregateParam, arg0);
|
||||
super("hll_union", aggregateParam, arg0);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected List<DataType> intermediateTypes(List<DataType> argumentTypes, List<Expression> arguments) {
|
||||
return ImmutableList.of(HllType.INSTANCE);
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -49,23 +56,13 @@ public class HllUnion extends AggregateFunction
|
||||
return new HllUnion(getAggregateParam(), children.get(0));
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<AbstractDataType> expectedInputTypes() {
|
||||
return ImmutableList.of(HllType.INSTANCE);
|
||||
}
|
||||
|
||||
@Override
|
||||
public DataType getFinalType() {
|
||||
return HllType.INSTANCE;
|
||||
}
|
||||
|
||||
@Override
|
||||
public DataType getIntermediateType() {
|
||||
return HllType.INSTANCE;
|
||||
}
|
||||
|
||||
@Override
|
||||
public HllUnion withAggregateParam(AggregateParam aggregateParam) {
|
||||
return new HllUnion(aggregateParam, child());
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<FunctionSignature> getSignatures() {
|
||||
return SIGNATURES;
|
||||
}
|
||||
}
|
||||
|
||||
@ -17,14 +17,14 @@
|
||||
|
||||
package org.apache.doris.nereids.trees.expressions.functions.agg;
|
||||
|
||||
import org.apache.doris.catalog.FunctionSignature;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.AlwaysNotNullable;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
|
||||
import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression;
|
||||
import org.apache.doris.nereids.trees.expressions.typecoercion.ImplicitCastInputTypes;
|
||||
import org.apache.doris.nereids.types.BigIntType;
|
||||
import org.apache.doris.nereids.types.DataType;
|
||||
import org.apache.doris.nereids.types.HllType;
|
||||
import org.apache.doris.nereids.types.coercion.AbstractDataType;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import com.google.common.collect.ImmutableList;
|
||||
@ -33,7 +33,11 @@ import java.util.List;
|
||||
|
||||
/** HllUnionAgg */
|
||||
public class HllUnionAgg extends AggregateFunction
|
||||
implements UnaryExpression, PropagateNullable, ImplicitCastInputTypes {
|
||||
implements UnaryExpression, AlwaysNotNullable, ExplicitlyCastableSignature {
|
||||
public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
|
||||
FunctionSignature.ret(BigIntType.INSTANCE).args(HllType.INSTANCE)
|
||||
);
|
||||
|
||||
public HllUnionAgg(Expression arg0) {
|
||||
super("hll_union_agg", arg0);
|
||||
}
|
||||
@ -42,29 +46,24 @@ public class HllUnionAgg extends AggregateFunction
|
||||
super("hll_union_agg", aggregateParam, arg0);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected List<DataType> intermediateTypes(List<DataType> argumentTypes, List<Expression> arguments) {
|
||||
return ImmutableList.of(HllType.INSTANCE);
|
||||
}
|
||||
|
||||
@Override
|
||||
public HllUnionAgg withChildren(List<Expression> children) {
|
||||
Preconditions.checkArgument(children.size() == 1);
|
||||
return new HllUnionAgg(getAggregateParam(), children.get(0));
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<AbstractDataType> expectedInputTypes() {
|
||||
return ImmutableList.of(HllType.INSTANCE);
|
||||
}
|
||||
|
||||
@Override
|
||||
public DataType getFinalType() {
|
||||
return BigIntType.INSTANCE;
|
||||
}
|
||||
|
||||
@Override
|
||||
public DataType getIntermediateType() {
|
||||
return HllType.INSTANCE;
|
||||
}
|
||||
|
||||
@Override
|
||||
public HllUnionAgg withAggregateParam(AggregateParam aggregateParam) {
|
||||
return new HllUnionAgg(aggregateParam, child());
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<FunctionSignature> getSignatures() {
|
||||
return SIGNATURES;
|
||||
}
|
||||
}
|
||||
|
||||
@ -17,7 +17,10 @@
|
||||
|
||||
package org.apache.doris.nereids.trees.expressions.functions.agg;
|
||||
|
||||
import org.apache.doris.catalog.FunctionSignature;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.CustomSignature;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable;
|
||||
import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression;
|
||||
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
|
||||
import org.apache.doris.nereids.types.DataType;
|
||||
@ -27,8 +30,7 @@ import com.google.common.base.Preconditions;
|
||||
import java.util.List;
|
||||
|
||||
/** max agg function. */
|
||||
public class Max extends AggregateFunction implements UnaryExpression {
|
||||
|
||||
public class Max extends AggregateFunction implements UnaryExpression, PropagateNullable, CustomSignature {
|
||||
public Max(Expression child) {
|
||||
super("max", child);
|
||||
}
|
||||
@ -38,18 +40,13 @@ public class Max extends AggregateFunction implements UnaryExpression {
|
||||
}
|
||||
|
||||
@Override
|
||||
public DataType getFinalType() {
|
||||
return child().getDataType();
|
||||
public FunctionSignature customSignature(List<DataType> argumentTypes, List<Expression> arguments) {
|
||||
return FunctionSignature.ret(argumentTypes.get(0)).args(argumentTypes.get(0));
|
||||
}
|
||||
|
||||
@Override
|
||||
public DataType getIntermediateType() {
|
||||
return getFinalType();
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean nullable() {
|
||||
return child().nullable();
|
||||
protected List<DataType> intermediateTypes(List<DataType> argumentTypes, List<Expression> arguments) {
|
||||
return argumentTypes;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@ -17,7 +17,10 @@
|
||||
|
||||
package org.apache.doris.nereids.trees.expressions.functions.agg;
|
||||
|
||||
import org.apache.doris.catalog.FunctionSignature;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.CustomSignature;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable;
|
||||
import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression;
|
||||
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
|
||||
import org.apache.doris.nereids.types.DataType;
|
||||
@ -27,7 +30,7 @@ import com.google.common.base.Preconditions;
|
||||
import java.util.List;
|
||||
|
||||
/** min agg function. */
|
||||
public class Min extends AggregateFunction implements UnaryExpression {
|
||||
public class Min extends AggregateFunction implements UnaryExpression, PropagateNullable, CustomSignature {
|
||||
|
||||
public Min(Expression child) {
|
||||
super("min", child);
|
||||
@ -38,18 +41,13 @@ public class Min extends AggregateFunction implements UnaryExpression {
|
||||
}
|
||||
|
||||
@Override
|
||||
public DataType getFinalType() {
|
||||
return child().getDataType();
|
||||
public FunctionSignature customSignature(List<DataType> argumentTypes, List<Expression> arguments) {
|
||||
return FunctionSignature.ret(argumentTypes.get(0)).args(argumentTypes.get(0));
|
||||
}
|
||||
|
||||
@Override
|
||||
public DataType getIntermediateType() {
|
||||
return getFinalType();
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean nullable() {
|
||||
return child().nullable();
|
||||
protected List<DataType> intermediateTypes(List<DataType> argumentTypes, List<Expression> arguments) {
|
||||
return argumentTypes;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@ -17,17 +17,18 @@
|
||||
|
||||
package org.apache.doris.nereids.trees.expressions.functions.agg;
|
||||
|
||||
import org.apache.doris.catalog.FunctionSignature;
|
||||
import org.apache.doris.nereids.exceptions.AnalysisException;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.CustomSignature;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable;
|
||||
import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression;
|
||||
import org.apache.doris.nereids.trees.expressions.typecoercion.ImplicitCastInputTypes;
|
||||
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
|
||||
import org.apache.doris.nereids.types.BigIntType;
|
||||
import org.apache.doris.nereids.types.DataType;
|
||||
import org.apache.doris.nereids.types.DecimalV2Type;
|
||||
import org.apache.doris.nereids.types.DoubleType;
|
||||
import org.apache.doris.nereids.types.LargeIntType;
|
||||
import org.apache.doris.nereids.types.coercion.AbstractDataType;
|
||||
import org.apache.doris.nereids.types.coercion.FractionalType;
|
||||
import org.apache.doris.nereids.types.coercion.IntegralType;
|
||||
import org.apache.doris.nereids.types.coercion.NumericType;
|
||||
|
||||
@ -37,11 +38,7 @@ import com.google.common.collect.ImmutableList;
|
||||
import java.util.List;
|
||||
|
||||
/** sum agg function. */
|
||||
public class Sum extends AggregateFunction implements UnaryExpression, ImplicitCastInputTypes {
|
||||
|
||||
// used in interface expectedInputTypes to avoid new list in each time it be called
|
||||
private static final List<AbstractDataType> EXPECTED_INPUT_TYPES = ImmutableList.of(NumericType.INSTANCE);
|
||||
|
||||
public class Sum extends AggregateFunction implements UnaryExpression, PropagateNullable, CustomSignature {
|
||||
public Sum(Expression child) {
|
||||
super("sum", child);
|
||||
}
|
||||
@ -51,34 +48,14 @@ public class Sum extends AggregateFunction implements UnaryExpression, ImplicitC
|
||||
}
|
||||
|
||||
@Override
|
||||
public DataType getFinalType() {
|
||||
DataType dataType = child().getDataType();
|
||||
if (dataType instanceof LargeIntType) {
|
||||
return dataType;
|
||||
} else if (dataType instanceof DecimalV2Type) {
|
||||
return DecimalV2Type.SYSTEM_DEFAULT;
|
||||
} else if (dataType instanceof IntegralType) {
|
||||
return BigIntType.INSTANCE;
|
||||
} else if (dataType instanceof FractionalType) {
|
||||
return DoubleType.INSTANCE;
|
||||
} else {
|
||||
throw new IllegalStateException("Unsupported sum type: " + dataType);
|
||||
}
|
||||
public FunctionSignature customSignature(List<DataType> argumentTypes, List<Expression> arguments) {
|
||||
DataType implicitCastType = implicitCast(argumentTypes.get(0));
|
||||
return FunctionSignature.ret(implicitCastType).args(NumericType.INSTANCE);
|
||||
}
|
||||
|
||||
@Override
|
||||
public DataType getIntermediateType() {
|
||||
return getFinalType();
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean nullable() {
|
||||
return child().nullable();
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<AbstractDataType> expectedInputTypes() {
|
||||
return EXPECTED_INPUT_TYPES;
|
||||
protected List<DataType> intermediateTypes(List<DataType> argumentTypes, List<Expression> arguments) {
|
||||
return ImmutableList.of(getFinalType());
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -96,4 +73,18 @@ public class Sum extends AggregateFunction implements UnaryExpression, ImplicitC
|
||||
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
|
||||
return visitor.visitSum(this, context);
|
||||
}
|
||||
|
||||
private DataType implicitCast(DataType dataType) {
|
||||
if (dataType instanceof LargeIntType) {
|
||||
return dataType;
|
||||
} else if (dataType instanceof DecimalV2Type) {
|
||||
return DecimalV2Type.SYSTEM_DEFAULT;
|
||||
} else if (dataType instanceof IntegralType) {
|
||||
return BigIntType.INSTANCE;
|
||||
} else if (dataType instanceof NumericType) {
|
||||
return DoubleType.INSTANCE;
|
||||
} else {
|
||||
throw new AnalysisException("sum requires a numeric parameter: " + dataType);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -41,6 +41,7 @@ import org.apache.doris.nereids.types.SmallIntType;
|
||||
import org.apache.doris.nereids.types.StringType;
|
||||
import org.apache.doris.nereids.types.TinyIntType;
|
||||
import org.apache.doris.nereids.types.VarcharType;
|
||||
import org.apache.doris.nereids.types.coercion.AbstractDataType;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import com.google.common.base.Suppliers;
|
||||
@ -90,7 +91,7 @@ public class If extends ScalarFunction
|
||||
);
|
||||
|
||||
private final Supplier<DataType> widerType = Suppliers.memoize(() -> {
|
||||
List<DataType> argumentsTypes = getSignature().argumentsTypes;
|
||||
List<AbstractDataType> argumentsTypes = getSignature().argumentsTypes;
|
||||
Type assignmentCompatibleType = ScalarType.getAssignmentCompatibleType(
|
||||
argumentsTypes.get(1).toCatalogDataType(),
|
||||
argumentsTypes.get(2).toCatalogDataType(),
|
||||
@ -106,11 +107,11 @@ public class If extends ScalarFunction
|
||||
}
|
||||
|
||||
@Override
|
||||
protected FunctionSignature computeSignature(FunctionSignature signature) {
|
||||
protected FunctionSignature computeSignature(FunctionSignature signature, List<Expression> arguments) {
|
||||
DataType widerType = this.widerType.get();
|
||||
signature = signature.withArgumentTypes(children(), (sigType, argType) -> widerType)
|
||||
signature = signature.withArgumentTypes(arguments, (sigType, argType) -> widerType)
|
||||
.withReturnType(widerType);
|
||||
return super.computeSignature(signature);
|
||||
return super.computeSignature(signature, arguments);
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@ -17,45 +17,17 @@
|
||||
|
||||
package org.apache.doris.nereids.trees.expressions.functions.scalar;
|
||||
|
||||
import org.apache.doris.catalog.FunctionSignature;
|
||||
import org.apache.doris.catalog.ScalarType;
|
||||
import org.apache.doris.catalog.Type;
|
||||
import org.apache.doris.common.Config;
|
||||
import org.apache.doris.common.Pair;
|
||||
import org.apache.doris.nereids.annotation.Developing;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.BoundFunction;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.ComputeSignature;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.DecimalSamePrecision;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.DecimalStddevPrecision;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.DecimalWiderPrecision;
|
||||
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
|
||||
import org.apache.doris.nereids.types.ArrayType;
|
||||
import org.apache.doris.nereids.types.DataType;
|
||||
import org.apache.doris.nereids.types.DateTimeV2Type;
|
||||
import org.apache.doris.nereids.types.DecimalV2Type;
|
||||
import org.apache.doris.nereids.types.DecimalV3Type;
|
||||
import org.apache.doris.nereids.types.coercion.AbstractDataType;
|
||||
import org.apache.doris.nereids.util.ResponsibilityChain;
|
||||
|
||||
import com.google.common.base.Suppliers;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.function.BiFunction;
|
||||
import java.util.function.Supplier;
|
||||
|
||||
/**
|
||||
* The function which consume zero or more arguments in a row and product one value.
|
||||
*/
|
||||
public abstract class ScalarFunction extends BoundFunction implements ComputeSignature {
|
||||
@Developing("this field will move to BoundFunction, when we support compute signature for AggregateFunction")
|
||||
private final Supplier<FunctionSignature> signatureCache = Suppliers.memoize(() -> {
|
||||
// first step: find the candidate signature in the signature list
|
||||
FunctionSignature matchedSignature = searchSignature();
|
||||
// second step: change the signature, e.g. fill precision for decimal v2
|
||||
return computeSignature(matchedSignature);
|
||||
});
|
||||
|
||||
public ScalarFunction(String name, Expression... arguments) {
|
||||
super(name, arguments);
|
||||
}
|
||||
@ -64,150 +36,8 @@ public abstract class ScalarFunction extends BoundFunction implements ComputeSig
|
||||
super(name, arguments);
|
||||
}
|
||||
|
||||
public FunctionSignature getSignature() {
|
||||
return signatureCache.get();
|
||||
}
|
||||
|
||||
protected FunctionSignature computeSignature(FunctionSignature signature) {
|
||||
// NOTE:
|
||||
// this computed chain only process the common cases.
|
||||
// If you want to add some common cases to here, please separate the process code
|
||||
// to the other methods and add to this chain.
|
||||
// If you want to add some special cases, please override this method in the special
|
||||
// function class, like 'If' function and 'Substring' function.
|
||||
return ComputeSignatureChain.from(signature, getArguments())
|
||||
.then(this::computePrecisionForDatetimeV2)
|
||||
.then(this::upgradeDateOrDateTimeToV2)
|
||||
.then(this::upgradeDecimalV2ToV3)
|
||||
.then(this::computePrecisionForDecimal)
|
||||
.then(this::dynamicComputePropertiesOfArray)
|
||||
.get();
|
||||
}
|
||||
|
||||
@Override
|
||||
@Developing("this method will move to BoundFunction, when we support compute signature for AggregateFunction")
|
||||
public final List<AbstractDataType> expectedInputTypes() {
|
||||
return ComputeSignature.super.expectedInputTypes();
|
||||
}
|
||||
|
||||
@Override
|
||||
@Developing("this method will move to BoundFunction, when we support compute signature for AggregateFunction")
|
||||
public final DataType getDataType() {
|
||||
return ComputeSignature.super.getDataType();
|
||||
}
|
||||
|
||||
@Override
|
||||
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
|
||||
return visitor.visitScalarFunction(this, context);
|
||||
}
|
||||
|
||||
private FunctionSignature computePrecisionForDatetimeV2(
|
||||
FunctionSignature signature, List<Expression> arguments) {
|
||||
|
||||
// fill for arguments type
|
||||
signature = signature.withArgumentTypes(arguments, (sigArgType, realArgType) -> {
|
||||
if (sigArgType instanceof DateTimeV2Type && realArgType.getDataType() instanceof DateTimeV2Type) {
|
||||
return realArgType.getDataType();
|
||||
}
|
||||
return sigArgType;
|
||||
});
|
||||
|
||||
// fill for return type
|
||||
if (signature.returnType instanceof DateTimeV2Type) {
|
||||
Integer maxScale = signature.argumentsTypes.stream()
|
||||
.filter(DateTimeV2Type.class::isInstance)
|
||||
.map(t -> ((DateTimeV2Type) t).getScale())
|
||||
.reduce(Math::max)
|
||||
.orElse(((DateTimeV2Type) signature.returnType).getScale());
|
||||
signature = signature.withReturnType(DateTimeV2Type.of(maxScale));
|
||||
}
|
||||
|
||||
return signature;
|
||||
}
|
||||
|
||||
private FunctionSignature upgradeDateOrDateTimeToV2(
|
||||
FunctionSignature signature, List<Expression> arguments) {
|
||||
DataType returnType = signature.returnType;
|
||||
Type type = returnType.toCatalogDataType();
|
||||
if ((type.isDate() || type.isDatetime()) && Config.enable_date_conversion) {
|
||||
Type legacyReturnType = ScalarType.getDefaultDateType(returnType.toCatalogDataType());
|
||||
signature = signature.withReturnType(DataType.fromCatalogType(legacyReturnType));
|
||||
}
|
||||
return signature;
|
||||
}
|
||||
|
||||
@Developing
|
||||
private FunctionSignature computePrecisionForDecimal(
|
||||
FunctionSignature signature, List<Expression> arguments) {
|
||||
if (signature.returnType instanceof DecimalV3Type || signature.returnType instanceof DecimalV2Type) {
|
||||
if (this instanceof DecimalSamePrecision) {
|
||||
signature = signature.withReturnType(signature.argumentsTypes.get(0));
|
||||
} else if (this instanceof DecimalWiderPrecision) {
|
||||
ScalarType widerType = ScalarType.createDecimalV3Type(ScalarType.MAX_DECIMAL128_PRECISION,
|
||||
((ScalarType) signature.argumentsTypes.get(0).toCatalogDataType()).getScalarScale());
|
||||
signature = signature.withReturnType(DataType.fromCatalogType(widerType));
|
||||
} else if (this instanceof DecimalStddevPrecision) {
|
||||
// for all stddev function, use decimal(38,9) as computing result
|
||||
ScalarType stddevDecimalType = ScalarType.createDecimalV3Type(ScalarType.MAX_DECIMAL128_PRECISION,
|
||||
DecimalStddevPrecision.STDDEV_DECIMAL_SCALE);
|
||||
signature = signature.withReturnType(DataType.fromCatalogType(stddevDecimalType));
|
||||
}
|
||||
}
|
||||
|
||||
return signature;
|
||||
}
|
||||
|
||||
private FunctionSignature upgradeDecimalV2ToV3(
|
||||
FunctionSignature signature, List<Expression> arguments) {
|
||||
DataType returnType = signature.returnType;
|
||||
Type type = returnType.toCatalogDataType();
|
||||
if (type.isDecimalV2() && Config.enable_decimal_conversion && Config.enable_decimalv3) {
|
||||
Type v3Type = ScalarType.createDecimalV3Type(type.getPrecision(), ((ScalarType) type).getScalarScale());
|
||||
signature = signature.withReturnType(DataType.fromCatalogType(v3Type));
|
||||
}
|
||||
return signature;
|
||||
}
|
||||
|
||||
private FunctionSignature dynamicComputePropertiesOfArray(
|
||||
FunctionSignature signature, List<Expression> arguments) {
|
||||
if (!(signature.returnType instanceof ArrayType)) {
|
||||
return signature;
|
||||
}
|
||||
|
||||
// TODO: compute array(...) function's itemType
|
||||
|
||||
// fill item type by the type of first item
|
||||
ArrayType arrayType = (ArrayType) signature.returnType;
|
||||
|
||||
// fill containsNull if any array argument contains null
|
||||
boolean containsNull = signature.argumentsTypes
|
||||
.stream()
|
||||
.filter(argType -> argType instanceof ArrayType)
|
||||
.map(ArrayType.class::cast)
|
||||
.anyMatch(ArrayType::containsNull);
|
||||
return signature.withReturnType(
|
||||
ArrayType.of(arrayType.getItemType(), arrayType.containsNull() || containsNull));
|
||||
}
|
||||
|
||||
static class ComputeSignatureChain {
|
||||
private ResponsibilityChain<Pair<FunctionSignature, List<Expression>>> computeChain;
|
||||
|
||||
public ComputeSignatureChain(ResponsibilityChain<Pair<FunctionSignature, List<Expression>>> computeChain) {
|
||||
this.computeChain = computeChain;
|
||||
}
|
||||
|
||||
public static ComputeSignatureChain from(FunctionSignature signature, List<Expression> arguments) {
|
||||
return new ComputeSignatureChain(ResponsibilityChain.from(Pair.of(signature, arguments)));
|
||||
}
|
||||
|
||||
public ComputeSignatureChain then(
|
||||
BiFunction<FunctionSignature, List<Expression>, FunctionSignature> computeFunction) {
|
||||
computeChain.then(pair -> Pair.of(computeFunction.apply(pair.first, pair.second), pair.second));
|
||||
return this;
|
||||
}
|
||||
|
||||
public FunctionSignature get() {
|
||||
return computeChain.get().first;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -56,7 +56,7 @@ public class StrToDate extends ScalarFunction
|
||||
}
|
||||
|
||||
@Override
|
||||
protected FunctionSignature computeSignature(FunctionSignature signature) {
|
||||
protected FunctionSignature computeSignature(FunctionSignature signature, List<Expression> arguments) {
|
||||
/*
|
||||
* The return type of str_to_date depends on whether the time part is included in the format.
|
||||
* If included, it is datetime, otherwise it is date.
|
||||
@ -83,8 +83,8 @@ public class StrToDate extends ScalarFunction
|
||||
* Return type is DATETIME
|
||||
*/
|
||||
DataType returnType;
|
||||
if (child(1) instanceof StringLikeLiteral) {
|
||||
if (DateLiteral.hasTimePart(((StringLikeLiteral) child(1)).getStringValue())) {
|
||||
if (arguments.get(1) instanceof StringLikeLiteral) {
|
||||
if (DateLiteral.hasTimePart(((StringLikeLiteral) arguments.get(1)).getStringValue())) {
|
||||
returnType = DataType.fromCatalogType(ScalarType.getDefaultDateType(Type.DATETIME));
|
||||
} else {
|
||||
returnType = DataType.fromCatalogType(ScalarType.getDefaultDateType(Type.DATE));
|
||||
|
||||
@ -66,8 +66,9 @@ public class Substring extends ScalarFunction
|
||||
}
|
||||
|
||||
@Override
|
||||
protected FunctionSignature computeSignature(FunctionSignature signature) {
|
||||
Optional<Expression> length = getLength();
|
||||
protected FunctionSignature computeSignature(FunctionSignature signature, List<Expression> arguments) {
|
||||
Optional<Expression> length = arguments.size() == 3
|
||||
? Optional.of(arguments.get(2)) : Optional.empty();
|
||||
DataType returnType = VarcharType.SYSTEM_DEFAULT;
|
||||
if (length.isPresent() && length.get() instanceof IntegerLiteral) {
|
||||
returnType = VarcharType.createVarcharType(((IntegerLiteral) length.get()).getValue());
|
||||
|
||||
@ -18,6 +18,7 @@
|
||||
package org.apache.doris.nereids.trees.expressions.functions.table;
|
||||
|
||||
import org.apache.doris.analysis.IntLiteral;
|
||||
import org.apache.doris.catalog.FunctionSignature;
|
||||
import org.apache.doris.catalog.Type;
|
||||
import org.apache.doris.common.Id;
|
||||
import org.apache.doris.common.NereidsException;
|
||||
@ -26,6 +27,8 @@ import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.Slot;
|
||||
import org.apache.doris.nereids.trees.expressions.TVFProperties;
|
||||
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
|
||||
import org.apache.doris.nereids.types.BigIntType;
|
||||
import org.apache.doris.nereids.types.DataType;
|
||||
import org.apache.doris.statistics.ColumnStatistic;
|
||||
import org.apache.doris.statistics.StatsDeriveResult;
|
||||
import org.apache.doris.tablefunction.NumbersTableValuedFunction;
|
||||
@ -43,10 +46,15 @@ public class Numbers extends TableValuedFunction {
|
||||
super("numbers", properties);
|
||||
}
|
||||
|
||||
@Override
|
||||
public FunctionSignature customSignature(List<DataType> argumentTypes, List<Expression> arguments) {
|
||||
return FunctionSignature.of(BigIntType.INSTANCE, (List) argumentTypes);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected TableValuedFunctionIf toCatalogFunction() {
|
||||
try {
|
||||
Map<String, String> arguments = getKeyValuesExpression().getMap();
|
||||
Map<String, String> arguments = getTVFProperties().getMap();
|
||||
return new NumbersTableValuedFunction(arguments);
|
||||
} catch (Throwable t) {
|
||||
throw new AnalysisException("Can not build NumbersTableValuedFunction by "
|
||||
@ -82,9 +90,4 @@ public class Numbers extends TableValuedFunction {
|
||||
&& children().get(0) instanceof TVFProperties);
|
||||
return new Numbers((TVFProperties) children.get(0));
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean hasVarArguments() {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
@ -24,6 +24,7 @@ import org.apache.doris.nereids.exceptions.UnboundException;
|
||||
import org.apache.doris.nereids.trees.expressions.Slot;
|
||||
import org.apache.doris.nereids.trees.expressions.TVFProperties;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.BoundFunction;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.CustomSignature;
|
||||
import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression;
|
||||
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
|
||||
import org.apache.doris.nereids.types.DataType;
|
||||
@ -38,7 +39,7 @@ import java.util.function.Supplier;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/** TableValuedFunction */
|
||||
public abstract class TableValuedFunction extends BoundFunction implements UnaryExpression {
|
||||
public abstract class TableValuedFunction extends BoundFunction implements UnaryExpression, CustomSignature {
|
||||
protected final Supplier<TableValuedFunctionIf> catalogFunctionCache = Suppliers.memoize(() -> toCatalogFunction());
|
||||
protected final Supplier<FunctionGenTable> tableCache = Suppliers.memoize(() -> {
|
||||
try {
|
||||
@ -58,7 +59,7 @@ public abstract class TableValuedFunction extends BoundFunction implements Unary
|
||||
|
||||
public abstract StatsDeriveResult computeStats(List<Slot> slots);
|
||||
|
||||
public TVFProperties getKeyValuesExpression() {
|
||||
public TVFProperties getTVFProperties() {
|
||||
return (TVFProperties) child(0);
|
||||
}
|
||||
|
||||
@ -95,7 +96,7 @@ public abstract class TableValuedFunction extends BoundFunction implements Unary
|
||||
|
||||
@Override
|
||||
public String toSql() {
|
||||
String args = getKeyValuesExpression()
|
||||
String args = getTVFProperties()
|
||||
.getMap()
|
||||
.entrySet()
|
||||
.stream()
|
||||
|
||||
@ -29,10 +29,22 @@ public class ExplainCommand implements Command {
|
||||
* explain level.
|
||||
*/
|
||||
public enum ExplainLevel {
|
||||
NORMAL,
|
||||
VERBOSE,
|
||||
GRAPH,
|
||||
NONE(false),
|
||||
NORMAL(false),
|
||||
VERBOSE(false),
|
||||
GRAPH(false),
|
||||
PARSED_PLAN(true),
|
||||
ANALYZED_PLAN(true),
|
||||
REWRITTEN_PLAN(true),
|
||||
OPTIMIZED_PLAN(true),
|
||||
ALL_PLAN(true)
|
||||
;
|
||||
|
||||
public final boolean isPlanLevel;
|
||||
|
||||
ExplainLevel(boolean isPlanLevel) {
|
||||
this.isPlanLevel = isPlanLevel;
|
||||
}
|
||||
}
|
||||
|
||||
private final ExplainLevel level;
|
||||
|
||||
@ -78,7 +78,7 @@ public class PhysicalTVFRelation extends PhysicalRelation implements TVFRelation
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return Utils.toSqlString("LogicalTVFRelation",
|
||||
return Utils.toSqlString("PhysicalTVFRelation",
|
||||
"qualified", Utils.qualifiedName(qualifier, getTable().getName()),
|
||||
"output", getOutput(),
|
||||
"function", function.toSql()
|
||||
|
||||
@ -0,0 +1,99 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package org.apache.doris.nereids.types;
|
||||
|
||||
import org.apache.doris.catalog.ScalarType;
|
||||
import org.apache.doris.catalog.Type;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import com.google.common.collect.ImmutableList;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
|
||||
/** PartialAggType */
|
||||
public class PartialAggType extends DataType {
|
||||
public final List<Expression> originArguments;
|
||||
public final List<DataType> intermediateTypes;
|
||||
|
||||
/** PartialAggType */
|
||||
public PartialAggType(List<Expression> originArguments, List<DataType> intermediateTypes) {
|
||||
this.originArguments = ImmutableList.copyOf(
|
||||
Objects.requireNonNull(originArguments, "originArguments can not be null"));
|
||||
this.intermediateTypes = ImmutableList.copyOf(
|
||||
Objects.requireNonNull(intermediateTypes, "intermediateTypes can not be null"));
|
||||
Preconditions.checkArgument(intermediateTypes.size() > 0, "intermediateTypes can not empty");
|
||||
}
|
||||
|
||||
public List<Expression> getOriginArguments() {
|
||||
return originArguments;
|
||||
}
|
||||
|
||||
public List<DataType> getIntermediateTypes() {
|
||||
return intermediateTypes;
|
||||
}
|
||||
|
||||
public List<DataType> getOriginInputTypes() {
|
||||
return originArguments.stream()
|
||||
.map(Expression::getDataType)
|
||||
.collect(ImmutableList.toImmutableList());
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toSql() {
|
||||
return "PartialAggType(types=" + intermediateTypes + ")";
|
||||
}
|
||||
|
||||
@Override
|
||||
public int width() {
|
||||
return intermediateTypes.stream()
|
||||
.map(DataType::width)
|
||||
.reduce((w1, w2) -> w1 + w2)
|
||||
.get();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Type toCatalogDataType() {
|
||||
if (intermediateTypes.size() == 1) {
|
||||
return intermediateTypes.get(0).toCatalogDataType();
|
||||
}
|
||||
return ScalarType.createVarcharType(-1);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) {
|
||||
return true;
|
||||
}
|
||||
if (o == null || getClass() != o.getClass()) {
|
||||
return false;
|
||||
}
|
||||
if (!super.equals(o)) {
|
||||
return false;
|
||||
}
|
||||
PartialAggType that = (PartialAggType) o;
|
||||
return Objects.equals(originArguments, that.originArguments)
|
||||
&& Objects.equals(intermediateTypes, that.intermediateTypes);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(super.hashCode(), originArguments, intermediateTypes);
|
||||
}
|
||||
}
|
||||
@ -18,15 +18,19 @@
|
||||
package org.apache.doris.nereids.rules.analysis;
|
||||
|
||||
import org.apache.doris.catalog.FunctionRegistry;
|
||||
import org.apache.doris.catalog.FunctionSignature;
|
||||
import org.apache.doris.nereids.exceptions.AnalysisException;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.BoundFunction;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.FunctionBuilder;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.scalar.ScalarFunction;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.scalar.Substring;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.scalar.Year;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.Literal;
|
||||
import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression;
|
||||
import org.apache.doris.nereids.types.IntegerType;
|
||||
import org.apache.doris.nereids.util.MemoTestUtils;
|
||||
import org.apache.doris.nereids.util.PatternMatchSupported;
|
||||
import org.apache.doris.nereids.util.PlanChecker;
|
||||
@ -136,18 +140,27 @@ public class FunctionRegistryTest implements PatternMatchSupported {
|
||||
});
|
||||
}
|
||||
|
||||
public static class ExtendFunction extends BoundFunction implements UnaryExpression, PropagateNullable {
|
||||
public static class ExtendFunction extends BoundFunction implements UnaryExpression, PropagateNullable,
|
||||
ExplicitlyCastableSignature {
|
||||
public ExtendFunction(Expression a1) {
|
||||
super("foo", a1);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<FunctionSignature> getSignatures() {
|
||||
return ImmutableList.of(
|
||||
FunctionSignature.ret(IntegerType.INSTANCE).args(IntegerType.INSTANCE)
|
||||
);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean hasVarArguments() {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
public static class AmbiguousFunction extends BoundFunction implements UnaryExpression, PropagateNullable {
|
||||
public static class AmbiguousFunction extends ScalarFunction implements UnaryExpression, PropagateNullable,
|
||||
ExplicitlyCastableSignature {
|
||||
public AmbiguousFunction(Expression a1) {
|
||||
super("abc", a1);
|
||||
}
|
||||
@ -156,6 +169,13 @@ public class FunctionRegistryTest implements PatternMatchSupported {
|
||||
super("abc", a1);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<FunctionSignature> getSignatures() {
|
||||
return ImmutableList.of(
|
||||
FunctionSignature.ret(IntegerType.INSTANCE).args(IntegerType.INSTANCE)
|
||||
);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean hasVarArguments() {
|
||||
return false;
|
||||
|
||||
@ -17,6 +17,7 @@
|
||||
|
||||
package org.apache.doris.nereids.rules.analysis;
|
||||
|
||||
import org.apache.doris.common.NereidsException;
|
||||
import org.apache.doris.nereids.NereidsPlanner;
|
||||
import org.apache.doris.nereids.StatementContext;
|
||||
import org.apache.doris.nereids.analyzer.UnboundAlias;
|
||||
@ -282,7 +283,7 @@ public class RegisterCTETest extends TestWithFeService implements PatternMatchSu
|
||||
String sql = "WITH cte1 (a1, A1) AS (SELECT * FROM supplier)"
|
||||
+ "SELECT * FROM cte1";
|
||||
|
||||
AnalysisException exception = Assertions.assertThrows(AnalysisException.class, () -> {
|
||||
NereidsException exception = Assertions.assertThrows(NereidsException.class, () -> {
|
||||
PlanChecker.from(connectContext).checkPlannerResult(sql);
|
||||
}, "Not throw expected exception.");
|
||||
Assertions.assertTrue(exception.getMessage().contains("Duplicated CTE column alias: [a1] in CTE [cte1]"));
|
||||
@ -294,7 +295,7 @@ public class RegisterCTETest extends TestWithFeService implements PatternMatchSu
|
||||
+ "(SELECT s_suppkey FROM supplier)"
|
||||
+ "SELECT * FROM cte1";
|
||||
|
||||
AnalysisException exception = Assertions.assertThrows(AnalysisException.class, () -> {
|
||||
NereidsException exception = Assertions.assertThrows(NereidsException.class, () -> {
|
||||
PlanChecker.from(connectContext).checkPlannerResult(sql);
|
||||
}, "Not throw expected exception.");
|
||||
System.out.println(exception.getMessage());
|
||||
|
||||
@ -17,6 +17,7 @@
|
||||
|
||||
package org.apache.doris.nereids.rules.expression.rewrite;
|
||||
|
||||
import org.apache.doris.nereids.exceptions.AnalysisException;
|
||||
import org.apache.doris.nereids.rules.expression.rewrite.rules.TypeCoercion;
|
||||
import org.apache.doris.nereids.trees.expressions.CaseWhen;
|
||||
import org.apache.doris.nereids.trees.expressions.Cast;
|
||||
@ -48,6 +49,7 @@ import org.apache.doris.nereids.types.TinyIntType;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import com.google.common.collect.Lists;
|
||||
import org.junit.jupiter.api.Assertions;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
@ -100,16 +102,16 @@ public class TypeCoercionTest extends ExpressionRewriteTestHelper {
|
||||
|
||||
@Test
|
||||
public void testSumImplicitCast() {
|
||||
Expression expression = new Sum(new StringLiteral("1"));
|
||||
Expression expected = new Sum(new Cast(new StringLiteral("1"), DoubleType.INSTANCE));
|
||||
assertRewrite(expression, expected);
|
||||
Assertions.assertThrows(AnalysisException.class, () -> {
|
||||
new Sum(new StringLiteral("1")).getDataType();
|
||||
});
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testAvgImplicitCast() {
|
||||
Expression expression = new Avg(new StringLiteral("1"));
|
||||
Expression expected = new Avg(new Cast(new StringLiteral("1"), DoubleType.INSTANCE));
|
||||
assertRewrite(expression, expected);
|
||||
Assertions.assertThrows(AnalysisException.class, () -> {
|
||||
new Avg(new StringLiteral("1")).getDataType();
|
||||
});
|
||||
}
|
||||
|
||||
@Test
|
||||
|
||||
@ -45,7 +45,6 @@ import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.TestInstance;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
|
||||
@TestInstance(TestInstance.Lifecycle.PER_CLASS)
|
||||
public class AggregateDisassembleTest implements PatternMatchSupported {
|
||||
@ -277,7 +276,7 @@ public class AggregateDisassembleTest implements PatternMatchSupported {
|
||||
// id
|
||||
Expression localOutput0 = rStudent.getOutput().get(0);
|
||||
// sum
|
||||
Sum localOutput1 = new Sum(new AggregateParam(false, false, Optional.empty()), rStudent.getOutput().get(0).toSlot());
|
||||
Sum localOutput1 = new Sum(new AggregateParam(false, false, true), rStudent.getOutput().get(0).toSlot());
|
||||
// age
|
||||
Expression localOutput2 = rStudent.getOutput().get(2);
|
||||
// id
|
||||
|
||||
@ -17,12 +17,10 @@
|
||||
|
||||
package org.apache.doris.nereids.rules.rewrite.logical;
|
||||
|
||||
import org.apache.doris.common.AnalysisException;
|
||||
import org.apache.doris.nereids.NereidsPlanner;
|
||||
import org.apache.doris.nereids.parser.NereidsParser;
|
||||
import org.apache.doris.nereids.properties.PhysicalProperties;
|
||||
import org.apache.doris.nereids.trees.expressions.NamedExpressionUtil;
|
||||
import org.apache.doris.nereids.trees.plans.physical.PhysicalPlan;
|
||||
import org.apache.doris.nereids.util.PatternMatchSupported;
|
||||
import org.apache.doris.nereids.util.PlanChecker;
|
||||
import org.apache.doris.utframe.TestWithFeService;
|
||||
@ -33,13 +31,6 @@ import org.junit.jupiter.api.Test;
|
||||
import java.util.List;
|
||||
|
||||
public class PushdownExpressionsInHashConditionTest extends TestWithFeService implements PatternMatchSupported {
|
||||
private final List<String> testSql = ImmutableList.of(
|
||||
"SELECT * FROM T1 JOIN T2 ON T1.ID + 1 = T2.ID + 2 AND T1.ID + 1 > 2",
|
||||
"SELECT * FROM (SELECT * FROM T1) X JOIN (SELECT * FROM T2) Y ON X.ID + 1 = Y.ID + 2 AND X.ID + 1 > 2",
|
||||
"SELECT * FROM T1 JOIN (SELECT ID, SUM(SCORE) SCORE FROM T2 GROUP BY ID) T ON T1.ID + 1 = T.ID AND T.SCORE < 10",
|
||||
"SELECT * FROM T1 JOIN (SELECT ID, SUM(SCORE) SCORE FROM T2 GROUP BY ID ORDER BY ID) T ON T1.ID + 1 = T.ID AND T.SCORE < 10"
|
||||
);
|
||||
|
||||
@Override
|
||||
protected void runBeforeAll() throws Exception {
|
||||
createDatabase("test");
|
||||
@ -81,15 +72,10 @@ public class PushdownExpressionsInHashConditionTest extends TestWithFeService im
|
||||
"SELECT * FROM T1 JOIN (SELECT ID, SUM(SCORE) SCORE FROM T2 GROUP BY ID ORDER BY ID) T ON T1.ID + 1 = T.ID AND T.SCORE < 10"
|
||||
);
|
||||
testSql.forEach(sql -> {
|
||||
try {
|
||||
PhysicalPlan plan = new NereidsPlanner(createStatementCtx(sql)).plan(
|
||||
new NereidsParser().parseSingle(sql),
|
||||
PhysicalProperties.ANY
|
||||
);
|
||||
System.out.println(plan.treeString());
|
||||
} catch (AnalysisException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
new NereidsPlanner(createStatementCtx(sql)).plan(
|
||||
new NereidsParser().parseSingle(sql),
|
||||
PhysicalProperties.ANY
|
||||
);
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
@ -32,4 +32,19 @@ suite("explain") {
|
||||
contains "project output tuple id: 1"
|
||||
}
|
||||
|
||||
|
||||
explain {
|
||||
sql("physical plan select 100")
|
||||
contains "PhysicalOneRowRelation"
|
||||
}
|
||||
|
||||
explain {
|
||||
sql("logical plan select 100")
|
||||
contains "LogicalOneRowRelation"
|
||||
}
|
||||
|
||||
explain {
|
||||
sql("parsed plan select 100")
|
||||
contains "UnboundOneRowRelation"
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user