[fix](Nereids) update immutable LogicalAggregate attribute by mistake (#13740)

This commit is contained in:
jakevin
2022-10-31 14:11:55 +08:00
committed by GitHub
parent 2fb218173e
commit ceb7b60a64
7 changed files with 145 additions and 185 deletions

View File

@ -27,6 +27,7 @@ import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.statistics.StatsDeriveResult;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
@ -43,7 +44,7 @@ public class GroupExpression {
private double cost = 0.0;
private CostEstimate costEstimate = null;
private Group ownerGroup;
private List<Group> children;
private ImmutableList<Group> children;
private final Plan plan;
private final BitSet ruleMasks;
private boolean statDerived;
@ -66,7 +67,7 @@ public class GroupExpression {
public GroupExpression(Plan plan, List<Group> children) {
this.plan = Objects.requireNonNull(plan, "plan can not be null")
.withGroupExpression(Optional.of(this));
this.children = Lists.newArrayList(Objects.requireNonNull(children, "children can not be null"));
this.children = ImmutableList.copyOf(Objects.requireNonNull(children, "children can not be null"));
this.ruleMasks = new BitSet(RuleType.SENTINEL.ordinal());
this.statDerived = false;
this.lowestCostTable = Maps.newHashMap();
@ -84,10 +85,6 @@ public class GroupExpression {
return children.size();
}
public void addChild(Group child) {
children.add(child);
}
public Group getOwnerGroup() {
return ownerGroup;
}
@ -108,12 +105,13 @@ public class GroupExpression {
return children;
}
public void setChildren(List<Group> children) {
public void setChildren(ImmutableList<Group> children) {
this.children = children;
}
/**
* replaceChild.
*
* @param originChild origin child group
* @param newChild new child group
*/

View File

@ -203,7 +203,7 @@ public class Memo {
/**
* add or replace the plan into the target group.
*
* <p>
* the result truth table:
* <pre>
* +---------------------------------------+-----------------------------------+--------------------------------+
@ -296,8 +296,7 @@ public class Memo {
}
}
plan = replaceChildrenToGroupPlan(plan, childrenGroups);
GroupExpression newGroupExpression = new GroupExpression(plan);
newGroupExpression.setChildren(childrenGroups);
GroupExpression newGroupExpression = new GroupExpression(plan, childrenGroups);
return insertGroupExpression(newGroupExpression, targetGroup, plan.getLogicalProperties());
// TODO: need to derive logical property if generate new group. currently we not copy logical plan into
}
@ -388,13 +387,15 @@ public class Memo {
}
for (GroupExpression groupExpression : needReplaceChild) {
groupExpressions.remove(groupExpression);
List<Group> children = groupExpression.children();
List<Group> children = new ArrayList<>(groupExpression.children());
// TODO: use a better way to replace child, avoid traversing all groupExpression
for (int i = 0; i < children.size(); i++) {
if (children.get(i).equals(source)) {
children.set(i, destination);
}
}
groupExpression.setChildren(ImmutableList.copyOf(children));
GroupExpression that = groupExpressions.get(groupExpression);
if (that != null && that.getOwnerGroup() != null
&& !that.getOwnerGroup().equals(groupExpression.getOwnerGroup())) {
@ -487,14 +488,14 @@ public class Memo {
/**
* eliminate fromGroup, clear targetGroup, then move the logical group expressions in the fromGroup to the toGroup.
*
* <p>
* the scenario is:
* ```
* Group 1(project, the targetGroup) Group 1(logicalOlapScan, the targetGroup)
* | =>
* Group 0(logicalOlapScan, the fromGroup)
* ```
*
* <p>
* we should recycle the group 0, and recycle all group expressions in group 1, then move the logicalOlapScan to
* the group 1, and reset logical properties of the group 1.
*/

View File

@ -34,6 +34,7 @@ import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
@ -147,7 +148,7 @@ public class AggregateDisassemble extends OneRewriteRuleFactory {
// +-----------+---------------------+-------------------------+--------------------------------+
// NOTICE: Ref: SlotReference, A: Alias, AF: AggregateFunction, #x: ExprId x
// 2. collect local aggregate output expressions and local aggregate group by expression list
List<Expression> localGroupByExprs = aggregate.getGroupByExpressions();
List<Expression> localGroupByExprs = new ArrayList<>(aggregate.getGroupByExpressions());
List<NamedExpression> localOutputExprs = Lists.newArrayList();
for (Expression originGroupByExpr : originGroupByExprs) {
if (inputSubstitutionMap.containsKey(originGroupByExpr)) {

View File

@ -119,8 +119,8 @@ public class LogicalAggregate<CHILD_TYPE extends Plan> extends LogicalUnary<CHIL
Optional<LogicalProperties> logicalProperties,
CHILD_TYPE child) {
super(PlanType.LOGICAL_AGGREGATE, groupExpression, logicalProperties, child);
this.groupByExpressions = groupByExpressions;
this.outputExpressions = outputExpressions;
this.groupByExpressions = ImmutableList.copyOf(groupByExpressions);
this.outputExpressions = ImmutableList.copyOf(outputExpressions);
this.partitionExpressions = partitionExpressions;
this.disassembled = disassembled;
this.normalized = normalized;

View File

@ -71,7 +71,7 @@ class MemoTest implements PatternMatchSupported {
JoinType.INNER_JOIN, logicalJoinAB, PlanConstructor.newLogicalOlapScan(2, "C", 0));
@Test
void mergeGroup() throws Exception {
void mergeGroup() {
Memo memo = new Memo();
GroupId gid2 = new GroupId(2);
Group srcGroup = new Group(gid2, new GroupExpression(new FakePlan()), new LogicalProperties(ArrayList::new));
@ -85,13 +85,13 @@ class MemoTest implements PatternMatchSupported {
GroupExpression ge2 = new GroupExpression(d, Arrays.asList(dstGroup));
GroupId gid1 = new GroupId(1);
Group g2 = new Group(gid1, ge2, new LogicalProperties(ArrayList::new));
Map<GroupId, Group> groups = (Map<GroupId, Group>) Deencapsulation.getField(memo, "groups");
Map<GroupId, Group> groups = Deencapsulation.getField(memo, "groups");
groups.put(gid2, srcGroup);
groups.put(gid3, dstGroup);
groups.put(gid0, g1);
groups.put(gid1, g2);
Map<GroupExpression, GroupExpression> groupExpressions =
(Map<GroupExpression, GroupExpression>) Deencapsulation.getField(memo, "groupExpressions");
Deencapsulation.getField(memo, "groupExpressions");
groupExpressions.put(ge1, ge1);
groupExpressions.put(ge2, ge2);
memo.mergeGroup(srcGroup, dstGroup);

View File

@ -22,7 +22,6 @@ import org.apache.doris.nereids.trees.expressions.Add;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral;
@ -31,14 +30,13 @@ import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.RelationId;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.trees.plans.logical.LogicalUnary;
import org.apache.doris.nereids.util.MemoTestUtils;
import org.apache.doris.nereids.util.PatternMatchSupported;
import org.apache.doris.nereids.util.PlanChecker;
import org.apache.doris.nereids.util.PlanConstructor;
import org.apache.doris.nereids.util.PlanRewriter;
import org.apache.doris.qe.ConnectContext;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInstance;
@ -46,15 +44,17 @@ import org.junit.jupiter.api.TestInstance;
import java.util.List;
@TestInstance(TestInstance.Lifecycle.PER_CLASS)
public class AggregateDisassembleTest {
public class AggregateDisassembleTest implements PatternMatchSupported {
private Plan rStudent;
@BeforeAll
public final void beforeAll() {
rStudent = new LogicalOlapScan(RelationId.createGenerator().getNextId(), PlanConstructor.student, ImmutableList.of(""));
rStudent = new LogicalOlapScan(RelationId.createGenerator().getNextId(), PlanConstructor.student,
ImmutableList.of(""));
}
/**
* <pre>
* the initial plan is:
* Aggregate(phase: [GLOBAL], outputExpr: [age, SUM(id) as sum], groupByExpr: [age])
* +--childPlan(id, name, age)
@ -62,6 +62,7 @@ public class AggregateDisassembleTest {
* Aggregate(phase: [GLOBAL], outputExpr: [a, SUM(b) as c], groupByExpr: [a])
* +--Aggregate(phase: [LOCAL], outputExpr: [age as a, SUM(id) as b], groupByExpr: [age])
* +--childPlan(id, name, age)
* </pre>
*/
@Test
public void slotReferenceGroupBy() {
@ -70,50 +71,43 @@ public class AggregateDisassembleTest {
List<NamedExpression> outputExpressionList = Lists.newArrayList(
rStudent.getOutput().get(2).toSlot(),
new Alias(new Sum(rStudent.getOutput().get(0).toSlot()), "sum"));
Plan root = new LogicalAggregate(groupExpressionList, outputExpressionList, rStudent);
Plan after = rewrite(root);
Assertions.assertTrue(after instanceof LogicalUnary);
Assertions.assertTrue(after instanceof LogicalAggregate);
Assertions.assertTrue(after.child(0) instanceof LogicalUnary);
LogicalAggregate<Plan> global = (LogicalAggregate) after;
LogicalAggregate<Plan> local = (LogicalAggregate) after.child(0);
Assertions.assertEquals(AggPhase.GLOBAL, global.getAggPhase());
Assertions.assertEquals(AggPhase.LOCAL, local.getAggPhase());
Plan root = new LogicalAggregate<>(groupExpressionList, outputExpressionList, rStudent);
Expression localOutput0 = rStudent.getOutput().get(2).toSlot();
Expression localOutput1 = new Sum(rStudent.getOutput().get(0).toSlot());
Expression localGroupBy = rStudent.getOutput().get(2).toSlot();
Assertions.assertEquals(2, local.getOutputExpressions().size());
Assertions.assertTrue(local.getOutputExpressions().get(0) instanceof SlotReference);
Assertions.assertEquals(localOutput0, local.getOutputExpressions().get(0));
Assertions.assertTrue(local.getOutputExpressions().get(1) instanceof Alias);
Assertions.assertEquals(localOutput1, local.getOutputExpressions().get(1).child(0));
Assertions.assertEquals(1, local.getGroupByExpressions().size());
Assertions.assertEquals(localGroupBy, local.getGroupByExpressions().get(0));
Expression globalOutput0 = local.getOutputExpressions().get(0).toSlot();
Expression globalOutput1 = new Sum(local.getOutputExpressions().get(1).toSlot());
Expression globalGroupBy = local.getOutputExpressions().get(0).toSlot();
Assertions.assertEquals(2, global.getOutputExpressions().size());
Assertions.assertTrue(global.getOutputExpressions().get(0) instanceof SlotReference);
Assertions.assertEquals(globalOutput0, global.getOutputExpressions().get(0));
Assertions.assertTrue(global.getOutputExpressions().get(1) instanceof Alias);
Assertions.assertEquals(globalOutput1, global.getOutputExpressions().get(1).child(0));
Assertions.assertEquals(1, global.getGroupByExpressions().size());
Assertions.assertEquals(globalGroupBy, global.getGroupByExpressions().get(0));
// check id:
Assertions.assertEquals(outputExpressionList.get(0).getExprId(),
global.getOutputExpressions().get(0).getExprId());
Assertions.assertEquals(outputExpressionList.get(1).getExprId(),
global.getOutputExpressions().get(1).getExprId());
PlanChecker.from(MemoTestUtils.createConnectContext(), root)
.applyTopDown(new AggregateDisassemble())
.printlnTree()
.matchesFromRoot(
logicalAggregate(
logicalAggregate()
.when(agg -> agg.getAggPhase().equals(AggPhase.LOCAL))
.when(agg -> agg.getOutputExpressions().size() == 2)
.when(agg -> agg.getOutputExpressions().get(0).equals(localOutput0))
.when(agg -> agg.getOutputExpressions().get(1).child(0).equals(localOutput1))
.when(agg -> agg.getGroupByExpressions().size() == 1)
.when(agg -> agg.getGroupByExpressions().get(0).equals(localGroupBy))
).when(agg -> agg.getAggPhase().equals(AggPhase.GLOBAL))
.when(agg -> agg.getOutputExpressions().size() == 2)
.when(agg -> agg.getOutputExpressions().get(0)
.equals(agg.child().getOutputExpressions().get(0).toSlot()))
.when(agg -> agg.getOutputExpressions().get(1).child(0)
.equals(new Sum(agg.child().getOutputExpressions().get(1).toSlot())))
.when(agg -> agg.getGroupByExpressions().size() == 1)
.when(agg -> agg.getGroupByExpressions().get(0)
.equals(agg.child().getOutputExpressions().get(0).toSlot()))
// check id:
.when(agg -> agg.getOutputExpressions().get(0).getExprId()
.equals(outputExpressionList.get(0).getExprId()))
.when(agg -> agg.getOutputExpressions().get(1).getExprId()
.equals(outputExpressionList.get(1).getExprId()))
);
}
/**
* <pre>
* the initial plan is:
* Aggregate(phase: [GLOBAL], outputExpr: [SUM(id) as sum], groupByExpr: [])
* +--childPlan(id, name, age)
@ -121,44 +115,41 @@ public class AggregateDisassembleTest {
* Aggregate(phase: [GLOBAL], outputExpr: [SUM(b) as b], groupByExpr: [])
* +--Aggregate(phase: [LOCAL], outputExpr: [SUM(id) as a], groupByExpr: [])
* +--childPlan(id, name, age)
* </pre>
*/
@Test
public void globalAggregate() {
List<Expression> groupExpressionList = Lists.newArrayList();
List<NamedExpression> outputExpressionList = Lists.newArrayList(
new Alias(new Sum(rStudent.getOutput().get(0).toSlot()), "sum"));
Plan root = new LogicalAggregate(groupExpressionList, outputExpressionList, rStudent);
Plan after = rewrite(root);
Assertions.assertTrue(after instanceof LogicalUnary);
Assertions.assertTrue(after instanceof LogicalAggregate);
Assertions.assertTrue(after.child(0) instanceof LogicalUnary);
LogicalAggregate<Plan> global = (LogicalAggregate) after;
LogicalAggregate<Plan> local = (LogicalAggregate) after.child(0);
Assertions.assertEquals(AggPhase.GLOBAL, global.getAggPhase());
Assertions.assertEquals(AggPhase.LOCAL, local.getAggPhase());
new Alias(new Sum(rStudent.getOutput().get(0)), "sum"));
Plan root = new LogicalAggregate<>(groupExpressionList, outputExpressionList, rStudent);
Expression localOutput0 = new Sum(rStudent.getOutput().get(0).toSlot());
Assertions.assertEquals(1, local.getOutputExpressions().size());
Assertions.assertTrue(local.getOutputExpressions().get(0) instanceof Alias);
Assertions.assertEquals(localOutput0, local.getOutputExpressions().get(0).child(0));
Assertions.assertEquals(0, local.getGroupByExpressions().size());
Expression globalOutput0 = new Sum(local.getOutputExpressions().get(0).toSlot());
Assertions.assertEquals(1, global.getOutputExpressions().size());
Assertions.assertTrue(global.getOutputExpressions().get(0) instanceof Alias);
Assertions.assertEquals(globalOutput0, global.getOutputExpressions().get(0).child(0));
Assertions.assertEquals(0, global.getGroupByExpressions().size());
// check id:
Assertions.assertEquals(outputExpressionList.get(0).getExprId(),
global.getOutputExpressions().get(0).getExprId());
PlanChecker.from(MemoTestUtils.createConnectContext(), root)
.applyTopDown(new AggregateDisassemble())
.printlnTree()
.matchesFromRoot(
logicalAggregate(
logicalAggregate()
.when(agg -> agg.getAggPhase().equals(AggPhase.LOCAL))
.when(agg -> agg.getOutputExpressions().size() == 1)
.when(agg -> agg.getOutputExpressions().get(0).child(0).equals(localOutput0))
.when(agg -> agg.getGroupByExpressions().size() == 0)
).when(agg -> agg.getAggPhase().equals(AggPhase.GLOBAL))
.when(agg -> agg.getOutputExpressions().size() == 1)
.when(agg -> agg.getOutputExpressions().get(0) instanceof Alias)
.when(agg -> agg.getOutputExpressions().get(0).child(0)
.equals(new Sum(agg.child().getOutputExpressions().get(0).toSlot())))
.when(agg -> agg.getGroupByExpressions().size() == 0)
// check id:
.when(agg -> agg.getOutputExpressions().get(0).getExprId()
.equals(outputExpressionList.get(0).getExprId()))
);
}
/**
* <pre>
* the initial plan is:
* Aggregate(phase: [GLOBAL], outputExpr: [SUM(id) as sum], groupByExpr: [age])
* +--childPlan(id, name, age)
@ -166,6 +157,7 @@ public class AggregateDisassembleTest {
* Aggregate(phase: [GLOBAL], outputExpr: [SUM(b) as c], groupByExpr: [a])
* +--Aggregate(phase: [LOCAL], outputExpr: [age as a, SUM(id) as b], groupByExpr: [age])
* +--childPlan(id, name, age)
* </pre>
*/
@Test
public void groupExpressionNotInOutput() {
@ -173,45 +165,40 @@ public class AggregateDisassembleTest {
rStudent.getOutput().get(2).toSlot());
List<NamedExpression> outputExpressionList = Lists.newArrayList(
new Alias(new Sum(rStudent.getOutput().get(0).toSlot()), "sum"));
Plan root = new LogicalAggregate(groupExpressionList, outputExpressionList, rStudent);
Plan after = rewrite(root);
Assertions.assertTrue(after instanceof LogicalUnary);
Assertions.assertTrue(after instanceof LogicalAggregate);
Assertions.assertTrue(after.child(0) instanceof LogicalUnary);
LogicalAggregate<Plan> global = (LogicalAggregate) after;
LogicalAggregate<Plan> local = (LogicalAggregate) after.child(0);
Assertions.assertEquals(AggPhase.GLOBAL, global.getAggPhase());
Assertions.assertEquals(AggPhase.LOCAL, local.getAggPhase());
Plan root = new LogicalAggregate<>(groupExpressionList, outputExpressionList, rStudent);
Expression localOutput0 = rStudent.getOutput().get(2).toSlot();
Expression localOutput1 = new Sum(rStudent.getOutput().get(0).toSlot());
Expression localGroupBy = rStudent.getOutput().get(2).toSlot();
Assertions.assertEquals(2, local.getOutputExpressions().size());
Assertions.assertTrue(local.getOutputExpressions().get(0) instanceof SlotReference);
Assertions.assertEquals(localOutput0, local.getOutputExpressions().get(0));
Assertions.assertTrue(local.getOutputExpressions().get(1) instanceof Alias);
Assertions.assertEquals(localOutput1, local.getOutputExpressions().get(1).child(0));
Assertions.assertEquals(1, local.getGroupByExpressions().size());
Assertions.assertEquals(localGroupBy, local.getGroupByExpressions().get(0));
Expression globalOutput0 = new Sum(local.getOutputExpressions().get(1).toSlot());
Expression globalGroupBy = local.getOutputExpressions().get(0).toSlot();
Assertions.assertEquals(1, global.getOutputExpressions().size());
Assertions.assertTrue(global.getOutputExpressions().get(0) instanceof Alias);
Assertions.assertEquals(globalOutput0, global.getOutputExpressions().get(0).child(0));
Assertions.assertEquals(1, global.getGroupByExpressions().size());
Assertions.assertEquals(globalGroupBy, global.getGroupByExpressions().get(0));
// check id:
Assertions.assertEquals(outputExpressionList.get(0).getExprId(),
global.getOutputExpressions().get(0).getExprId());
PlanChecker.from(MemoTestUtils.createConnectContext(), root)
.applyTopDown(new AggregateDisassemble())
.printlnTree()
.matchesFromRoot(
logicalAggregate(
logicalAggregate()
.when(agg -> agg.getAggPhase().equals(AggPhase.LOCAL))
.when(agg -> agg.getOutputExpressions().size() == 2)
.when(agg -> agg.getOutputExpressions().get(0).equals(localOutput0))
.when(agg -> agg.getOutputExpressions().get(1).child(0).equals(localOutput1))
.when(agg -> agg.getGroupByExpressions().size() == 1)
.when(agg -> agg.getGroupByExpressions().get(0).equals(localGroupBy))
).when(agg -> agg.getAggPhase().equals(AggPhase.GLOBAL))
.when(agg -> agg.getOutputExpressions().size() == 1)
.when(agg -> agg.getOutputExpressions().get(0) instanceof Alias)
.when(agg -> agg.getOutputExpressions().get(0).child(0)
.equals(new Sum(agg.child().getOutputExpressions().get(1).toSlot())))
.when(agg -> agg.getGroupByExpressions().size() == 1)
.when(agg -> agg.getGroupByExpressions().get(0)
.equals(agg.child().getOutputExpressions().get(0).toSlot()))
// check id:
.when(agg -> agg.getOutputExpressions().get(0).getExprId()
.equals(outputExpressionList.get(0).getExprId()))
);
}
/**
* <pre>
* the initial plan is:
* Aggregate(phase: [GLOBAL], outputExpr: [(COUNT(distinct age) + 2) as c], groupByExpr: [id])
* +-- childPlan(id, name, age)
@ -220,6 +207,7 @@ public class AggregateDisassembleTest {
* +-- Aggregate(phase: [GLOBAL], outputExpr: [id, age], groupByExpr: [id, age])
* +-- Aggregate(phase: [LOCAL], outputExpr: [id, age], groupByExpr: [id, age])
* +-- childPlan(id, name, age)
* </pre>
*/
@Test
public void distinctAggregateWithGroupBy() {
@ -229,68 +217,44 @@ public class AggregateDisassembleTest {
new IntegerLiteral(2)), "c"));
Plan root = new LogicalAggregate<>(groupExpressionList, outputExpressionList, rStudent);
Plan after = rewrite(root);
Assertions.assertTrue(after instanceof LogicalUnary);
Assertions.assertTrue(after instanceof LogicalAggregate);
Assertions.assertTrue(after.child(0) instanceof LogicalUnary);
LogicalAggregate<Plan> distinctLocal = (LogicalAggregate) after;
LogicalAggregate<Plan> global = (LogicalAggregate) after.child(0);
LogicalAggregate<Plan> local = (LogicalAggregate) after.child(0).child(0);
Assertions.assertEquals(AggPhase.DISTINCT_LOCAL, distinctLocal.getAggPhase());
Assertions.assertEquals(AggPhase.GLOBAL, global.getAggPhase());
Assertions.assertEquals(AggPhase.LOCAL, local.getAggPhase());
// check local:
// id
Expression localOutput0 = rStudent.getOutput().get(0).toSlot();
Expression localOutput0 = rStudent.getOutput().get(0);
// age
Expression localOutput1 = rStudent.getOutput().get(2).toSlot();
Expression localOutput1 = rStudent.getOutput().get(2);
// id
Expression localGroupBy0 = rStudent.getOutput().get(0).toSlot();
Expression localGroupBy0 = rStudent.getOutput().get(0);
// age
Expression localGroupBy1 = rStudent.getOutput().get(2).toSlot();
Expression localGroupBy1 = rStudent.getOutput().get(2);
Assertions.assertEquals(2, local.getOutputExpressions().size());
Assertions.assertTrue(local.getOutputExpressions().get(0) instanceof SlotReference);
Assertions.assertEquals(localOutput0, local.getOutputExpressions().get(0));
Assertions.assertTrue(local.getOutputExpressions().get(1) instanceof SlotReference);
Assertions.assertEquals(localOutput1, local.getOutputExpressions().get(1));
Assertions.assertEquals(2, local.getGroupByExpressions().size());
Assertions.assertEquals(localGroupBy0, local.getGroupByExpressions().get(0));
Assertions.assertEquals(localGroupBy1, local.getGroupByExpressions().get(1));
// check global:
Expression globalOutput0 = local.getOutputExpressions().get(0).toSlot();
Expression globalOutput1 = local.getOutputExpressions().get(1).toSlot();
Expression globalGroupBy0 = local.getOutputExpressions().get(0).toSlot();
Expression globalGroupBy1 = local.getOutputExpressions().get(1).toSlot();
Assertions.assertEquals(2, global.getOutputExpressions().size());
Assertions.assertTrue(global.getOutputExpressions().get(0) instanceof SlotReference);
Assertions.assertEquals(globalOutput0, global.getOutputExpressions().get(0));
Assertions.assertTrue(global.getOutputExpressions().get(1) instanceof SlotReference);
Assertions.assertEquals(globalOutput1, global.getOutputExpressions().get(1));
Assertions.assertEquals(2, global.getGroupByExpressions().size());
Assertions.assertEquals(globalGroupBy0, global.getGroupByExpressions().get(0));
Assertions.assertEquals(globalGroupBy1, global.getGroupByExpressions().get(1));
// check distinct local:
Expression distinctLocalOutput = new Add(new Count(local.getOutputExpressions().get(1).toSlot(), true),
new IntegerLiteral(2));
Expression distinctLocalGroupBy = local.getOutputExpressions().get(0).toSlot();
Assertions.assertEquals(1, distinctLocal.getOutputExpressions().size());
Assertions.assertTrue(distinctLocal.getOutputExpressions().get(0) instanceof Alias);
Assertions.assertEquals(distinctLocalOutput, distinctLocal.getOutputExpressions().get(0).child(0));
Assertions.assertEquals(1, distinctLocal.getGroupByExpressions().size());
Assertions.assertEquals(distinctLocalGroupBy, distinctLocal.getGroupByExpressions().get(0));
// check id:
Assertions.assertEquals(outputExpressionList.get(0).getExprId(),
distinctLocal.getOutputExpressions().get(0).getExprId());
}
private Plan rewrite(Plan input) {
return PlanRewriter.topDownRewrite(input, new ConnectContext(), new AggregateDisassemble());
PlanChecker.from(MemoTestUtils.createConnectContext(), root)
.applyTopDown(new AggregateDisassemble())
.matchesFromRoot(
logicalAggregate(
logicalAggregate(
logicalAggregate()
.when(agg -> agg.getAggPhase().equals(AggPhase.LOCAL))
.when(agg -> agg.getOutputExpressions().get(0).equals(localOutput0))
.when(agg -> agg.getOutputExpressions().get(1).equals(localOutput1))
.when(agg -> agg.getGroupByExpressions().get(0).equals(localGroupBy0))
.when(agg -> agg.getGroupByExpressions().get(1).equals(localGroupBy1))
).when(agg -> agg.getAggPhase().equals(AggPhase.GLOBAL))
.when(agg -> agg.getOutputExpressions().get(0)
.equals(agg.child().getOutputExpressions().get(0)))
.when(agg -> agg.getOutputExpressions().get(1)
.equals(agg.child().getOutputExpressions().get(1)))
.when(agg -> agg.getGroupByExpressions().get(0)
.equals(agg.child().getOutputExpressions().get(0)))
.when(agg -> agg.getGroupByExpressions().get(1)
.equals(agg.child().getOutputExpressions().get(1)))
).when(agg -> agg.getAggPhase().equals(AggPhase.DISTINCT_LOCAL))
.when(agg -> agg.getOutputExpressions().size() == 1)
.when(agg -> agg.getOutputExpressions().get(0) instanceof Alias)
.when(agg -> agg.getOutputExpressions().get(0).child(0) instanceof Add)
.when(agg -> agg.getGroupByExpressions().get(0)
.equals(agg.child().child().getOutputExpressions().get(0)))
.when(agg -> agg.getOutputExpressions().get(0).getExprId() == outputExpressionList.get(
0).getExprId())
);
}
}

View File

@ -130,16 +130,14 @@ public class StatsCalculatorTest {
childGroup.setStatistics(childStats);
LogicalFilter<GroupPlan> logicalFilter = new LogicalFilter<>(and, groupPlan);
GroupExpression groupExpression = new GroupExpression(logicalFilter);
groupExpression.addChild(childGroup);
GroupExpression groupExpression = new GroupExpression(logicalFilter, ImmutableList.of(childGroup));
Group ownerGroup = new Group();
groupExpression.setOwnerGroup(ownerGroup);
StatsCalculator.estimate(groupExpression);
Assertions.assertEquals((long) (10000 * 0.1 * 0.05), ownerGroup.getStatistics().getRowCount(), 0.001);
LogicalFilter<GroupPlan> logicalFilterOr = new LogicalFilter<>(or, groupPlan);
GroupExpression groupExpressionOr = new GroupExpression(logicalFilterOr);
groupExpressionOr.addChild(childGroup);
GroupExpression groupExpressionOr = new GroupExpression(logicalFilterOr, ImmutableList.of(childGroup));
Group ownerGroupOr = new Group();
groupExpressionOr.setOwnerGroup(ownerGroupOr);
StatsCalculator.estimate(groupExpressionOr);
@ -243,8 +241,7 @@ public class StatsCalculatorTest {
childGroup.setStatistics(childStats);
LogicalLimit<GroupPlan> logicalLimit = new LogicalLimit<>(1, 2, groupPlan);
GroupExpression groupExpression = new GroupExpression(logicalLimit);
groupExpression.addChild(childGroup);
GroupExpression groupExpression = new GroupExpression(logicalLimit, ImmutableList.of(childGroup));
Group ownerGroup = new Group();
ownerGroup.addGroupExpression(groupExpression);
StatsCalculator.estimate(groupExpression);
@ -274,8 +271,7 @@ public class StatsCalculatorTest {
childGroup.setStatistics(childStats);
LogicalTopN<GroupPlan> logicalTopN = new LogicalTopN<>(Collections.emptyList(), 1, 2, groupPlan);
GroupExpression groupExpression = new GroupExpression(logicalTopN);
groupExpression.addChild(childGroup);
GroupExpression groupExpression = new GroupExpression(logicalTopN, ImmutableList.of(childGroup));
Group ownerGroup = new Group();
ownerGroup.addGroupExpression(groupExpression);
StatsCalculator.estimate(groupExpression);