diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/batch/NereidsRewriter.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/batch/NereidsRewriter.java index e2e99594a3..3302c6ff06 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/batch/NereidsRewriter.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/batch/NereidsRewriter.java @@ -260,10 +260,11 @@ public class NereidsRewriter extends BatchRewriteJob { ), // this rule batch must keep at the end of rewrite to do some plan check - topic("Final rewrite and check", bottomUp( - new AdjustNullable(), - new ExpressionRewrite(CheckLegalityAfterRewrite.INSTANCE), - new CheckAfterRewrite() + topic("Final rewrite and check", + custom(RuleType.ADJUST_NULLABLE, AdjustNullable::new), + bottomUp( + new ExpressionRewrite(CheckLegalityAfterRewrite.INSTANCE), + new CheckAfterRewrite() )) ); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java index b376824e0d..c910eb4381 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java @@ -206,17 +206,7 @@ public enum RuleType { PUSH_LIMIT_THROUGH_UNION(RuleTypeClass.REWRITE), PUSH_LIMIT_INTO_SORT(RuleTypeClass.REWRITE), // adjust nullable - ADJUST_NULLABLE_ON_AGGREGATE(RuleTypeClass.REWRITE), - ADJUST_NULLABLE_ON_ASSERT_NUM_ROWS(RuleTypeClass.REWRITE), - ADJUST_NULLABLE_ON_FILTER(RuleTypeClass.REWRITE), - ADJUST_NULLABLE_ON_GENERATE(RuleTypeClass.REWRITE), - ADJUST_NULLABLE_ON_JOIN(RuleTypeClass.REWRITE), - ADJUST_NULLABLE_ON_LIMIT(RuleTypeClass.REWRITE), - ADJUST_NULLABLE_ON_PROJECT(RuleTypeClass.REWRITE), - ADJUST_NULLABLE_ON_REPEAT(RuleTypeClass.REWRITE), - ADJUST_NULLABLE_ON_SET_OPERATION(RuleTypeClass.REWRITE), - ADJUST_NULLABLE_ON_SORT(RuleTypeClass.REWRITE), - ADJUST_NULLABLE_ON_TOP_N(RuleTypeClass.REWRITE), + ADJUST_NULLABLE(RuleTypeClass.REWRITE), REWRITE_SENTINEL(RuleTypeClass.REWRITE), diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/AdjustNullable.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/AdjustNullable.java index dac174720d..55c5c2be08 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/AdjustNullable.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/AdjustNullable.java @@ -17,10 +17,8 @@ package org.apache.doris.nereids.rules.rewrite.logical; +import org.apache.doris.nereids.jobs.JobContext; import org.apache.doris.nereids.properties.OrderKey; -import org.apache.doris.nereids.rules.Rule; -import org.apache.doris.nereids.rules.RuleType; -import org.apache.doris.nereids.rules.rewrite.RewriteRuleFactory; import org.apache.doris.nereids.trees.expressions.ExprId; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.NamedExpression; @@ -30,8 +28,19 @@ import org.apache.doris.nereids.trees.expressions.functions.ExpressionTrait; import org.apache.doris.nereids.trees.expressions.functions.Function; import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter; import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; import org.apache.doris.nereids.trees.plans.logical.LogicalFilter; +import org.apache.doris.nereids.trees.plans.logical.LogicalGenerate; +import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; import org.apache.doris.nereids.trees.plans.logical.LogicalPlan; +import org.apache.doris.nereids.trees.plans.logical.LogicalProject; +import org.apache.doris.nereids.trees.plans.logical.LogicalRepeat; +import org.apache.doris.nereids.trees.plans.logical.LogicalSetOperation; +import org.apache.doris.nereids.trees.plans.logical.LogicalSort; +import org.apache.doris.nereids.trees.plans.logical.LogicalTopN; +import org.apache.doris.nereids.trees.plans.logical.LogicalWindow; +import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter; +import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; @@ -46,87 +55,126 @@ import java.util.stream.Collectors; * because some rule could change output's nullable. * So, we need add a rule to adjust all expression's nullable attribute after rewrite. */ -public class AdjustNullable implements RewriteRuleFactory { +public class AdjustNullable extends DefaultPlanRewriter implements CustomRewriter { + @Override - public List buildRules() { - return ImmutableList.of( - RuleType.ADJUST_NULLABLE_ON_AGGREGATE.build(logicalAggregate().then(aggregate -> { - Map exprIdSlotMap = collectChildrenOutputMap(aggregate); - List newOutputs - = updateExpressions(aggregate.getOutputExpressions(), exprIdSlotMap); - List newGroupExpressions - = updateExpressions(aggregate.getGroupByExpressions(), exprIdSlotMap); - return aggregate.withGroupByAndOutput(newGroupExpressions, newOutputs).recomputeLogicalProperties(); - })), - RuleType.ADJUST_NULLABLE_ON_ASSERT_NUM_ROWS.build(logicalAssertNumRows().then( - LogicalPlan::recomputeLogicalProperties)), - RuleType.ADJUST_NULLABLE_ON_FILTER.build(logicalFilter().then(filter -> { - Map exprIdSlotMap = collectChildrenOutputMap(filter); - Set conjuncts = updateExpressions(filter.getConjuncts(), exprIdSlotMap); - return new LogicalFilter<>(conjuncts, filter.child()); - })), - RuleType.ADJUST_NULLABLE_ON_GENERATE.build(logicalGenerate().then(generate -> { - Map exprIdSlotMap = collectChildrenOutputMap(generate); - List newGenerators = updateExpressions(generate.getGenerators(), exprIdSlotMap); - return generate.withGenerators(newGenerators).recomputeLogicalProperties(); - })), - RuleType.ADJUST_NULLABLE_ON_JOIN.build(logicalJoin().then(join -> { - Map exprIdSlotMap = collectChildrenOutputMap(join); - List hashConjuncts = updateExpressions(join.getHashJoinConjuncts(), exprIdSlotMap); - List otherConjuncts = updateExpressions(join.getOtherJoinConjuncts(), exprIdSlotMap); - return join.withJoinConjuncts(hashConjuncts, otherConjuncts).recomputeLogicalProperties(); - })), - RuleType.ADJUST_NULLABLE_ON_LIMIT.build(logicalLimit().then(LogicalPlan::recomputeLogicalProperties)), - RuleType.ADJUST_NULLABLE_ON_PROJECT.build(logicalProject().then(project -> { - Map exprIdSlotMap = collectChildrenOutputMap(project); - List newProjects = updateExpressions(project.getProjects(), exprIdSlotMap); - return project.withProjects(newProjects); - })), - RuleType.ADJUST_NULLABLE_ON_REPEAT.build(logicalRepeat().then(repeat -> { - Map exprIdSlotMap = collectChildrenOutputMap(repeat); - List newOutputs = updateExpressions(repeat.getOutputExpressions(), exprIdSlotMap); - List> newGroupingSets = repeat.getGroupingSets().stream() - .map(l -> updateExpressions(l, exprIdSlotMap)) - .collect(ImmutableList.toImmutableList()); - return repeat.withGroupSetsAndOutput(newGroupingSets, newOutputs).recomputeLogicalProperties(); - })), - RuleType.ADJUST_NULLABLE_ON_SET_OPERATION.build(logicalSetOperation().then(setOperation -> { - List inputNullable = setOperation.child(0).getOutput().stream() - .map(ExpressionTrait::nullable).collect(Collectors.toList()); - for (int i = 1; i < setOperation.arity(); i++) { - List childOutput = setOperation.getChildOutput(i); - for (int j = 0; j < childOutput.size(); j++) { - if (childOutput.get(j).nullable()) { - inputNullable.set(j, true); - } - } - } - List outputs = setOperation.getOutputs(); - List newOutputs = Lists.newArrayListWithCapacity(outputs.size()); - for (int i = 0; i < inputNullable.size(); i++) { - SlotReference slotReference = (SlotReference) outputs.get(i); - if (inputNullable.get(i)) { - slotReference = slotReference.withNullable(true); - } - newOutputs.add(slotReference); - } - return setOperation.withNewOutputs(newOutputs).recomputeLogicalProperties(); - })), - RuleType.ADJUST_NULLABLE_ON_SORT.build(logicalSort().then(sort -> { - Map exprIdSlotMap = collectChildrenOutputMap(sort); - List newKeys = sort.getOrderKeys().stream() - .map(old -> old.withExpression(updateExpression(old.getExpr(), exprIdSlotMap))) - .collect(ImmutableList.toImmutableList()); - return sort.withOrderKeys(newKeys).recomputeLogicalProperties(); - })), - RuleType.ADJUST_NULLABLE_ON_TOP_N.build(logicalTopN().then(topN -> { - Map exprIdSlotMap = collectChildrenOutputMap(topN); - List newKeys = topN.getOrderKeys().stream() - .map(old -> old.withExpression(updateExpression(old.getExpr(), exprIdSlotMap))) - .collect(ImmutableList.toImmutableList()); - return topN.withOrderKeys(newKeys).recomputeLogicalProperties(); - })) - ); + public Plan rewriteRoot(Plan plan, JobContext jobContext) { + return plan.accept(this, null); + } + + @Override + public Plan visit(Plan plan, Void context) { + LogicalPlan logicalPlan = (LogicalPlan) super.visit(plan, context); + return logicalPlan.recomputeLogicalProperties(); + } + + @Override + public Plan visitLogicalAggregate(LogicalAggregate aggregate, Void context) { + aggregate = (LogicalAggregate) super.visit(aggregate, context); + Map exprIdSlotMap = collectChildrenOutputMap(aggregate); + List newOutputs + = updateExpressions(aggregate.getOutputExpressions(), exprIdSlotMap); + List newGroupExpressions + = updateExpressions(aggregate.getGroupByExpressions(), exprIdSlotMap); + return aggregate.withGroupByAndOutput(newGroupExpressions, newOutputs); + } + + @Override + public Plan visitLogicalFilter(LogicalFilter filter, Void context) { + filter = (LogicalFilter) super.visit(filter, context); + Map exprIdSlotMap = collectChildrenOutputMap(filter); + Set conjuncts = updateExpressions(filter.getConjuncts(), exprIdSlotMap); + return filter.withConjuncts(conjuncts).recomputeLogicalProperties(); + } + + @Override + public Plan visitLogicalGenerate(LogicalGenerate generate, Void context) { + generate = (LogicalGenerate) super.visit(generate, context); + Map exprIdSlotMap = collectChildrenOutputMap(generate); + List newGenerators = updateExpressions(generate.getGenerators(), exprIdSlotMap); + return generate.withGenerators(newGenerators).recomputeLogicalProperties(); + } + + @Override + public Plan visitLogicalJoin(LogicalJoin join, Void context) { + join = (LogicalJoin) super.visit(join, context); + Map exprIdSlotMap = collectChildrenOutputMap(join); + List hashConjuncts = updateExpressions(join.getHashJoinConjuncts(), exprIdSlotMap); + List otherConjuncts = updateExpressions(join.getOtherJoinConjuncts(), exprIdSlotMap); + return join.withJoinConjuncts(hashConjuncts, otherConjuncts).recomputeLogicalProperties(); + } + + @Override + public Plan visitLogicalProject(LogicalProject project, Void context) { + project = (LogicalProject) super.visit(project, context); + Map exprIdSlotMap = collectChildrenOutputMap(project); + List newProjects = updateExpressions(project.getProjects(), exprIdSlotMap); + return project.withProjects(newProjects); + } + + @Override + public Plan visitLogicalRepeat(LogicalRepeat repeat, Void context) { + repeat = (LogicalRepeat) super.visit(repeat, context); + Map exprIdSlotMap = collectChildrenOutputMap(repeat); + List newOutputs = updateExpressions(repeat.getOutputExpressions(), exprIdSlotMap); + List> newGroupingSets = repeat.getGroupingSets().stream() + .map(l -> updateExpressions(l, exprIdSlotMap)) + .collect(ImmutableList.toImmutableList()); + return repeat.withGroupSetsAndOutput(newGroupingSets, newOutputs).recomputeLogicalProperties(); + } + + @Override + public Plan visitLogicalSetOperation(LogicalSetOperation setOperation, Void context) { + setOperation = (LogicalSetOperation) super.visit(setOperation, context); + List inputNullable = setOperation.child(0).getOutput().stream() + .map(ExpressionTrait::nullable).collect(Collectors.toList()); + for (int i = 1; i < setOperation.arity(); i++) { + List childOutput = setOperation.getChildOutput(i); + for (int j = 0; j < childOutput.size(); j++) { + if (childOutput.get(j).nullable()) { + inputNullable.set(j, true); + } + } + } + List outputs = setOperation.getOutputs(); + List newOutputs = Lists.newArrayListWithCapacity(outputs.size()); + for (int i = 0; i < inputNullable.size(); i++) { + SlotReference slotReference = (SlotReference) outputs.get(i); + if (inputNullable.get(i)) { + slotReference = slotReference.withNullable(true); + } + newOutputs.add(slotReference); + } + return setOperation.withNewOutputs(newOutputs).recomputeLogicalProperties(); + } + + @Override + public Plan visitLogicalSort(LogicalSort sort, Void context) { + sort = (LogicalSort) super.visit(sort, context); + Map exprIdSlotMap = collectChildrenOutputMap(sort); + List newKeys = sort.getOrderKeys().stream() + .map(old -> old.withExpression(updateExpression(old.getExpr(), exprIdSlotMap))) + .collect(ImmutableList.toImmutableList()); + return sort.withOrderKeys(newKeys).recomputeLogicalProperties(); + } + + @Override + public Plan visitLogicalTopN(LogicalTopN topN, Void context) { + topN = (LogicalTopN) super.visit(topN, context); + Map exprIdSlotMap = collectChildrenOutputMap(topN); + List newKeys = topN.getOrderKeys().stream() + .map(old -> old.withExpression(updateExpression(old.getExpr(), exprIdSlotMap))) + .collect(ImmutableList.toImmutableList()); + return topN.withOrderKeys(newKeys).recomputeLogicalProperties(); + } + + @Override + public Plan visitLogicalWindow(LogicalWindow window, Void context) { + window = (LogicalWindow) super.visit(window, context); + Map exprIdSlotMap = collectChildrenOutputMap(window); + List windowExpressions = + updateExpressions(window.getWindowExpressions(), exprIdSlotMap); + return window.withExpression(windowExpressions, window.child()); } private T updateExpression(T input, Map exprIdSlotMap) { @@ -153,7 +201,11 @@ public class AdjustNullable implements RewriteRuleFactory { @Override public Expression visitSlotReference(SlotReference slotReference, Map context) { - return context.getOrDefault(slotReference.getExprId(), slotReference); + if (context.containsKey(slotReference.getExprId())) { + return slotReference.withNullable(context.get(slotReference.getExprId()).nullable()); + } else { + return slotReference; + } } } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalFilter.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalFilter.java index 28bcccfe14..8ddcdca828 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalFilter.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalFilter.java @@ -111,19 +111,23 @@ public class LogicalFilter extends LogicalUnary withConjuncts(Set conjuncts) { + return new LogicalFilter<>(conjuncts, Optional.empty(), Optional.of(getLogicalProperties()), child()); + } + @Override - public LogicalUnary withChildren(List children) { + public LogicalFilter withChildren(List children) { Preconditions.checkArgument(children.size() == 1); return new LogicalFilter<>(conjuncts, children.get(0)); } @Override - public Plan withGroupExpression(Optional groupExpression) { + public LogicalFilter withGroupExpression(Optional groupExpression) { return new LogicalFilter<>(conjuncts, groupExpression, Optional.of(getLogicalProperties()), child()); } @Override - public Plan withLogicalProperties(Optional logicalProperties) { + public LogicalFilter withLogicalProperties(Optional logicalProperties) { return new LogicalFilter<>(conjuncts, Optional.empty(), logicalProperties, child()); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalWindow.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalWindow.java index c29f9aba7a..3e4a61b3f2 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalWindow.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalWindow.java @@ -76,8 +76,14 @@ public class LogicalWindow extends LogicalUnary windowExpressions, Plan child) { - return new LogicalWindow(windowExpressions, true, Optional.empty(), Optional.empty(), child); + public LogicalWindow withExpression(List windowExpressions, Plan child) { + return new LogicalWindow<>(windowExpressions, isChecked, Optional.empty(), + Optional.empty(), child); + } + + public LogicalWindow withChecked(List windowExpressions, Plan child) { + return new LogicalWindow<>(windowExpressions, true, Optional.empty(), + Optional.of(getLogicalProperties()), child); } @Override diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/visitor/PlanVisitor.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/visitor/PlanVisitor.java index 3cda38d6fb..1f7e97ab07 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/visitor/PlanVisitor.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/visitor/PlanVisitor.java @@ -227,9 +227,8 @@ public abstract class PlanVisitor { return visit(having, context); } - public R visitLogicalSetOperation( - LogicalSetOperation logicalSetOperation, C context) { - return visit(logicalSetOperation, context); + public R visitLogicalSetOperation(LogicalSetOperation setOperation, C context) { + return visit(setOperation, context); } public R visitLogicalUnion(LogicalUnion union, C context) { diff --git a/regression-test/suites/nereids_syntax_p0/window_function.groovy b/regression-test/suites/nereids_syntax_p0/window_function.groovy index bec8b93585..8833e526eb 100644 --- a/regression-test/suites/nereids_syntax_p0/window_function.groovy +++ b/regression-test/suites/nereids_syntax_p0/window_function.groovy @@ -31,6 +31,25 @@ suite("window_function") { "replication_allocation" = "tag.location.default: 1" ); """ + + sql """ + create table adj_nullable_1 ( + c1 int, + c2 int, + c3 int + ) distributed by hash(c1) + properties('replication_num'='1'); + """ + + sql """ + create table adj_nullable_2 ( + c4 int not null, + c5 int not null, + c6 int not null + ) distributed by hash(c4) + properties('replication_num'='1'); + """ + sql """INSERT INTO window_test VALUES(1, 1, 1)""" sql """INSERT INTO window_test VALUES(1, 2, 1)""" sql """INSERT INTO window_test VALUES(1, 3, 1)""" @@ -46,6 +65,9 @@ suite("window_function") { sql """INSERT INTO window_test VALUES(1, null, 3)""" sql """INSERT INTO window_test VALUES(2, null, 3)""" + sql """insert into adj_nullable_1 values(1, 1, 1);""" + sql """insert into adj_nullable_2 values(1, 1, 1);""" + sql "SET enable_fallback_to_original_planner=false" order_qt_empty_over "SELECT rank() over() FROM window_test" @@ -148,4 +170,15 @@ suite("window_function") { from window_test group by c1, c2 """ + + // test adjust nullable on window + sql """ + select + count(c1) over (partition by c4), + coalesce(c5, sum(c2) over (partition by c3)) + from + adj_nullable_1 + left join adj_nullable_2 on c1 = c4 + where c6 is not null; + """ }