From 7977bebfed79d987593a76ca2cfdff6c36ba78d5 Mon Sep 17 00:00:00 2001 From: shee <13843187+qzsee@users.noreply.github.com> Date: Mon, 26 Sep 2022 17:16:23 +0800 Subject: [PATCH] [feature](Nereids) constant expression folding (#12151) --- .../org/apache/doris/nereids/DorisParser.g4 | 1 + .../nereids/analyzer/NereidsAnalyzer.java | 5 + .../jobs/batch/NereidsRewriteJobExecutor.java | 4 +- .../nereids/jobs/batch/TypeCoercionJob.java | 41 +++ .../nereids/parser/LogicalPlanBuilder.java | 5 +- .../nereids/rules/analysis/BindFunction.java | 29 +- .../rewrite/ExpressionNormalization.java | 15 +- .../rewrite/ExpressionRewriteContext.java | 7 + .../rewrite/ExpressionRuleExecutor.java | 12 +- .../rewrite/rules/FoldConstantRule.java | 39 ++ .../rewrite/rules/FoldConstantRuleOnBE.java | 200 ++++++++++ .../rewrite/rules/FoldConstantRuleOnFE.java | 310 ++++++++++++++++ .../rewrite/rules/TypeCoercion.java | 25 +- .../trees/expressions/BinaryArithmetic.java | 3 +- .../trees/expressions/CompoundPredicate.java | 1 + .../nereids/trees/expressions/EqualTo.java | 3 +- .../trees/expressions/ExecFunction.java | 46 +++ .../trees/expressions/ExecFunctionList.java | 35 ++ .../nereids/trees/expressions/Expression.java | 9 + .../expressions/ExpressionEvaluator.java | 208 +++++++++++ .../trees/expressions/GreaterThan.java | 3 +- .../trees/expressions/GreaterThanEqual.java | 3 +- .../nereids/trees/expressions/IsNull.java | 78 ++++ .../nereids/trees/expressions/LessThan.java | 3 +- .../trees/expressions/LessThanEqual.java | 3 +- .../doris/nereids/trees/expressions/Not.java | 3 +- .../trees/expressions/NullSafeEqual.java | 1 + .../trees/expressions/SlotReference.java | 4 + .../expressions/TimestampArithmetic.java | 32 +- .../nereids/trees/expressions/WhenClause.java | 2 +- .../functions/ExecutableFunctions.java | 243 +++++++++++++ .../functions/scalar/Substring.java | 3 +- .../expressions/literal/DateLiteral.java | 39 +- .../expressions/literal/DateTimeLiteral.java | 85 +++-- .../trees/expressions/literal/Literal.java | 109 +++++- .../expressions/literal/StringLiteral.java | 27 -- .../expressions/literal/VarcharLiteral.java | 15 - .../visitor/ExpressionVisitor.java | 5 + .../apache/doris/nereids/types/DataType.java | 50 ++- .../doris/nereids/util/ExpressionUtils.java | 18 + .../doris/nereids/util/TypeCoercionUtils.java | 38 ++ .../datasets/ssb/SSBJoinReorderTest.java | 12 +- .../nereids/parser/HavingClauseTest.java | 33 +- .../rules/analysis/FunctionRegistryTest.java | 18 +- .../rewrite/ExpressionRewriteTest.java | 11 +- .../expression/rewrite/FoldConstantTest.java | 341 ++++++++++++++++++ .../expression/rewrite/TypeCoercionTest.java | 5 +- .../expressions/ExpressionParserTest.java | 9 + 48 files changed, 2005 insertions(+), 186 deletions(-) create mode 100644 fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/batch/TypeCoercionJob.java create mode 100644 fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/rules/FoldConstantRule.java create mode 100644 fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/rules/FoldConstantRuleOnBE.java create mode 100644 fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/rules/FoldConstantRuleOnFE.java create mode 100644 fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/ExecFunction.java create mode 100644 fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/ExecFunctionList.java create mode 100644 fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/ExpressionEvaluator.java create mode 100644 fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/IsNull.java create mode 100644 fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/ExecutableFunctions.java create mode 100644 fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rewrite/FoldConstantTest.java diff --git a/fe/fe-core/src/main/antlr4/org/apache/doris/nereids/DorisParser.g4 b/fe/fe-core/src/main/antlr4/org/apache/doris/nereids/DorisParser.g4 index d2a03ecccf..fa81545d31 100644 --- a/fe/fe-core/src/main/antlr4/org/apache/doris/nereids/DorisParser.g4 +++ b/fe/fe-core/src/main/antlr4/org/apache/doris/nereids/DorisParser.g4 @@ -203,6 +203,7 @@ predicate | NOT? kind=(LIKE | REGEXP) pattern=valueExpression | NOT? kind=IN LEFT_PAREN expression (COMMA expression)* RIGHT_PAREN | NOT? kind=IN LEFT_PAREN query RIGHT_PAREN + | IS NOT? kind=NULL ; valueExpression diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/analyzer/NereidsAnalyzer.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/analyzer/NereidsAnalyzer.java index 596bf9cd0e..8e9876f856 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/analyzer/NereidsAnalyzer.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/analyzer/NereidsAnalyzer.java @@ -22,6 +22,7 @@ import org.apache.doris.nereids.jobs.batch.AnalyzeRulesJob; import org.apache.doris.nereids.jobs.batch.AnalyzeSubqueryRulesJob; import org.apache.doris.nereids.jobs.batch.CheckAnalysisJob; import org.apache.doris.nereids.jobs.batch.FinalizeAnalyzeJob; +import org.apache.doris.nereids.jobs.batch.TypeCoercionJob; import org.apache.doris.nereids.rules.analysis.Scope; import java.util.Objects; @@ -44,9 +45,13 @@ public class NereidsAnalyzer { this.outerScope = Objects.requireNonNull(outerScope, "outerScope can not be null"); } + /** + * nereids analyze sql. + */ public void analyze() { new AnalyzeRulesJob(cascadesContext, outerScope).execute(); new AnalyzeSubqueryRulesJob(cascadesContext).execute(); + new TypeCoercionJob(cascadesContext).execute(); new FinalizeAnalyzeJob(cascadesContext).execute(); // check whether analyze result is meaningful new CheckAnalysisJob(cascadesContext).execute(); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/batch/NereidsRewriteJobExecutor.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/batch/NereidsRewriteJobExecutor.java index 5534871b50..ac008d8dac 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/batch/NereidsRewriteJobExecutor.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/batch/NereidsRewriteJobExecutor.java @@ -37,7 +37,7 @@ import org.apache.doris.nereids.rules.rewrite.logical.ReorderJoin; import com.google.common.collect.ImmutableList; /** - * Apply rules to normalize expressions. + * Apply rules to optimize logical plan. */ public class NereidsRewriteJobExecutor extends BatchRulesJob { @@ -58,7 +58,7 @@ public class NereidsRewriteJobExecutor extends BatchRulesJob { */ .addAll(new AdjustApplyFromCorrelatToUnCorrelatJob(cascadesContext).rulesJob) .addAll(new ConvertApplyToJoinJob(cascadesContext).rulesJob) - .add(topDownBatch(ImmutableList.of(new ExpressionNormalization()))) + .add(topDownBatch(ImmutableList.of(new ExpressionNormalization(cascadesContext.getConnectContext())))) .add(topDownBatch(ImmutableList.of(new ExpressionOptimization()))) .add(topDownBatch(ImmutableList.of(new ExtractSingleTableExpressionFromDisjunction()))) .add(topDownBatch(ImmutableList.of(new NormalizeAggregate()))) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/batch/TypeCoercionJob.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/batch/TypeCoercionJob.java new file mode 100644 index 0000000000..92dfa89c7e --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/batch/TypeCoercionJob.java @@ -0,0 +1,41 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.jobs.batch; + +import org.apache.doris.nereids.CascadesContext; +import org.apache.doris.nereids.rules.expression.rewrite.ExpressionNormalization; +import org.apache.doris.nereids.rules.expression.rewrite.rules.TypeCoercion; + +import com.google.common.collect.ImmutableList; + +/** + * type coercion job. + */ +public class TypeCoercionJob extends BatchRulesJob { + /** + * constructor. + */ + public TypeCoercionJob(CascadesContext cascadesContext) { + super(cascadesContext); + rulesJob.addAll(ImmutableList.of( + topDownBatch(ImmutableList.of( + new ExpressionNormalization(cascadesContext.getConnectContext(), + ImmutableList.of(TypeCoercion.INSTANCE))) + ))); + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/LogicalPlanBuilder.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/LogicalPlanBuilder.java index d45092d8fa..b84e0328bc 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/LogicalPlanBuilder.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/LogicalPlanBuilder.java @@ -95,6 +95,7 @@ import org.apache.doris.nereids.trees.expressions.GreaterThan; import org.apache.doris.nereids.trees.expressions.GreaterThanEqual; import org.apache.doris.nereids.trees.expressions.InPredicate; import org.apache.doris.nereids.trees.expressions.InSubquery; +import org.apache.doris.nereids.trees.expressions.IsNull; import org.apache.doris.nereids.trees.expressions.LessThan; import org.apache.doris.nereids.trees.expressions.LessThanEqual; import org.apache.doris.nereids.trees.expressions.Like; @@ -385,7 +386,6 @@ public class LogicalPlanBuilder extends DorisParserBaseVisitor { public Expression visitPredicated(PredicatedContext ctx) { return ParserUtils.withOrigin(ctx, () -> { Expression e = getExpression(ctx.valueExpression()); - // TODO: add predicate(is not null ...) return ctx.predicate() == null ? e : withPredicate(e, ctx.predicate()); }); } @@ -948,6 +948,9 @@ public class LogicalPlanBuilder extends DorisParserBaseVisitor { ); } break; + case DorisParser.NULL: + outExpression = new IsNull(valueExpression); + break; default: throw new ParseException("Unsupported predicate type: " + ctx.kind.getText(), ctx); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindFunction.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindFunction.java index f0cda69ff9..9249e52968 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindFunction.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindFunction.java @@ -36,13 +36,11 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalFilter; import org.apache.doris.nereids.trees.plans.logical.LogicalHaving; import org.apache.doris.nereids.trees.plans.logical.LogicalOneRowRelation; import org.apache.doris.nereids.trees.plans.logical.LogicalProject; -import org.apache.doris.nereids.types.DateTimeType; -import org.apache.doris.nereids.types.DateType; -import org.apache.doris.nereids.types.IntegerType; import com.google.common.collect.ImmutableList; import java.util.List; +import java.util.Locale; import java.util.stream.Collectors; /** @@ -134,8 +132,12 @@ public class BindFunction implements AnalysisRuleFactory { return builder.build(functionName, unboundFunction.getArguments()); } + /** + * gets the method for calculating the time. + * e.g. YEARS_ADD、YEARS_SUB、DAYS_ADD 、DAYS_SUB + */ @Override - public Expression visitTimestampArithmetic(TimestampArithmetic arithmetic, Env env) { + public Expression visitTimestampArithmetic(TimestampArithmetic arithmetic, Env context) { String funcOpName; if (arithmetic.getFuncName() == null) { // e.g. YEARS_ADD, MONTHS_SUB @@ -144,24 +146,7 @@ public class BindFunction implements AnalysisRuleFactory { } else { funcOpName = arithmetic.getFuncName(); } - - Expression left = arithmetic.left(); - Expression right = arithmetic.right(); - - if (!left.getDataType().isDateType()) { - try { - left = left.castTo(DateTimeType.INSTANCE); - } catch (Exception e) { - // ignore - } - if (!left.getDataType().isDateType() && !arithmetic.getTimeUnit().isDateTimeUnit()) { - left = arithmetic.left().castTo(DateType.INSTANCE); - } - } - if (!right.getDataType().isIntType()) { - right = right.castTo(IntegerType.INSTANCE); - } - return arithmetic.withFuncName(funcOpName).withChildren(ImmutableList.of(left, right)); + return arithmetic.withFuncName(funcOpName.toLowerCase(Locale.ROOT)); } } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/ExpressionNormalization.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/ExpressionNormalization.java index 99087e3393..57d6b509d7 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/ExpressionNormalization.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/ExpressionNormalization.java @@ -18,11 +18,13 @@ package org.apache.doris.nereids.rules.expression.rewrite; import org.apache.doris.nereids.rules.expression.rewrite.rules.BetweenToCompoundRule; +import org.apache.doris.nereids.rules.expression.rewrite.rules.FoldConstantRule; import org.apache.doris.nereids.rules.expression.rewrite.rules.InPredicateToEqualToRule; import org.apache.doris.nereids.rules.expression.rewrite.rules.NormalizeBinaryPredicatesRule; import org.apache.doris.nereids.rules.expression.rewrite.rules.SimplifyCastRule; import org.apache.doris.nereids.rules.expression.rewrite.rules.SimplifyNotExprRule; import org.apache.doris.nereids.rules.expression.rewrite.rules.TypeCoercion; +import org.apache.doris.qe.ConnectContext; import com.google.common.collect.ImmutableList; @@ -39,11 +41,16 @@ public class ExpressionNormalization extends ExpressionRewrite { InPredicateToEqualToRule.INSTANCE, SimplifyNotExprRule.INSTANCE, SimplifyCastRule.INSTANCE, - TypeCoercion.INSTANCE + TypeCoercion.INSTANCE, + FoldConstantRule.INSTANCE ); - private static final ExpressionRuleExecutor EXECUTOR = new ExpressionRuleExecutor(NORMALIZE_REWRITE_RULES); - public ExpressionNormalization() { - super(EXECUTOR); + public ExpressionNormalization(ConnectContext context) { + super(new ExpressionRuleExecutor(NORMALIZE_REWRITE_RULES, context)); + } + + public ExpressionNormalization(ConnectContext context, List rules) { + super(new ExpressionRuleExecutor(rules, context)); } } + diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/ExpressionRewriteContext.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/ExpressionRewriteContext.java index 13e2267f2b..dd32382810 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/ExpressionRewriteContext.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/ExpressionRewriteContext.java @@ -17,8 +17,15 @@ package org.apache.doris.nereids.rules.expression.rewrite; +import org.apache.doris.qe.ConnectContext; + /** * expression rewrite context. */ public class ExpressionRewriteContext { + public final ConnectContext connectContext; + + public ExpressionRewriteContext(ConnectContext connectContext) { + this.connectContext = connectContext; + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/ExpressionRuleExecutor.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/ExpressionRuleExecutor.java index 13fbdb20ca..44829a2a31 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/ExpressionRuleExecutor.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/ExpressionRuleExecutor.java @@ -18,8 +18,7 @@ package org.apache.doris.nereids.rules.expression.rewrite; import org.apache.doris.nereids.trees.expressions.Expression; - -import com.google.common.collect.Lists; +import org.apache.doris.qe.ConnectContext; import java.util.List; import java.util.Optional; @@ -33,14 +32,13 @@ public class ExpressionRuleExecutor { private final ExpressionRewriteContext ctx; private final List rules; - public ExpressionRuleExecutor(List rules) { + public ExpressionRuleExecutor(List rules, ConnectContext context) { this.rules = rules; - this.ctx = new ExpressionRewriteContext(); + this.ctx = new ExpressionRewriteContext(context); } - public ExpressionRuleExecutor(ExpressionRewriteRule rule) { - this.rules = Lists.newArrayList(rule); - this.ctx = new ExpressionRewriteContext(); + public ExpressionRuleExecutor(List rules) { + this(rules, null); } public List rewrite(List exprs) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/rules/FoldConstantRule.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/rules/FoldConstantRule.java new file mode 100644 index 0000000000..175e40b0aa --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/rules/FoldConstantRule.java @@ -0,0 +1,39 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.expression.rewrite.rules; + +import org.apache.doris.nereids.rules.expression.rewrite.AbstractExpressionRewriteRule; +import org.apache.doris.nereids.rules.expression.rewrite.ExpressionRewriteContext; +import org.apache.doris.nereids.trees.expressions.Expression; + +/** + * Constant evaluation of an expression. + */ +public class FoldConstantRule extends AbstractExpressionRewriteRule { + + public static final FoldConstantRule INSTANCE = new FoldConstantRule(); + + @Override + public Expression rewrite(Expression expr, ExpressionRewriteContext ctx) { + if (ctx.connectContext != null && ctx.connectContext.getSessionVariable().isEnableFoldConstantByBe()) { + return FoldConstantRuleOnBE.INSTANCE.rewrite(expr, ctx); + } + return FoldConstantRuleOnFE.INSTANCE.rewrite(expr, ctx); + } +} + diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/rules/FoldConstantRuleOnBE.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/rules/FoldConstantRuleOnBE.java new file mode 100644 index 0000000000..895462020e --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/rules/FoldConstantRuleOnBE.java @@ -0,0 +1,200 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.expression.rewrite.rules; + +import org.apache.doris.analysis.Expr; +import org.apache.doris.analysis.ExprId; +import org.apache.doris.analysis.LiteralExpr; +import org.apache.doris.catalog.Env; +import org.apache.doris.catalog.PrimitiveType; +import org.apache.doris.catalog.Type; +import org.apache.doris.common.IdGenerator; +import org.apache.doris.common.UserException; +import org.apache.doris.common.util.TimeUtils; +import org.apache.doris.common.util.VectorizedUtil; +import org.apache.doris.nereids.glue.translator.ExpressionTranslator; +import org.apache.doris.nereids.rules.expression.rewrite.AbstractExpressionRewriteRule; +import org.apache.doris.nereids.rules.expression.rewrite.ExpressionRewriteContext; +import org.apache.doris.nereids.trees.expressions.Between; +import org.apache.doris.nereids.trees.expressions.Cast; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.literal.Literal; +import org.apache.doris.nereids.types.DataType; +import org.apache.doris.proto.InternalService; +import org.apache.doris.proto.InternalService.PConstantExprResult; +import org.apache.doris.qe.ConnectContext; +import org.apache.doris.rpc.BackendServiceProxy; +import org.apache.doris.system.Backend; +import org.apache.doris.thrift.TExpr; +import org.apache.doris.thrift.TFoldConstantParams; +import org.apache.doris.thrift.TNetworkAddress; +import org.apache.doris.thrift.TPrimitiveType; +import org.apache.doris.thrift.TQueryGlobals; + +import com.google.common.collect.Maps; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; + +import java.time.LocalDateTime; +import java.time.format.DateTimeFormatter; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Objects; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; + +/** + * Constant evaluation of an expression. + */ +public class FoldConstantRuleOnBE extends AbstractExpressionRewriteRule { + public static final FoldConstantRuleOnBE INSTANCE = new FoldConstantRuleOnBE(); + private static final Logger LOG = LogManager.getLogger(FoldConstantRuleOnBE.class); + private static final DateTimeFormatter DATE_FORMAT = DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss"); + private final IdGenerator idGenerator = ExprId.createGenerator(); + + @Override + public Expression rewrite(Expression expr, ExpressionRewriteContext ctx) { + Expression expression = FoldConstantRuleOnFE.INSTANCE.rewrite(expr, ctx); + return foldByBE(expression, ctx); + } + + private Expression foldByBE(Expression root, ExpressionRewriteContext context) { + Map constMap = Maps.newHashMap(); + Map staleConstTExprMap = Maps.newHashMap(); + collectConst(root, constMap, staleConstTExprMap); + if (constMap.isEmpty()) { + return root; + } + Map> paramMap = new HashMap<>(); + paramMap.put("0", staleConstTExprMap); + Map resultMap = evalOnBE(paramMap, constMap, context.connectContext); + if (!resultMap.isEmpty()) { + return replace(root, constMap, resultMap); + } + return root; + } + + private Expression replace(Expression root, Map constMap, Map resultMap) { + for (Entry entry : constMap.entrySet()) { + if (entry.getValue().equals(root)) { + return resultMap.get(entry.getKey()); + } + } + List newChildren = new ArrayList<>(); + boolean hasNewChildren = false; + for (Expression child : root.children()) { + Expression newChild = replace(child, constMap, resultMap); + if (newChild != child) { + hasNewChildren = true; + } + newChildren.add(newChild); + } + return hasNewChildren ? root.withChildren(newChildren) : root; + } + + private void collectConst(Expression expr, Map constMap, Map tExprMap) { + if (expr.isConstant()) { + // Do not constant fold cast(null as dataType) because we cannot preserve the + // cast-to-types and that can lead to query failures, e.g., CTAS + if (expr instanceof Cast) { + if (((Cast) expr).child().isNullLiteral()) { + return; + } + } + // skip literal expr + if (expr.isLiteral()) { + return; + } + // skip BetweenPredicate need to be rewrite to CompoundPredicate + if (expr instanceof Between) { + return; + } + String id = idGenerator.getNextId().toString(); + constMap.put(id, expr); + Expr staleExpr = ExpressionTranslator.translate(expr, null); + tExprMap.put(id, staleExpr.treeToThrift()); + } else { + for (int i = 0; i < expr.children().size(); i++) { + final Expression child = expr.children().get(i); + collectConst(child, constMap, tExprMap); + } + } + } + + private Map evalOnBE(Map> paramMap, + Map constMap, ConnectContext context) { + + Map resultMap = new HashMap<>(); + try { + List backendIds = Env.getCurrentSystemInfo().getBackendIds(true); + if (backendIds.isEmpty()) { + throw new UserException("No alive backends"); + } + Collections.shuffle(backendIds); + Backend be = Env.getCurrentSystemInfo().getBackend(backendIds.get(0)); + TNetworkAddress brpcAddress = new TNetworkAddress(be.getHost(), be.getBrpcPort()); + + TQueryGlobals queryGlobals = new TQueryGlobals(); + queryGlobals.setNowString(DATE_FORMAT.format(LocalDateTime.now())); + queryGlobals.setTimestampMs(System.currentTimeMillis()); + queryGlobals.setTimeZone(TimeUtils.DEFAULT_TIME_ZONE); + if (context.getSessionVariable().getTimeZone().equals("CST")) { + queryGlobals.setTimeZone(TimeUtils.DEFAULT_TIME_ZONE); + } else { + queryGlobals.setTimeZone(context.getSessionVariable().getTimeZone()); + } + + TFoldConstantParams tParams = new TFoldConstantParams(paramMap, queryGlobals); + tParams.setVecExec(VectorizedUtil.isVectorized()); + + Future future = + BackendServiceProxy.getInstance().foldConstantExpr(brpcAddress, tParams); + PConstantExprResult result = future.get(5, TimeUnit.SECONDS); + + if (result.getStatus().getStatusCode() == 0) { + for (Entry e : result.getExprResultMapMap().entrySet()) { + for (Entry e1 : e.getValue().getMapMap().entrySet()) { + Expression ret; + if (e1.getValue().getSuccess()) { + TPrimitiveType type = TPrimitiveType.findByValue(e1.getValue().getType().getType()); + Type t = Type.fromPrimitiveType(PrimitiveType.fromThrift(Objects.requireNonNull(type))); + Expr staleExpr = LiteralExpr.create(e1.getValue().getContent(), Objects.requireNonNull(t)); + // Nereids type + DataType t1 = DataType.convertFromString(staleExpr.getType().getPrimitiveType().toString()); + ret = Literal.of(staleExpr.getStringValue()).castTo(t1); + } else { + ret = constMap.get(e.getKey()); + } + resultMap.put(e.getKey(), ret); + } + } + + } else { + LOG.warn("failed to get const expr value from be: {}", result.getStatus().getErrorMsgsList()); + } + } catch (Exception e) { + LOG.warn("failed to get const expr value from be: {}", e.getMessage()); + } + return resultMap; + } +} + diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/rules/FoldConstantRuleOnFE.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/rules/FoldConstantRuleOnFE.java new file mode 100644 index 0000000000..d576195389 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/rules/FoldConstantRuleOnFE.java @@ -0,0 +1,310 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.expression.rewrite.rules; + +import org.apache.doris.nereids.rules.expression.rewrite.AbstractExpressionRewriteRule; +import org.apache.doris.nereids.rules.expression.rewrite.ExpressionRewriteContext; +import org.apache.doris.nereids.trees.expressions.And; +import org.apache.doris.nereids.trees.expressions.BinaryArithmetic; +import org.apache.doris.nereids.trees.expressions.CaseWhen; +import org.apache.doris.nereids.trees.expressions.Cast; +import org.apache.doris.nereids.trees.expressions.EqualTo; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.ExpressionEvaluator; +import org.apache.doris.nereids.trees.expressions.GreaterThan; +import org.apache.doris.nereids.trees.expressions.GreaterThanEqual; +import org.apache.doris.nereids.trees.expressions.InPredicate; +import org.apache.doris.nereids.trees.expressions.IsNull; +import org.apache.doris.nereids.trees.expressions.LessThan; +import org.apache.doris.nereids.trees.expressions.LessThanEqual; +import org.apache.doris.nereids.trees.expressions.Like; +import org.apache.doris.nereids.trees.expressions.Not; +import org.apache.doris.nereids.trees.expressions.NullSafeEqual; +import org.apache.doris.nereids.trees.expressions.Or; +import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.expressions.TimestampArithmetic; +import org.apache.doris.nereids.trees.expressions.WhenClause; +import org.apache.doris.nereids.trees.expressions.functions.BoundFunction; +import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable; +import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral; +import org.apache.doris.nereids.trees.expressions.literal.Literal; +import org.apache.doris.nereids.util.ExpressionUtils; + +import com.google.common.collect.Lists; + +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Collectors; + +/** + * evaluate an expression on fe. + */ +public class FoldConstantRuleOnFE extends AbstractExpressionRewriteRule { + public static final FoldConstantRuleOnFE INSTANCE = new FoldConstantRuleOnFE(); + + @Override + public Expression rewrite(Expression expr, ExpressionRewriteContext ctx) { + return process(expr, ctx); + } + + @Override + public Expression visit(Expression expr, ExpressionRewriteContext context) { + return expr; + } + + /** + * process constant expression. + */ + public Expression process(Expression expr, ExpressionRewriteContext ctx) { + if (expr instanceof PropagateNullable) { + List children = expr.children() + .stream() + .map(child -> process(child, ctx)) + .collect(Collectors.toList()); + + if (ExpressionUtils.hasNullLiteral(children)) { + return Literal.of(null); + } + + if (!ExpressionUtils.isAllLiteral(children)) { + return expr.withChildren(children); + } + return expr.withChildren(children).accept(this, ctx); + } else { + return expr.accept(this, ctx); + } + } + + @Override + public Expression visitEqualTo(EqualTo equalTo, ExpressionRewriteContext context) { + return BooleanLiteral.of(((Literal) equalTo.left()).compareTo((Literal) equalTo.right()) == 0); + } + + @Override + public Expression visitGreaterThan(GreaterThan greaterThan, ExpressionRewriteContext context) { + return BooleanLiteral.of(((Literal) greaterThan.left()).compareTo((Literal) greaterThan.right()) > 0); + } + + @Override + public Expression visitGreaterThanEqual(GreaterThanEqual greaterThanEqual, ExpressionRewriteContext context) { + return BooleanLiteral.of(((Literal) greaterThanEqual.left()) + .compareTo((Literal) greaterThanEqual.right()) >= 0); + } + + @Override + public Expression visitLessThan(LessThan lessThan, ExpressionRewriteContext context) { + return BooleanLiteral.of(((Literal) lessThan.left()).compareTo((Literal) lessThan.right()) < 0); + + } + + @Override + public Expression visitLessThanEqual(LessThanEqual lessThanEqual, ExpressionRewriteContext context) { + return BooleanLiteral.of(((Literal) lessThanEqual.left()).compareTo((Literal) lessThanEqual.right()) <= 0); + } + + @Override + public Expression visitNullSafeEqual(NullSafeEqual nullSafeEqual, ExpressionRewriteContext context) { + Expression left = process(nullSafeEqual.left(), context); + Expression right = process(nullSafeEqual.right(), context); + if (ExpressionUtils.isAllLiteral(left, right)) { + Literal l = (Literal) left; + Literal r = (Literal) right; + if (l.isNullLiteral() && r.isNullLiteral()) { + return BooleanLiteral.TRUE; + } else if (!l.isNullLiteral() && !r.isNullLiteral()) { + return BooleanLiteral.of(l.compareTo(r) == 0); + } else { + return BooleanLiteral.FALSE; + } + } + return nullSafeEqual.withChildren(left, right); + } + + @Override + public Expression visitNot(Not not, ExpressionRewriteContext context) { + return BooleanLiteral.of(!((BooleanLiteral) not.child()).getValue()); + } + + @Override + public Expression visitSlot(Slot slot, ExpressionRewriteContext context) { + return slot; + } + + @Override + public Expression visitLiteral(Literal literal, ExpressionRewriteContext context) { + return literal; + } + + @Override + public Expression visitAnd(And and, ExpressionRewriteContext context) { + List children = Lists.newArrayList(); + for (Expression child : and.children()) { + Expression newChild = process(child, context); + if (newChild.equals(BooleanLiteral.FALSE)) { + return BooleanLiteral.FALSE; + } + if (!newChild.equals(BooleanLiteral.TRUE)) { + children.add(newChild); + } + } + if (children.isEmpty()) { + return BooleanLiteral.TRUE; + } + if (children.size() == 1) { + return children.get(0); + } + if (ExpressionUtils.isAllNullLiteral(children)) { + return Literal.of(null); + } + return and.withChildren(children); + } + + @Override + public Expression visitOr(Or or, ExpressionRewriteContext context) { + List children = Lists.newArrayList(); + for (Expression child : or.children()) { + Expression newChild = process(child, context); + if (newChild.equals(BooleanLiteral.TRUE)) { + return BooleanLiteral.TRUE; + } + if (!newChild.equals(BooleanLiteral.FALSE)) { + children.add(newChild); + } + } + if (children.isEmpty()) { + return BooleanLiteral.FALSE; + } + if (children.size() == 1) { + return children.get(0); + } + if (ExpressionUtils.isAllNullLiteral(children)) { + return Literal.of(null); + } + return or.withChildren(children); + } + + @Override + public Expression visitLike(Like like, ExpressionRewriteContext context) { + return like; + } + + @Override + public Expression visitCast(Cast cast, ExpressionRewriteContext context) { + Expression child = process(cast.child(), context); + // todo: process other null case + if (child.isNullLiteral()) { + return Literal.of(null); + } + if (child.isLiteral()) { + return child.castTo(cast.getDataType()); + } + return cast.withChildren(child); + } + + @Override + public Expression visitBoundFunction(BoundFunction boundFunction, ExpressionRewriteContext context) { + List newArgs = boundFunction.getArguments().stream().map(arg -> process(arg, context)) + .collect(Collectors.toList()); + if (ExpressionUtils.isAllLiteral(newArgs)) { + return ExpressionEvaluator.INSTANCE.eval(boundFunction.withChildren(newArgs)); + } + return boundFunction.withChildren(newArgs); + } + + @Override + public Expression visitBinaryArithmetic(BinaryArithmetic binaryArithmetic, ExpressionRewriteContext context) { + return ExpressionEvaluator.INSTANCE.eval(binaryArithmetic); + } + + @Override + public Expression visitCaseWhen(CaseWhen caseWhen, ExpressionRewriteContext context) { + Expression newDefault = null; + boolean foundNewDefault = false; + + List whenClauses = new ArrayList<>(); + for (WhenClause whenClause : caseWhen.getWhenClauses()) { + Expression whenOperand = process(whenClause.getOperand(), context); + + if (!(whenOperand.isLiteral())) { + whenClauses.add(new WhenClause(whenOperand, process(whenClause.getResult(), context))); + } else if (BooleanLiteral.TRUE.equals(whenOperand)) { + foundNewDefault = true; + newDefault = process(whenClause.getResult(), context); + break; + } + } + + Expression defaultResult; + if (foundNewDefault) { + defaultResult = newDefault; + } else { + defaultResult = process(caseWhen.getDefaultValue().orElse(Literal.of(null)), context); + } + + if (whenClauses.isEmpty()) { + return defaultResult; + } + return new CaseWhen(whenClauses, defaultResult); + } + + @Override + public Expression visitInPredicate(InPredicate inPredicate, ExpressionRewriteContext context) { + Expression value = process(inPredicate.child(0), context); + List children = Lists.newArrayList(); + children.add(value); + if (value.isNullLiteral()) { + return Literal.of(null); + } + boolean hasNull = false; + boolean hasUnresolvedValue = !value.isLiteral(); + for (int i = 1; i < inPredicate.children().size(); i++) { + Expression inValue = process(inPredicate.child(i), context); + children.add(inValue); + if (!inValue.isLiteral()) { + hasUnresolvedValue = true; + } + if (inValue.isNullLiteral()) { + hasNull = true; + } + if (inValue.isLiteral() && value.isLiteral() && ((Literal) value).compareTo((Literal) inValue) == 0) { + return Literal.of(true); + } + } + if (hasUnresolvedValue) { + return inPredicate.withChildren(children); + } + return hasNull ? Literal.of(null) : Literal.of(false); + } + + @Override + public Expression visitIsNull(IsNull isNull, ExpressionRewriteContext context) { + Expression child = process(isNull.child(), context); + if (child.isNullLiteral()) { + return Literal.of(true); + } else if (!child.nullable()) { + return Literal.of(false); + } + return isNull.withChildren(child); + } + + @Override + public Expression visitTimestampArithmetic(TimestampArithmetic arithmetic, ExpressionRewriteContext context) { + return ExpressionEvaluator.INSTANCE.eval(arithmetic); + } +} + diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/rules/TypeCoercion.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/rules/TypeCoercion.java index f13d20ec87..d3adede8be 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/rules/TypeCoercion.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/rules/TypeCoercion.java @@ -17,16 +17,19 @@ package org.apache.doris.nereids.rules.expression.rewrite.rules; +import org.apache.doris.analysis.ArithmeticExpr.Operator; import org.apache.doris.nereids.annotation.Developing; import org.apache.doris.nereids.rules.expression.rewrite.AbstractExpressionRewriteRule; import org.apache.doris.nereids.rules.expression.rewrite.ExpressionRewriteContext; import org.apache.doris.nereids.trees.expressions.BinaryOperator; import org.apache.doris.nereids.trees.expressions.CaseWhen; import org.apache.doris.nereids.trees.expressions.Cast; +import org.apache.doris.nereids.trees.expressions.Divide; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.InPredicate; import org.apache.doris.nereids.trees.expressions.typecoercion.ImplicitCastInputTypes; import org.apache.doris.nereids.types.DataType; +import org.apache.doris.nereids.types.DoubleType; import org.apache.doris.nereids.types.coercion.AbstractDataType; import org.apache.doris.nereids.util.TypeCoercionUtils; @@ -45,9 +48,8 @@ public class TypeCoercion extends AbstractExpressionRewriteRule { // TODO: // 1. DecimalPrecision Process - // 2. Divide process - // 3. String promote with numeric in binary arithmetic - // 4. Date and DateTime process + // 2. String promote with numeric in binary arithmetic + // 3. Date and DateTime process public static final TypeCoercion INSTANCE = new TypeCoercion(); @@ -82,6 +84,23 @@ public class TypeCoercion extends AbstractExpressionRewriteRule { .orElse(binaryOperator.withChildren(left, right)); } + @Override + public Expression visitDivide(Divide divide, ExpressionRewriteContext context) { + Expression left = rewrite(divide.left(), context); + Expression right = rewrite(divide.right(), context); + DataType t1 = TypeCoercionUtils.getNumResultType(left.getDataType()); + DataType t2 = TypeCoercionUtils.getNumResultType(right.getDataType()); + DataType commonType = TypeCoercionUtils.findCommonNumericsType(t1, t2); + if (divide.getLegacyOperator() == Operator.DIVIDE) { + if (commonType.isBigIntType() || commonType.isLargeIntType()) { + commonType = DoubleType.INSTANCE; + } + } + Expression newLeft = TypeCoercionUtils.castIfNotSameType(left, commonType); + Expression newRight = TypeCoercionUtils.castIfNotSameType(right, commonType); + return divide.withChildren(newLeft, newRight); + } + @Override public Expression visitCaseWhen(CaseWhen caseWhen, ExpressionRewriteContext context) { List rewrittenChildren = caseWhen.children().stream() diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/BinaryArithmetic.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/BinaryArithmetic.java index fc453f988d..ec99e5e2d2 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/BinaryArithmetic.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/BinaryArithmetic.java @@ -19,13 +19,14 @@ package org.apache.doris.nereids.trees.expressions; import org.apache.doris.analysis.ArithmeticExpr.Operator; import org.apache.doris.nereids.exceptions.UnboundException; +import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable; import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; import org.apache.doris.nereids.types.DataType; /** * binary arithmetic operator. Such as +, -, *, /. */ -public abstract class BinaryArithmetic extends BinaryOperator { +public abstract class BinaryArithmetic extends BinaryOperator implements PropagateNullable { private final Operator legacyOperator; diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/CompoundPredicate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/CompoundPredicate.java index ac6834e4fd..4b8c3b2464 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/CompoundPredicate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/CompoundPredicate.java @@ -71,5 +71,6 @@ public abstract class CompoundPredicate extends BinaryOperator { public abstract CompoundPredicate flip(Expression left, Expression right); public abstract Class flipType(); + } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/EqualTo.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/EqualTo.java index 0d193487d4..c4ddb89443 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/EqualTo.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/EqualTo.java @@ -18,6 +18,7 @@ package org.apache.doris.nereids.trees.expressions; import org.apache.doris.nereids.exceptions.UnboundException; +import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable; import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; import com.google.common.base.Preconditions; @@ -27,7 +28,7 @@ import java.util.List; /** * Equal to expression: a = b. */ -public class EqualTo extends ComparisonPredicate { +public class EqualTo extends ComparisonPredicate implements PropagateNullable { public EqualTo(Expression left, Expression right) { super(left, right, "="); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/ExecFunction.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/ExecFunction.java new file mode 100644 index 0000000000..6778c0971e --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/ExecFunction.java @@ -0,0 +1,46 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.trees.expressions; + +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +/** + * nereids function annotation. + */ +@Retention(RetentionPolicy.RUNTIME) +@Target(ElementType.METHOD) +public @interface ExecFunction { + + /** + * function name + */ + String name(); + + /** + * args type + */ + String[] argTypes(); + + /** + * return type + */ + String returnType(); +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/ExecFunctionList.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/ExecFunctionList.java new file mode 100644 index 0000000000..72f62b18c9 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/ExecFunctionList.java @@ -0,0 +1,35 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.trees.expressions; + +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +/** + * exec function list + */ +@Retention(RetentionPolicy.RUNTIME) +@Target(ElementType.METHOD) +public @interface ExecFunctionList { + /** + * exec functions. + */ + ExecFunction[] value(); +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Expression.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Expression.java index d2c76bf533..e911c739f0 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Expression.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Expression.java @@ -23,6 +23,7 @@ import org.apache.doris.nereids.exceptions.UnboundException; import org.apache.doris.nereids.trees.AbstractTreeNode; import org.apache.doris.nereids.trees.expressions.functions.ComputeNullable; import org.apache.doris.nereids.trees.expressions.literal.Literal; +import org.apache.doris.nereids.trees.expressions.literal.NullLiteral; import org.apache.doris.nereids.trees.expressions.shape.LeafExpression; import org.apache.doris.nereids.trees.expressions.typecoercion.ExpectsInputTypes; import org.apache.doris.nereids.trees.expressions.typecoercion.TypeCheckResult; @@ -139,6 +140,14 @@ public abstract class Expression extends AbstractTreeNode implements return collect(Slot.class::isInstance); } + public boolean isLiteral() { + return this instanceof Literal; + } + + public boolean isNullLiteral() { + return this instanceof NullLiteral; + } + @Override public boolean equals(Object o) { if (this == o) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/ExpressionEvaluator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/ExpressionEvaluator.java new file mode 100644 index 0000000000..ef452e7d68 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/ExpressionEvaluator.java @@ -0,0 +1,208 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.trees.expressions; + +import org.apache.doris.catalog.Env; +import org.apache.doris.common.AnalysisException; +import org.apache.doris.nereids.trees.expressions.functions.ExecutableFunctions; +import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; +import org.apache.doris.nereids.trees.expressions.literal.Literal; +import org.apache.doris.nereids.trees.expressions.literal.NullLiteral; +import org.apache.doris.nereids.types.DataType; + +import com.google.common.collect.ImmutableMultimap; + +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; + +/** + * An expression evaluator that evaluates the value of an expression. + */ +public enum ExpressionEvaluator { + INSTANCE; + private ImmutableMultimap functions; + + ExpressionEvaluator() { + registerFunctions(); + } + + /** + * Evaluate the value of the expression. + */ + public Expression eval(Expression expression) { + + if (!expression.isConstant() || expression instanceof AggregateFunction) { + return expression; + } + + String fnName = null; + DataType[] args = null; + if (expression instanceof BinaryArithmetic) { + BinaryArithmetic arithmetic = (BinaryArithmetic) expression; + fnName = arithmetic.getLegacyOperator().getName(); + args = new DataType[]{arithmetic.left().getDataType(), arithmetic.right().getDataType()}; + } else if (expression instanceof TimestampArithmetic) { + TimestampArithmetic arithmetic = (TimestampArithmetic) expression; + fnName = arithmetic.getFuncName(); + args = new DataType[]{arithmetic.left().getDataType(), arithmetic.right().getDataType()}; + } + + if ((Env.getCurrentEnv().isNullResultWithOneNullParamFunction(fnName))) { + for (Expression e : expression.children()) { + if (e instanceof NullLiteral) { + return Literal.of(null); + } + } + } + + return invoke(expression, fnName, args); + } + + private Expression invoke(Expression expression, String fnName, DataType[] args) { + FunctionSignature signature = new FunctionSignature(fnName, args, null); + FunctionInvoker invoker = getFunction(signature); + if (invoker != null) { + try { + return invoker.invoke(expression.children()); + } catch (AnalysisException e) { + return expression; + } + } + return expression; + } + + private FunctionInvoker getFunction(FunctionSignature signature) { + Collection functionInvokers = functions.get(signature.getName()); + if (functionInvokers == null) { + return null; + } + for (FunctionInvoker candidate : functionInvokers) { + DataType[] candidateTypes = candidate.getSignature().getArgTypes(); + DataType[] expectedTypes = signature.getArgTypes(); + + if (candidateTypes.length != expectedTypes.length) { + continue; + } + boolean match = true; + for (int i = 0; i < candidateTypes.length; i++) { + if (!candidateTypes[i].equals(expectedTypes[i])) { + match = false; + break; + } + } + if (match) { + return candidate; + } + } + return null; + } + + private void registerFunctions() { + if (functions != null) { + return; + } + ImmutableMultimap.Builder mapBuilder = + new ImmutableMultimap.Builder(); + Class clazz = ExecutableFunctions.class; + for (Method method : clazz.getDeclaredMethods()) { + ExecFunctionList annotationList = method.getAnnotation(ExecFunctionList.class); + if (annotationList != null) { + for (ExecFunction f : annotationList.value()) { + registerFEFunction(mapBuilder, method, f); + } + } + registerFEFunction(mapBuilder, method, method.getAnnotation(ExecFunction.class)); + } + this.functions = mapBuilder.build(); + } + + private void registerFEFunction(ImmutableMultimap.Builder mapBuilder, + Method method, ExecFunction annotation) { + if (annotation != null) { + String name = annotation.name(); + DataType returnType = DataType.convertFromString(annotation.returnType()); + List argTypes = new ArrayList<>(); + for (String type : annotation.argTypes()) { + argTypes.add(DataType.convertFromString(type)); + } + FunctionSignature signature = new FunctionSignature(name, + argTypes.toArray(new DataType[argTypes.size()]), returnType); + mapBuilder.put(name, new FunctionInvoker(method, signature)); + } + } + + /** + * function invoker. + */ + public static class FunctionInvoker { + private final Method method; + private final FunctionSignature signature; + + public FunctionInvoker(Method method, FunctionSignature signature) { + this.method = method; + this.signature = signature; + } + + public Method getMethod() { + return method; + } + + public FunctionSignature getSignature() { + return signature; + } + + public Literal invoke(List args) throws AnalysisException { + try { + return (Literal) method.invoke(null, args.toArray()); + } catch (InvocationTargetException | IllegalAccessException | IllegalArgumentException e) { + throw new AnalysisException(e.getLocalizedMessage()); + } + } + } + + /** + * function signature. + */ + public static class FunctionSignature { + private final String name; + private final DataType[] argTypes; + private final DataType returnType; + + public FunctionSignature(String name, DataType[] argTypes, DataType returnType) { + this.name = name; + this.argTypes = argTypes; + this.returnType = returnType; + } + + public DataType[] getArgTypes() { + return argTypes; + } + + public DataType getReturnType() { + return returnType; + } + + public String getName() { + return name; + } + } + +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/GreaterThan.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/GreaterThan.java index 6f4fc44a6b..dc673f5318 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/GreaterThan.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/GreaterThan.java @@ -18,6 +18,7 @@ package org.apache.doris.nereids.trees.expressions; import org.apache.doris.nereids.exceptions.UnboundException; +import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable; import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; import com.google.common.base.Preconditions; @@ -27,7 +28,7 @@ import java.util.List; /** * Greater than expression: a > b. */ -public class GreaterThan extends ComparisonPredicate { +public class GreaterThan extends ComparisonPredicate implements PropagateNullable { /** * Constructor of Greater Than ComparisonPredicate. * diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/GreaterThanEqual.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/GreaterThanEqual.java index acda6a8219..0e1b5a45bc 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/GreaterThanEqual.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/GreaterThanEqual.java @@ -18,6 +18,7 @@ package org.apache.doris.nereids.trees.expressions; import org.apache.doris.nereids.exceptions.UnboundException; +import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable; import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; import com.google.common.base.Preconditions; @@ -27,7 +28,7 @@ import java.util.List; /** * Greater than and equal expression: a >= b. */ -public class GreaterThanEqual extends ComparisonPredicate { +public class GreaterThanEqual extends ComparisonPredicate implements PropagateNullable { /** * Constructor of Greater Than And Equal. * diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/IsNull.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/IsNull.java new file mode 100644 index 0000000000..57983cc130 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/IsNull.java @@ -0,0 +1,78 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.trees.expressions; + +import org.apache.doris.nereids.exceptions.UnboundException; +import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression; +import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; + +import com.google.common.base.Preconditions; + +import java.util.List; +import java.util.Objects; + +/** + * expr is null predicate. + */ +public class IsNull extends Expression implements UnaryExpression { + + public IsNull(Expression e) { + super(e); + } + + @Override + public R accept(ExpressionVisitor visitor, C context) { + return visitor.visitIsNull(this, context); + } + + @Override + public boolean nullable() { + return false; + } + + @Override + public IsNull withChildren(List children) { + Preconditions.checkArgument(children.size() == 1); + return new IsNull(children.get(0)); + } + + @Override + public String toSql() throws UnboundException { + return child().toSql() + " IS NULL"; + } + + @Override + public String toString() { + return toSql(); + } + + @Override + public boolean equals(Object o) { + if (!super.equals(o)) { + return false; + } + IsNull other = (IsNull) o; + return Objects.equals(child(), other.child()); + } + + @Override + public int hashCode() { + return child().hashCode(); + } + +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/LessThan.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/LessThan.java index 0382c5b151..c682ef4678 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/LessThan.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/LessThan.java @@ -18,6 +18,7 @@ package org.apache.doris.nereids.trees.expressions; import org.apache.doris.nereids.exceptions.UnboundException; +import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable; import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; import com.google.common.base.Preconditions; @@ -27,7 +28,7 @@ import java.util.List; /** * Less than expression: a < b. */ -public class LessThan extends ComparisonPredicate { +public class LessThan extends ComparisonPredicate implements PropagateNullable { /** * Constructor of Less Than Comparison Predicate. * diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/LessThanEqual.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/LessThanEqual.java index af4fe1524d..92389869de 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/LessThanEqual.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/LessThanEqual.java @@ -18,6 +18,7 @@ package org.apache.doris.nereids.trees.expressions; import org.apache.doris.nereids.exceptions.UnboundException; +import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable; import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; import com.google.common.base.Preconditions; @@ -27,7 +28,7 @@ import java.util.List; /** * Less than and equal expression: a <= b. */ -public class LessThanEqual extends ComparisonPredicate { +public class LessThanEqual extends ComparisonPredicate implements PropagateNullable { /** * Constructor of Less Than And Equal. * diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Not.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Not.java index b1a2eee13b..a1da66c75d 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Not.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Not.java @@ -18,6 +18,7 @@ package org.apache.doris.nereids.trees.expressions; import org.apache.doris.nereids.exceptions.UnboundException; +import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable; import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression; import org.apache.doris.nereids.trees.expressions.typecoercion.ExpectsInputTypes; import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; @@ -34,7 +35,7 @@ import java.util.Objects; /** * Not expression: not a. */ -public class Not extends Expression implements UnaryExpression, ExpectsInputTypes { +public class Not extends Expression implements UnaryExpression, ExpectsInputTypes, PropagateNullable { public static final List EXPECTS_INPUT_TYPES = ImmutableList.of(BooleanType.INSTANCE); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/NullSafeEqual.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/NullSafeEqual.java index dc2ecbed42..b8a0a50b90 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/NullSafeEqual.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/NullSafeEqual.java @@ -64,4 +64,5 @@ public class NullSafeEqual extends ComparisonPredicate { public ComparisonPredicate commute() { return new NullSafeEqual(right(), left()); } + } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/SlotReference.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/SlotReference.java index 6634933318..64f5fd8f58 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/SlotReference.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/SlotReference.java @@ -79,6 +79,10 @@ public class SlotReference extends Slot { this.column = column; } + public static SlotReference of(String name, DataType type) { + return new SlotReference(name, type); + } + public static SlotReference fromColumn(Column column, List qualifier) { DataType dataType = DataType.convertFromCatalogDataType(column.getType()); return new SlotReference(NamedExpressionUtil.newExprId(), column.getName(), dataType, diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/TimestampArithmetic.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/TimestampArithmetic.java index ffe40448a3..c3365050b6 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/TimestampArithmetic.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/TimestampArithmetic.java @@ -22,16 +22,21 @@ import org.apache.doris.nereids.exceptions.UnboundException; import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable; import org.apache.doris.nereids.trees.expressions.literal.IntervalLiteral.TimeUnit; import org.apache.doris.nereids.trees.expressions.shape.BinaryExpression; +import org.apache.doris.nereids.trees.expressions.typecoercion.ImplicitCastInputTypes; import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; import org.apache.doris.nereids.types.DataType; import org.apache.doris.nereids.types.DateTimeType; import org.apache.doris.nereids.types.DateType; +import org.apache.doris.nereids.types.IntegerType; +import org.apache.doris.nereids.types.coercion.AbstractDataType; import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableList; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import java.util.List; +import java.util.Objects; /** * Describes the addition and subtraction of time units from timestamps. @@ -40,7 +45,14 @@ import java.util.List; * Example: '1996-01-01' + INTERVAL '3' month; * TODO: we need to rethink this, and maybe need to add a new type of Interval then implement IntervalLiteral as others */ -public class TimestampArithmetic extends Expression implements BinaryExpression, PropagateNullable { +public class TimestampArithmetic extends Expression implements BinaryExpression, ImplicitCastInputTypes, + PropagateNullable { + + private static final List EXPECTED_INPUT_TYPES = ImmutableList.of( + DateTimeType.INSTANCE, + IntegerType.INSTANCE + ); + private static final Logger LOG = LogManager.getLogger(TimestampArithmetic.class); private final String funcName; private final boolean intervalFirst; @@ -149,4 +161,22 @@ public class TimestampArithmetic extends Expression implements BinaryExpression, } return strBuilder.toString(); } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + TimestampArithmetic other = (TimestampArithmetic) o; + return Objects.equals(funcName, other.funcName) && Objects.equals(timeUnit, other.timeUnit) + && Objects.equals(left(), other.left()) && Objects.equals(right(), other.right()); + } + + @Override + public List expectedInputTypes() { + return EXPECTED_INPUT_TYPES; + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/WhenClause.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/WhenClause.java index be0c34c6c4..1ef30448bb 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/WhenClause.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/WhenClause.java @@ -54,7 +54,7 @@ public class WhenClause extends Expression implements BinaryExpression, ExpectsI @Override public String toSql() { - return "WHEN " + left().toSql() + " THEN " + right().toSql(); + return " WHEN " + left().toSql() + " THEN " + right().toSql(); } @Override diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/ExecutableFunctions.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/ExecutableFunctions.java new file mode 100644 index 0000000000..4c65d19883 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/ExecutableFunctions.java @@ -0,0 +1,243 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.trees.expressions.functions; + +import org.apache.doris.nereids.exceptions.AnalysisException; +import org.apache.doris.nereids.trees.expressions.ExecFunction; +import org.apache.doris.nereids.trees.expressions.literal.BigIntLiteral; +import org.apache.doris.nereids.trees.expressions.literal.DateTimeLiteral; +import org.apache.doris.nereids.trees.expressions.literal.DecimalLiteral; +import org.apache.doris.nereids.trees.expressions.literal.DoubleLiteral; +import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral; +import org.apache.doris.nereids.trees.expressions.literal.SmallIntLiteral; +import org.apache.doris.nereids.trees.expressions.literal.TinyIntLiteral; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; + +import java.math.BigDecimal; + +/** + * functions that can be executed in FE. + */ +public class ExecutableFunctions { + public static final ExecutableFunctions INSTANCE = new ExecutableFunctions(); + + private static final Logger LOG = LogManager.getLogger(ExecutableFunctions.class); + + /** + * Executable arithmetic functions + */ + + @ExecFunction(name = "add", argTypes = {"TINYINT", "TINYINT"}, returnType = "TINYINT") + public static TinyIntLiteral addTinyint(TinyIntLiteral first, TinyIntLiteral second) { + byte result = (byte) Math.addExact(first.getValue(), second.getValue()); + return new TinyIntLiteral(result); + } + + @ExecFunction(name = "add", argTypes = {"SMALLINT", "SMALLINT"}, returnType = "SMALLINT") + public static SmallIntLiteral addSmallint(SmallIntLiteral first, SmallIntLiteral second) { + short result = (short) Math.addExact(first.getValue(), second.getValue()); + return new SmallIntLiteral(result); + } + + @ExecFunction(name = "add", argTypes = {"INT", "INT"}, returnType = "INT") + public static IntegerLiteral addInt(IntegerLiteral first, IntegerLiteral second) { + int result = Math.addExact(first.getValue(), second.getValue()); + return new IntegerLiteral(result); + } + + @ExecFunction(name = "add", argTypes = {"BIGINT", "BIGINT"}, returnType = "BIGINT") + public static BigIntLiteral addBigint(BigIntLiteral first, BigIntLiteral second) { + long result = Math.addExact(first.getValue(), second.getValue()); + return new BigIntLiteral(result); + } + + @ExecFunction(name = "add", argTypes = {"DOUBLE", "DOUBLE"}, returnType = "DOUBLE") + public static DoubleLiteral addDouble(DoubleLiteral first, DoubleLiteral second) { + double result = first.getValue() + second.getValue(); + return new DoubleLiteral(result); + } + + @ExecFunction(name = "add", argTypes = {"DECIMAL", "DECIMAL"}, returnType = "DECIMAL") + public static DecimalLiteral addDecimal(DecimalLiteral first, DecimalLiteral second) { + BigDecimal result = first.getValue().add(second.getValue()); + return new DecimalLiteral(result); + } + + @ExecFunction(name = "subtract", argTypes = {"TINYINT", "TINYINT"}, returnType = "TINYINT") + public static TinyIntLiteral subtractTinyint(TinyIntLiteral first, TinyIntLiteral second) { + byte result = (byte) Math.subtractExact(first.getValue(), second.getValue()); + return new TinyIntLiteral(result); + } + + @ExecFunction(name = "subtract", argTypes = {"SMALLINT", "SMALLINT"}, returnType = "SMALLINT") + public static SmallIntLiteral subtractSmallint(SmallIntLiteral first, SmallIntLiteral second) { + short result = (short) Math.subtractExact(first.getValue(), second.getValue()); + return new SmallIntLiteral(result); + } + + @ExecFunction(name = "subtract", argTypes = {"INT", "INT"}, returnType = "INT") + public static IntegerLiteral subtractInt(IntegerLiteral first, IntegerLiteral second) { + int result = Math.subtractExact(first.getValue(), second.getValue()); + return new IntegerLiteral(result); + } + + @ExecFunction(name = "subtract", argTypes = {"BIGINT", "BIGINT"}, returnType = "BIGINT") + public static BigIntLiteral subtractBigint(BigIntLiteral first, BigIntLiteral second) { + long result = Math.subtractExact(first.getValue(), second.getValue()); + return new BigIntLiteral(result); + } + + @ExecFunction(name = "subtract", argTypes = {"DOUBLE", "DOUBLE"}, returnType = "DOUBLE") + public static DoubleLiteral subtractDouble(DoubleLiteral first, DoubleLiteral second) { + double result = first.getValue() - second.getValue(); + return new DoubleLiteral(result); + } + + @ExecFunction(name = "subtract", argTypes = {"DECIMAL", "DECIMAL"}, returnType = "DECIMAL") + public static DecimalLiteral subtractDecimal(DecimalLiteral first, DecimalLiteral second) { + BigDecimal result = first.getValue().subtract(second.getValue()); + return new DecimalLiteral(result); + } + + @ExecFunction(name = "multiply", argTypes = {"TINYINT", "TINYINT"}, returnType = "TINYINT") + public static TinyIntLiteral multiplyTinyint(TinyIntLiteral first, TinyIntLiteral second) { + byte result = (byte) Math.multiplyExact(first.getValue(), second.getValue()); + return new TinyIntLiteral(result); + } + + @ExecFunction(name = "multiply", argTypes = {"SMALLINT", "SMALLINT"}, returnType = "SMALLINT") + public static SmallIntLiteral multiplySmallint(SmallIntLiteral first, SmallIntLiteral second) { + short result = (short) Math.multiplyExact(first.getValue(), second.getValue()); + return new SmallIntLiteral(result); + } + + @ExecFunction(name = "multiply", argTypes = {"INT", "INT"}, returnType = "INT") + public static IntegerLiteral multiplyInt(IntegerLiteral first, IntegerLiteral second) { + int result = Math.multiplyExact(first.getValue(), second.getValue()); + return new IntegerLiteral(result); + } + + @ExecFunction(name = "multiply", argTypes = {"BIGINT", "BIGINT"}, returnType = "BIGINT") + public static BigIntLiteral multiplyBigint(BigIntLiteral first, BigIntLiteral second) { + long result = Math.multiplyExact(first.getValue(), second.getValue()); + return new BigIntLiteral(result); + } + + @ExecFunction(name = "multiply", argTypes = {"DOUBLE", "DOUBLE"}, returnType = "DOUBLE") + public static DoubleLiteral multiplyDouble(DoubleLiteral first, DoubleLiteral second) { + double result = first.getValue() * second.getValue(); + return new DoubleLiteral(result); + } + + @ExecFunction(name = "multiply", argTypes = {"DECIMAL", "DECIMAL"}, returnType = "DECIMAL") + public static DecimalLiteral multiplyDecimal(DecimalLiteral first, DecimalLiteral second) { + BigDecimal result = first.getValue().multiply(second.getValue()); + return new DecimalLiteral(result); + } + + @ExecFunction(name = "divide", argTypes = {"DOUBLE", "DOUBLE"}, returnType = "DOUBLE") + public static DoubleLiteral divideDouble(DoubleLiteral first, DoubleLiteral second) { + if (second.getValue() == 0.0) { + return null; + } + double result = first.getValue() / second.getValue(); + return new DoubleLiteral(result); + } + + @ExecFunction(name = "divide", argTypes = {"DECIMAL", "DECIMAL"}, returnType = "DECIMAL") + public static DecimalLiteral divideDecimal(DecimalLiteral first, DecimalLiteral second) { + if (first.getValue().compareTo(BigDecimal.ZERO) == 0) { + return null; + } + BigDecimal result = first.getValue().divide(second.getValue()); + return new DecimalLiteral(result); + } + + @ExecFunction(name = "date_sub", argTypes = { "DATETIME", "INT" }, returnType = "DATETIME") + public static DateTimeLiteral dateSub(DateTimeLiteral date, IntegerLiteral day) throws AnalysisException { + return dateAdd(date, new IntegerLiteral(-day.getValue())); + } + + @ExecFunction(name = "date_add", argTypes = { "DATETIME", "INT" }, returnType = "DATETIME") + public static DateTimeLiteral dateAdd(DateTimeLiteral date, IntegerLiteral day) throws AnalysisException { + return daysAdd(date, day); + } + + @ExecFunction(name = "years_add", argTypes = { "DATETIME", "INT" }, returnType = "DATETIME") + public static DateTimeLiteral yearsAdd(DateTimeLiteral date, IntegerLiteral year) throws AnalysisException { + return date.plusYears(year.getValue()); + } + + @ExecFunction(name = "months_add", argTypes = { "DATETIME", "INT" }, returnType = "DATETIME") + public static DateTimeLiteral monthsAdd(DateTimeLiteral date, IntegerLiteral month) throws AnalysisException { + return date.plusMonths(month.getValue()); + } + + @ExecFunction(name = "days_add", argTypes = { "DATETIME", "INT" }, returnType = "DATETIME") + public static DateTimeLiteral daysAdd(DateTimeLiteral date, IntegerLiteral day) throws AnalysisException { + return date.plusDays(day.getValue()); + } + + @ExecFunction(name = "hours_add", argTypes = { "DATETIME", "INT" }, returnType = "DATETIME") + public static DateTimeLiteral hoursAdd(DateTimeLiteral date, IntegerLiteral hour) throws AnalysisException { + return date.plusHours(hour.getValue()); + } + + @ExecFunction(name = "minutes_add", argTypes = { "DATETIME", "INT" }, returnType = "DATETIME") + public static DateTimeLiteral minutesAdd(DateTimeLiteral date, IntegerLiteral minute) throws AnalysisException { + return date.plusMinutes(minute.getValue()); + } + + @ExecFunction(name = "seconds_add", argTypes = { "DATETIME", "INT" }, returnType = "DATETIME") + public static DateTimeLiteral secondsAdd(DateTimeLiteral date, IntegerLiteral second) throws AnalysisException { + return date.plusSeconds(second.getValue()); + } + + @ExecFunction(name = "years_sub", argTypes = { "DATETIME", "INT" }, returnType = "DATETIME") + public static DateTimeLiteral yearsSub(DateTimeLiteral date, IntegerLiteral year) throws AnalysisException { + return yearsAdd(date, new IntegerLiteral(-year.getValue())); + } + + @ExecFunction(name = "months_sub", argTypes = { "DATETIME", "INT" }, returnType = "DATETIME") + public static DateTimeLiteral monthsSub(DateTimeLiteral date, IntegerLiteral month) throws AnalysisException { + return monthsAdd(date, new IntegerLiteral(-month.getValue())); + } + + @ExecFunction(name = "days_sub", argTypes = { "DATETIME", "INT" }, returnType = "DATETIME") + public static DateTimeLiteral daysSub(DateTimeLiteral date, IntegerLiteral day) throws AnalysisException { + return daysAdd(date, new IntegerLiteral(-day.getValue())); + } + + @ExecFunction(name = "hours_sub", argTypes = { "DATETIME", "INT" }, returnType = "DATETIME") + public static DateTimeLiteral hoursSub(DateTimeLiteral date, IntegerLiteral hour) throws AnalysisException { + return hoursAdd(date, new IntegerLiteral(-hour.getValue())); + } + + @ExecFunction(name = "minutes_sub", argTypes = { "DATETIME", "INT" }, returnType = "DATETIME") + public static DateTimeLiteral minutesSub(DateTimeLiteral date, IntegerLiteral minute) throws AnalysisException { + return minutesAdd(date, new IntegerLiteral(-minute.getValue())); + } + + @ExecFunction(name = "seconds_sub", argTypes = { "DATETIME", "INT" }, returnType = "DATETIME") + public static DateTimeLiteral secondsSub(DateTimeLiteral date, IntegerLiteral second) throws AnalysisException { + return secondsAdd(date, new IntegerLiteral(-second.getValue())); + } + +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/Substring.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/Substring.java index f1e38db5c5..ea751658c5 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/Substring.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/Substring.java @@ -20,6 +20,7 @@ package org.apache.doris.nereids.trees.expressions.functions.scalar; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.functions.AlwaysNullable; import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral; +import org.apache.doris.nereids.trees.expressions.literal.Literal; import org.apache.doris.nereids.trees.expressions.typecoercion.ImplicitCastInputTypes; import org.apache.doris.nereids.types.DataType; import org.apache.doris.nereids.types.IntegerType; @@ -50,7 +51,7 @@ public class Substring extends ScalarFunction implements ImplicitCastInputTypes, } public Substring(Expression str, Expression pos) { - super("substring", str, pos); + super("substring", str, pos, Literal.of(Integer.MAX_VALUE)); } public Expression getSource() { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/DateLiteral.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/DateLiteral.java index e7a4b98395..68243b3cea 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/DateLiteral.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/DateLiteral.java @@ -19,14 +19,11 @@ package org.apache.doris.nereids.trees.expressions.literal; import org.apache.doris.analysis.LiteralExpr; import org.apache.doris.nereids.exceptions.AnalysisException; -import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; import org.apache.doris.nereids.types.DataType; -import org.apache.doris.nereids.types.DateTimeType; import org.apache.doris.nereids.types.DateType; import org.apache.doris.nereids.util.DateUtils; -import com.google.common.base.Preconditions; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.joda.time.LocalDateTime; @@ -107,32 +104,6 @@ public class DateLiteral extends Literal { } } - @Override - protected Expression uncheckedCastTo(DataType targetType) throws AnalysisException { - if (getDataType().equals(targetType)) { - return this; - } - if (targetType.isDate()) { - if (getDataType().equals(targetType)) { - return this; - } - if (targetType.equals(DateType.INSTANCE)) { - return new DateLiteral(this.year, this.month, this.day); - } else if (targetType.equals(DateTimeType.INSTANCE)) { - return new DateTimeLiteral(this.year, this.month, this.day, 0, 0, 0); - } else { - throw new AnalysisException("Error date literal type"); - } - } - //todo other target type cast - return this; - } - - public DateLiteral withDataType(DataType type) { - Preconditions.checkArgument(type.isDate() || type.isDateTime()); - return new DateLiteral(this, type); - } - @Override public R accept(ExpressionVisitor visitor, C context) { return visitor.visitDateLiteral(this, context); @@ -153,6 +124,11 @@ public class DateLiteral extends Literal { return String.format("%04d-%02d-%02d", year, month, day); } + @Override + public String getStringValue() { + return String.format("%04d-%02d-%02d", year, month, day); + } + @Override public LiteralExpr toLegacyLiteral() { return new org.apache.doris.analysis.DateLiteral(year, month, day); @@ -169,6 +145,11 @@ public class DateLiteral extends Literal { public long getDay() { return day; } + + public DateLiteral plusDays(int days) { + LocalDateTime dateTime = LocalDateTime.parse(getStringValue(), DATE_FORMATTER).plusDays(days); + return new DateLiteral(dateTime.getYear(), dateTime.getMonthOfYear(), dateTime.getDayOfMonth()); + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/DateTimeLiteral.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/DateTimeLiteral.java index 60e84b66d8..5b62296057 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/DateTimeLiteral.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/DateTimeLiteral.java @@ -19,11 +19,8 @@ package org.apache.doris.nereids.trees.expressions.literal; import org.apache.doris.analysis.LiteralExpr; import org.apache.doris.nereids.exceptions.AnalysisException; -import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; -import org.apache.doris.nereids.types.DataType; import org.apache.doris.nereids.types.DateTimeType; -import org.apache.doris.nereids.types.DateType; import org.apache.doris.nereids.util.DateUtils; import org.apache.logging.log4j.LogManager; @@ -31,6 +28,8 @@ import org.apache.logging.log4j.Logger; import org.joda.time.LocalDateTime; import org.joda.time.format.DateTimeFormatter; +import java.util.Objects; + /** * date time literal. */ @@ -39,7 +38,8 @@ public class DateTimeLiteral extends DateLiteral { private static final int DATETIME_TO_MINUTE_STRING_LENGTH = 16; private static final int DATETIME_TO_HOUR_STRING_LENGTH = 13; - + private static final int DATETIME_DEFAULT_STRING_LENGTH = 10; + private static DateTimeFormatter DATE_TIME_DEFAULT_FORMATTER = null; private static DateTimeFormatter DATE_TIME_FORMATTER = null; private static DateTimeFormatter DATE_TIME_FORMATTER_TO_HOUR = null; private static DateTimeFormatter DATE_TIME_FORMATTER_TO_MINUTE = null; @@ -55,6 +55,7 @@ public class DateTimeLiteral extends DateLiteral { DATE_TIME_FORMATTER_TO_HOUR = DateUtils.formatBuilder("%Y-%m-%d %H").toFormatter(); DATE_TIME_FORMATTER_TO_MINUTE = DateUtils.formatBuilder("%Y-%m-%d %H:%i").toFormatter(); DATE_TIME_FORMATTER_TWO_DIGIT = DateUtils.formatBuilder("%y-%m-%d %H:%i:%s").toFormatter(); + DATE_TIME_DEFAULT_FORMATTER = DateUtils.formatBuilder("%Y-%m-%d").toFormatter(); } catch (AnalysisException e) { LOG.error("invalid date format", e); System.exit(-1); @@ -89,6 +90,8 @@ public class DateTimeLiteral extends DateLiteral { dateTime = DATE_TIME_FORMATTER_TO_MINUTE.parseLocalDateTime(s); } else if (s.length() == DATETIME_TO_HOUR_STRING_LENGTH) { dateTime = DATE_TIME_FORMATTER_TO_HOUR.parseLocalDateTime(s); + } else if (s.length() == DATETIME_DEFAULT_STRING_LENGTH) { + dateTime = DATE_TIME_DEFAULT_FORMATTER.parseLocalDateTime(s); } else { dateTime = DATE_TIME_FORMATTER.parseLocalDateTime(s); } @@ -104,27 +107,6 @@ public class DateTimeLiteral extends DateLiteral { } } - @Override - protected Expression uncheckedCastTo(DataType targetType) throws AnalysisException { - if (getDataType().equals(targetType)) { - return this; - } - if (targetType.isDate()) { - if (getDataType().equals(targetType)) { - return this; - } - if (targetType.equals(DateType.INSTANCE)) { - return new DateLiteral(this.year, this.month, this.day); - } else if (targetType.equals(DateTimeType.INSTANCE)) { - return new DateTimeLiteral(this.year, this.month, this.day, this.hour, this.minute, this.second); - } else { - throw new AnalysisException("Error date literal type"); - } - } - //todo other target type cast - return this; - } - @Override public R accept(ExpressionVisitor visitor, C context) { return visitor.visitDateTimeLiteral(this, context); @@ -145,6 +127,47 @@ public class DateTimeLiteral extends DateLiteral { return String.format("%04d-%02d-%02d %02d:%02d:%02d", year, month, day, hour, minute, second); } + @Override + public String getStringValue() { + return String.format("%04d-%02d-%02d %02d:%02d:%02d", year, month, day, hour, minute, second); + } + + public DateTimeLiteral plusYears(int years) { + LocalDateTime d = LocalDateTime.parse(getStringValue(), DATE_TIME_FORMATTER).plusYears(years); + return new DateTimeLiteral(d.getYear(), d.getMonthOfYear(), d.getDayOfMonth(), + d.getHourOfDay(), d.getMinuteOfHour(), d.getSecondOfMinute()); + } + + public DateTimeLiteral plusMonths(int months) { + LocalDateTime d = LocalDateTime.parse(getStringValue(), DATE_TIME_FORMATTER).plusMonths(months); + return new DateTimeLiteral(d.getYear(), d.getMonthOfYear(), d.getDayOfMonth(), + d.getHourOfDay(), d.getMinuteOfHour(), d.getSecondOfMinute()); + } + + public DateTimeLiteral plusDays(int days) { + LocalDateTime d = LocalDateTime.parse(getStringValue(), DATE_TIME_FORMATTER).plusDays(days); + return new DateTimeLiteral(d.getYear(), d.getMonthOfYear(), d.getDayOfMonth(), + d.getHourOfDay(), d.getMinuteOfHour(), d.getSecondOfMinute()); + } + + public DateTimeLiteral plusHours(int hours) { + LocalDateTime d = LocalDateTime.parse(getStringValue(), DATE_TIME_FORMATTER).plusHours(hours); + return new DateTimeLiteral(d.getYear(), d.getMonthOfYear(), d.getDayOfMonth(), + d.getHourOfDay(), d.getMinuteOfHour(), d.getSecondOfMinute()); + } + + public DateTimeLiteral plusMinutes(int minutes) { + LocalDateTime d = LocalDateTime.parse(getStringValue(), DATE_TIME_FORMATTER).plusMinutes(minutes); + return new DateTimeLiteral(d.getYear(), d.getMonthOfYear(), d.getDayOfMonth(), + d.getHourOfDay(), d.getMinuteOfHour(), d.getSecondOfMinute()); + } + + public DateTimeLiteral plusSeconds(int seconds) { + LocalDateTime d = LocalDateTime.parse(getStringValue(), DATE_TIME_FORMATTER).plusSeconds(seconds); + return new DateTimeLiteral(d.getYear(), d.getMonthOfYear(), d.getDayOfMonth(), + d.getHourOfDay(), d.getMinuteOfHour(), d.getSecondOfMinute()); + } + @Override public LiteralExpr toLegacyLiteral() { return new org.apache.doris.analysis.DateLiteral(year, month, day, hour, minute, second); @@ -161,4 +184,16 @@ public class DateTimeLiteral extends DateLiteral { public long getSecond() { return second; } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + DateTimeLiteral other = (DateTimeLiteral) o; + return Objects.equals(getValue(), other.getValue()); + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/Literal.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/Literal.java index 6ecb7032da..d88dbacac9 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/Literal.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/Literal.java @@ -18,13 +18,20 @@ package org.apache.doris.nereids.trees.expressions.literal; import org.apache.doris.analysis.LiteralExpr; +import org.apache.doris.nereids.exceptions.AnalysisException; import org.apache.doris.nereids.exceptions.UnboundException; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.shape.LeafExpression; import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; +import org.apache.doris.nereids.types.CharType; import org.apache.doris.nereids.types.DataType; +import org.apache.doris.nereids.types.StringType; +import org.apache.commons.lang3.StringUtils; + +import java.math.BigDecimal; import java.math.BigInteger; +import java.util.Locale; import java.util.Objects; /** @@ -51,17 +58,21 @@ public abstract class Literal extends Expression implements LeafExpression { if (value == null) { return new NullLiteral(); } else if (value instanceof Byte) { - return new TinyIntLiteral(((Byte) value).byteValue()); + return new TinyIntLiteral((byte) value); } else if (value instanceof Short) { - return new SmallIntLiteral(((Short) value).shortValue()); + return new SmallIntLiteral((short) value); } else if (value instanceof Integer) { - return new IntegerLiteral((Integer) value); + return new IntegerLiteral((int) value); } else if (value instanceof Long) { - return new BigIntLiteral(((Long) value).longValue()); + return new BigIntLiteral((long) value); } else if (value instanceof BigInteger) { return new LargeIntLiteral((BigInteger) value); + } else if (value instanceof Float) { + return new FloatLiteral((float) value); + } else if (value instanceof Double) { + return new DoubleLiteral((double) value); } else if (value instanceof Boolean) { - return BooleanLiteral.of((Boolean) value); + return BooleanLiteral.of((boolean) value); } else if (value instanceof String) { return new StringLiteral((String) value); } else { @@ -71,6 +82,10 @@ public abstract class Literal extends Expression implements LeafExpression { public abstract Object getValue(); + public String getStringValue() { + return String.valueOf(getValue()); + } + @Override public DataType getDataType() throws UnboundException { return dataType; @@ -91,6 +106,89 @@ public abstract class Literal extends Expression implements LeafExpression { return visitor.visitLiteral(this, context); } + /** + * literal expr compare. + */ + public int compareTo(Literal other) { + if (isNullLiteral() && other.isNullLiteral()) { + return 0; + } else if (isNullLiteral() || other.isNullLiteral()) { + return isNullLiteral() ? -1 : 1; + } + + DataType oType = other.getDataType(); + DataType type = getDataType(); + if (!type.equals(oType)) { + throw new RuntimeException("data type not equal!"); + } else if (type.isBooleanType()) { + return Boolean.compare((boolean) getValue(), (boolean) other.getValue()); + } else if (type.isTinyIntType()) { + return Byte.compare((byte) getValue(), (byte) other.getValue()); + } else if (type.isSmallIntType()) { + return Short.compare((short) getValue(), (short) other.getValue()); + } else if (type.isIntType()) { + return Integer.compare((int) getValue(), (int) other.getValue()); + } else if (type.isBigIntType()) { + return Long.compare((long) getValue(), (long) other.getValue()); + } else if (type.isLargeIntType()) { + return ((BigInteger) getValue()).compareTo((BigInteger) other.getValue()); + } else if (type.isFloatType()) { + return Float.compare((float) getValue(), (float) other.getValue()); + } else if (type.isDoubleType()) { + return Double.compare((double) getValue(), (double) other.getValue()); + } else if (type.isDecimalType()) { + return Long.compare((Long) getValue(), (Long) other.getValue()); + } else if (type.isDateType()) { + // todo process date + } else if (type.isDecimalType()) { + return ((BigDecimal) getValue()).compareTo((BigDecimal) other.getValue()); + } else if (type instanceof StringType) { + return StringUtils.compare((String) getValue(), (String) other.getValue()); + } + return -1; + } + + @Override + protected Expression uncheckedCastTo(DataType targetType) throws AnalysisException { + String desc = getStringValue(); + if (targetType.isBooleanType()) { + if ("0".equals(desc) || "false".equals(desc.toLowerCase(Locale.ROOT))) { + return Literal.of(false); + } + if ("1".equals(desc) || "true".equals(desc.toLowerCase(Locale.ROOT))) { + return Literal.of(true); + } + } + if (targetType.isTinyIntType()) { + return Literal.of(Double.valueOf(desc).byteValue()); + } else if (targetType.isSmallIntType()) { + return Literal.of(Double.valueOf(desc).shortValue()); + } else if (targetType.isIntType()) { + return Literal.of(Double.valueOf(desc).intValue()); + } else if (targetType.isBigIntType()) { + return Literal.of(Double.valueOf(desc).longValue()); + } else if (targetType.isLargeIntType()) { + return Literal.of(new BigInteger(desc)); + } else if (targetType.isFloatType()) { + return Literal.of(Float.parseFloat(desc)); + } else if (targetType.isDoubleType()) { + return Literal.of(Double.parseDouble(desc)); + } else if (targetType.isCharType()) { + return new CharLiteral(desc, ((CharType) targetType).getLen()); + } else if (targetType.isVarcharType()) { + return new VarcharLiteral(desc, desc.length()); + } else if (targetType.isStringType()) { + return Literal.of(desc); + } else if (targetType.isDate()) { + return new DateLiteral(desc); + } else if (targetType.isDateTime()) { + return new DateTimeLiteral(desc); + } else if (targetType.isDecimalType()) { + return new DecimalLiteral(BigDecimal.valueOf(Double.parseDouble(desc))); + } + throw new AnalysisException("no support cast!"); + } + @Override public boolean equals(Object o) { if (this == o) { @@ -114,4 +212,5 @@ public abstract class Literal extends Expression implements LeafExpression { } public abstract LiteralExpr toLegacyLiteral(); + } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/StringLiteral.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/StringLiteral.java index bc19a7f506..b7c8a06e9f 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/StringLiteral.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/StringLiteral.java @@ -18,10 +18,7 @@ package org.apache.doris.nereids.trees.expressions.literal; import org.apache.doris.analysis.LiteralExpr; -import org.apache.doris.nereids.exceptions.AnalysisException; -import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; -import org.apache.doris.nereids.types.DataType; import org.apache.doris.nereids.types.StringType; /** @@ -56,30 +53,6 @@ public class StringLiteral extends Literal { return new org.apache.doris.analysis.StringLiteral(value); } - @Override - protected Expression uncheckedCastTo(DataType targetType) throws AnalysisException { - if (getDataType().equals(targetType)) { - return this; - } - if (targetType.isDateType()) { - return convertToDate(targetType); - } else if (targetType.isIntType()) { - return new IntegerLiteral(Integer.parseInt(value)); - } - //todo other target type cast - return this; - } - - private DateLiteral convertToDate(DataType targetType) throws AnalysisException { - DateLiteral dateLiteral = null; - if (targetType.isDate()) { - dateLiteral = new DateLiteral(value); - } else if (targetType.isDateTime()) { - dateLiteral = new DateTimeLiteral(value); - } - return dateLiteral; - } - @Override public String toString() { return "'" + value + "'"; diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/VarcharLiteral.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/VarcharLiteral.java index 393214aa18..a51e0a815b 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/VarcharLiteral.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/VarcharLiteral.java @@ -20,7 +20,6 @@ package org.apache.doris.nereids.trees.expressions.literal; import org.apache.doris.analysis.LiteralExpr; import org.apache.doris.analysis.StringLiteral; import org.apache.doris.nereids.exceptions.AnalysisException; -import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; import org.apache.doris.nereids.types.DataType; import org.apache.doris.nereids.types.VarcharType; @@ -68,20 +67,6 @@ public class VarcharLiteral extends Literal { return "'" + value + "'"; } - // Temporary way to process type coercion in TimestampArithmetic, should be replaced by TypeCoercion rule. - @Override - protected Expression uncheckedCastTo(DataType targetType) throws AnalysisException { - if (getDataType().equals(targetType)) { - return this; - } - if (targetType.isDateType()) { - return convertToDate(targetType); - } else if (targetType.isIntType()) { - return new IntegerLiteral(Integer.parseInt(value)); - } - return this; - } - private DateLiteral convertToDate(DataType targetType) throws AnalysisException { DateLiteral dateLiteral = null; if (targetType.isDate()) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/ExpressionVisitor.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/ExpressionVisitor.java index 873bad058f..424e4ce291 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/ExpressionVisitor.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/ExpressionVisitor.java @@ -40,6 +40,7 @@ import org.apache.doris.nereids.trees.expressions.GreaterThan; import org.apache.doris.nereids.trees.expressions.GreaterThanEqual; import org.apache.doris.nereids.trees.expressions.InPredicate; import org.apache.doris.nereids.trees.expressions.InSubquery; +import org.apache.doris.nereids.trees.expressions.IsNull; import org.apache.doris.nereids.trees.expressions.LessThan; import org.apache.doris.nereids.trees.expressions.LessThanEqual; import org.apache.doris.nereids.trees.expressions.Like; @@ -277,6 +278,10 @@ public abstract class ExpressionVisitor { return visit(inPredicate, context); } + public R visitIsNull(IsNull isNull, C context) { + return visit(isNull, context); + } + public R visitInSubquery(InSubquery in, C context) { return visitSubqueryExpr(in, context); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/DataType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/DataType.java index 6a2bfa8365..09ee991f39 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/DataType.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/DataType.java @@ -165,6 +165,12 @@ public abstract class DataType implements AbstractDataType { return FloatType.INSTANCE; case "double": return DoubleType.INSTANCE; + case "decimal": + return DecimalType.SYSTEM_DEFAULT; + case "char": + return CharType.INSTANCE; + case "varchar": + return VarcharType.SYSTEM_DEFAULT; case "text": case "string": return StringType.INSTANCE; @@ -303,18 +309,50 @@ public abstract class DataType implements AbstractDataType { return 0; } - public boolean isDate() { - return this instanceof DateType; + public boolean isBooleanType() { + return this instanceof BooleanType; + } + + public boolean isTinyIntType() { + return this instanceof TinyIntType; + } + + public boolean isSmallIntType() { + return this instanceof SmallIntType; } public boolean isIntType() { return this instanceof IntegerType; } + public boolean isBigIntType() { + return this instanceof BigIntType; + } + + public boolean isLargeIntType() { + return this instanceof LargeIntType; + } + + public boolean isFloatType() { + return this instanceof FloatType; + } + + public boolean isDoubleType() { + return this instanceof DoubleType; + } + + public boolean isDecimalType() { + return this instanceof DecimalType; + } + public boolean isDateTime() { return this instanceof DateTimeType; } + public boolean isDate() { + return this instanceof DateType; + } + public boolean isDateType() { return isDate() || isDateTime(); } @@ -327,6 +365,14 @@ public abstract class DataType implements AbstractDataType { return this instanceof NumericType; } + public boolean isCharType() { + return this instanceof CharType; + } + + public boolean isVarcharType() { + return this instanceof VarcharType; + } + public boolean isStringType() { return this instanceof CharacterType; } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java index b084bafa1a..65084986f5 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java @@ -26,6 +26,8 @@ import org.apache.doris.nereids.trees.expressions.Or; import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.expressions.SlotReference; import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral; +import org.apache.doris.nereids.trees.expressions.literal.Literal; +import org.apache.doris.nereids.trees.expressions.literal.NullLiteral; import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter; import com.google.common.base.Preconditions; @@ -258,4 +260,20 @@ public class ExpressionUtils { } return builder.build(); } + + public static boolean isAllLiteral(Expression... children) { + return Arrays.stream(children).allMatch(c -> c instanceof Literal); + } + + public static boolean isAllLiteral(List children) { + return children.stream().allMatch(c -> c instanceof Literal); + } + + public static boolean hasNullLiteral(List children) { + return children.stream().anyMatch(c -> c instanceof NullLiteral); + } + + public static boolean isAllNullLiteral(List children) { + return children.stream().allMatch(c -> c instanceof NullLiteral); + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/TypeCoercionUtils.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/TypeCoercionUtils.java index ef190c8ad2..03795e6cdf 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/TypeCoercionUtils.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/TypeCoercionUtils.java @@ -18,6 +18,7 @@ package org.apache.doris.nereids.util; import org.apache.doris.nereids.annotation.Developing; +import org.apache.doris.nereids.exceptions.AnalysisException; import org.apache.doris.nereids.trees.expressions.Cast; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.literal.Literal; @@ -123,6 +124,10 @@ public class TypeCoercionUtils { } else if (expected instanceof DateTimeType) { returnType = DateTimeType.INSTANCE; } + } else if (input.isDate()) { + if (expected instanceof DateTimeType) { + returnType = expected.defaultConcreteType(); + } } if (returnType == null && input instanceof PrimitiveType @@ -202,6 +207,39 @@ public class TypeCoercionUtils { return Optional.ofNullable(tightestCommonType); } + /** + * The type used for arithmetic operations. + */ + public static DataType getNumResultType(DataType type) { + if (type.isTinyIntType() || type.isSmallIntType() || type.isIntType() || type.isBigIntType()) { + return BigIntType.INSTANCE; + } else if (type.isLargeIntType()) { + return LargeIntType.INSTANCE; + } else if (type.isFloatType() || type.isDoubleType() || type.isStringType()) { + return DoubleType.INSTANCE; + } else if (type.isDecimalType()) { + return DecimalType.SYSTEM_DEFAULT; + } else if (type.isNullType()) { + return NullType.INSTANCE; + } + throw new AnalysisException("no found appropriate data type."); + } + + /** + * The common type used by arithmetic operations. + */ + public static DataType findCommonNumericsType(DataType t1, DataType t2) { + if (t1.isDoubleType() || t2.isDoubleType()) { + return DoubleType.INSTANCE; + } else if (t1.isDecimalType() || t2.isDecimalType()) { + return DecimalType.SYSTEM_DEFAULT; + } else if (t1.isLargeIntType() || t2.isLargeIntType()) { + return LargeIntType.INSTANCE; + } else { + return BigIntType.INSTANCE; + } + } + /** * find wider common type for data type list. */ diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/datasets/ssb/SSBJoinReorderTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/datasets/ssb/SSBJoinReorderTest.java index 8c79c278d5..7aa28fe259 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/datasets/ssb/SSBJoinReorderTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/datasets/ssb/SSBJoinReorderTest.java @@ -47,7 +47,9 @@ public class SSBJoinReorderTest extends SSBTestBase { "(lo_suppkey = s_suppkey)", "(lo_partkey = p_partkey)" ), - ImmutableList.of("(((c_region = 'AMERICA') AND (s_region = 'AMERICA')) AND ((p_mfgr = 'MFGR#1') OR (p_mfgr = 'MFGR#2')))") + ImmutableList.of("(((CAST(c_region AS STRING) = CAST('AMERICA' AS STRING)) AND (CAST(s_region AS STRING) " + + "= CAST('AMERICA' AS STRING))) AND ((CAST(p_mfgr AS STRING) = CAST('MFGR#1' AS STRING)) " + + "OR (CAST(p_mfgr AS STRING) = CAST('MFGR#2' AS STRING))))") ); } @@ -62,7 +64,10 @@ public class SSBJoinReorderTest extends SSBTestBase { "(lo_partkey = p_partkey)" ), ImmutableList.of( - "((((c_region = 'AMERICA') AND (s_region = 'AMERICA')) AND ((d_year = 1997) OR (d_year = 1998))) AND ((p_mfgr = 'MFGR#1') OR (p_mfgr = 'MFGR#2')))") + "((((CAST(c_region AS STRING) = CAST('AMERICA' AS STRING)) AND (CAST(s_region AS STRING) " + + "= CAST('AMERICA' AS STRING))) AND ((d_year = 1997) OR (d_year = 1998))) " + + "AND ((CAST(p_mfgr AS STRING) = CAST('MFGR#1' AS STRING)) OR (CAST(p_mfgr AS STRING) " + + "= CAST('MFGR#2' AS STRING))))") ); } @@ -76,7 +81,8 @@ public class SSBJoinReorderTest extends SSBTestBase { "(lo_suppkey = s_suppkey)", "(lo_partkey = p_partkey)" ), - ImmutableList.of("(((s_nation = 'UNITED STATES') AND ((d_year = 1997) OR (d_year = 1998))) AND (p_category = 'MFGR#14'))") + ImmutableList.of("(((CAST(s_nation AS STRING) = CAST('UNITED STATES' AS STRING)) AND ((d_year = 1997) " + + "OR (d_year = 1998))) AND (CAST(p_category AS STRING) = CAST('MFGR#14' AS STRING)))") ); } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/parser/HavingClauseTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/parser/HavingClauseTest.java index d839853f4a..05d256f395 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/parser/HavingClauseTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/parser/HavingClauseTest.java @@ -25,6 +25,7 @@ import org.apache.doris.nereids.rules.expression.rewrite.rules.TypeCoercion; import org.apache.doris.nereids.trees.expressions.Add; import org.apache.doris.nereids.trees.expressions.Alias; import org.apache.doris.nereids.trees.expressions.And; +import org.apache.doris.nereids.trees.expressions.Cast; import org.apache.doris.nereids.trees.expressions.ExprId; import org.apache.doris.nereids.trees.expressions.GreaterThan; import org.apache.doris.nereids.trees.expressions.NamedExpressionUtil; @@ -34,6 +35,7 @@ import org.apache.doris.nereids.trees.expressions.functions.agg.Min; import org.apache.doris.nereids.trees.expressions.functions.agg.Sum; import org.apache.doris.nereids.trees.expressions.literal.Literal; import org.apache.doris.nereids.trees.expressions.literal.TinyIntLiteral; +import org.apache.doris.nereids.types.BigIntType; import org.apache.doris.nereids.types.TinyIntType; import org.apache.doris.nereids.util.FieldChecker; import org.apache.doris.nereids.util.PatternMatchSupported; @@ -168,7 +170,7 @@ public class HavingClauseTest extends AnalyzeCheckTestBase implements PatternMat logicalAggregate( logicalOlapScan() ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2))) - ).when(FieldChecker.check("predicates", new GreaterThan(sumA2.toSlot(), new TinyIntLiteral((byte) 0)))) + ).when(FieldChecker.check("predicates", new GreaterThan(sumA2.toSlot(), Literal.of(0L)))) ).when(FieldChecker.check("projects", Lists.newArrayList(a1.toSlot())))); NamedExpressionUtil.clear(); @@ -181,7 +183,7 @@ public class HavingClauseTest extends AnalyzeCheckTestBase implements PatternMat logicalAggregate( logicalOlapScan() ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2))) - ).when(FieldChecker.check("predicates", new GreaterThan(sumA2.toSlot(), new TinyIntLiteral((byte) 0)))))); + ).when(FieldChecker.check("predicates", new GreaterThan(sumA2.toSlot(), Literal.of(0L)))))); NamedExpressionUtil.clear(); sql = "SELECT a1, SUM(a2) as value FROM t1 GROUP BY a1 HAVING SUM(a2) > 0"; @@ -201,7 +203,7 @@ public class HavingClauseTest extends AnalyzeCheckTestBase implements PatternMat logicalAggregate( logicalOlapScan() ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, value))) - ).when(FieldChecker.check("predicates", new GreaterThan(value.toSlot(), new TinyIntLiteral((byte) 0)))))); + ).when(FieldChecker.check("predicates", new GreaterThan(value.toSlot(), Literal.of(0L)))))); NamedExpressionUtil.clear(); sql = "SELECT a1, SUM(a2) as value FROM t1 GROUP BY a1 HAVING value > 0"; @@ -212,7 +214,7 @@ public class HavingClauseTest extends AnalyzeCheckTestBase implements PatternMat logicalAggregate( logicalOlapScan() ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, value))) - ).when(FieldChecker.check("predicates", new GreaterThan(value.toSlot(), new TinyIntLiteral((byte) 0)))))); + ).when(FieldChecker.check("predicates", new GreaterThan(value.toSlot(), Literal.of(0L)))))); NamedExpressionUtil.clear(); sql = "SELECT a1, SUM(a2) FROM t1 GROUP BY a1 HAVING MIN(pk) > 0"; @@ -237,7 +239,7 @@ public class HavingClauseTest extends AnalyzeCheckTestBase implements PatternMat logicalAggregate( logicalOlapScan() ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2, minPK))) - ).when(FieldChecker.check("predicates", new GreaterThan(minPK.toSlot(), new TinyIntLiteral((byte) 0)))) + ).when(FieldChecker.check("predicates", new GreaterThan(minPK.toSlot(), Literal.of((byte) 0)))) ).when(FieldChecker.check("projects", Lists.newArrayList(a1.toSlot(), sumA2.toSlot())))); NamedExpressionUtil.clear(); @@ -250,7 +252,7 @@ public class HavingClauseTest extends AnalyzeCheckTestBase implements PatternMat logicalAggregate( logicalOlapScan() ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA1A2))) - ).when(FieldChecker.check("predicates", new GreaterThan(sumA1A2.toSlot(), new TinyIntLiteral((byte) 0)))))); + ).when(FieldChecker.check("predicates", new GreaterThan(sumA1A2.toSlot(), Literal.of(0L)))))); NamedExpressionUtil.clear(); sql = "SELECT a1, SUM(a1 + a2) FROM t1 GROUP BY a1 HAVING SUM(a1 + a2 + 3) > 0"; @@ -263,7 +265,7 @@ public class HavingClauseTest extends AnalyzeCheckTestBase implements PatternMat logicalAggregate( logicalOlapScan() ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA1A2, sumA1A23))) - ).when(FieldChecker.check("predicates", new GreaterThan(sumA1A23.toSlot(), new TinyIntLiteral((byte) 0)))) + ).when(FieldChecker.check("predicates", new GreaterThan(sumA1A23.toSlot(), Literal.of(0L)))) ).when(FieldChecker.check("projects", Lists.newArrayList(a1.toSlot(), sumA1A2.toSlot())))); NamedExpressionUtil.clear(); @@ -276,7 +278,7 @@ public class HavingClauseTest extends AnalyzeCheckTestBase implements PatternMat logicalAggregate( logicalOlapScan() ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, countStar))) - ).when(FieldChecker.check("predicates", new GreaterThan(countStar.toSlot(), new TinyIntLiteral((byte) 0)))) + ).when(FieldChecker.check("predicates", new GreaterThan(countStar.toSlot(), Literal.of(0L)))) ).when(FieldChecker.check("projects", Lists.newArrayList(a1.toSlot())))); NamedExpressionUtil.clear(); } @@ -310,7 +312,8 @@ public class HavingClauseTest extends AnalyzeCheckTestBase implements PatternMat ) ) ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2, sumB1))) - ).when(FieldChecker.check("predicates", new GreaterThan(a1, sumB1.toSlot()))) + ).when(FieldChecker.check("predicates", new GreaterThan(new Cast(a1, BigIntType.INSTANCE), + sumB1.toSlot()))) ).when(FieldChecker.check("projects", Lists.newArrayList(a1.toSlot(), sumA2.toSlot())))); NamedExpressionUtil.clear(); } @@ -353,14 +356,14 @@ public class HavingClauseTest extends AnalyzeCheckTestBase implements PatternMat ImmutableList.of("default_cluster:test_having", "t1") ); SlotReference a2 = new SlotReference( - new ExprId(3), "a1", TinyIntType.INSTANCE, true, + new ExprId(3), "a2", TinyIntType.INSTANCE, true, ImmutableList.of("default_cluster:test_having", "t1") ); Alias pk1 = new Alias(new ExprId(7), new Add(pk, Literal.of((byte) 1)), "(pk + 1)"); Alias pk11 = new Alias(new ExprId(8), new Add(new Add(pk, Literal.of((byte) 1)), Literal.of((byte) 1)), "((pk + 1) + 1)"); Alias pk2 = new Alias(new ExprId(9), new Add(pk, Literal.of((byte) 2)), "(pk + 2)"); Alias sumA1 = new Alias(new ExprId(10), new Sum(a1), "SUM(a1)"); - Alias countA11 = new Alias(new ExprId(11), new Add(new Count(a1, false), Literal.of((byte) 1)), "(COUNT(a1) + 1)"); + Alias countA11 = new Alias(new ExprId(11), new Add(new Count(a1, false), Literal.of(1L)), "(COUNT(a1) + 1)"); Alias sumA1A2 = new Alias(new ExprId(12), new Sum(new Add(a1, a2)), "SUM((a1 + a2))"); Alias v1 = new Alias(new ExprId(0), new Count(a2, false), "v1"); PlanChecker.from(connectContext).analyze(sql) @@ -382,11 +385,11 @@ public class HavingClauseTest extends AnalyzeCheckTestBase implements PatternMat new And( new And( new GreaterThan(pk.toSlot(), Literal.of((byte) 0)), - new GreaterThan(countA11.toSlot(), Literal.of((byte) 0))), - new GreaterThan(new Add(sumA1A2.toSlot(), Literal.of((byte) 1)), Literal.of((byte) 0))), - new GreaterThan(new Add(v1.toSlot(), Literal.of((byte) 1)), Literal.of((byte) 0)) + new GreaterThan(countA11.toSlot(), Literal.of(0L))), + new GreaterThan(new Add(sumA1A2.toSlot(), Literal.of(1L)), Literal.of(0L))), + new GreaterThan(new Add(v1.toSlot(), Literal.of(1L)), Literal.of(0L)) ), - new GreaterThan(v1.toSlot(), Literal.of((byte) 0)) + new GreaterThan(v1.toSlot(), Literal.of(0L)) ) )) ).when(FieldChecker.check( diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/FunctionRegistryTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/FunctionRegistryTest.java index 8cec091371..b2591c69bf 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/FunctionRegistryTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/FunctionRegistryTest.java @@ -53,7 +53,7 @@ public class FunctionRegistryTest implements PatternMatchSupported { logicalOneRowRelation().when(r -> { Year year = (Year) r.getProjects().get(0).child(0); Assertions.assertEquals("2021-01-01", - ((Literal) year.getArguments().get(0)).getValue()); + ((Literal) year.getArguments().get(0).child(0)).getValue()); return true; }) ); @@ -70,13 +70,13 @@ public class FunctionRegistryTest implements PatternMatchSupported { logicalOneRowRelation().when(r -> { Substring firstSubstring = (Substring) r.getProjects().get(0).child(0); Assertions.assertEquals("abc", ((Literal) firstSubstring.getSource()).getValue()); - Assertions.assertEquals((byte) 1, ((Literal) firstSubstring.getPosition()).getValue()); - Assertions.assertEquals((byte) 2, ((Literal) firstSubstring.getLength().get()).getValue()); + Assertions.assertEquals((byte) 1, ((Literal) firstSubstring.getPosition().child(0)).getValue()); + Assertions.assertEquals((byte) 2, ((Literal) firstSubstring.getLength().get().child(0)).getValue()); Substring secondSubstring = (Substring) r.getProjects().get(1).child(0); Assertions.assertTrue(secondSubstring.getSource() instanceof Substring); - Assertions.assertEquals((byte) 1, ((Literal) secondSubstring.getPosition()).getValue()); - Assertions.assertEquals((byte) 2, ((Literal) secondSubstring.getLength().get()).getValue()); + Assertions.assertEquals((byte) 1, ((Literal) secondSubstring.getPosition().child(0)).getValue()); + Assertions.assertEquals((byte) 2, ((Literal) secondSubstring.getLength().get().child(0)).getValue()); return true; }) ); @@ -93,13 +93,13 @@ public class FunctionRegistryTest implements PatternMatchSupported { logicalOneRowRelation().when(r -> { Substring firstSubstring = (Substring) r.getProjects().get(0).child(0); Assertions.assertEquals("abc", ((Literal) firstSubstring.getSource()).getValue()); - Assertions.assertEquals((byte) 1, ((Literal) firstSubstring.getPosition()).getValue()); - Assertions.assertFalse(firstSubstring.getLength().isPresent()); + Assertions.assertEquals((byte) 1, ((Literal) firstSubstring.getPosition().child(0)).getValue()); + Assertions.assertTrue(firstSubstring.getLength().isPresent()); Substring secondSubstring = (Substring) r.getProjects().get(1).child(0); Assertions.assertEquals("def", ((Literal) secondSubstring.getSource()).getValue()); - Assertions.assertEquals((byte) 2, ((Literal) secondSubstring.getPosition()).getValue()); - Assertions.assertEquals((byte) 3, ((Literal) secondSubstring.getLength().get()).getValue()); + Assertions.assertEquals((byte) 2, ((Literal) secondSubstring.getPosition().child(0)).getValue()); + Assertions.assertEquals((byte) 3, ((Literal) secondSubstring.getLength().get().child(0)).getValue()); return true; }) ); diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rewrite/ExpressionRewriteTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rewrite/ExpressionRewriteTest.java index 2798e05a5b..b8fe4a7c4f 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rewrite/ExpressionRewriteTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rewrite/ExpressionRewriteTest.java @@ -40,7 +40,7 @@ public class ExpressionRewriteTest { @Test public void testNotRewrite() { - executor = new ExpressionRuleExecutor(SimplifyNotExprRule.INSTANCE); + executor = new ExpressionRuleExecutor(ImmutableList.of(SimplifyNotExprRule.INSTANCE)); assertRewrite("not x", "not x"); assertRewrite("not not x", "x"); @@ -65,7 +65,7 @@ public class ExpressionRewriteTest { @Test public void testNormalizeExpressionRewrite() { - executor = new ExpressionRuleExecutor(NormalizeBinaryPredicatesRule.INSTANCE); + executor = new ExpressionRuleExecutor(ImmutableList.of(NormalizeBinaryPredicatesRule.INSTANCE)); assertRewrite("1 = 1", "1 = 1"); assertRewrite("2 > x", "x < 2"); @@ -77,7 +77,7 @@ public class ExpressionRewriteTest { @Test public void testDistinctPredicatesRewrite() { - executor = new ExpressionRuleExecutor(DistinctPredicatesRule.INSTANCE); + executor = new ExpressionRuleExecutor(ImmutableList.of(DistinctPredicatesRule.INSTANCE)); assertRewrite("a = 1", "a = 1"); assertRewrite("a = 1 and a = 1", "a = 1"); @@ -89,7 +89,7 @@ public class ExpressionRewriteTest { @Test public void testExtractCommonFactorRewrite() { - executor = new ExpressionRuleExecutor(ExtractCommonFactorRule.INSTANCE); + executor = new ExpressionRuleExecutor(ImmutableList.of(ExtractCommonFactorRule.INSTANCE)); assertRewrite("a", "a"); @@ -142,7 +142,8 @@ public class ExpressionRewriteTest { @Test public void testBetweenToCompoundRule() { - executor = new ExpressionRuleExecutor(ImmutableList.of(BetweenToCompoundRule.INSTANCE, SimplifyNotExprRule.INSTANCE)); + executor = new ExpressionRuleExecutor(ImmutableList.of(BetweenToCompoundRule.INSTANCE, + SimplifyNotExprRule.INSTANCE)); assertRewrite("a between c and d", "(a >= c) and (a <= d)"); assertRewrite("a not between c and d)", "(a < c) or (a > d)"); diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rewrite/FoldConstantTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rewrite/FoldConstantTest.java new file mode 100644 index 0000000000..ec21777e20 --- /dev/null +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rewrite/FoldConstantTest.java @@ -0,0 +1,341 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.expression.rewrite; + +import org.apache.doris.analysis.ArithmeticExpr.Operator; +import org.apache.doris.nereids.analyzer.UnboundSlot; +import org.apache.doris.nereids.parser.NereidsParser; +import org.apache.doris.nereids.rules.expression.rewrite.rules.FoldConstantRuleOnFE; +import org.apache.doris.nereids.rules.expression.rewrite.rules.TypeCoercion; +import org.apache.doris.nereids.trees.expressions.Add; +import org.apache.doris.nereids.trees.expressions.Cast; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.GreaterThan; +import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.expressions.SlotReference; +import org.apache.doris.nereids.trees.expressions.TimestampArithmetic; +import org.apache.doris.nereids.trees.expressions.literal.DateTimeLiteral; +import org.apache.doris.nereids.trees.expressions.literal.IntervalLiteral.TimeUnit; +import org.apache.doris.nereids.trees.expressions.literal.Literal; +import org.apache.doris.nereids.types.BigIntType; +import org.apache.doris.nereids.types.BooleanType; +import org.apache.doris.nereids.types.DataType; +import org.apache.doris.nereids.types.DateTimeType; +import org.apache.doris.nereids.types.DoubleType; +import org.apache.doris.nereids.types.IntegerType; +import org.apache.doris.nereids.types.StringType; +import org.apache.doris.nereids.types.TinyIntType; +import org.apache.doris.nereids.types.VarcharType; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Lists; +import com.google.common.collect.Maps; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +import java.util.List; +import java.util.Locale; +import java.util.Map; + +public class FoldConstantTest { + + private static final NereidsParser PARSER = new NereidsParser(); + private ExpressionRuleExecutor executor; + + @Test + public void testCaseWhenFold() { + executor = new ExpressionRuleExecutor(ImmutableList.of(TypeCoercion.INSTANCE, FoldConstantRuleOnFE.INSTANCE)); + assertRewrite("case when 1 = 2 then 1 when '1' < 2 then 2 else 3 end", "2"); + assertRewrite("case when 1 = 2 then 1 when '1' > 2 then 2 end", "null"); + assertRewrite("case when (1 + 5) / 2 > 2 then 4 when '1' < 2 then 2 else 3 end", "4"); + assertRewrite("case when not 1 = 2 then 1 when '1' > 2 then 2 end", "1"); + assertRewrite("case when 1 = 2 then 1 when 3 in ('1',2 + 8 / 2,3,4) then 2 end", "2"); + assertRewrite("case when TA = 2 then 1 when 3 in ('1',2 + 8 / 2,3,4) then 2 end", "CASE WHEN (TA = 2) THEN 1 ELSE 2 END"); + assertRewrite("case when TA = 2 then 5 when 3 in (2,3,4) then 2 else 4 end", "CASE WHEN (TA = 2) THEN 5 ELSE 2 END"); + assertRewrite("case when TA = 2 then 1 when TB in (2,3,4) then 2 else 4 end", "CASE WHEN (TA = 2) THEN 1 WHEN TB IN (2, 3, 4) THEN 2 ELSE 4 END"); + assertRewrite("case when null = 2 then 1 when 3 in (2,3,4) then 2 else 4 end", "2"); + assertRewrite("case when null = 2 then 1 else 4 end", "4"); + assertRewrite("case when null = 2 then 1 end", "null"); + assertRewrite("case when TA = TB then 1 when TC is null then 2 end", "CASE WHEN (TA = TB) THEN 1 WHEN TC IS NULL THEN 2 ELSE NULL END"); + } + + @Test + public void testInFold() { + executor = new ExpressionRuleExecutor(ImmutableList.of(TypeCoercion.INSTANCE, FoldConstantRuleOnFE.INSTANCE)); + assertRewrite("1 in (1,2,3,4)", "true"); + // Type Coercion trans all to string. + assertRewrite("3 in ('1',2 + 8 / 2,3,4)", "true"); + assertRewrite("4 / 2 * 1 - (5/2) in ('1',2 + 8 / 2,3,4)", "false"); + assertRewrite("null in ('1',2 + 8 / 2,3,4)", "null"); + assertRewrite("3 in ('1',null,3,4)", "true"); + assertRewrite("TA in (1,null,3,4)", "TA in (1, null, 3, 4)"); + assertRewrite("IA in (IB,IC,null)", "IA in (IB,IC,null)"); + } + + @Test + public void testLogicalFold() { + executor = new ExpressionRuleExecutor(ImmutableList.of(TypeCoercion.INSTANCE, FoldConstantRuleOnFE.INSTANCE)); + assertRewrite("10 + 1 > 1 and 1 > 2", "false"); + assertRewrite("10 + 1 > 1 and 1 < 2", "true"); + assertRewrite("null + 1 > 1 and 1 < 2", "null"); + assertRewrite("10 < 3 and 1 > 2", "false"); + assertRewrite("6 / 2 - 10 * (6 + 1) > 2 and 10 > 3 and 1 > 2", "false"); + + assertRewrite("10 + 1 > 1 or 1 > 2", "true"); + assertRewrite("null + 1 > 1 or 1 > 2", "null"); + assertRewrite("6 / 2 - 10 * (6 + 1) > 2 or 10 > 3 or 1 > 2", "true"); + + assertRewrite("(1 > 5 and 8 < 10 or 1 = 3) or (1 > 8 + 9 / (10 * 2) or ( 10 = 3))", "false"); + assertRewrite("(TA > 1 and 8 < 10 or 1 = 3) or (1 > 3 or ( 10 = 3))", "TA > 1"); + + assertRewrite("false or false", "false"); + assertRewrite("false or true", "true"); + assertRewrite("true or false", "true"); + assertRewrite("true or true", "true"); + + assertRewrite("true and true", "true"); + assertRewrite("false and true", "false"); + assertRewrite("true and false", "false"); + assertRewrite("false and false", "false"); + + assertRewrite("true and null", "null"); + assertRewrite("false and null", "false"); + assertRewrite("true or null", "true"); + assertRewrite("false or null", "null"); + + assertRewrite("null and null", "null"); + } + + @Test + public void testIsNullFold() { + executor = new ExpressionRuleExecutor(ImmutableList.of(TypeCoercion.INSTANCE, FoldConstantRuleOnFE.INSTANCE)); + assertRewrite("100 is null", "false"); + assertRewrite("null is null", "true"); + assertRewrite("null is not null", "false"); + assertRewrite("100 is not null", "true"); + assertRewrite("IA is not null", "IA is not null"); + assertRewrite("IA is null", "IA is null"); + } + + @Test + public void testNotFold() { + executor = new ExpressionRuleExecutor(ImmutableList.of(TypeCoercion.INSTANCE, FoldConstantRuleOnFE.INSTANCE)); + assertRewrite("not 1 > 2", "true"); + assertRewrite("not null + 1 > 2", "null"); + assertRewrite("not (1 + 5) / 2 + (10 - 1) * 3 > 3 * 5 + 1", "false"); + } + + @Test + public void testCastFold() { + executor = new ExpressionRuleExecutor(ImmutableList.of(FoldConstantRuleOnFE.INSTANCE)); + + // cast '1' as tinyint + Cast c = new Cast(Literal.of("1"), TinyIntType.INSTANCE); + Expression rewritten = executor.rewrite(c); + Literal expected = Literal.of((byte) 1); + Assertions.assertEquals(rewritten, expected); + } + + @Test + public void testCompareFold() { + executor = new ExpressionRuleExecutor(ImmutableList.of(TypeCoercion.INSTANCE, FoldConstantRuleOnFE.INSTANCE)); + assertRewrite("'1' = 2", "false"); + assertRewrite("1 = 2", "false"); + assertRewrite("1 != 2", "true"); + assertRewrite("2 > 2", "false"); + assertRewrite("3 * 10 + 1 / 2 >= 2", "true"); + assertRewrite("3 < 2", "false"); + assertRewrite("3 <= 2", "false"); + assertRewrite("3 <= null", "null"); + assertRewrite("3 >= null", "null"); + assertRewrite("null <=> null", "true"); + assertRewrite("2 <=> null", "false"); + assertRewrite("2 <=> 2", "true"); + } + + @Test + public void testArithmeticFold() { + executor = new ExpressionRuleExecutor(ImmutableList.of(TypeCoercion.INSTANCE, FoldConstantRuleOnFE.INSTANCE)); + assertRewrite("1 + 1", Literal.of((byte) 2)); + assertRewrite("1 - 1", Literal.of((byte) 0)); + assertRewrite("100 + 100", Literal.of((byte) 200)); + assertRewrite("1 - 2", Literal.of((byte) -1)); + + assertRewrite("1 - 2 > 1", "false"); + assertRewrite("1 - 2 + 1 > 1 + 1 - 100", "true"); + assertRewrite("10 * 2 / 1 + 1 > (1 + 1) - 100", "true"); + + // a + 1 > 2 + Slot a = SlotReference.of("a", IntegerType.INSTANCE); + Expression e1 = new Add(a, Literal.of(1L)); + Expression e2 = new Add(new Cast(a, BigIntType.INSTANCE), Literal.of(1L)); + assertRewrite(e1, e2); + + // a > (1 + 10) / 2 * (10 + 1) + Expression e3 = PARSER.parseExpression("(1 + 10) / 2 * (10 + 1)"); + Expression e4 = new GreaterThan(a, e3); + Expression e5 = new GreaterThan(new Cast(a, DoubleType.INSTANCE), Literal.of(60.5D)); + assertRewrite(e4, e5); + + // a > 1 + Expression e6 = new GreaterThan(a, Literal.of(1)); + assertRewrite(e6, e6); + assertRewrite(a, a); + + // a + assertRewrite(a, a); + + // 1 + Literal one = Literal.of(1); + assertRewrite(one, one); + } + + @Test + public void testTimestampFold() { + executor = new ExpressionRuleExecutor(ImmutableList.of(TypeCoercion.INSTANCE, FoldConstantRuleOnFE.INSTANCE)); + String interval = "'1991-05-01' - interval 1 day"; + Expression e7 = process((TimestampArithmetic) PARSER.parseExpression(interval)); + Expression e8 = new DateTimeLiteral(1991, 4, 30, 0, 0, 0); + assertRewrite(e7, e8); + + interval = "'1991-05-01' + interval '1' day"; + e7 = process((TimestampArithmetic) PARSER.parseExpression(interval)); + e8 = new DateTimeLiteral(1991, 5, 2, 0, 0, 0); + assertRewrite(e7, e8); + + interval = "'1991-05-01' + interval 1+1 day"; + e7 = process((TimestampArithmetic) PARSER.parseExpression(interval)); + e8 = new DateTimeLiteral(1991, 5, 3, 0, 0, 0); + assertRewrite(e7, e8); + + interval = "date '1991-05-01' + interval 10 / 2 + 1 day"; + e7 = process((TimestampArithmetic) PARSER.parseExpression(interval)); + e8 = new DateTimeLiteral(1991, 5, 7, 0, 0, 0); + assertRewrite(e7, e8); + + interval = "interval '1' day + '1991-05-01'"; + e7 = process((TimestampArithmetic) PARSER.parseExpression(interval)); + e8 = new DateTimeLiteral(1991, 5, 2, 0, 0, 0); + assertRewrite(e7, e8); + + interval = "interval '3' month + '1991-05-01'"; + e7 = process((TimestampArithmetic) PARSER.parseExpression(interval)); + e8 = new DateTimeLiteral(1991, 8, 1, 0, 0, 0); + assertRewrite(e7, e8); + + interval = "interval 3 + 1 month + '1991-05-01'"; + e7 = process((TimestampArithmetic) PARSER.parseExpression(interval)); + e8 = new DateTimeLiteral(1991, 9, 1, 0, 0, 0); + assertRewrite(e7, e8); + + interval = "interval 3 + 1 year + '1991-05-01'"; + e7 = process((TimestampArithmetic) PARSER.parseExpression(interval)); + e8 = new DateTimeLiteral(1995, 5, 1, 0, 0, 0); + assertRewrite(e7, e8); + + interval = "interval 3 + 3 / 2 hour + '1991-05-01 10:00:00'"; + e7 = process((TimestampArithmetic) PARSER.parseExpression(interval)); + e8 = new DateTimeLiteral(1991, 5, 1, 14, 0, 0); + assertRewrite(e7, e8); + + interval = "interval 3 * 2 / 3 minute + '1991-05-01 10:00:00'"; + e7 = process((TimestampArithmetic) PARSER.parseExpression(interval)); + e8 = new DateTimeLiteral(1991, 5, 1, 10, 2, 0); + assertRewrite(e7, e8); + + interval = "interval 3 / 2 + 1 second + '1991-05-01 10:00:00'"; + e7 = process((TimestampArithmetic) PARSER.parseExpression(interval)); + e8 = new DateTimeLiteral(1991, 5, 1, 10, 0, 2); + assertRewrite(e7, e8); + + // a + interval 1 day + Slot a = SlotReference.of("a", DateTimeType.INSTANCE); + TimestampArithmetic arithmetic = new TimestampArithmetic(Operator.ADD, a, Literal.of(1), TimeUnit.DAY, false); + Expression process = process(arithmetic); + assertRewrite(process, process); + } + + public Expression process(TimestampArithmetic arithmetic) { + String funcOpName; + if (arithmetic.getFuncName() == null) { + funcOpName = String.format("%sS_%s", arithmetic.getTimeUnit(), + (arithmetic.getOp() == Operator.ADD) ? "ADD" : "SUB"); + } else { + funcOpName = arithmetic.getFuncName(); + } + return arithmetic.withFuncName(funcOpName.toLowerCase(Locale.ROOT)); + } + + private void assertRewrite(String expression, String expected) { + Map mem = Maps.newHashMap(); + Expression needRewriteExpression = PARSER.parseExpression(expression); + needRewriteExpression = replaceUnboundSlot(needRewriteExpression, mem); + Expression expectedExpression = PARSER.parseExpression(expected); + expectedExpression = replaceUnboundSlot(expectedExpression, mem); + Expression rewrittenExpression = executor.rewrite(needRewriteExpression); + Assertions.assertEquals(expectedExpression, rewrittenExpression); + } + + private void assertRewrite(String expression, Expression expectedExpression) { + Expression needRewriteExpression = PARSER.parseExpression(expression); + Expression rewrittenExpression = executor.rewrite(needRewriteExpression); + Assertions.assertEquals(expectedExpression, rewrittenExpression); + } + + private void assertRewrite(Expression expression, Expression expectedExpression) { + Expression rewrittenExpression = executor.rewrite(expression); + Assertions.assertEquals(expectedExpression, rewrittenExpression); + } + + private Expression replaceUnboundSlot(Expression expression, Map mem) { + List children = Lists.newArrayList(); + boolean hasNewChildren = false; + for (Expression child : expression.children()) { + Expression newChild = replaceUnboundSlot(child, mem); + if (newChild != child) { + hasNewChildren = true; + } + children.add(newChild); + } + if (expression instanceof UnboundSlot) { + String name = ((UnboundSlot) expression).getName(); + mem.putIfAbsent(name, SlotReference.of(name, getType(name.charAt(0)))); + return mem.get(name); + } + return hasNewChildren ? expression.withChildren(children) : expression; + } + + private DataType getType(char t) { + switch (t) { + case 'T': + return TinyIntType.INSTANCE; + case 'I': + return IntegerType.INSTANCE; + case 'D': + return DoubleType.INSTANCE; + case 'S': + return StringType.INSTANCE; + case 'V': + return VarcharType.INSTANCE; + case 'B': + return BooleanType.INSTANCE; + default: + return BigIntType.INSTANCE; + } + } +} diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rewrite/TypeCoercionTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rewrite/TypeCoercionTest.java index bb7ca2102a..8427d46cae 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rewrite/TypeCoercionTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rewrite/TypeCoercionTest.java @@ -36,6 +36,7 @@ import org.apache.doris.nereids.trees.expressions.literal.DateLiteral; import org.apache.doris.nereids.trees.expressions.literal.DecimalLiteral; import org.apache.doris.nereids.trees.expressions.literal.DoubleLiteral; import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral; +import org.apache.doris.nereids.trees.expressions.literal.Literal; import org.apache.doris.nereids.trees.expressions.literal.SmallIntLiteral; import org.apache.doris.nereids.trees.expressions.literal.StringLiteral; import org.apache.doris.nereids.trees.expressions.literal.TinyIntLiteral; @@ -176,8 +177,8 @@ public class TypeCoercionTest { @Test public void testBinaryOperator() { Expression actual = new Divide(new SmallIntLiteral((short) 1), new BigIntLiteral(10L)); - Expression expected = new Divide(new BigIntLiteral(1L), - new BigIntLiteral(10L)); + Expression expected = new Divide(new Cast(Literal.of((short) 1), DoubleType.INSTANCE), + new Cast(Literal.of(10L), DoubleType.INSTANCE)); assertRewrite(expected, actual); } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/expressions/ExpressionParserTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/expressions/ExpressionParserTest.java index 11dbb1d5c4..1510cf743e 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/expressions/ExpressionParserTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/expressions/ExpressionParserTest.java @@ -272,4 +272,13 @@ public class ExpressionParserTest extends ParserTestBase { String cast2 = "SELECT CAST(A AS INT) AS I FROM TEST;"; assertSql(cast2); } + + @Test + public void testIsNull() { + String e1 = "a is null"; + assertExpr(e1); + + String e2 = "a is not null"; + assertExpr(e2); + } }