[enhancement](nereids)remove useless cast for floatlike type (#23621)

convert cast(c1 AS double) > 2.0 to c1 >= 2 (c1 is integer like type)
This commit is contained in:
starocean999
2023-08-30 19:00:16 +08:00
committed by GitHub
parent f7caae08d5
commit e1743b70f2
2 changed files with 381 additions and 32 deletions

View File

@ -31,13 +31,21 @@ import org.apache.doris.nereids.trees.expressions.IsNull;
import org.apache.doris.nereids.trees.expressions.LessThan;
import org.apache.doris.nereids.trees.expressions.LessThanEqual;
import org.apache.doris.nereids.trees.expressions.NullSafeEqual;
import org.apache.doris.nereids.trees.expressions.literal.BigIntLiteral;
import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral;
import org.apache.doris.nereids.trees.expressions.literal.DateLiteral;
import org.apache.doris.nereids.trees.expressions.literal.DateTimeLiteral;
import org.apache.doris.nereids.trees.expressions.literal.DateTimeV2Literal;
import org.apache.doris.nereids.trees.expressions.literal.DateV2Literal;
import org.apache.doris.nereids.trees.expressions.literal.DecimalV3Literal;
import org.apache.doris.nereids.trees.expressions.literal.DoubleLiteral;
import org.apache.doris.nereids.trees.expressions.literal.FloatLiteral;
import org.apache.doris.nereids.trees.expressions.literal.IntegerLikeLiteral;
import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
import org.apache.doris.nereids.trees.expressions.literal.SmallIntLiteral;
import org.apache.doris.nereids.trees.expressions.literal.TinyIntLiteral;
import org.apache.doris.nereids.types.BooleanType;
import org.apache.doris.nereids.types.DateTimeType;
import org.apache.doris.nereids.types.DateTimeV2Type;
@ -46,9 +54,15 @@ import org.apache.doris.nereids.types.DateV2Type;
import org.apache.doris.nereids.types.DecimalV3Type;
import org.apache.doris.nereids.types.coercion.DateLikeType;
import com.google.common.base.Preconditions;
import java.math.BigDecimal;
import java.math.RoundingMode;
/**
* simplify comparison
* such as: cast(c1 as DateV2) >= DateV2Literal --> c1 >= DateLiteral
* cast(c1 AS double) > 2.0 --> c1 >= 2 (c1 is integer like type)
*/
public class SimplifyComparisonPredicate extends AbstractExpressionRewriteRule {
@ -65,6 +79,11 @@ public class SimplifyComparisonPredicate extends AbstractExpressionRewriteRule {
Expression left = rewrite(cp.left(), context);
Expression right = rewrite(cp.right(), context);
// float like type: float, double
if (left.getDataType().isFloatLikeType() && right.getDataType().isFloatLikeType()) {
return processFloatLikeTypeCoercion(cp, left, right);
}
// decimalv3 type
if (left.getDataType() instanceof DecimalV3Type
&& right.getDataType() instanceof DecimalV3Type) {
@ -194,6 +213,26 @@ public class SimplifyComparisonPredicate extends AbstractExpressionRewriteRule {
}
}
private Expression processFloatLikeTypeCoercion(ComparisonPredicate comparisonPredicate,
Expression left, Expression right) {
if (left instanceof Literal) {
comparisonPredicate = comparisonPredicate.commute();
Expression temp = left;
left = right;
right = temp;
}
if (left instanceof Cast && left.child(0).getDataType().isIntegerLikeType()
&& (right instanceof DoubleLiteral || right instanceof FloatLiteral)) {
Cast cast = (Cast) left;
left = cast.child();
BigDecimal literal = new BigDecimal(((Literal) right).getStringValue());
return processIntegerDecimalLiteralComparison(comparisonPredicate, left, literal);
} else {
return comparisonPredicate;
}
}
private Expression processDecimalV3TypeCoercion(ComparisonPredicate comparisonPredicate,
Expression left, Expression right) {
if (left instanceof DecimalV3Literal) {
@ -203,51 +242,113 @@ public class SimplifyComparisonPredicate extends AbstractExpressionRewriteRule {
right = temp;
}
if (left instanceof Cast && left.child(0).getDataType().isDecimalV3Type()
&& right instanceof DecimalV3Literal) {
if (left instanceof Cast && right instanceof DecimalV3Literal) {
Cast cast = (Cast) left;
left = cast.child();
DecimalV3Literal literal = (DecimalV3Literal) right;
if (((DecimalV3Type) left.getDataType())
.getScale() < ((DecimalV3Type) literal.getDataType()).getScale()) {
int toScale = ((DecimalV3Type) left.getDataType()).getScale();
if (comparisonPredicate instanceof EqualTo) {
try {
return comparisonPredicate.withChildren(left, new DecimalV3Literal(
(DecimalV3Type) left.getDataType(), literal.getValue().setScale(toScale)));
} catch (ArithmeticException e) {
if (left.nullable()) {
// TODO: the ideal way is to return an If expr like:
// return new If(new IsNull(left), new NullLiteral(BooleanType.INSTANCE),
// BooleanLiteral.of(false));
// but current fold constant rule can't handle such complex expr with null literal
// before supporting complex conjuncts with null literal folding rules,
// we use a trick way like this:
return new And(new IsNull(left), new NullLiteral(BooleanType.INSTANCE));
} else {
if (left.getDataType().isDecimalV3Type()) {
if (((DecimalV3Type) left.getDataType())
.getScale() < ((DecimalV3Type) literal.getDataType()).getScale()) {
int toScale = ((DecimalV3Type) left.getDataType()).getScale();
if (comparisonPredicate instanceof EqualTo) {
try {
return comparisonPredicate.withChildren(left,
new DecimalV3Literal((DecimalV3Type) left.getDataType(),
literal.getValue().setScale(toScale)));
} catch (ArithmeticException e) {
if (left.nullable()) {
// TODO: the ideal way is to return an If expr like:
// return new If(new IsNull(left), new NullLiteral(BooleanType.INSTANCE),
// BooleanLiteral.of(false));
// but current fold constant rule can't handle such complex expr with null literal
// before supporting complex conjuncts with null literal folding rules,
// we use a trick way like this:
return new And(new IsNull(left),
new NullLiteral(BooleanType.INSTANCE));
} else {
return BooleanLiteral.of(false);
}
}
} else if (comparisonPredicate instanceof NullSafeEqual) {
try {
return comparisonPredicate.withChildren(left,
new DecimalV3Literal((DecimalV3Type) left.getDataType(),
literal.getValue().setScale(toScale)));
} catch (ArithmeticException e) {
return BooleanLiteral.of(false);
}
} else if (comparisonPredicate instanceof GreaterThan
|| comparisonPredicate instanceof LessThanEqual) {
return comparisonPredicate.withChildren(left, literal.roundFloor(toScale));
} else if (comparisonPredicate instanceof LessThan
|| comparisonPredicate instanceof GreaterThanEqual) {
return comparisonPredicate.withChildren(left,
literal.roundCeiling(toScale));
}
} else if (comparisonPredicate instanceof NullSafeEqual) {
try {
return comparisonPredicate.withChildren(left, new DecimalV3Literal(
(DecimalV3Type) left.getDataType(), literal.getValue().setScale(toScale)));
} catch (ArithmeticException e) {
return BooleanLiteral.of(false);
}
} else if (comparisonPredicate instanceof GreaterThan
|| comparisonPredicate instanceof LessThanEqual) {
return comparisonPredicate.withChildren(left, literal.roundFloor(toScale));
} else if (comparisonPredicate instanceof LessThan
|| comparisonPredicate instanceof GreaterThanEqual) {
return comparisonPredicate.withChildren(left, literal.roundCeiling(toScale));
}
} else if (left.getDataType().isIntegerLikeType()) {
return processIntegerDecimalLiteralComparison(comparisonPredicate, left,
literal.getValue());
}
}
return comparisonPredicate;
}
private Expression processIntegerDecimalLiteralComparison(
ComparisonPredicate comparisonPredicate, Expression left, BigDecimal literal) {
// we only process isIntegerLikeType, which are tinyint, smallint, int, bigint
if (literal.compareTo(new BigDecimal(Long.MAX_VALUE)) <= 0) {
if (literal.scale() > 0) {
if (comparisonPredicate instanceof EqualTo) {
if (left.nullable()) {
// TODO: the ideal way is to return an If expr like:
// return new If(new IsNull(left), new NullLiteral(BooleanType.INSTANCE),
// BooleanLiteral.of(false));
// but current fold constant rule can't handle such complex expr with null literal
// before supporting complex conjuncts with null literal folding rules,
// we use a trick way like this:
return new And(new IsNull(left), new NullLiteral(BooleanType.INSTANCE));
} else {
return BooleanLiteral.of(false);
}
} else if (comparisonPredicate instanceof NullSafeEqual) {
return BooleanLiteral.of(false);
} else if (comparisonPredicate instanceof GreaterThan
|| comparisonPredicate instanceof LessThanEqual) {
return comparisonPredicate.withChildren(left,
convertDecimalToIntegerLikeLiteral(
literal.setScale(0, RoundingMode.FLOOR)));
} else if (comparisonPredicate instanceof LessThan
|| comparisonPredicate instanceof GreaterThanEqual) {
return comparisonPredicate.withChildren(left,
convertDecimalToIntegerLikeLiteral(
literal.setScale(0, RoundingMode.CEILING)));
}
} else {
return comparisonPredicate.withChildren(left,
convertDecimalToIntegerLikeLiteral(literal));
}
}
return comparisonPredicate;
}
private IntegerLikeLiteral convertDecimalToIntegerLikeLiteral(BigDecimal decimal) {
Preconditions.checkArgument(
decimal.scale() == 0 && decimal.compareTo(new BigDecimal(Long.MAX_VALUE)) <= 0,
"decimal literal must have 0 scale and smaller than Long.MAX_VALUE");
long val = decimal.longValue();
if (val <= Byte.MAX_VALUE) {
return new TinyIntLiteral((byte) val);
} else if (val <= Short.MAX_VALUE) {
return new SmallIntLiteral((short) val);
} else if (val <= Integer.MAX_VALUE) {
return new IntegerLiteral((int) val);
} else {
return new BigIntLiteral(val);
}
}
private Expression migrateCastToDateTime(Cast cast) {
//cast( cast(v as date) as datetime) if v is datetime, set left = v
if (cast.child() instanceof Cast

View File

@ -72,4 +72,252 @@ suite("test_simplify_comparison") {
}
sql "select cast('1234' as decimalv3(18,4)) > 2000;"
sql 'drop table if exists simple_test_table_t;'
sql """CREATE TABLE IF NOT EXISTS `simple_test_table_t` (
a tinyint,
b smallint,
c int,
d bigint,
e largeint
) ENGINE=OLAP
UNIQUE KEY (`a`)
DISTRIBUTED BY HASH(`a`) BUCKETS 120
PROPERTIES (
"replication_num" = "1",
"in_memory" = "false",
"compression" = "LZ4"
);"""
explain {
sql "verbose select * from simple_test_table_t where a = cast(1.0 as double) and b = cast(1.0 as double) and c = cast(1.0 as double) and d = cast(1.0 as double);"
notContains "CAST"
}
explain {
sql "verbose select * from simple_test_table_t where e = cast(1.0 as double);"
contains "CAST"
}
explain {
sql "verbose select * from simple_test_table_t where a > cast(1.0 as double) and b > cast(1.0 as double) and c > cast(1.0 as double) and d > cast(1.0 as double);"
notContains "CAST"
}
explain {
sql "verbose select * from simple_test_table_t where e > cast(1.0 as double);"
contains "CAST"
}
explain {
sql "verbose select * from simple_test_table_t where a < cast(1.0 as double) and b < cast(1.0 as double) and c < cast(1.0 as double) and d < cast(1.0 as double);"
notContains "CAST"
}
explain {
sql "verbose select * from simple_test_table_t where e < cast(1.0 as double);"
contains "CAST"
}
explain {
sql "verbose select * from simple_test_table_t where a >= cast(1.0 as double) and b >= cast(1.0 as double) and c >= cast(1.0 as double) and d >= cast(1.0 as double);"
notContains "CAST"
}
explain {
sql "verbose select * from simple_test_table_t where e >= cast(1.0 as double);"
contains "CAST"
}
explain {
sql "verbose select * from simple_test_table_t where a <= cast(1.0 as double) and b <= cast(1.0 as double) and c <= cast(1.0 as double) and d <= cast(1.0 as double);"
notContains "CAST"
}
explain {
sql "verbose select * from simple_test_table_t where e <= cast(1.0 as double);"
contains "CAST"
}
explain {
sql "verbose select * from simple_test_table_t where a = cast(1.1 as double) and b = cast(1.1 as double) and c = cast(1.1 as double) and d = cast(1.1 as double);"
contains "a[#0] IS NULL"
contains "b[#1] IS NULL"
contains "c[#2] IS NULL"
contains "d[#3] IS NULL"
contains "AND NULL"
}
explain {
sql "verbose select * from simple_test_table_t where e = cast(1.1 as double);"
contains "CAST(e[#4] AS DOUBLE) = 1.1"
}
explain {
sql "verbose select * from simple_test_table_t where a > cast(1.1 as double) and b > cast(1.1 as double) and c > cast(1.1 as double) and d > cast(1.1 as double);"
contains "a[#0] > 1"
contains "b[#1] > 1"
contains "c[#2] > 1"
contains "d[#3] > 1"
}
explain {
sql "verbose select * from simple_test_table_t where e > cast(1.1 as double);"
contains "CAST(e[#4] AS DOUBLE) > 1.1"
}
explain {
sql "verbose select * from simple_test_table_t where a < cast(1.1 as double) and b < cast(1.1 as double) and c < cast(1.1 as double) and d < cast(1.1 as double);"
contains "a[#0] < 2"
contains "b[#1] < 2"
contains "c[#2] < 2"
contains "d[#3] < 2"
}
explain {
sql "verbose select * from simple_test_table_t where e < cast(1.1 as double);"
contains "CAST(e[#4] AS DOUBLE) < 1.1"
}
explain {
sql "verbose select * from simple_test_table_t where a >= cast(1.1 as double) and b >= cast(1.1 as double) and c >= cast(1.1 as double) and d >= cast(1.1 as double);"
contains "a[#0] >= 2"
contains "b[#1] >= 2"
contains "c[#2] >= 2"
contains "d[#3] >= 2"
}
explain {
sql "verbose select * from simple_test_table_t where e >= cast(1.1 as double);"
contains "CAST(e[#4] AS DOUBLE) >= 1.1"
}
explain {
sql "verbose select * from simple_test_table_t where a <= cast(1.1 as double) and b <= cast(1.1 as double) and c <= cast(1.1 as double) and d <= cast(1.1 as double);"
contains "a[#0] <= 1"
contains "b[#1] <= 1"
contains "c[#2] <= 1"
contains "d[#3] <= 1"
}
explain {
sql "verbose select * from simple_test_table_t where e <= cast(1.1 as double);"
contains "CAST(e[#4] AS DOUBLE) <= 1.1"
}
explain {
sql "verbose select * from simple_test_table_t where a = 1.0 and b = 1.0 and c = 1.0 and d = 1.0;"
notContains "CAST"
}
explain {
sql "verbose select * from simple_test_table_t where e = 1.0;"
contains "CAST"
}
explain {
sql "verbose select * from simple_test_table_t where a > 1.0 and b > 1.0 and c > 1.0 and d > 1.0;"
notContains "CAST"
}
explain {
sql "verbose select * from simple_test_table_t where e > 1.0;"
contains "CAST"
}
explain {
sql "verbose select * from simple_test_table_t where a < 1.0 and b < 1.0 and c < 1.0 and d < 1.0;"
notContains "CAST"
}
explain {
sql "verbose select * from simple_test_table_t where e < 1.0;"
contains "CAST"
}
explain {
sql "verbose select * from simple_test_table_t where a >= 1.0 and b >= 1.0 and c >= 1.0 and d >= 1.0;"
notContains "CAST"
}
explain {
sql "verbose select * from simple_test_table_t where e >= 1.0;"
contains "CAST"
}
explain {
sql "verbose select * from simple_test_table_t where a <= 1.0 and b <= 1.0 and c <= 1.0 and d <= 1.0;"
notContains "CAST"
}
explain {
sql "verbose select * from simple_test_table_t where e <= 1.0;"
contains "CAST"
}
explain {
sql "verbose select * from simple_test_table_t where a = 1.1 and b = 1.1 and c = 1.1 and d = 1.1;"
contains "a[#0] IS NULL"
contains "b[#1] IS NULL"
contains "c[#2] IS NULL"
contains "d[#3] IS NULL"
contains "AND NULL"
}
explain {
sql "verbose select * from simple_test_table_t where e = 1.1;"
contains "CAST(e[#4] AS DOUBLE) = 1.1"
}
explain {
sql "verbose select * from simple_test_table_t where a > 1.1 and b > 1.1 and c > 1.1 and d > 1.1;"
contains "a[#0] > 1"
contains "b[#1] > 1"
contains "c[#2] > 1"
contains "d[#3] > 1"
}
explain {
sql "verbose select * from simple_test_table_t where e > 1.1;"
contains "CAST(e[#4] AS DOUBLE) > 1.1"
}
explain {
sql "verbose select * from simple_test_table_t where a < 1.1 and b < 1.1 and c < 1.1 and d < 1.1;"
contains "a[#0] < 2"
contains "b[#1] < 2"
contains "c[#2] < 2"
contains "d[#3] < 2"
}
explain {
sql "verbose select * from simple_test_table_t where e < 1.1;"
contains "CAST(e[#4] AS DOUBLE) < 1.1"
}
explain {
sql "verbose select * from simple_test_table_t where a >= 1.1 and b >= 1.1 and c >= 1.1 and d >= 1.1;"
contains "a[#0] >= 2"
contains "b[#1] >= 2"
contains "c[#2] >= 2"
contains "d[#3] >= 2"
}
explain {
sql "verbose select * from simple_test_table_t where e >= 1.1;"
contains "CAST(e[#4] AS DOUBLE) >= 1.1"
}
explain {
sql "verbose select * from simple_test_table_t where a <= 1.1 and b <= 1.1 and c <= 1.1 and d <= 1.1;"
contains "a[#0] <= 1"
contains "b[#1] <= 1"
contains "c[#2] <= 1"
contains "d[#3] <= 1"
}
explain {
sql "verbose select * from simple_test_table_t where e <= 1.1;"
contains "CAST(e[#4] AS DOUBLE) <= 1.1"
}
}