branch-2.1: [fix](Nereids) not generate duplicate exprid after convert outer to anti rule #52798 (#53901)
cherry picked from #52798
This commit is contained in:
@ -47,6 +47,7 @@ import org.apache.doris.nereids.rules.rewrite.CollectCteConsumerOutput;
|
||||
import org.apache.doris.nereids.rules.rewrite.CollectFilterAboveConsumer;
|
||||
import org.apache.doris.nereids.rules.rewrite.ColumnPruning;
|
||||
import org.apache.doris.nereids.rules.rewrite.ConvertInnerOrCrossJoin;
|
||||
import org.apache.doris.nereids.rules.rewrite.ConvertOuterJoinToAntiJoin;
|
||||
import org.apache.doris.nereids.rules.rewrite.CountDistinctRewrite;
|
||||
import org.apache.doris.nereids.rules.rewrite.CountLiteralRewrite;
|
||||
import org.apache.doris.nereids.rules.rewrite.CreatePartitionTopNFromWindow;
|
||||
@ -443,6 +444,7 @@ public class Rewriter extends AbstractBatchJobExecutor {
|
||||
ImmutableSet.of(LogicalCTEAnchor.class),
|
||||
() -> jobs(
|
||||
// after variant sub path pruning, we need do column pruning again
|
||||
bottomUp(RuleSet.PUSH_DOWN_FILTERS),
|
||||
custom(RuleType.COLUMN_PRUNING, ColumnPruning::new),
|
||||
bottomUp(ImmutableList.of(
|
||||
new PushDownFilterThroughProject(),
|
||||
@ -531,6 +533,8 @@ public class Rewriter extends AbstractBatchJobExecutor {
|
||||
topic("rewrite cte sub-tree before sub path push down",
|
||||
custom(RuleType.REWRITE_CTE_CHILDREN, () -> new RewriteCteChildren(beforePushDownJobs))
|
||||
)));
|
||||
rewriteJobs.addAll(jobs(topic("convert outer join to anti",
|
||||
custom(RuleType.CONVERT_OUTER_JOIN_TO_ANTI, ConvertOuterJoinToAntiJoin::new))));
|
||||
if (needOrExpansion) {
|
||||
rewriteJobs.addAll(jobs(topic("or expansion",
|
||||
custom(RuleType.OR_EXPANSION, () -> OrExpansion.INSTANCE))));
|
||||
|
||||
@ -96,7 +96,6 @@ import org.apache.doris.nereids.rules.implementation.LogicalTVFRelationToPhysica
|
||||
import org.apache.doris.nereids.rules.implementation.LogicalTopNToPhysicalTopN;
|
||||
import org.apache.doris.nereids.rules.implementation.LogicalUnionToPhysicalUnion;
|
||||
import org.apache.doris.nereids.rules.implementation.LogicalWindowToPhysicalWindow;
|
||||
import org.apache.doris.nereids.rules.rewrite.ConvertOuterJoinToAntiJoin;
|
||||
import org.apache.doris.nereids.rules.rewrite.CreatePartitionTopNFromWindow;
|
||||
import org.apache.doris.nereids.rules.rewrite.EliminateFilter;
|
||||
import org.apache.doris.nereids.rules.rewrite.EliminateOuterJoin;
|
||||
@ -162,7 +161,6 @@ public class RuleSet {
|
||||
new PushDownFilterThroughGenerate(),
|
||||
new PushDownProjectThroughLimit(),
|
||||
new EliminateOuterJoin(),
|
||||
new ConvertOuterJoinToAntiJoin(),
|
||||
new MergeProjects(),
|
||||
new MergeFilters(),
|
||||
new MergeGenerates(),
|
||||
|
||||
@ -213,6 +213,12 @@ public enum RuleType {
|
||||
REWRITE_SORT_EXPRESSION(RuleTypeClass.REWRITE),
|
||||
REWRITE_HAVING_EXPRESSION(RuleTypeClass.REWRITE),
|
||||
REWRITE_REPEAT_EXPRESSION(RuleTypeClass.REWRITE),
|
||||
REWRITE_OLAP_TABLE_SINK_EXPRESSION(RuleTypeClass.REWRITE),
|
||||
REWRITE_SINK_EXPRESSION(RuleTypeClass.REWRITE),
|
||||
REWRITE_WINDOW_EXPRESSION(RuleTypeClass.REWRITE),
|
||||
REWRITE_SET_OPERATION_EXPRESSION(RuleTypeClass.REWRITE),
|
||||
REWRITE_PARTITION_TOPN_EXPRESSION(RuleTypeClass.REWRITE),
|
||||
REWRITE_TOPN_EXPRESSION(RuleTypeClass.REWRITE),
|
||||
EXTRACT_FILTER_FROM_JOIN(RuleTypeClass.REWRITE),
|
||||
REORDER_JOIN(RuleTypeClass.REWRITE),
|
||||
MERGE_PERCENTILE_TO_ARRAY(RuleTypeClass.REWRITE),
|
||||
|
||||
@ -20,6 +20,7 @@ package org.apache.doris.nereids.rules.expression;
|
||||
import org.apache.doris.common.Pair;
|
||||
import org.apache.doris.nereids.pattern.ExpressionPatternRules;
|
||||
import org.apache.doris.nereids.pattern.ExpressionPatternTraverseListeners;
|
||||
import org.apache.doris.nereids.pattern.MatchingContext;
|
||||
import org.apache.doris.nereids.properties.OrderKey;
|
||||
import org.apache.doris.nereids.rules.Rule;
|
||||
import org.apache.doris.nereids.rules.RuleType;
|
||||
@ -28,25 +29,39 @@ import org.apache.doris.nereids.rules.rewrite.RewriteRuleFactory;
|
||||
import org.apache.doris.nereids.trees.expressions.EqualPredicate;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.NamedExpression;
|
||||
import org.apache.doris.nereids.trees.expressions.OrderExpression;
|
||||
import org.apache.doris.nereids.trees.expressions.Slot;
|
||||
import org.apache.doris.nereids.trees.expressions.SlotReference;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.Function;
|
||||
import org.apache.doris.nereids.trees.plans.Plan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalCTEConsumer;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalGenerate;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalHaving;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalOneRowRelation;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalPartitionTopN;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalRepeat;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalSetOperation;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalSink;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalSort;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalTopN;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalWindow;
|
||||
import org.apache.doris.nereids.util.ExpressionUtils;
|
||||
import org.apache.doris.nereids.util.Utils;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import com.google.common.collect.ImmutableList.Builder;
|
||||
import com.google.common.collect.ImmutableMap;
|
||||
import com.google.common.collect.ImmutableMultimap;
|
||||
import com.google.common.collect.ImmutableSet;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collection;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
|
||||
@ -79,7 +94,19 @@ public class ExpressionRewrite implements RewriteRuleFactory {
|
||||
new JoinExpressionRewrite().build(),
|
||||
new SortExpressionRewrite().build(),
|
||||
new LogicalRepeatRewrite().build(),
|
||||
new HavingExpressionRewrite().build());
|
||||
new HavingExpressionRewrite().build(),
|
||||
new LogicalPartitionTopNExpressionRewrite().build(),
|
||||
new LogicalTopNExpressionRewrite().build(),
|
||||
new LogicalSetOperationRewrite().build(),
|
||||
new LogicalWindowRewrite().build(),
|
||||
new LogicalCteConsumerRewrite().build(),
|
||||
new LogicalResultSinkRewrite().build(),
|
||||
new LogicalFileSinkRewrite().build(),
|
||||
new LogicalHiveTableSinkRewrite().build(),
|
||||
new LogicalIcebergTableSinkRewrite().build(),
|
||||
new LogicalJdbcTableSinkRewrite().build(),
|
||||
new LogicalOlapTableSinkRewrite().build(),
|
||||
new LogicalDeferMaterializeResultSinkRewrite().build());
|
||||
}
|
||||
|
||||
private class GenerateExpressionRewrite extends OneRewriteRuleFactory {
|
||||
@ -264,7 +291,166 @@ public class ExpressionRewrite implements RewriteRuleFactory {
|
||||
}
|
||||
}
|
||||
|
||||
private class LogicalRepeatRewrite extends OneRewriteRuleFactory {
|
||||
private class LogicalWindowRewrite extends OneRewriteRuleFactory {
|
||||
@Override
|
||||
public Rule build() {
|
||||
return logicalWindow().thenApply(ctx -> {
|
||||
LogicalWindow<Plan> window = ctx.root;
|
||||
List<NamedExpression> windowExpressions = window.getWindowExpressions();
|
||||
ExpressionRewriteContext context = new ExpressionRewriteContext(ctx.cascadesContext);
|
||||
List<NamedExpression> result = rewriteAll(windowExpressions, rewriter, context);
|
||||
return window.withExpressionsAndChild(result, window.child());
|
||||
})
|
||||
.toRule(RuleType.REWRITE_WINDOW_EXPRESSION);
|
||||
}
|
||||
}
|
||||
|
||||
private class LogicalSetOperationRewrite extends OneRewriteRuleFactory {
|
||||
@Override
|
||||
public Rule build() {
|
||||
return logicalSetOperation().thenApply(ctx -> {
|
||||
LogicalSetOperation setOperation = ctx.root;
|
||||
List<List<SlotReference>> slotsList = setOperation.getRegularChildrenOutputs();
|
||||
List<List<SlotReference>> newSlotsList = new ArrayList<>();
|
||||
ExpressionRewriteContext context = new ExpressionRewriteContext(ctx.cascadesContext);
|
||||
for (List<SlotReference> slots : slotsList) {
|
||||
List<SlotReference> result = rewriteAll(slots, rewriter, context);
|
||||
newSlotsList.add(result);
|
||||
}
|
||||
return setOperation.withChildrenAndTheirOutputs(setOperation.children(), newSlotsList);
|
||||
})
|
||||
.toRule(RuleType.REWRITE_SET_OPERATION_EXPRESSION);
|
||||
}
|
||||
}
|
||||
|
||||
private class LogicalTopNExpressionRewrite extends OneRewriteRuleFactory {
|
||||
@Override
|
||||
public Rule build() {
|
||||
return logicalTopN().thenApply(ctx -> {
|
||||
LogicalTopN<Plan> topN = ctx.root;
|
||||
List<OrderKey> orderKeys = topN.getOrderKeys();
|
||||
ImmutableList.Builder<OrderKey> rewrittenOrderKeys
|
||||
= ImmutableList.builderWithExpectedSize(orderKeys.size());
|
||||
ExpressionRewriteContext context = new ExpressionRewriteContext(ctx.cascadesContext);
|
||||
boolean changed = false;
|
||||
for (OrderKey k : orderKeys) {
|
||||
Expression expression = rewriter.rewrite(k.getExpr(), context);
|
||||
changed |= expression != k.getExpr();
|
||||
rewrittenOrderKeys.add(new OrderKey(expression, k.isAsc(), k.isNullFirst()));
|
||||
}
|
||||
return changed ? topN.withOrderKeys(rewrittenOrderKeys.build()) : topN;
|
||||
}).toRule(RuleType.REWRITE_TOPN_EXPRESSION);
|
||||
}
|
||||
}
|
||||
|
||||
private class LogicalPartitionTopNExpressionRewrite extends OneRewriteRuleFactory {
|
||||
@Override
|
||||
public Rule build() {
|
||||
return logicalPartitionTopN().thenApply(ctx -> {
|
||||
LogicalPartitionTopN<Plan> partitionTopN = ctx.root;
|
||||
ExpressionRewriteContext context = new ExpressionRewriteContext(ctx.cascadesContext);
|
||||
List<OrderExpression> newOrderExpressions = new ArrayList<>();
|
||||
for (OrderExpression orderExpression : partitionTopN.getOrderKeys()) {
|
||||
OrderKey orderKey = orderExpression.getOrderKey();
|
||||
Expression expr = rewriter.rewrite(orderKey.getExpr(), context);
|
||||
OrderKey newOrderKey = new OrderKey(expr, orderKey.isAsc(), orderKey.isNullFirst());
|
||||
newOrderExpressions.add(new OrderExpression(newOrderKey));
|
||||
}
|
||||
List<Expression> result = rewriteAll(partitionTopN.getPartitionKeys(), rewriter, context);
|
||||
return partitionTopN.withPartitionKeysAndOrderKeys(result, newOrderExpressions);
|
||||
}).toRule(RuleType.REWRITE_PARTITION_TOPN_EXPRESSION);
|
||||
}
|
||||
}
|
||||
|
||||
private class LogicalCteConsumerRewrite extends OneRewriteRuleFactory {
|
||||
@Override
|
||||
public Rule build() {
|
||||
return logicalCTEConsumer().thenApply(ctx -> {
|
||||
LogicalCTEConsumer consumer = ctx.root;
|
||||
boolean changed = false;
|
||||
ExpressionRewriteContext context = new ExpressionRewriteContext(ctx.cascadesContext);
|
||||
ImmutableMap.Builder<Slot, Slot> cToPBuilder = ImmutableMap.builder();
|
||||
ImmutableMultimap.Builder<Slot, Slot> pToCBuilder = ImmutableMultimap.builder();
|
||||
for (Map.Entry<Slot, Slot> entry : consumer.getConsumerToProducerOutputMap().entrySet()) {
|
||||
Slot key = (Slot) rewriter.rewrite(entry.getKey(), context);
|
||||
Slot value = (Slot) rewriter.rewrite(entry.getValue(), context);
|
||||
cToPBuilder.put(key, value);
|
||||
pToCBuilder.put(value, key);
|
||||
if (!key.equals(entry.getKey()) || !value.equals(entry.getValue())) {
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
return changed ? consumer.withTwoMaps(cToPBuilder.build(), pToCBuilder.build()) : consumer;
|
||||
}).toRule(RuleType.REWRITE_TOPN_EXPRESSION);
|
||||
}
|
||||
}
|
||||
|
||||
private class LogicalResultSinkRewrite extends OneRewriteRuleFactory {
|
||||
@Override
|
||||
public Rule build() {
|
||||
return logicalResultSink().thenApply(ExpressionRewrite.this::applyRewriteToSink)
|
||||
.toRule(RuleType.REWRITE_SINK_EXPRESSION);
|
||||
}
|
||||
}
|
||||
|
||||
private class LogicalFileSinkRewrite extends OneRewriteRuleFactory {
|
||||
@Override
|
||||
public Rule build() {
|
||||
return logicalFileSink().thenApply(ExpressionRewrite.this::applyRewriteToSink)
|
||||
.toRule(RuleType.REWRITE_SINK_EXPRESSION);
|
||||
}
|
||||
}
|
||||
|
||||
private class LogicalHiveTableSinkRewrite extends OneRewriteRuleFactory {
|
||||
@Override
|
||||
public Rule build() {
|
||||
return logicalHiveTableSink().thenApply(ExpressionRewrite.this::applyRewriteToSink)
|
||||
.toRule(RuleType.REWRITE_SINK_EXPRESSION);
|
||||
}
|
||||
}
|
||||
|
||||
private class LogicalIcebergTableSinkRewrite extends OneRewriteRuleFactory {
|
||||
@Override
|
||||
public Rule build() {
|
||||
return logicalIcebergTableSink().thenApply(ExpressionRewrite.this::applyRewriteToSink)
|
||||
.toRule(RuleType.REWRITE_SINK_EXPRESSION);
|
||||
}
|
||||
}
|
||||
|
||||
private class LogicalJdbcTableSinkRewrite extends OneRewriteRuleFactory {
|
||||
@Override
|
||||
public Rule build() {
|
||||
return logicalJdbcTableSink().thenApply(ExpressionRewrite.this::applyRewriteToSink)
|
||||
.toRule(RuleType.REWRITE_SINK_EXPRESSION);
|
||||
}
|
||||
}
|
||||
|
||||
private class LogicalOlapTableSinkRewrite extends OneRewriteRuleFactory {
|
||||
@Override
|
||||
public Rule build() {
|
||||
return logicalOlapTableSink().thenApply(ExpressionRewrite.this::applyRewriteToSink)
|
||||
.toRule(RuleType.REWRITE_SINK_EXPRESSION);
|
||||
}
|
||||
}
|
||||
|
||||
private class LogicalDeferMaterializeResultSinkRewrite extends OneRewriteRuleFactory {
|
||||
@Override
|
||||
public Rule build() {
|
||||
return logicalDeferMaterializeResultSink().thenApply(ExpressionRewrite.this::applyRewriteToSink)
|
||||
.toRule(RuleType.REWRITE_SINK_EXPRESSION);
|
||||
}
|
||||
}
|
||||
|
||||
private LogicalSink<Plan> applyRewriteToSink(MatchingContext<? extends LogicalSink<Plan>> ctx) {
|
||||
LogicalSink<Plan> sink = ctx.root;
|
||||
ExpressionRewriteContext context = new ExpressionRewriteContext(ctx.cascadesContext);
|
||||
List<NamedExpression> outputExprs = sink.getOutputExprs();
|
||||
List<NamedExpression> result = rewriteAll(outputExprs, rewriter, context);
|
||||
return sink.withOutputExprs(result);
|
||||
}
|
||||
|
||||
/** LogicalRepeatRewrite */
|
||||
public class LogicalRepeatRewrite extends OneRewriteRuleFactory {
|
||||
@Override
|
||||
public Rule build() {
|
||||
return logicalRepeat().thenApply(ctx -> {
|
||||
|
||||
@ -17,9 +17,9 @@
|
||||
|
||||
package org.apache.doris.nereids.rules.rewrite;
|
||||
|
||||
import org.apache.doris.nereids.rules.Rule;
|
||||
import org.apache.doris.nereids.rules.RuleType;
|
||||
import org.apache.doris.nereids.jobs.JobContext;
|
||||
import org.apache.doris.nereids.trees.expressions.Alias;
|
||||
import org.apache.doris.nereids.trees.expressions.ExprId;
|
||||
import org.apache.doris.nereids.trees.expressions.NamedExpression;
|
||||
import org.apache.doris.nereids.trees.expressions.Slot;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
|
||||
@ -28,9 +28,14 @@ import org.apache.doris.nereids.trees.plans.Plan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
|
||||
import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter;
|
||||
import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter;
|
||||
import org.apache.doris.nereids.util.TypeUtils;
|
||||
|
||||
import java.util.List;
|
||||
import com.google.common.collect.ImmutableList;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@ -42,18 +47,41 @@ import java.util.stream.Collectors;
|
||||
* project(A.*)
|
||||
* - LeftAntiJoin(A, B)
|
||||
*/
|
||||
public class ConvertOuterJoinToAntiJoin extends OneRewriteRuleFactory {
|
||||
public class ConvertOuterJoinToAntiJoin extends DefaultPlanRewriter<Map<ExprId, ExprId>> implements CustomRewriter {
|
||||
private ExprIdRewriter exprIdReplacer;
|
||||
|
||||
@Override
|
||||
public Rule build() {
|
||||
return logicalFilter(logicalJoin()
|
||||
.when(join -> join.getJoinType().isOuterJoin()))
|
||||
.then(this::toAntiJoin)
|
||||
.toRule(RuleType.CONVERT_OUTER_JOIN_TO_ANTI);
|
||||
public Plan rewriteRoot(Plan plan, JobContext jobContext) {
|
||||
if (!plan.containsType(LogicalJoin.class)) {
|
||||
return plan;
|
||||
}
|
||||
Map<ExprId, ExprId> replaceMap = new HashMap<>();
|
||||
ExprIdRewriter.ReplaceRule replaceRule = new ExprIdRewriter.ReplaceRule(replaceMap);
|
||||
exprIdReplacer = new ExprIdRewriter(replaceRule, jobContext);
|
||||
return plan.accept(this, replaceMap);
|
||||
}
|
||||
|
||||
private Plan toAntiJoin(LogicalFilter<LogicalJoin<Plan, Plan>> filter) {
|
||||
@Override
|
||||
public Plan visit(Plan plan, Map<ExprId, ExprId> replaceMap) {
|
||||
plan = visitChildren(this, plan, replaceMap);
|
||||
plan = exprIdReplacer.rewriteExpr(plan, replaceMap);
|
||||
return plan;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Plan visitLogicalFilter(LogicalFilter<? extends Plan> filter, Map<ExprId, ExprId> replaceMap) {
|
||||
filter = (LogicalFilter<? extends Plan>) visit(filter, replaceMap);
|
||||
if (!(filter.child() instanceof LogicalJoin)) {
|
||||
return filter;
|
||||
}
|
||||
return toAntiJoin((LogicalFilter<LogicalJoin<Plan, Plan>>) filter, replaceMap);
|
||||
}
|
||||
|
||||
private Plan toAntiJoin(LogicalFilter<LogicalJoin<Plan, Plan>> filter, Map<ExprId, ExprId> replaceMap) {
|
||||
LogicalJoin<Plan, Plan> join = filter.child();
|
||||
if (!join.getJoinType().isLeftOuterJoin() && !join.getJoinType().isRightOuterJoin()) {
|
||||
return filter;
|
||||
}
|
||||
|
||||
Set<Slot> alwaysNullSlots = filter.getConjuncts().stream()
|
||||
.filter(p -> TypeUtils.isNull(p).isPresent())
|
||||
@ -66,33 +94,37 @@ public class ConvertOuterJoinToAntiJoin extends OneRewriteRuleFactory {
|
||||
.filter(s -> alwaysNullSlots.contains(s) && !s.nullable())
|
||||
.collect(Collectors.toSet());
|
||||
|
||||
Plan newJoin = null;
|
||||
Plan newChild = null;
|
||||
if (join.getJoinType().isLeftOuterJoin() && !rightAlwaysNullSlots.isEmpty()) {
|
||||
newJoin = join.withJoinTypeAndContext(JoinType.LEFT_ANTI_JOIN, join.getJoinReorderContext());
|
||||
newChild = join.withJoinTypeAndContext(JoinType.LEFT_ANTI_JOIN, join.getJoinReorderContext());
|
||||
}
|
||||
if (join.getJoinType().isRightOuterJoin() && !leftAlwaysNullSlots.isEmpty()) {
|
||||
newJoin = join.withJoinTypeAndContext(JoinType.RIGHT_ANTI_JOIN, join.getJoinReorderContext());
|
||||
newChild = join.withJoinTypeAndContext(JoinType.RIGHT_ANTI_JOIN, join.getJoinReorderContext());
|
||||
}
|
||||
if (newJoin == null) {
|
||||
return null;
|
||||
if (newChild == null) {
|
||||
return filter;
|
||||
}
|
||||
|
||||
if (!newJoin.getOutputSet().containsAll(filter.getInputSlots())) {
|
||||
if (!newChild.getOutputSet().containsAll(filter.getInputSlots())) {
|
||||
// if there are slots that don't belong to join output, we use null alias to replace them
|
||||
// such as:
|
||||
// project(A.id, null as B.id)
|
||||
// - (A left anti join B)
|
||||
Set<Slot> joinOutput = newJoin.getOutputSet();
|
||||
List<NamedExpression> projects = filter.getOutput().stream()
|
||||
.map(s -> {
|
||||
if (joinOutput.contains(s)) {
|
||||
return s;
|
||||
} else {
|
||||
return new Alias(s.getExprId(), new NullLiteral(s.getDataType()), s.getName());
|
||||
}
|
||||
}).collect(Collectors.toList());
|
||||
newJoin = new LogicalProject<>(projects, newJoin);
|
||||
Set<Slot> joinOutputs = newChild.getOutputSet();
|
||||
ImmutableList.Builder<NamedExpression> projectsBuilder = ImmutableList.builder();
|
||||
for (NamedExpression e : filter.getOutput()) {
|
||||
if (joinOutputs.contains(e)) {
|
||||
projectsBuilder.add(e);
|
||||
} else {
|
||||
Alias newAlias = new Alias(new NullLiteral(e.getDataType()), e.getName(), e.getQualifier());
|
||||
replaceMap.put(e.getExprId(), newAlias.getExprId());
|
||||
projectsBuilder.add(newAlias);
|
||||
}
|
||||
}
|
||||
newChild = new LogicalProject<>(projectsBuilder.build(), newChild);
|
||||
return exprIdReplacer.rewriteExpr(filter.withChildren(newChild), replaceMap);
|
||||
} else {
|
||||
return filter.withChildren(newChild);
|
||||
}
|
||||
return filter.withChildren(newJoin);
|
||||
}
|
||||
}
|
||||
|
||||
@ -0,0 +1,106 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package org.apache.doris.nereids.rules.rewrite;
|
||||
|
||||
import org.apache.doris.nereids.jobs.JobContext;
|
||||
import org.apache.doris.nereids.pattern.Pattern;
|
||||
import org.apache.doris.nereids.rules.Rule;
|
||||
import org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher;
|
||||
import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory;
|
||||
import org.apache.doris.nereids.rules.expression.ExpressionRewrite;
|
||||
import org.apache.doris.nereids.rules.expression.ExpressionRuleExecutor;
|
||||
import org.apache.doris.nereids.trees.expressions.ExprId;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.Slot;
|
||||
import org.apache.doris.nereids.trees.expressions.SlotReference;
|
||||
import org.apache.doris.nereids.trees.plans.Plan;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
/** replace SlotReference ExprId in logical plans */
|
||||
public class ExprIdRewriter extends ExpressionRewrite {
|
||||
private final List<Rule> rules;
|
||||
private final JobContext jobContext;
|
||||
|
||||
public ExprIdRewriter(ReplaceRule replaceRule, JobContext jobContext) {
|
||||
super(new ExpressionRuleExecutor(ImmutableList.of(bottomUp(replaceRule))));
|
||||
rules = buildRules();
|
||||
this.jobContext = jobContext;
|
||||
}
|
||||
|
||||
/**rewriteExpr*/
|
||||
public Plan rewriteExpr(Plan plan, Map<ExprId, ExprId> replaceMap) {
|
||||
if (replaceMap.isEmpty()) {
|
||||
return plan;
|
||||
}
|
||||
for (Rule rule : rules) {
|
||||
Pattern<Plan> pattern = (Pattern<Plan>) rule.getPattern();
|
||||
if (pattern.matchPlanTree(plan)) {
|
||||
List<Plan> newPlans = rule.transform(plan, jobContext.getCascadesContext());
|
||||
Plan newPlan = newPlans.get(0);
|
||||
if (!newPlan.deepEquals(plan)) {
|
||||
return newPlan;
|
||||
}
|
||||
}
|
||||
}
|
||||
return plan;
|
||||
}
|
||||
|
||||
/**
|
||||
* Iteratively rewrites IDs using the replaceMap:
|
||||
* 1. For a given SlotReference with initial ID, retrieve the corresponding value ID from the replaceMap.
|
||||
* 2. If the value ID exists within the replaceMap, continue the lookup process using the value ID
|
||||
* until it no longer appears in the replaceMap.
|
||||
* 3. return SlotReference final value ID as the result of the rewrite.
|
||||
* e.g. replaceMap:{0:3, 1:6, 6:7}
|
||||
* SlotReference:a#0 -> a#3, a#1 -> a#7
|
||||
* */
|
||||
public static class ReplaceRule implements ExpressionPatternRuleFactory {
|
||||
private final Map<ExprId, ExprId> replaceMap;
|
||||
|
||||
public ReplaceRule(Map<ExprId, ExprId> replaceMap) {
|
||||
this.replaceMap = replaceMap;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<ExpressionPatternMatcher<? extends Expression>> buildRules() {
|
||||
return ImmutableList.of(
|
||||
matchesType(SlotReference.class).thenApply(ctx -> {
|
||||
Slot slot = ctx.expr;
|
||||
|
||||
ExprId newId = replaceMap.get(slot.getExprId());
|
||||
if (newId == null) {
|
||||
return slot;
|
||||
}
|
||||
ExprId lastId = newId;
|
||||
while (true) {
|
||||
newId = replaceMap.get(lastId);
|
||||
if (newId == null) {
|
||||
return slot.withExprId(lastId);
|
||||
} else {
|
||||
lastId = newId;
|
||||
}
|
||||
}
|
||||
})
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -81,6 +81,6 @@ public class StatementScopeIdGenerator {
|
||||
if (ConnectContext.get() != null) {
|
||||
ConnectContext.get().setStatementContext(new StatementContext());
|
||||
}
|
||||
statementContext = new StatementContext();
|
||||
statementContext = new StatementContext(10000);
|
||||
}
|
||||
}
|
||||
|
||||
@ -198,4 +198,21 @@ public class LogicalCTEConsumer extends LogicalRelation implements BlockFuncDeps
|
||||
"relationId", relationId,
|
||||
"name", name);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (o == null || getClass() != o.getClass()) {
|
||||
return false;
|
||||
}
|
||||
if (!super.equals(o)) {
|
||||
return false;
|
||||
}
|
||||
LogicalCTEConsumer that = (LogicalCTEConsumer) o;
|
||||
return Objects.equals(consumerToProducerOutputMap, that.consumerToProducerOutputMap);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return super.hashCode();
|
||||
}
|
||||
}
|
||||
|
||||
@ -28,6 +28,7 @@ import org.apache.doris.nereids.util.MemoPatternMatchSupported;
|
||||
import org.apache.doris.nereids.util.MemoTestUtils;
|
||||
import org.apache.doris.nereids.util.PlanChecker;
|
||||
import org.apache.doris.nereids.util.PlanConstructor;
|
||||
import org.apache.doris.qe.ConnectContext;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import org.junit.jupiter.api.Assertions;
|
||||
@ -40,22 +41,22 @@ class OuterJoinAssocTest implements MemoPatternMatchSupported {
|
||||
LogicalOlapScan scan2;
|
||||
LogicalOlapScan scan3;
|
||||
|
||||
public OuterJoinAssocTest() throws Exception {
|
||||
// clear id so that slot id keep consistent every running
|
||||
@Test
|
||||
public void testInnerLeft() throws Exception {
|
||||
ConnectContext ctx = MemoTestUtils.createConnectContext();
|
||||
StatementScopeIdGenerator.clear();
|
||||
|
||||
// clear id so that slot id keep consistent every running
|
||||
scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0);
|
||||
scan2 = PlanConstructor.newLogicalOlapScan(1, "t2", 0);
|
||||
scan3 = PlanConstructor.newLogicalOlapScan(2, "t3", 0);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testInnerLeft() {
|
||||
LogicalPlan join = new LogicalPlanBuilder(scan1)
|
||||
.join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0)) // t1.id = t2.id
|
||||
.join(scan3, JoinType.LEFT_OUTER_JOIN, Pair.of(2, 0)) // t2.id = t3.id
|
||||
.build();
|
||||
|
||||
PlanChecker.from(MemoTestUtils.createConnectContext(), join)
|
||||
PlanChecker.from(ctx, join)
|
||||
.applyExploration(OuterJoinAssoc.INSTANCE.build())
|
||||
.matchesExploration(
|
||||
logicalJoin(
|
||||
@ -66,13 +67,21 @@ class OuterJoinAssocTest implements MemoPatternMatchSupported {
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testLeftLeft() {
|
||||
public void testLeftLeft() throws Exception {
|
||||
ConnectContext ctx = MemoTestUtils.createConnectContext();
|
||||
StatementScopeIdGenerator.clear();
|
||||
|
||||
// clear id so that slot id keep consistent every running
|
||||
scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0);
|
||||
scan2 = PlanConstructor.newLogicalOlapScan(1, "t2", 0);
|
||||
scan3 = PlanConstructor.newLogicalOlapScan(2, "t3", 0);
|
||||
|
||||
LogicalPlan join = new LogicalPlanBuilder(scan1)
|
||||
.join(scan2, JoinType.LEFT_OUTER_JOIN, Pair.of(0, 0)) // t1.id = t2.id
|
||||
.join(scan3, JoinType.LEFT_OUTER_JOIN, Pair.of(2, 0)) // t2.id = t3.id
|
||||
.build();
|
||||
|
||||
PlanChecker.from(MemoTestUtils.createConnectContext(), join)
|
||||
PlanChecker.from(ctx, join)
|
||||
.applyExploration(OuterJoinAssoc.INSTANCE.build())
|
||||
.matchesExploration(
|
||||
logicalJoin(
|
||||
@ -83,14 +92,22 @@ class OuterJoinAssocTest implements MemoPatternMatchSupported {
|
||||
}
|
||||
|
||||
@Test
|
||||
public void rejectNull() {
|
||||
public void rejectNull() throws Exception {
|
||||
ConnectContext ctx = MemoTestUtils.createConnectContext();
|
||||
StatementScopeIdGenerator.clear();
|
||||
|
||||
// clear id so that slot id keep consistent every running
|
||||
scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0);
|
||||
scan2 = PlanConstructor.newLogicalOlapScan(1, "t2", 0);
|
||||
scan3 = PlanConstructor.newLogicalOlapScan(2, "t3", 0);
|
||||
|
||||
IsNull isNull = new IsNull(scan3.getOutput().get(0));
|
||||
LogicalPlan join = new LogicalPlanBuilder(scan1)
|
||||
.join(scan2, JoinType.LEFT_OUTER_JOIN, Pair.of(0, 0)) // t1.id = t2.id
|
||||
.join(scan3, JoinType.LEFT_OUTER_JOIN, ImmutableList.of(), ImmutableList.of(isNull)) // t3.id is not null
|
||||
.build();
|
||||
|
||||
PlanChecker.from(MemoTestUtils.createConnectContext(), join)
|
||||
PlanChecker.from(ctx, join)
|
||||
.applyExploration(OuterJoinAssoc.INSTANCE.build())
|
||||
.checkMemo(memo -> Assertions.assertEquals(1, memo.getRoot().getLogicalExpressions().size()));
|
||||
}
|
||||
|
||||
@ -32,17 +32,21 @@ import org.apache.doris.nereids.util.MemoPatternMatchSupported;
|
||||
import org.apache.doris.nereids.util.MemoTestUtils;
|
||||
import org.apache.doris.nereids.util.PlanChecker;
|
||||
import org.apache.doris.nereids.util.PlanConstructor;
|
||||
import org.apache.doris.qe.ConnectContext;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import com.google.common.collect.Sets;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
class ConvertOuterJoinToAntiJoinTest implements MemoPatternMatchSupported {
|
||||
private final LogicalOlapScan scan1;
|
||||
private final LogicalOlapScan scan2;
|
||||
private LogicalOlapScan scan1;
|
||||
private LogicalOlapScan scan2;
|
||||
|
||||
public ConvertOuterJoinToAntiJoinTest() throws Exception {
|
||||
@BeforeEach
|
||||
void setUp() throws Exception {
|
||||
// clear id so that slot id keep consistent every running
|
||||
ConnectContext.remove();
|
||||
StatementScopeIdGenerator.clear();
|
||||
scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0);
|
||||
scan2 = PlanConstructor.newLogicalOlapScan(1, "t2", 0);
|
||||
@ -58,7 +62,7 @@ class ConvertOuterJoinToAntiJoinTest implements MemoPatternMatchSupported {
|
||||
|
||||
PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
|
||||
.applyTopDown(new InferFilterNotNull())
|
||||
.applyTopDown(new ConvertOuterJoinToAntiJoin())
|
||||
.applyCustom(new ConvertOuterJoinToAntiJoin())
|
||||
.printlnTree()
|
||||
.matches(logicalJoin().when(join -> join.getJoinType().isLeftAntiJoin()));
|
||||
}
|
||||
@ -73,7 +77,7 @@ class ConvertOuterJoinToAntiJoinTest implements MemoPatternMatchSupported {
|
||||
|
||||
PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
|
||||
.applyTopDown(new InferFilterNotNull())
|
||||
.applyTopDown(new ConvertOuterJoinToAntiJoin())
|
||||
.applyCustom(new ConvertOuterJoinToAntiJoin())
|
||||
.printlnTree()
|
||||
.matches(logicalJoin().when(join -> join.getJoinType().isRightAntiJoin()));
|
||||
}
|
||||
@ -91,7 +95,7 @@ class ConvertOuterJoinToAntiJoinTest implements MemoPatternMatchSupported {
|
||||
|
||||
PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
|
||||
.applyTopDown(new InferFilterNotNull())
|
||||
.applyTopDown(new ConvertOuterJoinToAntiJoin())
|
||||
.applyCustom(new ConvertOuterJoinToAntiJoin())
|
||||
.printlnTree()
|
||||
.matches(logicalJoin().when(join -> join.getJoinType().isLeftAntiJoin()));
|
||||
}
|
||||
@ -109,7 +113,7 @@ class ConvertOuterJoinToAntiJoinTest implements MemoPatternMatchSupported {
|
||||
|
||||
PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
|
||||
.applyTopDown(new InferFilterNotNull())
|
||||
.applyTopDown(new ConvertOuterJoinToAntiJoin())
|
||||
.applyCustom(new ConvertOuterJoinToAntiJoin())
|
||||
.printlnTree()
|
||||
.matches(logicalJoin().when(join -> join.getJoinType().isLeftAntiJoin()));
|
||||
}
|
||||
@ -127,7 +131,7 @@ class ConvertOuterJoinToAntiJoinTest implements MemoPatternMatchSupported {
|
||||
|
||||
PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
|
||||
.applyTopDown(new InferFilterNotNull())
|
||||
.applyTopDown(new ConvertOuterJoinToAntiJoin())
|
||||
.applyCustom(new ConvertOuterJoinToAntiJoin())
|
||||
.printlnTree()
|
||||
.matches(logicalJoin().when(join -> join.getJoinType().isLeftOuterJoin()));
|
||||
}
|
||||
@ -146,7 +150,7 @@ class ConvertOuterJoinToAntiJoinTest implements MemoPatternMatchSupported {
|
||||
|
||||
PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
|
||||
.applyTopDown(new InferFilterNotNull())
|
||||
.applyTopDown(new ConvertOuterJoinToAntiJoin())
|
||||
.applyCustom(new ConvertOuterJoinToAntiJoin())
|
||||
.printlnTree()
|
||||
.matches(logicalJoin().when(join -> join.getJoinType().isLeftOuterJoin()));
|
||||
}
|
||||
|
||||
@ -21,6 +21,7 @@ import org.apache.doris.common.Pair;
|
||||
import org.apache.doris.nereids.trees.expressions.And;
|
||||
import org.apache.doris.nereids.trees.expressions.GreaterThan;
|
||||
import org.apache.doris.nereids.trees.expressions.StatementScopeIdGenerator;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.Literal;
|
||||
import org.apache.doris.nereids.trees.plans.JoinType;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
|
||||
@ -69,7 +70,7 @@ class EliminateOuterJoinTest implements MemoPatternMatchSupported {
|
||||
void testEliminateRight() {
|
||||
LogicalPlan plan = new LogicalPlanBuilder(scan1)
|
||||
.join(scan2, JoinType.RIGHT_OUTER_JOIN, Pair.of(0, 0)) // t1.id = t2.id
|
||||
.filter(new GreaterThan(scan1.getOutput().get(0), Literal.of(1)))
|
||||
.filter(new GreaterThan(scan1.getOutput().get(0), new IntegerLiteral(1)))
|
||||
.build();
|
||||
|
||||
PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
|
||||
@ -81,7 +82,7 @@ class EliminateOuterJoinTest implements MemoPatternMatchSupported {
|
||||
logicalFilter(
|
||||
logicalJoin().when(join -> join.getJoinType().isInnerJoin())
|
||||
).when(filter -> filter.getConjuncts().size() == 1)
|
||||
.when(filter -> Objects.equals(filter.getConjuncts().toString(), "[(id#0 > 1)]"))
|
||||
.when(filter -> Objects.equals(filter.getConjuncts().iterator().next(), new GreaterThan(scan1.getOutput().get(0), new IntegerLiteral(1))))
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
@ -32,6 +32,7 @@ import org.apache.doris.nereids.jobs.cascades.DeriveStatsJob;
|
||||
import org.apache.doris.nereids.jobs.executor.Optimizer;
|
||||
import org.apache.doris.nereids.jobs.executor.Rewriter;
|
||||
import org.apache.doris.nereids.jobs.joinorder.JoinOrderJob;
|
||||
import org.apache.doris.nereids.jobs.rewrite.CustomRewriteJob;
|
||||
import org.apache.doris.nereids.jobs.rewrite.PlanTreeRewriteBottomUpJob;
|
||||
import org.apache.doris.nereids.jobs.rewrite.PlanTreeRewriteTopDownJob;
|
||||
import org.apache.doris.nereids.jobs.rewrite.RootPlanTreeRewriteJob;
|
||||
@ -202,6 +203,14 @@ public class PlanChecker {
|
||||
return this;
|
||||
}
|
||||
|
||||
public PlanChecker applyCustom(CustomRewriter customRewriter) {
|
||||
CustomRewriteJob customRewriteJob = new CustomRewriteJob(() -> customRewriter, RuleType.TEST_REWRITE);
|
||||
customRewriteJob.execute(cascadesContext.getCurrentJobContext());
|
||||
cascadesContext.toMemo();
|
||||
MemoValidator.validate(cascadesContext.getMemo());
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* apply a top down rewrite rule if you not care the ruleId
|
||||
*
|
||||
|
||||
@ -8,21 +8,6 @@ C2BD89103557CCBF7ED97B51860225A0
|
||||
-- !sql_1 --
|
||||
80000
|
||||
|
||||
-- !sql --
|
||||
PLAN FRAGMENT 0
|
||||
OUTPUT EXPRS:
|
||||
sleep(100)[#0]
|
||||
PARTITION: UNPARTITIONED
|
||||
|
||||
HAS_COLO_PLAN_NODE: false
|
||||
|
||||
VRESULT SINK
|
||||
MYSQL_PROTOCAL
|
||||
|
||||
0:VUNION(32)
|
||||
constant exprs:
|
||||
sleep(100)
|
||||
|
||||
-- !sql --
|
||||
true
|
||||
|
||||
|
||||
@ -49,7 +49,11 @@ suite("fold_constant_by_be") {
|
||||
log.info("result: {}, {}", res1, res2)
|
||||
assertEquals(res1[0][0], res2[0][0])
|
||||
|
||||
qt_sql "explain select sleep(sign(1)*100);"
|
||||
explain {
|
||||
sql "select sleep(sign(1)*100);"
|
||||
contains "sleep(100)"
|
||||
}
|
||||
|
||||
sql 'set query_timeout=12;'
|
||||
qt_sql "select sleep(sign(1)*10);"
|
||||
|
||||
|
||||
@ -84,4 +84,12 @@ suite("transform_outer_join_to_anti") {
|
||||
sql("select * from eliminate_outer_join_A right outer join eliminate_outer_join_B on eliminate_outer_join_B.b = eliminate_outer_join_A.a where eliminate_outer_join_A.a is null and eliminate_outer_join_B.null_b is null and eliminate_outer_join_A.null_a is null")
|
||||
contains "ANTI JOIN"
|
||||
}
|
||||
|
||||
explain {
|
||||
sql """with temp as (
|
||||
select * from eliminate_outer_join_A left outer join eliminate_outer_join_B on eliminate_outer_join_B.b = eliminate_outer_join_A.a where eliminate_outer_join_B.b is null
|
||||
)
|
||||
select * from temp t1 join temp t2"""
|
||||
contains "ANTI JOIN"
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user