From 4a062c49080b7cddeace76ff7ae3e2bf3294d751 Mon Sep 17 00:00:00 2001 From: morrySnow Date: Fri, 1 Aug 2025 11:34:07 +0800 Subject: [PATCH] branch-2.1: [fix](Nereids) not generate duplicate exprid after convert outer to anti rule #52798 (#53901) cherry picked from #52798 --- .../doris/nereids/jobs/executor/Rewriter.java | 4 + .../apache/doris/nereids/rules/RuleSet.java | 2 - .../apache/doris/nereids/rules/RuleType.java | 6 + .../rules/expression/ExpressionRewrite.java | 190 +++++++++++++++++- .../rewrite/ConvertOuterJoinToAntiJoin.java | 86 +++++--- .../nereids/rules/rewrite/ExprIdRewriter.java | 106 ++++++++++ .../StatementScopeIdGenerator.java | 2 +- .../plans/logical/LogicalCTEConsumer.java | 17 ++ .../exploration/join/OuterJoinAssocTest.java | 37 +++- .../ConvertOuterJoinToAntiJoinTest.java | 22 +- .../rules/rewrite/EliminateOuterJoinTest.java | 5 +- .../doris/nereids/util/PlanChecker.java | 9 + .../fold_constant/fold_constant_by_be.out | 15 -- .../fold_constant/fold_constant_by_be.groovy | 6 +- .../transform_outer_join_to_anti.groovy | 8 + 15 files changed, 446 insertions(+), 69 deletions(-) create mode 100644 fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ExprIdRewriter.java diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java index 4c712e919b..c5b2ee1c88 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java @@ -47,6 +47,7 @@ import org.apache.doris.nereids.rules.rewrite.CollectCteConsumerOutput; import org.apache.doris.nereids.rules.rewrite.CollectFilterAboveConsumer; import org.apache.doris.nereids.rules.rewrite.ColumnPruning; import org.apache.doris.nereids.rules.rewrite.ConvertInnerOrCrossJoin; +import org.apache.doris.nereids.rules.rewrite.ConvertOuterJoinToAntiJoin; import org.apache.doris.nereids.rules.rewrite.CountDistinctRewrite; import org.apache.doris.nereids.rules.rewrite.CountLiteralRewrite; import org.apache.doris.nereids.rules.rewrite.CreatePartitionTopNFromWindow; @@ -443,6 +444,7 @@ public class Rewriter extends AbstractBatchJobExecutor { ImmutableSet.of(LogicalCTEAnchor.class), () -> jobs( // after variant sub path pruning, we need do column pruning again + bottomUp(RuleSet.PUSH_DOWN_FILTERS), custom(RuleType.COLUMN_PRUNING, ColumnPruning::new), bottomUp(ImmutableList.of( new PushDownFilterThroughProject(), @@ -531,6 +533,8 @@ public class Rewriter extends AbstractBatchJobExecutor { topic("rewrite cte sub-tree before sub path push down", custom(RuleType.REWRITE_CTE_CHILDREN, () -> new RewriteCteChildren(beforePushDownJobs)) ))); + rewriteJobs.addAll(jobs(topic("convert outer join to anti", + custom(RuleType.CONVERT_OUTER_JOIN_TO_ANTI, ConvertOuterJoinToAntiJoin::new)))); if (needOrExpansion) { rewriteJobs.addAll(jobs(topic("or expansion", custom(RuleType.OR_EXPANSION, () -> OrExpansion.INSTANCE)))); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleSet.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleSet.java index 06f9581400..6feed32754 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleSet.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleSet.java @@ -96,7 +96,6 @@ import org.apache.doris.nereids.rules.implementation.LogicalTVFRelationToPhysica import org.apache.doris.nereids.rules.implementation.LogicalTopNToPhysicalTopN; import org.apache.doris.nereids.rules.implementation.LogicalUnionToPhysicalUnion; import org.apache.doris.nereids.rules.implementation.LogicalWindowToPhysicalWindow; -import org.apache.doris.nereids.rules.rewrite.ConvertOuterJoinToAntiJoin; import org.apache.doris.nereids.rules.rewrite.CreatePartitionTopNFromWindow; import org.apache.doris.nereids.rules.rewrite.EliminateFilter; import org.apache.doris.nereids.rules.rewrite.EliminateOuterJoin; @@ -162,7 +161,6 @@ public class RuleSet { new PushDownFilterThroughGenerate(), new PushDownProjectThroughLimit(), new EliminateOuterJoin(), - new ConvertOuterJoinToAntiJoin(), new MergeProjects(), new MergeFilters(), new MergeGenerates(), diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java index 7f4b5cb3fd..5129fb8cf5 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java @@ -213,6 +213,12 @@ public enum RuleType { REWRITE_SORT_EXPRESSION(RuleTypeClass.REWRITE), REWRITE_HAVING_EXPRESSION(RuleTypeClass.REWRITE), REWRITE_REPEAT_EXPRESSION(RuleTypeClass.REWRITE), + REWRITE_OLAP_TABLE_SINK_EXPRESSION(RuleTypeClass.REWRITE), + REWRITE_SINK_EXPRESSION(RuleTypeClass.REWRITE), + REWRITE_WINDOW_EXPRESSION(RuleTypeClass.REWRITE), + REWRITE_SET_OPERATION_EXPRESSION(RuleTypeClass.REWRITE), + REWRITE_PARTITION_TOPN_EXPRESSION(RuleTypeClass.REWRITE), + REWRITE_TOPN_EXPRESSION(RuleTypeClass.REWRITE), EXTRACT_FILTER_FROM_JOIN(RuleTypeClass.REWRITE), REORDER_JOIN(RuleTypeClass.REWRITE), MERGE_PERCENTILE_TO_ARRAY(RuleTypeClass.REWRITE), diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionRewrite.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionRewrite.java index 0a790fd586..a212017ab1 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionRewrite.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionRewrite.java @@ -20,6 +20,7 @@ package org.apache.doris.nereids.rules.expression; import org.apache.doris.common.Pair; import org.apache.doris.nereids.pattern.ExpressionPatternRules; import org.apache.doris.nereids.pattern.ExpressionPatternTraverseListeners; +import org.apache.doris.nereids.pattern.MatchingContext; import org.apache.doris.nereids.properties.OrderKey; import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleType; @@ -28,25 +29,39 @@ import org.apache.doris.nereids.rules.rewrite.RewriteRuleFactory; import org.apache.doris.nereids.trees.expressions.EqualPredicate; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.NamedExpression; +import org.apache.doris.nereids.trees.expressions.OrderExpression; +import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.expressions.SlotReference; import org.apache.doris.nereids.trees.expressions.functions.Function; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; +import org.apache.doris.nereids.trees.plans.logical.LogicalCTEConsumer; import org.apache.doris.nereids.trees.plans.logical.LogicalFilter; import org.apache.doris.nereids.trees.plans.logical.LogicalGenerate; import org.apache.doris.nereids.trees.plans.logical.LogicalHaving; import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; import org.apache.doris.nereids.trees.plans.logical.LogicalOneRowRelation; +import org.apache.doris.nereids.trees.plans.logical.LogicalPartitionTopN; import org.apache.doris.nereids.trees.plans.logical.LogicalProject; import org.apache.doris.nereids.trees.plans.logical.LogicalRepeat; +import org.apache.doris.nereids.trees.plans.logical.LogicalSetOperation; +import org.apache.doris.nereids.trees.plans.logical.LogicalSink; import org.apache.doris.nereids.trees.plans.logical.LogicalSort; +import org.apache.doris.nereids.trees.plans.logical.LogicalTopN; +import org.apache.doris.nereids.trees.plans.logical.LogicalWindow; import org.apache.doris.nereids.util.ExpressionUtils; import org.apache.doris.nereids.util.Utils; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableList.Builder; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableMultimap; import com.google.common.collect.ImmutableSet; +import java.util.ArrayList; import java.util.Collection; import java.util.List; +import java.util.Map; import java.util.Objects; import java.util.Set; @@ -79,7 +94,19 @@ public class ExpressionRewrite implements RewriteRuleFactory { new JoinExpressionRewrite().build(), new SortExpressionRewrite().build(), new LogicalRepeatRewrite().build(), - new HavingExpressionRewrite().build()); + new HavingExpressionRewrite().build(), + new LogicalPartitionTopNExpressionRewrite().build(), + new LogicalTopNExpressionRewrite().build(), + new LogicalSetOperationRewrite().build(), + new LogicalWindowRewrite().build(), + new LogicalCteConsumerRewrite().build(), + new LogicalResultSinkRewrite().build(), + new LogicalFileSinkRewrite().build(), + new LogicalHiveTableSinkRewrite().build(), + new LogicalIcebergTableSinkRewrite().build(), + new LogicalJdbcTableSinkRewrite().build(), + new LogicalOlapTableSinkRewrite().build(), + new LogicalDeferMaterializeResultSinkRewrite().build()); } private class GenerateExpressionRewrite extends OneRewriteRuleFactory { @@ -264,7 +291,166 @@ public class ExpressionRewrite implements RewriteRuleFactory { } } - private class LogicalRepeatRewrite extends OneRewriteRuleFactory { + private class LogicalWindowRewrite extends OneRewriteRuleFactory { + @Override + public Rule build() { + return logicalWindow().thenApply(ctx -> { + LogicalWindow window = ctx.root; + List windowExpressions = window.getWindowExpressions(); + ExpressionRewriteContext context = new ExpressionRewriteContext(ctx.cascadesContext); + List result = rewriteAll(windowExpressions, rewriter, context); + return window.withExpressionsAndChild(result, window.child()); + }) + .toRule(RuleType.REWRITE_WINDOW_EXPRESSION); + } + } + + private class LogicalSetOperationRewrite extends OneRewriteRuleFactory { + @Override + public Rule build() { + return logicalSetOperation().thenApply(ctx -> { + LogicalSetOperation setOperation = ctx.root; + List> slotsList = setOperation.getRegularChildrenOutputs(); + List> newSlotsList = new ArrayList<>(); + ExpressionRewriteContext context = new ExpressionRewriteContext(ctx.cascadesContext); + for (List slots : slotsList) { + List result = rewriteAll(slots, rewriter, context); + newSlotsList.add(result); + } + return setOperation.withChildrenAndTheirOutputs(setOperation.children(), newSlotsList); + }) + .toRule(RuleType.REWRITE_SET_OPERATION_EXPRESSION); + } + } + + private class LogicalTopNExpressionRewrite extends OneRewriteRuleFactory { + @Override + public Rule build() { + return logicalTopN().thenApply(ctx -> { + LogicalTopN topN = ctx.root; + List orderKeys = topN.getOrderKeys(); + ImmutableList.Builder rewrittenOrderKeys + = ImmutableList.builderWithExpectedSize(orderKeys.size()); + ExpressionRewriteContext context = new ExpressionRewriteContext(ctx.cascadesContext); + boolean changed = false; + for (OrderKey k : orderKeys) { + Expression expression = rewriter.rewrite(k.getExpr(), context); + changed |= expression != k.getExpr(); + rewrittenOrderKeys.add(new OrderKey(expression, k.isAsc(), k.isNullFirst())); + } + return changed ? topN.withOrderKeys(rewrittenOrderKeys.build()) : topN; + }).toRule(RuleType.REWRITE_TOPN_EXPRESSION); + } + } + + private class LogicalPartitionTopNExpressionRewrite extends OneRewriteRuleFactory { + @Override + public Rule build() { + return logicalPartitionTopN().thenApply(ctx -> { + LogicalPartitionTopN partitionTopN = ctx.root; + ExpressionRewriteContext context = new ExpressionRewriteContext(ctx.cascadesContext); + List newOrderExpressions = new ArrayList<>(); + for (OrderExpression orderExpression : partitionTopN.getOrderKeys()) { + OrderKey orderKey = orderExpression.getOrderKey(); + Expression expr = rewriter.rewrite(orderKey.getExpr(), context); + OrderKey newOrderKey = new OrderKey(expr, orderKey.isAsc(), orderKey.isNullFirst()); + newOrderExpressions.add(new OrderExpression(newOrderKey)); + } + List result = rewriteAll(partitionTopN.getPartitionKeys(), rewriter, context); + return partitionTopN.withPartitionKeysAndOrderKeys(result, newOrderExpressions); + }).toRule(RuleType.REWRITE_PARTITION_TOPN_EXPRESSION); + } + } + + private class LogicalCteConsumerRewrite extends OneRewriteRuleFactory { + @Override + public Rule build() { + return logicalCTEConsumer().thenApply(ctx -> { + LogicalCTEConsumer consumer = ctx.root; + boolean changed = false; + ExpressionRewriteContext context = new ExpressionRewriteContext(ctx.cascadesContext); + ImmutableMap.Builder cToPBuilder = ImmutableMap.builder(); + ImmutableMultimap.Builder pToCBuilder = ImmutableMultimap.builder(); + for (Map.Entry entry : consumer.getConsumerToProducerOutputMap().entrySet()) { + Slot key = (Slot) rewriter.rewrite(entry.getKey(), context); + Slot value = (Slot) rewriter.rewrite(entry.getValue(), context); + cToPBuilder.put(key, value); + pToCBuilder.put(value, key); + if (!key.equals(entry.getKey()) || !value.equals(entry.getValue())) { + changed = true; + } + } + return changed ? consumer.withTwoMaps(cToPBuilder.build(), pToCBuilder.build()) : consumer; + }).toRule(RuleType.REWRITE_TOPN_EXPRESSION); + } + } + + private class LogicalResultSinkRewrite extends OneRewriteRuleFactory { + @Override + public Rule build() { + return logicalResultSink().thenApply(ExpressionRewrite.this::applyRewriteToSink) + .toRule(RuleType.REWRITE_SINK_EXPRESSION); + } + } + + private class LogicalFileSinkRewrite extends OneRewriteRuleFactory { + @Override + public Rule build() { + return logicalFileSink().thenApply(ExpressionRewrite.this::applyRewriteToSink) + .toRule(RuleType.REWRITE_SINK_EXPRESSION); + } + } + + private class LogicalHiveTableSinkRewrite extends OneRewriteRuleFactory { + @Override + public Rule build() { + return logicalHiveTableSink().thenApply(ExpressionRewrite.this::applyRewriteToSink) + .toRule(RuleType.REWRITE_SINK_EXPRESSION); + } + } + + private class LogicalIcebergTableSinkRewrite extends OneRewriteRuleFactory { + @Override + public Rule build() { + return logicalIcebergTableSink().thenApply(ExpressionRewrite.this::applyRewriteToSink) + .toRule(RuleType.REWRITE_SINK_EXPRESSION); + } + } + + private class LogicalJdbcTableSinkRewrite extends OneRewriteRuleFactory { + @Override + public Rule build() { + return logicalJdbcTableSink().thenApply(ExpressionRewrite.this::applyRewriteToSink) + .toRule(RuleType.REWRITE_SINK_EXPRESSION); + } + } + + private class LogicalOlapTableSinkRewrite extends OneRewriteRuleFactory { + @Override + public Rule build() { + return logicalOlapTableSink().thenApply(ExpressionRewrite.this::applyRewriteToSink) + .toRule(RuleType.REWRITE_SINK_EXPRESSION); + } + } + + private class LogicalDeferMaterializeResultSinkRewrite extends OneRewriteRuleFactory { + @Override + public Rule build() { + return logicalDeferMaterializeResultSink().thenApply(ExpressionRewrite.this::applyRewriteToSink) + .toRule(RuleType.REWRITE_SINK_EXPRESSION); + } + } + + private LogicalSink applyRewriteToSink(MatchingContext> ctx) { + LogicalSink sink = ctx.root; + ExpressionRewriteContext context = new ExpressionRewriteContext(ctx.cascadesContext); + List outputExprs = sink.getOutputExprs(); + List result = rewriteAll(outputExprs, rewriter, context); + return sink.withOutputExprs(result); + } + + /** LogicalRepeatRewrite */ + public class LogicalRepeatRewrite extends OneRewriteRuleFactory { @Override public Rule build() { return logicalRepeat().thenApply(ctx -> { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ConvertOuterJoinToAntiJoin.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ConvertOuterJoinToAntiJoin.java index c9185fd1a3..4644557305 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ConvertOuterJoinToAntiJoin.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ConvertOuterJoinToAntiJoin.java @@ -17,9 +17,9 @@ package org.apache.doris.nereids.rules.rewrite; -import org.apache.doris.nereids.rules.Rule; -import org.apache.doris.nereids.rules.RuleType; +import org.apache.doris.nereids.jobs.JobContext; import org.apache.doris.nereids.trees.expressions.Alias; +import org.apache.doris.nereids.trees.expressions.ExprId; import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.expressions.literal.NullLiteral; @@ -28,9 +28,14 @@ import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalFilter; import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; import org.apache.doris.nereids.trees.plans.logical.LogicalProject; +import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter; +import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter; import org.apache.doris.nereids.util.TypeUtils; -import java.util.List; +import com.google.common.collect.ImmutableList; + +import java.util.HashMap; +import java.util.Map; import java.util.Set; import java.util.stream.Collectors; @@ -42,18 +47,41 @@ import java.util.stream.Collectors; * project(A.*) * - LeftAntiJoin(A, B) */ -public class ConvertOuterJoinToAntiJoin extends OneRewriteRuleFactory { +public class ConvertOuterJoinToAntiJoin extends DefaultPlanRewriter> implements CustomRewriter { + private ExprIdRewriter exprIdReplacer; @Override - public Rule build() { - return logicalFilter(logicalJoin() - .when(join -> join.getJoinType().isOuterJoin())) - .then(this::toAntiJoin) - .toRule(RuleType.CONVERT_OUTER_JOIN_TO_ANTI); + public Plan rewriteRoot(Plan plan, JobContext jobContext) { + if (!plan.containsType(LogicalJoin.class)) { + return plan; + } + Map replaceMap = new HashMap<>(); + ExprIdRewriter.ReplaceRule replaceRule = new ExprIdRewriter.ReplaceRule(replaceMap); + exprIdReplacer = new ExprIdRewriter(replaceRule, jobContext); + return plan.accept(this, replaceMap); } - private Plan toAntiJoin(LogicalFilter> filter) { + @Override + public Plan visit(Plan plan, Map replaceMap) { + plan = visitChildren(this, plan, replaceMap); + plan = exprIdReplacer.rewriteExpr(plan, replaceMap); + return plan; + } + + @Override + public Plan visitLogicalFilter(LogicalFilter filter, Map replaceMap) { + filter = (LogicalFilter) visit(filter, replaceMap); + if (!(filter.child() instanceof LogicalJoin)) { + return filter; + } + return toAntiJoin((LogicalFilter>) filter, replaceMap); + } + + private Plan toAntiJoin(LogicalFilter> filter, Map replaceMap) { LogicalJoin join = filter.child(); + if (!join.getJoinType().isLeftOuterJoin() && !join.getJoinType().isRightOuterJoin()) { + return filter; + } Set alwaysNullSlots = filter.getConjuncts().stream() .filter(p -> TypeUtils.isNull(p).isPresent()) @@ -66,33 +94,37 @@ public class ConvertOuterJoinToAntiJoin extends OneRewriteRuleFactory { .filter(s -> alwaysNullSlots.contains(s) && !s.nullable()) .collect(Collectors.toSet()); - Plan newJoin = null; + Plan newChild = null; if (join.getJoinType().isLeftOuterJoin() && !rightAlwaysNullSlots.isEmpty()) { - newJoin = join.withJoinTypeAndContext(JoinType.LEFT_ANTI_JOIN, join.getJoinReorderContext()); + newChild = join.withJoinTypeAndContext(JoinType.LEFT_ANTI_JOIN, join.getJoinReorderContext()); } if (join.getJoinType().isRightOuterJoin() && !leftAlwaysNullSlots.isEmpty()) { - newJoin = join.withJoinTypeAndContext(JoinType.RIGHT_ANTI_JOIN, join.getJoinReorderContext()); + newChild = join.withJoinTypeAndContext(JoinType.RIGHT_ANTI_JOIN, join.getJoinReorderContext()); } - if (newJoin == null) { - return null; + if (newChild == null) { + return filter; } - if (!newJoin.getOutputSet().containsAll(filter.getInputSlots())) { + if (!newChild.getOutputSet().containsAll(filter.getInputSlots())) { // if there are slots that don't belong to join output, we use null alias to replace them // such as: // project(A.id, null as B.id) // - (A left anti join B) - Set joinOutput = newJoin.getOutputSet(); - List projects = filter.getOutput().stream() - .map(s -> { - if (joinOutput.contains(s)) { - return s; - } else { - return new Alias(s.getExprId(), new NullLiteral(s.getDataType()), s.getName()); - } - }).collect(Collectors.toList()); - newJoin = new LogicalProject<>(projects, newJoin); + Set joinOutputs = newChild.getOutputSet(); + ImmutableList.Builder projectsBuilder = ImmutableList.builder(); + for (NamedExpression e : filter.getOutput()) { + if (joinOutputs.contains(e)) { + projectsBuilder.add(e); + } else { + Alias newAlias = new Alias(new NullLiteral(e.getDataType()), e.getName(), e.getQualifier()); + replaceMap.put(e.getExprId(), newAlias.getExprId()); + projectsBuilder.add(newAlias); + } + } + newChild = new LogicalProject<>(projectsBuilder.build(), newChild); + return exprIdReplacer.rewriteExpr(filter.withChildren(newChild), replaceMap); + } else { + return filter.withChildren(newChild); } - return filter.withChildren(newJoin); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ExprIdRewriter.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ExprIdRewriter.java new file mode 100644 index 0000000000..9e42491cac --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ExprIdRewriter.java @@ -0,0 +1,106 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.rewrite; + +import org.apache.doris.nereids.jobs.JobContext; +import org.apache.doris.nereids.pattern.Pattern; +import org.apache.doris.nereids.rules.Rule; +import org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher; +import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory; +import org.apache.doris.nereids.rules.expression.ExpressionRewrite; +import org.apache.doris.nereids.rules.expression.ExpressionRuleExecutor; +import org.apache.doris.nereids.trees.expressions.ExprId; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.expressions.SlotReference; +import org.apache.doris.nereids.trees.plans.Plan; + +import com.google.common.collect.ImmutableList; + +import java.util.List; +import java.util.Map; + +/** replace SlotReference ExprId in logical plans */ +public class ExprIdRewriter extends ExpressionRewrite { + private final List rules; + private final JobContext jobContext; + + public ExprIdRewriter(ReplaceRule replaceRule, JobContext jobContext) { + super(new ExpressionRuleExecutor(ImmutableList.of(bottomUp(replaceRule)))); + rules = buildRules(); + this.jobContext = jobContext; + } + + /**rewriteExpr*/ + public Plan rewriteExpr(Plan plan, Map replaceMap) { + if (replaceMap.isEmpty()) { + return plan; + } + for (Rule rule : rules) { + Pattern pattern = (Pattern) rule.getPattern(); + if (pattern.matchPlanTree(plan)) { + List newPlans = rule.transform(plan, jobContext.getCascadesContext()); + Plan newPlan = newPlans.get(0); + if (!newPlan.deepEquals(plan)) { + return newPlan; + } + } + } + return plan; + } + + /** + * Iteratively rewrites IDs using the replaceMap: + * 1. For a given SlotReference with initial ID, retrieve the corresponding value ID from the replaceMap. + * 2. If the value ID exists within the replaceMap, continue the lookup process using the value ID + * until it no longer appears in the replaceMap. + * 3. return SlotReference final value ID as the result of the rewrite. + * e.g. replaceMap:{0:3, 1:6, 6:7} + * SlotReference:a#0 -> a#3, a#1 -> a#7 + * */ + public static class ReplaceRule implements ExpressionPatternRuleFactory { + private final Map replaceMap; + + public ReplaceRule(Map replaceMap) { + this.replaceMap = replaceMap; + } + + @Override + public List> buildRules() { + return ImmutableList.of( + matchesType(SlotReference.class).thenApply(ctx -> { + Slot slot = ctx.expr; + + ExprId newId = replaceMap.get(slot.getExprId()); + if (newId == null) { + return slot; + } + ExprId lastId = newId; + while (true) { + newId = replaceMap.get(lastId); + if (newId == null) { + return slot.withExprId(lastId); + } else { + lastId = newId; + } + } + }) + ); + } + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/StatementScopeIdGenerator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/StatementScopeIdGenerator.java index df7ef2ab69..cf0ecc3cb9 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/StatementScopeIdGenerator.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/StatementScopeIdGenerator.java @@ -81,6 +81,6 @@ public class StatementScopeIdGenerator { if (ConnectContext.get() != null) { ConnectContext.get().setStatementContext(new StatementContext()); } - statementContext = new StatementContext(); + statementContext = new StatementContext(10000); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalCTEConsumer.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalCTEConsumer.java index 415fdddf80..6148f62378 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalCTEConsumer.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalCTEConsumer.java @@ -198,4 +198,21 @@ public class LogicalCTEConsumer extends LogicalRelation implements BlockFuncDeps "relationId", relationId, "name", name); } + + @Override + public boolean equals(Object o) { + if (o == null || getClass() != o.getClass()) { + return false; + } + if (!super.equals(o)) { + return false; + } + LogicalCTEConsumer that = (LogicalCTEConsumer) o; + return Objects.equals(consumerToProducerOutputMap, that.consumerToProducerOutputMap); + } + + @Override + public int hashCode() { + return super.hashCode(); + } } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/OuterJoinAssocTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/OuterJoinAssocTest.java index c3beb8fc11..9f86f31eb0 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/OuterJoinAssocTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/OuterJoinAssocTest.java @@ -28,6 +28,7 @@ import org.apache.doris.nereids.util.MemoPatternMatchSupported; import org.apache.doris.nereids.util.MemoTestUtils; import org.apache.doris.nereids.util.PlanChecker; import org.apache.doris.nereids.util.PlanConstructor; +import org.apache.doris.qe.ConnectContext; import com.google.common.collect.ImmutableList; import org.junit.jupiter.api.Assertions; @@ -40,22 +41,22 @@ class OuterJoinAssocTest implements MemoPatternMatchSupported { LogicalOlapScan scan2; LogicalOlapScan scan3; - public OuterJoinAssocTest() throws Exception { - // clear id so that slot id keep consistent every running + @Test + public void testInnerLeft() throws Exception { + ConnectContext ctx = MemoTestUtils.createConnectContext(); StatementScopeIdGenerator.clear(); + + // clear id so that slot id keep consistent every running scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0); scan2 = PlanConstructor.newLogicalOlapScan(1, "t2", 0); scan3 = PlanConstructor.newLogicalOlapScan(2, "t3", 0); - } - @Test - public void testInnerLeft() { LogicalPlan join = new LogicalPlanBuilder(scan1) .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0)) // t1.id = t2.id .join(scan3, JoinType.LEFT_OUTER_JOIN, Pair.of(2, 0)) // t2.id = t3.id .build(); - PlanChecker.from(MemoTestUtils.createConnectContext(), join) + PlanChecker.from(ctx, join) .applyExploration(OuterJoinAssoc.INSTANCE.build()) .matchesExploration( logicalJoin( @@ -66,13 +67,21 @@ class OuterJoinAssocTest implements MemoPatternMatchSupported { } @Test - public void testLeftLeft() { + public void testLeftLeft() throws Exception { + ConnectContext ctx = MemoTestUtils.createConnectContext(); + StatementScopeIdGenerator.clear(); + + // clear id so that slot id keep consistent every running + scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0); + scan2 = PlanConstructor.newLogicalOlapScan(1, "t2", 0); + scan3 = PlanConstructor.newLogicalOlapScan(2, "t3", 0); + LogicalPlan join = new LogicalPlanBuilder(scan1) .join(scan2, JoinType.LEFT_OUTER_JOIN, Pair.of(0, 0)) // t1.id = t2.id .join(scan3, JoinType.LEFT_OUTER_JOIN, Pair.of(2, 0)) // t2.id = t3.id .build(); - PlanChecker.from(MemoTestUtils.createConnectContext(), join) + PlanChecker.from(ctx, join) .applyExploration(OuterJoinAssoc.INSTANCE.build()) .matchesExploration( logicalJoin( @@ -83,14 +92,22 @@ class OuterJoinAssocTest implements MemoPatternMatchSupported { } @Test - public void rejectNull() { + public void rejectNull() throws Exception { + ConnectContext ctx = MemoTestUtils.createConnectContext(); + StatementScopeIdGenerator.clear(); + + // clear id so that slot id keep consistent every running + scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0); + scan2 = PlanConstructor.newLogicalOlapScan(1, "t2", 0); + scan3 = PlanConstructor.newLogicalOlapScan(2, "t3", 0); + IsNull isNull = new IsNull(scan3.getOutput().get(0)); LogicalPlan join = new LogicalPlanBuilder(scan1) .join(scan2, JoinType.LEFT_OUTER_JOIN, Pair.of(0, 0)) // t1.id = t2.id .join(scan3, JoinType.LEFT_OUTER_JOIN, ImmutableList.of(), ImmutableList.of(isNull)) // t3.id is not null .build(); - PlanChecker.from(MemoTestUtils.createConnectContext(), join) + PlanChecker.from(ctx, join) .applyExploration(OuterJoinAssoc.INSTANCE.build()) .checkMemo(memo -> Assertions.assertEquals(1, memo.getRoot().getLogicalExpressions().size())); } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/ConvertOuterJoinToAntiJoinTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/ConvertOuterJoinToAntiJoinTest.java index 1159fc2a7c..b3166c2224 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/ConvertOuterJoinToAntiJoinTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/ConvertOuterJoinToAntiJoinTest.java @@ -32,17 +32,21 @@ import org.apache.doris.nereids.util.MemoPatternMatchSupported; import org.apache.doris.nereids.util.MemoTestUtils; import org.apache.doris.nereids.util.PlanChecker; import org.apache.doris.nereids.util.PlanConstructor; +import org.apache.doris.qe.ConnectContext; import com.google.common.collect.ImmutableList; import com.google.common.collect.Sets; +import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; class ConvertOuterJoinToAntiJoinTest implements MemoPatternMatchSupported { - private final LogicalOlapScan scan1; - private final LogicalOlapScan scan2; + private LogicalOlapScan scan1; + private LogicalOlapScan scan2; - public ConvertOuterJoinToAntiJoinTest() throws Exception { + @BeforeEach + void setUp() throws Exception { // clear id so that slot id keep consistent every running + ConnectContext.remove(); StatementScopeIdGenerator.clear(); scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0); scan2 = PlanConstructor.newLogicalOlapScan(1, "t2", 0); @@ -58,7 +62,7 @@ class ConvertOuterJoinToAntiJoinTest implements MemoPatternMatchSupported { PlanChecker.from(MemoTestUtils.createConnectContext(), plan) .applyTopDown(new InferFilterNotNull()) - .applyTopDown(new ConvertOuterJoinToAntiJoin()) + .applyCustom(new ConvertOuterJoinToAntiJoin()) .printlnTree() .matches(logicalJoin().when(join -> join.getJoinType().isLeftAntiJoin())); } @@ -73,7 +77,7 @@ class ConvertOuterJoinToAntiJoinTest implements MemoPatternMatchSupported { PlanChecker.from(MemoTestUtils.createConnectContext(), plan) .applyTopDown(new InferFilterNotNull()) - .applyTopDown(new ConvertOuterJoinToAntiJoin()) + .applyCustom(new ConvertOuterJoinToAntiJoin()) .printlnTree() .matches(logicalJoin().when(join -> join.getJoinType().isRightAntiJoin())); } @@ -91,7 +95,7 @@ class ConvertOuterJoinToAntiJoinTest implements MemoPatternMatchSupported { PlanChecker.from(MemoTestUtils.createConnectContext(), plan) .applyTopDown(new InferFilterNotNull()) - .applyTopDown(new ConvertOuterJoinToAntiJoin()) + .applyCustom(new ConvertOuterJoinToAntiJoin()) .printlnTree() .matches(logicalJoin().when(join -> join.getJoinType().isLeftAntiJoin())); } @@ -109,7 +113,7 @@ class ConvertOuterJoinToAntiJoinTest implements MemoPatternMatchSupported { PlanChecker.from(MemoTestUtils.createConnectContext(), plan) .applyTopDown(new InferFilterNotNull()) - .applyTopDown(new ConvertOuterJoinToAntiJoin()) + .applyCustom(new ConvertOuterJoinToAntiJoin()) .printlnTree() .matches(logicalJoin().when(join -> join.getJoinType().isLeftAntiJoin())); } @@ -127,7 +131,7 @@ class ConvertOuterJoinToAntiJoinTest implements MemoPatternMatchSupported { PlanChecker.from(MemoTestUtils.createConnectContext(), plan) .applyTopDown(new InferFilterNotNull()) - .applyTopDown(new ConvertOuterJoinToAntiJoin()) + .applyCustom(new ConvertOuterJoinToAntiJoin()) .printlnTree() .matches(logicalJoin().when(join -> join.getJoinType().isLeftOuterJoin())); } @@ -146,7 +150,7 @@ class ConvertOuterJoinToAntiJoinTest implements MemoPatternMatchSupported { PlanChecker.from(MemoTestUtils.createConnectContext(), plan) .applyTopDown(new InferFilterNotNull()) - .applyTopDown(new ConvertOuterJoinToAntiJoin()) + .applyCustom(new ConvertOuterJoinToAntiJoin()) .printlnTree() .matches(logicalJoin().when(join -> join.getJoinType().isLeftOuterJoin())); } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/EliminateOuterJoinTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/EliminateOuterJoinTest.java index 255f1e82e0..f003441016 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/EliminateOuterJoinTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/EliminateOuterJoinTest.java @@ -21,6 +21,7 @@ import org.apache.doris.common.Pair; import org.apache.doris.nereids.trees.expressions.And; import org.apache.doris.nereids.trees.expressions.GreaterThan; import org.apache.doris.nereids.trees.expressions.StatementScopeIdGenerator; +import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral; import org.apache.doris.nereids.trees.expressions.literal.Literal; import org.apache.doris.nereids.trees.plans.JoinType; import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; @@ -69,7 +70,7 @@ class EliminateOuterJoinTest implements MemoPatternMatchSupported { void testEliminateRight() { LogicalPlan plan = new LogicalPlanBuilder(scan1) .join(scan2, JoinType.RIGHT_OUTER_JOIN, Pair.of(0, 0)) // t1.id = t2.id - .filter(new GreaterThan(scan1.getOutput().get(0), Literal.of(1))) + .filter(new GreaterThan(scan1.getOutput().get(0), new IntegerLiteral(1))) .build(); PlanChecker.from(MemoTestUtils.createConnectContext(), plan) @@ -81,7 +82,7 @@ class EliminateOuterJoinTest implements MemoPatternMatchSupported { logicalFilter( logicalJoin().when(join -> join.getJoinType().isInnerJoin()) ).when(filter -> filter.getConjuncts().size() == 1) - .when(filter -> Objects.equals(filter.getConjuncts().toString(), "[(id#0 > 1)]")) + .when(filter -> Objects.equals(filter.getConjuncts().iterator().next(), new GreaterThan(scan1.getOutput().get(0), new IntegerLiteral(1)))) ); } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/PlanChecker.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/PlanChecker.java index 71d0f0101b..6962572d07 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/PlanChecker.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/PlanChecker.java @@ -32,6 +32,7 @@ import org.apache.doris.nereids.jobs.cascades.DeriveStatsJob; import org.apache.doris.nereids.jobs.executor.Optimizer; import org.apache.doris.nereids.jobs.executor.Rewriter; import org.apache.doris.nereids.jobs.joinorder.JoinOrderJob; +import org.apache.doris.nereids.jobs.rewrite.CustomRewriteJob; import org.apache.doris.nereids.jobs.rewrite.PlanTreeRewriteBottomUpJob; import org.apache.doris.nereids.jobs.rewrite.PlanTreeRewriteTopDownJob; import org.apache.doris.nereids.jobs.rewrite.RootPlanTreeRewriteJob; @@ -202,6 +203,14 @@ public class PlanChecker { return this; } + public PlanChecker applyCustom(CustomRewriter customRewriter) { + CustomRewriteJob customRewriteJob = new CustomRewriteJob(() -> customRewriter, RuleType.TEST_REWRITE); + customRewriteJob.execute(cascadesContext.getCurrentJobContext()); + cascadesContext.toMemo(); + MemoValidator.validate(cascadesContext.getMemo()); + return this; + } + /** * apply a top down rewrite rule if you not care the ruleId * diff --git a/regression-test/data/nereids_p0/expression/fold_constant/fold_constant_by_be.out b/regression-test/data/nereids_p0/expression/fold_constant/fold_constant_by_be.out index 13cfa81870..8d9d704684 100644 --- a/regression-test/data/nereids_p0/expression/fold_constant/fold_constant_by_be.out +++ b/regression-test/data/nereids_p0/expression/fold_constant/fold_constant_by_be.out @@ -8,21 +8,6 @@ C2BD89103557CCBF7ED97B51860225A0 -- !sql_1 -- 80000 --- !sql -- -PLAN FRAGMENT 0 - OUTPUT EXPRS: - sleep(100)[#0] - PARTITION: UNPARTITIONED - - HAS_COLO_PLAN_NODE: false - - VRESULT SINK - MYSQL_PROTOCAL - - 0:VUNION(32) - constant exprs: - sleep(100) - -- !sql -- true diff --git a/regression-test/suites/nereids_p0/expression/fold_constant/fold_constant_by_be.groovy b/regression-test/suites/nereids_p0/expression/fold_constant/fold_constant_by_be.groovy index 8d64ac671c..52aec52400 100644 --- a/regression-test/suites/nereids_p0/expression/fold_constant/fold_constant_by_be.groovy +++ b/regression-test/suites/nereids_p0/expression/fold_constant/fold_constant_by_be.groovy @@ -49,7 +49,11 @@ suite("fold_constant_by_be") { log.info("result: {}, {}", res1, res2) assertEquals(res1[0][0], res2[0][0]) - qt_sql "explain select sleep(sign(1)*100);" + explain { + sql "select sleep(sign(1)*100);" + contains "sleep(100)" + } + sql 'set query_timeout=12;' qt_sql "select sleep(sign(1)*10);" diff --git a/regression-test/suites/nereids_syntax_p0/transform_outer_join_to_anti.groovy b/regression-test/suites/nereids_syntax_p0/transform_outer_join_to_anti.groovy index ccbb8fd64a..f806f4ce5c 100644 --- a/regression-test/suites/nereids_syntax_p0/transform_outer_join_to_anti.groovy +++ b/regression-test/suites/nereids_syntax_p0/transform_outer_join_to_anti.groovy @@ -84,4 +84,12 @@ suite("transform_outer_join_to_anti") { sql("select * from eliminate_outer_join_A right outer join eliminate_outer_join_B on eliminate_outer_join_B.b = eliminate_outer_join_A.a where eliminate_outer_join_A.a is null and eliminate_outer_join_B.null_b is null and eliminate_outer_join_A.null_a is null") contains "ANTI JOIN" } + + explain { + sql """with temp as ( + select * from eliminate_outer_join_A left outer join eliminate_outer_join_B on eliminate_outer_join_B.b = eliminate_outer_join_A.a where eliminate_outer_join_B.b is null + ) + select * from temp t1 join temp t2""" + contains "ANTI JOIN" + } }