diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinLAsscomHelper.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinLAsscomHelper.java index 860925d624..616d9d8aba 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinLAsscomHelper.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinLAsscomHelper.java @@ -20,6 +20,7 @@ package org.apache.doris.nereids.rules.exploration.join; import org.apache.doris.common.Pair; import org.apache.doris.nereids.rules.exploration.join.JoinReorderCommon.Type; import org.apache.doris.nereids.trees.expressions.NamedExpression; +import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.expressions.SlotReference; import org.apache.doris.nereids.trees.plans.GroupPlan; import org.apache.doris.nereids.trees.plans.JoinType; @@ -33,6 +34,7 @@ import com.google.common.collect.ImmutableSet; import java.util.ArrayList; import java.util.List; +import java.util.Map; import java.util.Set; import java.util.stream.Collectors; @@ -67,9 +69,15 @@ class JoinLAsscomHelper extends ThreeJoinHelper { * Create newTopJoin. */ public Plan newTopJoin() { - Pair, List> projectPair = splitProjectExprs(bOutput); - List newLeftProjectExpr = projectPair.second; - List newRightProjectExprs = projectPair.first; + // Split inside-project into two part. + Map> projectExprsMap = allProjects.stream() + .collect(Collectors.partitioningBy(projectExpr -> { + Set usedSlots = projectExpr.collect(Slot.class::isInstance); + return bOutput.containsAll(usedSlots); + })); + + List newLeftProjectExpr = projectExprsMap.get(Boolean.FALSE); + List newRightProjectExprs = projectExprsMap.get(Boolean.TRUE); // If add project to B, we should add all slotReference used by hashOnCondition. // TODO: Does nonHashOnCondition also need to be considered. diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/ThreeJoinHelper.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/ThreeJoinHelper.java index fdf70f2c05..e93aa5f104 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/ThreeJoinHelper.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/ThreeJoinHelper.java @@ -17,7 +17,6 @@ package org.apache.doris.nereids.rules.exploration.join; -import org.apache.doris.common.Pair; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.expressions.Slot; @@ -27,7 +26,6 @@ import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; import org.apache.doris.nereids.trees.plans.logical.LogicalProject; import org.apache.doris.nereids.util.ExpressionUtils; -import org.apache.doris.nereids.util.Utils; import com.google.common.base.Preconditions; import com.google.common.collect.Lists; @@ -46,9 +44,9 @@ abstract class ThreeJoinHelper { protected final GroupPlan b; protected final GroupPlan c; - protected final List aOutput; - protected final List bOutput; - protected final List cOutput; + protected final Set aOutput; + protected final Set bOutput; + protected final Set cOutput; protected final List allProjects = Lists.newArrayList(); @@ -72,9 +70,9 @@ abstract class ThreeJoinHelper { this.b = b; this.c = c; - aOutput = Utils.getOutputSlotReference(a); - bOutput = Utils.getOutputSlotReference(b); - cOutput = Utils.getOutputSlotReference(c); + aOutput = a.getOutputSet(); + bOutput = b.getOutputSet(); + cOutput = c.getOutputSet(); Preconditions.checkArgument(!topJoin.getHashJoinConjuncts().isEmpty(), "topJoin hashJoinConjuncts must exist."); Preconditions.checkArgument(!bottomJoin.getHashJoinConjuncts().isEmpty(), @@ -103,7 +101,7 @@ abstract class ThreeJoinHelper { // Join C = B + A for above example. // TODO: also need for otherJoinCondition for (Expression topJoinOnClauseConjunct : topJoin.getHashJoinConjuncts()) { - Set topJoinUsedSlot = topJoinOnClauseConjunct.collect(SlotReference.class::isInstance); + Set topJoinUsedSlot = topJoinOnClauseConjunct.collect(SlotReference.class::isInstance); if (ExpressionUtils.isIntersecting(topJoinUsedSlot, aOutput) && ExpressionUtils.isIntersecting( topJoinUsedSlot, bOutput) && ExpressionUtils.isIntersecting(topJoinUsedSlot, cOutput)) { return false; @@ -140,26 +138,4 @@ abstract class ThreeJoinHelper { return true; } - - /** - * Split inside-project into two part. - * - * @param topJoinChild output of topJoin groupPlan child. - */ - protected Pair, List> splitProjectExprs(List topJoinChild) { - List newTopJoinChildProjectExprs = Lists.newArrayList(); - List newBottomJoinProjectExprs = Lists.newArrayList(); - - HashSet topJoinOutputSlotsSet = new HashSet<>(topJoinChild); - - for (NamedExpression projectExpr : allProjects) { - Set usedSlotRefs = projectExpr.collect(SlotReference.class::isInstance); - if (topJoinOutputSlotsSet.containsAll(usedSlotRefs)) { - newTopJoinChildProjectExprs.add(projectExpr); - } else { - newBottomJoinProjectExprs.add(projectExpr); - } - } - return Pair.of(newTopJoinChildProjectExprs, newBottomJoinProjectExprs); - } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java index 79493ae1c2..9c1b72100c 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java @@ -143,8 +143,20 @@ public class ExpressionUtils { /** * Check whether lhs and rhs are intersecting. */ - public static boolean isIntersecting(Set lhs, List rhs) { - for (SlotReference rh : rhs) { + public static boolean isIntersecting(Set lhs, List rhs) { + for (T rh : rhs) { + if (lhs.contains(rh)) { + return true; + } + } + return false; + } + + /** + * Check whether lhs and rhs are intersecting. + */ + public static boolean isIntersecting(Set lhs, Set rhs) { + for (T rh : rhs) { if (lhs.contains(rh)) { return true; } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/JoinLAsscomProjectTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/JoinLAsscomProjectTest.java new file mode 100644 index 0000000000..bab120f2aa --- /dev/null +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/JoinLAsscomProjectTest.java @@ -0,0 +1,110 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.exploration.join; + +import org.apache.doris.nereids.memo.Group; +import org.apache.doris.nereids.memo.GroupExpression; +import org.apache.doris.nereids.trees.expressions.EqualTo; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.NamedExpression; +import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.plans.JoinType; +import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; +import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; +import org.apache.doris.nereids.trees.plans.logical.LogicalProject; +import org.apache.doris.nereids.util.MemoTestUtils; +import org.apache.doris.nereids.util.PlanChecker; +import org.apache.doris.nereids.util.PlanConstructor; + +import com.google.common.collect.Lists; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +import java.util.List; +import java.util.Optional; +import java.util.stream.Collectors; + +public class JoinLAsscomProjectTest { + + private final LogicalOlapScan scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0); + private final LogicalOlapScan scan2 = PlanConstructor.newLogicalOlapScan(1, "t2", 0); + private final LogicalOlapScan scan3 = PlanConstructor.newLogicalOlapScan(2, "t3", 0); + + @Test + public void testStarJoinLAsscomProject() { + /* + * Star-Join + * t1 -- t2 + * | + * t3 + *

+ * t1.id=t3.id t1.id=t2.id + * topJoin newTopJoin + * / \ / \ + * project t3 project project + * t1.id=t2.id t1.id=t3.id t2 + * bottomJoin --> newBottomJoin + * / \ / \ + * t1 t2 t1 t3 + */ + + Expression bottomJoinOnCondition = new EqualTo(scan1.getOutput().get(0), scan2.getOutput().get(0)); + Expression topJoinOnCondition = new EqualTo(scan1.getOutput().get(1), scan3.getOutput().get(1)); + + LogicalJoin bottomJoin = new LogicalJoin<>(JoinType.INNER_JOIN, + Lists.newArrayList(bottomJoinOnCondition), + Optional.empty(), scan1, scan2); + + List output = bottomJoin.getOutput(); + List projectExprs = output.subList(0, output.size() - 1).stream() + .map(NamedExpression.class::cast).collect(Collectors.toList()); + LogicalProject> project = new LogicalProject<>( + projectExprs, bottomJoin); + LogicalJoin>, LogicalOlapScan> + topJoin = new LogicalJoin<>(JoinType.INNER_JOIN, Lists.newArrayList(topJoinOnCondition), + Optional.empty(), project, scan3); + + PlanChecker.from(MemoTestUtils.createConnectContext(), topJoin) + .transform(JoinLAsscomProject.INNER.build()) + .checkMemo(memo -> { + Group root = memo.getRoot(); + Assertions.assertEquals(2, root.getLogicalExpressions().size()); + + Assertions.assertTrue(root.logicalExpressionsAt(0).getPlan() instanceof LogicalJoin); + Assertions.assertTrue(root.logicalExpressionsAt(1).getPlan() instanceof LogicalProject); + + GroupExpression newTopJoinGroupExpr = root.logicalExpressionsAt(1).child(0).getLogicalExpression(); + GroupExpression leftProjectGroupExpr = newTopJoinGroupExpr.child(0).getLogicalExpression(); + GroupExpression rightProjectGroupExpr = newTopJoinGroupExpr.child(1).getLogicalExpression(); + Plan leftProject = newTopJoinGroupExpr.child(0).getLogicalExpression().getPlan(); + Plan rightProject = newTopJoinGroupExpr.child(1).getLogicalExpression().getPlan(); + Assertions.assertEquals(4, ((LogicalProject) leftProject).getProjects().size()); + Assertions.assertEquals(1, ((LogicalProject) rightProject).getProjects().size()); + + Plan t2 = rightProjectGroupExpr.child(0).getLogicalExpression().getPlan(); + Plan t1 = leftProjectGroupExpr.child(0).getLogicalExpression().child(0).getLogicalExpression() + .getPlan(); + Plan t3 = leftProjectGroupExpr.child(0).getLogicalExpression().child(1).getLogicalExpression() + .getPlan(); + Assertions.assertEquals("t2", ((LogicalOlapScan) t2).getTable().getName()); + Assertions.assertEquals("t1", ((LogicalOlapScan) t1).getTable().getName()); + Assertions.assertEquals("t3", ((LogicalOlapScan) t3).getTable().getName()); + }); + } +} diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/JoinLAsscomTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/JoinLAsscomTest.java index 7d7f1d8b05..5bbee33b0a 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/JoinLAsscomTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/JoinLAsscomTest.java @@ -21,7 +21,6 @@ import org.apache.doris.nereids.memo.Group; import org.apache.doris.nereids.memo.GroupExpression; import org.apache.doris.nereids.trees.expressions.EqualTo; import org.apache.doris.nereids.trees.expressions.Expression; -import org.apache.doris.nereids.trees.expressions.SlotReference; import org.apache.doris.nereids.trees.plans.JoinType; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; @@ -30,13 +29,11 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalProject; import org.apache.doris.nereids.util.MemoTestUtils; import org.apache.doris.nereids.util.PlanChecker; import org.apache.doris.nereids.util.PlanConstructor; -import org.apache.doris.nereids.util.Utils; import com.google.common.collect.Lists; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; -import java.util.List; import java.util.Optional; public class JoinLAsscomTest { @@ -45,10 +42,6 @@ public class JoinLAsscomTest { private final LogicalOlapScan scan2 = PlanConstructor.newLogicalOlapScan(1, "t2", 0); private final LogicalOlapScan scan3 = PlanConstructor.newLogicalOlapScan(2, "t3", 0); - private final List t1Output = Utils.getOutputSlotReference(scan1); - private final List t2Output = Utils.getOutputSlotReference(scan2); - private final List t3Output = Utils.getOutputSlotReference(scan3); - @Test public void testStarJoinLAsscom() { /* @@ -66,8 +59,8 @@ public class JoinLAsscomTest { * t1 t2 t1 t3 */ - Expression bottomJoinOnCondition = new EqualTo(t1Output.get(0), t2Output.get(0)); - Expression topJoinOnCondition = new EqualTo(t1Output.get(1), t3Output.get(1)); + Expression bottomJoinOnCondition = new EqualTo(scan1.getOutput().get(0), scan2.getOutput().get(0)); + Expression topJoinOnCondition = new EqualTo(scan1.getOutput().get(1), scan3.getOutput().get(1)); LogicalJoin bottomJoin = new LogicalJoin<>(JoinType.INNER_JOIN, Lists.newArrayList(bottomJoinOnCondition), @@ -112,8 +105,8 @@ public class JoinLAsscomTest { * t1 t2 t1 t3 */ - Expression bottomJoinOnCondition = new EqualTo(t1Output.get(0), t2Output.get(0)); - Expression topJoinOnCondition = new EqualTo(t2Output.get(0), t3Output.get(0)); + Expression bottomJoinOnCondition = new EqualTo(scan1.getOutput().get(0), scan2.getOutput().get(0)); + Expression topJoinOnCondition = new EqualTo(scan2.getOutput().get(0), scan3.getOutput().get(0)); LogicalJoin bottomJoin = new LogicalJoin<>(JoinType.INNER_JOIN, Lists.newArrayList(bottomJoinOnCondition), Optional.empty(), scan1, scan2); diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/AnalyzeWhereSubqueryTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/AnalyzeWhereSubqueryTest.java index 9d8a8bbc81..a405f0d58d 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/AnalyzeWhereSubqueryTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/AnalyzeWhereSubqueryTest.java @@ -23,6 +23,8 @@ import org.apache.doris.nereids.glue.translator.PhysicalPlanTranslator; import org.apache.doris.nereids.glue.translator.PlanTranslatorContext; import org.apache.doris.nereids.parser.NereidsParser; import org.apache.doris.nereids.properties.PhysicalProperties; +import org.apache.doris.nereids.rules.Rule; +import org.apache.doris.nereids.rules.RuleSet; import org.apache.doris.nereids.rules.rewrite.logical.ApplyPullFilterOnAgg; import org.apache.doris.nereids.rules.rewrite.logical.ApplyPullFilterOnProjectUnderAgg; import org.apache.doris.nereids.rules.rewrite.logical.ExistsApplyToJoin; @@ -45,6 +47,9 @@ import org.apache.doris.nereids.types.VarcharType; import org.apache.doris.utframe.TestWithFeService; import com.google.common.collect.ImmutableList; +import com.google.common.collect.Lists; +import mockit.Mock; +import mockit.MockUp; import org.junit.jupiter.api.Test; import java.util.List; @@ -112,6 +117,13 @@ public class AnalyzeWhereSubqueryTest extends TestWithFeService implements Patte @Test public void testTranslateCase() throws Exception { + new MockUp() { + @Mock + public List getExplorationRules() { + return Lists.newArrayList(); + } + }; + for (String sql : testSql) { NamedExpressionUtil.clear(); StatementContext statementContext = MemoTestUtils.createStatementContext(connectContext, sql); diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/PlanUtilsTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/PlanUtilsTest.java new file mode 100644 index 0000000000..5218b57e3d --- /dev/null +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/PlanUtilsTest.java @@ -0,0 +1,60 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.util; + +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.NamedExpression; +import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral; +import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.logical.LogicalFilter; +import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; +import org.apache.doris.nereids.trees.plans.logical.LogicalProject; + +import com.google.common.collect.Lists; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +import java.util.List; + +class PlanUtilsTest { + + @Test + void projectOrSelf() { + LogicalOlapScan scan = PlanConstructor.newLogicalOlapScan(0, "t1", 0); + Plan self = PlanUtils.projectOrSelf(Lists.newArrayList(), scan); + Assertions.assertSame(scan, self); + + NamedExpression slot = scan.getOutput().get(0); + List projects = Lists.newArrayList(); + projects.add(slot); + Plan project = PlanUtils.projectOrSelf(projects, scan); + Assertions.assertTrue(project instanceof LogicalProject); + } + + @Test + void filterOrSelf() { + LogicalOlapScan scan = PlanConstructor.newLogicalOlapScan(0, "t1", 0); + Plan filterOrSelf = PlanUtils.filterOrSelf(Lists.newArrayList(), scan); + Assertions.assertSame(scan, filterOrSelf); + + List predicate = Lists.newArrayList(); + predicate.add(BooleanLiteral.TRUE); + Plan filter = PlanUtils.filterOrSelf(predicate, scan); + Assertions.assertTrue(filter instanceof LogicalFilter); + } +}