[enhancement](Nereids): optimize GroupExpressionMatching (#26196)

This commit is contained in:
jakevin
2023-11-01 19:05:08 +08:00
committed by GitHub
parent 502f5778f4
commit 6010be88bd
10 changed files with 76 additions and 63 deletions

View File

@ -61,7 +61,7 @@ public class ApplyRuleJob extends Job {
}
@Override
public void execute() throws AnalysisException {
public final void execute() throws AnalysisException {
if (groupExpression.hasApplied(rule)
|| groupExpression.isUnused()) {
return;

View File

@ -55,6 +55,7 @@ public class GroupExpressionMatching implements Iterable<Plan> {
public static class GroupExpressionIterator implements Iterator<Plan> {
private final List<Plan> results = Lists.newArrayList();
private int resultIndex = 0;
private int resultsSize;
/**
* Constructor.
@ -103,7 +104,7 @@ public class GroupExpressionMatching implements Iterable<Plan> {
// matching children group, one List<Plan> per child
// first dimension is every child group's plan
// second dimension is all matched plan in one group
List<List<Plan>> childrenPlans = Lists.newArrayListWithCapacity(childrenGroupArity);
List<Plan>[] childrenPlans = new List[childrenGroupArity];
for (int i = 0; i < childrenGroupArity; ++i) {
Group childGroup = groupExpression.child(i);
List<Plan> childrenPlan = matchingChildGroup(pattern, childGroup, i);
@ -116,7 +117,7 @@ public class GroupExpressionMatching implements Iterable<Plan> {
return;
}
}
childrenPlans.add(childrenPlan);
childrenPlans[i] = childrenPlan;
}
assembleAllCombinationPlanTree(root, pattern, groupExpression, childrenPlans);
} else if (patternArity == 1 && (pattern.hasMultiChild() || pattern.hasMultiGroupChild())) {
@ -127,6 +128,7 @@ public class GroupExpressionMatching implements Iterable<Plan> {
results.add(root);
}
}
this.resultsSize = results.size();
}
private List<Plan> matchingChildGroup(Pattern<? extends Plan> parentPattern,
@ -154,38 +156,35 @@ public class GroupExpressionMatching implements Iterable<Plan> {
}
private void assembleAllCombinationPlanTree(Plan root, Pattern<Plan> rootPattern,
GroupExpression groupExpression,
List<List<Plan>> childrenPlans) {
int[] childrenPlanIndex = new int[childrenPlans.size()];
GroupExpression groupExpression, List<Plan>[] childrenPlans) {
int childrenPlansSize = childrenPlans.length;
int[] childrenPlanIndex = new int[childrenPlansSize];
int offset = 0;
LogicalProperties logicalProperties = groupExpression.getOwnerGroup().getLogicalProperties();
// assemble all combination of plan tree by current root plan and children plan
while (offset < childrenPlans.size()) {
ImmutableList.Builder<Plan> childrenBuilder =
ImmutableList.builderWithExpectedSize(childrenPlans.size());
for (int i = 0; i < childrenPlans.size(); i++) {
childrenBuilder.add(childrenPlans.get(i).get(childrenPlanIndex[i]));
Optional<GroupExpression> groupExprOption = Optional.of(groupExpression);
Optional<LogicalProperties> logicalPropOption = Optional.of(logicalProperties);
while (offset < childrenPlansSize) {
ImmutableList.Builder<Plan> childrenBuilder = ImmutableList.builderWithExpectedSize(childrenPlansSize);
for (int i = 0; i < childrenPlansSize; i++) {
childrenBuilder.add(childrenPlans[i].get(childrenPlanIndex[i]));
}
List<Plan> children = childrenBuilder.build();
// assemble children: replace GroupPlan to real plan,
// withChildren will erase groupExpression, so we must
// withGroupExpression too.
Plan rootWithChildren = root.withGroupExprLogicalPropChildren(Optional.of(groupExpression),
Optional.of(logicalProperties), children);
Plan rootWithChildren = root.withGroupExprLogicalPropChildren(groupExprOption,
logicalPropOption, children);
if (rootPattern.matchPredicates(rootWithChildren)) {
results.add(rootWithChildren);
}
offset = 0;
while (true) {
for (offset = 0; offset < childrenPlansSize; offset++) {
childrenPlanIndex[offset]++;
if (childrenPlanIndex[offset] == childrenPlans.get(offset).size()) {
if (childrenPlanIndex[offset] == childrenPlans[offset].size()) {
// Reset the index when it reaches the size of the current child plan list
childrenPlanIndex[offset] = 0;
offset++;
if (offset == childrenPlans.size()) {
break;
}
} else {
break;
}
@ -195,7 +194,7 @@ public class GroupExpressionMatching implements Iterable<Plan> {
@Override
public boolean hasNext() {
return resultIndex < results.size();
return resultIndex < resultsSize;
}
@Override

View File

@ -17,10 +17,6 @@
package org.apache.doris.nereids.trees;
import org.apache.doris.nereids.trees.expressions.StatementScopeIdGenerator;
import org.apache.doris.nereids.trees.plans.ObjectId;
import org.apache.doris.planner.PlanNodeId;
import com.google.common.collect.ImmutableList;
import java.util.List;
@ -33,7 +29,6 @@ import java.util.List;
*/
public abstract class AbstractTreeNode<NODE_TYPE extends TreeNode<NODE_TYPE>>
implements TreeNode<NODE_TYPE> {
protected final ObjectId id = StatementScopeIdGenerator.newObjectId();
protected final List<NODE_TYPE> children;
// TODO: Maybe we should use a GroupPlan to avoid TreeNode hold the GroupExpression.
// https://github.com/apache/doris/pull/9807#discussion_r884829067
@ -59,12 +54,4 @@ public abstract class AbstractTreeNode<NODE_TYPE extends TreeNode<NODE_TYPE>>
public int arity() {
return children.size();
}
/**
* used for PhysicalPlanTranslator only
* @return PlanNodeId
*/
public PlanNodeId translatePlanNodeId() {
return id.toPlanNodeId();
}
}

View File

@ -37,6 +37,7 @@ import org.apache.doris.nereids.types.coercion.AnyDataType;
import org.apache.doris.nereids.util.Utils;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Lists;
import org.apache.commons.lang3.StringUtils;
@ -44,7 +45,6 @@ import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
/**
* Abstract class for all Expression in Nereids.
@ -247,8 +247,19 @@ public abstract class Expression extends AbstractTreeNode<Expression> implements
return collect(Slot.class::isInstance);
}
/**
* Get all the input slot ids of the expression.
* <p>
* Note that the input slots of subquery's inner plan is not included.
*/
public final Set<ExprId> getInputSlotExprIds() {
return getInputSlots().stream().map(NamedExpression::getExprId).collect(Collectors.toSet());
ImmutableSet.Builder<ExprId> result = ImmutableSet.builder();
foreach(node -> {
if (node instanceof Slot) {
result.add(((Slot) node).getExprId());
}
});
return result.build();
}
public boolean isLiteral() {

View File

@ -24,9 +24,11 @@ import org.apache.doris.nereids.properties.UnboundLogicalProperties;
import org.apache.doris.nereids.trees.AbstractTreeNode;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.StatementScopeIdGenerator;
import org.apache.doris.nereids.util.MutableState;
import org.apache.doris.nereids.util.MutableState.EmptyMutableState;
import org.apache.doris.nereids.util.TreeStringUtils;
import org.apache.doris.planner.PlanNodeId;
import org.apache.doris.statistics.Statistics;
import com.google.common.base.Supplier;
@ -45,6 +47,7 @@ import javax.annotation.Nullable;
*/
public abstract class AbstractPlan extends AbstractTreeNode<Plan> implements Plan {
public static final String FRAGMENT_ID = "fragment";
protected final ObjectId id = StatementScopeIdGenerator.newObjectId();
protected final Statistics statistics;
protected final PlanType type;
@ -168,4 +171,12 @@ public abstract class AbstractPlan extends AbstractTreeNode<Plan> implements Pla
public void setMutableState(String key, Object state) {
this.mutableState = this.mutableState.set(key, state);
}
/**
* used for PhysicalPlanTranslator only
* @return PlanNodeId
*/
public PlanNodeId translatePlanNodeId() {
return id.toPlanNodeId();
}
}

View File

@ -43,9 +43,9 @@ import org.apache.doris.statistics.Statistics;
import org.apache.doris.thrift.TRuntimeFilterType;
import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
@ -113,22 +113,25 @@ public class PhysicalHashJoin<
* Return pair of left used slots and right used slots.
*/
public Pair<List<ExprId>, List<ExprId>> getHashConjunctsExprIds() {
List<ExprId> exprIds1 = Lists.newArrayListWithCapacity(hashJoinConjuncts.size());
List<ExprId> exprIds2 = Lists.newArrayListWithCapacity(hashJoinConjuncts.size());
int size = hashJoinConjuncts.size();
List<ExprId> exprIds1 = new ArrayList<>(size);
List<ExprId> exprIds2 = new ArrayList<>(size);
Set<ExprId> leftExprIds = left().getOutputExprIdSet();
Set<ExprId> rightExprIds = right().getOutputExprIdSet();
for (Expression expr : hashJoinConjuncts) {
expr.getInputSlotExprIds().forEach(exprId -> {
for (ExprId exprId : expr.getInputSlotExprIds()) {
if (leftExprIds.contains(exprId)) {
exprIds1.add(exprId);
} else if (rightExprIds.contains(exprId)) {
exprIds2.add(exprId);
} else {
throw new RuntimeException("Could not generate valid equal on clause slot pairs for join");
throw new RuntimeException("Invalid ExprId found: " + exprId
+ ". Cannot generate valid equal on clause slot pairs for join.");
}
});
}
}
return Pair.of(exprIds1, exprIds2);
}

View File

@ -38,6 +38,7 @@ import org.apache.doris.nereids.trees.plans.physical.PhysicalDistribute;
import org.apache.doris.nereids.trees.plans.physical.PhysicalHashJoin;
import org.apache.doris.nereids.trees.plans.physical.PhysicalPlan;
import org.apache.doris.qe.ConnectContext;
import org.apache.doris.qe.SessionVariable;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
@ -66,9 +67,10 @@ public class JoinUtils {
* check if the row count of the left child in the broadcast join is less than a threshold value.
*/
public static boolean checkBroadcastJoinStats(PhysicalHashJoin<? extends Plan, ? extends Plan> join) {
double memLimit = ConnectContext.get().getSessionVariable().getMaxExecMemByte();
double rowsLimit = ConnectContext.get().getSessionVariable().getBroadcastRowCountLimit();
double brMemlimit = ConnectContext.get().getSessionVariable().getBroadcastHashtableMemLimitPercentage();
SessionVariable sessionVariable = ConnectContext.get().getSessionVariable();
double memLimit = sessionVariable.getMaxExecMemByte();
double rowsLimit = sessionVariable.getBroadcastRowCountLimit();
double brMemlimit = sessionVariable.getBroadcastHashtableMemLimitPercentage();
double datasize = join.getGroupExpression().get().child(1).getStatistics().computeSize();
double rowCount = join.getGroupExpression().get().child(1).getStatistics().getRowCount();
return rowCount <= rowsLimit && datasize <= memLimit * brMemlimit;
@ -114,12 +116,12 @@ public class JoinUtils {
* @return true if the equal can be used as hash join condition
*/
public boolean isHashJoinCondition(EqualTo equalTo) {
Set<Slot> equalLeft = equalTo.left().collect(Slot.class::isInstance);
Set<Slot> equalLeft = equalTo.left().getInputSlots();
if (equalLeft.isEmpty()) {
return false;
}
Set<Slot> equalRight = equalTo.right().collect(Slot.class::isInstance);
Set<Slot> equalRight = equalTo.right().getInputSlots();
if (equalRight.isEmpty()) {
return false;
}

View File

@ -55,7 +55,7 @@ public class PlanUtils {
* normalize comparison predicate on a binary plan to its two sides are corresponding to the child's output.
*/
public static ComparisonPredicate maybeCommuteComparisonPredicate(ComparisonPredicate expression, Plan left) {
Set<Slot> slots = expression.left().collect(Slot.class::isInstance);
Set<Slot> slots = expression.left().getInputSlots();
Set<Slot> leftSlots = left.getOutputSet();
Set<Slot> buffer = Sets.newHashSet(slots);
buffer.removeAll(leftSlots);

View File

@ -42,10 +42,10 @@ import org.junit.jupiter.api.Test;
import java.util.Iterator;
public class GroupExpressionMatchingTest {
class GroupExpressionMatchingTest {
@Test
public void testLeafNode() {
void testLeafNode() {
Pattern pattern = new Pattern<>(PlanType.LOGICAL_UNBOUND_RELATION);
Memo memo = new Memo(null, new UnboundRelation(StatementScopeIdGenerator.newRelationId(), Lists.newArrayList("test")));
@ -61,7 +61,7 @@ public class GroupExpressionMatchingTest {
}
@Test
public void testDepth2() {
void testDepth2() {
Pattern pattern = new Pattern<>(PlanType.LOGICAL_PROJECT,
new Pattern<>(PlanType.LOGICAL_UNBOUND_RELATION));
@ -93,7 +93,7 @@ public class GroupExpressionMatchingTest {
}
@Test
public void testDepth2WithGroup() {
void testDepth2WithGroup() {
Pattern pattern = new Pattern<>(PlanType.LOGICAL_PROJECT, Pattern.GROUP);
Plan leaf = new UnboundRelation(StatementScopeIdGenerator.newRelationId(), Lists.newArrayList("test"));
@ -119,7 +119,7 @@ public class GroupExpressionMatchingTest {
}
@Test
public void testLeafAny() {
void testLeafAny() {
Pattern pattern = Pattern.ANY;
Memo memo = new Memo(null, new UnboundRelation(StatementScopeIdGenerator.newRelationId(), Lists.newArrayList("test")));
@ -135,7 +135,7 @@ public class GroupExpressionMatchingTest {
}
@Test
public void testAnyWithChild() {
void testAnyWithChild() {
Plan root = new LogicalProject(
ImmutableList.of(new SlotReference("name", StringType.INSTANCE, true,
ImmutableList.of("test"))),
@ -159,7 +159,7 @@ public class GroupExpressionMatchingTest {
}
@Test
public void testInnerLogicalJoinMatch() {
void testInnerLogicalJoinMatch() {
Plan root = new LogicalJoin(JoinType.INNER_JOIN,
new UnboundRelation(StatementScopeIdGenerator.newRelationId(), ImmutableList.of("a")),
new UnboundRelation(StatementScopeIdGenerator.newRelationId(), ImmutableList.of("b"))
@ -181,7 +181,7 @@ public class GroupExpressionMatchingTest {
}
@Test
public void testInnerLogicalJoinMismatch() {
void testInnerLogicalJoinMismatch() {
Plan root = new LogicalJoin(JoinType.LEFT_OUTER_JOIN,
new UnboundRelation(StatementScopeIdGenerator.newRelationId(), ImmutableList.of("a")),
new UnboundRelation(StatementScopeIdGenerator.newRelationId(), ImmutableList.of("b"))
@ -198,7 +198,7 @@ public class GroupExpressionMatchingTest {
}
@Test
public void testTopMatchButChildrenNotMatch() {
void testTopMatchButChildrenNotMatch() {
Plan root = new LogicalJoin(JoinType.LEFT_OUTER_JOIN,
new UnboundRelation(StatementScopeIdGenerator.newRelationId(), ImmutableList.of("a")),
new UnboundRelation(StatementScopeIdGenerator.newRelationId(), ImmutableList.of("b"))
@ -216,12 +216,12 @@ public class GroupExpressionMatchingTest {
}
@Test
public void testSubTreeMatch() {
void testSubTreeMatch() {
Plan root =
new LogicalFilter(ImmutableSet.of(new EqualTo(new UnboundSlot(Lists.newArrayList("a", "id")),
new LogicalFilter<>(ImmutableSet.of(new EqualTo(new UnboundSlot(Lists.newArrayList("a", "id")),
new UnboundSlot(Lists.newArrayList("b", "id")))),
new LogicalJoin(JoinType.INNER_JOIN,
new LogicalJoin(JoinType.LEFT_OUTER_JOIN,
new LogicalJoin<>(JoinType.INNER_JOIN,
new LogicalJoin<>(JoinType.LEFT_OUTER_JOIN,
new UnboundRelation(StatementScopeIdGenerator.newRelationId(), ImmutableList.of("a")),
new UnboundRelation(StatementScopeIdGenerator.newRelationId(), ImmutableList.of("b"))),
new UnboundRelation(StatementScopeIdGenerator.newRelationId(), ImmutableList.of("c")))