[enhancement](Nereids): optimize GroupExpressionMatching (#26130)
1. pattern can't be SubTreePattern in CBO phase. 2. optimize getInputSlotExprId()
This commit is contained in:
@ -20,7 +20,6 @@ package org.apache.doris.nereids.pattern;
|
||||
import org.apache.doris.nereids.memo.Group;
|
||||
import org.apache.doris.nereids.memo.GroupExpression;
|
||||
import org.apache.doris.nereids.properties.LogicalProperties;
|
||||
import org.apache.doris.nereids.trees.plans.GroupPlan;
|
||||
import org.apache.doris.nereids.trees.plans.Plan;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
@ -70,21 +69,19 @@ public class GroupExpressionMatching implements Iterable<Plan> {
|
||||
|
||||
int childrenGroupArity = groupExpression.arity();
|
||||
int patternArity = pattern.arity();
|
||||
if (!(pattern instanceof SubTreePattern)) {
|
||||
// (logicalFilter(), multi()) match (logicalFilter()),
|
||||
// but (logicalFilter(), logicalFilter(), multi()) not match (logicalFilter())
|
||||
boolean extraMulti = patternArity == childrenGroupArity + 1
|
||||
&& (pattern.hasMultiChild() || pattern.hasMultiGroupChild());
|
||||
if (patternArity > childrenGroupArity && !extraMulti) {
|
||||
return;
|
||||
}
|
||||
// (logicalFilter(), multi()) match (logicalFilter()),
|
||||
// but (logicalFilter(), logicalFilter(), multi()) not match (logicalFilter())
|
||||
boolean extraMulti = patternArity == childrenGroupArity + 1
|
||||
&& (pattern.hasMultiChild() || pattern.hasMultiGroupChild());
|
||||
if (patternArity > childrenGroupArity && !extraMulti) {
|
||||
return;
|
||||
}
|
||||
|
||||
// (multi()) match (logicalFilter(), logicalFilter()),
|
||||
// but (logicalFilter()) not match (logicalFilter(), logicalFilter())
|
||||
if (!pattern.isAny() && patternArity < childrenGroupArity
|
||||
&& !pattern.hasMultiChild() && !pattern.hasMultiGroupChild()) {
|
||||
return;
|
||||
}
|
||||
// (multi()) match (logicalFilter(), logicalFilter()),
|
||||
// but (logicalFilter()) not match (logicalFilter(), logicalFilter())
|
||||
if (!pattern.isAny() && patternArity < childrenGroupArity
|
||||
&& !pattern.hasMultiChild() && !pattern.hasMultiGroupChild()) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Pattern.GROUP / Pattern.MULTI / Pattern.MULTI_GROUP can not match GroupExpression
|
||||
@ -94,7 +91,7 @@ public class GroupExpressionMatching implements Iterable<Plan> {
|
||||
|
||||
// getPlan return the plan with GroupPlan as children
|
||||
Plan root = groupExpression.getPlan();
|
||||
if (patternArity == 0 && !(pattern instanceof SubTreePattern)) {
|
||||
if (patternArity == 0) {
|
||||
if (pattern.matchPredicates(root)) {
|
||||
// if no children pattern, we treat all children as GROUP. e.g. Pattern.ANY.
|
||||
// leaf plan will enter this branch too, e.g. logicalRelation().
|
||||
@ -110,12 +107,8 @@ public class GroupExpressionMatching implements Iterable<Plan> {
|
||||
List<Plan> childrenPlan = matchingChildGroup(pattern, childGroup, i);
|
||||
|
||||
if (childrenPlan.isEmpty()) {
|
||||
if (pattern instanceof SubTreePattern) {
|
||||
childrenPlan = ImmutableList.of(new GroupPlan(childGroup));
|
||||
} else {
|
||||
// current pattern is match but children patterns not match
|
||||
return;
|
||||
}
|
||||
// current pattern is match but children patterns not match
|
||||
return;
|
||||
}
|
||||
childrenPlans[i] = childrenPlan;
|
||||
}
|
||||
@ -134,20 +127,16 @@ public class GroupExpressionMatching implements Iterable<Plan> {
|
||||
private List<Plan> matchingChildGroup(Pattern<? extends Plan> parentPattern,
|
||||
Group childGroup, int childIndex) {
|
||||
Pattern<? extends Plan> childPattern;
|
||||
if (parentPattern instanceof SubTreePattern) {
|
||||
childPattern = parentPattern;
|
||||
} else {
|
||||
boolean isLastPattern = childIndex + 1 >= parentPattern.arity();
|
||||
int patternChildIndex = isLastPattern ? parentPattern.arity() - 1 : childIndex;
|
||||
boolean isLastPattern = childIndex + 1 >= parentPattern.arity();
|
||||
int patternChildIndex = isLastPattern ? parentPattern.arity() - 1 : childIndex;
|
||||
|
||||
childPattern = parentPattern.child(patternChildIndex);
|
||||
// translate MULTI and MULTI_GROUP to ANY and GROUP
|
||||
if (isLastPattern) {
|
||||
if (childPattern.isMulti()) {
|
||||
childPattern = Pattern.ANY;
|
||||
} else if (childPattern.isMultiGroup()) {
|
||||
childPattern = Pattern.GROUP;
|
||||
}
|
||||
childPattern = parentPattern.child(patternChildIndex);
|
||||
// translate MULTI and MULTI_GROUP to ANY and GROUP
|
||||
if (isLastPattern) {
|
||||
if (childPattern.isMulti()) {
|
||||
childPattern = Pattern.ANY;
|
||||
} else if (childPattern.isMultiGroup()) {
|
||||
childPattern = Pattern.GROUP;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -45,12 +45,6 @@ public class GroupMatching {
|
||||
matchingPlans.add(plan);
|
||||
}
|
||||
}
|
||||
// Jackwener: We don't need to match physical expressions.
|
||||
// for (GroupExpression groupExpression : group.getPhysicalExpressions()) {
|
||||
// for (Plan plan : new GroupExpressionMatching(pattern, groupExpression)) {
|
||||
// matchingPlans.add(plan);
|
||||
// }
|
||||
// }
|
||||
}
|
||||
return matchingPlans;
|
||||
}
|
||||
|
||||
@ -37,6 +37,7 @@ import org.apache.doris.nereids.types.coercion.AnyDataType;
|
||||
import org.apache.doris.nereids.util.Utils;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import com.google.common.collect.ImmutableSet;
|
||||
import com.google.common.collect.Lists;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
|
||||
@ -44,7 +45,6 @@ import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* Abstract class for all Expression in Nereids.
|
||||
@ -247,8 +247,19 @@ public abstract class Expression extends AbstractTreeNode<Expression> implements
|
||||
return collect(Slot.class::isInstance);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get all the input slot ids of the expression.
|
||||
* <p>
|
||||
* Note that the input slots of subquery's inner plan is not included.
|
||||
*/
|
||||
public final Set<ExprId> getInputSlotExprIds() {
|
||||
return getInputSlots().stream().map(NamedExpression::getExprId).collect(Collectors.toSet());
|
||||
ImmutableSet.Builder<ExprId> result = ImmutableSet.builder();
|
||||
foreach(node -> {
|
||||
if (node instanceof Slot) {
|
||||
result.add(((Slot) node).getExprId());
|
||||
}
|
||||
});
|
||||
return result.build();
|
||||
}
|
||||
|
||||
public boolean isLiteral() {
|
||||
|
||||
@ -43,9 +43,9 @@ import org.apache.doris.statistics.Statistics;
|
||||
import org.apache.doris.thrift.TRuntimeFilterType;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.google.common.collect.Sets;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
@ -113,22 +113,25 @@ public class PhysicalHashJoin<
|
||||
* Return pair of left used slots and right used slots.
|
||||
*/
|
||||
public Pair<List<ExprId>, List<ExprId>> getHashConjunctsExprIds() {
|
||||
List<ExprId> exprIds1 = Lists.newArrayListWithCapacity(hashJoinConjuncts.size());
|
||||
List<ExprId> exprIds2 = Lists.newArrayListWithCapacity(hashJoinConjuncts.size());
|
||||
int size = hashJoinConjuncts.size();
|
||||
|
||||
List<ExprId> exprIds1 = new ArrayList<>(size);
|
||||
List<ExprId> exprIds2 = new ArrayList<>(size);
|
||||
|
||||
Set<ExprId> leftExprIds = left().getOutputExprIdSet();
|
||||
Set<ExprId> rightExprIds = right().getOutputExprIdSet();
|
||||
|
||||
for (Expression expr : hashJoinConjuncts) {
|
||||
expr.getInputSlotExprIds().forEach(exprId -> {
|
||||
for (ExprId exprId : expr.getInputSlotExprIds()) {
|
||||
if (leftExprIds.contains(exprId)) {
|
||||
exprIds1.add(exprId);
|
||||
} else if (rightExprIds.contains(exprId)) {
|
||||
exprIds2.add(exprId);
|
||||
} else {
|
||||
throw new RuntimeException("Could not generate valid equal on clause slot pairs for join");
|
||||
throw new RuntimeException("Invalid ExprId found: " + exprId
|
||||
+ ". Cannot generate valid equal on clause slot pairs for join.");
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
return Pair.of(exprIds1, exprIds2);
|
||||
}
|
||||
|
||||
@ -38,6 +38,7 @@ import org.apache.doris.nereids.trees.plans.physical.PhysicalDistribute;
|
||||
import org.apache.doris.nereids.trees.plans.physical.PhysicalHashJoin;
|
||||
import org.apache.doris.nereids.trees.plans.physical.PhysicalPlan;
|
||||
import org.apache.doris.qe.ConnectContext;
|
||||
import org.apache.doris.qe.SessionVariable;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import com.google.common.collect.Lists;
|
||||
@ -66,9 +67,10 @@ public class JoinUtils {
|
||||
* check if the row count of the left child in the broadcast join is less than a threshold value.
|
||||
*/
|
||||
public static boolean checkBroadcastJoinStats(PhysicalHashJoin<? extends Plan, ? extends Plan> join) {
|
||||
double memLimit = ConnectContext.get().getSessionVariable().getMaxExecMemByte();
|
||||
double rowsLimit = ConnectContext.get().getSessionVariable().getBroadcastRowCountLimit();
|
||||
double brMemlimit = ConnectContext.get().getSessionVariable().getBroadcastHashtableMemLimitPercentage();
|
||||
SessionVariable sessionVariable = ConnectContext.get().getSessionVariable();
|
||||
double memLimit = sessionVariable.getMaxExecMemByte();
|
||||
double rowsLimit = sessionVariable.getBroadcastRowCountLimit();
|
||||
double brMemlimit = sessionVariable.getBroadcastHashtableMemLimitPercentage();
|
||||
double datasize = join.getGroupExpression().get().child(1).getStatistics().computeSize();
|
||||
double rowCount = join.getGroupExpression().get().child(1).getStatistics().getRowCount();
|
||||
return rowCount <= rowsLimit && datasize <= memLimit * brMemlimit;
|
||||
@ -114,12 +116,12 @@ public class JoinUtils {
|
||||
* @return true if the equal can be used as hash join condition
|
||||
*/
|
||||
public boolean isHashJoinCondition(EqualTo equalTo) {
|
||||
Set<Slot> equalLeft = equalTo.left().collect(Slot.class::isInstance);
|
||||
Set<Slot> equalLeft = equalTo.left().getInputSlots();
|
||||
if (equalLeft.isEmpty()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
Set<Slot> equalRight = equalTo.right().collect(Slot.class::isInstance);
|
||||
Set<Slot> equalRight = equalTo.right().getInputSlots();
|
||||
if (equalRight.isEmpty()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@ -55,7 +55,7 @@ public class PlanUtils {
|
||||
* normalize comparison predicate on a binary plan to its two sides are corresponding to the child's output.
|
||||
*/
|
||||
public static ComparisonPredicate maybeCommuteComparisonPredicate(ComparisonPredicate expression, Plan left) {
|
||||
Set<Slot> slots = expression.left().collect(Slot.class::isInstance);
|
||||
Set<Slot> slots = expression.left().getInputSlots();
|
||||
Set<Slot> leftSlots = left.getOutputSet();
|
||||
Set<Slot> buffer = Sets.newHashSet(slots);
|
||||
buffer.removeAll(leftSlots);
|
||||
|
||||
Reference in New Issue
Block a user