[Fix](Nereids) Fix rewrite by materialized view fail when join input has agg (#30734)

materialized view definition is as following, and the query sql is the same
when outer group by use the col1 in the inner group, which can be rewritten by materialized view

select
t1.o_orderdate,
t1.o_orderkey,
t1.col1
from
(
select
o_orderkey,
o_custkey,
o_orderstatus,
o_orderdate,
sum(o_shippriority) as col1
from
orders
group by
o_orderkey,
o_custkey,
o_orderstatus,
o_orderdate
) as t1
left join lineitem on lineitem.l_orderkey = t1.o_orderkey
group by
t1.o_orderdate,
t1.o_orderkey,
t1.col1
This commit is contained in:
seawinde
2024-02-03 20:19:53 +08:00
committed by yiguolei
parent 119615dc50
commit 5aed3abb8a
7 changed files with 215 additions and 106 deletions

View File

@ -18,10 +18,6 @@
package org.apache.doris.nereids.rules.exploration.mv;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.HyperGraph;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.edge.JoinEdge;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.node.AbstractNode;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.node.StructInfoNode;
import org.apache.doris.nereids.rules.exploration.mv.StructInfo.PlanSplitContext;
import org.apache.doris.nereids.rules.exploration.mv.mapping.SlotMapping;
import org.apache.doris.nereids.trees.expressions.Alias;
@ -368,29 +364,11 @@ public abstract class AbstractMaterializedViewAggregateRule extends AbstractMate
}
// Check Aggregate is simple or not and check join is whether valid or not.
// Support join's input can not contain aggregate Only support project, filter, join, logical relation node and
// join condition should be slot reference equals currently
// Support project, filter, join, logical relation node and join condition should only contain
// slot reference equals currently.
@Override
protected boolean checkPattern(StructInfo structInfo) {
Plan topPlan = structInfo.getTopPlan();
Boolean valid = topPlan.accept(StructInfo.AGGREGATE_PATTERN_CHECKER, null);
if (!valid) {
return false;
}
HyperGraph hyperGraph = structInfo.getHyperGraph();
for (AbstractNode node : hyperGraph.getNodes()) {
StructInfoNode structInfoNode = (StructInfoNode) node;
if (!structInfoNode.getPlan().accept(StructInfo.JOIN_PATTERN_CHECKER, SUPPORTED_JOIN_TYPE_SET)) {
return false;
}
}
for (JoinEdge edge : hyperGraph.getJoinEdges()) {
if (!edge.getJoin().accept(StructInfo.JOIN_PATTERN_CHECKER, SUPPORTED_JOIN_TYPE_SET)) {
return false;
}
}
return true;
return structInfo.getTopPlan().accept(StructInfo.PLAN_PATTERN_CHECKER, SUPPORTED_JOIN_TYPE_SET);
}
/**

View File

@ -18,10 +18,6 @@
package org.apache.doris.nereids.rules.exploration.mv;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.HyperGraph;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.edge.JoinEdge;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.node.AbstractNode;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.node.StructInfoNode;
import org.apache.doris.nereids.rules.exploration.mv.mapping.SlotMapping;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Expression;
@ -79,26 +75,11 @@ public abstract class AbstractMaterializedViewJoinRule extends AbstractMateriali
}
/**
* Check join is whether valid or not. Support join's input can not contain aggregate
* Only support project, filter, join, logical relation node and
* join condition should be slot reference equals currently
* Check join is whether valid or not. Support join's input only support project, filter, join,
* logical relation node and join condition should be slot reference equals currently
*/
@Override
protected boolean checkPattern(StructInfo structInfo) {
HyperGraph hyperGraph = structInfo.getHyperGraph();
for (AbstractNode node : hyperGraph.getNodes()) {
StructInfoNode structInfoNode = (StructInfoNode) node;
if (!structInfoNode.getPlan().accept(StructInfo.JOIN_PATTERN_CHECKER,
SUPPORTED_JOIN_TYPE_SET)) {
return false;
}
}
for (JoinEdge edge : hyperGraph.getJoinEdges()) {
if (!edge.getJoin().accept(StructInfo.JOIN_PATTERN_CHECKER,
SUPPORTED_JOIN_TYPE_SET)) {
return false;
}
}
return true;
return structInfo.getTopPlan().accept(StructInfo.PLAN_PATTERN_CHECKER, SUPPORTED_JOIN_TYPE_SET);
}
}

