[feature](nereids)support subquey in join condition (#24598)
This commit is contained in:
@ -122,6 +122,7 @@ public enum RuleType {
|
||||
// subquery analyze
|
||||
FILTER_SUBQUERY_TO_APPLY(RuleTypeClass.REWRITE),
|
||||
PROJECT_SUBQUERY_TO_APPLY(RuleTypeClass.REWRITE),
|
||||
JOIN_SUBQUERY_TO_APPLY(RuleTypeClass.REWRITE),
|
||||
ONE_ROW_RELATION_SUBQUERY_TO_APPLY(RuleTypeClass.REWRITE),
|
||||
// subquery rewrite rule
|
||||
ELIMINATE_LIMIT_UNDER_APPLY(RuleTypeClass.REWRITE),
|
||||
|
||||
@ -23,6 +23,7 @@ import org.apache.doris.nereids.rules.Rule;
|
||||
import org.apache.doris.nereids.rules.RuleType;
|
||||
import org.apache.doris.nereids.trees.expressions.Alias;
|
||||
import org.apache.doris.nereids.trees.expressions.BinaryOperator;
|
||||
import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
|
||||
import org.apache.doris.nereids.trees.expressions.Exists;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.InSubquery;
|
||||
@ -30,6 +31,7 @@ import org.apache.doris.nereids.trees.expressions.MarkJoinSlotReference;
|
||||
import org.apache.doris.nereids.trees.expressions.NamedExpression;
|
||||
import org.apache.doris.nereids.trees.expressions.Or;
|
||||
import org.apache.doris.nereids.trees.expressions.ScalarSubquery;
|
||||
import org.apache.doris.nereids.trees.expressions.Slot;
|
||||
import org.apache.doris.nereids.trees.expressions.SlotReference;
|
||||
import org.apache.doris.nereids.trees.expressions.SubqueryExpr;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral;
|
||||
@ -38,6 +40,7 @@ import org.apache.doris.nereids.trees.plans.Plan;
|
||||
import org.apache.doris.nereids.trees.plans.algebra.Aggregate;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalApply;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalOneRowRelation;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
|
||||
@ -52,6 +55,7 @@ import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* SubqueryToApply. translate from subquery to LogicalApply.
|
||||
@ -68,7 +72,7 @@ public class SubqueryToApply implements AnalysisRuleFactory {
|
||||
LogicalFilter<Plan> filter = ctx.root;
|
||||
|
||||
ImmutableList<Set<SubqueryExpr>> subqueryExprsList = filter.getConjuncts().stream()
|
||||
.map(e -> (Set<SubqueryExpr>) e.collect(SubqueryExpr.class::isInstance))
|
||||
.<Set<SubqueryExpr>>map(e -> e.collect(SubqueryExpr.class::isInstance))
|
||||
.collect(ImmutableList.toImmutableList());
|
||||
if (subqueryExprsList.stream()
|
||||
.flatMap(Collection::stream).noneMatch(SubqueryExpr.class::isInstance)) {
|
||||
@ -117,7 +121,7 @@ public class SubqueryToApply implements AnalysisRuleFactory {
|
||||
RuleType.PROJECT_SUBQUERY_TO_APPLY.build(logicalProject().thenApply(ctx -> {
|
||||
LogicalProject<Plan> project = ctx.root;
|
||||
ImmutableList<Set<SubqueryExpr>> subqueryExprsList = project.getProjects().stream()
|
||||
.map(e -> (Set<SubqueryExpr>) e.collect(SubqueryExpr.class::isInstance))
|
||||
.<Set<SubqueryExpr>>map(e -> e.collect(SubqueryExpr.class::isInstance))
|
||||
.collect(ImmutableList.toImmutableList());
|
||||
if (subqueryExprsList.stream().flatMap(Collection::stream).count() == 0) {
|
||||
return project;
|
||||
@ -164,11 +168,84 @@ public class SubqueryToApply implements AnalysisRuleFactory {
|
||||
oneRowRelation.withProjects(
|
||||
ImmutableList.of(new Alias(BooleanLiteral.of(true),
|
||||
ctx.statementContext.generateColumnName()))));
|
||||
}
|
||||
))
|
||||
})),
|
||||
RuleType.JOIN_SUBQUERY_TO_APPLY
|
||||
.build(logicalJoin()
|
||||
.when(join -> join.getHashJoinConjuncts().isEmpty() && !join.getOtherJoinConjuncts().isEmpty())
|
||||
.thenApply(ctx -> {
|
||||
LogicalJoin<Plan, Plan> join = ctx.root;
|
||||
Map<Boolean, List<Expression>> joinConjuncts = join.getOtherJoinConjuncts().stream()
|
||||
.collect(Collectors.groupingBy(conjunct -> conjunct.containsType(SubqueryExpr.class),
|
||||
Collectors.toList()));
|
||||
List<Expression> subqueryConjuncts = joinConjuncts.get(true);
|
||||
if (subqueryConjuncts == null || subqueryConjuncts.stream()
|
||||
.anyMatch(expr -> !isValidSubqueryConjunct(expr, join.left()))) {
|
||||
return join;
|
||||
}
|
||||
|
||||
ImmutableList<Set<SubqueryExpr>> subqueryExprsList = subqueryConjuncts.stream()
|
||||
.<Set<SubqueryExpr>>map(e -> e.collect(SubqueryExpr.class::isInstance))
|
||||
.collect(ImmutableList.toImmutableList());
|
||||
ImmutableList.Builder<Expression> newConjuncts = new ImmutableList.Builder<>();
|
||||
LogicalPlan applyPlan = null;
|
||||
LogicalPlan leftChildPlan = (LogicalPlan) join.left();
|
||||
|
||||
// Subquery traversal with the conjunct of and as the granularity.
|
||||
for (int i = 0; i < subqueryExprsList.size(); ++i) {
|
||||
Set<SubqueryExpr> subqueryExprs = subqueryExprsList.get(i);
|
||||
if (subqueryExprs.size() > 1) {
|
||||
// only support the conjunct contains one subquery expr
|
||||
return join;
|
||||
}
|
||||
|
||||
// first step: Replace the subquery of predicate in LogicalFilter
|
||||
// second step: Replace subquery with LogicalApply
|
||||
ReplaceSubquery replaceSubquery = new ReplaceSubquery(ctx.statementContext, true);
|
||||
SubqueryContext context = new SubqueryContext(subqueryExprs);
|
||||
Expression conjunct = replaceSubquery.replace(subqueryConjuncts.get(i), context);
|
||||
|
||||
applyPlan = subqueryToApply(
|
||||
subqueryExprs.stream().collect(ImmutableList.toImmutableList()),
|
||||
leftChildPlan, context.getSubqueryToMarkJoinSlot(),
|
||||
ctx.cascadesContext, Optional.of(conjunct), false);
|
||||
leftChildPlan = applyPlan;
|
||||
newConjuncts.add(conjunct);
|
||||
}
|
||||
List<Expression> simpleConjuncts = joinConjuncts.get(false);
|
||||
if (simpleConjuncts != null) {
|
||||
newConjuncts.addAll(simpleConjuncts);
|
||||
}
|
||||
Plan newJoin = join.withConjunctsChildren(join.getHashJoinConjuncts(),
|
||||
newConjuncts.build(), applyPlan, join.right());
|
||||
return newJoin;
|
||||
}))
|
||||
);
|
||||
}
|
||||
|
||||
private static boolean isValidSubqueryConjunct(Expression expression, Plan leftChild) {
|
||||
// the subquery must be uncorrelated subquery or only correlated to the left child
|
||||
// currently only support the following 4 simple scenarios
|
||||
// 1. col ComparisonPredicate subquery
|
||||
// 2. col in (subquery)
|
||||
// 3. exists (subquery)
|
||||
// 4. col1 ComparisonPredicate subquery or xxx (no more subquery)
|
||||
List<Slot> slots = leftChild.getOutput();
|
||||
if (expression instanceof ComparisonPredicate && expression.child(1) instanceof ScalarSubquery) {
|
||||
ScalarSubquery subquery = (ScalarSubquery) expression.child(1);
|
||||
return slots.containsAll(subquery.getCorrelateSlots());
|
||||
} else if (expression instanceof InSubquery) {
|
||||
return slots.containsAll(((InSubquery) expression).getCorrelateSlots());
|
||||
} else if (expression instanceof Exists) {
|
||||
return slots.containsAll(((Exists) expression).getCorrelateSlots());
|
||||
} else {
|
||||
List<SubqueryExpr> subqueryExprs = expression.collectToList(SubqueryExpr.class::isInstance);
|
||||
if (subqueryExprs.size() == 1) {
|
||||
return slots.containsAll(subqueryExprs.get(0).getCorrelateSlots());
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
private LogicalPlan subqueryToApply(List<SubqueryExpr> subqueryExprs, LogicalPlan childPlan,
|
||||
Map<SubqueryExpr, Optional<MarkJoinSlotReference>> subqueryToMarkJoinSlot,
|
||||
CascadesContext ctx,
|
||||
@ -252,12 +329,12 @@ public class SubqueryToApply implements AnalysisRuleFactory {
|
||||
private final StatementContext statementContext;
|
||||
private boolean isMarkJoin;
|
||||
|
||||
private final boolean isProject;
|
||||
private final boolean shouldOutputMarkJoinSlot;
|
||||
|
||||
public ReplaceSubquery(StatementContext statementContext,
|
||||
boolean isProject) {
|
||||
boolean shouldOutputMarkJoinSlot) {
|
||||
this.statementContext = Objects.requireNonNull(statementContext, "statementContext can't be null");
|
||||
this.isProject = isProject;
|
||||
this.shouldOutputMarkJoinSlot = shouldOutputMarkJoinSlot;
|
||||
}
|
||||
|
||||
public Expression replace(Expression expression, SubqueryContext subqueryContext) {
|
||||
@ -269,7 +346,7 @@ public class SubqueryToApply implements AnalysisRuleFactory {
|
||||
// The result set when NULL is specified in the subquery and still evaluates to TRUE by using EXISTS
|
||||
// When the number of rows returned is empty, agg will return null, so if there is more agg,
|
||||
// it will always consider the returned result to be true
|
||||
boolean needCreateMarkJoinSlot = isMarkJoin || isProject;
|
||||
boolean needCreateMarkJoinSlot = isMarkJoin || shouldOutputMarkJoinSlot;
|
||||
MarkJoinSlotReference markJoinSlotReference = null;
|
||||
if (exists.getQueryPlan().anyMatch(Aggregate.class::isInstance) && needCreateMarkJoinSlot) {
|
||||
markJoinSlotReference =
|
||||
@ -288,7 +365,7 @@ public class SubqueryToApply implements AnalysisRuleFactory {
|
||||
public Expression visitInSubquery(InSubquery in, SubqueryContext context) {
|
||||
MarkJoinSlotReference markJoinSlotReference =
|
||||
new MarkJoinSlotReference(statementContext.generateColumnName());
|
||||
boolean needCreateMarkJoinSlot = isMarkJoin || isProject;
|
||||
boolean needCreateMarkJoinSlot = isMarkJoin || shouldOutputMarkJoinSlot;
|
||||
if (needCreateMarkJoinSlot) {
|
||||
context.setSubqueryToMarkJoinSlot(in, Optional.of(markJoinSlotReference));
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user