[feat](Nereids): eliminate left outer join by unique when comparing mv (#30228)

This commit is contained in:
谢健
2024-01-24 14:05:38 +08:00
committed by yiguolei
parent cd70f45ce2
commit 7e1a986fa1
9 changed files with 221 additions and 19 deletions

View File

@ -20,6 +20,7 @@ package org.apache.doris.nereids.jobs.joinorder.hypergraph.bitmap;
import org.apache.doris.nereids.trees.plans.RelationId;
import java.util.BitSet;
import java.util.Collection;
import java.util.Set;
/**
@ -49,6 +50,14 @@ public class LongBitmap {
return 0;
}
public static long newBitmap(Collection<Integer> values) {
long res = 0;
for (int v : values) {
res = LongBitmap.set(res, v);
}
return res;
}
public static long clone(long bitmap) {
return bitmap;
}

View File

@ -19,9 +19,14 @@ package org.apache.doris.nereids.jobs.joinorder.hypergraph.node;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.bitmap.LongBitmap;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.edge.Edge;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.edge.FilterEdge;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.edge.JoinEdge;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.Plan;
import com.google.common.collect.ImmutableList;
import java.util.ArrayList;
import java.util.List;
import java.util.Set;
@ -30,17 +35,34 @@ import java.util.Set;
*/
public class AbstractNode {
protected final int index;
protected final List<Edge> edges;
protected final List<JoinEdge> joinEdges;
protected final List<FilterEdge> filterEdges;
protected final Plan plan;
protected AbstractNode(Plan plan, int index, List<Edge> edges) {
this.index = index;
this.edges = edges;
this.joinEdges = new ArrayList<>();
this.filterEdges = new ArrayList<>();
this.plan = plan;
edges.forEach(e -> {
if (e instanceof JoinEdge) {
joinEdges.add((JoinEdge) e);
} else if (e instanceof FilterEdge) {
filterEdges.add((FilterEdge) e);
}
});
}
public List<JoinEdge> getJoinEdges() {
return ImmutableList.copyOf(joinEdges);
}
public List<Edge> getEdges() {
return edges;
return ImmutableList
.<Edge>builder()
.addAll(joinEdges)
.addAll(filterEdges)
.build();
}
public int getIndex() {
@ -61,7 +83,11 @@ public class AbstractNode {
* @param edge the edge that references this node
*/
public void attachEdge(Edge edge) {
edges.add(edge);
if (edge instanceof JoinEdge) {
joinEdges.add((JoinEdge) edge);
} else if (edge instanceof FilterEdge) {
filterEdges.add((FilterEdge) edge);
}
}
/**
@ -70,7 +96,11 @@ public class AbstractNode {
* @param edge The edge should be removed
*/
public void removeEdge(Edge edge) {
edges.remove(edge);
if (edge instanceof JoinEdge) {
joinEdges.remove(edge);
} else if (edge instanceof FilterEdge) {
filterEdges.remove(edge);
}
}
public String getName() {

View File

@ -44,7 +44,7 @@ public class DPhyperNode extends AbstractNode {
}
public DPhyperNode withGroup(Group group) {
return new DPhyperNode(index, group, edges);
return new DPhyperNode(index, group, getEdges());
}
public Group getGroup() {

View File

@ -19,6 +19,7 @@ package org.apache.doris.nereids.jobs.joinorder.hypergraph.node;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.edge.Edge;
import org.apache.doris.nereids.properties.FunctionalDependencies;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.Plan;
@ -28,8 +29,11 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalCatalogRelation;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanVisitor;
import org.apache.doris.nereids.util.JoinUtils;
import org.apache.doris.nereids.util.Utils;
import com.google.common.base.Supplier;
import com.google.common.base.Suppliers;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableList.Builder;
import com.google.common.collect.ImmutableSet;
@ -47,17 +51,36 @@ import javax.annotation.Nullable;
public class StructInfoNode extends AbstractNode {
private final List<Set<Expression>> expressions;
private final Set<CatalogRelation> relationSet;
private final Supplier<Boolean> eliminateSupplier;
public StructInfoNode(int index, Plan plan, List<Edge> edges) {
super(extractPlan(plan), index, edges);
relationSet = plan.collect(CatalogRelation.class::isInstance);
expressions = collectExpressions(plan);
eliminateSupplier = Suppliers.memoize(this::computeElimination);
}
public StructInfoNode(int index, Plan plan) {
this(index, plan, new ArrayList<>());
}
private boolean computeElimination() {
if (getJoinEdges().isEmpty()) {
return false;
}
return getJoinEdges().stream().allMatch(e -> {
if (e.getRightExtendedNodes() == getNodeMap()) {
return JoinUtils.canEliminateByLeft(e.getJoin(), FunctionalDependencies.EMPTY_FUNC_DEPS,
plan.getLogicalProperties().getFunctionalDependencies());
}
return false;
});
}
public boolean canEliminate() {
return eliminateSupplier.get();
}
private @Nullable List<Set<Expression>> collectExpressions(Plan plan) {
Pair<Boolean, Builder<Set<Expression>>> collector = Pair.of(true, ImmutableList.builder());

View File

@ -71,12 +71,19 @@ public class HyperGraphComparator {
private final Map<JoinEdge, Pair<JoinType, Set<Slot>>> inferredViewEdgeWithCond = new HashMap<>();
private List<JoinEdge> viewJoinEdgesAfterInferring;
private List<FilterEdge> viewFilterEdgesAfterInferring;
private final long eliminateViewNodesMap;
/**
* constructor
*/
public HyperGraphComparator(HyperGraph queryHyperGraph, HyperGraph viewHyperGraph,
LogicalCompatibilityContext logicalCompatibilityContext) {
this.queryHyperGraph = queryHyperGraph;
this.viewHyperGraph = viewHyperGraph;
this.logicalCompatibilityContext = logicalCompatibilityContext;
this.eliminateViewNodesMap = LongBitmap.newBitmapDiff(
viewHyperGraph.getNodesMap(),
LongBitmap.newBitmap(logicalCompatibilityContext.getQueryToViewNodeIDMapping().values()));
}
/**
@ -91,21 +98,26 @@ public class HyperGraphComparator {
}
private ComparisonResult isLogicCompatible() {
// 1 compare nodes
// 1 remove unused nodes
if (!tryEliminateNodesAndEdge()) {
return ComparisonResult.newInvalidResWithErrorMessage("Query and Mv has different nodes");
}
// 2 compare nodes
boolean nodeMatches = logicalCompatibilityContext.getQueryToViewNodeMapping().entrySet()
.stream().allMatch(e -> compareNodeWithExpr(e.getKey(), e.getValue()));
if (!nodeMatches) {
return ComparisonResult.newInvalidResWithErrorMessage("StructInfoNode are not compatible\n");
}
// 2 try to construct a map which can be mapped from edge to edge
// 3 try to construct a map which can be mapped from edge to edge
Map<Edge, Edge> queryToView = constructQueryToViewMapWithExpr();
if (!makeViewJoinCompatible(queryToView)) {
return ComparisonResult.newInvalidResWithErrorMessage("Join types are not compatible\n");
}
refreshViewEdges();
// 3. compare them by expression and nodes. Note compare edges after inferring for nodes
// 4 compare them by expression and nodes. Note compare edges after inferring for nodes
boolean matchNodes = queryToView.entrySet().stream()
.allMatch(e -> compareEdgeWithNode(e.getKey(), e.getValue()));
if (!matchNodes) {
@ -113,19 +125,32 @@ public class HyperGraphComparator {
}
queryToView.forEach(this::compareEdgeWithExpr);
// 1. process residual edges
// 5 process residual edges
Sets.difference(getQueryJoinEdgeSet(), queryToView.keySet())
.forEach(e -> pullUpQueryExprWithEdge.put(e, e.getExpressions()));
Sets.difference(getQueryFilterEdgeSet(), queryToView.keySet())
.forEach(e -> pullUpQueryExprWithEdge.put(e, e.getExpressions()));
Sets.difference(getViewJoinEdgeSet(), Sets.newHashSet(queryToView.values()))
.stream()
.filter(e -> !LongBitmap.isOverlap(e.getReferenceNodes(), eliminateViewNodesMap))
.forEach(e -> pullUpViewExprWithEdge.put(e, e.getExpressions()));
Sets.difference(getViewFilterEdgeSet(), Sets.newHashSet(queryToView.values()))
.stream()
.filter(e -> !LongBitmap.isOverlap(e.getReferenceNodes(), eliminateViewNodesMap))
.forEach(e -> pullUpViewExprWithEdge.put(e, e.getExpressions()));
return buildComparisonRes();
}
private boolean tryEliminateNodesAndEdge() {
for (int i : LongBitmap.getIterator(eliminateViewNodesMap)) {
if (!((StructInfoNode) viewHyperGraph.getNode(i)).canEliminate()) {
return false;
}
}
return true;
}
private boolean compareNodeWithExpr(StructInfoNode query, StructInfoNode view) {
List<Set<Expression>> queryExprSetList = query.getExprSetList();
List<Set<Expression>> viewExprSetList = view.getExprSetList();
@ -423,5 +448,4 @@ public class HyperGraphComparator {
pullUpQueryExprWithEdge.put(query, residualQueryExpr);
pullUpViewExprWithEdge.put(query, residualViewExpr);
}
}

View File

@ -19,13 +19,14 @@ package org.apache.doris.nereids.rules.rewrite;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.NullSafeEqual;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.util.JoinUtils;
/**
* Eliminate outer join.
*/
public class EliminateJoinByUnique extends OneRewriteRuleFactory {
// TODO: support distinct -> LOJ
@Override
public Rule build() {
return logicalProject(
@ -35,11 +36,9 @@ public class EliminateJoinByUnique extends OneRewriteRuleFactory {
if (!join.left().getOutputSet().containsAll(project.getInputSlots())) {
return project;
}
if (join.getHashJoinConjuncts().stream().anyMatch(NullSafeEqual.class::isInstance)) {
// TODO: support null safe equals in fd and this
return project;
}
if (!project.getLogicalProperties().getFunctionalDependencies().isUnique(project.getOutputSet())) {
if (!JoinUtils.canEliminateByLeft(join,
join.left().getLogicalProperties().getFunctionalDependencies(),
join.right().getLogicalProperties().getFunctionalDependencies())) {
return project;
}
return project.withChildren(join.left());

View File

@ -385,7 +385,10 @@ public class LogicalJoin<LEFT_CHILD_TYPE extends Plan, RIGHT_CHILD_TYPE extends
markJoinSlotReference, children);
}
private @Nullable Pair<Set<Slot>, Set<Slot>> extractHashKeys() {
/**
* extractNullRejectHashKeys
*/
public @Nullable Pair<Set<Slot>, Set<Slot>> extractNullRejectHashKeys() {
Set<Slot> leftKeys = new HashSet<>();
Set<Slot> rightKeys = new HashSet<>();
for (Expression expression : hashJoinConjuncts) {
@ -428,7 +431,7 @@ public class LogicalJoin<LEFT_CHILD_TYPE extends Plan, RIGHT_CHILD_TYPE extends
return FunctionalDependencies.EMPTY_FUNC_DEPS;
}
Pair<Set<Slot>, Set<Slot>> keys = extractHashKeys();
Pair<Set<Slot>, Set<Slot>> keys = extractNullRejectHashKeys();
if (keys == null) {
return FunctionalDependencies.EMPTY_FUNC_DEPS;
}

View File

@ -24,6 +24,7 @@ import org.apache.doris.nereids.properties.DistributionSpec;
import org.apache.doris.nereids.properties.DistributionSpecHash;
import org.apache.doris.nereids.properties.DistributionSpecHash.ShuffleType;
import org.apache.doris.nereids.properties.DistributionSpecReplicated;
import org.apache.doris.nereids.properties.FunctionalDependencies;
import org.apache.doris.nereids.trees.expressions.EqualPredicate;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Expression;
@ -33,6 +34,7 @@ import org.apache.doris.nereids.trees.expressions.functions.scalar.BitmapContain
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.algebra.Join;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.physical.AbstractPhysicalJoin;
import org.apache.doris.nereids.trees.plans.physical.PhysicalDistribute;
import org.apache.doris.nereids.trees.plans.physical.PhysicalHashJoin;
@ -277,6 +279,21 @@ public class JoinUtils {
.collect(ImmutableList.toImmutableList());
}
/**
* can this join be eliminated by its left child
*/
public static boolean canEliminateByLeft(LogicalJoin<?, ?> join, FunctionalDependencies leftFuncDeps,
FunctionalDependencies rightFuncDeps) {
if (join.getJoinType().isLeftOuterJoin()) {
Pair<Set<Slot>, Set<Slot>> njHashKeys = join.extractNullRejectHashKeys();
if (!join.getOtherJoinConjuncts().isEmpty() || njHashKeys == null) {
return false;
}
return rightFuncDeps.isUnique(njHashKeys.second);
}
return false;
}
/**
* calculate the output slot of a join operator according join type and its children
*

View File

@ -0,0 +1,97 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.exploration.mv;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.HyperGraph;
import org.apache.doris.nereids.rules.RuleSet;
import org.apache.doris.nereids.rules.exploration.mv.mapping.RelationMapping;
import org.apache.doris.nereids.rules.exploration.mv.mapping.SlotMapping;
import org.apache.doris.nereids.sqltest.SqlTestBase;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.util.PlanChecker;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
class EliminateJoinTest extends SqlTestBase {
@Test
void testLOJWithGroupBy() {
connectContext.getSessionVariable().setDisableNereidsRules("INFER_PREDICATES");
CascadesContext c1 = createCascadesContext(
"select * from T1",
connectContext
);
Plan p1 = PlanChecker.from(c1)
.analyze()
.rewrite()
.getPlan().child(0);
CascadesContext c2 = createCascadesContext(
"select * from T1 left outer join (select id from T2 group by id) T2 "
+ "on T1.id = T2.id ",
connectContext
);
Plan p2 = PlanChecker.from(c2)
.analyze()
.rewrite()
.applyExploration(RuleSet.BUSHY_TREE_JOIN_REORDER)
.getAllPlan().get(0).child(0);
HyperGraph h1 = HyperGraph.builderForMv(p1).buildAll().get(0);
HyperGraph h2 = HyperGraph.builderForMv(p2).buildAll().get(0);
ComparisonResult res = HyperGraphComparator.isLogicCompatible(h1, h2, constructContext(p1, p2));
Assertions.assertTrue(!res.isInvalid());
}
@Test
void testLOJWithUK() throws Exception {
connectContext.getSessionVariable().setDisableNereidsRules("INFER_PREDICATES");
CascadesContext c1 = createCascadesContext(
"select * from T1",
connectContext
);
Plan p1 = PlanChecker.from(c1)
.analyze()
.rewrite()
.getPlan().child(0);
addConstraint("alter table T2 add constraint uk unique (id)");
CascadesContext c2 = createCascadesContext(
"select * from T1 left outer join T2 "
+ "on T1.id = T2.id ",
connectContext
);
Plan p2 = PlanChecker.from(c2)
.analyze()
.rewrite()
.applyExploration(RuleSet.BUSHY_TREE_JOIN_REORDER)
.getAllPlan().get(0).child(0);
HyperGraph h1 = HyperGraph.builderForMv(p1).buildAll().get(0);
HyperGraph h2 = HyperGraph.builderForMv(p2).buildAll().get(0);
ComparisonResult res = HyperGraphComparator.isLogicCompatible(h1, h2, constructContext(p1, p2));
Assertions.assertTrue(!res.isInvalid());
}
LogicalCompatibilityContext constructContext(Plan p1, Plan p2) {
StructInfo st1 = MaterializedViewUtils.extractStructInfo(p1,
null).get(0);
StructInfo st2 = MaterializedViewUtils.extractStructInfo(p2,
null).get(0);
RelationMapping rm = RelationMapping.generate(st1.getRelations(), st2.getRelations()).get(0);
SlotMapping sm = SlotMapping.generate(rm);
return LogicalCompatibilityContext.from(rm, sm, st1, st2);
}
}