[Feature](Nereids) Support grouping set for materialized index. (#15383)

This PR adds support for materialized index selecting when the query has grouping sets.
This commit is contained in:
Shuo Wang
2022-12-29 23:17:02 +08:00
committed by GitHub
parent dda505487c
commit 6c847daba0
4 changed files with 224 additions and 14 deletions

View File

@ -98,14 +98,14 @@ public class NereidsRewriteJobExecutor extends BatchRulesJob {
.add(topDownBatch(ImmutableList.of(new EliminateFilter())))
.add(topDownBatch(ImmutableList.of(new PruneOlapScanPartition())))
.add(topDownBatch(ImmutableList.of(new CountDistinctRewrite())))
.add(topDownBatch(ImmutableList.of(new SelectMaterializedIndexWithAggregate())))
.add(topDownBatch(ImmutableList.of(new SelectMaterializedIndexWithoutAggregate())))
.add(topDownBatch(ImmutableList.of(new PruneOlapScanTablet())))
// we need to execute this rule at the end of rewrite
// to avoid two consecutive same project appear when we do optimization.
.add(topDownBatch(ImmutableList.of(new EliminateGroupByConstant())))
.add(topDownBatch(ImmutableList.of(new EliminateOrderByConstant())))
.add(topDownBatch(ImmutableList.of(new EliminateUnnecessaryProject())))
.add(topDownBatch(ImmutableList.of(new SelectMaterializedIndexWithAggregate())))
.add(topDownBatch(ImmutableList.of(new SelectMaterializedIndexWithoutAggregate())))
.add(topDownBatch(ImmutableList.of(new PruneOlapScanTablet())))
.add(topDownBatch(ImmutableList.of(new EliminateAggregate())))
.add(bottomUpBatch(ImmutableList.of(new MergeSetOperations())))
.add(topDownBatch(ImmutableList.of(new LimitPushDown())))

View File

@ -154,6 +154,11 @@ public enum RuleType {
MATERIALIZED_INDEX_AGG_PROJECT_SCAN(RuleTypeClass.REWRITE),
MATERIALIZED_INDEX_AGG_PROJECT_FILTER_SCAN(RuleTypeClass.REWRITE),
MATERIALIZED_INDEX_AGG_FILTER_PROJECT_SCAN(RuleTypeClass.REWRITE),
MATERIALIZED_INDEX_AGG_REPEAT_SCAN(RuleTypeClass.REWRITE),
MATERIALIZED_INDEX_AGG_REPEAT_FILTER_SCAN(RuleTypeClass.REWRITE),
MATERIALIZED_INDEX_AGG_REPEAT_PROJECT_SCAN(RuleTypeClass.REWRITE),
MATERIALIZED_INDEX_AGG_REPEAT_PROJECT_FILTER_SCAN(RuleTypeClass.REWRITE),
MATERIALIZED_INDEX_AGG_REPEAT_FILTER_PROJECT_SCAN(RuleTypeClass.REWRITE),
MATERIALIZED_INDEX_SCAN(RuleTypeClass.REWRITE),
MATERIALIZED_INDEX_FILTER_SCAN(RuleTypeClass.REWRITE),
MATERIALIZED_INDEX_PROJECT_SCAN(RuleTypeClass.REWRITE),

View File

@ -32,6 +32,7 @@ import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.VirtualSlotReference;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.agg.BitmapUnionCount;
import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
@ -52,6 +53,7 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.logical.LogicalRepeat;
import org.apache.doris.nereids.util.ExpressionUtils;
import com.google.common.base.Preconditions;
@ -160,7 +162,8 @@ public class SelectMaterializedIndexWithAggregate extends AbstractSelectMaterial
ImmutableSet.of(),
extractAggFunctionAndReplaceSlot(agg,
Optional.of(project)),
agg.getGroupByExpressions()
ExpressionUtils.replace(agg.getGroupByExpressions(),
project.getAliasToProducer())
);
if (result.exprRewriteMap.isEmpty()) {
@ -262,7 +265,200 @@ public class SelectMaterializedIndexWithAggregate extends AbstractSelectMaterial
filter.withChildren(newProject)
);
}
}).toRule(RuleType.MATERIALIZED_INDEX_AGG_FILTER_PROJECT_SCAN)
}).toRule(RuleType.MATERIALIZED_INDEX_AGG_FILTER_PROJECT_SCAN),
// only agg above scan
// Aggregate(Repeat(Scan))
logicalAggregate(logicalRepeat(logicalOlapScan().when(this::shouldSelectIndex))).then(agg -> {
LogicalRepeat<LogicalOlapScan> repeat = agg.child();
LogicalOlapScan scan = repeat.child();
SelectResult result = select(
scan,
agg.getInputSlots(),
ImmutableSet.of(),
extractAggFunctionAndReplaceSlot(agg, Optional.empty()),
nonVirtualGroupByExprs(agg));
if (result.exprRewriteMap.isEmpty()) {
return agg.withChildren(
repeat.withChildren(
scan.withMaterializedIndexSelected(result.preAggStatus, result.indexId))
);
} else {
return new LogicalAggregate<>(
agg.getGroupByExpressions(),
replaceAggOutput(agg, Optional.empty(), Optional.empty(), result.exprRewriteMap),
agg.isNormalized(),
agg.getSourceRepeat(),
repeat.withChildren(
scan.withMaterializedIndexSelected(result.preAggStatus, result.indexId))
);
}
}).toRule(RuleType.MATERIALIZED_INDEX_AGG_REPEAT_SCAN),
// filter could push down scan.
// Aggregate(Repeat(Filter(Scan)))
logicalAggregate(logicalRepeat(logicalFilter(logicalOlapScan().when(this::shouldSelectIndex))))
.then(agg -> {
LogicalRepeat<LogicalFilter<LogicalOlapScan>> repeat = agg.child();
LogicalFilter<LogicalOlapScan> filter = repeat.child();
LogicalOlapScan scan = filter.child();
ImmutableSet<Slot> requiredSlots = ImmutableSet.<Slot>builder()
.addAll(agg.getInputSlots())
.addAll(filter.getInputSlots())
.build();
SelectResult result = select(
scan,
requiredSlots,
filter.getConjuncts(),
extractAggFunctionAndReplaceSlot(agg, Optional.empty()),
nonVirtualGroupByExprs(agg)
);
if (result.exprRewriteMap.isEmpty()) {
return agg.withChildren(
repeat.withChildren(
filter.withChildren(
scan.withMaterializedIndexSelected(result.preAggStatus,
result.indexId))
));
} else {
return new LogicalAggregate<>(
agg.getGroupByExpressions(),
replaceAggOutput(agg, Optional.empty(), Optional.empty(),
result.exprRewriteMap),
agg.isNormalized(),
agg.getSourceRepeat(),
// Not that no need to replace slots in the filter, because the slots to replace
// are value columns, which shouldn't appear in filters.
repeat.withChildren(filter.withChildren(
scan.withMaterializedIndexSelected(result.preAggStatus,
result.indexId)))
);
}
}).toRule(RuleType.MATERIALIZED_INDEX_AGG_REPEAT_FILTER_SCAN),
// column pruning or other projections such as alias, etc.
// Aggregate(Repeat(Project(Scan)))
logicalAggregate(logicalRepeat(logicalProject(logicalOlapScan().when(this::shouldSelectIndex))))
.then(agg -> {
LogicalRepeat<LogicalProject<LogicalOlapScan>> repeat = agg.child();
LogicalProject<LogicalOlapScan> project = repeat.child();
LogicalOlapScan scan = project.child();
SelectResult result = select(
scan,
project.getInputSlots(),
ImmutableSet.of(),
extractAggFunctionAndReplaceSlot(agg,
Optional.of(project)),
ExpressionUtils.replace(nonVirtualGroupByExprs(agg),
project.getAliasToProducer())
);
if (result.exprRewriteMap.isEmpty()) {
return agg.withChildren(
repeat.withChildren(
project.withChildren(
scan.withMaterializedIndexSelected(result.preAggStatus,
result.indexId)
))
);
} else {
List<NamedExpression> newProjectList = replaceProjectList(project,
result.exprRewriteMap.projectExprMap);
LogicalProject<LogicalOlapScan> newProject = new LogicalProject<>(
newProjectList,
scan.withMaterializedIndexSelected(result.preAggStatus, result.indexId));
return new LogicalAggregate<>(
agg.getGroupByExpressions(),
replaceAggOutput(agg, Optional.of(project), Optional.of(newProject),
result.exprRewriteMap),
agg.isNormalized(),
agg.getSourceRepeat(),
repeat.withChildren(newProject)
);
}
}).toRule(RuleType.MATERIALIZED_INDEX_AGG_REPEAT_PROJECT_SCAN),
// filter could push down and project.
// Aggregate(Repeat(Project(Filter(Scan))))
logicalAggregate(logicalRepeat(logicalProject(logicalFilter(logicalOlapScan()
.when(this::shouldSelectIndex))))).then(agg -> {
LogicalRepeat<LogicalProject<LogicalFilter<LogicalOlapScan>>> repeat = agg.child();
LogicalProject<LogicalFilter<LogicalOlapScan>> project = repeat.child();
LogicalFilter<LogicalOlapScan> filter = project.child();
LogicalOlapScan scan = filter.child();
Set<Slot> requiredSlots = Stream.concat(
project.getInputSlots().stream(), filter.getInputSlots().stream())
.collect(Collectors.toSet());
SelectResult result = select(
scan,
requiredSlots,
filter.getConjuncts(),
extractAggFunctionAndReplaceSlot(agg, Optional.of(project)),
ExpressionUtils.replace(nonVirtualGroupByExprs(agg),
project.getAliasToProducer())
);
if (result.exprRewriteMap.isEmpty()) {
return agg.withChildren(repeat.withChildren(project.withChildren(filter.withChildren(
scan.withMaterializedIndexSelected(result.preAggStatus, result.indexId))
)));
} else {
List<NamedExpression> newProjectList = replaceProjectList(project,
result.exprRewriteMap.projectExprMap);
LogicalProject<Plan> newProject = new LogicalProject<>(newProjectList,
filter.withChildren(scan.withMaterializedIndexSelected(result.preAggStatus,
result.indexId)));
return new LogicalAggregate<>(
agg.getGroupByExpressions(),
replaceAggOutput(agg, Optional.of(project), Optional.of(newProject),
result.exprRewriteMap),
agg.isNormalized(),
agg.getSourceRepeat(),
repeat.withChildren(newProject)
);
}
}).toRule(RuleType.MATERIALIZED_INDEX_AGG_REPEAT_PROJECT_FILTER_SCAN),
// filter can't push down
// Aggregate(Repeat(Filter(Project(Scan))))
logicalAggregate(logicalRepeat(logicalFilter(logicalProject(logicalOlapScan()
.when(this::shouldSelectIndex))))).then(agg -> {
LogicalRepeat<LogicalFilter<LogicalProject<LogicalOlapScan>>> repeat = agg.child();
LogicalFilter<LogicalProject<LogicalOlapScan>> filter = repeat.child();
LogicalProject<LogicalOlapScan> project = filter.child();
LogicalOlapScan scan = project.child();
SelectResult result = select(
scan,
project.getInputSlots(),
ImmutableSet.of(),
extractAggFunctionAndReplaceSlot(agg, Optional.of(project)),
ExpressionUtils.replace(nonVirtualGroupByExprs(agg),
project.getAliasToProducer())
);
if (result.exprRewriteMap.isEmpty()) {
return agg.withChildren(repeat.withChildren(filter.withChildren(project.withChildren(
scan.withMaterializedIndexSelected(result.preAggStatus, result.indexId))
)));
} else {
List<NamedExpression> newProjectList = replaceProjectList(project,
result.exprRewriteMap.projectExprMap);
LogicalProject<Plan> newProject = new LogicalProject<>(newProjectList,
scan.withMaterializedIndexSelected(result.preAggStatus, result.indexId));
return new LogicalAggregate<>(
agg.getGroupByExpressions(),
replaceAggOutput(agg, Optional.of(project), Optional.of(newProject),
result.exprRewriteMap),
agg.isNormalized(),
agg.getSourceRepeat(),
repeat.withChildren(filter.withChildren(newProject))
);
}
}).toRule(RuleType.MATERIALIZED_INDEX_AGG_REPEAT_FILTER_PROJECT_SCAN)
);
}
@ -284,9 +480,13 @@ public class SelectMaterializedIndexWithAggregate extends AbstractSelectMaterial
Set<Expression> predicates,
List<AggregateFunction> aggregateFunctions,
List<Expression> groupingExprs) {
Preconditions.checkArgument(scan.getOutputSet().containsAll(requiredScanOutput),
// remove virtual slot for grouping sets.
Set<Slot> nonVirtualRequiredScanOutput = requiredScanOutput.stream()
.filter(slot -> !(slot instanceof VirtualSlotReference))
.collect(ImmutableSet.toImmutableSet());
Preconditions.checkArgument(scan.getOutputSet().containsAll(nonVirtualRequiredScanOutput),
String.format("Scan's output (%s) should contains all the input required scan output (%s).",
scan.getOutput(), requiredScanOutput));
scan.getOutput(), nonVirtualRequiredScanOutput));
OlapTable table = scan.getTable();
@ -303,7 +503,7 @@ public class SelectMaterializedIndexWithAggregate extends AbstractSelectMaterial
return new SelectResult(preAggStatus, scan.getTable().getBaseIndexId(), new ExprRewriteMap());
} else {
List<MaterializedIndex> rollupsWithAllRequiredCols = table.getVisibleIndex().stream()
.filter(index -> containAllRequiredColumns(index, scan, requiredScanOutput))
.filter(index -> containAllRequiredColumns(index, scan, nonVirtualRequiredScanOutput))
.collect(Collectors.toList());
return new SelectResult(preAggStatus, selectBestIndex(rollupsWithAllRequiredCols, scan, predicates),
new ExprRewriteMap());
@ -328,7 +528,8 @@ public class SelectMaterializedIndexWithAggregate extends AbstractSelectMaterial
ImmutableList.of())
.stream()
.filter(index -> !candidatesWithoutRewriting.contains(index))
.map(index -> rewriteAgg(index, scan, requiredScanOutput, predicates, aggregateFunctions,
.map(index -> rewriteAgg(index, scan, nonVirtualRequiredScanOutput, predicates,
aggregateFunctions,
groupingExprs))
.filter(aggRewriteResult -> checkPreAggStatus(scan, aggRewriteResult.index.getId(),
predicates,
@ -340,7 +541,7 @@ public class SelectMaterializedIndexWithAggregate extends AbstractSelectMaterial
List<MaterializedIndex> haveAllRequiredColumns = Streams.concat(
candidatesWithoutRewriting.stream()
.filter(index -> containAllRequiredColumns(index, scan, requiredScanOutput)),
.filter(index -> containAllRequiredColumns(index, scan, nonVirtualRequiredScanOutput)),
candidatesWithRewriting
.stream()
.filter(aggRewriteResult -> containAllRequiredColumns(aggRewriteResult.index, scan,
@ -995,4 +1196,10 @@ public class SelectMaterializedIndexWithAggregate extends AbstractSelectMaterial
.map(expr -> (NamedExpression) ExpressionUtils.replace(expr, projectMap))
.collect(ImmutableList.toImmutableList());
}
private List<Expression> nonVirtualGroupByExprs(LogicalAggregate<? extends Plan> agg) {
return agg.getGroupByExpressions().stream()
.filter(expr -> !(expr instanceof VirtualSlotReference))
.collect(ImmutableList.toImmutableList());
}
}

View File

@ -223,9 +223,8 @@ public class SelectMvIndexTest extends BaseMaterializedIndexSelectTest implement
/**
* Aggregation query with groupSets at coarser level of aggregation than
* aggregation materialized view.
* TODO: enable this when group by rollup is supported.
*/
@Disabled
@Test
public void testGroupingSetQueryOnAggMV() throws Exception {
String createMVSql = "create materialized view " + EMPS_MV_NAME + " as select empid, deptno, sum(salary) "
+ "from " + EMPS_TABLE_NAME + " group by empid, deptno;";
@ -271,9 +270,8 @@ public class SelectMvIndexTest extends BaseMaterializedIndexSelectTest implement
/**
* Query with rollup and arithmetic expr
* TODO: enable this when group by rollup is supported.
*/
@Disabled
@Test
public void testAggQueryOnAggMV10() throws Exception {
String createMVSql = "create materialized view " + EMPS_MV_NAME + " as select deptno, commission, sum(salary) "
+ "from " + EMPS_TABLE_NAME + " group by deptno, commission;";