[fix](nereids) Fix query rewrite by mv fail when self join (#29227)
Fix query rewrite by mv fail when self join, after fix query like following can be rewrited
def materialized view = """
select
a.o_orderkey,
count(distinct a.o_orderstatus) num1,
SUM(CASE WHEN a.o_orderstatus = 'o' AND a.o_shippriority = 1 AND a.o_orderdate = '2023-12-08' AND b.o_orderdate = '2023-12-09' THEN a.o_shippriority+b.o_custkey ELSE 0 END) num2,
SUM(CASE WHEN a.o_orderstatus = 'o' AND a.o_shippriority = 1 AND a.o_orderdate >= '2023-12-01' AND a.o_orderdate <= '2023-12-09' THEN a.o_shippriority+b.o_custkey ELSE 0 END) num3,
SUM(CASE WHEN a.o_orderstatus = 'o' AND a.o_shippriority in (1,2) AND a.o_orderdate >= '2023-12-08' AND b.o_orderdate <= '2023-12-09' THEN a.o_shippriority-b.o_custkey ELSE 0 END) num4,
AVG(a.o_totalprice) num5,
MAX(b.o_totalprice) num6,
MIN(a.o_totalprice) num7
from
orders a
left outer join orders b
on a.o_orderkey = b.o_orderkey
and a.o_custkey = b.o_custkey
group by a.o_orderkey;
"""
def query = """
select
a.o_orderkey,
SUM(CASE WHEN a.o_orderstatus = 'o' AND a.o_shippriority = 1 AND a.o_orderdate = '2023-12-08' AND b.o_orderdate = '2023-12-09' THEN a.o_shippriority+b.o_custkey ELSE 0 END) num2,
SUM(CASE WHEN a.o_orderstatus = 'o' AND a.o_shippriority = 1 AND a.o_orderdate >= '2023-12-01' AND a.o_orderdate <= '2023-12-09' THEN a.o_shippriority+b.o_custkey ELSE 0 END) num3,
SUM(CASE WHEN a.o_orderstatus = 'o' AND a.o_shippriority in (1,2) AND a.o_orderdate >= '2023-12-08' AND b.o_orderdate <= '2023-12-09' THEN a.o_shippriority-b.o_custkey ELSE 0 END) num4,
AVG(a.o_totalprice) num5,
MAX(b.o_totalprice) num6,
MIN(a.o_totalprice) num7
from
orders a
left outer join orders b
on a.o_orderkey = b.o_orderkey
and a.o_custkey = b.o_custkey
group by a.o_orderkey;
"""
This commit is contained in:
@ -22,11 +22,13 @@ import org.apache.doris.common.Pair;
|
||||
import org.apache.doris.nereids.trees.plans.algebra.CatalogRelation;
|
||||
|
||||
import com.google.common.collect.BiMap;
|
||||
import com.google.common.collect.HashBiMap;
|
||||
import com.google.common.collect.HashMultimap;
|
||||
import com.google.common.collect.ImmutableBiMap;
|
||||
import com.google.common.collect.ImmutableBiMap.Builder;
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import com.google.common.collect.LinkedListMultimap;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.google.common.collect.Sets;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
@ -59,54 +61,85 @@ public class RelationMapping extends Mapping {
|
||||
*/
|
||||
public static List<RelationMapping> generate(List<CatalogRelation> sources, List<CatalogRelation> targets) {
|
||||
// Construct tmp map, key is the table qualifier, value is the corresponding catalog relations
|
||||
LinkedListMultimap<Long, MappedRelation> sourceTableRelationIdMap = LinkedListMultimap.create();
|
||||
HashMultimap<Long, MappedRelation> sourceTableRelationIdMap = HashMultimap.create();
|
||||
for (CatalogRelation relation : sources) {
|
||||
sourceTableRelationIdMap.put(getTableQualifier(relation.getTable()),
|
||||
MappedRelation.of(relation.getRelationId(), relation));
|
||||
}
|
||||
LinkedListMultimap<Long, MappedRelation> targetTableRelationIdMap = LinkedListMultimap.create();
|
||||
HashMultimap<Long, MappedRelation> targetTableRelationIdMap = HashMultimap.create();
|
||||
for (CatalogRelation relation : targets) {
|
||||
targetTableRelationIdMap.put(getTableQualifier(relation.getTable()),
|
||||
MappedRelation.of(relation.getRelationId(), relation));
|
||||
}
|
||||
Set<Long> sourceTableKeySet = sourceTableRelationIdMap.keySet();
|
||||
List<List<Pair<MappedRelation, MappedRelation>>> mappedRelations = new ArrayList<>();
|
||||
List<List<BiMap<MappedRelation, MappedRelation>>> mappedRelations = new ArrayList<>();
|
||||
|
||||
for (Long sourceTableQualifier : sourceTableKeySet) {
|
||||
List<MappedRelation> sourceMappedRelations = sourceTableRelationIdMap.get(sourceTableQualifier);
|
||||
List<MappedRelation> targetMappedRelations = targetTableRelationIdMap.get(sourceTableQualifier);
|
||||
for (Long sourceTableId : sourceTableKeySet) {
|
||||
Set<MappedRelation> sourceMappedRelations = sourceTableRelationIdMap.get(sourceTableId);
|
||||
Set<MappedRelation> targetMappedRelations = targetTableRelationIdMap.get(sourceTableId);
|
||||
if (targetMappedRelations.isEmpty()) {
|
||||
continue;
|
||||
}
|
||||
// if source and target relation appear once, just map them
|
||||
if (targetMappedRelations.size() == 1 && sourceMappedRelations.size() == 1) {
|
||||
mappedRelations.add(ImmutableList.of(Pair.of(sourceMappedRelations.get(0),
|
||||
targetMappedRelations.get(0))));
|
||||
ImmutableBiMap.Builder<MappedRelation, MappedRelation> biMapBuilder = ImmutableBiMap.builder();
|
||||
mappedRelations.add(ImmutableList.of(
|
||||
biMapBuilder.put(sourceMappedRelations.iterator().next(),
|
||||
targetMappedRelations.iterator().next()).build()));
|
||||
continue;
|
||||
}
|
||||
// relation appear more than once, should cartesian them
|
||||
ImmutableList<Pair<MappedRelation, MappedRelation>> relationMapping = Lists.cartesianProduct(
|
||||
sourceTableRelationIdMap.get(sourceTableQualifier), targetMappedRelations)
|
||||
// relation appear more than once, should cartesian them and power set to correct combination
|
||||
// if query is select * from tableA0, tableA1, materialized view is select * from tableA2, tableA3,
|
||||
// tableA is the same table used by both query and materialized view
|
||||
// relationMapping will be
|
||||
// tableA0 tableA2
|
||||
// tableA0 tableA3
|
||||
// tableA1 tableA2
|
||||
// tableA1 tableA3
|
||||
ImmutableList<Pair<MappedRelation, MappedRelation>> relationMapping = Sets.cartesianProduct(
|
||||
sourceMappedRelations, targetMappedRelations)
|
||||
.stream()
|
||||
.map(listPair -> Pair.of(listPair.get(0), listPair.get(1)))
|
||||
.collect(ImmutableList.toImmutableList());
|
||||
mappedRelations.add(relationMapping);
|
||||
}
|
||||
|
||||
int mappedRelationCount = mappedRelations.size();
|
||||
|
||||
return Lists.cartesianProduct(mappedRelations).stream()
|
||||
.map(mappedRelationList -> {
|
||||
Builder<MappedRelation, MappedRelation> mapBuilder = ImmutableBiMap.builder();
|
||||
for (int relationIndex = 0; relationIndex < mappedRelationCount; relationIndex++) {
|
||||
mapBuilder.put(mappedRelationList.get(relationIndex).key(),
|
||||
mappedRelationList.get(relationIndex).value());
|
||||
// the mapping in relationMappingPowerList should be bi-direction
|
||||
// [
|
||||
// {tableA0 tableA2, tableA1 tableA3}
|
||||
// {tableA0 tableA3, tableA1 tableA2}
|
||||
// ]
|
||||
List<BiMap<MappedRelation, MappedRelation>> relationMappingPowerList = new ArrayList<>();
|
||||
int relationMappingSize = relationMapping.size();
|
||||
int relationMappingMinSize = Math.min(sourceMappedRelations.size(), targetMappedRelations.size());
|
||||
for (int i = 0; i < relationMappingSize; i++) {
|
||||
HashBiMap<MappedRelation, MappedRelation> relationBiMap = HashBiMap.create();
|
||||
relationBiMap.put(relationMapping.get(i).key(), relationMapping.get(i).value());
|
||||
for (int j = i + 1; j < relationMappingSize; j++) {
|
||||
if (!relationBiMap.containsKey(relationMapping.get(j).key())
|
||||
&& !relationBiMap.containsValue(relationMapping.get(j).value())) {
|
||||
relationBiMap.put(relationMapping.get(j).key(), relationMapping.get(j).value());
|
||||
}
|
||||
return RelationMapping.of(mapBuilder.build());
|
||||
})
|
||||
}
|
||||
// mapping should contain min num of relation in source or target at least
|
||||
if (relationBiMap.size() >= relationMappingMinSize) {
|
||||
relationMappingPowerList.add(relationBiMap);
|
||||
}
|
||||
}
|
||||
mappedRelations.add(relationMappingPowerList);
|
||||
}
|
||||
// mappedRelations product and merge into each relationMapping
|
||||
return Lists.cartesianProduct(mappedRelations).stream()
|
||||
.map(RelationMapping::merge)
|
||||
.collect(ImmutableList.toImmutableList());
|
||||
}
|
||||
|
||||
public static RelationMapping merge(List<BiMap<MappedRelation, MappedRelation>> relationMappings) {
|
||||
Builder<MappedRelation, MappedRelation> mappingBuilder = ImmutableBiMap.builder();
|
||||
for (BiMap<MappedRelation, MappedRelation> relationMapping : relationMappings) {
|
||||
relationMapping.forEach(mappingBuilder::put);
|
||||
}
|
||||
return RelationMapping.of(mappingBuilder.build());
|
||||
}
|
||||
|
||||
private static Long getTableQualifier(TableIf tableIf) {
|
||||
return tableIf.getId();
|
||||
}
|
||||
|
||||
@ -35,7 +35,9 @@ import org.junit.jupiter.api.Test;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
/**MappingTest*/
|
||||
/**
|
||||
* MappingTest
|
||||
*/
|
||||
public class MappingTest extends TestWithFeService {
|
||||
|
||||
@Override
|
||||
@ -275,6 +277,72 @@ public class MappingTest extends TestWithFeService {
|
||||
assertRelationMapping(generateRelationMapping.get(1), expectedRelationMapping, expectedSlotMapping);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testGenerateMapping5() {
|
||||
Plan sourcePlan = PlanChecker.from(connectContext)
|
||||
.analyze("SELECT orders.*, l1.* "
|
||||
+ "FROM\n"
|
||||
+ " orders,\n"
|
||||
+ " lineitem l1,\n"
|
||||
+ " lineitem l2\n"
|
||||
+ "WHERE\n"
|
||||
+ " l1.l_orderkey = l2.l_orderkey\n"
|
||||
+ " AND l1.l_orderkey = o_orderkey")
|
||||
.getPlan();
|
||||
|
||||
Plan targetPlan = PlanChecker.from(connectContext)
|
||||
.analyze("SELECT orders.*, l1.* "
|
||||
+ "FROM\n"
|
||||
+ " lineitem l1,\n"
|
||||
+ " orders,\n"
|
||||
+ " lineitem l2\n"
|
||||
+ "WHERE\n"
|
||||
+ " l1.l_orderkey = l2.l_orderkey\n"
|
||||
+ " AND l2.l_orderkey = o_orderkey")
|
||||
.getPlan();
|
||||
List<CatalogRelation> sourceRelations = new ArrayList<>();
|
||||
sourcePlan.accept(RelationCollector.INSTANCE, sourceRelations);
|
||||
|
||||
List<CatalogRelation> targetRelations = new ArrayList<>();
|
||||
targetPlan.accept(RelationCollector.INSTANCE, targetRelations);
|
||||
|
||||
List<RelationMapping> generateRelationMapping = RelationMapping.generate(sourceRelations, targetRelations);
|
||||
Assertions.assertNotNull(generateRelationMapping);
|
||||
Assertions.assertEquals(2, generateRelationMapping.size());
|
||||
|
||||
// expected slot mapping
|
||||
BiMap<ExprId, ExprId> expectedSlotMapping = HashBiMap.create();
|
||||
expectedSlotMapping.put(new ExprId(0), new ExprId(2));
|
||||
expectedSlotMapping.put(new ExprId(1), new ExprId(3));
|
||||
expectedSlotMapping.put(new ExprId(2), new ExprId(4));
|
||||
expectedSlotMapping.put(new ExprId(3), new ExprId(0));
|
||||
expectedSlotMapping.put(new ExprId(4), new ExprId(1));
|
||||
expectedSlotMapping.put(new ExprId(5), new ExprId(5));
|
||||
expectedSlotMapping.put(new ExprId(6), new ExprId(6));
|
||||
// expected relation mapping
|
||||
BiMap<RelationId, RelationId> expectedRelationMapping = HashBiMap.create();
|
||||
expectedRelationMapping.put(new RelationId(0), new RelationId(1));
|
||||
expectedRelationMapping.put(new RelationId(1), new RelationId(0));
|
||||
expectedRelationMapping.put(new RelationId(2), new RelationId(2));
|
||||
assertRelationMapping(generateRelationMapping.get(1), expectedRelationMapping, expectedSlotMapping);
|
||||
|
||||
// expected slot mapping
|
||||
expectedSlotMapping = HashBiMap.create();
|
||||
expectedSlotMapping.put(new ExprId(0), new ExprId(2));
|
||||
expectedSlotMapping.put(new ExprId(1), new ExprId(3));
|
||||
expectedSlotMapping.put(new ExprId(2), new ExprId(4));
|
||||
expectedSlotMapping.put(new ExprId(3), new ExprId(5));
|
||||
expectedSlotMapping.put(new ExprId(4), new ExprId(6));
|
||||
expectedSlotMapping.put(new ExprId(5), new ExprId(0));
|
||||
expectedSlotMapping.put(new ExprId(6), new ExprId(1));
|
||||
// expected relation mapping
|
||||
expectedRelationMapping = HashBiMap.create();
|
||||
expectedRelationMapping.put(new RelationId(0), new RelationId(1));
|
||||
expectedRelationMapping.put(new RelationId(1), new RelationId(2));
|
||||
expectedRelationMapping.put(new RelationId(2), new RelationId(0));
|
||||
assertRelationMapping(generateRelationMapping.get(0), expectedRelationMapping, expectedSlotMapping);
|
||||
}
|
||||
|
||||
private void assertRelationMapping(RelationMapping relationMapping,
|
||||
BiMap<RelationId, RelationId> expectRelationMapping,
|
||||
BiMap<ExprId, ExprId> expectSlotMapping) {
|
||||
|
||||
@ -237,3 +237,17 @@
|
||||
-- !query7_0_after --
|
||||
3 3 2023-12-11
|
||||
|
||||
-- !query8_0_before --
|
||||
1 0 8 0 10.0000 10.50 9.50
|
||||
2 0 2 0 11.5000 11.50 11.50
|
||||
3 0 0 0 23.0000 33.50 12.50
|
||||
4 0 0 0 43.2000 43.20 43.20
|
||||
5 0 0 0 28.7000 56.20 1.20
|
||||
|
||||
-- !query8_0_after --
|
||||
1 0 8 0 10.0000 10.50 9.50
|
||||
2 0 2 0 11.5000 11.50 11.50
|
||||
3 0 0 0 23.0000 33.50 12.50
|
||||
4 0 0 0 43.2000 43.20 43.20
|
||||
5 0 0 0 28.7000 56.20 1.20
|
||||
|
||||
|
||||
@ -361,4 +361,43 @@ suite("outer_join") {
|
||||
order_qt_query7_0_after "${query7_0}"
|
||||
sql """ DROP MATERIALIZED VIEW IF EXISTS mv7_0"""
|
||||
|
||||
|
||||
// self join test
|
||||
def mv8_0 = """
|
||||
select
|
||||
a.o_orderkey,
|
||||
count(distinct a.o_orderstatus) num1,
|
||||
SUM(CASE WHEN a.o_orderstatus = 'o' AND a.o_shippriority = 1 AND a.o_orderdate = '2023-12-08' AND b.o_orderdate = '2023-12-09' THEN a.o_shippriority+b.o_custkey ELSE 0 END) num2,
|
||||
SUM(CASE WHEN a.o_orderstatus = 'o' AND a.o_shippriority = 1 AND a.o_orderdate >= '2023-12-01' AND a.o_orderdate <= '2023-12-09' THEN a.o_shippriority+b.o_custkey ELSE 0 END) num3,
|
||||
SUM(CASE WHEN a.o_orderstatus = 'o' AND a.o_shippriority in (1,2) AND a.o_orderdate >= '2023-12-08' AND b.o_orderdate <= '2023-12-09' THEN a.o_shippriority-b.o_custkey ELSE 0 END) num4,
|
||||
AVG(a.o_totalprice) num5,
|
||||
MAX(b.o_totalprice) num6,
|
||||
MIN(a.o_totalprice) num7
|
||||
from
|
||||
orders a
|
||||
left outer join orders b
|
||||
on a.o_orderkey = b.o_orderkey
|
||||
and a.o_custkey = b.o_custkey
|
||||
group by a.o_orderkey;
|
||||
"""
|
||||
def query8_0 = """
|
||||
select
|
||||
a.o_orderkey,
|
||||
SUM(CASE WHEN a.o_orderstatus = 'o' AND a.o_shippriority = 1 AND a.o_orderdate = '2023-12-08' AND b.o_orderdate = '2023-12-09' THEN a.o_shippriority+b.o_custkey ELSE 0 END) num2,
|
||||
SUM(CASE WHEN a.o_orderstatus = 'o' AND a.o_shippriority = 1 AND a.o_orderdate >= '2023-12-01' AND a.o_orderdate <= '2023-12-09' THEN a.o_shippriority+b.o_custkey ELSE 0 END) num3,
|
||||
SUM(CASE WHEN a.o_orderstatus = 'o' AND a.o_shippriority in (1,2) AND a.o_orderdate >= '2023-12-08' AND b.o_orderdate <= '2023-12-09' THEN a.o_shippriority-b.o_custkey ELSE 0 END) num4,
|
||||
AVG(a.o_totalprice) num5,
|
||||
MAX(b.o_totalprice) num6,
|
||||
MIN(a.o_totalprice) num7
|
||||
from
|
||||
orders a
|
||||
left outer join orders b
|
||||
on a.o_orderkey = b.o_orderkey
|
||||
and a.o_custkey = b.o_custkey
|
||||
group by a.o_orderkey;
|
||||
"""
|
||||
order_qt_query8_0_before "${query8_0}"
|
||||
check_rewrite(mv8_0, query8_0, "mv8_0")
|
||||
order_qt_query8_0_after "${query8_0}"
|
||||
sql """ DROP MATERIALIZED VIEW IF EXISTS mv8_0"""
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user