[opt](Nereids)Join cluster connectivity (#27833)
* estimation join stats by connectivity
This commit is contained in:
@ -88,4 +88,8 @@ public class PlanContext {
|
||||
public List<Statistics> getChildrenStatistics() {
|
||||
return childrenStats;
|
||||
}
|
||||
|
||||
public StatementContext getStatementContext() {
|
||||
return connectContext.getStatementContext();
|
||||
}
|
||||
}
|
||||
|
||||
@ -43,6 +43,7 @@ import com.google.common.collect.Maps;
|
||||
import com.google.common.collect.Sets;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collection;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
@ -88,6 +89,9 @@ public class StatementContext {
|
||||
private final Map<CTEId, LogicalPlan> rewrittenCteConsumer = new HashMap<>();
|
||||
private final Set<String> viewDdlSqlSet = Sets.newHashSet();
|
||||
|
||||
// collect all hash join conditions to compute node connectivity in join graph
|
||||
private final List<Expression> joinFilters = new ArrayList<>();
|
||||
|
||||
private final List<Hint> hints = new ArrayList<>();
|
||||
|
||||
public StatementContext() {
|
||||
@ -242,4 +246,12 @@ public class StatementContext {
|
||||
public List<Hint> getHints() {
|
||||
return ImmutableList.copyOf(hints);
|
||||
}
|
||||
|
||||
public List<Expression> getJoinFilters() {
|
||||
return joinFilters;
|
||||
}
|
||||
|
||||
public void addJoinFilters(Collection<Expression> newJoinFilters) {
|
||||
this.joinFilters.addAll(newJoinFilters);
|
||||
}
|
||||
}
|
||||
|
||||
@ -22,6 +22,7 @@ import org.apache.doris.nereids.properties.DistributionSpec;
|
||||
import org.apache.doris.nereids.properties.DistributionSpecGather;
|
||||
import org.apache.doris.nereids.properties.DistributionSpecHash;
|
||||
import org.apache.doris.nereids.properties.DistributionSpecReplicated;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.plans.Plan;
|
||||
import org.apache.doris.nereids.trees.plans.physical.PhysicalAssertNumRows;
|
||||
import org.apache.doris.nereids.trees.plans.physical.PhysicalDeferMaterializeOlapScan;
|
||||
@ -48,6 +49,8 @@ import org.apache.doris.statistics.Statistics;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
|
||||
import java.util.Collections;
|
||||
|
||||
class CostModelV1 extends PlanVisitor<Cost, PlanContext> {
|
||||
|
||||
// for a join, skew = leftRowCount/rightRowCount
|
||||
@ -262,6 +265,17 @@ class CostModelV1 extends PlanVisitor<Cost, PlanContext> {
|
||||
|
||||
double leftRowCount = probeStats.getRowCount();
|
||||
double rightRowCount = buildStats.getRowCount();
|
||||
if (leftRowCount == rightRowCount
|
||||
&& physicalHashJoin.getGroupExpression().isPresent()
|
||||
&& physicalHashJoin.getGroupExpression().get().getOwnerGroup() != null
|
||||
&& !physicalHashJoin.getGroupExpression().get().getOwnerGroup().isStatsReliable()) {
|
||||
int leftConnectivity = computeConnectivity(physicalHashJoin.left(), context);
|
||||
int rightConnectivity = computeConnectivity(physicalHashJoin.right(), context);
|
||||
if (rightConnectivity < leftConnectivity) {
|
||||
leftRowCount += 1;
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
pattern1: L join1 (Agg1() join2 Agg2())
|
||||
result number of join2 may much less than Agg1.
|
||||
@ -310,6 +324,20 @@ class CostModelV1 extends PlanVisitor<Cost, PlanContext> {
|
||||
);
|
||||
}
|
||||
|
||||
/*
|
||||
in a join cluster graph, if a node has higher connectivity, it is more likely to be reduced
|
||||
by runtime filters, and it is also more likely to produce effective runtime filters.
|
||||
Thus, we prefer to put the node with higher connectivity on the join right side.
|
||||
*/
|
||||
private int computeConnectivity(
|
||||
Plan plan, PlanContext context) {
|
||||
int connectCount = 0;
|
||||
for (Expression expr : context.getStatementContext().getJoinFilters()) {
|
||||
connectCount += Collections.disjoint(expr.getInputSlots(), plan.getOutputSet()) ? 0 : 1;
|
||||
}
|
||||
return connectCount;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Cost visitPhysicalNestedLoopJoin(
|
||||
PhysicalNestedLoopJoin<? extends Plan, ? extends Plan> nestedLoopJoin,
|
||||
|
||||
@ -93,7 +93,7 @@ public class JoinCommute extends OneExplorationRuleFactory {
|
||||
if (swapType == SwapType.LEFT_ZIG_ZAG) {
|
||||
double leftRows = join.left().getGroup().getStatistics().getRowCount();
|
||||
double rightRows = join.right().getGroup().getStatistics().getRowCount();
|
||||
return leftRows < rightRows && isZigZagJoin(join);
|
||||
return leftRows <= rightRows && isZigZagJoin(join);
|
||||
}
|
||||
|
||||
return true;
|
||||
|
||||
@ -87,6 +87,7 @@ public class ReorderJoin extends OneRewriteRuleFactory {
|
||||
Plan plan = joinToMultiJoin(filter, planToHintType);
|
||||
Preconditions.checkState(plan instanceof MultiJoin);
|
||||
MultiJoin multiJoin = (MultiJoin) plan;
|
||||
ctx.statementContext.addJoinFilters(multiJoin.getJoinFilter());
|
||||
ctx.statementContext.setMaxNAryInnerJoin(multiJoin.children().size());
|
||||
Plan after = multiJoinToJoin(multiJoin, planToHintType);
|
||||
return after;
|
||||
|
||||
Reference in New Issue
Block a user