diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownCountThroughJoin.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownCountThroughJoin.java new file mode 100644 index 0000000000..6d5bf8b75f --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownCountThroughJoin.java @@ -0,0 +1,186 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.rewrite; + +import org.apache.doris.nereids.rules.Rule; +import org.apache.doris.nereids.rules.RuleType; +import org.apache.doris.nereids.trees.expressions.Alias; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.Multiply; +import org.apache.doris.nereids.trees.expressions.NamedExpression; +import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; +import org.apache.doris.nereids.trees.expressions.functions.agg.Count; +import org.apache.doris.nereids.trees.expressions.functions.agg.Sum; +import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; +import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; + +import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableList.Builder; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +/** + * Count(*) + * Count(col) + */ +public class PushdownCountThroughJoin implements RewriteRuleFactory { + @Override + public List buildRules() { + return ImmutableList.of( + logicalAggregate(innerLogicalJoin()) + .when(agg -> agg.child().getOtherJoinConjuncts().size() == 0) + .whenNot(agg -> agg.child().children().stream().anyMatch(p -> p instanceof LogicalAggregate)) + .when(agg -> agg.getGroupByExpressions().stream().allMatch(e -> e instanceof Slot)) + .when(agg -> { + Set funcs = agg.getAggregateFunctions(); + return !funcs.isEmpty() && funcs.stream() + .allMatch(f -> f instanceof Count && f.child(0) instanceof Slot); + }) + .then(agg -> pushCount(agg, agg.child(), ImmutableList.of())) + .toRule(RuleType.PUSHDOWN_COUNT_THROUGH_JOIN), + logicalAggregate(logicalProject(innerLogicalJoin())) + .when(agg -> agg.child().isAllSlots()) + .when(agg -> agg.child().child().getOtherJoinConjuncts().size() == 0) + .whenNot(agg -> agg.child().children().stream().anyMatch(p -> p instanceof LogicalAggregate)) + .when(agg -> agg.getGroupByExpressions().stream().allMatch(e -> e instanceof Slot)) + .when(agg -> { + Set funcs = agg.getAggregateFunctions(); + return !funcs.isEmpty() && funcs.stream() + .allMatch(f -> f instanceof Count && f.child(0) instanceof Slot); + }) + .then(agg -> pushCount(agg, agg.child().child(), agg.child().getProjects())) + .toRule(RuleType.PUSHDOWN_COUNT_THROUGH_JOIN) + ); + } + + private LogicalAggregate pushCount(LogicalAggregate agg, + LogicalJoin join, List projects) { + List leftOutput = join.left().getOutput(); + List rightOutput = join.right().getOutput(); + + List leftCounts = new ArrayList<>(); + List rightCounts = new ArrayList<>(); + for (AggregateFunction f : agg.getAggregateFunctions()) { + Count count = (Count) f; + if (count.isCountStar()) { + // TODO: handle Count(*) + return null; + } + Slot slot = (Slot) count.child(0); + if (leftOutput.contains(slot)) { + leftCounts.add(count); + } else if (rightOutput.contains(slot)) { + rightCounts.add(count); + } else { + throw new IllegalStateException("Slot " + slot + " not found in join output"); + } + } + + // TODO: empty GroupBy + Set leftGroupBy = new HashSet<>(); + Set rightGroupBy = new HashSet<>(); + for (Expression e : agg.getGroupByExpressions()) { + Slot slot = (Slot) e; + if (leftOutput.contains(slot)) { + leftGroupBy.add(slot); + } else if (rightOutput.contains(slot)) { + rightGroupBy.add(slot); + } else { + return null; + } + } + join.getHashJoinConjuncts().forEach(e -> e.getInputSlots().forEach(slot -> { + if (leftOutput.contains(slot)) { + leftGroupBy.add(slot); + } else if (rightOutput.contains(slot)) { + rightGroupBy.add(slot); + } else { + throw new IllegalStateException("Slot " + slot + " not found in join output"); + } + })); + + Alias leftCnt = null; + Alias rightCnt = null; + // left Count agg + Map leftCntSlotToOutput = new HashMap<>(); + Builder leftCntAggOutputBuilder = ImmutableList.builder() + .addAll(leftGroupBy); + leftCounts.forEach(func -> { + Alias alias = func.alias(func.getName()); + leftCntSlotToOutput.put((Slot) func.child(0), alias); + leftCntAggOutputBuilder.add(alias); + }); + if (!rightCounts.isEmpty()) { + leftCnt = new Count().alias("leftCntStar"); + leftCntAggOutputBuilder.add(leftCnt); + } + LogicalAggregate leftCntAgg = new LogicalAggregate<>( + ImmutableList.copyOf(leftGroupBy), leftCntAggOutputBuilder.build(), join.left()); + + // right Count agg + Map rightCntSlotToOutput = new HashMap<>(); + Builder rightCntAggOutputBuilder = ImmutableList.builder() + .addAll(rightGroupBy); + rightCounts.forEach(func -> { + Alias alias = func.alias(func.getName()); + rightCntSlotToOutput.put((Slot) func.child(0), alias); + rightCntAggOutputBuilder.add(alias); + }); + + if (!leftCounts.isEmpty()) { + rightCnt = new Count().alias("rightCntStar"); + rightCntAggOutputBuilder.add(rightCnt); + } + LogicalAggregate rightCntAgg = new LogicalAggregate<>( + ImmutableList.copyOf(rightGroupBy), rightCntAggOutputBuilder.build(), join.right()); + + Plan newJoin = join.withChildren(leftCntAgg, rightCntAgg); + + // top Sum agg + // count(slot) -> sum( count(slot) * cnt ) + List newOutputExprs = new ArrayList<>(); + for (NamedExpression ne : agg.getOutputExpressions()) { + if (ne instanceof Alias && ((Alias) ne).child() instanceof Count) { + Count oldTopCnt = (Count) ((Alias) ne).child(); + Slot slot = (Slot) oldTopCnt.child(0); + if (leftCntSlotToOutput.containsKey(slot)) { + Preconditions.checkState(rightCnt != null); + Expression expr = new Sum(new Multiply(leftCntSlotToOutput.get(slot).toSlot(), rightCnt.toSlot())); + newOutputExprs.add((NamedExpression) ne.withChildren(expr)); + } else if (rightCntSlotToOutput.containsKey(slot)) { + Preconditions.checkState(leftCnt != null); + Expression expr = new Sum(new Multiply(rightCntSlotToOutput.get(slot).toSlot(), leftCnt.toSlot())); + newOutputExprs.add((NamedExpression) ne.withChildren(expr)); + } else { + throw new IllegalStateException("Slot " + slot + " not found in join output"); + } + } else { + newOutputExprs.add(ne); + } + } + return agg.withAggOutputChild(newOutputExprs, newJoin); + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Count.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Count.java index 6212210816..1defb09d46 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Count.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Count.java @@ -24,6 +24,7 @@ import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.functions.AlwaysNotNullable; import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature; import org.apache.doris.nereids.trees.expressions.functions.window.SupportWindowAnalytic; +import org.apache.doris.nereids.trees.expressions.literal.Literal; import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; import org.apache.doris.nereids.types.BigIntType; import org.apache.doris.nereids.types.DataType; @@ -63,6 +64,12 @@ public class Count extends AggregateFunction this.isStar = false; } + public boolean isCountStar() { + return isStar + || children.size() == 0 + || (children.size() == 1 && child(0) instanceof Literal); + } + @Override public void checkLegalityBeforeTypeCoercion() { // for multiple exprs count must be qualified with distinct diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushdownCountThroughJoinTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushdownCountThroughJoinTest.java new file mode 100644 index 0000000000..1a39a7c5ff --- /dev/null +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushdownCountThroughJoinTest.java @@ -0,0 +1,68 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.rewrite; + +import org.apache.doris.common.Pair; +import org.apache.doris.nereids.trees.expressions.Alias; +import org.apache.doris.nereids.trees.expressions.functions.agg.Count; +import org.apache.doris.nereids.trees.plans.JoinType; +import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; +import org.apache.doris.nereids.trees.plans.logical.LogicalPlan; +import org.apache.doris.nereids.util.LogicalPlanBuilder; +import org.apache.doris.nereids.util.MemoPatternMatchSupported; +import org.apache.doris.nereids.util.MemoTestUtils; +import org.apache.doris.nereids.util.PlanChecker; +import org.apache.doris.nereids.util.PlanConstructor; + +import com.google.common.collect.ImmutableList; +import org.junit.jupiter.api.Test; + +class PushdownCountThroughJoinTest implements MemoPatternMatchSupported { + private static final LogicalOlapScan scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0); + private static final LogicalOlapScan scan2 = PlanConstructor.newLogicalOlapScan(1, "t2", 0); + + @Test + void testSingleCount() { + Alias count = new Count(scan1.getOutput().get(0)).alias("count"); + LogicalPlan plan = new LogicalPlanBuilder(scan1) + .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0)) + .aggGroupUsingIndex(ImmutableList.of(0), ImmutableList.of(scan1.getOutput().get(0), count)) + .build(); + + PlanChecker.from(MemoTestUtils.createConnectContext(), plan) + .applyTopDown(new PushdownCountThroughJoin()) + .printlnTree(); + } + + @Test + void testMultiCount() { + Alias leftCnt1 = new Count(scan1.getOutput().get(0)).alias("leftCnt1"); + Alias leftCnt2 = new Count(scan1.getOutput().get(1)).alias("leftCnt2"); + Alias rightCnt1 = new Count(scan2.getOutput().get(1)).alias("rightCnt1"); + LogicalPlan plan = new LogicalPlanBuilder(scan1) + .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0)) + .aggGroupUsingIndex(ImmutableList.of(0), + ImmutableList.of(scan1.getOutput().get(0), leftCnt1, leftCnt2, rightCnt1)) + .build(); + + PlanChecker.from(MemoTestUtils.createConnectContext(), plan) + .applyTopDown(new PushdownCountThroughJoin()) + .printlnTree(); + } + +}