[feature](nereids) judge if the join is at bottom of join cluster (#29383)

This commit is contained in:
minghong
2024-01-06 17:15:19 +08:00
committed by GitHub
parent cc7b9480cf
commit 911635fac6
27 changed files with 502 additions and 484 deletions

View File

@ -64,6 +64,7 @@ import org.apache.doris.qe.ConnectContext;
import org.apache.doris.qe.SessionVariable;
import org.apache.doris.statistics.ColumnStatistic;
import org.apache.doris.statistics.Statistics;
import org.apache.doris.statistics.StatisticsBuilder;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
@ -619,7 +620,7 @@ public class CascadesContext implements ScheduleContext {
List<Pair<Map<Slot, Slot>, Group>> consumerGroups = this.statementContext.getCteIdToConsumerGroup().get(cteId);
for (Pair<Map<Slot, Slot>, Group> p : consumerGroups) {
Map<Slot, Slot> producerSlotToConsumerSlot = p.first;
Statistics updatedConsumerStats = new Statistics(statistics);
Statistics updatedConsumerStats = new StatisticsBuilder(statistics).build();
for (Entry<Expression, ColumnStatistic> entry : statistics.columnStatistics().entrySet()) {
updatedConsumerStats.addColumnStats(producerSlotToConsumerSlot.get(entry.getKey()), entry.getValue());
}

View File

@ -238,6 +238,22 @@ public class RuntimeFilterContext {
return aliasTransferMap;
}
public Pair<PhysicalRelation, Slot> aliasTransferMapRemove(NamedExpression slot) {
return aliasTransferMap.remove(slot);
}
public Pair<PhysicalRelation, Slot> getAliasTransferPair(NamedExpression slot) {
return aliasTransferMap.get(slot);
}
public Pair<PhysicalRelation, Slot> aliasTransferMapPut(NamedExpression slot, Pair<PhysicalRelation, Slot> pair) {
return aliasTransferMap.put(slot, pair);
}
public boolean aliasTransferMapContains(NamedExpression slot) {
return aliasTransferMap.containsKey(slot);
}
public Map<Slot, ScanNode> getScanNodeOfLegacyRuntimeFilterTarget() {
return scanNodeOfLegacyRuntimeFilterTarget;
}

View File

@ -122,7 +122,7 @@ public class RuntimeFilterGenerator extends PlanPostProcessor {
join.left().accept(this, context);
if (RuntimeFilterGenerator.DENIED_JOIN_TYPES.contains(join.getJoinType()) || join.isMarkJoin()) {
join.right().getOutput().forEach(slot ->
context.getRuntimeFilterContext().getAliasTransferMap().remove(slot));
context.getRuntimeFilterContext().aliasTransferMapRemove(slot));
}
collectPushDownCTEInfos(join, context);
if (!getPushDownCTECandidates(ctx).isEmpty()) {
@ -136,7 +136,7 @@ public class RuntimeFilterGenerator extends PlanPostProcessor {
@Override
public PhysicalCTEConsumer visitPhysicalCTEConsumer(PhysicalCTEConsumer scan, CascadesContext context) {
RuntimeFilterContext ctx = context.getRuntimeFilterContext();
scan.getOutput().forEach(slot -> ctx.getAliasTransferMap().put(slot, Pair.of(scan, slot)));
scan.getOutput().forEach(slot -> ctx.aliasTransferMapPut(slot, Pair.of(scan, slot)));
return scan;
}
@ -158,7 +158,6 @@ public class RuntimeFilterGenerator extends PlanPostProcessor {
if (join.getJoinType() != JoinType.LEFT_SEMI_JOIN && join.getJoinType() != JoinType.CROSS_JOIN) {
return;
}
Map<NamedExpression, Pair<PhysicalRelation, Slot>> aliasTransferMap = ctx.getAliasTransferMap();
List<Slot> leftSlots = join.left().getOutput();
List<Slot> rightSlots = join.right().getOutput();
List<Expression> bitmapRuntimeFilterConditions = JoinUtils.extractBitmapRuntimeFilterConditions(leftSlots,
@ -183,15 +182,15 @@ public class RuntimeFilterGenerator extends PlanPostProcessor {
if (!checkPushDownPreconditionsForJoin(join, ctx, targetSlot)) {
continue;
}
Slot scanSlot = aliasTransferMap.get(targetSlot).second;
PhysicalRelation scan = aliasTransferMap.get(targetSlot).first;
Slot scanSlot = ctx.getAliasTransferPair(targetSlot).second;
PhysicalRelation scan = ctx.getAliasTransferPair(targetSlot).first;
RuntimeFilter filter = new RuntimeFilter(generator.getNextId(),
bitmapContains.child(0), ImmutableList.of(scanSlot),
ImmutableList.of(bitmapContains.child(1)), type, i, join, isNot, -1L);
scan.addAppliedRuntimeFilter(filter);
ctx.addJoinToTargetMap(join, scanSlot.getExprId());
ctx.setTargetExprIdToFilter(scanSlot.getExprId(), filter);
ctx.setTargetsOnScanNode(aliasTransferMap.get(targetSlot).first,
ctx.setTargetsOnScanNode(ctx.getAliasTransferPair(targetSlot).first,
scanSlot);
join.addBitmapRuntimeFilterCondition(bitmapRuntimeFilterCondition);
}
@ -246,7 +245,6 @@ public class RuntimeFilterGenerator extends PlanPostProcessor {
*/
private void generateMinMaxRuntimeFilter(AbstractPhysicalJoin<? extends Plan, ? extends Plan> join,
RuntimeFilterContext ctx) {
Map<NamedExpression, Pair<PhysicalRelation, Slot>> aliasTransferMap = ctx.getAliasTransferMap();
int hashCondionSize = join.getHashJoinConjuncts().size();
for (int idx = 0; idx < join.getOtherJoinConjuncts().size(); idx++) {
int exprOrder = idx + hashCondionSize;
@ -257,7 +255,7 @@ public class RuntimeFilterGenerator extends PlanPostProcessor {
if (unwrappedSlot == null) {
continue;
}
Pair<PhysicalRelation, Slot> pair = aliasTransferMap.get(unwrappedSlot);
Pair<PhysicalRelation, Slot> pair = ctx.getAliasTransferPair(unwrappedSlot);
if (pair == null) {
continue;
}
@ -286,7 +284,7 @@ public class RuntimeFilterGenerator extends PlanPostProcessor {
if (RuntimeFilterGenerator.DENIED_JOIN_TYPES.contains(join.getJoinType()) || join.isMarkJoin()) {
join.right().getOutput().forEach(slot ->
context.getRuntimeFilterContext().getAliasTransferMap().remove(slot));
context.getRuntimeFilterContext().aliasTransferMapRemove(slot));
return join;
}
RuntimeFilterContext ctx = context.getRuntimeFilterContext();
@ -310,8 +308,7 @@ public class RuntimeFilterGenerator extends PlanPostProcessor {
@Override
public PhysicalPlan visitPhysicalProject(PhysicalProject<? extends Plan> project, CascadesContext context) {
project.child().accept(this, context);
Map<NamedExpression, Pair<PhysicalRelation, Slot>> aliasTransferMap
= context.getRuntimeFilterContext().getAliasTransferMap();
RuntimeFilterContext ctx = context.getRuntimeFilterContext();
// change key when encounter alias.
// TODO: same action will be taken for set operation
for (Expression expression : project.getProjects()) {
@ -319,10 +316,11 @@ public class RuntimeFilterGenerator extends PlanPostProcessor {
continue;
}
Expression expr = ExpressionUtils.getExpressionCoveredByCast(expression.child(0));
if (expr instanceof NamedExpression && aliasTransferMap.containsKey((NamedExpression) expr)) {
if (expr instanceof NamedExpression
&& ctx.aliasTransferMapContains((NamedExpression) expr)) {
if (expression instanceof Alias) {
Alias alias = ((Alias) expression);
aliasTransferMap.put(alias.toSlot(), aliasTransferMap.get(expr));
ctx.aliasTransferMapPut(alias.toSlot(), ctx.getAliasTransferPair((NamedExpression) expr));
}
}
}
@ -340,7 +338,7 @@ public class RuntimeFilterGenerator extends PlanPostProcessor {
public PhysicalRelation visitPhysicalRelation(PhysicalRelation relation, CascadesContext context) {
// add all the slots in map.
RuntimeFilterContext ctx = context.getRuntimeFilterContext();
relation.getOutput().forEach(slot -> ctx.getAliasTransferMap().put(slot, Pair.of(relation, slot)));
relation.getOutput().forEach(slot -> ctx.aliasTransferMapPut(slot, Pair.of(relation, slot)));
return relation;
}
@ -579,7 +577,6 @@ public class RuntimeFilterGenerator extends PlanPostProcessor {
private void doPushDownIntoCTEProducerInternal(PhysicalHashJoin<? extends Plan, ? extends Plan> join,
RuntimeFilterContext ctx, EqualTo equalTo, TRuntimeFilterType type, PhysicalCTEProducer cteProducer) {
Map<NamedExpression, Pair<PhysicalRelation, Slot>> aliasTransferMap = ctx.getAliasTransferMap();
PhysicalPlan inputPlanNode = (PhysicalPlan) cteProducer.child(0);
Slot unwrappedSlot = checkTargetChild(equalTo.left());
// aliasTransMap doesn't contain the key, means that the path from the scan to the join
@ -587,8 +584,8 @@ public class RuntimeFilterGenerator extends PlanPostProcessor {
if (!checkPushDownPreconditionsForJoin(join, ctx, unwrappedSlot)) {
return;
}
Slot cteSlot = aliasTransferMap.get(unwrappedSlot).second;
PhysicalRelation cteNode = aliasTransferMap.get(unwrappedSlot).first;
Slot cteSlot = ctx.getAliasTransferPair(unwrappedSlot).second;
PhysicalRelation cteNode = ctx.getAliasTransferPair(unwrappedSlot).first;
long buildSideNdv = getBuildSideNdv(join, equalTo);
if (cteNode instanceof PhysicalCTEConsumer && inputPlanNode instanceof PhysicalProject) {
PhysicalProject project = (PhysicalProject) inputPlanNode;
@ -608,7 +605,7 @@ public class RuntimeFilterGenerator extends PlanPostProcessor {
return;
} else {
Map<Slot, PhysicalRelation> pushDownBasicTableInfos = getPushDownBasicTablesInfos(project,
(SlotReference) targetExpr, aliasTransferMap);
(SlotReference) targetExpr, ctx);
if (!pushDownBasicTableInfos.isEmpty()) {
List<Slot> targetList = new ArrayList<>();
List<PhysicalRelation> targetNodes = new ArrayList<>();
@ -642,7 +639,7 @@ public class RuntimeFilterGenerator extends PlanPostProcessor {
topN.child().accept(this, context);
PhysicalPlan child = (PhysicalPlan) topN.child();
for (Slot slot : child.getOutput()) {
context.getRuntimeFilterContext().getAliasTransferMap().remove(slot);
context.getRuntimeFilterContext().aliasTransferMapRemove(slot);
}
return topN;
}
@ -652,7 +649,7 @@ public class RuntimeFilterGenerator extends PlanPostProcessor {
window.child().accept(this, context);
Set<SlotReference> commonPartitionKeys = window.getCommonPartitionKeyFromWindowExpressions();
window.child().getOutput().stream().filter(slot -> !commonPartitionKeys.contains(slot)).forEach(
slot -> context.getRuntimeFilterContext().getAliasTransferMap().remove(slot)
slot -> context.getRuntimeFilterContext().aliasTransferMapRemove(slot)
);
return window;
}
@ -662,8 +659,7 @@ public class RuntimeFilterGenerator extends PlanPostProcessor {
*/
public static boolean checkPushDownPreconditionsForJoin(AbstractPhysicalJoin physicalJoin,
RuntimeFilterContext ctx, Slot slot) {
Map<NamedExpression, Pair<PhysicalRelation, Slot>> aliasTransferMap = ctx.getAliasTransferMap();
if (slot == null || !aliasTransferMap.containsKey(slot)) {
if (slot == null || !ctx.aliasTransferMapContains(slot)) {
return false;
} else if (DENIED_JOIN_TYPES.contains(physicalJoin.getJoinType()) || physicalJoin.isMarkJoin()) {
return false;
@ -695,12 +691,12 @@ public class RuntimeFilterGenerator extends PlanPostProcessor {
}
private Map<Slot, PhysicalRelation> getPushDownBasicTablesInfos(PhysicalPlan root, SlotReference slot,
Map<NamedExpression, Pair<PhysicalRelation, Slot>> aliasTransferMap) {
RuntimeFilterContext ctx) {
Map<Slot, PhysicalRelation> basicTableInfos = new HashMap<>();
Set<PhysicalHashJoin> joins = new HashSet<>();
ExprId exprId = slot.getExprId();
if (aliasTransferMap.get(slot) != null) {
basicTableInfos.put(slot, aliasTransferMap.get(slot).first);
if (ctx.getAliasTransferPair(slot) != null) {
basicTableInfos.put(slot, ctx.getAliasTransferPair(slot).first);
}
// try to find propagation condition from join
getAllJoinInfo(root, joins);
@ -710,13 +706,13 @@ public class RuntimeFilterGenerator extends PlanPostProcessor {
if (equalTo instanceof EqualTo) {
SlotReference leftSlot = (SlotReference) ((EqualTo) equalTo).left();
SlotReference rightSlot = (SlotReference) ((EqualTo) equalTo).right();
if (leftSlot.getExprId() == exprId && aliasTransferMap.get(rightSlot) != null) {
PhysicalRelation rightTable = aliasTransferMap.get(rightSlot).first;
if (leftSlot.getExprId() == exprId && ctx.getAliasTransferPair(rightSlot) != null) {
PhysicalRelation rightTable = ctx.getAliasTransferPair(rightSlot).first;
if (rightTable != null) {
basicTableInfos.put(rightSlot, rightTable);
}
} else if (rightSlot.getExprId() == exprId && aliasTransferMap.get(leftSlot) != null) {
PhysicalRelation leftTable = aliasTransferMap.get(leftSlot).first;
} else if (rightSlot.getExprId() == exprId && ctx.getAliasTransferPair(leftSlot) != null) {
PhysicalRelation leftTable = ctx.getAliasTransferPair(leftSlot).first;
if (leftTable != null) {
basicTableInfos.put(leftSlot, leftTable);
}

View File

@ -114,9 +114,13 @@ public class JoinCommute extends OneExplorationRuleFactory {
}
private static boolean containJoin(GroupPlan groupPlan) {
// TODO: tmp way to judge containJoin
List<Slot> output = groupPlan.getOutput();
return !output.stream().map(Slot::getQualifier).allMatch(output.get(0).getQualifier()::equals);
if (groupPlan.getGroup().getStatistics() != null) {
return groupPlan.getGroup().getStatistics().getWidthInJoinCluster() > 1;
} else {
// tmp way to judge containJoin, just used for test case where stats is null
List<Slot> output = groupPlan.getOutput();
return !output.stream().map(Slot::getQualifier).allMatch(output.get(0).getQualifier()::equals);
}
}
/**

View File

@ -332,7 +332,7 @@ public class FilterEstimation extends ExpressionVisitor<Statistics, EstimationCo
}
}
compareExprStatsBuilder.setNumNulls(0);
Statistics estimated = new Statistics(context.statistics);
Statistics estimated = new StatisticsBuilder(context.statistics).build();
ColumnStatistic stats = compareExprStatsBuilder.build();
selectivity = getNotNullSelectivity(stats, selectivity);
estimated = estimated.withSel(selectivity);

View File

@ -358,8 +358,12 @@ public class StatsCalculator extends DefaultPlanVisitor<Statistics, Void> {
@Override
public Statistics visitLogicalJoin(LogicalJoin<? extends Plan, ? extends Plan> join, Void context) {
return JoinEstimation.estimate(groupExpression.childStatistics(0),
Statistics joinStats = JoinEstimation.estimate(groupExpression.childStatistics(0),
groupExpression.childStatistics(1), join);
joinStats = new StatisticsBuilder(joinStats).setWidthInJoinCluster(
groupExpression.childStatistics(0).getWidthInJoinCluster()
+ groupExpression.childStatistics(1).getWidthInJoinCluster()).build();
return joinStats;
}
@Override
@ -555,6 +559,7 @@ public class StatsCalculator extends DefaultPlanVisitor<Statistics, Void> {
private Statistics computeAssertNumRows(long desiredNumOfRows) {
Statistics statistics = groupExpression.childStatistics(0);
statistics.withRowCountAndEnforceValid(Math.min(1, statistics.getRowCount()));
statistics = new StatisticsBuilder(statistics).setWidthInJoinCluster(1).build();
return statistics;
}
@ -764,7 +769,7 @@ public class StatsCalculator extends DefaultPlanVisitor<Statistics, Void> {
builder.setDataSize(rowCount * outputExpression.getDataType().width());
slotToColumnStats.put(outputExpression.toSlot(), columnStat);
}
return new Statistics(rowCount, slotToColumnStats);
return new Statistics(rowCount, 1, slotToColumnStats);
// TODO: Update ColumnStats properly, add new mapping from output slot to ColumnStats
}
@ -783,7 +788,7 @@ public class StatsCalculator extends DefaultPlanVisitor<Statistics, Void> {
.setDataSize(stats.dataSize < 0 ? stats.dataSize : stats.dataSize * groupingSetNum);
return Pair.of(kv.getKey(), columnStatisticBuilder.build());
}).collect(Collectors.toMap(Pair::key, Pair::value));
return new Statistics(rowCount < 0 ? rowCount : rowCount * groupingSetNum, columnStatisticMap);
return new Statistics(rowCount < 0 ? rowCount : rowCount * groupingSetNum, 1, columnStatisticMap);
}
private Statistics computeProject(Project project) {
@ -793,7 +798,7 @@ public class StatsCalculator extends DefaultPlanVisitor<Statistics, Void> {
ColumnStatistic columnStatistic = ExpressionEstimation.estimate(projection, childStats);
return new SimpleEntry<>(projection.toSlot(), columnStatistic);
}).collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue, (item1, item2) -> item1));
return new Statistics(childStats.getRowCount(), columnsStats);
return new Statistics(childStats.getRowCount(), childStats.getWidthInJoinCluster(), columnsStats);
}
private Statistics computeOneRowRelation(List<NamedExpression> projects) {
@ -805,7 +810,7 @@ public class StatsCalculator extends DefaultPlanVisitor<Statistics, Void> {
})
.collect(Collectors.toMap(Pair::key, Pair::value));
int rowCount = 1;
return new Statistics(rowCount, columnStatsMap);
return new Statistics(rowCount, 1, columnStatsMap);
}
private Statistics computeEmptyRelation(EmptyRelation emptyRelation) {
@ -820,7 +825,7 @@ public class StatsCalculator extends DefaultPlanVisitor<Statistics, Void> {
})
.collect(Collectors.toMap(Pair::key, Pair::value));
int rowCount = 0;
return new Statistics(rowCount, columnStatsMap);
return new Statistics(rowCount, 1, columnStatsMap);
}
private Statistics computeUnion(Union union) {
@ -863,7 +868,7 @@ public class StatsCalculator extends DefaultPlanVisitor<Statistics, Void> {
statisticsBuilder.setRowCount(leftRowCount);
statisticsBuilder.putColumnStatistics(unionOutput.get(i), headStats.findColumnStatistics(headSlot));
}
return statisticsBuilder.build();
return statisticsBuilder.setWidthInJoinCluster(1).build();
}
private Statistics computeExcept(SetOperation setOperation) {
@ -876,7 +881,7 @@ public class StatsCalculator extends DefaultPlanVisitor<Statistics, Void> {
statisticsBuilder.putColumnStatistics(operatorOutput.get(i), columnStatistic);
}
statisticsBuilder.setRowCount(leftStats.getRowCount());
return statisticsBuilder.build();
return statisticsBuilder.setWidthInJoinCluster(1).build();
}
private Statistics computeIntersect(SetOperation setOperation) {
@ -903,7 +908,8 @@ public class StatsCalculator extends DefaultPlanVisitor<Statistics, Void> {
leftChildStats.addColumnStats(outputs.get(i),
leftChildStats.findColumnStatistics(leftChildOutputs.get(i)));
}
return leftChildStats.withRowCountAndEnforceValid(rowCount);
return new StatisticsBuilder(leftChildStats.withRowCountAndEnforceValid(rowCount))
.setWidthInJoinCluster(1).build();
}
private Statistics computeGenerate(Generate generate) {
@ -925,7 +931,7 @@ public class StatsCalculator extends DefaultPlanVisitor<Statistics, Void> {
.build();
columnStatsMap.put(output, columnStatistic);
}
return new Statistics(count, columnStatsMap);
return new Statistics(count, 1, columnStatsMap);
}
private Statistics computeWindow(Window windowOperator) {
@ -994,7 +1000,7 @@ public class StatsCalculator extends DefaultPlanVisitor<Statistics, Void> {
return Pair.of(expr.toSlot(), colStatsBuilder.build());
}).collect(Collectors.toMap(Pair::key, Pair::value));
columnStatisticMap.putAll(childColumnStats);
return new Statistics(childStats.getRowCount(), columnStatisticMap);
return new Statistics(childStats.getRowCount(), 1, columnStatisticMap);
}
private ColumnStatistic unionColumn(ColumnStatistic leftStats, double leftRowCount, ColumnStatistic rightStats,
@ -1033,7 +1039,8 @@ public class StatsCalculator extends DefaultPlanVisitor<Statistics, Void> {
@Override
public Statistics visitLogicalCTEProducer(LogicalCTEProducer<? extends Plan> cteProducer, Void context) {
Statistics statistics = groupExpression.childStatistics(0);
StatisticsBuilder builder = new StatisticsBuilder(groupExpression.childStatistics(0));
Statistics statistics = builder.setWidthInJoinCluster(1).build();
cteIdToStats.put(cteProducer.getCteId(), statistics);
return statistics;
}
@ -1045,7 +1052,7 @@ public class StatsCalculator extends DefaultPlanVisitor<Statistics, Void> {
cteConsumer.getProducerToConsumerOutputMap());
Statistics prodStats = cteIdToStats.get(cteId);
Preconditions.checkArgument(prodStats != null, String.format("Stats for CTE: %s not found", cteId));
Statistics consumerStats = new Statistics(prodStats.getRowCount(), new HashMap<>());
Statistics consumerStats = new Statistics(prodStats.getRowCount(), 1, new HashMap<>());
for (Slot slot : cteConsumer.getOutput()) {
Slot prodSlot = cteConsumer.getProducerSlot(slot);
ColumnStatistic colStats = prodStats.columnStatistics().get(prodSlot);
@ -1065,7 +1072,8 @@ public class StatsCalculator extends DefaultPlanVisitor<Statistics, Void> {
@Override
public Statistics visitPhysicalCTEProducer(PhysicalCTEProducer<? extends Plan> cteProducer,
Void context) {
Statistics statistics = groupExpression.childStatistics(0);
Statistics statistics = new StatisticsBuilder(groupExpression.childStatistics(0))
.setWidthInJoinCluster(1).build();
cteIdToStats.put(cteProducer.getCteId(), statistics);
cascadesContext.updateConsumerStats(cteProducer.getCteId(), statistics);
return statistics;
@ -1081,7 +1089,7 @@ public class StatsCalculator extends DefaultPlanVisitor<Statistics, Void> {
prodStats = groupExpression.getOwnerGroup().getStatistics();
}
Preconditions.checkArgument(prodStats != null, String.format("Stats for CTE: %s not found", cteId));
Statistics consumerStats = new Statistics(prodStats.getRowCount(), new HashMap<>());
Statistics consumerStats = new Statistics(prodStats.getRowCount(), 1, new HashMap<>());
for (Slot slot : cteConsumer.getOutput()) {
Slot prodSlot = cteConsumer.getProducerSlot(slot);
ColumnStatistic colStats = prodStats.columnStatistics().get(prodSlot);

View File

@ -18,7 +18,6 @@
package org.apache.doris.nereids.trees.plans.physical;
import org.apache.doris.common.IdGenerator;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.memo.GroupExpression;
import org.apache.doris.nereids.processor.post.RuntimeFilterContext;
@ -43,7 +42,6 @@ import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import javax.annotation.Nullable;
@ -85,7 +83,6 @@ public abstract class AbstractPhysicalPlan extends AbstractPlan implements Physi
Expression src, Expression probeExpr,
TRuntimeFilterType type, long buildSideNdv, int exprOrder) {
RuntimeFilterContext ctx = context.getRuntimeFilterContext();
Map<NamedExpression, Pair<PhysicalRelation, Slot>> aliasTransferMap = ctx.getAliasTransferMap();
// currently, we can ensure children in the two side are corresponding to the equal_to's.
// so right maybe an expression and left is a slot
Slot probeSlot = RuntimeFilterGenerator.checkTargetChild(probeExpr);
@ -106,8 +103,8 @@ public abstract class AbstractPhysicalPlan extends AbstractPlan implements Physi
return true;
}
Slot scanSlot = aliasTransferMap.get(probeSlot).second;
PhysicalRelation scan = aliasTransferMap.get(probeSlot).first;
Slot scanSlot = ctx.getAliasTransferPair(probeSlot).second;
PhysicalRelation scan = ctx.getAliasTransferPair(probeSlot).first;
if (!RuntimeFilterGenerator.checkPushDownPreconditionsForRelation(this, scan)) {
return false;
}
@ -127,14 +124,14 @@ public abstract class AbstractPhysicalPlan extends AbstractPlan implements Physi
filter.addTargetExpression(scanSlot);
ctx.addJoinToTargetMap(builderNode, scanSlot.getExprId());
ctx.setTargetExprIdToFilter(scanSlot.getExprId(), filter);
ctx.setTargetsOnScanNode(aliasTransferMap.get(probeExpr).first, scanSlot);
ctx.setTargetsOnScanNode(ctx.getAliasTransferPair((NamedExpression) probeExpr).first, scanSlot);
} else {
filter = new RuntimeFilter(generator.getNextId(),
src, ImmutableList.of(scanSlot), type, exprOrder, builderNode, buildSideNdv);
this.addAppliedRuntimeFilter(filter);
ctx.addJoinToTargetMap(builderNode, scanSlot.getExprId());
ctx.setTargetExprIdToFilter(scanSlot.getExprId(), filter);
ctx.setTargetsOnScanNode(aliasTransferMap.get(probeExpr).first, scanSlot);
ctx.setTargetsOnScanNode(ctx.getAliasTransferPair((NamedExpression) probeExpr).first, scanSlot);
ctx.setRuntimeFilterIdentityToFilter(src, type, builderNode, filter);
}
return true;

View File

@ -18,7 +18,6 @@
package org.apache.doris.nereids.trees.plans.physical;
import org.apache.doris.common.IdGenerator;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.memo.GroupExpression;
import org.apache.doris.nereids.processor.post.RuntimeFilterContext;
@ -27,7 +26,6 @@ import org.apache.doris.nereids.properties.DistributionSpec;
import org.apache.doris.nereids.properties.LogicalProperties;
import org.apache.doris.nereids.properties.PhysicalProperties;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.PlanType;
@ -42,7 +40,6 @@ import com.google.common.collect.ImmutableList;
import org.json.JSONObject;
import java.util.List;
import java.util.Map;
import java.util.Optional;
/**
@ -134,7 +131,6 @@ public class PhysicalDistribute<CHILD_TYPE extends Plan> extends PhysicalUnary<C
AbstractPhysicalJoin<?, ?> builderNode, Expression src, Expression probeExpr,
TRuntimeFilterType type, long buildSideNdv, int exprOrder) {
RuntimeFilterContext ctx = context.getRuntimeFilterContext();
Map<NamedExpression, Pair<PhysicalRelation, Slot>> aliasTransferMap = ctx.getAliasTransferMap();
// currently, we can ensure children in the two side are corresponding to the equal_to's.
// so right maybe an expression and left is a slot
Slot probeSlot = RuntimeFilterGenerator.checkTargetChild(probeExpr);
@ -144,7 +140,7 @@ public class PhysicalDistribute<CHILD_TYPE extends Plan> extends PhysicalUnary<C
if (!RuntimeFilterGenerator.checkPushDownPreconditionsForJoin(builderNode, ctx, probeSlot)) {
return false;
}
PhysicalRelation scan = aliasTransferMap.get(probeSlot).first;
PhysicalRelation scan = ctx.getAliasTransferPair(probeSlot).first;
if (!RuntimeFilterGenerator.checkPushDownPreconditionsForRelation(this, scan)) {
return false;
}

View File

@ -18,7 +18,6 @@
package org.apache.doris.nereids.trees.plans.physical;
import org.apache.doris.common.IdGenerator;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.memo.GroupExpression;
import org.apache.doris.nereids.processor.post.RuntimeFilterContext;
@ -46,7 +45,6 @@ import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
@ -298,7 +296,6 @@ public class PhysicalHashAggregate<CHILD_TYPE extends Plan> extends PhysicalUnar
AbstractPhysicalJoin<?, ?> builderNode, Expression src, Expression probeExpr,
TRuntimeFilterType type, long buildSideNdv, int exprOrder) {
RuntimeFilterContext ctx = context.getRuntimeFilterContext();
Map<NamedExpression, Pair<PhysicalRelation, Slot>> aliasTransferMap = ctx.getAliasTransferMap();
// currently, we can ensure children in the two side are corresponding to the equal_to's.
// so right maybe an expression and left is a slot
Slot probeSlot = RuntimeFilterGenerator.checkTargetChild(probeExpr);
@ -308,7 +305,7 @@ public class PhysicalHashAggregate<CHILD_TYPE extends Plan> extends PhysicalUnar
if (!RuntimeFilterGenerator.checkPushDownPreconditionsForJoin(builderNode, ctx, probeSlot)) {
return false;
}
PhysicalRelation scan = aliasTransferMap.get(probeSlot).first;
PhysicalRelation scan = ctx.getAliasTransferPair(probeSlot).first;
if (!RuntimeFilterGenerator.checkPushDownPreconditionsForRelation(this, scan)) {
return false;
}

View File

@ -30,7 +30,6 @@ import org.apache.doris.nereids.trees.expressions.EqualPredicate;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.MarkJoinSlotReference;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.Plan;
@ -47,7 +46,6 @@ import com.google.common.collect.Sets;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
@ -188,7 +186,6 @@ public class PhysicalHashJoin<
}
}
RuntimeFilterContext ctx = context.getRuntimeFilterContext();
Map<NamedExpression, Pair<PhysicalRelation, Slot>> aliasTransferMap = ctx.getAliasTransferMap();
// if rf built between plan nodes containing cte both, for example both src slot and target slot are from cte,
// or two sub-queries both containing cte, disable this rf since this kind of cross-cte rf will make one side
@ -239,7 +236,7 @@ public class PhysicalHashJoin<
if (!RuntimeFilterGenerator.checkPushDownPreconditionsForJoin(builderNode, ctx, probeSlot)) {
return false;
}
PhysicalRelation scan = aliasTransferMap.get(probeSlot).first;
PhysicalRelation scan = ctx.getAliasTransferPair(probeSlot).first;
if (!RuntimeFilterGenerator.checkPushDownPreconditionsForRelation(this, scan)) {
return false;
}

View File

@ -18,7 +18,6 @@
package org.apache.doris.nereids.trees.plans.physical;
import org.apache.doris.common.IdGenerator;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.memo.GroupExpression;
import org.apache.doris.nereids.processor.post.RuntimeFilterContext;
@ -42,7 +41,6 @@ import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
@ -162,7 +160,6 @@ public class PhysicalProject<CHILD_TYPE extends Plan> extends PhysicalUnary<CHIL
AbstractPhysicalJoin<?, ?> builderNode, Expression src, Expression probeExpr,
TRuntimeFilterType type, long buildSideNdv, int exprOrder) {
RuntimeFilterContext ctx = context.getRuntimeFilterContext();
Map<NamedExpression, Pair<PhysicalRelation, Slot>> aliasTransferMap = ctx.getAliasTransferMap();
// currently, we can ensure children in the two side are corresponding to the equal_to's.
// so right maybe an expression and left is a slot
Slot probeSlot = RuntimeFilterGenerator.checkTargetChild(probeExpr);
@ -172,7 +169,7 @@ public class PhysicalProject<CHILD_TYPE extends Plan> extends PhysicalUnary<CHIL
if (!RuntimeFilterGenerator.checkPushDownPreconditionsForJoin(builderNode, ctx, probeSlot)) {
return false;
}
PhysicalRelation scan = aliasTransferMap.get(probeSlot).first;
PhysicalRelation scan = ctx.getAliasTransferPair(probeSlot).first;
Preconditions.checkState(scan != null, "scan is null");
if (scan instanceof PhysicalCTEConsumer) {
// update the probeExpr
@ -196,7 +193,7 @@ public class PhysicalProject<CHILD_TYPE extends Plan> extends PhysicalUnary<CHIL
if (!RuntimeFilterGenerator.checkPushDownPreconditionsForJoin(builderNode, ctx, newProbeSlot)) {
return false;
}
scan = aliasTransferMap.get(newProbeSlot).first;
scan = ctx.getAliasTransferPair(newProbeSlot).first;
probeExpr = newProbeExpr;
}
if (!RuntimeFilterGenerator.checkPushDownPreconditionsForRelation(this, scan)) {

View File

@ -33,18 +33,19 @@ public class Statistics {
private final double rowCount;
private final Map<Expression, ColumnStatistic> expressionToColumnStats;
private final int widthInJoinCluster;
// the byte size of one tuple
private double tupleSize;
public Statistics(Statistics another) {
this.rowCount = another.rowCount;
this.expressionToColumnStats = new HashMap<>(another.expressionToColumnStats);
this.tupleSize = another.tupleSize;
public Statistics(double rowCount, Map<Expression, ColumnStatistic> expressionToColumnStats) {
this(rowCount, 1, expressionToColumnStats);
}
public Statistics(double rowCount, Map<Expression, ColumnStatistic> expressionToColumnStats) {
public Statistics(double rowCount, int widthInJoinCluster,
Map<Expression, ColumnStatistic> expressionToColumnStats) {
this.rowCount = rowCount;
this.widthInJoinCluster = widthInJoinCluster;
this.expressionToColumnStats = expressionToColumnStats;
}
@ -61,14 +62,14 @@ public class Statistics {
}
public Statistics withRowCount(double rowCount) {
return new Statistics(rowCount, new HashMap<>(expressionToColumnStats));
return new Statistics(rowCount, widthInJoinCluster, new HashMap<>(expressionToColumnStats));
}
/**
* Update by count.
*/
public Statistics withRowCountAndEnforceValid(double rowCount) {
Statistics statistics = new Statistics(rowCount, expressionToColumnStats);
Statistics statistics = new Statistics(rowCount, widthInJoinCluster, expressionToColumnStats);
statistics.enforceValid();
return statistics;
}
@ -99,7 +100,7 @@ public class Statistics {
return this;
}
double newCount = rowCount * sel;
return new Statistics(newCount, new HashMap<>(expressionToColumnStats));
return new Statistics(newCount, widthInJoinCluster, new HashMap<>(expressionToColumnStats));
}
public Statistics addColumnStats(Expression expression, ColumnStatistic columnStatistic) {
@ -146,7 +147,7 @@ public class Statistics {
return "-Infinite";
}
DecimalFormat format = new DecimalFormat("#,###.##");
return format.format(rowCount);
return format.format(rowCount) + " " + widthInJoinCluster;
}
public int getBENumber() {
@ -181,10 +182,14 @@ public class Statistics {
StringBuilder builder = new StringBuilder();
builder.append(prefix).append("rows=").append(rowCount).append("\n");
builder.append(prefix).append("tupleSize=").append(computeTupleSize()).append("\n");
builder.append(prefix).append("width=").append(widthInJoinCluster).append("\n");
for (Entry<Expression, ColumnStatistic> entry : expressionToColumnStats.entrySet()) {
builder.append(prefix).append(entry.getKey()).append(" -> ").append(entry.getValue()).append("\n");
}
return builder.toString();
}
public int getWidthInJoinCluster() {
return widthInJoinCluster;
}
}

View File

@ -25,7 +25,7 @@ import java.util.Map;
public class StatisticsBuilder {
private double rowCount;
private int widthInJoinCluster;
private final Map<Expression, ColumnStatistic> expressionToColumnStats;
public StatisticsBuilder() {
@ -34,6 +34,7 @@ public class StatisticsBuilder {
public StatisticsBuilder(Statistics statistics) {
this.rowCount = statistics.getRowCount();
this.widthInJoinCluster = statistics.getWidthInJoinCluster();
expressionToColumnStats = new HashMap<>();
expressionToColumnStats.putAll(statistics.columnStatistics());
}
@ -43,6 +44,11 @@ public class StatisticsBuilder {
return this;
}
public StatisticsBuilder setWidthInJoinCluster(int widthInJoinCluster) {
this.widthInJoinCluster = widthInJoinCluster;
return this;
}
public StatisticsBuilder putColumnStatistics(
Map<Expression, ColumnStatistic> expressionToColumnStats) {
this.expressionToColumnStats.putAll(expressionToColumnStats);
@ -55,6 +61,6 @@ public class StatisticsBuilder {
}
public Statistics build() {
return new Statistics(rowCount, expressionToColumnStats);
return new Statistics(rowCount, widthInJoinCluster, expressionToColumnStats);
}
}