[fix](nereids)AssertNumRow node's output should be nullable (#32136)

Co-authored-by: Co-Author Jerry Hu <mrhhsg@gmail.com>
This commit is contained in:
starocean999
2024-03-15 17:07:27 +08:00
committed by yiguolei
parent c0776c7c07
commit 97b35d6830
21 changed files with 392 additions and 324 deletions

View File

@ -1013,12 +1013,39 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla
PlanTranslatorContext context) {
PlanFragment currentFragment = assertNumRows.child().accept(this, context);
List<List<Expr>> distributeExprLists = getDistributeExprs(assertNumRows.child());
// we need convert all columns to nullable in AssertNumRows node
// create a tuple for AssertNumRowsNode
TupleDescriptor tupleDescriptor = context.generateTupleDesc();
// create assertNode
AssertNumRowsNode assertNumRowsNode = new AssertNumRowsNode(context.nextPlanNodeId(),
currentFragment.getPlanRoot(),
ExpressionTranslator.translateAssert(assertNumRows.getAssertNumRowsElement()));
ExpressionTranslator.translateAssert(assertNumRows.getAssertNumRowsElement()), true, tupleDescriptor);
assertNumRowsNode.setChildrenDistributeExprLists(distributeExprLists);
assertNumRowsNode.setNereidsId(assertNumRows.getId());
// collect all child output slots
List<TupleDescriptor> childTuples = context.getTupleDesc(currentFragment.getPlanRoot());
List<SlotDescriptor> childSlotDescriptors = childTuples.stream()
.map(TupleDescriptor::getSlots)
.flatMap(Collection::stream)
.collect(Collectors.toList());
// create output slot based on child output
Map<ExprId, SlotReference> childOutputMap = Maps.newHashMap();
assertNumRows.child().getOutput().stream()
.map(SlotReference.class::cast)
.forEach(s -> childOutputMap.put(s.getExprId(), s));
List<SlotDescriptor> slotDescriptors = Lists.newArrayList();
for (SlotDescriptor slot : childSlotDescriptors) {
SlotReference sf = childOutputMap.get(context.findExprId(slot.getId()));
SlotDescriptor sd = context.createSlotDesc(tupleDescriptor, sf, slot.getParent().getTable());
slotDescriptors.add(sd);
}
// set all output slot nullable
slotDescriptors.forEach(sd -> sd.setIsNullable(true));
addPlanRoot(currentFragment, assertNumRowsNode, assertNumRows);
return currentFragment;
}

View File

