[feature](Nereids): semi join transpose. (#12515)

* [feature](Nereids): semi join transpose.

* fix conditionChecker and check lasscom
This commit is contained in:
jakevin
2022-09-13 13:32:47 +08:00
committed by GitHub
parent d35a8a24a5
commit 5b4d3616a4
6 changed files with 451 additions and 0 deletions

View File

@ -98,6 +98,16 @@ public class Memo {
*/
public Plan copyOut(Group group, boolean includeGroupExpression) {
GroupExpression logicalExpression = group.getLogicalExpression();
return copyOut(logicalExpression, includeGroupExpression);
}
/**
* copyOut the logicalExpression.
* @param logicalExpression the logicalExpression what want to copyOut
* @param includeGroupExpression whether include group expression in the plan
* @return plan
*/
public Plan copyOut(GroupExpression logicalExpression, boolean includeGroupExpression) {
List<Plan> children = Lists.newArrayList();
for (Group child : logicalExpression.children()) {
children.add(copyOut(child, includeGroupExpression));

View File

@ -0,0 +1,100 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.exploration.join;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.exploration.OneExplorationRuleFactory;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.util.ExpressionUtils;
import com.google.common.base.Preconditions;
import java.util.Set;
/**
* Planner rule that pushes a SemoJoin down in a tree past a LogicalJoin
* in order to trigger other rules that will convert {@code SemiJoin}s.
*
* <ul>
* <li>SemiJoin(LogicalJoin(X, Y), Z) -> LogicalJoin(SemiJoin(X, Z), Y)
* <li>SemiJoin(LogicalJoin(X, Y), Z) -> LogicalJoin(X, SemiJoin(Y, Z))
* </ul>
* <p>
* Whether this first or second conversion is applied depends on
* which operands actually participate in the semi-join.
*/
public class SemiJoinLogicalJoinTranspose extends OneExplorationRuleFactory {
@Override
public Rule build() {
return leftSemiLogicalJoin(logicalJoin(), group())
.when(this::conditionChecker)
.then(topSemiJoin -> {
LogicalJoin<GroupPlan, GroupPlan> bottomJoin = topSemiJoin.left();
GroupPlan a = bottomJoin.left();
GroupPlan b = bottomJoin.right();
GroupPlan c = topSemiJoin.right();
boolean lasscom = bottomJoin.getOutputSet().containsAll(a.getOutput());
if (lasscom) {
/*
* topSemiJoin newTopJoin
* / \ / \
* bottomJoin C --> newBottomSemiJoin B
* / \ / \
* A B A C
*/
LogicalJoin<GroupPlan, GroupPlan> newBottomSemiJoin = new LogicalJoin<>(
topSemiJoin.getJoinType(),
topSemiJoin.getHashJoinConjuncts(), topSemiJoin.getOtherJoinCondition(), a, c);
return new LogicalJoin<>(bottomJoin.getJoinType(), bottomJoin.getHashJoinConjuncts(),
bottomJoin.getOtherJoinCondition(), newBottomSemiJoin, b);
} else {
/*
* topSemiJoin newTopJoin
* / \ / \
* bottomJoin C --> A newBottomSemiJoin
* / \ / \
* A B B C
*/
LogicalJoin<GroupPlan, GroupPlan> newBottomSemiJoin = new LogicalJoin<>(
topSemiJoin.getJoinType(),
topSemiJoin.getHashJoinConjuncts(), topSemiJoin.getOtherJoinCondition(), b, c);
return new LogicalJoin<>(bottomJoin.getJoinType(), bottomJoin.getHashJoinConjuncts(),
bottomJoin.getOtherJoinCondition(), a, newBottomSemiJoin);
}
}).toRule(RuleType.LOGICAL_JOIN_L_ASSCOM);
}
// bottomJoin just return A OR B, else return false.
private boolean conditionChecker(LogicalJoin<LogicalJoin<GroupPlan, GroupPlan>, GroupPlan> topJoin) {
Set<Slot> bottomOutputSet = topJoin.left().getOutputSet();
Set<Slot> aOutputSet = topJoin.left().left().getOutputSet();
Set<Slot> bOutputSet = topJoin.left().right().getOutputSet();
boolean isProjectA = !ExpressionUtils.isIntersecting(bottomOutputSet, aOutputSet);
boolean isProjectB = !ExpressionUtils.isIntersecting(bottomOutputSet, bOutputSet);
Preconditions.checkState(isProjectA || isProjectB, "join output must contain child");
return !(isProjectA && isProjectB);
}
}

View File

@ -0,0 +1,122 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.exploration.join;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.exploration.OneExplorationRuleFactory;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.util.ExpressionUtils;
import com.google.common.base.Preconditions;
import java.util.List;
import java.util.Set;
/**
* Planner rule that pushes a SemoJoin down in a tree past a LogicalJoin
* in order to trigger other rules that will convert {@code SemiJoin}s.
*
* <ul>
* <li>SemiJoin(LogicalJoin(X, Y), Z) -> LogicalJoin(SemiJoin(X, Z), Y)
* <li>SemiJoin(LogicalJoin(X, Y), Z) -> LogicalJoin(X, SemiJoin(Y, Z))
* </ul>
* <p>
* Whether this first or second conversion is applied depends on
* which operands actually participate in the semi-join.
*/
public class SemiJoinLogicalJoinTransposeProject extends OneExplorationRuleFactory {
@Override
public Rule build() {
return leftSemiLogicalJoin(logicalProject(logicalJoin()), group())
.when(this::conditionChecker)
.then(topSemiJoin -> {
LogicalProject<LogicalJoin<GroupPlan, GroupPlan>> project = topSemiJoin.left();
LogicalJoin<GroupPlan, GroupPlan> bottomJoin = project.child();
GroupPlan a = bottomJoin.left();
GroupPlan b = bottomJoin.right();
GroupPlan c = topSemiJoin.right();
boolean lasscom = a.getOutputSet().containsAll(project.getOutput());
if (lasscom) {
/*-
* topSemiJoin newTopProject
* / \ |
* project C newTopJoin
* | -> / \
* bottomJoin newBottomSemiJoin B
* / \ / \
* A B aNewProject C
* |
* A
*/
List<NamedExpression> projects = project.getProjects();
LogicalProject<GroupPlan> aNewProject = new LogicalProject<>(projects, a);
LogicalJoin<LogicalProject<GroupPlan>, GroupPlan> newBottomSemiJoin = new LogicalJoin<>(
topSemiJoin.getJoinType(), topSemiJoin.getHashJoinConjuncts(),
topSemiJoin.getOtherJoinCondition(), aNewProject, c);
LogicalJoin<LogicalJoin<LogicalProject<GroupPlan>, GroupPlan>, GroupPlan> newTopJoin
= new LogicalJoin<>(bottomJoin.getJoinType(), bottomJoin.getHashJoinConjuncts(),
bottomJoin.getOtherJoinCondition(), newBottomSemiJoin, b);
return new LogicalProject<>(projects, newTopJoin);
} else {
/*-
* topSemiJoin newTopProject
* / \ |
* project C newTopJoin
* | / \
* bottomJoin C --> A newBottomSemiJoin
* / \ / \
* A B bNewProject C
* |
* B
*/
List<NamedExpression> projects = project.getProjects();
LogicalProject<GroupPlan> bNewProject = new LogicalProject<>(projects, b);
LogicalJoin<LogicalProject<GroupPlan>, GroupPlan> newBottomSemiJoin = new LogicalJoin<>(
topSemiJoin.getJoinType(), topSemiJoin.getHashJoinConjuncts(),
topSemiJoin.getOtherJoinCondition(), bNewProject, c);
LogicalJoin<GroupPlan, LogicalJoin<LogicalProject<GroupPlan>, GroupPlan>> newTopJoin
= new LogicalJoin<>(bottomJoin.getJoinType(), bottomJoin.getHashJoinConjuncts(),
bottomJoin.getOtherJoinCondition(), a, newBottomSemiJoin);
return new LogicalProject<>(projects, newTopJoin);
}
}).toRule(RuleType.LOGICAL_JOIN_L_ASSCOM);
}
// bottomJoin just return A OR B, else return false.
private boolean conditionChecker(
LogicalJoin<LogicalProject<LogicalJoin<GroupPlan, GroupPlan>>, GroupPlan> topJoin) {
Set<Slot> projectOutputSet = topJoin.left().getOutputSet();
Set<Slot> aOutputSet = topJoin.left().child().left().getOutputSet();
Set<Slot> bOutputSet = topJoin.left().child().right().getOutputSet();
boolean isProjectA = !ExpressionUtils.isIntersecting(projectOutputSet, aOutputSet);
boolean isProjectB = !ExpressionUtils.isIntersecting(projectOutputSet, bOutputSet);
Preconditions.checkState(isProjectA || isProjectB, "project must contain child");
return !(isProjectA && isProjectB);
}
}

View File

@ -0,0 +1,78 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.exploration.join;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.exploration.OneExplorationRuleFactory;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import com.google.common.collect.ImmutableSet;
import java.util.Set;
/**
* Rule for SemiJoinTranspose.
* <p>
* LEFT-Semi/ANTI(LEFT-Semi/ANTI(X, Y), Z)
* ->
* LEFT-Semi/ANTI(X, LEFT-Semi/ANTI(Y, Z))
*/
public class SemiJoinSemiJoinTranspose extends OneExplorationRuleFactory {
public static Set<Pair<JoinType, JoinType>> typeSet = ImmutableSet.of(
Pair.of(JoinType.LEFT_SEMI_JOIN, JoinType.LEFT_SEMI_JOIN),
Pair.of(JoinType.LEFT_ANTI_JOIN, JoinType.LEFT_ANTI_JOIN),
Pair.of(JoinType.LEFT_SEMI_JOIN, JoinType.LEFT_ANTI_JOIN),
Pair.of(JoinType.LEFT_ANTI_JOIN, JoinType.LEFT_SEMI_JOIN));
/*
* topJoin newTopJoin
* / \ / \
* bottomJoin C --> newBottomJoin B
* / \ / \
* A B A C
*/
@Override
public Rule build() {
return logicalJoin(logicalJoin(), group())
.when(this::typeChecker)
.then(topJoin -> {
LogicalJoin<GroupPlan, GroupPlan> bottomJoin = topJoin.left();
GroupPlan a = bottomJoin.left();
GroupPlan b = bottomJoin.right();
GroupPlan c = topJoin.right();
LogicalJoin<GroupPlan, GroupPlan> newBottomJoin = new LogicalJoin<>(topJoin.getJoinType(),
topJoin.getHashJoinConjuncts(), topJoin.getOtherJoinCondition(), a, c);
LogicalJoin<LogicalJoin<GroupPlan, GroupPlan>, GroupPlan> newTopJoin = new LogicalJoin<>(
bottomJoin.getJoinType(), bottomJoin.getHashJoinConjuncts(),
bottomJoin.getOtherJoinCondition(),
newBottomJoin, b);
return newTopJoin;
}).toRule(RuleType.LOGICAL_JOIN_L_ASSCOM);
}
private boolean typeChecker(LogicalJoin<LogicalJoin<GroupPlan, GroupPlan>, GroupPlan> topJoin) {
return typeSet.contains(Pair.of(topJoin.getJoinType(), topJoin.left().getJoinType()));
}
}

View File

@ -0,0 +1,63 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.exploration.join;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.memo.Group;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.util.LogicalPlanBuilder;
import org.apache.doris.nereids.util.MemoTestUtils;
import org.apache.doris.nereids.util.PlanChecker;
import org.apache.doris.nereids.util.PlanConstructor;
import com.google.common.collect.ImmutableList;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
public class SemiJoinTransposeTest {
public static final LogicalOlapScan scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0);
public static final LogicalOlapScan scan2 = PlanConstructor.newLogicalOlapScan(1, "t2", 0);
public static final LogicalOlapScan scan3 = PlanConstructor.newLogicalOlapScan(2, "t3", 0);
@Test
public void testSemiJoinLogicalTransposeCommute() {
LogicalPlan topJoin = new LogicalPlanBuilder(scan1)
.hashJoinUsing(scan2, JoinType.INNER_JOIN, Pair.of(0, 0))
.project(ImmutableList.of(0))
.hashJoinUsing(scan3, JoinType.LEFT_SEMI_JOIN, Pair.of(0, 0))
.build();
PlanChecker.from(MemoTestUtils.createConnectContext(), topJoin)
.transform((new SemiJoinLogicalJoinTransposeProject()).build())
.checkMemo(memo -> {
Group root = memo.getRoot();
Assertions.assertEquals(2, root.getLogicalExpressions().size());
Plan plan = memo.copyOut(root.getLogicalExpressions().get(1), false);
Plan join = plan.child(0);
Assertions.assertTrue(join instanceof LogicalJoin);
Assertions.assertEquals(JoinType.INNER_JOIN, ((LogicalJoin<?, ?>) join).getJoinType());
Assertions.assertEquals(JoinType.LEFT_SEMI_JOIN,
((LogicalJoin<?, ?>) ((LogicalJoin<?, ?>) join).left()).getJoinType());
});
}
}

View File

@ -0,0 +1,78 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.util;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
public class LogicalPlanBuilder {
private final LogicalPlan plan;
public LogicalPlanBuilder(LogicalPlan plan) {
this.plan = plan;
}
public LogicalPlan build() {
return plan;
}
public LogicalPlanBuilder from(LogicalPlan plan) {
return new LogicalPlanBuilder(plan);
}
public LogicalPlanBuilder scan(long tableId, String tableName, int hashColumn) {
LogicalOlapScan scan = PlanConstructor.newLogicalOlapScan(tableId, tableName, hashColumn);
return from(scan);
}
public LogicalPlanBuilder projectWithExprs(List<NamedExpression> projectExprs) {
LogicalProject<LogicalPlan> project = new LogicalProject<>(projectExprs, this.plan);
return from(project);
}
public LogicalPlanBuilder project(List<Integer> slots) {
List<NamedExpression> projectExprs = Lists.newArrayList();
for (int i = 0; i < slots.size(); i++) {
projectExprs.add(this.plan.getOutput().get(i));
}
LogicalProject<LogicalPlan> project = new LogicalProject<>(projectExprs, this.plan);
return from(project);
}
public LogicalPlanBuilder hashJoinUsing(LogicalPlan right, JoinType joinType, Pair<Integer, Integer> hashOnSlots) {
ImmutableList<EqualTo> hashConjunts = ImmutableList.of(
new EqualTo(this.plan.getOutput().get(hashOnSlots.first), right.getOutput().get(hashOnSlots.second)));
LogicalJoin<LogicalPlan, LogicalPlan> join = new LogicalJoin<>(joinType, new ArrayList<>(hashConjunts),
Optional.empty(), this.plan, right);
return from(join);
}
}