[fix](Nereids) column pruning should prune map in cte consumer (#34079)

we save bi-map in cte consumer to get the maping between producer and consumer.
the consumer's output is decided by the map in it.
so, cte consumer should be output prunable, and should remove useless entry from map when do column pruning
This commit is contained in:
morrySnow
2024-04-26 12:37:08 +08:00
committed by yiguolei
parent b41a5339d3
commit b24ff9953d
73 changed files with 514 additions and 499 deletions

View File

@ -50,7 +50,6 @@ import org.apache.doris.nereids.rules.analysis.BindRelation.CustomTableResolver;
import org.apache.doris.nereids.rules.exploration.mv.MaterializationContext;
import org.apache.doris.nereids.trees.expressions.CTEId;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SubqueryExpr;
import org.apache.doris.nereids.trees.plans.Plan;
@ -102,7 +101,7 @@ public class CascadesContext implements ScheduleContext {
private Optional<RootRewriteJobContext> currentRootRewriteJobContext;
// in optimize stage, the plan will storage in the memo
private Memo memo;
private StatementContext statementContext;
private final StatementContext statementContext;
private final CTEContext cteContext;
private final RuleSet ruleSet;
@ -616,16 +615,6 @@ public class CascadesContext implements ScheduleContext {
consumers.add(cteConsumer);
}
public void putCTEIdToProject(CTEId cteId, NamedExpression p) {
Set<NamedExpression> projects = this.statementContext.getCteIdToProjects()
.computeIfAbsent(cteId, k -> new HashSet<>());
projects.add(p);
}
public Set<NamedExpression> getProjectForProducer(CTEId cteId) {
return this.statementContext.getCteIdToProjects().get(cteId);
}
public Map<CTEId, Set<LogicalCTEConsumer>> getCteIdToConsumers() {
return this.statementContext.getCteIdToConsumers();
}
@ -639,17 +628,6 @@ public class CascadesContext implements ScheduleContext {
return this.statementContext.getConsumerIdToFilters();
}
public void markConsumerUnderProject(LogicalCTEConsumer cteConsumer) {
Set<RelationId> consumerIds = this.statementContext.getCteIdToConsumerUnderProjects()
.computeIfAbsent(cteConsumer.getCteId(), k -> new HashSet<>());
consumerIds.add(cteConsumer.getRelationId());
}
public boolean couldPruneColumnOnProducer(CTEId cteId) {
Set<RelationId> consumerIds = this.statementContext.getCteIdToConsumerUnderProjects().get(cteId);
return consumerIds.size() == this.statementContext.getCteIdToConsumers().get(cteId).size();
}
public void addCTEConsumerGroup(CTEId cteId, Group g, Map<Slot, Slot> producerSlotToConsumerSlot) {
List<Pair<Map<Slot, Slot>, Group>> consumerGroups =
this.statementContext.getCteIdToConsumerGroup().computeIfAbsent(cteId, k -> new ArrayList<>());
@ -746,7 +724,7 @@ public class CascadesContext implements ScheduleContext {
public static void printPlanProcess(List<PlanProcess> planProcesses) {
for (PlanProcess row : planProcesses) {
LOG.info("RULE: " + row.ruleName + "\nBEFORE:\n" + row.beforeShape + "\nafter:\n" + row.afterShape);
LOG.info("RULE: {}\nBEFORE:\n{}\nafter:\n{}", row.ruleName, row.beforeShape, row.afterShape);
}
}

View File

@ -27,7 +27,6 @@ import org.apache.doris.nereids.rules.analysis.ColumnAliasGenerator;
import org.apache.doris.nereids.trees.expressions.CTEId;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.plans.ObjectId;
@ -54,7 +53,6 @@ import java.io.Closeable;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.Collection;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
@ -103,9 +101,8 @@ public class StatementContext implements Closeable {
private final IdGenerator<CTEId> cteIdGenerator = CTEId.createGenerator();
private final Map<CTEId, Set<LogicalCTEConsumer>> cteIdToConsumers = new HashMap<>();
private final Map<CTEId, Set<NamedExpression>> cteIdToProjects = new HashMap<>();
private final Map<CTEId, Set<Slot>> cteIdToOutputIds = new HashMap<>();
private final Map<RelationId, Set<Expression>> consumerIdToFilters = new HashMap<>();
private final Map<CTEId, Set<RelationId>> cteIdToConsumerUnderProjects = new HashMap<>();
// Used to update consumer's stats
private final Map<CTEId, List<Pair<Map<Slot, Slot>, Group>>> cteIdToConsumerGroup = new HashMap<>();
private final Map<CTEId, LogicalPlan> rewrittenCteProducer = new HashMap<>();
@ -134,12 +131,13 @@ public class StatementContext implements Closeable {
private BitSet disableRules;
// table locks
private Stack<CloseableResource> plannerResources = new Stack<>();
private final Stack<CloseableResource> plannerResources = new Stack<>();
// for create view support in nereids
// key is the start and end position of the sql substring that needs to be replaced,
// and value is the new string used for replacement.
private TreeMap<Pair<Integer, Integer>, String> indexInSqlToString = new TreeMap<>(new Pair.PairComparator<>());
private final TreeMap<Pair<Integer, Integer>, String> indexInSqlToString
= new TreeMap<>(new Pair.PairComparator<>());
public StatementContext() {
this(ConnectContext.get(), null, 0);
@ -216,10 +214,6 @@ public class StatementContext implements Closeable {
return Optional.ofNullable(sqlCacheContext);
}
public int getMaxContinuousJoin() {
return joinCount;
}
public Set<SlotReference> getAllPathsSlots() {
Set<SlotReference> allSlotReferences = Sets.newHashSet();
for (Map<List<String>, SlotReference> slotReferenceMap : subColumnSlotRefMap.values()) {
@ -240,19 +234,16 @@ public class StatementContext implements Closeable {
* Add a slot ref attached with paths in context to avoid duplicated slot
*/
public void addPathSlotRef(Slot root, List<String> paths, SlotReference slotRef, Expression originalExpr) {
subColumnSlotRefMap.computeIfAbsent(root, k -> Maps.newTreeMap(new Comparator<List<String>>() {
@Override
public int compare(List<String> lst1, List<String> lst2) {
Iterator<String> it1 = lst1.iterator();
Iterator<String> it2 = lst2.iterator();
while (it1.hasNext() && it2.hasNext()) {
int result = it1.next().compareTo(it2.next());
if (result != 0) {
return result;
}
subColumnSlotRefMap.computeIfAbsent(root, k -> Maps.newTreeMap((lst1, lst2) -> {
Iterator<String> it1 = lst1.iterator();
Iterator<String> it2 = lst2.iterator();
while (it1.hasNext() && it2.hasNext()) {
int result = it1.next().compareTo(it2.next());
if (result != 0) {
return result;
}
return Integer.compare(lst1.size(), lst2.size());
}
return Integer.compare(lst1.size(), lst2.size());
}));
subColumnSlotRefMap.get(root).put(paths, slotRef);
subColumnOriginalExprMap.put(slotRef, originalExpr);
@ -349,18 +340,14 @@ public class StatementContext implements Closeable {
return cteIdToConsumers;
}
public Map<CTEId, Set<NamedExpression>> getCteIdToProjects() {
return cteIdToProjects;
public Map<CTEId, Set<Slot>> getCteIdToOutputIds() {
return cteIdToOutputIds;
}
public Map<RelationId, Set<Expression>> getConsumerIdToFilters() {
return consumerIdToFilters;
}
public Map<CTEId, Set<RelationId>> getCteIdToConsumerUnderProjects() {
return cteIdToConsumerUnderProjects;
}
public Map<CTEId, List<Pair<Map<Slot, Slot>, Group>>> getCteIdToConsumerGroup() {
return cteIdToConsumerGroup;
}

View File

@ -1119,9 +1119,13 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla
// update expr to slot mapping
TupleDescriptor tupleDescriptor = null;
for (Slot producerSlot : cteProducer.getOutput()) {
Slot consumerSlot = cteConsumer.getProducerToConsumerSlotMap().get(producerSlot);
SlotRef slotRef = context.findSlotRef(producerSlot.getExprId());
tupleDescriptor = slotRef.getDesc().getParent();
Slot consumerSlot = cteConsumer.getProducerToConsumerSlotMap().get(producerSlot);
// consumerSlot could be null if we prune partial consumers' columns
if (consumerSlot == null) {
continue;
}
context.addExprIdSlotRefPair(consumerSlot.getExprId(), slotRef);
}
CTEScanNode cteScanNode = new CTEScanNode(tupleDescriptor);

View File

@ -43,8 +43,8 @@ import org.apache.doris.nereids.rules.rewrite.CheckDataTypes;
import org.apache.doris.nereids.rules.rewrite.CheckMatchExpression;
import org.apache.doris.nereids.rules.rewrite.CheckMultiDistinct;
import org.apache.doris.nereids.rules.rewrite.CheckPrivileges;
import org.apache.doris.nereids.rules.rewrite.CollectCteConsumerOutput;
import org.apache.doris.nereids.rules.rewrite.CollectFilterAboveConsumer;
import org.apache.doris.nereids.rules.rewrite.CollectProjectAboveConsumer;
import org.apache.doris.nereids.rules.rewrite.ColumnPruning;
import org.apache.doris.nereids.rules.rewrite.ConvertInnerOrCrossJoin;
import org.apache.doris.nereids.rules.rewrite.CountDistinctRewrite;
@ -417,7 +417,7 @@ public class Rewriter extends AbstractBatchJobExecutor {
topic("Push project and filter on cte consumer to cte producer",
topDown(
new CollectFilterAboveConsumer(),
new CollectProjectAboveConsumer()
new CollectCteConsumerOutput()
)
)
);

View File

@ -306,8 +306,7 @@ public enum RuleType {
COLLECT_FILTER(RuleTypeClass.REWRITE),
COLLECT_JOIN_CONSTRAINT(RuleTypeClass.REWRITE),
COLLECT_PROJECT_ABOVE_CTE_CONSUMER(RuleTypeClass.REWRITE),
COLLECT_PROJECT_ABOVE_FILTER_CTE_CONSUMER(RuleTypeClass.REWRITE),
COLLECT_CTE_CONSUMER_OUTPUT(RuleTypeClass.REWRITE),
LEADING_JOIN(RuleTypeClass.REWRITE),
REWRITE_SENTINEL(RuleTypeClass.REWRITE),

View File

@ -0,0 +1,41 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.rewrite;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.Slot;
import java.util.HashSet;
import java.util.Set;
/**
* Collect outputs of CTE Consumer.
*/
public class CollectCteConsumerOutput extends OneRewriteRuleFactory {
@Override
public Rule build() {
return logicalCTEConsumer().thenApply(ctx -> {
Set<Slot> producerOutputs = ctx.statementContext
.getCteIdToOutputIds().computeIfAbsent(ctx.root.getCteId(), k -> new HashSet<>());
producerOutputs.addAll(ctx.root.getProducerToConsumerOutputMap().keySet());
return null;
}).toRule(RuleType.COLLECT_CTE_CONSUMER_OUTPUT);
}
}

View File

@ -1,81 +0,0 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.rewrite;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.logical.LogicalCTEConsumer;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import com.google.common.collect.ImmutableList;
import java.util.ArrayList;
import java.util.List;
import java.util.Set;
/**
* Collect Projects Above CTE Consumer.
*/
public class CollectProjectAboveConsumer implements RewriteRuleFactory {
@Override
public List<Rule> buildRules() {
return ImmutableList.of(RuleType.COLLECT_PROJECT_ABOVE_CTE_CONSUMER
.build(logicalProject(logicalCTEConsumer()).thenApply(ctx -> {
LogicalProject<LogicalCTEConsumer> project = ctx.root;
List<NamedExpression> namedExpressions = project.getProjects();
LogicalCTEConsumer cteConsumer = project.child();
collectProject(ctx.cascadesContext, namedExpressions, cteConsumer);
return ctx.root;
})),
RuleType.COLLECT_PROJECT_ABOVE_FILTER_CTE_CONSUMER
.build(logicalProject(logicalFilter(logicalCTEConsumer())).thenApply(ctx -> {
LogicalProject<LogicalFilter<LogicalCTEConsumer>> project = ctx.root;
LogicalFilter<LogicalCTEConsumer> filter = project.child();
Set<Slot> filterSlots = filter.getInputSlots();
List<NamedExpression> namedExpressions = new ArrayList<>(project.getProjects());
for (Slot slot : filterSlots) {
if (!project.getOutput().contains(slot)) {
namedExpressions.add(slot);
}
}
collectProject(ctx.cascadesContext, namedExpressions, filter.child());
return ctx.root;
}))
);
}
private static void collectProject(CascadesContext ctx,
List<NamedExpression> namedExpressions, LogicalCTEConsumer cteConsumer) {
for (Expression expr : namedExpressions) {
expr.foreach(node -> {
if (!(node instanceof Slot)) {
return;
}
Slot slot = cteConsumer.getProducerSlot((Slot) node);
ctx.putCTEIdToProject(cteConsumer.getCteId(), slot);
ctx.markConsumerUnderProject(cteConsumer);
});
}
}
}

View File

@ -29,6 +29,7 @@ import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.algebra.Aggregate;
import org.apache.doris.nereids.trees.plans.algebra.SetOperation.Qualifier;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalCTEConsumer;
import org.apache.doris.nereids.trees.plans.logical.LogicalCTEProducer;
import org.apache.doris.nereids.trees.plans.logical.LogicalExcept;
import org.apache.doris.nereids.trees.plans.logical.LogicalIntersect;
@ -200,13 +201,21 @@ public class ColumnPruning extends DefaultPlanRewriter<PruneContext> implements
return pruneAggregate(repeat, context);
}
private Plan pruneAggregate(Aggregate agg, PruneContext context) {
@Override
public Plan visitLogicalCTEProducer(LogicalCTEProducer<? extends Plan> cteProducer, PruneContext context) {
return skipPruneThisAndFirstLevelChildren(cteProducer);
}
@Override
public Plan visitLogicalCTEConsumer(LogicalCTEConsumer cteConsumer, PruneContext context) {
return super.visitLogicalCTEConsumer(cteConsumer, context);
}
private Plan pruneAggregate(Aggregate<?> agg, PruneContext context) {
// first try to prune group by and aggregate functions
Aggregate prunedOutputAgg = pruneOutput(agg, agg.getOutputs(), agg::pruneOutputs, context);
Aggregate fillUpAggr = fillUpGroupByAndOutput(prunedOutputAgg);
return pruneChildren(fillUpAggr);
Aggregate<? extends Plan> prunedOutputAgg = pruneOutput(agg, agg.getOutputs(), agg::pruneOutputs, context);
Aggregate<?> fillUpAggregate = fillUpGroupByAndOutput(prunedOutputAgg);
return pruneChildren(fillUpAggregate);
}
private Plan skipPruneThisAndFirstLevelChildren(Plan plan) {
@ -217,7 +226,7 @@ public class ColumnPruning extends DefaultPlanRewriter<PruneContext> implements
return pruneChildren(plan, requireAllOutputOfChildren.build());
}
private static Aggregate<Plan> fillUpGroupByAndOutput(Aggregate<Plan> prunedOutputAgg) {
private static Aggregate<? extends Plan> fillUpGroupByAndOutput(Aggregate<? extends Plan> prunedOutputAgg) {
List<Expression> groupBy = prunedOutputAgg.getGroupByExpressions();
List<NamedExpression> output = prunedOutputAgg.getOutputExpressions();
@ -239,12 +248,11 @@ public class ColumnPruning extends DefaultPlanRewriter<PruneContext> implements
ImmutableList.Builder<Expression> newGroupByExprList
= ImmutableList.builderWithExpectedSize(newOutputList.size());
for (NamedExpression e : newOutputList) {
if (!(aggregateFunctions.contains(e)
|| (e instanceof Alias && aggregateFunctions.contains(e.child(0))))) {
if (!(e instanceof Alias && aggregateFunctions.contains(e.child(0)))) {
newGroupByExprList.add(e);
}
}
return ((LogicalAggregate<Plan>) prunedOutputAgg).withGroupByAndOutput(
return ((LogicalAggregate<? extends Plan>) prunedOutputAgg).withGroupByAndOutput(
newGroupByExprList.build(), newOutputList);
}
@ -371,11 +379,6 @@ public class ColumnPruning extends DefaultPlanRewriter<PruneContext> implements
return prunedChild;
}
@Override
public Plan visitLogicalCTEProducer(LogicalCTEProducer<? extends Plan> cteProducer, PruneContext context) {
return skipPruneThisAndFirstLevelChildren(cteProducer);
}
/** PruneContext */
public static class PruneContext {
public Set<Slot> requiredSlots;

View File

@ -26,6 +26,7 @@ import org.apache.doris.nereids.properties.PhysicalProperties;
import org.apache.doris.nereids.trees.expressions.CTEId;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.RelationId;
@ -41,7 +42,6 @@ import org.apache.doris.nereids.util.ExpressionUtils;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import org.apache.commons.collections.CollectionUtils;
import java.util.HashSet;
import java.util.List;
@ -109,10 +109,17 @@ public class RewriteCteChildren extends DefaultPlanRewriter<CascadesContext> imp
} else {
child = (LogicalPlan) cteProducer.child();
child = tryToConstructFilter(cascadesContext, cteProducer.getCteId(), child);
Set<NamedExpression> projects = cascadesContext.getProjectForProducer(cteProducer.getCteId());
if (CollectionUtils.isNotEmpty(projects)
&& cascadesContext.couldPruneColumnOnProducer(cteProducer.getCteId())) {
child = new LogicalProject<>(ImmutableList.copyOf(projects), child);
Set<Slot> producerOutputs = cascadesContext.getStatementContext()
.getCteIdToOutputIds().get(cteProducer.getCteId());
if (producerOutputs.size() < child.getOutput().size()) {
ImmutableList.Builder<NamedExpression> projectsBuilder
= ImmutableList.builderWithExpectedSize(producerOutputs.size());
for (Slot slot : child.getOutput()) {
if (producerOutputs.contains(slot)) {
projectsBuilder.add(slot);
}
}
child = new LogicalProject<>(projectsBuilder.build(), child);
child = pushPlanUnderAnchor(child);
}
CascadesContext rewrittenCtx = CascadesContext.newSubtreeContext(

View File

@ -20,6 +20,7 @@ package org.apache.doris.nereids.trees.plans.logical;
import org.apache.doris.nereids.memo.GroupExpression;
import org.apache.doris.nereids.properties.LogicalProperties;
import org.apache.doris.nereids.trees.expressions.CTEId;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.StatementScopeIdGenerator;
@ -36,6 +37,7 @@ import com.google.common.collect.ImmutableList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Objects;
import java.util.Optional;
@ -43,7 +45,7 @@ import java.util.Optional;
* LogicalCTEConsumer
*/
//TODO: find cte producer and propagate its functional dependencies
public class LogicalCTEConsumer extends LogicalRelation implements BlockFuncDepsPropagation {
public class LogicalCTEConsumer extends LogicalRelation implements BlockFuncDepsPropagation, OutputPrunable {
private final String name;
private final CTEId cteId;
@ -145,6 +147,24 @@ public class LogicalCTEConsumer extends LogicalRelation implements BlockFuncDeps
return ImmutableList.copyOf(producerToConsumerOutputMap.values());
}
@Override
public Plan pruneOutputs(List<NamedExpression> prunedOutputs) {
Map<Slot, Slot> consumerToProducerOutputMap = new LinkedHashMap<>(this.consumerToProducerOutputMap.size());
Map<Slot, Slot> producerToConsumerOutputMap = new LinkedHashMap<>(this.consumerToProducerOutputMap.size());
for (Entry<Slot, Slot> consumerToProducerSlot : this.consumerToProducerOutputMap.entrySet()) {
if (prunedOutputs.contains(consumerToProducerSlot.getKey())) {
consumerToProducerOutputMap.put(consumerToProducerSlot.getKey(), consumerToProducerSlot.getValue());
producerToConsumerOutputMap.put(consumerToProducerSlot.getValue(), consumerToProducerSlot.getKey());
}
}
return withTwoMaps(consumerToProducerOutputMap, producerToConsumerOutputMap);
}
@Override
public List<NamedExpression> getOutputs() {
return (List) this.getOutput();
}
public CTEId getCteId() {
return cteId;
}