@ -60,19 +60,22 @@ public class EliminateAssertNumRows extends OneRewriteRuleFactory {
private boolean canEliminate(LogicalAssertNumRows<?> assertNumRows, Plan plan) {
long maxOutputRowcount;
AssertNumRowsElement assertNumRowsElement = assertNumRows.getAssertNumRowsElement();
Assertion assertion = assertNumRowsElement.getAssertion();
long assertNum = assertNumRowsElement.getDesiredNumOfRows();
// Don't need to consider TopN, because it's generated by Sort + Limit.
if (plan instanceof LogicalLimit) {
maxOutputRowcount = ((LogicalLimit<?>) plan).getLimit();
} else if (plan instanceof LogicalAggregate && ((LogicalAggregate<?>) plan).getGroupByExpressions().isEmpty()) {
maxOutputRowcount = 1;
if (assertion == Assertion.EQ && assertNum == 1) {
return true;
} else {
maxOutputRowcount = 1;
}
} else {
return false;
}
AssertNumRowsElement assertNumRowsElement = assertNumRows.getAssertNumRowsElement();
Assertion assertion = assertNumRowsElement.getAssertion();
long assertNum = assertNumRowsElement.getDesiredNumOfRows();
switch (assertion) {
case NE:
case LT:

View File

@ -54,11 +54,8 @@ public class ScalarApplyToJoin extends OneRewriteRuleFactory {
}
private Plan unCorrelatedToJoin(LogicalApply apply) {
LogicalAssertNumRows assertNumRows = new LogicalAssertNumRows<>(
new AssertNumRowsElement(
1, apply.getSubqueryExpr().toString(),
apply.isInProject()
? AssertNumRowsElement.Assertion.EQ : AssertNumRowsElement.Assertion.LE),
LogicalAssertNumRows assertNumRows = new LogicalAssertNumRows<>(new AssertNumRowsElement(1,
apply.getSubqueryExpr().toString(), AssertNumRowsElement.Assertion.EQ),
(LogicalPlan) apply.right());
return new LogicalJoin<>(JoinType.CROSS_JOIN,
ExpressionUtils.EMPTY_CONDITION,

View File

@ -34,6 +34,7 @@ import com.google.common.collect.ImmutableList;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Collectors;
/**
* Assert num rows node is used to determine whether the number of rows is less than desired num of rows.
@ -115,8 +116,6 @@ public class LogicalAssertNumRows<CHILD_TYPE extends Plan> extends LogicalUnary<
@Override
public List<Slot> computeOutput() {
return ImmutableList.<Slot>builder()
.addAll(child().getOutput())
.build();
return child().getOutput().stream().map(o -> o.withNullable(true)).collect(Collectors.toList());
}
}

View File

@ -35,6 +35,7 @@ import com.google.common.collect.ImmutableList;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Collectors;
/**
* Physical assertNumRows.
@ -59,9 +60,7 @@ public class PhysicalAssertNumRows<CHILD_TYPE extends Plan> extends PhysicalUnar
@Override
public List<Slot> computeOutput() {
return ImmutableList.<Slot>builder()
.addAll(child().getOutput())
.build();
return child().getOutput().stream().map(o -> o.withNullable(true)).collect(Collectors.toList());
}
public AssertNumRowsElement getAssertNumRowsElement() {

View File

@ -19,6 +19,7 @@ package org.apache.doris.planner;
import org.apache.doris.analysis.Analyzer;
import org.apache.doris.analysis.AssertNumRowsElement;
import org.apache.doris.analysis.TupleDescriptor;
import org.apache.doris.common.UserException;
import org.apache.doris.statistics.StatisticalType;
import org.apache.doris.statistics.StatsRecursiveDerive;
@ -43,21 +44,35 @@ public class AssertNumRowsNode extends PlanNode {
private String subqueryString;
private AssertNumRowsElement.Assertion assertion;
private boolean shouldConvertOutputToNullable = false;
public AssertNumRowsNode(PlanNodeId id, PlanNode input, AssertNumRowsElement assertNumRowsElement) {
this(id, input, assertNumRowsElement, false, null);
}
public AssertNumRowsNode(PlanNodeId id, PlanNode input, AssertNumRowsElement assertNumRowsElement,
boolean convertToNullable, TupleDescriptor tupleDescriptor) {
super(id, "ASSERT NUMBER OF ROWS", StatisticalType.ASSERT_NUM_ROWS_NODE);
this.desiredNumOfRows = assertNumRowsElement.getDesiredNumOfRows();
this.subqueryString = assertNumRowsElement.getSubqueryString();
this.assertion = assertNumRowsElement.getAssertion();
this.children.add(input);
if (input.getOutputTupleDesc() != null) {
this.tupleIds.add(input.getOutputTupleDesc().getId());
if (tupleDescriptor != null) {
this.tupleIds.add(tupleDescriptor.getId());
} else {
this.tupleIds.addAll(input.getTupleIds());
if (input.getOutputTupleDesc() != null) {
this.tupleIds.add(input.getOutputTupleDesc().getId());
} else {
this.tupleIds.addAll(input.getTupleIds());
}
}
this.tblRefIds.addAll(input.getTblRefIds());
this.nullableTupleIds.addAll(input.getNullableTupleIds());
this.shouldConvertOutputToNullable = convertToNullable;
}
@Override
public void init(Analyzer analyzer) throws UserException {
super.init(analyzer);
@ -94,6 +109,7 @@ public class AssertNumRowsNode extends PlanNode {
msg.assert_num_rows_node.setDesiredNumRows(desiredNumOfRows);
msg.assert_num_rows_node.setSubqueryString(subqueryString);
msg.assert_num_rows_node.setAssertion(assertion.toThrift());
msg.assert_num_rows_node.setShouldConvertOutputToNullable(shouldConvertOutputToNullable);
}
@Override