[Performance](Nereids): optimize GroupExpressionMatching (#26084)

This commit is contained in:
jakevin
2023-10-30 14:05:25 +08:00
committed by GitHub
parent eb2cbae6e3
commit 0d956e90cf
7 changed files with 40 additions and 44 deletions

View File

@ -61,7 +61,7 @@ public class ApplyRuleJob extends Job {
}
@Override
public void execute() throws AnalysisException {
public final void execute() throws AnalysisException {
if (groupExpression.hasApplied(rule)
|| groupExpression.isUnused()) {
return;

View File

@ -55,6 +55,7 @@ public class GroupExpressionMatching implements Iterable<Plan> {
public static class GroupExpressionIterator implements Iterator<Plan> {
private final List<Plan> results = Lists.newArrayList();
private int resultIndex = 0;
private int resultsSize;
/**
* Constructor.
@ -103,7 +104,7 @@ public class GroupExpressionMatching implements Iterable<Plan> {
// matching children group, one List<Plan> per child
// first dimension is every child group's plan
// second dimension is all matched plan in one group
List<List<Plan>> childrenPlans = Lists.newArrayListWithCapacity(childrenGroupArity);
List<Plan>[] childrenPlans = new List[childrenGroupArity];
for (int i = 0; i < childrenGroupArity; ++i) {
Group childGroup = groupExpression.child(i);
List<Plan> childrenPlan = matchingChildGroup(pattern, childGroup, i);
@ -116,7 +117,7 @@ public class GroupExpressionMatching implements Iterable<Plan> {
return;
}
}
childrenPlans.add(childrenPlan);
childrenPlans[i] = childrenPlan;
}
assembleAllCombinationPlanTree(root, pattern, groupExpression, childrenPlans);
} else if (patternArity == 1 && (pattern.hasMultiChild() || pattern.hasMultiGroupChild())) {
@ -127,6 +128,7 @@ public class GroupExpressionMatching implements Iterable<Plan> {
results.add(root);
}
}
this.resultsSize = results.size();
}
private List<Plan> matchingChildGroup(Pattern<? extends Plan> parentPattern,
@ -154,40 +156,37 @@ public class GroupExpressionMatching implements Iterable<Plan> {
}
private void assembleAllCombinationPlanTree(Plan root, Pattern<Plan> rootPattern,
GroupExpression groupExpression,
List<List<Plan>> childrenPlans) {
int[] childrenPlanIndex = new int[childrenPlans.size()];
GroupExpression groupExpression, List<Plan>[] childrenPlans) {
int childrenPlansSize = childrenPlans.length;
int[] childrenPlanIndex = new int[childrenPlansSize];
int offset = 0;
LogicalProperties logicalProperties = groupExpression.getOwnerGroup().getLogicalProperties();
// assemble all combination of plan tree by current root plan and children plan
while (offset < childrenPlans.size()) {
ImmutableList.Builder<Plan> childrenBuilder =
ImmutableList.builderWithExpectedSize(childrenPlans.size());
for (int i = 0; i < childrenPlans.size(); i++) {
childrenBuilder.add(childrenPlans.get(i).get(childrenPlanIndex[i]));
Optional<GroupExpression> groupExprOption = Optional.of(groupExpression);
Optional<LogicalProperties> logicalPropOption = Optional.of(logicalProperties);
while (offset < childrenPlansSize) {
ImmutableList.Builder<Plan> childrenBuilder = ImmutableList.builderWithExpectedSize(childrenPlansSize);
for (int i = 0; i < childrenPlansSize; i++) {
childrenBuilder.add(childrenPlans[i].get(childrenPlanIndex[i]));
}
List<Plan> children = childrenBuilder.build();
// assemble children: replace GroupPlan to real plan,
// withChildren will erase groupExpression, so we must
// withGroupExpression too.
Plan rootWithChildren = root.withGroupExprLogicalPropChildren(Optional.of(groupExpression),
Optional.of(logicalProperties), children);
Plan rootWithChildren = root.withGroupExprLogicalPropChildren(groupExprOption,
logicalPropOption, children);
if (rootPattern.matchPredicates(rootWithChildren)) {
results.add(rootWithChildren);
}
offset = 0;
while (true) {
for (offset = 0; offset < childrenPlansSize; offset++) {
childrenPlanIndex[offset]++;
if (childrenPlanIndex[offset] == childrenPlans.get(offset).size()) {
if (childrenPlanIndex[offset] == childrenPlans[offset].size()) {
// Reset the index when it reaches the size of the current child plan list
childrenPlanIndex[offset] = 0;
offset++;
if (offset == childrenPlans.size()) {
break;
}
} else {
break;
break; // Break the loop when the index is within the size of the current child plan list
}
}
}
@ -195,7 +194,7 @@ public class GroupExpressionMatching implements Iterable<Plan> {
@Override
public boolean hasNext() {
return resultIndex < results.size();
return resultIndex < resultsSize;
}
@Override

View File

@ -45,11 +45,12 @@ public class GroupMatching {
matchingPlans.add(plan);
}
}
for (GroupExpression groupExpression : group.getPhysicalExpressions()) {
for (Plan plan : new GroupExpressionMatching(pattern, groupExpression)) {
matchingPlans.add(plan);
}
}
// Jackwener: We don't need to match physical expressions.
// for (GroupExpression groupExpression : group.getPhysicalExpressions()) {
// for (Plan plan : new GroupExpressionMatching(pattern, groupExpression)) {
// matchingPlans.add(plan);
// }
// }
}
return matchingPlans;
}

View File

@ -38,13 +38,11 @@ import com.google.common.collect.Lists;
public class EnforceMissingPropertiesHelper {
private static final EventProducer ENFORCER_TRACER = new EventProducer(EnforcerEvent.class,
EventChannel.getDefaultChannel().addConsumers(new LogConsumer(EnforcerEvent.class, EventChannel.LOG)));
private final JobContext context;
private final GroupExpression groupExpression;
private Cost curTotalCost;
public EnforceMissingPropertiesHelper(JobContext context, GroupExpression groupExpression,
Cost curTotalCost) {
this.context = context;
this.groupExpression = groupExpression;
this.curTotalCost = curTotalCost;
}

View File

@ -17,10 +17,6 @@
package org.apache.doris.nereids.trees;
import org.apache.doris.nereids.trees.expressions.StatementScopeIdGenerator;
import org.apache.doris.nereids.trees.plans.ObjectId;
import org.apache.doris.planner.PlanNodeId;
import com.google.common.collect.ImmutableList;
import java.util.List;
@ -33,7 +29,6 @@ import java.util.List;
*/
public abstract class AbstractTreeNode<NODE_TYPE extends TreeNode<NODE_TYPE>>
implements TreeNode<NODE_TYPE> {
protected final ObjectId id = StatementScopeIdGenerator.newObjectId();
protected final List<NODE_TYPE> children;
// TODO: Maybe we should use a GroupPlan to avoid TreeNode hold the GroupExpression.
// https://github.com/apache/doris/pull/9807#discussion_r884829067
@ -59,12 +54,4 @@ public abstract class AbstractTreeNode<NODE_TYPE extends TreeNode<NODE_TYPE>>
public int arity() {
return children.size();
}
/**
* used for PhysicalPlanTranslator only
* @return PlanNodeId
*/
public PlanNodeId translatePlanNodeId() {
return id.toPlanNodeId();
}
}

View File

@ -24,9 +24,11 @@ import org.apache.doris.nereids.properties.UnboundLogicalProperties;
import org.apache.doris.nereids.trees.AbstractTreeNode;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.StatementScopeIdGenerator;
import org.apache.doris.nereids.util.MutableState;
import org.apache.doris.nereids.util.MutableState.EmptyMutableState;
import org.apache.doris.nereids.util.TreeStringUtils;
import org.apache.doris.planner.PlanNodeId;
import org.apache.doris.statistics.Statistics;
import com.google.common.base.Supplier;
@ -45,6 +47,7 @@ import javax.annotation.Nullable;
*/
public abstract class AbstractPlan extends AbstractTreeNode<Plan> implements Plan {
public static final String FRAGMENT_ID = "fragment";
protected final ObjectId id = StatementScopeIdGenerator.newObjectId();
protected final Statistics statistics;
protected final PlanType type;
@ -168,4 +171,12 @@ public abstract class AbstractPlan extends AbstractTreeNode<Plan> implements Pla
public void setMutableState(String key, Object state) {
this.mutableState = this.mutableState.set(key, state);
}
/**
* used for PhysicalPlanTranslator only
* @return PlanNodeId
*/
public PlanNodeId translatePlanNodeId() {
return id.toPlanNodeId();
}
}

View File

@ -31,7 +31,7 @@ suite("test_topn_to_max") {
group by k1;
'''
res = sql '''
explain rewritten plan select k1, max(k2)
explain rewritten plan select k1, topn(k2, 1)
from test_topn_to_max
group by k1;
'''
@ -42,7 +42,7 @@ suite("test_topn_to_max") {
from test_topn_to_max;
'''
res = sql '''
explain rewritten plan select max(k2)
explain rewritten plan select topn(k2, 1)
from test_topn_to_max;
'''
assertTrue(res.toString().contains("max"))