diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java index 93d5f2c443..a1242c4a84 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java @@ -122,6 +122,7 @@ public enum RuleType { // subquery analyze FILTER_SUBQUERY_TO_APPLY(RuleTypeClass.REWRITE), PROJECT_SUBQUERY_TO_APPLY(RuleTypeClass.REWRITE), + JOIN_SUBQUERY_TO_APPLY(RuleTypeClass.REWRITE), ONE_ROW_RELATION_SUBQUERY_TO_APPLY(RuleTypeClass.REWRITE), // subquery rewrite rule ELIMINATE_LIMIT_UNDER_APPLY(RuleTypeClass.REWRITE), diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/SubqueryToApply.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/SubqueryToApply.java index de67fea935..7ea406677e 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/SubqueryToApply.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/SubqueryToApply.java @@ -23,6 +23,7 @@ import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleType; import org.apache.doris.nereids.trees.expressions.Alias; import org.apache.doris.nereids.trees.expressions.BinaryOperator; +import org.apache.doris.nereids.trees.expressions.ComparisonPredicate; import org.apache.doris.nereids.trees.expressions.Exists; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.InSubquery; @@ -30,6 +31,7 @@ import org.apache.doris.nereids.trees.expressions.MarkJoinSlotReference; import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.expressions.Or; import org.apache.doris.nereids.trees.expressions.ScalarSubquery; +import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.expressions.SlotReference; import org.apache.doris.nereids.trees.expressions.SubqueryExpr; import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral; @@ -38,6 +40,7 @@ import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.algebra.Aggregate; import org.apache.doris.nereids.trees.plans.logical.LogicalApply; import org.apache.doris.nereids.trees.plans.logical.LogicalFilter; +import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; import org.apache.doris.nereids.trees.plans.logical.LogicalOneRowRelation; import org.apache.doris.nereids.trees.plans.logical.LogicalPlan; import org.apache.doris.nereids.trees.plans.logical.LogicalProject; @@ -52,6 +55,7 @@ import java.util.Map; import java.util.Objects; import java.util.Optional; import java.util.Set; +import java.util.stream.Collectors; /** * SubqueryToApply. translate from subquery to LogicalApply. @@ -68,7 +72,7 @@ public class SubqueryToApply implements AnalysisRuleFactory { LogicalFilter filter = ctx.root; ImmutableList> subqueryExprsList = filter.getConjuncts().stream() - .map(e -> (Set) e.collect(SubqueryExpr.class::isInstance)) + .>map(e -> e.collect(SubqueryExpr.class::isInstance)) .collect(ImmutableList.toImmutableList()); if (subqueryExprsList.stream() .flatMap(Collection::stream).noneMatch(SubqueryExpr.class::isInstance)) { @@ -117,7 +121,7 @@ public class SubqueryToApply implements AnalysisRuleFactory { RuleType.PROJECT_SUBQUERY_TO_APPLY.build(logicalProject().thenApply(ctx -> { LogicalProject project = ctx.root; ImmutableList> subqueryExprsList = project.getProjects().stream() - .map(e -> (Set) e.collect(SubqueryExpr.class::isInstance)) + .>map(e -> e.collect(SubqueryExpr.class::isInstance)) .collect(ImmutableList.toImmutableList()); if (subqueryExprsList.stream().flatMap(Collection::stream).count() == 0) { return project; @@ -164,11 +168,84 @@ public class SubqueryToApply implements AnalysisRuleFactory { oneRowRelation.withProjects( ImmutableList.of(new Alias(BooleanLiteral.of(true), ctx.statementContext.generateColumnName())))); - } - )) + })), + RuleType.JOIN_SUBQUERY_TO_APPLY + .build(logicalJoin() + .when(join -> join.getHashJoinConjuncts().isEmpty() && !join.getOtherJoinConjuncts().isEmpty()) + .thenApply(ctx -> { + LogicalJoin join = ctx.root; + Map> joinConjuncts = join.getOtherJoinConjuncts().stream() + .collect(Collectors.groupingBy(conjunct -> conjunct.containsType(SubqueryExpr.class), + Collectors.toList())); + List subqueryConjuncts = joinConjuncts.get(true); + if (subqueryConjuncts == null || subqueryConjuncts.stream() + .anyMatch(expr -> !isValidSubqueryConjunct(expr, join.left()))) { + return join; + } + + ImmutableList> subqueryExprsList = subqueryConjuncts.stream() + .>map(e -> e.collect(SubqueryExpr.class::isInstance)) + .collect(ImmutableList.toImmutableList()); + ImmutableList.Builder newConjuncts = new ImmutableList.Builder<>(); + LogicalPlan applyPlan = null; + LogicalPlan leftChildPlan = (LogicalPlan) join.left(); + + // Subquery traversal with the conjunct of and as the granularity. + for (int i = 0; i < subqueryExprsList.size(); ++i) { + Set subqueryExprs = subqueryExprsList.get(i); + if (subqueryExprs.size() > 1) { + // only support the conjunct contains one subquery expr + return join; + } + + // first step: Replace the subquery of predicate in LogicalFilter + // second step: Replace subquery with LogicalApply + ReplaceSubquery replaceSubquery = new ReplaceSubquery(ctx.statementContext, true); + SubqueryContext context = new SubqueryContext(subqueryExprs); + Expression conjunct = replaceSubquery.replace(subqueryConjuncts.get(i), context); + + applyPlan = subqueryToApply( + subqueryExprs.stream().collect(ImmutableList.toImmutableList()), + leftChildPlan, context.getSubqueryToMarkJoinSlot(), + ctx.cascadesContext, Optional.of(conjunct), false); + leftChildPlan = applyPlan; + newConjuncts.add(conjunct); + } + List simpleConjuncts = joinConjuncts.get(false); + if (simpleConjuncts != null) { + newConjuncts.addAll(simpleConjuncts); + } + Plan newJoin = join.withConjunctsChildren(join.getHashJoinConjuncts(), + newConjuncts.build(), applyPlan, join.right()); + return newJoin; + })) ); } + private static boolean isValidSubqueryConjunct(Expression expression, Plan leftChild) { + // the subquery must be uncorrelated subquery or only correlated to the left child + // currently only support the following 4 simple scenarios + // 1. col ComparisonPredicate subquery + // 2. col in (subquery) + // 3. exists (subquery) + // 4. col1 ComparisonPredicate subquery or xxx (no more subquery) + List slots = leftChild.getOutput(); + if (expression instanceof ComparisonPredicate && expression.child(1) instanceof ScalarSubquery) { + ScalarSubquery subquery = (ScalarSubquery) expression.child(1); + return slots.containsAll(subquery.getCorrelateSlots()); + } else if (expression instanceof InSubquery) { + return slots.containsAll(((InSubquery) expression).getCorrelateSlots()); + } else if (expression instanceof Exists) { + return slots.containsAll(((Exists) expression).getCorrelateSlots()); + } else { + List subqueryExprs = expression.collectToList(SubqueryExpr.class::isInstance); + if (subqueryExprs.size() == 1) { + return slots.containsAll(subqueryExprs.get(0).getCorrelateSlots()); + } + } + return false; + } + private LogicalPlan subqueryToApply(List subqueryExprs, LogicalPlan childPlan, Map> subqueryToMarkJoinSlot, CascadesContext ctx, @@ -252,12 +329,12 @@ public class SubqueryToApply implements AnalysisRuleFactory { private final StatementContext statementContext; private boolean isMarkJoin; - private final boolean isProject; + private final boolean shouldOutputMarkJoinSlot; public ReplaceSubquery(StatementContext statementContext, - boolean isProject) { + boolean shouldOutputMarkJoinSlot) { this.statementContext = Objects.requireNonNull(statementContext, "statementContext can't be null"); - this.isProject = isProject; + this.shouldOutputMarkJoinSlot = shouldOutputMarkJoinSlot; } public Expression replace(Expression expression, SubqueryContext subqueryContext) { @@ -269,7 +346,7 @@ public class SubqueryToApply implements AnalysisRuleFactory { // The result set when NULL is specified in the subquery and still evaluates to TRUE by using EXISTS // When the number of rows returned is empty, agg will return null, so if there is more agg, // it will always consider the returned result to be true - boolean needCreateMarkJoinSlot = isMarkJoin || isProject; + boolean needCreateMarkJoinSlot = isMarkJoin || shouldOutputMarkJoinSlot; MarkJoinSlotReference markJoinSlotReference = null; if (exists.getQueryPlan().anyMatch(Aggregate.class::isInstance) && needCreateMarkJoinSlot) { markJoinSlotReference = @@ -288,7 +365,7 @@ public class SubqueryToApply implements AnalysisRuleFactory { public Expression visitInSubquery(InSubquery in, SubqueryContext context) { MarkJoinSlotReference markJoinSlotReference = new MarkJoinSlotReference(statementContext.generateColumnName()); - boolean needCreateMarkJoinSlot = isMarkJoin || isProject; + boolean needCreateMarkJoinSlot = isMarkJoin || shouldOutputMarkJoinSlot; if (needCreateMarkJoinSlot) { context.setSubqueryToMarkJoinSlot(in, Optional.of(markJoinSlotReference)); }