[fix](nereids)support uncorrelated subquery in join condition (#26672)
sql select * from t1 a join t1 b on b.id in (select 1) and a.id = b.id; will report an error. This pr support uncorrelated subquery in join condition to fix it
This commit is contained in:
@ -23,7 +23,6 @@ import org.apache.doris.nereids.rules.Rule;
|
||||
import org.apache.doris.nereids.rules.RuleType;
|
||||
import org.apache.doris.nereids.trees.expressions.Alias;
|
||||
import org.apache.doris.nereids.trees.expressions.BinaryOperator;
|
||||
import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
|
||||
import org.apache.doris.nereids.trees.expressions.Exists;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.InSubquery;
|
||||
@ -45,6 +44,7 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalOneRowRelation;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalSort;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import com.google.common.collect.ImmutableSet;
|
||||
@ -180,7 +180,13 @@ public class SubqueryToApply implements AnalysisRuleFactory {
|
||||
Collectors.toList()));
|
||||
List<Expression> subqueryConjuncts = joinConjuncts.get(true);
|
||||
if (subqueryConjuncts == null || subqueryConjuncts.stream()
|
||||
.anyMatch(expr -> !isValidSubqueryConjunct(expr, join.left()))) {
|
||||
.anyMatch(expr -> !isValidSubqueryConjunct(expr))) {
|
||||
return join;
|
||||
}
|
||||
|
||||
List<RelatedInfo> relatedInfoList = collectRelatedInfo(
|
||||
subqueryConjuncts, join.left(), join.right());
|
||||
if (relatedInfoList.stream().anyMatch(info -> info == RelatedInfo.UnSupported)) {
|
||||
return join;
|
||||
}
|
||||
|
||||
@ -188,8 +194,9 @@ public class SubqueryToApply implements AnalysisRuleFactory {
|
||||
.<Set<SubqueryExpr>>map(e -> e.collect(SubqueryExpr.class::isInstance))
|
||||
.collect(ImmutableList.toImmutableList());
|
||||
ImmutableList.Builder<Expression> newConjuncts = new ImmutableList.Builder<>();
|
||||
LogicalPlan applyPlan = null;
|
||||
LogicalPlan applyPlan;
|
||||
LogicalPlan leftChildPlan = (LogicalPlan) join.left();
|
||||
LogicalPlan rightChildPlan = (LogicalPlan) join.right();
|
||||
|
||||
// Subquery traversal with the conjunct of and as the granularity.
|
||||
for (int i = 0; i < subqueryExprsList.size(); ++i) {
|
||||
@ -207,9 +214,14 @@ public class SubqueryToApply implements AnalysisRuleFactory {
|
||||
|
||||
applyPlan = subqueryToApply(
|
||||
subqueryExprs.stream().collect(ImmutableList.toImmutableList()),
|
||||
leftChildPlan, context.getSubqueryToMarkJoinSlot(),
|
||||
relatedInfoList.get(i) == RelatedInfo.RelatedToLeft ? leftChildPlan : rightChildPlan,
|
||||
context.getSubqueryToMarkJoinSlot(),
|
||||
ctx.cascadesContext, Optional.of(conjunct), false);
|
||||
leftChildPlan = applyPlan;
|
||||
if (relatedInfoList.get(i) == RelatedInfo.RelatedToLeft) {
|
||||
leftChildPlan = applyPlan;
|
||||
} else {
|
||||
rightChildPlan = applyPlan;
|
||||
}
|
||||
newConjuncts.add(conjunct);
|
||||
}
|
||||
List<Expression> simpleConjuncts = joinConjuncts.get(false);
|
||||
@ -217,34 +229,82 @@ public class SubqueryToApply implements AnalysisRuleFactory {
|
||||
newConjuncts.addAll(simpleConjuncts);
|
||||
}
|
||||
Plan newJoin = join.withConjunctsChildren(join.getHashJoinConjuncts(),
|
||||
newConjuncts.build(), applyPlan, join.right());
|
||||
newConjuncts.build(), leftChildPlan, rightChildPlan);
|
||||
return newJoin;
|
||||
}))
|
||||
);
|
||||
}
|
||||
|
||||
private static boolean isValidSubqueryConjunct(Expression expression, Plan leftChild) {
|
||||
// the subquery must be uncorrelated subquery or only correlated to the left child
|
||||
// currently only support the following 4 simple scenarios
|
||||
// 1. col ComparisonPredicate subquery
|
||||
// 2. col in (subquery)
|
||||
// 3. exists (subquery)
|
||||
// 4. col1 ComparisonPredicate subquery or xxx (no more subquery)
|
||||
List<Slot> slots = leftChild.getOutput();
|
||||
if (expression instanceof ComparisonPredicate && expression.child(1) instanceof ScalarSubquery) {
|
||||
ScalarSubquery subquery = (ScalarSubquery) expression.child(1);
|
||||
return slots.containsAll(subquery.getCorrelateSlots());
|
||||
} else if (expression instanceof InSubquery) {
|
||||
return slots.containsAll(((InSubquery) expression).getCorrelateSlots());
|
||||
} else if (expression instanceof Exists) {
|
||||
return slots.containsAll(((Exists) expression).getCorrelateSlots());
|
||||
} else {
|
||||
private static boolean isValidSubqueryConjunct(Expression expression) {
|
||||
// only support 1 subquery expr in the expression
|
||||
// don't support expression like subquery1 or subquery2
|
||||
return expression.collectToList(SubqueryExpr.class::isInstance).size() == 1;
|
||||
}
|
||||
|
||||
private enum RelatedInfo {
|
||||
// both subquery and its output don't related to any child. like (select sum(t.a) from t) > 1
|
||||
Unrelated,
|
||||
// either subquery or its output only related to left child. like bellow:
|
||||
// tableLeft.a in (select t.a from t)
|
||||
// 3 in (select t.b from t where t.a = tableLeft.a)
|
||||
// tableLeft.a > (select sum(t.a) from t where tableLeft.b = t.b)
|
||||
RelatedToLeft,
|
||||
// like above, but related to right child
|
||||
RelatedToRight,
|
||||
// subquery related to both left and child is not supported:
|
||||
// tableLeft.a > (select sum(t.a) from t where t.b = tableRight.b)
|
||||
UnSupported
|
||||
}
|
||||
|
||||
private ImmutableList<RelatedInfo> collectRelatedInfo(List<Expression> subqueryConjuncts,
|
||||
Plan leftChild, Plan rightChild) {
|
||||
int size = subqueryConjuncts.size();
|
||||
ImmutableList.Builder<RelatedInfo> correlatedInfoList = new ImmutableList.Builder<>();
|
||||
Set<Slot> leftOutputSlots = leftChild.getOutputSet();
|
||||
Set<Slot> rightOutputSlots = rightChild.getOutputSet();
|
||||
for (int i = 0; i < size; ++i) {
|
||||
Expression expression = subqueryConjuncts.get(i);
|
||||
List<SubqueryExpr> subqueryExprs = expression.collectToList(SubqueryExpr.class::isInstance);
|
||||
RelatedInfo relatedInfo = RelatedInfo.UnSupported;
|
||||
if (subqueryExprs.size() == 1) {
|
||||
return slots.containsAll(subqueryExprs.get(0).getCorrelateSlots());
|
||||
SubqueryExpr subqueryExpr = subqueryExprs.get(0);
|
||||
List<Slot> correlatedSlots = subqueryExpr.getCorrelateSlots();
|
||||
if (subqueryExpr instanceof ScalarSubquery) {
|
||||
Set<Slot> inputSlots = expression.getInputSlots();
|
||||
if (correlatedSlots.isEmpty() && inputSlots.isEmpty()) {
|
||||
relatedInfo = RelatedInfo.Unrelated;
|
||||
} else if (leftOutputSlots.containsAll(inputSlots)
|
||||
&& leftOutputSlots.containsAll(correlatedSlots)) {
|
||||
relatedInfo = RelatedInfo.RelatedToLeft;
|
||||
} else if (rightOutputSlots.containsAll(inputSlots)
|
||||
&& rightOutputSlots.containsAll(correlatedSlots)) {
|
||||
relatedInfo = RelatedInfo.RelatedToRight;
|
||||
}
|
||||
} else if (subqueryExpr instanceof InSubquery) {
|
||||
InSubquery inSubquery = (InSubquery) subqueryExpr;
|
||||
Set<Slot> compareSlots = inSubquery.getCompareExpr().getInputSlots();
|
||||
if (compareSlots.isEmpty()) {
|
||||
relatedInfo = RelatedInfo.UnSupported;
|
||||
} else if (leftOutputSlots.containsAll(compareSlots)
|
||||
&& leftOutputSlots.containsAll(correlatedSlots)) {
|
||||
relatedInfo = RelatedInfo.RelatedToLeft;
|
||||
} else if (rightOutputSlots.containsAll(compareSlots)
|
||||
&& rightOutputSlots.containsAll(correlatedSlots)) {
|
||||
relatedInfo = RelatedInfo.RelatedToRight;
|
||||
}
|
||||
} else if (subqueryExpr instanceof Exists) {
|
||||
if (correlatedSlots.isEmpty()) {
|
||||
relatedInfo = RelatedInfo.Unrelated;
|
||||
} else if (leftOutputSlots.containsAll(correlatedSlots)) {
|
||||
relatedInfo = RelatedInfo.RelatedToLeft;
|
||||
} else if (rightOutputSlots.containsAll(correlatedSlots)) {
|
||||
relatedInfo = RelatedInfo.RelatedToRight;
|
||||
}
|
||||
}
|
||||
}
|
||||
correlatedInfoList.add(relatedInfo);
|
||||
}
|
||||
return false;
|
||||
return correlatedInfoList.build();
|
||||
}
|
||||
|
||||
private LogicalPlan subqueryToApply(List<SubqueryExpr> subqueryExprs, LogicalPlan childPlan,
|
||||
@ -270,10 +330,17 @@ public class SubqueryToApply implements AnalysisRuleFactory {
|
||||
private boolean nonMarkJoinExistsWithAgg(SubqueryExpr exists,
|
||||
Map<SubqueryExpr, Optional<MarkJoinSlotReference>> subqueryToMarkJoinSlot) {
|
||||
return exists instanceof Exists
|
||||
&& exists.getQueryPlan()
|
||||
.anyMatch(planTreeNode -> planTreeNode instanceof LogicalAggregate
|
||||
&& ((LogicalAggregate<?>) planTreeNode).getGroupByExpressions().isEmpty())
|
||||
&& !subqueryToMarkJoinSlot.get(exists).isPresent();
|
||||
&& !subqueryToMarkJoinSlot.get(exists).isPresent()
|
||||
&& hasTopLevelAggWithoutGroupBy(exists.getQueryPlan());
|
||||
}
|
||||
|
||||
private boolean hasTopLevelAggWithoutGroupBy(Plan plan) {
|
||||
if (plan instanceof LogicalAggregate) {
|
||||
return ((LogicalAggregate) plan).getGroupByExpressions().isEmpty();
|
||||
} else if (plan instanceof LogicalProject || plan instanceof LogicalSort) {
|
||||
return hasTopLevelAggWithoutGroupBy(plan.child(0));
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
private LogicalPlan addApply(SubqueryExpr subquery, LogicalPlan childPlan,
|
||||
|
||||
Reference in New Issue
Block a user