[feat](Nereids): eliminate inner join by pk fk when comparing mv (#30258)

This commit is contained in:
谢健
2024-01-25 14:01:34 +08:00
committed by yiguolei
parent 0a5c375068
commit 0f81ecf415
8 changed files with 441 additions and 282 deletions

View File

@ -303,7 +303,7 @@ public class Rewriter extends AbstractBatchJobExecutor {
// this rule should invoke after infer predicate and push down distinct, and before push down limit
topic("eliminate join according unique or foreign key",
custom(RuleType.ELIMINATE_JOIN_BY_FOREIGN_KEY, EliminateJoinByFK::new),
bottomUp(new EliminateJoinByFK()),
topDown(new EliminateJoinByUnique())
),

View File

@ -19,7 +19,6 @@ package org.apache.doris.nereids.jobs.joinorder.hypergraph.node;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.edge.Edge;
import org.apache.doris.nereids.properties.FunctionalDependencies;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.Plan;
@ -29,11 +28,8 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalCatalogRelation;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanVisitor;
import org.apache.doris.nereids.util.JoinUtils;
import org.apache.doris.nereids.util.Utils;
import com.google.common.base.Supplier;
import com.google.common.base.Suppliers;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableList.Builder;
import com.google.common.collect.ImmutableSet;
@ -51,36 +47,17 @@ import javax.annotation.Nullable;
public class StructInfoNode extends AbstractNode {
private final List<Set<Expression>> expressions;
private final Set<CatalogRelation> relationSet;
private final Supplier<Boolean> eliminateSupplier;
public StructInfoNode(int index, Plan plan, List<Edge> edges) {
super(extractPlan(plan), index, edges);
relationSet = plan.collect(CatalogRelation.class::isInstance);
expressions = collectExpressions(plan);
eliminateSupplier = Suppliers.memoize(this::computeElimination);
}
public StructInfoNode(int index, Plan plan) {
this(index, plan, new ArrayList<>());
}
private boolean computeElimination() {
if (getJoinEdges().isEmpty()) {
return false;
}
return getJoinEdges().stream().allMatch(e -> {
if (e.getRightExtendedNodes() == getNodeMap()) {
return JoinUtils.canEliminateByLeft(e.getJoin(), FunctionalDependencies.EMPTY_FUNC_DEPS,
plan.getLogicalProperties().getFunctionalDependencies());
}
return false;
});
}
public boolean canEliminate() {
return eliminateSupplier.get();
}
private @Nullable List<Set<Expression>> collectExpressions(Plan plan) {
Pair<Boolean, Builder<Set<Expression>>> collector = Pair.of(true, ImmutableList.builder());

View File

@ -29,7 +29,9 @@ import org.apache.doris.nereids.rules.rewrite.PushDownFilterThroughJoin;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.JoinUtils;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
@ -143,10 +145,57 @@ public class HyperGraphComparator {
}
private boolean tryEliminateNodesAndEdge() {
for (int i : LongBitmap.getIterator(eliminateViewNodesMap)) {
if (!((StructInfoNode) viewHyperGraph.getNode(i)).canEliminate()) {
boolean hasFilterEdgeAbove = viewHyperGraph.getFilterEdges().stream()
.filter(e -> LongBitmap.getCardinality(e.getReferenceNodes()) == 1)
.anyMatch(e -> LongBitmap.isSubset(e.getReferenceNodes(), eliminateViewNodesMap));
if (hasFilterEdgeAbove) {
// If there is some filter edge above the eliminated node, we should rebuild a plan
// Right now, just refuse it.
return false;
}
for (JoinEdge joinEdge : viewHyperGraph.getJoinEdges()) {
if (!LongBitmap.isOverlap(joinEdge.getReferenceNodes(), eliminateViewNodesMap)) {
continue;
}
// eliminate by unique
if (joinEdge.getJoinType().isLeftOuterJoin()) {
long eliminatedRight =
LongBitmap.newBitmapIntersect(joinEdge.getRightExtendedNodes(), eliminateViewNodesMap);
if (LongBitmap.getCardinality(eliminatedRight) != 1) {
return false;
}
Plan rigthPlan = viewHyperGraph
.getNode(LongBitmap.lowestOneIndex(joinEdge.getRightExtendedNodes())).getPlan();
return JoinUtils.canEliminateByLeft(joinEdge.getJoin(),
rigthPlan.getLogicalProperties().getFunctionalDependencies());
}
// eliminate by pk fk
if (joinEdge.getJoinType().isInnerJoin()) {
if (!joinEdge.isSimple()) {
return false;
}
long eliminatedLeft =
LongBitmap.newBitmapIntersect(joinEdge.getLeftExtendedNodes(), eliminateViewNodesMap);
long eliminatedRight =
LongBitmap.newBitmapIntersect(joinEdge.getRightExtendedNodes(), eliminateViewNodesMap);
if (LongBitmap.getCardinality(eliminatedLeft) == 0
&& LongBitmap.getCardinality(eliminatedRight) == 1) {
Plan foreign = viewHyperGraph
.getNode(LongBitmap.lowestOneIndex(joinEdge.getLeftExtendedNodes())).getPlan();
Plan primary = viewHyperGraph
.getNode(LongBitmap.lowestOneIndex(joinEdge.getRightExtendedNodes())).getPlan();
return JoinUtils.canEliminateByFk(joinEdge.getJoin(), primary, foreign);
} else if (LongBitmap.getCardinality(eliminatedLeft) == 1
&& LongBitmap.getCardinality(eliminatedRight) == 0) {
Plan foreign = viewHyperGraph
.getNode(LongBitmap.lowestOneIndex(joinEdge.getRightExtendedNodes())).getPlan();
Plan primary = viewHyperGraph
.getNode(LongBitmap.lowestOneIndex(joinEdge.getLeftExtendedNodes())).getPlan();
return JoinUtils.canEliminateByFk(joinEdge.getJoin(), primary, foreign);
}
return false;
}
}
return true;
}

View File

@ -17,297 +17,106 @@
package org.apache.doris.nereids.rules.rewrite;
import org.apache.doris.catalog.Column;
import org.apache.doris.catalog.TableIf;
import org.apache.doris.nereids.jobs.JobContext;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.IsNull;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Not;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.functions.ExpressionTrait;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalCatalogRelation;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.logical.LogicalRelation;
import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter;
import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter;
import org.apache.doris.nereids.util.ImmutableEqualSet;
import org.apache.doris.nereids.util.JoinUtils;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Sets;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
/**
* Eliminate join by foreign.
*/
public class EliminateJoinByFK extends DefaultPlanRewriter<JobContext> implements CustomRewriter {
public class EliminateJoinByFK extends OneRewriteRuleFactory {
// Right now we only support eliminate inner join, which should meet the following condition:
// 1. only contain null-reject equal condition, and which all meet fk-pk constraint
// 2. only output foreign table output or can be converted to foreign table output
// 3. if foreign key is null, add a isNotNull predicate for null-reject join condition
@Override
public Plan rewriteRoot(Plan plan, JobContext jobContext) {
EliminateJoinByFKHelper helper = new EliminateJoinByFKHelper();
return helper.rewriteRoot(plan, jobContext);
}
private static class EliminateJoinByFKHelper
extends DefaultPlanRewriter<ForeignKeyContext> implements CustomRewriter {
@Override
public Plan rewriteRoot(Plan plan, JobContext jobContext) {
return plan.accept(this, new ForeignKeyContext());
}
@Override
public Plan visit(Plan plan, ForeignKeyContext context) {
Plan newPlan = visitChildren(this, plan, context);
// always expire primary key except filter, project and join.
// always keep foreign key alive
context.expirePrimaryKey(plan);
return newPlan;
}
@Override
public Plan visitLogicalRelation(LogicalRelation relation, ForeignKeyContext context) {
if (!(relation instanceof LogicalCatalogRelation)) {
return relation;
}
context.putAllForeignKeys(((LogicalCatalogRelation) relation).getTable());
relation.getOutput().stream()
.filter(SlotReference.class::isInstance)
.map(SlotReference.class::cast)
.forEach(context::putSlot);
return relation;
}
private boolean canEliminate(LogicalJoin<?, ?> join, Map<Slot, Slot> primaryToForeign,
ForeignKeyContext context) {
if (!join.getOtherJoinConjuncts().isEmpty()) {
return false;
}
if (!join.getJoinType().isInnerJoin() && !join.getJoinType().isSemiJoin()) {
return false;
}
return context.satisfyConstraint(primaryToForeign, join);
}
private @Nullable Map<Expression, Expression> tryMapOutputToForeignPlan(Plan foreignPlan,
Set<Slot> output, Map<Slot, Slot> primaryToForeign) {
Set<Slot> residualPrimary = Sets.difference(output, foreignPlan.getOutputSet());
ImmutableMap.Builder<Expression, Expression> builder = new ImmutableMap.Builder<>();
for (Slot slot : residualPrimary) {
if (primaryToForeign.containsKey(slot)) {
builder.put(slot, primaryToForeign.get(slot));
} else {
return null;
}
}
return builder.build();
}
private Plan applyNullCompensationFilter(Plan child, Set<Slot> childSlots) {
Set<Expression> predicates = childSlots.stream()
.filter(ExpressionTrait::nullable)
.map(s -> new Not(new IsNull(s)))
.collect(ImmutableSet.toImmutableSet());
if (predicates.isEmpty()) {
return child;
}
return new LogicalFilter<>(predicates, child);
}
private Plan tryEliminatePrimaryPlan(LogicalProject<LogicalJoin<?, ?>> project,
Plan foreignPlan, Set<Slot> foreignKeys,
Map<Slot, Slot> primaryToForeign, ForeignKeyContext context) {
Set<Slot> output = project.getInputSlots();
Map<Expression, Expression> outputToForeign =
tryMapOutputToForeignPlan(foreignPlan, output, primaryToForeign);
if (outputToForeign != null && canEliminate(project.child(), primaryToForeign, context)) {
List<NamedExpression> newProjects = project.getProjects().stream()
.map(e -> outputToForeign.containsKey(e)
? new Alias(e.getExprId(), outputToForeign.get(e), e.toSql())
: (NamedExpression) e.rewriteUp(s -> outputToForeign.getOrDefault(s, s)))
.collect(ImmutableList.toImmutableList());
return project.withProjects(newProjects)
.withChildren(applyNullCompensationFilter(foreignPlan, foreignKeys));
}
return project;
}
private @Nullable Map<Slot, Slot> mapPrimaryToForeign(ImmutableEqualSet<Slot> equivalenceSet,
Set<Slot> foreignKeys) {
ImmutableMap.Builder<Slot, Slot> builder = new ImmutableMap.Builder<>();
for (Slot foreignSlot : foreignKeys) {
Set<Slot> primarySlots = equivalenceSet.calEqualSet(foreignSlot);
if (primarySlots.size() != 1) {
return null;
}
builder.put(primarySlots.iterator().next(), foreignSlot);
}
return builder.build();
}
// Right now we only support eliminate inner join, which should meet the following condition:
// 1. only contain null-reject equal condition, and which all meet fk-pk constraint
// 2. only output foreign table output or can be converted to foreign table output
// 4. if foreign key is null, add a isNotNull predicate for null-reject join condition
private Plan eliminateJoin(LogicalProject<LogicalJoin<?, ?>> project, ForeignKeyContext context) {
LogicalJoin<?, ?> join = project.child();
public Rule build() {
return logicalProject(
logicalJoin().when(join -> join.getJoinType().isInnerJoin())
).then(project -> {
LogicalJoin<Plan, Plan> join = project.child();
ImmutableEqualSet<Slot> equalSet = join.getEqualSlots();
Set<Slot> leftSlots = Sets.intersection(join.left().getOutputSet(), equalSet.getAllItemSet());
Set<Slot> rightSlots = Sets.intersection(join.right().getOutputSet(), equalSet.getAllItemSet());
if (context.isForeignKey(leftSlots) && context.isPrimaryKey(rightSlots)) {
Map<Slot, Slot> primaryToForeignSlot = mapPrimaryToForeign(equalSet, leftSlots);
if (primaryToForeignSlot != null) {
return tryEliminatePrimaryPlan(project, join.left(), leftSlots, primaryToForeignSlot, context);
}
} else if (context.isForeignKey(rightSlots) && context.isPrimaryKey(leftSlots)) {
Map<Slot, Slot> primaryToForeignSlot = mapPrimaryToForeign(equalSet, rightSlots);
if (primaryToForeignSlot != null) {
return tryEliminatePrimaryPlan(project, join.right(), rightSlots, primaryToForeignSlot, context);
}
Set<Slot> residualSlot = Sets.difference(project.getInputSlots(), equalSet.getAllItemSet());
Plan res = null;
if (join.left().getOutputSet().containsAll(residualSlot)) {
res = tryEliminatePrimary(project, equalSet, join.right(), join.left());
}
return project;
}
@Override
public Plan visitLogicalProject(LogicalProject<?> project, ForeignKeyContext context) {
project = visitChildren(this, project, context);
for (NamedExpression expression : project.getProjects()) {
if (expression instanceof Alias && expression.child(0) instanceof Slot) {
context.putAlias(expression.toSlot(), (Slot) expression.child(0));
}
if (res == null && join.right().getOutputSet().containsAll(residualSlot)) {
res = tryEliminatePrimary(project, equalSet, join.left(), join.right());
}
if (project.child() instanceof LogicalJoin<?, ?>) {
return eliminateJoin((LogicalProject<LogicalJoin<?, ?>>) project, context);
}
return project;
}
@Override
public Plan visitLogicalJoin(LogicalJoin<?, ?> join, ForeignKeyContext context) {
Plan plan = visitChildren(this, join, context);
context.addJoin(join);
return plan;
}
@Override
public Plan visitLogicalFilter(LogicalFilter<?> filter, ForeignKeyContext context) {
Plan plan = visitChildren(this, filter, context);
context.addFilter(filter);
return plan;
}
return res;
}).toRule(RuleType.ELIMINATE_JOIN_BY_UK);
}
private static class ForeignKeyContext {
Set<Map<Column, Column>> constraints = new HashSet<>();
Set<Column> foreignKeys = new HashSet<>();
Set<Column> primaryKeys = new HashSet<>();
Map<Slot, Column> slotToColumn = new HashMap<>();
Map<Slot, Set<LogicalJoin<?, ?>>> slotWithJoin = new HashMap<>();
Map<Slot, Set<Expression>> slotWithPredicates = new HashMap<>();
public void putAllForeignKeys(TableIf table) {
table.getForeignKeyConstraints().forEach(c -> {
Map<Column, Column> constraint = c.getForeignToPrimary(table);
constraints.add(c.getForeignToPrimary(table));
foreignKeys.addAll(constraint.keySet());
primaryKeys.addAll(constraint.values());
});
private @Nullable Plan tryEliminatePrimary(LogicalProject<LogicalJoin<Plan, Plan>> project,
ImmutableEqualSet<Slot> equalSet, Plan primary, Plan foreign) {
if (!JoinUtils.canEliminateByFk(project.child(), primary, foreign)) {
return null;
}
public boolean isForeignKey(Set<Slot> key) {
return foreignKeys.containsAll(
key.stream().map(s -> slotToColumn.get(s)).collect(Collectors.toSet()));
Set<Slot> output = project.getInputSlots();
Set<Slot> foreignKeys = Sets.intersection(foreign.getOutputSet(), equalSet.getAllItemSet());
Map<Expression, Expression> outputToForeign =
tryMapOutputToForeignPlan(foreign, output, equalSet);
if (outputToForeign != null) {
List<NamedExpression> newProjects = project.getProjects().stream()
.map(e -> outputToForeign.containsKey(e)
? new Alias(e.getExprId(), outputToForeign.get(e), e.toSql())
: (NamedExpression) e.rewriteUp(s -> outputToForeign.getOrDefault(s, s)))
.collect(ImmutableList.toImmutableList());
return project.withProjects(newProjects)
.withChildren(applyNullCompensationFilter(foreign, foreignKeys));
}
return project;
}
public boolean isPrimaryKey(Set<Slot> key) {
return primaryKeys.containsAll(
key.stream().map(s -> slotToColumn.get(s)).collect(Collectors.toSet()));
}
public void putSlot(SlotReference slot) {
if (!slot.getColumn().isPresent()) {
return;
private @Nullable Map<Expression, Expression> tryMapOutputToForeignPlan(Plan foreignPlan,
Set<Slot> output, ImmutableEqualSet<Slot> equalSet) {
Set<Slot> residualPrimary = Sets.difference(output, foreignPlan.getOutputSet());
ImmutableMap.Builder<Expression, Expression> builder = new ImmutableMap.Builder<>();
for (Slot primarySlot : residualPrimary) {
Optional<Slot> replacedForeign = equalSet.calEqualSet(primarySlot).stream()
.filter(foreignPlan.getOutputSet()::contains)
.findFirst();
if (!replacedForeign.isPresent()) {
return null;
}
Column c = slot.getColumn().get();
slotToColumn.put(slot, c);
builder.put(primarySlot, replacedForeign.get());
}
return builder.build();
}
public void putAlias(Slot newSlot, Slot originSlot) {
if (slotToColumn.containsKey(originSlot)) {
slotToColumn.put(newSlot, slotToColumn.get(originSlot));
}
}
public void addFilter(LogicalFilter<?> filter) {
filter.getOutput().stream()
.filter(slotToColumn::containsKey)
.forEach(slot -> {
slotWithPredicates.computeIfAbsent(slot, v -> new HashSet<>());
slotWithPredicates.get(slot).addAll(filter.getConjuncts());
});
}
public void addJoin(LogicalJoin<?, ?> join) {
join.getOutput().stream()
.filter(slotToColumn::containsKey)
.forEach(slot ->
slotWithJoin.computeIfAbsent(slot, v -> Sets.newHashSet((join))));
}
public void expirePrimaryKey(Plan plan) {
plan.getOutput().stream()
.filter(slotToColumn::containsKey)
.map(s -> slotToColumn.get(s))
.forEach(primaryKeys::remove);
}
public boolean satisfyConstraint(Map<Slot, Slot> primaryToForeign, LogicalJoin<?, ?> join) {
Map<Column, Column> foreignToPrimary = primaryToForeign.entrySet().stream()
.collect(ImmutableMap.toImmutableMap(
e -> slotToColumn.get(e.getValue()),
e -> slotToColumn.get(e.getKey())));
// The primary key can only contain join that may be eliminated
if (!primaryToForeign.keySet().stream().allMatch(p ->
slotWithJoin.get(p).size() == 1 && slotWithJoin.get(p).iterator().next() == join)) {
return false;
}
// The foreign key's filters must contain primary filters
if (!isPredicateCompatible(primaryToForeign)) {
return false;
}
return constraints.contains(foreignToPrimary);
}
// When predicates of foreign keys is a subset of that of primary keys
private boolean isPredicateCompatible(Map<Slot, Slot> primaryToForeign) {
return primaryToForeign.entrySet().stream().allMatch(pf -> {
// There is no predicate in primary key
if (!slotWithPredicates.containsKey(pf.getKey()) || slotWithPredicates.get(pf.getKey()).isEmpty()) {
return true;
}
// There are some predicates in primary key but there is no predicate in foreign key
if (slotWithPredicates.containsKey(pf.getValue()) && slotWithPredicates.get(pf.getValue()).isEmpty()) {
return false;
}
Set<Expression> primaryPredicates = slotWithPredicates.get(pf.getKey()).stream()
.map(e -> e.rewriteUp(
s -> s instanceof Slot ? primaryToForeign.getOrDefault(s, (Slot) s) : s))
.collect(Collectors.toSet());
return slotWithPredicates.get(pf.getValue()).containsAll(primaryPredicates);
});
private Plan applyNullCompensationFilter(Plan child, Set<Slot> childSlots) {
Set<Expression> predicates = childSlots.stream()
.filter(ExpressionTrait::nullable)
.map(s -> new Not(new IsNull(s)))
.collect(ImmutableSet.toImmutableSet());
if (predicates.isEmpty()) {
return child;
}
return new LogicalFilter<>(predicates, child);
}
}

View File

@ -37,7 +37,6 @@ public class EliminateJoinByUnique extends OneRewriteRuleFactory {
return project;
}
if (!JoinUtils.canEliminateByLeft(join,
join.left().getLogicalProperties().getFunctionalDependencies(),
join.right().getLogicalProperties().getFunctionalDependencies())) {
return project;
}

View File

@ -0,0 +1,184 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.rewrite;
import org.apache.doris.catalog.Column;
import org.apache.doris.catalog.TableIf;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalCatalogRelation;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.logical.LogicalRelation;
import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanVisitor;
import com.google.common.collect.ImmutableMap;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
/**
* Record Foreign Key Context
*/
public class ForeignKeyContext {
Set<Map<Column, Column>> constraints = new HashSet<>();
Set<Column> foreignKeys = new HashSet<>();
Set<Column> primaryKeys = new HashSet<>();
Map<Slot, Column> slotToColumn = new HashMap<>();
Map<Slot, Set<Expression>> slotWithPredicates = new HashMap<>();
/**
* Collect Foreign Key Constraint From this Plan
*/
public ForeignKeyContext collectForeignKeyConstraint(Plan plan) {
plan.accept(new DefaultPlanVisitor<Void, ForeignKeyContext>() {
@Override
public Void visit(Plan plan, ForeignKeyContext context) {
super.visit(plan, context);
// always expire primary key except filter, project and join.
// always keep foreign key alive
context.expirePrimaryKey(plan);
return null;
}
@Override
public Void visitLogicalRelation(LogicalRelation relation, ForeignKeyContext context) {
if (relation instanceof LogicalCatalogRelation) {
context.putAllForeignKeys(((LogicalCatalogRelation) relation).getTable());
relation.getOutput().stream()
.filter(SlotReference.class::isInstance)
.map(SlotReference.class::cast)
.forEach(context::putSlot);
}
return null;
}
@Override
public Void visitLogicalProject(LogicalProject<?> project, ForeignKeyContext context) {
super.visit(project, context);
for (NamedExpression expression : project.getProjects()) {
if (expression instanceof Alias && expression.child(0) instanceof Slot) {
context.putAlias(expression.toSlot(), (Slot) expression.child(0));
}
}
return null;
}
@Override
public Void visitLogicalFilter(LogicalFilter<?> filter, ForeignKeyContext context) {
super.visit(filter, context);
context.addFilter(filter);
return null;
}
}, this);
return this;
}
void putAllForeignKeys(TableIf table) {
table.getForeignKeyConstraints().forEach(c -> {
Map<Column, Column> constraint = c.getForeignToPrimary(table);
constraints.add(c.getForeignToPrimary(table));
foreignKeys.addAll(constraint.keySet());
primaryKeys.addAll(constraint.values());
});
}
public boolean isForeignKey(Set<Slot> key) {
return foreignKeys.containsAll(
key.stream().map(s -> slotToColumn.get(s)).collect(Collectors.toSet()));
}
public boolean isPrimaryKey(Set<Slot> key) {
return primaryKeys.containsAll(
key.stream().map(s -> slotToColumn.get(s)).collect(Collectors.toSet()));
}
void putSlot(SlotReference slot) {
if (!slot.getColumn().isPresent()) {
return;
}
Column c = slot.getColumn().get();
slotToColumn.put(slot, c);
}
void putAlias(Slot newSlot, Slot originSlot) {
if (slotToColumn.containsKey(originSlot)) {
slotToColumn.put(newSlot, slotToColumn.get(originSlot));
}
}
private void addFilter(LogicalFilter<?> filter) {
filter.getOutput().stream()
.filter(slotToColumn::containsKey)
.forEach(slot -> {
slotWithPredicates.computeIfAbsent(slot, v -> new HashSet<>());
slotWithPredicates.get(slot).addAll(filter.getConjuncts());
});
}
private void expirePrimaryKey(Plan plan) {
plan.getOutput().stream()
.filter(slotToColumn::containsKey)
.map(s -> slotToColumn.get(s))
.forEach(primaryKeys::remove);
}
/**
* Check whether the given mapping relation satisfies any constraints
*/
public boolean satisfyConstraint(Map<Slot, Slot> primaryToForeign) {
Map<Column, Column> foreignToPrimary = primaryToForeign.entrySet().stream()
.collect(ImmutableMap.toImmutableMap(
e -> slotToColumn.get(e.getValue()),
e -> slotToColumn.get(e.getKey())));
if (primaryToForeign.isEmpty()) {
return false;
}
// The foreign key's filters must contain primary filters
if (!isPredicateCompatible(primaryToForeign)) {
return false;
}
return constraints.contains(foreignToPrimary);
}
// When predicates of foreign keys is a subset of that of primary keys
private boolean isPredicateCompatible(Map<Slot, Slot> primaryToForeign) {
return primaryToForeign.entrySet().stream().allMatch(pf -> {
// There is no predicate in primary key
if (!slotWithPredicates.containsKey(pf.getKey()) || slotWithPredicates.get(pf.getKey()).isEmpty()) {
return true;
}
// There are some predicates in primary key but there is no predicate in foreign key
if (slotWithPredicates.containsKey(pf.getValue()) && slotWithPredicates.get(pf.getValue()).isEmpty()) {
return false;
}
Set<Expression> primaryPredicates = slotWithPredicates.get(pf.getKey()).stream()
.map(e -> e.rewriteUp(
s -> s instanceof Slot ? primaryToForeign.getOrDefault(s, (Slot) s) : s))
.collect(Collectors.toSet());
return slotWithPredicates.get(pf.getValue()).containsAll(primaryPredicates);
});
}
}

View File

@ -25,6 +25,7 @@ import org.apache.doris.nereids.properties.DistributionSpecHash;
import org.apache.doris.nereids.properties.DistributionSpecHash.ShuffleType;
import org.apache.doris.nereids.properties.DistributionSpecReplicated;
import org.apache.doris.nereids.properties.FunctionalDependencies;
import org.apache.doris.nereids.rules.rewrite.ForeignKeyContext;
import org.apache.doris.nereids.trees.expressions.EqualPredicate;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Expression;
@ -43,7 +44,9 @@ import org.apache.doris.qe.ConnectContext;
import org.apache.doris.qe.SessionVariable;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import java.util.HashSet;
import java.util.List;
@ -279,11 +282,46 @@ public class JoinUtils {
.collect(ImmutableList.toImmutableList());
}
private static Map<Slot, Slot> mapPrimaryToForeign(ImmutableEqualSet<Slot> equivalenceSet,
Set<Slot> foreignKeys) {
ImmutableMap.Builder<Slot, Slot> builder = new ImmutableMap.Builder<>();
for (Slot foreignSlot : foreignKeys) {
Set<Slot> primarySlots = equivalenceSet.calEqualSet(foreignSlot);
if (primarySlots.size() != 1) {
return ImmutableMap.of();
}
builder.put(primarySlots.iterator().next(), foreignSlot);
}
return builder.build();
}
/**
* Check whether the given join can be eliminated by pk-fk
*/
public static boolean canEliminateByFk(LogicalJoin<?, ?> join, Plan primaryPlan, Plan foreignPlan) {
if (!join.getJoinType().isInnerJoin() || !join.getOtherJoinConjuncts().isEmpty() || join.isMarkJoin()) {
return false;
}
ForeignKeyContext context = new ForeignKeyContext();
context.collectForeignKeyConstraint(primaryPlan);
context.collectForeignKeyConstraint(foreignPlan);
ImmutableEqualSet<Slot> equalSet = join.getEqualSlots();
Set<Slot> primaryKey = Sets.intersection(equalSet.getAllItemSet(), primaryPlan.getOutputSet());
Set<Slot> foreignKey = Sets.intersection(equalSet.getAllItemSet(), foreignPlan.getOutputSet());
if (!context.isForeignKey(foreignKey) || !context.isPrimaryKey(primaryKey)) {
return false;
}
Map<Slot, Slot> primaryToForeignKey = mapPrimaryToForeign(equalSet, foreignKey);
return context.satisfyConstraint(primaryToForeignKey);
}
/**
* can this join be eliminated by its left child
*/
public static boolean canEliminateByLeft(LogicalJoin<?, ?> join, FunctionalDependencies leftFuncDeps,
FunctionalDependencies rightFuncDeps) {
public static boolean canEliminateByLeft(LogicalJoin<?, ?> join, FunctionalDependencies rightFuncDeps) {
if (join.getJoinType().isLeftOuterJoin()) {
Pair<Set<Slot>, Set<Slot>> njHashKeys = join.extractNullRejectHashKeys();
if (!join.getOtherJoinConjuncts().isEmpty() || njHashKeys == null) {

View File

@ -27,6 +27,7 @@ import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.util.PlanChecker;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
class EliminateJoinTest extends SqlTestBase {
@ -55,6 +56,7 @@ class EliminateJoinTest extends SqlTestBase {
HyperGraph h2 = HyperGraph.builderForMv(p2).buildAll().get(0);
ComparisonResult res = HyperGraphComparator.isLogicCompatible(h1, h2, constructContext(p1, p2));
Assertions.assertTrue(!res.isInvalid());
Assertions.assertTrue(res.getViewExpressions().isEmpty());
}
@Test
@ -83,6 +85,107 @@ class EliminateJoinTest extends SqlTestBase {
HyperGraph h2 = HyperGraph.builderForMv(p2).buildAll().get(0);
ComparisonResult res = HyperGraphComparator.isLogicCompatible(h1, h2, constructContext(p1, p2));
Assertions.assertTrue(!res.isInvalid());
Assertions.assertTrue(res.getViewExpressions().isEmpty());
dropConstraint("alter table T2 drop constraint uk");
}
@Test
void testLOJWithPKFK() throws Exception {
connectContext.getSessionVariable().setDisableNereidsRules("INFER_PREDICATES");
CascadesContext c1 = createCascadesContext(
"select * from T1",
connectContext
);
Plan p1 = PlanChecker.from(c1)
.analyze()
.rewrite()
.getPlan().child(0);
addConstraint("alter table T2 add constraint pk primary key (id)");
addConstraint("alter table T1 add constraint fk foreign key (id) references T2(id)");
CascadesContext c2 = createCascadesContext(
"select * from T1 inner join T2 "
+ "on T1.id = T2.id ",
connectContext
);
Plan p2 = PlanChecker.from(c2)
.analyze()
.rewrite()
.applyExploration(RuleSet.BUSHY_TREE_JOIN_REORDER)
.getAllPlan().get(0).child(0);
HyperGraph h1 = HyperGraph.builderForMv(p1).buildAll().get(0);
HyperGraph h2 = HyperGraph.builderForMv(p2).buildAll().get(0);
ComparisonResult res = HyperGraphComparator.isLogicCompatible(h1, h2, constructContext(p1, p2));
Assertions.assertTrue(!res.isInvalid());
Assertions.assertTrue(res.getViewExpressions().isEmpty());
dropConstraint("alter table T2 drop constraint pk");
}
@Disabled
@Test
void testLOJWithPKFKAndUK1() throws Exception {
connectContext.getSessionVariable().setDisableNereidsRules("INFER_PREDICATES");
CascadesContext c1 = createCascadesContext(
"select * from T1",
connectContext
);
Plan p1 = PlanChecker.from(c1)
.analyze()
.rewrite()
.getPlan().child(0);
addConstraint("alter table T2 add constraint pk primary key (id)");
addConstraint("alter table T1 add constraint fk foreign key (id) references T2(id)");
addConstraint("alter table T3 add constraint uk unique (id)");
CascadesContext c2 = createCascadesContext(
"select * from (select T1.*, T3.id as id3 from T1 left outer join T3 on T1.id = T3.id) T1 inner join T2 "
+ "on T1.id = T2.id ",
connectContext
);
Plan p2 = PlanChecker.from(c2)
.analyze()
.rewrite()
.applyExploration(RuleSet.BUSHY_TREE_JOIN_REORDER)
.getAllPlan().get(0).child(0);
HyperGraph h1 = HyperGraph.builderForMv(p1).buildAll().get(0);
HyperGraph h2 = HyperGraph.builderForMv(p2).buildAll().get(0);
ComparisonResult res = HyperGraphComparator.isLogicCompatible(h1, h2, constructContext(p1, p2));
Assertions.assertTrue(!res.isInvalid());
Assertions.assertTrue(res.getViewExpressions().isEmpty());
dropConstraint("alter table T2 drop constraint pk");
dropConstraint("alter table T3 drop constraint uk");
}
@Disabled
@Test
void testLOJWithPKFKAndUK2() throws Exception {
connectContext.getSessionVariable().setDisableNereidsRules("INFER_PREDICATES");
CascadesContext c1 = createCascadesContext(
"select * from T1",
connectContext
);
Plan p1 = PlanChecker.from(c1)
.analyze()
.rewrite()
.getPlan().child(0);
addConstraint("alter table T2 add constraint pk primary key (id)");
addConstraint("alter table T1 add constraint fk foreign key (id) references T2(id)");
addConstraint("alter table T3 add constraint uk unique (id)");
CascadesContext c2 = createCascadesContext(
"select * from (select T1.*, T2.id as id2 from T1 inner join T2 on T1.id = T2.id) T1 left outer join T3 "
+ "on T1.id = T3.id ",
connectContext
);
Plan p2 = PlanChecker.from(c2)
.analyze()
.rewrite()
.applyExploration(RuleSet.BUSHY_TREE_JOIN_REORDER)
.getAllPlan().get(0).child(0);
HyperGraph h1 = HyperGraph.builderForMv(p1).buildAll().get(0);
HyperGraph h2 = HyperGraph.builderForMv(p2).buildAll().get(0);
ComparisonResult res = HyperGraphComparator.isLogicCompatible(h1, h2, constructContext(p1, p2));
Assertions.assertTrue(!res.isInvalid());
Assertions.assertTrue(res.getViewExpressions().isEmpty());
dropConstraint("alter table T2 drop constraint pk");
dropConstraint("alter table T3 drop constraint uk");
}
LogicalCompatibilityContext constructContext(Plan p1, Plan p2) {