[opt](nereids) update CTEConsumer's stats when CTEProducer's stats updated (#21469)
This commit is contained in:
@ -20,6 +20,7 @@ package org.apache.doris.nereids;
|
||||
import org.apache.doris.catalog.Database;
|
||||
import org.apache.doris.catalog.Env;
|
||||
import org.apache.doris.catalog.Table;
|
||||
import org.apache.doris.common.Pair;
|
||||
import org.apache.doris.nereids.analyzer.Scope;
|
||||
import org.apache.doris.nereids.analyzer.UnboundRelation;
|
||||
import org.apache.doris.nereids.jobs.Job;
|
||||
@ -34,6 +35,7 @@ import org.apache.doris.nereids.jobs.scheduler.JobScheduler;
|
||||
import org.apache.doris.nereids.jobs.scheduler.JobStack;
|
||||
import org.apache.doris.nereids.jobs.scheduler.ScheduleContext;
|
||||
import org.apache.doris.nereids.jobs.scheduler.SimpleJobScheduler;
|
||||
import org.apache.doris.nereids.memo.Group;
|
||||
import org.apache.doris.nereids.memo.Memo;
|
||||
import org.apache.doris.nereids.processor.post.RuntimeFilterContext;
|
||||
import org.apache.doris.nereids.properties.PhysicalProperties;
|
||||
@ -44,6 +46,7 @@ import org.apache.doris.nereids.rules.RuleType;
|
||||
import org.apache.doris.nereids.rules.analysis.BindRelation.CustomTableResolver;
|
||||
import org.apache.doris.nereids.trees.expressions.CTEId;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.Slot;
|
||||
import org.apache.doris.nereids.trees.expressions.SubqueryExpr;
|
||||
import org.apache.doris.nereids.trees.plans.Plan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalCTE;
|
||||
@ -54,6 +57,8 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalSubQueryAlias;
|
||||
import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter;
|
||||
import org.apache.doris.qe.ConnectContext;
|
||||
import org.apache.doris.qe.SessionVariable;
|
||||
import org.apache.doris.statistics.ColumnStatistic;
|
||||
import org.apache.doris.statistics.Statistics;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import com.google.common.collect.ImmutableList;
|
||||
@ -63,6 +68,7 @@ import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Map.Entry;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
import java.util.Stack;
|
||||
@ -103,6 +109,9 @@ public class CascadesContext implements ScheduleContext {
|
||||
private Map<Integer, Set<Expression>> consumerIdToFilters = new HashMap<>();
|
||||
private Map<CTEId, Set<Integer>> cteIdToConsumerUnderProjects = new HashMap<>();
|
||||
|
||||
// Used to update consumer's stats
|
||||
private Map<CTEId, List<Pair<Map<Slot, Slot>, Group>>> cteIdToConsumerGroup = new HashMap<>();
|
||||
|
||||
public CascadesContext(Plan plan, Memo memo, StatementContext statementContext,
|
||||
PhysicalProperties requestProperties) {
|
||||
this(plan, memo, statementContext, new CTEContext(), requestProperties);
|
||||
@ -565,4 +574,25 @@ public class CascadesContext implements ScheduleContext {
|
||||
Set<Integer> consumerIds = this.cteIdToConsumerUnderProjects.get(cteId);
|
||||
return consumerIds.size() == this.cteIdToConsumers.get(cteId).size();
|
||||
}
|
||||
|
||||
public void addCTEConsumerGroup(CTEId cteId, Group g, Map<Slot, Slot> producerSlotToConsumerSlot) {
|
||||
List<Pair<Map<Slot, Slot>, Group>> consumerGroups =
|
||||
this.cteIdToConsumerGroup.computeIfAbsent(cteId, k -> new ArrayList<>());
|
||||
consumerGroups.add(Pair.of(producerSlotToConsumerSlot, g));
|
||||
}
|
||||
|
||||
/**
|
||||
* Update CTE consumer group as producer's stats update
|
||||
*/
|
||||
public void updateConsumerStats(CTEId cteId, Statistics statistics) {
|
||||
List<Pair<Map<Slot, Slot>, Group>> consumerGroups = this.cteIdToConsumerGroup.get(cteId);
|
||||
for (Pair<Map<Slot, Slot>, Group> p : consumerGroups) {
|
||||
Map<Slot, Slot> producerSlotToConsumerSlot = p.first;
|
||||
Statistics updatedConsumerStats = new Statistics(statistics);
|
||||
for (Entry<Expression, ColumnStatistic> entry : statistics.columnStatistics().entrySet()) {
|
||||
updatedConsumerStats.addColumnStats(producerSlotToConsumerSlot.get(entry.getKey()), entry.getValue());
|
||||
}
|
||||
p.value().setStatistics(updatedConsumerStats);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -245,7 +245,8 @@ public class CostAndEnforcerJob extends Job implements Cloneable {
|
||||
StatsCalculator statsCalculator = StatsCalculator.estimate(groupExpression,
|
||||
context.getCascadesContext().getConnectContext().getSessionVariable().getForbidUnknownColStats(),
|
||||
context.getCascadesContext().getConnectContext().getTotalColumnStatisticMap(),
|
||||
context.getCascadesContext().getConnectContext().getSessionVariable().isPlayNereidsDump());
|
||||
context.getCascadesContext().getConnectContext().getSessionVariable().isPlayNereidsDump(),
|
||||
context.getCascadesContext());
|
||||
if (!context.getCascadesContext().getConnectContext().getSessionVariable().isPlayNereidsDump()
|
||||
&& context.getCascadesContext().getConnectContext().getSessionVariable().isEnableMinidump()) {
|
||||
context.getCascadesContext().getConnectContext().getTotalColumnStatisticMap()
|
||||
|
||||
@ -105,7 +105,8 @@ public class DeriveStatsJob extends Job {
|
||||
context.getCascadesContext().getConnectContext().getSessionVariable().getForbidUnknownColStats(),
|
||||
context.getCascadesContext().getConnectContext().getTotalColumnStatisticMap(),
|
||||
context.getCascadesContext().getConnectContext().getSessionVariable().isPlayNereidsDump(),
|
||||
cteIdToStats);
|
||||
cteIdToStats,
|
||||
context.getCascadesContext());
|
||||
STATS_STATE_TRACER.log(StatsStateEvent.of(groupExpression,
|
||||
groupExpression.getOwnerGroup().getStatistics()));
|
||||
if (ConnectContext.get().getSessionVariable().isEnableMinidump()
|
||||
|
||||
@ -24,6 +24,7 @@ import org.apache.doris.catalog.SchemaTable;
|
||||
import org.apache.doris.catalog.TableIf;
|
||||
import org.apache.doris.common.Config;
|
||||
import org.apache.doris.common.Pair;
|
||||
import org.apache.doris.nereids.CascadesContext;
|
||||
import org.apache.doris.nereids.exceptions.AnalysisException;
|
||||
import org.apache.doris.nereids.memo.Group;
|
||||
import org.apache.doris.nereids.memo.GroupExpression;
|
||||
@ -158,14 +159,17 @@ public class StatsCalculator extends DefaultPlanVisitor<Statistics, Void> {
|
||||
|
||||
private Map<CTEId, Statistics> cteIdToStats;
|
||||
|
||||
private CascadesContext cascadesContext;
|
||||
|
||||
private StatsCalculator(GroupExpression groupExpression, boolean forbidUnknownColStats,
|
||||
Map<String, ColumnStatistic> columnStatisticMap, boolean isPlayNereidsDump,
|
||||
Map<CTEId, Statistics> cteIdToStats) {
|
||||
Map<CTEId, Statistics> cteIdToStats, CascadesContext context) {
|
||||
this.groupExpression = groupExpression;
|
||||
this.forbidUnknownColStats = forbidUnknownColStats;
|
||||
this.totalColumnStatisticMap = columnStatisticMap;
|
||||
this.isPlayNereidsDump = isPlayNereidsDump;
|
||||
this.cteIdToStats = Objects.requireNonNull(cteIdToStats, "CTEIdToStats can't be null");
|
||||
this.cascadesContext = context;
|
||||
}
|
||||
|
||||
public Map<String, Histogram> getTotalHistogramMap() {
|
||||
@ -189,25 +193,26 @@ public class StatsCalculator extends DefaultPlanVisitor<Statistics, Void> {
|
||||
*/
|
||||
public static StatsCalculator estimate(GroupExpression groupExpression, boolean forbidUnknownColStats,
|
||||
Map<String, ColumnStatistic> columnStatisticMap, boolean isPlayNereidsDump,
|
||||
Map<CTEId, Statistics> cteIdToStats) {
|
||||
Map<CTEId, Statistics> cteIdToStats, CascadesContext context) {
|
||||
StatsCalculator statsCalculator = new StatsCalculator(
|
||||
groupExpression, forbidUnknownColStats, columnStatisticMap, isPlayNereidsDump, cteIdToStats);
|
||||
groupExpression, forbidUnknownColStats, columnStatisticMap, isPlayNereidsDump, cteIdToStats, context);
|
||||
statsCalculator.estimate();
|
||||
return statsCalculator;
|
||||
}
|
||||
|
||||
public static StatsCalculator estimate(GroupExpression groupExpression, boolean forbidUnknownColStats,
|
||||
Map<String, ColumnStatistic> columnStatisticMap, boolean isPlayNereidsDump) {
|
||||
Map<String, ColumnStatistic> columnStatisticMap, boolean isPlayNereidsDump, CascadesContext context) {
|
||||
return StatsCalculator.estimate(groupExpression,
|
||||
forbidUnknownColStats,
|
||||
columnStatisticMap,
|
||||
isPlayNereidsDump,
|
||||
new HashMap<>());
|
||||
new HashMap<>(), context);
|
||||
}
|
||||
|
||||
public static void estimate(GroupExpression groupExpression) {
|
||||
// For unit test only
|
||||
public static void estimate(GroupExpression groupExpression, CascadesContext context) {
|
||||
StatsCalculator statsCalculator = new StatsCalculator(groupExpression, false,
|
||||
new HashMap<>(), false, Collections.EMPTY_MAP);
|
||||
new HashMap<>(), false, Collections.EMPTY_MAP, context);
|
||||
statsCalculator.estimate();
|
||||
}
|
||||
|
||||
@ -999,6 +1004,8 @@ public class StatsCalculator extends DefaultPlanVisitor<Statistics, Void> {
|
||||
@Override
|
||||
public Statistics visitLogicalCTEConsumer(LogicalCTEConsumer cteConsumer, Void context) {
|
||||
CTEId cteId = cteConsumer.getCteId();
|
||||
cascadesContext.addCTEConsumerGroup(cteConsumer.getCteId(), groupExpression.getOwnerGroup(),
|
||||
cteConsumer.getProducerToConsumerOutputMap());
|
||||
Statistics prodStats = cteIdToStats.get(cteId);
|
||||
Preconditions.checkArgument(prodStats != null, String.format("Stats for CTE: %s not found", cteId));
|
||||
Statistics consumerStats = new Statistics(prodStats.getRowCount(), new HashMap<>());
|
||||
@ -1023,11 +1030,14 @@ public class StatsCalculator extends DefaultPlanVisitor<Statistics, Void> {
|
||||
Void context) {
|
||||
Statistics statistics = groupExpression.childStatistics(0);
|
||||
cteIdToStats.put(cteProducer.getCteId(), statistics);
|
||||
cascadesContext.updateConsumerStats(cteProducer.getCteId(), statistics);
|
||||
return statistics;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Statistics visitPhysicalCTEConsumer(PhysicalCTEConsumer cteConsumer, Void context) {
|
||||
cascadesContext.addCTEConsumerGroup(cteConsumer.getCteId(), groupExpression.getOwnerGroup(),
|
||||
cteConsumer.getProducerToConsumerSlotMap());
|
||||
CTEId cteId = cteConsumer.getCteId();
|
||||
Statistics prodStats = cteIdToStats.get(cteId);
|
||||
if (prodStats == null) {
|
||||
|
||||
@ -144,14 +144,14 @@ public class StatsCalculatorTest {
|
||||
GroupExpression groupExpression = new GroupExpression(logicalFilter, ImmutableList.of(childGroup));
|
||||
Group ownerGroup = newGroup();
|
||||
groupExpression.setOwnerGroup(ownerGroup);
|
||||
StatsCalculator.estimate(groupExpression);
|
||||
StatsCalculator.estimate(groupExpression, null);
|
||||
Assertions.assertEquals((10000 * 0.1 * 0.05), ownerGroup.getStatistics().getRowCount(), 0.001);
|
||||
|
||||
LogicalFilter<GroupPlan> logicalFilterOr = new LogicalFilter<>(or, groupPlan);
|
||||
GroupExpression groupExpressionOr = new GroupExpression(logicalFilterOr, ImmutableList.of(childGroup));
|
||||
Group ownerGroupOr = newGroup();
|
||||
groupExpressionOr.setOwnerGroup(ownerGroupOr);
|
||||
StatsCalculator.estimate(groupExpressionOr);
|
||||
StatsCalculator.estimate(groupExpressionOr, null);
|
||||
Assertions.assertEquals((long) (10000 * (0.1 + 0.05 - 0.1 * 0.05)),
|
||||
ownerGroupOr.getStatistics().getRowCount(), 0.001);
|
||||
}
|
||||
@ -197,14 +197,14 @@ public class StatsCalculatorTest {
|
||||
GroupExpression groupExpression = new GroupExpression(logicalFilter, ImmutableList.of(childGroup));
|
||||
Group ownerGroup = newGroup();
|
||||
groupExpression.setOwnerGroup(ownerGroup);
|
||||
StatsCalculator.estimate(groupExpression);
|
||||
StatsCalculator.estimate(groupExpression, null);
|
||||
Assertions.assertEquals(0, ownerGroup.getStatistics().getRowCount(), 0.001);
|
||||
|
||||
LogicalFilter<GroupPlan> logicalFilterOr = new LogicalFilter<>(or, groupPlan);
|
||||
GroupExpression groupExpressionOr = new GroupExpression(logicalFilterOr, ImmutableList.of(childGroup));
|
||||
Group ownerGroupOr = newGroup();
|
||||
groupExpressionOr.setOwnerGroup(ownerGroupOr);
|
||||
StatsCalculator.estimate(groupExpressionOr);
|
||||
StatsCalculator.estimate(groupExpressionOr, null);
|
||||
Assertions.assertEquals(0, ownerGroupOr.getStatistics().getRowCount(), 0.001);
|
||||
}
|
||||
// TODO: temporary disable this test, until we could get column stats
|
||||
@ -259,7 +259,7 @@ public class StatsCalculatorTest {
|
||||
GroupExpression groupExpression = new GroupExpression(logicalOlapScan1, ImmutableList.of(childGroup));
|
||||
Group ownerGroup = newGroup();
|
||||
groupExpression.setOwnerGroup(ownerGroup);
|
||||
StatsCalculator.estimate(groupExpression);
|
||||
StatsCalculator.estimate(groupExpression, null);
|
||||
Statistics stats = ownerGroup.getStatistics();
|
||||
Assertions.assertEquals(1, stats.columnStatistics().size());
|
||||
Assertions.assertNotNull(stats.columnStatistics().get(slot1));
|
||||
@ -289,7 +289,7 @@ public class StatsCalculatorTest {
|
||||
GroupExpression groupExpression = new GroupExpression(logicalLimit, ImmutableList.of(childGroup));
|
||||
Group ownerGroup = newGroup();
|
||||
ownerGroup.addGroupExpression(groupExpression);
|
||||
StatsCalculator.estimate(groupExpression);
|
||||
StatsCalculator.estimate(groupExpression, null);
|
||||
Statistics limitStats = ownerGroup.getStatistics();
|
||||
Assertions.assertEquals(1, limitStats.getRowCount());
|
||||
ColumnStatistic slot1Stats = limitStats.columnStatistics().get(slot1);
|
||||
@ -319,7 +319,7 @@ public class StatsCalculatorTest {
|
||||
GroupExpression groupExpression = new GroupExpression(logicalTopN, ImmutableList.of(childGroup));
|
||||
Group ownerGroup = newGroup();
|
||||
ownerGroup.addGroupExpression(groupExpression);
|
||||
StatsCalculator.estimate(groupExpression);
|
||||
StatsCalculator.estimate(groupExpression, null);
|
||||
Statistics topNStats = ownerGroup.getStatistics();
|
||||
Assertions.assertEquals(1, topNStats.getRowCount());
|
||||
ColumnStatistic slot1Stats = topNStats.columnStatistics().get(slot1);
|
||||
|
||||
Reference in New Issue
Block a user