[feature](Nereids) constant expression folding (#12151)

This commit is contained in:
shee
2022-09-26 17:16:23 +08:00
committed by GitHub
parent 3902b2bfad
commit 7977bebfed
48 changed files with 2005 additions and 186 deletions

View File

@ -203,6 +203,7 @@ predicate
| NOT? kind=(LIKE | REGEXP) pattern=valueExpression
| NOT? kind=IN LEFT_PAREN expression (COMMA expression)* RIGHT_PAREN
| NOT? kind=IN LEFT_PAREN query RIGHT_PAREN
| IS NOT? kind=NULL
;
valueExpression

View File

@ -22,6 +22,7 @@ import org.apache.doris.nereids.jobs.batch.AnalyzeRulesJob;
import org.apache.doris.nereids.jobs.batch.AnalyzeSubqueryRulesJob;
import org.apache.doris.nereids.jobs.batch.CheckAnalysisJob;
import org.apache.doris.nereids.jobs.batch.FinalizeAnalyzeJob;
import org.apache.doris.nereids.jobs.batch.TypeCoercionJob;
import org.apache.doris.nereids.rules.analysis.Scope;
import java.util.Objects;
@ -44,9 +45,13 @@ public class NereidsAnalyzer {
this.outerScope = Objects.requireNonNull(outerScope, "outerScope can not be null");
}
/**
* nereids analyze sql.
*/
public void analyze() {
new AnalyzeRulesJob(cascadesContext, outerScope).execute();
new AnalyzeSubqueryRulesJob(cascadesContext).execute();
new TypeCoercionJob(cascadesContext).execute();
new FinalizeAnalyzeJob(cascadesContext).execute();
// check whether analyze result is meaningful
new CheckAnalysisJob(cascadesContext).execute();

View File

@ -37,7 +37,7 @@ import org.apache.doris.nereids.rules.rewrite.logical.ReorderJoin;
import com.google.common.collect.ImmutableList;
/**
* Apply rules to normalize expressions.
* Apply rules to optimize logical plan.
*/
public class NereidsRewriteJobExecutor extends BatchRulesJob {
@ -58,7 +58,7 @@ public class NereidsRewriteJobExecutor extends BatchRulesJob {
*/
.addAll(new AdjustApplyFromCorrelatToUnCorrelatJob(cascadesContext).rulesJob)
.addAll(new ConvertApplyToJoinJob(cascadesContext).rulesJob)
.add(topDownBatch(ImmutableList.of(new ExpressionNormalization())))
.add(topDownBatch(ImmutableList.of(new ExpressionNormalization(cascadesContext.getConnectContext()))))
.add(topDownBatch(ImmutableList.of(new ExpressionOptimization())))
.add(topDownBatch(ImmutableList.of(new ExtractSingleTableExpressionFromDisjunction())))
.add(topDownBatch(ImmutableList.of(new NormalizeAggregate())))

View File

@ -0,0 +1,41 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.jobs.batch;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.rules.expression.rewrite.ExpressionNormalization;
import org.apache.doris.nereids.rules.expression.rewrite.rules.TypeCoercion;
import com.google.common.collect.ImmutableList;
/**
* type coercion job.
*/
public class TypeCoercionJob extends BatchRulesJob {
/**
* constructor.
*/
public TypeCoercionJob(CascadesContext cascadesContext) {
super(cascadesContext);
rulesJob.addAll(ImmutableList.of(
topDownBatch(ImmutableList.of(
new ExpressionNormalization(cascadesContext.getConnectContext(),
ImmutableList.of(TypeCoercion.INSTANCE)))
)));
}
}

View File

@ -95,6 +95,7 @@ import org.apache.doris.nereids.trees.expressions.GreaterThan;
import org.apache.doris.nereids.trees.expressions.GreaterThanEqual;
import org.apache.doris.nereids.trees.expressions.InPredicate;
import org.apache.doris.nereids.trees.expressions.InSubquery;
import org.apache.doris.nereids.trees.expressions.IsNull;
import org.apache.doris.nereids.trees.expressions.LessThan;
import org.apache.doris.nereids.trees.expressions.LessThanEqual;
import org.apache.doris.nereids.trees.expressions.Like;
@ -385,7 +386,6 @@ public class LogicalPlanBuilder extends DorisParserBaseVisitor<Object> {
public Expression visitPredicated(PredicatedContext ctx) {
return ParserUtils.withOrigin(ctx, () -> {
Expression e = getExpression(ctx.valueExpression());
// TODO: add predicate(is not null ...)
return ctx.predicate() == null ? e : withPredicate(e, ctx.predicate());
});
}
@ -948,6 +948,9 @@ public class LogicalPlanBuilder extends DorisParserBaseVisitor<Object> {
);
}
break;
case DorisParser.NULL:
outExpression = new IsNull(valueExpression);
break;
default:
throw new ParseException("Unsupported predicate type: " + ctx.kind.getText(), ctx);
}

View File

@ -36,13 +36,11 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalHaving;
import org.apache.doris.nereids.trees.plans.logical.LogicalOneRowRelation;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.types.DateTimeType;
import org.apache.doris.nereids.types.DateType;
import org.apache.doris.nereids.types.IntegerType;
import com.google.common.collect.ImmutableList;
import java.util.List;
import java.util.Locale;
import java.util.stream.Collectors;
/**
@ -134,8 +132,12 @@ public class BindFunction implements AnalysisRuleFactory {
return builder.build(functionName, unboundFunction.getArguments());
}
/**
* gets the method for calculating the time.
* e.g. YEARS_ADD、YEARS_SUB、DAYS_ADD 、DAYS_SUB
*/
@Override
public Expression visitTimestampArithmetic(TimestampArithmetic arithmetic, Env env) {
public Expression visitTimestampArithmetic(TimestampArithmetic arithmetic, Env context) {
String funcOpName;
if (arithmetic.getFuncName() == null) {
// e.g. YEARS_ADD, MONTHS_SUB
@ -144,24 +146,7 @@ public class BindFunction implements AnalysisRuleFactory {
} else {
funcOpName = arithmetic.getFuncName();
}
Expression left = arithmetic.left();
Expression right = arithmetic.right();
if (!left.getDataType().isDateType()) {
try {
left = left.castTo(DateTimeType.INSTANCE);
} catch (Exception e) {
// ignore
}
if (!left.getDataType().isDateType() && !arithmetic.getTimeUnit().isDateTimeUnit()) {
left = arithmetic.left().castTo(DateType.INSTANCE);
}
}
if (!right.getDataType().isIntType()) {
right = right.castTo(IntegerType.INSTANCE);
}
return arithmetic.withFuncName(funcOpName).withChildren(ImmutableList.of(left, right));
return arithmetic.withFuncName(funcOpName.toLowerCase(Locale.ROOT));
}
}
}

View File

@ -18,11 +18,13 @@
package org.apache.doris.nereids.rules.expression.rewrite;
import org.apache.doris.nereids.rules.expression.rewrite.rules.BetweenToCompoundRule;
import org.apache.doris.nereids.rules.expression.rewrite.rules.FoldConstantRule;
import org.apache.doris.nereids.rules.expression.rewrite.rules.InPredicateToEqualToRule;
import org.apache.doris.nereids.rules.expression.rewrite.rules.NormalizeBinaryPredicatesRule;
import org.apache.doris.nereids.rules.expression.rewrite.rules.SimplifyCastRule;
import org.apache.doris.nereids.rules.expression.rewrite.rules.SimplifyNotExprRule;
import org.apache.doris.nereids.rules.expression.rewrite.rules.TypeCoercion;
import org.apache.doris.qe.ConnectContext;
import com.google.common.collect.ImmutableList;
@ -39,11 +41,16 @@ public class ExpressionNormalization extends ExpressionRewrite {
InPredicateToEqualToRule.INSTANCE,
SimplifyNotExprRule.INSTANCE,
SimplifyCastRule.INSTANCE,
TypeCoercion.INSTANCE
TypeCoercion.INSTANCE,
FoldConstantRule.INSTANCE
);
private static final ExpressionRuleExecutor EXECUTOR = new ExpressionRuleExecutor(NORMALIZE_REWRITE_RULES);
public ExpressionNormalization() {
super(EXECUTOR);
public ExpressionNormalization(ConnectContext context) {
super(new ExpressionRuleExecutor(NORMALIZE_REWRITE_RULES, context));
}
public ExpressionNormalization(ConnectContext context, List<ExpressionRewriteRule> rules) {
super(new ExpressionRuleExecutor(rules, context));
}
}

View File

@ -17,8 +17,15 @@
package org.apache.doris.nereids.rules.expression.rewrite;
import org.apache.doris.qe.ConnectContext;
/**
* expression rewrite context.
*/
public class ExpressionRewriteContext {
public final ConnectContext connectContext;
public ExpressionRewriteContext(ConnectContext connectContext) {
this.connectContext = connectContext;
}
}

View File

@ -18,8 +18,7 @@
package org.apache.doris.nereids.rules.expression.rewrite;
import org.apache.doris.nereids.trees.expressions.Expression;
import com.google.common.collect.Lists;
import org.apache.doris.qe.ConnectContext;
import java.util.List;
import java.util.Optional;
@ -33,14 +32,13 @@ public class ExpressionRuleExecutor {
private final ExpressionRewriteContext ctx;
private final List<ExpressionRewriteRule> rules;
public ExpressionRuleExecutor(List<ExpressionRewriteRule> rules) {
public ExpressionRuleExecutor(List<ExpressionRewriteRule> rules, ConnectContext context) {
this.rules = rules;
this.ctx = new ExpressionRewriteContext();
this.ctx = new ExpressionRewriteContext(context);
}
public ExpressionRuleExecutor(ExpressionRewriteRule rule) {
this.rules = Lists.newArrayList(rule);
this.ctx = new ExpressionRewriteContext();
public ExpressionRuleExecutor(List<ExpressionRewriteRule> rules) {
this(rules, null);
}
public List<Expression> rewrite(List<Expression> exprs) {

View File

@ -0,0 +1,39 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.expression.rewrite.rules;
import org.apache.doris.nereids.rules.expression.rewrite.AbstractExpressionRewriteRule;
import org.apache.doris.nereids.rules.expression.rewrite.ExpressionRewriteContext;
import org.apache.doris.nereids.trees.expressions.Expression;
/**
* Constant evaluation of an expression.
*/
public class FoldConstantRule extends AbstractExpressionRewriteRule {
public static final FoldConstantRule INSTANCE = new FoldConstantRule();
@Override
public Expression rewrite(Expression expr, ExpressionRewriteContext ctx) {
if (ctx.connectContext != null && ctx.connectContext.getSessionVariable().isEnableFoldConstantByBe()) {
return FoldConstantRuleOnBE.INSTANCE.rewrite(expr, ctx);
}
return FoldConstantRuleOnFE.INSTANCE.rewrite(expr, ctx);
}
}

View File

@ -0,0 +1,200 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.expression.rewrite.rules;
import org.apache.doris.analysis.Expr;
import org.apache.doris.analysis.ExprId;
import org.apache.doris.analysis.LiteralExpr;
import org.apache.doris.catalog.Env;
import org.apache.doris.catalog.PrimitiveType;
import org.apache.doris.catalog.Type;
import org.apache.doris.common.IdGenerator;
import org.apache.doris.common.UserException;
import org.apache.doris.common.util.TimeUtils;
import org.apache.doris.common.util.VectorizedUtil;
import org.apache.doris.nereids.glue.translator.ExpressionTranslator;
import org.apache.doris.nereids.rules.expression.rewrite.AbstractExpressionRewriteRule;
import org.apache.doris.nereids.rules.expression.rewrite.ExpressionRewriteContext;
import org.apache.doris.nereids.trees.expressions.Between;
import org.apache.doris.nereids.trees.expressions.Cast;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.proto.InternalService;
import org.apache.doris.proto.InternalService.PConstantExprResult;
import org.apache.doris.qe.ConnectContext;
import org.apache.doris.rpc.BackendServiceProxy;
import org.apache.doris.system.Backend;
import org.apache.doris.thrift.TExpr;
import org.apache.doris.thrift.TFoldConstantParams;
import org.apache.doris.thrift.TNetworkAddress;
import org.apache.doris.thrift.TPrimitiveType;
import org.apache.doris.thrift.TQueryGlobals;
import com.google.common.collect.Maps;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import java.time.LocalDateTime;
import java.time.format.DateTimeFormatter;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Objects;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
/**
* Constant evaluation of an expression.
*/
public class FoldConstantRuleOnBE extends AbstractExpressionRewriteRule {
public static final FoldConstantRuleOnBE INSTANCE = new FoldConstantRuleOnBE();
private static final Logger LOG = LogManager.getLogger(FoldConstantRuleOnBE.class);
private static final DateTimeFormatter DATE_FORMAT = DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss");
private final IdGenerator<ExprId> idGenerator = ExprId.createGenerator();
@Override
public Expression rewrite(Expression expr, ExpressionRewriteContext ctx) {
Expression expression = FoldConstantRuleOnFE.INSTANCE.rewrite(expr, ctx);
return foldByBE(expression, ctx);
}
private Expression foldByBE(Expression root, ExpressionRewriteContext context) {
Map<String, Expression> constMap = Maps.newHashMap();
Map<String, TExpr> staleConstTExprMap = Maps.newHashMap();
collectConst(root, constMap, staleConstTExprMap);
if (constMap.isEmpty()) {
return root;
}
Map<String, Map<String, TExpr>> paramMap = new HashMap<>();
paramMap.put("0", staleConstTExprMap);
Map<String, Expression> resultMap = evalOnBE(paramMap, constMap, context.connectContext);
if (!resultMap.isEmpty()) {
return replace(root, constMap, resultMap);
}
return root;
}
private Expression replace(Expression root, Map<String, Expression> constMap, Map<String, Expression> resultMap) {
for (Entry<String, Expression> entry : constMap.entrySet()) {
if (entry.getValue().equals(root)) {
return resultMap.get(entry.getKey());
}
}
List<Expression> newChildren = new ArrayList<>();
boolean hasNewChildren = false;
for (Expression child : root.children()) {
Expression newChild = replace(child, constMap, resultMap);
if (newChild != child) {
hasNewChildren = true;
}
newChildren.add(newChild);
}
return hasNewChildren ? root.withChildren(newChildren) : root;
}
private void collectConst(Expression expr, Map<String, Expression> constMap, Map<String, TExpr> tExprMap) {
if (expr.isConstant()) {
// Do not constant fold cast(null as dataType) because we cannot preserve the
// cast-to-types and that can lead to query failures, e.g., CTAS
if (expr instanceof Cast) {
if (((Cast) expr).child().isNullLiteral()) {
return;
}
}
// skip literal expr
if (expr.isLiteral()) {
return;
}
// skip BetweenPredicate need to be rewrite to CompoundPredicate
if (expr instanceof Between) {
return;
}
String id = idGenerator.getNextId().toString();
constMap.put(id, expr);
Expr staleExpr = ExpressionTranslator.translate(expr, null);
tExprMap.put(id, staleExpr.treeToThrift());
} else {
for (int i = 0; i < expr.children().size(); i++) {
final Expression child = expr.children().get(i);
collectConst(child, constMap, tExprMap);
}
}
}
private Map<String, Expression> evalOnBE(Map<String, Map<String, TExpr>> paramMap,
Map<String, Expression> constMap, ConnectContext context) {
Map<String, Expression> resultMap = new HashMap<>();
try {
List<Long> backendIds = Env.getCurrentSystemInfo().getBackendIds(true);
if (backendIds.isEmpty()) {
throw new UserException("No alive backends");
}
Collections.shuffle(backendIds);
Backend be = Env.getCurrentSystemInfo().getBackend(backendIds.get(0));
TNetworkAddress brpcAddress = new TNetworkAddress(be.getHost(), be.getBrpcPort());
TQueryGlobals queryGlobals = new TQueryGlobals();
queryGlobals.setNowString(DATE_FORMAT.format(LocalDateTime.now()));
queryGlobals.setTimestampMs(System.currentTimeMillis());
queryGlobals.setTimeZone(TimeUtils.DEFAULT_TIME_ZONE);
if (context.getSessionVariable().getTimeZone().equals("CST")) {
queryGlobals.setTimeZone(TimeUtils.DEFAULT_TIME_ZONE);
} else {
queryGlobals.setTimeZone(context.getSessionVariable().getTimeZone());
}
TFoldConstantParams tParams = new TFoldConstantParams(paramMap, queryGlobals);
tParams.setVecExec(VectorizedUtil.isVectorized());
Future<PConstantExprResult> future =
BackendServiceProxy.getInstance().foldConstantExpr(brpcAddress, tParams);
PConstantExprResult result = future.get(5, TimeUnit.SECONDS);
if (result.getStatus().getStatusCode() == 0) {
for (Entry<String, InternalService.PExprResultMap> e : result.getExprResultMapMap().entrySet()) {
for (Entry<String, InternalService.PExprResult> e1 : e.getValue().getMapMap().entrySet()) {
Expression ret;
if (e1.getValue().getSuccess()) {
TPrimitiveType type = TPrimitiveType.findByValue(e1.getValue().getType().getType());
Type t = Type.fromPrimitiveType(PrimitiveType.fromThrift(Objects.requireNonNull(type)));
Expr staleExpr = LiteralExpr.create(e1.getValue().getContent(), Objects.requireNonNull(t));
// Nereids type
DataType t1 = DataType.convertFromString(staleExpr.getType().getPrimitiveType().toString());
ret = Literal.of(staleExpr.getStringValue()).castTo(t1);
} else {
ret = constMap.get(e.getKey());
}
resultMap.put(e.getKey(), ret);
}
}
} else {
LOG.warn("failed to get const expr value from be: {}", result.getStatus().getErrorMsgsList());
}
} catch (Exception e) {
LOG.warn("failed to get const expr value from be: {}", e.getMessage());
}
return resultMap;
}
}

View File

@ -0,0 +1,310 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.expression.rewrite.rules;
import org.apache.doris.nereids.rules.expression.rewrite.AbstractExpressionRewriteRule;
import org.apache.doris.nereids.rules.expression.rewrite.ExpressionRewriteContext;
import org.apache.doris.nereids.trees.expressions.And;
import org.apache.doris.nereids.trees.expressions.BinaryArithmetic;
import org.apache.doris.nereids.trees.expressions.CaseWhen;
import org.apache.doris.nereids.trees.expressions.Cast;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.ExpressionEvaluator;
import org.apache.doris.nereids.trees.expressions.GreaterThan;
import org.apache.doris.nereids.trees.expressions.GreaterThanEqual;
import org.apache.doris.nereids.trees.expressions.InPredicate;
import org.apache.doris.nereids.trees.expressions.IsNull;
import org.apache.doris.nereids.trees.expressions.LessThan;
import org.apache.doris.nereids.trees.expressions.LessThanEqual;
import org.apache.doris.nereids.trees.expressions.Like;
import org.apache.doris.nereids.trees.expressions.Not;
import org.apache.doris.nereids.trees.expressions.NullSafeEqual;
import org.apache.doris.nereids.trees.expressions.Or;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.TimestampArithmetic;
import org.apache.doris.nereids.trees.expressions.WhenClause;
import org.apache.doris.nereids.trees.expressions.functions.BoundFunction;
import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable;
import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.util.ExpressionUtils;
import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;
/**
* evaluate an expression on fe.
*/
public class FoldConstantRuleOnFE extends AbstractExpressionRewriteRule {
public static final FoldConstantRuleOnFE INSTANCE = new FoldConstantRuleOnFE();
@Override
public Expression rewrite(Expression expr, ExpressionRewriteContext ctx) {
return process(expr, ctx);
}
@Override
public Expression visit(Expression expr, ExpressionRewriteContext context) {
return expr;
}
/**
* process constant expression.
*/
public Expression process(Expression expr, ExpressionRewriteContext ctx) {
if (expr instanceof PropagateNullable) {
List<Expression> children = expr.children()
.stream()
.map(child -> process(child, ctx))
.collect(Collectors.toList());
if (ExpressionUtils.hasNullLiteral(children)) {
return Literal.of(null);
}
if (!ExpressionUtils.isAllLiteral(children)) {
return expr.withChildren(children);
}
return expr.withChildren(children).accept(this, ctx);
} else {
return expr.accept(this, ctx);
}
}
@Override
public Expression visitEqualTo(EqualTo equalTo, ExpressionRewriteContext context) {
return BooleanLiteral.of(((Literal) equalTo.left()).compareTo((Literal) equalTo.right()) == 0);
}
@Override
public Expression visitGreaterThan(GreaterThan greaterThan, ExpressionRewriteContext context) {
return BooleanLiteral.of(((Literal) greaterThan.left()).compareTo((Literal) greaterThan.right()) > 0);
}
@Override
public Expression visitGreaterThanEqual(GreaterThanEqual greaterThanEqual, ExpressionRewriteContext context) {
return BooleanLiteral.of(((Literal) greaterThanEqual.left())
.compareTo((Literal) greaterThanEqual.right()) >= 0);
}
@Override
public Expression visitLessThan(LessThan lessThan, ExpressionRewriteContext context) {
return BooleanLiteral.of(((Literal) lessThan.left()).compareTo((Literal) lessThan.right()) < 0);
}
@Override
public Expression visitLessThanEqual(LessThanEqual lessThanEqual, ExpressionRewriteContext context) {
return BooleanLiteral.of(((Literal) lessThanEqual.left()).compareTo((Literal) lessThanEqual.right()) <= 0);
}
@Override
public Expression visitNullSafeEqual(NullSafeEqual nullSafeEqual, ExpressionRewriteContext context) {
Expression left = process(nullSafeEqual.left(), context);
Expression right = process(nullSafeEqual.right(), context);
if (ExpressionUtils.isAllLiteral(left, right)) {
Literal l = (Literal) left;
Literal r = (Literal) right;
if (l.isNullLiteral() && r.isNullLiteral()) {
return BooleanLiteral.TRUE;
} else if (!l.isNullLiteral() && !r.isNullLiteral()) {
return BooleanLiteral.of(l.compareTo(r) == 0);
} else {
return BooleanLiteral.FALSE;
}
}
return nullSafeEqual.withChildren(left, right);
}
@Override
public Expression visitNot(Not not, ExpressionRewriteContext context) {
return BooleanLiteral.of(!((BooleanLiteral) not.child()).getValue());
}
@Override
public Expression visitSlot(Slot slot, ExpressionRewriteContext context) {
return slot;
}
@Override
public Expression visitLiteral(Literal literal, ExpressionRewriteContext context) {
return literal;
}
@Override
public Expression visitAnd(And and, ExpressionRewriteContext context) {
List<Expression> children = Lists.newArrayList();
for (Expression child : and.children()) {
Expression newChild = process(child, context);
if (newChild.equals(BooleanLiteral.FALSE)) {
return BooleanLiteral.FALSE;
}
if (!newChild.equals(BooleanLiteral.TRUE)) {
children.add(newChild);
}
}
if (children.isEmpty()) {
return BooleanLiteral.TRUE;
}
if (children.size() == 1) {
return children.get(0);
}
if (ExpressionUtils.isAllNullLiteral(children)) {
return Literal.of(null);
}
return and.withChildren(children);
}
@Override
public Expression visitOr(Or or, ExpressionRewriteContext context) {
List<Expression> children = Lists.newArrayList();
for (Expression child : or.children()) {
Expression newChild = process(child, context);
if (newChild.equals(BooleanLiteral.TRUE)) {
return BooleanLiteral.TRUE;
}
if (!newChild.equals(BooleanLiteral.FALSE)) {
children.add(newChild);
}
}
if (children.isEmpty()) {
return BooleanLiteral.FALSE;
}
if (children.size() == 1) {
return children.get(0);
}
if (ExpressionUtils.isAllNullLiteral(children)) {
return Literal.of(null);
}
return or.withChildren(children);
}
@Override
public Expression visitLike(Like like, ExpressionRewriteContext context) {
return like;
}
@Override
public Expression visitCast(Cast cast, ExpressionRewriteContext context) {
Expression child = process(cast.child(), context);
// todo: process other null case
if (child.isNullLiteral()) {
return Literal.of(null);
}
if (child.isLiteral()) {
return child.castTo(cast.getDataType());
}
return cast.withChildren(child);
}
@Override
public Expression visitBoundFunction(BoundFunction boundFunction, ExpressionRewriteContext context) {
List<Expression> newArgs = boundFunction.getArguments().stream().map(arg -> process(arg, context))
.collect(Collectors.toList());
if (ExpressionUtils.isAllLiteral(newArgs)) {
return ExpressionEvaluator.INSTANCE.eval(boundFunction.withChildren(newArgs));
}
return boundFunction.withChildren(newArgs);
}
@Override
public Expression visitBinaryArithmetic(BinaryArithmetic binaryArithmetic, ExpressionRewriteContext context) {
return ExpressionEvaluator.INSTANCE.eval(binaryArithmetic);
}
@Override
public Expression visitCaseWhen(CaseWhen caseWhen, ExpressionRewriteContext context) {
Expression newDefault = null;
boolean foundNewDefault = false;
List<WhenClause> whenClauses = new ArrayList<>();
for (WhenClause whenClause : caseWhen.getWhenClauses()) {
Expression whenOperand = process(whenClause.getOperand(), context);
if (!(whenOperand.isLiteral())) {
whenClauses.add(new WhenClause(whenOperand, process(whenClause.getResult(), context)));
} else if (BooleanLiteral.TRUE.equals(whenOperand)) {
foundNewDefault = true;
newDefault = process(whenClause.getResult(), context);
break;
}
}
Expression defaultResult;
if (foundNewDefault) {
defaultResult = newDefault;
} else {
defaultResult = process(caseWhen.getDefaultValue().orElse(Literal.of(null)), context);
}
if (whenClauses.isEmpty()) {
return defaultResult;
}
return new CaseWhen(whenClauses, defaultResult);
}
@Override
public Expression visitInPredicate(InPredicate inPredicate, ExpressionRewriteContext context) {
Expression value = process(inPredicate.child(0), context);
List<Expression> children = Lists.newArrayList();
children.add(value);
if (value.isNullLiteral()) {
return Literal.of(null);
}
boolean hasNull = false;
boolean hasUnresolvedValue = !value.isLiteral();
for (int i = 1; i < inPredicate.children().size(); i++) {
Expression inValue = process(inPredicate.child(i), context);
children.add(inValue);
if (!inValue.isLiteral()) {
hasUnresolvedValue = true;
}
if (inValue.isNullLiteral()) {
hasNull = true;
}
if (inValue.isLiteral() && value.isLiteral() && ((Literal) value).compareTo((Literal) inValue) == 0) {
return Literal.of(true);
}
}
if (hasUnresolvedValue) {
return inPredicate.withChildren(children);
}
return hasNull ? Literal.of(null) : Literal.of(false);
}
@Override
public Expression visitIsNull(IsNull isNull, ExpressionRewriteContext context) {
Expression child = process(isNull.child(), context);
if (child.isNullLiteral()) {
return Literal.of(true);
} else if (!child.nullable()) {
return Literal.of(false);
}
return isNull.withChildren(child);
}
@Override
public Expression visitTimestampArithmetic(TimestampArithmetic arithmetic, ExpressionRewriteContext context) {
return ExpressionEvaluator.INSTANCE.eval(arithmetic);
}
}

View File

@ -17,16 +17,19 @@
package org.apache.doris.nereids.rules.expression.rewrite.rules;
import org.apache.doris.analysis.ArithmeticExpr.Operator;
import org.apache.doris.nereids.annotation.Developing;
import org.apache.doris.nereids.rules.expression.rewrite.AbstractExpressionRewriteRule;
import org.apache.doris.nereids.rules.expression.rewrite.ExpressionRewriteContext;
import org.apache.doris.nereids.trees.expressions.BinaryOperator;
import org.apache.doris.nereids.trees.expressions.CaseWhen;
import org.apache.doris.nereids.trees.expressions.Cast;
import org.apache.doris.nereids.trees.expressions.Divide;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.InPredicate;
import org.apache.doris.nereids.trees.expressions.typecoercion.ImplicitCastInputTypes;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.DoubleType;
import org.apache.doris.nereids.types.coercion.AbstractDataType;
import org.apache.doris.nereids.util.TypeCoercionUtils;
@ -45,9 +48,8 @@ public class TypeCoercion extends AbstractExpressionRewriteRule {
// TODO:
// 1. DecimalPrecision Process
// 2. Divide process
// 3. String promote with numeric in binary arithmetic
// 4. Date and DateTime process
// 2. String promote with numeric in binary arithmetic
// 3. Date and DateTime process
public static final TypeCoercion INSTANCE = new TypeCoercion();
@ -82,6 +84,23 @@ public class TypeCoercion extends AbstractExpressionRewriteRule {
.orElse(binaryOperator.withChildren(left, right));
}
@Override
public Expression visitDivide(Divide divide, ExpressionRewriteContext context) {
Expression left = rewrite(divide.left(), context);
Expression right = rewrite(divide.right(), context);
DataType t1 = TypeCoercionUtils.getNumResultType(left.getDataType());
DataType t2 = TypeCoercionUtils.getNumResultType(right.getDataType());
DataType commonType = TypeCoercionUtils.findCommonNumericsType(t1, t2);
if (divide.getLegacyOperator() == Operator.DIVIDE) {
if (commonType.isBigIntType() || commonType.isLargeIntType()) {
commonType = DoubleType.INSTANCE;
}
}
Expression newLeft = TypeCoercionUtils.castIfNotSameType(left, commonType);
Expression newRight = TypeCoercionUtils.castIfNotSameType(right, commonType);
return divide.withChildren(newLeft, newRight);
}
@Override
public Expression visitCaseWhen(CaseWhen caseWhen, ExpressionRewriteContext context) {
List<Expression> rewrittenChildren = caseWhen.children().stream()

View File

@ -19,13 +19,14 @@ package org.apache.doris.nereids.trees.expressions;
import org.apache.doris.analysis.ArithmeticExpr.Operator;
import org.apache.doris.nereids.exceptions.UnboundException;
import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.DataType;
/**
* binary arithmetic operator. Such as +, -, *, /.
*/
public abstract class BinaryArithmetic extends BinaryOperator {
public abstract class BinaryArithmetic extends BinaryOperator implements PropagateNullable {
private final Operator legacyOperator;

View File

@ -71,5 +71,6 @@ public abstract class CompoundPredicate extends BinaryOperator {
public abstract CompoundPredicate flip(Expression left, Expression right);
public abstract Class<? extends CompoundPredicate> flipType();
}

View File

@ -18,6 +18,7 @@
package org.apache.doris.nereids.trees.expressions;
import org.apache.doris.nereids.exceptions.UnboundException;
import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import com.google.common.base.Preconditions;
@ -27,7 +28,7 @@ import java.util.List;
/**
* Equal to expression: a = b.
*/
public class EqualTo extends ComparisonPredicate {
public class EqualTo extends ComparisonPredicate implements PropagateNullable {
public EqualTo(Expression left, Expression right) {
super(left, right, "=");

View File

@ -0,0 +1,46 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.trees.expressions;
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
/**
* nereids function annotation.
*/
@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.METHOD)
public @interface ExecFunction {
/**
* function name
*/
String name();
/**
* args type
*/
String[] argTypes();
/**
* return type
*/
String returnType();
}

View File

@ -0,0 +1,35 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.trees.expressions;
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
/**
* exec function list
*/
@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.METHOD)
public @interface ExecFunctionList {
/**
* exec functions.
*/
ExecFunction[] value();
}

View File

@ -23,6 +23,7 @@ import org.apache.doris.nereids.exceptions.UnboundException;
import org.apache.doris.nereids.trees.AbstractTreeNode;
import org.apache.doris.nereids.trees.expressions.functions.ComputeNullable;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
import org.apache.doris.nereids.trees.expressions.shape.LeafExpression;
import org.apache.doris.nereids.trees.expressions.typecoercion.ExpectsInputTypes;
import org.apache.doris.nereids.trees.expressions.typecoercion.TypeCheckResult;
@ -139,6 +140,14 @@ public abstract class Expression extends AbstractTreeNode<Expression> implements
return collect(Slot.class::isInstance);
}
public boolean isLiteral() {
return this instanceof Literal;
}
public boolean isNullLiteral() {
return this instanceof NullLiteral;
}
@Override
public boolean equals(Object o) {
if (this == o) {

View File

@ -0,0 +1,208 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.trees.expressions;
import org.apache.doris.catalog.Env;
import org.apache.doris.common.AnalysisException;
import org.apache.doris.nereids.trees.expressions.functions.ExecutableFunctions;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
import org.apache.doris.nereids.types.DataType;
import com.google.common.collect.ImmutableMultimap;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
/**
* An expression evaluator that evaluates the value of an expression.
*/
public enum ExpressionEvaluator {
INSTANCE;
private ImmutableMultimap<String, FunctionInvoker> functions;
ExpressionEvaluator() {
registerFunctions();
}
/**
* Evaluate the value of the expression.
*/
public Expression eval(Expression expression) {
if (!expression.isConstant() || expression instanceof AggregateFunction) {
return expression;
}
String fnName = null;
DataType[] args = null;
if (expression instanceof BinaryArithmetic) {
BinaryArithmetic arithmetic = (BinaryArithmetic) expression;
fnName = arithmetic.getLegacyOperator().getName();
args = new DataType[]{arithmetic.left().getDataType(), arithmetic.right().getDataType()};
} else if (expression instanceof TimestampArithmetic) {
TimestampArithmetic arithmetic = (TimestampArithmetic) expression;
fnName = arithmetic.getFuncName();
args = new DataType[]{arithmetic.left().getDataType(), arithmetic.right().getDataType()};
}
if ((Env.getCurrentEnv().isNullResultWithOneNullParamFunction(fnName))) {
for (Expression e : expression.children()) {
if (e instanceof NullLiteral) {
return Literal.of(null);
}
}
}
return invoke(expression, fnName, args);
}
private Expression invoke(Expression expression, String fnName, DataType[] args) {
FunctionSignature signature = new FunctionSignature(fnName, args, null);
FunctionInvoker invoker = getFunction(signature);
if (invoker != null) {
try {
return invoker.invoke(expression.children());
} catch (AnalysisException e) {
return expression;
}
}
return expression;
}
private FunctionInvoker getFunction(FunctionSignature signature) {
Collection<FunctionInvoker> functionInvokers = functions.get(signature.getName());
if (functionInvokers == null) {
return null;
}
for (FunctionInvoker candidate : functionInvokers) {
DataType[] candidateTypes = candidate.getSignature().getArgTypes();
DataType[] expectedTypes = signature.getArgTypes();
if (candidateTypes.length != expectedTypes.length) {
continue;
}
boolean match = true;
for (int i = 0; i < candidateTypes.length; i++) {
if (!candidateTypes[i].equals(expectedTypes[i])) {
match = false;
break;
}
}
if (match) {
return candidate;
}
}
return null;
}
private void registerFunctions() {
if (functions != null) {
return;
}
ImmutableMultimap.Builder<String, FunctionInvoker> mapBuilder =
new ImmutableMultimap.Builder<String, FunctionInvoker>();
Class clazz = ExecutableFunctions.class;
for (Method method : clazz.getDeclaredMethods()) {
ExecFunctionList annotationList = method.getAnnotation(ExecFunctionList.class);
if (annotationList != null) {
for (ExecFunction f : annotationList.value()) {
registerFEFunction(mapBuilder, method, f);
}
}
registerFEFunction(mapBuilder, method, method.getAnnotation(ExecFunction.class));
}
this.functions = mapBuilder.build();
}
private void registerFEFunction(ImmutableMultimap.Builder<String, FunctionInvoker> mapBuilder,
Method method, ExecFunction annotation) {
if (annotation != null) {
String name = annotation.name();
DataType returnType = DataType.convertFromString(annotation.returnType());
List<DataType> argTypes = new ArrayList<>();
for (String type : annotation.argTypes()) {
argTypes.add(DataType.convertFromString(type));
}
FunctionSignature signature = new FunctionSignature(name,
argTypes.toArray(new DataType[argTypes.size()]), returnType);
mapBuilder.put(name, new FunctionInvoker(method, signature));
}
}
/**
* function invoker.
*/
public static class FunctionInvoker {
private final Method method;
private final FunctionSignature signature;
public FunctionInvoker(Method method, FunctionSignature signature) {
this.method = method;
this.signature = signature;
}
public Method getMethod() {
return method;
}
public FunctionSignature getSignature() {
return signature;
}
public Literal invoke(List<Expression> args) throws AnalysisException {
try {
return (Literal) method.invoke(null, args.toArray());
} catch (InvocationTargetException | IllegalAccessException | IllegalArgumentException e) {
throw new AnalysisException(e.getLocalizedMessage());
}
}
}
/**
* function signature.
*/
public static class FunctionSignature {
private final String name;
private final DataType[] argTypes;
private final DataType returnType;
public FunctionSignature(String name, DataType[] argTypes, DataType returnType) {
this.name = name;
this.argTypes = argTypes;
this.returnType = returnType;
}
public DataType[] getArgTypes() {
return argTypes;
}
public DataType getReturnType() {
return returnType;
}
public String getName() {
return name;
}
}
}

View File

@ -18,6 +18,7 @@
package org.apache.doris.nereids.trees.expressions;
import org.apache.doris.nereids.exceptions.UnboundException;
import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import com.google.common.base.Preconditions;
@ -27,7 +28,7 @@ import java.util.List;
/**
* Greater than expression: a > b.
*/
public class GreaterThan extends ComparisonPredicate {
public class GreaterThan extends ComparisonPredicate implements PropagateNullable {
/**
* Constructor of Greater Than ComparisonPredicate.
*

View File

@ -18,6 +18,7 @@
package org.apache.doris.nereids.trees.expressions;
import org.apache.doris.nereids.exceptions.UnboundException;
import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import com.google.common.base.Preconditions;
@ -27,7 +28,7 @@ import java.util.List;
/**
* Greater than and equal expression: a >= b.
*/
public class GreaterThanEqual extends ComparisonPredicate {
public class GreaterThanEqual extends ComparisonPredicate implements PropagateNullable {
/**
* Constructor of Greater Than And Equal.
*

View File

@ -0,0 +1,78 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.trees.expressions;
import org.apache.doris.nereids.exceptions.UnboundException;
import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import com.google.common.base.Preconditions;
import java.util.List;
import java.util.Objects;
/**
* expr is null predicate.
*/
public class IsNull extends Expression implements UnaryExpression {
public IsNull(Expression e) {
super(e);
}
@Override
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
return visitor.visitIsNull(this, context);
}
@Override
public boolean nullable() {
return false;
}
@Override
public IsNull withChildren(List<Expression> children) {
Preconditions.checkArgument(children.size() == 1);
return new IsNull(children.get(0));
}
@Override
public String toSql() throws UnboundException {
return child().toSql() + " IS NULL";
}
@Override
public String toString() {
return toSql();
}
@Override
public boolean equals(Object o) {
if (!super.equals(o)) {
return false;
}
IsNull other = (IsNull) o;
return Objects.equals(child(), other.child());
}
@Override
public int hashCode() {
return child().hashCode();
}
}

View File

@ -18,6 +18,7 @@
package org.apache.doris.nereids.trees.expressions;
import org.apache.doris.nereids.exceptions.UnboundException;
import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import com.google.common.base.Preconditions;
@ -27,7 +28,7 @@ import java.util.List;
/**
* Less than expression: a < b.
*/
public class LessThan extends ComparisonPredicate {
public class LessThan extends ComparisonPredicate implements PropagateNullable {
/**
* Constructor of Less Than Comparison Predicate.
*

View File

@ -18,6 +18,7 @@
package org.apache.doris.nereids.trees.expressions;
import org.apache.doris.nereids.exceptions.UnboundException;
import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import com.google.common.base.Preconditions;
@ -27,7 +28,7 @@ import java.util.List;
/**
* Less than and equal expression: a <= b.
*/
public class LessThanEqual extends ComparisonPredicate {
public class LessThanEqual extends ComparisonPredicate implements PropagateNullable {
/**
* Constructor of Less Than And Equal.
*

View File

@ -18,6 +18,7 @@
package org.apache.doris.nereids.trees.expressions;
import org.apache.doris.nereids.exceptions.UnboundException;
import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable;
import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression;
import org.apache.doris.nereids.trees.expressions.typecoercion.ExpectsInputTypes;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
@ -34,7 +35,7 @@ import java.util.Objects;
/**
* Not expression: not a.
*/
public class Not extends Expression implements UnaryExpression, ExpectsInputTypes {
public class Not extends Expression implements UnaryExpression, ExpectsInputTypes, PropagateNullable {
public static final List<AbstractDataType> EXPECTS_INPUT_TYPES = ImmutableList.of(BooleanType.INSTANCE);

View File

@ -64,4 +64,5 @@ public class NullSafeEqual extends ComparisonPredicate {
public ComparisonPredicate commute() {
return new NullSafeEqual(right(), left());
}
}

View File

@ -79,6 +79,10 @@ public class SlotReference extends Slot {
this.column = column;
}
public static SlotReference of(String name, DataType type) {
return new SlotReference(name, type);
}
public static SlotReference fromColumn(Column column, List<String> qualifier) {
DataType dataType = DataType.convertFromCatalogDataType(column.getType());
return new SlotReference(NamedExpressionUtil.newExprId(), column.getName(), dataType,

View File

@ -22,16 +22,21 @@ import org.apache.doris.nereids.exceptions.UnboundException;
import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable;
import org.apache.doris.nereids.trees.expressions.literal.IntervalLiteral.TimeUnit;
import org.apache.doris.nereids.trees.expressions.shape.BinaryExpression;
import org.apache.doris.nereids.trees.expressions.typecoercion.ImplicitCastInputTypes;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.DateTimeType;
import org.apache.doris.nereids.types.DateType;
import org.apache.doris.nereids.types.IntegerType;
import org.apache.doris.nereids.types.coercion.AbstractDataType;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import java.util.List;
import java.util.Objects;
/**
* Describes the addition and subtraction of time units from timestamps.
@ -40,7 +45,14 @@ import java.util.List;
* Example: '1996-01-01' + INTERVAL '3' month;
* TODO: we need to rethink this, and maybe need to add a new type of Interval then implement IntervalLiteral as others
*/
public class TimestampArithmetic extends Expression implements BinaryExpression, PropagateNullable {
public class TimestampArithmetic extends Expression implements BinaryExpression, ImplicitCastInputTypes,
PropagateNullable {
private static final List<AbstractDataType> EXPECTED_INPUT_TYPES = ImmutableList.of(
DateTimeType.INSTANCE,
IntegerType.INSTANCE
);
private static final Logger LOG = LogManager.getLogger(TimestampArithmetic.class);
private final String funcName;
private final boolean intervalFirst;
@ -149,4 +161,22 @@ public class TimestampArithmetic extends Expression implements BinaryExpression,
}
return strBuilder.toString();
}
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
TimestampArithmetic other = (TimestampArithmetic) o;
return Objects.equals(funcName, other.funcName) && Objects.equals(timeUnit, other.timeUnit)
&& Objects.equals(left(), other.left()) && Objects.equals(right(), other.right());
}
@Override
public List<AbstractDataType> expectedInputTypes() {
return EXPECTED_INPUT_TYPES;
}
}

View File

@ -54,7 +54,7 @@ public class WhenClause extends Expression implements BinaryExpression, ExpectsI
@Override
public String toSql() {
return "WHEN " + left().toSql() + " THEN " + right().toSql();
return " WHEN " + left().toSql() + " THEN " + right().toSql();
}
@Override

View File

@ -0,0 +1,243 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.trees.expressions.functions;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.trees.expressions.ExecFunction;
import org.apache.doris.nereids.trees.expressions.literal.BigIntLiteral;
import org.apache.doris.nereids.trees.expressions.literal.DateTimeLiteral;
import org.apache.doris.nereids.trees.expressions.literal.DecimalLiteral;
import org.apache.doris.nereids.trees.expressions.literal.DoubleLiteral;
import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral;
import org.apache.doris.nereids.trees.expressions.literal.SmallIntLiteral;
import org.apache.doris.nereids.trees.expressions.literal.TinyIntLiteral;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import java.math.BigDecimal;
/**
* functions that can be executed in FE.
*/
public class ExecutableFunctions {
public static final ExecutableFunctions INSTANCE = new ExecutableFunctions();
private static final Logger LOG = LogManager.getLogger(ExecutableFunctions.class);
/**
* Executable arithmetic functions
*/
@ExecFunction(name = "add", argTypes = {"TINYINT", "TINYINT"}, returnType = "TINYINT")
public static TinyIntLiteral addTinyint(TinyIntLiteral first, TinyIntLiteral second) {
byte result = (byte) Math.addExact(first.getValue(), second.getValue());
return new TinyIntLiteral(result);
}
@ExecFunction(name = "add", argTypes = {"SMALLINT", "SMALLINT"}, returnType = "SMALLINT")
public static SmallIntLiteral addSmallint(SmallIntLiteral first, SmallIntLiteral second) {
short result = (short) Math.addExact(first.getValue(), second.getValue());
return new SmallIntLiteral(result);
}
@ExecFunction(name = "add", argTypes = {"INT", "INT"}, returnType = "INT")
public static IntegerLiteral addInt(IntegerLiteral first, IntegerLiteral second) {
int result = Math.addExact(first.getValue(), second.getValue());
return new IntegerLiteral(result);
}
@ExecFunction(name = "add", argTypes = {"BIGINT", "BIGINT"}, returnType = "BIGINT")
public static BigIntLiteral addBigint(BigIntLiteral first, BigIntLiteral second) {
long result = Math.addExact(first.getValue(), second.getValue());
return new BigIntLiteral(result);
}
@ExecFunction(name = "add", argTypes = {"DOUBLE", "DOUBLE"}, returnType = "DOUBLE")
public static DoubleLiteral addDouble(DoubleLiteral first, DoubleLiteral second) {
double result = first.getValue() + second.getValue();
return new DoubleLiteral(result);
}
@ExecFunction(name = "add", argTypes = {"DECIMAL", "DECIMAL"}, returnType = "DECIMAL")
public static DecimalLiteral addDecimal(DecimalLiteral first, DecimalLiteral second) {
BigDecimal result = first.getValue().add(second.getValue());
return new DecimalLiteral(result);
}
@ExecFunction(name = "subtract", argTypes = {"TINYINT", "TINYINT"}, returnType = "TINYINT")
public static TinyIntLiteral subtractTinyint(TinyIntLiteral first, TinyIntLiteral second) {
byte result = (byte) Math.subtractExact(first.getValue(), second.getValue());
return new TinyIntLiteral(result);
}
@ExecFunction(name = "subtract", argTypes = {"SMALLINT", "SMALLINT"}, returnType = "SMALLINT")
public static SmallIntLiteral subtractSmallint(SmallIntLiteral first, SmallIntLiteral second) {
short result = (short) Math.subtractExact(first.getValue(), second.getValue());
return new SmallIntLiteral(result);
}
@ExecFunction(name = "subtract", argTypes = {"INT", "INT"}, returnType = "INT")
public static IntegerLiteral subtractInt(IntegerLiteral first, IntegerLiteral second) {
int result = Math.subtractExact(first.getValue(), second.getValue());
return new IntegerLiteral(result);
}
@ExecFunction(name = "subtract", argTypes = {"BIGINT", "BIGINT"}, returnType = "BIGINT")
public static BigIntLiteral subtractBigint(BigIntLiteral first, BigIntLiteral second) {
long result = Math.subtractExact(first.getValue(), second.getValue());
return new BigIntLiteral(result);
}
@ExecFunction(name = "subtract", argTypes = {"DOUBLE", "DOUBLE"}, returnType = "DOUBLE")
public static DoubleLiteral subtractDouble(DoubleLiteral first, DoubleLiteral second) {
double result = first.getValue() - second.getValue();
return new DoubleLiteral(result);
}
@ExecFunction(name = "subtract", argTypes = {"DECIMAL", "DECIMAL"}, returnType = "DECIMAL")
public static DecimalLiteral subtractDecimal(DecimalLiteral first, DecimalLiteral second) {
BigDecimal result = first.getValue().subtract(second.getValue());
return new DecimalLiteral(result);
}
@ExecFunction(name = "multiply", argTypes = {"TINYINT", "TINYINT"}, returnType = "TINYINT")
public static TinyIntLiteral multiplyTinyint(TinyIntLiteral first, TinyIntLiteral second) {
byte result = (byte) Math.multiplyExact(first.getValue(), second.getValue());
return new TinyIntLiteral(result);
}
@ExecFunction(name = "multiply", argTypes = {"SMALLINT", "SMALLINT"}, returnType = "SMALLINT")
public static SmallIntLiteral multiplySmallint(SmallIntLiteral first, SmallIntLiteral second) {
short result = (short) Math.multiplyExact(first.getValue(), second.getValue());
return new SmallIntLiteral(result);
}
@ExecFunction(name = "multiply", argTypes = {"INT", "INT"}, returnType = "INT")
public static IntegerLiteral multiplyInt(IntegerLiteral first, IntegerLiteral second) {
int result = Math.multiplyExact(first.getValue(), second.getValue());
return new IntegerLiteral(result);
}
@ExecFunction(name = "multiply", argTypes = {"BIGINT", "BIGINT"}, returnType = "BIGINT")
public static BigIntLiteral multiplyBigint(BigIntLiteral first, BigIntLiteral second) {
long result = Math.multiplyExact(first.getValue(), second.getValue());
return new BigIntLiteral(result);
}
@ExecFunction(name = "multiply", argTypes = {"DOUBLE", "DOUBLE"}, returnType = "DOUBLE")
public static DoubleLiteral multiplyDouble(DoubleLiteral first, DoubleLiteral second) {
double result = first.getValue() * second.getValue();
return new DoubleLiteral(result);
}
@ExecFunction(name = "multiply", argTypes = {"DECIMAL", "DECIMAL"}, returnType = "DECIMAL")
public static DecimalLiteral multiplyDecimal(DecimalLiteral first, DecimalLiteral second) {
BigDecimal result = first.getValue().multiply(second.getValue());
return new DecimalLiteral(result);
}
@ExecFunction(name = "divide", argTypes = {"DOUBLE", "DOUBLE"}, returnType = "DOUBLE")
public static DoubleLiteral divideDouble(DoubleLiteral first, DoubleLiteral second) {
if (second.getValue() == 0.0) {
return null;
}
double result = first.getValue() / second.getValue();
return new DoubleLiteral(result);
}
@ExecFunction(name = "divide", argTypes = {"DECIMAL", "DECIMAL"}, returnType = "DECIMAL")
public static DecimalLiteral divideDecimal(DecimalLiteral first, DecimalLiteral second) {
if (first.getValue().compareTo(BigDecimal.ZERO) == 0) {
return null;
}
BigDecimal result = first.getValue().divide(second.getValue());
return new DecimalLiteral(result);
}
@ExecFunction(name = "date_sub", argTypes = { "DATETIME", "INT" }, returnType = "DATETIME")
public static DateTimeLiteral dateSub(DateTimeLiteral date, IntegerLiteral day) throws AnalysisException {
return dateAdd(date, new IntegerLiteral(-day.getValue()));
}
@ExecFunction(name = "date_add", argTypes = { "DATETIME", "INT" }, returnType = "DATETIME")
public static DateTimeLiteral dateAdd(DateTimeLiteral date, IntegerLiteral day) throws AnalysisException {
return daysAdd(date, day);
}
@ExecFunction(name = "years_add", argTypes = { "DATETIME", "INT" }, returnType = "DATETIME")
public static DateTimeLiteral yearsAdd(DateTimeLiteral date, IntegerLiteral year) throws AnalysisException {
return date.plusYears(year.getValue());
}
@ExecFunction(name = "months_add", argTypes = { "DATETIME", "INT" }, returnType = "DATETIME")
public static DateTimeLiteral monthsAdd(DateTimeLiteral date, IntegerLiteral month) throws AnalysisException {
return date.plusMonths(month.getValue());
}
@ExecFunction(name = "days_add", argTypes = { "DATETIME", "INT" }, returnType = "DATETIME")
public static DateTimeLiteral daysAdd(DateTimeLiteral date, IntegerLiteral day) throws AnalysisException {
return date.plusDays(day.getValue());
}
@ExecFunction(name = "hours_add", argTypes = { "DATETIME", "INT" }, returnType = "DATETIME")
public static DateTimeLiteral hoursAdd(DateTimeLiteral date, IntegerLiteral hour) throws AnalysisException {
return date.plusHours(hour.getValue());
}
@ExecFunction(name = "minutes_add", argTypes = { "DATETIME", "INT" }, returnType = "DATETIME")
public static DateTimeLiteral minutesAdd(DateTimeLiteral date, IntegerLiteral minute) throws AnalysisException {
return date.plusMinutes(minute.getValue());
}
@ExecFunction(name = "seconds_add", argTypes = { "DATETIME", "INT" }, returnType = "DATETIME")
public static DateTimeLiteral secondsAdd(DateTimeLiteral date, IntegerLiteral second) throws AnalysisException {
return date.plusSeconds(second.getValue());
}
@ExecFunction(name = "years_sub", argTypes = { "DATETIME", "INT" }, returnType = "DATETIME")
public static DateTimeLiteral yearsSub(DateTimeLiteral date, IntegerLiteral year) throws AnalysisException {
return yearsAdd(date, new IntegerLiteral(-year.getValue()));
}
@ExecFunction(name = "months_sub", argTypes = { "DATETIME", "INT" }, returnType = "DATETIME")
public static DateTimeLiteral monthsSub(DateTimeLiteral date, IntegerLiteral month) throws AnalysisException {
return monthsAdd(date, new IntegerLiteral(-month.getValue()));
}
@ExecFunction(name = "days_sub", argTypes = { "DATETIME", "INT" }, returnType = "DATETIME")
public static DateTimeLiteral daysSub(DateTimeLiteral date, IntegerLiteral day) throws AnalysisException {
return daysAdd(date, new IntegerLiteral(-day.getValue()));
}
@ExecFunction(name = "hours_sub", argTypes = { "DATETIME", "INT" }, returnType = "DATETIME")
public static DateTimeLiteral hoursSub(DateTimeLiteral date, IntegerLiteral hour) throws AnalysisException {
return hoursAdd(date, new IntegerLiteral(-hour.getValue()));
}
@ExecFunction(name = "minutes_sub", argTypes = { "DATETIME", "INT" }, returnType = "DATETIME")
public static DateTimeLiteral minutesSub(DateTimeLiteral date, IntegerLiteral minute) throws AnalysisException {
return minutesAdd(date, new IntegerLiteral(-minute.getValue()));
}
@ExecFunction(name = "seconds_sub", argTypes = { "DATETIME", "INT" }, returnType = "DATETIME")
public static DateTimeLiteral secondsSub(DateTimeLiteral date, IntegerLiteral second) throws AnalysisException {
return secondsAdd(date, new IntegerLiteral(-second.getValue()));
}
}

View File

@ -20,6 +20,7 @@ package org.apache.doris.nereids.trees.expressions.functions.scalar;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.AlwaysNullable;
import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.expressions.typecoercion.ImplicitCastInputTypes;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.IntegerType;
@ -50,7 +51,7 @@ public class Substring extends ScalarFunction implements ImplicitCastInputTypes,
}
public Substring(Expression str, Expression pos) {
super("substring", str, pos);
super("substring", str, pos, Literal.of(Integer.MAX_VALUE));
}
public Expression getSource() {

View File

@ -19,14 +19,11 @@ package org.apache.doris.nereids.trees.expressions.literal;
import org.apache.doris.analysis.LiteralExpr;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.DateTimeType;
import org.apache.doris.nereids.types.DateType;
import org.apache.doris.nereids.util.DateUtils;
import com.google.common.base.Preconditions;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.joda.time.LocalDateTime;
@ -107,32 +104,6 @@ public class DateLiteral extends Literal {
}
}
@Override
protected Expression uncheckedCastTo(DataType targetType) throws AnalysisException {
if (getDataType().equals(targetType)) {
return this;
}
if (targetType.isDate()) {
if (getDataType().equals(targetType)) {
return this;
}
if (targetType.equals(DateType.INSTANCE)) {
return new DateLiteral(this.year, this.month, this.day);
} else if (targetType.equals(DateTimeType.INSTANCE)) {
return new DateTimeLiteral(this.year, this.month, this.day, 0, 0, 0);
} else {
throw new AnalysisException("Error date literal type");
}
}
//todo other target type cast
return this;
}
public DateLiteral withDataType(DataType type) {
Preconditions.checkArgument(type.isDate() || type.isDateTime());
return new DateLiteral(this, type);
}
@Override
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
return visitor.visitDateLiteral(this, context);
@ -153,6 +124,11 @@ public class DateLiteral extends Literal {
return String.format("%04d-%02d-%02d", year, month, day);
}
@Override
public String getStringValue() {
return String.format("%04d-%02d-%02d", year, month, day);
}
@Override
public LiteralExpr toLegacyLiteral() {
return new org.apache.doris.analysis.DateLiteral(year, month, day);
@ -169,6 +145,11 @@ public class DateLiteral extends Literal {
public long getDay() {
return day;
}
public DateLiteral plusDays(int days) {
LocalDateTime dateTime = LocalDateTime.parse(getStringValue(), DATE_FORMATTER).plusDays(days);
return new DateLiteral(dateTime.getYear(), dateTime.getMonthOfYear(), dateTime.getDayOfMonth());
}
}

View File

@ -19,11 +19,8 @@ package org.apache.doris.nereids.trees.expressions.literal;
import org.apache.doris.analysis.LiteralExpr;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.DateTimeType;
import org.apache.doris.nereids.types.DateType;
import org.apache.doris.nereids.util.DateUtils;
import org.apache.logging.log4j.LogManager;
@ -31,6 +28,8 @@ import org.apache.logging.log4j.Logger;
import org.joda.time.LocalDateTime;
import org.joda.time.format.DateTimeFormatter;
import java.util.Objects;
/**
* date time literal.
*/
@ -39,7 +38,8 @@ public class DateTimeLiteral extends DateLiteral {
private static final int DATETIME_TO_MINUTE_STRING_LENGTH = 16;
private static final int DATETIME_TO_HOUR_STRING_LENGTH = 13;
private static final int DATETIME_DEFAULT_STRING_LENGTH = 10;
private static DateTimeFormatter DATE_TIME_DEFAULT_FORMATTER = null;
private static DateTimeFormatter DATE_TIME_FORMATTER = null;
private static DateTimeFormatter DATE_TIME_FORMATTER_TO_HOUR = null;
private static DateTimeFormatter DATE_TIME_FORMATTER_TO_MINUTE = null;
@ -55,6 +55,7 @@ public class DateTimeLiteral extends DateLiteral {
DATE_TIME_FORMATTER_TO_HOUR = DateUtils.formatBuilder("%Y-%m-%d %H").toFormatter();
DATE_TIME_FORMATTER_TO_MINUTE = DateUtils.formatBuilder("%Y-%m-%d %H:%i").toFormatter();
DATE_TIME_FORMATTER_TWO_DIGIT = DateUtils.formatBuilder("%y-%m-%d %H:%i:%s").toFormatter();
DATE_TIME_DEFAULT_FORMATTER = DateUtils.formatBuilder("%Y-%m-%d").toFormatter();
} catch (AnalysisException e) {
LOG.error("invalid date format", e);
System.exit(-1);
@ -89,6 +90,8 @@ public class DateTimeLiteral extends DateLiteral {
dateTime = DATE_TIME_FORMATTER_TO_MINUTE.parseLocalDateTime(s);
} else if (s.length() == DATETIME_TO_HOUR_STRING_LENGTH) {
dateTime = DATE_TIME_FORMATTER_TO_HOUR.parseLocalDateTime(s);
} else if (s.length() == DATETIME_DEFAULT_STRING_LENGTH) {
dateTime = DATE_TIME_DEFAULT_FORMATTER.parseLocalDateTime(s);
} else {
dateTime = DATE_TIME_FORMATTER.parseLocalDateTime(s);
}
@ -104,27 +107,6 @@ public class DateTimeLiteral extends DateLiteral {
}
}
@Override
protected Expression uncheckedCastTo(DataType targetType) throws AnalysisException {
if (getDataType().equals(targetType)) {
return this;
}
if (targetType.isDate()) {
if (getDataType().equals(targetType)) {
return this;
}
if (targetType.equals(DateType.INSTANCE)) {
return new DateLiteral(this.year, this.month, this.day);
} else if (targetType.equals(DateTimeType.INSTANCE)) {
return new DateTimeLiteral(this.year, this.month, this.day, this.hour, this.minute, this.second);
} else {
throw new AnalysisException("Error date literal type");
}
}
//todo other target type cast
return this;
}
@Override
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
return visitor.visitDateTimeLiteral(this, context);
@ -145,6 +127,47 @@ public class DateTimeLiteral extends DateLiteral {
return String.format("%04d-%02d-%02d %02d:%02d:%02d", year, month, day, hour, minute, second);
}
@Override
public String getStringValue() {
return String.format("%04d-%02d-%02d %02d:%02d:%02d", year, month, day, hour, minute, second);
}
public DateTimeLiteral plusYears(int years) {
LocalDateTime d = LocalDateTime.parse(getStringValue(), DATE_TIME_FORMATTER).plusYears(years);
return new DateTimeLiteral(d.getYear(), d.getMonthOfYear(), d.getDayOfMonth(),
d.getHourOfDay(), d.getMinuteOfHour(), d.getSecondOfMinute());
}
public DateTimeLiteral plusMonths(int months) {
LocalDateTime d = LocalDateTime.parse(getStringValue(), DATE_TIME_FORMATTER).plusMonths(months);
return new DateTimeLiteral(d.getYear(), d.getMonthOfYear(), d.getDayOfMonth(),
d.getHourOfDay(), d.getMinuteOfHour(), d.getSecondOfMinute());
}
public DateTimeLiteral plusDays(int days) {
LocalDateTime d = LocalDateTime.parse(getStringValue(), DATE_TIME_FORMATTER).plusDays(days);
return new DateTimeLiteral(d.getYear(), d.getMonthOfYear(), d.getDayOfMonth(),
d.getHourOfDay(), d.getMinuteOfHour(), d.getSecondOfMinute());
}
public DateTimeLiteral plusHours(int hours) {
LocalDateTime d = LocalDateTime.parse(getStringValue(), DATE_TIME_FORMATTER).plusHours(hours);
return new DateTimeLiteral(d.getYear(), d.getMonthOfYear(), d.getDayOfMonth(),
d.getHourOfDay(), d.getMinuteOfHour(), d.getSecondOfMinute());
}
public DateTimeLiteral plusMinutes(int minutes) {
LocalDateTime d = LocalDateTime.parse(getStringValue(), DATE_TIME_FORMATTER).plusMinutes(minutes);
return new DateTimeLiteral(d.getYear(), d.getMonthOfYear(), d.getDayOfMonth(),
d.getHourOfDay(), d.getMinuteOfHour(), d.getSecondOfMinute());
}
public DateTimeLiteral plusSeconds(int seconds) {
LocalDateTime d = LocalDateTime.parse(getStringValue(), DATE_TIME_FORMATTER).plusSeconds(seconds);
return new DateTimeLiteral(d.getYear(), d.getMonthOfYear(), d.getDayOfMonth(),
d.getHourOfDay(), d.getMinuteOfHour(), d.getSecondOfMinute());
}
@Override
public LiteralExpr toLegacyLiteral() {
return new org.apache.doris.analysis.DateLiteral(year, month, day, hour, minute, second);
@ -161,4 +184,16 @@ public class DateTimeLiteral extends DateLiteral {
public long getSecond() {
return second;
}
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
DateTimeLiteral other = (DateTimeLiteral) o;
return Objects.equals(getValue(), other.getValue());
}
}

View File

@ -18,13 +18,20 @@
package org.apache.doris.nereids.trees.expressions.literal;
import org.apache.doris.analysis.LiteralExpr;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.exceptions.UnboundException;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.shape.LeafExpression;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.CharType;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.StringType;
import org.apache.commons.lang3.StringUtils;
import java.math.BigDecimal;
import java.math.BigInteger;
import java.util.Locale;
import java.util.Objects;
/**
@ -51,17 +58,21 @@ public abstract class Literal extends Expression implements LeafExpression {
if (value == null) {
return new NullLiteral();
} else if (value instanceof Byte) {
return new TinyIntLiteral(((Byte) value).byteValue());
return new TinyIntLiteral((byte) value);
} else if (value instanceof Short) {
return new SmallIntLiteral(((Short) value).shortValue());
return new SmallIntLiteral((short) value);
} else if (value instanceof Integer) {
return new IntegerLiteral((Integer) value);
return new IntegerLiteral((int) value);
} else if (value instanceof Long) {
return new BigIntLiteral(((Long) value).longValue());
return new BigIntLiteral((long) value);
} else if (value instanceof BigInteger) {
return new LargeIntLiteral((BigInteger) value);
} else if (value instanceof Float) {
return new FloatLiteral((float) value);
} else if (value instanceof Double) {
return new DoubleLiteral((double) value);
} else if (value instanceof Boolean) {
return BooleanLiteral.of((Boolean) value);
return BooleanLiteral.of((boolean) value);
} else if (value instanceof String) {
return new StringLiteral((String) value);
} else {
@ -71,6 +82,10 @@ public abstract class Literal extends Expression implements LeafExpression {
public abstract Object getValue();
public String getStringValue() {
return String.valueOf(getValue());
}
@Override
public DataType getDataType() throws UnboundException {
return dataType;
@ -91,6 +106,89 @@ public abstract class Literal extends Expression implements LeafExpression {
return visitor.visitLiteral(this, context);
}
/**
* literal expr compare.
*/
public int compareTo(Literal other) {
if (isNullLiteral() && other.isNullLiteral()) {
return 0;
} else if (isNullLiteral() || other.isNullLiteral()) {
return isNullLiteral() ? -1 : 1;
}
DataType oType = other.getDataType();
DataType type = getDataType();
if (!type.equals(oType)) {
throw new RuntimeException("data type not equal!");
} else if (type.isBooleanType()) {
return Boolean.compare((boolean) getValue(), (boolean) other.getValue());
} else if (type.isTinyIntType()) {
return Byte.compare((byte) getValue(), (byte) other.getValue());
} else if (type.isSmallIntType()) {
return Short.compare((short) getValue(), (short) other.getValue());
} else if (type.isIntType()) {
return Integer.compare((int) getValue(), (int) other.getValue());
} else if (type.isBigIntType()) {
return Long.compare((long) getValue(), (long) other.getValue());
} else if (type.isLargeIntType()) {
return ((BigInteger) getValue()).compareTo((BigInteger) other.getValue());
} else if (type.isFloatType()) {
return Float.compare((float) getValue(), (float) other.getValue());
} else if (type.isDoubleType()) {
return Double.compare((double) getValue(), (double) other.getValue());
} else if (type.isDecimalType()) {
return Long.compare((Long) getValue(), (Long) other.getValue());
} else if (type.isDateType()) {
// todo process date
} else if (type.isDecimalType()) {
return ((BigDecimal) getValue()).compareTo((BigDecimal) other.getValue());
} else if (type instanceof StringType) {
return StringUtils.compare((String) getValue(), (String) other.getValue());
}
return -1;
}
@Override
protected Expression uncheckedCastTo(DataType targetType) throws AnalysisException {
String desc = getStringValue();
if (targetType.isBooleanType()) {
if ("0".equals(desc) || "false".equals(desc.toLowerCase(Locale.ROOT))) {
return Literal.of(false);
}
if ("1".equals(desc) || "true".equals(desc.toLowerCase(Locale.ROOT))) {
return Literal.of(true);
}
}
if (targetType.isTinyIntType()) {
return Literal.of(Double.valueOf(desc).byteValue());
} else if (targetType.isSmallIntType()) {
return Literal.of(Double.valueOf(desc).shortValue());
} else if (targetType.isIntType()) {
return Literal.of(Double.valueOf(desc).intValue());
} else if (targetType.isBigIntType()) {
return Literal.of(Double.valueOf(desc).longValue());
} else if (targetType.isLargeIntType()) {
return Literal.of(new BigInteger(desc));
} else if (targetType.isFloatType()) {
return Literal.of(Float.parseFloat(desc));
} else if (targetType.isDoubleType()) {
return Literal.of(Double.parseDouble(desc));
} else if (targetType.isCharType()) {
return new CharLiteral(desc, ((CharType) targetType).getLen());
} else if (targetType.isVarcharType()) {
return new VarcharLiteral(desc, desc.length());
} else if (targetType.isStringType()) {
return Literal.of(desc);
} else if (targetType.isDate()) {
return new DateLiteral(desc);
} else if (targetType.isDateTime()) {
return new DateTimeLiteral(desc);
} else if (targetType.isDecimalType()) {
return new DecimalLiteral(BigDecimal.valueOf(Double.parseDouble(desc)));
}
throw new AnalysisException("no support cast!");
}
@Override
public boolean equals(Object o) {
if (this == o) {
@ -114,4 +212,5 @@ public abstract class Literal extends Expression implements LeafExpression {
}
public abstract LiteralExpr toLegacyLiteral();
}

View File

@ -18,10 +18,7 @@
package org.apache.doris.nereids.trees.expressions.literal;
import org.apache.doris.analysis.LiteralExpr;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.StringType;
/**
@ -56,30 +53,6 @@ public class StringLiteral extends Literal {
return new org.apache.doris.analysis.StringLiteral(value);
}
@Override
protected Expression uncheckedCastTo(DataType targetType) throws AnalysisException {
if (getDataType().equals(targetType)) {
return this;
}
if (targetType.isDateType()) {
return convertToDate(targetType);
} else if (targetType.isIntType()) {
return new IntegerLiteral(Integer.parseInt(value));
}
//todo other target type cast
return this;
}
private DateLiteral convertToDate(DataType targetType) throws AnalysisException {
DateLiteral dateLiteral = null;
if (targetType.isDate()) {
dateLiteral = new DateLiteral(value);
} else if (targetType.isDateTime()) {
dateLiteral = new DateTimeLiteral(value);
}
return dateLiteral;
}
@Override
public String toString() {
return "'" + value + "'";

View File

@ -20,7 +20,6 @@ package org.apache.doris.nereids.trees.expressions.literal;
import org.apache.doris.analysis.LiteralExpr;
import org.apache.doris.analysis.StringLiteral;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.VarcharType;
@ -68,20 +67,6 @@ public class VarcharLiteral extends Literal {
return "'" + value + "'";
}
// Temporary way to process type coercion in TimestampArithmetic, should be replaced by TypeCoercion rule.
@Override
protected Expression uncheckedCastTo(DataType targetType) throws AnalysisException {
if (getDataType().equals(targetType)) {
return this;
}
if (targetType.isDateType()) {
return convertToDate(targetType);
} else if (targetType.isIntType()) {
return new IntegerLiteral(Integer.parseInt(value));
}
return this;
}
private DateLiteral convertToDate(DataType targetType) throws AnalysisException {
DateLiteral dateLiteral = null;
if (targetType.isDate()) {

View File

@ -40,6 +40,7 @@ import org.apache.doris.nereids.trees.expressions.GreaterThan;
import org.apache.doris.nereids.trees.expressions.GreaterThanEqual;
import org.apache.doris.nereids.trees.expressions.InPredicate;
import org.apache.doris.nereids.trees.expressions.InSubquery;
import org.apache.doris.nereids.trees.expressions.IsNull;
import org.apache.doris.nereids.trees.expressions.LessThan;
import org.apache.doris.nereids.trees.expressions.LessThanEqual;
import org.apache.doris.nereids.trees.expressions.Like;
@ -277,6 +278,10 @@ public abstract class ExpressionVisitor<R, C> {
return visit(inPredicate, context);
}
public R visitIsNull(IsNull isNull, C context) {
return visit(isNull, context);
}
public R visitInSubquery(InSubquery in, C context) {
return visitSubqueryExpr(in, context);
}

View File

@ -165,6 +165,12 @@ public abstract class DataType implements AbstractDataType {
return FloatType.INSTANCE;
case "double":
return DoubleType.INSTANCE;
case "decimal":
return DecimalType.SYSTEM_DEFAULT;
case "char":
return CharType.INSTANCE;
case "varchar":
return VarcharType.SYSTEM_DEFAULT;
case "text":
case "string":
return StringType.INSTANCE;
@ -303,18 +309,50 @@ public abstract class DataType implements AbstractDataType {
return 0;
}
public boolean isDate() {
return this instanceof DateType;
public boolean isBooleanType() {
return this instanceof BooleanType;
}
public boolean isTinyIntType() {
return this instanceof TinyIntType;
}
public boolean isSmallIntType() {
return this instanceof SmallIntType;
}
public boolean isIntType() {
return this instanceof IntegerType;
}
public boolean isBigIntType() {
return this instanceof BigIntType;
}
public boolean isLargeIntType() {
return this instanceof LargeIntType;
}
public boolean isFloatType() {
return this instanceof FloatType;
}
public boolean isDoubleType() {
return this instanceof DoubleType;
}
public boolean isDecimalType() {
return this instanceof DecimalType;
}
public boolean isDateTime() {
return this instanceof DateTimeType;
}
public boolean isDate() {
return this instanceof DateType;
}
public boolean isDateType() {
return isDate() || isDateTime();
}
@ -327,6 +365,14 @@ public abstract class DataType implements AbstractDataType {
return this instanceof NumericType;
}
public boolean isCharType() {
return this instanceof CharType;
}
public boolean isVarcharType() {
return this instanceof VarcharType;
}
public boolean isStringType() {
return this instanceof CharacterType;
}

View File

@ -26,6 +26,8 @@ import org.apache.doris.nereids.trees.expressions.Or;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;
import com.google.common.base.Preconditions;
@ -258,4 +260,20 @@ public class ExpressionUtils {
}
return builder.build();
}
public static boolean isAllLiteral(Expression... children) {
return Arrays.stream(children).allMatch(c -> c instanceof Literal);
}
public static boolean isAllLiteral(List<Expression> children) {
return children.stream().allMatch(c -> c instanceof Literal);
}
public static boolean hasNullLiteral(List<Expression> children) {
return children.stream().anyMatch(c -> c instanceof NullLiteral);
}
public static boolean isAllNullLiteral(List<Expression> children) {
return children.stream().allMatch(c -> c instanceof NullLiteral);
}
}

View File

@ -18,6 +18,7 @@
package org.apache.doris.nereids.util;
import org.apache.doris.nereids.annotation.Developing;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.trees.expressions.Cast;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
@ -123,6 +124,10 @@ public class TypeCoercionUtils {
} else if (expected instanceof DateTimeType) {
returnType = DateTimeType.INSTANCE;
}
} else if (input.isDate()) {
if (expected instanceof DateTimeType) {
returnType = expected.defaultConcreteType();
}
}
if (returnType == null && input instanceof PrimitiveType
@ -202,6 +207,39 @@ public class TypeCoercionUtils {
return Optional.ofNullable(tightestCommonType);
}
/**
* The type used for arithmetic operations.
*/
public static DataType getNumResultType(DataType type) {
if (type.isTinyIntType() || type.isSmallIntType() || type.isIntType() || type.isBigIntType()) {
return BigIntType.INSTANCE;
} else if (type.isLargeIntType()) {
return LargeIntType.INSTANCE;
} else if (type.isFloatType() || type.isDoubleType() || type.isStringType()) {
return DoubleType.INSTANCE;
} else if (type.isDecimalType()) {
return DecimalType.SYSTEM_DEFAULT;
} else if (type.isNullType()) {
return NullType.INSTANCE;
}
throw new AnalysisException("no found appropriate data type.");
}
/**
* The common type used by arithmetic operations.
*/
public static DataType findCommonNumericsType(DataType t1, DataType t2) {
if (t1.isDoubleType() || t2.isDoubleType()) {
return DoubleType.INSTANCE;
} else if (t1.isDecimalType() || t2.isDecimalType()) {
return DecimalType.SYSTEM_DEFAULT;
} else if (t1.isLargeIntType() || t2.isLargeIntType()) {
return LargeIntType.INSTANCE;
} else {
return BigIntType.INSTANCE;
}
}
/**
* find wider common type for data type list.
*/

View File

@ -47,7 +47,9 @@ public class SSBJoinReorderTest extends SSBTestBase {
"(lo_suppkey = s_suppkey)",
"(lo_partkey = p_partkey)"
),
ImmutableList.of("(((c_region = 'AMERICA') AND (s_region = 'AMERICA')) AND ((p_mfgr = 'MFGR#1') OR (p_mfgr = 'MFGR#2')))")
ImmutableList.of("(((CAST(c_region AS STRING) = CAST('AMERICA' AS STRING)) AND (CAST(s_region AS STRING) "
+ "= CAST('AMERICA' AS STRING))) AND ((CAST(p_mfgr AS STRING) = CAST('MFGR#1' AS STRING)) "
+ "OR (CAST(p_mfgr AS STRING) = CAST('MFGR#2' AS STRING))))")
);
}
@ -62,7 +64,10 @@ public class SSBJoinReorderTest extends SSBTestBase {
"(lo_partkey = p_partkey)"
),
ImmutableList.of(
"((((c_region = 'AMERICA') AND (s_region = 'AMERICA')) AND ((d_year = 1997) OR (d_year = 1998))) AND ((p_mfgr = 'MFGR#1') OR (p_mfgr = 'MFGR#2')))")
"((((CAST(c_region AS STRING) = CAST('AMERICA' AS STRING)) AND (CAST(s_region AS STRING) "
+ "= CAST('AMERICA' AS STRING))) AND ((d_year = 1997) OR (d_year = 1998))) "
+ "AND ((CAST(p_mfgr AS STRING) = CAST('MFGR#1' AS STRING)) OR (CAST(p_mfgr AS STRING) "
+ "= CAST('MFGR#2' AS STRING))))")
);
}
@ -76,7 +81,8 @@ public class SSBJoinReorderTest extends SSBTestBase {
"(lo_suppkey = s_suppkey)",
"(lo_partkey = p_partkey)"
),
ImmutableList.of("(((s_nation = 'UNITED STATES') AND ((d_year = 1997) OR (d_year = 1998))) AND (p_category = 'MFGR#14'))")
ImmutableList.of("(((CAST(s_nation AS STRING) = CAST('UNITED STATES' AS STRING)) AND ((d_year = 1997) "
+ "OR (d_year = 1998))) AND (CAST(p_category AS STRING) = CAST('MFGR#14' AS STRING)))")
);
}

View File

@ -25,6 +25,7 @@ import org.apache.doris.nereids.rules.expression.rewrite.rules.TypeCoercion;
import org.apache.doris.nereids.trees.expressions.Add;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.And;
import org.apache.doris.nereids.trees.expressions.Cast;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.GreaterThan;
import org.apache.doris.nereids.trees.expressions.NamedExpressionUtil;
@ -34,6 +35,7 @@ import org.apache.doris.nereids.trees.expressions.functions.agg.Min;
import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.expressions.literal.TinyIntLiteral;
import org.apache.doris.nereids.types.BigIntType;
import org.apache.doris.nereids.types.TinyIntType;
import org.apache.doris.nereids.util.FieldChecker;
import org.apache.doris.nereids.util.PatternMatchSupported;
@ -168,7 +170,7 @@ public class HavingClauseTest extends AnalyzeCheckTestBase implements PatternMat
logicalAggregate(
logicalOlapScan()
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2)))
).when(FieldChecker.check("predicates", new GreaterThan(sumA2.toSlot(), new TinyIntLiteral((byte) 0))))
).when(FieldChecker.check("predicates", new GreaterThan(sumA2.toSlot(), Literal.of(0L))))
).when(FieldChecker.check("projects", Lists.newArrayList(a1.toSlot()))));
NamedExpressionUtil.clear();
@ -181,7 +183,7 @@ public class HavingClauseTest extends AnalyzeCheckTestBase implements PatternMat
logicalAggregate(
logicalOlapScan()
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2)))
).when(FieldChecker.check("predicates", new GreaterThan(sumA2.toSlot(), new TinyIntLiteral((byte) 0))))));
).when(FieldChecker.check("predicates", new GreaterThan(sumA2.toSlot(), Literal.of(0L))))));
NamedExpressionUtil.clear();
sql = "SELECT a1, SUM(a2) as value FROM t1 GROUP BY a1 HAVING SUM(a2) > 0";
@ -201,7 +203,7 @@ public class HavingClauseTest extends AnalyzeCheckTestBase implements PatternMat
logicalAggregate(
logicalOlapScan()
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, value)))
).when(FieldChecker.check("predicates", new GreaterThan(value.toSlot(), new TinyIntLiteral((byte) 0))))));
).when(FieldChecker.check("predicates", new GreaterThan(value.toSlot(), Literal.of(0L))))));
NamedExpressionUtil.clear();
sql = "SELECT a1, SUM(a2) as value FROM t1 GROUP BY a1 HAVING value > 0";
@ -212,7 +214,7 @@ public class HavingClauseTest extends AnalyzeCheckTestBase implements PatternMat
logicalAggregate(
logicalOlapScan()
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, value)))
).when(FieldChecker.check("predicates", new GreaterThan(value.toSlot(), new TinyIntLiteral((byte) 0))))));
).when(FieldChecker.check("predicates", new GreaterThan(value.toSlot(), Literal.of(0L))))));
NamedExpressionUtil.clear();
sql = "SELECT a1, SUM(a2) FROM t1 GROUP BY a1 HAVING MIN(pk) > 0";
@ -237,7 +239,7 @@ public class HavingClauseTest extends AnalyzeCheckTestBase implements PatternMat
logicalAggregate(
logicalOlapScan()
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2, minPK)))
).when(FieldChecker.check("predicates", new GreaterThan(minPK.toSlot(), new TinyIntLiteral((byte) 0))))
).when(FieldChecker.check("predicates", new GreaterThan(minPK.toSlot(), Literal.of((byte) 0))))
).when(FieldChecker.check("projects", Lists.newArrayList(a1.toSlot(), sumA2.toSlot()))));
NamedExpressionUtil.clear();
@ -250,7 +252,7 @@ public class HavingClauseTest extends AnalyzeCheckTestBase implements PatternMat
logicalAggregate(
logicalOlapScan()
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA1A2)))
).when(FieldChecker.check("predicates", new GreaterThan(sumA1A2.toSlot(), new TinyIntLiteral((byte) 0))))));
).when(FieldChecker.check("predicates", new GreaterThan(sumA1A2.toSlot(), Literal.of(0L))))));
NamedExpressionUtil.clear();
sql = "SELECT a1, SUM(a1 + a2) FROM t1 GROUP BY a1 HAVING SUM(a1 + a2 + 3) > 0";
@ -263,7 +265,7 @@ public class HavingClauseTest extends AnalyzeCheckTestBase implements PatternMat
logicalAggregate(
logicalOlapScan()
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA1A2, sumA1A23)))
).when(FieldChecker.check("predicates", new GreaterThan(sumA1A23.toSlot(), new TinyIntLiteral((byte) 0))))
).when(FieldChecker.check("predicates", new GreaterThan(sumA1A23.toSlot(), Literal.of(0L))))
).when(FieldChecker.check("projects", Lists.newArrayList(a1.toSlot(), sumA1A2.toSlot()))));
NamedExpressionUtil.clear();
@ -276,7 +278,7 @@ public class HavingClauseTest extends AnalyzeCheckTestBase implements PatternMat
logicalAggregate(
logicalOlapScan()
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, countStar)))
).when(FieldChecker.check("predicates", new GreaterThan(countStar.toSlot(), new TinyIntLiteral((byte) 0))))
).when(FieldChecker.check("predicates", new GreaterThan(countStar.toSlot(), Literal.of(0L))))
).when(FieldChecker.check("projects", Lists.newArrayList(a1.toSlot()))));
NamedExpressionUtil.clear();
}
@ -310,7 +312,8 @@ public class HavingClauseTest extends AnalyzeCheckTestBase implements PatternMat
)
)
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2, sumB1)))
).when(FieldChecker.check("predicates", new GreaterThan(a1, sumB1.toSlot())))
).when(FieldChecker.check("predicates", new GreaterThan(new Cast(a1, BigIntType.INSTANCE),
sumB1.toSlot())))
).when(FieldChecker.check("projects", Lists.newArrayList(a1.toSlot(), sumA2.toSlot()))));
NamedExpressionUtil.clear();
}
@ -353,14 +356,14 @@ public class HavingClauseTest extends AnalyzeCheckTestBase implements PatternMat
ImmutableList.of("default_cluster:test_having", "t1")
);
SlotReference a2 = new SlotReference(
new ExprId(3), "a1", TinyIntType.INSTANCE, true,
new ExprId(3), "a2", TinyIntType.INSTANCE, true,
ImmutableList.of("default_cluster:test_having", "t1")
);
Alias pk1 = new Alias(new ExprId(7), new Add(pk, Literal.of((byte) 1)), "(pk + 1)");
Alias pk11 = new Alias(new ExprId(8), new Add(new Add(pk, Literal.of((byte) 1)), Literal.of((byte) 1)), "((pk + 1) + 1)");
Alias pk2 = new Alias(new ExprId(9), new Add(pk, Literal.of((byte) 2)), "(pk + 2)");
Alias sumA1 = new Alias(new ExprId(10), new Sum(a1), "SUM(a1)");
Alias countA11 = new Alias(new ExprId(11), new Add(new Count(a1, false), Literal.of((byte) 1)), "(COUNT(a1) + 1)");
Alias countA11 = new Alias(new ExprId(11), new Add(new Count(a1, false), Literal.of(1L)), "(COUNT(a1) + 1)");
Alias sumA1A2 = new Alias(new ExprId(12), new Sum(new Add(a1, a2)), "SUM((a1 + a2))");
Alias v1 = new Alias(new ExprId(0), new Count(a2, false), "v1");
PlanChecker.from(connectContext).analyze(sql)
@ -382,11 +385,11 @@ public class HavingClauseTest extends AnalyzeCheckTestBase implements PatternMat
new And(
new And(
new GreaterThan(pk.toSlot(), Literal.of((byte) 0)),
new GreaterThan(countA11.toSlot(), Literal.of((byte) 0))),
new GreaterThan(new Add(sumA1A2.toSlot(), Literal.of((byte) 1)), Literal.of((byte) 0))),
new GreaterThan(new Add(v1.toSlot(), Literal.of((byte) 1)), Literal.of((byte) 0))
new GreaterThan(countA11.toSlot(), Literal.of(0L))),
new GreaterThan(new Add(sumA1A2.toSlot(), Literal.of(1L)), Literal.of(0L))),
new GreaterThan(new Add(v1.toSlot(), Literal.of(1L)), Literal.of(0L))
),
new GreaterThan(v1.toSlot(), Literal.of((byte) 0))
new GreaterThan(v1.toSlot(), Literal.of(0L))
)
))
).when(FieldChecker.check(

View File

@ -53,7 +53,7 @@ public class FunctionRegistryTest implements PatternMatchSupported {
logicalOneRowRelation().when(r -> {
Year year = (Year) r.getProjects().get(0).child(0);
Assertions.assertEquals("2021-01-01",
((Literal) year.getArguments().get(0)).getValue());
((Literal) year.getArguments().get(0).child(0)).getValue());
return true;
})
);
@ -70,13 +70,13 @@ public class FunctionRegistryTest implements PatternMatchSupported {
logicalOneRowRelation().when(r -> {
Substring firstSubstring = (Substring) r.getProjects().get(0).child(0);
Assertions.assertEquals("abc", ((Literal) firstSubstring.getSource()).getValue());
Assertions.assertEquals((byte) 1, ((Literal) firstSubstring.getPosition()).getValue());
Assertions.assertEquals((byte) 2, ((Literal) firstSubstring.getLength().get()).getValue());
Assertions.assertEquals((byte) 1, ((Literal) firstSubstring.getPosition().child(0)).getValue());
Assertions.assertEquals((byte) 2, ((Literal) firstSubstring.getLength().get().child(0)).getValue());
Substring secondSubstring = (Substring) r.getProjects().get(1).child(0);
Assertions.assertTrue(secondSubstring.getSource() instanceof Substring);
Assertions.assertEquals((byte) 1, ((Literal) secondSubstring.getPosition()).getValue());
Assertions.assertEquals((byte) 2, ((Literal) secondSubstring.getLength().get()).getValue());
Assertions.assertEquals((byte) 1, ((Literal) secondSubstring.getPosition().child(0)).getValue());
Assertions.assertEquals((byte) 2, ((Literal) secondSubstring.getLength().get().child(0)).getValue());
return true;
})
);
@ -93,13 +93,13 @@ public class FunctionRegistryTest implements PatternMatchSupported {
logicalOneRowRelation().when(r -> {
Substring firstSubstring = (Substring) r.getProjects().get(0).child(0);
Assertions.assertEquals("abc", ((Literal) firstSubstring.getSource()).getValue());
Assertions.assertEquals((byte) 1, ((Literal) firstSubstring.getPosition()).getValue());
Assertions.assertFalse(firstSubstring.getLength().isPresent());
Assertions.assertEquals((byte) 1, ((Literal) firstSubstring.getPosition().child(0)).getValue());
Assertions.assertTrue(firstSubstring.getLength().isPresent());
Substring secondSubstring = (Substring) r.getProjects().get(1).child(0);
Assertions.assertEquals("def", ((Literal) secondSubstring.getSource()).getValue());
Assertions.assertEquals((byte) 2, ((Literal) secondSubstring.getPosition()).getValue());
Assertions.assertEquals((byte) 3, ((Literal) secondSubstring.getLength().get()).getValue());
Assertions.assertEquals((byte) 2, ((Literal) secondSubstring.getPosition().child(0)).getValue());
Assertions.assertEquals((byte) 3, ((Literal) secondSubstring.getLength().get().child(0)).getValue());
return true;
})
);

View File

@ -40,7 +40,7 @@ public class ExpressionRewriteTest {
@Test
public void testNotRewrite() {
executor = new ExpressionRuleExecutor(SimplifyNotExprRule.INSTANCE);
executor = new ExpressionRuleExecutor(ImmutableList.of(SimplifyNotExprRule.INSTANCE));
assertRewrite("not x", "not x");
assertRewrite("not not x", "x");
@ -65,7 +65,7 @@ public class ExpressionRewriteTest {
@Test
public void testNormalizeExpressionRewrite() {
executor = new ExpressionRuleExecutor(NormalizeBinaryPredicatesRule.INSTANCE);
executor = new ExpressionRuleExecutor(ImmutableList.of(NormalizeBinaryPredicatesRule.INSTANCE));
assertRewrite("1 = 1", "1 = 1");
assertRewrite("2 > x", "x < 2");
@ -77,7 +77,7 @@ public class ExpressionRewriteTest {
@Test
public void testDistinctPredicatesRewrite() {
executor = new ExpressionRuleExecutor(DistinctPredicatesRule.INSTANCE);
executor = new ExpressionRuleExecutor(ImmutableList.of(DistinctPredicatesRule.INSTANCE));
assertRewrite("a = 1", "a = 1");
assertRewrite("a = 1 and a = 1", "a = 1");
@ -89,7 +89,7 @@ public class ExpressionRewriteTest {
@Test
public void testExtractCommonFactorRewrite() {
executor = new ExpressionRuleExecutor(ExtractCommonFactorRule.INSTANCE);
executor = new ExpressionRuleExecutor(ImmutableList.of(ExtractCommonFactorRule.INSTANCE));
assertRewrite("a", "a");
@ -142,7 +142,8 @@ public class ExpressionRewriteTest {
@Test
public void testBetweenToCompoundRule() {
executor = new ExpressionRuleExecutor(ImmutableList.of(BetweenToCompoundRule.INSTANCE, SimplifyNotExprRule.INSTANCE));
executor = new ExpressionRuleExecutor(ImmutableList.of(BetweenToCompoundRule.INSTANCE,
SimplifyNotExprRule.INSTANCE));
assertRewrite("a between c and d", "(a >= c) and (a <= d)");
assertRewrite("a not between c and d)", "(a < c) or (a > d)");

View File

@ -0,0 +1,341 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.expression.rewrite;
import org.apache.doris.analysis.ArithmeticExpr.Operator;
import org.apache.doris.nereids.analyzer.UnboundSlot;
import org.apache.doris.nereids.parser.NereidsParser;
import org.apache.doris.nereids.rules.expression.rewrite.rules.FoldConstantRuleOnFE;
import org.apache.doris.nereids.rules.expression.rewrite.rules.TypeCoercion;
import org.apache.doris.nereids.trees.expressions.Add;
import org.apache.doris.nereids.trees.expressions.Cast;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.GreaterThan;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.TimestampArithmetic;
import org.apache.doris.nereids.trees.expressions.literal.DateTimeLiteral;
import org.apache.doris.nereids.trees.expressions.literal.IntervalLiteral.TimeUnit;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.types.BigIntType;
import org.apache.doris.nereids.types.BooleanType;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.DateTimeType;
import org.apache.doris.nereids.types.DoubleType;
import org.apache.doris.nereids.types.IntegerType;
import org.apache.doris.nereids.types.StringType;
import org.apache.doris.nereids.types.TinyIntType;
import org.apache.doris.nereids.types.VarcharType;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import java.util.List;
import java.util.Locale;
import java.util.Map;
public class FoldConstantTest {
private static final NereidsParser PARSER = new NereidsParser();
private ExpressionRuleExecutor executor;
@Test
public void testCaseWhenFold() {
executor = new ExpressionRuleExecutor(ImmutableList.of(TypeCoercion.INSTANCE, FoldConstantRuleOnFE.INSTANCE));
assertRewrite("case when 1 = 2 then 1 when '1' < 2 then 2 else 3 end", "2");
assertRewrite("case when 1 = 2 then 1 when '1' > 2 then 2 end", "null");
assertRewrite("case when (1 + 5) / 2 > 2 then 4 when '1' < 2 then 2 else 3 end", "4");
assertRewrite("case when not 1 = 2 then 1 when '1' > 2 then 2 end", "1");
assertRewrite("case when 1 = 2 then 1 when 3 in ('1',2 + 8 / 2,3,4) then 2 end", "2");
assertRewrite("case when TA = 2 then 1 when 3 in ('1',2 + 8 / 2,3,4) then 2 end", "CASE WHEN (TA = 2) THEN 1 ELSE 2 END");
assertRewrite("case when TA = 2 then 5 when 3 in (2,3,4) then 2 else 4 end", "CASE WHEN (TA = 2) THEN 5 ELSE 2 END");
assertRewrite("case when TA = 2 then 1 when TB in (2,3,4) then 2 else 4 end", "CASE WHEN (TA = 2) THEN 1 WHEN TB IN (2, 3, 4) THEN 2 ELSE 4 END");
assertRewrite("case when null = 2 then 1 when 3 in (2,3,4) then 2 else 4 end", "2");
assertRewrite("case when null = 2 then 1 else 4 end", "4");
assertRewrite("case when null = 2 then 1 end", "null");
assertRewrite("case when TA = TB then 1 when TC is null then 2 end", "CASE WHEN (TA = TB) THEN 1 WHEN TC IS NULL THEN 2 ELSE NULL END");
}
@Test
public void testInFold() {
executor = new ExpressionRuleExecutor(ImmutableList.of(TypeCoercion.INSTANCE, FoldConstantRuleOnFE.INSTANCE));
assertRewrite("1 in (1,2,3,4)", "true");
// Type Coercion trans all to string.
assertRewrite("3 in ('1',2 + 8 / 2,3,4)", "true");
assertRewrite("4 / 2 * 1 - (5/2) in ('1',2 + 8 / 2,3,4)", "false");
assertRewrite("null in ('1',2 + 8 / 2,3,4)", "null");
assertRewrite("3 in ('1',null,3,4)", "true");
assertRewrite("TA in (1,null,3,4)", "TA in (1, null, 3, 4)");
assertRewrite("IA in (IB,IC,null)", "IA in (IB,IC,null)");
}
@Test
public void testLogicalFold() {
executor = new ExpressionRuleExecutor(ImmutableList.of(TypeCoercion.INSTANCE, FoldConstantRuleOnFE.INSTANCE));
assertRewrite("10 + 1 > 1 and 1 > 2", "false");
assertRewrite("10 + 1 > 1 and 1 < 2", "true");
assertRewrite("null + 1 > 1 and 1 < 2", "null");
assertRewrite("10 < 3 and 1 > 2", "false");
assertRewrite("6 / 2 - 10 * (6 + 1) > 2 and 10 > 3 and 1 > 2", "false");
assertRewrite("10 + 1 > 1 or 1 > 2", "true");
assertRewrite("null + 1 > 1 or 1 > 2", "null");
assertRewrite("6 / 2 - 10 * (6 + 1) > 2 or 10 > 3 or 1 > 2", "true");
assertRewrite("(1 > 5 and 8 < 10 or 1 = 3) or (1 > 8 + 9 / (10 * 2) or ( 10 = 3))", "false");
assertRewrite("(TA > 1 and 8 < 10 or 1 = 3) or (1 > 3 or ( 10 = 3))", "TA > 1");
assertRewrite("false or false", "false");
assertRewrite("false or true", "true");
assertRewrite("true or false", "true");
assertRewrite("true or true", "true");
assertRewrite("true and true", "true");
assertRewrite("false and true", "false");
assertRewrite("true and false", "false");
assertRewrite("false and false", "false");
assertRewrite("true and null", "null");
assertRewrite("false and null", "false");
assertRewrite("true or null", "true");
assertRewrite("false or null", "null");
assertRewrite("null and null", "null");
}
@Test
public void testIsNullFold() {
executor = new ExpressionRuleExecutor(ImmutableList.of(TypeCoercion.INSTANCE, FoldConstantRuleOnFE.INSTANCE));
assertRewrite("100 is null", "false");
assertRewrite("null is null", "true");
assertRewrite("null is not null", "false");
assertRewrite("100 is not null", "true");
assertRewrite("IA is not null", "IA is not null");
assertRewrite("IA is null", "IA is null");
}
@Test
public void testNotFold() {
executor = new ExpressionRuleExecutor(ImmutableList.of(TypeCoercion.INSTANCE, FoldConstantRuleOnFE.INSTANCE));
assertRewrite("not 1 > 2", "true");
assertRewrite("not null + 1 > 2", "null");
assertRewrite("not (1 + 5) / 2 + (10 - 1) * 3 > 3 * 5 + 1", "false");
}
@Test
public void testCastFold() {
executor = new ExpressionRuleExecutor(ImmutableList.of(FoldConstantRuleOnFE.INSTANCE));
// cast '1' as tinyint
Cast c = new Cast(Literal.of("1"), TinyIntType.INSTANCE);
Expression rewritten = executor.rewrite(c);
Literal expected = Literal.of((byte) 1);
Assertions.assertEquals(rewritten, expected);
}
@Test
public void testCompareFold() {
executor = new ExpressionRuleExecutor(ImmutableList.of(TypeCoercion.INSTANCE, FoldConstantRuleOnFE.INSTANCE));
assertRewrite("'1' = 2", "false");
assertRewrite("1 = 2", "false");
assertRewrite("1 != 2", "true");
assertRewrite("2 > 2", "false");
assertRewrite("3 * 10 + 1 / 2 >= 2", "true");
assertRewrite("3 < 2", "false");
assertRewrite("3 <= 2", "false");
assertRewrite("3 <= null", "null");
assertRewrite("3 >= null", "null");
assertRewrite("null <=> null", "true");
assertRewrite("2 <=> null", "false");
assertRewrite("2 <=> 2", "true");
}
@Test
public void testArithmeticFold() {
executor = new ExpressionRuleExecutor(ImmutableList.of(TypeCoercion.INSTANCE, FoldConstantRuleOnFE.INSTANCE));
assertRewrite("1 + 1", Literal.of((byte) 2));
assertRewrite("1 - 1", Literal.of((byte) 0));
assertRewrite("100 + 100", Literal.of((byte) 200));
assertRewrite("1 - 2", Literal.of((byte) -1));
assertRewrite("1 - 2 > 1", "false");
assertRewrite("1 - 2 + 1 > 1 + 1 - 100", "true");
assertRewrite("10 * 2 / 1 + 1 > (1 + 1) - 100", "true");
// a + 1 > 2
Slot a = SlotReference.of("a", IntegerType.INSTANCE);
Expression e1 = new Add(a, Literal.of(1L));
Expression e2 = new Add(new Cast(a, BigIntType.INSTANCE), Literal.of(1L));
assertRewrite(e1, e2);
// a > (1 + 10) / 2 * (10 + 1)
Expression e3 = PARSER.parseExpression("(1 + 10) / 2 * (10 + 1)");
Expression e4 = new GreaterThan(a, e3);
Expression e5 = new GreaterThan(new Cast(a, DoubleType.INSTANCE), Literal.of(60.5D));
assertRewrite(e4, e5);
// a > 1
Expression e6 = new GreaterThan(a, Literal.of(1));
assertRewrite(e6, e6);
assertRewrite(a, a);
// a
assertRewrite(a, a);
// 1
Literal one = Literal.of(1);
assertRewrite(one, one);
}
@Test
public void testTimestampFold() {
executor = new ExpressionRuleExecutor(ImmutableList.of(TypeCoercion.INSTANCE, FoldConstantRuleOnFE.INSTANCE));
String interval = "'1991-05-01' - interval 1 day";
Expression e7 = process((TimestampArithmetic) PARSER.parseExpression(interval));
Expression e8 = new DateTimeLiteral(1991, 4, 30, 0, 0, 0);
assertRewrite(e7, e8);
interval = "'1991-05-01' + interval '1' day";
e7 = process((TimestampArithmetic) PARSER.parseExpression(interval));
e8 = new DateTimeLiteral(1991, 5, 2, 0, 0, 0);
assertRewrite(e7, e8);
interval = "'1991-05-01' + interval 1+1 day";
e7 = process((TimestampArithmetic) PARSER.parseExpression(interval));
e8 = new DateTimeLiteral(1991, 5, 3, 0, 0, 0);
assertRewrite(e7, e8);
interval = "date '1991-05-01' + interval 10 / 2 + 1 day";
e7 = process((TimestampArithmetic) PARSER.parseExpression(interval));
e8 = new DateTimeLiteral(1991, 5, 7, 0, 0, 0);
assertRewrite(e7, e8);
interval = "interval '1' day + '1991-05-01'";
e7 = process((TimestampArithmetic) PARSER.parseExpression(interval));
e8 = new DateTimeLiteral(1991, 5, 2, 0, 0, 0);
assertRewrite(e7, e8);
interval = "interval '3' month + '1991-05-01'";
e7 = process((TimestampArithmetic) PARSER.parseExpression(interval));
e8 = new DateTimeLiteral(1991, 8, 1, 0, 0, 0);
assertRewrite(e7, e8);
interval = "interval 3 + 1 month + '1991-05-01'";
e7 = process((TimestampArithmetic) PARSER.parseExpression(interval));
e8 = new DateTimeLiteral(1991, 9, 1, 0, 0, 0);
assertRewrite(e7, e8);
interval = "interval 3 + 1 year + '1991-05-01'";
e7 = process((TimestampArithmetic) PARSER.parseExpression(interval));
e8 = new DateTimeLiteral(1995, 5, 1, 0, 0, 0);
assertRewrite(e7, e8);
interval = "interval 3 + 3 / 2 hour + '1991-05-01 10:00:00'";
e7 = process((TimestampArithmetic) PARSER.parseExpression(interval));
e8 = new DateTimeLiteral(1991, 5, 1, 14, 0, 0);
assertRewrite(e7, e8);
interval = "interval 3 * 2 / 3 minute + '1991-05-01 10:00:00'";
e7 = process((TimestampArithmetic) PARSER.parseExpression(interval));
e8 = new DateTimeLiteral(1991, 5, 1, 10, 2, 0);
assertRewrite(e7, e8);
interval = "interval 3 / 2 + 1 second + '1991-05-01 10:00:00'";
e7 = process((TimestampArithmetic) PARSER.parseExpression(interval));
e8 = new DateTimeLiteral(1991, 5, 1, 10, 0, 2);
assertRewrite(e7, e8);
// a + interval 1 day
Slot a = SlotReference.of("a", DateTimeType.INSTANCE);
TimestampArithmetic arithmetic = new TimestampArithmetic(Operator.ADD, a, Literal.of(1), TimeUnit.DAY, false);
Expression process = process(arithmetic);
assertRewrite(process, process);
}
public Expression process(TimestampArithmetic arithmetic) {
String funcOpName;
if (arithmetic.getFuncName() == null) {
funcOpName = String.format("%sS_%s", arithmetic.getTimeUnit(),
(arithmetic.getOp() == Operator.ADD) ? "ADD" : "SUB");
} else {
funcOpName = arithmetic.getFuncName();
}
return arithmetic.withFuncName(funcOpName.toLowerCase(Locale.ROOT));
}
private void assertRewrite(String expression, String expected) {
Map<String, Slot> mem = Maps.newHashMap();
Expression needRewriteExpression = PARSER.parseExpression(expression);
needRewriteExpression = replaceUnboundSlot(needRewriteExpression, mem);
Expression expectedExpression = PARSER.parseExpression(expected);
expectedExpression = replaceUnboundSlot(expectedExpression, mem);
Expression rewrittenExpression = executor.rewrite(needRewriteExpression);
Assertions.assertEquals(expectedExpression, rewrittenExpression);
}
private void assertRewrite(String expression, Expression expectedExpression) {
Expression needRewriteExpression = PARSER.parseExpression(expression);
Expression rewrittenExpression = executor.rewrite(needRewriteExpression);
Assertions.assertEquals(expectedExpression, rewrittenExpression);
}
private void assertRewrite(Expression expression, Expression expectedExpression) {
Expression rewrittenExpression = executor.rewrite(expression);
Assertions.assertEquals(expectedExpression, rewrittenExpression);
}
private Expression replaceUnboundSlot(Expression expression, Map<String, Slot> mem) {
List<Expression> children = Lists.newArrayList();
boolean hasNewChildren = false;
for (Expression child : expression.children()) {
Expression newChild = replaceUnboundSlot(child, mem);
if (newChild != child) {
hasNewChildren = true;
}
children.add(newChild);
}
if (expression instanceof UnboundSlot) {
String name = ((UnboundSlot) expression).getName();
mem.putIfAbsent(name, SlotReference.of(name, getType(name.charAt(0))));
return mem.get(name);
}
return hasNewChildren ? expression.withChildren(children) : expression;
}
private DataType getType(char t) {
switch (t) {
case 'T':
return TinyIntType.INSTANCE;
case 'I':
return IntegerType.INSTANCE;
case 'D':
return DoubleType.INSTANCE;
case 'S':
return StringType.INSTANCE;
case 'V':
return VarcharType.INSTANCE;
case 'B':
return BooleanType.INSTANCE;
default:
return BigIntType.INSTANCE;
}
}
}

View File

@ -36,6 +36,7 @@ import org.apache.doris.nereids.trees.expressions.literal.DateLiteral;
import org.apache.doris.nereids.trees.expressions.literal.DecimalLiteral;
import org.apache.doris.nereids.trees.expressions.literal.DoubleLiteral;
import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.expressions.literal.SmallIntLiteral;
import org.apache.doris.nereids.trees.expressions.literal.StringLiteral;
import org.apache.doris.nereids.trees.expressions.literal.TinyIntLiteral;
@ -176,8 +177,8 @@ public class TypeCoercionTest {
@Test
public void testBinaryOperator() {
Expression actual = new Divide(new SmallIntLiteral((short) 1), new BigIntLiteral(10L));
Expression expected = new Divide(new BigIntLiteral(1L),
new BigIntLiteral(10L));
Expression expected = new Divide(new Cast(Literal.of((short) 1), DoubleType.INSTANCE),
new Cast(Literal.of(10L), DoubleType.INSTANCE));
assertRewrite(expected, actual);
}

View File

@ -272,4 +272,13 @@ public class ExpressionParserTest extends ParserTestBase {
String cast2 = "SELECT CAST(A AS INT) AS I FROM TEST;";
assertSql(cast2);
}
@Test
public void testIsNull() {
String e1 = "a is null";
assertExpr(e1);
String e2 = "a is not null";
assertExpr(e2);
}
}