[enhancement](Nereids) let BinaryArithmetic's dataType and nullable match with BE (#13015)
Do type promotion for BinaryArithmetic: - Add - Subtract - Multiply Do always nullable for: - Mod
This commit is contained in:
@ -18,6 +18,7 @@
|
||||
package org.apache.doris.nereids.trees.expressions;
|
||||
|
||||
import org.apache.doris.analysis.ArithmeticExpr.Operator;
|
||||
import org.apache.doris.nereids.exceptions.UnboundException;
|
||||
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
|
||||
import org.apache.doris.nereids.types.coercion.AbstractDataType;
|
||||
import org.apache.doris.nereids.types.coercion.NumericType;
|
||||
@ -41,6 +42,11 @@ public class Mod extends BinaryArithmetic {
|
||||
return new Mod(children.get(0), children.get(1));
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean nullable() throws UnboundException {
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
|
||||
return visitor.visitMod(this, context);
|
||||
|
||||
@ -19,6 +19,7 @@ package org.apache.doris.nereids.trees.expressions;
|
||||
|
||||
import org.apache.doris.analysis.ArithmeticExpr.Operator;
|
||||
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
|
||||
import org.apache.doris.nereids.types.DataType;
|
||||
import org.apache.doris.nereids.types.coercion.AbstractDataType;
|
||||
import org.apache.doris.nereids.types.coercion.NumericType;
|
||||
|
||||
@ -41,6 +42,11 @@ public class Multiply extends BinaryArithmetic {
|
||||
return new Multiply(children.get(0), children.get(1));
|
||||
}
|
||||
|
||||
@Override
|
||||
public DataType getDataType() {
|
||||
return left().getDataType().promotion();
|
||||
}
|
||||
|
||||
@Override
|
||||
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
|
||||
return visitor.visitMultiply(this, context);
|
||||
|
||||
@ -19,6 +19,7 @@ package org.apache.doris.nereids.trees.expressions;
|
||||
|
||||
import org.apache.doris.analysis.ArithmeticExpr.Operator;
|
||||
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
|
||||
import org.apache.doris.nereids.types.DataType;
|
||||
import org.apache.doris.nereids.types.coercion.AbstractDataType;
|
||||
import org.apache.doris.nereids.types.coercion.NumericType;
|
||||
|
||||
@ -41,6 +42,11 @@ public class Subtract extends BinaryArithmetic {
|
||||
return new Subtract(children.get(0), children.get(1));
|
||||
}
|
||||
|
||||
@Override
|
||||
public DataType getDataType() {
|
||||
return left().getDataType().promotion();
|
||||
}
|
||||
|
||||
@Override
|
||||
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
|
||||
return visitor.visitSubtract(this, context);
|
||||
|
||||
@ -80,24 +80,24 @@ public class ExecutableFunctions {
|
||||
return new DecimalLiteral(result);
|
||||
}
|
||||
|
||||
@ExecFunction(name = "subtract", argTypes = {"TINYINT", "TINYINT"}, returnType = "TINYINT")
|
||||
public static TinyIntLiteral subtractTinyint(TinyIntLiteral first, TinyIntLiteral second) {
|
||||
byte result = (byte) Math.subtractExact(first.getValue(), second.getValue());
|
||||
return new TinyIntLiteral(result);
|
||||
}
|
||||
|
||||
@ExecFunction(name = "subtract", argTypes = {"SMALLINT", "SMALLINT"}, returnType = "SMALLINT")
|
||||
public static SmallIntLiteral subtractSmallint(SmallIntLiteral first, SmallIntLiteral second) {
|
||||
@ExecFunction(name = "subtract", argTypes = {"TINYINT", "TINYINT"}, returnType = "SMALLINT")
|
||||
public static SmallIntLiteral subtractTinyint(TinyIntLiteral first, TinyIntLiteral second) {
|
||||
short result = (short) Math.subtractExact(first.getValue(), second.getValue());
|
||||
return new SmallIntLiteral(result);
|
||||
}
|
||||
|
||||
@ExecFunction(name = "subtract", argTypes = {"INT", "INT"}, returnType = "INT")
|
||||
public static IntegerLiteral subtractInt(IntegerLiteral first, IntegerLiteral second) {
|
||||
@ExecFunction(name = "subtract", argTypes = {"SMALLINT", "SMALLINT"}, returnType = "INT")
|
||||
public static IntegerLiteral subtractSmallint(SmallIntLiteral first, SmallIntLiteral second) {
|
||||
int result = Math.subtractExact(first.getValue(), second.getValue());
|
||||
return new IntegerLiteral(result);
|
||||
}
|
||||
|
||||
@ExecFunction(name = "subtract", argTypes = {"INT", "INT"}, returnType = "BIGINT")
|
||||
public static BigIntLiteral subtractInt(IntegerLiteral first, IntegerLiteral second) {
|
||||
long result = Math.subtractExact(first.getValue(), second.getValue());
|
||||
return new BigIntLiteral(result);
|
||||
}
|
||||
|
||||
@ExecFunction(name = "subtract", argTypes = {"BIGINT", "BIGINT"}, returnType = "BIGINT")
|
||||
public static BigIntLiteral subtractBigint(BigIntLiteral first, BigIntLiteral second) {
|
||||
long result = Math.subtractExact(first.getValue(), second.getValue());
|
||||
@ -116,24 +116,24 @@ public class ExecutableFunctions {
|
||||
return new DecimalLiteral(result);
|
||||
}
|
||||
|
||||
@ExecFunction(name = "multiply", argTypes = {"TINYINT", "TINYINT"}, returnType = "TINYINT")
|
||||
public static TinyIntLiteral multiplyTinyint(TinyIntLiteral first, TinyIntLiteral second) {
|
||||
byte result = (byte) Math.multiplyExact(first.getValue(), second.getValue());
|
||||
return new TinyIntLiteral(result);
|
||||
}
|
||||
|
||||
@ExecFunction(name = "multiply", argTypes = {"SMALLINT", "SMALLINT"}, returnType = "SMALLINT")
|
||||
public static SmallIntLiteral multiplySmallint(SmallIntLiteral first, SmallIntLiteral second) {
|
||||
@ExecFunction(name = "multiply", argTypes = {"TINYINT", "TINYINT"}, returnType = "SMALLINT")
|
||||
public static SmallIntLiteral multiplyTinyint(TinyIntLiteral first, TinyIntLiteral second) {
|
||||
short result = (short) Math.multiplyExact(first.getValue(), second.getValue());
|
||||
return new SmallIntLiteral(result);
|
||||
}
|
||||
|
||||
@ExecFunction(name = "multiply", argTypes = {"INT", "INT"}, returnType = "INT")
|
||||
public static IntegerLiteral multiplyInt(IntegerLiteral first, IntegerLiteral second) {
|
||||
@ExecFunction(name = "multiply", argTypes = {"SMALLINT", "SMALLINT"}, returnType = "INT")
|
||||
public static IntegerLiteral multiplySmallint(SmallIntLiteral first, SmallIntLiteral second) {
|
||||
int result = Math.multiplyExact(first.getValue(), second.getValue());
|
||||
return new IntegerLiteral(result);
|
||||
}
|
||||
|
||||
@ExecFunction(name = "multiply", argTypes = {"INT", "INT"}, returnType = "BIGINT")
|
||||
public static BigIntLiteral multiplyInt(IntegerLiteral first, IntegerLiteral second) {
|
||||
long result = Math.multiplyExact(first.getValue(), second.getValue());
|
||||
return new BigIntLiteral(result);
|
||||
}
|
||||
|
||||
@ExecFunction(name = "multiply", argTypes = {"BIGINT", "BIGINT"}, returnType = "BIGINT")
|
||||
public static BigIntLiteral multiplyBigint(BigIntLiteral first, BigIntLiteral second) {
|
||||
long result = Math.multiplyExact(first.getValue(), second.getValue());
|
||||
|
||||
@ -257,7 +257,7 @@ public class HavingClauseTest extends AnalyzeCheckTestBase implements PatternMat
|
||||
NamedExpressionUtil.clear();
|
||||
|
||||
sql = "SELECT a1, SUM(a1 + a2) FROM t1 GROUP BY a1 HAVING SUM(a1 + a2 + 3) > 0";
|
||||
Alias sumA1A23 = new Alias(new ExprId(4), new Sum(new Add(new Add(a1, a2), new SmallIntLiteral((byte) 3))),
|
||||
Alias sumA1A23 = new Alias(new ExprId(4), new Sum(new Add(new Add(a1, a2), new SmallIntLiteral((short) 3))),
|
||||
"sum(((a1 + a2) + 3))");
|
||||
PlanChecker.from(connectContext).analyze(sql)
|
||||
.matchesFromRoot(
|
||||
|
||||
@ -172,10 +172,9 @@ public class FoldConstantTest {
|
||||
public void testArithmeticFold() {
|
||||
executor = new ExpressionRuleExecutor(ImmutableList.of(TypeCoercion.INSTANCE, FoldConstantRuleOnFE.INSTANCE));
|
||||
assertRewrite("1 + 1", Literal.of((short) 2));
|
||||
assertRewrite("1 - 1", Literal.of((byte) 0));
|
||||
assertRewrite("1 - 1", Literal.of((short) 0));
|
||||
assertRewrite("100 + 100", Literal.of((short) 200));
|
||||
assertRewrite("1 - 2", Literal.of((byte) -1));
|
||||
|
||||
assertRewrite("1 - 2", Literal.of((short) -1));
|
||||
assertRewrite("1 - 2 > 1", "false");
|
||||
assertRewrite("1 - 2 + 1 > 1 + 1 - 100", "true");
|
||||
assertRewrite("10 * 2 / 1 + 1 > (1 + 1) - 100", "true");
|
||||
|
||||
Reference in New Issue
Block a user