[opt](nereids) update CTEConsumer's stats when CTEProducer's stats updated (#21469)

This commit is contained in:
AKIRA
2023-07-12 10:55:40 +08:00
committed by GitHub
parent 88c719233a
commit 56c2deadb1
5 changed files with 58 additions and 16 deletions

View File

@ -20,6 +20,7 @@ package org.apache.doris.nereids;
import org.apache.doris.catalog.Database;
import org.apache.doris.catalog.Env;
import org.apache.doris.catalog.Table;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.analyzer.Scope;
import org.apache.doris.nereids.analyzer.UnboundRelation;
import org.apache.doris.nereids.jobs.Job;
@ -34,6 +35,7 @@ import org.apache.doris.nereids.jobs.scheduler.JobScheduler;
import org.apache.doris.nereids.jobs.scheduler.JobStack;
import org.apache.doris.nereids.jobs.scheduler.ScheduleContext;
import org.apache.doris.nereids.jobs.scheduler.SimpleJobScheduler;
import org.apache.doris.nereids.memo.Group;
import org.apache.doris.nereids.memo.Memo;
import org.apache.doris.nereids.processor.post.RuntimeFilterContext;
import org.apache.doris.nereids.properties.PhysicalProperties;
@ -44,6 +46,7 @@ import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.analysis.BindRelation.CustomTableResolver;
import org.apache.doris.nereids.trees.expressions.CTEId;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SubqueryExpr;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalCTE;
@ -54,6 +57,8 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalSubQueryAlias;
import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter;
import org.apache.doris.qe.ConnectContext;
import org.apache.doris.qe.SessionVariable;
import org.apache.doris.statistics.ColumnStatistic;
import org.apache.doris.statistics.Statistics;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
@ -63,6 +68,7 @@ import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Optional;
import java.util.Set;
import java.util.Stack;
@ -103,6 +109,9 @@ public class CascadesContext implements ScheduleContext {
private Map<Integer, Set<Expression>> consumerIdToFilters = new HashMap<>();
private Map<CTEId, Set<Integer>> cteIdToConsumerUnderProjects = new HashMap<>();
// Used to update consumer's stats
private Map<CTEId, List<Pair<Map<Slot, Slot>, Group>>> cteIdToConsumerGroup = new HashMap<>();
public CascadesContext(Plan plan, Memo memo, StatementContext statementContext,
PhysicalProperties requestProperties) {
this(plan, memo, statementContext, new CTEContext(), requestProperties);
@ -565,4 +574,25 @@ public class CascadesContext implements ScheduleContext {
Set<Integer> consumerIds = this.cteIdToConsumerUnderProjects.get(cteId);
return consumerIds.size() == this.cteIdToConsumers.get(cteId).size();
}
public void addCTEConsumerGroup(CTEId cteId, Group g, Map<Slot, Slot> producerSlotToConsumerSlot) {
List<Pair<Map<Slot, Slot>, Group>> consumerGroups =
this.cteIdToConsumerGroup.computeIfAbsent(cteId, k -> new ArrayList<>());
consumerGroups.add(Pair.of(producerSlotToConsumerSlot, g));
}
/**
* Update CTE consumer group as producer's stats update
*/
public void updateConsumerStats(CTEId cteId, Statistics statistics) {
List<Pair<Map<Slot, Slot>, Group>> consumerGroups = this.cteIdToConsumerGroup.get(cteId);
for (Pair<Map<Slot, Slot>, Group> p : consumerGroups) {
Map<Slot, Slot> producerSlotToConsumerSlot = p.first;
Statistics updatedConsumerStats = new Statistics(statistics);
for (Entry<Expression, ColumnStatistic> entry : statistics.columnStatistics().entrySet()) {
updatedConsumerStats.addColumnStats(producerSlotToConsumerSlot.get(entry.getKey()), entry.getValue());
}
p.value().setStatistics(updatedConsumerStats);
}
}
}

View File

@ -245,7 +245,8 @@ public class CostAndEnforcerJob extends Job implements Cloneable {
StatsCalculator statsCalculator = StatsCalculator.estimate(groupExpression,
context.getCascadesContext().getConnectContext().getSessionVariable().getForbidUnknownColStats(),
context.getCascadesContext().getConnectContext().getTotalColumnStatisticMap(),
context.getCascadesContext().getConnectContext().getSessionVariable().isPlayNereidsDump());
context.getCascadesContext().getConnectContext().getSessionVariable().isPlayNereidsDump(),
context.getCascadesContext());
if (!context.getCascadesContext().getConnectContext().getSessionVariable().isPlayNereidsDump()
&& context.getCascadesContext().getConnectContext().getSessionVariable().isEnableMinidump()) {
context.getCascadesContext().getConnectContext().getTotalColumnStatisticMap()

View File

@ -105,7 +105,8 @@ public class DeriveStatsJob extends Job {
context.getCascadesContext().getConnectContext().getSessionVariable().getForbidUnknownColStats(),
context.getCascadesContext().getConnectContext().getTotalColumnStatisticMap(),
context.getCascadesContext().getConnectContext().getSessionVariable().isPlayNereidsDump(),
cteIdToStats);
cteIdToStats,
context.getCascadesContext());
STATS_STATE_TRACER.log(StatsStateEvent.of(groupExpression,
groupExpression.getOwnerGroup().getStatistics()));
if (ConnectContext.get().getSessionVariable().isEnableMinidump()

View File

@ -24,6 +24,7 @@ import org.apache.doris.catalog.SchemaTable;
import org.apache.doris.catalog.TableIf;
import org.apache.doris.common.Config;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.memo.Group;
import org.apache.doris.nereids.memo.GroupExpression;
@ -158,14 +159,17 @@ public class StatsCalculator extends DefaultPlanVisitor<Statistics, Void> {
private Map<CTEId, Statistics> cteIdToStats;
private CascadesContext cascadesContext;
private StatsCalculator(GroupExpression groupExpression, boolean forbidUnknownColStats,
Map<String, ColumnStatistic> columnStatisticMap, boolean isPlayNereidsDump,
Map<CTEId, Statistics> cteIdToStats) {
Map<CTEId, Statistics> cteIdToStats, CascadesContext context) {
this.groupExpression = groupExpression;
this.forbidUnknownColStats = forbidUnknownColStats;
this.totalColumnStatisticMap = columnStatisticMap;
this.isPlayNereidsDump = isPlayNereidsDump;
this.cteIdToStats = Objects.requireNonNull(cteIdToStats, "CTEIdToStats can't be null");
this.cascadesContext = context;
}
public Map<String, Histogram> getTotalHistogramMap() {
@ -189,25 +193,26 @@ public class StatsCalculator extends DefaultPlanVisitor<Statistics, Void> {
*/
public static StatsCalculator estimate(GroupExpression groupExpression, boolean forbidUnknownColStats,
Map<String, ColumnStatistic> columnStatisticMap, boolean isPlayNereidsDump,
Map<CTEId, Statistics> cteIdToStats) {
Map<CTEId, Statistics> cteIdToStats, CascadesContext context) {
StatsCalculator statsCalculator = new StatsCalculator(
groupExpression, forbidUnknownColStats, columnStatisticMap, isPlayNereidsDump, cteIdToStats);
groupExpression, forbidUnknownColStats, columnStatisticMap, isPlayNereidsDump, cteIdToStats, context);
statsCalculator.estimate();
return statsCalculator;
}
public static StatsCalculator estimate(GroupExpression groupExpression, boolean forbidUnknownColStats,
Map<String, ColumnStatistic> columnStatisticMap, boolean isPlayNereidsDump) {
Map<String, ColumnStatistic> columnStatisticMap, boolean isPlayNereidsDump, CascadesContext context) {
return StatsCalculator.estimate(groupExpression,
forbidUnknownColStats,
columnStatisticMap,
isPlayNereidsDump,
new HashMap<>());
new HashMap<>(), context);
}
public static void estimate(GroupExpression groupExpression) {
// For unit test only
public static void estimate(GroupExpression groupExpression, CascadesContext context) {
StatsCalculator statsCalculator = new StatsCalculator(groupExpression, false,
new HashMap<>(), false, Collections.EMPTY_MAP);
new HashMap<>(), false, Collections.EMPTY_MAP, context);
statsCalculator.estimate();
}
@ -999,6 +1004,8 @@ public class StatsCalculator extends DefaultPlanVisitor<Statistics, Void> {
@Override
public Statistics visitLogicalCTEConsumer(LogicalCTEConsumer cteConsumer, Void context) {
CTEId cteId = cteConsumer.getCteId();
cascadesContext.addCTEConsumerGroup(cteConsumer.getCteId(), groupExpression.getOwnerGroup(),
cteConsumer.getProducerToConsumerOutputMap());
Statistics prodStats = cteIdToStats.get(cteId);
Preconditions.checkArgument(prodStats != null, String.format("Stats for CTE: %s not found", cteId));
Statistics consumerStats = new Statistics(prodStats.getRowCount(), new HashMap<>());
@ -1023,11 +1030,14 @@ public class StatsCalculator extends DefaultPlanVisitor<Statistics, Void> {
Void context) {
Statistics statistics = groupExpression.childStatistics(0);
cteIdToStats.put(cteProducer.getCteId(), statistics);
cascadesContext.updateConsumerStats(cteProducer.getCteId(), statistics);
return statistics;
}
@Override
public Statistics visitPhysicalCTEConsumer(PhysicalCTEConsumer cteConsumer, Void context) {
cascadesContext.addCTEConsumerGroup(cteConsumer.getCteId(), groupExpression.getOwnerGroup(),
cteConsumer.getProducerToConsumerSlotMap());
CTEId cteId = cteConsumer.getCteId();
Statistics prodStats = cteIdToStats.get(cteId);
if (prodStats == null) {

View File

@ -144,14 +144,14 @@ public class StatsCalculatorTest {
GroupExpression groupExpression = new GroupExpression(logicalFilter, ImmutableList.of(childGroup));
Group ownerGroup = newGroup();
groupExpression.setOwnerGroup(ownerGroup);
StatsCalculator.estimate(groupExpression);
StatsCalculator.estimate(groupExpression, null);
Assertions.assertEquals((10000 * 0.1 * 0.05), ownerGroup.getStatistics().getRowCount(), 0.001);
LogicalFilter<GroupPlan> logicalFilterOr = new LogicalFilter<>(or, groupPlan);
GroupExpression groupExpressionOr = new GroupExpression(logicalFilterOr, ImmutableList.of(childGroup));
Group ownerGroupOr = newGroup();
groupExpressionOr.setOwnerGroup(ownerGroupOr);
StatsCalculator.estimate(groupExpressionOr);
StatsCalculator.estimate(groupExpressionOr, null);
Assertions.assertEquals((long) (10000 * (0.1 + 0.05 - 0.1 * 0.05)),
ownerGroupOr.getStatistics().getRowCount(), 0.001);
}
@ -197,14 +197,14 @@ public class StatsCalculatorTest {
GroupExpression groupExpression = new GroupExpression(logicalFilter, ImmutableList.of(childGroup));
Group ownerGroup = newGroup();
groupExpression.setOwnerGroup(ownerGroup);
StatsCalculator.estimate(groupExpression);
StatsCalculator.estimate(groupExpression, null);
Assertions.assertEquals(0, ownerGroup.getStatistics().getRowCount(), 0.001);
LogicalFilter<GroupPlan> logicalFilterOr = new LogicalFilter<>(or, groupPlan);
GroupExpression groupExpressionOr = new GroupExpression(logicalFilterOr, ImmutableList.of(childGroup));
Group ownerGroupOr = newGroup();
groupExpressionOr.setOwnerGroup(ownerGroupOr);
StatsCalculator.estimate(groupExpressionOr);
StatsCalculator.estimate(groupExpressionOr, null);
Assertions.assertEquals(0, ownerGroupOr.getStatistics().getRowCount(), 0.001);
}
// TODO: temporary disable this test, until we could get column stats
@ -259,7 +259,7 @@ public class StatsCalculatorTest {
GroupExpression groupExpression = new GroupExpression(logicalOlapScan1, ImmutableList.of(childGroup));
Group ownerGroup = newGroup();
groupExpression.setOwnerGroup(ownerGroup);
StatsCalculator.estimate(groupExpression);
StatsCalculator.estimate(groupExpression, null);
Statistics stats = ownerGroup.getStatistics();
Assertions.assertEquals(1, stats.columnStatistics().size());
Assertions.assertNotNull(stats.columnStatistics().get(slot1));
@ -289,7 +289,7 @@ public class StatsCalculatorTest {
GroupExpression groupExpression = new GroupExpression(logicalLimit, ImmutableList.of(childGroup));
Group ownerGroup = newGroup();
ownerGroup.addGroupExpression(groupExpression);
StatsCalculator.estimate(groupExpression);
StatsCalculator.estimate(groupExpression, null);
Statistics limitStats = ownerGroup.getStatistics();
Assertions.assertEquals(1, limitStats.getRowCount());
ColumnStatistic slot1Stats = limitStats.columnStatistics().get(slot1);
@ -319,7 +319,7 @@ public class StatsCalculatorTest {
GroupExpression groupExpression = new GroupExpression(logicalTopN, ImmutableList.of(childGroup));
Group ownerGroup = newGroup();
ownerGroup.addGroupExpression(groupExpression);
StatsCalculator.estimate(groupExpression);
StatsCalculator.estimate(groupExpression, null);
Statistics topNStats = ownerGroup.getStatistics();
Assertions.assertEquals(1, topNStats.getRowCount());
ColumnStatistic slot1Stats = topNStats.columnStatistics().get(slot1);