[feat](Nereids): use table map to construct struct info (#32058)

This commit is contained in:
谢健
2024-03-12 15:52:36 +08:00
committed by yiguolei
parent 45824d959c
commit 2da57526a3
3 changed files with 108 additions and 11 deletions

View File

@ -17,6 +17,7 @@
package org.apache.doris.nereids.memo;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.rules.exploration.mv.StructInfo;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalCatalogRelation;
@ -37,11 +38,12 @@ import javax.annotation.Nullable;
* Representation for group in cascades optimizer.
*/
public class StructInfoMap {
private final Map<BitSet, GroupExpression> groupExpressionMap = new HashMap<>();
private final Map<BitSet, Pair<GroupExpression, List<BitSet>>> groupExpressionMap = new HashMap<>();
private final Map<BitSet, StructInfo> infoMap = new HashMap<>();
/**
* get struct info according to table map
*
* @param mvTableMap the original table map
* @param foldTableMap the fold table map
* @param group the group that the mv matched
@ -54,7 +56,10 @@ public class StructInfoMap {
refresh(group);
}
if (groupExpressionMap.containsKey(mvTableMap)) {
StructInfo structInfo = constructStructInfo(groupExpressionMap.get(mvTableMap));
Pair<GroupExpression, List<BitSet>> groupExpressionBitSetPair = getGroupExpressionWithChildren(
mvTableMap);
StructInfo structInfo = constructStructInfo(groupExpressionBitSetPair.first,
groupExpressionBitSetPair.second, mvTableMap);
infoMap.put(mvTableMap, structInfo);
}
}
@ -66,20 +71,39 @@ public class StructInfoMap {
return groupExpressionMap.keySet();
}
private StructInfo constructStructInfo(GroupExpression groupExpression) {
throw new RuntimeException("has not been implemented for" + groupExpression);
public Pair<GroupExpression, List<BitSet>> getGroupExpressionWithChildren(BitSet tableMap) {
return groupExpressionMap.get(tableMap);
}
private StructInfo constructStructInfo(GroupExpression groupExpression, List<BitSet> children, BitSet tableMap) {
Plan plan = constructPlan(groupExpression, children, tableMap);
return StructInfo.of(plan).get(0);
}
private Plan constructPlan(GroupExpression groupExpression, List<BitSet> children, BitSet tableMap) {
List<Plan> childrenPlan = new ArrayList<>();
for (int i = 0; i < children.size(); i++) {
StructInfoMap structInfoMap = groupExpression.child(i).getstructInfoMap();
BitSet childMap = children.get(i);
Pair<GroupExpression, List<BitSet>> groupExpressionBitSetPair
= structInfoMap.getGroupExpressionWithChildren(childMap);
childrenPlan.add(
constructPlan(groupExpressionBitSetPair.first, groupExpressionBitSetPair.second, childMap));
}
return groupExpression.getPlan().withChildren(childrenPlan);
}
/**
* refresh group expression map
*
* @param group the root group
*/
public void refresh(Group group) {
List<Set<BitSet>> childrenTableMap = new ArrayList<>();
Set<Group> refreshedGroup = new HashSet<>();
for (GroupExpression groupExpression : group.getLogicalExpressions()) {
List<Set<BitSet>> childrenTableMap = new ArrayList<>();
if (groupExpression.children().isEmpty()) {
groupExpressionMap.put(constructLeaf(groupExpression), groupExpression);
groupExpressionMap.put(constructLeaf(groupExpression), Pair.of(groupExpression, new ArrayList<>()));
continue;
}
for (Group child : groupExpression.children()) {
@ -90,9 +114,9 @@ public class StructInfoMap {
refreshedGroup.add(child);
childrenTableMap.add(child.getstructInfoMap().getTableMaps());
}
Set<BitSet> bitSets = cartesianProduct(childrenTableMap);
for (BitSet bitSet : bitSets) {
groupExpressionMap.put(bitSet, groupExpression);
Set<Pair<BitSet, List<BitSet>>> bitSetWithChildren = cartesianProduct(childrenTableMap);
for (Pair<BitSet, List<BitSet>> bitSetWithChild : bitSetWithChildren) {
groupExpressionMap.put(bitSetWithChild.first, Pair.of(groupExpression, bitSetWithChild.second));
}
}
}
@ -108,7 +132,7 @@ public class StructInfoMap {
return tableMap;
}
private Set<BitSet> cartesianProduct(List<Set<BitSet>> childrenTableMap) {
private Set<Pair<BitSet, List<BitSet>>> cartesianProduct(List<Set<BitSet>> childrenTableMap) {
return Sets.cartesianProduct(childrenTableMap)
.stream()
.map(bitSetList -> {
@ -116,7 +140,7 @@ public class StructInfoMap {
for (BitSet b : bitSetList) {
bitSet.or(b);
}
return bitSet;
return Pair.of(bitSet, bitSetList);
})
.collect(Collectors.toSet());
}