[feature](Nereids) constant expression folding (#12151)
This commit is contained in:
@ -203,6 +203,7 @@ predicate
|
||||
| NOT? kind=(LIKE | REGEXP) pattern=valueExpression
|
||||
| NOT? kind=IN LEFT_PAREN expression (COMMA expression)* RIGHT_PAREN
|
||||
| NOT? kind=IN LEFT_PAREN query RIGHT_PAREN
|
||||
| IS NOT? kind=NULL
|
||||
;
|
||||
|
||||
valueExpression
|
||||
|
||||
@ -22,6 +22,7 @@ import org.apache.doris.nereids.jobs.batch.AnalyzeRulesJob;
|
||||
import org.apache.doris.nereids.jobs.batch.AnalyzeSubqueryRulesJob;
|
||||
import org.apache.doris.nereids.jobs.batch.CheckAnalysisJob;
|
||||
import org.apache.doris.nereids.jobs.batch.FinalizeAnalyzeJob;
|
||||
import org.apache.doris.nereids.jobs.batch.TypeCoercionJob;
|
||||
import org.apache.doris.nereids.rules.analysis.Scope;
|
||||
|
||||
import java.util.Objects;
|
||||
@ -44,9 +45,13 @@ public class NereidsAnalyzer {
|
||||
this.outerScope = Objects.requireNonNull(outerScope, "outerScope can not be null");
|
||||
}
|
||||
|
||||
/**
|
||||
* nereids analyze sql.
|
||||
*/
|
||||
public void analyze() {
|
||||
new AnalyzeRulesJob(cascadesContext, outerScope).execute();
|
||||
new AnalyzeSubqueryRulesJob(cascadesContext).execute();
|
||||
new TypeCoercionJob(cascadesContext).execute();
|
||||
new FinalizeAnalyzeJob(cascadesContext).execute();
|
||||
// check whether analyze result is meaningful
|
||||
new CheckAnalysisJob(cascadesContext).execute();
|
||||
|
||||
@ -37,7 +37,7 @@ import org.apache.doris.nereids.rules.rewrite.logical.ReorderJoin;
|
||||
import com.google.common.collect.ImmutableList;
|
||||
|
||||
/**
|
||||
* Apply rules to normalize expressions.
|
||||
* Apply rules to optimize logical plan.
|
||||
*/
|
||||
public class NereidsRewriteJobExecutor extends BatchRulesJob {
|
||||
|
||||
@ -58,7 +58,7 @@ public class NereidsRewriteJobExecutor extends BatchRulesJob {
|
||||
*/
|
||||
.addAll(new AdjustApplyFromCorrelatToUnCorrelatJob(cascadesContext).rulesJob)
|
||||
.addAll(new ConvertApplyToJoinJob(cascadesContext).rulesJob)
|
||||
.add(topDownBatch(ImmutableList.of(new ExpressionNormalization())))
|
||||
.add(topDownBatch(ImmutableList.of(new ExpressionNormalization(cascadesContext.getConnectContext()))))
|
||||
.add(topDownBatch(ImmutableList.of(new ExpressionOptimization())))
|
||||
.add(topDownBatch(ImmutableList.of(new ExtractSingleTableExpressionFromDisjunction())))
|
||||
.add(topDownBatch(ImmutableList.of(new NormalizeAggregate())))
|
||||
|
||||
@ -0,0 +1,41 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package org.apache.doris.nereids.jobs.batch;
|
||||
|
||||
import org.apache.doris.nereids.CascadesContext;
|
||||
import org.apache.doris.nereids.rules.expression.rewrite.ExpressionNormalization;
|
||||
import org.apache.doris.nereids.rules.expression.rewrite.rules.TypeCoercion;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
|
||||
/**
|
||||
* type coercion job.
|
||||
*/
|
||||
public class TypeCoercionJob extends BatchRulesJob {
|
||||
/**
|
||||
* constructor.
|
||||
*/
|
||||
public TypeCoercionJob(CascadesContext cascadesContext) {
|
||||
super(cascadesContext);
|
||||
rulesJob.addAll(ImmutableList.of(
|
||||
topDownBatch(ImmutableList.of(
|
||||
new ExpressionNormalization(cascadesContext.getConnectContext(),
|
||||
ImmutableList.of(TypeCoercion.INSTANCE)))
|
||||
)));
|
||||
}
|
||||
}
|
||||
@ -95,6 +95,7 @@ import org.apache.doris.nereids.trees.expressions.GreaterThan;
|
||||
import org.apache.doris.nereids.trees.expressions.GreaterThanEqual;
|
||||
import org.apache.doris.nereids.trees.expressions.InPredicate;
|
||||
import org.apache.doris.nereids.trees.expressions.InSubquery;
|
||||
import org.apache.doris.nereids.trees.expressions.IsNull;
|
||||
import org.apache.doris.nereids.trees.expressions.LessThan;
|
||||
import org.apache.doris.nereids.trees.expressions.LessThanEqual;
|
||||
import org.apache.doris.nereids.trees.expressions.Like;
|
||||
@ -385,7 +386,6 @@ public class LogicalPlanBuilder extends DorisParserBaseVisitor<Object> {
|
||||
public Expression visitPredicated(PredicatedContext ctx) {
|
||||
return ParserUtils.withOrigin(ctx, () -> {
|
||||
Expression e = getExpression(ctx.valueExpression());
|
||||
// TODO: add predicate(is not null ...)
|
||||
return ctx.predicate() == null ? e : withPredicate(e, ctx.predicate());
|
||||
});
|
||||
}
|
||||
@ -948,6 +948,9 @@ public class LogicalPlanBuilder extends DorisParserBaseVisitor<Object> {
|
||||
);
|
||||
}
|
||||
break;
|
||||
case DorisParser.NULL:
|
||||
outExpression = new IsNull(valueExpression);
|
||||
break;
|
||||
default:
|
||||
throw new ParseException("Unsupported predicate type: " + ctx.kind.getText(), ctx);
|
||||
}
|
||||
|
||||
@ -36,13 +36,11 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalHaving;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalOneRowRelation;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
|
||||
import org.apache.doris.nereids.types.DateTimeType;
|
||||
import org.apache.doris.nereids.types.DateType;
|
||||
import org.apache.doris.nereids.types.IntegerType;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Locale;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
@ -134,8 +132,12 @@ public class BindFunction implements AnalysisRuleFactory {
|
||||
return builder.build(functionName, unboundFunction.getArguments());
|
||||
}
|
||||
|
||||
/**
|
||||
* gets the method for calculating the time.
|
||||
* e.g. YEARS_ADD、YEARS_SUB、DAYS_ADD 、DAYS_SUB
|
||||
*/
|
||||
@Override
|
||||
public Expression visitTimestampArithmetic(TimestampArithmetic arithmetic, Env env) {
|
||||
public Expression visitTimestampArithmetic(TimestampArithmetic arithmetic, Env context) {
|
||||
String funcOpName;
|
||||
if (arithmetic.getFuncName() == null) {
|
||||
// e.g. YEARS_ADD, MONTHS_SUB
|
||||
@ -144,24 +146,7 @@ public class BindFunction implements AnalysisRuleFactory {
|
||||
} else {
|
||||
funcOpName = arithmetic.getFuncName();
|
||||
}
|
||||
|
||||
Expression left = arithmetic.left();
|
||||
Expression right = arithmetic.right();
|
||||
|
||||
if (!left.getDataType().isDateType()) {
|
||||
try {
|
||||
left = left.castTo(DateTimeType.INSTANCE);
|
||||
} catch (Exception e) {
|
||||
// ignore
|
||||
}
|
||||
if (!left.getDataType().isDateType() && !arithmetic.getTimeUnit().isDateTimeUnit()) {
|
||||
left = arithmetic.left().castTo(DateType.INSTANCE);
|
||||
}
|
||||
}
|
||||
if (!right.getDataType().isIntType()) {
|
||||
right = right.castTo(IntegerType.INSTANCE);
|
||||
}
|
||||
return arithmetic.withFuncName(funcOpName).withChildren(ImmutableList.of(left, right));
|
||||
return arithmetic.withFuncName(funcOpName.toLowerCase(Locale.ROOT));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -18,11 +18,13 @@
|
||||
package org.apache.doris.nereids.rules.expression.rewrite;
|
||||
|
||||
import org.apache.doris.nereids.rules.expression.rewrite.rules.BetweenToCompoundRule;
|
||||
import org.apache.doris.nereids.rules.expression.rewrite.rules.FoldConstantRule;
|
||||
import org.apache.doris.nereids.rules.expression.rewrite.rules.InPredicateToEqualToRule;
|
||||
import org.apache.doris.nereids.rules.expression.rewrite.rules.NormalizeBinaryPredicatesRule;
|
||||
import org.apache.doris.nereids.rules.expression.rewrite.rules.SimplifyCastRule;
|
||||
import org.apache.doris.nereids.rules.expression.rewrite.rules.SimplifyNotExprRule;
|
||||
import org.apache.doris.nereids.rules.expression.rewrite.rules.TypeCoercion;
|
||||
import org.apache.doris.qe.ConnectContext;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
|
||||
@ -39,11 +41,16 @@ public class ExpressionNormalization extends ExpressionRewrite {
|
||||
InPredicateToEqualToRule.INSTANCE,
|
||||
SimplifyNotExprRule.INSTANCE,
|
||||
SimplifyCastRule.INSTANCE,
|
||||
TypeCoercion.INSTANCE
|
||||
TypeCoercion.INSTANCE,
|
||||
FoldConstantRule.INSTANCE
|
||||
);
|
||||
private static final ExpressionRuleExecutor EXECUTOR = new ExpressionRuleExecutor(NORMALIZE_REWRITE_RULES);
|
||||
|
||||
public ExpressionNormalization() {
|
||||
super(EXECUTOR);
|
||||
public ExpressionNormalization(ConnectContext context) {
|
||||
super(new ExpressionRuleExecutor(NORMALIZE_REWRITE_RULES, context));
|
||||
}
|
||||
|
||||
public ExpressionNormalization(ConnectContext context, List<ExpressionRewriteRule> rules) {
|
||||
super(new ExpressionRuleExecutor(rules, context));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -17,8 +17,15 @@
|
||||
|
||||
package org.apache.doris.nereids.rules.expression.rewrite;
|
||||
|
||||
import org.apache.doris.qe.ConnectContext;
|
||||
|
||||
/**
|
||||
* expression rewrite context.
|
||||
*/
|
||||
public class ExpressionRewriteContext {
|
||||
public final ConnectContext connectContext;
|
||||
|
||||
public ExpressionRewriteContext(ConnectContext connectContext) {
|
||||
this.connectContext = connectContext;
|
||||
}
|
||||
}
|
||||
|
||||
@ -18,8 +18,7 @@
|
||||
package org.apache.doris.nereids.rules.expression.rewrite;
|
||||
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import org.apache.doris.qe.ConnectContext;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
@ -33,14 +32,13 @@ public class ExpressionRuleExecutor {
|
||||
private final ExpressionRewriteContext ctx;
|
||||
private final List<ExpressionRewriteRule> rules;
|
||||
|
||||
public ExpressionRuleExecutor(List<ExpressionRewriteRule> rules) {
|
||||
public ExpressionRuleExecutor(List<ExpressionRewriteRule> rules, ConnectContext context) {
|
||||
this.rules = rules;
|
||||
this.ctx = new ExpressionRewriteContext();
|
||||
this.ctx = new ExpressionRewriteContext(context);
|
||||
}
|
||||
|
||||
public ExpressionRuleExecutor(ExpressionRewriteRule rule) {
|
||||
this.rules = Lists.newArrayList(rule);
|
||||
this.ctx = new ExpressionRewriteContext();
|
||||
public ExpressionRuleExecutor(List<ExpressionRewriteRule> rules) {
|
||||
this(rules, null);
|
||||
}
|
||||
|
||||
public List<Expression> rewrite(List<Expression> exprs) {
|
||||
|
||||
@ -0,0 +1,39 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package org.apache.doris.nereids.rules.expression.rewrite.rules;
|
||||
|
||||
import org.apache.doris.nereids.rules.expression.rewrite.AbstractExpressionRewriteRule;
|
||||
import org.apache.doris.nereids.rules.expression.rewrite.ExpressionRewriteContext;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
|
||||
/**
|
||||
* Constant evaluation of an expression.
|
||||
*/
|
||||
public class FoldConstantRule extends AbstractExpressionRewriteRule {
|
||||
|
||||
public static final FoldConstantRule INSTANCE = new FoldConstantRule();
|
||||
|
||||
@Override
|
||||
public Expression rewrite(Expression expr, ExpressionRewriteContext ctx) {
|
||||
if (ctx.connectContext != null && ctx.connectContext.getSessionVariable().isEnableFoldConstantByBe()) {
|
||||
return FoldConstantRuleOnBE.INSTANCE.rewrite(expr, ctx);
|
||||
}
|
||||
return FoldConstantRuleOnFE.INSTANCE.rewrite(expr, ctx);
|
||||
}
|
||||
}
|
||||
|
||||
@ -0,0 +1,200 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package org.apache.doris.nereids.rules.expression.rewrite.rules;
|
||||
|
||||
import org.apache.doris.analysis.Expr;
|
||||
import org.apache.doris.analysis.ExprId;
|
||||
import org.apache.doris.analysis.LiteralExpr;
|
||||
import org.apache.doris.catalog.Env;
|
||||
import org.apache.doris.catalog.PrimitiveType;
|
||||
import org.apache.doris.catalog.Type;
|
||||
import org.apache.doris.common.IdGenerator;
|
||||
import org.apache.doris.common.UserException;
|
||||
import org.apache.doris.common.util.TimeUtils;
|
||||
import org.apache.doris.common.util.VectorizedUtil;
|
||||
import org.apache.doris.nereids.glue.translator.ExpressionTranslator;
|
||||
import org.apache.doris.nereids.rules.expression.rewrite.AbstractExpressionRewriteRule;
|
||||
import org.apache.doris.nereids.rules.expression.rewrite.ExpressionRewriteContext;
|
||||
import org.apache.doris.nereids.trees.expressions.Between;
|
||||
import org.apache.doris.nereids.trees.expressions.Cast;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.Literal;
|
||||
import org.apache.doris.nereids.types.DataType;
|
||||
import org.apache.doris.proto.InternalService;
|
||||
import org.apache.doris.proto.InternalService.PConstantExprResult;
|
||||
import org.apache.doris.qe.ConnectContext;
|
||||
import org.apache.doris.rpc.BackendServiceProxy;
|
||||
import org.apache.doris.system.Backend;
|
||||
import org.apache.doris.thrift.TExpr;
|
||||
import org.apache.doris.thrift.TFoldConstantParams;
|
||||
import org.apache.doris.thrift.TNetworkAddress;
|
||||
import org.apache.doris.thrift.TPrimitiveType;
|
||||
import org.apache.doris.thrift.TQueryGlobals;
|
||||
|
||||
import com.google.common.collect.Maps;
|
||||
import org.apache.logging.log4j.LogManager;
|
||||
import org.apache.logging.log4j.Logger;
|
||||
|
||||
import java.time.LocalDateTime;
|
||||
import java.time.format.DateTimeFormatter;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Map.Entry;
|
||||
import java.util.Objects;
|
||||
import java.util.concurrent.Future;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
/**
|
||||
* Constant evaluation of an expression.
|
||||
*/
|
||||
public class FoldConstantRuleOnBE extends AbstractExpressionRewriteRule {
|
||||
public static final FoldConstantRuleOnBE INSTANCE = new FoldConstantRuleOnBE();
|
||||
private static final Logger LOG = LogManager.getLogger(FoldConstantRuleOnBE.class);
|
||||
private static final DateTimeFormatter DATE_FORMAT = DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss");
|
||||
private final IdGenerator<ExprId> idGenerator = ExprId.createGenerator();
|
||||
|
||||
@Override
|
||||
public Expression rewrite(Expression expr, ExpressionRewriteContext ctx) {
|
||||
Expression expression = FoldConstantRuleOnFE.INSTANCE.rewrite(expr, ctx);
|
||||
return foldByBE(expression, ctx);
|
||||
}
|
||||
|
||||
private Expression foldByBE(Expression root, ExpressionRewriteContext context) {
|
||||
Map<String, Expression> constMap = Maps.newHashMap();
|
||||
Map<String, TExpr> staleConstTExprMap = Maps.newHashMap();
|
||||
collectConst(root, constMap, staleConstTExprMap);
|
||||
if (constMap.isEmpty()) {
|
||||
return root;
|
||||
}
|
||||
Map<String, Map<String, TExpr>> paramMap = new HashMap<>();
|
||||
paramMap.put("0", staleConstTExprMap);
|
||||
Map<String, Expression> resultMap = evalOnBE(paramMap, constMap, context.connectContext);
|
||||
if (!resultMap.isEmpty()) {
|
||||
return replace(root, constMap, resultMap);
|
||||
}
|
||||
return root;
|
||||
}
|
||||
|
||||
private Expression replace(Expression root, Map<String, Expression> constMap, Map<String, Expression> resultMap) {
|
||||
for (Entry<String, Expression> entry : constMap.entrySet()) {
|
||||
if (entry.getValue().equals(root)) {
|
||||
return resultMap.get(entry.getKey());
|
||||
}
|
||||
}
|
||||
List<Expression> newChildren = new ArrayList<>();
|
||||
boolean hasNewChildren = false;
|
||||
for (Expression child : root.children()) {
|
||||
Expression newChild = replace(child, constMap, resultMap);
|
||||
if (newChild != child) {
|
||||
hasNewChildren = true;
|
||||
}
|
||||
newChildren.add(newChild);
|
||||
}
|
||||
return hasNewChildren ? root.withChildren(newChildren) : root;
|
||||
}
|
||||
|
||||
private void collectConst(Expression expr, Map<String, Expression> constMap, Map<String, TExpr> tExprMap) {
|
||||
if (expr.isConstant()) {
|
||||
// Do not constant fold cast(null as dataType) because we cannot preserve the
|
||||
// cast-to-types and that can lead to query failures, e.g., CTAS
|
||||
if (expr instanceof Cast) {
|
||||
if (((Cast) expr).child().isNullLiteral()) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
// skip literal expr
|
||||
if (expr.isLiteral()) {
|
||||
return;
|
||||
}
|
||||
// skip BetweenPredicate need to be rewrite to CompoundPredicate
|
||||
if (expr instanceof Between) {
|
||||
return;
|
||||
}
|
||||
String id = idGenerator.getNextId().toString();
|
||||
constMap.put(id, expr);
|
||||
Expr staleExpr = ExpressionTranslator.translate(expr, null);
|
||||
tExprMap.put(id, staleExpr.treeToThrift());
|
||||
} else {
|
||||
for (int i = 0; i < expr.children().size(); i++) {
|
||||
final Expression child = expr.children().get(i);
|
||||
collectConst(child, constMap, tExprMap);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private Map<String, Expression> evalOnBE(Map<String, Map<String, TExpr>> paramMap,
|
||||
Map<String, Expression> constMap, ConnectContext context) {
|
||||
|
||||
Map<String, Expression> resultMap = new HashMap<>();
|
||||
try {
|
||||
List<Long> backendIds = Env.getCurrentSystemInfo().getBackendIds(true);
|
||||
if (backendIds.isEmpty()) {
|
||||
throw new UserException("No alive backends");
|
||||
}
|
||||
Collections.shuffle(backendIds);
|
||||
Backend be = Env.getCurrentSystemInfo().getBackend(backendIds.get(0));
|
||||
TNetworkAddress brpcAddress = new TNetworkAddress(be.getHost(), be.getBrpcPort());
|
||||
|
||||
TQueryGlobals queryGlobals = new TQueryGlobals();
|
||||
queryGlobals.setNowString(DATE_FORMAT.format(LocalDateTime.now()));
|
||||
queryGlobals.setTimestampMs(System.currentTimeMillis());
|
||||
queryGlobals.setTimeZone(TimeUtils.DEFAULT_TIME_ZONE);
|
||||
if (context.getSessionVariable().getTimeZone().equals("CST")) {
|
||||
queryGlobals.setTimeZone(TimeUtils.DEFAULT_TIME_ZONE);
|
||||
} else {
|
||||
queryGlobals.setTimeZone(context.getSessionVariable().getTimeZone());
|
||||
}
|
||||
|
||||
TFoldConstantParams tParams = new TFoldConstantParams(paramMap, queryGlobals);
|
||||
tParams.setVecExec(VectorizedUtil.isVectorized());
|
||||
|
||||
Future<PConstantExprResult> future =
|
||||
BackendServiceProxy.getInstance().foldConstantExpr(brpcAddress, tParams);
|
||||
PConstantExprResult result = future.get(5, TimeUnit.SECONDS);
|
||||
|
||||
if (result.getStatus().getStatusCode() == 0) {
|
||||
for (Entry<String, InternalService.PExprResultMap> e : result.getExprResultMapMap().entrySet()) {
|
||||
for (Entry<String, InternalService.PExprResult> e1 : e.getValue().getMapMap().entrySet()) {
|
||||
Expression ret;
|
||||
if (e1.getValue().getSuccess()) {
|
||||
TPrimitiveType type = TPrimitiveType.findByValue(e1.getValue().getType().getType());
|
||||
Type t = Type.fromPrimitiveType(PrimitiveType.fromThrift(Objects.requireNonNull(type)));
|
||||
Expr staleExpr = LiteralExpr.create(e1.getValue().getContent(), Objects.requireNonNull(t));
|
||||
// Nereids type
|
||||
DataType t1 = DataType.convertFromString(staleExpr.getType().getPrimitiveType().toString());
|
||||
ret = Literal.of(staleExpr.getStringValue()).castTo(t1);
|
||||
} else {
|
||||
ret = constMap.get(e.getKey());
|
||||
}
|
||||
resultMap.put(e.getKey(), ret);
|
||||
}
|
||||
}
|
||||
|
||||
} else {
|
||||
LOG.warn("failed to get const expr value from be: {}", result.getStatus().getErrorMsgsList());
|
||||
}
|
||||
} catch (Exception e) {
|
||||
LOG.warn("failed to get const expr value from be: {}", e.getMessage());
|
||||
}
|
||||
return resultMap;
|
||||
}
|
||||
}
|
||||
|
||||
@ -0,0 +1,310 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package org.apache.doris.nereids.rules.expression.rewrite.rules;
|
||||
|
||||
import org.apache.doris.nereids.rules.expression.rewrite.AbstractExpressionRewriteRule;
|
||||
import org.apache.doris.nereids.rules.expression.rewrite.ExpressionRewriteContext;
|
||||
import org.apache.doris.nereids.trees.expressions.And;
|
||||
import org.apache.doris.nereids.trees.expressions.BinaryArithmetic;
|
||||
import org.apache.doris.nereids.trees.expressions.CaseWhen;
|
||||
import org.apache.doris.nereids.trees.expressions.Cast;
|
||||
import org.apache.doris.nereids.trees.expressions.EqualTo;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.ExpressionEvaluator;
|
||||
import org.apache.doris.nereids.trees.expressions.GreaterThan;
|
||||
import org.apache.doris.nereids.trees.expressions.GreaterThanEqual;
|
||||
import org.apache.doris.nereids.trees.expressions.InPredicate;
|
||||
import org.apache.doris.nereids.trees.expressions.IsNull;
|
||||
import org.apache.doris.nereids.trees.expressions.LessThan;
|
||||
import org.apache.doris.nereids.trees.expressions.LessThanEqual;
|
||||
import org.apache.doris.nereids.trees.expressions.Like;
|
||||
import org.apache.doris.nereids.trees.expressions.Not;
|
||||
import org.apache.doris.nereids.trees.expressions.NullSafeEqual;
|
||||
import org.apache.doris.nereids.trees.expressions.Or;
|
||||
import org.apache.doris.nereids.trees.expressions.Slot;
|
||||
import org.apache.doris.nereids.trees.expressions.TimestampArithmetic;
|
||||
import org.apache.doris.nereids.trees.expressions.WhenClause;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.BoundFunction;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.Literal;
|
||||
import org.apache.doris.nereids.util.ExpressionUtils;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* evaluate an expression on fe.
|
||||
*/
|
||||
public class FoldConstantRuleOnFE extends AbstractExpressionRewriteRule {
|
||||
public static final FoldConstantRuleOnFE INSTANCE = new FoldConstantRuleOnFE();
|
||||
|
||||
@Override
|
||||
public Expression rewrite(Expression expr, ExpressionRewriteContext ctx) {
|
||||
return process(expr, ctx);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Expression visit(Expression expr, ExpressionRewriteContext context) {
|
||||
return expr;
|
||||
}
|
||||
|
||||
/**
|
||||
* process constant expression.
|
||||
*/
|
||||
public Expression process(Expression expr, ExpressionRewriteContext ctx) {
|
||||
if (expr instanceof PropagateNullable) {
|
||||
List<Expression> children = expr.children()
|
||||
.stream()
|
||||
.map(child -> process(child, ctx))
|
||||
.collect(Collectors.toList());
|
||||
|
||||
if (ExpressionUtils.hasNullLiteral(children)) {
|
||||
return Literal.of(null);
|
||||
}
|
||||
|
||||
if (!ExpressionUtils.isAllLiteral(children)) {
|
||||
return expr.withChildren(children);
|
||||
}
|
||||
return expr.withChildren(children).accept(this, ctx);
|
||||
} else {
|
||||
return expr.accept(this, ctx);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public Expression visitEqualTo(EqualTo equalTo, ExpressionRewriteContext context) {
|
||||
return BooleanLiteral.of(((Literal) equalTo.left()).compareTo((Literal) equalTo.right()) == 0);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Expression visitGreaterThan(GreaterThan greaterThan, ExpressionRewriteContext context) {
|
||||
return BooleanLiteral.of(((Literal) greaterThan.left()).compareTo((Literal) greaterThan.right()) > 0);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Expression visitGreaterThanEqual(GreaterThanEqual greaterThanEqual, ExpressionRewriteContext context) {
|
||||
return BooleanLiteral.of(((Literal) greaterThanEqual.left())
|
||||
.compareTo((Literal) greaterThanEqual.right()) >= 0);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Expression visitLessThan(LessThan lessThan, ExpressionRewriteContext context) {
|
||||
return BooleanLiteral.of(((Literal) lessThan.left()).compareTo((Literal) lessThan.right()) < 0);
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public Expression visitLessThanEqual(LessThanEqual lessThanEqual, ExpressionRewriteContext context) {
|
||||
return BooleanLiteral.of(((Literal) lessThanEqual.left()).compareTo((Literal) lessThanEqual.right()) <= 0);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Expression visitNullSafeEqual(NullSafeEqual nullSafeEqual, ExpressionRewriteContext context) {
|
||||
Expression left = process(nullSafeEqual.left(), context);
|
||||
Expression right = process(nullSafeEqual.right(), context);
|
||||
if (ExpressionUtils.isAllLiteral(left, right)) {
|
||||
Literal l = (Literal) left;
|
||||
Literal r = (Literal) right;
|
||||
if (l.isNullLiteral() && r.isNullLiteral()) {
|
||||
return BooleanLiteral.TRUE;
|
||||
} else if (!l.isNullLiteral() && !r.isNullLiteral()) {
|
||||
return BooleanLiteral.of(l.compareTo(r) == 0);
|
||||
} else {
|
||||
return BooleanLiteral.FALSE;
|
||||
}
|
||||
}
|
||||
return nullSafeEqual.withChildren(left, right);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Expression visitNot(Not not, ExpressionRewriteContext context) {
|
||||
return BooleanLiteral.of(!((BooleanLiteral) not.child()).getValue());
|
||||
}
|
||||
|
||||
@Override
|
||||
public Expression visitSlot(Slot slot, ExpressionRewriteContext context) {
|
||||
return slot;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Expression visitLiteral(Literal literal, ExpressionRewriteContext context) {
|
||||
return literal;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Expression visitAnd(And and, ExpressionRewriteContext context) {
|
||||
List<Expression> children = Lists.newArrayList();
|
||||
for (Expression child : and.children()) {
|
||||
Expression newChild = process(child, context);
|
||||
if (newChild.equals(BooleanLiteral.FALSE)) {
|
||||
return BooleanLiteral.FALSE;
|
||||
}
|
||||
if (!newChild.equals(BooleanLiteral.TRUE)) {
|
||||
children.add(newChild);
|
||||
}
|
||||
}
|
||||
if (children.isEmpty()) {
|
||||
return BooleanLiteral.TRUE;
|
||||
}
|
||||
if (children.size() == 1) {
|
||||
return children.get(0);
|
||||
}
|
||||
if (ExpressionUtils.isAllNullLiteral(children)) {
|
||||
return Literal.of(null);
|
||||
}
|
||||
return and.withChildren(children);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Expression visitOr(Or or, ExpressionRewriteContext context) {
|
||||
List<Expression> children = Lists.newArrayList();
|
||||
for (Expression child : or.children()) {
|
||||
Expression newChild = process(child, context);
|
||||
if (newChild.equals(BooleanLiteral.TRUE)) {
|
||||
return BooleanLiteral.TRUE;
|
||||
}
|
||||
if (!newChild.equals(BooleanLiteral.FALSE)) {
|
||||
children.add(newChild);
|
||||
}
|
||||
}
|
||||
if (children.isEmpty()) {
|
||||
return BooleanLiteral.FALSE;
|
||||
}
|
||||
if (children.size() == 1) {
|
||||
return children.get(0);
|
||||
}
|
||||
if (ExpressionUtils.isAllNullLiteral(children)) {
|
||||
return Literal.of(null);
|
||||
}
|
||||
return or.withChildren(children);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Expression visitLike(Like like, ExpressionRewriteContext context) {
|
||||
return like;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Expression visitCast(Cast cast, ExpressionRewriteContext context) {
|
||||
Expression child = process(cast.child(), context);
|
||||
// todo: process other null case
|
||||
if (child.isNullLiteral()) {
|
||||
return Literal.of(null);
|
||||
}
|
||||
if (child.isLiteral()) {
|
||||
return child.castTo(cast.getDataType());
|
||||
}
|
||||
return cast.withChildren(child);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Expression visitBoundFunction(BoundFunction boundFunction, ExpressionRewriteContext context) {
|
||||
List<Expression> newArgs = boundFunction.getArguments().stream().map(arg -> process(arg, context))
|
||||
.collect(Collectors.toList());
|
||||
if (ExpressionUtils.isAllLiteral(newArgs)) {
|
||||
return ExpressionEvaluator.INSTANCE.eval(boundFunction.withChildren(newArgs));
|
||||
}
|
||||
return boundFunction.withChildren(newArgs);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Expression visitBinaryArithmetic(BinaryArithmetic binaryArithmetic, ExpressionRewriteContext context) {
|
||||
return ExpressionEvaluator.INSTANCE.eval(binaryArithmetic);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Expression visitCaseWhen(CaseWhen caseWhen, ExpressionRewriteContext context) {
|
||||
Expression newDefault = null;
|
||||
boolean foundNewDefault = false;
|
||||
|
||||
List<WhenClause> whenClauses = new ArrayList<>();
|
||||
for (WhenClause whenClause : caseWhen.getWhenClauses()) {
|
||||
Expression whenOperand = process(whenClause.getOperand(), context);
|
||||
|
||||
if (!(whenOperand.isLiteral())) {
|
||||
whenClauses.add(new WhenClause(whenOperand, process(whenClause.getResult(), context)));
|
||||
} else if (BooleanLiteral.TRUE.equals(whenOperand)) {
|
||||
foundNewDefault = true;
|
||||
newDefault = process(whenClause.getResult(), context);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
Expression defaultResult;
|
||||
if (foundNewDefault) {
|
||||
defaultResult = newDefault;
|
||||
} else {
|
||||
defaultResult = process(caseWhen.getDefaultValue().orElse(Literal.of(null)), context);
|
||||
}
|
||||
|
||||
if (whenClauses.isEmpty()) {
|
||||
return defaultResult;
|
||||
}
|
||||
return new CaseWhen(whenClauses, defaultResult);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Expression visitInPredicate(InPredicate inPredicate, ExpressionRewriteContext context) {
|
||||
Expression value = process(inPredicate.child(0), context);
|
||||
List<Expression> children = Lists.newArrayList();
|
||||
children.add(value);
|
||||
if (value.isNullLiteral()) {
|
||||
return Literal.of(null);
|
||||
}
|
||||
boolean hasNull = false;
|
||||
boolean hasUnresolvedValue = !value.isLiteral();
|
||||
for (int i = 1; i < inPredicate.children().size(); i++) {
|
||||
Expression inValue = process(inPredicate.child(i), context);
|
||||
children.add(inValue);
|
||||
if (!inValue.isLiteral()) {
|
||||
hasUnresolvedValue = true;
|
||||
}
|
||||
if (inValue.isNullLiteral()) {
|
||||
hasNull = true;
|
||||
}
|
||||
if (inValue.isLiteral() && value.isLiteral() && ((Literal) value).compareTo((Literal) inValue) == 0) {
|
||||
return Literal.of(true);
|
||||
}
|
||||
}
|
||||
if (hasUnresolvedValue) {
|
||||
return inPredicate.withChildren(children);
|
||||
}
|
||||
return hasNull ? Literal.of(null) : Literal.of(false);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Expression visitIsNull(IsNull isNull, ExpressionRewriteContext context) {
|
||||
Expression child = process(isNull.child(), context);
|
||||
if (child.isNullLiteral()) {
|
||||
return Literal.of(true);
|
||||
} else if (!child.nullable()) {
|
||||
return Literal.of(false);
|
||||
}
|
||||
return isNull.withChildren(child);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Expression visitTimestampArithmetic(TimestampArithmetic arithmetic, ExpressionRewriteContext context) {
|
||||
return ExpressionEvaluator.INSTANCE.eval(arithmetic);
|
||||
}
|
||||
}
|
||||
|
||||
@ -17,16 +17,19 @@
|
||||
|
||||
package org.apache.doris.nereids.rules.expression.rewrite.rules;
|
||||
|
||||
import org.apache.doris.analysis.ArithmeticExpr.Operator;
|
||||
import org.apache.doris.nereids.annotation.Developing;
|
||||
import org.apache.doris.nereids.rules.expression.rewrite.AbstractExpressionRewriteRule;
|
||||
import org.apache.doris.nereids.rules.expression.rewrite.ExpressionRewriteContext;
|
||||
import org.apache.doris.nereids.trees.expressions.BinaryOperator;
|
||||
import org.apache.doris.nereids.trees.expressions.CaseWhen;
|
||||
import org.apache.doris.nereids.trees.expressions.Cast;
|
||||
import org.apache.doris.nereids.trees.expressions.Divide;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.InPredicate;
|
||||
import org.apache.doris.nereids.trees.expressions.typecoercion.ImplicitCastInputTypes;
|
||||
import org.apache.doris.nereids.types.DataType;
|
||||
import org.apache.doris.nereids.types.DoubleType;
|
||||
import org.apache.doris.nereids.types.coercion.AbstractDataType;
|
||||
import org.apache.doris.nereids.util.TypeCoercionUtils;
|
||||
|
||||
@ -45,9 +48,8 @@ public class TypeCoercion extends AbstractExpressionRewriteRule {
|
||||
|
||||
// TODO:
|
||||
// 1. DecimalPrecision Process
|
||||
// 2. Divide process
|
||||
// 3. String promote with numeric in binary arithmetic
|
||||
// 4. Date and DateTime process
|
||||
// 2. String promote with numeric in binary arithmetic
|
||||
// 3. Date and DateTime process
|
||||
|
||||
public static final TypeCoercion INSTANCE = new TypeCoercion();
|
||||
|
||||
@ -82,6 +84,23 @@ public class TypeCoercion extends AbstractExpressionRewriteRule {
|
||||
.orElse(binaryOperator.withChildren(left, right));
|
||||
}
|
||||
|
||||
@Override
|
||||
public Expression visitDivide(Divide divide, ExpressionRewriteContext context) {
|
||||
Expression left = rewrite(divide.left(), context);
|
||||
Expression right = rewrite(divide.right(), context);
|
||||
DataType t1 = TypeCoercionUtils.getNumResultType(left.getDataType());
|
||||
DataType t2 = TypeCoercionUtils.getNumResultType(right.getDataType());
|
||||
DataType commonType = TypeCoercionUtils.findCommonNumericsType(t1, t2);
|
||||
if (divide.getLegacyOperator() == Operator.DIVIDE) {
|
||||
if (commonType.isBigIntType() || commonType.isLargeIntType()) {
|
||||
commonType = DoubleType.INSTANCE;
|
||||
}
|
||||
}
|
||||
Expression newLeft = TypeCoercionUtils.castIfNotSameType(left, commonType);
|
||||
Expression newRight = TypeCoercionUtils.castIfNotSameType(right, commonType);
|
||||
return divide.withChildren(newLeft, newRight);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Expression visitCaseWhen(CaseWhen caseWhen, ExpressionRewriteContext context) {
|
||||
List<Expression> rewrittenChildren = caseWhen.children().stream()
|
||||
|
||||
@ -19,13 +19,14 @@ package org.apache.doris.nereids.trees.expressions;
|
||||
|
||||
import org.apache.doris.analysis.ArithmeticExpr.Operator;
|
||||
import org.apache.doris.nereids.exceptions.UnboundException;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable;
|
||||
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
|
||||
import org.apache.doris.nereids.types.DataType;
|
||||
|
||||
/**
|
||||
* binary arithmetic operator. Such as +, -, *, /.
|
||||
*/
|
||||
public abstract class BinaryArithmetic extends BinaryOperator {
|
||||
public abstract class BinaryArithmetic extends BinaryOperator implements PropagateNullable {
|
||||
|
||||
private final Operator legacyOperator;
|
||||
|
||||
|
||||
@ -71,5 +71,6 @@ public abstract class CompoundPredicate extends BinaryOperator {
|
||||
public abstract CompoundPredicate flip(Expression left, Expression right);
|
||||
|
||||
public abstract Class<? extends CompoundPredicate> flipType();
|
||||
|
||||
}
|
||||
|
||||
|
||||
@ -18,6 +18,7 @@
|
||||
package org.apache.doris.nereids.trees.expressions;
|
||||
|
||||
import org.apache.doris.nereids.exceptions.UnboundException;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable;
|
||||
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
@ -27,7 +28,7 @@ import java.util.List;
|
||||
/**
|
||||
* Equal to expression: a = b.
|
||||
*/
|
||||
public class EqualTo extends ComparisonPredicate {
|
||||
public class EqualTo extends ComparisonPredicate implements PropagateNullable {
|
||||
|
||||
public EqualTo(Expression left, Expression right) {
|
||||
super(left, right, "=");
|
||||
|
||||
@ -0,0 +1,46 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package org.apache.doris.nereids.trees.expressions;
|
||||
|
||||
import java.lang.annotation.ElementType;
|
||||
import java.lang.annotation.Retention;
|
||||
import java.lang.annotation.RetentionPolicy;
|
||||
import java.lang.annotation.Target;
|
||||
|
||||
/**
|
||||
* nereids function annotation.
|
||||
*/
|
||||
@Retention(RetentionPolicy.RUNTIME)
|
||||
@Target(ElementType.METHOD)
|
||||
public @interface ExecFunction {
|
||||
|
||||
/**
|
||||
* function name
|
||||
*/
|
||||
String name();
|
||||
|
||||
/**
|
||||
* args type
|
||||
*/
|
||||
String[] argTypes();
|
||||
|
||||
/**
|
||||
* return type
|
||||
*/
|
||||
String returnType();
|
||||
}
|
||||
@ -0,0 +1,35 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package org.apache.doris.nereids.trees.expressions;
|
||||
|
||||
import java.lang.annotation.ElementType;
|
||||
import java.lang.annotation.Retention;
|
||||
import java.lang.annotation.RetentionPolicy;
|
||||
import java.lang.annotation.Target;
|
||||
|
||||
/**
|
||||
* exec function list
|
||||
*/
|
||||
@Retention(RetentionPolicy.RUNTIME)
|
||||
@Target(ElementType.METHOD)
|
||||
public @interface ExecFunctionList {
|
||||
/**
|
||||
* exec functions.
|
||||
*/
|
||||
ExecFunction[] value();
|
||||
}
|
||||
@ -23,6 +23,7 @@ import org.apache.doris.nereids.exceptions.UnboundException;
|
||||
import org.apache.doris.nereids.trees.AbstractTreeNode;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.ComputeNullable;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.Literal;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
|
||||
import org.apache.doris.nereids.trees.expressions.shape.LeafExpression;
|
||||
import org.apache.doris.nereids.trees.expressions.typecoercion.ExpectsInputTypes;
|
||||
import org.apache.doris.nereids.trees.expressions.typecoercion.TypeCheckResult;
|
||||
@ -139,6 +140,14 @@ public abstract class Expression extends AbstractTreeNode<Expression> implements
|
||||
return collect(Slot.class::isInstance);
|
||||
}
|
||||
|
||||
public boolean isLiteral() {
|
||||
return this instanceof Literal;
|
||||
}
|
||||
|
||||
public boolean isNullLiteral() {
|
||||
return this instanceof NullLiteral;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) {
|
||||
|
||||
@ -0,0 +1,208 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package org.apache.doris.nereids.trees.expressions;
|
||||
|
||||
import org.apache.doris.catalog.Env;
|
||||
import org.apache.doris.common.AnalysisException;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.ExecutableFunctions;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.Literal;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
|
||||
import org.apache.doris.nereids.types.DataType;
|
||||
|
||||
import com.google.common.collect.ImmutableMultimap;
|
||||
|
||||
import java.lang.reflect.InvocationTargetException;
|
||||
import java.lang.reflect.Method;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collection;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* An expression evaluator that evaluates the value of an expression.
|
||||
*/
|
||||
public enum ExpressionEvaluator {
|
||||
INSTANCE;
|
||||
private ImmutableMultimap<String, FunctionInvoker> functions;
|
||||
|
||||
ExpressionEvaluator() {
|
||||
registerFunctions();
|
||||
}
|
||||
|
||||
/**
|
||||
* Evaluate the value of the expression.
|
||||
*/
|
||||
public Expression eval(Expression expression) {
|
||||
|
||||
if (!expression.isConstant() || expression instanceof AggregateFunction) {
|
||||
return expression;
|
||||
}
|
||||
|
||||
String fnName = null;
|
||||
DataType[] args = null;
|
||||
if (expression instanceof BinaryArithmetic) {
|
||||
BinaryArithmetic arithmetic = (BinaryArithmetic) expression;
|
||||
fnName = arithmetic.getLegacyOperator().getName();
|
||||
args = new DataType[]{arithmetic.left().getDataType(), arithmetic.right().getDataType()};
|
||||
} else if (expression instanceof TimestampArithmetic) {
|
||||
TimestampArithmetic arithmetic = (TimestampArithmetic) expression;
|
||||
fnName = arithmetic.getFuncName();
|
||||
args = new DataType[]{arithmetic.left().getDataType(), arithmetic.right().getDataType()};
|
||||
}
|
||||
|
||||
if ((Env.getCurrentEnv().isNullResultWithOneNullParamFunction(fnName))) {
|
||||
for (Expression e : expression.children()) {
|
||||
if (e instanceof NullLiteral) {
|
||||
return Literal.of(null);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return invoke(expression, fnName, args);
|
||||
}
|
||||
|
||||
private Expression invoke(Expression expression, String fnName, DataType[] args) {
|
||||
FunctionSignature signature = new FunctionSignature(fnName, args, null);
|
||||
FunctionInvoker invoker = getFunction(signature);
|
||||
if (invoker != null) {
|
||||
try {
|
||||
return invoker.invoke(expression.children());
|
||||
} catch (AnalysisException e) {
|
||||
return expression;
|
||||
}
|
||||
}
|
||||
return expression;
|
||||
}
|
||||
|
||||
private FunctionInvoker getFunction(FunctionSignature signature) {
|
||||
Collection<FunctionInvoker> functionInvokers = functions.get(signature.getName());
|
||||
if (functionInvokers == null) {
|
||||
return null;
|
||||
}
|
||||
for (FunctionInvoker candidate : functionInvokers) {
|
||||
DataType[] candidateTypes = candidate.getSignature().getArgTypes();
|
||||
DataType[] expectedTypes = signature.getArgTypes();
|
||||
|
||||
if (candidateTypes.length != expectedTypes.length) {
|
||||
continue;
|
||||
}
|
||||
boolean match = true;
|
||||
for (int i = 0; i < candidateTypes.length; i++) {
|
||||
if (!candidateTypes[i].equals(expectedTypes[i])) {
|
||||
match = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (match) {
|
||||
return candidate;
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
private void registerFunctions() {
|
||||
if (functions != null) {
|
||||
return;
|
||||
}
|
||||
ImmutableMultimap.Builder<String, FunctionInvoker> mapBuilder =
|
||||
new ImmutableMultimap.Builder<String, FunctionInvoker>();
|
||||
Class clazz = ExecutableFunctions.class;
|
||||
for (Method method : clazz.getDeclaredMethods()) {
|
||||
ExecFunctionList annotationList = method.getAnnotation(ExecFunctionList.class);
|
||||
if (annotationList != null) {
|
||||
for (ExecFunction f : annotationList.value()) {
|
||||
registerFEFunction(mapBuilder, method, f);
|
||||
}
|
||||
}
|
||||
registerFEFunction(mapBuilder, method, method.getAnnotation(ExecFunction.class));
|
||||
}
|
||||
this.functions = mapBuilder.build();
|
||||
}
|
||||
|
||||
private void registerFEFunction(ImmutableMultimap.Builder<String, FunctionInvoker> mapBuilder,
|
||||
Method method, ExecFunction annotation) {
|
||||
if (annotation != null) {
|
||||
String name = annotation.name();
|
||||
DataType returnType = DataType.convertFromString(annotation.returnType());
|
||||
List<DataType> argTypes = new ArrayList<>();
|
||||
for (String type : annotation.argTypes()) {
|
||||
argTypes.add(DataType.convertFromString(type));
|
||||
}
|
||||
FunctionSignature signature = new FunctionSignature(name,
|
||||
argTypes.toArray(new DataType[argTypes.size()]), returnType);
|
||||
mapBuilder.put(name, new FunctionInvoker(method, signature));
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* function invoker.
|
||||
*/
|
||||
public static class FunctionInvoker {
|
||||
private final Method method;
|
||||
private final FunctionSignature signature;
|
||||
|
||||
public FunctionInvoker(Method method, FunctionSignature signature) {
|
||||
this.method = method;
|
||||
this.signature = signature;
|
||||
}
|
||||
|
||||
public Method getMethod() {
|
||||
return method;
|
||||
}
|
||||
|
||||
public FunctionSignature getSignature() {
|
||||
return signature;
|
||||
}
|
||||
|
||||
public Literal invoke(List<Expression> args) throws AnalysisException {
|
||||
try {
|
||||
return (Literal) method.invoke(null, args.toArray());
|
||||
} catch (InvocationTargetException | IllegalAccessException | IllegalArgumentException e) {
|
||||
throw new AnalysisException(e.getLocalizedMessage());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* function signature.
|
||||
*/
|
||||
public static class FunctionSignature {
|
||||
private final String name;
|
||||
private final DataType[] argTypes;
|
||||
private final DataType returnType;
|
||||
|
||||
public FunctionSignature(String name, DataType[] argTypes, DataType returnType) {
|
||||
this.name = name;
|
||||
this.argTypes = argTypes;
|
||||
this.returnType = returnType;
|
||||
}
|
||||
|
||||
public DataType[] getArgTypes() {
|
||||
return argTypes;
|
||||
}
|
||||
|
||||
public DataType getReturnType() {
|
||||
return returnType;
|
||||
}
|
||||
|
||||
public String getName() {
|
||||
return name;
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
@ -18,6 +18,7 @@
|
||||
package org.apache.doris.nereids.trees.expressions;
|
||||
|
||||
import org.apache.doris.nereids.exceptions.UnboundException;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable;
|
||||
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
@ -27,7 +28,7 @@ import java.util.List;
|
||||
/**
|
||||
* Greater than expression: a > b.
|
||||
*/
|
||||
public class GreaterThan extends ComparisonPredicate {
|
||||
public class GreaterThan extends ComparisonPredicate implements PropagateNullable {
|
||||
/**
|
||||
* Constructor of Greater Than ComparisonPredicate.
|
||||
*
|
||||
|
||||
@ -18,6 +18,7 @@
|
||||
package org.apache.doris.nereids.trees.expressions;
|
||||
|
||||
import org.apache.doris.nereids.exceptions.UnboundException;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable;
|
||||
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
@ -27,7 +28,7 @@ import java.util.List;
|
||||
/**
|
||||
* Greater than and equal expression: a >= b.
|
||||
*/
|
||||
public class GreaterThanEqual extends ComparisonPredicate {
|
||||
public class GreaterThanEqual extends ComparisonPredicate implements PropagateNullable {
|
||||
/**
|
||||
* Constructor of Greater Than And Equal.
|
||||
*
|
||||
|
||||
@ -0,0 +1,78 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package org.apache.doris.nereids.trees.expressions;
|
||||
|
||||
import org.apache.doris.nereids.exceptions.UnboundException;
|
||||
import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression;
|
||||
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
|
||||
/**
|
||||
* expr is null predicate.
|
||||
*/
|
||||
public class IsNull extends Expression implements UnaryExpression {
|
||||
|
||||
public IsNull(Expression e) {
|
||||
super(e);
|
||||
}
|
||||
|
||||
@Override
|
||||
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
|
||||
return visitor.visitIsNull(this, context);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean nullable() {
|
||||
return false;
|
||||
}
|
||||
|
||||
@Override
|
||||
public IsNull withChildren(List<Expression> children) {
|
||||
Preconditions.checkArgument(children.size() == 1);
|
||||
return new IsNull(children.get(0));
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toSql() throws UnboundException {
|
||||
return child().toSql() + " IS NULL";
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return toSql();
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (!super.equals(o)) {
|
||||
return false;
|
||||
}
|
||||
IsNull other = (IsNull) o;
|
||||
return Objects.equals(child(), other.child());
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return child().hashCode();
|
||||
}
|
||||
|
||||
}
|
||||
@ -18,6 +18,7 @@
|
||||
package org.apache.doris.nereids.trees.expressions;
|
||||
|
||||
import org.apache.doris.nereids.exceptions.UnboundException;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable;
|
||||
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
@ -27,7 +28,7 @@ import java.util.List;
|
||||
/**
|
||||
* Less than expression: a < b.
|
||||
*/
|
||||
public class LessThan extends ComparisonPredicate {
|
||||
public class LessThan extends ComparisonPredicate implements PropagateNullable {
|
||||
/**
|
||||
* Constructor of Less Than Comparison Predicate.
|
||||
*
|
||||
|
||||
@ -18,6 +18,7 @@
|
||||
package org.apache.doris.nereids.trees.expressions;
|
||||
|
||||
import org.apache.doris.nereids.exceptions.UnboundException;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable;
|
||||
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
@ -27,7 +28,7 @@ import java.util.List;
|
||||
/**
|
||||
* Less than and equal expression: a <= b.
|
||||
*/
|
||||
public class LessThanEqual extends ComparisonPredicate {
|
||||
public class LessThanEqual extends ComparisonPredicate implements PropagateNullable {
|
||||
/**
|
||||
* Constructor of Less Than And Equal.
|
||||
*
|
||||
|
||||
@ -18,6 +18,7 @@
|
||||
package org.apache.doris.nereids.trees.expressions;
|
||||
|
||||
import org.apache.doris.nereids.exceptions.UnboundException;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable;
|
||||
import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression;
|
||||
import org.apache.doris.nereids.trees.expressions.typecoercion.ExpectsInputTypes;
|
||||
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
|
||||
@ -34,7 +35,7 @@ import java.util.Objects;
|
||||
/**
|
||||
* Not expression: not a.
|
||||
*/
|
||||
public class Not extends Expression implements UnaryExpression, ExpectsInputTypes {
|
||||
public class Not extends Expression implements UnaryExpression, ExpectsInputTypes, PropagateNullable {
|
||||
|
||||
public static final List<AbstractDataType> EXPECTS_INPUT_TYPES = ImmutableList.of(BooleanType.INSTANCE);
|
||||
|
||||
|
||||
@ -64,4 +64,5 @@ public class NullSafeEqual extends ComparisonPredicate {
|
||||
public ComparisonPredicate commute() {
|
||||
return new NullSafeEqual(right(), left());
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@ -79,6 +79,10 @@ public class SlotReference extends Slot {
|
||||
this.column = column;
|
||||
}
|
||||
|
||||
public static SlotReference of(String name, DataType type) {
|
||||
return new SlotReference(name, type);
|
||||
}
|
||||
|
||||
public static SlotReference fromColumn(Column column, List<String> qualifier) {
|
||||
DataType dataType = DataType.convertFromCatalogDataType(column.getType());
|
||||
return new SlotReference(NamedExpressionUtil.newExprId(), column.getName(), dataType,
|
||||
|
||||
@ -22,16 +22,21 @@ import org.apache.doris.nereids.exceptions.UnboundException;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.IntervalLiteral.TimeUnit;
|
||||
import org.apache.doris.nereids.trees.expressions.shape.BinaryExpression;
|
||||
import org.apache.doris.nereids.trees.expressions.typecoercion.ImplicitCastInputTypes;
|
||||
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
|
||||
import org.apache.doris.nereids.types.DataType;
|
||||
import org.apache.doris.nereids.types.DateTimeType;
|
||||
import org.apache.doris.nereids.types.DateType;
|
||||
import org.apache.doris.nereids.types.IntegerType;
|
||||
import org.apache.doris.nereids.types.coercion.AbstractDataType;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import org.apache.logging.log4j.LogManager;
|
||||
import org.apache.logging.log4j.Logger;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
|
||||
/**
|
||||
* Describes the addition and subtraction of time units from timestamps.
|
||||
@ -40,7 +45,14 @@ import java.util.List;
|
||||
* Example: '1996-01-01' + INTERVAL '3' month;
|
||||
* TODO: we need to rethink this, and maybe need to add a new type of Interval then implement IntervalLiteral as others
|
||||
*/
|
||||
public class TimestampArithmetic extends Expression implements BinaryExpression, PropagateNullable {
|
||||
public class TimestampArithmetic extends Expression implements BinaryExpression, ImplicitCastInputTypes,
|
||||
PropagateNullable {
|
||||
|
||||
private static final List<AbstractDataType> EXPECTED_INPUT_TYPES = ImmutableList.of(
|
||||
DateTimeType.INSTANCE,
|
||||
IntegerType.INSTANCE
|
||||
);
|
||||
|
||||
private static final Logger LOG = LogManager.getLogger(TimestampArithmetic.class);
|
||||
private final String funcName;
|
||||
private final boolean intervalFirst;
|
||||
@ -149,4 +161,22 @@ public class TimestampArithmetic extends Expression implements BinaryExpression,
|
||||
}
|
||||
return strBuilder.toString();
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) {
|
||||
return true;
|
||||
}
|
||||
if (o == null || getClass() != o.getClass()) {
|
||||
return false;
|
||||
}
|
||||
TimestampArithmetic other = (TimestampArithmetic) o;
|
||||
return Objects.equals(funcName, other.funcName) && Objects.equals(timeUnit, other.timeUnit)
|
||||
&& Objects.equals(left(), other.left()) && Objects.equals(right(), other.right());
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<AbstractDataType> expectedInputTypes() {
|
||||
return EXPECTED_INPUT_TYPES;
|
||||
}
|
||||
}
|
||||
|
||||
@ -54,7 +54,7 @@ public class WhenClause extends Expression implements BinaryExpression, ExpectsI
|
||||
|
||||
@Override
|
||||
public String toSql() {
|
||||
return "WHEN " + left().toSql() + " THEN " + right().toSql();
|
||||
return " WHEN " + left().toSql() + " THEN " + right().toSql();
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@ -0,0 +1,243 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package org.apache.doris.nereids.trees.expressions.functions;
|
||||
|
||||
import org.apache.doris.nereids.exceptions.AnalysisException;
|
||||
import org.apache.doris.nereids.trees.expressions.ExecFunction;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.BigIntLiteral;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.DateTimeLiteral;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.DecimalLiteral;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.DoubleLiteral;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.SmallIntLiteral;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.TinyIntLiteral;
|
||||
|
||||
import org.apache.logging.log4j.LogManager;
|
||||
import org.apache.logging.log4j.Logger;
|
||||
|
||||
import java.math.BigDecimal;
|
||||
|
||||
/**
|
||||
* functions that can be executed in FE.
|
||||
*/
|
||||
public class ExecutableFunctions {
|
||||
public static final ExecutableFunctions INSTANCE = new ExecutableFunctions();
|
||||
|
||||
private static final Logger LOG = LogManager.getLogger(ExecutableFunctions.class);
|
||||
|
||||
/**
|
||||
* Executable arithmetic functions
|
||||
*/
|
||||
|
||||
@ExecFunction(name = "add", argTypes = {"TINYINT", "TINYINT"}, returnType = "TINYINT")
|
||||
public static TinyIntLiteral addTinyint(TinyIntLiteral first, TinyIntLiteral second) {
|
||||
byte result = (byte) Math.addExact(first.getValue(), second.getValue());
|
||||
return new TinyIntLiteral(result);
|
||||
}
|
||||
|
||||
@ExecFunction(name = "add", argTypes = {"SMALLINT", "SMALLINT"}, returnType = "SMALLINT")
|
||||
public static SmallIntLiteral addSmallint(SmallIntLiteral first, SmallIntLiteral second) {
|
||||
short result = (short) Math.addExact(first.getValue(), second.getValue());
|
||||
return new SmallIntLiteral(result);
|
||||
}
|
||||
|
||||
@ExecFunction(name = "add", argTypes = {"INT", "INT"}, returnType = "INT")
|
||||
public static IntegerLiteral addInt(IntegerLiteral first, IntegerLiteral second) {
|
||||
int result = Math.addExact(first.getValue(), second.getValue());
|
||||
return new IntegerLiteral(result);
|
||||
}
|
||||
|
||||
@ExecFunction(name = "add", argTypes = {"BIGINT", "BIGINT"}, returnType = "BIGINT")
|
||||
public static BigIntLiteral addBigint(BigIntLiteral first, BigIntLiteral second) {
|
||||
long result = Math.addExact(first.getValue(), second.getValue());
|
||||
return new BigIntLiteral(result);
|
||||
}
|
||||
|
||||
@ExecFunction(name = "add", argTypes = {"DOUBLE", "DOUBLE"}, returnType = "DOUBLE")
|
||||
public static DoubleLiteral addDouble(DoubleLiteral first, DoubleLiteral second) {
|
||||
double result = first.getValue() + second.getValue();
|
||||
return new DoubleLiteral(result);
|
||||
}
|
||||
|
||||
@ExecFunction(name = "add", argTypes = {"DECIMAL", "DECIMAL"}, returnType = "DECIMAL")
|
||||
public static DecimalLiteral addDecimal(DecimalLiteral first, DecimalLiteral second) {
|
||||
BigDecimal result = first.getValue().add(second.getValue());
|
||||
return new DecimalLiteral(result);
|
||||
}
|
||||
|
||||
@ExecFunction(name = "subtract", argTypes = {"TINYINT", "TINYINT"}, returnType = "TINYINT")
|
||||
public static TinyIntLiteral subtractTinyint(TinyIntLiteral first, TinyIntLiteral second) {
|
||||
byte result = (byte) Math.subtractExact(first.getValue(), second.getValue());
|
||||
return new TinyIntLiteral(result);
|
||||
}
|
||||
|
||||
@ExecFunction(name = "subtract", argTypes = {"SMALLINT", "SMALLINT"}, returnType = "SMALLINT")
|
||||
public static SmallIntLiteral subtractSmallint(SmallIntLiteral first, SmallIntLiteral second) {
|
||||
short result = (short) Math.subtractExact(first.getValue(), second.getValue());
|
||||
return new SmallIntLiteral(result);
|
||||
}
|
||||
|
||||
@ExecFunction(name = "subtract", argTypes = {"INT", "INT"}, returnType = "INT")
|
||||
public static IntegerLiteral subtractInt(IntegerLiteral first, IntegerLiteral second) {
|
||||
int result = Math.subtractExact(first.getValue(), second.getValue());
|
||||
return new IntegerLiteral(result);
|
||||
}
|
||||
|
||||
@ExecFunction(name = "subtract", argTypes = {"BIGINT", "BIGINT"}, returnType = "BIGINT")
|
||||
public static BigIntLiteral subtractBigint(BigIntLiteral first, BigIntLiteral second) {
|
||||
long result = Math.subtractExact(first.getValue(), second.getValue());
|
||||
return new BigIntLiteral(result);
|
||||
}
|
||||
|
||||
@ExecFunction(name = "subtract", argTypes = {"DOUBLE", "DOUBLE"}, returnType = "DOUBLE")
|
||||
public static DoubleLiteral subtractDouble(DoubleLiteral first, DoubleLiteral second) {
|
||||
double result = first.getValue() - second.getValue();
|
||||
return new DoubleLiteral(result);
|
||||
}
|
||||
|
||||
@ExecFunction(name = "subtract", argTypes = {"DECIMAL", "DECIMAL"}, returnType = "DECIMAL")
|
||||
public static DecimalLiteral subtractDecimal(DecimalLiteral first, DecimalLiteral second) {
|
||||
BigDecimal result = first.getValue().subtract(second.getValue());
|
||||
return new DecimalLiteral(result);
|
||||
}
|
||||
|
||||
@ExecFunction(name = "multiply", argTypes = {"TINYINT", "TINYINT"}, returnType = "TINYINT")
|
||||
public static TinyIntLiteral multiplyTinyint(TinyIntLiteral first, TinyIntLiteral second) {
|
||||
byte result = (byte) Math.multiplyExact(first.getValue(), second.getValue());
|
||||
return new TinyIntLiteral(result);
|
||||
}
|
||||
|
||||
@ExecFunction(name = "multiply", argTypes = {"SMALLINT", "SMALLINT"}, returnType = "SMALLINT")
|
||||
public static SmallIntLiteral multiplySmallint(SmallIntLiteral first, SmallIntLiteral second) {
|
||||
short result = (short) Math.multiplyExact(first.getValue(), second.getValue());
|
||||
return new SmallIntLiteral(result);
|
||||
}
|
||||
|
||||
@ExecFunction(name = "multiply", argTypes = {"INT", "INT"}, returnType = "INT")
|
||||
public static IntegerLiteral multiplyInt(IntegerLiteral first, IntegerLiteral second) {
|
||||
int result = Math.multiplyExact(first.getValue(), second.getValue());
|
||||
return new IntegerLiteral(result);
|
||||
}
|
||||
|
||||
@ExecFunction(name = "multiply", argTypes = {"BIGINT", "BIGINT"}, returnType = "BIGINT")
|
||||
public static BigIntLiteral multiplyBigint(BigIntLiteral first, BigIntLiteral second) {
|
||||
long result = Math.multiplyExact(first.getValue(), second.getValue());
|
||||
return new BigIntLiteral(result);
|
||||
}
|
||||
|
||||
@ExecFunction(name = "multiply", argTypes = {"DOUBLE", "DOUBLE"}, returnType = "DOUBLE")
|
||||
public static DoubleLiteral multiplyDouble(DoubleLiteral first, DoubleLiteral second) {
|
||||
double result = first.getValue() * second.getValue();
|
||||
return new DoubleLiteral(result);
|
||||
}
|
||||
|
||||
@ExecFunction(name = "multiply", argTypes = {"DECIMAL", "DECIMAL"}, returnType = "DECIMAL")
|
||||
public static DecimalLiteral multiplyDecimal(DecimalLiteral first, DecimalLiteral second) {
|
||||
BigDecimal result = first.getValue().multiply(second.getValue());
|
||||
return new DecimalLiteral(result);
|
||||
}
|
||||
|
||||
@ExecFunction(name = "divide", argTypes = {"DOUBLE", "DOUBLE"}, returnType = "DOUBLE")
|
||||
public static DoubleLiteral divideDouble(DoubleLiteral first, DoubleLiteral second) {
|
||||
if (second.getValue() == 0.0) {
|
||||
return null;
|
||||
}
|
||||
double result = first.getValue() / second.getValue();
|
||||
return new DoubleLiteral(result);
|
||||
}
|
||||
|
||||
@ExecFunction(name = "divide", argTypes = {"DECIMAL", "DECIMAL"}, returnType = "DECIMAL")
|
||||
public static DecimalLiteral divideDecimal(DecimalLiteral first, DecimalLiteral second) {
|
||||
if (first.getValue().compareTo(BigDecimal.ZERO) == 0) {
|
||||
return null;
|
||||
}
|
||||
BigDecimal result = first.getValue().divide(second.getValue());
|
||||
return new DecimalLiteral(result);
|
||||
}
|
||||
|
||||
@ExecFunction(name = "date_sub", argTypes = { "DATETIME", "INT" }, returnType = "DATETIME")
|
||||
public static DateTimeLiteral dateSub(DateTimeLiteral date, IntegerLiteral day) throws AnalysisException {
|
||||
return dateAdd(date, new IntegerLiteral(-day.getValue()));
|
||||
}
|
||||
|
||||
@ExecFunction(name = "date_add", argTypes = { "DATETIME", "INT" }, returnType = "DATETIME")
|
||||
public static DateTimeLiteral dateAdd(DateTimeLiteral date, IntegerLiteral day) throws AnalysisException {
|
||||
return daysAdd(date, day);
|
||||
}
|
||||
|
||||
@ExecFunction(name = "years_add", argTypes = { "DATETIME", "INT" }, returnType = "DATETIME")
|
||||
public static DateTimeLiteral yearsAdd(DateTimeLiteral date, IntegerLiteral year) throws AnalysisException {
|
||||
return date.plusYears(year.getValue());
|
||||
}
|
||||
|
||||
@ExecFunction(name = "months_add", argTypes = { "DATETIME", "INT" }, returnType = "DATETIME")
|
||||
public static DateTimeLiteral monthsAdd(DateTimeLiteral date, IntegerLiteral month) throws AnalysisException {
|
||||
return date.plusMonths(month.getValue());
|
||||
}
|
||||
|
||||
@ExecFunction(name = "days_add", argTypes = { "DATETIME", "INT" }, returnType = "DATETIME")
|
||||
public static DateTimeLiteral daysAdd(DateTimeLiteral date, IntegerLiteral day) throws AnalysisException {
|
||||
return date.plusDays(day.getValue());
|
||||
}
|
||||
|
||||
@ExecFunction(name = "hours_add", argTypes = { "DATETIME", "INT" }, returnType = "DATETIME")
|
||||
public static DateTimeLiteral hoursAdd(DateTimeLiteral date, IntegerLiteral hour) throws AnalysisException {
|
||||
return date.plusHours(hour.getValue());
|
||||
}
|
||||
|
||||
@ExecFunction(name = "minutes_add", argTypes = { "DATETIME", "INT" }, returnType = "DATETIME")
|
||||
public static DateTimeLiteral minutesAdd(DateTimeLiteral date, IntegerLiteral minute) throws AnalysisException {
|
||||
return date.plusMinutes(minute.getValue());
|
||||
}
|
||||
|
||||
@ExecFunction(name = "seconds_add", argTypes = { "DATETIME", "INT" }, returnType = "DATETIME")
|
||||
public static DateTimeLiteral secondsAdd(DateTimeLiteral date, IntegerLiteral second) throws AnalysisException {
|
||||
return date.plusSeconds(second.getValue());
|
||||
}
|
||||
|
||||
@ExecFunction(name = "years_sub", argTypes = { "DATETIME", "INT" }, returnType = "DATETIME")
|
||||
public static DateTimeLiteral yearsSub(DateTimeLiteral date, IntegerLiteral year) throws AnalysisException {
|
||||
return yearsAdd(date, new IntegerLiteral(-year.getValue()));
|
||||
}
|
||||
|
||||
@ExecFunction(name = "months_sub", argTypes = { "DATETIME", "INT" }, returnType = "DATETIME")
|
||||
public static DateTimeLiteral monthsSub(DateTimeLiteral date, IntegerLiteral month) throws AnalysisException {
|
||||
return monthsAdd(date, new IntegerLiteral(-month.getValue()));
|
||||
}
|
||||
|
||||
@ExecFunction(name = "days_sub", argTypes = { "DATETIME", "INT" }, returnType = "DATETIME")
|
||||
public static DateTimeLiteral daysSub(DateTimeLiteral date, IntegerLiteral day) throws AnalysisException {
|
||||
return daysAdd(date, new IntegerLiteral(-day.getValue()));
|
||||
}
|
||||
|
||||
@ExecFunction(name = "hours_sub", argTypes = { "DATETIME", "INT" }, returnType = "DATETIME")
|
||||
public static DateTimeLiteral hoursSub(DateTimeLiteral date, IntegerLiteral hour) throws AnalysisException {
|
||||
return hoursAdd(date, new IntegerLiteral(-hour.getValue()));
|
||||
}
|
||||
|
||||
@ExecFunction(name = "minutes_sub", argTypes = { "DATETIME", "INT" }, returnType = "DATETIME")
|
||||
public static DateTimeLiteral minutesSub(DateTimeLiteral date, IntegerLiteral minute) throws AnalysisException {
|
||||
return minutesAdd(date, new IntegerLiteral(-minute.getValue()));
|
||||
}
|
||||
|
||||
@ExecFunction(name = "seconds_sub", argTypes = { "DATETIME", "INT" }, returnType = "DATETIME")
|
||||
public static DateTimeLiteral secondsSub(DateTimeLiteral date, IntegerLiteral second) throws AnalysisException {
|
||||
return secondsAdd(date, new IntegerLiteral(-second.getValue()));
|
||||
}
|
||||
|
||||
}
|
||||
@ -20,6 +20,7 @@ package org.apache.doris.nereids.trees.expressions.functions.scalar;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.AlwaysNullable;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.Literal;
|
||||
import org.apache.doris.nereids.trees.expressions.typecoercion.ImplicitCastInputTypes;
|
||||
import org.apache.doris.nereids.types.DataType;
|
||||
import org.apache.doris.nereids.types.IntegerType;
|
||||
@ -50,7 +51,7 @@ public class Substring extends ScalarFunction implements ImplicitCastInputTypes,
|
||||
}
|
||||
|
||||
public Substring(Expression str, Expression pos) {
|
||||
super("substring", str, pos);
|
||||
super("substring", str, pos, Literal.of(Integer.MAX_VALUE));
|
||||
}
|
||||
|
||||
public Expression getSource() {
|
||||
|
||||
@ -19,14 +19,11 @@ package org.apache.doris.nereids.trees.expressions.literal;
|
||||
|
||||
import org.apache.doris.analysis.LiteralExpr;
|
||||
import org.apache.doris.nereids.exceptions.AnalysisException;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
|
||||
import org.apache.doris.nereids.types.DataType;
|
||||
import org.apache.doris.nereids.types.DateTimeType;
|
||||
import org.apache.doris.nereids.types.DateType;
|
||||
import org.apache.doris.nereids.util.DateUtils;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import org.apache.logging.log4j.LogManager;
|
||||
import org.apache.logging.log4j.Logger;
|
||||
import org.joda.time.LocalDateTime;
|
||||
@ -107,32 +104,6 @@ public class DateLiteral extends Literal {
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Expression uncheckedCastTo(DataType targetType) throws AnalysisException {
|
||||
if (getDataType().equals(targetType)) {
|
||||
return this;
|
||||
}
|
||||
if (targetType.isDate()) {
|
||||
if (getDataType().equals(targetType)) {
|
||||
return this;
|
||||
}
|
||||
if (targetType.equals(DateType.INSTANCE)) {
|
||||
return new DateLiteral(this.year, this.month, this.day);
|
||||
} else if (targetType.equals(DateTimeType.INSTANCE)) {
|
||||
return new DateTimeLiteral(this.year, this.month, this.day, 0, 0, 0);
|
||||
} else {
|
||||
throw new AnalysisException("Error date literal type");
|
||||
}
|
||||
}
|
||||
//todo other target type cast
|
||||
return this;
|
||||
}
|
||||
|
||||
public DateLiteral withDataType(DataType type) {
|
||||
Preconditions.checkArgument(type.isDate() || type.isDateTime());
|
||||
return new DateLiteral(this, type);
|
||||
}
|
||||
|
||||
@Override
|
||||
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
|
||||
return visitor.visitDateLiteral(this, context);
|
||||
@ -153,6 +124,11 @@ public class DateLiteral extends Literal {
|
||||
return String.format("%04d-%02d-%02d", year, month, day);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getStringValue() {
|
||||
return String.format("%04d-%02d-%02d", year, month, day);
|
||||
}
|
||||
|
||||
@Override
|
||||
public LiteralExpr toLegacyLiteral() {
|
||||
return new org.apache.doris.analysis.DateLiteral(year, month, day);
|
||||
@ -169,6 +145,11 @@ public class DateLiteral extends Literal {
|
||||
public long getDay() {
|
||||
return day;
|
||||
}
|
||||
|
||||
public DateLiteral plusDays(int days) {
|
||||
LocalDateTime dateTime = LocalDateTime.parse(getStringValue(), DATE_FORMATTER).plusDays(days);
|
||||
return new DateLiteral(dateTime.getYear(), dateTime.getMonthOfYear(), dateTime.getDayOfMonth());
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -19,11 +19,8 @@ package org.apache.doris.nereids.trees.expressions.literal;
|
||||
|
||||
import org.apache.doris.analysis.LiteralExpr;
|
||||
import org.apache.doris.nereids.exceptions.AnalysisException;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
|
||||
import org.apache.doris.nereids.types.DataType;
|
||||
import org.apache.doris.nereids.types.DateTimeType;
|
||||
import org.apache.doris.nereids.types.DateType;
|
||||
import org.apache.doris.nereids.util.DateUtils;
|
||||
|
||||
import org.apache.logging.log4j.LogManager;
|
||||
@ -31,6 +28,8 @@ import org.apache.logging.log4j.Logger;
|
||||
import org.joda.time.LocalDateTime;
|
||||
import org.joda.time.format.DateTimeFormatter;
|
||||
|
||||
import java.util.Objects;
|
||||
|
||||
/**
|
||||
* date time literal.
|
||||
*/
|
||||
@ -39,7 +38,8 @@ public class DateTimeLiteral extends DateLiteral {
|
||||
|
||||
private static final int DATETIME_TO_MINUTE_STRING_LENGTH = 16;
|
||||
private static final int DATETIME_TO_HOUR_STRING_LENGTH = 13;
|
||||
|
||||
private static final int DATETIME_DEFAULT_STRING_LENGTH = 10;
|
||||
private static DateTimeFormatter DATE_TIME_DEFAULT_FORMATTER = null;
|
||||
private static DateTimeFormatter DATE_TIME_FORMATTER = null;
|
||||
private static DateTimeFormatter DATE_TIME_FORMATTER_TO_HOUR = null;
|
||||
private static DateTimeFormatter DATE_TIME_FORMATTER_TO_MINUTE = null;
|
||||
@ -55,6 +55,7 @@ public class DateTimeLiteral extends DateLiteral {
|
||||
DATE_TIME_FORMATTER_TO_HOUR = DateUtils.formatBuilder("%Y-%m-%d %H").toFormatter();
|
||||
DATE_TIME_FORMATTER_TO_MINUTE = DateUtils.formatBuilder("%Y-%m-%d %H:%i").toFormatter();
|
||||
DATE_TIME_FORMATTER_TWO_DIGIT = DateUtils.formatBuilder("%y-%m-%d %H:%i:%s").toFormatter();
|
||||
DATE_TIME_DEFAULT_FORMATTER = DateUtils.formatBuilder("%Y-%m-%d").toFormatter();
|
||||
} catch (AnalysisException e) {
|
||||
LOG.error("invalid date format", e);
|
||||
System.exit(-1);
|
||||
@ -89,6 +90,8 @@ public class DateTimeLiteral extends DateLiteral {
|
||||
dateTime = DATE_TIME_FORMATTER_TO_MINUTE.parseLocalDateTime(s);
|
||||
} else if (s.length() == DATETIME_TO_HOUR_STRING_LENGTH) {
|
||||
dateTime = DATE_TIME_FORMATTER_TO_HOUR.parseLocalDateTime(s);
|
||||
} else if (s.length() == DATETIME_DEFAULT_STRING_LENGTH) {
|
||||
dateTime = DATE_TIME_DEFAULT_FORMATTER.parseLocalDateTime(s);
|
||||
} else {
|
||||
dateTime = DATE_TIME_FORMATTER.parseLocalDateTime(s);
|
||||
}
|
||||
@ -104,27 +107,6 @@ public class DateTimeLiteral extends DateLiteral {
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Expression uncheckedCastTo(DataType targetType) throws AnalysisException {
|
||||
if (getDataType().equals(targetType)) {
|
||||
return this;
|
||||
}
|
||||
if (targetType.isDate()) {
|
||||
if (getDataType().equals(targetType)) {
|
||||
return this;
|
||||
}
|
||||
if (targetType.equals(DateType.INSTANCE)) {
|
||||
return new DateLiteral(this.year, this.month, this.day);
|
||||
} else if (targetType.equals(DateTimeType.INSTANCE)) {
|
||||
return new DateTimeLiteral(this.year, this.month, this.day, this.hour, this.minute, this.second);
|
||||
} else {
|
||||
throw new AnalysisException("Error date literal type");
|
||||
}
|
||||
}
|
||||
//todo other target type cast
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
|
||||
return visitor.visitDateTimeLiteral(this, context);
|
||||
@ -145,6 +127,47 @@ public class DateTimeLiteral extends DateLiteral {
|
||||
return String.format("%04d-%02d-%02d %02d:%02d:%02d", year, month, day, hour, minute, second);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getStringValue() {
|
||||
return String.format("%04d-%02d-%02d %02d:%02d:%02d", year, month, day, hour, minute, second);
|
||||
}
|
||||
|
||||
public DateTimeLiteral plusYears(int years) {
|
||||
LocalDateTime d = LocalDateTime.parse(getStringValue(), DATE_TIME_FORMATTER).plusYears(years);
|
||||
return new DateTimeLiteral(d.getYear(), d.getMonthOfYear(), d.getDayOfMonth(),
|
||||
d.getHourOfDay(), d.getMinuteOfHour(), d.getSecondOfMinute());
|
||||
}
|
||||
|
||||
public DateTimeLiteral plusMonths(int months) {
|
||||
LocalDateTime d = LocalDateTime.parse(getStringValue(), DATE_TIME_FORMATTER).plusMonths(months);
|
||||
return new DateTimeLiteral(d.getYear(), d.getMonthOfYear(), d.getDayOfMonth(),
|
||||
d.getHourOfDay(), d.getMinuteOfHour(), d.getSecondOfMinute());
|
||||
}
|
||||
|
||||
public DateTimeLiteral plusDays(int days) {
|
||||
LocalDateTime d = LocalDateTime.parse(getStringValue(), DATE_TIME_FORMATTER).plusDays(days);
|
||||
return new DateTimeLiteral(d.getYear(), d.getMonthOfYear(), d.getDayOfMonth(),
|
||||
d.getHourOfDay(), d.getMinuteOfHour(), d.getSecondOfMinute());
|
||||
}
|
||||
|
||||
public DateTimeLiteral plusHours(int hours) {
|
||||
LocalDateTime d = LocalDateTime.parse(getStringValue(), DATE_TIME_FORMATTER).plusHours(hours);
|
||||
return new DateTimeLiteral(d.getYear(), d.getMonthOfYear(), d.getDayOfMonth(),
|
||||
d.getHourOfDay(), d.getMinuteOfHour(), d.getSecondOfMinute());
|
||||
}
|
||||
|
||||
public DateTimeLiteral plusMinutes(int minutes) {
|
||||
LocalDateTime d = LocalDateTime.parse(getStringValue(), DATE_TIME_FORMATTER).plusMinutes(minutes);
|
||||
return new DateTimeLiteral(d.getYear(), d.getMonthOfYear(), d.getDayOfMonth(),
|
||||
d.getHourOfDay(), d.getMinuteOfHour(), d.getSecondOfMinute());
|
||||
}
|
||||
|
||||
public DateTimeLiteral plusSeconds(int seconds) {
|
||||
LocalDateTime d = LocalDateTime.parse(getStringValue(), DATE_TIME_FORMATTER).plusSeconds(seconds);
|
||||
return new DateTimeLiteral(d.getYear(), d.getMonthOfYear(), d.getDayOfMonth(),
|
||||
d.getHourOfDay(), d.getMinuteOfHour(), d.getSecondOfMinute());
|
||||
}
|
||||
|
||||
@Override
|
||||
public LiteralExpr toLegacyLiteral() {
|
||||
return new org.apache.doris.analysis.DateLiteral(year, month, day, hour, minute, second);
|
||||
@ -161,4 +184,16 @@ public class DateTimeLiteral extends DateLiteral {
|
||||
public long getSecond() {
|
||||
return second;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) {
|
||||
return true;
|
||||
}
|
||||
if (o == null || getClass() != o.getClass()) {
|
||||
return false;
|
||||
}
|
||||
DateTimeLiteral other = (DateTimeLiteral) o;
|
||||
return Objects.equals(getValue(), other.getValue());
|
||||
}
|
||||
}
|
||||
|
||||
@ -18,13 +18,20 @@
|
||||
package org.apache.doris.nereids.trees.expressions.literal;
|
||||
|
||||
import org.apache.doris.analysis.LiteralExpr;
|
||||
import org.apache.doris.nereids.exceptions.AnalysisException;
|
||||
import org.apache.doris.nereids.exceptions.UnboundException;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.shape.LeafExpression;
|
||||
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
|
||||
import org.apache.doris.nereids.types.CharType;
|
||||
import org.apache.doris.nereids.types.DataType;
|
||||
import org.apache.doris.nereids.types.StringType;
|
||||
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
|
||||
import java.math.BigDecimal;
|
||||
import java.math.BigInteger;
|
||||
import java.util.Locale;
|
||||
import java.util.Objects;
|
||||
|
||||
/**
|
||||
@ -51,17 +58,21 @@ public abstract class Literal extends Expression implements LeafExpression {
|
||||
if (value == null) {
|
||||
return new NullLiteral();
|
||||
} else if (value instanceof Byte) {
|
||||
return new TinyIntLiteral(((Byte) value).byteValue());
|
||||
return new TinyIntLiteral((byte) value);
|
||||
} else if (value instanceof Short) {
|
||||
return new SmallIntLiteral(((Short) value).shortValue());
|
||||
return new SmallIntLiteral((short) value);
|
||||
} else if (value instanceof Integer) {
|
||||
return new IntegerLiteral((Integer) value);
|
||||
return new IntegerLiteral((int) value);
|
||||
} else if (value instanceof Long) {
|
||||
return new BigIntLiteral(((Long) value).longValue());
|
||||
return new BigIntLiteral((long) value);
|
||||
} else if (value instanceof BigInteger) {
|
||||
return new LargeIntLiteral((BigInteger) value);
|
||||
} else if (value instanceof Float) {
|
||||
return new FloatLiteral((float) value);
|
||||
} else if (value instanceof Double) {
|
||||
return new DoubleLiteral((double) value);
|
||||
} else if (value instanceof Boolean) {
|
||||
return BooleanLiteral.of((Boolean) value);
|
||||
return BooleanLiteral.of((boolean) value);
|
||||
} else if (value instanceof String) {
|
||||
return new StringLiteral((String) value);
|
||||
} else {
|
||||
@ -71,6 +82,10 @@ public abstract class Literal extends Expression implements LeafExpression {
|
||||
|
||||
public abstract Object getValue();
|
||||
|
||||
public String getStringValue() {
|
||||
return String.valueOf(getValue());
|
||||
}
|
||||
|
||||
@Override
|
||||
public DataType getDataType() throws UnboundException {
|
||||
return dataType;
|
||||
@ -91,6 +106,89 @@ public abstract class Literal extends Expression implements LeafExpression {
|
||||
return visitor.visitLiteral(this, context);
|
||||
}
|
||||
|
||||
/**
|
||||
* literal expr compare.
|
||||
*/
|
||||
public int compareTo(Literal other) {
|
||||
if (isNullLiteral() && other.isNullLiteral()) {
|
||||
return 0;
|
||||
} else if (isNullLiteral() || other.isNullLiteral()) {
|
||||
return isNullLiteral() ? -1 : 1;
|
||||
}
|
||||
|
||||
DataType oType = other.getDataType();
|
||||
DataType type = getDataType();
|
||||
if (!type.equals(oType)) {
|
||||
throw new RuntimeException("data type not equal!");
|
||||
} else if (type.isBooleanType()) {
|
||||
return Boolean.compare((boolean) getValue(), (boolean) other.getValue());
|
||||
} else if (type.isTinyIntType()) {
|
||||
return Byte.compare((byte) getValue(), (byte) other.getValue());
|
||||
} else if (type.isSmallIntType()) {
|
||||
return Short.compare((short) getValue(), (short) other.getValue());
|
||||
} else if (type.isIntType()) {
|
||||
return Integer.compare((int) getValue(), (int) other.getValue());
|
||||
} else if (type.isBigIntType()) {
|
||||
return Long.compare((long) getValue(), (long) other.getValue());
|
||||
} else if (type.isLargeIntType()) {
|
||||
return ((BigInteger) getValue()).compareTo((BigInteger) other.getValue());
|
||||
} else if (type.isFloatType()) {
|
||||
return Float.compare((float) getValue(), (float) other.getValue());
|
||||
} else if (type.isDoubleType()) {
|
||||
return Double.compare((double) getValue(), (double) other.getValue());
|
||||
} else if (type.isDecimalType()) {
|
||||
return Long.compare((Long) getValue(), (Long) other.getValue());
|
||||
} else if (type.isDateType()) {
|
||||
// todo process date
|
||||
} else if (type.isDecimalType()) {
|
||||
return ((BigDecimal) getValue()).compareTo((BigDecimal) other.getValue());
|
||||
} else if (type instanceof StringType) {
|
||||
return StringUtils.compare((String) getValue(), (String) other.getValue());
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Expression uncheckedCastTo(DataType targetType) throws AnalysisException {
|
||||
String desc = getStringValue();
|
||||
if (targetType.isBooleanType()) {
|
||||
if ("0".equals(desc) || "false".equals(desc.toLowerCase(Locale.ROOT))) {
|
||||
return Literal.of(false);
|
||||
}
|
||||
if ("1".equals(desc) || "true".equals(desc.toLowerCase(Locale.ROOT))) {
|
||||
return Literal.of(true);
|
||||
}
|
||||
}
|
||||
if (targetType.isTinyIntType()) {
|
||||
return Literal.of(Double.valueOf(desc).byteValue());
|
||||
} else if (targetType.isSmallIntType()) {
|
||||
return Literal.of(Double.valueOf(desc).shortValue());
|
||||
} else if (targetType.isIntType()) {
|
||||
return Literal.of(Double.valueOf(desc).intValue());
|
||||
} else if (targetType.isBigIntType()) {
|
||||
return Literal.of(Double.valueOf(desc).longValue());
|
||||
} else if (targetType.isLargeIntType()) {
|
||||
return Literal.of(new BigInteger(desc));
|
||||
} else if (targetType.isFloatType()) {
|
||||
return Literal.of(Float.parseFloat(desc));
|
||||
} else if (targetType.isDoubleType()) {
|
||||
return Literal.of(Double.parseDouble(desc));
|
||||
} else if (targetType.isCharType()) {
|
||||
return new CharLiteral(desc, ((CharType) targetType).getLen());
|
||||
} else if (targetType.isVarcharType()) {
|
||||
return new VarcharLiteral(desc, desc.length());
|
||||
} else if (targetType.isStringType()) {
|
||||
return Literal.of(desc);
|
||||
} else if (targetType.isDate()) {
|
||||
return new DateLiteral(desc);
|
||||
} else if (targetType.isDateTime()) {
|
||||
return new DateTimeLiteral(desc);
|
||||
} else if (targetType.isDecimalType()) {
|
||||
return new DecimalLiteral(BigDecimal.valueOf(Double.parseDouble(desc)));
|
||||
}
|
||||
throw new AnalysisException("no support cast!");
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) {
|
||||
@ -114,4 +212,5 @@ public abstract class Literal extends Expression implements LeafExpression {
|
||||
}
|
||||
|
||||
public abstract LiteralExpr toLegacyLiteral();
|
||||
|
||||
}
|
||||
|
||||
@ -18,10 +18,7 @@
|
||||
package org.apache.doris.nereids.trees.expressions.literal;
|
||||
|
||||
import org.apache.doris.analysis.LiteralExpr;
|
||||
import org.apache.doris.nereids.exceptions.AnalysisException;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
|
||||
import org.apache.doris.nereids.types.DataType;
|
||||
import org.apache.doris.nereids.types.StringType;
|
||||
|
||||
/**
|
||||
@ -56,30 +53,6 @@ public class StringLiteral extends Literal {
|
||||
return new org.apache.doris.analysis.StringLiteral(value);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Expression uncheckedCastTo(DataType targetType) throws AnalysisException {
|
||||
if (getDataType().equals(targetType)) {
|
||||
return this;
|
||||
}
|
||||
if (targetType.isDateType()) {
|
||||
return convertToDate(targetType);
|
||||
} else if (targetType.isIntType()) {
|
||||
return new IntegerLiteral(Integer.parseInt(value));
|
||||
}
|
||||
//todo other target type cast
|
||||
return this;
|
||||
}
|
||||
|
||||
private DateLiteral convertToDate(DataType targetType) throws AnalysisException {
|
||||
DateLiteral dateLiteral = null;
|
||||
if (targetType.isDate()) {
|
||||
dateLiteral = new DateLiteral(value);
|
||||
} else if (targetType.isDateTime()) {
|
||||
dateLiteral = new DateTimeLiteral(value);
|
||||
}
|
||||
return dateLiteral;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "'" + value + "'";
|
||||
|
||||
@ -20,7 +20,6 @@ package org.apache.doris.nereids.trees.expressions.literal;
|
||||
import org.apache.doris.analysis.LiteralExpr;
|
||||
import org.apache.doris.analysis.StringLiteral;
|
||||
import org.apache.doris.nereids.exceptions.AnalysisException;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
|
||||
import org.apache.doris.nereids.types.DataType;
|
||||
import org.apache.doris.nereids.types.VarcharType;
|
||||
@ -68,20 +67,6 @@ public class VarcharLiteral extends Literal {
|
||||
return "'" + value + "'";
|
||||
}
|
||||
|
||||
// Temporary way to process type coercion in TimestampArithmetic, should be replaced by TypeCoercion rule.
|
||||
@Override
|
||||
protected Expression uncheckedCastTo(DataType targetType) throws AnalysisException {
|
||||
if (getDataType().equals(targetType)) {
|
||||
return this;
|
||||
}
|
||||
if (targetType.isDateType()) {
|
||||
return convertToDate(targetType);
|
||||
} else if (targetType.isIntType()) {
|
||||
return new IntegerLiteral(Integer.parseInt(value));
|
||||
}
|
||||
return this;
|
||||
}
|
||||
|
||||
private DateLiteral convertToDate(DataType targetType) throws AnalysisException {
|
||||
DateLiteral dateLiteral = null;
|
||||
if (targetType.isDate()) {
|
||||
|
||||
@ -40,6 +40,7 @@ import org.apache.doris.nereids.trees.expressions.GreaterThan;
|
||||
import org.apache.doris.nereids.trees.expressions.GreaterThanEqual;
|
||||
import org.apache.doris.nereids.trees.expressions.InPredicate;
|
||||
import org.apache.doris.nereids.trees.expressions.InSubquery;
|
||||
import org.apache.doris.nereids.trees.expressions.IsNull;
|
||||
import org.apache.doris.nereids.trees.expressions.LessThan;
|
||||
import org.apache.doris.nereids.trees.expressions.LessThanEqual;
|
||||
import org.apache.doris.nereids.trees.expressions.Like;
|
||||
@ -277,6 +278,10 @@ public abstract class ExpressionVisitor<R, C> {
|
||||
return visit(inPredicate, context);
|
||||
}
|
||||
|
||||
public R visitIsNull(IsNull isNull, C context) {
|
||||
return visit(isNull, context);
|
||||
}
|
||||
|
||||
public R visitInSubquery(InSubquery in, C context) {
|
||||
return visitSubqueryExpr(in, context);
|
||||
}
|
||||
|
||||
@ -165,6 +165,12 @@ public abstract class DataType implements AbstractDataType {
|
||||
return FloatType.INSTANCE;
|
||||
case "double":
|
||||
return DoubleType.INSTANCE;
|
||||
case "decimal":
|
||||
return DecimalType.SYSTEM_DEFAULT;
|
||||
case "char":
|
||||
return CharType.INSTANCE;
|
||||
case "varchar":
|
||||
return VarcharType.SYSTEM_DEFAULT;
|
||||
case "text":
|
||||
case "string":
|
||||
return StringType.INSTANCE;
|
||||
@ -303,18 +309,50 @@ public abstract class DataType implements AbstractDataType {
|
||||
return 0;
|
||||
}
|
||||
|
||||
public boolean isDate() {
|
||||
return this instanceof DateType;
|
||||
public boolean isBooleanType() {
|
||||
return this instanceof BooleanType;
|
||||
}
|
||||
|
||||
public boolean isTinyIntType() {
|
||||
return this instanceof TinyIntType;
|
||||
}
|
||||
|
||||
public boolean isSmallIntType() {
|
||||
return this instanceof SmallIntType;
|
||||
}
|
||||
|
||||
public boolean isIntType() {
|
||||
return this instanceof IntegerType;
|
||||
}
|
||||
|
||||
public boolean isBigIntType() {
|
||||
return this instanceof BigIntType;
|
||||
}
|
||||
|
||||
public boolean isLargeIntType() {
|
||||
return this instanceof LargeIntType;
|
||||
}
|
||||
|
||||
public boolean isFloatType() {
|
||||
return this instanceof FloatType;
|
||||
}
|
||||
|
||||
public boolean isDoubleType() {
|
||||
return this instanceof DoubleType;
|
||||
}
|
||||
|
||||
public boolean isDecimalType() {
|
||||
return this instanceof DecimalType;
|
||||
}
|
||||
|
||||
public boolean isDateTime() {
|
||||
return this instanceof DateTimeType;
|
||||
}
|
||||
|
||||
public boolean isDate() {
|
||||
return this instanceof DateType;
|
||||
}
|
||||
|
||||
public boolean isDateType() {
|
||||
return isDate() || isDateTime();
|
||||
}
|
||||
@ -327,6 +365,14 @@ public abstract class DataType implements AbstractDataType {
|
||||
return this instanceof NumericType;
|
||||
}
|
||||
|
||||
public boolean isCharType() {
|
||||
return this instanceof CharType;
|
||||
}
|
||||
|
||||
public boolean isVarcharType() {
|
||||
return this instanceof VarcharType;
|
||||
}
|
||||
|
||||
public boolean isStringType() {
|
||||
return this instanceof CharacterType;
|
||||
}
|
||||
|
||||
@ -26,6 +26,8 @@ import org.apache.doris.nereids.trees.expressions.Or;
|
||||
import org.apache.doris.nereids.trees.expressions.Slot;
|
||||
import org.apache.doris.nereids.trees.expressions.SlotReference;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.Literal;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
|
||||
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
@ -258,4 +260,20 @@ public class ExpressionUtils {
|
||||
}
|
||||
return builder.build();
|
||||
}
|
||||
|
||||
public static boolean isAllLiteral(Expression... children) {
|
||||
return Arrays.stream(children).allMatch(c -> c instanceof Literal);
|
||||
}
|
||||
|
||||
public static boolean isAllLiteral(List<Expression> children) {
|
||||
return children.stream().allMatch(c -> c instanceof Literal);
|
||||
}
|
||||
|
||||
public static boolean hasNullLiteral(List<Expression> children) {
|
||||
return children.stream().anyMatch(c -> c instanceof NullLiteral);
|
||||
}
|
||||
|
||||
public static boolean isAllNullLiteral(List<Expression> children) {
|
||||
return children.stream().allMatch(c -> c instanceof NullLiteral);
|
||||
}
|
||||
}
|
||||
|
||||
@ -18,6 +18,7 @@
|
||||
package org.apache.doris.nereids.util;
|
||||
|
||||
import org.apache.doris.nereids.annotation.Developing;
|
||||
import org.apache.doris.nereids.exceptions.AnalysisException;
|
||||
import org.apache.doris.nereids.trees.expressions.Cast;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.Literal;
|
||||
@ -123,6 +124,10 @@ public class TypeCoercionUtils {
|
||||
} else if (expected instanceof DateTimeType) {
|
||||
returnType = DateTimeType.INSTANCE;
|
||||
}
|
||||
} else if (input.isDate()) {
|
||||
if (expected instanceof DateTimeType) {
|
||||
returnType = expected.defaultConcreteType();
|
||||
}
|
||||
}
|
||||
|
||||
if (returnType == null && input instanceof PrimitiveType
|
||||
@ -202,6 +207,39 @@ public class TypeCoercionUtils {
|
||||
return Optional.ofNullable(tightestCommonType);
|
||||
}
|
||||
|
||||
/**
|
||||
* The type used for arithmetic operations.
|
||||
*/
|
||||
public static DataType getNumResultType(DataType type) {
|
||||
if (type.isTinyIntType() || type.isSmallIntType() || type.isIntType() || type.isBigIntType()) {
|
||||
return BigIntType.INSTANCE;
|
||||
} else if (type.isLargeIntType()) {
|
||||
return LargeIntType.INSTANCE;
|
||||
} else if (type.isFloatType() || type.isDoubleType() || type.isStringType()) {
|
||||
return DoubleType.INSTANCE;
|
||||
} else if (type.isDecimalType()) {
|
||||
return DecimalType.SYSTEM_DEFAULT;
|
||||
} else if (type.isNullType()) {
|
||||
return NullType.INSTANCE;
|
||||
}
|
||||
throw new AnalysisException("no found appropriate data type.");
|
||||
}
|
||||
|
||||
/**
|
||||
* The common type used by arithmetic operations.
|
||||
*/
|
||||
public static DataType findCommonNumericsType(DataType t1, DataType t2) {
|
||||
if (t1.isDoubleType() || t2.isDoubleType()) {
|
||||
return DoubleType.INSTANCE;
|
||||
} else if (t1.isDecimalType() || t2.isDecimalType()) {
|
||||
return DecimalType.SYSTEM_DEFAULT;
|
||||
} else if (t1.isLargeIntType() || t2.isLargeIntType()) {
|
||||
return LargeIntType.INSTANCE;
|
||||
} else {
|
||||
return BigIntType.INSTANCE;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* find wider common type for data type list.
|
||||
*/
|
||||
|
||||
@ -47,7 +47,9 @@ public class SSBJoinReorderTest extends SSBTestBase {
|
||||
"(lo_suppkey = s_suppkey)",
|
||||
"(lo_partkey = p_partkey)"
|
||||
),
|
||||
ImmutableList.of("(((c_region = 'AMERICA') AND (s_region = 'AMERICA')) AND ((p_mfgr = 'MFGR#1') OR (p_mfgr = 'MFGR#2')))")
|
||||
ImmutableList.of("(((CAST(c_region AS STRING) = CAST('AMERICA' AS STRING)) AND (CAST(s_region AS STRING) "
|
||||
+ "= CAST('AMERICA' AS STRING))) AND ((CAST(p_mfgr AS STRING) = CAST('MFGR#1' AS STRING)) "
|
||||
+ "OR (CAST(p_mfgr AS STRING) = CAST('MFGR#2' AS STRING))))")
|
||||
);
|
||||
}
|
||||
|
||||
@ -62,7 +64,10 @@ public class SSBJoinReorderTest extends SSBTestBase {
|
||||
"(lo_partkey = p_partkey)"
|
||||
),
|
||||
ImmutableList.of(
|
||||
"((((c_region = 'AMERICA') AND (s_region = 'AMERICA')) AND ((d_year = 1997) OR (d_year = 1998))) AND ((p_mfgr = 'MFGR#1') OR (p_mfgr = 'MFGR#2')))")
|
||||
"((((CAST(c_region AS STRING) = CAST('AMERICA' AS STRING)) AND (CAST(s_region AS STRING) "
|
||||
+ "= CAST('AMERICA' AS STRING))) AND ((d_year = 1997) OR (d_year = 1998))) "
|
||||
+ "AND ((CAST(p_mfgr AS STRING) = CAST('MFGR#1' AS STRING)) OR (CAST(p_mfgr AS STRING) "
|
||||
+ "= CAST('MFGR#2' AS STRING))))")
|
||||
);
|
||||
}
|
||||
|
||||
@ -76,7 +81,8 @@ public class SSBJoinReorderTest extends SSBTestBase {
|
||||
"(lo_suppkey = s_suppkey)",
|
||||
"(lo_partkey = p_partkey)"
|
||||
),
|
||||
ImmutableList.of("(((s_nation = 'UNITED STATES') AND ((d_year = 1997) OR (d_year = 1998))) AND (p_category = 'MFGR#14'))")
|
||||
ImmutableList.of("(((CAST(s_nation AS STRING) = CAST('UNITED STATES' AS STRING)) AND ((d_year = 1997) "
|
||||
+ "OR (d_year = 1998))) AND (CAST(p_category AS STRING) = CAST('MFGR#14' AS STRING)))")
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
@ -25,6 +25,7 @@ import org.apache.doris.nereids.rules.expression.rewrite.rules.TypeCoercion;
|
||||
import org.apache.doris.nereids.trees.expressions.Add;
|
||||
import org.apache.doris.nereids.trees.expressions.Alias;
|
||||
import org.apache.doris.nereids.trees.expressions.And;
|
||||
import org.apache.doris.nereids.trees.expressions.Cast;
|
||||
import org.apache.doris.nereids.trees.expressions.ExprId;
|
||||
import org.apache.doris.nereids.trees.expressions.GreaterThan;
|
||||
import org.apache.doris.nereids.trees.expressions.NamedExpressionUtil;
|
||||
@ -34,6 +35,7 @@ import org.apache.doris.nereids.trees.expressions.functions.agg.Min;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.Literal;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.TinyIntLiteral;
|
||||
import org.apache.doris.nereids.types.BigIntType;
|
||||
import org.apache.doris.nereids.types.TinyIntType;
|
||||
import org.apache.doris.nereids.util.FieldChecker;
|
||||
import org.apache.doris.nereids.util.PatternMatchSupported;
|
||||
@ -168,7 +170,7 @@ public class HavingClauseTest extends AnalyzeCheckTestBase implements PatternMat
|
||||
logicalAggregate(
|
||||
logicalOlapScan()
|
||||
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2)))
|
||||
).when(FieldChecker.check("predicates", new GreaterThan(sumA2.toSlot(), new TinyIntLiteral((byte) 0))))
|
||||
).when(FieldChecker.check("predicates", new GreaterThan(sumA2.toSlot(), Literal.of(0L))))
|
||||
).when(FieldChecker.check("projects", Lists.newArrayList(a1.toSlot()))));
|
||||
NamedExpressionUtil.clear();
|
||||
|
||||
@ -181,7 +183,7 @@ public class HavingClauseTest extends AnalyzeCheckTestBase implements PatternMat
|
||||
logicalAggregate(
|
||||
logicalOlapScan()
|
||||
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2)))
|
||||
).when(FieldChecker.check("predicates", new GreaterThan(sumA2.toSlot(), new TinyIntLiteral((byte) 0))))));
|
||||
).when(FieldChecker.check("predicates", new GreaterThan(sumA2.toSlot(), Literal.of(0L))))));
|
||||
NamedExpressionUtil.clear();
|
||||
|
||||
sql = "SELECT a1, SUM(a2) as value FROM t1 GROUP BY a1 HAVING SUM(a2) > 0";
|
||||
@ -201,7 +203,7 @@ public class HavingClauseTest extends AnalyzeCheckTestBase implements PatternMat
|
||||
logicalAggregate(
|
||||
logicalOlapScan()
|
||||
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, value)))
|
||||
).when(FieldChecker.check("predicates", new GreaterThan(value.toSlot(), new TinyIntLiteral((byte) 0))))));
|
||||
).when(FieldChecker.check("predicates", new GreaterThan(value.toSlot(), Literal.of(0L))))));
|
||||
NamedExpressionUtil.clear();
|
||||
|
||||
sql = "SELECT a1, SUM(a2) as value FROM t1 GROUP BY a1 HAVING value > 0";
|
||||
@ -212,7 +214,7 @@ public class HavingClauseTest extends AnalyzeCheckTestBase implements PatternMat
|
||||
logicalAggregate(
|
||||
logicalOlapScan()
|
||||
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, value)))
|
||||
).when(FieldChecker.check("predicates", new GreaterThan(value.toSlot(), new TinyIntLiteral((byte) 0))))));
|
||||
).when(FieldChecker.check("predicates", new GreaterThan(value.toSlot(), Literal.of(0L))))));
|
||||
NamedExpressionUtil.clear();
|
||||
|
||||
sql = "SELECT a1, SUM(a2) FROM t1 GROUP BY a1 HAVING MIN(pk) > 0";
|
||||
@ -237,7 +239,7 @@ public class HavingClauseTest extends AnalyzeCheckTestBase implements PatternMat
|
||||
logicalAggregate(
|
||||
logicalOlapScan()
|
||||
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2, minPK)))
|
||||
).when(FieldChecker.check("predicates", new GreaterThan(minPK.toSlot(), new TinyIntLiteral((byte) 0))))
|
||||
).when(FieldChecker.check("predicates", new GreaterThan(minPK.toSlot(), Literal.of((byte) 0))))
|
||||
).when(FieldChecker.check("projects", Lists.newArrayList(a1.toSlot(), sumA2.toSlot()))));
|
||||
NamedExpressionUtil.clear();
|
||||
|
||||
@ -250,7 +252,7 @@ public class HavingClauseTest extends AnalyzeCheckTestBase implements PatternMat
|
||||
logicalAggregate(
|
||||
logicalOlapScan()
|
||||
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA1A2)))
|
||||
).when(FieldChecker.check("predicates", new GreaterThan(sumA1A2.toSlot(), new TinyIntLiteral((byte) 0))))));
|
||||
).when(FieldChecker.check("predicates", new GreaterThan(sumA1A2.toSlot(), Literal.of(0L))))));
|
||||
NamedExpressionUtil.clear();
|
||||
|
||||
sql = "SELECT a1, SUM(a1 + a2) FROM t1 GROUP BY a1 HAVING SUM(a1 + a2 + 3) > 0";
|
||||
@ -263,7 +265,7 @@ public class HavingClauseTest extends AnalyzeCheckTestBase implements PatternMat
|
||||
logicalAggregate(
|
||||
logicalOlapScan()
|
||||
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA1A2, sumA1A23)))
|
||||
).when(FieldChecker.check("predicates", new GreaterThan(sumA1A23.toSlot(), new TinyIntLiteral((byte) 0))))
|
||||
).when(FieldChecker.check("predicates", new GreaterThan(sumA1A23.toSlot(), Literal.of(0L))))
|
||||
).when(FieldChecker.check("projects", Lists.newArrayList(a1.toSlot(), sumA1A2.toSlot()))));
|
||||
NamedExpressionUtil.clear();
|
||||
|
||||
@ -276,7 +278,7 @@ public class HavingClauseTest extends AnalyzeCheckTestBase implements PatternMat
|
||||
logicalAggregate(
|
||||
logicalOlapScan()
|
||||
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, countStar)))
|
||||
).when(FieldChecker.check("predicates", new GreaterThan(countStar.toSlot(), new TinyIntLiteral((byte) 0))))
|
||||
).when(FieldChecker.check("predicates", new GreaterThan(countStar.toSlot(), Literal.of(0L))))
|
||||
).when(FieldChecker.check("projects", Lists.newArrayList(a1.toSlot()))));
|
||||
NamedExpressionUtil.clear();
|
||||
}
|
||||
@ -310,7 +312,8 @@ public class HavingClauseTest extends AnalyzeCheckTestBase implements PatternMat
|
||||
)
|
||||
)
|
||||
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2, sumB1)))
|
||||
).when(FieldChecker.check("predicates", new GreaterThan(a1, sumB1.toSlot())))
|
||||
).when(FieldChecker.check("predicates", new GreaterThan(new Cast(a1, BigIntType.INSTANCE),
|
||||
sumB1.toSlot())))
|
||||
).when(FieldChecker.check("projects", Lists.newArrayList(a1.toSlot(), sumA2.toSlot()))));
|
||||
NamedExpressionUtil.clear();
|
||||
}
|
||||
@ -353,14 +356,14 @@ public class HavingClauseTest extends AnalyzeCheckTestBase implements PatternMat
|
||||
ImmutableList.of("default_cluster:test_having", "t1")
|
||||
);
|
||||
SlotReference a2 = new SlotReference(
|
||||
new ExprId(3), "a1", TinyIntType.INSTANCE, true,
|
||||
new ExprId(3), "a2", TinyIntType.INSTANCE, true,
|
||||
ImmutableList.of("default_cluster:test_having", "t1")
|
||||
);
|
||||
Alias pk1 = new Alias(new ExprId(7), new Add(pk, Literal.of((byte) 1)), "(pk + 1)");
|
||||
Alias pk11 = new Alias(new ExprId(8), new Add(new Add(pk, Literal.of((byte) 1)), Literal.of((byte) 1)), "((pk + 1) + 1)");
|
||||
Alias pk2 = new Alias(new ExprId(9), new Add(pk, Literal.of((byte) 2)), "(pk + 2)");
|
||||
Alias sumA1 = new Alias(new ExprId(10), new Sum(a1), "SUM(a1)");
|
||||
Alias countA11 = new Alias(new ExprId(11), new Add(new Count(a1, false), Literal.of((byte) 1)), "(COUNT(a1) + 1)");
|
||||
Alias countA11 = new Alias(new ExprId(11), new Add(new Count(a1, false), Literal.of(1L)), "(COUNT(a1) + 1)");
|
||||
Alias sumA1A2 = new Alias(new ExprId(12), new Sum(new Add(a1, a2)), "SUM((a1 + a2))");
|
||||
Alias v1 = new Alias(new ExprId(0), new Count(a2, false), "v1");
|
||||
PlanChecker.from(connectContext).analyze(sql)
|
||||
@ -382,11 +385,11 @@ public class HavingClauseTest extends AnalyzeCheckTestBase implements PatternMat
|
||||
new And(
|
||||
new And(
|
||||
new GreaterThan(pk.toSlot(), Literal.of((byte) 0)),
|
||||
new GreaterThan(countA11.toSlot(), Literal.of((byte) 0))),
|
||||
new GreaterThan(new Add(sumA1A2.toSlot(), Literal.of((byte) 1)), Literal.of((byte) 0))),
|
||||
new GreaterThan(new Add(v1.toSlot(), Literal.of((byte) 1)), Literal.of((byte) 0))
|
||||
new GreaterThan(countA11.toSlot(), Literal.of(0L))),
|
||||
new GreaterThan(new Add(sumA1A2.toSlot(), Literal.of(1L)), Literal.of(0L))),
|
||||
new GreaterThan(new Add(v1.toSlot(), Literal.of(1L)), Literal.of(0L))
|
||||
),
|
||||
new GreaterThan(v1.toSlot(), Literal.of((byte) 0))
|
||||
new GreaterThan(v1.toSlot(), Literal.of(0L))
|
||||
)
|
||||
))
|
||||
).when(FieldChecker.check(
|
||||
|
||||
@ -53,7 +53,7 @@ public class FunctionRegistryTest implements PatternMatchSupported {
|
||||
logicalOneRowRelation().when(r -> {
|
||||
Year year = (Year) r.getProjects().get(0).child(0);
|
||||
Assertions.assertEquals("2021-01-01",
|
||||
((Literal) year.getArguments().get(0)).getValue());
|
||||
((Literal) year.getArguments().get(0).child(0)).getValue());
|
||||
return true;
|
||||
})
|
||||
);
|
||||
@ -70,13 +70,13 @@ public class FunctionRegistryTest implements PatternMatchSupported {
|
||||
logicalOneRowRelation().when(r -> {
|
||||
Substring firstSubstring = (Substring) r.getProjects().get(0).child(0);
|
||||
Assertions.assertEquals("abc", ((Literal) firstSubstring.getSource()).getValue());
|
||||
Assertions.assertEquals((byte) 1, ((Literal) firstSubstring.getPosition()).getValue());
|
||||
Assertions.assertEquals((byte) 2, ((Literal) firstSubstring.getLength().get()).getValue());
|
||||
Assertions.assertEquals((byte) 1, ((Literal) firstSubstring.getPosition().child(0)).getValue());
|
||||
Assertions.assertEquals((byte) 2, ((Literal) firstSubstring.getLength().get().child(0)).getValue());
|
||||
|
||||
Substring secondSubstring = (Substring) r.getProjects().get(1).child(0);
|
||||
Assertions.assertTrue(secondSubstring.getSource() instanceof Substring);
|
||||
Assertions.assertEquals((byte) 1, ((Literal) secondSubstring.getPosition()).getValue());
|
||||
Assertions.assertEquals((byte) 2, ((Literal) secondSubstring.getLength().get()).getValue());
|
||||
Assertions.assertEquals((byte) 1, ((Literal) secondSubstring.getPosition().child(0)).getValue());
|
||||
Assertions.assertEquals((byte) 2, ((Literal) secondSubstring.getLength().get().child(0)).getValue());
|
||||
return true;
|
||||
})
|
||||
);
|
||||
@ -93,13 +93,13 @@ public class FunctionRegistryTest implements PatternMatchSupported {
|
||||
logicalOneRowRelation().when(r -> {
|
||||
Substring firstSubstring = (Substring) r.getProjects().get(0).child(0);
|
||||
Assertions.assertEquals("abc", ((Literal) firstSubstring.getSource()).getValue());
|
||||
Assertions.assertEquals((byte) 1, ((Literal) firstSubstring.getPosition()).getValue());
|
||||
Assertions.assertFalse(firstSubstring.getLength().isPresent());
|
||||
Assertions.assertEquals((byte) 1, ((Literal) firstSubstring.getPosition().child(0)).getValue());
|
||||
Assertions.assertTrue(firstSubstring.getLength().isPresent());
|
||||
|
||||
Substring secondSubstring = (Substring) r.getProjects().get(1).child(0);
|
||||
Assertions.assertEquals("def", ((Literal) secondSubstring.getSource()).getValue());
|
||||
Assertions.assertEquals((byte) 2, ((Literal) secondSubstring.getPosition()).getValue());
|
||||
Assertions.assertEquals((byte) 3, ((Literal) secondSubstring.getLength().get()).getValue());
|
||||
Assertions.assertEquals((byte) 2, ((Literal) secondSubstring.getPosition().child(0)).getValue());
|
||||
Assertions.assertEquals((byte) 3, ((Literal) secondSubstring.getLength().get().child(0)).getValue());
|
||||
return true;
|
||||
})
|
||||
);
|
||||
|
||||
@ -40,7 +40,7 @@ public class ExpressionRewriteTest {
|
||||
|
||||
@Test
|
||||
public void testNotRewrite() {
|
||||
executor = new ExpressionRuleExecutor(SimplifyNotExprRule.INSTANCE);
|
||||
executor = new ExpressionRuleExecutor(ImmutableList.of(SimplifyNotExprRule.INSTANCE));
|
||||
|
||||
assertRewrite("not x", "not x");
|
||||
assertRewrite("not not x", "x");
|
||||
@ -65,7 +65,7 @@ public class ExpressionRewriteTest {
|
||||
|
||||
@Test
|
||||
public void testNormalizeExpressionRewrite() {
|
||||
executor = new ExpressionRuleExecutor(NormalizeBinaryPredicatesRule.INSTANCE);
|
||||
executor = new ExpressionRuleExecutor(ImmutableList.of(NormalizeBinaryPredicatesRule.INSTANCE));
|
||||
|
||||
assertRewrite("1 = 1", "1 = 1");
|
||||
assertRewrite("2 > x", "x < 2");
|
||||
@ -77,7 +77,7 @@ public class ExpressionRewriteTest {
|
||||
|
||||
@Test
|
||||
public void testDistinctPredicatesRewrite() {
|
||||
executor = new ExpressionRuleExecutor(DistinctPredicatesRule.INSTANCE);
|
||||
executor = new ExpressionRuleExecutor(ImmutableList.of(DistinctPredicatesRule.INSTANCE));
|
||||
|
||||
assertRewrite("a = 1", "a = 1");
|
||||
assertRewrite("a = 1 and a = 1", "a = 1");
|
||||
@ -89,7 +89,7 @@ public class ExpressionRewriteTest {
|
||||
|
||||
@Test
|
||||
public void testExtractCommonFactorRewrite() {
|
||||
executor = new ExpressionRuleExecutor(ExtractCommonFactorRule.INSTANCE);
|
||||
executor = new ExpressionRuleExecutor(ImmutableList.of(ExtractCommonFactorRule.INSTANCE));
|
||||
|
||||
assertRewrite("a", "a");
|
||||
|
||||
@ -142,7 +142,8 @@ public class ExpressionRewriteTest {
|
||||
|
||||
@Test
|
||||
public void testBetweenToCompoundRule() {
|
||||
executor = new ExpressionRuleExecutor(ImmutableList.of(BetweenToCompoundRule.INSTANCE, SimplifyNotExprRule.INSTANCE));
|
||||
executor = new ExpressionRuleExecutor(ImmutableList.of(BetweenToCompoundRule.INSTANCE,
|
||||
SimplifyNotExprRule.INSTANCE));
|
||||
|
||||
assertRewrite("a between c and d", "(a >= c) and (a <= d)");
|
||||
assertRewrite("a not between c and d)", "(a < c) or (a > d)");
|
||||
|
||||
@ -0,0 +1,341 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package org.apache.doris.nereids.rules.expression.rewrite;
|
||||
|
||||
import org.apache.doris.analysis.ArithmeticExpr.Operator;
|
||||
import org.apache.doris.nereids.analyzer.UnboundSlot;
|
||||
import org.apache.doris.nereids.parser.NereidsParser;
|
||||
import org.apache.doris.nereids.rules.expression.rewrite.rules.FoldConstantRuleOnFE;
|
||||
import org.apache.doris.nereids.rules.expression.rewrite.rules.TypeCoercion;
|
||||
import org.apache.doris.nereids.trees.expressions.Add;
|
||||
import org.apache.doris.nereids.trees.expressions.Cast;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.GreaterThan;
|
||||
import org.apache.doris.nereids.trees.expressions.Slot;
|
||||
import org.apache.doris.nereids.trees.expressions.SlotReference;
|
||||
import org.apache.doris.nereids.trees.expressions.TimestampArithmetic;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.DateTimeLiteral;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.IntervalLiteral.TimeUnit;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.Literal;
|
||||
import org.apache.doris.nereids.types.BigIntType;
|
||||
import org.apache.doris.nereids.types.BooleanType;
|
||||
import org.apache.doris.nereids.types.DataType;
|
||||
import org.apache.doris.nereids.types.DateTimeType;
|
||||
import org.apache.doris.nereids.types.DoubleType;
|
||||
import org.apache.doris.nereids.types.IntegerType;
|
||||
import org.apache.doris.nereids.types.StringType;
|
||||
import org.apache.doris.nereids.types.TinyIntType;
|
||||
import org.apache.doris.nereids.types.VarcharType;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.google.common.collect.Maps;
|
||||
import org.junit.jupiter.api.Assertions;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Locale;
|
||||
import java.util.Map;
|
||||
|
||||
public class FoldConstantTest {
|
||||
|
||||
private static final NereidsParser PARSER = new NereidsParser();
|
||||
private ExpressionRuleExecutor executor;
|
||||
|
||||
@Test
|
||||
public void testCaseWhenFold() {
|
||||
executor = new ExpressionRuleExecutor(ImmutableList.of(TypeCoercion.INSTANCE, FoldConstantRuleOnFE.INSTANCE));
|
||||
assertRewrite("case when 1 = 2 then 1 when '1' < 2 then 2 else 3 end", "2");
|
||||
assertRewrite("case when 1 = 2 then 1 when '1' > 2 then 2 end", "null");
|
||||
assertRewrite("case when (1 + 5) / 2 > 2 then 4 when '1' < 2 then 2 else 3 end", "4");
|
||||
assertRewrite("case when not 1 = 2 then 1 when '1' > 2 then 2 end", "1");
|
||||
assertRewrite("case when 1 = 2 then 1 when 3 in ('1',2 + 8 / 2,3,4) then 2 end", "2");
|
||||
assertRewrite("case when TA = 2 then 1 when 3 in ('1',2 + 8 / 2,3,4) then 2 end", "CASE WHEN (TA = 2) THEN 1 ELSE 2 END");
|
||||
assertRewrite("case when TA = 2 then 5 when 3 in (2,3,4) then 2 else 4 end", "CASE WHEN (TA = 2) THEN 5 ELSE 2 END");
|
||||
assertRewrite("case when TA = 2 then 1 when TB in (2,3,4) then 2 else 4 end", "CASE WHEN (TA = 2) THEN 1 WHEN TB IN (2, 3, 4) THEN 2 ELSE 4 END");
|
||||
assertRewrite("case when null = 2 then 1 when 3 in (2,3,4) then 2 else 4 end", "2");
|
||||
assertRewrite("case when null = 2 then 1 else 4 end", "4");
|
||||
assertRewrite("case when null = 2 then 1 end", "null");
|
||||
assertRewrite("case when TA = TB then 1 when TC is null then 2 end", "CASE WHEN (TA = TB) THEN 1 WHEN TC IS NULL THEN 2 ELSE NULL END");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testInFold() {
|
||||
executor = new ExpressionRuleExecutor(ImmutableList.of(TypeCoercion.INSTANCE, FoldConstantRuleOnFE.INSTANCE));
|
||||
assertRewrite("1 in (1,2,3,4)", "true");
|
||||
// Type Coercion trans all to string.
|
||||
assertRewrite("3 in ('1',2 + 8 / 2,3,4)", "true");
|
||||
assertRewrite("4 / 2 * 1 - (5/2) in ('1',2 + 8 / 2,3,4)", "false");
|
||||
assertRewrite("null in ('1',2 + 8 / 2,3,4)", "null");
|
||||
assertRewrite("3 in ('1',null,3,4)", "true");
|
||||
assertRewrite("TA in (1,null,3,4)", "TA in (1, null, 3, 4)");
|
||||
assertRewrite("IA in (IB,IC,null)", "IA in (IB,IC,null)");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testLogicalFold() {
|
||||
executor = new ExpressionRuleExecutor(ImmutableList.of(TypeCoercion.INSTANCE, FoldConstantRuleOnFE.INSTANCE));
|
||||
assertRewrite("10 + 1 > 1 and 1 > 2", "false");
|
||||
assertRewrite("10 + 1 > 1 and 1 < 2", "true");
|
||||
assertRewrite("null + 1 > 1 and 1 < 2", "null");
|
||||
assertRewrite("10 < 3 and 1 > 2", "false");
|
||||
assertRewrite("6 / 2 - 10 * (6 + 1) > 2 and 10 > 3 and 1 > 2", "false");
|
||||
|
||||
assertRewrite("10 + 1 > 1 or 1 > 2", "true");
|
||||
assertRewrite("null + 1 > 1 or 1 > 2", "null");
|
||||
assertRewrite("6 / 2 - 10 * (6 + 1) > 2 or 10 > 3 or 1 > 2", "true");
|
||||
|
||||
assertRewrite("(1 > 5 and 8 < 10 or 1 = 3) or (1 > 8 + 9 / (10 * 2) or ( 10 = 3))", "false");
|
||||
assertRewrite("(TA > 1 and 8 < 10 or 1 = 3) or (1 > 3 or ( 10 = 3))", "TA > 1");
|
||||
|
||||
assertRewrite("false or false", "false");
|
||||
assertRewrite("false or true", "true");
|
||||
assertRewrite("true or false", "true");
|
||||
assertRewrite("true or true", "true");
|
||||
|
||||
assertRewrite("true and true", "true");
|
||||
assertRewrite("false and true", "false");
|
||||
assertRewrite("true and false", "false");
|
||||
assertRewrite("false and false", "false");
|
||||
|
||||
assertRewrite("true and null", "null");
|
||||
assertRewrite("false and null", "false");
|
||||
assertRewrite("true or null", "true");
|
||||
assertRewrite("false or null", "null");
|
||||
|
||||
assertRewrite("null and null", "null");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testIsNullFold() {
|
||||
executor = new ExpressionRuleExecutor(ImmutableList.of(TypeCoercion.INSTANCE, FoldConstantRuleOnFE.INSTANCE));
|
||||
assertRewrite("100 is null", "false");
|
||||
assertRewrite("null is null", "true");
|
||||
assertRewrite("null is not null", "false");
|
||||
assertRewrite("100 is not null", "true");
|
||||
assertRewrite("IA is not null", "IA is not null");
|
||||
assertRewrite("IA is null", "IA is null");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testNotFold() {
|
||||
executor = new ExpressionRuleExecutor(ImmutableList.of(TypeCoercion.INSTANCE, FoldConstantRuleOnFE.INSTANCE));
|
||||
assertRewrite("not 1 > 2", "true");
|
||||
assertRewrite("not null + 1 > 2", "null");
|
||||
assertRewrite("not (1 + 5) / 2 + (10 - 1) * 3 > 3 * 5 + 1", "false");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testCastFold() {
|
||||
executor = new ExpressionRuleExecutor(ImmutableList.of(FoldConstantRuleOnFE.INSTANCE));
|
||||
|
||||
// cast '1' as tinyint
|
||||
Cast c = new Cast(Literal.of("1"), TinyIntType.INSTANCE);
|
||||
Expression rewritten = executor.rewrite(c);
|
||||
Literal expected = Literal.of((byte) 1);
|
||||
Assertions.assertEquals(rewritten, expected);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testCompareFold() {
|
||||
executor = new ExpressionRuleExecutor(ImmutableList.of(TypeCoercion.INSTANCE, FoldConstantRuleOnFE.INSTANCE));
|
||||
assertRewrite("'1' = 2", "false");
|
||||
assertRewrite("1 = 2", "false");
|
||||
assertRewrite("1 != 2", "true");
|
||||
assertRewrite("2 > 2", "false");
|
||||
assertRewrite("3 * 10 + 1 / 2 >= 2", "true");
|
||||
assertRewrite("3 < 2", "false");
|
||||
assertRewrite("3 <= 2", "false");
|
||||
assertRewrite("3 <= null", "null");
|
||||
assertRewrite("3 >= null", "null");
|
||||
assertRewrite("null <=> null", "true");
|
||||
assertRewrite("2 <=> null", "false");
|
||||
assertRewrite("2 <=> 2", "true");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testArithmeticFold() {
|
||||
executor = new ExpressionRuleExecutor(ImmutableList.of(TypeCoercion.INSTANCE, FoldConstantRuleOnFE.INSTANCE));
|
||||
assertRewrite("1 + 1", Literal.of((byte) 2));
|
||||
assertRewrite("1 - 1", Literal.of((byte) 0));
|
||||
assertRewrite("100 + 100", Literal.of((byte) 200));
|
||||
assertRewrite("1 - 2", Literal.of((byte) -1));
|
||||
|
||||
assertRewrite("1 - 2 > 1", "false");
|
||||
assertRewrite("1 - 2 + 1 > 1 + 1 - 100", "true");
|
||||
assertRewrite("10 * 2 / 1 + 1 > (1 + 1) - 100", "true");
|
||||
|
||||
// a + 1 > 2
|
||||
Slot a = SlotReference.of("a", IntegerType.INSTANCE);
|
||||
Expression e1 = new Add(a, Literal.of(1L));
|
||||
Expression e2 = new Add(new Cast(a, BigIntType.INSTANCE), Literal.of(1L));
|
||||
assertRewrite(e1, e2);
|
||||
|
||||
// a > (1 + 10) / 2 * (10 + 1)
|
||||
Expression e3 = PARSER.parseExpression("(1 + 10) / 2 * (10 + 1)");
|
||||
Expression e4 = new GreaterThan(a, e3);
|
||||
Expression e5 = new GreaterThan(new Cast(a, DoubleType.INSTANCE), Literal.of(60.5D));
|
||||
assertRewrite(e4, e5);
|
||||
|
||||
// a > 1
|
||||
Expression e6 = new GreaterThan(a, Literal.of(1));
|
||||
assertRewrite(e6, e6);
|
||||
assertRewrite(a, a);
|
||||
|
||||
// a
|
||||
assertRewrite(a, a);
|
||||
|
||||
// 1
|
||||
Literal one = Literal.of(1);
|
||||
assertRewrite(one, one);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testTimestampFold() {
|
||||
executor = new ExpressionRuleExecutor(ImmutableList.of(TypeCoercion.INSTANCE, FoldConstantRuleOnFE.INSTANCE));
|
||||
String interval = "'1991-05-01' - interval 1 day";
|
||||
Expression e7 = process((TimestampArithmetic) PARSER.parseExpression(interval));
|
||||
Expression e8 = new DateTimeLiteral(1991, 4, 30, 0, 0, 0);
|
||||
assertRewrite(e7, e8);
|
||||
|
||||
interval = "'1991-05-01' + interval '1' day";
|
||||
e7 = process((TimestampArithmetic) PARSER.parseExpression(interval));
|
||||
e8 = new DateTimeLiteral(1991, 5, 2, 0, 0, 0);
|
||||
assertRewrite(e7, e8);
|
||||
|
||||
interval = "'1991-05-01' + interval 1+1 day";
|
||||
e7 = process((TimestampArithmetic) PARSER.parseExpression(interval));
|
||||
e8 = new DateTimeLiteral(1991, 5, 3, 0, 0, 0);
|
||||
assertRewrite(e7, e8);
|
||||
|
||||
interval = "date '1991-05-01' + interval 10 / 2 + 1 day";
|
||||
e7 = process((TimestampArithmetic) PARSER.parseExpression(interval));
|
||||
e8 = new DateTimeLiteral(1991, 5, 7, 0, 0, 0);
|
||||
assertRewrite(e7, e8);
|
||||
|
||||
interval = "interval '1' day + '1991-05-01'";
|
||||
e7 = process((TimestampArithmetic) PARSER.parseExpression(interval));
|
||||
e8 = new DateTimeLiteral(1991, 5, 2, 0, 0, 0);
|
||||
assertRewrite(e7, e8);
|
||||
|
||||
interval = "interval '3' month + '1991-05-01'";
|
||||
e7 = process((TimestampArithmetic) PARSER.parseExpression(interval));
|
||||
e8 = new DateTimeLiteral(1991, 8, 1, 0, 0, 0);
|
||||
assertRewrite(e7, e8);
|
||||
|
||||
interval = "interval 3 + 1 month + '1991-05-01'";
|
||||
e7 = process((TimestampArithmetic) PARSER.parseExpression(interval));
|
||||
e8 = new DateTimeLiteral(1991, 9, 1, 0, 0, 0);
|
||||
assertRewrite(e7, e8);
|
||||
|
||||
interval = "interval 3 + 1 year + '1991-05-01'";
|
||||
e7 = process((TimestampArithmetic) PARSER.parseExpression(interval));
|
||||
e8 = new DateTimeLiteral(1995, 5, 1, 0, 0, 0);
|
||||
assertRewrite(e7, e8);
|
||||
|
||||
interval = "interval 3 + 3 / 2 hour + '1991-05-01 10:00:00'";
|
||||
e7 = process((TimestampArithmetic) PARSER.parseExpression(interval));
|
||||
e8 = new DateTimeLiteral(1991, 5, 1, 14, 0, 0);
|
||||
assertRewrite(e7, e8);
|
||||
|
||||
interval = "interval 3 * 2 / 3 minute + '1991-05-01 10:00:00'";
|
||||
e7 = process((TimestampArithmetic) PARSER.parseExpression(interval));
|
||||
e8 = new DateTimeLiteral(1991, 5, 1, 10, 2, 0);
|
||||
assertRewrite(e7, e8);
|
||||
|
||||
interval = "interval 3 / 2 + 1 second + '1991-05-01 10:00:00'";
|
||||
e7 = process((TimestampArithmetic) PARSER.parseExpression(interval));
|
||||
e8 = new DateTimeLiteral(1991, 5, 1, 10, 0, 2);
|
||||
assertRewrite(e7, e8);
|
||||
|
||||
// a + interval 1 day
|
||||
Slot a = SlotReference.of("a", DateTimeType.INSTANCE);
|
||||
TimestampArithmetic arithmetic = new TimestampArithmetic(Operator.ADD, a, Literal.of(1), TimeUnit.DAY, false);
|
||||
Expression process = process(arithmetic);
|
||||
assertRewrite(process, process);
|
||||
}
|
||||
|
||||
public Expression process(TimestampArithmetic arithmetic) {
|
||||
String funcOpName;
|
||||
if (arithmetic.getFuncName() == null) {
|
||||
funcOpName = String.format("%sS_%s", arithmetic.getTimeUnit(),
|
||||
(arithmetic.getOp() == Operator.ADD) ? "ADD" : "SUB");
|
||||
} else {
|
||||
funcOpName = arithmetic.getFuncName();
|
||||
}
|
||||
return arithmetic.withFuncName(funcOpName.toLowerCase(Locale.ROOT));
|
||||
}
|
||||
|
||||
private void assertRewrite(String expression, String expected) {
|
||||
Map<String, Slot> mem = Maps.newHashMap();
|
||||
Expression needRewriteExpression = PARSER.parseExpression(expression);
|
||||
needRewriteExpression = replaceUnboundSlot(needRewriteExpression, mem);
|
||||
Expression expectedExpression = PARSER.parseExpression(expected);
|
||||
expectedExpression = replaceUnboundSlot(expectedExpression, mem);
|
||||
Expression rewrittenExpression = executor.rewrite(needRewriteExpression);
|
||||
Assertions.assertEquals(expectedExpression, rewrittenExpression);
|
||||
}
|
||||
|
||||
private void assertRewrite(String expression, Expression expectedExpression) {
|
||||
Expression needRewriteExpression = PARSER.parseExpression(expression);
|
||||
Expression rewrittenExpression = executor.rewrite(needRewriteExpression);
|
||||
Assertions.assertEquals(expectedExpression, rewrittenExpression);
|
||||
}
|
||||
|
||||
private void assertRewrite(Expression expression, Expression expectedExpression) {
|
||||
Expression rewrittenExpression = executor.rewrite(expression);
|
||||
Assertions.assertEquals(expectedExpression, rewrittenExpression);
|
||||
}
|
||||
|
||||
private Expression replaceUnboundSlot(Expression expression, Map<String, Slot> mem) {
|
||||
List<Expression> children = Lists.newArrayList();
|
||||
boolean hasNewChildren = false;
|
||||
for (Expression child : expression.children()) {
|
||||
Expression newChild = replaceUnboundSlot(child, mem);
|
||||
if (newChild != child) {
|
||||
hasNewChildren = true;
|
||||
}
|
||||
children.add(newChild);
|
||||
}
|
||||
if (expression instanceof UnboundSlot) {
|
||||
String name = ((UnboundSlot) expression).getName();
|
||||
mem.putIfAbsent(name, SlotReference.of(name, getType(name.charAt(0))));
|
||||
return mem.get(name);
|
||||
}
|
||||
return hasNewChildren ? expression.withChildren(children) : expression;
|
||||
}
|
||||
|
||||
private DataType getType(char t) {
|
||||
switch (t) {
|
||||
case 'T':
|
||||
return TinyIntType.INSTANCE;
|
||||
case 'I':
|
||||
return IntegerType.INSTANCE;
|
||||
case 'D':
|
||||
return DoubleType.INSTANCE;
|
||||
case 'S':
|
||||
return StringType.INSTANCE;
|
||||
case 'V':
|
||||
return VarcharType.INSTANCE;
|
||||
case 'B':
|
||||
return BooleanType.INSTANCE;
|
||||
default:
|
||||
return BigIntType.INSTANCE;
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -36,6 +36,7 @@ import org.apache.doris.nereids.trees.expressions.literal.DateLiteral;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.DecimalLiteral;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.DoubleLiteral;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.Literal;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.SmallIntLiteral;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.StringLiteral;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.TinyIntLiteral;
|
||||
@ -176,8 +177,8 @@ public class TypeCoercionTest {
|
||||
@Test
|
||||
public void testBinaryOperator() {
|
||||
Expression actual = new Divide(new SmallIntLiteral((short) 1), new BigIntLiteral(10L));
|
||||
Expression expected = new Divide(new BigIntLiteral(1L),
|
||||
new BigIntLiteral(10L));
|
||||
Expression expected = new Divide(new Cast(Literal.of((short) 1), DoubleType.INSTANCE),
|
||||
new Cast(Literal.of(10L), DoubleType.INSTANCE));
|
||||
assertRewrite(expected, actual);
|
||||
}
|
||||
|
||||
|
||||
@ -272,4 +272,13 @@ public class ExpressionParserTest extends ParserTestBase {
|
||||
String cast2 = "SELECT CAST(A AS INT) AS I FROM TEST;";
|
||||
assertSql(cast2);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testIsNull() {
|
||||
String e1 = "a is null";
|
||||
assertExpr(e1);
|
||||
|
||||
String e2 = "a is not null";
|
||||
assertExpr(e2);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user