diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/SlotRef.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/SlotRef.java index e8c1c82b12..6bb77fd379 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/analysis/SlotRef.java +++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/SlotRef.java @@ -305,7 +305,7 @@ public class SlotRef extends Expr { if ((col == null) != (other.col == null)) { return false; } - if (col != null && !col.toLowerCase().equals(other.col.toLowerCase())) { + if (col != null && !col.equalsIgnoreCase(other.col)) { return false; } return true; diff --git a/fe/fe-core/src/main/java/org/apache/doris/rewrite/InferFiltersRule.java b/fe/fe-core/src/main/java/org/apache/doris/rewrite/InferFiltersRule.java index e0c4e4d628..50d123f99c 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/rewrite/InferFiltersRule.java +++ b/fe/fe-core/src/main/java/org/apache/doris/rewrite/InferFiltersRule.java @@ -43,6 +43,8 @@ import java.util.Set; /** * The function of this rule is to derive a new predicate based on the current predicate. + * + *
  * eg.
  * t1.id = t2.id and t2.id = t3.id and t3.id = 100;
  * -->
@@ -51,8 +53,9 @@ import java.util.Set;
  * 1. Register a new rule InferFiltersRule and add it to GlobalState.
  * 2. Traverse Conjunct to construct on/where equivalence connection, numerical connection and isNullPredicate.
  * 3. Use Warshall to infer all equivalence connections.
- *    details:https://en.wikipedia.org/wiki/Floyd%E2%80%93Warshall_algorithm
+ *    details: https://en.wikipedia.org/wiki/Floyd%E2%80%93Warshall_algorithm
  * 4. Construct additional numerical connections and isNullPredicate.
+ * 
*/ public class InferFiltersRule implements ExprRewriteRule { private static final Logger LOG = LogManager.getLogger(InferFiltersRule.class); @@ -127,10 +130,10 @@ public class InferFiltersRule implements ExprRewriteRule { if (!newExprWithState.isEmpty()) { Expr rewriteExpr = expr; - for (int index = 0; index < newExprWithState.size(); index++) { - if (newExprWithState.get(index).second) { - rewriteExpr = new CompoundPredicate(CompoundPredicate.Operator.AND, - rewriteExpr, newExprWithState.get(index).first); + for (Pair exprBooleanPair : newExprWithState) { + if (exprBooleanPair.second) { + rewriteExpr = new CompoundPredicate(CompoundPredicate.Operator.AND, rewriteExpr, + exprBooleanPair.first); } } return rewriteExpr; @@ -171,7 +174,7 @@ public class InferFiltersRule implements ExprRewriteRule { if (conjunct instanceof BinaryPredicate && conjunct.getChild(0) != null && conjunct.getChild(1) != null) { - if (conjunct.getChild(0).unwrapSlotRef() instanceof SlotRef + if (conjunct.getChild(0).unwrapSlotRef() != null && conjunct.getChild(1) instanceof LiteralExpr) { Pair pair = new Pair<>(conjunct.getChild(0).unwrapSlotRef(), conjunct.getChild(1)); if (!slotToLiteralDeDuplication.contains(pair)) { @@ -184,8 +187,8 @@ public class InferFiltersRule implements ExprRewriteRule { analyzer.registerGlobalSlotToLiteralDeDuplication(pair); } } else if (((BinaryPredicate) conjunct).getOp().isEquivalence() - && conjunct.getChild(0).unwrapSlotRef() instanceof SlotRef - && conjunct.getChild(1).unwrapSlotRef() instanceof SlotRef) { + && conjunct.getChild(0).unwrapSlotRef() != null + && conjunct.getChild(1).unwrapSlotRef() != null) { Pair pair = new Pair<>(conjunct.getChild(0).unwrapSlotRef(), conjunct.getChild(1).unwrapSlotRef()); Pair eqPair = new Pair<>(conjunct.getChild(1).unwrapSlotRef(), @@ -202,7 +205,7 @@ public class InferFiltersRule implements ExprRewriteRule { } } else if (conjunct instanceof IsNullPredicate && conjunct.getChild(0) != null - && conjunct.getChild(0).unwrapSlotRef() instanceof SlotRef) { + && conjunct.getChild(0).unwrapSlotRef() != null) { if (!isNullDeDuplication.contains(conjunct.getChild(0).unwrapSlotRef()) && ((IsNullPredicate) conjunct).isNotNull()) { isNullDeDuplication.add(conjunct.getChild(0).unwrapSlotRef()); @@ -214,7 +217,7 @@ public class InferFiltersRule implements ExprRewriteRule { } } else if (conjunct instanceof InPredicate && conjunct.getChild(0) != null - && conjunct.getChild(0).unwrapSlotRef() instanceof SlotRef) { + && conjunct.getChild(0).unwrapSlotRef() != null) { if (!inDeDuplication.contains(conjunct.getChild(0).unwrapSlotRef())) { inDeDuplication.add(conjunct.getChild(0).unwrapSlotRef()); inExpr.add(conjunct); @@ -234,10 +237,10 @@ public class InferFiltersRule implements ExprRewriteRule { * old expr:t1.id = t2.id and t2.id = t3.id and t3.id = t4.id * new expr:t1.id = t2.id and t2.id = t3.id and t3.id = t4.id and t1.id = t3.id and t1.id = t4.id and t2.id = t4.id * - * @param slotEqSlotExpr - * @param slotEqSlotDeDuplication + * @param slotEqSlotExpr slot to slot exprs + * @param slotEqSlotDeDuplication set pairs in slot = slot exprs * @param exprToWarshallArraySubscript: A Map the key is Expr, the value is int - * @param warshallArraySubscriptToExpr: A Map the key is int, the value is exper + * @param warshallArraySubscriptToExpr: A Map the key is int, the value is expr */ private void genNewSlotEqSlotPredicate(List slotEqSlotExpr, Set> slotEqSlotDeDuplication, @@ -268,9 +271,9 @@ public class InferFiltersRule implements ExprRewriteRule { * * @param warshall: Two-dimensional array * @param arrayMaxSize: slotEqSlotExpr.size() * 2 - * @param slotEqSlotExpr - * @param exprToWarshallArraySubscript - * @param warshallArraySubscriptToExpr + * @param slotEqSlotExpr slot to slot exprs + * @param exprToWarshallArraySubscript expr to offset in Warshall array + * @param warshallArraySubscriptToExpr offset in Warshall array to expr * @return needGenWarshallArray. True:needGen; False:don't needGen */ private boolean initWarshallArray(int[][] warshall, @@ -281,8 +284,8 @@ public class InferFiltersRule implements ExprRewriteRule { boolean needGenWarshallArray = false; int index = 0; for (Expr slotEqSlot : slotEqSlotExpr) { - int row = 0; - int column = 0; + int row; + int column; if (!exprToWarshallArraySubscript.containsKey(slotEqSlot.getChild(0))) { exprToWarshallArraySubscript.put(slotEqSlot.getChild(0), index); warshallArraySubscriptToExpr.put(index, slotEqSlot.getChild(0)); @@ -303,7 +306,7 @@ public class InferFiltersRule implements ExprRewriteRule { if (row >= arrayMaxSize || column >= arrayMaxSize) { - LOG.debug("Error row or column", row, column, arrayMaxSize); + LOG.debug("Error row {} or column {}, but max size is {}.", row, column, arrayMaxSize); needGenWarshallArray = false; break; } else { @@ -391,18 +394,18 @@ public class InferFiltersRule implements ExprRewriteRule { Analyzer analyzer, ExprRewriter.ClauseType clauseType) { SlotRef checkSlot = slotToLiteral.getChild(0).unwrapSlotRef(); - if (checkSlot instanceof SlotRef) { + if (checkSlot != null) { for (Expr conjunct : slotEqSlotExpr) { SlotRef leftSlot = conjunct.getChild(0).unwrapSlotRef(); SlotRef rightSlot = conjunct.getChild(1).unwrapSlotRef(); - if (leftSlot instanceof SlotRef && rightSlot instanceof SlotRef) { - if (checkSlot.notCheckDescIdEquals(leftSlot)) { + if (leftSlot != null && rightSlot != null) { + if (checkSlot.equals(leftSlot)) { addNewBinaryPredicate(genNewBinaryPredicate(slotToLiteral, rightSlot), slotToLiteralDeDuplication, newExprWithState, isNeedInfer(rightSlot, leftSlot, analyzer, clauseType), analyzer, clauseType); - } else if (checkSlot.notCheckDescIdEquals(rightSlot)) { + } else if (checkSlot.equals(rightSlot)) { addNewBinaryPredicate(genNewBinaryPredicate(slotToLiteral, leftSlot), slotToLiteralDeDuplication, newExprWithState, isNeedInfer(leftSlot, rightSlot, analyzer, clauseType), @@ -426,17 +429,15 @@ public class InferFiltersRule implements ExprRewriteRule { boolean ret = false; TupleId newTid = newSlot.getDesc().getParent().getRef().getId(); TupleId checkTid = checkSlot.getDesc().getParent().getRef().getId(); - boolean needChange = false; Pair tids = new Pair<>(newTid, checkTid); if (analyzer.isContainTupleIds(tids)) { JoinOperator joinOperator = analyzer.getAnyTwoTablesJoinOp(tids); - ret = checkNeedInfer(joinOperator, needChange, clauseType); + ret = checkNeedInfer(joinOperator, false, clauseType); } else { Pair changeTids = new Pair<>(checkTid, newTid); if (analyzer.isContainTupleIds(changeTids)) { - needChange = true; JoinOperator joinOperator = analyzer.getAnyTwoTablesJoinOp(changeTids); - ret = checkNeedInfer(joinOperator, needChange, clauseType); + ret = checkNeedInfer(joinOperator, true, clauseType); } } return ret; @@ -474,8 +475,7 @@ public class InferFiltersRule implements ExprRewriteRule { private Expr genNewBinaryPredicate(Expr oldExpr, Expr newSlot) { if (oldExpr instanceof BinaryPredicate) { BinaryPredicate oldBP = (BinaryPredicate) oldExpr; - BinaryPredicate newBP = new BinaryPredicate(oldBP.getOp(), newSlot, oldBP.getChild(1)); - return newBP; + return new BinaryPredicate(oldBP.getOp(), newSlot, oldBP.getChild(1)); } return oldExpr; } @@ -534,16 +534,16 @@ public class InferFiltersRule implements ExprRewriteRule { if (expr instanceof IsNullPredicate) { IsNullPredicate isNullPredicate = (IsNullPredicate) expr; SlotRef checkSlot = isNullPredicate.getChild(0).unwrapSlotRef(); - if (checkSlot instanceof SlotRef) { + if (checkSlot != null) { for (Expr conjunct : slotEqSlotExpr) { SlotRef leftSlot = conjunct.getChild(0).unwrapSlotRef(); SlotRef rightSlot = conjunct.getChild(1).unwrapSlotRef(); - if (leftSlot instanceof SlotRef && rightSlot instanceof SlotRef) { - if (checkSlot.notCheckDescIdEquals(leftSlot) && isNullPredicate.isNotNull()) { + if (leftSlot != null && rightSlot != null) { + if (checkSlot.equals(leftSlot) && isNullPredicate.isNotNull()) { addNewIsNotNullPredicate(genNewIsNotNullPredicate(isNullPredicate, rightSlot), isNullDeDuplication, newExprWithState, analyzer, clauseType); - } else if (checkSlot.notCheckDescIdEquals(rightSlot)) { + } else if (checkSlot.equals(rightSlot)) { addNewIsNotNullPredicate(genNewIsNotNullPredicate(isNullPredicate, leftSlot), isNullDeDuplication, newExprWithState, analyzer, clauseType); } @@ -558,11 +558,7 @@ public class InferFiltersRule implements ExprRewriteRule { * @return new IsNullPredicate. */ private Expr genNewIsNotNullPredicate(IsNullPredicate oldExpr, Expr newSlot) { - if (oldExpr instanceof IsNullPredicate) { - IsNullPredicate newExpr = new IsNullPredicate(newSlot, oldExpr.isNotNull()); - return newExpr; - } - return oldExpr; + return oldExpr != null ? new IsNullPredicate(newSlot, oldExpr.isNotNull()) : null; } /** @@ -614,18 +610,18 @@ public class InferFiltersRule implements ExprRewriteRule { if (inExpr instanceof InPredicate) { InPredicate inpredicate = (InPredicate) inExpr; SlotRef checkSlot = inpredicate.getChild(0).unwrapSlotRef(); - if (checkSlot instanceof SlotRef) { + if (checkSlot != null) { for (Expr conjunct : slotEqSlotExpr) { SlotRef leftSlot = conjunct.getChild(0).unwrapSlotRef(); SlotRef rightSlot = conjunct.getChild(1).unwrapSlotRef(); - if (leftSlot instanceof SlotRef && rightSlot instanceof SlotRef) { - if (checkSlot.notCheckDescIdEquals(leftSlot)) { + if (leftSlot != null && rightSlot != null) { + if (checkSlot.equals(leftSlot)) { addNewInPredicate(genNewInPredicate(inpredicate, rightSlot), inDeDuplication, newExprWithState, isNeedInfer(rightSlot, leftSlot, analyzer, clauseType), analyzer, clauseType); - } else if (checkSlot.notCheckDescIdEquals(rightSlot)) { + } else if (checkSlot.equals(rightSlot)) { addNewInPredicate(genNewInPredicate(inpredicate, leftSlot), inDeDuplication, newExprWithState, isNeedInfer(leftSlot, rightSlot, analyzer, clauseType), @@ -644,8 +640,7 @@ public class InferFiltersRule implements ExprRewriteRule { private Expr genNewInPredicate(Expr oldExpr, Expr newSlot) { if (oldExpr instanceof InPredicate) { InPredicate oldBP = (InPredicate) oldExpr; - InPredicate newBP = new InPredicate(newSlot, oldBP.getListChildren(), oldBP.isNotIn()); - return newBP; + return new InPredicate(newSlot, oldBP.getListChildren(), oldBP.isNotIn()); } return oldExpr; } diff --git a/fe/fe-core/src/test/java/org/apache/doris/rewrite/InferFiltersRuleTest.java b/fe/fe-core/src/test/java/org/apache/doris/rewrite/InferFiltersRuleTest.java index 92367c4b2d..d8889d8d7a 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/rewrite/InferFiltersRuleTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/rewrite/InferFiltersRuleTest.java @@ -30,8 +30,8 @@ import org.junit.Test; import java.util.UUID; public class InferFiltersRuleTest { - private static String baseDir = "fe"; - private static String runningDir = baseDir + "/mocked/InferFiltersRuleTest/" + private static final String baseDir = "fe"; + private static final String runningDir = baseDir + "/mocked/InferFiltersRuleTest/" + UUID.randomUUID() + "/"; private static DorisAssert dorisAssert; private static final String DB_NAME = "db1"; @@ -55,7 +55,8 @@ public class InferFiltersRuleTest { + "distributed by hash(k1) buckets 3 properties('replication_num' = '1');"; dorisAssert.withTable(createTableSQL); createTableSQL = "create table " + DB_NAME + "." + TABLE_NAME_3 - + " (k1 tinyint, k2 smallint, k3 int, k4 bigint, k5 largeint, k6 date, k7 datetime, k8 float, k9 double) " + + " (k1 tinyint, k2 smallint, k3 int, k4 bigint," + + " k5 largeint, k6 date, k7 datetime, k8 float, k9 double) " + "distributed by hash(k1) buckets 3 properties('replication_num' = '1');"; dorisAssert.withTable(createTableSQL); } @@ -101,7 +102,8 @@ public class InferFiltersRuleTest { SessionVariable sessionVariable = dorisAssert.getSessionVariable(); sessionVariable.setEnableInferPredicate(true); Assert.assertTrue(sessionVariable.isEnableInferPredicate()); - String query = "select * from tb1 inner join tb2 inner join tb3 where tb1.k1 = tb2.k1 and tb2.k1 = tb3.k1 and tb3.k1 = 1"; + String query = "select * from tb1 inner join tb2 inner join tb3" + + " where tb1.k1 = tb2.k1 and tb2.k1 = tb3.k1 and tb3.k1 = 1"; String planString = dorisAssert.query(query).explainQuery(); Assert.assertTrue(planString.contains("`tb2`.`k1` = 1")); Assert.assertTrue(planString.contains("`tb1`.`k1` = 1")); @@ -152,7 +154,8 @@ public class InferFiltersRuleTest { SessionVariable sessionVariable = dorisAssert.getSessionVariable(); sessionVariable.setEnableInferPredicate(true); Assert.assertTrue(sessionVariable.isEnableInferPredicate()); - String query = "select * from tb1 inner join tb2 on tb1.k1 = tb2.k1 right outer join tb3 on tb2.k1 = tb3.k1 and tb2.k1 = 1"; + String query = "select * from tb1 inner join tb2 on tb1.k1 = tb2.k1" + + " right outer join tb3 on tb2.k1 = tb3.k1 and tb2.k1 = 1"; String planString = dorisAssert.query(query).explainQuery(); Assert.assertTrue(planString.contains("`tb1`.`k1` = 1")); Assert.assertFalse(planString.contains("`tb3`.`k1` = 1")); @@ -163,7 +166,8 @@ public class InferFiltersRuleTest { SessionVariable sessionVariable = dorisAssert.getSessionVariable(); sessionVariable.setEnableInferPredicate(true); Assert.assertTrue(sessionVariable.isEnableInferPredicate()); - String query = "select * from tb1 inner join tb2 on tb1.k1 = tb2.k1 right outer join tb3 on tb2.k1 = tb3.k1 and tb3.k1 = 1"; + String query = "select * from tb1 inner join tb2 on tb1.k1 = tb2.k1" + + " right outer join tb3 on tb2.k1 = tb3.k1 and tb3.k1 = 1"; String planString = dorisAssert.query(query).explainQuery(); Assert.assertTrue(planString.contains("`tb2`.`k1` = 1")); Assert.assertTrue(planString.contains("`tb1`.`k1` = 1")); @@ -174,7 +178,8 @@ public class InferFiltersRuleTest { SessionVariable sessionVariable = dorisAssert.getSessionVariable(); sessionVariable.setEnableInferPredicate(true); Assert.assertTrue(sessionVariable.isEnableInferPredicate()); - String query = "select * from tb1 inner join tb2 on tb1.k1 = tb2.k1 right outer join tb3 on tb2.k1 = tb3.k1 where tb1.k1 = tb2.k1 and tb2.k1 = tb3.k1 and tb2.k1 = 1"; + String query = "select * from tb1 inner join tb2 on tb1.k1 = tb2.k1 right outer join tb3 on tb2.k1 = tb3.k1" + + " where tb1.k1 = tb2.k1 and tb2.k1 = tb3.k1 and tb2.k1 = 1"; String planString = dorisAssert.query(query).explainQuery(); Assert.assertTrue(planString.contains("`tb1`.`k1` = 1")); Assert.assertTrue(planString.contains("`tb3`.`k1` = 1")); @@ -217,7 +222,8 @@ public class InferFiltersRuleTest { SessionVariable sessionVariable = dorisAssert.getSessionVariable(); sessionVariable.setEnableInferPredicate(true); Assert.assertTrue(sessionVariable.isEnableInferPredicate()); - String query = "select * from tb1 inner join tb2 inner join tb3 where tb1.k1 = tb2.k1 and tb2.k1 = tb3.k1 and tb3.k1 = 1"; + String query = "select * from tb1 inner join tb2 inner join tb3" + + " where tb1.k1 = tb2.k1 and tb2.k1 = tb3.k1 and tb3.k1 = 1"; String planString = dorisAssert.query(query).explainQuery(); Assert.assertTrue(planString.contains("`tb2`.`k1` = 1")); Assert.assertTrue(planString.contains("`tb1`.`k1` = 1")); @@ -228,7 +234,8 @@ public class InferFiltersRuleTest { SessionVariable sessionVariable = dorisAssert.getSessionVariable(); sessionVariable.setEnableInferPredicate(true); Assert.assertTrue(sessionVariable.isEnableInferPredicate()); - String query = "select * from tb1 inner join tb2 left outer join tb3 on tb3.k1 = tb2.k1 where tb1.k1 = tb2.k1 and tb2.k1 = tb3.k1 and tb3.k1 = 1"; + String query = "select * from tb1 inner join tb2 left outer join tb3 on tb3.k1 = tb2.k1" + + " where tb1.k1 = tb2.k1 and tb2.k1 = tb3.k1 and tb3.k1 = 1"; String planString = dorisAssert.query(query).explainQuery(); Assert.assertTrue(planString.contains("`tb2`.`k1` = 1")); Assert.assertTrue(planString.contains("`tb1`.`k1` = 1")); @@ -239,7 +246,8 @@ public class InferFiltersRuleTest { SessionVariable sessionVariable = dorisAssert.getSessionVariable(); sessionVariable.setEnableInferPredicate(true); Assert.assertTrue(sessionVariable.isEnableInferPredicate()); - String query = "select * from tb1 inner join tb2 left outer join tb3 on tb3.k1 = tb2.k1 where tb1.k1 = tb2.k1 and tb2.k1 = tb3.k1 and tb2.k1 = 1"; + String query = "select * from tb1 inner join tb2 left outer join tb3 on tb3.k1 = tb2.k1" + + " where tb1.k1 = tb2.k1 and tb2.k1 = tb3.k1 and tb2.k1 = 1"; String planString = dorisAssert.query(query).explainQuery(); Assert.assertTrue(planString.contains("`tb1`.`k1` = 1")); Assert.assertFalse(planString.contains("`tb3`.`k1` = 1")); @@ -250,7 +258,8 @@ public class InferFiltersRuleTest { SessionVariable sessionVariable = dorisAssert.getSessionVariable(); sessionVariable.setEnableInferPredicate(true); Assert.assertTrue(sessionVariable.isEnableInferPredicate()); - String query = "select * from tb1 inner join tb2 on tb1.k1 = tb2.k1 right outer join tb3 on tb2.k1 = tb3.k1 where tb1.k1 = tb2.k1 and tb2.k1 = tb3.k1 and tb2.k1 = 1"; + String query = "select * from tb1 inner join tb2 on tb1.k1 = tb2.k1 right outer join tb3 on tb2.k1 = tb3.k1" + + " where tb1.k1 = tb2.k1 and tb2.k1 = tb3.k1 and tb2.k1 = 1"; String planString = dorisAssert.query(query).explainQuery(); Assert.assertTrue(planString.contains("`tb1`.`k1` = 1")); Assert.assertTrue(planString.contains("`tb3`.`k1` = 1")); @@ -261,7 +270,8 @@ public class InferFiltersRuleTest { SessionVariable sessionVariable = dorisAssert.getSessionVariable(); sessionVariable.setEnableInferPredicate(true); Assert.assertTrue(sessionVariable.isEnableInferPredicate()); - String query = "select * from tb1 inner join tb2 on tb1.k1 = tb2.k1 right outer join tb3 on tb2.k1 = tb3.k1 where tb1.k1 = tb2.k1 and tb2.k1 = tb3.k1 and tb3.k1 = 1"; + String query = "select * from tb1 inner join tb2 on tb1.k1 = tb2.k1 right outer join tb3 on tb2.k1 = tb3.k1" + + " where tb1.k1 = tb2.k1 and tb2.k1 = tb3.k1 and tb3.k1 = 1"; String planString = dorisAssert.query(query).explainQuery(); Assert.assertFalse(planString.contains("`tb2`.`k1` = 1")); Assert.assertFalse(planString.contains("`tb1`.`k1` = 1")); @@ -272,7 +282,8 @@ public class InferFiltersRuleTest { SessionVariable sessionVariable = dorisAssert.getSessionVariable(); sessionVariable.setEnableInferPredicate(true); Assert.assertTrue(sessionVariable.isEnableInferPredicate()); - String query = "select * from tb1 inner join tb2 inner join tb3 where tb1.k1 = tb3.k1 and tb2.k1 = tb3.k1 and tb1.k1 is not null"; + String query = "select * from tb1 inner join tb2 inner join tb3" + + " where tb1.k1 = tb3.k1 and tb2.k1 = tb3.k1 and tb1.k1 is not null"; String planString = dorisAssert.query(query).explainQuery(); Assert.assertTrue(planString.contains("`tb3`.`k1` IS NOT NULL")); Assert.assertTrue(planString.contains("`tb2`.`k1` IS NOT NULL")); @@ -338,4 +349,37 @@ public class InferFiltersRuleTest { String planString = dorisAssert.query(query).explainQuery(); Assert.assertTrue(planString.contains("`tb2`.`k1` = 1")); } + + @Test + public void testSameAliasWithSlotEqualToLiteralInDifferentUnionChildren() throws Exception { + SessionVariable sessionVariable = dorisAssert.getSessionVariable(); + sessionVariable.setEnableInferPredicate(true); + Assert.assertTrue(sessionVariable.isEnableInferPredicate()); + String query = "select * from tb1 inner join tb2 on tb1.k1 = tb2.k1" + + " union select * from tb1 inner join tb2 on tb1.k2 = tb2.k2 where tb1.k1 = 3"; + String planString = dorisAssert.query(query).explainQuery(); + Assert.assertFalse(planString.contains("`tb2`.`k1` = 3")); + } + + @Test + public void testSameAliasWithSlotInPredicateInDifferentUnionChildren() throws Exception { + SessionVariable sessionVariable = dorisAssert.getSessionVariable(); + sessionVariable.setEnableInferPredicate(true); + Assert.assertTrue(sessionVariable.isEnableInferPredicate()); + String query = "select * from tb1 inner join tb2 on tb1.k1 = tb2.k1" + + " union select * from tb1 inner join tb2 on tb1.k2 = tb2.k2 where tb1.k1 in (3, 4, 5)"; + String planString = dorisAssert.query(query).explainQuery(); + Assert.assertFalse(planString.contains("`tb2`.`k1` IN (3, 4, 5)")); + } + + @Test + public void testSameAliasWithSlotIsNullInDifferentUnionChildren() throws Exception { + SessionVariable sessionVariable = dorisAssert.getSessionVariable(); + sessionVariable.setEnableInferPredicate(true); + Assert.assertTrue(sessionVariable.isEnableInferPredicate()); + String query = "select * from tb1 inner join tb2 on tb1.k1 = tb2.k1" + + " union select * from tb1 inner join tb2 on tb1.k2 = tb2.k2 where tb1.k1 is not null"; + String planString = dorisAssert.query(query).explainQuery(); + Assert.assertFalse(planString.contains("`tb2`.`k1` IS NOT NULL")); + } }