[refactor](Nereids) refactor adjust nullable rule as a custom rewriter (#19702)

use custom rewriter to do adjust nullable to avoid nullable changed in expression but not changed in output
This commit is contained in:
Zhang Wenxin
2023-05-17 19:24:42 +08:00
committed by GitHub
parent 272a7565b8
commit bee2e2964f
7 changed files with 192 additions and 107 deletions

View File

@ -260,10 +260,11 @@ public class NereidsRewriter extends BatchRewriteJob {
),
// this rule batch must keep at the end of rewrite to do some plan check
topic("Final rewrite and check", bottomUp(
new AdjustNullable(),
new ExpressionRewrite(CheckLegalityAfterRewrite.INSTANCE),
new CheckAfterRewrite()
topic("Final rewrite and check",
custom(RuleType.ADJUST_NULLABLE, AdjustNullable::new),
bottomUp(
new ExpressionRewrite(CheckLegalityAfterRewrite.INSTANCE),
new CheckAfterRewrite()
))
);

View File

@ -206,17 +206,7 @@ public enum RuleType {
PUSH_LIMIT_THROUGH_UNION(RuleTypeClass.REWRITE),
PUSH_LIMIT_INTO_SORT(RuleTypeClass.REWRITE),
// adjust nullable
ADJUST_NULLABLE_ON_AGGREGATE(RuleTypeClass.REWRITE),
ADJUST_NULLABLE_ON_ASSERT_NUM_ROWS(RuleTypeClass.REWRITE),
ADJUST_NULLABLE_ON_FILTER(RuleTypeClass.REWRITE),
ADJUST_NULLABLE_ON_GENERATE(RuleTypeClass.REWRITE),
ADJUST_NULLABLE_ON_JOIN(RuleTypeClass.REWRITE),
ADJUST_NULLABLE_ON_LIMIT(RuleTypeClass.REWRITE),
ADJUST_NULLABLE_ON_PROJECT(RuleTypeClass.REWRITE),
ADJUST_NULLABLE_ON_REPEAT(RuleTypeClass.REWRITE),
ADJUST_NULLABLE_ON_SET_OPERATION(RuleTypeClass.REWRITE),
ADJUST_NULLABLE_ON_SORT(RuleTypeClass.REWRITE),
ADJUST_NULLABLE_ON_TOP_N(RuleTypeClass.REWRITE),
ADJUST_NULLABLE(RuleTypeClass.REWRITE),
REWRITE_SENTINEL(RuleTypeClass.REWRITE),

View File

@ -17,10 +17,8 @@
package org.apache.doris.nereids.rules.rewrite.logical;
import org.apache.doris.nereids.jobs.JobContext;
import org.apache.doris.nereids.properties.OrderKey;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.rewrite.RewriteRuleFactory;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
@ -30,8 +28,19 @@ import org.apache.doris.nereids.trees.expressions.functions.ExpressionTrait;
import org.apache.doris.nereids.trees.expressions.functions.Function;
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalGenerate;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.logical.LogicalRepeat;
import org.apache.doris.nereids.trees.plans.logical.LogicalSetOperation;
import org.apache.doris.nereids.trees.plans.logical.LogicalSort;
import org.apache.doris.nereids.trees.plans.logical.LogicalTopN;
import org.apache.doris.nereids.trees.plans.logical.LogicalWindow;
import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter;
import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
@ -46,87 +55,126 @@ import java.util.stream.Collectors;
* because some rule could change output's nullable.
* So, we need add a rule to adjust all expression's nullable attribute after rewrite.
*/
public class AdjustNullable implements RewriteRuleFactory {
public class AdjustNullable extends DefaultPlanRewriter<Void> implements CustomRewriter {
@Override
public List<Rule> buildRules() {
return ImmutableList.of(
RuleType.ADJUST_NULLABLE_ON_AGGREGATE.build(logicalAggregate().then(aggregate -> {
Map<ExprId, Slot> exprIdSlotMap = collectChildrenOutputMap(aggregate);
List<NamedExpression> newOutputs
= updateExpressions(aggregate.getOutputExpressions(), exprIdSlotMap);
List<Expression> newGroupExpressions
= updateExpressions(aggregate.getGroupByExpressions(), exprIdSlotMap);
return aggregate.withGroupByAndOutput(newGroupExpressions, newOutputs).recomputeLogicalProperties();
})),
RuleType.ADJUST_NULLABLE_ON_ASSERT_NUM_ROWS.build(logicalAssertNumRows().then(
LogicalPlan::recomputeLogicalProperties)),
RuleType.ADJUST_NULLABLE_ON_FILTER.build(logicalFilter().then(filter -> {
Map<ExprId, Slot> exprIdSlotMap = collectChildrenOutputMap(filter);
Set<Expression> conjuncts = updateExpressions(filter.getConjuncts(), exprIdSlotMap);
return new LogicalFilter<>(conjuncts, filter.child());
})),
RuleType.ADJUST_NULLABLE_ON_GENERATE.build(logicalGenerate().then(generate -> {
Map<ExprId, Slot> exprIdSlotMap = collectChildrenOutputMap(generate);
List<Function> newGenerators = updateExpressions(generate.getGenerators(), exprIdSlotMap);
return generate.withGenerators(newGenerators).recomputeLogicalProperties();
})),
RuleType.ADJUST_NULLABLE_ON_JOIN.build(logicalJoin().then(join -> {
Map<ExprId, Slot> exprIdSlotMap = collectChildrenOutputMap(join);
List<Expression> hashConjuncts = updateExpressions(join.getHashJoinConjuncts(), exprIdSlotMap);
List<Expression> otherConjuncts = updateExpressions(join.getOtherJoinConjuncts(), exprIdSlotMap);
return join.withJoinConjuncts(hashConjuncts, otherConjuncts).recomputeLogicalProperties();
})),
RuleType.ADJUST_NULLABLE_ON_LIMIT.build(logicalLimit().then(LogicalPlan::recomputeLogicalProperties)),
RuleType.ADJUST_NULLABLE_ON_PROJECT.build(logicalProject().then(project -> {
Map<ExprId, Slot> exprIdSlotMap = collectChildrenOutputMap(project);
List<NamedExpression> newProjects = updateExpressions(project.getProjects(), exprIdSlotMap);
return project.withProjects(newProjects);
})),
RuleType.ADJUST_NULLABLE_ON_REPEAT.build(logicalRepeat().then(repeat -> {
Map<ExprId, Slot> exprIdSlotMap = collectChildrenOutputMap(repeat);
List<NamedExpression> newOutputs = updateExpressions(repeat.getOutputExpressions(), exprIdSlotMap);
List<List<Expression>> newGroupingSets = repeat.getGroupingSets().stream()
.map(l -> updateExpressions(l, exprIdSlotMap))
.collect(ImmutableList.toImmutableList());
return repeat.withGroupSetsAndOutput(newGroupingSets, newOutputs).recomputeLogicalProperties();
})),
RuleType.ADJUST_NULLABLE_ON_SET_OPERATION.build(logicalSetOperation().then(setOperation -> {
List<Boolean> inputNullable = setOperation.child(0).getOutput().stream()
.map(ExpressionTrait::nullable).collect(Collectors.toList());
for (int i = 1; i < setOperation.arity(); i++) {
List<Slot> childOutput = setOperation.getChildOutput(i);
for (int j = 0; j < childOutput.size(); j++) {
if (childOutput.get(j).nullable()) {
inputNullable.set(j, true);
}
}
}
List<NamedExpression> outputs = setOperation.getOutputs();
List<NamedExpression> newOutputs = Lists.newArrayListWithCapacity(outputs.size());
for (int i = 0; i < inputNullable.size(); i++) {
SlotReference slotReference = (SlotReference) outputs.get(i);
if (inputNullable.get(i)) {
slotReference = slotReference.withNullable(true);
}
newOutputs.add(slotReference);
}
return setOperation.withNewOutputs(newOutputs).recomputeLogicalProperties();
})),
RuleType.ADJUST_NULLABLE_ON_SORT.build(logicalSort().then(sort -> {
Map<ExprId, Slot> exprIdSlotMap = collectChildrenOutputMap(sort);
List<OrderKey> newKeys = sort.getOrderKeys().stream()
.map(old -> old.withExpression(updateExpression(old.getExpr(), exprIdSlotMap)))
.collect(ImmutableList.toImmutableList());
return sort.withOrderKeys(newKeys).recomputeLogicalProperties();
})),
RuleType.ADJUST_NULLABLE_ON_TOP_N.build(logicalTopN().then(topN -> {
Map<ExprId, Slot> exprIdSlotMap = collectChildrenOutputMap(topN);
List<OrderKey> newKeys = topN.getOrderKeys().stream()
.map(old -> old.withExpression(updateExpression(old.getExpr(), exprIdSlotMap)))
.collect(ImmutableList.toImmutableList());
return topN.withOrderKeys(newKeys).recomputeLogicalProperties();
}))
);
public Plan rewriteRoot(Plan plan, JobContext jobContext) {
return plan.accept(this, null);
}
@Override
public Plan visit(Plan plan, Void context) {
LogicalPlan logicalPlan = (LogicalPlan) super.visit(plan, context);
return logicalPlan.recomputeLogicalProperties();
}
@Override
public Plan visitLogicalAggregate(LogicalAggregate<? extends Plan> aggregate, Void context) {
aggregate = (LogicalAggregate<? extends Plan>) super.visit(aggregate, context);
Map<ExprId, Slot> exprIdSlotMap = collectChildrenOutputMap(aggregate);
List<NamedExpression> newOutputs
= updateExpressions(aggregate.getOutputExpressions(), exprIdSlotMap);
List<Expression> newGroupExpressions
= updateExpressions(aggregate.getGroupByExpressions(), exprIdSlotMap);
return aggregate.withGroupByAndOutput(newGroupExpressions, newOutputs);
}
@Override
public Plan visitLogicalFilter(LogicalFilter<? extends Plan> filter, Void context) {
filter = (LogicalFilter<? extends Plan>) super.visit(filter, context);
Map<ExprId, Slot> exprIdSlotMap = collectChildrenOutputMap(filter);
Set<Expression> conjuncts = updateExpressions(filter.getConjuncts(), exprIdSlotMap);
return filter.withConjuncts(conjuncts).recomputeLogicalProperties();
}
@Override
public Plan visitLogicalGenerate(LogicalGenerate<? extends Plan> generate, Void context) {
generate = (LogicalGenerate<? extends Plan>) super.visit(generate, context);
Map<ExprId, Slot> exprIdSlotMap = collectChildrenOutputMap(generate);
List<Function> newGenerators = updateExpressions(generate.getGenerators(), exprIdSlotMap);
return generate.withGenerators(newGenerators).recomputeLogicalProperties();
}
@Override
public Plan visitLogicalJoin(LogicalJoin<? extends Plan, ? extends Plan> join, Void context) {
join = (LogicalJoin<? extends Plan, ? extends Plan>) super.visit(join, context);
Map<ExprId, Slot> exprIdSlotMap = collectChildrenOutputMap(join);
List<Expression> hashConjuncts = updateExpressions(join.getHashJoinConjuncts(), exprIdSlotMap);
List<Expression> otherConjuncts = updateExpressions(join.getOtherJoinConjuncts(), exprIdSlotMap);
return join.withJoinConjuncts(hashConjuncts, otherConjuncts).recomputeLogicalProperties();
}
@Override
public Plan visitLogicalProject(LogicalProject<? extends Plan> project, Void context) {
project = (LogicalProject<? extends Plan>) super.visit(project, context);
Map<ExprId, Slot> exprIdSlotMap = collectChildrenOutputMap(project);
List<NamedExpression> newProjects = updateExpressions(project.getProjects(), exprIdSlotMap);
return project.withProjects(newProjects);
}
@Override
public Plan visitLogicalRepeat(LogicalRepeat<? extends Plan> repeat, Void context) {
repeat = (LogicalRepeat<? extends Plan>) super.visit(repeat, context);
Map<ExprId, Slot> exprIdSlotMap = collectChildrenOutputMap(repeat);
List<NamedExpression> newOutputs = updateExpressions(repeat.getOutputExpressions(), exprIdSlotMap);
List<List<Expression>> newGroupingSets = repeat.getGroupingSets().stream()
.map(l -> updateExpressions(l, exprIdSlotMap))
.collect(ImmutableList.toImmutableList());
return repeat.withGroupSetsAndOutput(newGroupingSets, newOutputs).recomputeLogicalProperties();
}
@Override
public Plan visitLogicalSetOperation(LogicalSetOperation setOperation, Void context) {
setOperation = (LogicalSetOperation) super.visit(setOperation, context);
List<Boolean> inputNullable = setOperation.child(0).getOutput().stream()
.map(ExpressionTrait::nullable).collect(Collectors.toList());
for (int i = 1; i < setOperation.arity(); i++) {
List<Slot> childOutput = setOperation.getChildOutput(i);
for (int j = 0; j < childOutput.size(); j++) {
if (childOutput.get(j).nullable()) {
inputNullable.set(j, true);
}
}
}
List<NamedExpression> outputs = setOperation.getOutputs();
List<NamedExpression> newOutputs = Lists.newArrayListWithCapacity(outputs.size());
for (int i = 0; i < inputNullable.size(); i++) {
SlotReference slotReference = (SlotReference) outputs.get(i);
if (inputNullable.get(i)) {
slotReference = slotReference.withNullable(true);
}
newOutputs.add(slotReference);
}
return setOperation.withNewOutputs(newOutputs).recomputeLogicalProperties();
}
@Override
public Plan visitLogicalSort(LogicalSort<? extends Plan> sort, Void context) {
sort = (LogicalSort<? extends Plan>) super.visit(sort, context);
Map<ExprId, Slot> exprIdSlotMap = collectChildrenOutputMap(sort);
List<OrderKey> newKeys = sort.getOrderKeys().stream()
.map(old -> old.withExpression(updateExpression(old.getExpr(), exprIdSlotMap)))
.collect(ImmutableList.toImmutableList());
return sort.withOrderKeys(newKeys).recomputeLogicalProperties();
}
@Override
public Plan visitLogicalTopN(LogicalTopN<? extends Plan> topN, Void context) {
topN = (LogicalTopN<? extends Plan>) super.visit(topN, context);
Map<ExprId, Slot> exprIdSlotMap = collectChildrenOutputMap(topN);
List<OrderKey> newKeys = topN.getOrderKeys().stream()
.map(old -> old.withExpression(updateExpression(old.getExpr(), exprIdSlotMap)))
.collect(ImmutableList.toImmutableList());
return topN.withOrderKeys(newKeys).recomputeLogicalProperties();
}
@Override
public Plan visitLogicalWindow(LogicalWindow<? extends Plan> window, Void context) {
window = (LogicalWindow<? extends Plan>) super.visit(window, context);
Map<ExprId, Slot> exprIdSlotMap = collectChildrenOutputMap(window);
List<NamedExpression> windowExpressions =
updateExpressions(window.getWindowExpressions(), exprIdSlotMap);
return window.withExpression(windowExpressions, window.child());
}
private <T extends Expression> T updateExpression(T input, Map<ExprId, Slot> exprIdSlotMap) {
@ -153,7 +201,11 @@ public class AdjustNullable implements RewriteRuleFactory {
@Override
public Expression visitSlotReference(SlotReference slotReference, Map<ExprId, Slot> context) {
return context.getOrDefault(slotReference.getExprId(), slotReference);
if (context.containsKey(slotReference.getExprId())) {
return slotReference.withNullable(context.get(slotReference.getExprId()).nullable());
} else {
return slotReference;
}
}
}
}

View File

@ -111,19 +111,23 @@ public class LogicalFilter<CHILD_TYPE extends Plan> extends LogicalUnary<CHILD_T
return visitor.visitLogicalFilter(this, context);
}
public LogicalFilter<Plan> withConjuncts(Set<Expression> conjuncts) {
return new LogicalFilter<>(conjuncts, Optional.empty(), Optional.of(getLogicalProperties()), child());
}
@Override
public LogicalUnary<Plan> withChildren(List<Plan> children) {
public LogicalFilter<Plan> withChildren(List<Plan> children) {
Preconditions.checkArgument(children.size() == 1);
return new LogicalFilter<>(conjuncts, children.get(0));
}
@Override
public Plan withGroupExpression(Optional<GroupExpression> groupExpression) {
public LogicalFilter<Plan> withGroupExpression(Optional<GroupExpression> groupExpression) {
return new LogicalFilter<>(conjuncts, groupExpression, Optional.of(getLogicalProperties()), child());
}
@Override
public Plan withLogicalProperties(Optional<LogicalProperties> logicalProperties) {
public LogicalFilter<Plan> withLogicalProperties(Optional<LogicalProperties> logicalProperties) {
return new LogicalFilter<>(conjuncts, Optional.empty(), logicalProperties, child());
}
}

View File

@ -76,8 +76,14 @@ public class LogicalWindow<CHILD_TYPE extends Plan> extends LogicalUnary<CHILD_T
return windowExpressions;
}
public LogicalWindow withChecked(List<NamedExpression> windowExpressions, Plan child) {
return new LogicalWindow(windowExpressions, true, Optional.empty(), Optional.empty(), child);
public LogicalWindow<Plan> withExpression(List<NamedExpression> windowExpressions, Plan child) {
return new LogicalWindow<>(windowExpressions, isChecked, Optional.empty(),
Optional.empty(), child);
}
public LogicalWindow<Plan> withChecked(List<NamedExpression> windowExpressions, Plan child) {
return new LogicalWindow<>(windowExpressions, true, Optional.empty(),
Optional.of(getLogicalProperties()), child);
}
@Override

View File

@ -227,9 +227,8 @@ public abstract class PlanVisitor<R, C> {
return visit(having, context);
}
public R visitLogicalSetOperation(
LogicalSetOperation logicalSetOperation, C context) {
return visit(logicalSetOperation, context);
public R visitLogicalSetOperation(LogicalSetOperation setOperation, C context) {
return visit(setOperation, context);
}
public R visitLogicalUnion(LogicalUnion union, C context) {

View File

@ -31,6 +31,25 @@ suite("window_function") {
"replication_allocation" = "tag.location.default: 1"
);
"""
sql """
create table adj_nullable_1 (
c1 int,
c2 int,
c3 int
) distributed by hash(c1)
properties('replication_num'='1');
"""
sql """
create table adj_nullable_2 (
c4 int not null,
c5 int not null,
c6 int not null
) distributed by hash(c4)
properties('replication_num'='1');
"""
sql """INSERT INTO window_test VALUES(1, 1, 1)"""
sql """INSERT INTO window_test VALUES(1, 2, 1)"""
sql """INSERT INTO window_test VALUES(1, 3, 1)"""
@ -46,6 +65,9 @@ suite("window_function") {
sql """INSERT INTO window_test VALUES(1, null, 3)"""
sql """INSERT INTO window_test VALUES(2, null, 3)"""
sql """insert into adj_nullable_1 values(1, 1, 1);"""
sql """insert into adj_nullable_2 values(1, 1, 1);"""
sql "SET enable_fallback_to_original_planner=false"
order_qt_empty_over "SELECT rank() over() FROM window_test"
@ -148,4 +170,15 @@ suite("window_function") {
from window_test
group by c1, c2
"""
// test adjust nullable on window
sql """
select
count(c1) over (partition by c4),
coalesce(c5, sum(c2) over (partition by c3))
from
adj_nullable_1
left join adj_nullable_2 on c1 = c4
where c6 is not null;
"""
}