[feature](Nereids): pushdown distinct through join. (#21437)

This commit is contained in:
jakevin
2023-07-05 15:55:21 +08:00
committed by GitHub
parent 4d414c649a
commit 1121e7d0c3
8 changed files with 242 additions and 25 deletions

View File

@ -239,6 +239,10 @@ public class Rewriter extends AbstractBatchJobExecutor {
topDown(new BuildAggForUnion())
),
// topic("Distinct",
// costBased(custom(RuleType.PUSH_DOWN_DISTINCT_THROUGH_JOIN, PushdownDistinctThroughJoin::new))
// ),
topic("Window optimization",
topDown(
new PushdownLimit(),

View File

@ -147,6 +147,8 @@ public enum RuleType {
PUSHDOWN_FILTER_THROUGH_CTE(RuleTypeClass.REWRITE),
PUSHDOWN_FILTER_THROUGH_CTE_ANCHOR(RuleTypeClass.REWRITE),
PUSH_DOWN_DISTINCT_THROUGH_JOIN(RuleTypeClass.REWRITE),
COLUMN_PRUNING(RuleTypeClass.REWRITE),
PUSHDOWN_TOP_N_THROUGH_PROJECTION_WINDOW(RuleTypeClass.REWRITE),

View File

@ -47,7 +47,7 @@ public class ProjectWithDistinctToAggregate extends OneAnalysisRuleFactory {
logicalProject()
.when(LogicalProject::isDistinct)
.whenNot(project -> project.getProjects().stream().anyMatch(this::hasAggregateFunction))
.then(project -> new LogicalAggregate<>(project.getProjects(), project.child()))
.then(project -> new LogicalAggregate<>(project.getProjects(), false, project.child()))
);
}

View File

@ -50,7 +50,7 @@ public class InferSetOperatorDistinct extends OneRewriteRuleFactory {
}
List<Plan> newChildren = setOperation.children().stream()
.map(child -> new LogicalAggregate<>(ImmutableList.copyOf(child.getOutput()), child))
.map(child -> new LogicalAggregate<>(ImmutableList.copyOf(child.getOutput()), true, child))
.collect(ImmutableList.toImmutableList());
if (newChildren.equals(setOperation.children())) {
return null;

View File

@ -0,0 +1,92 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.rewrite;
import org.apache.doris.nereids.jobs.JobContext;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter;
import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter;
import com.google.common.collect.ImmutableList;
import java.util.function.Function;
/**
* PushdownDistinctThroughJoin
*/
public class PushdownDistinctThroughJoin extends DefaultPlanRewriter<JobContext> implements CustomRewriter {
@Override
public Plan rewriteRoot(Plan plan, JobContext context) {
return plan.accept(this, context);
}
@Override
public Plan visitLogicalAggregate(LogicalAggregate<? extends Plan> agg, JobContext context) {
agg = visitChildren(this, agg, context);
if (agg.hasPushed() || !agg.isDistinct() || isLeaf(agg.child())) {
return agg;
}
// After we push down distinct, if this distinct is generated, we will eliminate this distinct
if (agg.isGenerated()) {
return skipProjectPushDistinct(agg.child());
} else {
return agg.withChildren(skipProjectPushDistinct(agg.child()));
}
}
private Plan skipProjectPushDistinct(Plan plan) {
if (plan instanceof LogicalProject) {
LogicalProject project = (LogicalProject) plan;
Plan pushJoin = pushDistinct((LogicalJoin<? extends Plan, ? extends Plan>) project.child());
return project.withChildren(ImmutableList.of(pushJoin));
} else {
Plan pushJoin = pushDistinct((LogicalJoin<? extends Plan, ? extends Plan>) plan);
return pushJoin;
}
}
private Plan pushDistinct(LogicalJoin<? extends Plan, ? extends Plan> join) {
Function<Plan, Plan> pushChild = (Plan plan) -> {
if (isLeaf(plan)) {
return withDistinct(plan);
} else {
// Due to there isn't statistics during Rewrite, so we just push down through 1 join.
// return skipProjectPushDistinct(plan);
return withDistinct(plan);
}
};
Plan left = pushChild.apply(join.left());
Plan right = pushChild.apply(join.right());
return join.withChildren(ImmutableList.of(left, right));
}
private Plan withDistinct(Plan plan) {
return new LogicalAggregate<>(ImmutableList.copyOf(plan.getOutput()), true, true, plan);
}
private boolean isLeaf(Plan plan) {
if (plan instanceof LogicalProject && ((LogicalProject<?>) plan).isAllSlots()) {
plan = plan.child(0);
}
return !(plan instanceof LogicalJoin);
}
}

View File

@ -56,14 +56,16 @@ public class LogicalAggregate<CHILD_TYPE extends Plan>
extends LogicalUnary<CHILD_TYPE>
implements Aggregate<CHILD_TYPE> {
private final boolean normalized;
private final List<Expression> groupByExpressions;
private final List<NamedExpression> outputExpressions;
// When there are grouping sets/rollup/cube, LogicalAgg is generated by LogicalRepeat.
private final Optional<LogicalRepeat> sourceRepeat;
private final boolean normalized;
private final boolean ordinalIsResolved;
private final boolean generated;
private final boolean hasPushed;
/**
* Desc: Constructor for LogicalAggregate.
@ -79,13 +81,20 @@ public class LogicalAggregate<CHILD_TYPE extends Plan>
/**
* Distinct Agg
*/
public LogicalAggregate(List<NamedExpression> namedExpressions, CHILD_TYPE child) {
this(ImmutableList.copyOf(namedExpressions), namedExpressions, false, Optional.empty(), child);
public LogicalAggregate(List<NamedExpression> namedExpressions, boolean generated, CHILD_TYPE child) {
this(ImmutableList.copyOf(namedExpressions), namedExpressions, false, true, generated, false, Optional.empty(),
Optional.empty(), Optional.empty(), child);
}
public LogicalAggregate(List<NamedExpression> namedExpressions, boolean generated, boolean hasPushed,
CHILD_TYPE child) {
this(ImmutableList.copyOf(namedExpressions), namedExpressions, false, true, generated, hasPushed,
Optional.empty(), Optional.empty(), Optional.empty(), child);
}
public LogicalAggregate(List<Expression> groupByExpressions,
List<NamedExpression> outputExpressions, boolean ordinalIsResolved, CHILD_TYPE child) {
this(groupByExpressions, outputExpressions, false, ordinalIsResolved, Optional.empty(),
this(groupByExpressions, outputExpressions, false, ordinalIsResolved, false, false, Optional.empty(),
Optional.empty(), Optional.empty(), child);
}
@ -107,18 +116,20 @@ public class LogicalAggregate<CHILD_TYPE extends Plan>
boolean normalized,
Optional<LogicalRepeat> sourceRepeat,
CHILD_TYPE child) {
this(groupByExpressions, outputExpressions, normalized, false, sourceRepeat,
this(groupByExpressions, outputExpressions, normalized, false, false, false, sourceRepeat,
Optional.empty(), Optional.empty(), child);
}
/**
* Whole parameters constructor for LogicalAggregate.
*/
public LogicalAggregate(
private LogicalAggregate(
List<Expression> groupByExpressions,
List<NamedExpression> outputExpressions,
boolean normalized,
boolean ordinalIsResolved,
boolean generated,
boolean hasPushed,
Optional<LogicalRepeat> sourceRepeat,
Optional<GroupExpression> groupExpression,
Optional<LogicalProperties> logicalProperties,
@ -128,6 +139,8 @@ public class LogicalAggregate<CHILD_TYPE extends Plan>
this.outputExpressions = ImmutableList.copyOf(outputExpressions);
this.normalized = normalized;
this.ordinalIsResolved = ordinalIsResolved;
this.generated = generated;
this.hasPushed = hasPushed;
this.sourceRepeat = Objects.requireNonNull(sourceRepeat, "sourceRepeat cannot be null");
}
@ -151,6 +164,18 @@ public class LogicalAggregate<CHILD_TYPE extends Plan>
return sourceRepeat.isPresent();
}
public boolean isDistinct() {
return outputExpressions.equals(groupByExpressions);
}
public boolean isGenerated() {
return generated;
}
public boolean hasPushed() {
return hasPushed;
}
@Override
public String toString() {
return Utils.toSqlString("LogicalAggregate[" + id.asInt() + "]",
@ -203,6 +228,7 @@ public class LogicalAggregate<CHILD_TYPE extends Plan>
&& Objects.equals(outputExpressions, that.outputExpressions)
&& normalized == that.normalized
&& ordinalIsResolved == that.ordinalIsResolved
&& generated == that.generated
&& Objects.equals(sourceRepeat, that.sourceRepeat);
}
@ -214,28 +240,26 @@ public class LogicalAggregate<CHILD_TYPE extends Plan>
@Override
public LogicalAggregate<Plan> withChildren(List<Plan> children) {
Preconditions.checkArgument(children.size() == 1);
return new LogicalAggregate<>(groupByExpressions, outputExpressions,
normalized, ordinalIsResolved, sourceRepeat, Optional.empty(), Optional.empty(), children.get(0));
return new LogicalAggregate<>(groupByExpressions, outputExpressions, normalized, ordinalIsResolved, generated,
hasPushed, sourceRepeat, Optional.empty(), Optional.empty(), children.get(0));
}
@Override
public LogicalAggregate<Plan> withGroupExpression(Optional<GroupExpression> groupExpression) {
return new LogicalAggregate<>(groupByExpressions, outputExpressions,
normalized, ordinalIsResolved, sourceRepeat, groupExpression, Optional.of(getLogicalProperties()),
children.get(0));
return new LogicalAggregate<>(groupByExpressions, outputExpressions, normalized, ordinalIsResolved, generated,
hasPushed, sourceRepeat, groupExpression, Optional.of(getLogicalProperties()), children.get(0));
}
@Override
public LogicalAggregate<Plan> withLogicalProperties(Optional<LogicalProperties> logicalProperties) {
return new LogicalAggregate<>(groupByExpressions, outputExpressions,
normalized, ordinalIsResolved, sourceRepeat,
Optional.empty(), logicalProperties, children.get(0));
return new LogicalAggregate<>(groupByExpressions, outputExpressions, normalized, ordinalIsResolved, generated,
hasPushed, sourceRepeat, Optional.empty(), logicalProperties, children.get(0));
}
public LogicalAggregate<Plan> withGroupByAndOutput(List<Expression> groupByExprList,
List<NamedExpression> outputExpressionList) {
return new LogicalAggregate<>(groupByExprList, outputExpressionList, normalized, ordinalIsResolved,
sourceRepeat, Optional.empty(), Optional.empty(), child());
return new LogicalAggregate<>(groupByExprList, outputExpressionList, normalized, ordinalIsResolved, generated,
hasPushed, sourceRepeat, Optional.empty(), Optional.empty(), child());
}
@Override
@ -245,18 +269,18 @@ public class LogicalAggregate<CHILD_TYPE extends Plan>
@Override
public LogicalAggregate<CHILD_TYPE> withAggOutput(List<NamedExpression> newOutput) {
return new LogicalAggregate<>(groupByExpressions, newOutput, normalized, ordinalIsResolved,
sourceRepeat, Optional.empty(), Optional.empty(), child());
return new LogicalAggregate<>(groupByExpressions, newOutput, normalized, ordinalIsResolved, generated,
hasPushed, sourceRepeat, Optional.empty(), Optional.empty(), child());
}
public LogicalAggregate<Plan> withAggOutputChild(List<NamedExpression> newOutput, Plan newChild) {
return new LogicalAggregate<>(groupByExpressions, newOutput, normalized, ordinalIsResolved,
sourceRepeat, Optional.empty(), Optional.empty(), newChild);
return new LogicalAggregate<>(groupByExpressions, newOutput, normalized, ordinalIsResolved, generated,
hasPushed, sourceRepeat, Optional.empty(), Optional.empty(), newChild);
}
public LogicalAggregate<Plan> withNormalized(List<Expression> normalizedGroupBy,
List<NamedExpression> normalizedOutput, Plan normalizedChild) {
return new LogicalAggregate<>(normalizedGroupBy, normalizedOutput, true, ordinalIsResolved,
sourceRepeat, Optional.empty(), Optional.empty(), normalizedChild);
return new LogicalAggregate<>(normalizedGroupBy, normalizedOutput, true, ordinalIsResolved, generated,
hasPushed, sourceRepeat, Optional.empty(), Optional.empty(), normalizedChild);
}
}

View File

@ -0,0 +1,84 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.rewrite;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.util.LogicalPlanBuilder;
import org.apache.doris.nereids.util.MemoPatternMatchSupported;
import org.apache.doris.nereids.util.MemoTestUtils;
import org.apache.doris.nereids.util.PlanChecker;
import org.apache.doris.nereids.util.PlanConstructor;
import com.google.common.collect.ImmutableList;
import org.junit.jupiter.api.Test;
class PushdownDistinctThroughJoinTest implements MemoPatternMatchSupported {
private static final LogicalOlapScan scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0);
private static final LogicalOlapScan scan2 = PlanConstructor.newLogicalOlapScan(1, "t2", 0);
private static final LogicalOlapScan scan3 = PlanConstructor.newLogicalOlapScan(2, "t3", 0);
private static final LogicalOlapScan scan4 = PlanConstructor.newLogicalOlapScan(3, "t4", 0);
@Test
void testPushdownJoin() {
LogicalPlan plan = new LogicalPlanBuilder(scan1)
.join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0))
.join(scan3, JoinType.INNER_JOIN, Pair.of(0, 0))
.join(scan4, JoinType.INNER_JOIN, Pair.of(0, 0))
.distinct(ImmutableList.of(1, 3, 5, 7))
.build();
PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
.applyTopDown(new PushdownDistinctThroughJoin())
.matches(
logicalAggregate(
logicalJoin(
logicalAggregate(logicalJoin()),
logicalAggregate(logicalOlapScan())
)
)
)
.printlnTree();
}
@Test
void testPushdownProjectJoin() {
LogicalPlan plan = new LogicalPlanBuilder(scan1)
.join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0))
.project(ImmutableList.of(0, 2))
.join(scan3, JoinType.INNER_JOIN, Pair.of(0, 0))
.project(ImmutableList.of(0, 2, 3))
.join(scan4, JoinType.INNER_JOIN, Pair.of(0, 0))
.distinct(ImmutableList.of(1, 2, 3))
.build();
PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
.applyTopDown(new PushdownDistinctThroughJoin())
.matches(
logicalAggregate(
logicalJoin(
logicalAggregate(logicalProject(logicalJoin())),
logicalAggregate(logicalOlapScan())
)
)
)
.printlnTree();
}
}

View File

@ -171,12 +171,23 @@ public class LogicalPlanBuilder {
for (Integer index : groupByKeysIndex) {
groupByBuilder.add(this.plan.getOutput().get(index));
}
ImmutableList<Expression> groupByKeys = groupByBuilder.build();
List<Expression> groupByKeys = groupByBuilder.build();
LogicalAggregate<Plan> agg = new LogicalAggregate<>(groupByKeys, outputExprsList, this.plan);
return from(agg);
}
public LogicalPlanBuilder distinct(List<Integer> groupByKeysIndex) {
Builder<NamedExpression> groupByBuilder = ImmutableList.builder();
for (Integer index : groupByKeysIndex) {
groupByBuilder.add(this.plan.getOutput().get(index));
}
List<NamedExpression> groupByKeys = groupByBuilder.build();
LogicalAggregate<Plan> agg = new LogicalAggregate<>(groupByKeys, false, this.plan);
return from(agg);
}
public LogicalPlanBuilder agg(List<Expression> groupByKeys, List<NamedExpression> outputExprsList) {
LogicalAggregate<Plan> agg = new LogicalAggregate<>(groupByKeys, outputExprsList, this.plan);
return from(agg);