[feature](Nereids): pushdown distinct through join. (#21437)
This commit is contained in:
@ -239,6 +239,10 @@ public class Rewriter extends AbstractBatchJobExecutor {
|
||||
topDown(new BuildAggForUnion())
|
||||
),
|
||||
|
||||
// topic("Distinct",
|
||||
// costBased(custom(RuleType.PUSH_DOWN_DISTINCT_THROUGH_JOIN, PushdownDistinctThroughJoin::new))
|
||||
// ),
|
||||
|
||||
topic("Window optimization",
|
||||
topDown(
|
||||
new PushdownLimit(),
|
||||
|
||||
@ -147,6 +147,8 @@ public enum RuleType {
|
||||
PUSHDOWN_FILTER_THROUGH_CTE(RuleTypeClass.REWRITE),
|
||||
PUSHDOWN_FILTER_THROUGH_CTE_ANCHOR(RuleTypeClass.REWRITE),
|
||||
|
||||
PUSH_DOWN_DISTINCT_THROUGH_JOIN(RuleTypeClass.REWRITE),
|
||||
|
||||
COLUMN_PRUNING(RuleTypeClass.REWRITE),
|
||||
|
||||
PUSHDOWN_TOP_N_THROUGH_PROJECTION_WINDOW(RuleTypeClass.REWRITE),
|
||||
|
||||
@ -47,7 +47,7 @@ public class ProjectWithDistinctToAggregate extends OneAnalysisRuleFactory {
|
||||
logicalProject()
|
||||
.when(LogicalProject::isDistinct)
|
||||
.whenNot(project -> project.getProjects().stream().anyMatch(this::hasAggregateFunction))
|
||||
.then(project -> new LogicalAggregate<>(project.getProjects(), project.child()))
|
||||
.then(project -> new LogicalAggregate<>(project.getProjects(), false, project.child()))
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
@ -50,7 +50,7 @@ public class InferSetOperatorDistinct extends OneRewriteRuleFactory {
|
||||
}
|
||||
|
||||
List<Plan> newChildren = setOperation.children().stream()
|
||||
.map(child -> new LogicalAggregate<>(ImmutableList.copyOf(child.getOutput()), child))
|
||||
.map(child -> new LogicalAggregate<>(ImmutableList.copyOf(child.getOutput()), true, child))
|
||||
.collect(ImmutableList.toImmutableList());
|
||||
if (newChildren.equals(setOperation.children())) {
|
||||
return null;
|
||||
|
||||
@ -0,0 +1,92 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package org.apache.doris.nereids.rules.rewrite;
|
||||
|
||||
import org.apache.doris.nereids.jobs.JobContext;
|
||||
import org.apache.doris.nereids.trees.plans.Plan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
|
||||
import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter;
|
||||
import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
|
||||
import java.util.function.Function;
|
||||
|
||||
/**
|
||||
* PushdownDistinctThroughJoin
|
||||
*/
|
||||
public class PushdownDistinctThroughJoin extends DefaultPlanRewriter<JobContext> implements CustomRewriter {
|
||||
@Override
|
||||
public Plan rewriteRoot(Plan plan, JobContext context) {
|
||||
return plan.accept(this, context);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Plan visitLogicalAggregate(LogicalAggregate<? extends Plan> agg, JobContext context) {
|
||||
agg = visitChildren(this, agg, context);
|
||||
if (agg.hasPushed() || !agg.isDistinct() || isLeaf(agg.child())) {
|
||||
return agg;
|
||||
}
|
||||
|
||||
// After we push down distinct, if this distinct is generated, we will eliminate this distinct
|
||||
if (agg.isGenerated()) {
|
||||
return skipProjectPushDistinct(agg.child());
|
||||
} else {
|
||||
return agg.withChildren(skipProjectPushDistinct(agg.child()));
|
||||
}
|
||||
}
|
||||
|
||||
private Plan skipProjectPushDistinct(Plan plan) {
|
||||
if (plan instanceof LogicalProject) {
|
||||
LogicalProject project = (LogicalProject) plan;
|
||||
Plan pushJoin = pushDistinct((LogicalJoin<? extends Plan, ? extends Plan>) project.child());
|
||||
return project.withChildren(ImmutableList.of(pushJoin));
|
||||
} else {
|
||||
Plan pushJoin = pushDistinct((LogicalJoin<? extends Plan, ? extends Plan>) plan);
|
||||
return pushJoin;
|
||||
}
|
||||
}
|
||||
|
||||
private Plan pushDistinct(LogicalJoin<? extends Plan, ? extends Plan> join) {
|
||||
Function<Plan, Plan> pushChild = (Plan plan) -> {
|
||||
if (isLeaf(plan)) {
|
||||
return withDistinct(plan);
|
||||
} else {
|
||||
// Due to there isn't statistics during Rewrite, so we just push down through 1 join.
|
||||
// return skipProjectPushDistinct(plan);
|
||||
return withDistinct(plan);
|
||||
}
|
||||
};
|
||||
Plan left = pushChild.apply(join.left());
|
||||
Plan right = pushChild.apply(join.right());
|
||||
return join.withChildren(ImmutableList.of(left, right));
|
||||
}
|
||||
|
||||
private Plan withDistinct(Plan plan) {
|
||||
return new LogicalAggregate<>(ImmutableList.copyOf(plan.getOutput()), true, true, plan);
|
||||
}
|
||||
|
||||
private boolean isLeaf(Plan plan) {
|
||||
if (plan instanceof LogicalProject && ((LogicalProject<?>) plan).isAllSlots()) {
|
||||
plan = plan.child(0);
|
||||
}
|
||||
return !(plan instanceof LogicalJoin);
|
||||
}
|
||||
}
|
||||
@ -56,14 +56,16 @@ public class LogicalAggregate<CHILD_TYPE extends Plan>
|
||||
extends LogicalUnary<CHILD_TYPE>
|
||||
implements Aggregate<CHILD_TYPE> {
|
||||
|
||||
private final boolean normalized;
|
||||
private final List<Expression> groupByExpressions;
|
||||
private final List<NamedExpression> outputExpressions;
|
||||
|
||||
// When there are grouping sets/rollup/cube, LogicalAgg is generated by LogicalRepeat.
|
||||
private final Optional<LogicalRepeat> sourceRepeat;
|
||||
|
||||
private final boolean normalized;
|
||||
private final boolean ordinalIsResolved;
|
||||
private final boolean generated;
|
||||
private final boolean hasPushed;
|
||||
|
||||
/**
|
||||
* Desc: Constructor for LogicalAggregate.
|
||||
@ -79,13 +81,20 @@ public class LogicalAggregate<CHILD_TYPE extends Plan>
|
||||
/**
|
||||
* Distinct Agg
|
||||
*/
|
||||
public LogicalAggregate(List<NamedExpression> namedExpressions, CHILD_TYPE child) {
|
||||
this(ImmutableList.copyOf(namedExpressions), namedExpressions, false, Optional.empty(), child);
|
||||
public LogicalAggregate(List<NamedExpression> namedExpressions, boolean generated, CHILD_TYPE child) {
|
||||
this(ImmutableList.copyOf(namedExpressions), namedExpressions, false, true, generated, false, Optional.empty(),
|
||||
Optional.empty(), Optional.empty(), child);
|
||||
}
|
||||
|
||||
public LogicalAggregate(List<NamedExpression> namedExpressions, boolean generated, boolean hasPushed,
|
||||
CHILD_TYPE child) {
|
||||
this(ImmutableList.copyOf(namedExpressions), namedExpressions, false, true, generated, hasPushed,
|
||||
Optional.empty(), Optional.empty(), Optional.empty(), child);
|
||||
}
|
||||
|
||||
public LogicalAggregate(List<Expression> groupByExpressions,
|
||||
List<NamedExpression> outputExpressions, boolean ordinalIsResolved, CHILD_TYPE child) {
|
||||
this(groupByExpressions, outputExpressions, false, ordinalIsResolved, Optional.empty(),
|
||||
this(groupByExpressions, outputExpressions, false, ordinalIsResolved, false, false, Optional.empty(),
|
||||
Optional.empty(), Optional.empty(), child);
|
||||
}
|
||||
|
||||
@ -107,18 +116,20 @@ public class LogicalAggregate<CHILD_TYPE extends Plan>
|
||||
boolean normalized,
|
||||
Optional<LogicalRepeat> sourceRepeat,
|
||||
CHILD_TYPE child) {
|
||||
this(groupByExpressions, outputExpressions, normalized, false, sourceRepeat,
|
||||
this(groupByExpressions, outputExpressions, normalized, false, false, false, sourceRepeat,
|
||||
Optional.empty(), Optional.empty(), child);
|
||||
}
|
||||
|
||||
/**
|
||||
* Whole parameters constructor for LogicalAggregate.
|
||||
*/
|
||||
public LogicalAggregate(
|
||||
private LogicalAggregate(
|
||||
List<Expression> groupByExpressions,
|
||||
List<NamedExpression> outputExpressions,
|
||||
boolean normalized,
|
||||
boolean ordinalIsResolved,
|
||||
boolean generated,
|
||||
boolean hasPushed,
|
||||
Optional<LogicalRepeat> sourceRepeat,
|
||||
Optional<GroupExpression> groupExpression,
|
||||
Optional<LogicalProperties> logicalProperties,
|
||||
@ -128,6 +139,8 @@ public class LogicalAggregate<CHILD_TYPE extends Plan>
|
||||
this.outputExpressions = ImmutableList.copyOf(outputExpressions);
|
||||
this.normalized = normalized;
|
||||
this.ordinalIsResolved = ordinalIsResolved;
|
||||
this.generated = generated;
|
||||
this.hasPushed = hasPushed;
|
||||
this.sourceRepeat = Objects.requireNonNull(sourceRepeat, "sourceRepeat cannot be null");
|
||||
}
|
||||
|
||||
@ -151,6 +164,18 @@ public class LogicalAggregate<CHILD_TYPE extends Plan>
|
||||
return sourceRepeat.isPresent();
|
||||
}
|
||||
|
||||
public boolean isDistinct() {
|
||||
return outputExpressions.equals(groupByExpressions);
|
||||
}
|
||||
|
||||
public boolean isGenerated() {
|
||||
return generated;
|
||||
}
|
||||
|
||||
public boolean hasPushed() {
|
||||
return hasPushed;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return Utils.toSqlString("LogicalAggregate[" + id.asInt() + "]",
|
||||
@ -203,6 +228,7 @@ public class LogicalAggregate<CHILD_TYPE extends Plan>
|
||||
&& Objects.equals(outputExpressions, that.outputExpressions)
|
||||
&& normalized == that.normalized
|
||||
&& ordinalIsResolved == that.ordinalIsResolved
|
||||
&& generated == that.generated
|
||||
&& Objects.equals(sourceRepeat, that.sourceRepeat);
|
||||
}
|
||||
|
||||
@ -214,28 +240,26 @@ public class LogicalAggregate<CHILD_TYPE extends Plan>
|
||||
@Override
|
||||
public LogicalAggregate<Plan> withChildren(List<Plan> children) {
|
||||
Preconditions.checkArgument(children.size() == 1);
|
||||
return new LogicalAggregate<>(groupByExpressions, outputExpressions,
|
||||
normalized, ordinalIsResolved, sourceRepeat, Optional.empty(), Optional.empty(), children.get(0));
|
||||
return new LogicalAggregate<>(groupByExpressions, outputExpressions, normalized, ordinalIsResolved, generated,
|
||||
hasPushed, sourceRepeat, Optional.empty(), Optional.empty(), children.get(0));
|
||||
}
|
||||
|
||||
@Override
|
||||
public LogicalAggregate<Plan> withGroupExpression(Optional<GroupExpression> groupExpression) {
|
||||
return new LogicalAggregate<>(groupByExpressions, outputExpressions,
|
||||
normalized, ordinalIsResolved, sourceRepeat, groupExpression, Optional.of(getLogicalProperties()),
|
||||
children.get(0));
|
||||
return new LogicalAggregate<>(groupByExpressions, outputExpressions, normalized, ordinalIsResolved, generated,
|
||||
hasPushed, sourceRepeat, groupExpression, Optional.of(getLogicalProperties()), children.get(0));
|
||||
}
|
||||
|
||||
@Override
|
||||
public LogicalAggregate<Plan> withLogicalProperties(Optional<LogicalProperties> logicalProperties) {
|
||||
return new LogicalAggregate<>(groupByExpressions, outputExpressions,
|
||||
normalized, ordinalIsResolved, sourceRepeat,
|
||||
Optional.empty(), logicalProperties, children.get(0));
|
||||
return new LogicalAggregate<>(groupByExpressions, outputExpressions, normalized, ordinalIsResolved, generated,
|
||||
hasPushed, sourceRepeat, Optional.empty(), logicalProperties, children.get(0));
|
||||
}
|
||||
|
||||
public LogicalAggregate<Plan> withGroupByAndOutput(List<Expression> groupByExprList,
|
||||
List<NamedExpression> outputExpressionList) {
|
||||
return new LogicalAggregate<>(groupByExprList, outputExpressionList, normalized, ordinalIsResolved,
|
||||
sourceRepeat, Optional.empty(), Optional.empty(), child());
|
||||
return new LogicalAggregate<>(groupByExprList, outputExpressionList, normalized, ordinalIsResolved, generated,
|
||||
hasPushed, sourceRepeat, Optional.empty(), Optional.empty(), child());
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -245,18 +269,18 @@ public class LogicalAggregate<CHILD_TYPE extends Plan>
|
||||
|
||||
@Override
|
||||
public LogicalAggregate<CHILD_TYPE> withAggOutput(List<NamedExpression> newOutput) {
|
||||
return new LogicalAggregate<>(groupByExpressions, newOutput, normalized, ordinalIsResolved,
|
||||
sourceRepeat, Optional.empty(), Optional.empty(), child());
|
||||
return new LogicalAggregate<>(groupByExpressions, newOutput, normalized, ordinalIsResolved, generated,
|
||||
hasPushed, sourceRepeat, Optional.empty(), Optional.empty(), child());
|
||||
}
|
||||
|
||||
public LogicalAggregate<Plan> withAggOutputChild(List<NamedExpression> newOutput, Plan newChild) {
|
||||
return new LogicalAggregate<>(groupByExpressions, newOutput, normalized, ordinalIsResolved,
|
||||
sourceRepeat, Optional.empty(), Optional.empty(), newChild);
|
||||
return new LogicalAggregate<>(groupByExpressions, newOutput, normalized, ordinalIsResolved, generated,
|
||||
hasPushed, sourceRepeat, Optional.empty(), Optional.empty(), newChild);
|
||||
}
|
||||
|
||||
public LogicalAggregate<Plan> withNormalized(List<Expression> normalizedGroupBy,
|
||||
List<NamedExpression> normalizedOutput, Plan normalizedChild) {
|
||||
return new LogicalAggregate<>(normalizedGroupBy, normalizedOutput, true, ordinalIsResolved,
|
||||
sourceRepeat, Optional.empty(), Optional.empty(), normalizedChild);
|
||||
return new LogicalAggregate<>(normalizedGroupBy, normalizedOutput, true, ordinalIsResolved, generated,
|
||||
hasPushed, sourceRepeat, Optional.empty(), Optional.empty(), normalizedChild);
|
||||
}
|
||||
}
|
||||
|
||||
@ -0,0 +1,84 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package org.apache.doris.nereids.rules.rewrite;
|
||||
|
||||
import org.apache.doris.common.Pair;
|
||||
import org.apache.doris.nereids.trees.plans.JoinType;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
|
||||
import org.apache.doris.nereids.util.LogicalPlanBuilder;
|
||||
import org.apache.doris.nereids.util.MemoPatternMatchSupported;
|
||||
import org.apache.doris.nereids.util.MemoTestUtils;
|
||||
import org.apache.doris.nereids.util.PlanChecker;
|
||||
import org.apache.doris.nereids.util.PlanConstructor;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
class PushdownDistinctThroughJoinTest implements MemoPatternMatchSupported {
|
||||
private static final LogicalOlapScan scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0);
|
||||
private static final LogicalOlapScan scan2 = PlanConstructor.newLogicalOlapScan(1, "t2", 0);
|
||||
private static final LogicalOlapScan scan3 = PlanConstructor.newLogicalOlapScan(2, "t3", 0);
|
||||
private static final LogicalOlapScan scan4 = PlanConstructor.newLogicalOlapScan(3, "t4", 0);
|
||||
|
||||
@Test
|
||||
void testPushdownJoin() {
|
||||
LogicalPlan plan = new LogicalPlanBuilder(scan1)
|
||||
.join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0))
|
||||
.join(scan3, JoinType.INNER_JOIN, Pair.of(0, 0))
|
||||
.join(scan4, JoinType.INNER_JOIN, Pair.of(0, 0))
|
||||
.distinct(ImmutableList.of(1, 3, 5, 7))
|
||||
.build();
|
||||
|
||||
PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
|
||||
.applyTopDown(new PushdownDistinctThroughJoin())
|
||||
.matches(
|
||||
logicalAggregate(
|
||||
logicalJoin(
|
||||
logicalAggregate(logicalJoin()),
|
||||
logicalAggregate(logicalOlapScan())
|
||||
)
|
||||
)
|
||||
)
|
||||
.printlnTree();
|
||||
}
|
||||
|
||||
@Test
|
||||
void testPushdownProjectJoin() {
|
||||
LogicalPlan plan = new LogicalPlanBuilder(scan1)
|
||||
.join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0))
|
||||
.project(ImmutableList.of(0, 2))
|
||||
.join(scan3, JoinType.INNER_JOIN, Pair.of(0, 0))
|
||||
.project(ImmutableList.of(0, 2, 3))
|
||||
.join(scan4, JoinType.INNER_JOIN, Pair.of(0, 0))
|
||||
.distinct(ImmutableList.of(1, 2, 3))
|
||||
.build();
|
||||
|
||||
PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
|
||||
.applyTopDown(new PushdownDistinctThroughJoin())
|
||||
.matches(
|
||||
logicalAggregate(
|
||||
logicalJoin(
|
||||
logicalAggregate(logicalProject(logicalJoin())),
|
||||
logicalAggregate(logicalOlapScan())
|
||||
)
|
||||
)
|
||||
)
|
||||
.printlnTree();
|
||||
}
|
||||
}
|
||||
@ -171,12 +171,23 @@ public class LogicalPlanBuilder {
|
||||
for (Integer index : groupByKeysIndex) {
|
||||
groupByBuilder.add(this.plan.getOutput().get(index));
|
||||
}
|
||||
ImmutableList<Expression> groupByKeys = groupByBuilder.build();
|
||||
List<Expression> groupByKeys = groupByBuilder.build();
|
||||
|
||||
LogicalAggregate<Plan> agg = new LogicalAggregate<>(groupByKeys, outputExprsList, this.plan);
|
||||
return from(agg);
|
||||
}
|
||||
|
||||
public LogicalPlanBuilder distinct(List<Integer> groupByKeysIndex) {
|
||||
Builder<NamedExpression> groupByBuilder = ImmutableList.builder();
|
||||
for (Integer index : groupByKeysIndex) {
|
||||
groupByBuilder.add(this.plan.getOutput().get(index));
|
||||
}
|
||||
List<NamedExpression> groupByKeys = groupByBuilder.build();
|
||||
|
||||
LogicalAggregate<Plan> agg = new LogicalAggregate<>(groupByKeys, false, this.plan);
|
||||
return from(agg);
|
||||
}
|
||||
|
||||
public LogicalPlanBuilder agg(List<Expression> groupByKeys, List<NamedExpression> outputExprsList) {
|
||||
LogicalAggregate<Plan> agg = new LogicalAggregate<>(groupByKeys, outputExprsList, this.plan);
|
||||
return from(agg);
|
||||
|
||||
Reference in New Issue
Block a user