diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/OrExpansion.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/OrExpansion.java index 9f9257f5f6..e6b6c524fb 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/OrExpansion.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/OrExpansion.java @@ -104,6 +104,31 @@ public class OrExpansion extends DefaultPlanRewriter implem return hasNewChildren ? plan.withChildren(newChildren) : plan; } + @Override + public Plan visitLogicalCTEAnchor( + LogicalCTEAnchor anchor, OrExpandsionContext ctx) { + Plan child1 = this.visit(anchor.child(0), ctx); + // Consumer's CTE must be child of the cteAnchor in this case: + // anchor + // +-producer1 + // +-agg(consumer1) join agg(consumer1) + // ------------> + // anchor + // +-producer1 + // +-anchor + // +--producer2(agg2(consumer1)) + // +--producer3(agg3(consumer1)) + // +-consumer2 join consumer3 + OrExpandsionContext consumerContext = + new OrExpandsionContext(ctx.statementContext, ctx.cascadesContext); + Plan child2 = this.visit(anchor.child(1), consumerContext); + for (int i = consumerContext.cteProducerList.size() - 1; i >= 0; i--) { + LogicalCTEProducer producer = consumerContext.cteProducerList.get(i); + child2 = new LogicalCTEAnchor<>(producer.getCteId(), producer, child2); + } + return anchor.withChildren(ImmutableList.of(child1, child2)); + } + @Override public Plan visitLogicalJoin(LogicalJoin join, OrExpandsionContext ctx) { join = (LogicalJoin) this.visit(join, ctx); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/copier/LogicalPlanDeepCopier.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/copier/LogicalPlanDeepCopier.java index dfa881aa2e..197de0089a 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/copier/LogicalPlanDeepCopier.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/copier/LogicalPlanDeepCopier.java @@ -340,8 +340,14 @@ public class LogicalPlanDeepCopier extends DefaultPlanRewriter markJoinConjuncts = join.getMarkJoinConjuncts().stream() .map(c -> ExpressionDeepCopier.INSTANCE.deepCopy(c, context)) .collect(ImmutableList.toImmutableList()); + Optional markJoinSlotReference = Optional.empty(); + if (join.getMarkJoinSlotReference().isPresent()) { + markJoinSlotReference = Optional.of((MarkJoinSlotReference) ExpressionDeepCopier.INSTANCE + .deepCopy(join.getMarkJoinSlotReference().get(), context)); + + } return new LogicalJoin<>(join.getJoinType(), hashJoinConjuncts, otherJoinConjuncts, markJoinConjuncts, - join.getDistributeHint(), join.getMarkJoinSlotReference(), children, join.getJoinReorderContext()); + join.getDistributeHint(), markJoinSlotReference, children, join.getJoinReorderContext()); } @Override diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/OrExpansionTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/OrExpansionTest.java index 9f8bd8bcc5..0f2d9418bf 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/OrExpansionTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/OrExpansionTest.java @@ -19,6 +19,7 @@ package org.apache.doris.nereids.rules.rewrite; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalCTEAnchor; +import org.apache.doris.nereids.trees.plans.logical.LogicalCTEConsumer; import org.apache.doris.nereids.util.MemoPatternMatchSupported; import org.apache.doris.nereids.util.PlanChecker; import org.apache.doris.utframe.TestWithFeService; @@ -81,6 +82,9 @@ class OrExpansionTest extends TestWithFeService implements MemoPatternMatchSuppo Assertions.assertTrue(plan instanceof LogicalCTEAnchor); Assertions.assertTrue(plan.child(1) instanceof LogicalCTEAnchor); Assertions.assertTrue(plan.child(1).child(1) instanceof LogicalCTEAnchor); + Assertions.assertTrue(plan.child(1).child(1).anyMatch(x -> x instanceof LogicalCTEConsumer)); Assertions.assertTrue(plan.child(1).child(1).child(1) instanceof LogicalCTEAnchor); + Assertions.assertTrue(plan.child(1).child(1).child(1) + .anyMatch(x -> x instanceof LogicalCTEConsumer)); } }