[fix](nereids)push more than one runtime filters into cte (#30901)

* push rf into cte, used by tpcds95
This commit is contained in:
minghong
2024-02-21 09:55:30 +08:00
committed by yiguolei
parent c734e79d14
commit cd7230885f
7 changed files with 58 additions and 59 deletions

View File

@ -1284,8 +1284,7 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla
.forEach(s -> rightChildOutputMap.put(s.getExprId(), s));
// translate runtime filter
context.getRuntimeTranslator().ifPresent(runtimeFilterTranslator -> runtimeFilterTranslator
.getRuntimeFilterOfHashJoinNode(physicalHashJoin)
context.getRuntimeTranslator().ifPresent(runtimeFilterTranslator -> physicalHashJoin.getRuntimeFilters()
.forEach(filter -> runtimeFilterTranslator.createLegacyRuntimeFilter(filter, hashJoinNode, context)));
// make intermediate tuple
@ -1484,8 +1483,7 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla
}
// translate runtime filter
context.getRuntimeTranslator().ifPresent(runtimeFilterTranslator -> {
Set<RuntimeFilter> filters = runtimeFilterTranslator
.getRuntimeFilterOfHashJoinNode(nestedLoopJoin);
List<RuntimeFilter> filters = nestedLoopJoin.getRuntimeFilters();
filters.forEach(filter -> runtimeFilterTranslator
.createLegacyRuntimeFilter(filter, nestedLoopJoinNode, context));
if (filters.stream().anyMatch(filter -> filter.getType() == TRuntimeFilterType.BITMAP)) {

View File

@ -28,7 +28,6 @@ import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.plans.physical.AbstractPhysicalJoin;
import org.apache.doris.nereids.trees.plans.physical.RuntimeFilter;
import org.apache.doris.planner.CTEScanNode;
import org.apache.doris.planner.DataStreamSink;
@ -47,7 +46,6 @@ import com.google.common.collect.ImmutableMap;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Set;
/**
* translate runtime filter
@ -58,11 +56,6 @@ public class RuntimeFilterTranslator {
public RuntimeFilterTranslator(RuntimeFilterContext context) {
this.context = context;
context.generatePhysicalHashJoinToRuntimeFilter();
}
public Set<RuntimeFilter> getRuntimeFilterOfHashJoinNode(AbstractPhysicalJoin join) {
return context.getRuntimeFilterOnHashJoinNode(join);
}
public RuntimeFilterContext getContext() {

View File

@ -110,8 +110,6 @@ public class RuntimeFilterContext {
// exprId to olap scan node slotRef because the slotRef will be changed when translating.
private final Map<ExprId, SlotRef> exprIdToOlapScanNodeSlotRef = Maps.newHashMap();
private final Map<AbstractPhysicalJoin, Set<RuntimeFilter>> runtimeFilterOnHashJoinNode = Maps.newHashMap();
// alias -> alias's child, if there's a key that is alias's child, the key-value will change by this way
// Alias(A) = B, now B -> A in map, and encounter Alias(B) -> C, the kv will be C -> A.
// you can see disjoint set data structure to learn the processing detailed.
@ -191,19 +189,31 @@ public class RuntimeFilterContext {
public void removeFilter(ExprId targetId, PhysicalHashJoin builderNode) {
List<RuntimeFilter> filters = targetExprIdToFilter.get(targetId);
if (filters != null) {
Iterator<RuntimeFilter> iter = filters.iterator();
while (iter.hasNext()) {
RuntimeFilter rf = iter.next();
Iterator<RuntimeFilter> filterIter = filters.iterator();
while (filterIter.hasNext()) {
RuntimeFilter rf = filterIter.next();
if (rf.getBuilderNode().equals(builderNode)) {
builderNode.getRuntimeFilters().remove(rf);
for (int i = 0; i < rf.getTargetSlots().size(); i++) {
Slot targetSlot = rf.getTargetSlots().get(i);
Iterator<Slot> targetSlotIter = rf.getTargetSlots().listIterator();
Iterator<PhysicalRelation> targetScanIter = rf.getTargetScans().iterator();
Iterator<Expression> targetExpressionIter = rf.getTargetExpressions().iterator();
Slot targetSlot;
PhysicalRelation targetScan;
while (targetScanIter.hasNext() && targetSlotIter.hasNext() && targetExpressionIter.hasNext()) {
targetExpressionIter.next();
targetScan = targetScanIter.next();
targetSlot = targetSlotIter.next();
if (targetSlot.getExprId().equals(targetId)) {
rf.getTargetScans().get(i).removeAppliedRuntimeFilter(rf);
targetScan.removeAppliedRuntimeFilter(rf);
targetExpressionIter.remove();
targetScanIter.remove();
targetSlotIter.remove();
}
}
iter.remove();
prunedRF.add(rf);
if (rf.getTargetSlots().isEmpty()) {
builderNode.getRuntimeFilters().remove(rf);
filterIter.remove();
prunedRF.add(rf);
}
}
}
}
@ -255,15 +265,6 @@ public class RuntimeFilterContext {
return scanNodeOfLegacyRuntimeFilterTarget;
}
public Set<RuntimeFilter> getRuntimeFilterOnHashJoinNode(AbstractPhysicalJoin join) {
return runtimeFilterOnHashJoinNode.getOrDefault(join, Collections.emptySet());
}
public void generatePhysicalHashJoinToRuntimeFilter() {
targetExprIdToFilter.values().forEach(filters -> filters.forEach(filter -> runtimeFilterOnHashJoinNode
.computeIfAbsent(filter.getBuilderNode(), k -> Sets.newHashSet()).add(filter)));
}
public Map<ExprId, List<RuntimeFilter>> getTargetExprIdToFilter() {
return targetExprIdToFilter;
}

View File

@ -158,7 +158,7 @@ public class RuntimeFilterGenerator extends PlanPostProcessor {
// the most right deep buildNode from rfsToPushDown is used as buildNode for pushDown rf
// since the srcExpr are the same, all buildNodes of rfToPushDown are in the same tree path
// the longest ancestors means its corresponding rf build node is the most right deep one.
RuntimeFilter rightDeep = rfsToPushDown.get(0);
List<RuntimeFilter> rightDeepRfs = Lists.newArrayList();
List<Plan> rightDeepAncestors = rfsToPushDown.get(0).getBuilderNode().getAncestors();
int rightDeepAncestorsSize = rightDeepAncestors.size();
RuntimeFilter leftTop = rfsToPushDown.get(0);
@ -166,10 +166,15 @@ public class RuntimeFilterGenerator extends PlanPostProcessor {
for (RuntimeFilter rf : rfsToPushDown) {
List<Plan> ancestors = rf.getBuilderNode().getAncestors();
int currentAncestorsSize = ancestors.size();
if (currentAncestorsSize > rightDeepAncestorsSize) {
rightDeep = rf;
rightDeepAncestorsSize = currentAncestorsSize;
rightDeepAncestors = ancestors;
if (currentAncestorsSize >= rightDeepAncestorsSize) {
if (currentAncestorsSize == rightDeepAncestorsSize) {
rightDeepRfs.add(rf);
} else {
rightDeepAncestorsSize = currentAncestorsSize;
rightDeepAncestors = ancestors;
rightDeepRfs.clear();
rightDeepRfs.add(rf);
}
}
if (currentAncestorsSize < leftTopAncestorsSize) {
leftTopAncestorsSize = currentAncestorsSize;
@ -187,7 +192,7 @@ public class RuntimeFilterGenerator extends PlanPostProcessor {
if (cursor instanceof AbstractPhysicalJoin) {
AbstractPhysicalJoin cursorJoin = (AbstractPhysicalJoin) cursor;
valid = (!RuntimeFilterGenerator.DENIED_JOIN_TYPES
.contains(cursorJoin.getJoinType())
.contains(cursorJoin.getJoinType())
|| cursorJoin.isMarkJoin()) && valid;
}
if (!valid) {
@ -199,27 +204,29 @@ public class RuntimeFilterGenerator extends PlanPostProcessor {
break;
}
Expression rightDeepTargetExpressionOnCTE = null;
int targetCount = rightDeep.getTargetExpressions().size();
for (int i = 0; i < targetCount; i++) {
PhysicalRelation rel = rightDeep.getTargetScans().get(i);
if (rel instanceof PhysicalCTEConsumer
&& ((PhysicalCTEConsumer) rel).getCteId().equals(cteId)) {
rightDeepTargetExpressionOnCTE = rightDeep.getTargetExpressions().get(i);
break;
for (RuntimeFilter rfToPush : rightDeepRfs) {
Expression rightDeepTargetExpressionOnCTE = null;
int targetCount = rfToPush.getTargetExpressions().size();
for (int i = 0; i < targetCount; i++) {
PhysicalRelation rel = rfToPush.getTargetScans().get(i);
if (rel instanceof PhysicalCTEConsumer
&& ((PhysicalCTEConsumer) rel).getCteId().equals(cteId)) {
rightDeepTargetExpressionOnCTE = rfToPush.getTargetExpressions().get(i);
break;
}
}
}
boolean pushedDown = doPushDownIntoCTEProducerInternal(
rightDeep,
rightDeepTargetExpressionOnCTE,
rfCtx,
rfCtx.getCteProduceMap().get(cteId)
);
if (pushedDown) {
rfCtx.removeFilter(
rightDeepTargetExpressionOnCTE.getInputSlotExprIds().iterator().next(),
(PhysicalHashJoin) rightDeep.getBuilderNode());
boolean pushedDown = doPushDownIntoCTEProducerInternal(
rfToPush,
rightDeepTargetExpressionOnCTE,
rfCtx,
rfCtx.getCteProduceMap().get(cteId)
);
if (pushedDown) {
rfCtx.removeFilter(
rightDeepTargetExpressionOnCTE.getInputSlotExprIds().iterator().next(),
(PhysicalHashJoin) rfToPush.getBuilderNode());
}
}
}
}

View File

@ -295,7 +295,7 @@ public class RuntimeFilterTest extends SSBTestBase {
.rewrite()
.implement();
PhysicalPlan plan = checker.getPhysicalPlan();
new PlanPostProcessors(checker.getCascadesContext()).process(plan);
plan = new PlanPostProcessors(checker.getCascadesContext()).process(plan);
System.out.println(plan.treeString());
new PhysicalPlanTranslator(new PlanTranslatorContext(checker.getCascadesContext())).translatePlan(plan);
RuntimeFilterContext context = checker.getCascadesContext().getRuntimeFilterContext();

View File

@ -26,7 +26,7 @@ PhysicalCteAnchor ( cteId=CTEId#0 )
----------------------hashJoin[INNER_JOIN] hashCondition=((ws1.ws_web_site_sk = web_site.web_site_sk)) otherCondition=() build RFs:RF5 web_site_sk->[ws_web_site_sk]
------------------------hashJoin[INNER_JOIN] hashCondition=((ws1.ws_ship_addr_sk = customer_address.ca_address_sk)) otherCondition=() build RFs:RF4 ca_address_sk->[ws_ship_addr_sk]
--------------------------hashJoin[INNER_JOIN] hashCondition=((ws1.ws_ship_date_sk = date_dim.d_date_sk)) otherCondition=() build RFs:RF3 d_date_sk->[ws_ship_date_sk]
----------------------------hashJoin[RIGHT_SEMI_JOIN] hashCondition=((ws1.ws_order_number = web_returns.wr_order_number)) otherCondition=() build RFs:RF7 ws_order_number->[ws_order_number,ws_order_number]
----------------------------hashJoin[RIGHT_SEMI_JOIN] hashCondition=((ws1.ws_order_number = web_returns.wr_order_number)) otherCondition=() build RFs:RF2 ws_order_number->[wr_order_number];RF7 ws_order_number->[ws_order_number,ws_order_number]
------------------------------PhysicalProject
--------------------------------hashJoin[INNER_JOIN] hashCondition=((web_returns.wr_order_number = ws_wh.ws_order_number)) otherCondition=()
----------------------------------PhysicalDistribute[DistributionSpecHash]

View File

@ -26,7 +26,7 @@ PhysicalCteAnchor ( cteId=CTEId#0 )
----------------------hashJoin[INNER_JOIN] hashCondition=((ws1.ws_web_site_sk = web_site.web_site_sk)) otherCondition=() build RFs:RF5 web_site_sk->[ws_web_site_sk]
------------------------hashJoin[INNER_JOIN] hashCondition=((ws1.ws_ship_addr_sk = customer_address.ca_address_sk)) otherCondition=() build RFs:RF4 ca_address_sk->[ws_ship_addr_sk]
--------------------------hashJoin[INNER_JOIN] hashCondition=((ws1.ws_ship_date_sk = date_dim.d_date_sk)) otherCondition=() build RFs:RF3 d_date_sk->[ws_ship_date_sk]
----------------------------hashJoin[RIGHT_SEMI_JOIN] hashCondition=((ws1.ws_order_number = web_returns.wr_order_number)) otherCondition=() build RFs:RF7 ws_order_number->[ws_order_number,ws_order_number]
----------------------------hashJoin[RIGHT_SEMI_JOIN] hashCondition=((ws1.ws_order_number = web_returns.wr_order_number)) otherCondition=() build RFs:RF2 ws_order_number->[wr_order_number];RF7 ws_order_number->[ws_order_number,ws_order_number]
------------------------------PhysicalProject
--------------------------------hashJoin[INNER_JOIN] hashCondition=((web_returns.wr_order_number = ws_wh.ws_order_number)) otherCondition=() build RFs:RF1 wr_order_number->[ws_order_number]
----------------------------------PhysicalDistribute[DistributionSpecHash]