[enhancement](Nereids) let BinaryArithmetic's dataType and nullable match with BE (#13015)

Do type promotion for BinaryArithmetic:
- Add
- Subtract
- Multiply

Do always nullable for:
- Mod
This commit is contained in:
morrySnow
2022-09-28 20:02:27 +08:00
committed by GitHub
parent cd549d8a8f
commit 7019166469
6 changed files with 41 additions and 24 deletions

View File

@ -18,6 +18,7 @@
package org.apache.doris.nereids.trees.expressions;
import org.apache.doris.analysis.ArithmeticExpr.Operator;
import org.apache.doris.nereids.exceptions.UnboundException;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.coercion.AbstractDataType;
import org.apache.doris.nereids.types.coercion.NumericType;
@ -41,6 +42,11 @@ public class Mod extends BinaryArithmetic {
return new Mod(children.get(0), children.get(1));
}
@Override
public boolean nullable() throws UnboundException {
return true;
}
@Override
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
return visitor.visitMod(this, context);

View File

@ -19,6 +19,7 @@ package org.apache.doris.nereids.trees.expressions;
import org.apache.doris.analysis.ArithmeticExpr.Operator;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.coercion.AbstractDataType;
import org.apache.doris.nereids.types.coercion.NumericType;
@ -41,6 +42,11 @@ public class Multiply extends BinaryArithmetic {
return new Multiply(children.get(0), children.get(1));
}
@Override
public DataType getDataType() {
return left().getDataType().promotion();
}
@Override
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
return visitor.visitMultiply(this, context);

View File

@ -19,6 +19,7 @@ package org.apache.doris.nereids.trees.expressions;
import org.apache.doris.analysis.ArithmeticExpr.Operator;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.coercion.AbstractDataType;
import org.apache.doris.nereids.types.coercion.NumericType;
@ -41,6 +42,11 @@ public class Subtract extends BinaryArithmetic {
return new Subtract(children.get(0), children.get(1));
}
@Override
public DataType getDataType() {
return left().getDataType().promotion();
}
@Override
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
return visitor.visitSubtract(this, context);

View File

