diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java index 60e31fc904..63fb58beae 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java @@ -699,6 +699,10 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor ExpressionTranslator.translate(e, context)).forEach(crossJoinNode::addConjunct); + } return leftFragment; } else { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/batch/NereidsRewriteJobExecutor.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/batch/NereidsRewriteJobExecutor.java index 3f0f8be01e..91ede3e80c 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/batch/NereidsRewriteJobExecutor.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/batch/NereidsRewriteJobExecutor.java @@ -32,6 +32,7 @@ import org.apache.doris.nereids.rules.rewrite.logical.FindHashConditionForJoin; import org.apache.doris.nereids.rules.rewrite.logical.LimitPushDown; import org.apache.doris.nereids.rules.rewrite.logical.NormalizeAggregate; import org.apache.doris.nereids.rules.rewrite.logical.PruneOlapScanPartition; +import org.apache.doris.nereids.rules.rewrite.logical.PushFilterInsideJoin; import org.apache.doris.nereids.rules.rewrite.logical.ReorderJoin; import com.google.common.collect.ImmutableList; @@ -66,6 +67,7 @@ public class NereidsRewriteJobExecutor extends BatchRulesJob { .add(topDownBatch(ImmutableList.of(new ReorderJoin()))) .add(topDownBatch(ImmutableList.of(new ColumnPruning()))) .add(topDownBatch(RuleSet.PUSH_DOWN_JOIN_CONDITION_RULES, false)) + .add(topDownBatch(ImmutableList.of(PushFilterInsideJoin.INSTANCE))) .add(topDownBatch(ImmutableList.of(new FindHashConditionForJoin()))) .add(topDownBatch(ImmutableList.of(new LimitPushDown()))) .add(topDownBatch(ImmutableList.of(new EliminateLimit()))) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/Validator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/Validator.java index d73a93e308..af878129c4 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/Validator.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/Validator.java @@ -22,6 +22,7 @@ import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.physical.PhysicalFilter; +import org.apache.doris.nereids.trees.plans.physical.PhysicalNestedLoopJoin; import org.apache.doris.nereids.trees.plans.physical.PhysicalProject; import com.google.common.base.Preconditions; @@ -53,7 +54,9 @@ public class Validator extends PlanPostProcessor { Plan child = filter.child(); // Forbidden filter-project, we must make filter-project -> project-filter. - Preconditions.checkArgument(!(child instanceof PhysicalProject)); + Preconditions.checkState(!(child instanceof PhysicalProject)); + // Forbidden filter-cross join, because we put all filter on cross join into its other join condition. + Preconditions.checkState(!(child instanceof PhysicalNestedLoopJoin)); // Check filter is from child output. Set childOutputSet = child.getOutputSet(); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleSet.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleSet.java index 8456b73d64..d88b81ffed 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleSet.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleSet.java @@ -42,10 +42,10 @@ import org.apache.doris.nereids.rules.rewrite.logical.EliminateOuter; import org.apache.doris.nereids.rules.rewrite.logical.MergeConsecutiveFilters; import org.apache.doris.nereids.rules.rewrite.logical.MergeConsecutiveLimits; import org.apache.doris.nereids.rules.rewrite.logical.MergeConsecutiveProjects; -import org.apache.doris.nereids.rules.rewrite.logical.PushDownExpressionsInHashCondition; -import org.apache.doris.nereids.rules.rewrite.logical.PushDownJoinOtherCondition; -import org.apache.doris.nereids.rules.rewrite.logical.PushPredicatesThroughJoin; +import org.apache.doris.nereids.rules.rewrite.logical.PushdownExpressionsInHashCondition; +import org.apache.doris.nereids.rules.rewrite.logical.PushdownFilterThroughJoin; import org.apache.doris.nereids.rules.rewrite.logical.PushdownFilterThroughProject; +import org.apache.doris.nereids.rules.rewrite.logical.PushdownJoinOtherCondition; import org.apache.doris.nereids.rules.rewrite.logical.PushdownProjectThroughLimit; import com.google.common.collect.ImmutableList; @@ -72,9 +72,9 @@ public class RuleSet { .build(); public static final List PUSH_DOWN_JOIN_CONDITION_RULES = ImmutableList.of( - new PushDownJoinOtherCondition(), - new PushPredicatesThroughJoin(), - new PushDownExpressionsInHashCondition(), + new PushdownJoinOtherCondition(), + new PushdownFilterThroughJoin(), + new PushdownExpressionsInHashCondition(), new PushdownProjectThroughLimit(), new PushdownFilterThroughProject(), EliminateOuter.INSTANCE, diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java index 5dd74b5df6..2e85180097 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java @@ -78,11 +78,15 @@ public enum RuleType { IN_APPLY_TO_JOIN(RuleTypeClass.REWRITE), EXISTS_APPLY_TO_JOIN(RuleTypeClass.REWRITE), // predicate push down rules - PUSH_DOWN_PREDICATE_THROUGH_JOIN(RuleTypeClass.REWRITE), - PUSH_DOWN_JOIN_OTHER_CONDITION(RuleTypeClass.REWRITE), - PUSH_DOWN_PREDICATE_THROUGH_LEFT_SEMI_JOIN(RuleTypeClass.REWRITE), - PUSH_DOWN_PREDICATE_THROUGH_AGGREGATION(RuleTypeClass.REWRITE), - PUSH_DOWN_EXPRESSIONS_IN_HASH_CONDITIONS(RuleTypeClass.REWRITE), + PUSHDOWN_JOIN_OTHER_CONDITION(RuleTypeClass.REWRITE), + PUSHDOWN_PREDICATE_THROUGH_AGGREGATION(RuleTypeClass.REWRITE), + PUSHDOWN_EXPRESSIONS_IN_HASH_CONDITIONS(RuleTypeClass.REWRITE), + // Pushdown filter + PUSHDOWN_FILTER_THROUGH_JOIN(RuleTypeClass.REWRITE), + PUSHDOWN_FILTER_THROUGH_LEFT_SEMI_JOIN(RuleTypeClass.REWRITE), + PUSH_FILTER_INSIDE_JOIN(RuleTypeClass.REWRITE), + PUSHDOWN_FILTER_THROUGH_PROJET(RuleTypeClass.REWRITE), + PUSHDOWN_PROJECT_THROUGHT_LIMIT(RuleTypeClass.REWRITE), // column prune rules, COLUMN_PRUNE_AGGREGATION_CHILD(RuleTypeClass.REWRITE), COLUMN_PRUNE_FILTER_CHILD(RuleTypeClass.REWRITE), @@ -112,9 +116,6 @@ public enum RuleType { ROLLUP_WITH_OUT_AGG(RuleTypeClass.REWRITE), OLAP_SCAN_PARTITION_PRUNE(RuleTypeClass.REWRITE), EXTRACT_SINGLE_TABLE_EXPRESSION_FROM_DISJUNCTION(RuleTypeClass.REWRITE), - // Pushdown filter - PUSHDOWN_FILTER_THROUGH_PROJET(RuleTypeClass.REWRITE), - PUSHDOWN_PROJECT_THROUGHT_LIMIT(RuleTypeClass.REWRITE), REWRITE_SENTINEL(RuleTypeClass.REWRITE), // limit push down diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PushFilterInsideJoin.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PushFilterInsideJoin.java new file mode 100644 index 0000000000..dc84e663c7 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PushFilterInsideJoin.java @@ -0,0 +1,52 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.rewrite.logical; + +import org.apache.doris.nereids.rules.Rule; +import org.apache.doris.nereids.rules.RuleType; +import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.plans.GroupPlan; +import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; +import org.apache.doris.nereids.util.ExpressionUtils; + +import com.google.common.collect.Lists; + +import java.util.List; + +/** + * Push the predicate in the LogicalFilter to the join children. + */ +public class PushFilterInsideJoin extends OneRewriteRuleFactory { + public static final PushFilterInsideJoin INSTANCE = new PushFilterInsideJoin(); + + @Override + public Rule build() { + return logicalFilter(logicalJoin()) + // TODO: current just handle cross/inner join. + .when(filter -> filter.child().getJoinType().isCrossJoin() + || filter.child().getJoinType().isInnerJoin()) + .then(filter -> { + List otherConditions = Lists.newArrayList(filter.getPredicates()); + LogicalJoin join = filter.child(); + join.getOtherJoinCondition().map(otherConditions::add); + return new LogicalJoin<>(join.getJoinType(), join.getHashJoinConjuncts(), + ExpressionUtils.optionalAnd(otherConditions), join.left(), join.right()); + }).toRule(RuleType.PUSH_FILTER_INSIDE_JOIN); + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PushDownExpressionsInHashCondition.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PushdownExpressionsInHashCondition.java similarity index 97% rename from fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PushDownExpressionsInHashCondition.java rename to fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PushdownExpressionsInHashCondition.java index cdb307e3b2..f6e39ea32a 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PushDownExpressionsInHashCondition.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PushdownExpressionsInHashCondition.java @@ -45,7 +45,7 @@ import java.util.stream.Collectors; /** * push down expression which is not slot reference */ -public class PushDownExpressionsInHashCondition extends OneRewriteRuleFactory { +public class PushdownExpressionsInHashCondition extends OneRewriteRuleFactory { /* * rewrite example: * join(t1.a + 1 = t2.b + 2) join(c = d) @@ -94,7 +94,7 @@ public class PushDownExpressionsInHashCondition extends OneRewriteRuleFactory { .collect(Collectors.toList())) .addAll(getOutput(plan, join)).build(), plan)) .collect(Collectors.toList())); - }).toRule(RuleType.PUSH_DOWN_EXPRESSIONS_IN_HASH_CONDITIONS); + }).toRule(RuleType.PUSHDOWN_EXPRESSIONS_IN_HASH_CONDITIONS); } private List getOutput(Plan plan, LogicalJoin join) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PushPredicateThroughAggregation.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PushdownFilterThroughAggregation.java similarity index 96% rename from fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PushPredicateThroughAggregation.java rename to fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PushdownFilterThroughAggregation.java index 9ed04483e2..afc36383f8 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PushPredicateThroughAggregation.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PushdownFilterThroughAggregation.java @@ -62,7 +62,7 @@ import java.util.Set; * */ -public class PushPredicateThroughAggregation extends OneRewriteRuleFactory { +public class PushdownFilterThroughAggregation extends OneRewriteRuleFactory { @Override public Rule build() { @@ -86,7 +86,7 @@ public class PushPredicateThroughAggregation extends OneRewriteRuleFactory { }); return pushDownPredicate(filter, aggregate, pushDownPredicates, filterPredicates); - }).toRule(RuleType.PUSH_DOWN_PREDICATE_THROUGH_AGGREGATION); + }).toRule(RuleType.PUSHDOWN_PREDICATE_THROUGH_AGGREGATION); } private Plan pushDownPredicate(LogicalFilter filter, LogicalAggregate aggregate, diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PushPredicatesThroughJoin.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PushdownFilterThroughJoin.java similarity index 63% rename from fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PushPredicatesThroughJoin.java rename to fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PushdownFilterThroughJoin.java index 3cdfac4918..2b5cf9065b 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PushPredicatesThroughJoin.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PushdownFilterThroughJoin.java @@ -20,13 +20,12 @@ package org.apache.doris.nereids.rules.rewrite.logical; import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleType; import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory; -import org.apache.doris.nereids.trees.expressions.ComparisonPredicate; import org.apache.doris.nereids.trees.expressions.EqualTo; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.expressions.SlotReference; import org.apache.doris.nereids.trees.plans.GroupPlan; import org.apache.doris.nereids.trees.plans.JoinType; -import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; import org.apache.doris.nereids.util.ExpressionUtils; import org.apache.doris.nereids.util.PlanUtils; @@ -35,13 +34,13 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.Lists; import java.util.List; -import java.util.Objects; import java.util.Set; /** * Push the predicate in the LogicalFilter to the join children. */ -public class PushPredicatesThroughJoin extends OneRewriteRuleFactory { +public class PushdownFilterThroughJoin extends OneRewriteRuleFactory { + public static final PushdownFilterThroughJoin INSTANCE = new PushdownFilterThroughJoin(); private static final ImmutableList COULD_PUSH_THROUGH_LEFT = ImmutableList.of( JoinType.INNER_JOIN, @@ -59,7 +58,7 @@ public class PushPredicatesThroughJoin extends OneRewriteRuleFactory { JoinType.CROSS_JOIN ); - private static final ImmutableList COULD_PUSH_EQUAL_TO = ImmutableList.of( + private static final ImmutableList COULD_PUSH_INSIDE = ImmutableList.of( JoinType.INNER_JOIN ); @@ -68,6 +67,9 @@ public class PushPredicatesThroughJoin extends OneRewriteRuleFactory { * select a.k1, b.k1 from a join b on a.k1 = b.k1 and a.k2 > 2 and b.k2 > 5 * where a.k1 > 1 and b.k1 > 2 and a.k2 > b.k2 * + * TODO(jakevin): following graph is wrong, we should add a new rule to extract + * a.k2 > 2 and b.k2 > 5, and pushdown. + * * Logical plan tree: * project * | @@ -93,29 +95,28 @@ public class PushPredicatesThroughJoin extends OneRewriteRuleFactory { LogicalJoin join = filter.child(); - Expression filterPredicates = filter.getPredicates(); + List predicates = ExpressionUtils.extractConjunction(filter.getPredicates()); - List filterConditions = Lists.newArrayList(); + List filterPredicates = Lists.newArrayList(); List joinConditions = Lists.newArrayList(); Set leftInput = join.left().getOutputSet(); Set rightInput = join.right().getOutputSet(); - ExpressionUtils.extractConjunction(filterPredicates) - .forEach(predicate -> { - if (Objects.nonNull(getJoinCondition(predicate, leftInput, rightInput)) - && COULD_PUSH_EQUAL_TO.contains(join.getJoinType())) { - joinConditions.add(predicate); - } else { - filterConditions.add(predicate); - } - }); + // TODO: predicate slotReference should be not nullable. + for (Expression predicate : predicates) { + if (convertJoinCondition(predicate, leftInput, rightInput, join.getJoinType())) { + joinConditions.add(predicate); + } else { + filterPredicates.add(predicate); + } + } List leftPredicates = Lists.newArrayList(); List rightPredicates = Lists.newArrayList(); - - for (Expression p : filterConditions) { - Set slots = p.getInputSlots(); + List remainingPredicates = Lists.newArrayList(); + for (Expression p : filterPredicates) { + Set slots = p.collect(SlotReference.class::isInstance); if (slots.isEmpty()) { leftPredicates.add(p); rightPredicates.add(p); @@ -123,50 +124,47 @@ public class PushPredicatesThroughJoin extends OneRewriteRuleFactory { } if (leftInput.containsAll(slots) && COULD_PUSH_THROUGH_LEFT.contains(join.getJoinType())) { leftPredicates.add(p); - } - if (rightInput.containsAll(slots) && COULD_PUSH_THROUGH_RIGHT.contains(join.getJoinType())) { + } else if (rightInput.containsAll(slots) && COULD_PUSH_THROUGH_RIGHT.contains(join.getJoinType())) { rightPredicates.add(p); + } else { + remainingPredicates.add(p); } } - filterConditions.removeAll(leftPredicates); - filterConditions.removeAll(rightPredicates); join.getOtherJoinCondition().map(joinConditions::add); - return PlanUtils.filterOrSelf(filterConditions, - pushDownPredicate(join, joinConditions, leftPredicates, rightPredicates)); - }).toRule(RuleType.PUSH_DOWN_PREDICATE_THROUGH_JOIN); + return PlanUtils.filterOrSelf(remainingPredicates, + new LogicalJoin<>(join.getJoinType(), + join.getHashJoinConjuncts(), + ExpressionUtils.optionalAnd(joinConditions), + PlanUtils.filterOrSelf(leftPredicates, join.left()), + PlanUtils.filterOrSelf(rightPredicates, join.right()))); + }).toRule(RuleType.PUSHDOWN_FILTER_THROUGH_JOIN); } - private Plan pushDownPredicate(LogicalJoin join, - List joinConditions, List leftPredicates, List rightPredicates) { - // todo expr should optimize again using expr rewrite - Plan leftPlan = PlanUtils.filterOrSelf(leftPredicates, join.left()); - Plan rightPlan = PlanUtils.filterOrSelf(rightPredicates, join.right()); - - return new LogicalJoin<>(join.getJoinType(), join.getHashJoinConjuncts(), - ExpressionUtils.optionalAnd(joinConditions), leftPlan, rightPlan); - } - - private Expression getJoinCondition(Expression predicate, Set leftOutputs, Set rightOutputs) { - if (!(predicate instanceof ComparisonPredicate)) { - return null; + private boolean convertJoinCondition(Expression predicate, Set leftOutputs, Set rightOutputs, + JoinType joinType) { + if (!COULD_PUSH_INSIDE.contains(joinType)) { + return false; + } + if (!(predicate instanceof EqualTo)) { + return false; } - ComparisonPredicate comparison = (ComparisonPredicate) predicate; + EqualTo equalTo = (EqualTo) predicate; - if (!(comparison instanceof EqualTo)) { - return null; + Set leftSlots = equalTo.left().collect(SlotReference.class::isInstance); + Set rightSlots = equalTo.right().collect(SlotReference.class::isInstance); + + if (leftSlots.size() == 0 || rightSlots.size() == 0) { + return false; } - Set leftSlots = comparison.left().getInputSlots(); - Set rightSlots = comparison.right().getInputSlots(); - if ((leftOutputs.containsAll(leftSlots) && rightOutputs.containsAll(rightSlots)) || (leftOutputs.containsAll(rightSlots) && rightOutputs.containsAll(leftSlots))) { - return predicate; + return true; } - return null; + return false; } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PushDownJoinOtherCondition.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PushdownJoinOtherCondition.java similarity index 96% rename from fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PushDownJoinOtherCondition.java rename to fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PushdownJoinOtherCondition.java index 88744fbe65..52e60aa102 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PushDownJoinOtherCondition.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PushdownJoinOtherCondition.java @@ -37,7 +37,7 @@ import java.util.Set; /** * Push the other join conditions in LogicalJoin to children. */ -public class PushDownJoinOtherCondition extends OneRewriteRuleFactory { +public class PushdownJoinOtherCondition extends OneRewriteRuleFactory { private static final ImmutableList PUSH_DOWN_LEFT_VALID_TYPE = ImmutableList.of( JoinType.INNER_JOIN, JoinType.LEFT_SEMI_JOIN, @@ -90,7 +90,7 @@ public class PushDownJoinOtherCondition extends OneRewriteRuleFactory { return new LogicalJoin<>(join.getJoinType(), join.getHashJoinConjuncts(), ExpressionUtils.optionalAnd(otherConjuncts), left, right); - }).toRule(RuleType.PUSH_DOWN_JOIN_OTHER_CONDITION); + }).toRule(RuleType.PUSHDOWN_JOIN_OTHER_CONDITION); } private boolean allCoveredBy(Expression predicate, Set inputSlotSet) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/SlotReference.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/SlotReference.java index 64f5fd8f58..326750d1cb 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/SlotReference.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/SlotReference.java @@ -37,10 +37,9 @@ public class SlotReference extends Slot { // TODO: we should distinguish the name is alias or column name, and the column name should contains // `cluster:db`.`table`.`column` private final String name; - private final List qualifier; private final DataType dataType; private final boolean nullable; - + private final List qualifier; private final Column column; public SlotReference(String name, DataType dataType) { diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/postprocess/RuntimeFilterTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/postprocess/RuntimeFilterTest.java index ce2674a854..7b0fe88971 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/postprocess/RuntimeFilterTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/postprocess/RuntimeFilterTest.java @@ -17,7 +17,7 @@ package org.apache.doris.nereids.postprocess; -import org.apache.doris.common.AnalysisException; +import org.apache.doris.common.Pair; import org.apache.doris.nereids.datasets.ssb.SSBTestBase; import org.apache.doris.nereids.datasets.ssb.SSBUtils; import org.apache.doris.nereids.glue.translator.PhysicalPlanTranslator; @@ -28,6 +28,7 @@ import org.apache.doris.nereids.trees.plans.physical.PhysicalPlan; import org.apache.doris.nereids.trees.plans.physical.RuntimeFilter; import org.apache.doris.nereids.util.PlanChecker; +import com.google.common.collect.ImmutableList; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; @@ -44,99 +45,110 @@ public class RuntimeFilterTest extends SSBTestBase { } @Test - public void testGenerateRuntimeFilter() throws AnalysisException { + public void testGenerateRuntimeFilter() { String sql = "SELECT * FROM lineorder JOIN customer on c_custkey = lo_custkey"; List filters = getRuntimeFilters(sql).get(); - Assertions.assertTrue(filters.size() == 1 - && checkRuntimeFilterExprs(filters, "c_custkey", "lo_custkey")); + Assertions.assertEquals(1, filters.size()); + checkRuntimeFilterExprs(filters, ImmutableList.of( + Pair.of("c_custkey", "lo_custkey"))); } @Test - public void testGenerateRuntimeFilterByIllegalSrcExpr() throws AnalysisException { + public void testGenerateRuntimeFilterByIllegalSrcExpr() { String sql = "SELECT * FROM lineorder JOIN customer on c_custkey = c_custkey"; List filters = getRuntimeFilters(sql).get(); Assertions.assertEquals(0, filters.size()); } @Test - public void testComplexExpressionToRuntimeFilter() throws AnalysisException { + public void testComplexExpressionToRuntimeFilter() { String sql = "SELECT * FROM supplier JOIN customer on c_name = s_name and s_city = c_city and s_nation = c_nation"; List filters = getRuntimeFilters(sql).get(); - Assertions.assertTrue(filters.size() == 3 - && checkRuntimeFilterExprs(filters, "c_name", "s_name", "c_city", "s_city", "c_nation", "s_nation")); + Assertions.assertEquals(3, filters.size()); + checkRuntimeFilterExprs(filters, ImmutableList.of( + Pair.of("c_name", "s_name"), + Pair.of("c_city", "s_city"), + Pair.of("c_nation", "s_nation"))); } @Test - public void testNestedJoinGenerateRuntimeFilter() throws AnalysisException { + public void testNestedJoinGenerateRuntimeFilter() { String sql = SSBUtils.Q4_1; List filters = getRuntimeFilters(sql).get(); - Assertions.assertTrue(filters.size() == 4 - && checkRuntimeFilterExprs(filters, "p_partkey", "lo_partkey", "s_suppkey", "lo_suppkey", - "c_custkey", "lo_custkey", "lo_orderdate", "d_datekey")); + Assertions.assertEquals(4, filters.size()); + checkRuntimeFilterExprs(filters, ImmutableList.of( + Pair.of("p_partkey", "lo_partkey"), Pair.of("s_suppkey", "lo_suppkey"), + Pair.of("c_custkey", "lo_custkey"), Pair.of("lo_orderdate", "d_datekey"))); } @Test - public void testSubTreeInUnsupportedJoinType() throws AnalysisException { + public void testSubTreeInUnsupportedJoinType() { String sql = "select c_custkey" + " from (select lo_custkey from lineorder inner join dates on lo_orderdate = d_datekey) a" + " left outer join (select c_custkey from customer inner join supplier on c_custkey = s_suppkey) b" + " on b.c_custkey = a.lo_custkey"; List filters = getRuntimeFilters(sql).get(); - Assertions.assertTrue(filters.size() == 2 - && checkRuntimeFilterExprs(filters, "d_datekey", "lo_orderdate", "s_suppkey", "c_custkey")); + Assertions.assertEquals(2, filters.size()); + checkRuntimeFilterExprs(filters, ImmutableList.of( + Pair.of("d_datekey", "lo_orderdate"), Pair.of("s_suppkey", "c_custkey"))); } @Test - public void testPushDownEncounterUnsupportedJoinType() throws AnalysisException { + public void testPushDownEncounterUnsupportedJoinType() { String sql = "select c_custkey" + " from (select lo_custkey from lineorder left outer join dates on lo_orderdate = d_datekey) a" + " inner join (select c_custkey from customer inner join supplier on c_custkey = s_suppkey) b" + " on b.c_custkey = a.lo_custkey"; List filters = getRuntimeFilters(sql).get(); - Assertions.assertTrue(filters.size() == 1 - && checkRuntimeFilterExprs(filters, "s_suppkey", "c_custkey")); + Assertions.assertEquals(1, filters.size()); + checkRuntimeFilterExprs(filters, ImmutableList.of( + Pair.of("s_suppkey", "c_custkey"))); } @Test - public void testPushDownThroughAggNode() throws AnalysisException { + public void testPushDownThroughAggNode() { String sql = "select profit" + " from (select lo_custkey, sum(lo_revenue - lo_supplycost) as profit from lineorder inner join dates" + " on lo_orderdate = d_datekey group by lo_custkey) a" + " inner join (select c_custkey from customer inner join supplier on c_custkey = s_suppkey) b" + " on b.c_custkey = a.lo_custkey"; List filters = getRuntimeFilters(sql).get(); - Assertions.assertTrue(filters.size() == 3 - && checkRuntimeFilterExprs(filters, "c_custkey", "lo_custkey", "d_datekey", "lo_orderdate", - "s_suppkey", "c_custkey")); + Assertions.assertEquals(3, filters.size()); + checkRuntimeFilterExprs(filters, ImmutableList.of( + Pair.of("c_custkey", "lo_custkey"), Pair.of("d_datekey", "lo_orderdate"), + Pair.of("s_suppkey", "c_custkey"))); } @Test - public void testDoNotPushDownThroughAggFunction() throws AnalysisException { + public void testDoNotPushDownThroughAggFunction() { String sql = "select profit" + " from (select lo_custkey, sum(lo_revenue - lo_supplycost) as profit from lineorder inner join dates" + " on lo_orderdate = d_datekey group by lo_custkey) a" + " inner join (select sum(c_custkey) c_custkey from customer inner join supplier on c_custkey = s_suppkey group by s_suppkey) b" + " on b.c_custkey = a.lo_custkey"; List filters = getRuntimeFilters(sql).get(); - Assertions.assertTrue(filters.size() == 2 - && checkRuntimeFilterExprs(filters, "d_datekey", "lo_orderdate", "s_suppkey", "c_custkey")); + Assertions.assertEquals(2, filters.size()); + checkRuntimeFilterExprs(filters, ImmutableList.of( + Pair.of("d_datekey", "lo_orderdate"), Pair.of("s_suppkey", "c_custkey"))); } @Test - public void testCrossJoin() throws AnalysisException { + public void testCrossJoin() { String sql = "select c_custkey, lo_custkey from lineorder, customer where lo_custkey = c_custkey"; List filters = getRuntimeFilters(sql).get(); - Assertions.assertTrue(filters.size() == 1 - && checkRuntimeFilterExprs(filters, "c_custkey", "lo_custkey")); + Assertions.assertEquals(1, filters.size()); + checkRuntimeFilterExprs(filters, ImmutableList.of( + Pair.of("c_custkey", "lo_custkey"))); } @Test - public void testSubQueryAlias() throws AnalysisException { + public void testSubQueryAlias() { String sql = "select c_custkey, lo_custkey from lineorder l, customer c where c.c_custkey = l.lo_custkey"; List filters = getRuntimeFilters(sql).get(); - Assertions.assertTrue(filters.size() == 1 - && checkRuntimeFilterExprs(filters, "c_custkey", "lo_custkey")); + Assertions.assertEquals(1, filters.size()); + checkRuntimeFilterExprs(filters, ImmutableList.of( + Pair.of("c_custkey", "lo_custkey"))); } @Test @@ -165,35 +177,39 @@ public class RuntimeFilterTest extends SSBTestBase { + " on t1.p_partkey = t2.lo_partkey\n" + " order by t1.lo_custkey, t1.p_partkey, t2.s_suppkey, t2.c_custkey, t2.lo_orderkey"; List filters = getRuntimeFilters(sql).get(); - Assertions.assertTrue(filters.size() == 4 - && checkRuntimeFilterExprs(filters, "lo_partkey", "p_partkey", "lo_partkey", "p_partkey", - "c_region", "s_region", "lo_custkey", "c_custkey")); + Assertions.assertEquals(4, filters.size()); + checkRuntimeFilterExprs(filters, ImmutableList.of( + Pair.of("lo_partkey", "p_partkey"), Pair.of("lo_partkey", "p_partkey"), + Pair.of("c_region", "s_region"), Pair.of("lo_custkey", "c_custkey"))); } @Test - public void testPushDownThroughJoin() throws AnalysisException { + public void testPushDownThroughJoin() { String sql = "select c_custkey from (select c_custkey from (select lo_custkey from lineorder inner join dates" + " on lo_orderdate = d_datekey) a" + " inner join (select c_custkey from customer inner join supplier on c_custkey = s_suppkey) b" + " on b.c_custkey = a.lo_custkey) c inner join (select lo_custkey from customer inner join lineorder" + " on c_custkey = lo_custkey) d on c.c_custkey = d.lo_custkey"; List filters = getRuntimeFilters(sql).get(); - Assertions.assertTrue(filters.size() == 5 - && checkRuntimeFilterExprs(filters, "lo_custkey", "c_custkey", "c_custkey", "lo_custkey", - "d_datekey", "lo_orderdate", "s_suppkey", "c_custkey", "lo_custkey", "c_custkey")); + Assertions.assertEquals(5, filters.size()); + checkRuntimeFilterExprs(filters, ImmutableList.of( + Pair.of("lo_custkey", "c_custkey"), Pair.of("c_custkey", "lo_custkey"), + Pair.of("d_datekey", "lo_orderdate"), Pair.of("s_suppkey", "c_custkey"), + Pair.of("lo_custkey", "c_custkey"))); } @Test - public void testPushDownThroughUnsupportedJoinType() throws AnalysisException { + public void testPushDownThroughUnsupportedJoinType() { String sql = "select c_custkey from (select c_custkey from (select lo_custkey from lineorder inner join dates" + " on lo_orderdate = d_datekey) a" + " inner join (select c_custkey from customer left outer join supplier on c_custkey = s_suppkey) b" + " on b.c_custkey = a.lo_custkey) c inner join (select lo_custkey from customer inner join lineorder" + " on c_custkey = lo_custkey) d on c.c_custkey = d.lo_custkey"; List filters = getRuntimeFilters(sql).get(); - Assertions.assertTrue(filters.size() == 3 - && checkRuntimeFilterExprs(filters, "c_custkey", "lo_custkey", "d_datekey", "lo_orderdate", - "lo_custkey", "c_custkey")); + Assertions.assertEquals(3, filters.size()); + checkRuntimeFilterExprs(filters, ImmutableList.of( + Pair.of("c_custkey", "lo_custkey"), Pair.of("d_datekey", "lo_orderdate"), + Pair.of("lo_custkey", "c_custkey"))); } private Optional> getRuntimeFilters(String sql) { @@ -210,18 +226,11 @@ public class RuntimeFilterTest extends SSBTestBase { return Optional.of(filters); } - private boolean checkRuntimeFilterExprs(List filters, String... colNames) { - int idx = 0; - for (RuntimeFilter filter : filters) { - if (!checkRuntimeFilterExpr(filter, colNames[idx++], colNames[idx++])) { - return false; - } + private void checkRuntimeFilterExprs(List filters, List> colNames) { + Assertions.assertEquals(filters.size(), colNames.size()); + for (int i = 0; i < filters.size(); i++) { + Assertions.assertTrue(filters.get(i).getSrcExpr().toSql().equals(colNames.get(i).first) + && filters.get(i).getTargetExpr().toSql().equals(colNames.get(i).second)); } - return true; - } - - private boolean checkRuntimeFilterExpr(RuntimeFilter filter, String srcColName, String targetColName) { - return filter.getSrcExpr().toSql().equals(srcColName) - && filter.getTargetExpr().toSql().equals(targetColName); } } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PushFilterInsideJoinTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PushFilterInsideJoinTest.java new file mode 100644 index 0000000000..842523ef4e --- /dev/null +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PushFilterInsideJoinTest.java @@ -0,0 +1,54 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.rewrite.logical; + +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.GreaterThan; +import org.apache.doris.nereids.trees.plans.JoinType; +import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; +import org.apache.doris.nereids.trees.plans.logical.LogicalPlan; +import org.apache.doris.nereids.util.LogicalPlanBuilder; +import org.apache.doris.nereids.util.MemoTestUtils; +import org.apache.doris.nereids.util.PatternMatchSupported; +import org.apache.doris.nereids.util.PlanChecker; +import org.apache.doris.nereids.util.PlanConstructor; + +import org.junit.jupiter.api.Test; + +class PushFilterInsideJoinTest implements PatternMatchSupported { + + private final LogicalOlapScan scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0); + private final LogicalOlapScan scan2 = PlanConstructor.newLogicalOlapScan(1, "t2", 0); + + @Test + void testPushInside() { + Expression predicates = new GreaterThan(scan1.getOutput().get(1), scan2.getOutput().get(1)); + + LogicalPlan plan = new LogicalPlanBuilder(scan1) + .hashJoinEmptyOn(scan2, JoinType.CROSS_JOIN) + .filter(predicates) + .build(); + + PlanChecker.from(MemoTestUtils.createConnectContext(), plan) + .applyTopDown(PushFilterInsideJoin.INSTANCE) + .printlnTree() + .matchesFromRoot( + logicalJoin().when(join -> join.getOtherJoinCondition().get().equals(predicates)) + ); + } +} diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PushPredicateThroughJoinTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PushPredicateThroughJoinTest.java deleted file mode 100644 index 3374613f73..0000000000 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PushPredicateThroughJoinTest.java +++ /dev/null @@ -1,208 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -package org.apache.doris.nereids.rules.rewrite.logical; - -import org.apache.doris.nereids.memo.Group; -import org.apache.doris.nereids.memo.Memo; -import org.apache.doris.nereids.trees.expressions.Add; -import org.apache.doris.nereids.trees.expressions.EqualTo; -import org.apache.doris.nereids.trees.expressions.Expression; -import org.apache.doris.nereids.trees.expressions.GreaterThan; -import org.apache.doris.nereids.trees.expressions.Subtract; -import org.apache.doris.nereids.trees.expressions.literal.Literal; -import org.apache.doris.nereids.trees.plans.JoinType; -import org.apache.doris.nereids.trees.plans.Plan; -import org.apache.doris.nereids.trees.plans.logical.LogicalFilter; -import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; -import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; -import org.apache.doris.nereids.trees.plans.logical.LogicalProject; -import org.apache.doris.nereids.util.ExpressionUtils; -import org.apache.doris.nereids.util.PlanConstructor; -import org.apache.doris.nereids.util.PlanRewriter; -import org.apache.doris.qe.ConnectContext; - -import com.google.common.collect.ImmutableList; -import com.google.common.collect.Lists; -import org.junit.jupiter.api.Assertions; -import org.junit.jupiter.api.BeforeAll; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.TestInstance; - -import java.util.Optional; - -/** - * plan rewrite ut. - */ -@TestInstance(TestInstance.Lifecycle.PER_CLASS) -public class PushPredicateThroughJoinTest { - - private Plan rStudent; - private Plan rScore; - - /** - * ut before. - */ - @BeforeAll - public final void beforeAll() { - rStudent = new LogicalOlapScan(PlanConstructor.getNextRelationId(), PlanConstructor.student, ImmutableList.of("")); - rScore = new LogicalOlapScan(PlanConstructor.getNextRelationId(), PlanConstructor.score, ImmutableList.of("")); - } - - @Test - public void oneSide() { - oneSide(JoinType.CROSS_JOIN, false); - oneSide(JoinType.INNER_JOIN, false); - oneSide(JoinType.LEFT_OUTER_JOIN, false); - oneSide(JoinType.LEFT_SEMI_JOIN, false); - oneSide(JoinType.LEFT_ANTI_JOIN, false); - oneSide(JoinType.RIGHT_OUTER_JOIN, true); - oneSide(JoinType.RIGHT_SEMI_JOIN, true); - oneSide(JoinType.RIGHT_ANTI_JOIN, true); - } - - private void oneSide(JoinType joinType, boolean testRight) { - - Expression whereCondition1 = new GreaterThan(rStudent.getOutput().get(1), Literal.of(18)); - Expression whereCondition2 = new GreaterThan(rStudent.getOutput().get(1), Literal.of(50)); - Expression whereCondition = ExpressionUtils.and(whereCondition1, whereCondition2); - - Plan left = rStudent; - Plan right = rScore; - if (testRight) { - left = rScore; - right = rStudent; - } - - Plan join = new LogicalJoin<>(joinType, Lists.newArrayList(), Optional.empty(), left, right); - Plan filter = new LogicalFilter<>(whereCondition, join); - Plan root = new LogicalProject<>(Lists.newArrayList(), filter); - - Memo memo = rewrite(root); - Group rootGroup = memo.getRoot(); - - Plan shouldJoin = rootGroup.getLogicalExpression().child(0).getLogicalExpression().getPlan(); - Plan shouldFilter = rootGroup.getLogicalExpression().child(0).getLogicalExpression() - .child(0).getLogicalExpression().getPlan(); - Plan shouldScan = rootGroup.getLogicalExpression().child(0).getLogicalExpression() - .child(1).getLogicalExpression().getPlan(); - if (testRight) { - shouldFilter = rootGroup.getLogicalExpression().child(0).getLogicalExpression() - .child(1).getLogicalExpression().getPlan(); - shouldScan = rootGroup.getLogicalExpression().child(0).getLogicalExpression() - .child(0).getLogicalExpression().getPlan(); - } - - Assertions.assertTrue(shouldJoin instanceof LogicalJoin); - Assertions.assertTrue(shouldFilter instanceof LogicalFilter); - Assertions.assertTrue(shouldScan instanceof LogicalOlapScan); - LogicalFilter actualFilter = (LogicalFilter) shouldFilter; - - Assertions.assertEquals(whereCondition, actualFilter.getPredicates()); - } - - @Test - public void bothSideToBothSide() { - bothSideToBothSide(JoinType.INNER_JOIN); - } - - private void bothSideToBothSide(JoinType joinType) { - - Expression bothSideEqualTo = new EqualTo(new Add(rStudent.getOutput().get(0), Literal.of(1)), - new Subtract(rScore.getOutput().get(0), Literal.of(2))); - Expression leftSide = new GreaterThan(rStudent.getOutput().get(1), Literal.of(18)); - Expression rightSide = new GreaterThan(rScore.getOutput().get(2), Literal.of(60)); - Expression whereCondition = ExpressionUtils.and(bothSideEqualTo, leftSide, rightSide); - - Plan join = new LogicalJoin<>(joinType, Lists.newArrayList(), Optional.empty(), rStudent, rScore); - Plan filter = new LogicalFilter<>(whereCondition, join); - Plan root = new LogicalProject<>(Lists.newArrayList(), filter); - - Memo memo = rewrite(root); - Group rootGroup = memo.getRoot(); - - Plan shouldJoin = rootGroup.getLogicalExpression().child(0).getLogicalExpression().getPlan(); - Plan leftFilter = rootGroup.getLogicalExpression().child(0).getLogicalExpression() - .child(0).getLogicalExpression().getPlan(); - Plan rightFilter = rootGroup.getLogicalExpression().child(0).getLogicalExpression() - .child(1).getLogicalExpression().getPlan(); - - Assertions.assertTrue(shouldJoin instanceof LogicalJoin); - Assertions.assertTrue(leftFilter instanceof LogicalFilter); - Assertions.assertTrue(rightFilter instanceof LogicalFilter); - LogicalJoin actualJoin = (LogicalJoin) shouldJoin; - LogicalFilter actualLeft = (LogicalFilter) leftFilter; - LogicalFilter actualRight = (LogicalFilter) rightFilter; - Assertions.assertEquals(bothSideEqualTo, actualJoin.getOtherJoinCondition().get()); - Assertions.assertEquals(leftSide, actualLeft.getPredicates()); - Assertions.assertEquals(rightSide, actualRight.getPredicates()); - } - - @Test - public void bothSideToOneSide() { - bothSideToOneSide(JoinType.LEFT_OUTER_JOIN, false); - bothSideToOneSide(JoinType.LEFT_ANTI_JOIN, false); - bothSideToOneSide(JoinType.LEFT_SEMI_JOIN, false); - bothSideToOneSide(JoinType.RIGHT_OUTER_JOIN, true); - bothSideToOneSide(JoinType.RIGHT_ANTI_JOIN, true); - bothSideToOneSide(JoinType.RIGHT_SEMI_JOIN, true); - } - - private void bothSideToOneSide(JoinType joinType, boolean testRight) { - - Expression pushSide = new GreaterThan(rStudent.getOutput().get(1), Literal.of(18)); - Expression reserveSide = new GreaterThan(rScore.getOutput().get(2), Literal.of(60)); - Expression whereCondition = ExpressionUtils.and(pushSide, reserveSide); - - Plan left = rStudent; - Plan right = rScore; - if (testRight) { - left = rScore; - right = rStudent; - } - - Plan join = new LogicalJoin<>(joinType, Lists.newArrayList(), Optional.empty(), left, right); - Plan filter = new LogicalFilter<>(whereCondition, join); - Plan root = new LogicalProject<>(Lists.newArrayList(), filter); - - Memo memo = rewrite(root); - Group rootGroup = memo.getRoot(); - - Plan shouldJoin = rootGroup.getLogicalExpression().child(0).getLogicalExpression() - .child(0).getLogicalExpression().getPlan(); - Plan shouldFilter = rootGroup.getLogicalExpression().child(0).getLogicalExpression() - .child(0).getLogicalExpression().child(0).getLogicalExpression().getPlan(); - Plan shouldScan = rootGroup.getLogicalExpression().child(0).getLogicalExpression() - .child(0).getLogicalExpression().child(1).getLogicalExpression().getPlan(); - if (testRight) { - shouldFilter = rootGroup.getLogicalExpression().child(0).getLogicalExpression() - .child(0).getLogicalExpression().child(1).getLogicalExpression().getPlan(); - shouldScan = rootGroup.getLogicalExpression().child(0).getLogicalExpression() - .child(0).getLogicalExpression().child(0).getLogicalExpression().getPlan(); - } - - Assertions.assertTrue(shouldJoin instanceof LogicalJoin); - Assertions.assertTrue(shouldFilter instanceof LogicalFilter); - Assertions.assertTrue(shouldScan instanceof LogicalOlapScan); - LogicalFilter actualFilter = (LogicalFilter) shouldFilter; - Assertions.assertEquals(pushSide, actualFilter.getPredicates()); - } - - private Memo rewrite(Plan plan) { - return PlanRewriter.topDownRewriteMemo(plan, new ConnectContext(), new PushPredicatesThroughJoin()); - } -} diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PushDownExpressionsInHashConditionTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PushdownExpressionsInHashConditionTest.java similarity index 96% rename from fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PushDownExpressionsInHashConditionTest.java rename to fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PushdownExpressionsInHashConditionTest.java index 2033b68668..b60d1acad1 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PushDownExpressionsInHashConditionTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PushdownExpressionsInHashConditionTest.java @@ -32,7 +32,7 @@ import org.junit.jupiter.api.Test; import java.util.List; -public class PushDownExpressionsInHashConditionTest extends TestWithFeService implements PatternMatchSupported { +public class PushdownExpressionsInHashConditionTest extends TestWithFeService implements PatternMatchSupported { private final List testSql = ImmutableList.of( "SELECT * FROM T1 JOIN T2 ON T1.ID + 1 = T2.ID + 2 AND T1.ID + 1 > 2", "SELECT * FROM (SELECT * FROM T1) X JOIN (SELECT * FROM T2) Y ON X.ID + 1 = Y.ID + 2 AND X.ID + 1 > 2", @@ -98,7 +98,7 @@ public class PushDownExpressionsInHashConditionTest extends TestWithFeService im PlanChecker.from(connectContext) .analyze("SELECT * FROM T1 JOIN T2 ON T1.ID + 1 = T2.ID + 2 AND T1.ID + 1 > 2") .applyTopDown(new FindHashConditionForJoin()) - .applyTopDown(new PushDownExpressionsInHashCondition()) + .applyTopDown(new PushdownExpressionsInHashCondition()) .matches( logicalProject( logicalJoin( @@ -119,7 +119,7 @@ public class PushDownExpressionsInHashConditionTest extends TestWithFeService im .analyze( "SELECT * FROM (SELECT * FROM T1) X JOIN (SELECT * FROM T2) Y ON X.ID + 1 = Y.ID + 2 AND X.ID + 1 > 2") .applyTopDown(new FindHashConditionForJoin()) - .applyTopDown(new PushDownExpressionsInHashCondition()) + .applyTopDown(new PushdownExpressionsInHashCondition()) .matches( logicalProject( logicalJoin( @@ -144,7 +144,7 @@ public class PushDownExpressionsInHashConditionTest extends TestWithFeService im .analyze( "SELECT * FROM T1 JOIN (SELECT ID, SUM(SCORE) SCORE FROM T2 GROUP BY ID) T ON T1.ID + 1 = T.ID AND T.SCORE = T1.SCORE + 10") .applyTopDown(new FindHashConditionForJoin()) - .applyTopDown(new PushDownExpressionsInHashCondition()) + .applyTopDown(new PushdownExpressionsInHashCondition()) .matches( logicalProject( logicalJoin( @@ -167,7 +167,7 @@ public class PushDownExpressionsInHashConditionTest extends TestWithFeService im .analyze( "SELECT * FROM T1 JOIN (SELECT ID, SUM(SCORE) SCORE FROM T2 GROUP BY ID ORDER BY ID) T ON T1.ID + 1 = T.ID AND T.SCORE = T1.SCORE + 10") .applyTopDown(new FindHashConditionForJoin()) - .applyTopDown(new PushDownExpressionsInHashCondition()) + .applyTopDown(new PushdownExpressionsInHashCondition()) .matches( logicalProject( logicalJoin( diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PushDownPredicateThroughAggregationTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PushdownFilterThroughAggregationTest.java similarity index 98% rename from fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PushDownPredicateThroughAggregationTest.java rename to fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PushdownFilterThroughAggregationTest.java index bca195270b..208ffb5b44 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PushDownPredicateThroughAggregationTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PushdownFilterThroughAggregationTest.java @@ -47,7 +47,7 @@ import org.junit.jupiter.api.Test; import java.util.List; -public class PushDownPredicateThroughAggregationTest { +public class PushdownFilterThroughAggregationTest { /** * origin plan: @@ -187,6 +187,6 @@ public class PushDownPredicateThroughAggregationTest { } private Memo rewrite(Plan plan) { - return PlanRewriter.topDownRewriteMemo(plan, new ConnectContext(), new PushPredicateThroughAggregation()); + return PlanRewriter.topDownRewriteMemo(plan, new ConnectContext(), new PushdownFilterThroughAggregation()); } } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PushdownFilterThroughJoinTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PushdownFilterThroughJoinTest.java new file mode 100644 index 0000000000..58e9760e9c --- /dev/null +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PushdownFilterThroughJoinTest.java @@ -0,0 +1,218 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.rewrite.logical; + +import org.apache.doris.nereids.trees.expressions.Add; +import org.apache.doris.nereids.trees.expressions.EqualTo; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.GreaterThan; +import org.apache.doris.nereids.trees.expressions.Subtract; +import org.apache.doris.nereids.trees.expressions.literal.Literal; +import org.apache.doris.nereids.trees.plans.JoinType; +import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; +import org.apache.doris.nereids.trees.plans.logical.LogicalPlan; +import org.apache.doris.nereids.util.ExpressionUtils; +import org.apache.doris.nereids.util.LogicalPlanBuilder; +import org.apache.doris.nereids.util.MemoTestUtils; +import org.apache.doris.nereids.util.PatternMatchSupported; +import org.apache.doris.nereids.util.PlanChecker; +import org.apache.doris.nereids.util.PlanConstructor; + +import com.google.common.collect.ImmutableList; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInstance; + +/** + * PushdownFilterThroughJoinTest UT. + */ +@TestInstance(TestInstance.Lifecycle.PER_CLASS) +public class PushdownFilterThroughJoinTest implements PatternMatchSupported { + + private LogicalPlan rStudent; + private LogicalPlan rScore; + + /** + * ut before. + */ + @BeforeAll + public final void beforeAll() { + rStudent = new LogicalOlapScan(PlanConstructor.getNextRelationId(), PlanConstructor.student, + ImmutableList.of("")); + rScore = new LogicalOlapScan(PlanConstructor.getNextRelationId(), PlanConstructor.score, ImmutableList.of("")); + } + + public void testLeft(JoinType joinType) { + Expression whereCondition1 = new GreaterThan(rStudent.getOutput().get(1), Literal.of(18)); + Expression whereCondition2 = new GreaterThan(rStudent.getOutput().get(1), Literal.of(50)); + Expression whereCondition = ExpressionUtils.and(whereCondition1, whereCondition2); + + LogicalPlan plan = new LogicalPlanBuilder(rStudent) + .hashJoinEmptyOn(rScore, joinType) + .filter(whereCondition) + .build(); + + PlanChecker.from(MemoTestUtils.createConnectContext(), plan) + .applyTopDown(PushdownFilterThroughJoin.INSTANCE) + .matchesFromRoot( + logicalJoin( + logicalFilter(logicalOlapScan()) + .when(filter -> filter.getPredicates().equals(whereCondition)), + logicalOlapScan() + ) + ); + } + + public void testRight(JoinType joinType) { + Expression whereCondition1 = new GreaterThan(rStudent.getOutput().get(1), Literal.of(18)); + Expression whereCondition2 = new GreaterThan(rStudent.getOutput().get(1), Literal.of(50)); + Expression whereCondition = ExpressionUtils.and(whereCondition1, whereCondition2); + + LogicalPlan plan = new LogicalPlanBuilder(rScore) + .hashJoinEmptyOn(rStudent, joinType) + .filter(whereCondition) + .build(); + + PlanChecker.from(MemoTestUtils.createConnectContext(), plan) + .applyTopDown(PushdownFilterThroughJoin.INSTANCE) + .matchesFromRoot( + logicalJoin( + logicalOlapScan(), + logicalFilter(logicalOlapScan()) + .when(filter -> filter.getPredicates().equals(whereCondition)) + ) + ); + } + + @Test + public void oneSide() { + testLeft(JoinType.CROSS_JOIN); + testLeft(JoinType.INNER_JOIN); + testLeft(JoinType.LEFT_OUTER_JOIN); + testLeft(JoinType.LEFT_SEMI_JOIN); + testLeft(JoinType.LEFT_ANTI_JOIN); + testRight(JoinType.RIGHT_OUTER_JOIN); + testRight(JoinType.RIGHT_SEMI_JOIN); + testRight(JoinType.RIGHT_ANTI_JOIN); + } + + @Test + public void bothSideToBothSide() { + bothSideToBothSide(JoinType.INNER_JOIN); + bothSideToBothSide(JoinType.CROSS_JOIN); + } + + private void bothSideToBothSide(JoinType joinType) { + + Expression bothSideEqualTo = new EqualTo(new Add(rStudent.getOutput().get(0), Literal.of(1)), + new Subtract(rScore.getOutput().get(0), Literal.of(2))); + Expression leftSide = new GreaterThan(rStudent.getOutput().get(1), Literal.of(18)); + Expression rightSide = new GreaterThan(rScore.getOutput().get(2), Literal.of(60)); + Expression whereCondition = ExpressionUtils.and(bothSideEqualTo, leftSide, rightSide); + + LogicalPlan plan = new LogicalPlanBuilder(rStudent) + .hashJoinEmptyOn(rScore, joinType) + .filter(whereCondition) + .build(); + + if (joinType.isInnerJoin()) { + PlanChecker.from(MemoTestUtils.createConnectContext(), plan) + .applyTopDown(PushdownFilterThroughJoin.INSTANCE) + .printlnTree() + .matchesFromRoot( + logicalJoin( + logicalFilter(logicalOlapScan()) + .when(filter -> filter.getPredicates().equals(leftSide)), + logicalFilter(logicalOlapScan()) + .when(filter -> filter.getPredicates().equals(rightSide)) + ).when(join -> join.getOtherJoinCondition().get().equals(bothSideEqualTo)) + ); + } + if (joinType.isCrossJoin()) { + PlanChecker.from(MemoTestUtils.createConnectContext(), plan) + .applyTopDown(PushdownFilterThroughJoin.INSTANCE) + .printlnTree() + .matchesFromRoot( + logicalFilter( + logicalJoin( + logicalFilter(logicalOlapScan()) + .when(filter -> filter.getPredicates().equals(leftSide)), + logicalFilter(logicalOlapScan()) + .when(filter -> filter.getPredicates().equals(rightSide)) + ) + ).when(filter -> filter.getPredicates().equals(bothSideEqualTo)) + ); + } + } + + @Test + public void bothSideToOneSide() { + bothSideToLeft(JoinType.LEFT_OUTER_JOIN); + bothSideToLeft(JoinType.LEFT_ANTI_JOIN); + bothSideToLeft(JoinType.LEFT_SEMI_JOIN); + bothSideToRight(JoinType.RIGHT_OUTER_JOIN); + bothSideToRight(JoinType.RIGHT_ANTI_JOIN); + bothSideToRight(JoinType.RIGHT_SEMI_JOIN); + } + + private void bothSideToLeft(JoinType joinType) { + Expression pushSide = new GreaterThan(rStudent.getOutput().get(1), Literal.of(18)); + Expression reserveSide = new GreaterThan(rScore.getOutput().get(2), Literal.of(60)); + Expression whereCondition = ExpressionUtils.and(pushSide, reserveSide); + + LogicalPlan plan = new LogicalPlanBuilder(rStudent) + .hashJoinEmptyOn(rScore, joinType) + .filter(whereCondition) + .build(); + + PlanChecker.from(MemoTestUtils.createConnectContext(), plan) + .applyTopDown(PushdownFilterThroughJoin.INSTANCE) + .matchesFromRoot( + logicalFilter( + logicalJoin( + logicalFilter(logicalOlapScan()) + .when(filter -> filter.getPredicates().equals(pushSide)), + logicalOlapScan() + ) + ) + ); + } + + private void bothSideToRight(JoinType joinType) { + Expression pushSide = new GreaterThan(rStudent.getOutput().get(1), Literal.of(18)); + Expression reserveSide = new GreaterThan(rScore.getOutput().get(2), Literal.of(60)); + Expression whereCondition = ExpressionUtils.and(pushSide, reserveSide); + + LogicalPlan plan = new LogicalPlanBuilder(rScore) + .hashJoinEmptyOn(rStudent, joinType) + .filter(whereCondition) + .build(); + + PlanChecker.from(MemoTestUtils.createConnectContext(), plan) + .applyTopDown(PushdownFilterThroughJoin.INSTANCE) + .matchesFromRoot( + logicalFilter( + logicalJoin( + logicalOlapScan(), + logicalFilter(logicalOlapScan()).when( + filter -> filter.getPredicates().equals(pushSide)) + ) + ) + ); + } +} diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PushDownJoinOtherConditionTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PushdownJoinOtherConditionTest.java similarity index 98% rename from fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PushDownJoinOtherConditionTest.java rename to fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PushdownJoinOtherConditionTest.java index f6f5d664fe..0d7d711c42 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PushDownJoinOtherConditionTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PushdownJoinOtherConditionTest.java @@ -43,7 +43,7 @@ import org.junit.jupiter.api.TestInstance; import java.util.Optional; @TestInstance(TestInstance.Lifecycle.PER_CLASS) -public class PushDownJoinOtherConditionTest { +public class PushdownJoinOtherConditionTest { private Plan rStudent; private Plan rScore; @@ -191,6 +191,6 @@ public class PushDownJoinOtherConditionTest { } private Memo rewrite(Plan plan) { - return PlanRewriter.topDownRewriteMemo(plan, new ConnectContext(), new PushDownJoinOtherCondition()); + return PlanRewriter.topDownRewriteMemo(plan, new ConnectContext(), new PushdownJoinOtherCondition()); } } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/LogicalPlanBuilder.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/LogicalPlanBuilder.java index 338116f612..cfaec47b77 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/LogicalPlanBuilder.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/LogicalPlanBuilder.java @@ -79,6 +79,12 @@ public class LogicalPlanBuilder { return from(join); } + public LogicalPlanBuilder hashJoinEmptyOn(LogicalPlan right, JoinType joinType) { + LogicalJoin join = new LogicalJoin<>(joinType, new ArrayList<>(), + Optional.empty(), this.plan, right); + return from(join); + } + public LogicalPlanBuilder limit(long limit, long offset) { LogicalLimit limitPlan = new LogicalLimit<>(limit, offset, this.plan); return from(limitPlan); diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/PlanChecker.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/PlanChecker.java index 54efa439c8..12bcaeff0d 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/PlanChecker.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/PlanChecker.java @@ -107,6 +107,7 @@ public class PlanChecker { /** * apply a top down rewrite rule if you not care the ruleId + * * @param patternMatcher the rule dsl, such as: logicalOlapScan().then(olapScan -> olapScan) * @return this checker, for call chaining of follow-up check */ @@ -129,6 +130,7 @@ public class PlanChecker { /** * apply a bottom up rewrite rule if you not care the ruleId + * * @param patternMatcher the rule dsl, such as: logicalOlapScan().then(olapScan -> olapScan) * @return this checker, for call chaining of follow-up check */ @@ -163,7 +165,8 @@ public class PlanChecker { PhysicalPlan current = null; loop: for (Rule rule : RuleSet.IMPLEMENTATION_RULES) { - GroupExpressionMatching matching = new GroupExpressionMatching(rule.getPattern(), group.getLogicalExpression()); + GroupExpressionMatching matching = new GroupExpressionMatching(rule.getPattern(), + group.getLogicalExpression()); for (Plan plan : matching) { Plan after = rule.transform(plan, cascadesContext).get(0); if (after instanceof PhysicalPlan) { @@ -346,4 +349,10 @@ public class PlanChecker { public Plan getPlan() { return cascadesContext.getMemo().copyOut(); } + + public PlanChecker printlnTree() { + System.out.println(cascadesContext.getMemo().copyOut().treeString()); + return this; + } + }