[feature](nereids)support subquey in join condition (#24598)

This commit is contained in:
starocean999
2023-09-27 10:02:56 +08:00
committed by GitHub
parent 6d27a016b9
commit 0dd57b1982
2 changed files with 87 additions and 9 deletions

View File

@ -122,6 +122,7 @@ public enum RuleType {
// subquery analyze
FILTER_SUBQUERY_TO_APPLY(RuleTypeClass.REWRITE),
PROJECT_SUBQUERY_TO_APPLY(RuleTypeClass.REWRITE),
JOIN_SUBQUERY_TO_APPLY(RuleTypeClass.REWRITE),
ONE_ROW_RELATION_SUBQUERY_TO_APPLY(RuleTypeClass.REWRITE),
// subquery rewrite rule
ELIMINATE_LIMIT_UNDER_APPLY(RuleTypeClass.REWRITE),

View File

@ -23,6 +23,7 @@ import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.BinaryOperator;
import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
import org.apache.doris.nereids.trees.expressions.Exists;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.InSubquery;
@ -30,6 +31,7 @@ import org.apache.doris.nereids.trees.expressions.MarkJoinSlotReference;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Or;
import org.apache.doris.nereids.trees.expressions.ScalarSubquery;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.SubqueryExpr;
import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral;
@ -38,6 +40,7 @@ import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.algebra.Aggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalApply;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalOneRowRelation;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
@ -52,6 +55,7 @@ import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
/**
* SubqueryToApply. translate from subquery to LogicalApply.
@ -68,7 +72,7 @@ public class SubqueryToApply implements AnalysisRuleFactory {
LogicalFilter<Plan> filter = ctx.root;
ImmutableList<Set<SubqueryExpr>> subqueryExprsList = filter.getConjuncts().stream()
.map(e -> (Set<SubqueryExpr>) e.collect(SubqueryExpr.class::isInstance))
.<Set<SubqueryExpr>>map(e -> e.collect(SubqueryExpr.class::isInstance))
.collect(ImmutableList.toImmutableList());
if (subqueryExprsList.stream()
.flatMap(Collection::stream).noneMatch(SubqueryExpr.class::isInstance)) {
@ -117,7 +121,7 @@ public class SubqueryToApply implements AnalysisRuleFactory {
RuleType.PROJECT_SUBQUERY_TO_APPLY.build(logicalProject().thenApply(ctx -> {
LogicalProject<Plan> project = ctx.root;
ImmutableList<Set<SubqueryExpr>> subqueryExprsList = project.getProjects().stream()
.map(e -> (Set<SubqueryExpr>) e.collect(SubqueryExpr.class::isInstance))
.<Set<SubqueryExpr>>map(e -> e.collect(SubqueryExpr.class::isInstance))
.collect(ImmutableList.toImmutableList());
if (subqueryExprsList.stream().flatMap(Collection::stream).count() == 0) {
return project;
@ -164,11 +168,84 @@ public class SubqueryToApply implements AnalysisRuleFactory {
oneRowRelation.withProjects(
ImmutableList.of(new Alias(BooleanLiteral.of(true),
ctx.statementContext.generateColumnName()))));
}
))
})),
RuleType.JOIN_SUBQUERY_TO_APPLY
.build(logicalJoin()
.when(join -> join.getHashJoinConjuncts().isEmpty() && !join.getOtherJoinConjuncts().isEmpty())
.thenApply(ctx -> {
LogicalJoin<Plan, Plan> join = ctx.root;
Map<Boolean, List<Expression>> joinConjuncts = join.getOtherJoinConjuncts().stream()
.collect(Collectors.groupingBy(conjunct -> conjunct.containsType(SubqueryExpr.class),
Collectors.toList()));
List<Expression> subqueryConjuncts = joinConjuncts.get(true);
if (subqueryConjuncts == null || subqueryConjuncts.stream()
.anyMatch(expr -> !isValidSubqueryConjunct(expr, join.left()))) {
return join;
}
ImmutableList<Set<SubqueryExpr>> subqueryExprsList = subqueryConjuncts.stream()
.<Set<SubqueryExpr>>map(e -> e.collect(SubqueryExpr.class::isInstance))
.collect(ImmutableList.toImmutableList());
ImmutableList.Builder<Expression> newConjuncts = new ImmutableList.Builder<>();
LogicalPlan applyPlan = null;
LogicalPlan leftChildPlan = (LogicalPlan) join.left();
// Subquery traversal with the conjunct of and as the granularity.
for (int i = 0; i < subqueryExprsList.size(); ++i) {
Set<SubqueryExpr> subqueryExprs = subqueryExprsList.get(i);
if (subqueryExprs.size() > 1) {
// only support the conjunct contains one subquery expr
return join;
}
// first step: Replace the subquery of predicate in LogicalFilter
// second step: Replace subquery with LogicalApply
ReplaceSubquery replaceSubquery = new ReplaceSubquery(ctx.statementContext, true);
SubqueryContext context = new SubqueryContext(subqueryExprs);
Expression conjunct = replaceSubquery.replace(subqueryConjuncts.get(i), context);
applyPlan = subqueryToApply(
subqueryExprs.stream().collect(ImmutableList.toImmutableList()),
leftChildPlan, context.getSubqueryToMarkJoinSlot(),
ctx.cascadesContext, Optional.of(conjunct), false);
leftChildPlan = applyPlan;
newConjuncts.add(conjunct);
}
List<Expression> simpleConjuncts = joinConjuncts.get(false);
if (simpleConjuncts != null) {
newConjuncts.addAll(simpleConjuncts);
}
Plan newJoin = join.withConjunctsChildren(join.getHashJoinConjuncts(),
newConjuncts.build(), applyPlan, join.right());
return newJoin;
}))
);
}
private static boolean isValidSubqueryConjunct(Expression expression, Plan leftChild) {
// the subquery must be uncorrelated subquery or only correlated to the left child
// currently only support the following 4 simple scenarios
// 1. col ComparisonPredicate subquery
// 2. col in (subquery)
// 3. exists (subquery)
// 4. col1 ComparisonPredicate subquery or xxx (no more subquery)
List<Slot> slots = leftChild.getOutput();
if (expression instanceof ComparisonPredicate && expression.child(1) instanceof ScalarSubquery) {
ScalarSubquery subquery = (ScalarSubquery) expression.child(1);
return slots.containsAll(subquery.getCorrelateSlots());
} else if (expression instanceof InSubquery) {
return slots.containsAll(((InSubquery) expression).getCorrelateSlots());
} else if (expression instanceof Exists) {
return slots.containsAll(((Exists) expression).getCorrelateSlots());
} else {
List<SubqueryExpr> subqueryExprs = expression.collectToList(SubqueryExpr.class::isInstance);
if (subqueryExprs.size() == 1) {
return slots.containsAll(subqueryExprs.get(0).getCorrelateSlots());
}
}
return false;
}
private LogicalPlan subqueryToApply(List<SubqueryExpr> subqueryExprs, LogicalPlan childPlan,
Map<SubqueryExpr, Optional<MarkJoinSlotReference>> subqueryToMarkJoinSlot,
CascadesContext ctx,
@ -252,12 +329,12 @@ public class SubqueryToApply implements AnalysisRuleFactory {
private final StatementContext statementContext;
private boolean isMarkJoin;
private final boolean isProject;
private final boolean shouldOutputMarkJoinSlot;
public ReplaceSubquery(StatementContext statementContext,
boolean isProject) {
boolean shouldOutputMarkJoinSlot) {
this.statementContext = Objects.requireNonNull(statementContext, "statementContext can't be null");
this.isProject = isProject;
this.shouldOutputMarkJoinSlot = shouldOutputMarkJoinSlot;
}
public Expression replace(Expression expression, SubqueryContext subqueryContext) {
@ -269,7 +346,7 @@ public class SubqueryToApply implements AnalysisRuleFactory {
// The result set when NULL is specified in the subquery and still evaluates to TRUE by using EXISTS
// When the number of rows returned is empty, agg will return null, so if there is more agg,
// it will always consider the returned result to be true
boolean needCreateMarkJoinSlot = isMarkJoin || isProject;
boolean needCreateMarkJoinSlot = isMarkJoin || shouldOutputMarkJoinSlot;
MarkJoinSlotReference markJoinSlotReference = null;
if (exists.getQueryPlan().anyMatch(Aggregate.class::isInstance) && needCreateMarkJoinSlot) {
markJoinSlotReference =
@ -288,7 +365,7 @@ public class SubqueryToApply implements AnalysisRuleFactory {
public Expression visitInSubquery(InSubquery in, SubqueryContext context) {
MarkJoinSlotReference markJoinSlotReference =
new MarkJoinSlotReference(statementContext.generateColumnName());
boolean needCreateMarkJoinSlot = isMarkJoin || isProject;
boolean needCreateMarkJoinSlot = isMarkJoin || shouldOutputMarkJoinSlot;
if (needCreateMarkJoinSlot) {
context.setSubqueryToMarkJoinSlot(in, Optional.of(markJoinSlotReference));
}