[fix](nereids) Fix query rewrite by mv fail when self join (#29227)

Fix query rewrite by mv fail when self join, after fix query like following can be rewrited

def materialized view = """
    select 
    a.o_orderkey,
    count(distinct a.o_orderstatus) num1,
    SUM(CASE WHEN a.o_orderstatus = 'o' AND a.o_shippriority = 1 AND a.o_orderdate = '2023-12-08' AND b.o_orderdate = '2023-12-09' THEN a.o_shippriority+b.o_custkey ELSE 0 END) num2,
    SUM(CASE WHEN a.o_orderstatus = 'o' AND a.o_shippriority = 1 AND a.o_orderdate >= '2023-12-01' AND a.o_orderdate <= '2023-12-09' THEN a.o_shippriority+b.o_custkey ELSE 0 END) num3,
    SUM(CASE WHEN a.o_orderstatus = 'o' AND a.o_shippriority in (1,2) AND a.o_orderdate >= '2023-12-08' AND b.o_orderdate <= '2023-12-09' THEN a.o_shippriority-b.o_custkey ELSE 0 END) num4,
    AVG(a.o_totalprice) num5,
    MAX(b.o_totalprice) num6,
    MIN(a.o_totalprice) num7
    from
    orders a
    left outer join orders b
    on a.o_orderkey = b.o_orderkey
    and a.o_custkey = b.o_custkey
    group by a.o_orderkey;
"""

def query = """
    select 
    a.o_orderkey,
    SUM(CASE WHEN a.o_orderstatus = 'o' AND a.o_shippriority = 1 AND a.o_orderdate = '2023-12-08' AND b.o_orderdate = '2023-12-09' THEN a.o_shippriority+b.o_custkey ELSE 0 END) num2,
    SUM(CASE WHEN a.o_orderstatus = 'o' AND a.o_shippriority = 1 AND a.o_orderdate >= '2023-12-01' AND a.o_orderdate <= '2023-12-09' THEN a.o_shippriority+b.o_custkey ELSE 0 END) num3,
    SUM(CASE WHEN a.o_orderstatus = 'o' AND a.o_shippriority in (1,2) AND a.o_orderdate >= '2023-12-08' AND b.o_orderdate <= '2023-12-09' THEN a.o_shippriority-b.o_custkey ELSE 0 END) num4,
    AVG(a.o_totalprice) num5,
    MAX(b.o_totalprice) num6,
    MIN(a.o_totalprice) num7
    from
    orders a
    left outer join orders b
    on a.o_orderkey = b.o_orderkey
    and a.o_custkey = b.o_custkey
    group by a.o_orderkey;
"""
This commit is contained in:
seawinde
2023-12-29 13:45:33 +08:00
committed by GitHub
parent 2794427e7f
commit 9fc613de9c
4 changed files with 179 additions and 25 deletions

View File

@ -22,11 +22,13 @@ import org.apache.doris.common.Pair;
import org.apache.doris.nereids.trees.plans.algebra.CatalogRelation;
import com.google.common.collect.BiMap;
import com.google.common.collect.HashBiMap;
import com.google.common.collect.HashMultimap;
import com.google.common.collect.ImmutableBiMap;
import com.google.common.collect.ImmutableBiMap.Builder;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.LinkedListMultimap;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import java.util.ArrayList;
import java.util.List;
@ -59,54 +61,85 @@ public class RelationMapping extends Mapping {
*/
public static List<RelationMapping> generate(List<CatalogRelation> sources, List<CatalogRelation> targets) {
// Construct tmp map, key is the table qualifier, value is the corresponding catalog relations
LinkedListMultimap<Long, MappedRelation> sourceTableRelationIdMap = LinkedListMultimap.create();
HashMultimap<Long, MappedRelation> sourceTableRelationIdMap = HashMultimap.create();
for (CatalogRelation relation : sources) {
sourceTableRelationIdMap.put(getTableQualifier(relation.getTable()),
MappedRelation.of(relation.getRelationId(), relation));
}
LinkedListMultimap<Long, MappedRelation> targetTableRelationIdMap = LinkedListMultimap.create();
HashMultimap<Long, MappedRelation> targetTableRelationIdMap = HashMultimap.create();
for (CatalogRelation relation : targets) {
targetTableRelationIdMap.put(getTableQualifier(relation.getTable()),
MappedRelation.of(relation.getRelationId(), relation));
}
Set<Long> sourceTableKeySet = sourceTableRelationIdMap.keySet();
List<List<Pair<MappedRelation, MappedRelation>>> mappedRelations = new ArrayList<>();
List<List<BiMap<MappedRelation, MappedRelation>>> mappedRelations = new ArrayList<>();
for (Long sourceTableQualifier : sourceTableKeySet) {
List<MappedRelation> sourceMappedRelations = sourceTableRelationIdMap.get(sourceTableQualifier);
List<MappedRelation> targetMappedRelations = targetTableRelationIdMap.get(sourceTableQualifier);
for (Long sourceTableId : sourceTableKeySet) {
Set<MappedRelation> sourceMappedRelations = sourceTableRelationIdMap.get(sourceTableId);
Set<MappedRelation> targetMappedRelations = targetTableRelationIdMap.get(sourceTableId);
if (targetMappedRelations.isEmpty()) {
continue;
}
// if source and target relation appear once, just map them
if (targetMappedRelations.size() == 1 && sourceMappedRelations.size() == 1) {
mappedRelations.add(ImmutableList.of(Pair.of(sourceMappedRelations.get(0),
targetMappedRelations.get(0))));
ImmutableBiMap.Builder<MappedRelation, MappedRelation> biMapBuilder = ImmutableBiMap.builder();
mappedRelations.add(ImmutableList.of(
biMapBuilder.put(sourceMappedRelations.iterator().next(),
targetMappedRelations.iterator().next()).build()));
continue;
}
// relation appear more than once, should cartesian them
ImmutableList<Pair<MappedRelation, MappedRelation>> relationMapping = Lists.cartesianProduct(
sourceTableRelationIdMap.get(sourceTableQualifier), targetMappedRelations)
// relation appear more than once, should cartesian them and power set to correct combination
// if query is select * from tableA0, tableA1, materialized view is select * from tableA2, tableA3,
// tableA is the same table used by both query and materialized view
// relationMapping will be
// tableA0 tableA2
// tableA0 tableA3
// tableA1 tableA2
// tableA1 tableA3
ImmutableList<Pair<MappedRelation, MappedRelation>> relationMapping = Sets.cartesianProduct(
sourceMappedRelations, targetMappedRelations)
.stream()
.map(listPair -> Pair.of(listPair.get(0), listPair.get(1)))
.collect(ImmutableList.toImmutableList());
mappedRelations.add(relationMapping);
}
int mappedRelationCount = mappedRelations.size();
return Lists.cartesianProduct(mappedRelations).stream()
.map(mappedRelationList -> {
Builder<MappedRelation, MappedRelation> mapBuilder = ImmutableBiMap.builder();
for (int relationIndex = 0; relationIndex < mappedRelationCount; relationIndex++) {
mapBuilder.put(mappedRelationList.get(relationIndex).key(),
mappedRelationList.get(relationIndex).value());
// the mapping in relationMappingPowerList should be bi-direction
// [
// {tableA0 tableA2, tableA1 tableA3}
// {tableA0 tableA3, tableA1 tableA2}
// ]
List<BiMap<MappedRelation, MappedRelation>> relationMappingPowerList = new ArrayList<>();
int relationMappingSize = relationMapping.size();
int relationMappingMinSize = Math.min(sourceMappedRelations.size(), targetMappedRelations.size());
for (int i = 0; i < relationMappingSize; i++) {
HashBiMap<MappedRelation, MappedRelation> relationBiMap = HashBiMap.create();
relationBiMap.put(relationMapping.get(i).key(), relationMapping.get(i).value());
for (int j = i + 1; j < relationMappingSize; j++) {
if (!relationBiMap.containsKey(relationMapping.get(j).key())
&& !relationBiMap.containsValue(relationMapping.get(j).value())) {
relationBiMap.put(relationMapping.get(j).key(), relationMapping.get(j).value());
}
return RelationMapping.of(mapBuilder.build());
})
}
// mapping should contain min num of relation in source or target at least
if (relationBiMap.size() >= relationMappingMinSize) {
relationMappingPowerList.add(relationBiMap);
}
}
mappedRelations.add(relationMappingPowerList);
}
// mappedRelations product and merge into each relationMapping
return Lists.cartesianProduct(mappedRelations).stream()
.map(RelationMapping::merge)
.collect(ImmutableList.toImmutableList());
}
public static RelationMapping merge(List<BiMap<MappedRelation, MappedRelation>> relationMappings) {
Builder<MappedRelation, MappedRelation> mappingBuilder = ImmutableBiMap.builder();
for (BiMap<MappedRelation, MappedRelation> relationMapping : relationMappings) {
relationMapping.forEach(mappingBuilder::put);
}
return RelationMapping.of(mappingBuilder.build());
}
private static Long getTableQualifier(TableIf tableIf) {
return tableIf.getId();
}

View File

@ -35,7 +35,9 @@ import org.junit.jupiter.api.Test;
import java.util.ArrayList;
import java.util.List;
/**MappingTest*/
/**
* MappingTest
*/
public class MappingTest extends TestWithFeService {
@Override
@ -275,6 +277,72 @@ public class MappingTest extends TestWithFeService {
assertRelationMapping(generateRelationMapping.get(1), expectedRelationMapping, expectedSlotMapping);
}
@Test
public void testGenerateMapping5() {
Plan sourcePlan = PlanChecker.from(connectContext)
.analyze("SELECT orders.*, l1.* "
+ "FROM\n"
+ " orders,\n"
+ " lineitem l1,\n"
+ " lineitem l2\n"
+ "WHERE\n"
+ " l1.l_orderkey = l2.l_orderkey\n"
+ " AND l1.l_orderkey = o_orderkey")
.getPlan();
Plan targetPlan = PlanChecker.from(connectContext)
.analyze("SELECT orders.*, l1.* "
+ "FROM\n"
+ " lineitem l1,\n"
+ " orders,\n"
+ " lineitem l2\n"
+ "WHERE\n"
+ " l1.l_orderkey = l2.l_orderkey\n"
+ " AND l2.l_orderkey = o_orderkey")
.getPlan();
List<CatalogRelation> sourceRelations = new ArrayList<>();
sourcePlan.accept(RelationCollector.INSTANCE, sourceRelations);
List<CatalogRelation> targetRelations = new ArrayList<>();
targetPlan.accept(RelationCollector.INSTANCE, targetRelations);
List<RelationMapping> generateRelationMapping = RelationMapping.generate(sourceRelations, targetRelations);
Assertions.assertNotNull(generateRelationMapping);
Assertions.assertEquals(2, generateRelationMapping.size());
// expected slot mapping
BiMap<ExprId, ExprId> expectedSlotMapping = HashBiMap.create();
expectedSlotMapping.put(new ExprId(0), new ExprId(2));
expectedSlotMapping.put(new ExprId(1), new ExprId(3));
expectedSlotMapping.put(new ExprId(2), new ExprId(4));
expectedSlotMapping.put(new ExprId(3), new ExprId(0));
expectedSlotMapping.put(new ExprId(4), new ExprId(1));
expectedSlotMapping.put(new ExprId(5), new ExprId(5));
expectedSlotMapping.put(new ExprId(6), new ExprId(6));
// expected relation mapping
BiMap<RelationId, RelationId> expectedRelationMapping = HashBiMap.create();
expectedRelationMapping.put(new RelationId(0), new RelationId(1));
expectedRelationMapping.put(new RelationId(1), new RelationId(0));
expectedRelationMapping.put(new RelationId(2), new RelationId(2));
assertRelationMapping(generateRelationMapping.get(1), expectedRelationMapping, expectedSlotMapping);
// expected slot mapping
expectedSlotMapping = HashBiMap.create();
expectedSlotMapping.put(new ExprId(0), new ExprId(2));
expectedSlotMapping.put(new ExprId(1), new ExprId(3));
expectedSlotMapping.put(new ExprId(2), new ExprId(4));
expectedSlotMapping.put(new ExprId(3), new ExprId(5));
expectedSlotMapping.put(new ExprId(4), new ExprId(6));
expectedSlotMapping.put(new ExprId(5), new ExprId(0));
expectedSlotMapping.put(new ExprId(6), new ExprId(1));
// expected relation mapping
expectedRelationMapping = HashBiMap.create();
expectedRelationMapping.put(new RelationId(0), new RelationId(1));
expectedRelationMapping.put(new RelationId(1), new RelationId(2));
expectedRelationMapping.put(new RelationId(2), new RelationId(0));
assertRelationMapping(generateRelationMapping.get(0), expectedRelationMapping, expectedSlotMapping);
}
private void assertRelationMapping(RelationMapping relationMapping,
BiMap<RelationId, RelationId> expectRelationMapping,
BiMap<ExprId, ExprId> expectSlotMapping) {

View File

@ -237,3 +237,17 @@
-- !query7_0_after --
3 3 2023-12-11
-- !query8_0_before --
1 0 8 0 10.0000 10.50 9.50
2 0 2 0 11.5000 11.50 11.50
3 0 0 0 23.0000 33.50 12.50
4 0 0 0 43.2000 43.20 43.20
5 0 0 0 28.7000 56.20 1.20
-- !query8_0_after --
1 0 8 0 10.0000 10.50 9.50
2 0 2 0 11.5000 11.50 11.50
3 0 0 0 23.0000 33.50 12.50
4 0 0 0 43.2000 43.20 43.20
5 0 0 0 28.7000 56.20 1.20

View File

@ -361,4 +361,43 @@ suite("outer_join") {
order_qt_query7_0_after "${query7_0}"
sql """ DROP MATERIALIZED VIEW IF EXISTS mv7_0"""
// self join test
def mv8_0 = """
select
a.o_orderkey,
count(distinct a.o_orderstatus) num1,
SUM(CASE WHEN a.o_orderstatus = 'o' AND a.o_shippriority = 1 AND a.o_orderdate = '2023-12-08' AND b.o_orderdate = '2023-12-09' THEN a.o_shippriority+b.o_custkey ELSE 0 END) num2,
SUM(CASE WHEN a.o_orderstatus = 'o' AND a.o_shippriority = 1 AND a.o_orderdate >= '2023-12-01' AND a.o_orderdate <= '2023-12-09' THEN a.o_shippriority+b.o_custkey ELSE 0 END) num3,
SUM(CASE WHEN a.o_orderstatus = 'o' AND a.o_shippriority in (1,2) AND a.o_orderdate >= '2023-12-08' AND b.o_orderdate <= '2023-12-09' THEN a.o_shippriority-b.o_custkey ELSE 0 END) num4,
AVG(a.o_totalprice) num5,
MAX(b.o_totalprice) num6,
MIN(a.o_totalprice) num7
from
orders a
left outer join orders b
on a.o_orderkey = b.o_orderkey
and a.o_custkey = b.o_custkey
group by a.o_orderkey;
"""
def query8_0 = """
select
a.o_orderkey,
SUM(CASE WHEN a.o_orderstatus = 'o' AND a.o_shippriority = 1 AND a.o_orderdate = '2023-12-08' AND b.o_orderdate = '2023-12-09' THEN a.o_shippriority+b.o_custkey ELSE 0 END) num2,
SUM(CASE WHEN a.o_orderstatus = 'o' AND a.o_shippriority = 1 AND a.o_orderdate >= '2023-12-01' AND a.o_orderdate <= '2023-12-09' THEN a.o_shippriority+b.o_custkey ELSE 0 END) num3,
SUM(CASE WHEN a.o_orderstatus = 'o' AND a.o_shippriority in (1,2) AND a.o_orderdate >= '2023-12-08' AND b.o_orderdate <= '2023-12-09' THEN a.o_shippriority-b.o_custkey ELSE 0 END) num4,
AVG(a.o_totalprice) num5,
MAX(b.o_totalprice) num6,
MIN(a.o_totalprice) num7
from
orders a
left outer join orders b
on a.o_orderkey = b.o_orderkey
and a.o_custkey = b.o_custkey
group by a.o_orderkey;
"""
order_qt_query8_0_before "${query8_0}"
check_rewrite(mv8_0, query8_0, "mv8_0")
order_qt_query8_0_after "${query8_0}"
sql """ DROP MATERIALIZED VIEW IF EXISTS mv8_0"""
}