[improvement](nereids) Support rf into cte (#21114)

Support runtime filter pushing down into cte internal.
This commit is contained in:
xzj7019
2023-06-29 16:58:31 +08:00
committed by GitHub
parent 64e9eab0dd
commit 59198ed59e
7 changed files with 548 additions and 111 deletions

View File

@ -844,6 +844,9 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla
cteProduce.setOutputExprs(outputs);
context.getCteProduceFragments().put(cteId, cteProduce);
context.getCteProduceMap().put(cteId, cteProducer);
if (context.getRuntimeTranslator().isPresent()) {
context.getRuntimeTranslator().get().getContext().getCteProduceMap().put(cteId, cteProducer);
}
context.getPlanFragments().add(cteProduce);
return child;
}

View File

@ -64,6 +64,10 @@ public class RuntimeFilterTranslator {
return context.getRuntimeFilterOnHashJoinNode(join);
}
public RuntimeFilterContext getContext() {
return context;
}
public List<Slot> getTargetOnScanNode(ObjectId id) {
return context.getTargetOnOlapScanNodeMap().getOrDefault(id, Collections.emptyList());
}

View File

@ -20,6 +20,8 @@ package org.apache.doris.nereids.processor.post;
import org.apache.doris.analysis.SlotRef;
import org.apache.doris.common.IdGenerator;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.trees.expressions.CTEId;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
@ -27,6 +29,7 @@ import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.plans.ObjectId;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.physical.AbstractPhysicalJoin;
import org.apache.doris.nereids.trees.plans.physical.PhysicalCTEProducer;
import org.apache.doris.nereids.trees.plans.physical.PhysicalHashJoin;
import org.apache.doris.nereids.trees.plans.physical.PhysicalRelation;
import org.apache.doris.nereids.trees.plans.physical.RuntimeFilter;
@ -77,6 +80,21 @@ public class RuntimeFilterContext {
private final Map<Slot, ScanNode> scanNodeOfLegacyRuntimeFilterTarget = Maps.newHashMap();
private final Set<Plan> effectiveSrcNodes = Sets.newHashSet();
// cte to related joins map which can extract common runtime filter to cte inside
private final Map<CTEId, Set<PhysicalHashJoin>> cteToJoinsMap = Maps.newHashMap();
// cte candidates which can be pushed into common runtime filter into from outside
private final Map<PhysicalCTEProducer, Map<EqualTo, PhysicalHashJoin>> cteRFPushDownMap = Maps.newHashMap();
private final Map<CTEId, PhysicalCTEProducer> cteProducerMap = Maps.newHashMap();
// cte whose runtime filter has been extracted
private final Set<CTEId> processedCTE = Sets.newHashSet();
// cte whose outer runtime filter has been pushed down into
private final Set<CTEId> pushedDownCTE = Sets.newHashSet();
private final SessionVariable sessionVariable;
private final FilterSizeLimits limits;
@ -96,6 +114,26 @@ public class RuntimeFilterContext {
return limits;
}
public Map<CTEId, PhysicalCTEProducer> getCteProduceMap() {
return cteProducerMap;
}
public Map<PhysicalCTEProducer, Map<EqualTo, PhysicalHashJoin>> getCteRFPushDownMap() {
return cteRFPushDownMap;
}
public Map<CTEId, Set<PhysicalHashJoin>> getCteToJoinsMap() {
return cteToJoinsMap;
}
public Set<CTEId> getProcessedCTE() {
return processedCTE;
}
public Set<CTEId> getPushedDownCTE() {
return pushedDownCTE;
}
public void setTargetExprIdToFilter(ExprId id, RuntimeFilter filter) {
Preconditions.checkArgument(filter.getTargetExprs().stream().anyMatch(expr -> expr.getExprId() == id));
this.targetExprIdToFilter.computeIfAbsent(id, k -> Lists.newArrayList()).add(filter);

View File

@ -22,7 +22,9 @@ import org.apache.doris.common.Pair;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.stats.ExpressionEstimation;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.CTEId;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Not;
@ -32,10 +34,16 @@ import org.apache.doris.nereids.trees.expressions.functions.scalar.BitmapContain
import org.apache.doris.nereids.trees.plans.AbstractPlan;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalCTEConsumer;
import org.apache.doris.nereids.trees.plans.physical.PhysicalCTEConsumer;
import org.apache.doris.nereids.trees.plans.physical.PhysicalCTEProducer;
import org.apache.doris.nereids.trees.plans.physical.PhysicalDistribute;
import org.apache.doris.nereids.trees.plans.physical.PhysicalExcept;
import org.apache.doris.nereids.trees.plans.physical.PhysicalFilter;
import org.apache.doris.nereids.trees.plans.physical.PhysicalHashJoin;
import org.apache.doris.nereids.trees.plans.physical.PhysicalIntersect;
import org.apache.doris.nereids.trees.plans.physical.PhysicalNestedLoopJoin;
import org.apache.doris.nereids.trees.plans.physical.PhysicalOlapScan;
import org.apache.doris.nereids.trees.plans.physical.PhysicalPlan;
import org.apache.doris.nereids.trees.plans.physical.PhysicalProject;
import org.apache.doris.nereids.trees.plans.physical.PhysicalRelation;
@ -50,9 +58,13 @@ import org.apache.doris.thrift.TRuntimeFilterType;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
@ -70,6 +82,14 @@ public class RuntimeFilterGenerator extends PlanPostProcessor {
JoinType.NULL_AWARE_LEFT_ANTI_JOIN
);
private static final Set<Class<? extends PhysicalPlan>> SPJ_PLAN = ImmutableSet.of(
PhysicalOlapScan.class,
PhysicalProject.class,
PhysicalFilter.class,
PhysicalDistribute.class,
PhysicalHashJoin.class
);
private final IdGenerator<RuntimeFilterId> generator = RuntimeFilterId.createGenerator();
/**
@ -98,121 +118,28 @@ public class RuntimeFilterGenerator extends PlanPostProcessor {
Set<Slot> slots = join.getOutputSet();
slots.forEach(aliasTransferMap::remove);
} else {
List<TRuntimeFilterType> legalTypes = Arrays.stream(TRuntimeFilterType.values())
.filter(type -> (type.getValue() & ctx.getSessionVariable().getRuntimeFilterType()) > 0)
.collect(Collectors.toList());
// TODO: some complex situation cannot be handled now, see testPushDownThroughJoin.
// we will support it in later version.
for (int i = 0; i < join.getHashJoinConjuncts().size(); i++) {
EqualTo equalTo = ((EqualTo) JoinUtils.swapEqualToForChildrenOrder(
(EqualTo) join.getHashJoinConjuncts().get(i), join.left().getOutputSet()));
for (TRuntimeFilterType type : legalTypes) {
//bitmap rf is generated by nested loop join.
if (type == TRuntimeFilterType.BITMAP) {
continue;
}
if (join.left() instanceof PhysicalUnion
|| join.left() instanceof PhysicalIntersect
|| join.left() instanceof PhysicalExcept) {
List<Slot> targetList = new ArrayList<>();
int projIndex = -1;
for (int j = 0; j < join.left().children().size(); j++) {
PhysicalPlan child = (PhysicalPlan) join.left().child(j);
if (child instanceof PhysicalProject) {
PhysicalProject project = (PhysicalProject) child;
Slot leftSlot = checkTargetChild(equalTo.left());
if (leftSlot == null) {
break;
}
for (int k = 0; projIndex < 0 && k < project.getProjects().size(); k++) {
NamedExpression expr = (NamedExpression) project.getProjects().get(k);
if (expr.getName().equals(leftSlot.getName())) {
projIndex = k;
break;
}
}
Preconditions.checkState(projIndex >= 0
&& projIndex < project.getProjects().size());
NamedExpression targetExpr = (NamedExpression) project.getProjects().get(projIndex);
SlotReference origSlot = null;
if (targetExpr instanceof Alias) {
origSlot = (SlotReference) targetExpr.child(0);
} else {
origSlot = (SlotReference) targetExpr;
}
if (!aliasTransferMap.containsKey(origSlot)) {
continue;
}
Slot olapScanSlot = aliasTransferMap.get(origSlot).second;
PhysicalRelation scan = aliasTransferMap.get(origSlot).first;
if (type == TRuntimeFilterType.IN_OR_BLOOM
&& ctx.getSessionVariable().enablePipelineEngine()
&& hasRemoteTarget(join, scan)) {
type = TRuntimeFilterType.BLOOM;
}
targetList.add(olapScanSlot);
ctx.addJoinToTargetMap(join, olapScanSlot.getExprId());
ctx.setTargetsOnScanNode(aliasTransferMap.get(origSlot).first.getId(), olapScanSlot);
}
}
if (!targetList.isEmpty()) {
long buildSideNdv = getBuildSideNdv(join, equalTo);
RuntimeFilter filter = new RuntimeFilter(generator.getNextId(),
equalTo.right(), targetList, type, i, join, buildSideNdv);
for (int j = 0; j < targetList.size(); j++) {
ctx.setTargetExprIdToFilter(targetList.get(j).getExprId(), filter);
}
}
} else {
// currently, we can ensure children in the two side are corresponding to the equal_to's.
// so right maybe an expression and left is a slot
Slot unwrappedSlot = checkTargetChild(equalTo.left());
// aliasTransMap doesn't contain the key, means that the path from the olap scan to the join
// contains join with denied join type. for example: a left join b on a.id = b.id
if (unwrappedSlot == null || !aliasTransferMap.containsKey(unwrappedSlot)) {
continue;
}
Slot olapScanSlot = aliasTransferMap.get(unwrappedSlot).second;
PhysicalRelation scan = aliasTransferMap.get(unwrappedSlot).first;
// in-filter is not friendly to pipeline
if (type == TRuntimeFilterType.IN_OR_BLOOM
&& ctx.getSessionVariable().enablePipelineEngine()
&& hasRemoteTarget(join, scan)) {
type = TRuntimeFilterType.BLOOM;
}
long buildSideNdv = getBuildSideNdv(join, equalTo);
RuntimeFilter filter = new RuntimeFilter(generator.getNextId(),
equalTo.right(), ImmutableList.of(olapScanSlot), type, i, join, buildSideNdv);
ctx.addJoinToTargetMap(join, olapScanSlot.getExprId());
ctx.setTargetExprIdToFilter(olapScanSlot.getExprId(), filter);
ctx.setTargetsOnScanNode(aliasTransferMap.get(unwrappedSlot).first.getId(), olapScanSlot);
}
}
collectPushDownCTEInfos(join, context);
if (!getPushDownCTECandidates(ctx).isEmpty()) {
pushDownRuntimeFilterIntoCTE(ctx);
} else {
pushDownRuntimeFilterCommon(join, context);
}
}
return join;
}
private boolean hasRemoteTarget(AbstractPlan join, AbstractPlan scan) {
Preconditions.checkArgument(join.getMutableState(AbstractPlan.FRAGMENT_ID).isPresent(),
"cannot find fragment id for Join node");
Preconditions.checkArgument(scan.getMutableState(AbstractPlan.FRAGMENT_ID).isPresent(),
"cannot find fragment id for scan node");
return join.getMutableState(AbstractPlan.FRAGMENT_ID).get()
!= scan.getMutableState(AbstractPlan.FRAGMENT_ID).get();
@Override
public PhysicalCTEConsumer visitPhysicalCTEConsumer(PhysicalCTEConsumer scan, CascadesContext context) {
RuntimeFilterContext ctx = context.getRuntimeFilterContext();
scan.getOutput().forEach(slot -> ctx.getAliasTransferMap().put(slot, Pair.of(scan, slot)));
return scan;
}
private long getBuildSideNdv(PhysicalHashJoin<? extends Plan, ? extends Plan> join, EqualTo equalTo) {
AbstractPlan right = (AbstractPlan) join.right();
//make ut test friendly
if (right.getStats() == null) {
return -1L;
}
ExpressionEstimation estimator = new ExpressionEstimation();
ColumnStatistic buildColStats = equalTo.right().accept(estimator, right.getStats());
return buildColStats.isUnKnown ? -1 : Math.max(1, (long) buildColStats.ndv);
@Override
public PhysicalCTEProducer visitPhysicalCTEProducer(PhysicalCTEProducer producer, CascadesContext context) {
CTEId id = producer.getCteId();
context.getRuntimeFilterContext().getCteProduceMap().put(id, producer);
return producer;
}
@Override
@ -297,8 +224,450 @@ public class RuntimeFilterGenerator extends PlanPostProcessor {
return scan;
}
private long getBuildSideNdv(PhysicalHashJoin<? extends Plan, ? extends Plan> join, EqualTo equalTo) {
AbstractPlan right = (AbstractPlan) join.right();
//make ut test friendly
if (right.getStats() == null) {
return -1L;
}
ExpressionEstimation estimator = new ExpressionEstimation();
ColumnStatistic buildColStats = equalTo.right().accept(estimator, right.getStats());
return buildColStats.isUnKnown ? -1 : Math.max(1, (long) buildColStats.ndv);
}
private static Slot checkTargetChild(Expression leftChild) {
Expression expression = ExpressionUtils.getExpressionCoveredByCast(leftChild);
return expression instanceof Slot ? ((Slot) expression) : null;
}
private void pushDownRuntimeFilterCommon(PhysicalHashJoin<? extends Plan, ? extends Plan> join,
CascadesContext context) {
RuntimeFilterContext ctx = context.getRuntimeFilterContext();
List<TRuntimeFilterType> legalTypes = Arrays.stream(TRuntimeFilterType.values())
.filter(type -> (type.getValue() & ctx.getSessionVariable().getRuntimeFilterType()) > 0)
.collect(Collectors.toList());
// TODO: some complex situation cannot be handled now, see testPushDownThroughJoin.
// we will support it in later version.
for (int i = 0; i < join.getHashJoinConjuncts().size(); i++) {
EqualTo equalTo = ((EqualTo) JoinUtils.swapEqualToForChildrenOrder(
(EqualTo) join.getHashJoinConjuncts().get(i), join.left().getOutputSet()));
for (TRuntimeFilterType type : legalTypes) {
//bitmap rf is generated by nested loop join.
if (type == TRuntimeFilterType.BITMAP) {
continue;
}
if (join.left() instanceof PhysicalUnion
|| join.left() instanceof PhysicalIntersect
|| join.left() instanceof PhysicalExcept) {
doPushDownIntoSetOperation(join, ctx, equalTo, type, i);
} else {
doPushDownBasic(join, context, ctx, equalTo, type, i);
}
}
}
}
private void doPushDownBasic(PhysicalHashJoin<? extends Plan, ? extends Plan> join, CascadesContext context,
RuntimeFilterContext ctx, EqualTo equalTo, TRuntimeFilterType type, int exprOrder) {
Map<NamedExpression, Pair<PhysicalRelation, Slot>> aliasTransferMap = ctx.getAliasTransferMap();
// currently, we can ensure children in the two side are corresponding to the equal_to's.
// so right maybe an expression and left is a slot
Slot unwrappedSlot = checkTargetChild(equalTo.left());
// aliasTransMap doesn't contain the key, means that the path from the olap scan to the join
// contains join with denied join type. for example: a left join b on a.id = b.id
if (unwrappedSlot == null || !aliasTransferMap.containsKey(unwrappedSlot)) {
return;
}
Slot olapScanSlot = aliasTransferMap.get(unwrappedSlot).second;
PhysicalRelation scan = aliasTransferMap.get(unwrappedSlot).first;
Preconditions.checkState(olapScanSlot != null && scan != null);
if (scan instanceof PhysicalCTEConsumer) {
Set<CTEId> processedCTE = context.getRuntimeFilterContext().getProcessedCTE();
CTEId cteId = ((PhysicalCTEConsumer) scan).getCteId();
if (!processedCTE.contains(cteId)) {
PhysicalCTEProducer cteProducer = context.getRuntimeFilterContext()
.getCteProduceMap().get(cteId);
PhysicalPlan inputPlanNode = (PhysicalPlan) cteProducer.child(0);
// process cte producer self recursively
inputPlanNode.accept(this, context);
processedCTE.add(cteId);
}
} else {
// in-filter is not friendly to pipeline
if (type == TRuntimeFilterType.IN_OR_BLOOM
&& ctx.getSessionVariable().enablePipelineEngine()
&& hasRemoteTarget(join, scan)) {
type = TRuntimeFilterType.BLOOM;
}
long buildSideNdv = getBuildSideNdv(join, equalTo);
RuntimeFilter filter = new RuntimeFilter(generator.getNextId(),
equalTo.right(), ImmutableList.of(olapScanSlot), type, exprOrder, join, buildSideNdv);
ctx.addJoinToTargetMap(join, olapScanSlot.getExprId());
ctx.setTargetExprIdToFilter(olapScanSlot.getExprId(), filter);
ctx.setTargetsOnScanNode(aliasTransferMap.get(unwrappedSlot).first.getId(), olapScanSlot);
}
}
private void doPushDownIntoSetOperation(PhysicalHashJoin<? extends Plan, ? extends Plan> join,
RuntimeFilterContext ctx, EqualTo equalTo, TRuntimeFilterType type, int exprOrder) {
Map<NamedExpression, Pair<PhysicalRelation, Slot>> aliasTransferMap = ctx.getAliasTransferMap();
List<Slot> targetList = new ArrayList<>();
int projIndex = -1;
for (int j = 0; j < join.left().children().size(); j++) {
PhysicalPlan child = (PhysicalPlan) join.left().child(j);
if (child instanceof PhysicalProject) {
PhysicalProject project = (PhysicalProject) child;
Slot leftSlot = checkTargetChild(equalTo.left());
if (leftSlot == null) {
break;
}
for (int k = 0; projIndex < 0 && k < project.getProjects().size(); k++) {
NamedExpression expr = (NamedExpression) project.getProjects().get(k);
if (expr.getName().equals(leftSlot.getName())) {
projIndex = k;
break;
}
}
Preconditions.checkState(projIndex >= 0
&& projIndex < project.getProjects().size());
NamedExpression targetExpr = (NamedExpression) project.getProjects().get(projIndex);
SlotReference origSlot = null;
if (targetExpr instanceof Alias) {
origSlot = (SlotReference) targetExpr.child(0);
} else {
origSlot = (SlotReference) targetExpr;
}
Slot olapScanSlot = aliasTransferMap.get(origSlot).second;
PhysicalRelation scan = aliasTransferMap.get(origSlot).first;
if (type == TRuntimeFilterType.IN_OR_BLOOM
&& ctx.getSessionVariable().enablePipelineEngine()
&& hasRemoteTarget(join, scan)) {
type = TRuntimeFilterType.BLOOM;
}
targetList.add(olapScanSlot);
ctx.addJoinToTargetMap(join, olapScanSlot.getExprId());
ctx.setTargetsOnScanNode(aliasTransferMap.get(origSlot).first.getId(), olapScanSlot);
}
}
if (!targetList.isEmpty()) {
long buildSideNdv = getBuildSideNdv(join, equalTo);
RuntimeFilter filter = new RuntimeFilter(generator.getNextId(),
equalTo.right(), targetList, type, exprOrder, join, buildSideNdv);
for (int j = 0; j < targetList.size(); j++) {
ctx.setTargetExprIdToFilter(targetList.get(j).getExprId(), filter);
}
}
}
private void collectPushDownCTEInfos(PhysicalHashJoin<? extends Plan, ? extends Plan> join,
CascadesContext context) {
RuntimeFilterContext ctx = context.getRuntimeFilterContext();
Set<CTEId> cteIds = new HashSet<>();
PhysicalPlan leftChild = (PhysicalPlan) join.left();
PhysicalPlan rightChild = (PhysicalPlan) join.right();
Preconditions.checkState(leftChild != null && rightChild != null);
boolean leftHasCTE = hasCTEConsumerUnderJoin(leftChild, cteIds);
boolean rightHasCTE = hasCTEConsumerUnderJoin(rightChild, cteIds);
// only support single cte in join currently
if ((leftHasCTE && !rightHasCTE) || (!leftHasCTE && rightHasCTE)) {
for (CTEId id : cteIds) {
if (ctx.getCteToJoinsMap().get(id) == null) {
Set<PhysicalHashJoin> newJoin = new HashSet<>();
newJoin.add(join);
ctx.getCteToJoinsMap().put(id, newJoin);
} else {
ctx.getCteToJoinsMap().get(id).add(join);
}
}
}
if (!ctx.getCteToJoinsMap().isEmpty()) {
analyzeRuntimeFilterPushDownIntoCTEInfos(join, context);
}
}
private List<CTEId> getPushDownCTECandidates(RuntimeFilterContext ctx) {
List<CTEId> candidates = new ArrayList<>();
Map<PhysicalCTEProducer, Map<EqualTo, PhysicalHashJoin>> cteRFPushDownMap = ctx.getCteRFPushDownMap();
for (Map.Entry<PhysicalCTEProducer, Map<EqualTo, PhysicalHashJoin>> entry : cteRFPushDownMap.entrySet()) {
CTEId cteId = entry.getKey().getCteId();
if (ctx.getPushedDownCTE().contains(cteId)) {
continue;
}
candidates.add(cteId);
}
return candidates;
}
private boolean hasCTEConsumerUnderJoin(PhysicalPlan root, Set<CTEId> cteIds) {
if (root instanceof PhysicalCTEConsumer) {
cteIds.add(((PhysicalCTEConsumer) root).getCteId());
return true;
} else if (root.children().size() != 1) {
// only collect cte in one side
return false;
} else if (root instanceof PhysicalDistribute
|| root instanceof PhysicalFilter
|| root instanceof PhysicalProject) {
// only collect cte as single child node under join
return hasCTEConsumerUnderJoin((PhysicalPlan) root.child(0), cteIds);
} else {
return false;
}
}
private void analyzeRuntimeFilterPushDownIntoCTEInfos(PhysicalHashJoin<? extends Plan, ? extends Plan> curJoin,
CascadesContext context) {
RuntimeFilterContext ctx = context.getRuntimeFilterContext();
Map<CTEId, Set<PhysicalHashJoin>> cteToJoinsMap = ctx.getCteToJoinsMap();
for (Map.Entry<CTEId, Set<PhysicalHashJoin>> entry : cteToJoinsMap.entrySet()) {
CTEId cteId = entry.getKey();
Set<PhysicalHashJoin> joinSet = entry.getValue();
if (joinSet.contains(curJoin)) {
// skip current join
continue;
}
Set<LogicalCTEConsumer> cteSet = context.getCteIdToConsumers().get(cteId);
Preconditions.checkState(!cteSet.isEmpty());
String cteName = cteSet.iterator().next().getName();
// preconditions for rf pushing into cte producer:
// multiple joins whose join condition is on the same cte's column of the same cte
// the other side of these join conditions are the same column of the same table, or
// they in the same equal sets, such as under an equal join condition
// case 1: two joins with t1.c1 = cte1_consumer1.c1 and t1.c1 = cte1_consumer2.c1 conditions
// rf of t1.c1 can be pushed down into cte1 producer.
// ----------------------hashJoin(t1.c1 = cte2_consumer1.c1)
// ----------------------------CteConsumer[cteId= ( CTEId#1=] )
// ----------------------------PhysicalOlapScan[t1]
// ----------------------hashJoin(t1.c1 = cte2_consumer2.c1)
// ----------------------------CteConsumer[cteId= ( CTEId#1=] )
// ----------------------------PhysicalOlapScan[t1]
// case 2: two joins with t1.c1 = cte2_consumer1.c1 and t2.c2 = cte2_consumer2.c1 and another equal join
// condition t1.c1 = t2.c2, which means t1.c1 and t2.c2 are in the same equal set.
// rf of t1.c1 and t2.c2 can be pushed down into cte2 producer.
// --------------------hashJoin(t1.c1 = t2.c2)
// ----------------------hashJoin(t2.c2 = cte2_consumer1.c1)
// ----------------------------CteConsumer[cteId= ( CTEId#1=] )
// ----------------------------PhysicalOlapScan[t2]
// ----------------------hashJoin(t1.c1 = cte2_consumer2.c1)
// ----------------------------CteConsumer[cteId= ( CTEId#1=] )
// ----------------------------PhysicalOlapScan[t1]
if (joinSet.size() != cteSet.size()) {
continue;
}
List<EqualTo> equalTos = new ArrayList<>();
Map<EqualTo, PhysicalHashJoin> equalCondToJoinMap = new LinkedHashMap<>();
for (PhysicalHashJoin join : joinSet) {
// precondition:
// 1. no non-equal join condition
// 2. only equalTo and slotReference both sides
// 3. only support one join condition (will be refined further)
if (join.getOtherJoinConjuncts().size() > 1
|| join.getHashJoinConjuncts().size() != 1
|| !(join.getHashJoinConjuncts().get(0) instanceof EqualTo)) {
break;
} else {
EqualTo equalTo = (EqualTo) join.getHashJoinConjuncts().get(0);
equalTos.add(equalTo);
equalCondToJoinMap.put(equalTo, join);
}
}
if (joinSet.size() == equalTos.size()) {
int matchNum = 0;
Set<String> cteNameSet = new HashSet<>();
Set<SlotReference> anotherSideSlotSet = new HashSet<>();
for (EqualTo equalTo : equalTos) {
SlotReference left = (SlotReference) equalTo.left();
SlotReference right = (SlotReference) equalTo.right();
if (left.getQualifier().size() == 1 && left.getQualifier().get(0).equals(cteName)) {
matchNum += 1;
anotherSideSlotSet.add(right);
cteNameSet.add(left.getQualifiedName());
} else if (right.getQualifier().size() == 1 && right.getQualifier().get(0).equals(cteName)) {
matchNum += 1;
anotherSideSlotSet.add(left);
cteNameSet.add(right.getQualifiedName());
}
}
if (matchNum == equalTos.size() && cteNameSet.size() == 1) {
// means all join condition points to the same cte on the same cte column.
// collect the other side columns besides cte column side.
Preconditions.checkState(equalTos.size() == equalCondToJoinMap.size(),
"equalTos.size() != equalCondToJoinMap.size()");
PhysicalCTEProducer cteProducer = context.getRuntimeFilterContext().getCteProduceMap().get(cteId);
if (anotherSideSlotSet.size() == 1) {
// meet requirement for pushing down into cte producer
ctx.getCteRFPushDownMap().put(cteProducer, equalCondToJoinMap);
} else {
// check further whether the join upper side can bring equal set, which
// indicating actually the same runtime filter build side
// see above case 2 for reference
List<Expression> conditions = curJoin.getHashJoinConjuncts();
boolean inSameEqualSet = false;
for (Expression e : conditions) {
if (e instanceof EqualTo) {
SlotReference oneSide = (SlotReference) ((EqualTo) e).left();
SlotReference anotherSide = (SlotReference) ((EqualTo) e).right();
if (anotherSideSlotSet.contains(oneSide) && anotherSideSlotSet.contains(anotherSide)) {
inSameEqualSet = true;
break;
}
}
}
if (inSameEqualSet) {
ctx.getCteRFPushDownMap().put(cteProducer, equalCondToJoinMap);
}
}
}
}
}
}
private void pushDownRuntimeFilterIntoCTE(RuntimeFilterContext ctx) {
Map<PhysicalCTEProducer, Map<EqualTo, PhysicalHashJoin>> cteRFPushDownMap = ctx.getCteRFPushDownMap();
for (Map.Entry<PhysicalCTEProducer, Map<EqualTo, PhysicalHashJoin>> entry : cteRFPushDownMap.entrySet()) {
PhysicalCTEProducer cteProducer = entry.getKey();
Preconditions.checkState(cteProducer != null);
if (ctx.getPushedDownCTE().contains(cteProducer.getCteId())) {
continue;
}
Map<EqualTo, PhysicalHashJoin> equalCondToJoinMap = entry.getValue();
int exprOrder = 0;
for (Map.Entry<EqualTo, PhysicalHashJoin> innerEntry : equalCondToJoinMap.entrySet()) {
EqualTo equalTo = innerEntry.getKey();
PhysicalHashJoin join = innerEntry.getValue();
Preconditions.checkState(join != null);
TRuntimeFilterType type = TRuntimeFilterType.IN_OR_BLOOM;
if (ctx.getSessionVariable().enablePipelineEngine()) {
type = TRuntimeFilterType.BLOOM;
}
EqualTo newEqualTo = ((EqualTo) JoinUtils.swapEqualToForChildrenOrder(
equalTo, join.child(0).getOutputSet()));
doPushDownIntoCTEProducerInternal(join, ctx, newEqualTo, type, exprOrder++, cteProducer);
}
ctx.getPushedDownCTE().add(cteProducer.getCteId());
}
}
private void doPushDownIntoCTEProducerInternal(PhysicalHashJoin<? extends Plan, ? extends Plan> join,
RuntimeFilterContext ctx, EqualTo equalTo, TRuntimeFilterType type, int exprOrder,
PhysicalCTEProducer cteProducer) {
Map<NamedExpression, Pair<PhysicalRelation, Slot>> aliasTransferMap = ctx.getAliasTransferMap();
PhysicalPlan inputPlanNode = (PhysicalPlan) cteProducer.child(0);
Slot unwrappedSlot = checkTargetChild(equalTo.left());
// aliasTransMap doesn't contain the key, means that the path from the olap scan to the join
// contains join with denied join type. for example: a left join b on a.id = b.id
if (unwrappedSlot == null || !aliasTransferMap.containsKey(unwrappedSlot)) {
return;
}
Slot cteSlot = aliasTransferMap.get(unwrappedSlot).second;
PhysicalRelation cteNode = aliasTransferMap.get(unwrappedSlot).first;
long buildSideNdv = getBuildSideNdv(join, equalTo);
if (cteNode instanceof PhysicalCTEConsumer && inputPlanNode instanceof PhysicalProject) {
PhysicalProject project = (PhysicalProject) inputPlanNode;
NamedExpression targetExpr = null;
for (Object column : project.getProjects()) {
NamedExpression alias = (NamedExpression) column;
if (cteSlot.getName().equals(alias.getName())) {
targetExpr = alias;
break;
}
}
Preconditions.checkState(targetExpr != null);
if (!(targetExpr instanceof SlotReference)) {
// if not SlotReference, skip the push down
return;
} else if (!checkCanPushDownIntoBasicTable(project)) {
return;
} else {
Map<Slot, PhysicalOlapScan> pushDownBasicTableInfos = getPushDownBasicTablesInfos(project,
(SlotReference) targetExpr, aliasTransferMap);
if (!pushDownBasicTableInfos.isEmpty()) {
List<Slot> targetList = new ArrayList<>();
for (Map.Entry<Slot, PhysicalOlapScan> entry : pushDownBasicTableInfos.entrySet()) {
Slot targetSlot = entry.getKey();
PhysicalOlapScan scan = entry.getValue();
targetList.add(targetSlot);
ctx.addJoinToTargetMap(join, targetSlot.getExprId());
ctx.setTargetsOnScanNode(scan.getId(), targetSlot);
}
// build multi-target runtime filter
RuntimeFilter filter = new RuntimeFilter(generator.getNextId(),
equalTo.right(), targetList, type, exprOrder, join, buildSideNdv);
for (Slot slot : targetList) {
ctx.setTargetExprIdToFilter(slot.getExprId(), filter);
}
}
}
}
}
private boolean checkCanPushDownIntoBasicTable(PhysicalPlan root) {
// only support spj currently
List<PhysicalPlan> plans = Lists.newArrayList();
plans.addAll(root.collect(PhysicalPlan.class::isInstance));
return plans.stream().allMatch(p -> SPJ_PLAN.stream().anyMatch(c -> c.isInstance(p)));
}
private Map<Slot, PhysicalOlapScan> getPushDownBasicTablesInfos(PhysicalPlan root, SlotReference slot,
Map<NamedExpression, Pair<PhysicalRelation, Slot>> aliasTransferMap) {
Map<Slot, PhysicalOlapScan> basicTableInfos = new HashMap<>();
Set<PhysicalHashJoin> joins = new HashSet<>();
ExprId exprId = slot.getExprId();
if (aliasTransferMap.get(slot) != null && aliasTransferMap.get(slot).first instanceof PhysicalOlapScan) {
basicTableInfos.put(slot, (PhysicalOlapScan) aliasTransferMap.get(slot).first);
}
// try to find propagation condition from join
getAllJoinInfo(root, joins);
for (PhysicalHashJoin join : joins) {
List<Expression> conditions = join.getHashJoinConjuncts();
for (Expression equalTo : conditions) {
if (equalTo instanceof EqualTo) {
SlotReference leftSlot = (SlotReference) ((EqualTo) equalTo).left();
SlotReference rightSlot = (SlotReference) ((EqualTo) equalTo).right();
if (leftSlot.getExprId() == exprId) {
PhysicalOlapScan rightTable = (PhysicalOlapScan) aliasTransferMap.get(rightSlot).first;
if (rightTable != null) {
basicTableInfos.put(rightSlot, rightTable);
}
} else if (rightSlot.getExprId() == exprId) {
PhysicalOlapScan leftTable = (PhysicalOlapScan) aliasTransferMap.get(leftSlot).first;
if (leftTable != null) {
basicTableInfos.put(leftSlot, leftTable);
}
}
}
}
}
return basicTableInfos;
}
private void getAllJoinInfo(PhysicalPlan root, Set<PhysicalHashJoin> joins) {
if (root instanceof PhysicalHashJoin) {
joins.add((PhysicalHashJoin) root);
} else {
for (Object child : root.children()) {
getAllJoinInfo((PhysicalPlan) child, joins);
}
}
}
private boolean hasRemoteTarget(AbstractPlan join, AbstractPlan scan) {
if (scan instanceof PhysicalCTEConsumer) {
return true;
} else {
Preconditions.checkArgument(join.getMutableState(AbstractPlan.FRAGMENT_ID).isPresent(),
"cannot find fragment id for Join node");
Preconditions.checkArgument(scan.getMutableState(AbstractPlan.FRAGMENT_ID).isPresent(),
"cannot find fragment id for scan node");
return join.getMutableState(AbstractPlan.FRAGMENT_ID).get()
!= scan.getMutableState(AbstractPlan.FRAGMENT_ID).get();
}
}
}

View File

@ -17,6 +17,8 @@
package org.apache.doris.nereids.trees.plans.physical;
import org.apache.doris.catalog.OlapTable;
import org.apache.doris.nereids.exceptions.TransformException;
import org.apache.doris.nereids.memo.GroupExpression;
import org.apache.doris.nereids.properties.LogicalProperties;
import org.apache.doris.nereids.properties.PhysicalProperties;
@ -26,6 +28,7 @@ import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.PlanType;
import org.apache.doris.nereids.trees.plans.visitor.PlanVisitor;
import org.apache.doris.nereids.util.RelationUtil;
import org.apache.doris.nereids.util.Utils;
import org.apache.doris.statistics.Statistics;
@ -41,7 +44,7 @@ import java.util.Optional;
/**
* Physical CTE consumer.
*/
public class PhysicalCTEConsumer extends PhysicalLeaf {
public class PhysicalCTEConsumer extends PhysicalRelation {
private final CTEId cteId;
private final Map<Slot, Slot> producerToConsumerSlotMap;
@ -77,12 +80,23 @@ public class PhysicalCTEConsumer extends PhysicalLeaf {
LogicalProperties logicalProperties,
PhysicalProperties physicalProperties,
Statistics statistics) {
super(PlanType.PHYSICAL_CTE_CONSUME, groupExpression, logicalProperties, physicalProperties, statistics);
super(RelationUtil.newRelationId(), PlanType.PHYSICAL_CTE_CONSUME, ImmutableList.of(), groupExpression,
logicalProperties, physicalProperties, statistics);
this.cteId = cteId;
this.consumerToProducerSlotMap = ImmutableMap.copyOf(consumerToProducerSlotMap);
this.producerToConsumerSlotMap = ImmutableMap.copyOf(producerToConsumerSlotMap);
}
@Override
public OlapTable getTable() {
throw new TransformException("should not reach here");
}
@Override
public List<String> getQualifier() {
throw new TransformException("should not reach here");
}
public CTEId getCteId() {
return cteId;
}

View File

@ -29,7 +29,8 @@ public class MultiCastPlanFragment extends PlanFragment {
private final List<ExchangeNode> destNodeList = Lists.newArrayList();
public MultiCastPlanFragment(PlanFragment planFragment) {
super(planFragment.getFragmentId(), planFragment.getPlanRoot(), planFragment.getDataPartition());
super(planFragment.getFragmentId(), planFragment.getPlanRoot(), planFragment.getDataPartition(),
planFragment.getBuilderRuntimeFilterIds(), planFragment.getTargetRuntimeFilterIds());
this.outputPartition = DataPartition.RANDOM;
this.children.addAll(planFragment.getChildren());
}

View File

@ -33,6 +33,7 @@ import org.apache.doris.thrift.TPartitionType;
import org.apache.doris.thrift.TPlanFragment;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Lists;
import org.apache.commons.collections.CollectionUtils;
import org.apache.logging.log4j.LogManager;
@ -167,6 +168,13 @@ public class PlanFragment extends TreeNode<PlanFragment> {
this.dataPartitionForThrift = partitionForThrift;
}
public PlanFragment(PlanFragmentId id, PlanNode root, DataPartition partition,
Set<RuntimeFilterId> builderRuntimeFilterIds, Set<RuntimeFilterId> targetRuntimeFilterIds) {
this(id, root, partition);
this.builderRuntimeFilterIds = ImmutableSet.copyOf(builderRuntimeFilterIds);
this.targetRuntimeFilterIds = ImmutableSet.copyOf(targetRuntimeFilterIds);
}
/**
* Assigns 'this' as fragment of all PlanNodes in the plan tree rooted at node.
* Does not traverse the children of ExchangeNodes because those must belong to a