[fix](Nereids): fix LAsscom project split. (#12506)
This commit is contained in:
@ -20,6 +20,7 @@ package org.apache.doris.nereids.rules.exploration.join;
|
||||
import org.apache.doris.common.Pair;
|
||||
import org.apache.doris.nereids.rules.exploration.join.JoinReorderCommon.Type;
|
||||
import org.apache.doris.nereids.trees.expressions.NamedExpression;
|
||||
import org.apache.doris.nereids.trees.expressions.Slot;
|
||||
import org.apache.doris.nereids.trees.expressions.SlotReference;
|
||||
import org.apache.doris.nereids.trees.plans.GroupPlan;
|
||||
import org.apache.doris.nereids.trees.plans.JoinType;
|
||||
@ -33,6 +34,7 @@ import com.google.common.collect.ImmutableSet;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@ -67,9 +69,15 @@ class JoinLAsscomHelper extends ThreeJoinHelper {
|
||||
* Create newTopJoin.
|
||||
*/
|
||||
public Plan newTopJoin() {
|
||||
Pair<List<NamedExpression>, List<NamedExpression>> projectPair = splitProjectExprs(bOutput);
|
||||
List<NamedExpression> newLeftProjectExpr = projectPair.second;
|
||||
List<NamedExpression> newRightProjectExprs = projectPair.first;
|
||||
// Split inside-project into two part.
|
||||
Map<Boolean, List<NamedExpression>> projectExprsMap = allProjects.stream()
|
||||
.collect(Collectors.partitioningBy(projectExpr -> {
|
||||
Set<Slot> usedSlots = projectExpr.collect(Slot.class::isInstance);
|
||||
return bOutput.containsAll(usedSlots);
|
||||
}));
|
||||
|
||||
List<NamedExpression> newLeftProjectExpr = projectExprsMap.get(Boolean.FALSE);
|
||||
List<NamedExpression> newRightProjectExprs = projectExprsMap.get(Boolean.TRUE);
|
||||
|
||||
// If add project to B, we should add all slotReference used by hashOnCondition.
|
||||
// TODO: Does nonHashOnCondition also need to be considered.
|
||||
|
||||
@ -17,7 +17,6 @@
|
||||
|
||||
package org.apache.doris.nereids.rules.exploration.join;
|
||||
|
||||
import org.apache.doris.common.Pair;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.NamedExpression;
|
||||
import org.apache.doris.nereids.trees.expressions.Slot;
|
||||
@ -27,7 +26,6 @@ import org.apache.doris.nereids.trees.plans.Plan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
|
||||
import org.apache.doris.nereids.util.ExpressionUtils;
|
||||
import org.apache.doris.nereids.util.Utils;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import com.google.common.collect.Lists;
|
||||
@ -46,9 +44,9 @@ abstract class ThreeJoinHelper {
|
||||
protected final GroupPlan b;
|
||||
protected final GroupPlan c;
|
||||
|
||||
protected final List<SlotReference> aOutput;
|
||||
protected final List<SlotReference> bOutput;
|
||||
protected final List<SlotReference> cOutput;
|
||||
protected final Set<Slot> aOutput;
|
||||
protected final Set<Slot> bOutput;
|
||||
protected final Set<Slot> cOutput;
|
||||
|
||||
protected final List<NamedExpression> allProjects = Lists.newArrayList();
|
||||
|
||||
@ -72,9 +70,9 @@ abstract class ThreeJoinHelper {
|
||||
this.b = b;
|
||||
this.c = c;
|
||||
|
||||
aOutput = Utils.getOutputSlotReference(a);
|
||||
bOutput = Utils.getOutputSlotReference(b);
|
||||
cOutput = Utils.getOutputSlotReference(c);
|
||||
aOutput = a.getOutputSet();
|
||||
bOutput = b.getOutputSet();
|
||||
cOutput = c.getOutputSet();
|
||||
|
||||
Preconditions.checkArgument(!topJoin.getHashJoinConjuncts().isEmpty(), "topJoin hashJoinConjuncts must exist.");
|
||||
Preconditions.checkArgument(!bottomJoin.getHashJoinConjuncts().isEmpty(),
|
||||
@ -103,7 +101,7 @@ abstract class ThreeJoinHelper {
|
||||
// Join C = B + A for above example.
|
||||
// TODO: also need for otherJoinCondition
|
||||
for (Expression topJoinOnClauseConjunct : topJoin.getHashJoinConjuncts()) {
|
||||
Set<SlotReference> topJoinUsedSlot = topJoinOnClauseConjunct.collect(SlotReference.class::isInstance);
|
||||
Set<Slot> topJoinUsedSlot = topJoinOnClauseConjunct.collect(SlotReference.class::isInstance);
|
||||
if (ExpressionUtils.isIntersecting(topJoinUsedSlot, aOutput) && ExpressionUtils.isIntersecting(
|
||||
topJoinUsedSlot, bOutput) && ExpressionUtils.isIntersecting(topJoinUsedSlot, cOutput)) {
|
||||
return false;
|
||||
@ -140,26 +138,4 @@ abstract class ThreeJoinHelper {
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* Split inside-project into two part.
|
||||
*
|
||||
* @param topJoinChild output of topJoin groupPlan child.
|
||||
*/
|
||||
protected Pair<List<NamedExpression>, List<NamedExpression>> splitProjectExprs(List<SlotReference> topJoinChild) {
|
||||
List<NamedExpression> newTopJoinChildProjectExprs = Lists.newArrayList();
|
||||
List<NamedExpression> newBottomJoinProjectExprs = Lists.newArrayList();
|
||||
|
||||
HashSet<SlotReference> topJoinOutputSlotsSet = new HashSet<>(topJoinChild);
|
||||
|
||||
for (NamedExpression projectExpr : allProjects) {
|
||||
Set<SlotReference> usedSlotRefs = projectExpr.collect(SlotReference.class::isInstance);
|
||||
if (topJoinOutputSlotsSet.containsAll(usedSlotRefs)) {
|
||||
newTopJoinChildProjectExprs.add(projectExpr);
|
||||
} else {
|
||||
newBottomJoinProjectExprs.add(projectExpr);
|
||||
}
|
||||
}
|
||||
return Pair.of(newTopJoinChildProjectExprs, newBottomJoinProjectExprs);
|
||||
}
|
||||
}
|
||||
|
||||
@ -143,8 +143,20 @@ public class ExpressionUtils {
|
||||
/**
|
||||
* Check whether lhs and rhs are intersecting.
|
||||
*/
|
||||
public static boolean isIntersecting(Set<SlotReference> lhs, List<SlotReference> rhs) {
|
||||
for (SlotReference rh : rhs) {
|
||||
public static <T> boolean isIntersecting(Set<T> lhs, List<T> rhs) {
|
||||
for (T rh : rhs) {
|
||||
if (lhs.contains(rh)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
* Check whether lhs and rhs are intersecting.
|
||||
*/
|
||||
public static <T> boolean isIntersecting(Set<T> lhs, Set<T> rhs) {
|
||||
for (T rh : rhs) {
|
||||
if (lhs.contains(rh)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
@ -0,0 +1,110 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package org.apache.doris.nereids.rules.exploration.join;
|
||||
|
||||
import org.apache.doris.nereids.memo.Group;
|
||||
import org.apache.doris.nereids.memo.GroupExpression;
|
||||
import org.apache.doris.nereids.trees.expressions.EqualTo;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.NamedExpression;
|
||||
import org.apache.doris.nereids.trees.expressions.Slot;
|
||||
import org.apache.doris.nereids.trees.plans.JoinType;
|
||||
import org.apache.doris.nereids.trees.plans.Plan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
|
||||
import org.apache.doris.nereids.util.MemoTestUtils;
|
||||
import org.apache.doris.nereids.util.PlanChecker;
|
||||
import org.apache.doris.nereids.util.PlanConstructor;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import org.junit.jupiter.api.Assertions;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
public class JoinLAsscomProjectTest {
|
||||
|
||||
private final LogicalOlapScan scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0);
|
||||
private final LogicalOlapScan scan2 = PlanConstructor.newLogicalOlapScan(1, "t2", 0);
|
||||
private final LogicalOlapScan scan3 = PlanConstructor.newLogicalOlapScan(2, "t3", 0);
|
||||
|
||||
@Test
|
||||
public void testStarJoinLAsscomProject() {
|
||||
/*
|
||||
* Star-Join
|
||||
* t1 -- t2
|
||||
* |
|
||||
* t3
|
||||
* <p>
|
||||
* t1.id=t3.id t1.id=t2.id
|
||||
* topJoin newTopJoin
|
||||
* / \ / \
|
||||
* project t3 project project
|
||||
* t1.id=t2.id t1.id=t3.id t2
|
||||
* bottomJoin --> newBottomJoin
|
||||
* / \ / \
|
||||
* t1 t2 t1 t3
|
||||
*/
|
||||
|
||||
Expression bottomJoinOnCondition = new EqualTo(scan1.getOutput().get(0), scan2.getOutput().get(0));
|
||||
Expression topJoinOnCondition = new EqualTo(scan1.getOutput().get(1), scan3.getOutput().get(1));
|
||||
|
||||
LogicalJoin<LogicalOlapScan, LogicalOlapScan> bottomJoin = new LogicalJoin<>(JoinType.INNER_JOIN,
|
||||
Lists.newArrayList(bottomJoinOnCondition),
|
||||
Optional.empty(), scan1, scan2);
|
||||
|
||||
List<Slot> output = bottomJoin.getOutput();
|
||||
List<NamedExpression> projectExprs = output.subList(0, output.size() - 1).stream()
|
||||
.map(NamedExpression.class::cast).collect(Collectors.toList());
|
||||
LogicalProject<LogicalJoin<LogicalOlapScan, LogicalOlapScan>> project = new LogicalProject<>(
|
||||
projectExprs, bottomJoin);
|
||||
LogicalJoin<LogicalProject<LogicalJoin<LogicalOlapScan, LogicalOlapScan>>, LogicalOlapScan>
|
||||
topJoin = new LogicalJoin<>(JoinType.INNER_JOIN, Lists.newArrayList(topJoinOnCondition),
|
||||
Optional.empty(), project, scan3);
|
||||
|
||||
PlanChecker.from(MemoTestUtils.createConnectContext(), topJoin)
|
||||
.transform(JoinLAsscomProject.INNER.build())
|
||||
.checkMemo(memo -> {
|
||||
Group root = memo.getRoot();
|
||||
Assertions.assertEquals(2, root.getLogicalExpressions().size());
|
||||
|
||||
Assertions.assertTrue(root.logicalExpressionsAt(0).getPlan() instanceof LogicalJoin);
|
||||
Assertions.assertTrue(root.logicalExpressionsAt(1).getPlan() instanceof LogicalProject);
|
||||
|
||||
GroupExpression newTopJoinGroupExpr = root.logicalExpressionsAt(1).child(0).getLogicalExpression();
|
||||
GroupExpression leftProjectGroupExpr = newTopJoinGroupExpr.child(0).getLogicalExpression();
|
||||
GroupExpression rightProjectGroupExpr = newTopJoinGroupExpr.child(1).getLogicalExpression();
|
||||
Plan leftProject = newTopJoinGroupExpr.child(0).getLogicalExpression().getPlan();
|
||||
Plan rightProject = newTopJoinGroupExpr.child(1).getLogicalExpression().getPlan();
|
||||
Assertions.assertEquals(4, ((LogicalProject) leftProject).getProjects().size());
|
||||
Assertions.assertEquals(1, ((LogicalProject) rightProject).getProjects().size());
|
||||
|
||||
Plan t2 = rightProjectGroupExpr.child(0).getLogicalExpression().getPlan();
|
||||
Plan t1 = leftProjectGroupExpr.child(0).getLogicalExpression().child(0).getLogicalExpression()
|
||||
.getPlan();
|
||||
Plan t3 = leftProjectGroupExpr.child(0).getLogicalExpression().child(1).getLogicalExpression()
|
||||
.getPlan();
|
||||
Assertions.assertEquals("t2", ((LogicalOlapScan) t2).getTable().getName());
|
||||
Assertions.assertEquals("t1", ((LogicalOlapScan) t1).getTable().getName());
|
||||
Assertions.assertEquals("t3", ((LogicalOlapScan) t3).getTable().getName());
|
||||
});
|
||||
}
|
||||
}
|
||||
@ -21,7 +21,6 @@ import org.apache.doris.nereids.memo.Group;
|
||||
import org.apache.doris.nereids.memo.GroupExpression;
|
||||
import org.apache.doris.nereids.trees.expressions.EqualTo;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.SlotReference;
|
||||
import org.apache.doris.nereids.trees.plans.JoinType;
|
||||
import org.apache.doris.nereids.trees.plans.Plan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
|
||||
@ -30,13 +29,11 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
|
||||
import org.apache.doris.nereids.util.MemoTestUtils;
|
||||
import org.apache.doris.nereids.util.PlanChecker;
|
||||
import org.apache.doris.nereids.util.PlanConstructor;
|
||||
import org.apache.doris.nereids.util.Utils;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import org.junit.jupiter.api.Assertions;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
|
||||
public class JoinLAsscomTest {
|
||||
@ -45,10 +42,6 @@ public class JoinLAsscomTest {
|
||||
private final LogicalOlapScan scan2 = PlanConstructor.newLogicalOlapScan(1, "t2", 0);
|
||||
private final LogicalOlapScan scan3 = PlanConstructor.newLogicalOlapScan(2, "t3", 0);
|
||||
|
||||
private final List<SlotReference> t1Output = Utils.getOutputSlotReference(scan1);
|
||||
private final List<SlotReference> t2Output = Utils.getOutputSlotReference(scan2);
|
||||
private final List<SlotReference> t3Output = Utils.getOutputSlotReference(scan3);
|
||||
|
||||
@Test
|
||||
public void testStarJoinLAsscom() {
|
||||
/*
|
||||
@ -66,8 +59,8 @@ public class JoinLAsscomTest {
|
||||
* t1 t2 t1 t3
|
||||
*/
|
||||
|
||||
Expression bottomJoinOnCondition = new EqualTo(t1Output.get(0), t2Output.get(0));
|
||||
Expression topJoinOnCondition = new EqualTo(t1Output.get(1), t3Output.get(1));
|
||||
Expression bottomJoinOnCondition = new EqualTo(scan1.getOutput().get(0), scan2.getOutput().get(0));
|
||||
Expression topJoinOnCondition = new EqualTo(scan1.getOutput().get(1), scan3.getOutput().get(1));
|
||||
|
||||
LogicalJoin<LogicalOlapScan, LogicalOlapScan> bottomJoin = new LogicalJoin<>(JoinType.INNER_JOIN,
|
||||
Lists.newArrayList(bottomJoinOnCondition),
|
||||
@ -112,8 +105,8 @@ public class JoinLAsscomTest {
|
||||
* t1 t2 t1 t3
|
||||
*/
|
||||
|
||||
Expression bottomJoinOnCondition = new EqualTo(t1Output.get(0), t2Output.get(0));
|
||||
Expression topJoinOnCondition = new EqualTo(t2Output.get(0), t3Output.get(0));
|
||||
Expression bottomJoinOnCondition = new EqualTo(scan1.getOutput().get(0), scan2.getOutput().get(0));
|
||||
Expression topJoinOnCondition = new EqualTo(scan2.getOutput().get(0), scan3.getOutput().get(0));
|
||||
LogicalJoin<LogicalOlapScan, LogicalOlapScan> bottomJoin = new LogicalJoin<>(JoinType.INNER_JOIN,
|
||||
Lists.newArrayList(bottomJoinOnCondition),
|
||||
Optional.empty(), scan1, scan2);
|
||||
|
||||
@ -23,6 +23,8 @@ import org.apache.doris.nereids.glue.translator.PhysicalPlanTranslator;
|
||||
import org.apache.doris.nereids.glue.translator.PlanTranslatorContext;
|
||||
import org.apache.doris.nereids.parser.NereidsParser;
|
||||
import org.apache.doris.nereids.properties.PhysicalProperties;
|
||||
import org.apache.doris.nereids.rules.Rule;
|
||||
import org.apache.doris.nereids.rules.RuleSet;
|
||||
import org.apache.doris.nereids.rules.rewrite.logical.ApplyPullFilterOnAgg;
|
||||
import org.apache.doris.nereids.rules.rewrite.logical.ApplyPullFilterOnProjectUnderAgg;
|
||||
import org.apache.doris.nereids.rules.rewrite.logical.ExistsApplyToJoin;
|
||||
@ -45,6 +47,9 @@ import org.apache.doris.nereids.types.VarcharType;
|
||||
import org.apache.doris.utframe.TestWithFeService;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import com.google.common.collect.Lists;
|
||||
import mockit.Mock;
|
||||
import mockit.MockUp;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.util.List;
|
||||
@ -112,6 +117,13 @@ public class AnalyzeWhereSubqueryTest extends TestWithFeService implements Patte
|
||||
|
||||
@Test
|
||||
public void testTranslateCase() throws Exception {
|
||||
new MockUp<RuleSet>() {
|
||||
@Mock
|
||||
public List<Rule> getExplorationRules() {
|
||||
return Lists.newArrayList();
|
||||
}
|
||||
};
|
||||
|
||||
for (String sql : testSql) {
|
||||
NamedExpressionUtil.clear();
|
||||
StatementContext statementContext = MemoTestUtils.createStatementContext(connectContext, sql);
|
||||
|
||||
@ -0,0 +1,60 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package org.apache.doris.nereids.util;
|
||||
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.NamedExpression;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral;
|
||||
import org.apache.doris.nereids.trees.plans.Plan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import org.junit.jupiter.api.Assertions;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
class PlanUtilsTest {
|
||||
|
||||
@Test
|
||||
void projectOrSelf() {
|
||||
LogicalOlapScan scan = PlanConstructor.newLogicalOlapScan(0, "t1", 0);
|
||||
Plan self = PlanUtils.projectOrSelf(Lists.newArrayList(), scan);
|
||||
Assertions.assertSame(scan, self);
|
||||
|
||||
NamedExpression slot = scan.getOutput().get(0);
|
||||
List<NamedExpression> projects = Lists.newArrayList();
|
||||
projects.add(slot);
|
||||
Plan project = PlanUtils.projectOrSelf(projects, scan);
|
||||
Assertions.assertTrue(project instanceof LogicalProject);
|
||||
}
|
||||
|
||||
@Test
|
||||
void filterOrSelf() {
|
||||
LogicalOlapScan scan = PlanConstructor.newLogicalOlapScan(0, "t1", 0);
|
||||
Plan filterOrSelf = PlanUtils.filterOrSelf(Lists.newArrayList(), scan);
|
||||
Assertions.assertSame(scan, filterOrSelf);
|
||||
|
||||
List<Expression> predicate = Lists.newArrayList();
|
||||
predicate.add(BooleanLiteral.TRUE);
|
||||
Plan filter = PlanUtils.filterOrSelf(predicate, scan);
|
||||
Assertions.assertTrue(filter instanceof LogicalFilter);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user