[fix](Nereids) update immutable LogicalAggregate attribute by mistake (#13740)
This commit is contained in:
@ -27,6 +27,7 @@ import org.apache.doris.nereids.trees.plans.Plan;
|
||||
import org.apache.doris.statistics.StatsDeriveResult;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.google.common.collect.Maps;
|
||||
|
||||
@ -43,7 +44,7 @@ public class GroupExpression {
|
||||
private double cost = 0.0;
|
||||
private CostEstimate costEstimate = null;
|
||||
private Group ownerGroup;
|
||||
private List<Group> children;
|
||||
private ImmutableList<Group> children;
|
||||
private final Plan plan;
|
||||
private final BitSet ruleMasks;
|
||||
private boolean statDerived;
|
||||
@ -66,7 +67,7 @@ public class GroupExpression {
|
||||
public GroupExpression(Plan plan, List<Group> children) {
|
||||
this.plan = Objects.requireNonNull(plan, "plan can not be null")
|
||||
.withGroupExpression(Optional.of(this));
|
||||
this.children = Lists.newArrayList(Objects.requireNonNull(children, "children can not be null"));
|
||||
this.children = ImmutableList.copyOf(Objects.requireNonNull(children, "children can not be null"));
|
||||
this.ruleMasks = new BitSet(RuleType.SENTINEL.ordinal());
|
||||
this.statDerived = false;
|
||||
this.lowestCostTable = Maps.newHashMap();
|
||||
@ -84,10 +85,6 @@ public class GroupExpression {
|
||||
return children.size();
|
||||
}
|
||||
|
||||
public void addChild(Group child) {
|
||||
children.add(child);
|
||||
}
|
||||
|
||||
public Group getOwnerGroup() {
|
||||
return ownerGroup;
|
||||
}
|
||||
@ -108,12 +105,13 @@ public class GroupExpression {
|
||||
return children;
|
||||
}
|
||||
|
||||
public void setChildren(List<Group> children) {
|
||||
public void setChildren(ImmutableList<Group> children) {
|
||||
this.children = children;
|
||||
}
|
||||
|
||||
/**
|
||||
* replaceChild.
|
||||
*
|
||||
* @param originChild origin child group
|
||||
* @param newChild new child group
|
||||
*/
|
||||
|
||||
@ -203,7 +203,7 @@ public class Memo {
|
||||
|
||||
/**
|
||||
* add or replace the plan into the target group.
|
||||
*
|
||||
* <p>
|
||||
* the result truth table:
|
||||
* <pre>
|
||||
* +---------------------------------------+-----------------------------------+--------------------------------+
|
||||
@ -296,8 +296,7 @@ public class Memo {
|
||||
}
|
||||
}
|
||||
plan = replaceChildrenToGroupPlan(plan, childrenGroups);
|
||||
GroupExpression newGroupExpression = new GroupExpression(plan);
|
||||
newGroupExpression.setChildren(childrenGroups);
|
||||
GroupExpression newGroupExpression = new GroupExpression(plan, childrenGroups);
|
||||
return insertGroupExpression(newGroupExpression, targetGroup, plan.getLogicalProperties());
|
||||
// TODO: need to derive logical property if generate new group. currently we not copy logical plan into
|
||||
}
|
||||
@ -388,13 +387,15 @@ public class Memo {
|
||||
}
|
||||
for (GroupExpression groupExpression : needReplaceChild) {
|
||||
groupExpressions.remove(groupExpression);
|
||||
List<Group> children = groupExpression.children();
|
||||
List<Group> children = new ArrayList<>(groupExpression.children());
|
||||
// TODO: use a better way to replace child, avoid traversing all groupExpression
|
||||
for (int i = 0; i < children.size(); i++) {
|
||||
if (children.get(i).equals(source)) {
|
||||
children.set(i, destination);
|
||||
}
|
||||
}
|
||||
groupExpression.setChildren(ImmutableList.copyOf(children));
|
||||
|
||||
GroupExpression that = groupExpressions.get(groupExpression);
|
||||
if (that != null && that.getOwnerGroup() != null
|
||||
&& !that.getOwnerGroup().equals(groupExpression.getOwnerGroup())) {
|
||||
@ -487,14 +488,14 @@ public class Memo {
|
||||
|
||||
/**
|
||||
* eliminate fromGroup, clear targetGroup, then move the logical group expressions in the fromGroup to the toGroup.
|
||||
*
|
||||
* <p>
|
||||
* the scenario is:
|
||||
* ```
|
||||
* Group 1(project, the targetGroup) Group 1(logicalOlapScan, the targetGroup)
|
||||
* | =>
|
||||
* Group 0(logicalOlapScan, the fromGroup)
|
||||
* ```
|
||||
*
|
||||
* <p>
|
||||
* we should recycle the group 0, and recycle all group expressions in group 1, then move the logicalOlapScan to
|
||||
* the group 1, and reset logical properties of the group 1.
|
||||
*/
|
||||
|
||||
@ -34,6 +34,7 @@ import com.google.common.base.Preconditions;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.google.common.collect.Maps;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
@ -147,7 +148,7 @@ public class AggregateDisassemble extends OneRewriteRuleFactory {
|
||||
// +-----------+---------------------+-------------------------+--------------------------------+
|
||||
// NOTICE: Ref: SlotReference, A: Alias, AF: AggregateFunction, #x: ExprId x
|
||||
// 2. collect local aggregate output expressions and local aggregate group by expression list
|
||||
List<Expression> localGroupByExprs = aggregate.getGroupByExpressions();
|
||||
List<Expression> localGroupByExprs = new ArrayList<>(aggregate.getGroupByExpressions());
|
||||
List<NamedExpression> localOutputExprs = Lists.newArrayList();
|
||||
for (Expression originGroupByExpr : originGroupByExprs) {
|
||||
if (inputSubstitutionMap.containsKey(originGroupByExpr)) {
|
||||
|
||||
@ -119,8 +119,8 @@ public class LogicalAggregate<CHILD_TYPE extends Plan> extends LogicalUnary<CHIL
|
||||
Optional<LogicalProperties> logicalProperties,
|
||||
CHILD_TYPE child) {
|
||||
super(PlanType.LOGICAL_AGGREGATE, groupExpression, logicalProperties, child);
|
||||
this.groupByExpressions = groupByExpressions;
|
||||
this.outputExpressions = outputExpressions;
|
||||
this.groupByExpressions = ImmutableList.copyOf(groupByExpressions);
|
||||
this.outputExpressions = ImmutableList.copyOf(outputExpressions);
|
||||
this.partitionExpressions = partitionExpressions;
|
||||
this.disassembled = disassembled;
|
||||
this.normalized = normalized;
|
||||
|
||||
@ -71,7 +71,7 @@ class MemoTest implements PatternMatchSupported {
|
||||
JoinType.INNER_JOIN, logicalJoinAB, PlanConstructor.newLogicalOlapScan(2, "C", 0));
|
||||
|
||||
@Test
|
||||
void mergeGroup() throws Exception {
|
||||
void mergeGroup() {
|
||||
Memo memo = new Memo();
|
||||
GroupId gid2 = new GroupId(2);
|
||||
Group srcGroup = new Group(gid2, new GroupExpression(new FakePlan()), new LogicalProperties(ArrayList::new));
|
||||
@ -85,13 +85,13 @@ class MemoTest implements PatternMatchSupported {
|
||||
GroupExpression ge2 = new GroupExpression(d, Arrays.asList(dstGroup));
|
||||
GroupId gid1 = new GroupId(1);
|
||||
Group g2 = new Group(gid1, ge2, new LogicalProperties(ArrayList::new));
|
||||
Map<GroupId, Group> groups = (Map<GroupId, Group>) Deencapsulation.getField(memo, "groups");
|
||||
Map<GroupId, Group> groups = Deencapsulation.getField(memo, "groups");
|
||||
groups.put(gid2, srcGroup);
|
||||
groups.put(gid3, dstGroup);
|
||||
groups.put(gid0, g1);
|
||||
groups.put(gid1, g2);
|
||||
Map<GroupExpression, GroupExpression> groupExpressions =
|
||||
(Map<GroupExpression, GroupExpression>) Deencapsulation.getField(memo, "groupExpressions");
|
||||
Deencapsulation.getField(memo, "groupExpressions");
|
||||
groupExpressions.put(ge1, ge1);
|
||||
groupExpressions.put(ge2, ge2);
|
||||
memo.mergeGroup(srcGroup, dstGroup);
|
||||
|
||||
@ -22,7 +22,6 @@ import org.apache.doris.nereids.trees.expressions.Add;
|
||||
import org.apache.doris.nereids.trees.expressions.Alias;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.NamedExpression;
|
||||
import org.apache.doris.nereids.trees.expressions.SlotReference;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral;
|
||||
@ -31,14 +30,13 @@ import org.apache.doris.nereids.trees.plans.Plan;
|
||||
import org.apache.doris.nereids.trees.plans.RelationId;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalUnary;
|
||||
import org.apache.doris.nereids.util.MemoTestUtils;
|
||||
import org.apache.doris.nereids.util.PatternMatchSupported;
|
||||
import org.apache.doris.nereids.util.PlanChecker;
|
||||
import org.apache.doris.nereids.util.PlanConstructor;
|
||||
import org.apache.doris.nereids.util.PlanRewriter;
|
||||
import org.apache.doris.qe.ConnectContext;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import com.google.common.collect.Lists;
|
||||
import org.junit.jupiter.api.Assertions;
|
||||
import org.junit.jupiter.api.BeforeAll;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.TestInstance;
|
||||
@ -46,15 +44,17 @@ import org.junit.jupiter.api.TestInstance;
|
||||
import java.util.List;
|
||||
|
||||
@TestInstance(TestInstance.Lifecycle.PER_CLASS)
|
||||
public class AggregateDisassembleTest {
|
||||
public class AggregateDisassembleTest implements PatternMatchSupported {
|
||||
private Plan rStudent;
|
||||
|
||||
@BeforeAll
|
||||
public final void beforeAll() {
|
||||
rStudent = new LogicalOlapScan(RelationId.createGenerator().getNextId(), PlanConstructor.student, ImmutableList.of(""));
|
||||
rStudent = new LogicalOlapScan(RelationId.createGenerator().getNextId(), PlanConstructor.student,
|
||||
ImmutableList.of(""));
|
||||
}
|
||||
|
||||
/**
|
||||
* <pre>
|
||||
* the initial plan is:
|
||||
* Aggregate(phase: [GLOBAL], outputExpr: [age, SUM(id) as sum], groupByExpr: [age])
|
||||
* +--childPlan(id, name, age)
|
||||
@ -62,6 +62,7 @@ public class AggregateDisassembleTest {
|
||||
* Aggregate(phase: [GLOBAL], outputExpr: [a, SUM(b) as c], groupByExpr: [a])
|
||||
* +--Aggregate(phase: [LOCAL], outputExpr: [age as a, SUM(id) as b], groupByExpr: [age])
|
||||
* +--childPlan(id, name, age)
|
||||
* </pre>
|
||||
*/
|
||||
@Test
|
||||
public void slotReferenceGroupBy() {
|
||||
@ -70,50 +71,43 @@ public class AggregateDisassembleTest {
|
||||
List<NamedExpression> outputExpressionList = Lists.newArrayList(
|
||||
rStudent.getOutput().get(2).toSlot(),
|
||||
new Alias(new Sum(rStudent.getOutput().get(0).toSlot()), "sum"));
|
||||
Plan root = new LogicalAggregate(groupExpressionList, outputExpressionList, rStudent);
|
||||
|
||||
Plan after = rewrite(root);
|
||||
|
||||
Assertions.assertTrue(after instanceof LogicalUnary);
|
||||
Assertions.assertTrue(after instanceof LogicalAggregate);
|
||||
Assertions.assertTrue(after.child(0) instanceof LogicalUnary);
|
||||
LogicalAggregate<Plan> global = (LogicalAggregate) after;
|
||||
LogicalAggregate<Plan> local = (LogicalAggregate) after.child(0);
|
||||
Assertions.assertEquals(AggPhase.GLOBAL, global.getAggPhase());
|
||||
Assertions.assertEquals(AggPhase.LOCAL, local.getAggPhase());
|
||||
Plan root = new LogicalAggregate<>(groupExpressionList, outputExpressionList, rStudent);
|
||||
|
||||
Expression localOutput0 = rStudent.getOutput().get(2).toSlot();
|
||||
Expression localOutput1 = new Sum(rStudent.getOutput().get(0).toSlot());
|
||||
Expression localGroupBy = rStudent.getOutput().get(2).toSlot();
|
||||
|
||||
Assertions.assertEquals(2, local.getOutputExpressions().size());
|
||||
Assertions.assertTrue(local.getOutputExpressions().get(0) instanceof SlotReference);
|
||||
Assertions.assertEquals(localOutput0, local.getOutputExpressions().get(0));
|
||||
Assertions.assertTrue(local.getOutputExpressions().get(1) instanceof Alias);
|
||||
Assertions.assertEquals(localOutput1, local.getOutputExpressions().get(1).child(0));
|
||||
Assertions.assertEquals(1, local.getGroupByExpressions().size());
|
||||
Assertions.assertEquals(localGroupBy, local.getGroupByExpressions().get(0));
|
||||
|
||||
Expression globalOutput0 = local.getOutputExpressions().get(0).toSlot();
|
||||
Expression globalOutput1 = new Sum(local.getOutputExpressions().get(1).toSlot());
|
||||
Expression globalGroupBy = local.getOutputExpressions().get(0).toSlot();
|
||||
|
||||
Assertions.assertEquals(2, global.getOutputExpressions().size());
|
||||
Assertions.assertTrue(global.getOutputExpressions().get(0) instanceof SlotReference);
|
||||
Assertions.assertEquals(globalOutput0, global.getOutputExpressions().get(0));
|
||||
Assertions.assertTrue(global.getOutputExpressions().get(1) instanceof Alias);
|
||||
Assertions.assertEquals(globalOutput1, global.getOutputExpressions().get(1).child(0));
|
||||
Assertions.assertEquals(1, global.getGroupByExpressions().size());
|
||||
Assertions.assertEquals(globalGroupBy, global.getGroupByExpressions().get(0));
|
||||
|
||||
// check id:
|
||||
Assertions.assertEquals(outputExpressionList.get(0).getExprId(),
|
||||
global.getOutputExpressions().get(0).getExprId());
|
||||
Assertions.assertEquals(outputExpressionList.get(1).getExprId(),
|
||||
global.getOutputExpressions().get(1).getExprId());
|
||||
PlanChecker.from(MemoTestUtils.createConnectContext(), root)
|
||||
.applyTopDown(new AggregateDisassemble())
|
||||
.printlnTree()
|
||||
.matchesFromRoot(
|
||||
logicalAggregate(
|
||||
logicalAggregate()
|
||||
.when(agg -> agg.getAggPhase().equals(AggPhase.LOCAL))
|
||||
.when(agg -> agg.getOutputExpressions().size() == 2)
|
||||
.when(agg -> agg.getOutputExpressions().get(0).equals(localOutput0))
|
||||
.when(agg -> agg.getOutputExpressions().get(1).child(0).equals(localOutput1))
|
||||
.when(agg -> agg.getGroupByExpressions().size() == 1)
|
||||
.when(agg -> agg.getGroupByExpressions().get(0).equals(localGroupBy))
|
||||
).when(agg -> agg.getAggPhase().equals(AggPhase.GLOBAL))
|
||||
.when(agg -> agg.getOutputExpressions().size() == 2)
|
||||
.when(agg -> agg.getOutputExpressions().get(0)
|
||||
.equals(agg.child().getOutputExpressions().get(0).toSlot()))
|
||||
.when(agg -> agg.getOutputExpressions().get(1).child(0)
|
||||
.equals(new Sum(agg.child().getOutputExpressions().get(1).toSlot())))
|
||||
.when(agg -> agg.getGroupByExpressions().size() == 1)
|
||||
.when(agg -> agg.getGroupByExpressions().get(0)
|
||||
.equals(agg.child().getOutputExpressions().get(0).toSlot()))
|
||||
// check id:
|
||||
.when(agg -> agg.getOutputExpressions().get(0).getExprId()
|
||||
.equals(outputExpressionList.get(0).getExprId()))
|
||||
.when(agg -> agg.getOutputExpressions().get(1).getExprId()
|
||||
.equals(outputExpressionList.get(1).getExprId()))
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* <pre>
|
||||
* the initial plan is:
|
||||
* Aggregate(phase: [GLOBAL], outputExpr: [SUM(id) as sum], groupByExpr: [])
|
||||
* +--childPlan(id, name, age)
|
||||
@ -121,44 +115,41 @@ public class AggregateDisassembleTest {
|
||||
* Aggregate(phase: [GLOBAL], outputExpr: [SUM(b) as b], groupByExpr: [])
|
||||
* +--Aggregate(phase: [LOCAL], outputExpr: [SUM(id) as a], groupByExpr: [])
|
||||
* +--childPlan(id, name, age)
|
||||
* </pre>
|
||||
*/
|
||||
@Test
|
||||
public void globalAggregate() {
|
||||
List<Expression> groupExpressionList = Lists.newArrayList();
|
||||
List<NamedExpression> outputExpressionList = Lists.newArrayList(
|
||||
new Alias(new Sum(rStudent.getOutput().get(0).toSlot()), "sum"));
|
||||
Plan root = new LogicalAggregate(groupExpressionList, outputExpressionList, rStudent);
|
||||
|
||||
Plan after = rewrite(root);
|
||||
|
||||
Assertions.assertTrue(after instanceof LogicalUnary);
|
||||
Assertions.assertTrue(after instanceof LogicalAggregate);
|
||||
Assertions.assertTrue(after.child(0) instanceof LogicalUnary);
|
||||
LogicalAggregate<Plan> global = (LogicalAggregate) after;
|
||||
LogicalAggregate<Plan> local = (LogicalAggregate) after.child(0);
|
||||
Assertions.assertEquals(AggPhase.GLOBAL, global.getAggPhase());
|
||||
Assertions.assertEquals(AggPhase.LOCAL, local.getAggPhase());
|
||||
new Alias(new Sum(rStudent.getOutput().get(0)), "sum"));
|
||||
Plan root = new LogicalAggregate<>(groupExpressionList, outputExpressionList, rStudent);
|
||||
|
||||
Expression localOutput0 = new Sum(rStudent.getOutput().get(0).toSlot());
|
||||
|
||||
Assertions.assertEquals(1, local.getOutputExpressions().size());
|
||||
Assertions.assertTrue(local.getOutputExpressions().get(0) instanceof Alias);
|
||||
Assertions.assertEquals(localOutput0, local.getOutputExpressions().get(0).child(0));
|
||||
Assertions.assertEquals(0, local.getGroupByExpressions().size());
|
||||
|
||||
Expression globalOutput0 = new Sum(local.getOutputExpressions().get(0).toSlot());
|
||||
|
||||
Assertions.assertEquals(1, global.getOutputExpressions().size());
|
||||
Assertions.assertTrue(global.getOutputExpressions().get(0) instanceof Alias);
|
||||
Assertions.assertEquals(globalOutput0, global.getOutputExpressions().get(0).child(0));
|
||||
Assertions.assertEquals(0, global.getGroupByExpressions().size());
|
||||
|
||||
// check id:
|
||||
Assertions.assertEquals(outputExpressionList.get(0).getExprId(),
|
||||
global.getOutputExpressions().get(0).getExprId());
|
||||
PlanChecker.from(MemoTestUtils.createConnectContext(), root)
|
||||
.applyTopDown(new AggregateDisassemble())
|
||||
.printlnTree()
|
||||
.matchesFromRoot(
|
||||
logicalAggregate(
|
||||
logicalAggregate()
|
||||
.when(agg -> agg.getAggPhase().equals(AggPhase.LOCAL))
|
||||
.when(agg -> agg.getOutputExpressions().size() == 1)
|
||||
.when(agg -> agg.getOutputExpressions().get(0).child(0).equals(localOutput0))
|
||||
.when(agg -> agg.getGroupByExpressions().size() == 0)
|
||||
).when(agg -> agg.getAggPhase().equals(AggPhase.GLOBAL))
|
||||
.when(agg -> agg.getOutputExpressions().size() == 1)
|
||||
.when(agg -> agg.getOutputExpressions().get(0) instanceof Alias)
|
||||
.when(agg -> agg.getOutputExpressions().get(0).child(0)
|
||||
.equals(new Sum(agg.child().getOutputExpressions().get(0).toSlot())))
|
||||
.when(agg -> agg.getGroupByExpressions().size() == 0)
|
||||
// check id:
|
||||
.when(agg -> agg.getOutputExpressions().get(0).getExprId()
|
||||
.equals(outputExpressionList.get(0).getExprId()))
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* <pre>
|
||||
* the initial plan is:
|
||||
* Aggregate(phase: [GLOBAL], outputExpr: [SUM(id) as sum], groupByExpr: [age])
|
||||
* +--childPlan(id, name, age)
|
||||
@ -166,6 +157,7 @@ public class AggregateDisassembleTest {
|
||||
* Aggregate(phase: [GLOBAL], outputExpr: [SUM(b) as c], groupByExpr: [a])
|
||||
* +--Aggregate(phase: [LOCAL], outputExpr: [age as a, SUM(id) as b], groupByExpr: [age])
|
||||
* +--childPlan(id, name, age)
|
||||
* </pre>
|
||||
*/
|
||||
@Test
|
||||
public void groupExpressionNotInOutput() {
|
||||
@ -173,45 +165,40 @@ public class AggregateDisassembleTest {
|
||||
rStudent.getOutput().get(2).toSlot());
|
||||
List<NamedExpression> outputExpressionList = Lists.newArrayList(
|
||||
new Alias(new Sum(rStudent.getOutput().get(0).toSlot()), "sum"));
|
||||
Plan root = new LogicalAggregate(groupExpressionList, outputExpressionList, rStudent);
|
||||
|
||||
Plan after = rewrite(root);
|
||||
|
||||
Assertions.assertTrue(after instanceof LogicalUnary);
|
||||
Assertions.assertTrue(after instanceof LogicalAggregate);
|
||||
Assertions.assertTrue(after.child(0) instanceof LogicalUnary);
|
||||
LogicalAggregate<Plan> global = (LogicalAggregate) after;
|
||||
LogicalAggregate<Plan> local = (LogicalAggregate) after.child(0);
|
||||
Assertions.assertEquals(AggPhase.GLOBAL, global.getAggPhase());
|
||||
Assertions.assertEquals(AggPhase.LOCAL, local.getAggPhase());
|
||||
Plan root = new LogicalAggregate<>(groupExpressionList, outputExpressionList, rStudent);
|
||||
|
||||
Expression localOutput0 = rStudent.getOutput().get(2).toSlot();
|
||||
Expression localOutput1 = new Sum(rStudent.getOutput().get(0).toSlot());
|
||||
Expression localGroupBy = rStudent.getOutput().get(2).toSlot();
|
||||
|
||||
Assertions.assertEquals(2, local.getOutputExpressions().size());
|
||||
Assertions.assertTrue(local.getOutputExpressions().get(0) instanceof SlotReference);
|
||||
Assertions.assertEquals(localOutput0, local.getOutputExpressions().get(0));
|
||||
Assertions.assertTrue(local.getOutputExpressions().get(1) instanceof Alias);
|
||||
Assertions.assertEquals(localOutput1, local.getOutputExpressions().get(1).child(0));
|
||||
Assertions.assertEquals(1, local.getGroupByExpressions().size());
|
||||
Assertions.assertEquals(localGroupBy, local.getGroupByExpressions().get(0));
|
||||
|
||||
Expression globalOutput0 = new Sum(local.getOutputExpressions().get(1).toSlot());
|
||||
Expression globalGroupBy = local.getOutputExpressions().get(0).toSlot();
|
||||
|
||||
Assertions.assertEquals(1, global.getOutputExpressions().size());
|
||||
Assertions.assertTrue(global.getOutputExpressions().get(0) instanceof Alias);
|
||||
Assertions.assertEquals(globalOutput0, global.getOutputExpressions().get(0).child(0));
|
||||
Assertions.assertEquals(1, global.getGroupByExpressions().size());
|
||||
Assertions.assertEquals(globalGroupBy, global.getGroupByExpressions().get(0));
|
||||
|
||||
// check id:
|
||||
Assertions.assertEquals(outputExpressionList.get(0).getExprId(),
|
||||
global.getOutputExpressions().get(0).getExprId());
|
||||
PlanChecker.from(MemoTestUtils.createConnectContext(), root)
|
||||
.applyTopDown(new AggregateDisassemble())
|
||||
.printlnTree()
|
||||
.matchesFromRoot(
|
||||
logicalAggregate(
|
||||
logicalAggregate()
|
||||
.when(agg -> agg.getAggPhase().equals(AggPhase.LOCAL))
|
||||
.when(agg -> agg.getOutputExpressions().size() == 2)
|
||||
.when(agg -> agg.getOutputExpressions().get(0).equals(localOutput0))
|
||||
.when(agg -> agg.getOutputExpressions().get(1).child(0).equals(localOutput1))
|
||||
.when(agg -> agg.getGroupByExpressions().size() == 1)
|
||||
.when(agg -> agg.getGroupByExpressions().get(0).equals(localGroupBy))
|
||||
).when(agg -> agg.getAggPhase().equals(AggPhase.GLOBAL))
|
||||
.when(agg -> agg.getOutputExpressions().size() == 1)
|
||||
.when(agg -> agg.getOutputExpressions().get(0) instanceof Alias)
|
||||
.when(agg -> agg.getOutputExpressions().get(0).child(0)
|
||||
.equals(new Sum(agg.child().getOutputExpressions().get(1).toSlot())))
|
||||
.when(agg -> agg.getGroupByExpressions().size() == 1)
|
||||
.when(agg -> agg.getGroupByExpressions().get(0)
|
||||
.equals(agg.child().getOutputExpressions().get(0).toSlot()))
|
||||
// check id:
|
||||
.when(agg -> agg.getOutputExpressions().get(0).getExprId()
|
||||
.equals(outputExpressionList.get(0).getExprId()))
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* <pre>
|
||||
* the initial plan is:
|
||||
* Aggregate(phase: [GLOBAL], outputExpr: [(COUNT(distinct age) + 2) as c], groupByExpr: [id])
|
||||
* +-- childPlan(id, name, age)
|
||||
@ -220,6 +207,7 @@ public class AggregateDisassembleTest {
|
||||
* +-- Aggregate(phase: [GLOBAL], outputExpr: [id, age], groupByExpr: [id, age])
|
||||
* +-- Aggregate(phase: [LOCAL], outputExpr: [id, age], groupByExpr: [id, age])
|
||||
* +-- childPlan(id, name, age)
|
||||
* </pre>
|
||||
*/
|
||||
@Test
|
||||
public void distinctAggregateWithGroupBy() {
|
||||
@ -229,68 +217,44 @@ public class AggregateDisassembleTest {
|
||||
new IntegerLiteral(2)), "c"));
|
||||
Plan root = new LogicalAggregate<>(groupExpressionList, outputExpressionList, rStudent);
|
||||
|
||||
Plan after = rewrite(root);
|
||||
|
||||
Assertions.assertTrue(after instanceof LogicalUnary);
|
||||
Assertions.assertTrue(after instanceof LogicalAggregate);
|
||||
Assertions.assertTrue(after.child(0) instanceof LogicalUnary);
|
||||
LogicalAggregate<Plan> distinctLocal = (LogicalAggregate) after;
|
||||
LogicalAggregate<Plan> global = (LogicalAggregate) after.child(0);
|
||||
LogicalAggregate<Plan> local = (LogicalAggregate) after.child(0).child(0);
|
||||
Assertions.assertEquals(AggPhase.DISTINCT_LOCAL, distinctLocal.getAggPhase());
|
||||
Assertions.assertEquals(AggPhase.GLOBAL, global.getAggPhase());
|
||||
Assertions.assertEquals(AggPhase.LOCAL, local.getAggPhase());
|
||||
// check local:
|
||||
// id
|
||||
Expression localOutput0 = rStudent.getOutput().get(0).toSlot();
|
||||
Expression localOutput0 = rStudent.getOutput().get(0);
|
||||
// age
|
||||
Expression localOutput1 = rStudent.getOutput().get(2).toSlot();
|
||||
Expression localOutput1 = rStudent.getOutput().get(2);
|
||||
// id
|
||||
Expression localGroupBy0 = rStudent.getOutput().get(0).toSlot();
|
||||
Expression localGroupBy0 = rStudent.getOutput().get(0);
|
||||
// age
|
||||
Expression localGroupBy1 = rStudent.getOutput().get(2).toSlot();
|
||||
Expression localGroupBy1 = rStudent.getOutput().get(2);
|
||||
|
||||
Assertions.assertEquals(2, local.getOutputExpressions().size());
|
||||
Assertions.assertTrue(local.getOutputExpressions().get(0) instanceof SlotReference);
|
||||
Assertions.assertEquals(localOutput0, local.getOutputExpressions().get(0));
|
||||
Assertions.assertTrue(local.getOutputExpressions().get(1) instanceof SlotReference);
|
||||
Assertions.assertEquals(localOutput1, local.getOutputExpressions().get(1));
|
||||
Assertions.assertEquals(2, local.getGroupByExpressions().size());
|
||||
Assertions.assertEquals(localGroupBy0, local.getGroupByExpressions().get(0));
|
||||
Assertions.assertEquals(localGroupBy1, local.getGroupByExpressions().get(1));
|
||||
|
||||
// check global:
|
||||
Expression globalOutput0 = local.getOutputExpressions().get(0).toSlot();
|
||||
Expression globalOutput1 = local.getOutputExpressions().get(1).toSlot();
|
||||
Expression globalGroupBy0 = local.getOutputExpressions().get(0).toSlot();
|
||||
Expression globalGroupBy1 = local.getOutputExpressions().get(1).toSlot();
|
||||
|
||||
Assertions.assertEquals(2, global.getOutputExpressions().size());
|
||||
Assertions.assertTrue(global.getOutputExpressions().get(0) instanceof SlotReference);
|
||||
Assertions.assertEquals(globalOutput0, global.getOutputExpressions().get(0));
|
||||
Assertions.assertTrue(global.getOutputExpressions().get(1) instanceof SlotReference);
|
||||
Assertions.assertEquals(globalOutput1, global.getOutputExpressions().get(1));
|
||||
Assertions.assertEquals(2, global.getGroupByExpressions().size());
|
||||
Assertions.assertEquals(globalGroupBy0, global.getGroupByExpressions().get(0));
|
||||
Assertions.assertEquals(globalGroupBy1, global.getGroupByExpressions().get(1));
|
||||
|
||||
// check distinct local:
|
||||
Expression distinctLocalOutput = new Add(new Count(local.getOutputExpressions().get(1).toSlot(), true),
|
||||
new IntegerLiteral(2));
|
||||
Expression distinctLocalGroupBy = local.getOutputExpressions().get(0).toSlot();
|
||||
|
||||
Assertions.assertEquals(1, distinctLocal.getOutputExpressions().size());
|
||||
Assertions.assertTrue(distinctLocal.getOutputExpressions().get(0) instanceof Alias);
|
||||
Assertions.assertEquals(distinctLocalOutput, distinctLocal.getOutputExpressions().get(0).child(0));
|
||||
Assertions.assertEquals(1, distinctLocal.getGroupByExpressions().size());
|
||||
Assertions.assertEquals(distinctLocalGroupBy, distinctLocal.getGroupByExpressions().get(0));
|
||||
|
||||
// check id:
|
||||
Assertions.assertEquals(outputExpressionList.get(0).getExprId(),
|
||||
distinctLocal.getOutputExpressions().get(0).getExprId());
|
||||
}
|
||||
|
||||
private Plan rewrite(Plan input) {
|
||||
return PlanRewriter.topDownRewrite(input, new ConnectContext(), new AggregateDisassemble());
|
||||
PlanChecker.from(MemoTestUtils.createConnectContext(), root)
|
||||
.applyTopDown(new AggregateDisassemble())
|
||||
.matchesFromRoot(
|
||||
logicalAggregate(
|
||||
logicalAggregate(
|
||||
logicalAggregate()
|
||||
.when(agg -> agg.getAggPhase().equals(AggPhase.LOCAL))
|
||||
.when(agg -> agg.getOutputExpressions().get(0).equals(localOutput0))
|
||||
.when(agg -> agg.getOutputExpressions().get(1).equals(localOutput1))
|
||||
.when(agg -> agg.getGroupByExpressions().get(0).equals(localGroupBy0))
|
||||
.when(agg -> agg.getGroupByExpressions().get(1).equals(localGroupBy1))
|
||||
).when(agg -> agg.getAggPhase().equals(AggPhase.GLOBAL))
|
||||
.when(agg -> agg.getOutputExpressions().get(0)
|
||||
.equals(agg.child().getOutputExpressions().get(0)))
|
||||
.when(agg -> agg.getOutputExpressions().get(1)
|
||||
.equals(agg.child().getOutputExpressions().get(1)))
|
||||
.when(agg -> agg.getGroupByExpressions().get(0)
|
||||
.equals(agg.child().getOutputExpressions().get(0)))
|
||||
.when(agg -> agg.getGroupByExpressions().get(1)
|
||||
.equals(agg.child().getOutputExpressions().get(1)))
|
||||
).when(agg -> agg.getAggPhase().equals(AggPhase.DISTINCT_LOCAL))
|
||||
.when(agg -> agg.getOutputExpressions().size() == 1)
|
||||
.when(agg -> agg.getOutputExpressions().get(0) instanceof Alias)
|
||||
.when(agg -> agg.getOutputExpressions().get(0).child(0) instanceof Add)
|
||||
.when(agg -> agg.getGroupByExpressions().get(0)
|
||||
.equals(agg.child().child().getOutputExpressions().get(0)))
|
||||
.when(agg -> agg.getOutputExpressions().get(0).getExprId() == outputExpressionList.get(
|
||||
0).getExprId())
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@ -130,16 +130,14 @@ public class StatsCalculatorTest {
|
||||
childGroup.setStatistics(childStats);
|
||||
|
||||
LogicalFilter<GroupPlan> logicalFilter = new LogicalFilter<>(and, groupPlan);
|
||||
GroupExpression groupExpression = new GroupExpression(logicalFilter);
|
||||
groupExpression.addChild(childGroup);
|
||||
GroupExpression groupExpression = new GroupExpression(logicalFilter, ImmutableList.of(childGroup));
|
||||
Group ownerGroup = new Group();
|
||||
groupExpression.setOwnerGroup(ownerGroup);
|
||||
StatsCalculator.estimate(groupExpression);
|
||||
Assertions.assertEquals((long) (10000 * 0.1 * 0.05), ownerGroup.getStatistics().getRowCount(), 0.001);
|
||||
|
||||
LogicalFilter<GroupPlan> logicalFilterOr = new LogicalFilter<>(or, groupPlan);
|
||||
GroupExpression groupExpressionOr = new GroupExpression(logicalFilterOr);
|
||||
groupExpressionOr.addChild(childGroup);
|
||||
GroupExpression groupExpressionOr = new GroupExpression(logicalFilterOr, ImmutableList.of(childGroup));
|
||||
Group ownerGroupOr = new Group();
|
||||
groupExpressionOr.setOwnerGroup(ownerGroupOr);
|
||||
StatsCalculator.estimate(groupExpressionOr);
|
||||
@ -243,8 +241,7 @@ public class StatsCalculatorTest {
|
||||
childGroup.setStatistics(childStats);
|
||||
|
||||
LogicalLimit<GroupPlan> logicalLimit = new LogicalLimit<>(1, 2, groupPlan);
|
||||
GroupExpression groupExpression = new GroupExpression(logicalLimit);
|
||||
groupExpression.addChild(childGroup);
|
||||
GroupExpression groupExpression = new GroupExpression(logicalLimit, ImmutableList.of(childGroup));
|
||||
Group ownerGroup = new Group();
|
||||
ownerGroup.addGroupExpression(groupExpression);
|
||||
StatsCalculator.estimate(groupExpression);
|
||||
@ -274,8 +271,7 @@ public class StatsCalculatorTest {
|
||||
childGroup.setStatistics(childStats);
|
||||
|
||||
LogicalTopN<GroupPlan> logicalTopN = new LogicalTopN<>(Collections.emptyList(), 1, 2, groupPlan);
|
||||
GroupExpression groupExpression = new GroupExpression(logicalTopN);
|
||||
groupExpression.addChild(childGroup);
|
||||
GroupExpression groupExpression = new GroupExpression(logicalTopN, ImmutableList.of(childGroup));
|
||||
Group ownerGroup = new Group();
|
||||
ownerGroup.addGroupExpression(groupExpression);
|
||||
StatsCalculator.estimate(groupExpression);
|
||||
|
||||
Reference in New Issue
Block a user