validForInfer(Expression expression, InferType inferType) {
if (!inferType.superClazz.isAssignableFrom(expression.getDataType().getClass())) {
return Optional.empty();
}
if (expression instanceof SlotReference || expression.isConstant()) {
return Optional.of(expression);
}
+ if (!(expression instanceof Cast)) {
+ return Optional.empty();
+ }
+ Cast cast = (Cast) expression;
+ Expression child = cast.child();
+ DataType dataType = cast.getDataType();
+ DataType childType = child.getDataType();
if (inferType == InferType.INTEGRAL) {
- if (expression instanceof Cast) {
- // avoid cast from wider type to narrower type, such as cast(int as smallint)
- // IntegralType dataType = (IntegralType) expression.getDataType();
- // DataType childType = ((Cast) expression).child().getDataType();
- // if (childType instanceof IntegralType && dataType.widerThan((IntegralType) childType)) {
- // return validForInfer(((Cast) expression).child(), inferType);
- // }
- return validForInfer(((Cast) expression).child(), inferType);
- }
+ // avoid cast from wider type to narrower type, such as cast(int as smallint)
+ // IntegralType dataType = (IntegralType) expression.getDataType();
+ // DataType childType = ((Cast) expression).child().getDataType();
+ // if (childType instanceof IntegralType && dataType.widerThan((IntegralType) childType)) {
+ // return validForInfer(((Cast) expression).child(), inferType);
+ // }
+ return validForInfer(child, inferType);
} else if (inferType == InferType.DATE) {
- if (expression instanceof Cast) {
- DataType dataType = expression.getDataType();
- DataType childType = ((Cast) expression).child().getDataType();
- // avoid lost precision
- if (dataType instanceof DateType) {
- if (childType instanceof DateV2Type || childType instanceof DateType) {
- return validForInfer(((Cast) expression).child(), inferType);
- }
- } else if (dataType instanceof DateV2Type) {
- if (childType instanceof DateType || childType instanceof DateV2Type) {
- return validForInfer(((Cast) expression).child(), inferType);
- }
- } else if (dataType instanceof DateTimeType) {
- if (!(childType instanceof DateTimeV2Type)) {
- return validForInfer(((Cast) expression).child(), inferType);
- }
- } else if (dataType instanceof DateTimeV2Type) {
- return validForInfer(((Cast) expression).child(), inferType);
+ // avoid lost precision
+ if (dataType instanceof DateType) {
+ if (childType instanceof DateV2Type || childType instanceof DateType) {
+ return validForInfer(child, inferType);
}
+ } else if (dataType instanceof DateV2Type) {
+ if (childType instanceof DateType || childType instanceof DateV2Type) {
+ return validForInfer(child, inferType);
+ }
+ } else if (dataType instanceof DateTimeType) {
+ if (!(childType instanceof DateTimeV2Type)) {
+ return validForInfer(child, inferType);
+ }
+ } else if (dataType instanceof DateTimeV2Type) {
+ return validForInfer(child, inferType);
}
} else if (inferType == InferType.STRING) {
- if (expression instanceof Cast) {
- DataType dataType = expression.getDataType();
- DataType childType = ((Cast) expression).child().getDataType();
- // avoid substring cast such as cast(char(3) as char(2))
- if (dataType.width() <= 0 || (dataType.width() >= childType.width() && childType.width() >= 0)) {
- return validForInfer(((Cast) expression).child(), inferType);
- }
+ // avoid substring cast such as cast(char(3) as char(2))
+ if (dataType.width() <= 0 || (dataType.width() >= childType.width() && childType.width() >= 0)) {
+ return validForInfer(child, inferType);
}
- } else {
- return Optional.empty();
}
return Optional.empty();
}
- private ComparisonInferInfo inferInferInfo(ComparisonPredicate comparisonPredicate) {
+ private static EqualInferInfo inferInferInfo(ComparisonPredicate comparisonPredicate) {
DataType leftType = comparisonPredicate.left().getDataType();
InferType inferType;
if (leftType instanceof CharacterType) {
@@ -223,25 +266,27 @@ public class PredicatePropagation {
if (!left.isPresent() || !right.isPresent()) {
inferType = InferType.NONE;
}
- return new ComparisonInferInfo(inferType, left, right, comparisonPredicate);
+ return new EqualInferInfo(inferType, left.orElse(comparisonPredicate.left()),
+ right.orElse(comparisonPredicate.right()), comparisonPredicate);
}
/**
* Currently only equivalence derivation is supported
* and requires that the left and right sides of an expression must be slot
+ *
+ * TODO: NullSafeEqual
*/
- private ComparisonInferInfo getEquivalentInferInfo(ComparisonPredicate predicate) {
+ private static EqualInferInfo getEqualInferInfo(ComparisonPredicate predicate) {
if (!(predicate instanceof EqualTo)) {
- return new ComparisonInferInfo(InferType.NONE,
- Optional.of(predicate.left()), Optional.of(predicate.right()), predicate);
+ return new EqualInferInfo(InferType.NONE, predicate.left(), predicate.right(), predicate);
}
- ComparisonInferInfo info = inferInferInfo(predicate);
+ EqualInferInfo info = inferInferInfo(predicate);
if (info.inferType == InferType.NONE) {
return info;
}
- if (info.left.get() instanceof SlotReference && info.right.get() instanceof SlotReference) {
+ if (info.left instanceof SlotReference && info.right instanceof SlotReference) {
return info;
}
- return new ComparisonInferInfo(InferType.NONE, info.left, info.right, info.comparisonPredicate);
+ return new EqualInferInfo(InferType.NONE, info.left, info.right, info.comparisonPredicate);
}
}
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PullUpPredicates.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PullUpPredicates.java
index 1a198c76ea..26e1358c2e 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PullUpPredicates.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PullUpPredicates.java
@@ -47,7 +47,6 @@ import java.util.stream.Collectors;
*/
public class PullUpPredicates extends PlanVisitor, Void> {
- PredicatePropagation propagation = new PredicatePropagation();
Map> cache = new IdentityHashMap<>();
@Override
@@ -99,6 +98,7 @@ public class PullUpPredicates extends PlanVisitor, Void
public ImmutableSet visitLogicalAggregate(LogicalAggregate extends Plan> aggregate, Void context) {
return cacheOrElse(aggregate, () -> {
ImmutableSet childPredicates = aggregate.child().accept(this, context);
+ // TODO
Map expressionSlotMap = aggregate.getOutputExpressions()
.stream()
.filter(this::hasAgg)
@@ -130,7 +130,7 @@ public class PullUpPredicates extends PlanVisitor, Void
private ImmutableSet getAvailableExpressions(Collection predicates, Plan plan) {
Set expressions = Sets.newHashSet(predicates);
- expressions.addAll(propagation.infer(expressions));
+ expressions.addAll(PredicatePropagation.infer(expressions));
return expressions.stream()
.filter(p -> plan.getOutputSet().containsAll(p.getInputSlots()))
.collect(ImmutableSet.toImmutableSet());
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/EqualTo.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/EqualTo.java
index 2704d44655..3e71b3b89a 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/EqualTo.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/EqualTo.java
@@ -39,10 +39,6 @@ public class EqualTo extends EqualPredicate implements PropagateNullable {
super(ImmutableList.of(left, right), "=", inferred);
}
- private EqualTo(List children) {
- this(children, false);
- }
-
private EqualTo(List children, boolean inferred) {
super(children, "=", inferred);
}
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/InPredicate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/InPredicate.java
index d839a1e906..c86a074dcf 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/InPredicate.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/InPredicate.java
@@ -48,6 +48,12 @@ public class InPredicate extends Expression {
this.options = ImmutableList.copyOf(Objects.requireNonNull(options, "In list cannot be null"));
}
+ public InPredicate(Expression compareExpr, List options, boolean inferred) {
+ super(new Builder().add(compareExpr).addAll(options).build(), inferred);
+ this.compareExpr = Objects.requireNonNull(compareExpr, "Compare Expr cannot be null");
+ this.options = ImmutableList.copyOf(Objects.requireNonNull(options, "In list cannot be null"));
+ }
+
public R accept(ExpressionVisitor visitor, C context) {
return visitor.visitInPredicate(this, context);
}
@@ -80,6 +86,11 @@ public class InPredicate extends Expression {
});
}
+ @Override
+ public Expression withInferred(boolean inferred) {
+ return new InPredicate(children.get(0), ImmutableList.copyOf(children).subList(1, children.size()), true);
+ }
+
@Override
public String toString() {
return compareExpr + " IN " + options.stream()
diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/InferPredicatesTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/InferPredicatesTest.java
index c910e98fcd..0708ea3f17 100644
--- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/InferPredicatesTest.java
+++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/InferPredicatesTest.java
@@ -25,7 +25,7 @@ import org.apache.doris.utframe.TestWithFeService;
import org.junit.jupiter.api.Test;
-public class InferPredicatesTest extends TestWithFeService implements MemoPatternMatchSupported {
+class InferPredicatesTest extends TestWithFeService implements MemoPatternMatchSupported {
@Override
protected void runBeforeAll() throws Exception {
@@ -77,7 +77,7 @@ public class InferPredicatesTest extends TestWithFeService implements MemoPatter
}
@Test
- public void inferPredicatesTest01() {
+ void inferPredicatesTest01() {
String sql = "select * from student join score on student.id = score.sid where student.id > 1";
PlanChecker.from(connectContext)
@@ -100,7 +100,7 @@ public class InferPredicatesTest extends TestWithFeService implements MemoPatter
}
@Test
- public void inferPredicatesTest02() {
+ void inferPredicatesTest02() {
String sql = "select * from student join score on student.id = score.sid";
PlanChecker.from(connectContext)
@@ -117,7 +117,7 @@ public class InferPredicatesTest extends TestWithFeService implements MemoPatter
}
@Test
- public void inferPredicatesTest03() {
+ void inferPredicatesTest03() {
String sql = "select * from student join score on student.id = score.sid where student.id in (1,2,3)";
PlanChecker.from(connectContext)
@@ -126,18 +126,17 @@ public class InferPredicatesTest extends TestWithFeService implements MemoPatter
.matches(
logicalProject(
logicalJoin(
- logicalFilter(
- logicalOlapScan()
- ).when(filter -> !ExpressionUtils.isInferred(filter.getPredicate())
+ logicalFilter(logicalOlapScan()).when(filter -> !ExpressionUtils.isInferred(filter.getPredicate())
& filter.getPredicate().toSql().contains("id IN (1, 2, 3)")),
- logicalOlapScan()
+ logicalFilter(logicalOlapScan()).when(filter -> ExpressionUtils.isInferred(filter.getPredicate())
+ & filter.getPredicate().toSql().contains("sid IN (1, 2, 3)"))
)
)
);
}
@Test
- public void inferPredicatesTest04() {
+ void inferPredicatesTest04() {
String sql = "select * from student join score on student.id = score.sid and student.id in (1,2,3)";
PlanChecker.from(connectContext)
@@ -146,18 +145,17 @@ public class InferPredicatesTest extends TestWithFeService implements MemoPatter
.matches(
logicalProject(
logicalJoin(
- logicalFilter(
- logicalOlapScan()
- ).when(filter -> !ExpressionUtils.isInferred(filter.getPredicate())
+ logicalFilter(logicalOlapScan()).when(filter -> !ExpressionUtils.isInferred(filter.getPredicate())
& filter.getPredicate().toSql().contains("id IN (1, 2, 3)")),
- logicalOlapScan()
+ logicalFilter(logicalOlapScan()).when(filter -> ExpressionUtils.isInferred(filter.getPredicate())
+ & filter.getPredicate().toSql().contains("sid IN (1, 2, 3)"))
)
)
);
}
@Test
- public void inferPredicatesTest05() {
+ void inferPredicatesTest05() {
String sql = "select * from student join score on student.id = score.sid join course on score.sid = course.id where student.id > 1";
PlanChecker.from(connectContext)
@@ -185,7 +183,7 @@ public class InferPredicatesTest extends TestWithFeService implements MemoPatter
}
@Test
- public void inferPredicatesTest06() {
+ void inferPredicatesTest06() {
String sql = "select * from student join score on student.id = score.sid join course on score.sid = course.id and score.sid > 1";
PlanChecker.from(connectContext)
@@ -213,7 +211,7 @@ public class InferPredicatesTest extends TestWithFeService implements MemoPatter
}
@Test
- public void inferPredicatesTest07() {
+ void inferPredicatesTest07() {
String sql = "select * from student left join score on student.id = score.sid where student.id > 1";
PlanChecker.from(connectContext)
@@ -236,7 +234,7 @@ public class InferPredicatesTest extends TestWithFeService implements MemoPatter
}
@Test
- public void inferPredicatesTest08() {
+ void inferPredicatesTest08() {
String sql = "select * from student left join score on student.id = score.sid and student.id > 1";
PlanChecker.from(connectContext)
@@ -256,7 +254,7 @@ public class InferPredicatesTest extends TestWithFeService implements MemoPatter
}
@Test
- public void inferPredicatesTest09() {
+ void inferPredicatesTest09() {
// convert left join to inner join
String sql = "select * from student left join score on student.id = score.sid where score.sid > 1";
@@ -280,7 +278,7 @@ public class InferPredicatesTest extends TestWithFeService implements MemoPatter
}
@Test
- public void inferPredicatesTest10() {
+ void inferPredicatesTest10() {
String sql = "select * from (select id as nid, name from student) t left join score on t.nid = score.sid where t.nid > 1";
PlanChecker.from(connectContext)
@@ -305,7 +303,7 @@ public class InferPredicatesTest extends TestWithFeService implements MemoPatter
}
@Test
- public void inferPredicatesTest11() {
+ void inferPredicatesTest11() {
String sql = "select * from (select id as nid, name from student) t left join score on t.nid = score.sid and t.nid > 1";
PlanChecker.from(connectContext)
@@ -327,7 +325,7 @@ public class InferPredicatesTest extends TestWithFeService implements MemoPatter
}
@Test
- public void inferPredicatesTest12() {
+ void inferPredicatesTest12() {
String sql = "select * from student left join (select sid as nid, sum(grade) from score group by sid) s on s.nid = student.id where student.id > 1";
PlanChecker.from(connectContext)
@@ -356,7 +354,7 @@ public class InferPredicatesTest extends TestWithFeService implements MemoPatter
}
@Test
- public void inferPredicatesTest13() {
+ void inferPredicatesTest13() {
String sql = "select * from (select id, name from student where id = 1) t left join score on t.id = score.sid";
PlanChecker.from(connectContext)
@@ -381,7 +379,7 @@ public class InferPredicatesTest extends TestWithFeService implements MemoPatter
}
@Test
- public void inferPredicatesTest14() {
+ void inferPredicatesTest14() {
String sql = "select * from student left semi join score on student.id = score.sid where student.id > 1";
PlanChecker.from(connectContext)
@@ -406,7 +404,7 @@ public class InferPredicatesTest extends TestWithFeService implements MemoPatter
}
@Test
- public void inferPredicatesTest15() {
+ void inferPredicatesTest15() {
String sql = "select * from student left semi join score on student.id = score.sid and student.id > 1";
PlanChecker.from(connectContext)
@@ -431,7 +429,7 @@ public class InferPredicatesTest extends TestWithFeService implements MemoPatter
}
@Test
- public void inferPredicatesTest16() {
+ void inferPredicatesTest16() {
String sql = "select * from student left anti join score on student.id = score.sid and student.id > 1";
PlanChecker.from(connectContext)
@@ -453,7 +451,7 @@ public class InferPredicatesTest extends TestWithFeService implements MemoPatter
}
@Test
- public void inferPredicatesTest17() {
+ void inferPredicatesTest17() {
String sql = "select * from student left anti join score on student.id = score.sid and score.sid > 1";
PlanChecker.from(connectContext)
@@ -475,7 +473,7 @@ public class InferPredicatesTest extends TestWithFeService implements MemoPatter
}
@Test
- public void inferPredicatesTest18() {
+ void inferPredicatesTest18() {
String sql = "select * from student left anti join score on student.id = score.sid where student.id > 1";
PlanChecker.from(connectContext)
@@ -500,7 +498,7 @@ public class InferPredicatesTest extends TestWithFeService implements MemoPatter
}
@Test
- public void inferPredicatesTest19() {
+ void inferPredicatesTest19() {
String sql = "select * from subquery1\n"
+ "left semi join (\n"
+ " select t1.k3\n"
@@ -564,7 +562,7 @@ public class InferPredicatesTest extends TestWithFeService implements MemoPatter
}
@Test
- public void inferPredicatesTest20() {
+ void inferPredicatesTest20() {
String sql = "select * from student left join score on student.id = score.sid and score.sid > 1 inner join course on course.id = score.sid";
PlanChecker.from(connectContext).analyze(sql).rewrite().printlnTree();
PlanChecker.from(connectContext)
@@ -592,7 +590,7 @@ public class InferPredicatesTest extends TestWithFeService implements MemoPatter
}
@Test
- public void inferPredicatesTest21() {
+ void inferPredicatesTest21() {
String sql = "select * from student,score,course where student.id = score.sid and score.sid = course.id and score.sid > 1";
PlanChecker.from(connectContext).analyze(sql).rewrite().printlnTree();
PlanChecker.from(connectContext)
@@ -623,7 +621,7 @@ public class InferPredicatesTest extends TestWithFeService implements MemoPatter
* test for #15310
*/
@Test
- public void inferPredicatesTest22() {
+ void inferPredicatesTest22() {
String sql = "select * from student join (select sid as id1, sid as id2, grade from score) s on student.id = s.id1 where s.id1 > 1";
PlanChecker.from(connectContext).analyze(sql).rewrite().printlnTree();
PlanChecker.from(connectContext)
@@ -651,7 +649,7 @@ public class InferPredicatesTest extends TestWithFeService implements MemoPatter
* in this case, filter on relation s1 should not contain s1.id = 1.
*/
@Test
- public void innerJoinShouldNotInferUnderLeftJoinOnClausePredicates() {
+ void innerJoinShouldNotInferUnderLeftJoinOnClausePredicates() {
String sql = "select * from student s1"
+ " left join (select sid as id1, sid as id2, grade from score) s2 on s1.id = s2.id1 and s1.id = 1"
+ " join (select sid as id1, sid as id2, grade from score) s3 on s1.id = s3.id1 where s1.id = 2";
diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagationTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagationTest.java
new file mode 100644
index 0000000000..b1aa25df1b
--- /dev/null
+++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagationTest.java
@@ -0,0 +1,51 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.rewrite;
+
+import org.apache.doris.nereids.trees.expressions.EqualTo;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.InPredicate;
+import org.apache.doris.nereids.trees.expressions.SlotReference;
+import org.apache.doris.nereids.trees.expressions.literal.Literal;
+import org.apache.doris.nereids.types.BigIntType;
+import org.apache.doris.nereids.types.SmallIntType;
+
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableSet;
+import org.junit.jupiter.api.Test;
+
+import java.util.Set;
+
+class PredicatePropagationTest {
+ private final SlotReference a = new SlotReference("a", SmallIntType.INSTANCE);
+ private final SlotReference b = new SlotReference("b", BigIntType.INSTANCE);
+
+ @Test
+ void equal() {
+ Set exprs = ImmutableSet.of(new EqualTo(a, b), new EqualTo(a, Literal.of(1)));
+ Set inferExprs = PredicatePropagation.infer(exprs);
+ System.out.println(inferExprs);
+ }
+
+ @Test
+ void in() {
+ Set exprs = ImmutableSet.of(new EqualTo(a, b), new InPredicate(a, ImmutableList.of(Literal.of(1))));
+ Set inferExprs = PredicatePropagation.infer(exprs);
+ System.out.println(inferExprs);
+ }
+}
diff --git a/regression-test/data/nereids_p0/hint/fix_leading.out b/regression-test/data/nereids_p0/hint/fix_leading.out
index a3ca4f5411..58122945bb 100644
--- a/regression-test/data/nereids_p0/hint/fix_leading.out
+++ b/regression-test/data/nereids_p0/hint/fix_leading.out
@@ -9,7 +9,7 @@ PhysicalResultSink
----------PhysicalDistribute[DistributionSpecHash]
------------PhysicalOlapScan[t2]
--------PhysicalDistribute[DistributionSpecHash]
-----------NestedLoopJoin[CROSS_JOIN]
+----------NestedLoopJoin[CROSS_JOIN](t4.c4 = t3.c3)(t3.c3 = t4.c4)
------------PhysicalOlapScan[t3]
------------PhysicalDistribute[DistributionSpecReplicated]
--------------PhysicalOlapScan[t4]
diff --git a/regression-test/data/nereids_p0/hint/test_leading.out b/regression-test/data/nereids_p0/hint/test_leading.out
index d1bd8f8bd2..fe3831a9fc 100644
--- a/regression-test/data/nereids_p0/hint/test_leading.out
+++ b/regression-test/data/nereids_p0/hint/test_leading.out
@@ -2609,7 +2609,7 @@ PhysicalResultSink
------------PhysicalProject
--------------PhysicalOlapScan[t2]
------------PhysicalDistribute[DistributionSpecReplicated]
---------------NestedLoopJoin[CROSS_JOIN]
+--------------NestedLoopJoin[CROSS_JOIN](t1.c1 = t3.c3)
----------------PhysicalProject
------------------PhysicalOlapScan[t1]
----------------PhysicalDistribute[DistributionSpecReplicated]
@@ -2631,7 +2631,7 @@ PhysicalResultSink
------------PhysicalProject
--------------PhysicalOlapScan[t2]
------------PhysicalDistribute[DistributionSpecReplicated]
---------------NestedLoopJoin[CROSS_JOIN]
+--------------NestedLoopJoin[CROSS_JOIN](t1.c1 = t3.c3)
----------------PhysicalProject
------------------PhysicalOlapScan[t3]
----------------PhysicalDistribute[DistributionSpecReplicated]
@@ -2745,7 +2745,7 @@ PhysicalResultSink
------------PhysicalProject
--------------PhysicalOlapScan[t2]
------------PhysicalDistribute[DistributionSpecReplicated]
---------------NestedLoopJoin[CROSS_JOIN]
+--------------NestedLoopJoin[CROSS_JOIN](t1.c1 = t3.c3)
----------------PhysicalProject
------------------PhysicalOlapScan[t1]
----------------PhysicalDistribute[DistributionSpecReplicated]
@@ -2767,7 +2767,7 @@ PhysicalResultSink
------------PhysicalProject
--------------PhysicalOlapScan[t2]
------------PhysicalDistribute[DistributionSpecReplicated]
---------------NestedLoopJoin[CROSS_JOIN]
+--------------NestedLoopJoin[CROSS_JOIN](t1.c1 = t3.c3)
----------------PhysicalProject
------------------PhysicalOlapScan[t3]
----------------PhysicalDistribute[DistributionSpecReplicated]
@@ -2881,7 +2881,7 @@ PhysicalResultSink
------------PhysicalProject
--------------PhysicalOlapScan[t2]
------------PhysicalDistribute[DistributionSpecHash]
---------------NestedLoopJoin[CROSS_JOIN]
+--------------NestedLoopJoin[CROSS_JOIN](t1.c1 = t3.c3)
----------------PhysicalProject
------------------PhysicalOlapScan[t1]
----------------PhysicalDistribute[DistributionSpecReplicated]
@@ -2903,7 +2903,7 @@ PhysicalResultSink
------------PhysicalProject
--------------PhysicalOlapScan[t2]
------------PhysicalDistribute[DistributionSpecHash]
---------------NestedLoopJoin[CROSS_JOIN]
+--------------NestedLoopJoin[CROSS_JOIN](t1.c1 = t3.c3)
----------------PhysicalProject
------------------PhysicalOlapScan[t3]
----------------PhysicalDistribute[DistributionSpecReplicated]
diff --git a/regression-test/suites/nereids_p0/infer_predicate/infer_predicate.groovy b/regression-test/suites/nereids_p0/infer_predicate/infer_predicate.groovy
index c5942680ea..55645ed8ea 100644
--- a/regression-test/suites/nereids_p0/infer_predicate/infer_predicate.groovy
+++ b/regression-test/suites/nereids_p0/infer_predicate/infer_predicate.groovy
@@ -41,7 +41,7 @@ suite("test_infer_predicate") {
explain {
sql "select * from infer_tb1 inner join infer_tb2 where cast(infer_tb2.k4 as int) = infer_tb1.k2 and infer_tb2.k4 = 1;"
- contains "PREDICATES: k2"
+ contains "PREDICATES: CAST(k2"
}
explain {