[Fix](Nereids)fix group by binding error, resulting in incorrect results (#15328)

Original: group by is bound to the outputExpression of the current node.

Problem: When the name of the new reference of outputExpression is the same as the child's output column, the child's output column should be used for group by, but at this time, the new reference of the node's outputExpression will be used for group by, resulting in an error

Now: Give priority to the child's output for group by binding. If the child does not have a corresponding column, use the outputExpression of this node for binding
This commit is contained in:
zhengshiJ
2022-12-28 10:42:21 +08:00
committed by GitHub
parent 9f9651b2f2
commit 2af831de33
6 changed files with 133 additions and 6 deletions

View File

@ -167,18 +167,26 @@ public class BindSlotReference implements AnalysisRuleFactory {
LogicalAggregate<GroupPlan> agg = ctx.root;
List<NamedExpression> output =
bind(agg.getOutputExpressions(), agg.children(), agg, ctx.cascadesContext);
// The columns referenced in group by are first obtained from the child's output,
// and then from the node's output
Map<String, Expression> childOutputsToExpr = agg.child().getOutput().stream()
.collect(Collectors.toMap(Slot::getName, Slot::toSlot, (oldExpr, newExpr) -> oldExpr));
Map<String, Expression> aliasNameToExpr = output.stream()
.filter(ne -> ne instanceof Alias)
.map(Alias.class::cast)
.collect(Collectors.toMap(Alias::getName, UnaryNode::child, (oldExpr, newExpr) -> oldExpr));
aliasNameToExpr.entrySet().stream()
.forEach(e -> childOutputsToExpr.putIfAbsent(e.getKey(), e.getValue()));
List<Expression> replacedGroupBy = agg.getGroupByExpressions().stream()
.map(groupBy -> {
if (groupBy instanceof UnboundSlot) {
UnboundSlot unboundSlot = (UnboundSlot) groupBy;
if (unboundSlot.getNameParts().size() == 1) {
String name = unboundSlot.getNameParts().get(0);
if (aliasNameToExpr.containsKey(name)) {
return aliasNameToExpr.get(name);
if (childOutputsToExpr.containsKey(name)) {
return childOutputsToExpr.get(name);
}
}
}
@ -197,10 +205,17 @@ public class BindSlotReference implements AnalysisRuleFactory {
List<NamedExpression> output =
bind(repeat.getOutputExpressions(), repeat.children(), repeat, ctx.cascadesContext);
// The columns referenced in group by are first obtained from the child's output,
// and then from the node's output
Map<String, Expression> childOutputsToExpr = repeat.child().getOutput().stream()
.collect(Collectors.toMap(Slot::getName, Slot::toSlot, (oldExpr, newExpr) -> oldExpr));
Map<String, Expression> aliasNameToExpr = output.stream()
.filter(ne -> ne instanceof Alias)
.map(Alias.class::cast)
.collect(Collectors.toMap(Alias::getName, UnaryNode::child, (oldExpr, newExpr) -> oldExpr));
aliasNameToExpr.entrySet().stream()
.forEach(e -> childOutputsToExpr.putIfAbsent(e.getKey(), e.getValue()));
List<List<Expression>> replacedGroupingSets = repeat.getGroupingSets().stream()
.map(groupBy ->
groupBy.stream().map(expr -> {
@ -208,8 +223,8 @@ public class BindSlotReference implements AnalysisRuleFactory {
UnboundSlot unboundSlot = (UnboundSlot) expr;
if (unboundSlot.getNameParts().size() == 1) {
String name = unboundSlot.getNameParts().get(0);
if (aliasNameToExpr.containsKey(name)) {
return aliasNameToExpr.get(name);
if (childOutputsToExpr.containsKey(name)) {
return childOutputsToExpr.get(name);
}
}
}

View File

@ -200,7 +200,7 @@ public class NormalizeRepeat extends OneAnalysisRuleFactory {
}
private Plan pushDownProject(Set<NamedExpression> pushedExprs, Plan originBottomPlan) {
if (!pushedExprs.equals(originBottomPlan.getOutputSet())) {
if (!pushedExprs.equals(originBottomPlan.getOutputSet()) && !pushedExprs.isEmpty()) {
return new LogicalProject<>(ImmutableList.copyOf(pushedExprs), originBottomPlan);
}
return originBottomPlan;

View File

@ -93,7 +93,8 @@ public class NormalizeAggregate extends OneRewriteRuleFactory implements Normali
aggregateFunctionToSlotContext.pushDownToNamedExpression(normalizedAggregateFunctions);
List<Slot> normalizedGroupBy =
(List) groupByAndArgumentToSlotContext.normalizeToUseSlotRef(aggregate.getGroupByExpressions());
(List) groupByAndArgumentToSlotContext
.normalizeToUseSlotRef(aggregate.getGroupByExpressions());
// we can safely add all groupBy and aggregate functions to output, because we will
// add a project on it, and the upper project can protect the scope of visible of slot

View File

@ -183,4 +183,53 @@ public class GroupingSetsTest extends TestWithFeService {
PlanChecker.from(connectContext)
.checkPlannerResult("select if(k1 = 1, 2, k1) k_if from t1");
}
@Test
public void test1() {
PlanChecker.from(connectContext)
.checkPlannerResult("select coalesce(col1, 'all') as col1, count(*) as cnt from"
+ " (select null as col1 union all select 'a' as col1 ) t group by grouping sets ((col1),());");
}
@Test
public void test1_1() {
PlanChecker.from(connectContext)
.checkPlannerResult("select coalesce(col1, 'all') as col2, count(*) as cnt from"
+ " (select null as col1 union all select 'a' as col1 ) t group by grouping sets ((col1),());");
}
@Test
public void test1_2() {
PlanChecker.from(connectContext)
.checkPlannerResult("select coalesce(col1, 'all') as col2, count(*) as cnt from"
+ " (select null as col1 union all select 'a' as col1 ) t group by grouping sets ((col2),());");
}
@Test
public void test2() {
PlanChecker.from(connectContext)
.checkPlannerResult("select if(1 = null, 'all', 2) as col1, count(*) as cnt from"
+ " (select null as col1 union all select 'a' as col1 ) t group by grouping sets ((col1),());");
}
@Test
public void test2_1() {
PlanChecker.from(connectContext)
.checkPlannerResult("select if(col1 = null, 'all', 2) as col1, count(*) as cnt from"
+ " (select null as col1 union all select 'a' as col1 ) t group by grouping sets ((col1),());");
}
@Test
public void test2_2() {
PlanChecker.from(connectContext)
.checkPlannerResult("select if(col1 = null, 'all', 2) as col2, count(*) as cnt from"
+ " (select null as col1 union all select 'a' as col1 ) t group by grouping sets ((col1),());");
}
@Test
public void test2_3() {
PlanChecker.from(connectContext)
.checkPlannerResult("select if(col1 = null, 'all', 2) as col2, count(*) as cnt from"
+ " (select null as col1 union all select 'a' as col1 ) t group by grouping sets ((col2),());");
}
}

View File

@ -224,3 +224,37 @@
3
4
-- !select1 --
a 1
all 1
all 2
-- !select2 --
a 1
all 1
all 2
-- !select3 --
\N 2
a 1
all 1
-- !select4 --
2 1
2 1
2 2
-- !select5 --
2 1
2 1
2 2
-- !select6 --
2 1
2 1
2 2
-- !select7 --
2 1
2 1
2 2

View File

@ -232,4 +232,32 @@ suite("test_nereids_grouping_sets") {
) T
) T2;
"""
order_qt_select1 """
select coalesce(col1, 'all') as col1, count(*) as cnt from (select null as col1 union all select 'a' as col1 ) t group by grouping sets ((col1),());
"""
order_qt_select2 """
select coalesce(col1, 'all') as col2, count(*) as cnt from (select null as col1 union all select 'a' as col1 ) t group by grouping sets ((col1),());
"""
order_qt_select3 """
select coalesce(col1, 'all') as col2, count(*) as cnt from (select null as col1 union all select 'a' as col1 ) t group by grouping sets ((col2),());
"""
order_qt_select4 """
select if(1 = null, 'all', 2) as col1, count(*) as cnt from (select null as col1 union all select 'a' as col1 ) t group by grouping sets ((col1),());
"""
order_qt_select5 """
select if(col1 = null, 'all', 2) as col1, count(*) as cnt from (select null as col1 union all select 'a' as col1 ) t group by grouping sets ((col1),());
"""
order_qt_select6 """
select if(1 = null, 'all', 2) as col2, count(*) as cnt from (select null as col1 union all select 'a' as col1 ) t group by grouping sets ((col1),());
"""
order_qt_select7 """
select if(col1 = null, 'all', 2) as col2, count(*) as cnt from (select null as col1 union all select 'a' as col1 ) t group by grouping sets ((col1),());
"""
}