[fix](nereids) fix distribution expr list (#39435)

pick from #39148
This commit is contained in:
xzj7019
2024-08-22 15:19:51 +08:00
committed by GitHub
parent 1c566253a8
commit 10f3e88f7a
4 changed files with 187 additions and 147 deletions

View File

@ -434,8 +434,7 @@ public class NereidsPlanner extends Planner {
// add groupExpression to plan so that we could print group id in plan.treeString()
plan = plan.withGroupExpression(Optional.of(groupExpression));
PhysicalPlan physicalPlan = ((PhysicalPlan) plan).withPhysicalPropertiesAndStats(
groupExpression.getOutputProperties(physicalProperties),
groupExpression.getOwnerGroup().getStatistics());
physicalProperties, groupExpression.getOwnerGroup().getStatistics());
return physicalPlan;
} catch (Exception e) {
if (e instanceof AnalysisException && e.getMessage().contains("Failed to choose best plan")) {

View File

@ -256,8 +256,25 @@ public class ChildOutputPropertyDeriver extends PlanVisitor<PhysicalProperties,
// broadcast
if (rightOutputProperty.getDistributionSpec() instanceof DistributionSpecReplicated) {
DistributionSpec parentDistributionSpec = leftOutputProperty.getDistributionSpec();
return new PhysicalProperties(parentDistributionSpec);
DistributionSpec leftDistributionSpec = leftOutputProperty.getDistributionSpec();
// if left side is hash distribute and the key can satisfy the join keys, then mock
// a right side hash spec with the corresponding join keys, to filling the returning spec
// with refined EquivalenceExprIds.
if (leftDistributionSpec instanceof DistributionSpecHash
&& !(hashJoin.isMarkJoin() && hashJoin.getHashJoinConjuncts().isEmpty())
&& !hashJoin.getHashConjunctsExprIds().first.isEmpty()
&& !hashJoin.getHashConjunctsExprIds().second.isEmpty()
&& hashJoin.getHashConjunctsExprIds().first.size()
== hashJoin.getHashConjunctsExprIds().second.size()
&& leftDistributionSpec.satisfy(
new DistributionSpecHash(hashJoin.getHashConjunctsExprIds().first, ShuffleType.REQUIRE))) {
DistributionSpecHash mockedRightHashSpec = mockAnotherSideSpecFromConjuncts(
hashJoin, (DistributionSpecHash) leftDistributionSpec);
return computeShuffleJoinOutputProperties(hashJoin,
(DistributionSpecHash) leftDistributionSpec, mockedRightHashSpec);
} else {
return new PhysicalProperties(leftDistributionSpec);
}
}
// shuffle
@ -265,33 +282,7 @@ public class ChildOutputPropertyDeriver extends PlanVisitor<PhysicalProperties,
&& rightOutputProperty.getDistributionSpec() instanceof DistributionSpecHash) {
DistributionSpecHash leftHashSpec = (DistributionSpecHash) leftOutputProperty.getDistributionSpec();
DistributionSpecHash rightHashSpec = (DistributionSpecHash) rightOutputProperty.getDistributionSpec();
switch (hashJoin.getJoinType()) {
case INNER_JOIN:
case CROSS_JOIN:
return new PhysicalProperties(DistributionSpecHash.merge(
leftHashSpec, rightHashSpec, leftHashSpec.getShuffleType()));
case LEFT_SEMI_JOIN:
case LEFT_ANTI_JOIN:
case NULL_AWARE_LEFT_ANTI_JOIN:
case LEFT_OUTER_JOIN:
return new PhysicalProperties(leftHashSpec);
case RIGHT_SEMI_JOIN:
case RIGHT_ANTI_JOIN:
case RIGHT_OUTER_JOIN:
if (JoinUtils.couldColocateJoin(leftHashSpec, rightHashSpec, hashJoin.getHashJoinConjuncts())) {
return new PhysicalProperties(rightHashSpec);
} else {
// retain left shuffle type, since coordinator use left most node to schedule fragment
// forbid colocate join, since right table already shuffle
return new PhysicalProperties(rightHashSpec.withShuffleTypeAndForbidColocateJoin(
leftHashSpec.getShuffleType()));
}
case FULL_OUTER_JOIN:
return PhysicalProperties.createAnyFromHash(leftHashSpec);
default:
throw new AnalysisException("unknown join type " + hashJoin.getJoinType());
}
return computeShuffleJoinOutputProperties(hashJoin, leftHashSpec, rightHashSpec);
}
throw new RuntimeException("Could not derive hash join's output properties. join: " + hashJoin);
@ -465,6 +456,61 @@ public class ChildOutputPropertyDeriver extends PlanVisitor<PhysicalProperties,
return childrenOutputProperties.get(0);
}
private PhysicalProperties computeShuffleJoinOutputProperties(
PhysicalHashJoin<? extends Plan, ? extends Plan> hashJoin,
DistributionSpecHash leftHashSpec, DistributionSpecHash rightHashSpec) {
switch (hashJoin.getJoinType()) {
case INNER_JOIN:
case CROSS_JOIN:
return new PhysicalProperties(DistributionSpecHash.merge(
leftHashSpec, rightHashSpec, leftHashSpec.getShuffleType()));
case LEFT_SEMI_JOIN:
case LEFT_ANTI_JOIN:
case NULL_AWARE_LEFT_ANTI_JOIN:
case LEFT_OUTER_JOIN:
return new PhysicalProperties(leftHashSpec);
case RIGHT_SEMI_JOIN:
case RIGHT_ANTI_JOIN:
case RIGHT_OUTER_JOIN:
if (JoinUtils.couldColocateJoin(leftHashSpec, rightHashSpec, hashJoin.getHashJoinConjuncts())) {
return new PhysicalProperties(rightHashSpec);
} else {
// retain left shuffle type, since coordinator use left most node to schedule fragment
// forbid colocate join, since right table already shuffle
return new PhysicalProperties(rightHashSpec.withShuffleTypeAndForbidColocateJoin(
leftHashSpec.getShuffleType()));
}
case FULL_OUTER_JOIN:
return PhysicalProperties.createAnyFromHash(leftHashSpec);
default:
throw new AnalysisException("unknown join type " + hashJoin.getJoinType());
}
}
private DistributionSpecHash mockAnotherSideSpecFromConjuncts(
PhysicalHashJoin<? extends Plan, ? extends Plan> hashJoin, DistributionSpecHash oneSideSpec) {
List<ExprId> leftExprIds = hashJoin.getHashConjunctsExprIds().first;
List<ExprId> rightExprIds = hashJoin.getHashConjunctsExprIds().second;
Preconditions.checkState(!leftExprIds.isEmpty() && !rightExprIds.isEmpty()
&& leftExprIds.size() == rightExprIds.size(), "invalid hash join conjuncts");
List<ExprId> anotherSideOrderedExprIds = Lists.newArrayList();
for (ExprId exprId : oneSideSpec.getOrderedShuffledColumns()) {
int index = leftExprIds.indexOf(exprId);
if (index == -1) {
Set<ExprId> equivalentExprIds = oneSideSpec.getEquivalenceExprIdsOf(exprId);
for (ExprId id : equivalentExprIds) {
index = leftExprIds.indexOf(id);
if (index >= 0) {
break;
}
}
Preconditions.checkState(index >= 0, "can't find exprId in equivalence set");
}
anotherSideOrderedExprIds.add(rightExprIds.get(index));
}
return new DistributionSpecHash(anotherSideOrderedExprIds, oneSideSpec.getShuffleType());
}
private boolean isSameHashValue(DataType originType, DataType castType) {
if (originType.isStringLikeType() && (castType.isVarcharType() || castType.isStringType())
&& (castType.width() >= originType.width() || castType.width() < 0)) {

View File

@ -19,14 +19,17 @@ package org.apache.doris.nereids.properties;
import org.apache.doris.catalog.ColocateTableIndex;
import org.apache.doris.catalog.Env;
import org.apache.doris.common.IdGenerator;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.hint.DistributeHint;
import org.apache.doris.nereids.memo.Group;
import org.apache.doris.nereids.memo.GroupExpression;
import org.apache.doris.nereids.memo.GroupId;
import org.apache.doris.nereids.properties.DistributionSpecHash.ShuffleType;
import org.apache.doris.nereids.trees.expressions.AssertNumRowsElement;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateParam;
import org.apache.doris.nereids.trees.plans.AggMode;
@ -63,6 +66,7 @@ import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
@ -529,21 +533,25 @@ class ChildOutputPropertyDeriverTest {
}
@Test
void testBroadcastJoin() {
new MockUp<JoinUtils>() {
@Mock
Pair<List<ExprId>, List<ExprId>> getOnClauseUsedSlots(
AbstractPhysicalJoin<? extends Plan, ? extends Plan> join) {
return Pair.of(Lists.newArrayList(new ExprId(0)), Lists.newArrayList(new ExprId(2)));
}
};
void testBroadcastJoin(@Injectable LogicalProperties p1, @Injectable GroupPlan p2) {
SlotReference leftSlot = new SlotReference(new ExprId(0), "left", IntegerType.INSTANCE, false, Collections.emptyList());
SlotReference rightSlot = new SlotReference(new ExprId(2), "right", IntegerType.INSTANCE, false, Collections.emptyList());
List<Slot> leftOutput = new ArrayList<>();
List<Slot> rightOutput = new ArrayList<>();
leftOutput.add(leftSlot);
rightOutput.add(rightSlot);
LogicalProperties leftProperties = new LogicalProperties(() -> leftOutput, () -> FunctionalDependencies.EMPTY_FUNC_DEPS);
LogicalProperties rightProperties = new LogicalProperties(() -> rightOutput, () -> FunctionalDependencies.EMPTY_FUNC_DEPS);
IdGenerator<GroupId> idGenerator = GroupId.createGenerator();
GroupPlan leftGroupPlan = new GroupPlan(new Group(idGenerator.getNextId(), leftProperties));
GroupPlan rightGroupPlan = new GroupPlan(new Group(idGenerator.getNextId(), rightProperties));
PhysicalHashJoin<GroupPlan, GroupPlan> join = new PhysicalHashJoin<>(JoinType.INNER_JOIN,
Lists.newArrayList(new EqualTo(
new SlotReference(new ExprId(0), "left", IntegerType.INSTANCE, false, Collections.emptyList()),
new SlotReference(new ExprId(2), "right", IntegerType.INSTANCE, false,
Collections.emptyList()))),
ExpressionUtils.EMPTY_CONDITION, new DistributeHint(DistributeType.NONE), Optional.empty(), logicalProperties, groupPlan, groupPlan);
leftSlot, rightSlot
)),
ExpressionUtils.EMPTY_CONDITION, new DistributeHint(DistributeType.NONE),
Optional.empty(), logicalProperties, leftGroupPlan, rightGroupPlan);
GroupExpression groupExpression = new GroupExpression(join);
new Group(null, groupExpression, null);
@ -572,7 +580,7 @@ class ChildOutputPropertyDeriverTest {
DistributionSpecHash actual = (DistributionSpecHash) result.getDistributionSpec();
Assertions.assertEquals(ShuffleType.NATURAL, actual.getShuffleType());
// check merged
Assertions.assertEquals(2, actual.getExprIdToEquivalenceSet().size());
Assertions.assertEquals(3, actual.getExprIdToEquivalenceSet().size());
}
@Test