[fix](nereids) binding group by key on agg.output if output is slot (#15623)

case 1
`select count(1) from t1 join t2 on t1.a = t2.a group by a`
`group by a` is ambiguous

case 2
`select t1.a from t1 join t2 on t1.a = t2.a group by a`
`group by a` is bound on t1.a
This commit is contained in:
minghong
2023-01-12 16:34:56 +08:00
committed by GitHub
parent 0fbdf8e3e1
commit d23646793c
2 changed files with 93 additions and 2 deletions

View File

@ -74,6 +74,7 @@ import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import org.apache.commons.lang.StringUtils;
import java.util.ArrayList;
@ -227,8 +228,20 @@ public class BindSlotReference implements AnalysisRuleFactory {
group by key cannot bind with agg func
plan:
agg(group_by v, output sum(k) as v)
throw AnalysisException
CASE 4
sql:
`select count(1) from t1 join t2 group by a`
we cannot bind `group by a`, because it is ambiguous (t1.a and t2.a)
CASE 5
following case 4, if t1.a is in agg.output, we can bind `group by a` to t1.a
sql
select t1.a
from t1 join t2 on t1.a = t2.a
group by a
group_by_key is bound on t1.a
*/
duplicatedSlotNames.stream().forEach(dup -> childOutputsToExpr.remove(dup));
Map<String, Expression> aliasNameToExpr = output.stream()
@ -261,8 +274,31 @@ public class BindSlotReference implements AnalysisRuleFactory {
}
return groupBy;
}).collect(Collectors.toList());
/*
according to case 4 and case 5, we construct boundSlots
*/
Set<String> outputSlotNames = Sets.newHashSet();
Set<Slot> outputSlots = output.stream()
.filter(SlotReference.class::isInstance)
.peek(slot -> outputSlotNames.add(slot.getName()))
.map(NamedExpression::toSlot).collect(
Collectors.toSet());
//suppose group by key is a.
// if both t1.a and t2.a are in agg.child.output, and t1.a in agg.output,
// bind group_by_key a with t1.a
// ` .filter(slot -> !outputSlotNames.contains(slot.getName()))`
// is used to avoid add t2.a into boundSlots
Set<Slot> boundSlots = agg.child().getOutputSet().stream()
.filter(slot -> !outputSlotNames.contains(slot.getName()))
.collect(Collectors.toSet());
boundSlots.addAll(outputSlots);
SlotBinder binder = new SlotBinder(toScope(Lists.newArrayList(boundSlots)), ctx.cascadesContext);
List<Expression> groupBy = replacedGroupBy.stream()
.map(expression -> binder.bind(expression))
.collect(Collectors.toList());
List<Expression> groupBy = bind(replacedGroupBy, agg.children(), agg, ctx.cascadesContext);
List<Expression> unboundGroupBys = Lists.newArrayList();
boolean hasUnbound = groupBy.stream().anyMatch(
expression -> {

View File

@ -19,17 +19,24 @@ package org.apache.doris.nereids.rules.analysis;
import org.apache.doris.nereids.analyzer.UnboundSlot;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.NamedExpressionUtil;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.logical.LogicalSubQueryAlias;
import org.apache.doris.nereids.trees.plans.logical.RelationUtil;
import org.apache.doris.nereids.util.MemoTestUtils;
import org.apache.doris.nereids.util.PlanChecker;
import org.apache.doris.nereids.util.PlanConstructor;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
@ -65,4 +72,52 @@ class BindSlotReferenceTest {
Assertions.assertTrue(exception.getMessage().contains("id#4"));
Assertions.assertTrue(exception.getMessage().contains("id#0"));
}
/*
select t1.id from student t1 join on student t2 on t1.di=t2.id group by id;
group_by_key bind on t1.id, not t2.id
*/
@Test
public void testGroupByOnJoin() {
LogicalOlapScan scan1 = new LogicalOlapScan(RelationUtil.newRelationId(), PlanConstructor.student);
LogicalSubQueryAlias sub1 = new LogicalSubQueryAlias("t1", scan1);
LogicalOlapScan scan2 = new LogicalOlapScan(RelationUtil.newRelationId(), PlanConstructor.student);
LogicalSubQueryAlias sub2 = new LogicalSubQueryAlias("t2", scan2);
LogicalJoin<LogicalSubQueryAlias<LogicalOlapScan>, LogicalSubQueryAlias<LogicalOlapScan>> join =
new LogicalJoin<>(JoinType.CROSS_JOIN, sub1, sub2);
LogicalAggregate<LogicalJoin> aggregate = new LogicalAggregate<>(
Lists.newArrayList(new UnboundSlot("id")), //group by
Lists.newArrayList(new UnboundSlot("t1", "id")), //output
join
);
PlanChecker checker = PlanChecker.from(MemoTestUtils.createConnectContext()).analyze(aggregate);
LogicalAggregate plan = (LogicalAggregate) checker.getCascadesContext().getMemo().copyOut();
SlotReference groupByKey = (SlotReference) plan.getGroupByExpressions().get(0);
SlotReference t1id = (SlotReference) ((LogicalJoin) plan.child()).left().getOutput().get(0);
SlotReference t2id = (SlotReference) ((LogicalJoin) plan.child()).right().getOutput().get(0);
Assertions.assertEquals(groupByKey.getExprId(), t1id.getExprId());
Assertions.assertNotEquals(t1id.getExprId(), t2id.getExprId());
}
/*
select count(1) from student t1 join on student t2 on t1.di=t2.id group by id;
group by key is ambiguous
*/
@Test
public void testGroupByOnJoinAmbiguous() {
LogicalOlapScan scan1 = new LogicalOlapScan(RelationUtil.newRelationId(), PlanConstructor.student);
LogicalSubQueryAlias sub1 = new LogicalSubQueryAlias("t1", scan1);
LogicalOlapScan scan2 = new LogicalOlapScan(RelationUtil.newRelationId(), PlanConstructor.student);
LogicalSubQueryAlias sub2 = new LogicalSubQueryAlias("t2", scan2);
LogicalJoin<LogicalSubQueryAlias<LogicalOlapScan>, LogicalSubQueryAlias<LogicalOlapScan>> join =
new LogicalJoin<>(JoinType.CROSS_JOIN, sub1, sub2);
LogicalAggregate<LogicalJoin> aggregate = new LogicalAggregate<>(
Lists.newArrayList(new UnboundSlot("id")), //group by
Lists.newArrayList(new Alias(new Count(new IntegerLiteral(1)), "count(1)")), //output
join
);
AnalysisException exception = Assertions.assertThrows(AnalysisException.class,
() -> PlanChecker.from(MemoTestUtils.createConnectContext()).analyze(aggregate));
Assertions.assertTrue(exception.getMessage().contains("id is ambiguous: "));
}
}