@ -80,24 +80,24 @@ public class ExecutableFunctions {
return new DecimalLiteral(result);
}
@ExecFunction(name = "subtract", argTypes = {"TINYINT", "TINYINT"}, returnType = "TINYINT")
public static TinyIntLiteral subtractTinyint(TinyIntLiteral first, TinyIntLiteral second) {
byte result = (byte) Math.subtractExact(first.getValue(), second.getValue());
return new TinyIntLiteral(result);
}
@ExecFunction(name = "subtract", argTypes = {"SMALLINT", "SMALLINT"}, returnType = "SMALLINT")
public static SmallIntLiteral subtractSmallint(SmallIntLiteral first, SmallIntLiteral second) {
@ExecFunction(name = "subtract", argTypes = {"TINYINT", "TINYINT"}, returnType = "SMALLINT")
public static SmallIntLiteral subtractTinyint(TinyIntLiteral first, TinyIntLiteral second) {
short result = (short) Math.subtractExact(first.getValue(), second.getValue());
return new SmallIntLiteral(result);
}
@ExecFunction(name = "subtract", argTypes = {"INT", "INT"}, returnType = "INT")
public static IntegerLiteral subtractInt(IntegerLiteral first, IntegerLiteral second) {
@ExecFunction(name = "subtract", argTypes = {"SMALLINT", "SMALLINT"}, returnType = "INT")
public static IntegerLiteral subtractSmallint(SmallIntLiteral first, SmallIntLiteral second) {
int result = Math.subtractExact(first.getValue(), second.getValue());
return new IntegerLiteral(result);
}
@ExecFunction(name = "subtract", argTypes = {"INT", "INT"}, returnType = "BIGINT")
public static BigIntLiteral subtractInt(IntegerLiteral first, IntegerLiteral second) {
long result = Math.subtractExact(first.getValue(), second.getValue());
return new BigIntLiteral(result);
}
@ExecFunction(name = "subtract", argTypes = {"BIGINT", "BIGINT"}, returnType = "BIGINT")
public static BigIntLiteral subtractBigint(BigIntLiteral first, BigIntLiteral second) {
long result = Math.subtractExact(first.getValue(), second.getValue());
@ -116,24 +116,24 @@ public class ExecutableFunctions {
return new DecimalLiteral(result);
}
@ExecFunction(name = "multiply", argTypes = {"TINYINT", "TINYINT"}, returnType = "TINYINT")
public static TinyIntLiteral multiplyTinyint(TinyIntLiteral first, TinyIntLiteral second) {
byte result = (byte) Math.multiplyExact(first.getValue(), second.getValue());
return new TinyIntLiteral(result);
}
@ExecFunction(name = "multiply", argTypes = {"SMALLINT", "SMALLINT"}, returnType = "SMALLINT")
public static SmallIntLiteral multiplySmallint(SmallIntLiteral first, SmallIntLiteral second) {
@ExecFunction(name = "multiply", argTypes = {"TINYINT", "TINYINT"}, returnType = "SMALLINT")
public static SmallIntLiteral multiplyTinyint(TinyIntLiteral first, TinyIntLiteral second) {
short result = (short) Math.multiplyExact(first.getValue(), second.getValue());
return new SmallIntLiteral(result);
}
@ExecFunction(name = "multiply", argTypes = {"INT", "INT"}, returnType = "INT")
public static IntegerLiteral multiplyInt(IntegerLiteral first, IntegerLiteral second) {
@ExecFunction(name = "multiply", argTypes = {"SMALLINT", "SMALLINT"}, returnType = "INT")
public static IntegerLiteral multiplySmallint(SmallIntLiteral first, SmallIntLiteral second) {
int result = Math.multiplyExact(first.getValue(), second.getValue());
return new IntegerLiteral(result);
}
@ExecFunction(name = "multiply", argTypes = {"INT", "INT"}, returnType = "BIGINT")
public static BigIntLiteral multiplyInt(IntegerLiteral first, IntegerLiteral second) {
long result = Math.multiplyExact(first.getValue(), second.getValue());
return new BigIntLiteral(result);
}
@ExecFunction(name = "multiply", argTypes = {"BIGINT", "BIGINT"}, returnType = "BIGINT")
public static BigIntLiteral multiplyBigint(BigIntLiteral first, BigIntLiteral second) {
long result = Math.multiplyExact(first.getValue(), second.getValue());

View File

@ -257,7 +257,7 @@ public class HavingClauseTest extends AnalyzeCheckTestBase implements PatternMat
NamedExpressionUtil.clear();
sql = "SELECT a1, SUM(a1 + a2) FROM t1 GROUP BY a1 HAVING SUM(a1 + a2 + 3) > 0";
Alias sumA1A23 = new Alias(new ExprId(4), new Sum(new Add(new Add(a1, a2), new SmallIntLiteral((byte) 3))),
Alias sumA1A23 = new Alias(new ExprId(4), new Sum(new Add(new Add(a1, a2), new SmallIntLiteral((short) 3))),
"sum(((a1 + a2) + 3))");
PlanChecker.from(connectContext).analyze(sql)
.matchesFromRoot(

View File

@ -172,10 +172,9 @@ public class FoldConstantTest {
public void testArithmeticFold() {
executor = new ExpressionRuleExecutor(ImmutableList.of(TypeCoercion.INSTANCE, FoldConstantRuleOnFE.INSTANCE));
assertRewrite("1 + 1", Literal.of((short) 2));
assertRewrite("1 - 1", Literal.of((byte) 0));
assertRewrite("1 - 1", Literal.of((short) 0));
assertRewrite("100 + 100", Literal.of((short) 200));
assertRewrite("1 - 2", Literal.of((byte) -1));
assertRewrite("1 - 2", Literal.of((short) -1));
assertRewrite("1 - 2 > 1", "false");
assertRewrite("1 - 2 + 1 > 1 + 1 - 100", "true");
assertRewrite("10 * 2 / 1 + 1 > (1 + 1) - 100", "true");