branch-2.1: [fix](Nereids) not generate duplicate exprid after convert outer to anti rule #52798 (#53901)

cherry picked from #52798
This commit is contained in:
morrySnow
2025-08-01 11:34:07 +08:00
committed by GitHub
parent 8523fdeba3
commit 4a062c4908
15 changed files with 446 additions and 69 deletions

View File

@ -47,6 +47,7 @@ import org.apache.doris.nereids.rules.rewrite.CollectCteConsumerOutput;
import org.apache.doris.nereids.rules.rewrite.CollectFilterAboveConsumer;
import org.apache.doris.nereids.rules.rewrite.ColumnPruning;
import org.apache.doris.nereids.rules.rewrite.ConvertInnerOrCrossJoin;
import org.apache.doris.nereids.rules.rewrite.ConvertOuterJoinToAntiJoin;
import org.apache.doris.nereids.rules.rewrite.CountDistinctRewrite;
import org.apache.doris.nereids.rules.rewrite.CountLiteralRewrite;
import org.apache.doris.nereids.rules.rewrite.CreatePartitionTopNFromWindow;
@ -443,6 +444,7 @@ public class Rewriter extends AbstractBatchJobExecutor {
ImmutableSet.of(LogicalCTEAnchor.class),
() -> jobs(
// after variant sub path pruning, we need do column pruning again
bottomUp(RuleSet.PUSH_DOWN_FILTERS),
custom(RuleType.COLUMN_PRUNING, ColumnPruning::new),
bottomUp(ImmutableList.of(
new PushDownFilterThroughProject(),
@ -531,6 +533,8 @@ public class Rewriter extends AbstractBatchJobExecutor {
topic("rewrite cte sub-tree before sub path push down",
custom(RuleType.REWRITE_CTE_CHILDREN, () -> new RewriteCteChildren(beforePushDownJobs))
)));
rewriteJobs.addAll(jobs(topic("convert outer join to anti",
custom(RuleType.CONVERT_OUTER_JOIN_TO_ANTI, ConvertOuterJoinToAntiJoin::new))));
if (needOrExpansion) {
rewriteJobs.addAll(jobs(topic("or expansion",
custom(RuleType.OR_EXPANSION, () -> OrExpansion.INSTANCE))));

View File

@ -96,7 +96,6 @@ import org.apache.doris.nereids.rules.implementation.LogicalTVFRelationToPhysica
import org.apache.doris.nereids.rules.implementation.LogicalTopNToPhysicalTopN;
import org.apache.doris.nereids.rules.implementation.LogicalUnionToPhysicalUnion;
import org.apache.doris.nereids.rules.implementation.LogicalWindowToPhysicalWindow;
import org.apache.doris.nereids.rules.rewrite.ConvertOuterJoinToAntiJoin;
import org.apache.doris.nereids.rules.rewrite.CreatePartitionTopNFromWindow;
import org.apache.doris.nereids.rules.rewrite.EliminateFilter;
import org.apache.doris.nereids.rules.rewrite.EliminateOuterJoin;
@ -162,7 +161,6 @@ public class RuleSet {
new PushDownFilterThroughGenerate(),
new PushDownProjectThroughLimit(),
new EliminateOuterJoin(),
new ConvertOuterJoinToAntiJoin(),
new MergeProjects(),
new MergeFilters(),
new MergeGenerates(),

View File

@ -213,6 +213,12 @@ public enum RuleType {
REWRITE_SORT_EXPRESSION(RuleTypeClass.REWRITE),
REWRITE_HAVING_EXPRESSION(RuleTypeClass.REWRITE),
REWRITE_REPEAT_EXPRESSION(RuleTypeClass.REWRITE),
REWRITE_OLAP_TABLE_SINK_EXPRESSION(RuleTypeClass.REWRITE),
REWRITE_SINK_EXPRESSION(RuleTypeClass.REWRITE),
REWRITE_WINDOW_EXPRESSION(RuleTypeClass.REWRITE),
REWRITE_SET_OPERATION_EXPRESSION(RuleTypeClass.REWRITE),
REWRITE_PARTITION_TOPN_EXPRESSION(RuleTypeClass.REWRITE),
REWRITE_TOPN_EXPRESSION(RuleTypeClass.REWRITE),
EXTRACT_FILTER_FROM_JOIN(RuleTypeClass.REWRITE),
REORDER_JOIN(RuleTypeClass.REWRITE),
MERGE_PERCENTILE_TO_ARRAY(RuleTypeClass.REWRITE),

View File

@ -20,6 +20,7 @@ package org.apache.doris.nereids.rules.expression;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.pattern.ExpressionPatternRules;
import org.apache.doris.nereids.pattern.ExpressionPatternTraverseListeners;
import org.apache.doris.nereids.pattern.MatchingContext;
import org.apache.doris.nereids.properties.OrderKey;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
@ -28,25 +29,39 @@ import org.apache.doris.nereids.rules.rewrite.RewriteRuleFactory;
import org.apache.doris.nereids.trees.expressions.EqualPredicate;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.OrderExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.functions.Function;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalCTEConsumer;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalGenerate;
import org.apache.doris.nereids.trees.plans.logical.LogicalHaving;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalOneRowRelation;
import org.apache.doris.nereids.trees.plans.logical.LogicalPartitionTopN;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.logical.LogicalRepeat;
import org.apache.doris.nereids.trees.plans.logical.LogicalSetOperation;
import org.apache.doris.nereids.trees.plans.logical.LogicalSink;
import org.apache.doris.nereids.trees.plans.logical.LogicalSort;
import org.apache.doris.nereids.trees.plans.logical.LogicalTopN;
import org.apache.doris.nereids.trees.plans.logical.LogicalWindow;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.Utils;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableList.Builder;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableMultimap;
import com.google.common.collect.ImmutableSet;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
@ -79,7 +94,19 @@ public class ExpressionRewrite implements RewriteRuleFactory {
new JoinExpressionRewrite().build(),
new SortExpressionRewrite().build(),
new LogicalRepeatRewrite().build(),
new HavingExpressionRewrite().build());
new HavingExpressionRewrite().build(),
new LogicalPartitionTopNExpressionRewrite().build(),
new LogicalTopNExpressionRewrite().build(),
new LogicalSetOperationRewrite().build(),
new LogicalWindowRewrite().build(),
new LogicalCteConsumerRewrite().build(),
new LogicalResultSinkRewrite().build(),
new LogicalFileSinkRewrite().build(),
new LogicalHiveTableSinkRewrite().build(),
new LogicalIcebergTableSinkRewrite().build(),
new LogicalJdbcTableSinkRewrite().build(),
new LogicalOlapTableSinkRewrite().build(),
new LogicalDeferMaterializeResultSinkRewrite().build());
}
private class GenerateExpressionRewrite extends OneRewriteRuleFactory {
@ -264,7 +291,166 @@ public class ExpressionRewrite implements RewriteRuleFactory {
}
}
private class LogicalRepeatRewrite extends OneRewriteRuleFactory {
private class LogicalWindowRewrite extends OneRewriteRuleFactory {
@Override
public Rule build() {
return logicalWindow().thenApply(ctx -> {
LogicalWindow<Plan> window = ctx.root;
List<NamedExpression> windowExpressions = window.getWindowExpressions();
ExpressionRewriteContext context = new ExpressionRewriteContext(ctx.cascadesContext);
List<NamedExpression> result = rewriteAll(windowExpressions, rewriter, context);
return window.withExpressionsAndChild(result, window.child());
})
.toRule(RuleType.REWRITE_WINDOW_EXPRESSION);
}
}
private class LogicalSetOperationRewrite extends OneRewriteRuleFactory {
@Override
public Rule build() {
return logicalSetOperation().thenApply(ctx -> {
LogicalSetOperation setOperation = ctx.root;
List<List<SlotReference>> slotsList = setOperation.getRegularChildrenOutputs();
List<List<SlotReference>> newSlotsList = new ArrayList<>();
ExpressionRewriteContext context = new ExpressionRewriteContext(ctx.cascadesContext);
for (List<SlotReference> slots : slotsList) {
List<SlotReference> result = rewriteAll(slots, rewriter, context);
newSlotsList.add(result);
}
return setOperation.withChildrenAndTheirOutputs(setOperation.children(), newSlotsList);
})
.toRule(RuleType.REWRITE_SET_OPERATION_EXPRESSION);
}
}
private class LogicalTopNExpressionRewrite extends OneRewriteRuleFactory {
@Override
public Rule build() {
return logicalTopN().thenApply(ctx -> {
LogicalTopN<Plan> topN = ctx.root;
List<OrderKey> orderKeys = topN.getOrderKeys();
ImmutableList.Builder<OrderKey> rewrittenOrderKeys
= ImmutableList.builderWithExpectedSize(orderKeys.size());
ExpressionRewriteContext context = new ExpressionRewriteContext(ctx.cascadesContext);
boolean changed = false;
for (OrderKey k : orderKeys) {
Expression expression = rewriter.rewrite(k.getExpr(), context);
changed |= expression != k.getExpr();
rewrittenOrderKeys.add(new OrderKey(expression, k.isAsc(), k.isNullFirst()));
}
return changed ? topN.withOrderKeys(rewrittenOrderKeys.build()) : topN;
}).toRule(RuleType.REWRITE_TOPN_EXPRESSION);
}
}
private class LogicalPartitionTopNExpressionRewrite extends OneRewriteRuleFactory {
@Override
public Rule build() {
return logicalPartitionTopN().thenApply(ctx -> {
LogicalPartitionTopN<Plan> partitionTopN = ctx.root;
ExpressionRewriteContext context = new ExpressionRewriteContext(ctx.cascadesContext);
List<OrderExpression> newOrderExpressions = new ArrayList<>();
for (OrderExpression orderExpression : partitionTopN.getOrderKeys()) {
OrderKey orderKey = orderExpression.getOrderKey();
Expression expr = rewriter.rewrite(orderKey.getExpr(), context);
OrderKey newOrderKey = new OrderKey(expr, orderKey.isAsc(), orderKey.isNullFirst());
newOrderExpressions.add(new OrderExpression(newOrderKey));
}
List<Expression> result = rewriteAll(partitionTopN.getPartitionKeys(), rewriter, context);
return partitionTopN.withPartitionKeysAndOrderKeys(result, newOrderExpressions);
}).toRule(RuleType.REWRITE_PARTITION_TOPN_EXPRESSION);
}
}
private class LogicalCteConsumerRewrite extends OneRewriteRuleFactory {
@Override
public Rule build() {
return logicalCTEConsumer().thenApply(ctx -> {
LogicalCTEConsumer consumer = ctx.root;
boolean changed = false;
ExpressionRewriteContext context = new ExpressionRewriteContext(ctx.cascadesContext);
ImmutableMap.Builder<Slot, Slot> cToPBuilder = ImmutableMap.builder();
ImmutableMultimap.Builder<Slot, Slot> pToCBuilder = ImmutableMultimap.builder();
for (Map.Entry<Slot, Slot> entry : consumer.getConsumerToProducerOutputMap().entrySet()) {
Slot key = (Slot) rewriter.rewrite(entry.getKey(), context);
Slot value = (Slot) rewriter.rewrite(entry.getValue(), context);
cToPBuilder.put(key, value);
pToCBuilder.put(value, key);
if (!key.equals(entry.getKey()) || !value.equals(entry.getValue())) {
changed = true;
}
}
return changed ? consumer.withTwoMaps(cToPBuilder.build(), pToCBuilder.build()) : consumer;
}).toRule(RuleType.REWRITE_TOPN_EXPRESSION);
}
}
private class LogicalResultSinkRewrite extends OneRewriteRuleFactory {
@Override
public Rule build() {
return logicalResultSink().thenApply(ExpressionRewrite.this::applyRewriteToSink)
.toRule(RuleType.REWRITE_SINK_EXPRESSION);
}
}
private class LogicalFileSinkRewrite extends OneRewriteRuleFactory {
@Override
public Rule build() {
return logicalFileSink().thenApply(ExpressionRewrite.this::applyRewriteToSink)
.toRule(RuleType.REWRITE_SINK_EXPRESSION);
}
}
private class LogicalHiveTableSinkRewrite extends OneRewriteRuleFactory {
@Override
public Rule build() {
return logicalHiveTableSink().thenApply(ExpressionRewrite.this::applyRewriteToSink)
.toRule(RuleType.REWRITE_SINK_EXPRESSION);
}
}
private class LogicalIcebergTableSinkRewrite extends OneRewriteRuleFactory {
@Override
public Rule build() {
return logicalIcebergTableSink().thenApply(ExpressionRewrite.this::applyRewriteToSink)
.toRule(RuleType.REWRITE_SINK_EXPRESSION);
}
}
private class LogicalJdbcTableSinkRewrite extends OneRewriteRuleFactory {
@Override
public Rule build() {
return logicalJdbcTableSink().thenApply(ExpressionRewrite.this::applyRewriteToSink)
.toRule(RuleType.REWRITE_SINK_EXPRESSION);
}
}
private class LogicalOlapTableSinkRewrite extends OneRewriteRuleFactory {
@Override
public Rule build() {
return logicalOlapTableSink().thenApply(ExpressionRewrite.this::applyRewriteToSink)
.toRule(RuleType.REWRITE_SINK_EXPRESSION);
}
}
private class LogicalDeferMaterializeResultSinkRewrite extends OneRewriteRuleFactory {
@Override
public Rule build() {
return logicalDeferMaterializeResultSink().thenApply(ExpressionRewrite.this::applyRewriteToSink)
.toRule(RuleType.REWRITE_SINK_EXPRESSION);
}
}
private LogicalSink<Plan> applyRewriteToSink(MatchingContext<? extends LogicalSink<Plan>> ctx) {
LogicalSink<Plan> sink = ctx.root;
ExpressionRewriteContext context = new ExpressionRewriteContext(ctx.cascadesContext);
List<NamedExpression> outputExprs = sink.getOutputExprs();
List<NamedExpression> result = rewriteAll(outputExprs, rewriter, context);
return sink.withOutputExprs(result);
}
/** LogicalRepeatRewrite */
public class LogicalRepeatRewrite extends OneRewriteRuleFactory {
@Override
public Rule build() {
return logicalRepeat().thenApply(ctx -> {

View File

@ -17,9 +17,9 @@
package org.apache.doris.nereids.rules.rewrite;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.jobs.JobContext;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
@ -28,9 +28,14 @@ import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter;
import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter;
import org.apache.doris.nereids.util.TypeUtils;
import java.util.List;
import com.google.common.collect.ImmutableList;
import java.util.HashMap;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
@ -42,18 +47,41 @@ import java.util.stream.Collectors;
* project(A.*)
* - LeftAntiJoin(A, B)
*/
public class ConvertOuterJoinToAntiJoin extends OneRewriteRuleFactory {
public class ConvertOuterJoinToAntiJoin extends DefaultPlanRewriter<Map<ExprId, ExprId>> implements CustomRewriter {
private ExprIdRewriter exprIdReplacer;
@Override
public Rule build() {
return logicalFilter(logicalJoin()
.when(join -> join.getJoinType().isOuterJoin()))
.then(this::toAntiJoin)
.toRule(RuleType.CONVERT_OUTER_JOIN_TO_ANTI);
public Plan rewriteRoot(Plan plan, JobContext jobContext) {
if (!plan.containsType(LogicalJoin.class)) {
return plan;
}
Map<ExprId, ExprId> replaceMap = new HashMap<>();
ExprIdRewriter.ReplaceRule replaceRule = new ExprIdRewriter.ReplaceRule(replaceMap);
exprIdReplacer = new ExprIdRewriter(replaceRule, jobContext);
return plan.accept(this, replaceMap);
}
private Plan toAntiJoin(LogicalFilter<LogicalJoin<Plan, Plan>> filter) {
@Override
public Plan visit(Plan plan, Map<ExprId, ExprId> replaceMap) {
plan = visitChildren(this, plan, replaceMap);
plan = exprIdReplacer.rewriteExpr(plan, replaceMap);
return plan;
}
@Override
public Plan visitLogicalFilter(LogicalFilter<? extends Plan> filter, Map<ExprId, ExprId> replaceMap) {
filter = (LogicalFilter<? extends Plan>) visit(filter, replaceMap);
if (!(filter.child() instanceof LogicalJoin)) {
return filter;
}
return toAntiJoin((LogicalFilter<LogicalJoin<Plan, Plan>>) filter, replaceMap);
}
private Plan toAntiJoin(LogicalFilter<LogicalJoin<Plan, Plan>> filter, Map<ExprId, ExprId> replaceMap) {
LogicalJoin<Plan, Plan> join = filter.child();
if (!join.getJoinType().isLeftOuterJoin() && !join.getJoinType().isRightOuterJoin()) {
return filter;
}
Set<Slot> alwaysNullSlots = filter.getConjuncts().stream()
.filter(p -> TypeUtils.isNull(p).isPresent())
@ -66,33 +94,37 @@ public class ConvertOuterJoinToAntiJoin extends OneRewriteRuleFactory {
.filter(s -> alwaysNullSlots.contains(s) && !s.nullable())
.collect(Collectors.toSet());
Plan newJoin = null;
Plan newChild = null;
if (join.getJoinType().isLeftOuterJoin() && !rightAlwaysNullSlots.isEmpty()) {
newJoin = join.withJoinTypeAndContext(JoinType.LEFT_ANTI_JOIN, join.getJoinReorderContext());
newChild = join.withJoinTypeAndContext(JoinType.LEFT_ANTI_JOIN, join.getJoinReorderContext());
}
if (join.getJoinType().isRightOuterJoin() && !leftAlwaysNullSlots.isEmpty()) {
newJoin = join.withJoinTypeAndContext(JoinType.RIGHT_ANTI_JOIN, join.getJoinReorderContext());
newChild = join.withJoinTypeAndContext(JoinType.RIGHT_ANTI_JOIN, join.getJoinReorderContext());
}
if (newJoin == null) {
return null;
if (newChild == null) {
return filter;
}
if (!newJoin.getOutputSet().containsAll(filter.getInputSlots())) {
if (!newChild.getOutputSet().containsAll(filter.getInputSlots())) {
// if there are slots that don't belong to join output, we use null alias to replace them
// such as:
// project(A.id, null as B.id)
// - (A left anti join B)
Set<Slot> joinOutput = newJoin.getOutputSet();
List<NamedExpression> projects = filter.getOutput().stream()
.map(s -> {
if (joinOutput.contains(s)) {
return s;
} else {
return new Alias(s.getExprId(), new NullLiteral(s.getDataType()), s.getName());
}
}).collect(Collectors.toList());
newJoin = new LogicalProject<>(projects, newJoin);
Set<Slot> joinOutputs = newChild.getOutputSet();
ImmutableList.Builder<NamedExpression> projectsBuilder = ImmutableList.builder();
for (NamedExpression e : filter.getOutput()) {
if (joinOutputs.contains(e)) {
projectsBuilder.add(e);
} else {
Alias newAlias = new Alias(new NullLiteral(e.getDataType()), e.getName(), e.getQualifier());
replaceMap.put(e.getExprId(), newAlias.getExprId());
projectsBuilder.add(newAlias);
}
}
newChild = new LogicalProject<>(projectsBuilder.build(), newChild);
return exprIdReplacer.rewriteExpr(filter.withChildren(newChild), replaceMap);
} else {
return filter.withChildren(newChild);
}
return filter.withChildren(newJoin);
}
}

View File

@ -0,0 +1,106 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.rewrite;
import org.apache.doris.nereids.jobs.JobContext;
import org.apache.doris.nereids.pattern.Pattern;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher;
import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory;
import org.apache.doris.nereids.rules.expression.ExpressionRewrite;
import org.apache.doris.nereids.rules.expression.ExpressionRuleExecutor;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.plans.Plan;
import com.google.common.collect.ImmutableList;
import java.util.List;
import java.util.Map;
/** replace SlotReference ExprId in logical plans */
public class ExprIdRewriter extends ExpressionRewrite {
private final List<Rule> rules;
private final JobContext jobContext;
public ExprIdRewriter(ReplaceRule replaceRule, JobContext jobContext) {
super(new ExpressionRuleExecutor(ImmutableList.of(bottomUp(replaceRule))));
rules = buildRules();
this.jobContext = jobContext;
}
/**rewriteExpr*/
public Plan rewriteExpr(Plan plan, Map<ExprId, ExprId> replaceMap) {
if (replaceMap.isEmpty()) {
return plan;
}
for (Rule rule : rules) {
Pattern<Plan> pattern = (Pattern<Plan>) rule.getPattern();
if (pattern.matchPlanTree(plan)) {
List<Plan> newPlans = rule.transform(plan, jobContext.getCascadesContext());
Plan newPlan = newPlans.get(0);
if (!newPlan.deepEquals(plan)) {
return newPlan;
}
}
}
return plan;
}
/**
* Iteratively rewrites IDs using the replaceMap:
* 1. For a given SlotReference with initial ID, retrieve the corresponding value ID from the replaceMap.
* 2. If the value ID exists within the replaceMap, continue the lookup process using the value ID
* until it no longer appears in the replaceMap.
* 3. return SlotReference final value ID as the result of the rewrite.
* e.g. replaceMap:{0:3, 1:6, 6:7}
* SlotReference:a#0 -> a#3, a#1 -> a#7
* */
public static class ReplaceRule implements ExpressionPatternRuleFactory {
private final Map<ExprId, ExprId> replaceMap;
public ReplaceRule(Map<ExprId, ExprId> replaceMap) {
this.replaceMap = replaceMap;
}
@Override
public List<ExpressionPatternMatcher<? extends Expression>> buildRules() {
return ImmutableList.of(
matchesType(SlotReference.class).thenApply(ctx -> {
Slot slot = ctx.expr;
ExprId newId = replaceMap.get(slot.getExprId());
if (newId == null) {
return slot;
}
ExprId lastId = newId;
while (true) {
newId = replaceMap.get(lastId);
if (newId == null) {
return slot.withExprId(lastId);
} else {
lastId = newId;
}
}
})
);
}
}
}

View File

@ -81,6 +81,6 @@ public class StatementScopeIdGenerator {
if (ConnectContext.get() != null) {
ConnectContext.get().setStatementContext(new StatementContext());
}
statementContext = new StatementContext();
statementContext = new StatementContext(10000);
}
}

View File

@ -198,4 +198,21 @@ public class LogicalCTEConsumer extends LogicalRelation implements BlockFuncDeps
"relationId", relationId,
"name", name);
}
@Override
public boolean equals(Object o) {
if (o == null || getClass() != o.getClass()) {
return false;
}
if (!super.equals(o)) {
return false;
}
LogicalCTEConsumer that = (LogicalCTEConsumer) o;
return Objects.equals(consumerToProducerOutputMap, that.consumerToProducerOutputMap);
}
@Override
public int hashCode() {
return super.hashCode();
}
}

View File

@ -28,6 +28,7 @@ import org.apache.doris.nereids.util.MemoPatternMatchSupported;
import org.apache.doris.nereids.util.MemoTestUtils;
import org.apache.doris.nereids.util.PlanChecker;
import org.apache.doris.nereids.util.PlanConstructor;
import org.apache.doris.qe.ConnectContext;
import com.google.common.collect.ImmutableList;
import org.junit.jupiter.api.Assertions;
@ -40,22 +41,22 @@ class OuterJoinAssocTest implements MemoPatternMatchSupported {
LogicalOlapScan scan2;
LogicalOlapScan scan3;
public OuterJoinAssocTest() throws Exception {
// clear id so that slot id keep consistent every running
@Test
public void testInnerLeft() throws Exception {
ConnectContext ctx = MemoTestUtils.createConnectContext();
StatementScopeIdGenerator.clear();
// clear id so that slot id keep consistent every running
scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0);
scan2 = PlanConstructor.newLogicalOlapScan(1, "t2", 0);
scan3 = PlanConstructor.newLogicalOlapScan(2, "t3", 0);
}
@Test
public void testInnerLeft() {
LogicalPlan join = new LogicalPlanBuilder(scan1)
.join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0)) // t1.id = t2.id
.join(scan3, JoinType.LEFT_OUTER_JOIN, Pair.of(2, 0)) // t2.id = t3.id
.build();
PlanChecker.from(MemoTestUtils.createConnectContext(), join)
PlanChecker.from(ctx, join)
.applyExploration(OuterJoinAssoc.INSTANCE.build())
.matchesExploration(
logicalJoin(
@ -66,13 +67,21 @@ class OuterJoinAssocTest implements MemoPatternMatchSupported {
}
@Test
public void testLeftLeft() {
public void testLeftLeft() throws Exception {
ConnectContext ctx = MemoTestUtils.createConnectContext();
StatementScopeIdGenerator.clear();
// clear id so that slot id keep consistent every running
scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0);
scan2 = PlanConstructor.newLogicalOlapScan(1, "t2", 0);
scan3 = PlanConstructor.newLogicalOlapScan(2, "t3", 0);
LogicalPlan join = new LogicalPlanBuilder(scan1)
.join(scan2, JoinType.LEFT_OUTER_JOIN, Pair.of(0, 0)) // t1.id = t2.id
.join(scan3, JoinType.LEFT_OUTER_JOIN, Pair.of(2, 0)) // t2.id = t3.id
.build();
PlanChecker.from(MemoTestUtils.createConnectContext(), join)
PlanChecker.from(ctx, join)
.applyExploration(OuterJoinAssoc.INSTANCE.build())
.matchesExploration(
logicalJoin(
@ -83,14 +92,22 @@ class OuterJoinAssocTest implements MemoPatternMatchSupported {
}
@Test
public void rejectNull() {
public void rejectNull() throws Exception {
ConnectContext ctx = MemoTestUtils.createConnectContext();
StatementScopeIdGenerator.clear();
// clear id so that slot id keep consistent every running
scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0);
scan2 = PlanConstructor.newLogicalOlapScan(1, "t2", 0);
scan3 = PlanConstructor.newLogicalOlapScan(2, "t3", 0);
IsNull isNull = new IsNull(scan3.getOutput().get(0));
LogicalPlan join = new LogicalPlanBuilder(scan1)
.join(scan2, JoinType.LEFT_OUTER_JOIN, Pair.of(0, 0)) // t1.id = t2.id
.join(scan3, JoinType.LEFT_OUTER_JOIN, ImmutableList.of(), ImmutableList.of(isNull)) // t3.id is not null
.build();
PlanChecker.from(MemoTestUtils.createConnectContext(), join)
PlanChecker.from(ctx, join)
.applyExploration(OuterJoinAssoc.INSTANCE.build())
.checkMemo(memo -> Assertions.assertEquals(1, memo.getRoot().getLogicalExpressions().size()));
}

View File

@ -32,17 +32,21 @@ import org.apache.doris.nereids.util.MemoPatternMatchSupported;
import org.apache.doris.nereids.util.MemoTestUtils;
import org.apache.doris.nereids.util.PlanChecker;
import org.apache.doris.nereids.util.PlanConstructor;
import org.apache.doris.qe.ConnectContext;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Sets;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
class ConvertOuterJoinToAntiJoinTest implements MemoPatternMatchSupported {
private final LogicalOlapScan scan1;
private final LogicalOlapScan scan2;
private LogicalOlapScan scan1;
private LogicalOlapScan scan2;
public ConvertOuterJoinToAntiJoinTest() throws Exception {
@BeforeEach
void setUp() throws Exception {
// clear id so that slot id keep consistent every running
ConnectContext.remove();
StatementScopeIdGenerator.clear();
scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0);
scan2 = PlanConstructor.newLogicalOlapScan(1, "t2", 0);
@ -58,7 +62,7 @@ class ConvertOuterJoinToAntiJoinTest implements MemoPatternMatchSupported {
PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
.applyTopDown(new InferFilterNotNull())
.applyTopDown(new ConvertOuterJoinToAntiJoin())
.applyCustom(new ConvertOuterJoinToAntiJoin())
.printlnTree()
.matches(logicalJoin().when(join -> join.getJoinType().isLeftAntiJoin()));
}
@ -73,7 +77,7 @@ class ConvertOuterJoinToAntiJoinTest implements MemoPatternMatchSupported {
PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
.applyTopDown(new InferFilterNotNull())
.applyTopDown(new ConvertOuterJoinToAntiJoin())
.applyCustom(new ConvertOuterJoinToAntiJoin())
.printlnTree()
.matches(logicalJoin().when(join -> join.getJoinType().isRightAntiJoin()));
}
@ -91,7 +95,7 @@ class ConvertOuterJoinToAntiJoinTest implements MemoPatternMatchSupported {
PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
.applyTopDown(new InferFilterNotNull())
.applyTopDown(new ConvertOuterJoinToAntiJoin())
.applyCustom(new ConvertOuterJoinToAntiJoin())
.printlnTree()
.matches(logicalJoin().when(join -> join.getJoinType().isLeftAntiJoin()));
}
@ -109,7 +113,7 @@ class ConvertOuterJoinToAntiJoinTest implements MemoPatternMatchSupported {
PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
.applyTopDown(new InferFilterNotNull())
.applyTopDown(new ConvertOuterJoinToAntiJoin())
.applyCustom(new ConvertOuterJoinToAntiJoin())
.printlnTree()
.matches(logicalJoin().when(join -> join.getJoinType().isLeftAntiJoin()));
}
@ -127,7 +131,7 @@ class ConvertOuterJoinToAntiJoinTest implements MemoPatternMatchSupported {
PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
.applyTopDown(new InferFilterNotNull())
.applyTopDown(new ConvertOuterJoinToAntiJoin())
.applyCustom(new ConvertOuterJoinToAntiJoin())
.printlnTree()
.matches(logicalJoin().when(join -> join.getJoinType().isLeftOuterJoin()));
}
@ -146,7 +150,7 @@ class ConvertOuterJoinToAntiJoinTest implements MemoPatternMatchSupported {
PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
.applyTopDown(new InferFilterNotNull())
.applyTopDown(new ConvertOuterJoinToAntiJoin())
.applyCustom(new ConvertOuterJoinToAntiJoin())
.printlnTree()
.matches(logicalJoin().when(join -> join.getJoinType().isLeftOuterJoin()));
}

View File

@ -21,6 +21,7 @@ import org.apache.doris.common.Pair;
import org.apache.doris.nereids.trees.expressions.And;
import org.apache.doris.nereids.trees.expressions.GreaterThan;
import org.apache.doris.nereids.trees.expressions.StatementScopeIdGenerator;
import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
@ -69,7 +70,7 @@ class EliminateOuterJoinTest implements MemoPatternMatchSupported {
void testEliminateRight() {
LogicalPlan plan = new LogicalPlanBuilder(scan1)
.join(scan2, JoinType.RIGHT_OUTER_JOIN, Pair.of(0, 0)) // t1.id = t2.id
.filter(new GreaterThan(scan1.getOutput().get(0), Literal.of(1)))
.filter(new GreaterThan(scan1.getOutput().get(0), new IntegerLiteral(1)))
.build();
PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
@ -81,7 +82,7 @@ class EliminateOuterJoinTest implements MemoPatternMatchSupported {
logicalFilter(
logicalJoin().when(join -> join.getJoinType().isInnerJoin())
).when(filter -> filter.getConjuncts().size() == 1)
.when(filter -> Objects.equals(filter.getConjuncts().toString(), "[(id#0 > 1)]"))
.when(filter -> Objects.equals(filter.getConjuncts().iterator().next(), new GreaterThan(scan1.getOutput().get(0), new IntegerLiteral(1))))
);
}

View File

@ -32,6 +32,7 @@ import org.apache.doris.nereids.jobs.cascades.DeriveStatsJob;
import org.apache.doris.nereids.jobs.executor.Optimizer;
import org.apache.doris.nereids.jobs.executor.Rewriter;
import org.apache.doris.nereids.jobs.joinorder.JoinOrderJob;
import org.apache.doris.nereids.jobs.rewrite.CustomRewriteJob;
import org.apache.doris.nereids.jobs.rewrite.PlanTreeRewriteBottomUpJob;
import org.apache.doris.nereids.jobs.rewrite.PlanTreeRewriteTopDownJob;
import org.apache.doris.nereids.jobs.rewrite.RootPlanTreeRewriteJob;
@ -202,6 +203,14 @@ public class PlanChecker {
return this;
}
public PlanChecker applyCustom(CustomRewriter customRewriter) {
CustomRewriteJob customRewriteJob = new CustomRewriteJob(() -> customRewriter, RuleType.TEST_REWRITE);
customRewriteJob.execute(cascadesContext.getCurrentJobContext());
cascadesContext.toMemo();
MemoValidator.validate(cascadesContext.getMemo());
return this;
}
/**
* apply a top down rewrite rule if you not care the ruleId
*

View File

@ -8,21 +8,6 @@ C2BD89103557CCBF7ED97B51860225A0
-- !sql_1 --
80000
-- !sql --
PLAN FRAGMENT 0
OUTPUT EXPRS:
sleep(100)[#0]
PARTITION: UNPARTITIONED
HAS_COLO_PLAN_NODE: false
VRESULT SINK
MYSQL_PROTOCAL
0:VUNION(32)
constant exprs:
sleep(100)
-- !sql --
true

View File

@ -49,7 +49,11 @@ suite("fold_constant_by_be") {
log.info("result: {}, {}", res1, res2)
assertEquals(res1[0][0], res2[0][0])
qt_sql "explain select sleep(sign(1)*100);"
explain {
sql "select sleep(sign(1)*100);"
contains "sleep(100)"
}
sql 'set query_timeout=12;'
qt_sql "select sleep(sign(1)*10);"

View File

@ -84,4 +84,12 @@ suite("transform_outer_join_to_anti") {
sql("select * from eliminate_outer_join_A right outer join eliminate_outer_join_B on eliminate_outer_join_B.b = eliminate_outer_join_A.a where eliminate_outer_join_A.a is null and eliminate_outer_join_B.null_b is null and eliminate_outer_join_A.null_a is null")
contains "ANTI JOIN"
}
explain {
sql """with temp as (
select * from eliminate_outer_join_A left outer join eliminate_outer_join_B on eliminate_outer_join_B.b = eliminate_outer_join_A.a where eliminate_outer_join_B.b is null
)
select * from temp t1 join temp t2"""
contains "ANTI JOIN"
}
}