[fix](Nereids) all slot in grouping sets in repeat node should be nullable (#15991)
according to be's code, all slot in grouping set should be nullable.
reference to be code (be3482e6d6/be/src/vec/exec/vrepeat_node.cpp (L113))
This commit is contained in:
@ -72,12 +72,12 @@ import org.apache.doris.planner.PlannerContext;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import com.google.common.collect.ImmutableSet;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.google.common.collect.Sets;
|
||||
import org.apache.commons.lang.StringUtils;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collection;
|
||||
import java.util.Collections;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
@ -923,7 +923,7 @@ public class BindSlotReference implements AnalysisRuleFactory {
|
||||
.filter(ExpressionTrait::nullable)
|
||||
.collect(Collectors.toSet());
|
||||
return projects.stream()
|
||||
.map(e -> e.accept(new RewriteNullableToTrue(childrenOutput), null))
|
||||
.map(e -> e.accept(RewriteNullableToTrue.INSTANCE, childrenOutput))
|
||||
.map(NamedExpression.class::cast)
|
||||
.collect(ImmutableList.toImmutableList());
|
||||
}
|
||||
@ -939,7 +939,7 @@ public class BindSlotReference implements AnalysisRuleFactory {
|
||||
.filter(ExpressionTrait::nullable)
|
||||
.collect(Collectors.toSet());
|
||||
return output.stream()
|
||||
.map(e -> e.accept(new RewriteNullableToTrue(childrenOutput), null))
|
||||
.map(e -> e.accept(RewriteNullableToTrue.INSTANCE, childrenOutput))
|
||||
.map(NamedExpression.class::cast)
|
||||
.collect(ImmutableList.toImmutableList());
|
||||
}
|
||||
@ -951,24 +951,21 @@ public class BindSlotReference implements AnalysisRuleFactory {
|
||||
List<List<Expression>> groupingSets,
|
||||
List<NamedExpression> output) {
|
||||
Set<Slot> groupingSetsSlots = groupingSets.stream()
|
||||
.flatMap(e -> e.stream())
|
||||
.flatMap(e -> e.<Set<SlotReference>>collect(SlotReference.class::isInstance).stream())
|
||||
.flatMap(Collection::stream)
|
||||
.map(Expression::getInputSlots)
|
||||
.flatMap(Set::stream)
|
||||
.collect(Collectors.toSet());
|
||||
return output.stream()
|
||||
.map(e -> e.accept(new RewriteNullableToTrue(groupingSetsSlots), null))
|
||||
.map(e -> e.accept(RewriteNullableToTrue.INSTANCE, groupingSetsSlots))
|
||||
.map(NamedExpression.class::cast)
|
||||
.collect(ImmutableList.toImmutableList());
|
||||
}
|
||||
|
||||
private static class RewriteNullableToTrue extends DefaultExpressionRewriter<PlannerContext> {
|
||||
private final Set<Slot> childrenOutput;
|
||||
|
||||
public RewriteNullableToTrue(Set<Slot> childrenOutput) {
|
||||
this.childrenOutput = ImmutableSet.copyOf(childrenOutput);
|
||||
}
|
||||
private static class RewriteNullableToTrue extends DefaultExpressionRewriter<Set<Slot>> {
|
||||
public static RewriteNullableToTrue INSTANCE = new RewriteNullableToTrue();
|
||||
|
||||
@Override
|
||||
public Expression visitSlotReference(SlotReference slotReference, PlannerContext context) {
|
||||
public Expression visitSlotReference(SlotReference slotReference, Set<Slot> childrenOutput) {
|
||||
if (childrenOutput.contains(slotReference)) {
|
||||
return slotReference.withNullable(true);
|
||||
}
|
||||
|
||||
@ -223,7 +223,6 @@ public class NormalizeRepeat extends OneAnalysisRuleFactory {
|
||||
}
|
||||
|
||||
List<Expression> groupingSetExpressions = ExpressionUtils.flatExpressions(repeat.getGroupingSets());
|
||||
Set<Expression> commonGroupingSetExpressions = repeat.getCommonGroupingSetExpressions();
|
||||
|
||||
// nullable will be different from grouping set and output expressions,
|
||||
// so we can not use the slot in grouping set,but use the equivalent slot in output expressions.
|
||||
@ -236,9 +235,7 @@ public class NormalizeRepeat extends OneAnalysisRuleFactory {
|
||||
expression = outputs.get(outputs.indexOf(expression));
|
||||
}
|
||||
if (groupingSetExpressions.contains(expression)) {
|
||||
boolean isCommonGroupingSetExpression = commonGroupingSetExpressions.contains(expression);
|
||||
pushDownTriplet = toGroupingSetExpressionPushDownTriplet(
|
||||
isCommonGroupingSetExpression, expression, existsAliasMap.get(expression));
|
||||
pushDownTriplet = toGroupingSetExpressionPushDownTriplet(expression, existsAliasMap.get(expression));
|
||||
} else {
|
||||
pushDownTriplet = Optional.of(
|
||||
NormalizeToSlotTriplet.toTriplet(expression, existsAliasMap.get(expression)));
|
||||
@ -252,10 +249,10 @@ public class NormalizeRepeat extends OneAnalysisRuleFactory {
|
||||
}
|
||||
|
||||
private Optional<NormalizeToSlotTriplet> toGroupingSetExpressionPushDownTriplet(
|
||||
boolean isCommonGroupingSetExpression, Expression expression, @Nullable Alias existsAlias) {
|
||||
Expression expression, @Nullable Alias existsAlias) {
|
||||
NormalizeToSlotTriplet originTriplet = NormalizeToSlotTriplet.toTriplet(expression, existsAlias);
|
||||
SlotReference remainSlot = (SlotReference) originTriplet.remainExpr;
|
||||
Slot newSlot = remainSlot.withCommonGroupingSetExpression(isCommonGroupingSetExpression);
|
||||
Slot newSlot = remainSlot.withNullable(true);
|
||||
return Optional.of(new NormalizeToSlotTriplet(expression, newSlot, originTriplet.pushedExpr));
|
||||
}
|
||||
|
||||
|
||||
@ -186,12 +186,4 @@ public class SlotReference extends Slot {
|
||||
public Slot withName(String name) {
|
||||
return new SlotReference(exprId, name, dataType, nullable, qualifier, column);
|
||||
}
|
||||
|
||||
/** withCommonGroupingSetExpression */
|
||||
public Slot withCommonGroupingSetExpression(boolean isCommonGroupingSetExpression) {
|
||||
if (!isCommonGroupingSetExpression) {
|
||||
return withNullable(true);
|
||||
}
|
||||
return this;
|
||||
}
|
||||
}
|
||||
|
||||
@ -76,7 +76,6 @@ public interface Repeat<CHILD_PLAN extends Plan> extends Aggregate<CHILD_PLAN> {
|
||||
*/
|
||||
default Set<Expression> getCommonGroupingSetExpressions() {
|
||||
List<List<Expression>> groupingSets = getGroupingSets();
|
||||
Sets.newLinkedHashSet();
|
||||
Iterator<List<Expression>> iterator = groupingSets.iterator();
|
||||
Set<Expression> commonGroupingExpressions = Sets.newLinkedHashSet(iterator.next());
|
||||
while (iterator.hasNext()) {
|
||||
|
||||
Reference in New Issue
Block a user