From 0b90e37227b01edd56edcb24e69398f89b3065b6 Mon Sep 17 00:00:00 2001 From: wangqt Date: Fri, 24 May 2024 14:26:52 +0800 Subject: [PATCH] [fix](Nereids) string literal coercion of in predicate (#35337) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit pick from master #35200 Description: The sql execute much slow when the literal value with string format in `in predicate`; and the real data is integral type。 ``` mysql> set enable_nereids_planner = false; Query OK, 0 rows affected (0.03 sec) mysql> select id,sum(clicks) from a_table where id in ('787934713', '306960695') group by id limit 10; +------------+---------------+ | id | sum(`clicks`) | +------------+---------------+ | 787934713 | 2838 | | 306960695 | 339 | +------------+---------------+ 2 rows in set (1.81 sec) mysql> set enable_nereids_planner = true; Query OK, 0 rows affected (0.02 sec) mysql> select id,sum(clicks) from a_table where id in ('787934713', '306960695') group by id limit 10; +------------+-------------+ | id | sum(clicks) | +------------+-------------+ | 787934713 | 2838 | | 306960695 | 339 | +------------+-------------+ 2 rows in set (28.14 sec) ``` Reason: In legacy planner, the string literal with convert to integral value, but in the nereids planner do not do this convert and with do string matching in BE。 Solved: do process string literal with numeric in `in predicate` like in `comparison predicate`; test table: ``` create table a_table( k1 BIGINT NOT NULL, k2 VARCHAR(100) NOT NULL, v1 INT SUM NULL DEFAULT "0" ) ENGINE=OLAP AGGREGATE KEY(k1,k2) distributed BY hash(k1) buckets 2 properties("replication_num" = "1"); insert into a_table values (10, 'name1', 10),(20, 'name2', 10); explain plan select * from a_table where k1 in ('10', '20001'); ``` before optimize: ``` +--------------------------------------------------------------------------------------------------------------------------------------+ | Explain String(Nereids Planner) | +--------------------------------------------------------------------------------------------------------------------------------------+ | ========== PARSED PLAN (time: 1ms) ========== | | UnboundResultSink[4] ( ) | | +--LogicalProject[3] ( distinct=false, projects=[*], excepts=[] ) | | +--LogicalFilter[2] ( predicates='k1 IN ('10001', '20001') ) | | +--LogicalCheckPolicy ( ) | | +--UnboundRelation ( id=RelationId#0, nameParts=a_table ) | | | | ========== ANALYZED PLAN (time: 2ms) ========== | | LogicalResultSink[15] ( outputExprs=[k1#0, k2#1, v1#2] ) | | +--LogicalProject[13] ( distinct=false, projects=[k1#0, k2#1, v1#2], excepts=[] ) | | +--LogicalFilter[11] ( predicates=cast(k1#0 as TEXT) IN ('10001', '20001') ) | | +--LogicalOlapScan ( qualified=internal.db.a_table, indexName=, selectedIndexId=12003, preAgg=UNSET ) | | | | ========== REWRITTEN PLAN (time: 6ms) ========== | | LogicalResultSink[45] ( outputExprs=[k1#0, k2#1, v1#2] ) | | +--LogicalFilter[43] ( predicates=cast(k1#0 as TEXT) IN ('10001', '20001') ) | | +--LogicalOlapScan ( qualified=internal.db.a_table, indexName=a_table, selectedIndexId=12003, preAgg=OFF, No aggregate on scan. ) | | | | ========== OPTIMIZED PLAN (time: 6ms) ========== | | PhysicalResultSink[90] ( outputExprs=[k1#0, k2#1, v1#2] ) | | +--PhysicalDistribute[87]@1 ( stats=0.33, distributionSpec=DistributionSpecGather ) | | +--PhysicalFilter[84]@1 ( stats=0.33, predicates=cast(k1#0 as TEXT) IN ('10001', '20001') ) | | +--PhysicalOlapScan[a_table]@0 ( stats=1 ) | +--------------------------------------------------------------------------------------------------------------------------------------+ ``` after optimize: ``` +--------------------------------------------------------------------------------------------------------------------------------------+ | Explain String(Nereids Planner) | +--------------------------------------------------------------------------------------------------------------------------------------+ | ========== PARSED PLAN (time: 15ms) ========== | | UnboundResultSink[4] ( ) | | +--LogicalProject[3] ( distinct=false, projects=[*], excepts=[] ) | | +--LogicalFilter[2] ( predicates='k1 IN ('10001', '20001') ) | | +--LogicalCheckPolicy ( ) | | +--UnboundRelation ( id=RelationId#0, nameParts=a_table ) | | | | ========== ANALYZED PLAN (time: 11ms) ========== | | LogicalResultSink[15] ( outputExprs=[k1#0, k2#1, v1#2] ) | | +--LogicalProject[13] ( distinct=false, projects=[k1#0, k2#1, v1#2], excepts=[] ) | | +--LogicalFilter[11] ( predicates=k1#0 IN (10001, 20001) ) | | +--LogicalOlapScan ( qualified=internal.db.a_table, indexName=, selectedIndexId=12003, preAgg=UNSET ) | | | | ========== REWRITTEN PLAN (time: 12ms) ========== | | LogicalResultSink[45] ( outputExprs=[k1#0, k2#1, v1#2] ) | | +--LogicalFilter[43] ( predicates=k1#0 IN (10001, 20001) ) | | +--LogicalOlapScan ( qualified=internal.db.a_table, indexName=a_table, selectedIndexId=12003, preAgg=OFF, No aggregate on scan. ) | | | | ========== OPTIMIZED PLAN (time: 4ms) ========== | | PhysicalResultSink[90] ( outputExprs=[k1#0, k2#1, v1#2] ) | | +--PhysicalDistribute[87]@1 ( stats=0, distributionSpec=DistributionSpecGather ) | | +--PhysicalFilter[84]@1 ( stats=0, predicates=k1#0 IN (10001, 20001) ) | | +--PhysicalOlapScan[a_table]@0 ( stats=2 ) | +--------------------------------------------------------------------------------------------------------------------------------------+ ``` --- .../doris/nereids/util/TypeCoercionUtils.java | 34 +++++++++++++++---- .../nereids/util/TypeCoercionUtilsTest.java | 23 +++++++++++++ 2 files changed, 51 insertions(+), 6 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/TypeCoercionUtils.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/TypeCoercionUtils.java index 214fbc1804..e6dcca83a8 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/TypeCoercionUtils.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/TypeCoercionUtils.java @@ -118,7 +118,9 @@ import org.apache.logging.log4j.Logger; import java.math.BigDecimal; import java.math.BigInteger; +import java.util.ArrayList; import java.util.List; +import java.util.ListIterator; import java.util.Map; import java.util.Optional; import java.util.function.Function; @@ -975,8 +977,28 @@ public class TypeCoercionUtils { } return inPredicate; } + // process string literal with numeric + boolean hitString = false; + List newOptions = new ArrayList<>(inPredicate.getOptions()); + if (!(inPredicate.getCompareExpr().getDataType().isStringLikeType())) { + ListIterator iterator = newOptions.listIterator(); + while (iterator.hasNext()) { + Expression origOption = iterator.next(); + if (origOption instanceof Literal && ((Literal) origOption).isStringLikeLiteral()) { + Optional option = TypeCoercionUtils.characterLiteralTypeCoercion( + ((Literal) origOption).getStringValue(), inPredicate.getCompareExpr().getDataType()); + if (option.isPresent()) { + iterator.set(option.get()); + hitString = true; + } + } + } + } + final InPredicate fmtInPredicate = + hitString ? new InPredicate(inPredicate.getCompareExpr(), newOptions) : inPredicate; + Optional optionalCommonType = TypeCoercionUtils.findWiderCommonTypeForComparison( - inPredicate.children() + fmtInPredicate.children() .stream() .map(Expression::getDataType).collect(Collectors.toList()), true); @@ -999,18 +1021,18 @@ public class TypeCoercionUtils { if (optionalCommonType.isPresent()) { optionalCommonType = Optional.of(downgradeDecimalAndDateLikeType( optionalCommonType.get(), - inPredicate.getCompareExpr(), - inPredicate.getOptions().toArray(new Expression[0]))); + fmtInPredicate.getCompareExpr(), + fmtInPredicate.getOptions().toArray(new Expression[0]))); } return optionalCommonType .map(commonType -> { - List newChildren = inPredicate.children().stream() + List newChildren = fmtInPredicate.children().stream() .map(e -> TypeCoercionUtils.castIfNotSameType(e, commonType)) .collect(Collectors.toList()); - return inPredicate.withChildren(newChildren); + return fmtInPredicate.withChildren(newChildren); }) - .orElse(inPredicate); + .orElse(fmtInPredicate); } /** diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/TypeCoercionUtilsTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/TypeCoercionUtilsTest.java index be465d9c37..8d32dfbb2a 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/TypeCoercionUtilsTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/TypeCoercionUtilsTest.java @@ -802,4 +802,27 @@ public class TypeCoercionUtilsTest { datetimeDowngrade = (EqualTo) TypeCoercionUtils.processComparisonPredicate(datetimeDowngrade); Assertions.assertEquals(DateTimeType.INSTANCE, datetimeDowngrade.left().getDataType()); } + + @Test + public void testProcessInStringCoercion() { + // BigInt slot vs String literal + InPredicate bigintString = new InPredicate( + new SlotReference("c1", BigIntType.INSTANCE), + ImmutableList.of( + new VarcharLiteral("200"), + new VarcharLiteral("922337203685477001"))); + bigintString = (InPredicate) TypeCoercionUtils.processInPredicate(bigintString); + Assertions.assertEquals(BigIntType.INSTANCE, bigintString.getCompareExpr().getDataType()); + Assertions.assertEquals(BigIntType.INSTANCE, bigintString.getOptions().get(0).getDataType()); + + // SmallInt slot vs String literal + InPredicate smallIntString = new InPredicate( + new SlotReference("c1", SmallIntType.INSTANCE), + ImmutableList.of( + new DecimalLiteral(new BigDecimal("987654.321")), + new VarcharLiteral("922337203685477001"))); + smallIntString = (InPredicate) TypeCoercionUtils.processInPredicate(smallIntString); + Assertions.assertEquals(DecimalV3Type.createDecimalV3Type(23, 3), smallIntString.getCompareExpr().getDataType()); + Assertions.assertEquals(DecimalV3Type.createDecimalV3Type(23, 3), smallIntString.getOptions().get(0).getDataType()); + } }