[enhancement](Nereids): optimize GroupExpressionMatching (#26130)

1. pattern can't be SubTreePattern in CBO phase.
2. optimize getInputSlotExprId()
This commit is contained in:
jakevin
2023-10-31 15:47:13 +08:00
committed by GitHub
parent 111b8e2b4f
commit 19122b55cd
6 changed files with 54 additions and 55 deletions

View File

@ -20,7 +20,6 @@ package org.apache.doris.nereids.pattern;
import org.apache.doris.nereids.memo.Group;
import org.apache.doris.nereids.memo.GroupExpression;
import org.apache.doris.nereids.properties.LogicalProperties;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.Plan;
import com.google.common.collect.ImmutableList;
@ -70,21 +69,19 @@ public class GroupExpressionMatching implements Iterable<Plan> {
int childrenGroupArity = groupExpression.arity();
int patternArity = pattern.arity();
if (!(pattern instanceof SubTreePattern)) {
// (logicalFilter(), multi()) match (logicalFilter()),
// but (logicalFilter(), logicalFilter(), multi()) not match (logicalFilter())
boolean extraMulti = patternArity == childrenGroupArity + 1
&& (pattern.hasMultiChild() || pattern.hasMultiGroupChild());
if (patternArity > childrenGroupArity && !extraMulti) {
return;
}
// (logicalFilter(), multi()) match (logicalFilter()),
// but (logicalFilter(), logicalFilter(), multi()) not match (logicalFilter())
boolean extraMulti = patternArity == childrenGroupArity + 1
&& (pattern.hasMultiChild() || pattern.hasMultiGroupChild());
if (patternArity > childrenGroupArity && !extraMulti) {
return;
}
// (multi()) match (logicalFilter(), logicalFilter()),
// but (logicalFilter()) not match (logicalFilter(), logicalFilter())
if (!pattern.isAny() && patternArity < childrenGroupArity
&& !pattern.hasMultiChild() && !pattern.hasMultiGroupChild()) {
return;
}
// (multi()) match (logicalFilter(), logicalFilter()),
// but (logicalFilter()) not match (logicalFilter(), logicalFilter())
if (!pattern.isAny() && patternArity < childrenGroupArity
&& !pattern.hasMultiChild() && !pattern.hasMultiGroupChild()) {
return;
}
// Pattern.GROUP / Pattern.MULTI / Pattern.MULTI_GROUP can not match GroupExpression
@ -94,7 +91,7 @@ public class GroupExpressionMatching implements Iterable<Plan> {
// getPlan return the plan with GroupPlan as children
Plan root = groupExpression.getPlan();
if (patternArity == 0 && !(pattern instanceof SubTreePattern)) {
if (patternArity == 0) {
if (pattern.matchPredicates(root)) {
// if no children pattern, we treat all children as GROUP. e.g. Pattern.ANY.
// leaf plan will enter this branch too, e.g. logicalRelation().
@ -110,12 +107,8 @@ public class GroupExpressionMatching implements Iterable<Plan> {
List<Plan> childrenPlan = matchingChildGroup(pattern, childGroup, i);
if (childrenPlan.isEmpty()) {
if (pattern instanceof SubTreePattern) {
childrenPlan = ImmutableList.of(new GroupPlan(childGroup));
} else {
// current pattern is match but children patterns not match
return;
}
// current pattern is match but children patterns not match
return;
}
childrenPlans[i] = childrenPlan;
}
@ -134,20 +127,16 @@ public class GroupExpressionMatching implements Iterable<Plan> {
private List<Plan> matchingChildGroup(Pattern<? extends Plan> parentPattern,
Group childGroup, int childIndex) {
Pattern<? extends Plan> childPattern;
if (parentPattern instanceof SubTreePattern) {
childPattern = parentPattern;
} else {
boolean isLastPattern = childIndex + 1 >= parentPattern.arity();
int patternChildIndex = isLastPattern ? parentPattern.arity() - 1 : childIndex;
boolean isLastPattern = childIndex + 1 >= parentPattern.arity();
int patternChildIndex = isLastPattern ? parentPattern.arity() - 1 : childIndex;
childPattern = parentPattern.child(patternChildIndex);
// translate MULTI and MULTI_GROUP to ANY and GROUP
if (isLastPattern) {
if (childPattern.isMulti()) {
childPattern = Pattern.ANY;
} else if (childPattern.isMultiGroup()) {
childPattern = Pattern.GROUP;
}
childPattern = parentPattern.child(patternChildIndex);
// translate MULTI and MULTI_GROUP to ANY and GROUP
if (isLastPattern) {
if (childPattern.isMulti()) {
childPattern = Pattern.ANY;
} else if (childPattern.isMultiGroup()) {
childPattern = Pattern.GROUP;
}
}

View File

@ -45,12 +45,6 @@ public class GroupMatching {
matchingPlans.add(plan);
}
}
// Jackwener: We don't need to match physical expressions.
// for (GroupExpression groupExpression : group.getPhysicalExpressions()) {
// for (Plan plan : new GroupExpressionMatching(pattern, groupExpression)) {
// matchingPlans.add(plan);
// }
// }
}
return matchingPlans;
}

View File

@ -37,6 +37,7 @@ import org.apache.doris.nereids.types.coercion.AnyDataType;
import org.apache.doris.nereids.util.Utils;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Lists;
import org.apache.commons.lang3.StringUtils;
@ -44,7 +45,6 @@ import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
/**
* Abstract class for all Expression in Nereids.
@ -247,8 +247,19 @@ public abstract class Expression extends AbstractTreeNode<Expression> implements
return collect(Slot.class::isInstance);
}
/**
* Get all the input slot ids of the expression.
* <p>
* Note that the input slots of subquery's inner plan is not included.
*/
public final Set<ExprId> getInputSlotExprIds() {
return getInputSlots().stream().map(NamedExpression::getExprId).collect(Collectors.toSet());
ImmutableSet.Builder<ExprId> result = ImmutableSet.builder();
foreach(node -> {
if (node instanceof Slot) {
result.add(((Slot) node).getExprId());
}
});
return result.build();
}
public boolean isLiteral() {

View File

@ -43,9 +43,9 @@ import org.apache.doris.statistics.Statistics;
import org.apache.doris.thrift.TRuntimeFilterType;
import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
@ -113,22 +113,25 @@ public class PhysicalHashJoin<
* Return pair of left used slots and right used slots.
*/
public Pair<List<ExprId>, List<ExprId>> getHashConjunctsExprIds() {
List<ExprId> exprIds1 = Lists.newArrayListWithCapacity(hashJoinConjuncts.size());
List<ExprId> exprIds2 = Lists.newArrayListWithCapacity(hashJoinConjuncts.size());
int size = hashJoinConjuncts.size();
List<ExprId> exprIds1 = new ArrayList<>(size);
List<ExprId> exprIds2 = new ArrayList<>(size);
Set<ExprId> leftExprIds = left().getOutputExprIdSet();
Set<ExprId> rightExprIds = right().getOutputExprIdSet();
for (Expression expr : hashJoinConjuncts) {
expr.getInputSlotExprIds().forEach(exprId -> {
for (ExprId exprId : expr.getInputSlotExprIds()) {
if (leftExprIds.contains(exprId)) {
exprIds1.add(exprId);
} else if (rightExprIds.contains(exprId)) {
exprIds2.add(exprId);
} else {
throw new RuntimeException("Could not generate valid equal on clause slot pairs for join");
throw new RuntimeException("Invalid ExprId found: " + exprId
+ ". Cannot generate valid equal on clause slot pairs for join.");
}
});
}
}
return Pair.of(exprIds1, exprIds2);
}

View File

@ -38,6 +38,7 @@ import org.apache.doris.nereids.trees.plans.physical.PhysicalDistribute;
import org.apache.doris.nereids.trees.plans.physical.PhysicalHashJoin;
import org.apache.doris.nereids.trees.plans.physical.PhysicalPlan;
import org.apache.doris.qe.ConnectContext;
import org.apache.doris.qe.SessionVariable;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
@ -66,9 +67,10 @@ public class JoinUtils {
* check if the row count of the left child in the broadcast join is less than a threshold value.
*/
public static boolean checkBroadcastJoinStats(PhysicalHashJoin<? extends Plan, ? extends Plan> join) {
double memLimit = ConnectContext.get().getSessionVariable().getMaxExecMemByte();
double rowsLimit = ConnectContext.get().getSessionVariable().getBroadcastRowCountLimit();
double brMemlimit = ConnectContext.get().getSessionVariable().getBroadcastHashtableMemLimitPercentage();
SessionVariable sessionVariable = ConnectContext.get().getSessionVariable();
double memLimit = sessionVariable.getMaxExecMemByte();
double rowsLimit = sessionVariable.getBroadcastRowCountLimit();
double brMemlimit = sessionVariable.getBroadcastHashtableMemLimitPercentage();
double datasize = join.getGroupExpression().get().child(1).getStatistics().computeSize();
double rowCount = join.getGroupExpression().get().child(1).getStatistics().getRowCount();
return rowCount <= rowsLimit && datasize <= memLimit * brMemlimit;
@ -114,12 +116,12 @@ public class JoinUtils {
* @return true if the equal can be used as hash join condition
*/
public boolean isHashJoinCondition(EqualTo equalTo) {
Set<Slot> equalLeft = equalTo.left().collect(Slot.class::isInstance);
Set<Slot> equalLeft = equalTo.left().getInputSlots();
if (equalLeft.isEmpty()) {
return false;
}
Set<Slot> equalRight = equalTo.right().collect(Slot.class::isInstance);
Set<Slot> equalRight = equalTo.right().getInputSlots();
if (equalRight.isEmpty()) {
return false;
}

View File

@ -55,7 +55,7 @@ public class PlanUtils {
* normalize comparison predicate on a binary plan to its two sides are corresponding to the child's output.
*/
public static ComparisonPredicate maybeCommuteComparisonPredicate(ComparisonPredicate expression, Plan left) {
Set<Slot> slots = expression.left().collect(Slot.class::isInstance);
Set<Slot> slots = expression.left().getInputSlots();
Set<Slot> leftSlots = left.getOutputSet();
Set<Slot> buffer = Sets.newHashSet(slots);
buffer.removeAll(leftSlots);