diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindExpression.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindExpression.java index dd967984c3..970858c4a1 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindExpression.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindExpression.java @@ -76,6 +76,7 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalSubQueryAlias; import org.apache.doris.nereids.trees.plans.logical.LogicalTVFRelation; import org.apache.doris.nereids.trees.plans.logical.UsingJoin; import org.apache.doris.nereids.trees.plans.visitor.InferPlanOutputAlias; +import org.apache.doris.nereids.types.BooleanType; import org.apache.doris.nereids.util.ExpressionUtils; import org.apache.doris.nereids.util.TypeCoercionUtils; import org.apache.doris.qe.ConnectContext; @@ -158,6 +159,7 @@ public class BindExpression implements AnalysisRuleFactory { Set boundConjuncts = filter.getConjuncts().stream() .map(expr -> bindSlot(expr, filter.child(), ctx.cascadesContext)) .map(expr -> bindFunction(expr, ctx.root, ctx.cascadesContext)) + .map(expr -> TypeCoercionUtils.castIfNotSameType(expr, BooleanType.INSTANCE)) .collect(ImmutableSet.toImmutableSet()); return new LogicalFilter<>(boundConjuncts, filter.child()); }) @@ -211,10 +213,12 @@ public class BindExpression implements AnalysisRuleFactory { List cond = join.getOtherJoinConjuncts().stream() .map(expr -> bindSlot(expr, join.children(), ctx.cascadesContext)) .map(expr -> bindFunction(expr, ctx.root, ctx.cascadesContext)) + .map(expr -> TypeCoercionUtils.castIfNotSameType(expr, BooleanType.INSTANCE)) .collect(Collectors.toList()); List hashJoinConjuncts = join.getHashJoinConjuncts().stream() .map(expr -> bindSlot(expr, join.children(), ctx.cascadesContext)) .map(expr -> bindFunction(expr, ctx.root, ctx.cascadesContext)) + .map(expr -> TypeCoercionUtils.castIfNotSameType(expr, BooleanType.INSTANCE)) .collect(Collectors.toList()); return new LogicalJoin<>(join.getJoinType(), hashJoinConjuncts, cond, join.getHint(), join.getMarkJoinSlotReference(), @@ -476,6 +480,7 @@ public class BindExpression implements AnalysisRuleFactory { return bindSlot(expr, childPlan, ctx.cascadesContext, false); }) .map(expr -> bindFunction(expr, ctx.root, ctx.cascadesContext)) + .map(expr -> TypeCoercionUtils.castIfNotSameType(expr, BooleanType.INSTANCE)) .collect(Collectors.toSet()); checkIfOutputAliasNameDuplicatedForGroupBy(ImmutableList.copyOf(boundConjuncts), childPlan.getOutputExpressions()); @@ -492,6 +497,7 @@ public class BindExpression implements AnalysisRuleFactory { return bindSlot(expr, childPlan.children(), ctx.cascadesContext, false); }) .map(expr -> bindFunction(expr, ctx.root, ctx.cascadesContext)) + .map(expr -> TypeCoercionUtils.castIfNotSameType(expr, BooleanType.INSTANCE)) .collect(Collectors.toSet()); checkIfOutputAliasNameDuplicatedForGroupBy(ImmutableList.copyOf(boundConjuncts), childPlan.getOutput().stream().map(NamedExpression.class::cast) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AdjustConjunctsReturnType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AdjustConjunctsReturnType.java index 096799bec8..9c1a60fdf5 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AdjustConjunctsReturnType.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AdjustConjunctsReturnType.java @@ -22,7 +22,6 @@ import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalFilter; import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; -import org.apache.doris.nereids.trees.plans.logical.LogicalPlan; import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter; import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter; import org.apache.doris.nereids.types.BooleanType; @@ -44,7 +43,7 @@ public class AdjustConjunctsReturnType extends DefaultPlanRewriter impleme @Override public Plan visit(Plan plan, Void context) { - return (LogicalPlan) super.visit(plan, context); + return super.visit(plan, context); } @Override diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalFilter.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalFilter.java index a181755543..cb5502d1b6 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalFilter.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalFilter.java @@ -63,7 +63,7 @@ public class LogicalFilter extends LogicalUnary getExpressions() { - return ImmutableList.of(getPredicate()); + return ImmutableList.copyOf(conjuncts); } @Override diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalHaving.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalHaving.java index 6537770b40..8524b35d82 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalHaving.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalHaving.java @@ -61,7 +61,7 @@ public class LogicalHaving extends LogicalUnary getExpressions() { - return ImmutableList.of(getPredicate()); + return ImmutableList.copyOf(conjuncts); } @Override diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/TypeCoercionUtils.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/TypeCoercionUtils.java index 3574051817..c940e08de1 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/TypeCoercionUtils.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/TypeCoercionUtils.java @@ -1028,19 +1028,9 @@ public class TypeCoercionUtils { public static Expression processCompoundPredicate(CompoundPredicate compoundPredicate) { // check compoundPredicate.checkLegalityBeforeTypeCoercion(); - - compoundPredicate.children().forEach(e -> { - if (!e.getDataType().isBooleanType() && !e.getDataType().isNullType() - && !(e instanceof SubqueryExpr)) { - throw new AnalysisException(String.format( - "Operand '%s' part of predicate " + "'%s' should return type 'BOOLEAN' but " - + "returns type '%s'.", - e.toSql(), compoundPredicate.toSql(), e.getDataType())); - } - } - ); List children = compoundPredicate.children().stream() - .map(e -> e.getDataType().isNullType() ? new NullLiteral(BooleanType.INSTANCE) : e) + .map(e -> e.getDataType().isNullType() ? new NullLiteral(BooleanType.INSTANCE) + : castIfNotSameType(e, BooleanType.INSTANCE)) .collect(Collectors.toList()); return compoundPredicate.withChildren(children); }