[fix](nereids) uncorrelated subquery can't get the correct result (#12421)

When the current non-correlated subquery is executed, an error will be reported that the corresponding column cannot be found.
The reason is that the tupleID of the child obtained in visitPhysicalNestedLoopJoin is not consistent with the child.

The non-correlated subquery will trigger this bug because it uses crossJoin.
At the same time, sub-query regression tests for non-associative and complex scenarios have been added

Co-authored-by: morrySnow <morrysnow@126.com>
This commit is contained in:
zhengshiJ
2022-09-09 18:08:34 +08:00
committed by GitHub
parent 554ba40b13
commit dc7e5ca039
10 changed files with 209 additions and 95 deletions

View File

@ -90,6 +90,7 @@ import com.google.common.collect.Sets;
import org.apache.commons.collections.CollectionUtils;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
@ -463,18 +464,16 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla
// but BE need left child's output must be before right child's output.
// So we need to swap the output order of left and right child if necessary.
// TODO: revert this after Nereids could ensure the output order is correct.
TupleDescriptor leftChildOutputTupleDesc = leftPlanRoot.getOutputTupleDesc();
TupleDescriptor leftTuple =
leftChildOutputTupleDesc != null ? leftChildOutputTupleDesc : context.getTupleDesc(leftPlanRoot);
TupleDescriptor rightChildOutputTupleDesc = rightPlanRoot.getOutputTupleDesc();
TupleDescriptor rightTuple =
rightChildOutputTupleDesc != null ? rightChildOutputTupleDesc : context.getTupleDesc(rightPlanRoot);
List<TupleDescriptor> leftTuples = context.getTupleDesc(leftPlanRoot);
List<TupleDescriptor> rightTuples = context.getTupleDesc(rightPlanRoot);
TupleDescriptor outputDescriptor = context.generateTupleDesc();
Map<ExprId, SlotReference> slotReferenceMap = Maps.newHashMap();
hashJoin.getOutput().stream()
.map(SlotReference.class::cast)
.forEach(s -> slotReferenceMap.put(s.getExprId(), s));
List<Expr> srcToOutput = Stream.concat(leftTuple.getSlots().stream(), rightTuple.getSlots().stream())
List<Expr> srcToOutput = Stream.concat(leftTuples.stream(), rightTuples.stream())
.map(TupleDescriptor::getSlots)
.flatMap(Collection::stream)
.map(sd -> context.findExprId(sd.getId()))
.map(slotReferenceMap::get)
.filter(Objects::nonNull)
@ -512,12 +511,19 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla
PlanNode leftFragmentPlanRoot = leftFragment.getPlanRoot();
PlanNode rightFragmentPlanRoot = rightFragment.getPlanRoot();
if (JoinUtils.shouldNestedLoopJoin(nestedLoopJoin)) {
CrossJoinNode crossJoinNode =
new CrossJoinNode(context.nextPlanNodeId(), leftFragmentPlanRoot, rightFragmentPlanRoot, null);
List<TupleDescriptor> leftTuples = context.getTupleDesc(leftFragmentPlanRoot);
List<TupleDescriptor> rightTuples = context.getTupleDesc(rightFragmentPlanRoot);
List<TupleId> tupleIds = Stream.concat(leftTuples.stream(), rightTuples.stream())
.map(TupleDescriptor::getId)
.collect(Collectors.toList());
CrossJoinNode crossJoinNode = new CrossJoinNode(context.nextPlanNodeId(),
leftFragmentPlanRoot, rightFragmentPlanRoot, tupleIds);
rightFragment.getPlanRoot().setCompactData(false);
crossJoinNode.setChild(0, leftFragment.getPlanRoot());
connectChildFragment(crossJoinNode, 1, leftFragment, rightFragment, context);
leftFragment.setPlanRoot(crossJoinNode);
return leftFragment;
} else {
throw new RuntimeException("Physical nested loop join could not execute with equal join condition.");
@ -650,9 +656,9 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla
PlanTranslatorContext context) {
PlanFragment inputFragment = assertNumRows.child(0).accept(this, context);
//create assertNode
PlanNode child = inputFragment.getPlanRoot();
AssertNumRowsNode assertNumRowsNode = new AssertNumRowsNode(context.nextPlanNodeId(),
child, ExpressionTranslator.translateAssert(assertNumRows.getAssertNumRowsElement()));
inputFragment.getPlanRoot(),
ExpressionTranslator.translateAssert(assertNumRows.getAssertNumRowsElement()));
PlanFragment mergeFragment = createParentFragment(inputFragment, DataPartition.UNPARTITIONED, context);
mergeFragment.addPlanRoot(assertNumRowsNode);
return mergeFragment;

View File

