|
|
|
|
@ -17,10 +17,8 @@
|
|
|
|
|
|
|
|
|
|
package org.apache.doris.nereids.rules.rewrite.logical;
|
|
|
|
|
|
|
|
|
|
import org.apache.doris.nereids.jobs.JobContext;
|
|
|
|
|
import org.apache.doris.nereids.properties.OrderKey;
|
|
|
|
|
import org.apache.doris.nereids.rules.Rule;
|
|
|
|
|
import org.apache.doris.nereids.rules.RuleType;
|
|
|
|
|
import org.apache.doris.nereids.rules.rewrite.RewriteRuleFactory;
|
|
|
|
|
import org.apache.doris.nereids.trees.expressions.ExprId;
|
|
|
|
|
import org.apache.doris.nereids.trees.expressions.Expression;
|
|
|
|
|
import org.apache.doris.nereids.trees.expressions.NamedExpression;
|
|
|
|
|
@ -30,8 +28,19 @@ import org.apache.doris.nereids.trees.expressions.functions.ExpressionTrait;
|
|
|
|
|
import org.apache.doris.nereids.trees.expressions.functions.Function;
|
|
|
|
|
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;
|
|
|
|
|
import org.apache.doris.nereids.trees.plans.Plan;
|
|
|
|
|
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
|
|
|
|
|
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
|
|
|
|
|
import org.apache.doris.nereids.trees.plans.logical.LogicalGenerate;
|
|
|
|
|
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
|
|
|
|
|
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
|
|
|
|
|
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
|
|
|
|
|
import org.apache.doris.nereids.trees.plans.logical.LogicalRepeat;
|
|
|
|
|
import org.apache.doris.nereids.trees.plans.logical.LogicalSetOperation;
|
|
|
|
|
import org.apache.doris.nereids.trees.plans.logical.LogicalSort;
|
|
|
|
|
import org.apache.doris.nereids.trees.plans.logical.LogicalTopN;
|
|
|
|
|
import org.apache.doris.nereids.trees.plans.logical.LogicalWindow;
|
|
|
|
|
import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter;
|
|
|
|
|
import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter;
|
|
|
|
|
|
|
|
|
|
import com.google.common.collect.ImmutableList;
|
|
|
|
|
import com.google.common.collect.ImmutableSet;
|
|
|
|
|
@ -46,87 +55,126 @@ import java.util.stream.Collectors;
|
|
|
|
|
* because some rule could change output's nullable.
|
|
|
|
|
* So, we need add a rule to adjust all expression's nullable attribute after rewrite.
|
|
|
|
|
*/
|
|
|
|
|
public class AdjustNullable implements RewriteRuleFactory {
|
|
|
|
|
public class AdjustNullable extends DefaultPlanRewriter<Void> implements CustomRewriter {
|
|
|
|
|
|
|
|
|
|
@Override
|
|
|
|
|
public List<Rule> buildRules() {
|
|
|
|
|
return ImmutableList.of(
|
|
|
|
|
RuleType.ADJUST_NULLABLE_ON_AGGREGATE.build(logicalAggregate().then(aggregate -> {
|
|
|
|
|
Map<ExprId, Slot> exprIdSlotMap = collectChildrenOutputMap(aggregate);
|
|
|
|
|
List<NamedExpression> newOutputs
|
|
|
|
|
= updateExpressions(aggregate.getOutputExpressions(), exprIdSlotMap);
|
|
|
|
|
List<Expression> newGroupExpressions
|
|
|
|
|
= updateExpressions(aggregate.getGroupByExpressions(), exprIdSlotMap);
|
|
|
|
|
return aggregate.withGroupByAndOutput(newGroupExpressions, newOutputs).recomputeLogicalProperties();
|
|
|
|
|
})),
|
|
|
|
|
RuleType.ADJUST_NULLABLE_ON_ASSERT_NUM_ROWS.build(logicalAssertNumRows().then(
|
|
|
|
|
LogicalPlan::recomputeLogicalProperties)),
|
|
|
|
|
RuleType.ADJUST_NULLABLE_ON_FILTER.build(logicalFilter().then(filter -> {
|
|
|
|
|
Map<ExprId, Slot> exprIdSlotMap = collectChildrenOutputMap(filter);
|
|
|
|
|
Set<Expression> conjuncts = updateExpressions(filter.getConjuncts(), exprIdSlotMap);
|
|
|
|
|
return new LogicalFilter<>(conjuncts, filter.child());
|
|
|
|
|
})),
|
|
|
|
|
RuleType.ADJUST_NULLABLE_ON_GENERATE.build(logicalGenerate().then(generate -> {
|
|
|
|
|
Map<ExprId, Slot> exprIdSlotMap = collectChildrenOutputMap(generate);
|
|
|
|
|
List<Function> newGenerators = updateExpressions(generate.getGenerators(), exprIdSlotMap);
|
|
|
|
|
return generate.withGenerators(newGenerators).recomputeLogicalProperties();
|
|
|
|
|
})),
|
|
|
|
|
RuleType.ADJUST_NULLABLE_ON_JOIN.build(logicalJoin().then(join -> {
|
|
|
|
|
Map<ExprId, Slot> exprIdSlotMap = collectChildrenOutputMap(join);
|
|
|
|
|
List<Expression> hashConjuncts = updateExpressions(join.getHashJoinConjuncts(), exprIdSlotMap);
|
|
|
|
|
List<Expression> otherConjuncts = updateExpressions(join.getOtherJoinConjuncts(), exprIdSlotMap);
|
|
|
|
|
return join.withJoinConjuncts(hashConjuncts, otherConjuncts).recomputeLogicalProperties();
|
|
|
|
|
})),
|
|
|
|
|
RuleType.ADJUST_NULLABLE_ON_LIMIT.build(logicalLimit().then(LogicalPlan::recomputeLogicalProperties)),
|
|
|
|
|
RuleType.ADJUST_NULLABLE_ON_PROJECT.build(logicalProject().then(project -> {
|
|
|
|
|
Map<ExprId, Slot> exprIdSlotMap = collectChildrenOutputMap(project);
|
|
|
|
|
List<NamedExpression> newProjects = updateExpressions(project.getProjects(), exprIdSlotMap);
|
|
|
|
|
return project.withProjects(newProjects);
|
|
|
|
|
})),
|
|
|
|
|
RuleType.ADJUST_NULLABLE_ON_REPEAT.build(logicalRepeat().then(repeat -> {
|
|
|
|
|
Map<ExprId, Slot> exprIdSlotMap = collectChildrenOutputMap(repeat);
|
|
|
|
|
List<NamedExpression> newOutputs = updateExpressions(repeat.getOutputExpressions(), exprIdSlotMap);
|
|
|
|
|
List<List<Expression>> newGroupingSets = repeat.getGroupingSets().stream()
|
|
|
|
|
.map(l -> updateExpressions(l, exprIdSlotMap))
|
|
|
|
|
.collect(ImmutableList.toImmutableList());
|
|
|
|
|
return repeat.withGroupSetsAndOutput(newGroupingSets, newOutputs).recomputeLogicalProperties();
|
|
|
|
|
})),
|
|
|
|
|
RuleType.ADJUST_NULLABLE_ON_SET_OPERATION.build(logicalSetOperation().then(setOperation -> {
|
|
|
|
|
List<Boolean> inputNullable = setOperation.child(0).getOutput().stream()
|
|
|
|
|
.map(ExpressionTrait::nullable).collect(Collectors.toList());
|
|
|
|
|
for (int i = 1; i < setOperation.arity(); i++) {
|
|
|
|
|
List<Slot> childOutput = setOperation.getChildOutput(i);
|
|
|
|
|
for (int j = 0; j < childOutput.size(); j++) {
|
|
|
|
|
if (childOutput.get(j).nullable()) {
|
|
|
|
|
inputNullable.set(j, true);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
List<NamedExpression> outputs = setOperation.getOutputs();
|
|
|
|
|
List<NamedExpression> newOutputs = Lists.newArrayListWithCapacity(outputs.size());
|
|
|
|
|
for (int i = 0; i < inputNullable.size(); i++) {
|
|
|
|
|
SlotReference slotReference = (SlotReference) outputs.get(i);
|
|
|
|
|
if (inputNullable.get(i)) {
|
|
|
|
|
slotReference = slotReference.withNullable(true);
|
|
|
|
|
}
|
|
|
|
|
newOutputs.add(slotReference);
|
|
|
|
|
}
|
|
|
|
|
return setOperation.withNewOutputs(newOutputs).recomputeLogicalProperties();
|
|
|
|
|
})),
|
|
|
|
|
RuleType.ADJUST_NULLABLE_ON_SORT.build(logicalSort().then(sort -> {
|
|
|
|
|
Map<ExprId, Slot> exprIdSlotMap = collectChildrenOutputMap(sort);
|
|
|
|
|
List<OrderKey> newKeys = sort.getOrderKeys().stream()
|
|
|
|
|
.map(old -> old.withExpression(updateExpression(old.getExpr(), exprIdSlotMap)))
|
|
|
|
|
.collect(ImmutableList.toImmutableList());
|
|
|
|
|
return sort.withOrderKeys(newKeys).recomputeLogicalProperties();
|
|
|
|
|
})),
|
|
|
|
|
RuleType.ADJUST_NULLABLE_ON_TOP_N.build(logicalTopN().then(topN -> {
|
|
|
|
|
Map<ExprId, Slot> exprIdSlotMap = collectChildrenOutputMap(topN);
|
|
|
|
|
List<OrderKey> newKeys = topN.getOrderKeys().stream()
|
|
|
|
|
.map(old -> old.withExpression(updateExpression(old.getExpr(), exprIdSlotMap)))
|
|
|
|
|
.collect(ImmutableList.toImmutableList());
|
|
|
|
|
return topN.withOrderKeys(newKeys).recomputeLogicalProperties();
|
|
|
|
|
}))
|
|
|
|
|
);
|
|
|
|
|
public Plan rewriteRoot(Plan plan, JobContext jobContext) {
|
|
|
|
|
return plan.accept(this, null);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@Override
|
|
|
|
|
public Plan visit(Plan plan, Void context) {
|
|
|
|
|
LogicalPlan logicalPlan = (LogicalPlan) super.visit(plan, context);
|
|
|
|
|
return logicalPlan.recomputeLogicalProperties();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@Override
|
|
|
|
|
public Plan visitLogicalAggregate(LogicalAggregate<? extends Plan> aggregate, Void context) {
|
|
|
|
|
aggregate = (LogicalAggregate<? extends Plan>) super.visit(aggregate, context);
|
|
|
|
|
Map<ExprId, Slot> exprIdSlotMap = collectChildrenOutputMap(aggregate);
|
|
|
|
|
List<NamedExpression> newOutputs
|
|
|
|
|
= updateExpressions(aggregate.getOutputExpressions(), exprIdSlotMap);
|
|
|
|
|
List<Expression> newGroupExpressions
|
|
|
|
|
= updateExpressions(aggregate.getGroupByExpressions(), exprIdSlotMap);
|
|
|
|
|
return aggregate.withGroupByAndOutput(newGroupExpressions, newOutputs);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@Override
|
|
|
|
|
public Plan visitLogicalFilter(LogicalFilter<? extends Plan> filter, Void context) {
|
|
|
|
|
filter = (LogicalFilter<? extends Plan>) super.visit(filter, context);
|
|
|
|
|
Map<ExprId, Slot> exprIdSlotMap = collectChildrenOutputMap(filter);
|
|
|
|
|
Set<Expression> conjuncts = updateExpressions(filter.getConjuncts(), exprIdSlotMap);
|
|
|
|
|
return filter.withConjuncts(conjuncts).recomputeLogicalProperties();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@Override
|
|
|
|
|
public Plan visitLogicalGenerate(LogicalGenerate<? extends Plan> generate, Void context) {
|
|
|
|
|
generate = (LogicalGenerate<? extends Plan>) super.visit(generate, context);
|
|
|
|
|
Map<ExprId, Slot> exprIdSlotMap = collectChildrenOutputMap(generate);
|
|
|
|
|
List<Function> newGenerators = updateExpressions(generate.getGenerators(), exprIdSlotMap);
|
|
|
|
|
return generate.withGenerators(newGenerators).recomputeLogicalProperties();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@Override
|
|
|
|
|
public Plan visitLogicalJoin(LogicalJoin<? extends Plan, ? extends Plan> join, Void context) {
|
|
|
|
|
join = (LogicalJoin<? extends Plan, ? extends Plan>) super.visit(join, context);
|
|
|
|
|
Map<ExprId, Slot> exprIdSlotMap = collectChildrenOutputMap(join);
|
|
|
|
|
List<Expression> hashConjuncts = updateExpressions(join.getHashJoinConjuncts(), exprIdSlotMap);
|
|
|
|
|
List<Expression> otherConjuncts = updateExpressions(join.getOtherJoinConjuncts(), exprIdSlotMap);
|
|
|
|
|
return join.withJoinConjuncts(hashConjuncts, otherConjuncts).recomputeLogicalProperties();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@Override
|
|
|
|
|
public Plan visitLogicalProject(LogicalProject<? extends Plan> project, Void context) {
|
|
|
|
|
project = (LogicalProject<? extends Plan>) super.visit(project, context);
|
|
|
|
|
Map<ExprId, Slot> exprIdSlotMap = collectChildrenOutputMap(project);
|
|
|
|
|
List<NamedExpression> newProjects = updateExpressions(project.getProjects(), exprIdSlotMap);
|
|
|
|
|
return project.withProjects(newProjects);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@Override
|
|
|
|
|
public Plan visitLogicalRepeat(LogicalRepeat<? extends Plan> repeat, Void context) {
|
|
|
|
|
repeat = (LogicalRepeat<? extends Plan>) super.visit(repeat, context);
|
|
|
|
|
Map<ExprId, Slot> exprIdSlotMap = collectChildrenOutputMap(repeat);
|
|
|
|
|
List<NamedExpression> newOutputs = updateExpressions(repeat.getOutputExpressions(), exprIdSlotMap);
|
|
|
|
|
List<List<Expression>> newGroupingSets = repeat.getGroupingSets().stream()
|
|
|
|
|
.map(l -> updateExpressions(l, exprIdSlotMap))
|
|
|
|
|
.collect(ImmutableList.toImmutableList());
|
|
|
|
|
return repeat.withGroupSetsAndOutput(newGroupingSets, newOutputs).recomputeLogicalProperties();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@Override
|
|
|
|
|
public Plan visitLogicalSetOperation(LogicalSetOperation setOperation, Void context) {
|
|
|
|
|
setOperation = (LogicalSetOperation) super.visit(setOperation, context);
|
|
|
|
|
List<Boolean> inputNullable = setOperation.child(0).getOutput().stream()
|
|
|
|
|
.map(ExpressionTrait::nullable).collect(Collectors.toList());
|
|
|
|
|
for (int i = 1; i < setOperation.arity(); i++) {
|
|
|
|
|
List<Slot> childOutput = setOperation.getChildOutput(i);
|
|
|
|
|
for (int j = 0; j < childOutput.size(); j++) {
|
|
|
|
|
if (childOutput.get(j).nullable()) {
|
|
|
|
|
inputNullable.set(j, true);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
List<NamedExpression> outputs = setOperation.getOutputs();
|
|
|
|
|
List<NamedExpression> newOutputs = Lists.newArrayListWithCapacity(outputs.size());
|
|
|
|
|
for (int i = 0; i < inputNullable.size(); i++) {
|
|
|
|
|
SlotReference slotReference = (SlotReference) outputs.get(i);
|
|
|
|
|
if (inputNullable.get(i)) {
|
|
|
|
|
slotReference = slotReference.withNullable(true);
|
|
|
|
|
}
|
|
|
|
|
newOutputs.add(slotReference);
|
|
|
|
|
}
|
|
|
|
|
return setOperation.withNewOutputs(newOutputs).recomputeLogicalProperties();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@Override
|
|
|
|
|
public Plan visitLogicalSort(LogicalSort<? extends Plan> sort, Void context) {
|
|
|
|
|
sort = (LogicalSort<? extends Plan>) super.visit(sort, context);
|
|
|
|
|
Map<ExprId, Slot> exprIdSlotMap = collectChildrenOutputMap(sort);
|
|
|
|
|
List<OrderKey> newKeys = sort.getOrderKeys().stream()
|
|
|
|
|
.map(old -> old.withExpression(updateExpression(old.getExpr(), exprIdSlotMap)))
|
|
|
|
|
.collect(ImmutableList.toImmutableList());
|
|
|
|
|
return sort.withOrderKeys(newKeys).recomputeLogicalProperties();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@Override
|
|
|
|
|
public Plan visitLogicalTopN(LogicalTopN<? extends Plan> topN, Void context) {
|
|
|
|
|
topN = (LogicalTopN<? extends Plan>) super.visit(topN, context);
|
|
|
|
|
Map<ExprId, Slot> exprIdSlotMap = collectChildrenOutputMap(topN);
|
|
|
|
|
List<OrderKey> newKeys = topN.getOrderKeys().stream()
|
|
|
|
|
.map(old -> old.withExpression(updateExpression(old.getExpr(), exprIdSlotMap)))
|
|
|
|
|
.collect(ImmutableList.toImmutableList());
|
|
|
|
|
return topN.withOrderKeys(newKeys).recomputeLogicalProperties();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@Override
|
|
|
|
|
public Plan visitLogicalWindow(LogicalWindow<? extends Plan> window, Void context) {
|
|
|
|
|
window = (LogicalWindow<? extends Plan>) super.visit(window, context);
|
|
|
|
|
Map<ExprId, Slot> exprIdSlotMap = collectChildrenOutputMap(window);
|
|
|
|
|
List<NamedExpression> windowExpressions =
|
|
|
|
|
updateExpressions(window.getWindowExpressions(), exprIdSlotMap);
|
|
|
|
|
return window.withExpression(windowExpressions, window.child());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private <T extends Expression> T updateExpression(T input, Map<ExprId, Slot> exprIdSlotMap) {
|
|
|
|
|
@ -153,7 +201,11 @@ public class AdjustNullable implements RewriteRuleFactory {
|
|
|
|
|
|
|
|
|
|
@Override
|
|
|
|
|
public Expression visitSlotReference(SlotReference slotReference, Map<ExprId, Slot> context) {
|
|
|
|
|
return context.getOrDefault(slotReference.getExprId(), slotReference);
|
|
|
|
|
if (context.containsKey(slotReference.getExprId())) {
|
|
|
|
|
return slotReference.withNullable(context.get(slotReference.getExprId()).nullable());
|
|
|
|
|
} else {
|
|
|
|
|
return slotReference;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|