[feat](Nereids): a new CBO rule: Eager Split/GroupByCount (#18556)
This commit is contained in:
@ -241,6 +241,10 @@ public enum RuleType {
|
||||
PUSH_DOWN_PROJECT_THROUGH_INNER_JOIN(RuleTypeClass.EXPLORATION),
|
||||
EAGER_COUNT(RuleTypeClass.EXPLORATION),
|
||||
EAGER_GROUP_BY(RuleTypeClass.EXPLORATION),
|
||||
EAGER_GROUP_BY_COUNT(RuleTypeClass.EXPLORATION),
|
||||
EAGER_SPLIT(RuleTypeClass.EXPLORATION),
|
||||
|
||||
EXPLORATION_SENTINEL(RuleTypeClass.EXPLORATION),
|
||||
|
||||
// implementation rules
|
||||
LOGICAL_ONE_ROW_RELATION_TO_PHYSICAL_ONE_ROW_RELATION(RuleTypeClass.IMPLEMENTATION),
|
||||
|
||||
@ -49,7 +49,7 @@ import java.util.Set;
|
||||
* | *
|
||||
* (x)
|
||||
* ->
|
||||
* aggregate: SUM(x * cnt)
|
||||
* aggregate: SUM(x) * cnt
|
||||
* |
|
||||
* join
|
||||
* | \
|
||||
@ -62,7 +62,7 @@ public class EagerCount extends OneExplorationRuleFactory {
|
||||
|
||||
@Override
|
||||
public Rule build() {
|
||||
return logicalAggregate(logicalJoin())
|
||||
return logicalAggregate(innerLogicalJoin())
|
||||
.when(agg -> agg.child().getOtherJoinConjuncts().size() == 0)
|
||||
.when(agg -> agg.getGroupByExpressions().stream().allMatch(e -> e instanceof Slot))
|
||||
.when(agg -> agg.getAggregateFunctions().stream()
|
||||
|
||||
@ -54,13 +54,15 @@ import java.util.stream.Collectors;
|
||||
* | *
|
||||
* aggregate: SUM(x) as sum1
|
||||
* </pre>
|
||||
* After Eager Group By, new plan also can apply `Eager Count`.
|
||||
* It's `Double Eager`.
|
||||
*/
|
||||
public class EagerGroupBy extends OneExplorationRuleFactory {
|
||||
public static final EagerGroupBy INSTANCE = new EagerGroupBy();
|
||||
|
||||
@Override
|
||||
public Rule build() {
|
||||
return logicalAggregate(logicalJoin())
|
||||
return logicalAggregate(innerLogicalJoin())
|
||||
.when(agg -> agg.child().getOtherJoinConjuncts().size() == 0)
|
||||
.when(agg -> agg.getAggregateFunctions().stream()
|
||||
.allMatch(f -> f instanceof Sum
|
||||
|
||||
@ -0,0 +1,138 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package org.apache.doris.nereids.rules.exploration;
|
||||
|
||||
import org.apache.doris.nereids.rules.Rule;
|
||||
import org.apache.doris.nereids.rules.RuleType;
|
||||
import org.apache.doris.nereids.trees.expressions.Alias;
|
||||
import org.apache.doris.nereids.trees.expressions.Multiply;
|
||||
import org.apache.doris.nereids.trees.expressions.NamedExpression;
|
||||
import org.apache.doris.nereids.trees.expressions.Slot;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.Literal;
|
||||
import org.apache.doris.nereids.trees.plans.GroupPlan;
|
||||
import org.apache.doris.nereids.trees.plans.Plan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
|
||||
/**
|
||||
* Related paper "Eager aggregation and lazy aggregation".
|
||||
* <pre>
|
||||
* aggregate: SUM(x), SUM(y)
|
||||
* |
|
||||
* join
|
||||
* | \
|
||||
* | (y)
|
||||
* (x)
|
||||
* ->
|
||||
* aggregate: SUM(sum1), SUM(y) * cnt
|
||||
* |
|
||||
* join
|
||||
* | \
|
||||
* | (y)
|
||||
* aggregate: SUM(x) as sum1 , COUNT as cnt
|
||||
* </pre>
|
||||
*/
|
||||
public class EagerGroupByCount extends OneExplorationRuleFactory {
|
||||
public static final EagerGroupByCount INSTANCE = new EagerGroupByCount();
|
||||
|
||||
@Override
|
||||
public Rule build() {
|
||||
return logicalAggregate(innerLogicalJoin())
|
||||
.when(agg -> agg.child().getOtherJoinConjuncts().size() == 0)
|
||||
.when(agg -> agg.getAggregateFunctions().stream()
|
||||
.allMatch(f -> f instanceof Sum && ((Sum) f).child() instanceof Slot))
|
||||
.then(agg -> {
|
||||
LogicalJoin<GroupPlan, GroupPlan> join = agg.child();
|
||||
List<Slot> leftOutput = join.left().getOutput();
|
||||
List<Sum> leftSums = new ArrayList<>();
|
||||
List<Sum> rightSums = new ArrayList<>();
|
||||
for (AggregateFunction f : agg.getAggregateFunctions()) {
|
||||
Sum sum = (Sum) f;
|
||||
if (leftOutput.contains((Slot) sum.child())) {
|
||||
leftSums.add(sum);
|
||||
} else {
|
||||
rightSums.add(sum);
|
||||
}
|
||||
}
|
||||
if (leftSums.size() == 0 || rightSums.size() == 0) {
|
||||
return null;
|
||||
}
|
||||
|
||||
// left bottom agg
|
||||
Set<Slot> bottomAggGroupBy = new HashSet<>();
|
||||
agg.getGroupByExpressions().stream().map(e -> (Slot) e).filter(leftOutput::contains)
|
||||
.forEach(bottomAggGroupBy::add);
|
||||
join.getHashJoinConjuncts().forEach(e -> e.getInputSlots().forEach(slot -> {
|
||||
if (leftOutput.contains(slot)) {
|
||||
bottomAggGroupBy.add(slot);
|
||||
}
|
||||
}));
|
||||
List<NamedExpression> bottomSums = new ArrayList<>();
|
||||
for (int i = 0; i < leftSums.size(); i++) {
|
||||
bottomSums.add(new Alias(new Sum(leftSums.get(i).child()), "sum" + i));
|
||||
}
|
||||
Alias cnt = new Alias(new Count(Literal.of(1)), "cnt");
|
||||
List<NamedExpression> bottomAggOutput = ImmutableList.<NamedExpression>builder()
|
||||
.addAll(bottomAggGroupBy).addAll(bottomSums).add(cnt).build();
|
||||
LogicalAggregate<GroupPlan> bottomAgg = new LogicalAggregate<>(
|
||||
ImmutableList.copyOf(bottomAggGroupBy), bottomAggOutput, join.left());
|
||||
Plan newJoin = join.withChildren(bottomAgg, join.right());
|
||||
|
||||
// top agg
|
||||
List<NamedExpression> newOutputExprs = new ArrayList<>();
|
||||
List<Alias> leftSumOutputExprs = new ArrayList<>();
|
||||
List<Alias> rightSumOutputExprs = new ArrayList<>();
|
||||
for (NamedExpression ne : agg.getOutputExpressions()) {
|
||||
if (ne instanceof Alias && ((Alias) ne).child() instanceof Sum) {
|
||||
Alias sumOutput = (Alias) ne;
|
||||
Slot child = (Slot) ((Sum) (sumOutput).child()).child();
|
||||
if (leftOutput.contains(child)) {
|
||||
leftSumOutputExprs.add(sumOutput);
|
||||
} else {
|
||||
rightSumOutputExprs.add(sumOutput);
|
||||
}
|
||||
} else {
|
||||
newOutputExprs.add(ne);
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < leftSumOutputExprs.size(); i++) {
|
||||
Alias oldSum = leftSumOutputExprs.get(i);
|
||||
// sum in bottom Agg
|
||||
Slot bottomSum = bottomSums.get(i).toSlot();
|
||||
Alias newSum = new Alias(oldSum.getExprId(), new Sum(bottomSum), oldSum.getName());
|
||||
newOutputExprs.add(newSum);
|
||||
}
|
||||
for (Alias oldSum : rightSumOutputExprs) {
|
||||
Sum oldSumFunc = (Sum) oldSum.child();
|
||||
newOutputExprs.add(new Alias(oldSum.getExprId(), new Multiply(oldSumFunc, cnt.toSlot()),
|
||||
oldSum.getName()));
|
||||
}
|
||||
return agg.withAggOutput(newOutputExprs).withChildren(newJoin);
|
||||
}).toRule(RuleType.EAGER_GROUP_BY_COUNT);
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,164 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package org.apache.doris.nereids.rules.exploration;
|
||||
|
||||
import org.apache.doris.nereids.rules.Rule;
|
||||
import org.apache.doris.nereids.rules.RuleType;
|
||||
import org.apache.doris.nereids.trees.expressions.Alias;
|
||||
import org.apache.doris.nereids.trees.expressions.Multiply;
|
||||
import org.apache.doris.nereids.trees.expressions.NamedExpression;
|
||||
import org.apache.doris.nereids.trees.expressions.Slot;
|
||||
import org.apache.doris.nereids.trees.expressions.SlotReference;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.Literal;
|
||||
import org.apache.doris.nereids.trees.plans.GroupPlan;
|
||||
import org.apache.doris.nereids.trees.plans.Plan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import com.google.common.collect.ImmutableList;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
|
||||
/**
|
||||
* Related paper "Eager aggregation and lazy aggregation".
|
||||
* <pre>
|
||||
* aggregate: SUM(x), SUM(y)
|
||||
* |
|
||||
* join
|
||||
* | \
|
||||
* | (y)
|
||||
* (x)
|
||||
* ->
|
||||
* aggregate: SUM(sum1) * cnt2, SUM(sum2) * cnt1
|
||||
* |
|
||||
* join
|
||||
* | \
|
||||
* | aggregate: SUM(y) as sum2, COUNT: cnt2
|
||||
* aggregate: SUM(x) as sum1, COUNT: cnt1
|
||||
* </pre>
|
||||
*/
|
||||
public class EagerSplit extends OneExplorationRuleFactory {
|
||||
public static final EagerSplit INSTANCE = new EagerSplit();
|
||||
|
||||
@Override
|
||||
public Rule build() {
|
||||
return logicalAggregate(innerLogicalJoin())
|
||||
.when(agg -> agg.getAggregateFunctions().stream()
|
||||
.allMatch(f -> f instanceof Sum && ((Sum) f).child() instanceof SlotReference))
|
||||
.then(agg -> {
|
||||
LogicalJoin<GroupPlan, GroupPlan> join = agg.child();
|
||||
List<Slot> leftOutput = join.left().getOutput();
|
||||
List<Slot> rightOutput = join.right().getOutput();
|
||||
List<Sum> leftSums = new ArrayList<>();
|
||||
List<Sum> rightSums = new ArrayList<>();
|
||||
for (AggregateFunction f : agg.getAggregateFunctions()) {
|
||||
Sum sum = (Sum) f;
|
||||
if (leftOutput.contains((Slot) sum.child())) {
|
||||
leftSums.add(sum);
|
||||
} else {
|
||||
rightSums.add(sum);
|
||||
}
|
||||
}
|
||||
if (leftSums.size() == 0 || rightSums.size() == 0) {
|
||||
return null;
|
||||
}
|
||||
|
||||
// left bottom agg
|
||||
Set<Slot> leftBottomAggGroupBy = new HashSet<>();
|
||||
agg.getGroupByExpressions().stream().map(e -> (Slot) e).filter(leftOutput::contains)
|
||||
.forEach(leftBottomAggGroupBy::add);
|
||||
join.getHashJoinConjuncts().forEach(e -> e.getInputSlots().forEach(slot -> {
|
||||
if (leftOutput.contains(slot)) {
|
||||
leftBottomAggGroupBy.add(slot);
|
||||
}
|
||||
}));
|
||||
List<NamedExpression> leftBottomSums = new ArrayList<>();
|
||||
for (int i = 0; i < leftSums.size(); i++) {
|
||||
leftBottomSums.add(new Alias(new Sum(leftSums.get(i).child()), "left_sum" + i));
|
||||
}
|
||||
Alias leftCnt = new Alias(new Count(Literal.of(1)), "left_cnt");
|
||||
List<NamedExpression> leftBottomAggOutput = ImmutableList.<NamedExpression>builder()
|
||||
.addAll(leftBottomAggGroupBy).addAll(leftBottomSums).add(leftCnt).build();
|
||||
LogicalAggregate<GroupPlan> leftBottomAgg = new LogicalAggregate<>(
|
||||
ImmutableList.copyOf(leftBottomAggGroupBy), leftBottomAggOutput, join.left());
|
||||
|
||||
// right bottom agg
|
||||
Set<Slot> rightBottomAggGroupBy = new HashSet<>();
|
||||
agg.getGroupByExpressions().stream().map(e -> (Slot) e).filter(rightOutput::contains)
|
||||
.forEach(rightBottomAggGroupBy::add);
|
||||
join.getHashJoinConjuncts().forEach(e -> e.getInputSlots().forEach(slot -> {
|
||||
if (rightOutput.contains(slot)) {
|
||||
rightBottomAggGroupBy.add(slot);
|
||||
}
|
||||
}));
|
||||
List<NamedExpression> rightBottomSums = new ArrayList<>();
|
||||
for (int i = 0; i < rightSums.size(); i++) {
|
||||
rightBottomSums.add(new Alias(new Sum(rightSums.get(i).child()), "right_sum" + i));
|
||||
}
|
||||
Alias rightCnt = new Alias(new Count(Literal.of(1)), "right_cnt");
|
||||
List<NamedExpression> rightBottomAggOutput = ImmutableList.<NamedExpression>builder()
|
||||
.addAll(rightBottomAggGroupBy).addAll(rightBottomSums).add(rightCnt).build();
|
||||
LogicalAggregate<GroupPlan> rightBottomAgg = new LogicalAggregate<>(
|
||||
ImmutableList.copyOf(rightBottomAggGroupBy), rightBottomAggOutput, join.right());
|
||||
|
||||
Plan newJoin = join.withChildren(leftBottomAgg, rightBottomAgg);
|
||||
|
||||
// top agg
|
||||
List<NamedExpression> newOutputExprs = new ArrayList<>();
|
||||
List<Alias> leftSumOutputExprs = new ArrayList<>();
|
||||
List<Alias> rightSumOutputExprs = new ArrayList<>();
|
||||
for (NamedExpression ne : agg.getOutputExpressions()) {
|
||||
if (ne instanceof Alias && ((Alias) ne).child() instanceof Sum) {
|
||||
Alias sumOutput = (Alias) ne;
|
||||
Slot child = (Slot) ((Sum) (sumOutput).child()).child();
|
||||
if (leftOutput.contains(child)) {
|
||||
leftSumOutputExprs.add(sumOutput);
|
||||
} else {
|
||||
rightSumOutputExprs.add(sumOutput);
|
||||
}
|
||||
} else {
|
||||
newOutputExprs.add(ne);
|
||||
}
|
||||
}
|
||||
Preconditions.checkState(leftSumOutputExprs.size() == leftBottomSums.size());
|
||||
Preconditions.checkState(rightSumOutputExprs.size() == rightBottomSums.size());
|
||||
for (int i = 0; i < leftSumOutputExprs.size(); i++) {
|
||||
Alias oldSum = leftSumOutputExprs.get(i);
|
||||
Slot bottomSum = leftBottomSums.get(i).toSlot();
|
||||
Alias newSum = new Alias(oldSum.getExprId(),
|
||||
new Multiply(new Sum(bottomSum), rightCnt.toSlot()), oldSum.getName());
|
||||
newOutputExprs.add(newSum);
|
||||
}
|
||||
for (int i = 0; i < rightSumOutputExprs.size(); i++) {
|
||||
Alias oldSum = rightSumOutputExprs.get(i);
|
||||
Slot bottomSum = rightBottomSums.get(i).toSlot();
|
||||
Alias newSum = new Alias(oldSum.getExprId(),
|
||||
new Multiply(new Sum(bottomSum), leftCnt.toSlot()), oldSum.getName());
|
||||
newOutputExprs.add(newSum);
|
||||
}
|
||||
return agg.withAggOutput(newOutputExprs).withChildren(newJoin);
|
||||
}).toRule(RuleType.EAGER_SPLIT);
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,101 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package org.apache.doris.nereids.rules.exploration;
|
||||
|
||||
import org.apache.doris.common.Pair;
|
||||
import org.apache.doris.nereids.trees.expressions.Alias;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
|
||||
import org.apache.doris.nereids.trees.plans.JoinType;
|
||||
import org.apache.doris.nereids.trees.plans.algebra.Aggregate;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
|
||||
import org.apache.doris.nereids.util.LogicalPlanBuilder;
|
||||
import org.apache.doris.nereids.util.MemoPatternMatchSupported;
|
||||
import org.apache.doris.nereids.util.MemoTestUtils;
|
||||
import org.apache.doris.nereids.util.PlanChecker;
|
||||
import org.apache.doris.nereids.util.PlanConstructor;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
class EagerGroupByCountTest implements MemoPatternMatchSupported {
|
||||
|
||||
private final LogicalOlapScan scan1 = new LogicalOlapScan(PlanConstructor.getNextRelationId(),
|
||||
PlanConstructor.student, ImmutableList.of(""));
|
||||
private final LogicalOlapScan scan2 = new LogicalOlapScan(PlanConstructor.getNextRelationId(),
|
||||
PlanConstructor.score, ImmutableList.of(""));
|
||||
|
||||
@Test
|
||||
void singleSum() {
|
||||
LogicalPlan agg = new LogicalPlanBuilder(scan1)
|
||||
.join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0))
|
||||
.aggGroupUsingIndex(ImmutableList.of(0, 4),
|
||||
ImmutableList.of(
|
||||
new Alias(new Sum(scan1.getOutput().get(3)), "lsum0"),
|
||||
new Alias(new Sum(scan2.getOutput().get(2)), "rsum0")
|
||||
))
|
||||
.build();
|
||||
PlanChecker.from(MemoTestUtils.createConnectContext(), agg)
|
||||
.applyExploration(EagerGroupByCount.INSTANCE.build())
|
||||
.printlnOrigin()
|
||||
.printlnExploration()
|
||||
.matchesExploration(
|
||||
logicalAggregate(
|
||||
logicalJoin(
|
||||
logicalAggregate().when(
|
||||
bottomAgg -> bottomAgg.getOutputExprsSql().equals("id, sum(age) AS `sum0`, count(1) AS `cnt`")),
|
||||
logicalOlapScan()
|
||||
)
|
||||
).when(newAgg ->
|
||||
newAgg.getGroupByExpressions().equals(((Aggregate) agg).getGroupByExpressions())
|
||||
&& newAgg.getOutputExprsSql().equals("sum(sum0) AS `lsum0`, (sum(grade) * cnt) AS `rsum0`"))
|
||||
);
|
||||
}
|
||||
|
||||
@Test
|
||||
void multiSum() {
|
||||
LogicalPlan agg = new LogicalPlanBuilder(scan1)
|
||||
.join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0))
|
||||
.aggGroupUsingIndex(ImmutableList.of(0, 4),
|
||||
ImmutableList.of(
|
||||
new Alias(new Sum(scan1.getOutput().get(1)), "lsum0"),
|
||||
new Alias(new Sum(scan1.getOutput().get(2)), "lsum1"),
|
||||
new Alias(new Sum(scan1.getOutput().get(3)), "lsum2"),
|
||||
new Alias(new Sum(scan2.getOutput().get(1)), "rsum0"),
|
||||
new Alias(new Sum(scan2.getOutput().get(2)), "rsum1")
|
||||
))
|
||||
.build();
|
||||
|
||||
PlanChecker.from(MemoTestUtils.createConnectContext(), agg)
|
||||
.applyExploration(EagerGroupByCount.INSTANCE.build())
|
||||
.printlnOrigin()
|
||||
.printlnExploration()
|
||||
.matchesExploration(
|
||||
logicalAggregate(
|
||||
logicalJoin(
|
||||
logicalAggregate().when(cntAgg -> cntAgg.getOutputExprsSql()
|
||||
.equals("id, sum(gender) AS `sum0`, sum(name) AS `sum1`, sum(age) AS `sum2`, count(1) AS `cnt`")),
|
||||
logicalOlapScan()
|
||||
)
|
||||
).when(newAgg ->
|
||||
newAgg.getGroupByExpressions().equals(((Aggregate) agg).getGroupByExpressions())
|
||||
&& newAgg.getOutputExprsSql()
|
||||
.equals("sum(sum0) AS `lsum0`, sum(sum1) AS `lsum1`, sum(sum2) AS `lsum2`, (sum(cid) * cnt) AS `rsum0`, (sum(grade) * cnt) AS `rsum1`"))
|
||||
);
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,102 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package org.apache.doris.nereids.rules.exploration;
|
||||
|
||||
import org.apache.doris.common.Pair;
|
||||
import org.apache.doris.nereids.trees.expressions.Alias;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
|
||||
import org.apache.doris.nereids.trees.plans.JoinType;
|
||||
import org.apache.doris.nereids.trees.plans.algebra.Aggregate;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
|
||||
import org.apache.doris.nereids.util.LogicalPlanBuilder;
|
||||
import org.apache.doris.nereids.util.MemoPatternMatchSupported;
|
||||
import org.apache.doris.nereids.util.MemoTestUtils;
|
||||
import org.apache.doris.nereids.util.PlanChecker;
|
||||
import org.apache.doris.nereids.util.PlanConstructor;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
class EagerSplitTest implements MemoPatternMatchSupported {
|
||||
|
||||
private final LogicalOlapScan scan1 = new LogicalOlapScan(PlanConstructor.getNextRelationId(),
|
||||
PlanConstructor.student, ImmutableList.of(""));
|
||||
private final LogicalOlapScan scan2 = new LogicalOlapScan(PlanConstructor.getNextRelationId(),
|
||||
PlanConstructor.score, ImmutableList.of(""));
|
||||
|
||||
@Test
|
||||
void singleSum() {
|
||||
LogicalPlan agg = new LogicalPlanBuilder(scan1)
|
||||
.join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0))
|
||||
.aggGroupUsingIndex(ImmutableList.of(0, 4),
|
||||
ImmutableList.of(
|
||||
new Alias(new Sum(scan1.getOutput().get(3)), "lsum0"),
|
||||
new Alias(new Sum(scan2.getOutput().get(2)), "rsum0")
|
||||
))
|
||||
.build();
|
||||
PlanChecker.from(MemoTestUtils.createConnectContext(), agg)
|
||||
.applyExploration(EagerSplit.INSTANCE.build())
|
||||
.printlnOrigin()
|
||||
.printlnExploration()
|
||||
.matchesExploration(
|
||||
logicalAggregate(
|
||||
logicalJoin(
|
||||
logicalAggregate().when(
|
||||
a -> a.getOutputExprsSql().equals("id, sum(age) AS `left_sum0`, count(1) AS `left_cnt`")),
|
||||
logicalAggregate().when(
|
||||
a -> a.getOutputExprsSql().equals("sid, sum(grade) AS `right_sum0`, count(1) AS `right_cnt`"))
|
||||
)
|
||||
).when(newAgg ->
|
||||
newAgg.getGroupByExpressions().equals(((Aggregate) agg).getGroupByExpressions())
|
||||
&& newAgg.getOutputExprsSql().equals("(sum(left_sum0) * right_cnt) AS `lsum0`, (sum(right_sum0) * left_cnt) AS `rsum0`"))
|
||||
);
|
||||
}
|
||||
|
||||
@Test
|
||||
void multiSum() {
|
||||
LogicalPlan agg = new LogicalPlanBuilder(scan1)
|
||||
.join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0))
|
||||
.aggGroupUsingIndex(ImmutableList.of(0, 4),
|
||||
ImmutableList.of(
|
||||
new Alias(new Sum(scan1.getOutput().get(1)), "lsum0"),
|
||||
new Alias(new Sum(scan1.getOutput().get(2)), "lsum1"),
|
||||
new Alias(new Sum(scan1.getOutput().get(3)), "lsum2"),
|
||||
new Alias(new Sum(scan2.getOutput().get(1)), "rsum0"),
|
||||
new Alias(new Sum(scan2.getOutput().get(2)), "rsum1")
|
||||
))
|
||||
.build();
|
||||
|
||||
PlanChecker.from(MemoTestUtils.createConnectContext(), agg)
|
||||
.applyExploration(EagerSplit.INSTANCE.build())
|
||||
.printlnExploration()
|
||||
.matchesExploration(
|
||||
logicalAggregate(
|
||||
logicalJoin(
|
||||
logicalAggregate().when(a -> a.getOutputExprsSql()
|
||||
.equals("id, sum(gender) AS `left_sum0`, sum(name) AS `left_sum1`, sum(age) AS `left_sum2`, count(1) AS `left_cnt`")),
|
||||
logicalAggregate().when(a -> a.getOutputExprsSql()
|
||||
.equals("sid, sum(cid) AS `right_sum0`, sum(grade) AS `right_sum1`, count(1) AS `right_cnt`"))
|
||||
)
|
||||
).when(newAgg ->
|
||||
newAgg.getGroupByExpressions().equals(((Aggregate) agg).getGroupByExpressions())
|
||||
&& newAgg.getOutputExprsSql()
|
||||
.equals("(sum(left_sum0) * right_cnt) AS `lsum0`, (sum(left_sum1) * right_cnt) AS `lsum1`, (sum(left_sum2) * right_cnt) AS `lsum2`, (sum(right_sum0) * left_cnt) AS `rsum0`, (sum(right_sum1) * left_cnt) AS `rsum1`"))
|
||||
);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user