[fix](nereids)AssertNumRow node's output should be nullable (#32136)
Co-authored-by: Co-Author Jerry Hu <mrhhsg@gmail.com>
This commit is contained in:
@ -1013,12 +1013,39 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla
|
||||
PlanTranslatorContext context) {
|
||||
PlanFragment currentFragment = assertNumRows.child().accept(this, context);
|
||||
List<List<Expr>> distributeExprLists = getDistributeExprs(assertNumRows.child());
|
||||
|
||||
// we need convert all columns to nullable in AssertNumRows node
|
||||
// create a tuple for AssertNumRowsNode
|
||||
TupleDescriptor tupleDescriptor = context.generateTupleDesc();
|
||||
// create assertNode
|
||||
AssertNumRowsNode assertNumRowsNode = new AssertNumRowsNode(context.nextPlanNodeId(),
|
||||
currentFragment.getPlanRoot(),
|
||||
ExpressionTranslator.translateAssert(assertNumRows.getAssertNumRowsElement()));
|
||||
ExpressionTranslator.translateAssert(assertNumRows.getAssertNumRowsElement()), true, tupleDescriptor);
|
||||
assertNumRowsNode.setChildrenDistributeExprLists(distributeExprLists);
|
||||
assertNumRowsNode.setNereidsId(assertNumRows.getId());
|
||||
|
||||
// collect all child output slots
|
||||
List<TupleDescriptor> childTuples = context.getTupleDesc(currentFragment.getPlanRoot());
|
||||
List<SlotDescriptor> childSlotDescriptors = childTuples.stream()
|
||||
.map(TupleDescriptor::getSlots)
|
||||
.flatMap(Collection::stream)
|
||||
.collect(Collectors.toList());
|
||||
|
||||
// create output slot based on child output
|
||||
Map<ExprId, SlotReference> childOutputMap = Maps.newHashMap();
|
||||
assertNumRows.child().getOutput().stream()
|
||||
.map(SlotReference.class::cast)
|
||||
.forEach(s -> childOutputMap.put(s.getExprId(), s));
|
||||
List<SlotDescriptor> slotDescriptors = Lists.newArrayList();
|
||||
for (SlotDescriptor slot : childSlotDescriptors) {
|
||||
SlotReference sf = childOutputMap.get(context.findExprId(slot.getId()));
|
||||
SlotDescriptor sd = context.createSlotDesc(tupleDescriptor, sf, slot.getParent().getTable());
|
||||
slotDescriptors.add(sd);
|
||||
}
|
||||
|
||||
// set all output slot nullable
|
||||
slotDescriptors.forEach(sd -> sd.setIsNullable(true));
|
||||
|
||||
addPlanRoot(currentFragment, assertNumRowsNode, assertNumRows);
|
||||
return currentFragment;
|
||||
}
|
||||
|
||||
@ -60,19 +60,22 @@ public class EliminateAssertNumRows extends OneRewriteRuleFactory {
|
||||
|
||||
private boolean canEliminate(LogicalAssertNumRows<?> assertNumRows, Plan plan) {
|
||||
long maxOutputRowcount;
|
||||
AssertNumRowsElement assertNumRowsElement = assertNumRows.getAssertNumRowsElement();
|
||||
Assertion assertion = assertNumRowsElement.getAssertion();
|
||||
long assertNum = assertNumRowsElement.getDesiredNumOfRows();
|
||||
// Don't need to consider TopN, because it's generated by Sort + Limit.
|
||||
if (plan instanceof LogicalLimit) {
|
||||
maxOutputRowcount = ((LogicalLimit<?>) plan).getLimit();
|
||||
} else if (plan instanceof LogicalAggregate && ((LogicalAggregate<?>) plan).getGroupByExpressions().isEmpty()) {
|
||||
maxOutputRowcount = 1;
|
||||
if (assertion == Assertion.EQ && assertNum == 1) {
|
||||
return true;
|
||||
} else {
|
||||
maxOutputRowcount = 1;
|
||||
}
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
|
||||
AssertNumRowsElement assertNumRowsElement = assertNumRows.getAssertNumRowsElement();
|
||||
Assertion assertion = assertNumRowsElement.getAssertion();
|
||||
long assertNum = assertNumRowsElement.getDesiredNumOfRows();
|
||||
|
||||
switch (assertion) {
|
||||
case NE:
|
||||
case LT:
|
||||
|
||||
@ -54,11 +54,8 @@ public class ScalarApplyToJoin extends OneRewriteRuleFactory {
|
||||
}
|
||||
|
||||
private Plan unCorrelatedToJoin(LogicalApply apply) {
|
||||
LogicalAssertNumRows assertNumRows = new LogicalAssertNumRows<>(
|
||||
new AssertNumRowsElement(
|
||||
1, apply.getSubqueryExpr().toString(),
|
||||
apply.isInProject()
|
||||
? AssertNumRowsElement.Assertion.EQ : AssertNumRowsElement.Assertion.LE),
|
||||
LogicalAssertNumRows assertNumRows = new LogicalAssertNumRows<>(new AssertNumRowsElement(1,
|
||||
apply.getSubqueryExpr().toString(), AssertNumRowsElement.Assertion.EQ),
|
||||
(LogicalPlan) apply.right());
|
||||
return new LogicalJoin<>(JoinType.CROSS_JOIN,
|
||||
ExpressionUtils.EMPTY_CONDITION,
|
||||
|
||||
@ -34,6 +34,7 @@ import com.google.common.collect.ImmutableList;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.Optional;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* Assert num rows node is used to determine whether the number of rows is less than desired num of rows.
|
||||
@ -115,8 +116,6 @@ public class LogicalAssertNumRows<CHILD_TYPE extends Plan> extends LogicalUnary<
|
||||
|
||||
@Override
|
||||
public List<Slot> computeOutput() {
|
||||
return ImmutableList.<Slot>builder()
|
||||
.addAll(child().getOutput())
|
||||
.build();
|
||||
return child().getOutput().stream().map(o -> o.withNullable(true)).collect(Collectors.toList());
|
||||
}
|
||||
}
|
||||
|
||||
@ -35,6 +35,7 @@ import com.google.common.collect.ImmutableList;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.Optional;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* Physical assertNumRows.
|
||||
@ -59,9 +60,7 @@ public class PhysicalAssertNumRows<CHILD_TYPE extends Plan> extends PhysicalUnar
|
||||
|
||||
@Override
|
||||
public List<Slot> computeOutput() {
|
||||
return ImmutableList.<Slot>builder()
|
||||
.addAll(child().getOutput())
|
||||
.build();
|
||||
return child().getOutput().stream().map(o -> o.withNullable(true)).collect(Collectors.toList());
|
||||
}
|
||||
|
||||
public AssertNumRowsElement getAssertNumRowsElement() {
|
||||
|
||||
@ -19,6 +19,7 @@ package org.apache.doris.planner;
|
||||
|
||||
import org.apache.doris.analysis.Analyzer;
|
||||
import org.apache.doris.analysis.AssertNumRowsElement;
|
||||
import org.apache.doris.analysis.TupleDescriptor;
|
||||
import org.apache.doris.common.UserException;
|
||||
import org.apache.doris.statistics.StatisticalType;
|
||||
import org.apache.doris.statistics.StatsRecursiveDerive;
|
||||
@ -43,21 +44,35 @@ public class AssertNumRowsNode extends PlanNode {
|
||||
private String subqueryString;
|
||||
private AssertNumRowsElement.Assertion assertion;
|
||||
|
||||
private boolean shouldConvertOutputToNullable = false;
|
||||
|
||||
public AssertNumRowsNode(PlanNodeId id, PlanNode input, AssertNumRowsElement assertNumRowsElement) {
|
||||
this(id, input, assertNumRowsElement, false, null);
|
||||
}
|
||||
|
||||
public AssertNumRowsNode(PlanNodeId id, PlanNode input, AssertNumRowsElement assertNumRowsElement,
|
||||
boolean convertToNullable, TupleDescriptor tupleDescriptor) {
|
||||
super(id, "ASSERT NUMBER OF ROWS", StatisticalType.ASSERT_NUM_ROWS_NODE);
|
||||
this.desiredNumOfRows = assertNumRowsElement.getDesiredNumOfRows();
|
||||
this.subqueryString = assertNumRowsElement.getSubqueryString();
|
||||
this.assertion = assertNumRowsElement.getAssertion();
|
||||
this.children.add(input);
|
||||
if (input.getOutputTupleDesc() != null) {
|
||||
this.tupleIds.add(input.getOutputTupleDesc().getId());
|
||||
if (tupleDescriptor != null) {
|
||||
this.tupleIds.add(tupleDescriptor.getId());
|
||||
} else {
|
||||
this.tupleIds.addAll(input.getTupleIds());
|
||||
if (input.getOutputTupleDesc() != null) {
|
||||
this.tupleIds.add(input.getOutputTupleDesc().getId());
|
||||
} else {
|
||||
this.tupleIds.addAll(input.getTupleIds());
|
||||
}
|
||||
}
|
||||
|
||||
this.tblRefIds.addAll(input.getTblRefIds());
|
||||
this.nullableTupleIds.addAll(input.getNullableTupleIds());
|
||||
this.shouldConvertOutputToNullable = convertToNullable;
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public void init(Analyzer analyzer) throws UserException {
|
||||
super.init(analyzer);
|
||||
@ -94,6 +109,7 @@ public class AssertNumRowsNode extends PlanNode {
|
||||
msg.assert_num_rows_node.setDesiredNumRows(desiredNumOfRows);
|
||||
msg.assert_num_rows_node.setSubqueryString(subqueryString);
|
||||
msg.assert_num_rows_node.setAssertion(assertion.toThrift());
|
||||
msg.assert_num_rows_node.setShouldConvertOutputToNullable(shouldConvertOutputToNullable);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
Reference in New Issue
Block a user