[enhancement](Nereids) generate colocate join when property is different with require property (#15479)

1. When checking HashProperty which's type is nature, we only need to check whether the required properties contain all shuffle column
2. In ChildrenPropertiesRegulator.java, when colocate/buckte join is not allowed, we will enforce the required property.
This commit is contained in:
谢健
2023-01-05 11:41:18 +08:00
committed by GitHub
parent 4f2a36f032
commit 0dfa143140
8 changed files with 189 additions and 46 deletions

View File

@ -1552,9 +1552,10 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla
DistributionSpecHash leftDistributionSpec
= (DistributionSpecHash) physicalHashJoin.left().getPhysicalProperties().getDistributionSpec();
Pair<List<ExprId>, List<ExprId>> onClauseUsedSlots = JoinUtils.getOnClauseUsedSlots(physicalHashJoin);
List<ExprId> rightPartitionExprIds = Lists.newArrayList(onClauseUsedSlots.second);
for (int i = 0; i < onClauseUsedSlots.first.size(); i++) {
int idx = leftDistributionSpec.getExprIdToEquivalenceSet().get(onClauseUsedSlots.first.get(i));
List<ExprId> rightPartitionExprIds = Lists.newArrayList(leftDistributionSpec.getOrderedShuffledColumns());
for (int i = 0; i < leftDistributionSpec.getOrderedShuffledColumns().size(); i++) {
int idx = leftDistributionSpec.getExprIdToEquivalenceSet()
.get(leftDistributionSpec.getOrderedShuffledColumns().get(i));
rightPartitionExprIds.set(idx, onClauseUsedSlots.second.get(i));
}
// assemble fragment

View File

@ -33,6 +33,7 @@ import org.apache.doris.nereids.stats.StatsCalculator;
import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
@ -56,6 +57,8 @@ public class CostAndEnforcerJob extends Job implements Cloneable {
// [ [Properties {"", ANY}, Properties {"", BROADCAST}],
// [Properties {"", SHUFFLE_JOIN}, Properties {"", SHUFFLE_JOIN}]]
private List<List<PhysicalProperties>> requestChildrenPropertiesList;
private List<List<PhysicalProperties>> outputChildrenPropertiesList = new ArrayList<>();
// index of List<request property to children>
private int requestPropertiesIndex = 0;
@ -113,6 +116,9 @@ public class CostAndEnforcerJob extends Job implements Cloneable {
// [Properties {"", SHUFFLE_JOIN}, Properties {"", SHUFFLE_JOIN}] ]
RequestPropertyDeriver requestPropertyDeriver = new RequestPropertyDeriver(context);
requestChildrenPropertiesList = requestPropertyDeriver.getRequestChildrenPropertyList(groupExpression);
for (List<PhysicalProperties> requestChildrenProperties : requestChildrenPropertiesList) {
outputChildrenPropertiesList.add(new ArrayList<>(requestChildrenProperties));
}
}
for (; requestPropertiesIndex < requestChildrenPropertiesList.size(); requestPropertiesIndex++) {
@ -120,7 +126,8 @@ public class CostAndEnforcerJob extends Job implements Cloneable {
// like: [ Properties {"", ANY}, Properties {"", BROADCAST} ],
List<PhysicalProperties> requestChildrenProperties
= requestChildrenPropertiesList.get(requestPropertiesIndex);
List<PhysicalProperties> outputChildrenProperties
= outputChildrenPropertiesList.get(requestPropertiesIndex);
// Calculate cost
if (curChildIndex == 0 && prevChildIndex == -1) {
curNodeCost = CostCalculator.calculateCost(groupExpression);
@ -164,11 +171,10 @@ public class CostAndEnforcerJob extends Job implements Cloneable {
lowestCostChildren.add(lowestCostExpr);
PhysicalProperties outputProperties = lowestCostExpr.getOutputProperties(requestChildProperty);
// use child's outputProperties to reset the request properties, so no unnecessary enforce.
// this is safety because `childGroup.getLowestCostPlan(current plan's requestChildProperty).
// getOutputProperties(current plan's requestChildProperty) == child plan's outputProperties`,
// the outputProperties must satisfy the origin requestChildProperty
requestChildrenProperties.set(curChildIndex, outputProperties);
// record child's outputProperties.this is safety because `childGroup.getLowestCostPlan(current
// plan's requestChildProperty).getOutputProperties(current plan's requestChildProperty) == child
// plan's outputProperties`, the outputProperties must satisfy the origin requestChildProperty
outputChildrenProperties.set(curChildIndex, outputProperties);
curTotalCost += lowestCostExpr.getLowestCostTable().get(requestChildProperty).first;
if (curTotalCost > context.getCostUpperBound()) {
@ -182,7 +188,7 @@ public class CostAndEnforcerJob extends Job implements Cloneable {
// This mean that we successfully optimize all child groups.
// if break when running the loop above, the condition must be false.
if (curChildIndex == groupExpression.arity()) {
if (!calculateEnforce(requestChildrenProperties)) {
if (!calculateEnforce(requestChildrenProperties, outputChildrenProperties)) {
return; // if error exists, return
}
if (curTotalCost < context.getCostUpperBound()) {
@ -195,13 +201,15 @@ public class CostAndEnforcerJob extends Job implements Cloneable {
/**
* calculate enforce
*
* @return false if error occurs, the caller will return.
*/
private boolean calculateEnforce(List<PhysicalProperties> requestChildrenProperties) {
private boolean calculateEnforce(List<PhysicalProperties> requestChildrenProperties,
List<PhysicalProperties> outputChildrenProperties) {
// to ensure distributionSpec has been added sufficiently.
// it's certain that lowestCostChildren is equals to arity().
ChildrenPropertiesRegulator regulator = new ChildrenPropertiesRegulator(groupExpression,
lowestCostChildren, requestChildrenProperties, requestChildrenProperties, context);
lowestCostChildren, outputChildrenProperties, requestChildrenProperties, context);
double enforceCost = regulator.adjustChildrenProperties();
if (enforceCost < 0) {
// invalid enforce, return.
@ -212,7 +220,7 @@ public class CostAndEnforcerJob extends Job implements Cloneable {
// Not need to do pruning here because it has been done when we get the
// best expr from the child group
ChildOutputPropertyDeriver childOutputPropertyDeriver
= new ChildOutputPropertyDeriver(requestChildrenProperties);
= new ChildOutputPropertyDeriver(outputChildrenProperties);
// the physical properties the group expression support for its parent.
PhysicalProperties outputProperty = childOutputPropertyDeriver.getOutputProperties(groupExpression);
@ -232,22 +240,23 @@ public class CostAndEnforcerJob extends Job implements Cloneable {
curTotalCost += curNodeCost;
// record map { outputProperty -> outputProperty }, { ANY -> outputProperty },
recordPropertyAndCost(groupExpression, outputProperty, PhysicalProperties.ANY, requestChildrenProperties);
recordPropertyAndCost(groupExpression, outputProperty, outputProperty, requestChildrenProperties);
enforce(outputProperty, requestChildrenProperties);
recordPropertyAndCost(groupExpression, outputProperty, PhysicalProperties.ANY, outputChildrenProperties);
recordPropertyAndCost(groupExpression, outputProperty, outputProperty, outputChildrenProperties);
enforce(outputProperty, outputChildrenProperties);
return true;
}
/**
* add enforce node
*
* @param outputProperty the group expression's out property
* @param requestChildrenProperty the group expression's request to its child.
* @param outputChildrenProperty the children's output properties of this group expression.
*/
private void enforce(PhysicalProperties outputProperty, List<PhysicalProperties> requestChildrenProperty) {
private void enforce(PhysicalProperties outputProperty, List<PhysicalProperties> outputChildrenProperty) {
PhysicalProperties requiredProperties = context.getRequiredProperties();
if (outputProperty.satisfy(requiredProperties)) {
if (!outputProperty.equals(requiredProperties)) {
recordPropertyAndCost(groupExpression, outputProperty, requiredProperties, requestChildrenProperty);
recordPropertyAndCost(groupExpression, outputProperty, requiredProperties, outputChildrenProperty);
}
return;
}
@ -266,6 +275,7 @@ public class CostAndEnforcerJob extends Job implements Cloneable {
/**
* record property and cost
*
* @param groupExpression the target group expression
* @param outputProperty the child output physical corresponding to the required property of the group expression.
* @param requestProperty mentioned above

View File

@ -22,7 +22,9 @@ import org.apache.doris.nereids.cost.CostCalculator;
import org.apache.doris.nereids.jobs.JobContext;
import org.apache.doris.nereids.memo.GroupExpression;
import org.apache.doris.nereids.properties.DistributionSpecHash.ShuffleType;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.physical.PhysicalDistribute;
import org.apache.doris.nereids.trees.plans.physical.PhysicalHashAggregate;
import org.apache.doris.nereids.trees.plans.physical.PhysicalHashJoin;
import org.apache.doris.nereids.trees.plans.visitor.PlanVisitor;
@ -32,7 +34,9 @@ import org.apache.doris.qe.ConnectContext;
import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.List;
import java.util.Set;
/**
* ensure child add enough distribute.
@ -58,6 +62,7 @@ public class ChildrenPropertiesRegulator extends PlanVisitor<Double, Void> {
/**
* adjust children properties
*
* @return enforce cost.
*/
public double adjustChildrenProperties() {
@ -99,13 +104,13 @@ public class ChildrenPropertiesRegulator extends PlanVisitor<Double, Void> {
GroupExpression leftChild = children.get(0);
final Pair<Double, List<PhysicalProperties>> leftLowest
= leftChild.getLowestCostTable().get(requiredProperties.get(0));
PhysicalProperties leftOutput = leftChild.getOutputProperties(requiredProperties.get(0));
= leftChild.getLowestCostTable().get(childrenProperties.get(0));
PhysicalProperties leftOutput = leftChild.getOutputProperties(childrenProperties.get(0));
GroupExpression rightChild = children.get(1);
final Pair<Double, List<PhysicalProperties>> rightLowest
= rightChild.getLowestCostTable().get(requiredProperties.get(1));
PhysicalProperties rightOutput = rightChild.getOutputProperties(requiredProperties.get(1));
Pair<Double, List<PhysicalProperties>> rightLowest
= rightChild.getLowestCostTable().get(childrenProperties.get(1));
PhysicalProperties rightOutput = rightChild.getOutputProperties(childrenProperties.get(1));
// check colocate join
if (leftHashSpec.getShuffleType() == ShuffleType.NATURAL
@ -115,38 +120,80 @@ public class ChildrenPropertiesRegulator extends PlanVisitor<Double, Void> {
}
}
// check right hand must distribute
if (rightHashSpec.getShuffleType() != ShuffleType.ENFORCED) {
enforceCost += updateChildEnforceAndCost(rightChild, rightOutput,
rightHashSpec, rightLowest.first);
childrenProperties.set(1, new PhysicalProperties(
rightHashSpec.withShuffleType(ShuffleType.ENFORCED),
childrenProperties.get(1).getOrderSpec()));
}
// check bucket shuffle join
if (leftHashSpec.getShuffleType() != ShuffleType.ENFORCED) {
if (ConnectContext.get().getSessionVariable().isEnableBucketShuffleJoin()) {
// We need to recalculate the required property of right child,
// to make right child compatible with left child.
PhysicalProperties rightRequireProperties = calRightRequiredOfBucketShuffleJoin(leftHashSpec,
rightHashSpec);
if (!rightOutput.equals(rightRequireProperties)) {
enforceCost += updateChildEnforceAndCost(rightChild, rightOutput,
(DistributionSpecHash) rightRequireProperties.getDistributionSpec(), rightLowest.first);
}
childrenProperties.set(1, rightRequireProperties);
return enforceCost;
}
enforceCost += updateChildEnforceAndCost(leftChild, leftOutput,
leftHashSpec, leftLowest.first);
childrenProperties.set(0, new PhysicalProperties(
leftHashSpec.withShuffleType(ShuffleType.ENFORCED),
childrenProperties.get(0).getOrderSpec()));
(DistributionSpecHash) requiredProperties.get(0).getDistributionSpec(), leftLowest.first);
childrenProperties.set(0, requiredProperties.get(0));
}
// check right hand must distribute.
if (rightHashSpec.getShuffleType() != ShuffleType.ENFORCED) {
enforceCost += updateChildEnforceAndCost(rightChild, rightOutput,
(DistributionSpecHash) requiredProperties.get(1).getDistributionSpec(), rightLowest.first);
childrenProperties.set(1, requiredProperties.get(1));
}
return enforceCost;
}
private PhysicalProperties calRightRequiredOfBucketShuffleJoin(DistributionSpecHash leftHashSpec,
DistributionSpecHash rightHashSpec) {
Preconditions.checkArgument(leftHashSpec.getShuffleType() != ShuffleType.ENFORCED);
DistributionSpecHash leftRequireSpec = (DistributionSpecHash) requiredProperties.get(0).getDistributionSpec();
DistributionSpecHash rightRequireSpec = (DistributionSpecHash) requiredProperties.get(1).getDistributionSpec();
List<ExprId> rightShuffleIds = new ArrayList<>();
for (ExprId scanId : leftHashSpec.getOrderedShuffledColumns()) {
int index = leftRequireSpec.getOrderedShuffledColumns().indexOf(scanId);
if (index == -1) {
// when there is no exprId in leftHashSpec, we need to check EquivalenceExprIds
Set<ExprId> equivalentExprIds = leftHashSpec.getEquivalenceExprIdsOf(scanId);
for (ExprId alternativeExpr : equivalentExprIds) {
index = leftRequireSpec.getOrderedShuffledColumns().indexOf(alternativeExpr);
if (index != -1) {
break;
}
}
}
Preconditions.checkArgument(index != -1);
rightShuffleIds.add(rightRequireSpec.getOrderedShuffledColumns().get(index));
}
return new PhysicalProperties(new DistributionSpecHash(rightShuffleIds, ShuffleType.ENFORCED,
rightHashSpec.getTableId(), rightHashSpec.getPartitionIds()));
}
private double updateChildEnforceAndCost(GroupExpression child, PhysicalProperties childOutput,
DistributionSpecHash required, double currentCost) {
double enforceCost = 0;
if (child.getPlan() instanceof PhysicalDistribute) {
//To avoid continuous distribute operator, we just enforce the child's child
childOutput = child.getInputPropertiesList(childOutput).get(0);
Pair<Double, GroupExpression> newChildAndCost
= child.getOwnerGroup().getLowestCostPlan(childOutput).get();
child = newChildAndCost.second;
enforceCost = newChildAndCost.first - currentCost;
currentCost = newChildAndCost.first;
}
DistributionSpec outputDistributionSpec;
outputDistributionSpec = required.withShuffleType(ShuffleType.ENFORCED);
PhysicalProperties newOutputProperty = new PhysicalProperties(outputDistributionSpec);
GroupExpression enforcer = outputDistributionSpec.addEnforcer(child.getOwnerGroup());
jobContext.getCascadesContext().getMemo().addEnforcerPlan(enforcer, child.getOwnerGroup());
double enforceCost = CostCalculator.calculateCost(enforcer);
enforceCost = Double.sum(enforceCost, CostCalculator.calculateCost(enforcer));
if (enforcer.updateLowestCostTable(newOutputProperty,
Lists.newArrayList(childOutput), enforceCost + currentCost)) {

View File

@ -28,6 +28,7 @@ import com.google.common.collect.Sets;
import java.util.BitSet;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
@ -154,6 +155,13 @@ public class DistributionSpecHash extends DistributionSpec {
return exprIdToEquivalenceSet;
}
public Set<ExprId> getEquivalenceExprIdsOf(ExprId exprId) {
if (exprIdToEquivalenceSet.containsKey(exprId)) {
return equivalenceExprIds.get(exprIdToEquivalenceSet.get(exprId));
}
return new HashSet<>();
}
@Override
public boolean satisfy(DistributionSpec required) {
if (required instanceof DistributionSpecAny) {
@ -170,12 +178,19 @@ public class DistributionSpecHash extends DistributionSpec {
return false;
}
if (requiredHash.shuffleType == ShuffleType.NATURAL && this.shuffleType != ShuffleType.NATURAL) {
// this shuffle type is not natural but require natural
return false;
}
if (requiredHash.shuffleType == ShuffleType.AGGREGATE) {
return containsSatisfy(requiredHash.getOrderedShuffledColumns());
}
if (requiredHash.shuffleType == ShuffleType.NATURAL && this.shuffleType != ShuffleType.NATURAL) {
return false;
// If the required property is from join and this property is not enforced, we only need to check to contain
// And more checking is in ChildrenPropertiesRegulator
if (requiredHash.shuffleType == shuffleType.JOIN && this.shuffleType != shuffleType.ENFORCED) {
return containsSatisfy(requiredHash.getOrderedShuffledColumns());
}
return equalsSatisfy(requiredHash.getOrderedShuffledColumns());

View File

@ -306,7 +306,7 @@ public class DistributionSpecHashTest {
// require is same order
Assertions.assertTrue(join1.satisfy(join2));
// require contains all sets but order is not same
Assertions.assertFalse(join1.satisfy(join3));
Assertions.assertTrue(join1.satisfy(join3));
// require slots is not contained by target
Assertions.assertFalse(join3.satisfy(join1));
// other shuffle type with same order
@ -314,8 +314,8 @@ public class DistributionSpecHashTest {
Assertions.assertTrue(aggregate.satisfy(join2));
Assertions.assertTrue(enforce.satisfy(join2));
// other shuffle type contain all set but order is not same
Assertions.assertFalse(natural.satisfy(join3));
Assertions.assertFalse(aggregate.satisfy(join3));
Assertions.assertTrue(natural.satisfy(join3));
Assertions.assertTrue(aggregate.satisfy(join3));
Assertions.assertFalse(enforce.satisfy(join3));
}

View File

@ -18,8 +18,11 @@
package org.apache.doris.nereids.sqltest;
import org.apache.doris.nereids.rules.rewrite.logical.ReorderJoin;
import org.apache.doris.nereids.trees.plans.physical.PhysicalDistribute;
import org.apache.doris.nereids.trees.plans.physical.PhysicalPlan;
import org.apache.doris.nereids.util.PlanChecker;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
public class JoinTest extends SqlTestBase {
@ -33,4 +36,26 @@ public class JoinTest extends SqlTestBase {
innerLogicalJoin().when(j -> j.getHashJoinConjuncts().size() == 1)
);
}
@Test
void testColocatedJoin() {
String sql = "select * from T2 join T2 b on T2.id = b.id and T2.id = b.id;";
PhysicalPlan plan = PlanChecker.from(connectContext)
.analyze(sql)
.rewrite()
.deriveStats()
.optimize()
.getBestPlanTree();
// generate colocate join plan without physicalDistribute
Assertions.assertFalse(plan.anyMatch(PhysicalDistribute.class::isInstance));
sql = "select * from T1 join T0 on T1.score = T0.score and T1.id = T0.id;";
plan = PlanChecker.from(connectContext)
.analyze(sql)
.rewrite()
.deriveStats()
.optimize()
.getBestPlanTree();
// generate colocate join plan without physicalDistribute
Assertions.assertFalse(plan.anyMatch(PhysicalDistribute.class::isInstance));
}
}

View File

@ -28,21 +28,32 @@ public abstract class SqlTestBase extends TestWithFeService implements PatternMa
connectContext.setDatabase("default_cluster:test");
createTables(
"CREATE TABLE IF NOT EXISTS T0 (\n"
+ " id bigint,\n"
+ " score bigint\n"
+ ")\n"
+ "DUPLICATE KEY(id)\n"
+ "DISTRIBUTED BY HASH(id, score) BUCKETS 10\n"
+ "PROPERTIES (\n"
+ " \"replication_num\" = \"1\", \n"
+ " \"colocate_with\" = \"T0\"\n"
+ ")\n",
"CREATE TABLE IF NOT EXISTS T1 (\n"
+ " id bigint,\n"
+ " score bigint\n"
+ ")\n"
+ "DUPLICATE KEY(id)\n"
+ "DISTRIBUTED BY HASH(id) BUCKETS 1\n"
+ "DISTRIBUTED BY HASH(id, score) BUCKETS 10\n"
+ "PROPERTIES (\n"
+ " \"replication_num\" = \"1\"\n"
+ " \"replication_num\" = \"1\", \n"
+ " \"colocate_with\" = \"T0\"\n"
+ ")\n",
"CREATE TABLE IF NOT EXISTS T2 (\n"
+ " id bigint,\n"
+ " score bigint\n"
+ ")\n"
+ "DUPLICATE KEY(id)\n"
+ "DISTRIBUTED BY HASH(id) BUCKETS 1\n"
+ "DISTRIBUTED BY HASH(id) BUCKETS 10\n"
+ "PROPERTIES (\n"
+ " \"replication_num\" = \"1\"\n"
+ ")\n",

View File

@ -21,6 +21,7 @@ import org.apache.doris.analysis.ExplainOptions;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.NereidsPlanner;
import org.apache.doris.nereids.StatementContext;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.glue.LogicalPlanAdapter;
import org.apache.doris.nereids.jobs.JobContext;
import org.apache.doris.nereids.jobs.batch.NereidsRewriteJobExecutor;
@ -35,6 +36,7 @@ import org.apache.doris.nereids.pattern.GroupExpressionMatching;
import org.apache.doris.nereids.pattern.MatchingContext;
import org.apache.doris.nereids.pattern.PatternDescriptor;
import org.apache.doris.nereids.pattern.PatternMatcher;
import org.apache.doris.nereids.properties.PhysicalProperties;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleFactory;
import org.apache.doris.nereids.rules.RuleSet;
@ -452,6 +454,38 @@ public class PlanChecker {
return cascadesContext.getMemo().copyOut();
}
private PhysicalPlan chooseBestPlan(Group rootGroup, PhysicalProperties physicalProperties) {
GroupExpression groupExpression = rootGroup.getLowestCostPlan(physicalProperties).orElseThrow(
() -> new AnalysisException("lowestCostPlans with physicalProperties("
+ physicalProperties + ") doesn't exist in root group")).second;
List<PhysicalProperties> inputPropertiesList = groupExpression.getInputPropertiesList(physicalProperties);
List<Plan> planChildren = Lists.newArrayList();
for (int i = 0; i < groupExpression.arity(); i++) {
planChildren.add(chooseBestPlan(groupExpression.child(i), inputPropertiesList.get(i)));
}
Plan plan = groupExpression.getPlan().withChildren(planChildren);
if (!(plan instanceof PhysicalPlan)) {
throw new AnalysisException("Result plan must be PhysicalPlan");
}
PhysicalPlan physicalPlan = ((PhysicalPlan) plan).withPhysicalPropertiesAndStats(
groupExpression.getOutputProperties(physicalProperties),
groupExpression.getOwnerGroup().getStatistics());
return physicalPlan;
}
public PhysicalPlan getBestPlanTree() {
return chooseBestPlan(cascadesContext.getMemo().getRoot(), PhysicalProperties.ANY);
}
public PlanChecker printlnBestPlanTree() {
System.out.println(chooseBestPlan(cascadesContext.getMemo().getRoot(), PhysicalProperties.ANY).treeString());
System.out.println("-----------------------------");
return this;
}
public PlanChecker printlnTree() {
System.out.println(cascadesContext.getMemo().copyOut().treeString());
System.out.println("-----------------------------");