@ -39,6 +39,7 @@ import com.google.common.collect.Maps;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;
/**
* Context of physical plan.
@ -132,15 +133,11 @@ public class PlanTranslatorContext {
return slotDescriptor;
}
/**
* in Nereids, all node only has one TupleDescriptor, so we can use the first one.
*
* @param planNode the node to get the TupleDescriptor
*
* @return plan node's tuple descriptor
*/
public TupleDescriptor getTupleDesc(PlanNode planNode) {
return descTable.getTupleDesc(planNode.getOutputTupleIds().get(0));
public List<TupleDescriptor> getTupleDesc(PlanNode planNode) {
if (planNode.getOutputTupleDesc() != null) {
return Lists.newArrayList(planNode.getOutputTupleDesc());
}
return planNode.getOutputTupleIds().stream().map(this::getTupleDesc).collect(Collectors.toList());
}
public TupleDescriptor getTupleDesc(TupleId tupleId) {

View File

@ -26,7 +26,6 @@ import org.apache.doris.nereids.trees.expressions.InSubquery;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.logical.LogicalApply;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.util.ExpressionUtils;
import com.google.common.collect.Lists;
@ -54,15 +53,13 @@ public class InApplyToJoin extends OneRewriteRuleFactory {
apply.right().getOutput().get(0));
}
LogicalJoin newJoin;
if (((InSubquery) apply.getSubqueryExpr()).isNot()) {
newJoin = new LogicalJoin<>(JoinType.LEFT_ANTI_JOIN, Lists.newArrayList(), Optional.of(predicate),
return new LogicalJoin<>(JoinType.LEFT_ANTI_JOIN, Lists.newArrayList(), Optional.of(predicate),
apply.left(), apply.right());
} else {
newJoin = new LogicalJoin<>(JoinType.LEFT_SEMI_JOIN, Lists.newArrayList(), Optional.of(predicate),
return new LogicalJoin<>(JoinType.LEFT_SEMI_JOIN, Lists.newArrayList(), Optional.of(predicate),
apply.left(), apply.right());
}
return new LogicalProject(apply.left().getOutput(), newJoin);
}).toRule(RuleType.IN_APPLY_TO_JOIN);
}
}

View File

@ -21,19 +21,15 @@ import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory;
import org.apache.doris.nereids.trees.expressions.AssertNumRowsElement;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalApply;
import org.apache.doris.nereids.trees.plans.logical.LogicalAssertNumRows;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import com.google.common.collect.Lists;
import java.util.List;
/**
* Convert scalarApply to LogicalJoin.
*
@ -58,10 +54,8 @@ public class ScalarApplyToJoin extends OneRewriteRuleFactory {
1, apply.getSubqueryExpr().toString(),
AssertNumRowsElement.Assertion.EQ),
(LogicalPlan) apply.right());
LogicalJoin newJoin = new LogicalJoin<>(JoinType.CROSS_JOIN,
return new LogicalJoin<>(JoinType.CROSS_JOIN,
(LogicalPlan) apply.left(), assertNumRows);
List<Slot> projects = ((LogicalPlan) apply.left()).getOutput();
return new LogicalProject(projects, newJoin);
}
private Plan correlatedToJoin(LogicalApply apply) {

View File

@ -73,9 +73,10 @@ public class AssertNumRowsElement extends Expression implements LeafExpression {
@Override
public String toString() {
return Utils.toSqlString("desiredNumOfRows: ",
return Utils.toSqlString("AssertNumRowsElement",
"desiredNumOfRows: ",
Long.toString(desiredNumOfRows),
"assertion: " + assertion);
"assertion: ", assertion);
}
@Override

View File

@ -25,6 +25,7 @@ import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.PlanType;
import org.apache.doris.nereids.trees.plans.visitor.PlanVisitor;
import org.apache.doris.nereids.util.Utils;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
@ -59,7 +60,8 @@ public class LogicalAssertNumRows<CHILD_TYPE extends Plan> extends LogicalUnary<
@Override
public String toString() {
return "LogicalAssertNumRows (" + assertNumRowsElement + ")";
return Utils.toSqlString("LogicalAssertNumRows",
"assertNumRowsElement", assertNumRowsElement);
}
@Override

View File

@ -49,7 +49,11 @@ public class AssertNumRowsNode extends PlanNode {
this.subqueryString = assertNumRowsElement.getSubqueryString();
this.assertion = assertNumRowsElement.getAssertion();
this.children.add(input);
this.tupleIds.addAll(input.getTupleIds());
if (input.getOutputTupleDesc() != null) {
this.tupleIds.add(input.getOutputTupleDesc().getId());
} else {
this.tupleIds.addAll(input.getTupleIds());
}
this.tblRefIds.addAll(input.getTblRefIds());
this.nullableTupleIds.addAll(input.getNullableTupleIds());
}

View File

@ -19,6 +19,7 @@ package org.apache.doris.planner;
import org.apache.doris.analysis.Analyzer;
import org.apache.doris.analysis.TableRef;
import org.apache.doris.analysis.TupleId;
import org.apache.doris.common.UserException;
import org.apache.doris.statistics.StatisticalType;
import org.apache.doris.statistics.StatsRecursiveDerive;
@ -30,6 +31,8 @@ import com.google.common.base.MoreObjects;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import java.util.List;
/**
* Cross join between left child and right child.
*/
@ -57,6 +60,22 @@ public class CrossJoinNode extends PlanNode {
nullableTupleIds.addAll(inner.getNullableTupleIds());
}
/**
* Only for Nereids.
*/
public CrossJoinNode(PlanNodeId id, PlanNode outer, PlanNode inner, List<TupleId> tupleIds) {
super(id, "CROSS JOIN", StatisticalType.CROSS_JOIN_NODE);
this.innerRef = null;
this.tupleIds.addAll(tupleIds);
children.add(outer);
children.add(inner);
// Inherits all the nullable tuple from the children
// Mark tuples that form the "nullable" side of the outer join as nullable.
nullableTupleIds.addAll(outer.getNullableTupleIds());
nullableTupleIds.addAll(inner.getNullableTupleIds());
}
public TableRef getInnerRef() {
return innerRef;
}