[feature](Nereids) add lambda argument and array_map function (#23598)

add array_map function

SELECT ARRAY_MAP(x->x+1, ARRAY(87, 33, -49))
+----------------------------------------------------------------------+
| array_map([x] -> (x + 1), x#1 of array(87, 33, -49))     |
+----------------------------------------------------------------------+
| [88, 34, -48]                                                                 |
+----------------------------------------------------------------------+
This commit is contained in:
谢健
2023-09-13 14:24:16 +08:00
committed by GitHub
parent edd711105a
commit 335064f897
27 changed files with 797 additions and 40 deletions

View File

@ -408,6 +408,15 @@ namedExpressionSeq
expression
: booleanExpression
| lambdaExpression
;
lambdaExpression
: args+=errorCapturingIdentifier ARROW body=booleanExpression
| LEFT_PAREN
args+=errorCapturingIdentifier (COMMA args+=errorCapturingIdentifier)+
RIGHT_PAREN
ARROW body=booleanExpression
;
booleanExpression

View File

@ -61,6 +61,11 @@ public class LambdaFunctionCallExpr extends FunctionCallExpr {
super(other);
}
// nereids high order function call expr constructor without finalize/analyze
public LambdaFunctionCallExpr(Function function, FunctionParams functionParams) {
super(function, functionParams, null, false, functionParams.exprs());
}
@Override
public Expr clone() {
return new LambdaFunctionCallExpr(this);

View File

@ -53,6 +53,16 @@ public class LambdaFunctionExpr extends Expr {
this.setType(Type.LAMBDA_FUNCTION);
}
// for Nereids
public LambdaFunctionExpr(Expr lambdaBody, List<String> argNames, List<Expr> slotExpr) {
this.slotExpr.add(lambdaBody);
this.slotExpr.addAll(slotExpr);
this.names.addAll(argNames);
this.params.addAll(slotExpr);
this.children.add(lambdaBody);
this.setType(Type.LAMBDA_FUNCTION);
}
public LambdaFunctionExpr(LambdaFunctionExpr rhs) {
super(rhs);
this.names.addAll(rhs.names);

View File

@ -35,6 +35,7 @@ import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayEnumerat
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayExcept;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayIntersect;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayJoin;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayMap;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayMax;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayMin;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayPopBack;
@ -392,6 +393,7 @@ public class BuiltinScalarFunctions implements FunctionHelper {
scalar(ArrayExcept.class, "array_except"),
scalar(ArrayIntersect.class, "array_intersect"),
scalar(ArrayJoin.class, "array_join"),
scalar(ArrayMap.class, "array_map"),
scalar(ArrayMax.class, "array_max"),
scalar(ArrayMin.class, "array_min"),
scalar(ArrayPopBack.class, "array_popback"),

View File

@ -32,7 +32,6 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalRelation;
import org.apache.doris.nereids.trees.plans.visitor.PlanVisitor;
import org.apache.doris.nereids.util.Utils;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import java.util.List;
@ -55,8 +54,6 @@ public class UnboundOneRowRelation extends LogicalRelation implements Unbound, O
Optional<GroupExpression> groupExpression,
Optional<LogicalProperties> logicalProperties) {
super(id, PlanType.LOGICAL_UNBOUND_ONE_ROW_RELATION, groupExpression, logicalProperties);
Preconditions.checkArgument(projects.stream().noneMatch(p -> p.containsType(Slot.class)),
"OneRowRelation can not contains any slot");
this.projects = ImmutableList.copyOf(projects);
}

View File

@ -25,6 +25,7 @@ import org.apache.doris.analysis.BoolLiteral;
import org.apache.doris.analysis.CaseExpr;
import org.apache.doris.analysis.CaseWhenClause;
import org.apache.doris.analysis.CastExpr;
import org.apache.doris.analysis.ColumnRefExpr;
import org.apache.doris.analysis.CompoundPredicate;
import org.apache.doris.analysis.Expr;
import org.apache.doris.analysis.FunctionCallExpr;
@ -33,12 +34,15 @@ import org.apache.doris.analysis.FunctionParams;
import org.apache.doris.analysis.IndexDef;
import org.apache.doris.analysis.InvertedIndexUtil;
import org.apache.doris.analysis.IsNullPredicate;
import org.apache.doris.analysis.LambdaFunctionCallExpr;
import org.apache.doris.analysis.LambdaFunctionExpr;
import org.apache.doris.analysis.MatchPredicate;
import org.apache.doris.analysis.OrderByElement;
import org.apache.doris.analysis.SlotDescriptor;
import org.apache.doris.analysis.SlotRef;
import org.apache.doris.analysis.TimestampArithmeticExpr;
import org.apache.doris.analysis.TupleDescriptor;
import org.apache.doris.catalog.ArrayType;
import org.apache.doris.catalog.Function;
import org.apache.doris.catalog.Function.NullableMode;
import org.apache.doris.catalog.Index;
@ -48,6 +52,7 @@ import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.trees.expressions.AggregateExpression;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.And;
import org.apache.doris.nereids.trees.expressions.ArrayItemReference;
import org.apache.doris.nereids.trees.expressions.AssertNumRowsElement;
import org.apache.doris.nereids.trees.expressions.BinaryArithmetic;
import org.apache.doris.nereids.trees.expressions.CaseWhen;
@ -82,6 +87,7 @@ import org.apache.doris.nereids.trees.expressions.functions.combinator.MergeComb
import org.apache.doris.nereids.trees.expressions.functions.combinator.StateCombinator;
import org.apache.doris.nereids.trees.expressions.functions.combinator.UnionCombinator;
import org.apache.doris.nereids.trees.expressions.functions.generator.TableGeneratingFunction;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Lambda;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ScalarFunction;
import org.apache.doris.nereids.trees.expressions.functions.udf.JavaUdaf;
import org.apache.doris.nereids.trees.expressions.functions.udf.JavaUdf;
@ -262,6 +268,16 @@ public class ExpressionTranslator extends DefaultExpressionVisitor<Expr, PlanTra
return context.findSlotRef(slotReference.getExprId());
}
@Override
public Expr visitArrayItemSlot(SlotReference slotReference, PlanTranslatorContext context) {
return context.findColumnRef(slotReference.getExprId());
}
@Override
public Expr visitArrayItemReference(ArrayItemReference arrayItemReference, PlanTranslatorContext context) {
return context.findColumnRef(arrayItemReference.getExprId());
}
@Override
public Expr visitMarkJoinReference(MarkJoinSlotReference markJoinSlotReference, PlanTranslatorContext context) {
return markJoinSlotReference.isExistsHasAgg()
@ -376,8 +392,59 @@ public class ExpressionTranslator extends DefaultExpressionVisitor<Expr, PlanTra
}
@Override
public Expr visitLambda(Lambda lambda, PlanTranslatorContext context) {
Expr func = lambda.getLambdaFunction().accept(this, context);
List<Expr> arguments = lambda.getLambdaArguments().stream().map(e -> e.accept(this, context))
.collect(Collectors.toList());
return new LambdaFunctionExpr(func, lambda.getLambdaArgumentNames(), arguments);
}
private Expr visitHighOrderFunction(ScalarFunction function, PlanTranslatorContext context) {
Lambda lambda = (Lambda) function.child(0);
List<Expr> arguments = new ArrayList<>(function.children().size());
arguments.add(null);
int columnId = 0;
for (ArrayItemReference arrayItemReference : lambda.getLambdaArguments()) {
String argName = arrayItemReference.getName();
Expr expr = arrayItemReference.getArrayExpression().accept(this, context);
arguments.add(expr);
ColumnRefExpr column = new ColumnRefExpr();
column.setName(argName);
column.setColumnId(columnId);
column.setNullable(true);
column.setType(((ArrayType) expr.getType()).getItemType());
context.addExprIdColumnRefPair(arrayItemReference.getExprId(), column);
columnId += 1;
}
List<Type> argTypes = function.getArguments().stream()
.map(Expression::getDataType)
.map(DataType::toCatalogDataType)
.collect(Collectors.toList());
lambda.getLambdaArguments().stream()
.map(ArrayItemReference::getArrayExpression)
.map(Expression::getDataType)
.map(DataType::toCatalogDataType)
.forEach(argTypes::add);
org.apache.doris.catalog.Function catalogFunction = new Function(
new FunctionName(function.getName()), argTypes,
ArrayType.create(lambda.getRetType().toCatalogDataType(), true),
true, true, NullableMode.DEPEND_ON_ARGUMENT);
// create catalog FunctionCallExpr without analyze again
Expr lambdaBody = visitLambda(lambda, context);
arguments.set(0, lambdaBody);
return new LambdaFunctionCallExpr(catalogFunction, new FunctionParams(false, arguments));
}
@Override
public Expr visitScalarFunction(ScalarFunction function, PlanTranslatorContext context) {
if (function.isHighOrder()) {
return visitHighOrderFunction(function, context);
}
List<Expr> arguments = function.getArguments().stream()
.map(arg -> arg.accept(this, context))
.collect(Collectors.toList());

View File

@ -17,6 +17,7 @@
package org.apache.doris.nereids.glue.translator;
import org.apache.doris.analysis.ColumnRefExpr;
import org.apache.doris.analysis.DescriptorTable;
import org.apache.doris.analysis.SlotDescriptor;
import org.apache.doris.analysis.SlotId;
@ -77,6 +78,13 @@ public class PlanTranslatorContext {
*/
private final Map<SlotId, ExprId> slotIdToExprId = Maps.newHashMap();
/**
* For each lambda argument (ArrayItemReference),
* we create a ColumnRef representing it and
* then translate it based on the ExprId of the ArrayItemReference.
*/
private final Map<ExprId, ColumnRefExpr> exprIdToColumnRef = Maps.newHashMap();
private final List<ScanNode> scanNodes = Lists.newArrayList();
private final IdGenerator<PlanFragmentId> fragmentIdGenerator = PlanFragmentId.createGenerator();
@ -187,6 +195,10 @@ public class PlanTranslatorContext {
slotIdToExprId.put(slotRef.getDesc().getId(), exprId);
}
public void addExprIdColumnRefPair(ExprId exprId, ColumnRefExpr columnRefExpr) {
exprIdToColumnRef.put(exprId, columnRefExpr);
}
public void mergePlanFragment(PlanFragment srcFragment, PlanFragment targetFragment) {
srcFragment.getTargetRuntimeFilterIds().forEach(targetFragment::setTargetRuntimeFilterIds);
srcFragment.getBuilderRuntimeFilterIds().forEach(targetFragment::setBuilderRuntimeFilterIds);
@ -197,6 +209,10 @@ public class PlanTranslatorContext {
return exprIdToSlotRef.get(exprId);
}
public ColumnRefExpr findColumnRef(ExprId exprId) {
return exprIdToColumnRef.get(exprId);
}
public void addScanNode(ScanNode scanNode) {
scanNodes.add(scanNode);
}

View File

@ -86,6 +86,7 @@ import org.apache.doris.nereids.DorisParser.Is_not_null_predContext;
import org.apache.doris.nereids.DorisParser.IsnullContext;
import org.apache.doris.nereids.DorisParser.JoinCriteriaContext;
import org.apache.doris.nereids.DorisParser.JoinRelationContext;
import org.apache.doris.nereids.DorisParser.LambdaExpressionContext;
import org.apache.doris.nereids.DorisParser.LateralViewContext;
import org.apache.doris.nereids.DorisParser.LessThanPartitionDefContext;
import org.apache.doris.nereids.DorisParser.LimitClauseContext;
@ -230,6 +231,7 @@ import org.apache.doris.nereids.trees.expressions.functions.scalar.HourFloor;
import org.apache.doris.nereids.trees.expressions.functions.scalar.HoursAdd;
import org.apache.doris.nereids.trees.expressions.functions.scalar.HoursDiff;
import org.apache.doris.nereids.trees.expressions.functions.scalar.HoursSub;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Lambda;
import org.apache.doris.nereids.trees.expressions.functions.scalar.MinuteCeil;
import org.apache.doris.nereids.trees.expressions.functions.scalar.MinuteFloor;
import org.apache.doris.nereids.trees.expressions.functions.scalar.MinutesAdd;
@ -949,6 +951,15 @@ public class LogicalPlanBuilder extends DorisParserBaseVisitor<Object> {
});
}
@Override
public Expression visitLambdaExpression(LambdaExpressionContext ctx) {
ImmutableList<String> args = ctx.args.stream()
.map(RuleContext::getText)
.collect(ImmutableList.toImmutableList());
Expression body = (Expression) visit(ctx.body);
return new Lambda(args, body);
}
private Expression expressionCombiner(Expression left, Expression right, LogicalBinaryContext ctx) {
switch (ctx.operator.getType()) {
case DorisParser.LOGICALAND:

View File

@ -47,6 +47,7 @@ import org.apache.doris.nereids.trees.expressions.functions.Function;
import org.apache.doris.nereids.trees.expressions.functions.FunctionBuilder;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.generator.TableGeneratingFunction;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Lambda;
import org.apache.doris.nereids.trees.expressions.functions.table.TableValuedFunction;
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;
import org.apache.doris.nereids.trees.plans.JoinType;
@ -326,7 +327,7 @@ public class BindExpression implements AnalysisRuleFactory {
return e;
})
.collect(Collectors.toList());
groupBy.forEach(expression -> checkBound(expression, ctx.root));
groupBy.forEach(expression -> checkBoundExceptLambda(expression, ctx.root));
groupBy = groupBy.stream()
.map(expr -> bindFunction(expr, ctx.root, ctx.cascadesContext))
.collect(ImmutableList.toImmutableList());
@ -649,7 +650,7 @@ public class BindExpression implements AnalysisRuleFactory {
@SuppressWarnings("unchecked")
private <E extends Expression> E bindFunction(E expr, Plan plan, CascadesContext cascadesContext) {
return (E) FunctionBinder.INSTANCE.rewrite(checkBound(expr, plan),
return (E) FunctionBinder.INSTANCE.rewrite(checkBoundExceptLambda(expr, plan),
new ExpressionRewriteContext(cascadesContext));
}
@ -752,20 +753,22 @@ public class BindExpression implements AnalysisRuleFactory {
}
}
private <E extends Expression> E checkBound(E expression, Plan plan) {
expression.foreachUp(e -> {
if (e instanceof UnboundSlot) {
UnboundSlot unboundSlot = (UnboundSlot) e;
String tableName = StringUtils.join(unboundSlot.getQualifier(), ".");
if (tableName.isEmpty()) {
tableName = "table list";
}
throw new AnalysisException("Unknown column '"
+ unboundSlot.getNameParts().get(unboundSlot.getNameParts().size() - 1)
+ "' in '" + tableName + "' in "
+ plan.getType().toString().substring("LOGICAL_".length()) + " clause");
private <E extends Expression> E checkBoundExceptLambda(E expression, Plan plan) {
if (expression instanceof Lambda) {
return expression;
}
if (expression instanceof UnboundSlot) {
UnboundSlot unboundSlot = (UnboundSlot) expression;
String tableName = StringUtils.join(unboundSlot.getQualifier(), ".");
if (tableName.isEmpty()) {
tableName = "table list";
}
});
throw new AnalysisException("Unknown column '"
+ unboundSlot.getNameParts().get(unboundSlot.getNameParts().size() - 1)
+ "' in '" + tableName + "' in "
+ plan.getType().toString().substring("LOGICAL_".length()) + " clause");
}
expression.children().forEach(e -> checkBoundExceptLambda(e, plan));
return expression;
}
}

View File

@ -21,7 +21,6 @@ import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.WindowExpression;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.generator.TableGeneratingFunction;
@ -79,7 +78,6 @@ public class CheckAnalysis implements AnalysisRuleFactory {
.put(LogicalOneRowRelation.class, ImmutableSet.of(
AggregateFunction.class,
GroupingScalarFunction.class,
SlotReference.class,
TableGeneratingFunction.class,
WindowExpression.class))
.put(LogicalProject.class, ImmutableSet.of(

View File

@ -52,7 +52,10 @@ import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;
class SlotBinder extends SubExprAnalyzer {
/**
* SlotBinder is used to bind slot
*/
public class SlotBinder extends SubExprAnalyzer {
/*
bounded={table.a, a}
unbound=a

View File

@ -20,11 +20,15 @@ package org.apache.doris.nereids.rules.expression.rules;
import org.apache.doris.analysis.ArithmeticExpr.Operator;
import org.apache.doris.catalog.Env;
import org.apache.doris.catalog.FunctionRegistry;
import org.apache.doris.nereids.analyzer.Scope;
import org.apache.doris.nereids.analyzer.UnboundFunction;
import org.apache.doris.nereids.analyzer.UnboundSlot;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.rules.analysis.ArithmeticFunctionBinder;
import org.apache.doris.nereids.rules.analysis.SlotBinder;
import org.apache.doris.nereids.rules.expression.AbstractExpressionRewriteRule;
import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext;
import org.apache.doris.nereids.trees.expressions.ArrayItemReference;
import org.apache.doris.nereids.trees.expressions.BinaryArithmetic;
import org.apache.doris.nereids.trees.expressions.BitNot;
import org.apache.doris.nereids.trees.expressions.CaseWhen;
@ -40,10 +44,12 @@ import org.apache.doris.nereids.trees.expressions.IntegralDivide;
import org.apache.doris.nereids.trees.expressions.ListQuery;
import org.apache.doris.nereids.trees.expressions.Match;
import org.apache.doris.nereids.trees.expressions.Not;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.TimestampArithmetic;
import org.apache.doris.nereids.trees.expressions.WhenClause;
import org.apache.doris.nereids.trees.expressions.functions.BoundFunction;
import org.apache.doris.nereids.trees.expressions.functions.FunctionBuilder;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Lambda;
import org.apache.doris.nereids.trees.expressions.functions.udf.AliasUdfBuilder;
import org.apache.doris.nereids.trees.expressions.typecoercion.ImplicitCastInputTypes;
import org.apache.doris.nereids.types.ArrayType;
@ -54,6 +60,7 @@ import org.apache.doris.nereids.util.TypeCoercionUtils;
import com.google.common.collect.ImmutableList;
import java.util.ArrayList;
import java.util.List;
import java.util.Locale;
import java.util.stream.Collectors;
@ -82,11 +89,56 @@ public class FunctionBinder extends AbstractExpressionRewriteRule {
/* ********************************************************************************************
* bind function
* ******************************************************************************************** */
private void checkBoundLambda(Expression lambdaFunction, List<String> argumentNames) {
lambdaFunction.foreachUp(e -> {
if (e instanceof UnboundSlot) {
UnboundSlot unboundSlot = (UnboundSlot) e;
throw new AnalysisException("Unknown lambda slot '"
+ unboundSlot.getNameParts().get(unboundSlot.getNameParts().size() - 1)
+ " in lambda arguments" + argumentNames);
}
});
}
private UnboundFunction bindHighOrderFunction(UnboundFunction unboundFunction, ExpressionRewriteContext context) {
int childrenSize = unboundFunction.children().size();
List<Expression> subChildren = new ArrayList<>();
for (int i = 1; i < childrenSize; i++) {
subChildren.add(unboundFunction.child(i).accept(this, context));
}
// bindLambdaFunction
Lambda lambda = (Lambda) unboundFunction.children().get(0);
Expression lambdaFunction = lambda.getLambdaFunction();
List<ArrayItemReference> arrayItemReferences = lambda.makeArguments(subChildren);
// 1.bindSlot
List<Slot> boundedSlots = arrayItemReferences.stream()
.map(ArrayItemReference::toSlot)
.collect(ImmutableList.toImmutableList());
lambdaFunction = new SlotBinder(new Scope(boundedSlots), context.cascadesContext,
true, false).bind(lambdaFunction);
checkBoundLambda(lambdaFunction, lambda.getLambdaArgumentNames());
// 2.bindFunction
lambdaFunction = lambdaFunction.accept(this, context);
Lambda lambdaClosure = lambda.withLambdaFunctionArguments(lambdaFunction, arrayItemReferences);
// We don't add the ArrayExpression in high order function at all
return unboundFunction.withChildren(ImmutableList.<Expression>builder()
.add(lambdaClosure)
.build());
}
@Override
public Expression visitUnboundFunction(UnboundFunction unboundFunction, ExpressionRewriteContext context) {
unboundFunction = unboundFunction.withChildren(unboundFunction.children().stream()
.map(e -> e.accept(this, context)).collect(Collectors.toList()));
if (unboundFunction.isHighOrder()) {
unboundFunction = bindHighOrderFunction(unboundFunction, context);
} else {
unboundFunction = unboundFunction.withChildren(unboundFunction.children().stream()
.map(e -> e.accept(this, context)).collect(Collectors.toList()));
}
// bind function
FunctionRegistry functionRegistry = Env.getCurrentEnv().getFunctionRegistry();

View File

@ -0,0 +1,149 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.trees.expressions;
import org.apache.doris.nereids.trees.expressions.typecoercion.ExpectsInputTypes;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.ArrayType;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.coercion.AnyDataType;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import java.util.List;
import java.util.Objects;
/**
* it is item from array, which used in lambda function
*/
public class ArrayItemReference extends NamedExpression implements ExpectsInputTypes {
protected final ExprId exprId;
protected final String name;
/** ArrayItemReference */
public ArrayItemReference(String name, Expression arrayExpression) {
this(StatementScopeIdGenerator.newExprId(), name, arrayExpression);
}
public ArrayItemReference(ExprId exprId, String name, Expression arrayExpression) {
super(ImmutableList.of(arrayExpression));
Preconditions.checkArgument(arrayExpression.getDataType() instanceof ArrayType,
String.format("ArrayItemReference' child %s must return array", child(0)));
this.exprId = exprId;
this.name = name;
}
public Expression getArrayExpression() {
return children.get(0);
}
@Override
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
return visitor.visitArrayItemReference(this, context);
}
@Override
public String getName() {
return name;
}
@Override
public ExprId getExprId() {
return exprId;
}
@Override
public List<String> getQualifier() {
return ImmutableList.of(name);
}
@Override
public boolean nullable() {
return ((ArrayType) (this.children.get(0).getDataType())).containsNull();
}
@Override
public ArrayItemReference withChildren(List<Expression> expressions) {
return new ArrayItemReference(exprId, name, expressions.get(0));
}
@Override
public DataType getDataType() {
return ((ArrayType) (this.children.get(0).getDataType())).getItemType();
}
@Override
public String toSql() {
return child(0).toSql();
}
@Override
public Slot toSlot() {
return new ArrayItemSlot(exprId, name, getDataType(), nullable());
}
@Override
public String toString() {
String str = getName() + "#" + getExprId();
str += " of " + child(0).toString();
return str;
}
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
ArrayItemReference that = (ArrayItemReference) o;
return exprId.equals(that.exprId);
}
@Override
public int hashCode() {
return Objects.hash(exprId);
}
@Override
public List<DataType> expectedInputTypes() {
return ImmutableList.of(ArrayType.of(AnyDataType.INSTANCE_WITHOUT_INDEX));
}
static class ArrayItemSlot extends SlotReference implements SlotNotFromChildren {
/**
* Constructor for SlotReference.
*
* @param exprId UUID for this slot reference
* @param name slot reference name
* @param dataType slot reference logical data type
* @param nullable true if nullable
*/
public ArrayItemSlot(ExprId exprId, String name, DataType dataType, boolean nullable) {
super(exprId, name, dataType, nullable, ImmutableList.of(), null);
}
@Override
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
return visitor.visitArrayItemSlot(this, context);
}
}
}

View File

@ -18,6 +18,7 @@
package org.apache.doris.nereids.trees.expressions.functions;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Lambda;
import java.util.List;
@ -32,4 +33,8 @@ public abstract class Function extends Expression {
public Function(List<Expression> children) {
super(children);
}
public boolean isHighOrder() {
return !children.isEmpty() && children.get(0) instanceof Lambda;
}
}

View File

@ -0,0 +1,80 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.trees.expressions.functions.scalar;
import org.apache.doris.catalog.FunctionSignature;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.ArrayType;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.LambdaType;
import org.apache.doris.nereids.types.coercion.FollowToAnyDataType;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import java.util.List;
/**
* ScalarFunction 'array_map'.
*/
public class ArrayMap extends ScalarFunction
implements ExplicitlyCastableSignature, PropagateNullable {
public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(new FollowToAnyDataType(0)).args(LambdaType.INSTANCE)
);
/**
* constructor with arguments.
*/
public ArrayMap(Expression... arg) {
super("array_map", arg);
}
public ArrayMap(List<Expression> arg) {
super("array_map", arg);
}
/**
* withChildren.
*/
@Override
public ArrayMap withChildren(List<Expression> children) {
return new ArrayMap(children);
}
@Override
public DataType getDataType() {
Preconditions.checkArgument(children.get(0) instanceof Lambda,
"The first arg of array_map must be lambda");
return ArrayType.of(((Lambda) children.get(0)).getRetType(), true);
}
@Override
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
return visitor.visitArrayMap(this, context);
}
@Override
public List<FunctionSignature> getSignatures() {
return SIGNATURES;
}
}

View File

@ -0,0 +1,168 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.trees.expressions.functions.scalar;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.trees.expressions.ArrayItemReference;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.ArrayType;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.LambdaType;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableList.Builder;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
/**
* Lambda includes lambda arguments and function body
* Before bind, x -> x : arguments("x") -> children: Expression(x)
* After bind, x -> x : arguments("x") -> children: Expression(x) ArrayItemReference(x)
*/
public class Lambda extends Expression {
private final List<String> argumentNames;
/**
* constructor
*/
public Lambda(List<String> argumentNames, Expression lambdaFunction) {
this(argumentNames, ImmutableList.of(lambdaFunction));
}
public Lambda(List<String> argumentNames, Expression lambdaFunction, List<ArrayItemReference> arguments) {
this(argumentNames, ImmutableList.<Expression>builder().add(lambdaFunction).addAll(arguments).build());
}
public Lambda(List<String> argumentNames, List<Expression> children) {
super(children);
this.argumentNames = ImmutableList.copyOf(Objects.requireNonNull(
argumentNames, "argumentNames should not null"));
}
/**
* make slot according array expression
* @param arrays array expression
* @return item slots of array expression
*/
public ImmutableList<ArrayItemReference> makeArguments(List<Expression> arrays) {
Builder<ArrayItemReference> builder = new ImmutableList.Builder<>();
if (arrays.size() != argumentNames.size()) {
throw new AnalysisException(String.format("lambda %s arguments' size is not equal parameters' size",
toSql()));
}
for (int i = 0; i < arrays.size(); i++) {
Expression array = arrays.get(i);
if (!(array.getDataType() instanceof ArrayType)) {
throw new AnalysisException(String.format("lambda argument must be array but is %s", array));
}
String name = argumentNames.get(i);
builder.add(new ArrayItemReference(name, array));
}
return builder.build();
}
public String getLambdaArgumentName(int i) {
return argumentNames.get(i);
}
public ArrayItemReference getLambdaArgument(int i) {
return (ArrayItemReference) children.get(i + 1);
}
public List<ArrayItemReference> getLambdaArguments() {
return children.stream()
.skip(1)
.map(e -> (ArrayItemReference) e)
.collect(Collectors.toList());
}
public List<String> getLambdaArgumentNames() {
return argumentNames;
}
public Expression getLambdaFunction() {
return child(0);
}
@Override
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
return visitor.visitLambda(this, context);
}
public Lambda withLambdaFunctionArguments(Expression lambdaFunction, List<ArrayItemReference> arguments) {
return new Lambda(argumentNames, lambdaFunction, arguments);
}
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
Lambda that = (Lambda) o;
return argumentNames.equals(that.argumentNames)
&& Objects.equals(children(), that.children());
}
@Override
public String toSql() {
StringBuilder builder = new StringBuilder();
String argStr = argumentNames.stream().collect(Collectors.joining(", ", "(", ")"));
builder.append(String.format("%s -> %s", argStr, getLambdaFunction().toString()));
for (int i = 1; i < getArguments().size(); i++) {
builder.append(", ").append(getArgument(i).toSql());
}
return builder.toString();
}
@Override
public String toString() {
StringBuilder builder = new StringBuilder();
String argStr = argumentNames.stream().collect(Collectors.joining(", ", "(", ")"));
builder.append(String.format("%s -> %s", argStr, getLambdaFunction().toString()));
for (int i = 1; i < getArguments().size(); i++) {
builder.append(", ").append(getArgument(i).toString());
}
return builder.toString();
}
@Override
public Lambda withChildren(List<Expression> children) {
return new Lambda(argumentNames, children);
}
@Override
public boolean nullable() {
return getLambdaArguments().stream().anyMatch(ArrayItemReference::nullable);
}
@Override
public DataType getDataType() {
return new LambdaType();
}
public DataType getRetType() {
return getLambdaFunction().getDataType();
}
}

View File

@ -22,6 +22,7 @@ import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.ArrayType;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.NullType;
import com.google.common.collect.ImmutableList;
@ -33,9 +34,19 @@ public class ArrayLiteral extends Literal {
private final List<Literal> items;
/**
* construct array literal
*/
public ArrayLiteral(List<Literal> items) {
super(computeDataType(items));
this.items = ImmutableList.copyOf(items);
this.items = items.stream()
.map(i -> {
if (i instanceof NullLiteral) {
DataType type = ((ArrayType) (this.getDataType())).getItemType();
return new NullLiteral(type);
}
return i;
}).collect(ImmutableList.toImmutableList());
}
@Override
@ -64,7 +75,7 @@ public class ArrayLiteral extends Literal {
String items = this.items.stream()
.map(Literal::toString)
.collect(Collectors.joining(", "));
return "array(" + items + ")";
return "[" + items + "]";
}
@Override
@ -72,7 +83,7 @@ public class ArrayLiteral extends Literal {
String items = this.items.stream()
.map(Literal::toSql)
.collect(Collectors.joining(", "));
return "array(" + items + ")";
return "[" + items + "]";
}
@Override
@ -84,6 +95,12 @@ public class ArrayLiteral extends Literal {
if (items.isEmpty()) {
return ArrayType.SYSTEM_DEFAULT;
}
return ArrayType.of(items.get(0).dataType);
DataType dataType = NullType.INSTANCE;
for (Literal item : items) {
if (!item.dataType.isNullType()) {
dataType = item.dataType;
}
}
return ArrayType.of(dataType);
}
}

View File

@ -56,7 +56,7 @@ public class NullLiteral extends Literal {
@Override
public LiteralExpr toLegacyLiteral() {
return new org.apache.doris.analysis.NullLiteral();
return org.apache.doris.analysis.NullLiteral.create(dataType.toCatalogDataType());
}
@Override

View File

@ -26,6 +26,7 @@ import org.apache.doris.nereids.trees.expressions.Add;
import org.apache.doris.nereids.trees.expressions.AggregateExpression;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.And;
import org.apache.doris.nereids.trees.expressions.ArrayItemReference;
import org.apache.doris.nereids.trees.expressions.AssertNumRowsElement;
import org.apache.doris.nereids.trees.expressions.BinaryArithmetic;
import org.apache.doris.nereids.trees.expressions.BinaryOperator;
@ -82,6 +83,7 @@ import org.apache.doris.nereids.trees.expressions.functions.BoundFunction;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.generator.TableGeneratingFunction;
import org.apache.doris.nereids.trees.expressions.functions.scalar.GroupingScalarFunction;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Lambda;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ScalarFunction;
import org.apache.doris.nereids.trees.expressions.functions.table.TableValuedFunction;
import org.apache.doris.nereids.trees.expressions.functions.window.WindowFunction;
@ -123,6 +125,10 @@ public abstract class ExpressionVisitor<R, C>
return visitBoundFunction(aggregateFunction, context);
}
public R visitLambda(Lambda lambda, C context) {
return visit(lambda, context);
}
@Override
public R visitScalarFunction(ScalarFunction scalarFunction, C context) {
return visitBoundFunction(scalarFunction, context);
@ -211,6 +217,10 @@ public abstract class ExpressionVisitor<R, C>
return visitSlot(slotReference, context);
}
public R visitArrayItemSlot(SlotReference arrayItemSlot, C context) {
return visit(arrayItemSlot, context);
}
public R visitMarkJoinReference(MarkJoinSlotReference markJoinSlotReference, C context) {
return visitSlotReference(markJoinSlotReference, context);
}
@ -411,6 +421,10 @@ public abstract class ExpressionVisitor<R, C>
return visit(virtualSlotReference, context);
}
public R visitArrayItemReference(ArrayItemReference arrayItemReference, C context) {
return visit(arrayItemReference, context);
}
public R visitVariableDesc(VariableDesc variableDesc, C context) {
return visit(variableDesc, context);
}

View File

@ -39,6 +39,7 @@ import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayEnumerat
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayExcept;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayIntersect;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayJoin;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayMap;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayMax;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayMin;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayPopBack;
@ -477,6 +478,10 @@ public interface ScalarFunctionVisitor<R, C> {
return visitScalarFunction(arraySort, context);
}
default R visitArrayMap(ArrayMap arraySort, C context) {
return visitScalarFunction(arraySort, context);
}
default R visitArraySum(ArraySum arraySum, C context) {
return visitScalarFunction(arraySum, context);
}

View File

@ -52,8 +52,6 @@ public class LogicalOneRowRelation extends LogicalRelation implements OneRowRela
private LogicalOneRowRelation(RelationId relationId, List<NamedExpression> projects,
Optional<GroupExpression> groupExpression, Optional<LogicalProperties> logicalProperties) {
super(relationId, PlanType.LOGICAL_ONE_ROW_RELATION, groupExpression, logicalProperties);
Preconditions.checkArgument(projects.stream().noneMatch(p -> p.containsType(Slot.class)),
"OneRowRelation can not contains any slot");
Preconditions.checkArgument(projects.stream().noneMatch(p -> p.containsType(AggregateFunction.class)),
"OneRowRelation can not contains any aggregate function");
this.projects = ImmutableList.copyOf(Objects.requireNonNull(projects, "projects can not be null"));

View File

@ -0,0 +1,44 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.types;
import org.apache.doris.catalog.Type;
/**
* Func type in Nereids.
*/
public class LambdaType extends DataType {
public static final LambdaType INSTANCE = new LambdaType();
public LambdaType() {}
@Override
public Type toCatalogDataType() {
return org.apache.doris.catalog.Type.LAMBDA_FUNCTION;
}
@Override
public String toSql() {
return "Lambda";
}
@Override
public int width() {
return 0;
}
}

View File

@ -41,8 +41,8 @@ class ArrayContainsToArrayOverlapTest extends ExpressionRewriteTestHelper {
.getPlan();
Expression expression = plan.child(0).getExpressions().get(0).child(0);
Assertions.assertTrue(expression instanceof ArraysOverlap);
Assertions.assertEquals("array(1)", expression.child(0).toSql());
Assertions.assertEquals("array(1, 2, 3)", expression.child(1).toSql());
Assertions.assertEquals("[1]", expression.child(0).toSql());
Assertions.assertEquals("[1, 2, 3]", expression.child(1).toSql());
}
@Test
@ -92,8 +92,8 @@ class ArrayContainsToArrayOverlapTest extends ExpressionRewriteTestHelper {
.rewrite()
.getPlan();
Expression expression = plan.child(0).getExpressions().get(0).child(0);
Assertions.assertEquals("(array_contains(array(1), 0) OR "
+ "(array_contains(array(1), 1) AND arrays_overlap(array(1), array(2, 3, 4))))",
Assertions.assertEquals("(array_contains([1], 0) OR "
+ "(array_contains([1], 1) AND arrays_overlap([1], [2, 3, 4])))",
expression.toSql());
}
}