[fix](Nereids): fix LAsscom project split. (#12506)

This commit is contained in:
jakevin
2022-09-13 12:12:39 +08:00
committed by GitHub
parent 6b52e47805
commit c3d7d4ce7a
7 changed files with 218 additions and 47 deletions

View File

@ -20,6 +20,7 @@ package org.apache.doris.nereids.rules.exploration.join;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.rules.exploration.join.JoinReorderCommon.Type;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.JoinType;
@ -33,6 +34,7 @@ import com.google.common.collect.ImmutableSet;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
@ -67,9 +69,15 @@ class JoinLAsscomHelper extends ThreeJoinHelper {
* Create newTopJoin.
*/
public Plan newTopJoin() {
Pair<List<NamedExpression>, List<NamedExpression>> projectPair = splitProjectExprs(bOutput);
List<NamedExpression> newLeftProjectExpr = projectPair.second;
List<NamedExpression> newRightProjectExprs = projectPair.first;
// Split inside-project into two part.
Map<Boolean, List<NamedExpression>> projectExprsMap = allProjects.stream()
.collect(Collectors.partitioningBy(projectExpr -> {
Set<Slot> usedSlots = projectExpr.collect(Slot.class::isInstance);
return bOutput.containsAll(usedSlots);
}));
List<NamedExpression> newLeftProjectExpr = projectExprsMap.get(Boolean.FALSE);
List<NamedExpression> newRightProjectExprs = projectExprsMap.get(Boolean.TRUE);
// If add project to B, we should add all slotReference used by hashOnCondition.
// TODO: Does nonHashOnCondition also need to be considered.

View File

@ -17,7 +17,6 @@
package org.apache.doris.nereids.rules.exploration.join;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
@ -27,7 +26,6 @@ import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.Utils;
import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
@ -46,9 +44,9 @@ abstract class ThreeJoinHelper {
protected final GroupPlan b;
protected final GroupPlan c;
protected final List<SlotReference> aOutput;
protected final List<SlotReference> bOutput;
protected final List<SlotReference> cOutput;
protected final Set<Slot> aOutput;
protected final Set<Slot> bOutput;
protected final Set<Slot> cOutput;
protected final List<NamedExpression> allProjects = Lists.newArrayList();
@ -72,9 +70,9 @@ abstract class ThreeJoinHelper {
this.b = b;
this.c = c;
aOutput = Utils.getOutputSlotReference(a);
bOutput = Utils.getOutputSlotReference(b);
cOutput = Utils.getOutputSlotReference(c);
aOutput = a.getOutputSet();
bOutput = b.getOutputSet();
cOutput = c.getOutputSet();
Preconditions.checkArgument(!topJoin.getHashJoinConjuncts().isEmpty(), "topJoin hashJoinConjuncts must exist.");
Preconditions.checkArgument(!bottomJoin.getHashJoinConjuncts().isEmpty(),
@ -103,7 +101,7 @@ abstract class ThreeJoinHelper {
// Join C = B + A for above example.
// TODO: also need for otherJoinCondition
for (Expression topJoinOnClauseConjunct : topJoin.getHashJoinConjuncts()) {
Set<SlotReference> topJoinUsedSlot = topJoinOnClauseConjunct.collect(SlotReference.class::isInstance);
Set<Slot> topJoinUsedSlot = topJoinOnClauseConjunct.collect(SlotReference.class::isInstance);
if (ExpressionUtils.isIntersecting(topJoinUsedSlot, aOutput) && ExpressionUtils.isIntersecting(
topJoinUsedSlot, bOutput) && ExpressionUtils.isIntersecting(topJoinUsedSlot, cOutput)) {
return false;
@ -140,26 +138,4 @@ abstract class ThreeJoinHelper {
return true;
}
/**
* Split inside-project into two part.
*
* @param topJoinChild output of topJoin groupPlan child.
*/
protected Pair<List<NamedExpression>, List<NamedExpression>> splitProjectExprs(List<SlotReference> topJoinChild) {
List<NamedExpression> newTopJoinChildProjectExprs = Lists.newArrayList();
List<NamedExpression> newBottomJoinProjectExprs = Lists.newArrayList();
HashSet<SlotReference> topJoinOutputSlotsSet = new HashSet<>(topJoinChild);
for (NamedExpression projectExpr : allProjects) {
Set<SlotReference> usedSlotRefs = projectExpr.collect(SlotReference.class::isInstance);
if (topJoinOutputSlotsSet.containsAll(usedSlotRefs)) {
newTopJoinChildProjectExprs.add(projectExpr);
} else {
newBottomJoinProjectExprs.add(projectExpr);
}
}
return Pair.of(newTopJoinChildProjectExprs, newBottomJoinProjectExprs);
}
}

View File

@ -143,8 +143,20 @@ public class ExpressionUtils {
/**
* Check whether lhs and rhs are intersecting.
*/
public static boolean isIntersecting(Set<SlotReference> lhs, List<SlotReference> rhs) {
for (SlotReference rh : rhs) {
public static <T> boolean isIntersecting(Set<T> lhs, List<T> rhs) {
for (T rh : rhs) {
if (lhs.contains(rh)) {
return true;
}
}
return false;
}
/**
* Check whether lhs and rhs are intersecting.
*/
public static <T> boolean isIntersecting(Set<T> lhs, Set<T> rhs) {
for (T rh : rhs) {
if (lhs.contains(rh)) {
return true;
}

View File

@ -0,0 +1,110 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.exploration.join;
import org.apache.doris.nereids.memo.Group;
import org.apache.doris.nereids.memo.GroupExpression;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.util.MemoTestUtils;
import org.apache.doris.nereids.util.PlanChecker;
import org.apache.doris.nereids.util.PlanConstructor;
import com.google.common.collect.Lists;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;
public class JoinLAsscomProjectTest {
private final LogicalOlapScan scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0);
private final LogicalOlapScan scan2 = PlanConstructor.newLogicalOlapScan(1, "t2", 0);
private final LogicalOlapScan scan3 = PlanConstructor.newLogicalOlapScan(2, "t3", 0);
@Test
public void testStarJoinLAsscomProject() {
/*
* Star-Join
* t1 -- t2
* |
* t3
* <p>
* t1.id=t3.id t1.id=t2.id
* topJoin newTopJoin
* / \ / \
* project t3 project project
* t1.id=t2.id t1.id=t3.id t2
* bottomJoin --> newBottomJoin
* / \ / \
* t1 t2 t1 t3
*/
Expression bottomJoinOnCondition = new EqualTo(scan1.getOutput().get(0), scan2.getOutput().get(0));
Expression topJoinOnCondition = new EqualTo(scan1.getOutput().get(1), scan3.getOutput().get(1));
LogicalJoin<LogicalOlapScan, LogicalOlapScan> bottomJoin = new LogicalJoin<>(JoinType.INNER_JOIN,
Lists.newArrayList(bottomJoinOnCondition),
Optional.empty(), scan1, scan2);
List<Slot> output = bottomJoin.getOutput();
List<NamedExpression> projectExprs = output.subList(0, output.size() - 1).stream()
.map(NamedExpression.class::cast).collect(Collectors.toList());
LogicalProject<LogicalJoin<LogicalOlapScan, LogicalOlapScan>> project = new LogicalProject<>(
projectExprs, bottomJoin);
LogicalJoin<LogicalProject<LogicalJoin<LogicalOlapScan, LogicalOlapScan>>, LogicalOlapScan>
topJoin = new LogicalJoin<>(JoinType.INNER_JOIN, Lists.newArrayList(topJoinOnCondition),
Optional.empty(), project, scan3);
PlanChecker.from(MemoTestUtils.createConnectContext(), topJoin)
.transform(JoinLAsscomProject.INNER.build())
.checkMemo(memo -> {
Group root = memo.getRoot();
Assertions.assertEquals(2, root.getLogicalExpressions().size());
Assertions.assertTrue(root.logicalExpressionsAt(0).getPlan() instanceof LogicalJoin);
Assertions.assertTrue(root.logicalExpressionsAt(1).getPlan() instanceof LogicalProject);
GroupExpression newTopJoinGroupExpr = root.logicalExpressionsAt(1).child(0).getLogicalExpression();
GroupExpression leftProjectGroupExpr = newTopJoinGroupExpr.child(0).getLogicalExpression();
GroupExpression rightProjectGroupExpr = newTopJoinGroupExpr.child(1).getLogicalExpression();
Plan leftProject = newTopJoinGroupExpr.child(0).getLogicalExpression().getPlan();
Plan rightProject = newTopJoinGroupExpr.child(1).getLogicalExpression().getPlan();
Assertions.assertEquals(4, ((LogicalProject) leftProject).getProjects().size());
Assertions.assertEquals(1, ((LogicalProject) rightProject).getProjects().size());
Plan t2 = rightProjectGroupExpr.child(0).getLogicalExpression().getPlan();
Plan t1 = leftProjectGroupExpr.child(0).getLogicalExpression().child(0).getLogicalExpression()
.getPlan();
Plan t3 = leftProjectGroupExpr.child(0).getLogicalExpression().child(1).getLogicalExpression()
.getPlan();
Assertions.assertEquals("t2", ((LogicalOlapScan) t2).getTable().getName());
Assertions.assertEquals("t1", ((LogicalOlapScan) t1).getTable().getName());
Assertions.assertEquals("t3", ((LogicalOlapScan) t3).getTable().getName());
});
}
}

View File

@ -21,7 +21,6 @@ import org.apache.doris.nereids.memo.Group;
import org.apache.doris.nereids.memo.GroupExpression;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
@ -30,13 +29,11 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.util.MemoTestUtils;
import org.apache.doris.nereids.util.PlanChecker;
import org.apache.doris.nereids.util.PlanConstructor;
import org.apache.doris.nereids.util.Utils;
import com.google.common.collect.Lists;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import java.util.List;
import java.util.Optional;
public class JoinLAsscomTest {
@ -45,10 +42,6 @@ public class JoinLAsscomTest {
private final LogicalOlapScan scan2 = PlanConstructor.newLogicalOlapScan(1, "t2", 0);
private final LogicalOlapScan scan3 = PlanConstructor.newLogicalOlapScan(2, "t3", 0);
private final List<SlotReference> t1Output = Utils.getOutputSlotReference(scan1);
private final List<SlotReference> t2Output = Utils.getOutputSlotReference(scan2);
private final List<SlotReference> t3Output = Utils.getOutputSlotReference(scan3);
@Test
public void testStarJoinLAsscom() {
/*
@ -66,8 +59,8 @@ public class JoinLAsscomTest {
* t1 t2 t1 t3
*/
Expression bottomJoinOnCondition = new EqualTo(t1Output.get(0), t2Output.get(0));
Expression topJoinOnCondition = new EqualTo(t1Output.get(1), t3Output.get(1));
Expression bottomJoinOnCondition = new EqualTo(scan1.getOutput().get(0), scan2.getOutput().get(0));
Expression topJoinOnCondition = new EqualTo(scan1.getOutput().get(1), scan3.getOutput().get(1));
LogicalJoin<LogicalOlapScan, LogicalOlapScan> bottomJoin = new LogicalJoin<>(JoinType.INNER_JOIN,
Lists.newArrayList(bottomJoinOnCondition),
@ -112,8 +105,8 @@ public class JoinLAsscomTest {
* t1 t2 t1 t3
*/
Expression bottomJoinOnCondition = new EqualTo(t1Output.get(0), t2Output.get(0));
Expression topJoinOnCondition = new EqualTo(t2Output.get(0), t3Output.get(0));
Expression bottomJoinOnCondition = new EqualTo(scan1.getOutput().get(0), scan2.getOutput().get(0));
Expression topJoinOnCondition = new EqualTo(scan2.getOutput().get(0), scan3.getOutput().get(0));
LogicalJoin<LogicalOlapScan, LogicalOlapScan> bottomJoin = new LogicalJoin<>(JoinType.INNER_JOIN,
Lists.newArrayList(bottomJoinOnCondition),
Optional.empty(), scan1, scan2);

View File

@ -23,6 +23,8 @@ import org.apache.doris.nereids.glue.translator.PhysicalPlanTranslator;
import org.apache.doris.nereids.glue.translator.PlanTranslatorContext;
import org.apache.doris.nereids.parser.NereidsParser;
import org.apache.doris.nereids.properties.PhysicalProperties;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleSet;
import org.apache.doris.nereids.rules.rewrite.logical.ApplyPullFilterOnAgg;
import org.apache.doris.nereids.rules.rewrite.logical.ApplyPullFilterOnProjectUnderAgg;
import org.apache.doris.nereids.rules.rewrite.logical.ExistsApplyToJoin;
@ -45,6 +47,9 @@ import org.apache.doris.nereids.types.VarcharType;
import org.apache.doris.utframe.TestWithFeService;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import mockit.Mock;
import mockit.MockUp;
import org.junit.jupiter.api.Test;
import java.util.List;
@ -112,6 +117,13 @@ public class AnalyzeWhereSubqueryTest extends TestWithFeService implements Patte
@Test
public void testTranslateCase() throws Exception {
new MockUp<RuleSet>() {
@Mock
public List<Rule> getExplorationRules() {
return Lists.newArrayList();
}
};
for (String sql : testSql) {
NamedExpressionUtil.clear();
StatementContext statementContext = MemoTestUtils.createStatementContext(connectContext, sql);

View File

@ -0,0 +1,60 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.util;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import com.google.common.collect.Lists;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import java.util.List;
class PlanUtilsTest {
@Test
void projectOrSelf() {
LogicalOlapScan scan = PlanConstructor.newLogicalOlapScan(0, "t1", 0);
Plan self = PlanUtils.projectOrSelf(Lists.newArrayList(), scan);
Assertions.assertSame(scan, self);
NamedExpression slot = scan.getOutput().get(0);
List<NamedExpression> projects = Lists.newArrayList();
projects.add(slot);
Plan project = PlanUtils.projectOrSelf(projects, scan);
Assertions.assertTrue(project instanceof LogicalProject);
}
@Test
void filterOrSelf() {
LogicalOlapScan scan = PlanConstructor.newLogicalOlapScan(0, "t1", 0);
Plan filterOrSelf = PlanUtils.filterOrSelf(Lists.newArrayList(), scan);
Assertions.assertSame(scan, filterOrSelf);
List<Expression> predicate = Lists.newArrayList();
predicate.add(BooleanLiteral.TRUE);
Plan filter = PlanUtils.filterOrSelf(predicate, scan);
Assertions.assertTrue(filter instanceof LogicalFilter);
}
}