[opt](nereids) enable runtime filter use cte as target #40815 (2.1) (#41090)

## Proposed changes
pick #40815
Issue Number: close #xxx

<!--Describe your changes.-->
This commit is contained in:
minghong
2024-09-23 22:34:03 +08:00
committed by GitHub
parent a6ef7e00e4
commit 5bcea1983d
90 changed files with 709 additions and 748 deletions

View File

@ -1226,9 +1226,6 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla
multiCastPlanFragment.setOutputExprs(outputs);
context.getCteProduceFragments().put(cteId, multiCastPlanFragment);
context.getCteProduceMap().put(cteId, cteProducer);
if (context.getRuntimeTranslator().isPresent()) {
context.getRuntimeTranslator().get().getContext().getCteProduceMap().put(cteId, cteProducer);
}
context.getPlanFragments().add(multiCastPlanFragment);
return child;
}

View File

@ -19,7 +19,6 @@ package org.apache.doris.nereids.processor.post;
import org.apache.doris.analysis.SlotRef;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.trees.expressions.CTEId;
import org.apache.doris.nereids.trees.expressions.EqualPredicate;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Expression;
@ -28,7 +27,6 @@ import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.physical.AbstractPhysicalJoin;
import org.apache.doris.nereids.trees.plans.physical.PhysicalCTEProducer;
import org.apache.doris.nereids.trees.plans.physical.PhysicalHashJoin;
import org.apache.doris.nereids.trees.plans.physical.PhysicalRelation;
import org.apache.doris.nereids.trees.plans.physical.RuntimeFilter;
@ -119,11 +117,6 @@ public class RuntimeFilterContext {
private final Map<Plan, EffectiveSrcType> effectiveSrcNodes = Maps.newHashMap();
private final Map<CTEId, PhysicalCTEProducer> cteProducerMap = Maps.newLinkedHashMap();
// cte whose runtime filter has been extracted
private final Set<CTEId> processedCTE = Sets.newHashSet();
private final SessionVariable sessionVariable;
private final FilterSizeLimits limits;
@ -160,10 +153,6 @@ public class RuntimeFilterContext {
this.limits = new FilterSizeLimits(sessionVariable);
}
public void setRelationsUsedByPlan(Plan plan, Set<PhysicalRelation> relations) {
relationsUsedByPlan.put(plan, relations);
}
/**
* return true, if the relation is in the subtree
*/
@ -185,14 +174,6 @@ public class RuntimeFilterContext {
return limits;
}
public Map<CTEId, PhysicalCTEProducer> getCteProduceMap() {
return cteProducerMap;
}
public Set<CTEId> getProcessedCTE() {
return processedCTE;
}
public void setTargetExprIdToFilter(ExprId id, RuntimeFilter filter) {
Preconditions.checkArgument(filter.getTargetSlots().stream().anyMatch(expr -> expr.getExprId() == id));
this.targetExprIdToFilter.computeIfAbsent(id, k -> Lists.newArrayList()).add(filter);

View File

@ -101,132 +101,131 @@ public class RuntimeFilterGenerator extends PlanPostProcessor {
@Override
public Plan processRoot(Plan plan, CascadesContext ctx) {
Plan result = plan.accept(this, ctx);
// cte rf
// try to push rf inside CTEProducer
// collect cteProducers
RuntimeFilterContext rfCtx = ctx.getRuntimeFilterContext();
int cteCount = rfCtx.getProcessedCTE().size();
if (cteCount != 0) {
Map<CTEId, Set<PhysicalCTEConsumer>> cteIdToConsumersWithRF = Maps.newHashMap();
Map<CTEId, List<RuntimeFilter>> cteToRFsMap = Maps.newHashMap();
Map<PhysicalCTEConsumer, Set<RuntimeFilter>> consumerToRFs = Maps.newHashMap();
Map<PhysicalCTEConsumer, Set<Expression>> consumerToSrcExpression = Maps.newHashMap();
List<RuntimeFilter> allRFs = rfCtx.getNereidsRuntimeFilter();
for (RuntimeFilter rf : allRFs) {
for (PhysicalRelation rel : rf.getTargetScans()) {
if (rel instanceof PhysicalCTEConsumer) {
PhysicalCTEConsumer consumer = (PhysicalCTEConsumer) rel;
CTEId cteId = consumer.getCteId();
cteToRFsMap.computeIfAbsent(cteId, key -> Lists.newArrayList()).add(rf);
cteIdToConsumersWithRF.computeIfAbsent(cteId, key -> Sets.newHashSet()).add(consumer);
consumerToRFs.computeIfAbsent(consumer, key -> Sets.newHashSet()).add(rf);
consumerToSrcExpression.computeIfAbsent(consumer, key -> Sets.newHashSet())
.add(rf.getSrcExpr());
}
Map<CTEId, PhysicalCTEProducer> cteProducerMap = plan.collect(PhysicalCTEProducer.class::isInstance)
.stream().collect(Collectors.toMap(p -> ((PhysicalCTEProducer) p).getCteId(),
p -> (PhysicalCTEProducer) p));
// collect cteConsumers which are RF targets
Map<CTEId, Set<PhysicalCTEConsumer>> cteIdToConsumersWithRF = Maps.newHashMap();
Map<PhysicalCTEConsumer, Set<RuntimeFilter>> consumerToRFs = Maps.newHashMap();
Map<PhysicalCTEConsumer, Set<Expression>> consumerToSrcExpression = Maps.newHashMap();
List<RuntimeFilter> allRFs = rfCtx.getNereidsRuntimeFilter();
for (RuntimeFilter rf : allRFs) {
for (PhysicalRelation rel : rf.getTargetScans()) {
if (rel instanceof PhysicalCTEConsumer) {
PhysicalCTEConsumer consumer = (PhysicalCTEConsumer) rel;
CTEId cteId = consumer.getCteId();
cteIdToConsumersWithRF.computeIfAbsent(cteId, key -> Sets.newHashSet()).add(consumer);
consumerToRFs.computeIfAbsent(consumer, key -> Sets.newHashSet()).add(rf);
consumerToSrcExpression.computeIfAbsent(consumer, key -> Sets.newHashSet())
.add(rf.getSrcExpr());
}
}
for (CTEId cteId : rfCtx.getCteProduceMap().keySet()) {
// if any consumer does not have RF, RF cannot be pushed down.
// cteIdToConsumersWithRF.get(cteId).size() can not be 1, o.w. this cte will be inlined.
if (cteIdToConsumersWithRF.get(cteId) != null
&& ctx.getCteIdToConsumers().get(cteId).size() == cteIdToConsumersWithRF.get(cteId).size()
&& cteIdToConsumersWithRF.get(cteId).size() >= 2) {
// check if there is a common srcExpr among all the consumers
Set<PhysicalCTEConsumer> consumers = cteIdToConsumersWithRF.get(cteId);
PhysicalCTEConsumer consumer0 = consumers.iterator().next();
Set<Expression> candidateSrcExpressions = consumerToSrcExpression.get(consumer0);
for (PhysicalCTEConsumer currentConsumer : consumers) {
Set<Expression> srcExpressionsOnCurrentConsumer = consumerToSrcExpression.get(currentConsumer);
candidateSrcExpressions.retainAll(srcExpressionsOnCurrentConsumer);
if (candidateSrcExpressions.isEmpty()) {
}
for (CTEId cteId : cteIdToConsumersWithRF.keySet()) {
// if any consumer does not have RF, RF cannot be pushed down.
// cteIdToConsumersWithRF.get(cteId).size() can not be 1, o.w. this cte will be inlined.
if (ctx.getCteIdToConsumers().get(cteId).size() == cteIdToConsumersWithRF.get(cteId).size()
&& cteIdToConsumersWithRF.get(cteId).size() >= 2) {
// check if there is a common srcExpr among all the consumers
Set<PhysicalCTEConsumer> consumers = cteIdToConsumersWithRF.get(cteId);
PhysicalCTEConsumer consumer0 = consumers.iterator().next();
Set<Expression> candidateSrcExpressions = consumerToSrcExpression.get(consumer0);
for (PhysicalCTEConsumer currentConsumer : consumers) {
Set<Expression> srcExpressionsOnCurrentConsumer = consumerToSrcExpression.get(currentConsumer);
candidateSrcExpressions.retainAll(srcExpressionsOnCurrentConsumer);
if (candidateSrcExpressions.isEmpty()) {
break;
}
}
if (!candidateSrcExpressions.isEmpty()) {
// find RFs to push down
for (Expression srcExpr : candidateSrcExpressions) {
List<RuntimeFilter> rfsToPushDown = Lists.newArrayList();
for (PhysicalCTEConsumer consumer : cteIdToConsumersWithRF.get(cteId)) {
for (RuntimeFilter rf : consumerToRFs.get(consumer)) {
if (rf.getSrcExpr().equals(srcExpr)) {
rfsToPushDown.add(rf);
}
}
}
if (rfsToPushDown.isEmpty()) {
break;
}
}
if (!candidateSrcExpressions.isEmpty()) {
// find RFs to push down
for (Expression srcExpr : candidateSrcExpressions) {
List<RuntimeFilter> rfsToPushDown = Lists.newArrayList();
for (PhysicalCTEConsumer consumer : cteIdToConsumersWithRF.get(cteId)) {
for (RuntimeFilter rf : consumerToRFs.get(consumer)) {
if (rf.getSrcExpr().equals(srcExpr)) {
rfsToPushDown.add(rf);
}
// the most right deep buildNode from rfsToPushDown is used as buildNode for pushDown rf
// since the srcExpr are the same, all buildNodes of rfToPushDown are in the same tree path
// the longest ancestors means its corresponding rf build node is the most right deep one.
List<RuntimeFilter> rightDeepRfs = Lists.newArrayList();
List<Plan> rightDeepAncestors = rfsToPushDown.get(0).getBuilderNode().getAncestors();
int rightDeepAncestorsSize = rightDeepAncestors.size();
RuntimeFilter leftTop = rfsToPushDown.get(0);
int leftTopAncestorsSize = rightDeepAncestorsSize;
for (RuntimeFilter rf : rfsToPushDown) {
List<Plan> ancestors = rf.getBuilderNode().getAncestors();
int currentAncestorsSize = ancestors.size();
if (currentAncestorsSize >= rightDeepAncestorsSize) {
if (currentAncestorsSize == rightDeepAncestorsSize) {
rightDeepRfs.add(rf);
} else {
rightDeepAncestorsSize = currentAncestorsSize;
rightDeepAncestors = ancestors;
rightDeepRfs.clear();
rightDeepRfs.add(rf);
}
}
if (rfsToPushDown.isEmpty()) {
if (currentAncestorsSize < leftTopAncestorsSize) {
leftTopAncestorsSize = currentAncestorsSize;
leftTop = rf;
}
}
Preconditions.checkArgument(rightDeepAncestors.contains(leftTop.getBuilderNode()));
// check nodes between right deep and left top are SPJ and not denied join and not mark join
boolean valid = true;
for (Plan cursor : rightDeepAncestors) {
if (cursor.equals(leftTop.getBuilderNode())) {
break;
}
// the most right deep buildNode from rfsToPushDown is used as buildNode for pushDown rf
// since the srcExpr are the same, all buildNodes of rfToPushDown are in the same tree path
// the longest ancestors means its corresponding rf build node is the most right deep one.
List<RuntimeFilter> rightDeepRfs = Lists.newArrayList();
List<Plan> rightDeepAncestors = rfsToPushDown.get(0).getBuilderNode().getAncestors();
int rightDeepAncestorsSize = rightDeepAncestors.size();
RuntimeFilter leftTop = rfsToPushDown.get(0);
int leftTopAncestorsSize = rightDeepAncestorsSize;
for (RuntimeFilter rf : rfsToPushDown) {
List<Plan> ancestors = rf.getBuilderNode().getAncestors();
int currentAncestorsSize = ancestors.size();
if (currentAncestorsSize >= rightDeepAncestorsSize) {
if (currentAncestorsSize == rightDeepAncestorsSize) {
rightDeepRfs.add(rf);
} else {
rightDeepAncestorsSize = currentAncestorsSize;
rightDeepAncestors = ancestors;
rightDeepRfs.clear();
rightDeepRfs.add(rf);
}
}
if (currentAncestorsSize < leftTopAncestorsSize) {
leftTopAncestorsSize = currentAncestorsSize;
leftTop = rf;
}
// valid = valid && SPJ_PLAN.contains(cursor.getClass());
if (cursor instanceof AbstractPhysicalJoin) {
AbstractPhysicalJoin cursorJoin = (AbstractPhysicalJoin) cursor;
valid = (!RuntimeFilterGenerator.DENIED_JOIN_TYPES
.contains(cursorJoin.getJoinType())
|| cursorJoin.isMarkJoin()) && valid;
}
Preconditions.checkArgument(rightDeepAncestors.contains(leftTop.getBuilderNode()));
// check nodes between right deep and left top are SPJ and not denied join and not mark join
boolean valid = true;
for (Plan cursor : rightDeepAncestors) {
if (cursor.equals(leftTop.getBuilderNode())) {
break;
}
// valid = valid && SPJ_PLAN.contains(cursor.getClass());
if (cursor instanceof AbstractPhysicalJoin) {
AbstractPhysicalJoin cursorJoin = (AbstractPhysicalJoin) cursor;
valid = (!RuntimeFilterGenerator.DENIED_JOIN_TYPES
.contains(cursorJoin.getJoinType())
|| cursorJoin.isMarkJoin()) && valid;
}
if (!valid) {
break;
}
}
if (!valid) {
break;
}
}
for (RuntimeFilter rfToPush : rightDeepRfs) {
Expression rightDeepTargetExpressionOnCTE = null;
int targetCount = rfToPush.getTargetExpressions().size();
for (int i = 0; i < targetCount; i++) {
PhysicalRelation rel = rfToPush.getTargetScans().get(i);
if (rel instanceof PhysicalCTEConsumer
&& ((PhysicalCTEConsumer) rel).getCteId().equals(cteId)) {
rightDeepTargetExpressionOnCTE = rfToPush.getTargetExpressions().get(i);
break;
}
if (!valid) {
break;
}
for (RuntimeFilter rfToPush : rightDeepRfs) {
Expression rightDeepTargetExpressionOnCTE = null;
int targetCount = rfToPush.getTargetExpressions().size();
for (int i = 0; i < targetCount; i++) {
PhysicalRelation rel = rfToPush.getTargetScans().get(i);
if (rel instanceof PhysicalCTEConsumer
&& ((PhysicalCTEConsumer) rel).getCteId().equals(cteId)) {
rightDeepTargetExpressionOnCTE = rfToPush.getTargetExpressions().get(i);
break;
}
}
boolean pushedDown = doPushDownIntoCTEProducerInternal(
boolean pushedDown = doPushDownIntoCTEProducerInternal(
rfToPush,
rightDeepTargetExpressionOnCTE,
rfCtx,
cteProducerMap.get(cteId)
);
if (pushedDown) {
rfCtx.removeFilter(
rfToPush,
rightDeepTargetExpressionOnCTE,
rfCtx,
rfCtx.getCteProduceMap().get(cteId)
);
if (pushedDown) {
rfCtx.removeFilter(
rfToPush,
rightDeepTargetExpressionOnCTE.getInputSlotExprIds().iterator().next());
}
rightDeepTargetExpressionOnCTE.getInputSlotExprIds().iterator().next());
}
}
}
@ -265,8 +264,7 @@ public class RuntimeFilterGenerator extends PlanPostProcessor {
.filter(type -> (type.getValue() & ctx.getSessionVariable().getRuntimeFilterType()) > 0)
.collect(Collectors.toList());
List<Expression> hashJoinConjuncts = join.getHashJoinConjuncts().stream().collect(Collectors.toList());
boolean buildSideContainsConsumer = hasCTEConsumerDescendant((PhysicalPlan) join.right());
List<Expression> hashJoinConjuncts = join.getHashJoinConjuncts();
for (int i = 0; i < hashJoinConjuncts.size(); i++) {
EqualPredicate equalTo = JoinUtils.swapEqualToForChildrenOrder(
(EqualPredicate) hashJoinConjuncts.get(i), join.left().getOutputSet());
@ -278,8 +276,7 @@ public class RuntimeFilterGenerator extends PlanPostProcessor {
long buildSideNdv = getBuildSideNdv(join, equalTo);
Pair<PhysicalRelation, Slot> pair = ctx.getAliasTransferMap().get(equalTo.right());
// CteConsumer is not allowed to generate RF in order to avoid RF cycle.
if ((pair == null && buildSideContainsConsumer)
|| (pair != null && pair.first instanceof PhysicalCTEConsumer)) {
if (pair == null) {
continue;
}
if (equalTo.left().getInputSlots().size() == 1) {
@ -306,20 +303,6 @@ public class RuntimeFilterGenerator extends PlanPostProcessor {
return scan;
}
@Override
public PhysicalCTEProducer<? extends Plan> visitPhysicalCTEProducer(PhysicalCTEProducer<? extends Plan> producer,
CascadesContext context) {
CTEId cteId = producer.getCteId();
context.getRuntimeFilterContext().getCteProduceMap().put(cteId, producer);
Set<CTEId> processedCTE = context.getRuntimeFilterContext().getProcessedCTE();
if (!processedCTE.contains(cteId)) {
PhysicalPlan inputPlanNode = (PhysicalPlan) producer.child(0);
inputPlanNode.accept(this, context);
processedCTE.add(cteId);
}
return producer;
}
private void generateBitMapRuntimeFilterForNLJ(PhysicalNestedLoopJoin<? extends Plan, ? extends Plan> join,
RuntimeFilterContext ctx) {
if (join.getJoinType() != JoinType.LEFT_SEMI_JOIN && join.getJoinType() != JoinType.CROSS_JOIN) {