[enhancement](nereids)remove useless cast for floatlike type (#23621)
convert cast(c1 AS double) > 2.0 to c1 >= 2 (c1 is integer like type)
This commit is contained in:
@ -31,13 +31,21 @@ import org.apache.doris.nereids.trees.expressions.IsNull;
|
||||
import org.apache.doris.nereids.trees.expressions.LessThan;
|
||||
import org.apache.doris.nereids.trees.expressions.LessThanEqual;
|
||||
import org.apache.doris.nereids.trees.expressions.NullSafeEqual;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.BigIntLiteral;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.DateLiteral;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.DateTimeLiteral;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.DateTimeV2Literal;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.DateV2Literal;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.DecimalV3Literal;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.DoubleLiteral;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.FloatLiteral;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.IntegerLikeLiteral;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.Literal;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.SmallIntLiteral;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.TinyIntLiteral;
|
||||
import org.apache.doris.nereids.types.BooleanType;
|
||||
import org.apache.doris.nereids.types.DateTimeType;
|
||||
import org.apache.doris.nereids.types.DateTimeV2Type;
|
||||
@ -46,9 +54,15 @@ import org.apache.doris.nereids.types.DateV2Type;
|
||||
import org.apache.doris.nereids.types.DecimalV3Type;
|
||||
import org.apache.doris.nereids.types.coercion.DateLikeType;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
|
||||
import java.math.BigDecimal;
|
||||
import java.math.RoundingMode;
|
||||
|
||||
/**
|
||||
* simplify comparison
|
||||
* such as: cast(c1 as DateV2) >= DateV2Literal --> c1 >= DateLiteral
|
||||
* cast(c1 AS double) > 2.0 --> c1 >= 2 (c1 is integer like type)
|
||||
*/
|
||||
public class SimplifyComparisonPredicate extends AbstractExpressionRewriteRule {
|
||||
|
||||
@ -65,6 +79,11 @@ public class SimplifyComparisonPredicate extends AbstractExpressionRewriteRule {
|
||||
Expression left = rewrite(cp.left(), context);
|
||||
Expression right = rewrite(cp.right(), context);
|
||||
|
||||
// float like type: float, double
|
||||
if (left.getDataType().isFloatLikeType() && right.getDataType().isFloatLikeType()) {
|
||||
return processFloatLikeTypeCoercion(cp, left, right);
|
||||
}
|
||||
|
||||
// decimalv3 type
|
||||
if (left.getDataType() instanceof DecimalV3Type
|
||||
&& right.getDataType() instanceof DecimalV3Type) {
|
||||
@ -194,6 +213,26 @@ public class SimplifyComparisonPredicate extends AbstractExpressionRewriteRule {
|
||||
}
|
||||
}
|
||||
|
||||
private Expression processFloatLikeTypeCoercion(ComparisonPredicate comparisonPredicate,
|
||||
Expression left, Expression right) {
|
||||
if (left instanceof Literal) {
|
||||
comparisonPredicate = comparisonPredicate.commute();
|
||||
Expression temp = left;
|
||||
left = right;
|
||||
right = temp;
|
||||
}
|
||||
|
||||
if (left instanceof Cast && left.child(0).getDataType().isIntegerLikeType()
|
||||
&& (right instanceof DoubleLiteral || right instanceof FloatLiteral)) {
|
||||
Cast cast = (Cast) left;
|
||||
left = cast.child();
|
||||
BigDecimal literal = new BigDecimal(((Literal) right).getStringValue());
|
||||
return processIntegerDecimalLiteralComparison(comparisonPredicate, left, literal);
|
||||
} else {
|
||||
return comparisonPredicate;
|
||||
}
|
||||
}
|
||||
|
||||
private Expression processDecimalV3TypeCoercion(ComparisonPredicate comparisonPredicate,
|
||||
Expression left, Expression right) {
|
||||
if (left instanceof DecimalV3Literal) {
|
||||
@ -203,51 +242,113 @@ public class SimplifyComparisonPredicate extends AbstractExpressionRewriteRule {
|
||||
right = temp;
|
||||
}
|
||||
|
||||
if (left instanceof Cast && left.child(0).getDataType().isDecimalV3Type()
|
||||
&& right instanceof DecimalV3Literal) {
|
||||
if (left instanceof Cast && right instanceof DecimalV3Literal) {
|
||||
Cast cast = (Cast) left;
|
||||
left = cast.child();
|
||||
DecimalV3Literal literal = (DecimalV3Literal) right;
|
||||
if (((DecimalV3Type) left.getDataType())
|
||||
.getScale() < ((DecimalV3Type) literal.getDataType()).getScale()) {
|
||||
int toScale = ((DecimalV3Type) left.getDataType()).getScale();
|
||||
if (comparisonPredicate instanceof EqualTo) {
|
||||
try {
|
||||
return comparisonPredicate.withChildren(left, new DecimalV3Literal(
|
||||
(DecimalV3Type) left.getDataType(), literal.getValue().setScale(toScale)));
|
||||
} catch (ArithmeticException e) {
|
||||
if (left.nullable()) {
|
||||
// TODO: the ideal way is to return an If expr like:
|
||||
// return new If(new IsNull(left), new NullLiteral(BooleanType.INSTANCE),
|
||||
// BooleanLiteral.of(false));
|
||||
// but current fold constant rule can't handle such complex expr with null literal
|
||||
// before supporting complex conjuncts with null literal folding rules,
|
||||
// we use a trick way like this:
|
||||
return new And(new IsNull(left), new NullLiteral(BooleanType.INSTANCE));
|
||||
} else {
|
||||
if (left.getDataType().isDecimalV3Type()) {
|
||||
if (((DecimalV3Type) left.getDataType())
|
||||
.getScale() < ((DecimalV3Type) literal.getDataType()).getScale()) {
|
||||
int toScale = ((DecimalV3Type) left.getDataType()).getScale();
|
||||
if (comparisonPredicate instanceof EqualTo) {
|
||||
try {
|
||||
return comparisonPredicate.withChildren(left,
|
||||
new DecimalV3Literal((DecimalV3Type) left.getDataType(),
|
||||
literal.getValue().setScale(toScale)));
|
||||
} catch (ArithmeticException e) {
|
||||
if (left.nullable()) {
|
||||
// TODO: the ideal way is to return an If expr like:
|
||||
// return new If(new IsNull(left), new NullLiteral(BooleanType.INSTANCE),
|
||||
// BooleanLiteral.of(false));
|
||||
// but current fold constant rule can't handle such complex expr with null literal
|
||||
// before supporting complex conjuncts with null literal folding rules,
|
||||
// we use a trick way like this:
|
||||
return new And(new IsNull(left),
|
||||
new NullLiteral(BooleanType.INSTANCE));
|
||||
} else {
|
||||
return BooleanLiteral.of(false);
|
||||
}
|
||||
}
|
||||
} else if (comparisonPredicate instanceof NullSafeEqual) {
|
||||
try {
|
||||
return comparisonPredicate.withChildren(left,
|
||||
new DecimalV3Literal((DecimalV3Type) left.getDataType(),
|
||||
literal.getValue().setScale(toScale)));
|
||||
} catch (ArithmeticException e) {
|
||||
return BooleanLiteral.of(false);
|
||||
}
|
||||
} else if (comparisonPredicate instanceof GreaterThan
|
||||
|| comparisonPredicate instanceof LessThanEqual) {
|
||||
return comparisonPredicate.withChildren(left, literal.roundFloor(toScale));
|
||||
} else if (comparisonPredicate instanceof LessThan
|
||||
|| comparisonPredicate instanceof GreaterThanEqual) {
|
||||
return comparisonPredicate.withChildren(left,
|
||||
literal.roundCeiling(toScale));
|
||||
}
|
||||
} else if (comparisonPredicate instanceof NullSafeEqual) {
|
||||
try {
|
||||
return comparisonPredicate.withChildren(left, new DecimalV3Literal(
|
||||
(DecimalV3Type) left.getDataType(), literal.getValue().setScale(toScale)));
|
||||
} catch (ArithmeticException e) {
|
||||
return BooleanLiteral.of(false);
|
||||
}
|
||||
} else if (comparisonPredicate instanceof GreaterThan
|
||||
|| comparisonPredicate instanceof LessThanEqual) {
|
||||
return comparisonPredicate.withChildren(left, literal.roundFloor(toScale));
|
||||
} else if (comparisonPredicate instanceof LessThan
|
||||
|| comparisonPredicate instanceof GreaterThanEqual) {
|
||||
return comparisonPredicate.withChildren(left, literal.roundCeiling(toScale));
|
||||
}
|
||||
} else if (left.getDataType().isIntegerLikeType()) {
|
||||
return processIntegerDecimalLiteralComparison(comparisonPredicate, left,
|
||||
literal.getValue());
|
||||
}
|
||||
}
|
||||
|
||||
return comparisonPredicate;
|
||||
}
|
||||
|
||||
private Expression processIntegerDecimalLiteralComparison(
|
||||
ComparisonPredicate comparisonPredicate, Expression left, BigDecimal literal) {
|
||||
// we only process isIntegerLikeType, which are tinyint, smallint, int, bigint
|
||||
if (literal.compareTo(new BigDecimal(Long.MAX_VALUE)) <= 0) {
|
||||
if (literal.scale() > 0) {
|
||||
if (comparisonPredicate instanceof EqualTo) {
|
||||
if (left.nullable()) {
|
||||
// TODO: the ideal way is to return an If expr like:
|
||||
// return new If(new IsNull(left), new NullLiteral(BooleanType.INSTANCE),
|
||||
// BooleanLiteral.of(false));
|
||||
// but current fold constant rule can't handle such complex expr with null literal
|
||||
// before supporting complex conjuncts with null literal folding rules,
|
||||
// we use a trick way like this:
|
||||
return new And(new IsNull(left), new NullLiteral(BooleanType.INSTANCE));
|
||||
} else {
|
||||
return BooleanLiteral.of(false);
|
||||
}
|
||||
} else if (comparisonPredicate instanceof NullSafeEqual) {
|
||||
return BooleanLiteral.of(false);
|
||||
} else if (comparisonPredicate instanceof GreaterThan
|
||||
|| comparisonPredicate instanceof LessThanEqual) {
|
||||
return comparisonPredicate.withChildren(left,
|
||||
convertDecimalToIntegerLikeLiteral(
|
||||
literal.setScale(0, RoundingMode.FLOOR)));
|
||||
} else if (comparisonPredicate instanceof LessThan
|
||||
|| comparisonPredicate instanceof GreaterThanEqual) {
|
||||
return comparisonPredicate.withChildren(left,
|
||||
convertDecimalToIntegerLikeLiteral(
|
||||
literal.setScale(0, RoundingMode.CEILING)));
|
||||
}
|
||||
} else {
|
||||
return comparisonPredicate.withChildren(left,
|
||||
convertDecimalToIntegerLikeLiteral(literal));
|
||||
}
|
||||
}
|
||||
return comparisonPredicate;
|
||||
}
|
||||
|
||||
private IntegerLikeLiteral convertDecimalToIntegerLikeLiteral(BigDecimal decimal) {
|
||||
Preconditions.checkArgument(
|
||||
decimal.scale() == 0 && decimal.compareTo(new BigDecimal(Long.MAX_VALUE)) <= 0,
|
||||
"decimal literal must have 0 scale and smaller than Long.MAX_VALUE");
|
||||
long val = decimal.longValue();
|
||||
if (val <= Byte.MAX_VALUE) {
|
||||
return new TinyIntLiteral((byte) val);
|
||||
} else if (val <= Short.MAX_VALUE) {
|
||||
return new SmallIntLiteral((short) val);
|
||||
} else if (val <= Integer.MAX_VALUE) {
|
||||
return new IntegerLiteral((int) val);
|
||||
} else {
|
||||
return new BigIntLiteral(val);
|
||||
}
|
||||
}
|
||||
|
||||
private Expression migrateCastToDateTime(Cast cast) {
|
||||
//cast( cast(v as date) as datetime) if v is datetime, set left = v
|
||||
if (cast.child() instanceof Cast
|
||||
|
||||
@ -72,4 +72,252 @@ suite("test_simplify_comparison") {
|
||||
}
|
||||
|
||||
sql "select cast('1234' as decimalv3(18,4)) > 2000;"
|
||||
|
||||
sql 'drop table if exists simple_test_table_t;'
|
||||
sql """CREATE TABLE IF NOT EXISTS `simple_test_table_t` (
|
||||
a tinyint,
|
||||
b smallint,
|
||||
c int,
|
||||
d bigint,
|
||||
e largeint
|
||||
) ENGINE=OLAP
|
||||
UNIQUE KEY (`a`)
|
||||
DISTRIBUTED BY HASH(`a`) BUCKETS 120
|
||||
PROPERTIES (
|
||||
"replication_num" = "1",
|
||||
"in_memory" = "false",
|
||||
"compression" = "LZ4"
|
||||
);"""
|
||||
|
||||
explain {
|
||||
sql "verbose select * from simple_test_table_t where a = cast(1.0 as double) and b = cast(1.0 as double) and c = cast(1.0 as double) and d = cast(1.0 as double);"
|
||||
notContains "CAST"
|
||||
}
|
||||
|
||||
explain {
|
||||
sql "verbose select * from simple_test_table_t where e = cast(1.0 as double);"
|
||||
contains "CAST"
|
||||
}
|
||||
|
||||
explain {
|
||||
sql "verbose select * from simple_test_table_t where a > cast(1.0 as double) and b > cast(1.0 as double) and c > cast(1.0 as double) and d > cast(1.0 as double);"
|
||||
notContains "CAST"
|
||||
}
|
||||
|
||||
explain {
|
||||
sql "verbose select * from simple_test_table_t where e > cast(1.0 as double);"
|
||||
contains "CAST"
|
||||
}
|
||||
|
||||
explain {
|
||||
sql "verbose select * from simple_test_table_t where a < cast(1.0 as double) and b < cast(1.0 as double) and c < cast(1.0 as double) and d < cast(1.0 as double);"
|
||||
notContains "CAST"
|
||||
}
|
||||
|
||||
explain {
|
||||
sql "verbose select * from simple_test_table_t where e < cast(1.0 as double);"
|
||||
contains "CAST"
|
||||
}
|
||||
|
||||
explain {
|
||||
sql "verbose select * from simple_test_table_t where a >= cast(1.0 as double) and b >= cast(1.0 as double) and c >= cast(1.0 as double) and d >= cast(1.0 as double);"
|
||||
notContains "CAST"
|
||||
}
|
||||
|
||||
explain {
|
||||
sql "verbose select * from simple_test_table_t where e >= cast(1.0 as double);"
|
||||
contains "CAST"
|
||||
}
|
||||
|
||||
explain {
|
||||
sql "verbose select * from simple_test_table_t where a <= cast(1.0 as double) and b <= cast(1.0 as double) and c <= cast(1.0 as double) and d <= cast(1.0 as double);"
|
||||
notContains "CAST"
|
||||
}
|
||||
|
||||
explain {
|
||||
sql "verbose select * from simple_test_table_t where e <= cast(1.0 as double);"
|
||||
contains "CAST"
|
||||
}
|
||||
|
||||
explain {
|
||||
sql "verbose select * from simple_test_table_t where a = cast(1.1 as double) and b = cast(1.1 as double) and c = cast(1.1 as double) and d = cast(1.1 as double);"
|
||||
contains "a[#0] IS NULL"
|
||||
contains "b[#1] IS NULL"
|
||||
contains "c[#2] IS NULL"
|
||||
contains "d[#3] IS NULL"
|
||||
contains "AND NULL"
|
||||
}
|
||||
|
||||
explain {
|
||||
sql "verbose select * from simple_test_table_t where e = cast(1.1 as double);"
|
||||
contains "CAST(e[#4] AS DOUBLE) = 1.1"
|
||||
}
|
||||
|
||||
explain {
|
||||
sql "verbose select * from simple_test_table_t where a > cast(1.1 as double) and b > cast(1.1 as double) and c > cast(1.1 as double) and d > cast(1.1 as double);"
|
||||
contains "a[#0] > 1"
|
||||
contains "b[#1] > 1"
|
||||
contains "c[#2] > 1"
|
||||
contains "d[#3] > 1"
|
||||
}
|
||||
|
||||
explain {
|
||||
sql "verbose select * from simple_test_table_t where e > cast(1.1 as double);"
|
||||
contains "CAST(e[#4] AS DOUBLE) > 1.1"
|
||||
}
|
||||
|
||||
explain {
|
||||
sql "verbose select * from simple_test_table_t where a < cast(1.1 as double) and b < cast(1.1 as double) and c < cast(1.1 as double) and d < cast(1.1 as double);"
|
||||
contains "a[#0] < 2"
|
||||
contains "b[#1] < 2"
|
||||
contains "c[#2] < 2"
|
||||
contains "d[#3] < 2"
|
||||
}
|
||||
|
||||
explain {
|
||||
sql "verbose select * from simple_test_table_t where e < cast(1.1 as double);"
|
||||
contains "CAST(e[#4] AS DOUBLE) < 1.1"
|
||||
}
|
||||
|
||||
explain {
|
||||
sql "verbose select * from simple_test_table_t where a >= cast(1.1 as double) and b >= cast(1.1 as double) and c >= cast(1.1 as double) and d >= cast(1.1 as double);"
|
||||
contains "a[#0] >= 2"
|
||||
contains "b[#1] >= 2"
|
||||
contains "c[#2] >= 2"
|
||||
contains "d[#3] >= 2"
|
||||
}
|
||||
|
||||
explain {
|
||||
sql "verbose select * from simple_test_table_t where e >= cast(1.1 as double);"
|
||||
contains "CAST(e[#4] AS DOUBLE) >= 1.1"
|
||||
}
|
||||
|
||||
explain {
|
||||
sql "verbose select * from simple_test_table_t where a <= cast(1.1 as double) and b <= cast(1.1 as double) and c <= cast(1.1 as double) and d <= cast(1.1 as double);"
|
||||
contains "a[#0] <= 1"
|
||||
contains "b[#1] <= 1"
|
||||
contains "c[#2] <= 1"
|
||||
contains "d[#3] <= 1"
|
||||
}
|
||||
|
||||
explain {
|
||||
sql "verbose select * from simple_test_table_t where e <= cast(1.1 as double);"
|
||||
contains "CAST(e[#4] AS DOUBLE) <= 1.1"
|
||||
}
|
||||
|
||||
explain {
|
||||
sql "verbose select * from simple_test_table_t where a = 1.0 and b = 1.0 and c = 1.0 and d = 1.0;"
|
||||
notContains "CAST"
|
||||
}
|
||||
|
||||
explain {
|
||||
sql "verbose select * from simple_test_table_t where e = 1.0;"
|
||||
contains "CAST"
|
||||
}
|
||||
|
||||
explain {
|
||||
sql "verbose select * from simple_test_table_t where a > 1.0 and b > 1.0 and c > 1.0 and d > 1.0;"
|
||||
notContains "CAST"
|
||||
}
|
||||
|
||||
explain {
|
||||
sql "verbose select * from simple_test_table_t where e > 1.0;"
|
||||
contains "CAST"
|
||||
}
|
||||
|
||||
explain {
|
||||
sql "verbose select * from simple_test_table_t where a < 1.0 and b < 1.0 and c < 1.0 and d < 1.0;"
|
||||
notContains "CAST"
|
||||
}
|
||||
|
||||
explain {
|
||||
sql "verbose select * from simple_test_table_t where e < 1.0;"
|
||||
contains "CAST"
|
||||
}
|
||||
|
||||
explain {
|
||||
sql "verbose select * from simple_test_table_t where a >= 1.0 and b >= 1.0 and c >= 1.0 and d >= 1.0;"
|
||||
notContains "CAST"
|
||||
}
|
||||
|
||||
explain {
|
||||
sql "verbose select * from simple_test_table_t where e >= 1.0;"
|
||||
contains "CAST"
|
||||
}
|
||||
|
||||
explain {
|
||||
sql "verbose select * from simple_test_table_t where a <= 1.0 and b <= 1.0 and c <= 1.0 and d <= 1.0;"
|
||||
notContains "CAST"
|
||||
}
|
||||
|
||||
explain {
|
||||
sql "verbose select * from simple_test_table_t where e <= 1.0;"
|
||||
contains "CAST"
|
||||
}
|
||||
|
||||
explain {
|
||||
sql "verbose select * from simple_test_table_t where a = 1.1 and b = 1.1 and c = 1.1 and d = 1.1;"
|
||||
contains "a[#0] IS NULL"
|
||||
contains "b[#1] IS NULL"
|
||||
contains "c[#2] IS NULL"
|
||||
contains "d[#3] IS NULL"
|
||||
contains "AND NULL"
|
||||
}
|
||||
|
||||
explain {
|
||||
sql "verbose select * from simple_test_table_t where e = 1.1;"
|
||||
contains "CAST(e[#4] AS DOUBLE) = 1.1"
|
||||
}
|
||||
|
||||
explain {
|
||||
sql "verbose select * from simple_test_table_t where a > 1.1 and b > 1.1 and c > 1.1 and d > 1.1;"
|
||||
contains "a[#0] > 1"
|
||||
contains "b[#1] > 1"
|
||||
contains "c[#2] > 1"
|
||||
contains "d[#3] > 1"
|
||||
}
|
||||
|
||||
explain {
|
||||
sql "verbose select * from simple_test_table_t where e > 1.1;"
|
||||
contains "CAST(e[#4] AS DOUBLE) > 1.1"
|
||||
}
|
||||
|
||||
explain {
|
||||
sql "verbose select * from simple_test_table_t where a < 1.1 and b < 1.1 and c < 1.1 and d < 1.1;"
|
||||
contains "a[#0] < 2"
|
||||
contains "b[#1] < 2"
|
||||
contains "c[#2] < 2"
|
||||
contains "d[#3] < 2"
|
||||
}
|
||||
|
||||
explain {
|
||||
sql "verbose select * from simple_test_table_t where e < 1.1;"
|
||||
contains "CAST(e[#4] AS DOUBLE) < 1.1"
|
||||
}
|
||||
|
||||
explain {
|
||||
sql "verbose select * from simple_test_table_t where a >= 1.1 and b >= 1.1 and c >= 1.1 and d >= 1.1;"
|
||||
contains "a[#0] >= 2"
|
||||
contains "b[#1] >= 2"
|
||||
contains "c[#2] >= 2"
|
||||
contains "d[#3] >= 2"
|
||||
}
|
||||
|
||||
explain {
|
||||
sql "verbose select * from simple_test_table_t where e >= 1.1;"
|
||||
contains "CAST(e[#4] AS DOUBLE) >= 1.1"
|
||||
}
|
||||
|
||||
explain {
|
||||
sql "verbose select * from simple_test_table_t where a <= 1.1 and b <= 1.1 and c <= 1.1 and d <= 1.1;"
|
||||
contains "a[#0] <= 1"
|
||||
contains "b[#1] <= 1"
|
||||
contains "c[#2] <= 1"
|
||||
contains "d[#3] <= 1"
|
||||
}
|
||||
|
||||
explain {
|
||||
sql "verbose select * from simple_test_table_t where e <= 1.1;"
|
||||
contains "CAST(e[#4] AS DOUBLE) <= 1.1"
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user