From 59198ed59e5d33a27dcfc4368ec719ef00f5d64d Mon Sep 17 00:00:00 2001 From: xzj7019 <131111794+xzj7019@users.noreply.github.com> Date: Thu, 29 Jun 2023 16:58:31 +0800 Subject: [PATCH] [improvement](nereids) Support rf into cte (#21114) Support runtime filter pushing down into cte internal. --- .../translator/PhysicalPlanTranslator.java | 3 + .../translator/RuntimeFilterTranslator.java | 4 + .../processor/post/RuntimeFilterContext.java | 38 ++ .../post/RuntimeFilterGenerator.java | 585 ++++++++++++++---- .../plans/physical/PhysicalCTEConsumer.java | 18 +- .../doris/planner/MultiCastPlanFragment.java | 3 +- .../apache/doris/planner/PlanFragment.java | 8 + 7 files changed, 548 insertions(+), 111 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java index 4aec102644..8210c1a082 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java @@ -844,6 +844,9 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor getTargetOnScanNode(ObjectId id) { return context.getTargetOnOlapScanNodeMap().getOrDefault(id, Collections.emptyList()); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/RuntimeFilterContext.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/RuntimeFilterContext.java index b9d2ac301e..76fb3311fe 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/RuntimeFilterContext.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/RuntimeFilterContext.java @@ -20,6 +20,8 @@ package org.apache.doris.nereids.processor.post; import org.apache.doris.analysis.SlotRef; import org.apache.doris.common.IdGenerator; import org.apache.doris.common.Pair; +import org.apache.doris.nereids.trees.expressions.CTEId; +import org.apache.doris.nereids.trees.expressions.EqualTo; import org.apache.doris.nereids.trees.expressions.ExprId; import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.expressions.Slot; @@ -27,6 +29,7 @@ import org.apache.doris.nereids.trees.expressions.SlotReference; import org.apache.doris.nereids.trees.plans.ObjectId; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.physical.AbstractPhysicalJoin; +import org.apache.doris.nereids.trees.plans.physical.PhysicalCTEProducer; import org.apache.doris.nereids.trees.plans.physical.PhysicalHashJoin; import org.apache.doris.nereids.trees.plans.physical.PhysicalRelation; import org.apache.doris.nereids.trees.plans.physical.RuntimeFilter; @@ -77,6 +80,21 @@ public class RuntimeFilterContext { private final Map scanNodeOfLegacyRuntimeFilterTarget = Maps.newHashMap(); private final Set effectiveSrcNodes = Sets.newHashSet(); + + // cte to related joins map which can extract common runtime filter to cte inside + private final Map> cteToJoinsMap = Maps.newHashMap(); + + // cte candidates which can be pushed into common runtime filter into from outside + private final Map> cteRFPushDownMap = Maps.newHashMap(); + + private final Map cteProducerMap = Maps.newHashMap(); + + // cte whose runtime filter has been extracted + private final Set processedCTE = Sets.newHashSet(); + + // cte whose outer runtime filter has been pushed down into + private final Set pushedDownCTE = Sets.newHashSet(); + private final SessionVariable sessionVariable; private final FilterSizeLimits limits; @@ -96,6 +114,26 @@ public class RuntimeFilterContext { return limits; } + public Map getCteProduceMap() { + return cteProducerMap; + } + + public Map> getCteRFPushDownMap() { + return cteRFPushDownMap; + } + + public Map> getCteToJoinsMap() { + return cteToJoinsMap; + } + + public Set getProcessedCTE() { + return processedCTE; + } + + public Set getPushedDownCTE() { + return pushedDownCTE; + } + public void setTargetExprIdToFilter(ExprId id, RuntimeFilter filter) { Preconditions.checkArgument(filter.getTargetExprs().stream().anyMatch(expr -> expr.getExprId() == id)); this.targetExprIdToFilter.computeIfAbsent(id, k -> Lists.newArrayList()).add(filter); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/RuntimeFilterGenerator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/RuntimeFilterGenerator.java index d042d9a083..19561f9ec9 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/RuntimeFilterGenerator.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/RuntimeFilterGenerator.java @@ -22,7 +22,9 @@ import org.apache.doris.common.Pair; import org.apache.doris.nereids.CascadesContext; import org.apache.doris.nereids.stats.ExpressionEstimation; import org.apache.doris.nereids.trees.expressions.Alias; +import org.apache.doris.nereids.trees.expressions.CTEId; import org.apache.doris.nereids.trees.expressions.EqualTo; +import org.apache.doris.nereids.trees.expressions.ExprId; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.expressions.Not; @@ -32,10 +34,16 @@ import org.apache.doris.nereids.trees.expressions.functions.scalar.BitmapContain import org.apache.doris.nereids.trees.plans.AbstractPlan; import org.apache.doris.nereids.trees.plans.JoinType; import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.logical.LogicalCTEConsumer; +import org.apache.doris.nereids.trees.plans.physical.PhysicalCTEConsumer; +import org.apache.doris.nereids.trees.plans.physical.PhysicalCTEProducer; +import org.apache.doris.nereids.trees.plans.physical.PhysicalDistribute; import org.apache.doris.nereids.trees.plans.physical.PhysicalExcept; +import org.apache.doris.nereids.trees.plans.physical.PhysicalFilter; import org.apache.doris.nereids.trees.plans.physical.PhysicalHashJoin; import org.apache.doris.nereids.trees.plans.physical.PhysicalIntersect; import org.apache.doris.nereids.trees.plans.physical.PhysicalNestedLoopJoin; +import org.apache.doris.nereids.trees.plans.physical.PhysicalOlapScan; import org.apache.doris.nereids.trees.plans.physical.PhysicalPlan; import org.apache.doris.nereids.trees.plans.physical.PhysicalProject; import org.apache.doris.nereids.trees.plans.physical.PhysicalRelation; @@ -50,9 +58,13 @@ import org.apache.doris.thrift.TRuntimeFilterType; import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Lists; import java.util.ArrayList; import java.util.Arrays; +import java.util.HashMap; +import java.util.HashSet; +import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.Set; @@ -70,6 +82,14 @@ public class RuntimeFilterGenerator extends PlanPostProcessor { JoinType.NULL_AWARE_LEFT_ANTI_JOIN ); + private static final Set> SPJ_PLAN = ImmutableSet.of( + PhysicalOlapScan.class, + PhysicalProject.class, + PhysicalFilter.class, + PhysicalDistribute.class, + PhysicalHashJoin.class + ); + private final IdGenerator generator = RuntimeFilterId.createGenerator(); /** @@ -98,121 +118,28 @@ public class RuntimeFilterGenerator extends PlanPostProcessor { Set slots = join.getOutputSet(); slots.forEach(aliasTransferMap::remove); } else { - List legalTypes = Arrays.stream(TRuntimeFilterType.values()) - .filter(type -> (type.getValue() & ctx.getSessionVariable().getRuntimeFilterType()) > 0) - .collect(Collectors.toList()); - // TODO: some complex situation cannot be handled now, see testPushDownThroughJoin. - // we will support it in later version. - for (int i = 0; i < join.getHashJoinConjuncts().size(); i++) { - EqualTo equalTo = ((EqualTo) JoinUtils.swapEqualToForChildrenOrder( - (EqualTo) join.getHashJoinConjuncts().get(i), join.left().getOutputSet())); - for (TRuntimeFilterType type : legalTypes) { - //bitmap rf is generated by nested loop join. - if (type == TRuntimeFilterType.BITMAP) { - continue; - } - if (join.left() instanceof PhysicalUnion - || join.left() instanceof PhysicalIntersect - || join.left() instanceof PhysicalExcept) { - List targetList = new ArrayList<>(); - int projIndex = -1; - for (int j = 0; j < join.left().children().size(); j++) { - PhysicalPlan child = (PhysicalPlan) join.left().child(j); - if (child instanceof PhysicalProject) { - PhysicalProject project = (PhysicalProject) child; - Slot leftSlot = checkTargetChild(equalTo.left()); - if (leftSlot == null) { - break; - } - for (int k = 0; projIndex < 0 && k < project.getProjects().size(); k++) { - NamedExpression expr = (NamedExpression) project.getProjects().get(k); - if (expr.getName().equals(leftSlot.getName())) { - projIndex = k; - break; - } - } - Preconditions.checkState(projIndex >= 0 - && projIndex < project.getProjects().size()); - - NamedExpression targetExpr = (NamedExpression) project.getProjects().get(projIndex); - - SlotReference origSlot = null; - if (targetExpr instanceof Alias) { - origSlot = (SlotReference) targetExpr.child(0); - } else { - origSlot = (SlotReference) targetExpr; - } - if (!aliasTransferMap.containsKey(origSlot)) { - continue; - } - Slot olapScanSlot = aliasTransferMap.get(origSlot).second; - PhysicalRelation scan = aliasTransferMap.get(origSlot).first; - if (type == TRuntimeFilterType.IN_OR_BLOOM - && ctx.getSessionVariable().enablePipelineEngine() - && hasRemoteTarget(join, scan)) { - type = TRuntimeFilterType.BLOOM; - } - targetList.add(olapScanSlot); - ctx.addJoinToTargetMap(join, olapScanSlot.getExprId()); - ctx.setTargetsOnScanNode(aliasTransferMap.get(origSlot).first.getId(), olapScanSlot); - } - } - if (!targetList.isEmpty()) { - long buildSideNdv = getBuildSideNdv(join, equalTo); - RuntimeFilter filter = new RuntimeFilter(generator.getNextId(), - equalTo.right(), targetList, type, i, join, buildSideNdv); - for (int j = 0; j < targetList.size(); j++) { - ctx.setTargetExprIdToFilter(targetList.get(j).getExprId(), filter); - } - } - } else { - // currently, we can ensure children in the two side are corresponding to the equal_to's. - // so right maybe an expression and left is a slot - Slot unwrappedSlot = checkTargetChild(equalTo.left()); - // aliasTransMap doesn't contain the key, means that the path from the olap scan to the join - // contains join with denied join type. for example: a left join b on a.id = b.id - if (unwrappedSlot == null || !aliasTransferMap.containsKey(unwrappedSlot)) { - continue; - } - Slot olapScanSlot = aliasTransferMap.get(unwrappedSlot).second; - PhysicalRelation scan = aliasTransferMap.get(unwrappedSlot).first; - // in-filter is not friendly to pipeline - if (type == TRuntimeFilterType.IN_OR_BLOOM - && ctx.getSessionVariable().enablePipelineEngine() - && hasRemoteTarget(join, scan)) { - type = TRuntimeFilterType.BLOOM; - } - long buildSideNdv = getBuildSideNdv(join, equalTo); - RuntimeFilter filter = new RuntimeFilter(generator.getNextId(), - equalTo.right(), ImmutableList.of(olapScanSlot), type, i, join, buildSideNdv); - ctx.addJoinToTargetMap(join, olapScanSlot.getExprId()); - ctx.setTargetExprIdToFilter(olapScanSlot.getExprId(), filter); - ctx.setTargetsOnScanNode(aliasTransferMap.get(unwrappedSlot).first.getId(), olapScanSlot); - } - } + collectPushDownCTEInfos(join, context); + if (!getPushDownCTECandidates(ctx).isEmpty()) { + pushDownRuntimeFilterIntoCTE(ctx); + } else { + pushDownRuntimeFilterCommon(join, context); } } return join; } - private boolean hasRemoteTarget(AbstractPlan join, AbstractPlan scan) { - Preconditions.checkArgument(join.getMutableState(AbstractPlan.FRAGMENT_ID).isPresent(), - "cannot find fragment id for Join node"); - Preconditions.checkArgument(scan.getMutableState(AbstractPlan.FRAGMENT_ID).isPresent(), - "cannot find fragment id for scan node"); - return join.getMutableState(AbstractPlan.FRAGMENT_ID).get() - != scan.getMutableState(AbstractPlan.FRAGMENT_ID).get(); + @Override + public PhysicalCTEConsumer visitPhysicalCTEConsumer(PhysicalCTEConsumer scan, CascadesContext context) { + RuntimeFilterContext ctx = context.getRuntimeFilterContext(); + scan.getOutput().forEach(slot -> ctx.getAliasTransferMap().put(slot, Pair.of(scan, slot))); + return scan; } - private long getBuildSideNdv(PhysicalHashJoin join, EqualTo equalTo) { - AbstractPlan right = (AbstractPlan) join.right(); - //make ut test friendly - if (right.getStats() == null) { - return -1L; - } - ExpressionEstimation estimator = new ExpressionEstimation(); - ColumnStatistic buildColStats = equalTo.right().accept(estimator, right.getStats()); - return buildColStats.isUnKnown ? -1 : Math.max(1, (long) buildColStats.ndv); + @Override + public PhysicalCTEProducer visitPhysicalCTEProducer(PhysicalCTEProducer producer, CascadesContext context) { + CTEId id = producer.getCteId(); + context.getRuntimeFilterContext().getCteProduceMap().put(id, producer); + return producer; } @Override @@ -297,8 +224,450 @@ public class RuntimeFilterGenerator extends PlanPostProcessor { return scan; } + private long getBuildSideNdv(PhysicalHashJoin join, EqualTo equalTo) { + AbstractPlan right = (AbstractPlan) join.right(); + //make ut test friendly + if (right.getStats() == null) { + return -1L; + } + ExpressionEstimation estimator = new ExpressionEstimation(); + ColumnStatistic buildColStats = equalTo.right().accept(estimator, right.getStats()); + return buildColStats.isUnKnown ? -1 : Math.max(1, (long) buildColStats.ndv); + } + private static Slot checkTargetChild(Expression leftChild) { Expression expression = ExpressionUtils.getExpressionCoveredByCast(leftChild); return expression instanceof Slot ? ((Slot) expression) : null; } + + private void pushDownRuntimeFilterCommon(PhysicalHashJoin join, + CascadesContext context) { + RuntimeFilterContext ctx = context.getRuntimeFilterContext(); + List legalTypes = Arrays.stream(TRuntimeFilterType.values()) + .filter(type -> (type.getValue() & ctx.getSessionVariable().getRuntimeFilterType()) > 0) + .collect(Collectors.toList()); + // TODO: some complex situation cannot be handled now, see testPushDownThroughJoin. + // we will support it in later version. + for (int i = 0; i < join.getHashJoinConjuncts().size(); i++) { + EqualTo equalTo = ((EqualTo) JoinUtils.swapEqualToForChildrenOrder( + (EqualTo) join.getHashJoinConjuncts().get(i), join.left().getOutputSet())); + for (TRuntimeFilterType type : legalTypes) { + //bitmap rf is generated by nested loop join. + if (type == TRuntimeFilterType.BITMAP) { + continue; + } + if (join.left() instanceof PhysicalUnion + || join.left() instanceof PhysicalIntersect + || join.left() instanceof PhysicalExcept) { + doPushDownIntoSetOperation(join, ctx, equalTo, type, i); + } else { + doPushDownBasic(join, context, ctx, equalTo, type, i); + } + } + } + } + + private void doPushDownBasic(PhysicalHashJoin join, CascadesContext context, + RuntimeFilterContext ctx, EqualTo equalTo, TRuntimeFilterType type, int exprOrder) { + Map> aliasTransferMap = ctx.getAliasTransferMap(); + // currently, we can ensure children in the two side are corresponding to the equal_to's. + // so right maybe an expression and left is a slot + Slot unwrappedSlot = checkTargetChild(equalTo.left()); + // aliasTransMap doesn't contain the key, means that the path from the olap scan to the join + // contains join with denied join type. for example: a left join b on a.id = b.id + if (unwrappedSlot == null || !aliasTransferMap.containsKey(unwrappedSlot)) { + return; + } + Slot olapScanSlot = aliasTransferMap.get(unwrappedSlot).second; + PhysicalRelation scan = aliasTransferMap.get(unwrappedSlot).first; + + Preconditions.checkState(olapScanSlot != null && scan != null); + + if (scan instanceof PhysicalCTEConsumer) { + Set processedCTE = context.getRuntimeFilterContext().getProcessedCTE(); + CTEId cteId = ((PhysicalCTEConsumer) scan).getCteId(); + if (!processedCTE.contains(cteId)) { + PhysicalCTEProducer cteProducer = context.getRuntimeFilterContext() + .getCteProduceMap().get(cteId); + PhysicalPlan inputPlanNode = (PhysicalPlan) cteProducer.child(0); + // process cte producer self recursively + inputPlanNode.accept(this, context); + processedCTE.add(cteId); + } + } else { + // in-filter is not friendly to pipeline + if (type == TRuntimeFilterType.IN_OR_BLOOM + && ctx.getSessionVariable().enablePipelineEngine() + && hasRemoteTarget(join, scan)) { + type = TRuntimeFilterType.BLOOM; + } + long buildSideNdv = getBuildSideNdv(join, equalTo); + RuntimeFilter filter = new RuntimeFilter(generator.getNextId(), + equalTo.right(), ImmutableList.of(olapScanSlot), type, exprOrder, join, buildSideNdv); + ctx.addJoinToTargetMap(join, olapScanSlot.getExprId()); + ctx.setTargetExprIdToFilter(olapScanSlot.getExprId(), filter); + ctx.setTargetsOnScanNode(aliasTransferMap.get(unwrappedSlot).first.getId(), olapScanSlot); + } + } + + private void doPushDownIntoSetOperation(PhysicalHashJoin join, + RuntimeFilterContext ctx, EqualTo equalTo, TRuntimeFilterType type, int exprOrder) { + Map> aliasTransferMap = ctx.getAliasTransferMap(); + List targetList = new ArrayList<>(); + int projIndex = -1; + for (int j = 0; j < join.left().children().size(); j++) { + PhysicalPlan child = (PhysicalPlan) join.left().child(j); + if (child instanceof PhysicalProject) { + PhysicalProject project = (PhysicalProject) child; + Slot leftSlot = checkTargetChild(equalTo.left()); + if (leftSlot == null) { + break; + } + for (int k = 0; projIndex < 0 && k < project.getProjects().size(); k++) { + NamedExpression expr = (NamedExpression) project.getProjects().get(k); + if (expr.getName().equals(leftSlot.getName())) { + projIndex = k; + break; + } + } + Preconditions.checkState(projIndex >= 0 + && projIndex < project.getProjects().size()); + + NamedExpression targetExpr = (NamedExpression) project.getProjects().get(projIndex); + + SlotReference origSlot = null; + if (targetExpr instanceof Alias) { + origSlot = (SlotReference) targetExpr.child(0); + } else { + origSlot = (SlotReference) targetExpr; + } + Slot olapScanSlot = aliasTransferMap.get(origSlot).second; + PhysicalRelation scan = aliasTransferMap.get(origSlot).first; + if (type == TRuntimeFilterType.IN_OR_BLOOM + && ctx.getSessionVariable().enablePipelineEngine() + && hasRemoteTarget(join, scan)) { + type = TRuntimeFilterType.BLOOM; + } + targetList.add(olapScanSlot); + ctx.addJoinToTargetMap(join, olapScanSlot.getExprId()); + ctx.setTargetsOnScanNode(aliasTransferMap.get(origSlot).first.getId(), olapScanSlot); + } + } + if (!targetList.isEmpty()) { + long buildSideNdv = getBuildSideNdv(join, equalTo); + RuntimeFilter filter = new RuntimeFilter(generator.getNextId(), + equalTo.right(), targetList, type, exprOrder, join, buildSideNdv); + for (int j = 0; j < targetList.size(); j++) { + ctx.setTargetExprIdToFilter(targetList.get(j).getExprId(), filter); + } + } + } + + private void collectPushDownCTEInfos(PhysicalHashJoin join, + CascadesContext context) { + RuntimeFilterContext ctx = context.getRuntimeFilterContext(); + Set cteIds = new HashSet<>(); + PhysicalPlan leftChild = (PhysicalPlan) join.left(); + PhysicalPlan rightChild = (PhysicalPlan) join.right(); + + Preconditions.checkState(leftChild != null && rightChild != null); + + boolean leftHasCTE = hasCTEConsumerUnderJoin(leftChild, cteIds); + boolean rightHasCTE = hasCTEConsumerUnderJoin(rightChild, cteIds); + // only support single cte in join currently + if ((leftHasCTE && !rightHasCTE) || (!leftHasCTE && rightHasCTE)) { + for (CTEId id : cteIds) { + if (ctx.getCteToJoinsMap().get(id) == null) { + Set newJoin = new HashSet<>(); + newJoin.add(join); + ctx.getCteToJoinsMap().put(id, newJoin); + } else { + ctx.getCteToJoinsMap().get(id).add(join); + } + } + } + if (!ctx.getCteToJoinsMap().isEmpty()) { + analyzeRuntimeFilterPushDownIntoCTEInfos(join, context); + } + } + + private List getPushDownCTECandidates(RuntimeFilterContext ctx) { + List candidates = new ArrayList<>(); + Map> cteRFPushDownMap = ctx.getCteRFPushDownMap(); + for (Map.Entry> entry : cteRFPushDownMap.entrySet()) { + CTEId cteId = entry.getKey().getCteId(); + if (ctx.getPushedDownCTE().contains(cteId)) { + continue; + } + candidates.add(cteId); + } + return candidates; + } + + private boolean hasCTEConsumerUnderJoin(PhysicalPlan root, Set cteIds) { + if (root instanceof PhysicalCTEConsumer) { + cteIds.add(((PhysicalCTEConsumer) root).getCteId()); + return true; + } else if (root.children().size() != 1) { + // only collect cte in one side + return false; + } else if (root instanceof PhysicalDistribute + || root instanceof PhysicalFilter + || root instanceof PhysicalProject) { + // only collect cte as single child node under join + return hasCTEConsumerUnderJoin((PhysicalPlan) root.child(0), cteIds); + } else { + return false; + } + } + + private void analyzeRuntimeFilterPushDownIntoCTEInfos(PhysicalHashJoin curJoin, + CascadesContext context) { + RuntimeFilterContext ctx = context.getRuntimeFilterContext(); + Map> cteToJoinsMap = ctx.getCteToJoinsMap(); + for (Map.Entry> entry : cteToJoinsMap.entrySet()) { + CTEId cteId = entry.getKey(); + Set joinSet = entry.getValue(); + if (joinSet.contains(curJoin)) { + // skip current join + continue; + } + Set cteSet = context.getCteIdToConsumers().get(cteId); + Preconditions.checkState(!cteSet.isEmpty()); + String cteName = cteSet.iterator().next().getName(); + // preconditions for rf pushing into cte producer: + // multiple joins whose join condition is on the same cte's column of the same cte + // the other side of these join conditions are the same column of the same table, or + // they in the same equal sets, such as under an equal join condition + // case 1: two joins with t1.c1 = cte1_consumer1.c1 and t1.c1 = cte1_consumer2.c1 conditions + // rf of t1.c1 can be pushed down into cte1 producer. + // ----------------------hashJoin(t1.c1 = cte2_consumer1.c1) + // ----------------------------CteConsumer[cteId= ( CTEId#1=] ) + // ----------------------------PhysicalOlapScan[t1] + // ----------------------hashJoin(t1.c1 = cte2_consumer2.c1) + // ----------------------------CteConsumer[cteId= ( CTEId#1=] ) + // ----------------------------PhysicalOlapScan[t1] + // case 2: two joins with t1.c1 = cte2_consumer1.c1 and t2.c2 = cte2_consumer2.c1 and another equal join + // condition t1.c1 = t2.c2, which means t1.c1 and t2.c2 are in the same equal set. + // rf of t1.c1 and t2.c2 can be pushed down into cte2 producer. + // --------------------hashJoin(t1.c1 = t2.c2) + // ----------------------hashJoin(t2.c2 = cte2_consumer1.c1) + // ----------------------------CteConsumer[cteId= ( CTEId#1=] ) + // ----------------------------PhysicalOlapScan[t2] + // ----------------------hashJoin(t1.c1 = cte2_consumer2.c1) + // ----------------------------CteConsumer[cteId= ( CTEId#1=] ) + // ----------------------------PhysicalOlapScan[t1] + if (joinSet.size() != cteSet.size()) { + continue; + } + List equalTos = new ArrayList<>(); + Map equalCondToJoinMap = new LinkedHashMap<>(); + for (PhysicalHashJoin join : joinSet) { + // precondition: + // 1. no non-equal join condition + // 2. only equalTo and slotReference both sides + // 3. only support one join condition (will be refined further) + if (join.getOtherJoinConjuncts().size() > 1 + || join.getHashJoinConjuncts().size() != 1 + || !(join.getHashJoinConjuncts().get(0) instanceof EqualTo)) { + break; + } else { + EqualTo equalTo = (EqualTo) join.getHashJoinConjuncts().get(0); + equalTos.add(equalTo); + equalCondToJoinMap.put(equalTo, join); + } + } + if (joinSet.size() == equalTos.size()) { + int matchNum = 0; + Set cteNameSet = new HashSet<>(); + Set anotherSideSlotSet = new HashSet<>(); + for (EqualTo equalTo : equalTos) { + SlotReference left = (SlotReference) equalTo.left(); + SlotReference right = (SlotReference) equalTo.right(); + if (left.getQualifier().size() == 1 && left.getQualifier().get(0).equals(cteName)) { + matchNum += 1; + anotherSideSlotSet.add(right); + cteNameSet.add(left.getQualifiedName()); + } else if (right.getQualifier().size() == 1 && right.getQualifier().get(0).equals(cteName)) { + matchNum += 1; + anotherSideSlotSet.add(left); + cteNameSet.add(right.getQualifiedName()); + } + } + if (matchNum == equalTos.size() && cteNameSet.size() == 1) { + // means all join condition points to the same cte on the same cte column. + // collect the other side columns besides cte column side. + Preconditions.checkState(equalTos.size() == equalCondToJoinMap.size(), + "equalTos.size() != equalCondToJoinMap.size()"); + + PhysicalCTEProducer cteProducer = context.getRuntimeFilterContext().getCteProduceMap().get(cteId); + if (anotherSideSlotSet.size() == 1) { + // meet requirement for pushing down into cte producer + ctx.getCteRFPushDownMap().put(cteProducer, equalCondToJoinMap); + } else { + // check further whether the join upper side can bring equal set, which + // indicating actually the same runtime filter build side + // see above case 2 for reference + List conditions = curJoin.getHashJoinConjuncts(); + boolean inSameEqualSet = false; + for (Expression e : conditions) { + if (e instanceof EqualTo) { + SlotReference oneSide = (SlotReference) ((EqualTo) e).left(); + SlotReference anotherSide = (SlotReference) ((EqualTo) e).right(); + if (anotherSideSlotSet.contains(oneSide) && anotherSideSlotSet.contains(anotherSide)) { + inSameEqualSet = true; + break; + } + } + } + if (inSameEqualSet) { + ctx.getCteRFPushDownMap().put(cteProducer, equalCondToJoinMap); + } + } + } + } + } + } + + private void pushDownRuntimeFilterIntoCTE(RuntimeFilterContext ctx) { + Map> cteRFPushDownMap = ctx.getCteRFPushDownMap(); + for (Map.Entry> entry : cteRFPushDownMap.entrySet()) { + PhysicalCTEProducer cteProducer = entry.getKey(); + Preconditions.checkState(cteProducer != null); + if (ctx.getPushedDownCTE().contains(cteProducer.getCteId())) { + continue; + } + Map equalCondToJoinMap = entry.getValue(); + int exprOrder = 0; + for (Map.Entry innerEntry : equalCondToJoinMap.entrySet()) { + EqualTo equalTo = innerEntry.getKey(); + PhysicalHashJoin join = innerEntry.getValue(); + Preconditions.checkState(join != null); + TRuntimeFilterType type = TRuntimeFilterType.IN_OR_BLOOM; + if (ctx.getSessionVariable().enablePipelineEngine()) { + type = TRuntimeFilterType.BLOOM; + } + EqualTo newEqualTo = ((EqualTo) JoinUtils.swapEqualToForChildrenOrder( + equalTo, join.child(0).getOutputSet())); + doPushDownIntoCTEProducerInternal(join, ctx, newEqualTo, type, exprOrder++, cteProducer); + } + ctx.getPushedDownCTE().add(cteProducer.getCteId()); + } + } + + private void doPushDownIntoCTEProducerInternal(PhysicalHashJoin join, + RuntimeFilterContext ctx, EqualTo equalTo, TRuntimeFilterType type, int exprOrder, + PhysicalCTEProducer cteProducer) { + Map> aliasTransferMap = ctx.getAliasTransferMap(); + PhysicalPlan inputPlanNode = (PhysicalPlan) cteProducer.child(0); + Slot unwrappedSlot = checkTargetChild(equalTo.left()); + // aliasTransMap doesn't contain the key, means that the path from the olap scan to the join + // contains join with denied join type. for example: a left join b on a.id = b.id + if (unwrappedSlot == null || !aliasTransferMap.containsKey(unwrappedSlot)) { + return; + } + Slot cteSlot = aliasTransferMap.get(unwrappedSlot).second; + PhysicalRelation cteNode = aliasTransferMap.get(unwrappedSlot).first; + long buildSideNdv = getBuildSideNdv(join, equalTo); + if (cteNode instanceof PhysicalCTEConsumer && inputPlanNode instanceof PhysicalProject) { + PhysicalProject project = (PhysicalProject) inputPlanNode; + NamedExpression targetExpr = null; + for (Object column : project.getProjects()) { + NamedExpression alias = (NamedExpression) column; + if (cteSlot.getName().equals(alias.getName())) { + targetExpr = alias; + break; + } + } + Preconditions.checkState(targetExpr != null); + if (!(targetExpr instanceof SlotReference)) { + // if not SlotReference, skip the push down + return; + } else if (!checkCanPushDownIntoBasicTable(project)) { + return; + } else { + Map pushDownBasicTableInfos = getPushDownBasicTablesInfos(project, + (SlotReference) targetExpr, aliasTransferMap); + if (!pushDownBasicTableInfos.isEmpty()) { + List targetList = new ArrayList<>(); + for (Map.Entry entry : pushDownBasicTableInfos.entrySet()) { + Slot targetSlot = entry.getKey(); + PhysicalOlapScan scan = entry.getValue(); + targetList.add(targetSlot); + ctx.addJoinToTargetMap(join, targetSlot.getExprId()); + ctx.setTargetsOnScanNode(scan.getId(), targetSlot); + } + // build multi-target runtime filter + RuntimeFilter filter = new RuntimeFilter(generator.getNextId(), + equalTo.right(), targetList, type, exprOrder, join, buildSideNdv); + for (Slot slot : targetList) { + ctx.setTargetExprIdToFilter(slot.getExprId(), filter); + } + } + } + } + } + + private boolean checkCanPushDownIntoBasicTable(PhysicalPlan root) { + // only support spj currently + List plans = Lists.newArrayList(); + plans.addAll(root.collect(PhysicalPlan.class::isInstance)); + return plans.stream().allMatch(p -> SPJ_PLAN.stream().anyMatch(c -> c.isInstance(p))); + } + + private Map getPushDownBasicTablesInfos(PhysicalPlan root, SlotReference slot, + Map> aliasTransferMap) { + Map basicTableInfos = new HashMap<>(); + Set joins = new HashSet<>(); + ExprId exprId = slot.getExprId(); + if (aliasTransferMap.get(slot) != null && aliasTransferMap.get(slot).first instanceof PhysicalOlapScan) { + basicTableInfos.put(slot, (PhysicalOlapScan) aliasTransferMap.get(slot).first); + } + // try to find propagation condition from join + getAllJoinInfo(root, joins); + for (PhysicalHashJoin join : joins) { + List conditions = join.getHashJoinConjuncts(); + for (Expression equalTo : conditions) { + if (equalTo instanceof EqualTo) { + SlotReference leftSlot = (SlotReference) ((EqualTo) equalTo).left(); + SlotReference rightSlot = (SlotReference) ((EqualTo) equalTo).right(); + if (leftSlot.getExprId() == exprId) { + PhysicalOlapScan rightTable = (PhysicalOlapScan) aliasTransferMap.get(rightSlot).first; + if (rightTable != null) { + basicTableInfos.put(rightSlot, rightTable); + } + } else if (rightSlot.getExprId() == exprId) { + PhysicalOlapScan leftTable = (PhysicalOlapScan) aliasTransferMap.get(leftSlot).first; + if (leftTable != null) { + basicTableInfos.put(leftSlot, leftTable); + } + } + } + } + } + return basicTableInfos; + } + + private void getAllJoinInfo(PhysicalPlan root, Set joins) { + if (root instanceof PhysicalHashJoin) { + joins.add((PhysicalHashJoin) root); + } else { + for (Object child : root.children()) { + getAllJoinInfo((PhysicalPlan) child, joins); + } + } + } + + private boolean hasRemoteTarget(AbstractPlan join, AbstractPlan scan) { + if (scan instanceof PhysicalCTEConsumer) { + return true; + } else { + Preconditions.checkArgument(join.getMutableState(AbstractPlan.FRAGMENT_ID).isPresent(), + "cannot find fragment id for Join node"); + Preconditions.checkArgument(scan.getMutableState(AbstractPlan.FRAGMENT_ID).isPresent(), + "cannot find fragment id for scan node"); + return join.getMutableState(AbstractPlan.FRAGMENT_ID).get() + != scan.getMutableState(AbstractPlan.FRAGMENT_ID).get(); + } + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalCTEConsumer.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalCTEConsumer.java index b9215e70fb..b75b932cd8 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalCTEConsumer.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalCTEConsumer.java @@ -17,6 +17,8 @@ package org.apache.doris.nereids.trees.plans.physical; +import org.apache.doris.catalog.OlapTable; +import org.apache.doris.nereids.exceptions.TransformException; import org.apache.doris.nereids.memo.GroupExpression; import org.apache.doris.nereids.properties.LogicalProperties; import org.apache.doris.nereids.properties.PhysicalProperties; @@ -26,6 +28,7 @@ import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.PlanType; import org.apache.doris.nereids.trees.plans.visitor.PlanVisitor; +import org.apache.doris.nereids.util.RelationUtil; import org.apache.doris.nereids.util.Utils; import org.apache.doris.statistics.Statistics; @@ -41,7 +44,7 @@ import java.util.Optional; /** * Physical CTE consumer. */ -public class PhysicalCTEConsumer extends PhysicalLeaf { +public class PhysicalCTEConsumer extends PhysicalRelation { private final CTEId cteId; private final Map producerToConsumerSlotMap; @@ -77,12 +80,23 @@ public class PhysicalCTEConsumer extends PhysicalLeaf { LogicalProperties logicalProperties, PhysicalProperties physicalProperties, Statistics statistics) { - super(PlanType.PHYSICAL_CTE_CONSUME, groupExpression, logicalProperties, physicalProperties, statistics); + super(RelationUtil.newRelationId(), PlanType.PHYSICAL_CTE_CONSUME, ImmutableList.of(), groupExpression, + logicalProperties, physicalProperties, statistics); this.cteId = cteId; this.consumerToProducerSlotMap = ImmutableMap.copyOf(consumerToProducerSlotMap); this.producerToConsumerSlotMap = ImmutableMap.copyOf(producerToConsumerSlotMap); } + @Override + public OlapTable getTable() { + throw new TransformException("should not reach here"); + } + + @Override + public List getQualifier() { + throw new TransformException("should not reach here"); + } + public CTEId getCteId() { return cteId; } diff --git a/fe/fe-core/src/main/java/org/apache/doris/planner/MultiCastPlanFragment.java b/fe/fe-core/src/main/java/org/apache/doris/planner/MultiCastPlanFragment.java index b49906fecb..0d5b54b269 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/planner/MultiCastPlanFragment.java +++ b/fe/fe-core/src/main/java/org/apache/doris/planner/MultiCastPlanFragment.java @@ -29,7 +29,8 @@ public class MultiCastPlanFragment extends PlanFragment { private final List destNodeList = Lists.newArrayList(); public MultiCastPlanFragment(PlanFragment planFragment) { - super(planFragment.getFragmentId(), planFragment.getPlanRoot(), planFragment.getDataPartition()); + super(planFragment.getFragmentId(), planFragment.getPlanRoot(), planFragment.getDataPartition(), + planFragment.getBuilderRuntimeFilterIds(), planFragment.getTargetRuntimeFilterIds()); this.outputPartition = DataPartition.RANDOM; this.children.addAll(planFragment.getChildren()); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/planner/PlanFragment.java b/fe/fe-core/src/main/java/org/apache/doris/planner/PlanFragment.java index 3511ccf6ff..54903beae5 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/planner/PlanFragment.java +++ b/fe/fe-core/src/main/java/org/apache/doris/planner/PlanFragment.java @@ -33,6 +33,7 @@ import org.apache.doris.thrift.TPartitionType; import org.apache.doris.thrift.TPlanFragment; import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableSet; import com.google.common.collect.Lists; import org.apache.commons.collections.CollectionUtils; import org.apache.logging.log4j.LogManager; @@ -167,6 +168,13 @@ public class PlanFragment extends TreeNode { this.dataPartitionForThrift = partitionForThrift; } + public PlanFragment(PlanFragmentId id, PlanNode root, DataPartition partition, + Set builderRuntimeFilterIds, Set targetRuntimeFilterIds) { + this(id, root, partition); + this.builderRuntimeFilterIds = ImmutableSet.copyOf(builderRuntimeFilterIds); + this.targetRuntimeFilterIds = ImmutableSet.copyOf(targetRuntimeFilterIds); + } + /** * Assigns 'this' as fragment of all PlanNodes in the plan tree rooted at node. * Does not traverse the children of ExchangeNodes because those must belong to a