[fix](colocate join) fix wrong use of colocate join (#37361) (#37714)

cherry-pick from master #37361
This commit is contained in:
camby
2024-07-15 16:47:17 +08:00
committed by GitHub
parent e5339a4014
commit 57301920e3
5 changed files with 137 additions and 14 deletions

View File

@ -279,7 +279,7 @@ public class ChildOutputPropertyDeriver extends PlanVisitor<PhysicalProperties,
case RIGHT_SEMI_JOIN:
case RIGHT_ANTI_JOIN:
case RIGHT_OUTER_JOIN:
if (JoinUtils.couldColocateJoin(leftHashSpec, rightHashSpec)) {
if (JoinUtils.couldColocateJoin(leftHashSpec, rightHashSpec, hashJoin.getHashJoinConjuncts())) {
return new PhysicalProperties(rightHashSpec);
} else {
// retain left shuffle type, since coordinator use left most node to schedule fragment

View File

@ -245,7 +245,7 @@ public class ChildrenPropertiesRegulator extends PlanVisitor<Boolean, Void> {
Optional<PhysicalProperties> updatedForLeft = Optional.empty();
Optional<PhysicalProperties> updatedForRight = Optional.empty();
if (JoinUtils.couldColocateJoin(leftHashSpec, rightHashSpec)) {
if (JoinUtils.couldColocateJoin(leftHashSpec, rightHashSpec, hashJoin.getHashJoinConjuncts())) {
// check colocate join with scan
return true;
} else if (couldNotRightBucketShuffleJoin(hashJoin.getJoinType(), leftHashSpec, rightHashSpec)) {

View File

@ -83,8 +83,8 @@ public class LogicalOlapScanToPhysicalOlapScan extends OneImplementationRuleFact
List<Slot> output = olapScan.getOutput();
List<Slot> baseOutput = olapScan.getOutputByIndex(olapScan.getTable().getBaseIndexId());
List<ExprId> hashColumns = Lists.newArrayList();
for (Slot slot : output) {
for (Column column : hashDistributionInfo.getDistributionColumns()) {
for (Column column : hashDistributionInfo.getDistributionColumns()) {
for (Slot slot : output) {
if (((SlotReference) slot).getColumn().get().getNameWithoutMvPrefix()
.equals(column.getName())) {
hashColumns.add(slot.getExprId());
@ -92,8 +92,8 @@ public class LogicalOlapScanToPhysicalOlapScan extends OneImplementationRuleFact
}
}
if (hashColumns.size() != hashDistributionInfo.getDistributionColumns().size()) {
for (Slot slot : baseOutput) {
for (Column column : hashDistributionInfo.getDistributionColumns()) {
for (Column column : hashDistributionInfo.getDistributionColumns()) {
for (Slot slot : baseOutput) {
// If the length of the column in the bucket key changes after DDL, the length cannot be
// determined. As a result, some bucket fields are lost in the query execution plan.
// So here we use the column name to avoid this problem
@ -109,8 +109,8 @@ public class LogicalOlapScanToPhysicalOlapScan extends OneImplementationRuleFact
HashDistributionInfo hashDistributionInfo = (HashDistributionInfo) distributionInfo;
List<Slot> output = olapScan.getOutput();
List<ExprId> hashColumns = Lists.newArrayList();
for (Slot slot : output) {
for (Column column : hashDistributionInfo.getDistributionColumns()) {
for (Column column : hashDistributionInfo.getDistributionColumns()) {
for (Slot slot : output) {
// If the length of the column in the bucket key changes after DDL, the length cannot be
// determined. As a result, some bucket fields are lost in the query execution plan.
// So here we use the column name to avoid this problem

View File

@ -32,6 +32,7 @@ import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.MarkJoinSlotReference;
import org.apache.doris.nereids.trees.expressions.Not;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.functions.scalar.BitmapContains;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.Plan;
@ -54,6 +55,7 @@ import com.google.common.collect.Sets;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
@ -257,13 +259,14 @@ public class JoinUtils {
return false;
}
return couldColocateJoin((DistributionSpecHash) leftDistributionSpec,
(DistributionSpecHash) rightDistributionSpec);
(DistributionSpecHash) rightDistributionSpec, join.getHashJoinConjuncts());
}
/**
* could do colocate join with left and right child distribution spec.
*/
public static boolean couldColocateJoin(DistributionSpecHash leftHashSpec, DistributionSpecHash rightHashSpec) {
public static boolean couldColocateJoin(DistributionSpecHash leftHashSpec, DistributionSpecHash rightHashSpec,
List<Expression> conjuncts) {
if (ConnectContext.get() == null
|| ConnectContext.get().getSessionVariable().isDisableColocatePlan()) {
return false;
@ -285,12 +288,50 @@ public class JoinUtils {
boolean noNeedCheckColocateGroup = hitSameIndex && (leftTablePartitions.equals(rightTablePartitions))
&& (leftTablePartitions.size() <= 1);
ColocateTableIndex colocateIndex = Env.getCurrentColocateIndex();
if (noNeedCheckColocateGroup
|| (colocateIndex.isSameGroup(leftTableId, rightTableId)
&& !colocateIndex.isGroupUnstable(colocateIndex.getGroup(leftTableId)))) {
if (noNeedCheckColocateGroup) {
return true;
}
return false;
if (!colocateIndex.isSameGroup(leftTableId, rightTableId)
|| colocateIndex.isGroupUnstable(colocateIndex.getGroup(leftTableId))) {
return false;
}
Set<Integer> equalIndices = new HashSet<>();
for (Expression expr : conjuncts) {
// only simple equal predicate can use colocate join
if (!(expr instanceof EqualPredicate)) {
return false;
}
Expression leftChild = ((EqualPredicate) expr).left();
Expression rightChild = ((EqualPredicate) expr).right();
if (!(leftChild instanceof SlotReference) || !(rightChild instanceof SlotReference)) {
return false;
}
SlotReference leftSlot = (SlotReference) leftChild;
SlotReference rightSlot = (SlotReference) rightChild;
Integer leftIndex = null;
Integer rightIndex = null;
if (leftSlot.getTable().isPresent() && leftSlot.getTable().get().getId() == leftHashSpec.getTableId()) {
leftIndex = leftHashSpec.getExprIdToEquivalenceSet().get(leftSlot.getExprId());
rightIndex = rightHashSpec.getExprIdToEquivalenceSet().get(rightSlot.getExprId());
} else {
leftIndex = rightHashSpec.getExprIdToEquivalenceSet().get(leftSlot.getExprId());
rightIndex = leftHashSpec.getExprIdToEquivalenceSet().get(rightSlot.getExprId());
}
if (!Objects.equals(leftIndex, rightIndex)) {
return false;
}
if (leftIndex != null) {
equalIndices.add(leftIndex);
}
}
// on conditions must contain all distributed columns
if (equalIndices.containsAll(leftHashSpec.getExprIdToEquivalenceSet().values())) {
return true;
} else {
return false;
}
}
public static Set<ExprId> getJoinOutputExprIdSet(Plan left, Plan right) {