View File

@ -29,6 +29,7 @@ import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.ObjectId;
import org.apache.doris.nereids.trees.plans.Plan;
@ -37,12 +38,11 @@ import org.apache.doris.nereids.trees.plans.algebra.CatalogRelation;
import org.apache.doris.nereids.trees.plans.algebra.Filter;
import org.apache.doris.nereids.trees.plans.algebra.Join;
import org.apache.doris.nereids.trees.plans.algebra.Project;
import org.apache.doris.nereids.trees.plans.algebra.Sort;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.logical.LogicalRepeat;
import org.apache.doris.nereids.trees.plans.logical.LogicalSort;
import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanVisitor;
import org.apache.doris.nereids.trees.plans.visitor.ExpressionLineageReplacer;
import org.apache.doris.nereids.util.ExpressionUtils;
@ -57,7 +57,6 @@ import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
@ -68,8 +67,7 @@ import javax.annotation.Nullable;
* modify, if wanting to modify, should copy and then modify
*/
public class StructInfo {
public static final JoinPatternChecker JOIN_PATTERN_CHECKER = new JoinPatternChecker();
public static final AggregatePatternChecker AGGREGATE_PATTERN_CHECKER = new AggregatePatternChecker();
public static final PlanPatternChecker PLAN_PATTERN_CHECKER = new PlanPatternChecker();
// struct info splitter
public static final PlanSplitter PLAN_SPLITTER = new PlanSplitter();
private static final RelationCollector RELATION_COLLECTOR = new RelationCollector();
@ -189,7 +187,7 @@ public class StructInfo {
Lists.newArrayList(expression),
ImmutableSet.of(),
ImmutableSet.of());
topPlan.accept(ExpressionLineageReplacer.INSTANCE, replaceContext);
structInfoNode.getPlan().accept(ExpressionLineageReplacer.INSTANCE, replaceContext);
// Replace expressions by expression map
List<Expression> replacedExpressions = replaceContext.getReplacedExpressions();
shuttledHashConjunctsToConjunctsMap.put(replacedExpressions.get(0), expression);
@ -448,55 +446,58 @@ public class StructInfo {
}
/**
* JoinPatternChecker
* PlanPatternChecker, this is used to check the plan pattern is valid or not
*/
public static class JoinPatternChecker extends DefaultPlanVisitor<Boolean, Set<JoinType>> {
public static class PlanPatternChecker extends DefaultPlanVisitor<Boolean, Set<JoinType>> {
@Override
public Boolean visit(Plan plan, Set<JoinType> requiredJoinType) {
super.visit(plan, requiredJoinType);
if (!(plan instanceof Filter)
&& !(plan instanceof Project)
&& !(plan instanceof CatalogRelation)
&& !(plan instanceof Join)
&& !(plan instanceof Sort)
&& !(plan instanceof LogicalAggregate && !((LogicalAggregate) plan).getSourceRepeat()
.isPresent())) {
public Boolean visitLogicalJoin(LogicalJoin<? extends Plan, ? extends Plan> join,
Set<JoinType> supportJoinTypes) {
if (!supportJoinTypes.contains(join.getJoinType())) {
return false;
}
if (plan instanceof Join) {
Join join = (Join) plan;
if (!requiredJoinType.contains(join.getJoinType())) {
return false;
}
if (!join.getOtherJoinConjuncts().isEmpty()) {
if (!join.getOtherJoinConjuncts().isEmpty()) {
return false;
}
return visit(join, supportJoinTypes);
}
@Override
public Boolean visitLogicalAggregate(LogicalAggregate<? extends Plan> aggregate,
Set<JoinType> supportJoinTypes) {
if (aggregate.getSourceRepeat().isPresent()) {
return false;
}
return visit(aggregate, supportJoinTypes);
}
@Override
public Boolean visitGroupPlan(GroupPlan groupPlan, Set<JoinType> supportJoinTypes) {
return groupPlan.getGroup().getLogicalExpressions().stream()
.anyMatch(logicalExpression -> logicalExpression.getPlan().accept(this, supportJoinTypes));
}
@Override
public Boolean visit(Plan plan, Set<JoinType> supportJoinTypes) {
if (plan instanceof Filter
|| plan instanceof Project
|| plan instanceof CatalogRelation
|| plan instanceof Join
|| plan instanceof LogicalSort
|| plan instanceof LogicalAggregate
|| plan instanceof GroupPlan) {
return doVisit(plan, supportJoinTypes);
}
return false;
}
private Boolean doVisit(Plan plan, Set<JoinType> supportJoinTypes) {
for (Plan child : plan.children()) {
boolean valid = child.accept(this, supportJoinTypes);
if (!valid) {
return false;
}
}
return true;
}
}
/**
* AggregatePatternChecker
*/
public static class AggregatePatternChecker extends DefaultPlanVisitor<Boolean, Void> {
@Override
public Boolean visit(Plan plan, Void context) {
if (plan instanceof LogicalAggregate) {
LogicalAggregate<Plan> aggregate = (LogicalAggregate<Plan>) plan;
Optional<LogicalRepeat<?>> sourceRepeat = aggregate.getSourceRepeat();
if (sourceRepeat.isPresent()) {
return false;
}
super.visit(aggregate, context);
return true;
}
if (plan instanceof Project || plan instanceof Filter || plan instanceof Sort) {
super.visit(plan, context);
return true;
}
super.visit(plan, context);
return false;
}
}
}

View File

@ -295,3 +295,17 @@
4 0 0 0 43.2000 43.20 43.20
5 0 0 0 28.7000 56.20 1.20
-- !query9_0_before --
2023-12-08 1 2
2023-12-09 2 1
2023-12-10 3 2
2023-12-11 4 2
2023-12-12 5 4
-- !query9_0_after --
2023-12-08 1 2
2023-12-09 2 1
2023-12-10 3 2
2023-12-11 4 2
2023-12-12 5 4

View File

@ -311,3 +311,17 @@
4 0 0 0 43.2000 43.20 43.20
5 0 0 0 28.7000 56.20 1.20
-- !query9_0_before --
2023-12-08 1 2
2023-12-09 2 1
2023-12-10 3 2
2023-12-11 4 2
2023-12-12 5 4
-- !query9_0_after --
2023-12-08 1 2
2023-12-09 2 1
2023-12-10 3 2
2023-12-11 4 2
2023-12-12 5 4

View File

@ -466,4 +466,64 @@ suite("outer_join_dphyp") {
check_rewrite(mv8_0, query8_0, "mv8_0")
order_qt_query8_0_after "${query8_0}"
sql """ DROP MATERIALIZED VIEW IF EXISTS mv8_0"""
// join input with simple agg, use aggregate function as outer group by
def mv9_0 = """
select
t1.o_orderdate,
t1.o_orderkey,
t1.col1
from
(
select
o_orderkey,
o_custkey,
o_orderstatus,
o_orderdate,
sum(o_shippriority) as col1
from
orders
group by
o_orderkey,
o_custkey,
o_orderstatus,
o_orderdate
) as t1
left join lineitem on lineitem.l_orderkey = t1.o_orderkey
group by
t1.o_orderdate,
t1.o_orderkey,
t1.col1
"""
def query9_0 = """
select
t1.o_orderdate,
t1.o_orderkey,
t1.col1
from
(
select
o_orderkey,
o_custkey,
o_orderstatus,
o_orderdate,
sum(o_shippriority) as col1
from
orders
group by
o_orderkey,
o_custkey,
o_orderstatus,
o_orderdate
) as t1
left join lineitem on lineitem.l_orderkey = t1.o_orderkey
group by
t1.o_orderdate,
t1.o_orderkey,
t1.col1
"""
order_qt_query9_0_before "${query9_0}"
check_rewrite(mv9_0, query9_0, "mv9_0")
order_qt_query9_0_after "${query9_0}"
sql """ DROP MATERIALIZED VIEW IF EXISTS mv9_0"""
}

View File

@ -476,14 +476,14 @@ suite("outer_join") {
def mv6_1 = """
select l_shipdate, t.o_orderdate, l_partkey, l_suppkey, t.o_orderkey
from lineitem_null
left join (select o_orderdate,o_orderkey from orders_null where o_orderdate = '2023-12-10' ) t
left join (select o_orderdate,o_orderkey from orders_null where o_orderdate = '2023-12-10' ) t
on l_orderkey = t.o_orderkey;
"""
def query6_1 = """
select l_shipdate, o_orderdate, l_partkey, l_suppkey, o_orderkey
from lineitem_null
left join orders_null
on l_orderkey = o_orderkey
select l_shipdate, o_orderdate, l_partkey, l_suppkey, o_orderkey
from lineitem_null
left join orders_null
on l_orderkey = o_orderkey
where l_shipdate = '2023-12-10' and o_orderdate = '2023-12-10';
"""
order_qt_query6_1_before "${query6_1}"
@ -494,16 +494,16 @@ suite("outer_join") {
// should compensate predicate o_orderdate = '2023-12-10' on mv
def mv6_2 = """
select l_shipdate, o_orderdate, l_partkey, l_suppkey, o_orderkey
from lineitem
left join (select * from orders where o_orderdate = '2023-12-10' ) t2
select l_shipdate, o_orderdate, l_partkey, l_suppkey, o_orderkey
from lineitem
left join (select * from orders where o_orderdate = '2023-12-10' ) t2
on lineitem.l_orderkey = t2.o_orderkey;
"""
def query6_2 = """
select l_shipdate, o_orderdate, l_partkey, l_suppkey, o_orderkey
from lineitem
left join orders
on lineitem.l_orderkey = orders.o_orderkey
select l_shipdate, o_orderdate, l_partkey, l_suppkey, o_orderkey
from lineitem
left join orders
on lineitem.l_orderkey = orders.o_orderkey
where o_orderdate = '2023-12-10' order by 1, 2, 3, 4, 5;
"""
order_qt_query6_2_before "${query6_2}"
@ -586,4 +586,65 @@ suite("outer_join") {
check_rewrite(mv8_0, query8_0, "mv8_0")
order_qt_query8_0_after "${query8_0}"
sql """ DROP MATERIALIZED VIEW IF EXISTS mv8_0"""
// join input with simple agg, use aggregate function as outer group by
def mv9_0 = """
select
t1.o_orderdate,
t1.o_orderkey,
t1.col1
from
(
select
o_orderkey,
o_custkey,
o_orderstatus,
o_orderdate,
sum(o_shippriority) as col1
from
orders
group by
o_orderkey,
o_custkey,
o_orderstatus,
o_orderdate
) as t1
left join lineitem on lineitem.l_orderkey = t1.o_orderkey
group by
t1.o_orderdate,
t1.o_orderkey,
t1.col1
"""
def query9_0 = """
select
t1.o_orderdate,
t1.o_orderkey,
t1.col1
from
(
select
o_orderkey,
o_custkey,
o_orderstatus,
o_orderdate,
sum(o_shippriority) as col1
from
orders
group by
o_orderkey,
o_custkey,
o_orderstatus,
o_orderdate
) as t1
left join lineitem on lineitem.l_orderkey = t1.o_orderkey
group by
t1.o_orderdate,
t1.o_orderkey,
t1.col1
"""
order_qt_query9_0_before "${query9_0}"
check_rewrite(mv9_0, query9_0, "mv9_0")
order_qt_query9_0_after "${query9_0}"
sql """ DROP MATERIALIZED VIEW IF EXISTS mv9_0"""
}