diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/batch/NereidsRewriter.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/batch/NereidsRewriter.java index fb1769f7c0..a7d958199b 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/batch/NereidsRewriter.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/batch/NereidsRewriter.java @@ -53,6 +53,7 @@ import org.apache.doris.nereids.rules.rewrite.logical.ExtractAndNormalizeWindowE import org.apache.doris.nereids.rules.rewrite.logical.ExtractFilterFromCrossJoin; import org.apache.doris.nereids.rules.rewrite.logical.ExtractSingleTableExpressionFromDisjunction; import org.apache.doris.nereids.rules.rewrite.logical.FindHashConditionForJoin; +import org.apache.doris.nereids.rules.rewrite.logical.InferAggNotNull; import org.apache.doris.nereids.rules.rewrite.logical.InferFilterNotNull; import org.apache.doris.nereids.rules.rewrite.logical.InferJoinNotNull; import org.apache.doris.nereids.rules.rewrite.logical.InferPredicates; @@ -161,6 +162,7 @@ public class NereidsRewriter extends BatchRewriteJob { topic("Rewrite join", // infer not null filter, then push down filter, and then reorder join(cross join to inner join) topDown( + new InferAggNotNull(), new InferFilterNotNull(), new InferJoinNotNull() ), diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java index 3113a120c2..0f97e67291 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java @@ -100,6 +100,7 @@ public enum RuleType { ELIMINATE_ORDER_BY_CONSTANT(RuleTypeClass.REWRITE), ELIMINATE_HINT(RuleTypeClass.REWRITE), INFER_PREDICATES(RuleTypeClass.REWRITE), + INFER_AGG_NOT_NULL(RuleTypeClass.REWRITE), INFER_FILTER_NOT_NULL(RuleTypeClass.REWRITE), INFER_JOIN_NOT_NULL(RuleTypeClass.REWRITE), // subquery analyze diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindExpression.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindExpression.java index 390d9e1fe5..a481aa6559 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindExpression.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindExpression.java @@ -189,7 +189,7 @@ public class BindExpression implements AnalysisRuleFactory { for (int i = 0; i < size; i++) { hashEqExpr.add(new EqualTo(leftSlots.get(i), rightSlots.get(i))); } - return lj.withHashJoinConjuncts(hashEqExpr); + return lj.withJoinConjuncts(hashEqExpr, lj.getOtherJoinConjuncts()); }) ), RuleType.BINDING_JOIN_SLOT.build( diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/EliminateNotNull.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/EliminateNotNull.java index 039e32ea5e..518e53d13e 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/EliminateNotNull.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/EliminateNotNull.java @@ -17,9 +17,10 @@ package org.apache.doris.nereids.rules.rewrite.logical; +import org.apache.doris.nereids.CascadesContext; import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleType; -import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory; +import org.apache.doris.nereids.rules.rewrite.RewriteRuleFactory; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.IsNull; import org.apache.doris.nereids.trees.expressions.Not; @@ -27,14 +28,17 @@ import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.expressions.functions.ExpressionTrait; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalFilter; +import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; import org.apache.doris.nereids.util.ExpressionUtils; import org.apache.doris.nereids.util.PlanUtils; import org.apache.doris.nereids.util.TypeUtils; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Lists; import com.google.common.collect.Sets; +import java.util.Collection; import java.util.List; import java.util.Optional; import java.util.Set; @@ -45,44 +49,60 @@ import java.util.stream.Collectors; * - redundant `is not null` predicate like `a > 0 and a is not null` -> `a > 0` * - `is not null` predicate is generated by `InferFilterNotNull` */ -public class EliminateNotNull extends OneRewriteRuleFactory { +public class EliminateNotNull implements RewriteRuleFactory { @Override - public Rule build() { - return logicalFilter() - .when(filter -> filter.getConjuncts().stream().anyMatch(expr -> expr.isGeneratedIsNotNull)) - .thenApply(ctx -> { - LogicalFilter filter = ctx.root; - // Progress Example: `id > 0 and id is not null and name is not null(generated)` - // predicatesNotContainIsNotNull: `id > 0` - // predicatesNotContainIsNotNull infer nonNullable slots: `id` - // slotsFromIsNotNull: `id`, `name` - // remove `name` (it's generated), remove `id` (because `id > 0` already contains it) - Set predicatesNotContainIsNotNull = Sets.newHashSet(); - List slotsFromIsNotNull = Lists.newArrayList(); - filter.getConjuncts().stream() - .filter(expr -> !expr.isGeneratedIsNotNull) // remove generated `is not null` - .forEach(expr -> { - Optional notNullSlot = TypeUtils.isNotNull(expr); - if (notNullSlot.isPresent()) { - slotsFromIsNotNull.add(notNullSlot.get()); - } else { - predicatesNotContainIsNotNull.add(expr); - } - }); - Set inferNonNotSlots = ExpressionUtils.inferNotNullSlots( - predicatesNotContainIsNotNull, ctx.cascadesContext); + public List buildRules() { + return ImmutableList.of( + logicalFilter() + .when(filter -> filter.getConjuncts().stream().anyMatch(expr -> expr.isGeneratedIsNotNull)) + .thenApply(ctx -> { + LogicalFilter filter = ctx.root; + List predicates = removeGeneratedNotNull(filter.getConjuncts(), + ctx.cascadesContext); + return PlanUtils.filterOrSelf(ImmutableSet.copyOf(predicates), filter.child()); + }).toRule(RuleType.ELIMINATE_NOT_NULL), + innerLogicalJoin() + .when(join -> join.getOtherJoinConjuncts().stream().anyMatch(expr -> expr.isGeneratedIsNotNull)) + .thenApply(ctx -> { + LogicalJoin join = ctx.root; + List newOtherJoinConjuncts = removeGeneratedNotNull( + join.getOtherJoinConjuncts(), ctx.cascadesContext); + return join.withJoinConjuncts(join.getHashJoinConjuncts(), newOtherJoinConjuncts); + }) + .toRule(RuleType.ELIMINATE_NOT_NULL) + ); + } - Set keepIsNotNull = slotsFromIsNotNull.stream() - .filter(ExpressionTrait::nullable) - .filter(slot -> !inferNonNotSlots.contains(slot)) - .map(slot -> new Not(new IsNull(slot))).collect(Collectors.toSet()); + private List removeGeneratedNotNull(Collection exprs, CascadesContext ctx) { + // Example: `id > 0 and id is not null and name is not null(generated)` + // predicatesNotContainIsNotNull: `id > 0` + // predicatesNotContainIsNotNull infer nonNullable slots: `id` + // slotsFromIsNotNull: `id`, `name` + // remove `name` (it's generated), remove `id` (because `id > 0` already contains it) + Set predicatesNotContainIsNotNull = Sets.newHashSet(); + List slotsFromIsNotNull = Lists.newArrayList(); + exprs.stream() + .filter(expr -> !expr.isGeneratedIsNotNull) // remove generated `is not null` + .forEach(expr -> { + Optional notNullSlot = TypeUtils.isNotNull(expr); + if (notNullSlot.isPresent()) { + slotsFromIsNotNull.add(notNullSlot.get()); + } else { + predicatesNotContainIsNotNull.add(expr); + } + }); + Set inferNonNotSlots = ExpressionUtils.inferNotNullSlots( + predicatesNotContainIsNotNull, ctx); - // merge predicatesNotContainIsNotNull and keepIsNotNull into a new ImmutableSet - Set newPredicates = ImmutableSet.builder() - .addAll(predicatesNotContainIsNotNull) - .addAll(keepIsNotNull) - .build(); - return PlanUtils.filterOrSelf(newPredicates, filter.child()); - }).toRule(RuleType.ELIMINATE_NOT_NULL); + Set keepIsNotNull = slotsFromIsNotNull.stream() + .filter(ExpressionTrait::nullable) + .filter(slot -> !inferNonNotSlots.contains(slot)) + .map(slot -> new Not(new IsNull(slot))).collect(Collectors.toSet()); + + // merge predicatesNotContainIsNotNull and keepIsNotNull into a new List + return ImmutableList.builder() + .addAll(predicatesNotContainIsNotNull) + .addAll(keepIsNotNull) + .build(); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/InferAggNotNull.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/InferAggNotNull.java new file mode 100644 index 0000000000..4b9833f07f --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/InferAggNotNull.java @@ -0,0 +1,67 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.rewrite.logical; + +import org.apache.doris.nereids.rules.Rule; +import org.apache.doris.nereids.rules.RuleType; +import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; +import org.apache.doris.nereids.trees.expressions.functions.agg.Avg; +import org.apache.doris.nereids.trees.expressions.functions.agg.Count; +import org.apache.doris.nereids.trees.expressions.functions.agg.Max; +import org.apache.doris.nereids.trees.expressions.functions.agg.Min; +import org.apache.doris.nereids.trees.expressions.functions.agg.Sum; +import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.algebra.Filter; +import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; +import org.apache.doris.nereids.util.ExpressionUtils; +import org.apache.doris.nereids.util.PlanUtils; + +import java.util.Set; +import java.util.stream.Collectors; + +/** + * InferNotNull from Agg count(distinct); + */ +public class InferAggNotNull extends OneRewriteRuleFactory { + @Override + public Rule build() { + return logicalAggregate() + .when(agg -> agg.getGroupByExpressions().size() == 0) + .when(agg -> agg.getAggregateFunctions().size() == 1) + .when(agg -> { + Set funcs = agg.getAggregateFunctions(); + return funcs.stream().allMatch(f -> f instanceof Count) + || funcs.stream().allMatch(f -> f instanceof Avg) + || funcs.stream().allMatch(f -> f instanceof Sum) + || funcs.stream().allMatch(f -> f instanceof Max) + || funcs.stream().allMatch(f -> f instanceof Min); + }).thenApply(ctx -> { + LogicalAggregate agg = ctx.root; + Set exprs = agg.getAggregateFunctions().stream().flatMap(f -> f.children().stream()) + .collect(Collectors.toSet()); + Set isNotNull = ExpressionUtils.inferNotNull(exprs, ctx.cascadesContext); + if (isNotNull.size() == 0 || (agg.child() instanceof Filter && isNotNull.equals( + ((Filter) agg.child()).getConjuncts()))) { + return null; + } + return agg.withChildren(PlanUtils.filter(isNotNull, agg.child()).get()); + }).toRule(RuleType.INFER_AGG_NOT_NULL); + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalJoin.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalJoin.java index 12c0ab13c2..225e0fcb28 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalJoin.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalJoin.java @@ -266,16 +266,11 @@ public class LogicalJoin withChildrenNoContext(Plan left, Plan right) { return new LogicalJoin<>(joinType, hashJoinConjuncts, otherJoinConjuncts, hint, markJoinSlotReference, left, right); } - public LogicalJoin withHashJoinConjuncts(List hashJoinConjuncts) { - return new LogicalJoin<>(joinType, hashJoinConjuncts, this.otherJoinConjuncts, hint, markJoinSlotReference, - left(), right()); - } - public LogicalJoin withJoinConjuncts( List hashJoinConjuncts, List otherJoinConjuncts) { return new LogicalJoin<>(joinType, hashJoinConjuncts, otherJoinConjuncts, diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/InferAggNotNullTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/InferAggNotNullTest.java new file mode 100644 index 0000000000..6c9be21f6c --- /dev/null +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/InferAggNotNullTest.java @@ -0,0 +1,67 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.rewrite.logical; + +import org.apache.doris.nereids.trees.expressions.Alias; +import org.apache.doris.nereids.trees.expressions.functions.agg.Count; +import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; +import org.apache.doris.nereids.trees.plans.logical.LogicalPlan; +import org.apache.doris.nereids.util.LogicalPlanBuilder; +import org.apache.doris.nereids.util.MemoPatternMatchSupported; +import org.apache.doris.nereids.util.MemoTestUtils; +import org.apache.doris.nereids.util.PlanChecker; +import org.apache.doris.nereids.util.PlanConstructor; + +import com.google.common.collect.ImmutableList; +import org.junit.jupiter.api.Test; + +class InferAggNotNullTest implements MemoPatternMatchSupported { + private final LogicalOlapScan scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0); + + @Test + void testInfer() { + LogicalPlan plan = new LogicalPlanBuilder(scan1) + .aggGroupUsingIndex(ImmutableList.of(), + ImmutableList.of(new Alias(new Count(true, scan1.getOutput().get(1)), "dnt"))) + .build(); + + PlanChecker.from(MemoTestUtils.createConnectContext(), plan) + .applyTopDown(new InferAggNotNull()) + .matches( + logicalAggregate( + logicalFilter().when(filter -> filter.getConjuncts().stream().allMatch(e -> e.isGeneratedIsNotNull)) + ) + ); + } + + @Test + void testCountStar() { + LogicalPlan plan = new LogicalPlanBuilder(scan1) + .aggGroupUsingIndex(ImmutableList.of(), ImmutableList.of(new Alias(new Count(), "dnt"))) + .build(); + + PlanChecker.from(MemoTestUtils.createConnectContext(), plan) + .applyTopDown(new InferAggNotNull()) + .printlnTree() + .matches( + logicalAggregate( + logicalOlapScan() + ) + ); + } +} diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/sqltest/InferTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/sqltest/InferTest.java index 010a920023..1d412a3d5d 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/sqltest/InferTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/sqltest/InferTest.java @@ -88,4 +88,20 @@ public class InferTest extends SqlTestBase { ) ); } + + @Test + void aggEliminateOuterJoin() { + String sql = "select count(T2.score) from T1 left Join T2 on T1.id = T2.id"; + + PlanChecker.from(connectContext) + .analyze(sql) + .rewrite() + .matches( + logicalAggregate( + logicalProject( + innerLogicalJoin() + ) + ) + ); + } }