[fix](nereids) binding group by key on agg.output if output is slot (#15623)
case 1 `select count(1) from t1 join t2 on t1.a = t2.a group by a` `group by a` is ambiguous case 2 `select t1.a from t1 join t2 on t1.a = t2.a group by a` `group by a` is bound on t1.a
This commit is contained in:
@ -74,6 +74,7 @@ import com.google.common.base.Preconditions;
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import com.google.common.collect.ImmutableSet;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.google.common.collect.Sets;
|
||||
import org.apache.commons.lang.StringUtils;
|
||||
|
||||
import java.util.ArrayList;
|
||||
@ -227,8 +228,20 @@ public class BindSlotReference implements AnalysisRuleFactory {
|
||||
group by key cannot bind with agg func
|
||||
plan:
|
||||
agg(group_by v, output sum(k) as v)
|
||||
|
||||
throw AnalysisException
|
||||
|
||||
CASE 4
|
||||
sql:
|
||||
`select count(1) from t1 join t2 group by a`
|
||||
we cannot bind `group by a`, because it is ambiguous (t1.a and t2.a)
|
||||
|
||||
CASE 5
|
||||
following case 4, if t1.a is in agg.output, we can bind `group by a` to t1.a
|
||||
sql
|
||||
select t1.a
|
||||
from t1 join t2 on t1.a = t2.a
|
||||
group by a
|
||||
group_by_key is bound on t1.a
|
||||
*/
|
||||
duplicatedSlotNames.stream().forEach(dup -> childOutputsToExpr.remove(dup));
|
||||
Map<String, Expression> aliasNameToExpr = output.stream()
|
||||
@ -261,8 +274,31 @@ public class BindSlotReference implements AnalysisRuleFactory {
|
||||
}
|
||||
return groupBy;
|
||||
}).collect(Collectors.toList());
|
||||
/*
|
||||
according to case 4 and case 5, we construct boundSlots
|
||||
*/
|
||||
Set<String> outputSlotNames = Sets.newHashSet();
|
||||
Set<Slot> outputSlots = output.stream()
|
||||
.filter(SlotReference.class::isInstance)
|
||||
.peek(slot -> outputSlotNames.add(slot.getName()))
|
||||
.map(NamedExpression::toSlot).collect(
|
||||
Collectors.toSet());
|
||||
//suppose group by key is a.
|
||||
// if both t1.a and t2.a are in agg.child.output, and t1.a in agg.output,
|
||||
// bind group_by_key a with t1.a
|
||||
// ` .filter(slot -> !outputSlotNames.contains(slot.getName()))`
|
||||
// is used to avoid add t2.a into boundSlots
|
||||
Set<Slot> boundSlots = agg.child().getOutputSet().stream()
|
||||
.filter(slot -> !outputSlotNames.contains(slot.getName()))
|
||||
.collect(Collectors.toSet());
|
||||
|
||||
boundSlots.addAll(outputSlots);
|
||||
SlotBinder binder = new SlotBinder(toScope(Lists.newArrayList(boundSlots)), ctx.cascadesContext);
|
||||
|
||||
List<Expression> groupBy = replacedGroupBy.stream()
|
||||
.map(expression -> binder.bind(expression))
|
||||
.collect(Collectors.toList());
|
||||
|
||||
List<Expression> groupBy = bind(replacedGroupBy, agg.children(), agg, ctx.cascadesContext);
|
||||
List<Expression> unboundGroupBys = Lists.newArrayList();
|
||||
boolean hasUnbound = groupBy.stream().anyMatch(
|
||||
expression -> {
|
||||
|
||||
@ -19,17 +19,24 @@ package org.apache.doris.nereids.rules.analysis;
|
||||
|
||||
import org.apache.doris.nereids.analyzer.UnboundSlot;
|
||||
import org.apache.doris.nereids.exceptions.AnalysisException;
|
||||
import org.apache.doris.nereids.trees.expressions.Alias;
|
||||
import org.apache.doris.nereids.trees.expressions.NamedExpressionUtil;
|
||||
import org.apache.doris.nereids.trees.expressions.SlotReference;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral;
|
||||
import org.apache.doris.nereids.trees.plans.JoinType;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalSubQueryAlias;
|
||||
import org.apache.doris.nereids.trees.plans.logical.RelationUtil;
|
||||
import org.apache.doris.nereids.util.MemoTestUtils;
|
||||
import org.apache.doris.nereids.util.PlanChecker;
|
||||
import org.apache.doris.nereids.util.PlanConstructor;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import com.google.common.collect.Lists;
|
||||
import org.junit.jupiter.api.Assertions;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
@ -65,4 +72,52 @@ class BindSlotReferenceTest {
|
||||
Assertions.assertTrue(exception.getMessage().contains("id#4"));
|
||||
Assertions.assertTrue(exception.getMessage().contains("id#0"));
|
||||
}
|
||||
|
||||
/*
|
||||
select t1.id from student t1 join on student t2 on t1.di=t2.id group by id;
|
||||
group_by_key bind on t1.id, not t2.id
|
||||
*/
|
||||
@Test
|
||||
public void testGroupByOnJoin() {
|
||||
LogicalOlapScan scan1 = new LogicalOlapScan(RelationUtil.newRelationId(), PlanConstructor.student);
|
||||
LogicalSubQueryAlias sub1 = new LogicalSubQueryAlias("t1", scan1);
|
||||
LogicalOlapScan scan2 = new LogicalOlapScan(RelationUtil.newRelationId(), PlanConstructor.student);
|
||||
LogicalSubQueryAlias sub2 = new LogicalSubQueryAlias("t2", scan2);
|
||||
LogicalJoin<LogicalSubQueryAlias<LogicalOlapScan>, LogicalSubQueryAlias<LogicalOlapScan>> join =
|
||||
new LogicalJoin<>(JoinType.CROSS_JOIN, sub1, sub2);
|
||||
LogicalAggregate<LogicalJoin> aggregate = new LogicalAggregate<>(
|
||||
Lists.newArrayList(new UnboundSlot("id")), //group by
|
||||
Lists.newArrayList(new UnboundSlot("t1", "id")), //output
|
||||
join
|
||||
);
|
||||
PlanChecker checker = PlanChecker.from(MemoTestUtils.createConnectContext()).analyze(aggregate);
|
||||
LogicalAggregate plan = (LogicalAggregate) checker.getCascadesContext().getMemo().copyOut();
|
||||
SlotReference groupByKey = (SlotReference) plan.getGroupByExpressions().get(0);
|
||||
SlotReference t1id = (SlotReference) ((LogicalJoin) plan.child()).left().getOutput().get(0);
|
||||
SlotReference t2id = (SlotReference) ((LogicalJoin) plan.child()).right().getOutput().get(0);
|
||||
Assertions.assertEquals(groupByKey.getExprId(), t1id.getExprId());
|
||||
Assertions.assertNotEquals(t1id.getExprId(), t2id.getExprId());
|
||||
}
|
||||
|
||||
/*
|
||||
select count(1) from student t1 join on student t2 on t1.di=t2.id group by id;
|
||||
group by key is ambiguous
|
||||
*/
|
||||
@Test
|
||||
public void testGroupByOnJoinAmbiguous() {
|
||||
LogicalOlapScan scan1 = new LogicalOlapScan(RelationUtil.newRelationId(), PlanConstructor.student);
|
||||
LogicalSubQueryAlias sub1 = new LogicalSubQueryAlias("t1", scan1);
|
||||
LogicalOlapScan scan2 = new LogicalOlapScan(RelationUtil.newRelationId(), PlanConstructor.student);
|
||||
LogicalSubQueryAlias sub2 = new LogicalSubQueryAlias("t2", scan2);
|
||||
LogicalJoin<LogicalSubQueryAlias<LogicalOlapScan>, LogicalSubQueryAlias<LogicalOlapScan>> join =
|
||||
new LogicalJoin<>(JoinType.CROSS_JOIN, sub1, sub2);
|
||||
LogicalAggregate<LogicalJoin> aggregate = new LogicalAggregate<>(
|
||||
Lists.newArrayList(new UnboundSlot("id")), //group by
|
||||
Lists.newArrayList(new Alias(new Count(new IntegerLiteral(1)), "count(1)")), //output
|
||||
join
|
||||
);
|
||||
AnalysisException exception = Assertions.assertThrows(AnalysisException.class,
|
||||
() -> PlanChecker.from(MemoTestUtils.createConnectContext()).analyze(aggregate));
|
||||
Assertions.assertTrue(exception.getMessage().contains("id is ambiguous: "));
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user