diff --git a/be/src/vec/exprs/lambda_function/varray_map_function.cpp b/be/src/vec/exprs/lambda_function/varray_map_function.cpp index 67bd6bf4aa..d32bb094f8 100644 --- a/be/src/vec/exprs/lambda_function/varray_map_function.cpp +++ b/be/src/vec/exprs/lambda_function/varray_map_function.cpp @@ -81,7 +81,7 @@ public: // offset column MutableColumnPtr array_column_offset; int nested_array_column_rows = 0; - const ColumnArray::Offsets64* array_offsets = nullptr; + ColumnPtr first_array_offsets = nullptr; //2. get the result column from executed expr, and the needed is nested column of array Block lambda_block; for (int i = 0; i < arguments.size(); ++i) { @@ -113,17 +113,20 @@ public: if (i == 0) { nested_array_column_rows = col_array.get_data_ptr()->size(); - array_offsets = &col_array.get_offsets(); + first_array_offsets = col_array.get_offsets_ptr(); auto& off_data = assert_cast( col_array.get_offsets_column()); array_column_offset = off_data.clone_resized(col_array.get_offsets_column().size()); } else { // select array_map((x,y)->x+y,c_array1,[0,1,2,3]) from array_test2; // c_array1: [0,1,2,3,4,5,6,7,8,9] + auto& array_offsets = + assert_cast(*first_array_offsets) + .get_data(); if (nested_array_column_rows != col_array.get_data_ptr()->size() || - (array_offsets->size() > 0 && - memcmp(array_offsets->data(), col_array.get_offsets().data(), - sizeof((*array_offsets)[0]) * array_offsets->size()) != 0)) { + (array_offsets.size() > 0 && + memcmp(array_offsets.data(), col_array.get_offsets().data(), + sizeof(array_offsets[0]) * array_offsets.size()) != 0)) { return Status::InternalError( "in array map function, the input column size " "are " diff --git a/fe/fe-core/src/main/antlr4/org/apache/doris/nereids/DorisParser.g4 b/fe/fe-core/src/main/antlr4/org/apache/doris/nereids/DorisParser.g4 index b19f6123ed..d52473cd0c 100644 --- a/fe/fe-core/src/main/antlr4/org/apache/doris/nereids/DorisParser.g4 +++ b/fe/fe-core/src/main/antlr4/org/apache/doris/nereids/DorisParser.g4 @@ -408,6 +408,15 @@ namedExpressionSeq expression : booleanExpression + | lambdaExpression + ; + +lambdaExpression + : args+=errorCapturingIdentifier ARROW body=booleanExpression + | LEFT_PAREN + args+=errorCapturingIdentifier (COMMA args+=errorCapturingIdentifier)+ + RIGHT_PAREN + ARROW body=booleanExpression ; booleanExpression diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/LambdaFunctionCallExpr.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/LambdaFunctionCallExpr.java index 4c9455d4e0..2abee78581 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/analysis/LambdaFunctionCallExpr.java +++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/LambdaFunctionCallExpr.java @@ -61,6 +61,11 @@ public class LambdaFunctionCallExpr extends FunctionCallExpr { super(other); } + // nereids high order function call expr constructor without finalize/analyze + public LambdaFunctionCallExpr(Function function, FunctionParams functionParams) { + super(function, functionParams, null, false, functionParams.exprs()); + } + @Override public Expr clone() { return new LambdaFunctionCallExpr(this); diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/LambdaFunctionExpr.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/LambdaFunctionExpr.java index e2e7b90bfb..434893d554 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/analysis/LambdaFunctionExpr.java +++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/LambdaFunctionExpr.java @@ -53,6 +53,16 @@ public class LambdaFunctionExpr extends Expr { this.setType(Type.LAMBDA_FUNCTION); } + // for Nereids + public LambdaFunctionExpr(Expr lambdaBody, List argNames, List slotExpr) { + this.slotExpr.add(lambdaBody); + this.slotExpr.addAll(slotExpr); + this.names.addAll(argNames); + this.params.addAll(slotExpr); + this.children.add(lambdaBody); + this.setType(Type.LAMBDA_FUNCTION); + } + public LambdaFunctionExpr(LambdaFunctionExpr rhs) { super(rhs); this.names.addAll(rhs.names); diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinScalarFunctions.java b/fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinScalarFunctions.java index 90789100cd..d65f84d876 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinScalarFunctions.java +++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinScalarFunctions.java @@ -35,6 +35,7 @@ import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayEnumerat import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayExcept; import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayIntersect; import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayJoin; +import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayMap; import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayMax; import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayMin; import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayPopBack; @@ -392,6 +393,7 @@ public class BuiltinScalarFunctions implements FunctionHelper { scalar(ArrayExcept.class, "array_except"), scalar(ArrayIntersect.class, "array_intersect"), scalar(ArrayJoin.class, "array_join"), + scalar(ArrayMap.class, "array_map"), scalar(ArrayMax.class, "array_max"), scalar(ArrayMin.class, "array_min"), scalar(ArrayPopBack.class, "array_popback"), diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/analyzer/UnboundOneRowRelation.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/analyzer/UnboundOneRowRelation.java index dd04e9625a..d2574c6ff2 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/analyzer/UnboundOneRowRelation.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/analyzer/UnboundOneRowRelation.java @@ -32,7 +32,6 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalRelation; import org.apache.doris.nereids.trees.plans.visitor.PlanVisitor; import org.apache.doris.nereids.util.Utils; -import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; import java.util.List; @@ -55,8 +54,6 @@ public class UnboundOneRowRelation extends LogicalRelation implements Unbound, O Optional groupExpression, Optional logicalProperties) { super(id, PlanType.LOGICAL_UNBOUND_ONE_ROW_RELATION, groupExpression, logicalProperties); - Preconditions.checkArgument(projects.stream().noneMatch(p -> p.containsType(Slot.class)), - "OneRowRelation can not contains any slot"); this.projects = ImmutableList.copyOf(projects); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/ExpressionTranslator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/ExpressionTranslator.java index 0244a80c2a..d6e9911145 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/ExpressionTranslator.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/ExpressionTranslator.java @@ -25,6 +25,7 @@ import org.apache.doris.analysis.BoolLiteral; import org.apache.doris.analysis.CaseExpr; import org.apache.doris.analysis.CaseWhenClause; import org.apache.doris.analysis.CastExpr; +import org.apache.doris.analysis.ColumnRefExpr; import org.apache.doris.analysis.CompoundPredicate; import org.apache.doris.analysis.Expr; import org.apache.doris.analysis.FunctionCallExpr; @@ -33,12 +34,15 @@ import org.apache.doris.analysis.FunctionParams; import org.apache.doris.analysis.IndexDef; import org.apache.doris.analysis.InvertedIndexUtil; import org.apache.doris.analysis.IsNullPredicate; +import org.apache.doris.analysis.LambdaFunctionCallExpr; +import org.apache.doris.analysis.LambdaFunctionExpr; import org.apache.doris.analysis.MatchPredicate; import org.apache.doris.analysis.OrderByElement; import org.apache.doris.analysis.SlotDescriptor; import org.apache.doris.analysis.SlotRef; import org.apache.doris.analysis.TimestampArithmeticExpr; import org.apache.doris.analysis.TupleDescriptor; +import org.apache.doris.catalog.ArrayType; import org.apache.doris.catalog.Function; import org.apache.doris.catalog.Function.NullableMode; import org.apache.doris.catalog.Index; @@ -48,6 +52,7 @@ import org.apache.doris.nereids.exceptions.AnalysisException; import org.apache.doris.nereids.trees.expressions.AggregateExpression; import org.apache.doris.nereids.trees.expressions.Alias; import org.apache.doris.nereids.trees.expressions.And; +import org.apache.doris.nereids.trees.expressions.ArrayItemReference; import org.apache.doris.nereids.trees.expressions.AssertNumRowsElement; import org.apache.doris.nereids.trees.expressions.BinaryArithmetic; import org.apache.doris.nereids.trees.expressions.CaseWhen; @@ -82,6 +87,7 @@ import org.apache.doris.nereids.trees.expressions.functions.combinator.MergeComb import org.apache.doris.nereids.trees.expressions.functions.combinator.StateCombinator; import org.apache.doris.nereids.trees.expressions.functions.combinator.UnionCombinator; import org.apache.doris.nereids.trees.expressions.functions.generator.TableGeneratingFunction; +import org.apache.doris.nereids.trees.expressions.functions.scalar.Lambda; import org.apache.doris.nereids.trees.expressions.functions.scalar.ScalarFunction; import org.apache.doris.nereids.trees.expressions.functions.udf.JavaUdaf; import org.apache.doris.nereids.trees.expressions.functions.udf.JavaUdf; @@ -262,6 +268,16 @@ public class ExpressionTranslator extends DefaultExpressionVisitor arguments = lambda.getLambdaArguments().stream().map(e -> e.accept(this, context)) + .collect(Collectors.toList()); + return new LambdaFunctionExpr(func, lambda.getLambdaArgumentNames(), arguments); + } + + private Expr visitHighOrderFunction(ScalarFunction function, PlanTranslatorContext context) { + Lambda lambda = (Lambda) function.child(0); + List arguments = new ArrayList<>(function.children().size()); + arguments.add(null); + int columnId = 0; + for (ArrayItemReference arrayItemReference : lambda.getLambdaArguments()) { + String argName = arrayItemReference.getName(); + Expr expr = arrayItemReference.getArrayExpression().accept(this, context); + arguments.add(expr); + + ColumnRefExpr column = new ColumnRefExpr(); + column.setName(argName); + column.setColumnId(columnId); + column.setNullable(true); + column.setType(((ArrayType) expr.getType()).getItemType()); + context.addExprIdColumnRefPair(arrayItemReference.getExprId(), column); + columnId += 1; + } + + List argTypes = function.getArguments().stream() + .map(Expression::getDataType) + .map(DataType::toCatalogDataType) + .collect(Collectors.toList()); + lambda.getLambdaArguments().stream() + .map(ArrayItemReference::getArrayExpression) + .map(Expression::getDataType) + .map(DataType::toCatalogDataType) + .forEach(argTypes::add); + org.apache.doris.catalog.Function catalogFunction = new Function( + new FunctionName(function.getName()), argTypes, + ArrayType.create(lambda.getRetType().toCatalogDataType(), true), + true, true, NullableMode.DEPEND_ON_ARGUMENT); + + // create catalog FunctionCallExpr without analyze again + Expr lambdaBody = visitLambda(lambda, context); + arguments.set(0, lambdaBody); + return new LambdaFunctionCallExpr(catalogFunction, new FunctionParams(false, arguments)); + } + @Override public Expr visitScalarFunction(ScalarFunction function, PlanTranslatorContext context) { + if (function.isHighOrder()) { + return visitHighOrderFunction(function, context); + } + List arguments = function.getArguments().stream() .map(arg -> arg.accept(this, context)) .collect(Collectors.toList()); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PlanTranslatorContext.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PlanTranslatorContext.java index b04b3d0901..721eea37b7 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PlanTranslatorContext.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PlanTranslatorContext.java @@ -17,6 +17,7 @@ package org.apache.doris.nereids.glue.translator; +import org.apache.doris.analysis.ColumnRefExpr; import org.apache.doris.analysis.DescriptorTable; import org.apache.doris.analysis.SlotDescriptor; import org.apache.doris.analysis.SlotId; @@ -77,6 +78,13 @@ public class PlanTranslatorContext { */ private final Map slotIdToExprId = Maps.newHashMap(); + /** + * For each lambda argument (ArrayItemReference), + * we create a ColumnRef representing it and + * then translate it based on the ExprId of the ArrayItemReference. + */ + private final Map exprIdToColumnRef = Maps.newHashMap(); + private final List scanNodes = Lists.newArrayList(); private final IdGenerator fragmentIdGenerator = PlanFragmentId.createGenerator(); @@ -187,6 +195,10 @@ public class PlanTranslatorContext { slotIdToExprId.put(slotRef.getDesc().getId(), exprId); } + public void addExprIdColumnRefPair(ExprId exprId, ColumnRefExpr columnRefExpr) { + exprIdToColumnRef.put(exprId, columnRefExpr); + } + public void mergePlanFragment(PlanFragment srcFragment, PlanFragment targetFragment) { srcFragment.getTargetRuntimeFilterIds().forEach(targetFragment::setTargetRuntimeFilterIds); srcFragment.getBuilderRuntimeFilterIds().forEach(targetFragment::setBuilderRuntimeFilterIds); @@ -197,6 +209,10 @@ public class PlanTranslatorContext { return exprIdToSlotRef.get(exprId); } + public ColumnRefExpr findColumnRef(ExprId exprId) { + return exprIdToColumnRef.get(exprId); + } + public void addScanNode(ScanNode scanNode) { scanNodes.add(scanNode); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/LogicalPlanBuilder.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/LogicalPlanBuilder.java index 5a50c25523..23675cc56b 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/LogicalPlanBuilder.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/LogicalPlanBuilder.java @@ -86,6 +86,7 @@ import org.apache.doris.nereids.DorisParser.Is_not_null_predContext; import org.apache.doris.nereids.DorisParser.IsnullContext; import org.apache.doris.nereids.DorisParser.JoinCriteriaContext; import org.apache.doris.nereids.DorisParser.JoinRelationContext; +import org.apache.doris.nereids.DorisParser.LambdaExpressionContext; import org.apache.doris.nereids.DorisParser.LateralViewContext; import org.apache.doris.nereids.DorisParser.LessThanPartitionDefContext; import org.apache.doris.nereids.DorisParser.LimitClauseContext; @@ -230,6 +231,7 @@ import org.apache.doris.nereids.trees.expressions.functions.scalar.HourFloor; import org.apache.doris.nereids.trees.expressions.functions.scalar.HoursAdd; import org.apache.doris.nereids.trees.expressions.functions.scalar.HoursDiff; import org.apache.doris.nereids.trees.expressions.functions.scalar.HoursSub; +import org.apache.doris.nereids.trees.expressions.functions.scalar.Lambda; import org.apache.doris.nereids.trees.expressions.functions.scalar.MinuteCeil; import org.apache.doris.nereids.trees.expressions.functions.scalar.MinuteFloor; import org.apache.doris.nereids.trees.expressions.functions.scalar.MinutesAdd; @@ -949,6 +951,15 @@ public class LogicalPlanBuilder extends DorisParserBaseVisitor { }); } + @Override + public Expression visitLambdaExpression(LambdaExpressionContext ctx) { + ImmutableList args = ctx.args.stream() + .map(RuleContext::getText) + .collect(ImmutableList.toImmutableList()); + Expression body = (Expression) visit(ctx.body); + return new Lambda(args, body); + } + private Expression expressionCombiner(Expression left, Expression right, LogicalBinaryContext ctx) { switch (ctx.operator.getType()) { case DorisParser.LOGICALAND: diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindExpression.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindExpression.java index fae733edce..1dc24fe476 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindExpression.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindExpression.java @@ -47,6 +47,7 @@ import org.apache.doris.nereids.trees.expressions.functions.Function; import org.apache.doris.nereids.trees.expressions.functions.FunctionBuilder; import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; import org.apache.doris.nereids.trees.expressions.functions.generator.TableGeneratingFunction; +import org.apache.doris.nereids.trees.expressions.functions.scalar.Lambda; import org.apache.doris.nereids.trees.expressions.functions.table.TableValuedFunction; import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter; import org.apache.doris.nereids.trees.plans.JoinType; @@ -326,7 +327,7 @@ public class BindExpression implements AnalysisRuleFactory { return e; }) .collect(Collectors.toList()); - groupBy.forEach(expression -> checkBound(expression, ctx.root)); + groupBy.forEach(expression -> checkBoundExceptLambda(expression, ctx.root)); groupBy = groupBy.stream() .map(expr -> bindFunction(expr, ctx.root, ctx.cascadesContext)) .collect(ImmutableList.toImmutableList()); @@ -649,7 +650,7 @@ public class BindExpression implements AnalysisRuleFactory { @SuppressWarnings("unchecked") private E bindFunction(E expr, Plan plan, CascadesContext cascadesContext) { - return (E) FunctionBinder.INSTANCE.rewrite(checkBound(expr, plan), + return (E) FunctionBinder.INSTANCE.rewrite(checkBoundExceptLambda(expr, plan), new ExpressionRewriteContext(cascadesContext)); } @@ -752,20 +753,22 @@ public class BindExpression implements AnalysisRuleFactory { } } - private E checkBound(E expression, Plan plan) { - expression.foreachUp(e -> { - if (e instanceof UnboundSlot) { - UnboundSlot unboundSlot = (UnboundSlot) e; - String tableName = StringUtils.join(unboundSlot.getQualifier(), "."); - if (tableName.isEmpty()) { - tableName = "table list"; - } - throw new AnalysisException("Unknown column '" - + unboundSlot.getNameParts().get(unboundSlot.getNameParts().size() - 1) - + "' in '" + tableName + "' in " - + plan.getType().toString().substring("LOGICAL_".length()) + " clause"); + private E checkBoundExceptLambda(E expression, Plan plan) { + if (expression instanceof Lambda) { + return expression; + } + if (expression instanceof UnboundSlot) { + UnboundSlot unboundSlot = (UnboundSlot) expression; + String tableName = StringUtils.join(unboundSlot.getQualifier(), "."); + if (tableName.isEmpty()) { + tableName = "table list"; } - }); + throw new AnalysisException("Unknown column '" + + unboundSlot.getNameParts().get(unboundSlot.getNameParts().size() - 1) + + "' in '" + tableName + "' in " + + plan.getType().toString().substring("LOGICAL_".length()) + " clause"); + } + expression.children().forEach(e -> checkBoundExceptLambda(e, plan)); return expression; } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/CheckAnalysis.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/CheckAnalysis.java index 321dbe9f76..d32823799e 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/CheckAnalysis.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/CheckAnalysis.java @@ -21,7 +21,6 @@ import org.apache.doris.nereids.exceptions.AnalysisException; import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleType; import org.apache.doris.nereids.trees.expressions.Expression; -import org.apache.doris.nereids.trees.expressions.SlotReference; import org.apache.doris.nereids.trees.expressions.WindowExpression; import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; import org.apache.doris.nereids.trees.expressions.functions.generator.TableGeneratingFunction; @@ -79,7 +78,6 @@ public class CheckAnalysis implements AnalysisRuleFactory { .put(LogicalOneRowRelation.class, ImmutableSet.of( AggregateFunction.class, GroupingScalarFunction.class, - SlotReference.class, TableGeneratingFunction.class, WindowExpression.class)) .put(LogicalProject.class, ImmutableSet.of( diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/SlotBinder.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/SlotBinder.java index 303254868c..e8af1a832e 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/SlotBinder.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/SlotBinder.java @@ -52,7 +52,10 @@ import java.util.List; import java.util.Optional; import java.util.stream.Collectors; -class SlotBinder extends SubExprAnalyzer { +/** + * SlotBinder is used to bind slot + */ +public class SlotBinder extends SubExprAnalyzer { /* bounded={table.a, a} unbound=a diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/FunctionBinder.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/FunctionBinder.java index cd9c16dd87..6488e5e76a 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/FunctionBinder.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/FunctionBinder.java @@ -20,11 +20,15 @@ package org.apache.doris.nereids.rules.expression.rules; import org.apache.doris.analysis.ArithmeticExpr.Operator; import org.apache.doris.catalog.Env; import org.apache.doris.catalog.FunctionRegistry; +import org.apache.doris.nereids.analyzer.Scope; import org.apache.doris.nereids.analyzer.UnboundFunction; +import org.apache.doris.nereids.analyzer.UnboundSlot; import org.apache.doris.nereids.exceptions.AnalysisException; import org.apache.doris.nereids.rules.analysis.ArithmeticFunctionBinder; +import org.apache.doris.nereids.rules.analysis.SlotBinder; import org.apache.doris.nereids.rules.expression.AbstractExpressionRewriteRule; import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext; +import org.apache.doris.nereids.trees.expressions.ArrayItemReference; import org.apache.doris.nereids.trees.expressions.BinaryArithmetic; import org.apache.doris.nereids.trees.expressions.BitNot; import org.apache.doris.nereids.trees.expressions.CaseWhen; @@ -40,10 +44,12 @@ import org.apache.doris.nereids.trees.expressions.IntegralDivide; import org.apache.doris.nereids.trees.expressions.ListQuery; import org.apache.doris.nereids.trees.expressions.Match; import org.apache.doris.nereids.trees.expressions.Not; +import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.expressions.TimestampArithmetic; import org.apache.doris.nereids.trees.expressions.WhenClause; import org.apache.doris.nereids.trees.expressions.functions.BoundFunction; import org.apache.doris.nereids.trees.expressions.functions.FunctionBuilder; +import org.apache.doris.nereids.trees.expressions.functions.scalar.Lambda; import org.apache.doris.nereids.trees.expressions.functions.udf.AliasUdfBuilder; import org.apache.doris.nereids.trees.expressions.typecoercion.ImplicitCastInputTypes; import org.apache.doris.nereids.types.ArrayType; @@ -54,6 +60,7 @@ import org.apache.doris.nereids.util.TypeCoercionUtils; import com.google.common.collect.ImmutableList; +import java.util.ArrayList; import java.util.List; import java.util.Locale; import java.util.stream.Collectors; @@ -82,11 +89,56 @@ public class FunctionBinder extends AbstractExpressionRewriteRule { /* ******************************************************************************************** * bind function * ******************************************************************************************** */ + private void checkBoundLambda(Expression lambdaFunction, List argumentNames) { + lambdaFunction.foreachUp(e -> { + if (e instanceof UnboundSlot) { + UnboundSlot unboundSlot = (UnboundSlot) e; + throw new AnalysisException("Unknown lambda slot '" + + unboundSlot.getNameParts().get(unboundSlot.getNameParts().size() - 1) + + " in lambda arguments" + argumentNames); + } + }); + } + + private UnboundFunction bindHighOrderFunction(UnboundFunction unboundFunction, ExpressionRewriteContext context) { + int childrenSize = unboundFunction.children().size(); + List subChildren = new ArrayList<>(); + for (int i = 1; i < childrenSize; i++) { + subChildren.add(unboundFunction.child(i).accept(this, context)); + } + + // bindLambdaFunction + Lambda lambda = (Lambda) unboundFunction.children().get(0); + Expression lambdaFunction = lambda.getLambdaFunction(); + List arrayItemReferences = lambda.makeArguments(subChildren); + + // 1.bindSlot + List boundedSlots = arrayItemReferences.stream() + .map(ArrayItemReference::toSlot) + .collect(ImmutableList.toImmutableList()); + lambdaFunction = new SlotBinder(new Scope(boundedSlots), context.cascadesContext, + true, false).bind(lambdaFunction); + checkBoundLambda(lambdaFunction, lambda.getLambdaArgumentNames()); + + // 2.bindFunction + lambdaFunction = lambdaFunction.accept(this, context); + + Lambda lambdaClosure = lambda.withLambdaFunctionArguments(lambdaFunction, arrayItemReferences); + + // We don't add the ArrayExpression in high order function at all + return unboundFunction.withChildren(ImmutableList.builder() + .add(lambdaClosure) + .build()); + } @Override public Expression visitUnboundFunction(UnboundFunction unboundFunction, ExpressionRewriteContext context) { - unboundFunction = unboundFunction.withChildren(unboundFunction.children().stream() - .map(e -> e.accept(this, context)).collect(Collectors.toList())); + if (unboundFunction.isHighOrder()) { + unboundFunction = bindHighOrderFunction(unboundFunction, context); + } else { + unboundFunction = unboundFunction.withChildren(unboundFunction.children().stream() + .map(e -> e.accept(this, context)).collect(Collectors.toList())); + } // bind function FunctionRegistry functionRegistry = Env.getCurrentEnv().getFunctionRegistry(); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/ArrayItemReference.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/ArrayItemReference.java new file mode 100644 index 0000000000..afebaa3016 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/ArrayItemReference.java @@ -0,0 +1,149 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.trees.expressions; + +import org.apache.doris.nereids.trees.expressions.typecoercion.ExpectsInputTypes; +import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; +import org.apache.doris.nereids.types.ArrayType; +import org.apache.doris.nereids.types.DataType; +import org.apache.doris.nereids.types.coercion.AnyDataType; + +import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableList; + +import java.util.List; +import java.util.Objects; + +/** + * it is item from array, which used in lambda function + */ +public class ArrayItemReference extends NamedExpression implements ExpectsInputTypes { + + protected final ExprId exprId; + protected final String name; + + /** ArrayItemReference */ + public ArrayItemReference(String name, Expression arrayExpression) { + this(StatementScopeIdGenerator.newExprId(), name, arrayExpression); + } + + public ArrayItemReference(ExprId exprId, String name, Expression arrayExpression) { + super(ImmutableList.of(arrayExpression)); + Preconditions.checkArgument(arrayExpression.getDataType() instanceof ArrayType, + String.format("ArrayItemReference' child %s must return array", child(0))); + this.exprId = exprId; + this.name = name; + } + + public Expression getArrayExpression() { + return children.get(0); + } + + @Override + public R accept(ExpressionVisitor visitor, C context) { + return visitor.visitArrayItemReference(this, context); + } + + @Override + public String getName() { + return name; + } + + @Override + public ExprId getExprId() { + return exprId; + } + + @Override + public List getQualifier() { + return ImmutableList.of(name); + } + + @Override + public boolean nullable() { + return ((ArrayType) (this.children.get(0).getDataType())).containsNull(); + } + + @Override + public ArrayItemReference withChildren(List expressions) { + return new ArrayItemReference(exprId, name, expressions.get(0)); + } + + @Override + public DataType getDataType() { + return ((ArrayType) (this.children.get(0).getDataType())).getItemType(); + } + + @Override + public String toSql() { + return child(0).toSql(); + } + + @Override + public Slot toSlot() { + return new ArrayItemSlot(exprId, name, getDataType(), nullable()); + } + + @Override + public String toString() { + String str = getName() + "#" + getExprId(); + str += " of " + child(0).toString(); + return str; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + ArrayItemReference that = (ArrayItemReference) o; + return exprId.equals(that.exprId); + } + + @Override + public int hashCode() { + return Objects.hash(exprId); + } + + @Override + public List expectedInputTypes() { + return ImmutableList.of(ArrayType.of(AnyDataType.INSTANCE_WITHOUT_INDEX)); + } + + static class ArrayItemSlot extends SlotReference implements SlotNotFromChildren { + /** + * Constructor for SlotReference. + * + * @param exprId UUID for this slot reference + * @param name slot reference name + * @param dataType slot reference logical data type + * @param nullable true if nullable + */ + public ArrayItemSlot(ExprId exprId, String name, DataType dataType, boolean nullable) { + super(exprId, name, dataType, nullable, ImmutableList.of(), null); + } + + @Override + public R accept(ExpressionVisitor visitor, C context) { + return visitor.visitArrayItemSlot(this, context); + } + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/Function.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/Function.java index 71d4519275..d1d23c192b 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/Function.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/Function.java @@ -18,6 +18,7 @@ package org.apache.doris.nereids.trees.expressions.functions; import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.functions.scalar.Lambda; import java.util.List; @@ -32,4 +33,8 @@ public abstract class Function extends Expression { public Function(List children) { super(children); } + + public boolean isHighOrder() { + return !children.isEmpty() && children.get(0) instanceof Lambda; + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/ArrayMap.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/ArrayMap.java new file mode 100644 index 0000000000..f152b7d386 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/ArrayMap.java @@ -0,0 +1,80 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.trees.expressions.functions.scalar; + +import org.apache.doris.catalog.FunctionSignature; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature; +import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable; +import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; +import org.apache.doris.nereids.types.ArrayType; +import org.apache.doris.nereids.types.DataType; +import org.apache.doris.nereids.types.LambdaType; +import org.apache.doris.nereids.types.coercion.FollowToAnyDataType; + +import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableList; + +import java.util.List; + +/** + * ScalarFunction 'array_map'. + */ +public class ArrayMap extends ScalarFunction + implements ExplicitlyCastableSignature, PropagateNullable { + + public static final List SIGNATURES = ImmutableList.of( + FunctionSignature.ret(new FollowToAnyDataType(0)).args(LambdaType.INSTANCE) + ); + + /** + * constructor with arguments. + */ + public ArrayMap(Expression... arg) { + super("array_map", arg); + } + + public ArrayMap(List arg) { + super("array_map", arg); + } + + /** + * withChildren. + */ + @Override + public ArrayMap withChildren(List children) { + return new ArrayMap(children); + } + + @Override + public DataType getDataType() { + Preconditions.checkArgument(children.get(0) instanceof Lambda, + "The first arg of array_map must be lambda"); + return ArrayType.of(((Lambda) children.get(0)).getRetType(), true); + } + + @Override + public R accept(ExpressionVisitor visitor, C context) { + return visitor.visitArrayMap(this, context); + } + + @Override + public List getSignatures() { + return SIGNATURES; + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/Lambda.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/Lambda.java new file mode 100644 index 0000000000..e4df28d22d --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/Lambda.java @@ -0,0 +1,168 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.trees.expressions.functions.scalar; + +import org.apache.doris.nereids.exceptions.AnalysisException; +import org.apache.doris.nereids.trees.expressions.ArrayItemReference; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; +import org.apache.doris.nereids.types.ArrayType; +import org.apache.doris.nereids.types.DataType; +import org.apache.doris.nereids.types.LambdaType; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableList.Builder; + +import java.util.List; +import java.util.Objects; +import java.util.stream.Collectors; + +/** + * Lambda includes lambda arguments and function body + * Before bind, x -> x : arguments("x") -> children: Expression(x) + * After bind, x -> x : arguments("x") -> children: Expression(x) ArrayItemReference(x) + */ +public class Lambda extends Expression { + + private final List argumentNames; + + /** + * constructor + */ + public Lambda(List argumentNames, Expression lambdaFunction) { + this(argumentNames, ImmutableList.of(lambdaFunction)); + } + + public Lambda(List argumentNames, Expression lambdaFunction, List arguments) { + this(argumentNames, ImmutableList.builder().add(lambdaFunction).addAll(arguments).build()); + } + + public Lambda(List argumentNames, List children) { + super(children); + this.argumentNames = ImmutableList.copyOf(Objects.requireNonNull( + argumentNames, "argumentNames should not null")); + } + + /** + * make slot according array expression + * @param arrays array expression + * @return item slots of array expression + */ + public ImmutableList makeArguments(List arrays) { + Builder builder = new ImmutableList.Builder<>(); + if (arrays.size() != argumentNames.size()) { + throw new AnalysisException(String.format("lambda %s arguments' size is not equal parameters' size", + toSql())); + } + for (int i = 0; i < arrays.size(); i++) { + Expression array = arrays.get(i); + if (!(array.getDataType() instanceof ArrayType)) { + throw new AnalysisException(String.format("lambda argument must be array but is %s", array)); + } + String name = argumentNames.get(i); + builder.add(new ArrayItemReference(name, array)); + } + return builder.build(); + } + + public String getLambdaArgumentName(int i) { + return argumentNames.get(i); + } + + public ArrayItemReference getLambdaArgument(int i) { + return (ArrayItemReference) children.get(i + 1); + } + + public List getLambdaArguments() { + return children.stream() + .skip(1) + .map(e -> (ArrayItemReference) e) + .collect(Collectors.toList()); + } + + public List getLambdaArgumentNames() { + return argumentNames; + } + + public Expression getLambdaFunction() { + return child(0); + } + + @Override + public R accept(ExpressionVisitor visitor, C context) { + return visitor.visitLambda(this, context); + } + + public Lambda withLambdaFunctionArguments(Expression lambdaFunction, List arguments) { + return new Lambda(argumentNames, lambdaFunction, arguments); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + Lambda that = (Lambda) o; + return argumentNames.equals(that.argumentNames) + && Objects.equals(children(), that.children()); + } + + @Override + public String toSql() { + StringBuilder builder = new StringBuilder(); + String argStr = argumentNames.stream().collect(Collectors.joining(", ", "(", ")")); + builder.append(String.format("%s -> %s", argStr, getLambdaFunction().toString())); + for (int i = 1; i < getArguments().size(); i++) { + builder.append(", ").append(getArgument(i).toSql()); + } + return builder.toString(); + } + + @Override + public String toString() { + StringBuilder builder = new StringBuilder(); + String argStr = argumentNames.stream().collect(Collectors.joining(", ", "(", ")")); + builder.append(String.format("%s -> %s", argStr, getLambdaFunction().toString())); + for (int i = 1; i < getArguments().size(); i++) { + builder.append(", ").append(getArgument(i).toString()); + } + return builder.toString(); + } + + @Override + public Lambda withChildren(List children) { + return new Lambda(argumentNames, children); + } + + @Override + public boolean nullable() { + return getLambdaArguments().stream().anyMatch(ArrayItemReference::nullable); + } + + @Override + public DataType getDataType() { + return new LambdaType(); + } + + public DataType getRetType() { + return getLambdaFunction().getDataType(); + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/ArrayLiteral.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/ArrayLiteral.java index 5fc4ded2cb..daee005522 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/ArrayLiteral.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/ArrayLiteral.java @@ -22,6 +22,7 @@ import org.apache.doris.nereids.exceptions.AnalysisException; import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; import org.apache.doris.nereids.types.ArrayType; import org.apache.doris.nereids.types.DataType; +import org.apache.doris.nereids.types.NullType; import com.google.common.collect.ImmutableList; @@ -33,9 +34,19 @@ public class ArrayLiteral extends Literal { private final List items; + /** + * construct array literal + */ public ArrayLiteral(List items) { super(computeDataType(items)); - this.items = ImmutableList.copyOf(items); + this.items = items.stream() + .map(i -> { + if (i instanceof NullLiteral) { + DataType type = ((ArrayType) (this.getDataType())).getItemType(); + return new NullLiteral(type); + } + return i; + }).collect(ImmutableList.toImmutableList()); } @Override @@ -64,7 +75,7 @@ public class ArrayLiteral extends Literal { String items = this.items.stream() .map(Literal::toString) .collect(Collectors.joining(", ")); - return "array(" + items + ")"; + return "[" + items + "]"; } @Override @@ -72,7 +83,7 @@ public class ArrayLiteral extends Literal { String items = this.items.stream() .map(Literal::toSql) .collect(Collectors.joining(", ")); - return "array(" + items + ")"; + return "[" + items + "]"; } @Override @@ -84,6 +95,12 @@ public class ArrayLiteral extends Literal { if (items.isEmpty()) { return ArrayType.SYSTEM_DEFAULT; } - return ArrayType.of(items.get(0).dataType); + DataType dataType = NullType.INSTANCE; + for (Literal item : items) { + if (!item.dataType.isNullType()) { + dataType = item.dataType; + } + } + return ArrayType.of(dataType); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/NullLiteral.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/NullLiteral.java index 6dd8492ba1..04da582728 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/NullLiteral.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/NullLiteral.java @@ -56,7 +56,7 @@ public class NullLiteral extends Literal { @Override public LiteralExpr toLegacyLiteral() { - return new org.apache.doris.analysis.NullLiteral(); + return org.apache.doris.analysis.NullLiteral.create(dataType.toCatalogDataType()); } @Override diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/ExpressionVisitor.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/ExpressionVisitor.java index a9aa0047ba..0f3bbf3b31 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/ExpressionVisitor.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/ExpressionVisitor.java @@ -26,6 +26,7 @@ import org.apache.doris.nereids.trees.expressions.Add; import org.apache.doris.nereids.trees.expressions.AggregateExpression; import org.apache.doris.nereids.trees.expressions.Alias; import org.apache.doris.nereids.trees.expressions.And; +import org.apache.doris.nereids.trees.expressions.ArrayItemReference; import org.apache.doris.nereids.trees.expressions.AssertNumRowsElement; import org.apache.doris.nereids.trees.expressions.BinaryArithmetic; import org.apache.doris.nereids.trees.expressions.BinaryOperator; @@ -82,6 +83,7 @@ import org.apache.doris.nereids.trees.expressions.functions.BoundFunction; import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; import org.apache.doris.nereids.trees.expressions.functions.generator.TableGeneratingFunction; import org.apache.doris.nereids.trees.expressions.functions.scalar.GroupingScalarFunction; +import org.apache.doris.nereids.trees.expressions.functions.scalar.Lambda; import org.apache.doris.nereids.trees.expressions.functions.scalar.ScalarFunction; import org.apache.doris.nereids.trees.expressions.functions.table.TableValuedFunction; import org.apache.doris.nereids.trees.expressions.functions.window.WindowFunction; @@ -123,6 +125,10 @@ public abstract class ExpressionVisitor return visitBoundFunction(aggregateFunction, context); } + public R visitLambda(Lambda lambda, C context) { + return visit(lambda, context); + } + @Override public R visitScalarFunction(ScalarFunction scalarFunction, C context) { return visitBoundFunction(scalarFunction, context); @@ -211,6 +217,10 @@ public abstract class ExpressionVisitor return visitSlot(slotReference, context); } + public R visitArrayItemSlot(SlotReference arrayItemSlot, C context) { + return visit(arrayItemSlot, context); + } + public R visitMarkJoinReference(MarkJoinSlotReference markJoinSlotReference, C context) { return visitSlotReference(markJoinSlotReference, context); } @@ -411,6 +421,10 @@ public abstract class ExpressionVisitor return visit(virtualSlotReference, context); } + public R visitArrayItemReference(ArrayItemReference arrayItemReference, C context) { + return visit(arrayItemReference, context); + } + public R visitVariableDesc(VariableDesc variableDesc, C context) { return visit(variableDesc, context); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/ScalarFunctionVisitor.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/ScalarFunctionVisitor.java index 51a86474af..7f3d50322b 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/ScalarFunctionVisitor.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/ScalarFunctionVisitor.java @@ -39,6 +39,7 @@ import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayEnumerat import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayExcept; import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayIntersect; import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayJoin; +import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayMap; import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayMax; import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayMin; import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayPopBack; @@ -477,6 +478,10 @@ public interface ScalarFunctionVisitor { return visitScalarFunction(arraySort, context); } + default R visitArrayMap(ArrayMap arraySort, C context) { + return visitScalarFunction(arraySort, context); + } + default R visitArraySum(ArraySum arraySum, C context) { return visitScalarFunction(arraySum, context); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalOneRowRelation.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalOneRowRelation.java index 8a03ae0882..68351c9a90 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalOneRowRelation.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalOneRowRelation.java @@ -52,8 +52,6 @@ public class LogicalOneRowRelation extends LogicalRelation implements OneRowRela private LogicalOneRowRelation(RelationId relationId, List projects, Optional groupExpression, Optional logicalProperties) { super(relationId, PlanType.LOGICAL_ONE_ROW_RELATION, groupExpression, logicalProperties); - Preconditions.checkArgument(projects.stream().noneMatch(p -> p.containsType(Slot.class)), - "OneRowRelation can not contains any slot"); Preconditions.checkArgument(projects.stream().noneMatch(p -> p.containsType(AggregateFunction.class)), "OneRowRelation can not contains any aggregate function"); this.projects = ImmutableList.copyOf(Objects.requireNonNull(projects, "projects can not be null")); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/LambdaType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/LambdaType.java new file mode 100644 index 0000000000..fb461bef62 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/LambdaType.java @@ -0,0 +1,44 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.types; + +import org.apache.doris.catalog.Type; + +/** + * Func type in Nereids. + */ +public class LambdaType extends DataType { + public static final LambdaType INSTANCE = new LambdaType(); + + public LambdaType() {} + + @Override + public Type toCatalogDataType() { + return org.apache.doris.catalog.Type.LAMBDA_FUNCTION; + } + + @Override + public String toSql() { + return "Lambda"; + } + + @Override + public int width() { + return 0; + } +} diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/ArrayContainsToArrayOverlapTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/ArrayContainsToArrayOverlapTest.java index dfee1a7cae..1ef76c1434 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/ArrayContainsToArrayOverlapTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/ArrayContainsToArrayOverlapTest.java @@ -41,8 +41,8 @@ class ArrayContainsToArrayOverlapTest extends ExpressionRewriteTestHelper { .getPlan(); Expression expression = plan.child(0).getExpressions().get(0).child(0); Assertions.assertTrue(expression instanceof ArraysOverlap); - Assertions.assertEquals("array(1)", expression.child(0).toSql()); - Assertions.assertEquals("array(1, 2, 3)", expression.child(1).toSql()); + Assertions.assertEquals("[1]", expression.child(0).toSql()); + Assertions.assertEquals("[1, 2, 3]", expression.child(1).toSql()); } @Test @@ -92,8 +92,8 @@ class ArrayContainsToArrayOverlapTest extends ExpressionRewriteTestHelper { .rewrite() .getPlan(); Expression expression = plan.child(0).getExpressions().get(0).child(0); - Assertions.assertEquals("(array_contains(array(1), 0) OR " - + "(array_contains(array(1), 1) AND arrays_overlap(array(1), array(2, 3, 4))))", + Assertions.assertEquals("(array_contains([1], 0) OR " + + "(array_contains([1], 1) AND arrays_overlap([1], [2, 3, 4])))", expression.toSql()); } } diff --git a/regression-test/data/nereids_syntax_p0/array_function.out b/regression-test/data/nereids_syntax_p0/array_function.out new file mode 100644 index 0000000000..6c94ead374 --- /dev/null +++ b/regression-test/data/nereids_syntax_p0/array_function.out @@ -0,0 +1,19 @@ +-- This file is automatically generated. You should know what you did if you want to edit this +-- !0 -- +[NULL, NULL, NULL, NULL] + +-- !1 -- +[NULL, NULL, NULL] + +-- !2 -- +[88, 34, -48] + +-- !3 -- +[57, NULL, NULL] + +-- !4 -- +[-81.31, -71.18, 36.59, -66.13] + +-- !5 -- +[-35.47, 67.31, 98.60, -92.89] + diff --git a/regression-test/suites/nereids_syntax_p0/array_function.groovy b/regression-test/suites/nereids_syntax_p0/array_function.groovy index bb6e1444d2..8cc857f07d 100644 --- a/regression-test/suites/nereids_syntax_p0/array_function.groovy +++ b/regression-test/suites/nereids_syntax_p0/array_function.groovy @@ -18,7 +18,12 @@ suite("array_function") { sql "SET enable_nereids_planner=true" sql "SET enable_fallback_to_original_planner=false" - + qt_0 """SELECT ARRAY_MAP(x->x+1, ARRAY('crqdt', 'oxpaa', 'xwadf', 'znwln'))""" + qt_1 "SELECT ARRAY_MAP((x,y)->x+y, ARRAY('kdjah', 'ptytj', 'quxhq'), ARRAY('vzhwj', 'bmkrc', 'snaek'))" + qt_2 "SELECT ARRAY_MAP(x->x+1, ARRAY(87, 33, -49))" + qt_3 "SELECT ARRAY_MAP((x,y)->x+y, ARRAY(-41, NULL, -18), ARRAY(98, 47, NULL))" + qt_4 "SELECT ARRAY_MAP(x->x+1, ARRAY(-82.31, -72.18, 35.59, -67.13))" + qt_5 "SELECT ARRAY_MAP((x,y)->x+y, ARRAY(-37.03, 81.89, 56.38, -36.76), ARRAY(1.56, -14.58, 42.22, -56.13))" // test { // sql "select array(), array(null), array(1), array('abc'), array(null, 1), array(1, null)" // result([["[]", "[NULL]", "[1]", "['abc']", "[NULL, 1]", "[1, NULL]"]]) diff --git a/regression-test/suites/nereids_syntax_p0/scripts/gen_array_func.py b/regression-test/suites/nereids_syntax_p0/scripts/gen_array_func.py new file mode 100644 index 0000000000..0fc07cace0 --- /dev/null +++ b/regression-test/suites/nereids_syntax_p0/scripts/gen_array_func.py @@ -0,0 +1,77 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import random +from typing import List + +DATA_TYPES = {"INT", "DOUBLE", "STR"} + +def generate_lambda_sql(array_func, lambda_func: str, data_type, nullable) -> str: + n = len(lambda_func.split(",")) + sql = "SELECT {array_func}({lambda_func}, {arrays})".format( + array_func=array_func, + lambda_func=lambda_func, + arrays=", ".join(generate_array_values(n, data_type, nullable)) + ) + return sql + +def generate_normal_sql(array_func, data_type, n, nullable) -> str: + sql = "SELECT {array_func}({arrays})".format( + array_func=array_func, + arrays=", ".join(generate_array_values(n, data_type, nullable)) + ) + return sql + +def generate_array_values(n, data_type, nullable) -> List[str]: + array_values = [] + array_len = random.randint(3, 5) + for array in range(n): + array_values.append(generate_array_value(array_len, data_type, nullable)) + return array_values + +def generate_array_value(array_len: int, data_type: str, nullable: bool) -> str: + array_values = [generate_value(data_type, nullable) for _ in range(array_len)] + return "ARRAY({values})".format( + values=", ".join(array_values) + ) + +def generate_value(value_type: str, nullable: bool): + if generate_null(nullable): + return "NULL" + if value_type == "INT": + return str(random.randint(-100, 100)) + elif value_type == "DOUBLE": + return str(round(random.uniform(-100, 100), 2)) + elif value_type == "STR": + return "'{}'".format(generate_random_string(5)) + else: + raise ValueError("Unsupported data type") + +def generate_null(nullable): + return nullable and random.random() > 0.9 + +def generate_random_string(length: int) -> str: + characters = "abcdefghijklmnopqrstuvwxyz" + return ''.join(random.choice(characters) for _ in range(length)) + +lambda_func_set = {"x->x+1", "(x,y)->x+y"} +id = 0 +for dt in DATA_TYPES: + for lambda_func in lambda_func_set: + sql = generate_lambda_sql("ARRAY_MAP", lambda_func, dt, True) + print(f"sql_{id} \"{sql}\"") + id += 1 \ No newline at end of file