[fix](nereids) Fix query rewrite by mv fail when self join (#29227)

Fix query rewrite by mv fail when self join, after fix query like following can be rewrited

def materialized view = """
    select 
    a.o_orderkey,
    count(distinct a.o_orderstatus) num1,
    SUM(CASE WHEN a.o_orderstatus = 'o' AND a.o_shippriority = 1 AND a.o_orderdate = '2023-12-08' AND b.o_orderdate = '2023-12-09' THEN a.o_shippriority+b.o_custkey ELSE 0 END) num2,
    SUM(CASE WHEN a.o_orderstatus = 'o' AND a.o_shippriority = 1 AND a.o_orderdate >= '2023-12-01' AND a.o_orderdate <= '2023-12-09' THEN a.o_shippriority+b.o_custkey ELSE 0 END) num3,
    SUM(CASE WHEN a.o_orderstatus = 'o' AND a.o_shippriority in (1,2) AND a.o_orderdate >= '2023-12-08' AND b.o_orderdate <= '2023-12-09' THEN a.o_shippriority-b.o_custkey ELSE 0 END) num4,
    AVG(a.o_totalprice) num5,
    MAX(b.o_totalprice) num6,
    MIN(a.o_totalprice) num7
    from
    orders a
    left outer join orders b
    on a.o_orderkey = b.o_orderkey
    and a.o_custkey = b.o_custkey
    group by a.o_orderkey;
"""

def query = """
    select 
    a.o_orderkey,
    SUM(CASE WHEN a.o_orderstatus = 'o' AND a.o_shippriority = 1 AND a.o_orderdate = '2023-12-08' AND b.o_orderdate = '2023-12-09' THEN a.o_shippriority+b.o_custkey ELSE 0 END) num2,
    SUM(CASE WHEN a.o_orderstatus = 'o' AND a.o_shippriority = 1 AND a.o_orderdate >= '2023-12-01' AND a.o_orderdate <= '2023-12-09' THEN a.o_shippriority+b.o_custkey ELSE 0 END) num3,
    SUM(CASE WHEN a.o_orderstatus = 'o' AND a.o_shippriority in (1,2) AND a.o_orderdate >= '2023-12-08' AND b.o_orderdate <= '2023-12-09' THEN a.o_shippriority-b.o_custkey ELSE 0 END) num4,
    AVG(a.o_totalprice) num5,
    MAX(b.o_totalprice) num6,
    MIN(a.o_totalprice) num7
    from
    orders a
    left outer join orders b
    on a.o_orderkey = b.o_orderkey
    and a.o_custkey = b.o_custkey
    group by a.o_orderkey;
"""
This commit is contained in:
seawinde
2023-12-29 13:45:33 +08:00
committed by GitHub
parent 2794427e7f
commit 9fc613de9c
4 changed files with 179 additions and 25 deletions

View File

@ -22,11 +22,13 @@ import org.apache.doris.common.Pair;
import org.apache.doris.nereids.trees.plans.algebra.CatalogRelation;
import com.google.common.collect.BiMap;
import com.google.common.collect.HashBiMap;
import com.google.common.collect.HashMultimap;
import com.google.common.collect.ImmutableBiMap;
import com.google.common.collect.ImmutableBiMap.Builder;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.LinkedListMultimap;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import java.util.ArrayList;
import java.util.List;
@ -59,54 +61,85 @@ public class RelationMapping extends Mapping {
*/
public static List<RelationMapping> generate(List<CatalogRelation> sources, List<CatalogRelation> targets) {
// Construct tmp map, key is the table qualifier, value is the corresponding catalog relations
LinkedListMultimap<Long, MappedRelation> sourceTableRelationIdMap = LinkedListMultimap.create();
HashMultimap<Long, MappedRelation> sourceTableRelationIdMap = HashMultimap.create();
for (CatalogRelation relation : sources) {
sourceTableRelationIdMap.put(getTableQualifier(relation.getTable()),
MappedRelation.of(relation.getRelationId(), relation));
}
LinkedListMultimap<Long, MappedRelation> targetTableRelationIdMap = LinkedListMultimap.create();
HashMultimap<Long, MappedRelation> targetTableRelationIdMap = HashMultimap.create();
for (CatalogRelation relation : targets) {
targetTableRelationIdMap.put(getTableQualifier(relation.getTable()),
MappedRelation.of(relation.getRelationId(), relation));
}
Set<Long> sourceTableKeySet = sourceTableRelationIdMap.keySet();
List<List<Pair<MappedRelation, MappedRelation>>> mappedRelations = new ArrayList<>();
List<List<BiMap<MappedRelation, MappedRelation>>> mappedRelations = new ArrayList<>();
for (Long sourceTableQualifier : sourceTableKeySet) {
List<MappedRelation> sourceMappedRelations = sourceTableRelationIdMap.get(sourceTableQualifier);
List<MappedRelation> targetMappedRelations = targetTableRelationIdMap.get(sourceTableQualifier);
for (Long sourceTableId : sourceTableKeySet) {
Set<MappedRelation> sourceMappedRelations = sourceTableRelationIdMap.get(sourceTableId);
Set<MappedRelation> targetMappedRelations = targetTableRelationIdMap.get(sourceTableId);
if (targetMappedRelations.isEmpty()) {
continue;
}
// if source and target relation appear once, just map them
if (targetMappedRelations.size() == 1 && sourceMappedRelations.size() == 1) {
mappedRelations.add(ImmutableList.of(Pair.of(sourceMappedRelations.get(0),
targetMappedRelations.get(0))));
ImmutableBiMap.Builder<MappedRelation, MappedRelation> biMapBuilder = ImmutableBiMap.builder();
mappedRelations.add(ImmutableList.of(
biMapBuilder.put(sourceMappedRelations.iterator().next(),
targetMappedRelations.iterator().next()).build()));
continue;
}
// relation appear more than once, should cartesian them
ImmutableList<Pair<MappedRelation, MappedRelation>> relationMapping = Lists.cartesianProduct(
sourceTableRelationIdMap.get(sourceTableQualifier), targetMappedRelations)
// relation appear more than once, should cartesian them and power set to correct combination
// if query is select * from tableA0, tableA1, materialized view is select * from tableA2, tableA3,
// tableA is the same table used by both query and materialized view
// relationMapping will be
// tableA0 tableA2
// tableA0 tableA3
// tableA1 tableA2
// tableA1 tableA3
ImmutableList<Pair<MappedRelation, MappedRelation>> relationMapping = Sets.cartesianProduct(
sourceMappedRelations, targetMappedRelations)
.stream()
.map(listPair -> Pair.of(listPair.get(0), listPair.get(1)))
.collect(ImmutableList.toImmutableList());
mappedRelations.add(relationMapping);
}
int mappedRelationCount = mappedRelations.size();
return Lists.cartesianProduct(mappedRelations).stream()
.map(mappedRelationList -> {
Builder<MappedRelation, MappedRelation> mapBuilder = ImmutableBiMap.builder();
for (int relationIndex = 0; relationIndex < mappedRelationCount; relationIndex++) {
mapBuilder.put(mappedRelationList.get(relationIndex).key(),
mappedRelationList.get(relationIndex).value());
// the mapping in relationMappingPowerList should be bi-direction
// [
// {tableA0 tableA2, tableA1 tableA3}
// {tableA0 tableA3, tableA1 tableA2}
// ]
List<BiMap<MappedRelation, MappedRelation>> relationMappingPowerList = new ArrayList<>();
int relationMappingSize = relationMapping.size();
int relationMappingMinSize = Math.min(sourceMappedRelations.size(), targetMappedRelations.size());
for (int i = 0; i < relationMappingSize; i++) {
HashBiMap<MappedRelation, MappedRelation> relationBiMap = HashBiMap.create();
relationBiMap.put(relationMapping.get(i).key(), relationMapping.get(i).value());
for (int j = i + 1; j < relationMappingSize; j++) {
if (!relationBiMap.containsKey(relationMapping.get(j).key())
&& !relationBiMap.containsValue(relationMapping.get(j).value())) {
relationBiMap.put(relationMapping.get(j).key(), relationMapping.get(j).value());
}
return RelationMapping.of(mapBuilder.build());
})
}
// mapping should contain min num of relation in source or target at least
if (relationBiMap.size() >= relationMappingMinSize) {
relationMappingPowerList.add(relationBiMap);
}
}
mappedRelations.add(relationMappingPowerList);
}
// mappedRelations product and merge into each relationMapping
return Lists.cartesianProduct(mappedRelations).stream()
.map(RelationMapping::merge)
.collect(ImmutableList.toImmutableList());
}
public static RelationMapping merge(List<BiMap<MappedRelation, MappedRelation>> relationMappings) {
Builder<MappedRelation, MappedRelation> mappingBuilder = ImmutableBiMap.builder();
for (BiMap<MappedRelation, MappedRelation> relationMapping : relationMappings) {
relationMapping.forEach(mappingBuilder::put);
}
return RelationMapping.of(mappingBuilder.build());
}
private static Long getTableQualifier(TableIf tableIf) {
return tableIf.getId();
}