[feature](nereids) Support basic aggregate rewrite and function rollup using materialized view (#28269)

Add aggregate materializedviewRules for query rewrite.
it support the query rewrite as following:

    def mv = "select lineitem.L_LINENUMBER, orders.O_CUSTKEY, sum(O_TOTALPRICE) as sum_alias " +
            "from lineitem " +
            "inner join orders on lineitem.L_ORDERKEY = orders.O_ORDERKEY " +
            "group by lineitem.L_LINENUMBER, orders.O_CUSTKEY "
    def query = "select lineitem.L_LINENUMBER, sum(O_TOTALPRICE) as sum_alias " +
            "from lineitem " +
            "inner join orders on lineitem.L_ORDERKEY = orders.O_ORDERKEY " +
            "group by lineitem.L_LINENUMBER"
This commit is contained in:
seawinde
2023-12-15 11:30:02 +08:00
committed by GitHub
parent c4242ab69e
commit 4c51558f6b
28 changed files with 1031 additions and 223 deletions

View File

@ -82,6 +82,10 @@ public class HyperGraph {
return joinEdges;
}
public List<FilterEdge> getFilterEdges() {
return filterEdges;
}
public List<AbstractNode> getNodes() {
return nodes;
}
@ -589,7 +593,7 @@ public class HyperGraph {
*
* @param viewHG the compared hyper graph
* @return null represents not compatible, or return some expression which can
* be pull up from this hyper graph
* be pull up from this hyper graph
*/
public @Nullable List<Expression> isLogicCompatible(HyperGraph viewHG, LogicalCompatibilityContext ctx) {
Map<Edge, Edge> queryToView = constructEdgeMap(viewHG, ctx.getQueryToViewEdgeExpressionMapping());
@ -661,14 +665,15 @@ public class HyperGraph {
long tRight = t.getRightExtendedNodes();
long oLeft = o.getLeftExtendedNodes();
long oRight = o.getRightExtendedNodes();
if (!t.getJoinType().equals(o.getJoinType())) {
if (!t.getJoinType().swap().equals(o.getJoinType())) {
return false;
}
oRight = o.getLeftExtendedNodes();
oLeft = o.getRightExtendedNodes();
if (!t.getJoinType().equals(o.getJoinType()) && !t.getJoinType().swap().equals(o.getJoinType())) {
return false;
}
return compareNodeMap(tLeft, oLeft, nodeMap) && compareNodeMap(tRight, oRight, nodeMap);
boolean matched = false;
if (t.getJoinType().swap().equals(o.getJoinType())) {
matched |= compareNodeMap(tRight, oLeft, nodeMap) && compareNodeMap(tLeft, oRight, nodeMap);
}
matched |= compareNodeMap(tLeft, oLeft, nodeMap) && compareNodeMap(tRight, oRight, nodeMap);
return matched;
}
private boolean compareNodeMap(long bitmap1, long bitmap2, Map<Integer, Integer> nodeIDMap) {

View File

@ -40,6 +40,8 @@ import org.apache.doris.nereids.rules.exploration.join.PushDownProjectThroughInn
import org.apache.doris.nereids.rules.exploration.join.PushDownProjectThroughSemiJoin;
import org.apache.doris.nereids.rules.exploration.join.SemiJoinSemiJoinTranspose;
import org.apache.doris.nereids.rules.exploration.join.SemiJoinSemiJoinTransposeProject;
import org.apache.doris.nereids.rules.exploration.mv.MaterializedViewAggregateRule;
import org.apache.doris.nereids.rules.exploration.mv.MaterializedViewProjectAggregateRule;
import org.apache.doris.nereids.rules.exploration.mv.MaterializedViewProjectJoinRule;
import org.apache.doris.nereids.rules.implementation.AggregateStrategies;
import org.apache.doris.nereids.rules.implementation.LogicalAssertNumRowsToPhysicalAssertNumRows;
@ -223,6 +225,8 @@ public class RuleSet {
public static final List<Rule> MATERIALIZED_VIEW_RULES = planRuleFactories()
.add(MaterializedViewProjectJoinRule.INSTANCE)
.add(MaterializedViewAggregateRule.INSTANCE)
.add(MaterializedViewProjectAggregateRule.INSTANCE)
.build();
public List<Rule> getDPHypReorderRules() {

View File

@ -17,9 +17,293 @@
package org.apache.doris.nereids.rules.exploration.mv;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.HyperGraph;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.edge.JoinEdge;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.node.AbstractNode;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.node.StructInfoNode;
import org.apache.doris.nereids.rules.exploration.mv.StructInfo.PlanSplitContext;
import org.apache.doris.nereids.rules.exploration.mv.mapping.ExpressionMapping;
import org.apache.doris.nereids.rules.exploration.mv.mapping.SlotMapping;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.util.ExpressionUtils;
import com.google.common.collect.HashMultimap;
import com.google.common.collect.Multimap;
import com.google.common.collect.Sets;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
/**
* AbstractMaterializedViewAggregateRule
* This is responsible for common aggregate rewriting
* */
*/
public abstract class AbstractMaterializedViewAggregateRule extends AbstractMaterializedViewRule {
@Override
protected Plan rewriteQueryByView(MatchMode matchMode,
StructInfo queryStructInfo,
StructInfo viewStructInfo,
SlotMapping queryToViewSlotMapping,
Plan tempRewritedPlan,
MaterializationContext materializationContext) {
// get view and query aggregate and top plan correspondingly
Pair<Plan, LogicalAggregate<Plan>> viewTopPlanAndAggPair = splitToTopPlanAndAggregate(viewStructInfo);
if (viewTopPlanAndAggPair == null) {
return null;
}
Pair<Plan, LogicalAggregate<Plan>> queryTopPlanAndAggPair = splitToTopPlanAndAggregate(queryStructInfo);
if (queryTopPlanAndAggPair == null) {
return null;
}
// Firstly, handle query group by expression rewrite
LogicalAggregate<Plan> queryAggregate = queryTopPlanAndAggPair.value();
Plan queryTopPlan = queryTopPlanAndAggPair.key();
// query and view have the same dimension, try to rewrite rewrittenQueryGroupExpr
LogicalAggregate<Plan> viewAggregate = viewTopPlanAndAggPair.value();
Plan viewTopPlan = viewTopPlanAndAggPair.key();
boolean needRollUp =
queryAggregate.getGroupByExpressions().size() != viewAggregate.getGroupByExpressions().size();
if (queryAggregate.getGroupByExpressions().size() == viewAggregate.getGroupByExpressions().size()) {
List<? extends Expression> queryGroupShuttledExpression = ExpressionUtils.shuttleExpressionWithLineage(
queryAggregate.getGroupByExpressions(), queryTopPlan);
List<? extends Expression> viewGroupShuttledExpression = ExpressionUtils.shuttleExpressionWithLineage(
viewAggregate.getGroupByExpressions(), viewTopPlan)
.stream()
.map(expr -> ExpressionUtils.replace(expr, queryToViewSlotMapping.inverse().toSlotReferenceMap()))
.collect(Collectors.toList());
needRollUp = !queryGroupShuttledExpression.equals(viewGroupShuttledExpression);
}
if (!needRollUp) {
List<Expression> rewrittenQueryGroupExpr = rewriteExpression(queryTopPlan.getOutput(),
queryTopPlan,
materializationContext.getMvExprToMvScanExprMapping(),
queryToViewSlotMapping,
true);
if (rewrittenQueryGroupExpr == null) {
// can not rewrite, bail out.
return null;
}
return new LogicalProject<>(
rewrittenQueryGroupExpr.stream().map(NamedExpression.class::cast).collect(Collectors.toList()),
tempRewritedPlan);
}
// the dimension in query and view are different, try to roll up
// Split query aggregate to group expression and agg function
// Firstly, find the query top output rewrite function expr list which only use query aggregate function,
// This will be used to roll up
if (viewAggregate.getOutputExpressions().stream().anyMatch(
viewExpr -> viewExpr.anyMatch(expr -> expr instanceof AggregateFunction
&& ((AggregateFunction) expr).isDistinct()))) {
// if mv aggregate function contains distinct, can not roll up, bail out.
return null;
}
// split the query top plan expressions to group expressions and functions, if can not, bail out.
Pair<Set<? extends Expression>, Set<? extends Expression>> queryGroupAndFunctionPair
= topPlanSplitToGroupAndFunction(queryTopPlanAndAggPair);
if (queryGroupAndFunctionPair == null) {
return null;
}
// Secondly, try to roll up the agg functions
// this map will be used to rewrite expression
Multimap<Expression, Expression> needRollupExprMap = HashMultimap.create();
Multimap<Expression, Expression> groupRewrittenExprMap = HashMultimap.create();
Map<Expression, Expression> mvExprToMvScanExprQueryBased =
materializationContext.getMvExprToMvScanExprMapping().keyPermute(
queryToViewSlotMapping.inverse()).flattenMap().get(0);
Set<? extends Expression> queryTopPlanFunctionSet = queryGroupAndFunctionPair.value();
// try to rewrite, contains both roll up aggregate functions and aggregate group expression
List<NamedExpression> finalAggregateExpressions = new ArrayList<>();
List<Expression> finalGroupExpressions = new ArrayList<>();
for (Expression topExpression : queryTopPlan.getExpressions()) {
// is agg function, try to roll up and rewrite
if (queryTopPlanFunctionSet.contains(topExpression)) {
Expression needRollupShuttledExpr = ExpressionUtils.shuttleExpressionWithLineage(
topExpression,
queryTopPlan);
if (!mvExprToMvScanExprQueryBased.containsKey(needRollupShuttledExpr)) {
// function can not rewrite by view
return null;
}
// try to roll up
AggregateFunction needRollupAggFunction = (AggregateFunction) topExpression.firstMatch(
expr -> expr instanceof AggregateFunction);
AggregateFunction rollupAggregateFunction = rollup(needRollupAggFunction,
mvExprToMvScanExprQueryBased.get(needRollupShuttledExpr));
if (rollupAggregateFunction == null) {
return null;
}
// key is query need roll up expr, value is mv scan based roll up expr
needRollupExprMap.put(needRollupShuttledExpr, rollupAggregateFunction);
// rewrite query function expression by mv expression
Expression rewrittenFunctionExpression = rewriteExpression(topExpression,
queryTopPlan,
new ExpressionMapping(needRollupExprMap),
queryToViewSlotMapping,
false);
if (rewrittenFunctionExpression == null) {
return null;
}
finalAggregateExpressions.add((NamedExpression) rewrittenFunctionExpression);
} else {
// try to rewrite group expression
Expression queryGroupShuttledExpr =
ExpressionUtils.shuttleExpressionWithLineage(topExpression, queryTopPlan);
if (!mvExprToMvScanExprQueryBased.containsKey(queryGroupShuttledExpr)) {
// group expr can not rewrite by view
return null;
}
groupRewrittenExprMap.put(queryGroupShuttledExpr,
mvExprToMvScanExprQueryBased.get(queryGroupShuttledExpr));
// rewrite query group expression by mv expression
Expression rewrittenGroupExpression = rewriteExpression(
topExpression,
queryTopPlan,
new ExpressionMapping(groupRewrittenExprMap),
queryToViewSlotMapping,
true);
if (rewrittenGroupExpression == null) {
return null;
}
finalAggregateExpressions.add((NamedExpression) rewrittenGroupExpression);
finalGroupExpressions.add(rewrittenGroupExpression);
}
}
// add project to guarantee group by column ref is slot reference,
// this is necessary because physical createHash will need slotReference later
List<Expression> copiedFinalGroupExpressions = new ArrayList<>(finalGroupExpressions);
List<NamedExpression> projectsUnderAggregate = copiedFinalGroupExpressions.stream()
.map(NamedExpression.class::cast)
.collect(Collectors.toList());
projectsUnderAggregate.addAll(tempRewritedPlan.getOutput());
LogicalProject<Plan> mvProject = new LogicalProject<>(projectsUnderAggregate, tempRewritedPlan);
// add agg rewrite
Map<ExprId, Slot> projectOutPutExprIdMap = mvProject.getOutput().stream()
.distinct()
.collect(Collectors.toMap(NamedExpression::getExprId, slot -> slot));
// make the expressions to re reference project output
finalGroupExpressions = finalGroupExpressions.stream()
.map(expr -> {
ExprId exprId = ((NamedExpression) expr).getExprId();
if (projectOutPutExprIdMap.containsKey(exprId)) {
return projectOutPutExprIdMap.get(exprId);
}
return (NamedExpression) expr;
})
.collect(Collectors.toList());
finalAggregateExpressions = finalAggregateExpressions.stream()
.map(expr -> {
ExprId exprId = expr.getExprId();
if (projectOutPutExprIdMap.containsKey(exprId)) {
return projectOutPutExprIdMap.get(exprId);
}
return expr;
})
.collect(Collectors.toList());
LogicalAggregate rewrittenAggregate = new LogicalAggregate(finalGroupExpressions,
finalAggregateExpressions, mvProject);
// record the group id in materializationContext, and when rewrite again in
// the same group, bail out quickly.
if (queryStructInfo.getOriginalPlan().getGroupExpression().isPresent()) {
materializationContext.addMatchedGroup(
queryStructInfo.getOriginalPlan().getGroupExpression().get().getOwnerGroup().getGroupId());
}
return rewrittenAggregate;
}
// only support sum roll up, support other agg functions later.
private AggregateFunction rollup(AggregateFunction originFunction,
Expression mappedExpression) {
Class<? extends AggregateFunction> rollupAggregateFunction = originFunction.getRollup();
if (rollupAggregateFunction == null) {
return null;
}
if (Sum.class.isAssignableFrom(rollupAggregateFunction)) {
return new Sum(originFunction.isDistinct(), mappedExpression);
}
// can rollup return null
return null;
}
private Pair<Set<? extends Expression>, Set<? extends Expression>> topPlanSplitToGroupAndFunction(
Pair<Plan, LogicalAggregate<Plan>> topPlanAndAggPair) {
LogicalAggregate<Plan> queryAggregate = topPlanAndAggPair.value();
Set<Expression> queryAggGroupSet = new HashSet<>(queryAggregate.getGroupByExpressions());
Set<Expression> queryAggFunctionSet = queryAggregate.getOutputExpressions().stream()
.filter(expr -> !queryAggGroupSet.contains(expr))
.collect(Collectors.toSet());
Plan queryTopPlan = topPlanAndAggPair.key();
Set<Expression> topGroupByExpressions = new HashSet<>();
Set<Expression> topFunctionExpressions = new HashSet<>();
queryTopPlan.getExpressions().forEach(
expression -> {
if (expression.anyMatch(expr -> expr instanceof NamedExpression
&& queryAggFunctionSet.contains((NamedExpression) expr))) {
topFunctionExpressions.add(expression);
} else {
topGroupByExpressions.add(expression);
}
});
// only support to reference the aggregate function directly in top, will support expression later.
if (topFunctionExpressions.stream().anyMatch(
topAggFunc -> !(topAggFunc instanceof NamedExpression) && (!queryAggFunctionSet.contains(topAggFunc)
|| !queryAggFunctionSet.contains(topAggFunc.child(0))))) {
return null;
}
return Pair.of(topGroupByExpressions, topFunctionExpressions);
}
private Pair<Plan, LogicalAggregate<Plan>> splitToTopPlanAndAggregate(StructInfo structInfo) {
Plan topPlan = structInfo.getTopPlan();
PlanSplitContext splitContext = new PlanSplitContext(Sets.newHashSet(LogicalAggregate.class));
topPlan.accept(StructInfo.PLAN_SPLITTER, splitContext);
if (!(splitContext.getBottomPlan() instanceof LogicalAggregate)) {
return null;
} else {
return Pair.of(topPlan, (LogicalAggregate<Plan>) splitContext.getBottomPlan());
}
}
// Check Aggregate is simple or not and check join is whether valid or not.
// Support join's input can not contain aggregate Only support project, filter, join, logical relation node and
// join condition should be slot reference equals currently
@Override
protected boolean checkPattern(StructInfo structInfo) {
Plan topPlan = structInfo.getTopPlan();
Boolean valid = topPlan.accept(StructInfo.AGGREGATE_PATTERN_CHECKER, null);
if (!valid) {
return false;
}
HyperGraph hyperGraph = structInfo.getHyperGraph();
for (AbstractNode node : hyperGraph.getNodes()) {
StructInfoNode structInfoNode = (StructInfoNode) node;
if (!structInfoNode.getPlan().accept(StructInfo.JOIN_PATTERN_CHECKER,
SUPPORTED_JOIN_TYPE_SET)) {
return false;
}
for (JoinEdge edge : hyperGraph.getJoinEdges()) {
if (!edge.getJoin().accept(StructInfo.JOIN_PATTERN_CHECKER, SUPPORTED_JOIN_TYPE_SET)) {
return false;
}
}
}
return true;
}
}

View File

@ -24,14 +24,9 @@ import org.apache.doris.nereids.jobs.joinorder.hypergraph.node.StructInfoNode;
import org.apache.doris.nereids.rules.exploration.mv.mapping.SlotMapping;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.util.ExpressionUtils;
import com.google.common.collect.Sets;
import java.util.HashSet;
import java.util.List;
import java.util.stream.Collectors;
@ -40,28 +35,23 @@ import java.util.stream.Collectors;
* This is responsible for common join rewriting
*/
public abstract class AbstractMaterializedViewJoinRule extends AbstractMaterializedViewRule {
private static final HashSet<JoinType> SUPPORTED_JOIN_TYPE_SET =
Sets.newHashSet(JoinType.INNER_JOIN, JoinType.LEFT_OUTER_JOIN);
@Override
protected Plan rewriteQueryByView(MatchMode matchMode,
StructInfo queryStructInfo,
StructInfo viewStructInfo,
SlotMapping queryToViewSlotMappings,
SlotMapping queryToViewSlotMapping,
Plan tempRewritedPlan,
MaterializationContext materializationContext) {
List<? extends Expression> queryShuttleExpression = ExpressionUtils.shuttleExpressionWithLineage(
queryStructInfo.getExpressions(),
queryStructInfo.getOriginalPlan());
// Rewrite top projects, represent the query projects by view
List<Expression> expressionsRewritten = rewriteExpression(
queryShuttleExpression,
materializationContext.getViewExpressionIndexMapping(),
queryToViewSlotMappings
queryStructInfo.getExpressions(),
queryStructInfo.getOriginalPlan(),
materializationContext.getMvExprToMvScanExprMapping(),
queryToViewSlotMapping,
true
);
// Can not rewrite, bail out
if (expressionsRewritten == null
if (expressionsRewritten.isEmpty()
|| expressionsRewritten.stream().anyMatch(expr -> !(expr instanceof NamedExpression))) {
return null;
}

View File

@ -24,6 +24,7 @@ import org.apache.doris.nereids.rules.exploration.mv.mapping.EquivalenceClassSet
import org.apache.doris.nereids.rules.exploration.mv.mapping.ExpressionMapping;
import org.apache.doris.nereids.rules.exploration.mv.mapping.RelationMapping;
import org.apache.doris.nereids.rules.exploration.mv.mapping.SlotMapping;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
@ -31,11 +32,13 @@ import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.algebra.CatalogRelation;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.util.ExpressionUtils;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Sets;
import java.util.ArrayList;
@ -51,6 +54,9 @@ import java.util.stream.Collectors;
*/
public abstract class AbstractMaterializedViewRule {
public static final HashSet<JoinType> SUPPORTED_JOIN_TYPE_SET =
Sets.newHashSet(JoinType.INNER_JOIN, JoinType.LEFT_OUTER_JOIN);
/**
* The abstract template method for query rewrite, it contains the main logic and different query
* pattern should override the sub logic.
@ -105,9 +111,15 @@ public abstract class AbstractMaterializedViewRule {
LogicalCompatibilityContext.from(queryToViewTableMapping, queryToViewSlotMapping,
queryStructInfo, viewStructInfo);
// todo outer join compatibility check
if (StructInfo.isGraphLogicalEquals(queryStructInfo, viewStructInfo, compatibilityContext) == null) {
List<Expression> pulledUpExpressions = StructInfo.isGraphLogicalEquals(queryStructInfo, viewStructInfo,
compatibilityContext);
if (pulledUpExpressions == null) {
continue;
}
// set pulled up expression to queryStructInfo predicates and update related predicates
if (!pulledUpExpressions.isEmpty()) {
queryStructInfo.addPredicates(pulledUpExpressions);
}
SplitPredicate compensatePredicates = predicatesCompensate(queryStructInfo, viewStructInfo,
queryToViewSlotMapping);
// Can not compensate, bail out
@ -122,8 +134,10 @@ public abstract class AbstractMaterializedViewRule {
// Try to rewrite compensate predicates by using mv scan
List<Expression> rewriteCompensatePredicates = rewriteExpression(
compensatePredicates.toList(),
materializationContext.getViewExpressionIndexMapping(),
queryToViewSlotMapping);
queryPlan,
materializationContext.getMvExprToMvScanExprMapping(),
queryToViewSlotMapping,
true);
if (rewriteCompensatePredicates.isEmpty()) {
continue;
}
@ -151,22 +165,30 @@ public abstract class AbstractMaterializedViewRule {
protected Plan rewriteQueryByView(MatchMode matchMode,
StructInfo queryStructInfo,
StructInfo viewStructInfo,
SlotMapping queryToViewSlotMappings,
SlotMapping queryToViewSlotMapping,
Plan tempRewritedPlan,
MaterializationContext materializationContext) {
return tempRewritedPlan;
}
/**
* Use target output expression to represent the source expression
* Use target expression to represent the source expression. Visit the source expression,
* try to replace the source expression with target expression in targetExpressionMapping, if found then
* replace the source expression by target expression mapping value.
* Note: make the target expression map key to source based according to targetExpressionNeedSourceBased,
* if targetExpressionNeedSourceBased is true, we should make it source based.
* the key expression in targetExpressionMapping should be shuttled. with the method
* ExpressionUtils.shuttleExpressionWithLineage.
*/
protected List<Expression> rewriteExpression(
List<? extends Expression> sourceExpressionsToWrite,
ExpressionMapping mvExprToMvScanExprMapping,
SlotMapping sourceToTargetMapping) {
// Firstly, rewrite the target plan output expression using query with inverse mapping
// then try to use the mv expression to represent the query. if any of source expressions
// can not be represented by mv, return null
Plan sourcePlan,
ExpressionMapping targetExpressionMapping,
SlotMapping sourceToTargetMapping,
boolean targetExpressionNeedSourceBased) {
// Firstly, rewrite the target expression using source with inverse mapping
// then try to use the target expression to represent the query. if any of source expressions
// can not be represented by target expressions, return null.
//
// example as following:
// source target
@ -176,36 +198,58 @@ public abstract class AbstractMaterializedViewRule {
// transform source to:
// project(slot 2, 1)
// target
// generate mvSql to mvScan mvExprToMvScanExprMapping, and change mv sql expression to query based
ExpressionMapping mvExprToMvScanExprMappingKeySourceBased =
mvExprToMvScanExprMapping.keyPermute(sourceToTargetMapping.inverse());
List<Map<? extends Expression, ? extends Expression>> flattenExpressionMapping =
mvExprToMvScanExprMappingKeySourceBased.flattenMap();
// view to view scan expression is 1:1 so get first element
Map<? extends Expression, ? extends Expression> mvSqlToMvScanMappingQueryBased =
flattenExpressionMapping.get(0);
// generate target to target replacement expression mapping, and change target expression to source based
List<? extends Expression> sourceShuttledExpressions =
ExpressionUtils.shuttleExpressionWithLineage(sourceExpressionsToWrite, sourcePlan);
ExpressionMapping expressionMappingKeySourceBased = targetExpressionNeedSourceBased
? targetExpressionMapping.keyPermute(sourceToTargetMapping.inverse()) : targetExpressionMapping;
// target to target replacement expression mapping, because mv is 1:1 so get first element
List<Map<Expression, Expression>> flattenExpressionMap =
expressionMappingKeySourceBased.flattenMap();
Map<? extends Expression, ? extends Expression> targetToTargetReplacementMapping = flattenExpressionMap.get(0);
List<Expression> rewrittenExpressions = new ArrayList<>();
for (Expression expressionToRewrite : sourceExpressionsToWrite) {
for (int index = 0; index < sourceShuttledExpressions.size(); index++) {
Expression expressionToRewrite = sourceShuttledExpressions.get(index);
if (expressionToRewrite instanceof Literal) {
rewrittenExpressions.add(expressionToRewrite);
continue;
}
final Set<Object> slotsToRewrite =
expressionToRewrite.collectToSet(expression -> expression instanceof Slot);
boolean wiAlias = expressionToRewrite instanceof NamedExpression;
Expression replacedExpression = ExpressionUtils.replace(expressionToRewrite,
mvSqlToMvScanMappingQueryBased,
wiAlias);
targetToTargetReplacementMapping);
if (replacedExpression.anyMatch(slotsToRewrite::contains)) {
// if contains any slot to rewrite, which means can not be rewritten by target, bail out
return null;
return ImmutableList.of();
}
Expression sourceExpression = sourceExpressionsToWrite.get(index);
if (sourceExpression instanceof NamedExpression) {
NamedExpression sourceNamedExpression = (NamedExpression) sourceExpression;
replacedExpression = new Alias(sourceNamedExpression.getExprId(), replacedExpression,
sourceNamedExpression.getName());
}
rewrittenExpressions.add(replacedExpression);
}
return rewrittenExpressions;
}
protected Expression rewriteExpression(
Expression sourceExpressionsToWrite,
Plan sourcePlan,
ExpressionMapping targetExpressionMapping,
SlotMapping sourceToTargetMapping,
boolean targetExpressionNeedSourceBased) {
List<Expression> expressionToRewrite = new ArrayList<>();
expressionToRewrite.add(sourceExpressionsToWrite);
List<Expression> rewrittenExpressions = rewriteExpression(expressionToRewrite, sourcePlan,
targetExpressionMapping, sourceToTargetMapping, targetExpressionNeedSourceBased);
if (rewrittenExpressions.isEmpty()) {
return null;
}
return rewrittenExpressions.get(0);
}
/**
* Compensate mv predicates by query predicates, compensate predicate result is query based.
* Such as a > 5 in mv, and a > 10 in query, the compensatory predicate is a > 10.

View File

@ -30,6 +30,7 @@ import org.apache.doris.nereids.NereidsPlanner;
import org.apache.doris.nereids.PlannerHook;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.PreAggStatus;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.visitor.TableCollector;
@ -90,19 +91,22 @@ public class InitMaterializationContextHook implements PlannerHook {
.getDbOrMetaException(mvBaseTableInfo.getDbId())
.getTableOrMetaException(mvBaseTableInfo.getTableId(), TableType.MATERIALIZED_VIEW);
String qualifiedName = materializedView.getQualifiedName();
// generate outside, maybe add partition filter in the future
Plan mvScan = new LogicalOlapScan(cascadesContext.getStatementContext().getNextRelationId(),
LogicalOlapScan mvScan = new LogicalOlapScan(
cascadesContext.getStatementContext().getNextRelationId(),
(OlapTable) materializedView,
ImmutableList.of(qualifiedName),
Lists.newArrayList(materializedView.getId()),
ImmutableList.of(materializedView.getQualifiedDbName()),
// this must be empty, or it will be used to sample
Lists.newArrayList(),
Lists.newArrayList(),
Optional.empty());
mvScan = mvScan.withMaterializedIndexSelected(PreAggStatus.on(), materializedView.getBaseIndexId());
List<NamedExpression> mvProjects = mvScan.getOutput().stream().map(NamedExpression.class::cast)
.collect(Collectors.toList());
mvScan = new LogicalProject<Plan>(mvProjects, mvScan);
// todo should force keep consistency to mv sql plan output
Plan projectScan = new LogicalProject<Plan>(mvProjects, mvScan);
cascadesContext.addMaterializationContext(
MaterializationContext.fromMaterializedView(materializedView, mvScan, cascadesContext));
MaterializationContext.fromMaterializedView(materializedView, projectScan, cascadesContext));
} catch (MetaNotFoundException metaNotFoundException) {
LOG.error(mvBaseTableInfo.toString() + " can not find corresponding materialized view.");
}

View File

@ -21,8 +21,11 @@ import org.apache.doris.nereids.jobs.joinorder.hypergraph.node.StructInfoNode;
import org.apache.doris.nereids.rules.exploration.mv.mapping.Mapping.MappedRelation;
import org.apache.doris.nereids.rules.exploration.mv.mapping.RelationMapping;
import org.apache.doris.nereids.rules.exploration.mv.mapping.SlotMapping;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;
import org.apache.doris.nereids.trees.plans.RelationId;
import org.apache.doris.nereids.util.ExpressionUtils;
@ -92,16 +95,39 @@ public class LogicalCompatibilityContext {
final Map<Expression, Expression> viewEdgeToConjunctsMapQueryBased = new HashMap<>();
viewShuttledExprToExprMap.forEach((shuttledExpr, expr) -> {
viewEdgeToConjunctsMapQueryBased.put(
ExpressionUtils.replace(shuttledExpr, viewToQuerySlotMapping),
orderSlotAsc(ExpressionUtils.replace(shuttledExpr, viewToQuerySlotMapping)),
expr);
});
BiMap<Expression, Expression> queryToViewEdgeMapping = HashBiMap.create();
queryShuttledExprToExprMap.forEach((exprSet, edge) -> {
Expression viewExpr = viewEdgeToConjunctsMapQueryBased.get(exprSet);
Expression viewExpr = viewEdgeToConjunctsMapQueryBased.get(orderSlotAsc(exprSet));
if (viewExpr != null) {
queryToViewEdgeMapping.put(edge, viewExpr);
}
});
return new LogicalCompatibilityContext(queryToViewNodeMapping, queryToViewEdgeMapping);
}
private static Expression orderSlotAsc(Expression expression) {
return expression.accept(ExpressionSlotOrder.INSTANCE, null);
}
private static final class ExpressionSlotOrder extends DefaultExpressionRewriter<Void> {
public static final ExpressionSlotOrder INSTANCE = new ExpressionSlotOrder();
@Override
public Expression visitEqualTo(EqualTo equalTo, Void context) {
if (!(equalTo.getArgument(0) instanceof NamedExpression)
|| !(equalTo.getArgument(1) instanceof NamedExpression)) {
return equalTo;
}
NamedExpression left = (NamedExpression) equalTo.getArgument(0);
NamedExpression right = (NamedExpression) equalTo.getArgument(1);
if (right.getExprId().asInt() < left.getExprId().asInt()) {
return new EqualTo(right, left);
} else {
return equalTo;
}
}
}
}

View File

@ -23,7 +23,6 @@ import org.apache.doris.mtmv.MVCache;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.memo.GroupId;
import org.apache.doris.nereids.rules.exploration.mv.mapping.ExpressionMapping;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.util.ExpressionUtils;
@ -32,7 +31,6 @@ import com.google.common.collect.ImmutableList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
/**
* Maintain the context for query rewrite by materialized view
@ -47,7 +45,7 @@ public class MaterializationContext {
// Group ids that are rewritten by this mv to reduce rewrite times
private final Set<GroupId> matchedGroups = new HashSet<>();
// generate form mv scan plan
private ExpressionMapping viewExpressionMapping;
private ExpressionMapping mvExprToMvScanExprMapping;
/**
* MaterializationContext, this contains necessary info for query rewriting by mv
@ -67,14 +65,11 @@ public class MaterializationContext {
mvCache = MVCache.from(mtmv, cascadesContext.getConnectContext());
mtmv.setMvCache(mvCache);
}
List<NamedExpression> mvOutputExpressions = mvCache.getMvOutputExpressions();
// mv output expression shuttle, this will be used to expression rewrite
mvOutputExpressions =
ExpressionUtils.shuttleExpressionWithLineage(mvOutputExpressions, mvCache.getLogicalPlan()).stream()
.map(NamedExpression.class::cast)
.collect(Collectors.toList());
this.viewExpressionMapping = ExpressionMapping.generate(
mvOutputExpressions,
this.mvExprToMvScanExprMapping = ExpressionMapping.generate(
ExpressionUtils.shuttleExpressionWithLineage(
mvCache.getMvOutputExpressions(),
mvCache.getLogicalPlan()),
mvScanPlan.getExpressions());
}
@ -106,8 +101,8 @@ public class MaterializationContext {
return baseViews;
}
public ExpressionMapping getViewExpressionIndexMapping() {
return viewExpressionMapping;
public ExpressionMapping getMvExprToMvScanExprMapping() {
return mvExprToMvScanExprMapping;
}
/**

View File

@ -18,7 +18,13 @@
package org.apache.doris.nereids.rules.exploration.mv;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RulePromise;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.rewrite.RewriteRuleFactory;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import com.google.common.collect.ImmutableList;
import java.util.List;
@ -26,8 +32,15 @@ import java.util.List;
* This is responsible for aggregate rewriting according to different pattern
* */
public class MaterializedViewAggregateRule extends AbstractMaterializedViewAggregateRule implements RewriteRuleFactory {
public static final MaterializedViewAggregateRule INSTANCE = new MaterializedViewAggregateRule();
@Override
public List<Rule> buildRules() {
return null;
return ImmutableList.of(
logicalAggregate(any()).thenApplyMulti(ctx -> {
LogicalAggregate<Plan> root = ctx.root;
return rewrite(root, ctx.cascadesContext);
}).toRule(RuleType.MATERIALIZED_VIEW_ONLY_AGGREGATE, RulePromise.EXPLORE));
}
}

View File

@ -0,0 +1,46 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.exploration.mv;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RulePromise;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.rewrite.RewriteRuleFactory;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import com.google.common.collect.ImmutableList;
import java.util.List;
/**MaterializedViewProjectAggregateRule*/
public class MaterializedViewProjectAggregateRule extends AbstractMaterializedViewAggregateRule implements
RewriteRuleFactory {
public static final MaterializedViewProjectAggregateRule INSTANCE = new MaterializedViewProjectAggregateRule();
@Override
public List<Rule> buildRules() {
return ImmutableList.of(
logicalProject(logicalAggregate(any())).thenApplyMulti(ctx -> {
LogicalProject<LogicalAggregate<Plan>> root = ctx.root;
return rewrite(root, ctx.cascadesContext);
}).toRule(RuleType.MATERIALIZED_VIEW_PROJECT_AGGREGATE, RulePromise.EXPLORE));
}
}

View File

@ -42,6 +42,6 @@ public class MaterializedViewProjectJoinRule extends AbstractMaterializedViewJoi
logicalProject(logicalJoin(any(), any())).thenApplyMulti(ctx -> {
LogicalProject<LogicalJoin<Plan, Plan>> root = ctx.root;
return rewrite(root, ctx.cascadesContext);
}).toRule(RuleType.MATERIALIZED_VIEW_ONLY_JOIN, RulePromise.EXPLORE));
}).toRule(RuleType.MATERIALIZED_VIEW_PROJECT_JOIN, RulePromise.EXPLORE));
}
}

View File

@ -24,7 +24,7 @@ import org.apache.doris.nereids.rules.exploration.mv.Predicates.SplitPredicate;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.RelationId;
@ -32,13 +32,14 @@ import org.apache.doris.nereids.trees.plans.algebra.CatalogRelation;
import org.apache.doris.nereids.trees.plans.algebra.Filter;
import org.apache.doris.nereids.trees.plans.algebra.Join;
import org.apache.doris.nereids.trees.plans.algebra.Project;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.logical.LogicalRepeat;
import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanVisitor;
import org.apache.doris.nereids.util.ExpressionUtils;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
@ -47,15 +48,17 @@ import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
/**
* StructInfo
* StructInfo for plan, this contains necessary info for query rewrite by materialized view
*/
public class StructInfo {
public static final JoinPatternChecker JOIN_PATTERN_CHECKER = new JoinPatternChecker();
public static final AggregatePatternChecker AGGREGATE_PATTERN_CHECKER = new AggregatePatternChecker();
// struct info splitter
public static final PlanSplitter PLAN_SPLITTER = new PlanSplitter();
private static final RelationCollector RELATION_COLLECTOR = new RelationCollector();
@ -72,7 +75,9 @@ public class StructInfo {
private final List<CatalogRelation> relations = new ArrayList<>();
// this is for LogicalCompatibilityContext later
private final Map<RelationId, StructInfoNode> relationIdStructInfoNodeMap = new HashMap<>();
// this recorde the predicates which can pull up, not shuttled
private Predicates predicates;
// split predicates is shuttled
private SplitPredicate splitPredicate;
private EquivalenceClass equivalenceClass;
// this is for LogicalCompatibilityContext later
@ -87,20 +92,28 @@ public class StructInfo {
}
private void init() {
// split the top plan to two parts by join node
if (topPlan == null || bottomPlan == null) {
PlanSplitContext planSplitContext = new PlanSplitContext(Sets.newHashSet(LogicalJoin.class));
originalPlan.accept(PLAN_SPLITTER, planSplitContext);
this.bottomPlan = planSplitContext.getBottomPlan();
this.topPlan = planSplitContext.getTopPlan();
}
collectStructInfoFromGraph();
initPredicates();
predicatesDerive();
}
this.predicates = Predicates.of();
// Collect predicate from join condition in hyper graph
public void addPredicates(List<Expression> canPulledUpExpressions) {
canPulledUpExpressions.forEach(this.predicates::addPredicate);
predicatesDerive();
}
private void collectStructInfoFromGraph() {
// Collect expression from join condition in hyper graph
this.hyperGraph.getJoinEdges().forEach(edge -> {
List<Expression> hashJoinConjuncts = edge.getHashJoinConjuncts();
hashJoinConjuncts.forEach(conjunctExpr -> {
predicates.addPredicate(conjunctExpr);
// shuttle expression in edge for LogicalCompatibilityContext later
shuttledHashConjunctsToConjunctsMap.put(
ExpressionUtils.shuttleExpressionWithLineage(
@ -115,8 +128,7 @@ public class StructInfo {
if (!this.isValid()) {
return;
}
// Collect predicate from filter node in hyper graph
// Collect relations from hyper graph which in the bottom plan
this.hyperGraph.getNodes().forEach(node -> {
// plan relation collector and set to map
Plan nodePlan = node.getPlan();
@ -125,28 +137,40 @@ public class StructInfo {
this.relations.addAll(nodeRelations);
// every node should only have one relation, this is for LogicalCompatibilityContext
relationIdStructInfoNodeMap.put(nodeRelations.get(0).getRelationId(), (StructInfoNode) node);
// if inner join add where condition
Set<Expression> predicates = new HashSet<>();
nodePlan.accept(PREDICATE_COLLECTOR, predicates);
predicates.forEach(this.predicates::addPredicate);
});
// Collect expression from where in hyper graph
this.hyperGraph.getFilterEdges().forEach(filterEdge -> {
List<? extends Expression> filterExpressions = filterEdge.getExpressions();
filterExpressions.forEach(predicate -> {
// this is used for LogicalCompatibilityContext
ExpressionUtils.extractConjunction(predicate).forEach(expr ->
shuttledHashConjunctsToConjunctsMap.put(
ExpressionUtils.shuttleExpressionWithLineage(predicate, topPlan), predicate));
});
});
}
// TODO Collect predicate from top plan not in hyper graph, should optimize, twice now
private void initPredicates() {
// Collect predicate from top plan which not in hyper graph
this.predicates = Predicates.of();
Set<Expression> topPlanPredicates = new HashSet<>();
topPlan.accept(PREDICATE_COLLECTOR, topPlanPredicates);
topPlanPredicates.forEach(this.predicates::addPredicate);
}
// derive some useful predicate by predicates
private void predicatesDerive() {
// construct equivalenceClass according to equals predicates
this.equivalenceClass = new EquivalenceClass();
List<Expression> shuttledExpression = ExpressionUtils.shuttleExpressionWithLineage(
this.predicates.getPulledUpPredicates(), originalPlan).stream()
.map(Expression.class::cast)
.collect(Collectors.toList());
SplitPredicate splitPredicate = Predicates.splitPredicates(ExpressionUtils.and(shuttledExpression));
this.splitPredicate = splitPredicate;
this.equivalenceClass = new EquivalenceClass();
for (Expression expression : ExpressionUtils.extractConjunction(splitPredicate.getEqualPredicate())) {
if (expression instanceof BooleanLiteral && ((BooleanLiteral) expression).getValue()) {
if (expression instanceof Literal) {
continue;
}
if (expression instanceof EqualTo) {
@ -166,6 +190,7 @@ public class StructInfo {
// TODO only consider the inner join currently, Should support outer join
// Split plan by the boundary which contains multi child
PlanSplitContext planSplitContext = new PlanSplitContext(Sets.newHashSet(LogicalJoin.class));
// if single table without join, the bottom is
originalPlan.accept(PLAN_SPLITTER, planSplitContext);
List<HyperGraph> structInfos = HyperGraph.toStructInfo(planSplitContext.getBottomPlan());
@ -240,9 +265,7 @@ public class StructInfo {
*/
public static @Nullable List<Expression> isGraphLogicalEquals(StructInfo queryStructInfo, StructInfo viewStructInfo,
LogicalCompatibilityContext compatibilityContext) {
// TODO: open it after supporting filter
// return queryStructInfo.hyperGraph.isLogicCompatible(viewStructInfo.hyperGraph, compatibilityContext);
return ImmutableList.of();
return queryStructInfo.hyperGraph.isLogicCompatible(viewStructInfo.hyperGraph, compatibilityContext);
}
private static class RelationCollector extends DefaultPlanVisitor<Void, List<CatalogRelation>> {
@ -258,8 +281,12 @@ public class StructInfo {
private static class PredicateCollector extends DefaultPlanVisitor<Void, Set<Expression>> {
@Override
public Void visit(Plan plan, Set<Expression> predicates) {
// Just collect the filter in top plan, if meet other node except project and filter, return
if (!(plan instanceof LogicalProject) && !(plan instanceof LogicalFilter)) {
return null;
}
if (plan instanceof LogicalFilter) {
predicates.add(((LogicalFilter) plan).getPredicate());
predicates.addAll(ExpressionUtils.extractConjunction(((LogicalFilter) plan).getPredicate()));
}
return super.visit(plan, predicates);
}
@ -267,7 +294,7 @@ public class StructInfo {
/**
* Split the plan into bottom and up, the boundary is given by context,
* the bottom contains the boundary.
* the bottom contains the boundary, and top plan doesn't contain the boundary.
*/
public static class PlanSplitter extends DefaultPlanVisitor<Void, PlanSplitContext> {
@Override
@ -275,6 +302,10 @@ public class StructInfo {
if (context.getTopPlan() == null) {
context.setTopPlan(plan);
}
if (plan.children().isEmpty() && context.getBottomPlan() == null) {
context.setBottomPlan(plan);
return null;
}
if (context.isBoundary(plan)) {
context.setBottomPlan(plan);
return null;
@ -349,4 +380,28 @@ public class StructInfo {
return true;
}
}
/**
* AggregatePatternChecker
*/
public static class AggregatePatternChecker extends DefaultPlanVisitor<Boolean, Void> {
@Override
public Boolean visit(Plan plan, Void context) {
if (plan instanceof LogicalAggregate) {
LogicalAggregate<Plan> aggregate = (LogicalAggregate<Plan>) plan;
Optional<LogicalRepeat<?>> sourceRepeat = aggregate.getSourceRepeat();
if (sourceRepeat.isPresent()) {
return false;
}
super.visit(aggregate, context);
return true;
}
if (plan instanceof LogicalProject) {
super.visit(plan, context);
return true;
}
super.visit(plan, context);
return false;
}
}
}

View File

@ -1,48 +0,0 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.exploration.mv.mapping;
import org.apache.doris.nereids.trees.expressions.Expression;
import com.google.common.collect.ArrayListMultimap;
import com.google.common.collect.Multimap;
import java.util.List;
/**
* Expression and it's index mapping
*/
public class ExpressionIndexMapping extends Mapping {
private final Multimap<Expression, Integer> expressionIndexMapping;
public ExpressionIndexMapping(Multimap<Expression, Integer> expressionIndexMapping) {
this.expressionIndexMapping = expressionIndexMapping;
}
public Multimap<Expression, Integer> getExpressionIndexMapping() {
return expressionIndexMapping;
}
public static ExpressionIndexMapping generate(List<? extends Expression> expressions) {
Multimap<Expression, Integer> expressionIndexMapping = ArrayListMultimap.create();
for (int i = 0; i < expressions.size(); i++) {
expressionIndexMapping.put(expressions.get(i), i);
}
return new ExpressionIndexMapping(expressionIndexMapping);
}
}

View File

@ -22,6 +22,7 @@ import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.util.ExpressionUtils;
import com.google.common.collect.ArrayListMultimap;
import com.google.common.collect.ImmutableMultimap;
import com.google.common.collect.Lists;
import com.google.common.collect.Multimap;
@ -30,25 +31,26 @@ import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
/**
* Expression mapping, maybe one expression map to multi expression
*/
public class ExpressionMapping extends Mapping {
private final Multimap<? extends Expression, ? extends Expression> expressionMapping;
private final Multimap<Expression, Expression> expressionMapping;
public ExpressionMapping(Multimap<? extends Expression, ? extends Expression> expressionMapping) {
public ExpressionMapping(Multimap<Expression, Expression> expressionMapping) {
this.expressionMapping = expressionMapping;
}
public Multimap<? extends Expression, ? extends Expression> getExpressionMapping() {
public Multimap<Expression, Expression> getExpressionMapping() {
return expressionMapping;
}
/**
* ExpressionMapping flatten
*/
public List<Map<? extends Expression, ? extends Expression>> flattenMap() {
public List<Map<Expression, Expression>> flattenMap() {
List<List<Pair<Expression, Expression>>> tmpExpressionPairs = new ArrayList<>(this.expressionMapping.size());
Map<? extends Expression, ? extends Collection<? extends Expression>> expressionMappingMap =
expressionMapping.asMap();
@ -62,7 +64,7 @@ public class ExpressionMapping extends Mapping {
}
List<List<Pair<Expression, Expression>>> cartesianExpressionMap = Lists.cartesianProduct(tmpExpressionPairs);
final List<Map<? extends Expression, ? extends Expression>> flattenedMap = new ArrayList<>();
final List<Map<Expression, Expression>> flattenedMap = new ArrayList<>();
for (List<Pair<Expression, Expression>> listPair : cartesianExpressionMap) {
final Map<Expression, Expression> expressionMap = new HashMap<>();
listPair.forEach(pair -> expressionMap.put(pair.key(), pair.value()));
@ -71,7 +73,8 @@ public class ExpressionMapping extends Mapping {
return flattenedMap;
}
/**Permute the key of expression mapping. this is useful for expression rewrite, if permute key to query based
/**
* Permute the key of expression mapping. this is useful for expression rewrite, if permute key to query based
* then when expression rewrite success, we can get the mv scan expression directly.
*/
public ExpressionMapping keyPermute(SlotMapping slotMapping) {
@ -86,7 +89,9 @@ public class ExpressionMapping extends Mapping {
return new ExpressionMapping(permutedExpressionMapping);
}
/**ExpressionMapping generate*/
/**
* ExpressionMapping generate
*/
public static ExpressionMapping generate(
List<? extends Expression> sourceExpressions,
List<? extends Expression> targetExpressions) {
@ -97,4 +102,25 @@ public class ExpressionMapping extends Mapping {
}
return new ExpressionMapping(expressionMultiMap);
}
@Override
public Mapping chainedFold(Mapping target) {
ImmutableMultimap.Builder<Expression, Expression> foldedMappingBuilder =
ImmutableMultimap.builder();
Multimap<Expression, Expression> targetMapping
= ((ExpressionMapping) target).getExpressionMapping();
for (Entry<Expression, ? extends Collection<Expression>> exprMapping :
this.getExpressionMapping().asMap().entrySet()) {
Collection<? extends Expression> valueExpressions = exprMapping.getValue();
valueExpressions.forEach(valueExpr -> {
if (targetMapping.containsKey(valueExpr)) {
targetMapping.get(valueExpr).forEach(
targetValue -> foldedMappingBuilder.put(exprMapping.getKey(), targetValue));
}
});
}
return new ExpressionMapping(foldedMappingBuilder.build());
}
}

View File

@ -136,4 +136,12 @@ public abstract class Mapping {
return Objects.hash(exprId);
}
}
/** Chain fold tow mapping, such as this mapping is {[a -> b]}, the target mapping is
* {[b -> c]} after chain fold, this result will be {[a -> c]}, if the value side in this mapping
* can get the key in the target mapping, will lose the mapping
*/
protected Mapping chainedFold(Mapping target) {
return null;
}
}

View File

@ -185,6 +185,23 @@ public interface TreeNode<NODE_TYPE extends TreeNode<NODE_TYPE>> {
return false;
}
/**
* iterate top down and test predicate if any matched. Top-down traverse implicitly.
* @param predicate predicate
* @return the first node which match the predicate
*/
default TreeNode<NODE_TYPE> firstMatch(Predicate<TreeNode<NODE_TYPE>> predicate) {
if (predicate.test(this)) {
return this;
}
for (NODE_TYPE child : children()) {
if (child.anyMatch(predicate)) {
return child;
}
}
return this;
}
/**
* Collect the nodes that satisfied the predicate.
*/

View File

@ -77,6 +77,10 @@ public abstract class AggregateFunction extends BoundFunction implements Expects
return distinct;
}
public Class<? extends AggregateFunction> getRollup() {
return null;
}
@Override
public boolean equals(Object o) {
if (this == o) {

View File

@ -109,4 +109,9 @@ public class Sum extends NullableAggregateFunction
public List<FunctionSignature> getSignatures() {
return SIGNATURES;
}
@Override
public Class<? extends AggregateFunction> getRollup() {
return Sum.class;
}
}

View File

@ -209,7 +209,7 @@ public class CreateMTMVInfo {
colNames.add(colName);
}
columns.add(new ColumnDefinition(
colName, slots.get(i).getDataType(), true,
colName, slots.get(i).getDataType(), slots.get(i).nullable(),
CollectionUtils.isEmpty(simpleColumnDefinitions) ? null
: simpleColumnDefinitions.get(i).getComment()));
}

View File

@ -210,7 +210,6 @@ public class LogicalAggregate<CHILD_TYPE extends Plan>
@Override
public List<? extends Expression> getExpressions() {
return new ImmutableList.Builder<Expression>()
.addAll(groupByExpressions)
.addAll(outputExpressions)
.build();
}

View File

@ -19,11 +19,10 @@ package org.apache.doris.nereids.trees.plans.visitor;
import org.apache.doris.catalog.TableIf.TableType;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.ArrayItemReference.ArrayItemSlot;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionVisitor;
import org.apache.doris.nereids.trees.plans.Plan;
@ -74,7 +73,8 @@ public class ExpressionLineageReplacer extends DefaultPlanVisitor<Expression, Ex
}
/**
* The Collector for target named expressions
* The Collector for target named expressions in the whole plan, and will be used to
* replace the target expression later
* TODO Collect named expression by targetTypes, tableIdentifiers
*/
public static class NamedExpressionCollector
@ -83,15 +83,9 @@ public class ExpressionLineageReplacer extends DefaultPlanVisitor<Expression, Ex
public static final NamedExpressionCollector INSTANCE = new NamedExpressionCollector();
@Override
public Void visitSlotReference(SlotReference slotReference, ExpressionReplaceContext context) {
context.getExprIdExpressionMap().put(slotReference.getExprId(), slotReference);
return super.visitSlotReference(slotReference, context);
}
@Override
public Void visitArrayItemSlot(ArrayItemSlot arrayItemSlot, ExpressionReplaceContext context) {
context.getExprIdExpressionMap().put(arrayItemSlot.getExprId(), arrayItemSlot);
return super.visitArrayItemSlot(arrayItemSlot, context);
public Void visitSlot(Slot slot, ExpressionReplaceContext context) {
context.getExprIdExpressionMap().put(slot.getExprId(), slot);
return super.visit(slot, context);
}
@Override
@ -121,7 +115,7 @@ public class ExpressionLineageReplacer extends DefaultPlanVisitor<Expression, Ex
this.targetExpressions = targetExpressions;
this.targetTypes = targetTypes;
this.tableIdentifiers = tableIdentifiers;
// collect only named expressions and replace them with linage identifier later
// collect the named expressions used in target expression and will be replaced later
this.exprIdExpressionMap = targetExpressions.stream()
.map(each -> each.collectToList(NamedExpression.class::isInstance))
.flatMap(Collection::stream)

View File

@ -207,6 +207,11 @@ public class ExpressionUtils {
.orElse(BooleanLiteral.of(type == And.class));
}
public static Expression shuttleExpressionWithLineage(Expression expression, Plan plan) {
return shuttleExpressionWithLineage(Lists.newArrayList(expression),
plan, ImmutableSet.of(), ImmutableSet.of()).get(0);
}
public static List<? extends Expression> shuttleExpressionWithLineage(List<? extends Expression> expressions,
Plan plan) {
return shuttleExpressionWithLineage(expressions, plan, ImmutableSet.of(), ImmutableSet.of());
@ -310,24 +315,7 @@ public class ExpressionUtils {
* </pre>
*/
public static Expression replace(Expression expr, Map<? extends Expression, ? extends Expression> replaceMap) {
return expr.accept(ExpressionReplacer.INSTANCE, ExpressionReplacerContext.of(replaceMap, false));
}
/**
* Replace expression node in the expression tree by `replaceMap` in top-down manner.
* if replaced, create alias
* For example.
* <pre>
* input expression: a > 1
* replaceMap: a -> b + c
*
* output:
* ((b + c) as a) > 1
* </pre>
*/
public static Expression replace(Expression expr, Map<? extends Expression, ? extends Expression> replaceMap,
boolean withAlias) {
return expr.accept(ExpressionReplacer.INSTANCE, ExpressionReplacerContext.of(replaceMap, true));
return expr.accept(ExpressionReplacer.INSTANCE, replaceMap);
}
/**
@ -335,8 +323,7 @@ public class ExpressionUtils {
*/
public static NamedExpression replace(NamedExpression expr,
Map<? extends Expression, ? extends Expression> replaceMap) {
Expression newExpr = expr.accept(ExpressionReplacer.INSTANCE,
ExpressionReplacerContext.of(replaceMap, false));
Expression newExpr = expr.accept(ExpressionReplacer.INSTANCE, replaceMap);
if (newExpr instanceof NamedExpression) {
return (NamedExpression) newExpr;
} else {
@ -366,54 +353,49 @@ public class ExpressionUtils {
}
private static class ExpressionReplacer
extends DefaultExpressionRewriter<ExpressionReplacerContext> {
extends DefaultExpressionRewriter<Map<? extends Expression, ? extends Expression>> {
public static final ExpressionReplacer INSTANCE = new ExpressionReplacer();
private ExpressionReplacer() {
}
@Override
public Expression visit(Expression expr, ExpressionReplacerContext replacerContext) {
Map<? extends Expression, ? extends Expression> replaceMap = replacerContext.getReplaceMap();
boolean isContained = replaceMap.containsKey(expr);
if (!isContained) {
return super.visit(expr, replacerContext);
}
boolean withAlias = replacerContext.isWithAlias();
if (!withAlias) {
return replaceMap.get(expr);
} else {
public Expression visit(Expression expr, Map<? extends Expression, ? extends Expression> replaceMap) {
if (replaceMap.containsKey(expr)) {
Expression replacedExpression = replaceMap.get(expr);
if (replacedExpression instanceof SlotReference) {
replacedExpression = ((SlotReference) (replacedExpression)).withNullable(expr.nullable());
if (replacedExpression instanceof SlotReference
&& replacedExpression.nullable() != expr.nullable()) {
replacedExpression = ((SlotReference) replacedExpression).withNullable(expr.nullable());
}
return new Alias(((NamedExpression) expr).getExprId(), replacedExpression,
((NamedExpression) expr).getName());
return replacedExpression;
}
return super.visit(expr, replaceMap);
}
}
private static class ExpressionReplacerContext {
private final Map<? extends Expression, ? extends Expression> replaceMap;
private final boolean withAlias;
// if the key of replaceMap is named expr and withAlias is true, we should
// add alias after replaced
private final boolean withAliasIfKeyNamed;
private ExpressionReplacerContext(Map<? extends Expression, ? extends Expression> replaceMap,
boolean withAlias) {
boolean withAliasIfKeyNamed) {
this.replaceMap = replaceMap;
this.withAlias = withAlias;
this.withAliasIfKeyNamed = withAliasIfKeyNamed;
}
public static ExpressionReplacerContext of(Map<? extends Expression, ? extends Expression> replaceMap,
boolean withAlias) {
return new ExpressionReplacerContext(replaceMap, withAlias);
boolean withAliasIfKeyNamed) {
return new ExpressionReplacerContext(replaceMap, withAliasIfKeyNamed);
}
public Map<? extends Expression, ? extends Expression> getReplaceMap() {
return replaceMap;
}
public boolean isWithAlias() {
return withAlias;
public boolean isWithAliasIfKeyNamed() {
return withAliasIfKeyNamed;
}
}

View File

@ -122,7 +122,6 @@ import org.apache.doris.nereids.glue.LogicalPlanAdapter;
import org.apache.doris.nereids.minidump.MinidumpUtils;
import org.apache.doris.nereids.parser.NereidsParser;
import org.apache.doris.nereids.rules.exploration.mv.InitMaterializationContextHook;
import org.apache.doris.nereids.stats.StatsErrorEstimator;
import org.apache.doris.nereids.trees.plans.commands.BatchInsertIntoTableCommand;
import org.apache.doris.nereids.trees.plans.commands.Command;
import org.apache.doris.nereids.trees.plans.commands.CreateTableCommand;
@ -581,9 +580,6 @@ public class StmtExecutor {
}
} else {
context.getState().setIsQuery(true);
if (context.getSessionVariable().enableProfile) {
ConnectContext.get().setStatsErrorEstimator(new StatsErrorEstimator());
}
// create plan
planner = new NereidsPlanner(statementContext);
if (context.getSessionVariable().isEnableMaterializedViewRewrite()) {