[fix](Nereids) could not run query with repeat node in cte (#26330)

ExpressionDeepCopier not process VirtualReference, so we generate inline
plan with mistake.
This commit is contained in:
morrySnow
2023-11-03 14:24:01 +08:00
committed by GitHub
parent 9243de1898
commit a89477e8b5
4 changed files with 50 additions and 12 deletions

View File

@ -26,9 +26,15 @@ import org.apache.doris.nereids.trees.expressions.ListQuery;
import org.apache.doris.nereids.trees.expressions.ScalarSubquery;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.StatementScopeIdGenerator;
import org.apache.doris.nereids.trees.expressions.VirtualSlotReference;
import org.apache.doris.nereids.trees.expressions.functions.scalar.GroupingScalarFunction;
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;
import org.apache.doris.nereids.trees.plans.algebra.Repeat.GroupingSetShapes;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import com.google.common.base.Function;
import java.util.List;
import java.util.Map;
import java.util.Optional;
@ -75,6 +81,30 @@ public class ExpressionDeepCopier extends DefaultExpressionRewriter<DeepCopierCo
}
}
@Override
public Expression visitVirtualReference(VirtualSlotReference virtualSlotReference, DeepCopierContext context) {
Map<ExprId, ExprId> exprIdReplaceMap = context.exprIdReplaceMap;
ExprId newExprId;
if (exprIdReplaceMap.containsKey(virtualSlotReference.getExprId())) {
newExprId = exprIdReplaceMap.get(virtualSlotReference.getExprId());
} else {
newExprId = StatementScopeIdGenerator.newExprId();
}
// according to VirtualReference generating logic in Repeat.java
// generateVirtualGroupingIdSlot and generateVirtualSlotByFunction
Optional<GroupingScalarFunction> newOriginExpression = virtualSlotReference.getOriginExpression()
.map(func -> (GroupingScalarFunction) func.accept(this, context));
Function<GroupingSetShapes, List<Long>> newFunction = newOriginExpression
.<Function<GroupingSetShapes, List<Long>>>map(f -> f::computeVirtualSlotValue)
.orElseGet(() -> GroupingSetShapes::computeVirtualGroupingIdValue);
VirtualSlotReference newOne = new VirtualSlotReference(newExprId,
virtualSlotReference.getName(), virtualSlotReference.getDataType(),
virtualSlotReference.nullable(), virtualSlotReference.getQualifier(),
newOriginExpression, newFunction);
exprIdReplaceMap.put(virtualSlotReference.getExprId(), newOne.getExprId());
return newOne;
}
@Override
public Expression visitExistsSubquery(Exists exists, DeepCopierContext context) {
LogicalPlan logicalPlan = LogicalPlanDeepCopier.INSTANCE.deepCopy(exists.getQueryPlan(), context);

View File

@ -60,13 +60,13 @@ public interface Repeat<CHILD_PLAN extends Plan> extends Aggregate<CHILD_PLAN> {
static VirtualSlotReference generateVirtualGroupingIdSlot() {
return new VirtualSlotReference(COL_GROUPING_ID, BigIntType.INSTANCE, Optional.empty(),
shapes -> shapes.computeVirtualGroupingIdValue());
GroupingSetShapes::computeVirtualGroupingIdValue);
}
static VirtualSlotReference generateVirtualSlotByFunction(GroupingScalarFunction function) {
return new VirtualSlotReference(
generateVirtualSlotName(function), function.getDataType(), Optional.of(function),
shapes -> function.computeVirtualSlotValue(shapes));
function::computeVirtualSlotValue);
}
/**
@ -175,7 +175,7 @@ public interface Repeat<CHILD_PLAN extends Plan> extends Aggregate<CHILD_PLAN> {
if (index == null) {
throw new AnalysisException("Can not find grouping set expression in output: " + expression);
}
if (groupingSetsIndex.contains(index)) {
if (groupingSetIndex.contains(index)) {
throw new AnalysisException("expression duplicate in grouping set: " + expression);
}
groupingSetIndex.add(index);
@ -228,14 +228,6 @@ public interface Repeat<CHILD_PLAN extends Plan> extends Aggregate<CHILD_PLAN> {
this.shapes = ImmutableList.copyOf(shapes);
}
public GroupingSetShape getGroupingSetShape(int index) {
return shapes.get(index);
}
public Expression getExpression(int index) {
return flattenGroupingSetExpression.get(index);
}
// compute a long value that backend need to fill to the GROUPING_ID slot
public List<Long> computeVirtualGroupingIdValue() {
return shapes.stream()