[enhancement](Nereids) compare LogicalProperties with output set instead of output list (#12743)

We used output list to compare two LogicalProperties before. Since join reorder will change the children order of a join plan and caused output list changed. the two join plan will not equals anymore in memo although they should be. So we must add a project on the new join to keep the LogicalProperties the same.
This PR changes the equals and hashCode funtions of LogicalProperties. use a set of output to compare two LogicalProperties. Then we do not need add the top peoject anymore. This help us keep memo simple and efficient.
This commit is contained in:
morrySnow
2022-09-20 10:55:29 +08:00
committed by GitHub
parent d435f0de41
commit 954c44db39
14 changed files with 135 additions and 177 deletions

View File

@ -549,7 +549,6 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla
.map(TupleDescriptor::getSlots)
.flatMap(Collection::stream)
.collect(Collectors.toList());
TupleDescriptor outputDescriptor = context.generateTupleDesc();
Map<ExprId, SlotReference> outputSlotReferenceMap = Maps.newHashMap();
hashJoin.getOutput().stream()
@ -579,6 +578,15 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla
.map(SlotReference.class::cast)
.forEach(s -> hashOutputSlotReferenceMap.put(s.getExprId(), s));
Map<ExprId, SlotReference> leftChildOutputMap = Maps.newHashMap();
hashJoin.child(0).getOutput().stream()
.map(SlotReference.class::cast)
.forEach(s -> leftChildOutputMap.put(s.getExprId(), s));
Map<ExprId, SlotReference> rightChildOutputMap = Maps.newHashMap();
hashJoin.child(1).getOutput().stream()
.map(SlotReference.class::cast)
.forEach(s -> rightChildOutputMap.put(s.getExprId(), s));
//make intermediate tuple
List<SlotDescriptor> leftIntermediateSlotDescriptor = Lists.newArrayList();
List<SlotDescriptor> rightIntermediateSlotDescriptor = Lists.newArrayList();
@ -586,47 +594,43 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla
if (!hashJoin.getOtherJoinCondition().isPresent()
&& (joinType == JoinType.LEFT_ANTI_JOIN || joinType == JoinType.LEFT_SEMI_JOIN)) {
leftIntermediateSlotDescriptor = hashJoin.child(0).getOutput().stream()
.map(SlotReference.class::cast)
.map(s -> context.createSlotDesc(intermediateDescriptor, s))
.collect(Collectors.toList());
for (SlotDescriptor leftSlotDescriptor : leftSlotDescriptors) {
SlotReference sf = leftChildOutputMap.get(context.findExprId(leftSlotDescriptor.getId()));
SlotDescriptor sd = context.createSlotDesc(intermediateDescriptor, sf);
leftIntermediateSlotDescriptor.add(sd);
}
} else if (!hashJoin.getOtherJoinCondition().isPresent()
&& (joinType == JoinType.RIGHT_ANTI_JOIN || joinType == JoinType.RIGHT_SEMI_JOIN)) {
rightIntermediateSlotDescriptor = hashJoin.child(1).getOutput().stream()
.map(SlotReference.class::cast)
.map(s -> context.createSlotDesc(intermediateDescriptor, s))
.collect(Collectors.toList());
for (SlotDescriptor rightSlotDescriptor : rightSlotDescriptors) {
SlotReference sf = rightChildOutputMap.get(context.findExprId(rightSlotDescriptor.getId()));
SlotDescriptor sd = context.createSlotDesc(intermediateDescriptor, sf);
rightIntermediateSlotDescriptor.add(sd);
}
} else {
for (int i = 0; i < hashJoin.child(0).getOutput().size(); i++) {
SlotReference sf = (SlotReference) hashJoin.child(0).getOutput().get(i);
for (SlotDescriptor leftSlotDescriptor : leftSlotDescriptors) {
SlotReference sf = leftChildOutputMap.get(context.findExprId(leftSlotDescriptor.getId()));
SlotDescriptor sd = context.createSlotDesc(intermediateDescriptor, sf);
if (hashOutputSlotReferenceMap.get(sf.getExprId()) != null) {
hashJoinNode.addSlotIdToHashOutputSlotIds(leftSlotDescriptors.get(i).getId());
hashJoinNode.addSlotIdToHashOutputSlotIds(leftSlotDescriptor.getId());
}
leftIntermediateSlotDescriptor.add(sd);
}
for (int i = 0; i < hashJoin.child(1).getOutput().size(); i++) {
SlotReference sf = (SlotReference) hashJoin.child(1).getOutput().get(i);
for (SlotDescriptor rightSlotDescriptor : rightSlotDescriptors) {
SlotReference sf = rightChildOutputMap.get(context.findExprId(rightSlotDescriptor.getId()));
SlotDescriptor sd = context.createSlotDesc(intermediateDescriptor, sf);
if (hashOutputSlotReferenceMap.get(sf.getExprId()) != null) {
hashJoinNode.addSlotIdToHashOutputSlotIds(rightSlotDescriptors.get(i).getId());
hashJoinNode.addSlotIdToHashOutputSlotIds(rightSlotDescriptor.getId());
}
rightIntermediateSlotDescriptor.add(sd);
}
}
//set slots as nullable for outer join
if (joinType == JoinType.FULL_OUTER_JOIN) {
rightIntermediateSlotDescriptor.stream()
.forEach(sd -> sd.setIsNullable(true));
leftIntermediateSlotDescriptor.stream()
.forEach(sd -> sd.setIsNullable(true));
} else if (joinType == JoinType.LEFT_OUTER_JOIN) {
rightIntermediateSlotDescriptor.stream()
.forEach(sd -> sd.setIsNullable(true));
} else if (joinType == JoinType.RIGHT_OUTER_JOIN) {
leftIntermediateSlotDescriptor.stream()
.forEach(sd -> sd.setIsNullable(true));
if (joinType == JoinType.LEFT_OUTER_JOIN || joinType == JoinType.FULL_OUTER_JOIN) {
rightIntermediateSlotDescriptor.forEach(sd -> sd.setIsNullable(true));
}
if (joinType == JoinType.RIGHT_OUTER_JOIN || joinType == JoinType.FULL_OUTER_JOIN) {
leftIntermediateSlotDescriptor.forEach(sd -> sd.setIsNullable(true));
}
List<Expr> otherJoinConjuncts = hashJoin.getOtherJoinCondition()
@ -658,7 +662,8 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla
.map(e -> ExpressionTranslator.translate(e, context))
.collect(Collectors.toList());
outputSlotReferences.stream().forEach(s -> context.createSlotDesc(outputDescriptor, s));
TupleDescriptor outputDescriptor = context.generateTupleDesc();
outputSlotReferences.forEach(s -> context.createSlotDesc(outputDescriptor, s));
hashJoinNode.setvOutputTupleDesc(outputDescriptor);
hashJoinNode.setvSrcToOutputSMap(srcToOutput);

View File

@ -489,9 +489,7 @@ public class Memo {
List<Plan> groupPlanChildren = childrenGroups.stream()
.map(GroupPlan::new)
.collect(ImmutableList.toImmutableList());
LogicalProperties logicalProperties = plan.getLogicalProperties();
return plan.withChildren(groupPlanChildren)
.withLogicalProperties(Optional.of(logicalProperties));
return plan.withChildren(groupPlanChildren);
}
/*

View File

@ -17,19 +17,24 @@
package org.apache.doris.nereids.properties;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import com.google.common.base.Supplier;
import com.google.common.base.Suppliers;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
/**
* Logical properties used for analysis and optimize in Nereids.
*/
public class LogicalProperties {
protected Supplier<List<Slot>> outputSupplier;
protected final Supplier<List<Slot>> outputSupplier;
protected final Supplier<HashSet<ExprId>> outputSetSupplier;
/**
* constructor of LogicalProperties.
@ -41,6 +46,10 @@ public class LogicalProperties {
this.outputSupplier = Suppliers.memoize(
Objects.requireNonNull(outputSupplier, "outputSupplier can not be null")
);
this.outputSetSupplier = Suppliers.memoize(
() -> outputSupplier.get().stream().map(NamedExpression::getExprId)
.collect(Collectors.toCollection(HashSet::new))
);
}
public List<Slot> getOutput() {
@ -60,11 +69,11 @@ public class LogicalProperties {
return false;
}
LogicalProperties that = (LogicalProperties) o;
return Objects.equals(outputSupplier.get(), that.outputSupplier.get());
return Objects.equals(outputSetSupplier.get(), that.outputSetSupplier.get());
}
@Override
public int hashCode() {
return Objects.hash(outputSupplier.get());
return Objects.hash(outputSetSupplier.get());
}
}

View File

@ -20,7 +20,6 @@ package org.apache.doris.nereids.rules;
import org.apache.doris.nereids.rules.exploration.join.InnerJoinLAsscom;
import org.apache.doris.nereids.rules.exploration.join.InnerJoinLAsscomProject;
import org.apache.doris.nereids.rules.exploration.join.JoinCommute;
import org.apache.doris.nereids.rules.exploration.join.JoinCommuteProject;
import org.apache.doris.nereids.rules.exploration.join.OuterJoinLAsscom;
import org.apache.doris.nereids.rules.exploration.join.OuterJoinLAsscomProject;
import org.apache.doris.nereids.rules.exploration.join.SemiJoinLogicalJoinTranspose;
@ -57,7 +56,6 @@ import java.util.List;
public class RuleSet {
public static final List<Rule> EXPLORATION_RULES = planRuleFactories()
.add(JoinCommute.LEFT_DEEP)
.add(JoinCommuteProject.LEFT_DEEP)
.add(InnerJoinLAsscom.INSTANCE)
.add(InnerJoinLAsscomProject.INSTANCE)
.add(OuterJoinLAsscom.INSTANCE)

View File

@ -23,9 +23,6 @@ import org.apache.doris.nereids.rules.exploration.OneExplorationRuleFactory;
import org.apache.doris.nereids.rules.exploration.join.JoinCommuteHelper.SwapType;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.util.PlanUtils;
import java.util.ArrayList;
/**
* Join Commute
@ -58,7 +55,7 @@ public class JoinCommute extends OneExplorationRuleFactory {
newJoin.getJoinReorderContext().setHasCommuteZigZag(true);
}
return PlanUtils.project(new ArrayList<>(join.getOutput()), newJoin).get();
return newJoin;
}).toRule(RuleType.LOGICAL_JOIN_COMMUTATE);
}
}

View File

@ -1,68 +0,0 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.exploration.join;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.exploration.OneExplorationRuleFactory;
import org.apache.doris.nereids.rules.exploration.join.JoinCommuteHelper.SwapType;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.util.PlanUtils;
import java.util.ArrayList;
/**
* Project-Join Commute.
* This rule can prevent double JoinCommute cause dead-loop in Memo.
*/
public class JoinCommuteProject extends OneExplorationRuleFactory {
public static final JoinCommuteProject LEFT_DEEP = new JoinCommuteProject(SwapType.LEFT_DEEP);
public static final JoinCommuteProject ZIG_ZAG = new JoinCommuteProject(SwapType.ZIG_ZAG);
public static final JoinCommuteProject BUSHY = new JoinCommuteProject(SwapType.BUSHY);
private final SwapType swapType;
public JoinCommuteProject(SwapType swapType) {
this.swapType = swapType;
}
@Override
public Rule build() {
return logicalProject(logicalJoin())
.when(project -> JoinCommuteHelper.check(swapType, project.child()))
.then(project -> {
LogicalJoin<GroupPlan, GroupPlan> join = project.child();
// prevent this join match by JoinCommute.
join.getGroupExpression().get().setApplied(RuleType.LOGICAL_JOIN_COMMUTATE);
LogicalJoin<GroupPlan, GroupPlan> newJoin = new LogicalJoin<>(
join.getJoinType().swap(),
join.getHashJoinConjuncts(),
join.getOtherJoinCondition(),
join.right(), join.left(),
join.getJoinReorderContext());
newJoin.getJoinReorderContext().setHasCommute(true);
if (swapType == SwapType.ZIG_ZAG && JoinCommuteHelper.isNotBottomJoin(join)) {
newJoin.getJoinReorderContext().setHasCommuteZigZag(true);
}
return PlanUtils.project(new ArrayList<>(project.getProjects()), newJoin).get();
}).toRule(RuleType.LOGICAL_JOIN_COMMUTATE);
}
}

View File

@ -96,6 +96,10 @@ class JoinLAsscomHelper extends ThreeJoinHelper {
topJoin.getJoinReorderContext());
newTopJoin.getJoinReorderContext().setHasLAsscom(true);
if (topJoin.getLogicalProperties().equals(newTopJoin.getLogicalProperties())) {
return newTopJoin;
}
return PlanUtils.projectOrSelf(new ArrayList<>(topJoin.getOutput()), newTopJoin);
}

View File

@ -87,9 +87,9 @@ public class InnerJoinLAsscomProjectTest {
Assertions.assertEquals(2, root.getLogicalExpressions().size());
Assertions.assertTrue(root.logicalExpressionsAt(0).getPlan() instanceof LogicalJoin);
Assertions.assertTrue(root.logicalExpressionsAt(1).getPlan() instanceof LogicalProject);
Assertions.assertTrue(root.logicalExpressionsAt(1).getPlan() instanceof LogicalJoin);
GroupExpression newTopJoinGroupExpr = root.logicalExpressionsAt(1).child(0).getLogicalExpression();
GroupExpression newTopJoinGroupExpr = root.logicalExpressionsAt(1);
GroupExpression leftProjectGroupExpr = newTopJoinGroupExpr.child(0).getLogicalExpression();
GroupExpression rightProjectGroupExpr = newTopJoinGroupExpr.child(1).getLogicalExpression();
Plan leftProject = newTopJoinGroupExpr.child(0).getLogicalExpression().getPlan();

View File

@ -25,7 +25,6 @@ import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.util.MemoTestUtils;
import org.apache.doris.nereids.util.PlanChecker;
import org.apache.doris.nereids.util.PlanConstructor;
@ -76,9 +75,9 @@ public class InnerJoinLAsscomTest {
Assertions.assertEquals(2, root.getLogicalExpressions().size());
Assertions.assertTrue(root.logicalExpressionsAt(0).getPlan() instanceof LogicalJoin);
Assertions.assertTrue(root.logicalExpressionsAt(1).getPlan() instanceof LogicalProject);
Assertions.assertTrue(root.logicalExpressionsAt(1).getPlan() instanceof LogicalJoin);
GroupExpression newTopJoinGroupExpr = root.logicalExpressionsAt(1).child(0).getLogicalExpression();
GroupExpression newTopJoinGroupExpr = root.logicalExpressionsAt(1);
GroupExpression newBottomJoinGroupExpr = newTopJoinGroupExpr.child(0).getLogicalExpression();
Plan bottomLeft = newBottomJoinGroupExpr.child(0).getLogicalExpression().getPlan();
Plan bottomRight = newBottomJoinGroupExpr.child(1).getLogicalExpression().getPlan();

View File

@ -25,7 +25,6 @@ import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.util.LogicalPlanBuilder;
import org.apache.doris.nereids.util.MemoTestUtils;
import org.apache.doris.nereids.util.PlanChecker;
@ -51,9 +50,9 @@ public class JoinCommuteTest {
Assertions.assertEquals(2, root.getLogicalExpressions().size());
Assertions.assertTrue(root.logicalExpressionsAt(0).getPlan() instanceof LogicalJoin);
Assertions.assertTrue(root.logicalExpressionsAt(1).getPlan() instanceof LogicalProject);
Assertions.assertTrue(root.logicalExpressionsAt(1).getPlan() instanceof LogicalJoin);
GroupExpression newJoinGroupExpr = root.logicalExpressionsAt(1).child(0).getLogicalExpression();
GroupExpression newJoinGroupExpr = root.logicalExpressionsAt(1);
Plan left = newJoinGroupExpr.child(0).getLogicalExpression().getPlan();
Plan right = newJoinGroupExpr.child(1).getLogicalExpression().getPlan();
Assertions.assertTrue(left instanceof LogicalOlapScan);