[enhancement](Nereids): optimize GroupExpressionMatching (#26196)
This commit is contained in:
@ -61,7 +61,7 @@ public class ApplyRuleJob extends Job {
|
||||
}
|
||||
|
||||
@Override
|
||||
public void execute() throws AnalysisException {
|
||||
public final void execute() throws AnalysisException {
|
||||
if (groupExpression.hasApplied(rule)
|
||||
|| groupExpression.isUnused()) {
|
||||
return;
|
||||
|
||||
@ -55,6 +55,7 @@ public class GroupExpressionMatching implements Iterable<Plan> {
|
||||
public static class GroupExpressionIterator implements Iterator<Plan> {
|
||||
private final List<Plan> results = Lists.newArrayList();
|
||||
private int resultIndex = 0;
|
||||
private int resultsSize;
|
||||
|
||||
/**
|
||||
* Constructor.
|
||||
@ -103,7 +104,7 @@ public class GroupExpressionMatching implements Iterable<Plan> {
|
||||
// matching children group, one List<Plan> per child
|
||||
// first dimension is every child group's plan
|
||||
// second dimension is all matched plan in one group
|
||||
List<List<Plan>> childrenPlans = Lists.newArrayListWithCapacity(childrenGroupArity);
|
||||
List<Plan>[] childrenPlans = new List[childrenGroupArity];
|
||||
for (int i = 0; i < childrenGroupArity; ++i) {
|
||||
Group childGroup = groupExpression.child(i);
|
||||
List<Plan> childrenPlan = matchingChildGroup(pattern, childGroup, i);
|
||||
@ -116,7 +117,7 @@ public class GroupExpressionMatching implements Iterable<Plan> {
|
||||
return;
|
||||
}
|
||||
}
|
||||
childrenPlans.add(childrenPlan);
|
||||
childrenPlans[i] = childrenPlan;
|
||||
}
|
||||
assembleAllCombinationPlanTree(root, pattern, groupExpression, childrenPlans);
|
||||
} else if (patternArity == 1 && (pattern.hasMultiChild() || pattern.hasMultiGroupChild())) {
|
||||
@ -127,6 +128,7 @@ public class GroupExpressionMatching implements Iterable<Plan> {
|
||||
results.add(root);
|
||||
}
|
||||
}
|
||||
this.resultsSize = results.size();
|
||||
}
|
||||
|
||||
private List<Plan> matchingChildGroup(Pattern<? extends Plan> parentPattern,
|
||||
@ -154,38 +156,35 @@ public class GroupExpressionMatching implements Iterable<Plan> {
|
||||
}
|
||||
|
||||
private void assembleAllCombinationPlanTree(Plan root, Pattern<Plan> rootPattern,
|
||||
GroupExpression groupExpression,
|
||||
List<List<Plan>> childrenPlans) {
|
||||
int[] childrenPlanIndex = new int[childrenPlans.size()];
|
||||
GroupExpression groupExpression, List<Plan>[] childrenPlans) {
|
||||
int childrenPlansSize = childrenPlans.length;
|
||||
int[] childrenPlanIndex = new int[childrenPlansSize];
|
||||
int offset = 0;
|
||||
LogicalProperties logicalProperties = groupExpression.getOwnerGroup().getLogicalProperties();
|
||||
|
||||
// assemble all combination of plan tree by current root plan and children plan
|
||||
while (offset < childrenPlans.size()) {
|
||||
ImmutableList.Builder<Plan> childrenBuilder =
|
||||
ImmutableList.builderWithExpectedSize(childrenPlans.size());
|
||||
for (int i = 0; i < childrenPlans.size(); i++) {
|
||||
childrenBuilder.add(childrenPlans.get(i).get(childrenPlanIndex[i]));
|
||||
Optional<GroupExpression> groupExprOption = Optional.of(groupExpression);
|
||||
Optional<LogicalProperties> logicalPropOption = Optional.of(logicalProperties);
|
||||
while (offset < childrenPlansSize) {
|
||||
ImmutableList.Builder<Plan> childrenBuilder = ImmutableList.builderWithExpectedSize(childrenPlansSize);
|
||||
for (int i = 0; i < childrenPlansSize; i++) {
|
||||
childrenBuilder.add(childrenPlans[i].get(childrenPlanIndex[i]));
|
||||
}
|
||||
List<Plan> children = childrenBuilder.build();
|
||||
|
||||
// assemble children: replace GroupPlan to real plan,
|
||||
// withChildren will erase groupExpression, so we must
|
||||
// withGroupExpression too.
|
||||
Plan rootWithChildren = root.withGroupExprLogicalPropChildren(Optional.of(groupExpression),
|
||||
Optional.of(logicalProperties), children);
|
||||
Plan rootWithChildren = root.withGroupExprLogicalPropChildren(groupExprOption,
|
||||
logicalPropOption, children);
|
||||
if (rootPattern.matchPredicates(rootWithChildren)) {
|
||||
results.add(rootWithChildren);
|
||||
}
|
||||
offset = 0;
|
||||
while (true) {
|
||||
for (offset = 0; offset < childrenPlansSize; offset++) {
|
||||
childrenPlanIndex[offset]++;
|
||||
if (childrenPlanIndex[offset] == childrenPlans.get(offset).size()) {
|
||||
if (childrenPlanIndex[offset] == childrenPlans[offset].size()) {
|
||||
// Reset the index when it reaches the size of the current child plan list
|
||||
childrenPlanIndex[offset] = 0;
|
||||
offset++;
|
||||
if (offset == childrenPlans.size()) {
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
@ -195,7 +194,7 @@ public class GroupExpressionMatching implements Iterable<Plan> {
|
||||
|
||||
@Override
|
||||
public boolean hasNext() {
|
||||
return resultIndex < results.size();
|
||||
return resultIndex < resultsSize;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@ -17,10 +17,6 @@
|
||||
|
||||
package org.apache.doris.nereids.trees;
|
||||
|
||||
import org.apache.doris.nereids.trees.expressions.StatementScopeIdGenerator;
|
||||
import org.apache.doris.nereids.trees.plans.ObjectId;
|
||||
import org.apache.doris.planner.PlanNodeId;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
|
||||
import java.util.List;
|
||||
@ -33,7 +29,6 @@ import java.util.List;
|
||||
*/
|
||||
public abstract class AbstractTreeNode<NODE_TYPE extends TreeNode<NODE_TYPE>>
|
||||
implements TreeNode<NODE_TYPE> {
|
||||
protected final ObjectId id = StatementScopeIdGenerator.newObjectId();
|
||||
protected final List<NODE_TYPE> children;
|
||||
// TODO: Maybe we should use a GroupPlan to avoid TreeNode hold the GroupExpression.
|
||||
// https://github.com/apache/doris/pull/9807#discussion_r884829067
|
||||
@ -59,12 +54,4 @@ public abstract class AbstractTreeNode<NODE_TYPE extends TreeNode<NODE_TYPE>>
|
||||
public int arity() {
|
||||
return children.size();
|
||||
}
|
||||
|
||||
/**
|
||||
* used for PhysicalPlanTranslator only
|
||||
* @return PlanNodeId
|
||||
*/
|
||||
public PlanNodeId translatePlanNodeId() {
|
||||
return id.toPlanNodeId();
|
||||
}
|
||||
}
|
||||
|
||||
@ -37,6 +37,7 @@ import org.apache.doris.nereids.types.coercion.AnyDataType;
|
||||
import org.apache.doris.nereids.util.Utils;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import com.google.common.collect.ImmutableSet;
|
||||
import com.google.common.collect.Lists;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
|
||||
@ -44,7 +45,6 @@ import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* Abstract class for all Expression in Nereids.
|
||||
@ -247,8 +247,19 @@ public abstract class Expression extends AbstractTreeNode<Expression> implements
|
||||
return collect(Slot.class::isInstance);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get all the input slot ids of the expression.
|
||||
* <p>
|
||||
* Note that the input slots of subquery's inner plan is not included.
|
||||
*/
|
||||
public final Set<ExprId> getInputSlotExprIds() {
|
||||
return getInputSlots().stream().map(NamedExpression::getExprId).collect(Collectors.toSet());
|
||||
ImmutableSet.Builder<ExprId> result = ImmutableSet.builder();
|
||||
foreach(node -> {
|
||||
if (node instanceof Slot) {
|
||||
result.add(((Slot) node).getExprId());
|
||||
}
|
||||
});
|
||||
return result.build();
|
||||
}
|
||||
|
||||
public boolean isLiteral() {
|
||||
|
||||
@ -24,9 +24,11 @@ import org.apache.doris.nereids.properties.UnboundLogicalProperties;
|
||||
import org.apache.doris.nereids.trees.AbstractTreeNode;
|
||||
import org.apache.doris.nereids.trees.expressions.ExprId;
|
||||
import org.apache.doris.nereids.trees.expressions.Slot;
|
||||
import org.apache.doris.nereids.trees.expressions.StatementScopeIdGenerator;
|
||||
import org.apache.doris.nereids.util.MutableState;
|
||||
import org.apache.doris.nereids.util.MutableState.EmptyMutableState;
|
||||
import org.apache.doris.nereids.util.TreeStringUtils;
|
||||
import org.apache.doris.planner.PlanNodeId;
|
||||
import org.apache.doris.statistics.Statistics;
|
||||
|
||||
import com.google.common.base.Supplier;
|
||||
@ -45,6 +47,7 @@ import javax.annotation.Nullable;
|
||||
*/
|
||||
public abstract class AbstractPlan extends AbstractTreeNode<Plan> implements Plan {
|
||||
public static final String FRAGMENT_ID = "fragment";
|
||||
protected final ObjectId id = StatementScopeIdGenerator.newObjectId();
|
||||
|
||||
protected final Statistics statistics;
|
||||
protected final PlanType type;
|
||||
@ -168,4 +171,12 @@ public abstract class AbstractPlan extends AbstractTreeNode<Plan> implements Pla
|
||||
public void setMutableState(String key, Object state) {
|
||||
this.mutableState = this.mutableState.set(key, state);
|
||||
}
|
||||
|
||||
/**
|
||||
* used for PhysicalPlanTranslator only
|
||||
* @return PlanNodeId
|
||||
*/
|
||||
public PlanNodeId translatePlanNodeId() {
|
||||
return id.toPlanNodeId();
|
||||
}
|
||||
}
|
||||
|
||||
@ -43,9 +43,9 @@ import org.apache.doris.statistics.Statistics;
|
||||
import org.apache.doris.thrift.TRuntimeFilterType;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.google.common.collect.Sets;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
@ -113,22 +113,25 @@ public class PhysicalHashJoin<
|
||||
* Return pair of left used slots and right used slots.
|
||||
*/
|
||||
public Pair<List<ExprId>, List<ExprId>> getHashConjunctsExprIds() {
|
||||
List<ExprId> exprIds1 = Lists.newArrayListWithCapacity(hashJoinConjuncts.size());
|
||||
List<ExprId> exprIds2 = Lists.newArrayListWithCapacity(hashJoinConjuncts.size());
|
||||
int size = hashJoinConjuncts.size();
|
||||
|
||||
List<ExprId> exprIds1 = new ArrayList<>(size);
|
||||
List<ExprId> exprIds2 = new ArrayList<>(size);
|
||||
|
||||
Set<ExprId> leftExprIds = left().getOutputExprIdSet();
|
||||
Set<ExprId> rightExprIds = right().getOutputExprIdSet();
|
||||
|
||||
for (Expression expr : hashJoinConjuncts) {
|
||||
expr.getInputSlotExprIds().forEach(exprId -> {
|
||||
for (ExprId exprId : expr.getInputSlotExprIds()) {
|
||||
if (leftExprIds.contains(exprId)) {
|
||||
exprIds1.add(exprId);
|
||||
} else if (rightExprIds.contains(exprId)) {
|
||||
exprIds2.add(exprId);
|
||||
} else {
|
||||
throw new RuntimeException("Could not generate valid equal on clause slot pairs for join");
|
||||
throw new RuntimeException("Invalid ExprId found: " + exprId
|
||||
+ ". Cannot generate valid equal on clause slot pairs for join.");
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
return Pair.of(exprIds1, exprIds2);
|
||||
}
|
||||
|
||||
@ -38,6 +38,7 @@ import org.apache.doris.nereids.trees.plans.physical.PhysicalDistribute;
|
||||
import org.apache.doris.nereids.trees.plans.physical.PhysicalHashJoin;
|
||||
import org.apache.doris.nereids.trees.plans.physical.PhysicalPlan;
|
||||
import org.apache.doris.qe.ConnectContext;
|
||||
import org.apache.doris.qe.SessionVariable;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import com.google.common.collect.Lists;
|
||||
@ -66,9 +67,10 @@ public class JoinUtils {
|
||||
* check if the row count of the left child in the broadcast join is less than a threshold value.
|
||||
*/
|
||||
public static boolean checkBroadcastJoinStats(PhysicalHashJoin<? extends Plan, ? extends Plan> join) {
|
||||
double memLimit = ConnectContext.get().getSessionVariable().getMaxExecMemByte();
|
||||
double rowsLimit = ConnectContext.get().getSessionVariable().getBroadcastRowCountLimit();
|
||||
double brMemlimit = ConnectContext.get().getSessionVariable().getBroadcastHashtableMemLimitPercentage();
|
||||
SessionVariable sessionVariable = ConnectContext.get().getSessionVariable();
|
||||
double memLimit = sessionVariable.getMaxExecMemByte();
|
||||
double rowsLimit = sessionVariable.getBroadcastRowCountLimit();
|
||||
double brMemlimit = sessionVariable.getBroadcastHashtableMemLimitPercentage();
|
||||
double datasize = join.getGroupExpression().get().child(1).getStatistics().computeSize();
|
||||
double rowCount = join.getGroupExpression().get().child(1).getStatistics().getRowCount();
|
||||
return rowCount <= rowsLimit && datasize <= memLimit * brMemlimit;
|
||||
@ -114,12 +116,12 @@ public class JoinUtils {
|
||||
* @return true if the equal can be used as hash join condition
|
||||
*/
|
||||
public boolean isHashJoinCondition(EqualTo equalTo) {
|
||||
Set<Slot> equalLeft = equalTo.left().collect(Slot.class::isInstance);
|
||||
Set<Slot> equalLeft = equalTo.left().getInputSlots();
|
||||
if (equalLeft.isEmpty()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
Set<Slot> equalRight = equalTo.right().collect(Slot.class::isInstance);
|
||||
Set<Slot> equalRight = equalTo.right().getInputSlots();
|
||||
if (equalRight.isEmpty()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@ -55,7 +55,7 @@ public class PlanUtils {
|
||||
* normalize comparison predicate on a binary plan to its two sides are corresponding to the child's output.
|
||||
*/
|
||||
public static ComparisonPredicate maybeCommuteComparisonPredicate(ComparisonPredicate expression, Plan left) {
|
||||
Set<Slot> slots = expression.left().collect(Slot.class::isInstance);
|
||||
Set<Slot> slots = expression.left().getInputSlots();
|
||||
Set<Slot> leftSlots = left.getOutputSet();
|
||||
Set<Slot> buffer = Sets.newHashSet(slots);
|
||||
buffer.removeAll(leftSlots);
|
||||
|
||||
@ -42,10 +42,10 @@ import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.util.Iterator;
|
||||
|
||||
public class GroupExpressionMatchingTest {
|
||||
class GroupExpressionMatchingTest {
|
||||
|
||||
@Test
|
||||
public void testLeafNode() {
|
||||
void testLeafNode() {
|
||||
Pattern pattern = new Pattern<>(PlanType.LOGICAL_UNBOUND_RELATION);
|
||||
|
||||
Memo memo = new Memo(null, new UnboundRelation(StatementScopeIdGenerator.newRelationId(), Lists.newArrayList("test")));
|
||||
@ -61,7 +61,7 @@ public class GroupExpressionMatchingTest {
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testDepth2() {
|
||||
void testDepth2() {
|
||||
Pattern pattern = new Pattern<>(PlanType.LOGICAL_PROJECT,
|
||||
new Pattern<>(PlanType.LOGICAL_UNBOUND_RELATION));
|
||||
|
||||
@ -93,7 +93,7 @@ public class GroupExpressionMatchingTest {
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testDepth2WithGroup() {
|
||||
void testDepth2WithGroup() {
|
||||
Pattern pattern = new Pattern<>(PlanType.LOGICAL_PROJECT, Pattern.GROUP);
|
||||
|
||||
Plan leaf = new UnboundRelation(StatementScopeIdGenerator.newRelationId(), Lists.newArrayList("test"));
|
||||
@ -119,7 +119,7 @@ public class GroupExpressionMatchingTest {
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testLeafAny() {
|
||||
void testLeafAny() {
|
||||
Pattern pattern = Pattern.ANY;
|
||||
|
||||
Memo memo = new Memo(null, new UnboundRelation(StatementScopeIdGenerator.newRelationId(), Lists.newArrayList("test")));
|
||||
@ -135,7 +135,7 @@ public class GroupExpressionMatchingTest {
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testAnyWithChild() {
|
||||
void testAnyWithChild() {
|
||||
Plan root = new LogicalProject(
|
||||
ImmutableList.of(new SlotReference("name", StringType.INSTANCE, true,
|
||||
ImmutableList.of("test"))),
|
||||
@ -159,7 +159,7 @@ public class GroupExpressionMatchingTest {
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testInnerLogicalJoinMatch() {
|
||||
void testInnerLogicalJoinMatch() {
|
||||
Plan root = new LogicalJoin(JoinType.INNER_JOIN,
|
||||
new UnboundRelation(StatementScopeIdGenerator.newRelationId(), ImmutableList.of("a")),
|
||||
new UnboundRelation(StatementScopeIdGenerator.newRelationId(), ImmutableList.of("b"))
|
||||
@ -181,7 +181,7 @@ public class GroupExpressionMatchingTest {
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testInnerLogicalJoinMismatch() {
|
||||
void testInnerLogicalJoinMismatch() {
|
||||
Plan root = new LogicalJoin(JoinType.LEFT_OUTER_JOIN,
|
||||
new UnboundRelation(StatementScopeIdGenerator.newRelationId(), ImmutableList.of("a")),
|
||||
new UnboundRelation(StatementScopeIdGenerator.newRelationId(), ImmutableList.of("b"))
|
||||
@ -198,7 +198,7 @@ public class GroupExpressionMatchingTest {
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testTopMatchButChildrenNotMatch() {
|
||||
void testTopMatchButChildrenNotMatch() {
|
||||
Plan root = new LogicalJoin(JoinType.LEFT_OUTER_JOIN,
|
||||
new UnboundRelation(StatementScopeIdGenerator.newRelationId(), ImmutableList.of("a")),
|
||||
new UnboundRelation(StatementScopeIdGenerator.newRelationId(), ImmutableList.of("b"))
|
||||
@ -216,12 +216,12 @@ public class GroupExpressionMatchingTest {
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSubTreeMatch() {
|
||||
void testSubTreeMatch() {
|
||||
Plan root =
|
||||
new LogicalFilter(ImmutableSet.of(new EqualTo(new UnboundSlot(Lists.newArrayList("a", "id")),
|
||||
new LogicalFilter<>(ImmutableSet.of(new EqualTo(new UnboundSlot(Lists.newArrayList("a", "id")),
|
||||
new UnboundSlot(Lists.newArrayList("b", "id")))),
|
||||
new LogicalJoin(JoinType.INNER_JOIN,
|
||||
new LogicalJoin(JoinType.LEFT_OUTER_JOIN,
|
||||
new LogicalJoin<>(JoinType.INNER_JOIN,
|
||||
new LogicalJoin<>(JoinType.LEFT_OUTER_JOIN,
|
||||
new UnboundRelation(StatementScopeIdGenerator.newRelationId(), ImmutableList.of("a")),
|
||||
new UnboundRelation(StatementScopeIdGenerator.newRelationId(), ImmutableList.of("b"))),
|
||||
new UnboundRelation(StatementScopeIdGenerator.newRelationId(), ImmutableList.of("c")))
|
||||
|
||||
Reference in New Issue
Block a user