[enhancement](Nereids) push filter into join otherJoinCondition (#12842)

This commit is contained in:
jakevin
2022-09-29 16:19:30 +08:00
committed by GitHub
parent 1ae9454771
commit 42729786bf
20 changed files with 489 additions and 342 deletions

View File

@ -699,6 +699,10 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla
crossJoinNode.setChild(0, leftFragment.getPlanRoot());
connectChildFragment(crossJoinNode, 1, leftFragment, rightFragment, context);
leftFragment.setPlanRoot(crossJoinNode);
if (nestedLoopJoin.getOtherJoinCondition().isPresent()) {
ExpressionUtils.extractConjunction(nestedLoopJoin.getOtherJoinCondition().get()).stream()
.map(e -> ExpressionTranslator.translate(e, context)).forEach(crossJoinNode::addConjunct);
}
return leftFragment;
} else {

View File

@ -32,6 +32,7 @@ import org.apache.doris.nereids.rules.rewrite.logical.FindHashConditionForJoin;
import org.apache.doris.nereids.rules.rewrite.logical.LimitPushDown;
import org.apache.doris.nereids.rules.rewrite.logical.NormalizeAggregate;
import org.apache.doris.nereids.rules.rewrite.logical.PruneOlapScanPartition;
import org.apache.doris.nereids.rules.rewrite.logical.PushFilterInsideJoin;
import org.apache.doris.nereids.rules.rewrite.logical.ReorderJoin;
import com.google.common.collect.ImmutableList;
@ -66,6 +67,7 @@ public class NereidsRewriteJobExecutor extends BatchRulesJob {
.add(topDownBatch(ImmutableList.of(new ReorderJoin())))
.add(topDownBatch(ImmutableList.of(new ColumnPruning())))
.add(topDownBatch(RuleSet.PUSH_DOWN_JOIN_CONDITION_RULES, false))
.add(topDownBatch(ImmutableList.of(PushFilterInsideJoin.INSTANCE)))
.add(topDownBatch(ImmutableList.of(new FindHashConditionForJoin())))
.add(topDownBatch(ImmutableList.of(new LimitPushDown())))
.add(topDownBatch(ImmutableList.of(new EliminateLimit())))

View File

@ -22,6 +22,7 @@ import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.physical.PhysicalFilter;
import org.apache.doris.nereids.trees.plans.physical.PhysicalNestedLoopJoin;
import org.apache.doris.nereids.trees.plans.physical.PhysicalProject;
import com.google.common.base.Preconditions;
@ -53,7 +54,9 @@ public class Validator extends PlanPostProcessor {
Plan child = filter.child();
// Forbidden filter-project, we must make filter-project -> project-filter.
Preconditions.checkArgument(!(child instanceof PhysicalProject));
Preconditions.checkState(!(child instanceof PhysicalProject));
// Forbidden filter-cross join, because we put all filter on cross join into its other join condition.
Preconditions.checkState(!(child instanceof PhysicalNestedLoopJoin));
// Check filter is from child output.
Set<Slot> childOutputSet = child.getOutputSet();

View File

@ -42,10 +42,10 @@ import org.apache.doris.nereids.rules.rewrite.logical.EliminateOuter;
import org.apache.doris.nereids.rules.rewrite.logical.MergeConsecutiveFilters;
import org.apache.doris.nereids.rules.rewrite.logical.MergeConsecutiveLimits;
import org.apache.doris.nereids.rules.rewrite.logical.MergeConsecutiveProjects;
import org.apache.doris.nereids.rules.rewrite.logical.PushDownExpressionsInHashCondition;
import org.apache.doris.nereids.rules.rewrite.logical.PushDownJoinOtherCondition;
import org.apache.doris.nereids.rules.rewrite.logical.PushPredicatesThroughJoin;
import org.apache.doris.nereids.rules.rewrite.logical.PushdownExpressionsInHashCondition;
import org.apache.doris.nereids.rules.rewrite.logical.PushdownFilterThroughJoin;
import org.apache.doris.nereids.rules.rewrite.logical.PushdownFilterThroughProject;
import org.apache.doris.nereids.rules.rewrite.logical.PushdownJoinOtherCondition;
import org.apache.doris.nereids.rules.rewrite.logical.PushdownProjectThroughLimit;
import com.google.common.collect.ImmutableList;
@ -72,9 +72,9 @@ public class RuleSet {
.build();
public static final List<RuleFactory> PUSH_DOWN_JOIN_CONDITION_RULES = ImmutableList.of(
new PushDownJoinOtherCondition(),
new PushPredicatesThroughJoin(),
new PushDownExpressionsInHashCondition(),
new PushdownJoinOtherCondition(),
new PushdownFilterThroughJoin(),
new PushdownExpressionsInHashCondition(),
new PushdownProjectThroughLimit(),
new PushdownFilterThroughProject(),
EliminateOuter.INSTANCE,

View File

@ -78,11 +78,15 @@ public enum RuleType {
IN_APPLY_TO_JOIN(RuleTypeClass.REWRITE),
EXISTS_APPLY_TO_JOIN(RuleTypeClass.REWRITE),
// predicate push down rules
PUSH_DOWN_PREDICATE_THROUGH_JOIN(RuleTypeClass.REWRITE),
PUSH_DOWN_JOIN_OTHER_CONDITION(RuleTypeClass.REWRITE),
PUSH_DOWN_PREDICATE_THROUGH_LEFT_SEMI_JOIN(RuleTypeClass.REWRITE),
PUSH_DOWN_PREDICATE_THROUGH_AGGREGATION(RuleTypeClass.REWRITE),
PUSH_DOWN_EXPRESSIONS_IN_HASH_CONDITIONS(RuleTypeClass.REWRITE),
PUSHDOWN_JOIN_OTHER_CONDITION(RuleTypeClass.REWRITE),
PUSHDOWN_PREDICATE_THROUGH_AGGREGATION(RuleTypeClass.REWRITE),
PUSHDOWN_EXPRESSIONS_IN_HASH_CONDITIONS(RuleTypeClass.REWRITE),
// Pushdown filter
PUSHDOWN_FILTER_THROUGH_JOIN(RuleTypeClass.REWRITE),
PUSHDOWN_FILTER_THROUGH_LEFT_SEMI_JOIN(RuleTypeClass.REWRITE),
PUSH_FILTER_INSIDE_JOIN(RuleTypeClass.REWRITE),
PUSHDOWN_FILTER_THROUGH_PROJET(RuleTypeClass.REWRITE),
PUSHDOWN_PROJECT_THROUGHT_LIMIT(RuleTypeClass.REWRITE),
// column prune rules,
COLUMN_PRUNE_AGGREGATION_CHILD(RuleTypeClass.REWRITE),
COLUMN_PRUNE_FILTER_CHILD(RuleTypeClass.REWRITE),
@ -112,9 +116,6 @@ public enum RuleType {
ROLLUP_WITH_OUT_AGG(RuleTypeClass.REWRITE),
OLAP_SCAN_PARTITION_PRUNE(RuleTypeClass.REWRITE),
EXTRACT_SINGLE_TABLE_EXPRESSION_FROM_DISJUNCTION(RuleTypeClass.REWRITE),
// Pushdown filter
PUSHDOWN_FILTER_THROUGH_PROJET(RuleTypeClass.REWRITE),
PUSHDOWN_PROJECT_THROUGHT_LIMIT(RuleTypeClass.REWRITE),
REWRITE_SENTINEL(RuleTypeClass.REWRITE),
// limit push down

View File

@ -0,0 +1,52 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.rewrite.logical;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.util.ExpressionUtils;
import com.google.common.collect.Lists;
import java.util.List;
/**
* Push the predicate in the LogicalFilter to the join children.
*/
public class PushFilterInsideJoin extends OneRewriteRuleFactory {
public static final PushFilterInsideJoin INSTANCE = new PushFilterInsideJoin();
@Override
public Rule build() {
return logicalFilter(logicalJoin())
// TODO: current just handle cross/inner join.
.when(filter -> filter.child().getJoinType().isCrossJoin()
|| filter.child().getJoinType().isInnerJoin())
.then(filter -> {
List<Expression> otherConditions = Lists.newArrayList(filter.getPredicates());
LogicalJoin<GroupPlan, GroupPlan> join = filter.child();
join.getOtherJoinCondition().map(otherConditions::add);
return new LogicalJoin<>(join.getJoinType(), join.getHashJoinConjuncts(),
ExpressionUtils.optionalAnd(otherConditions), join.left(), join.right());
}).toRule(RuleType.PUSH_FILTER_INSIDE_JOIN);
}
}

View File

@ -45,7 +45,7 @@ import java.util.stream.Collectors;
/**
* push down expression which is not slot reference
*/
public class PushDownExpressionsInHashCondition extends OneRewriteRuleFactory {
public class PushdownExpressionsInHashCondition extends OneRewriteRuleFactory {
/*
* rewrite example:
* join(t1.a + 1 = t2.b + 2) join(c = d)
@ -94,7 +94,7 @@ public class PushDownExpressionsInHashCondition extends OneRewriteRuleFactory {
.collect(Collectors.toList()))
.addAll(getOutput(plan, join)).build(), plan))
.collect(Collectors.toList()));
}).toRule(RuleType.PUSH_DOWN_EXPRESSIONS_IN_HASH_CONDITIONS);
}).toRule(RuleType.PUSHDOWN_EXPRESSIONS_IN_HASH_CONDITIONS);
}
private List<Slot> getOutput(Plan plan, LogicalJoin join) {

View File

@ -62,7 +62,7 @@ import java.util.Set;
*
*/
public class PushPredicateThroughAggregation extends OneRewriteRuleFactory {
public class PushdownFilterThroughAggregation extends OneRewriteRuleFactory {
@Override
public Rule build() {
@ -86,7 +86,7 @@ public class PushPredicateThroughAggregation extends OneRewriteRuleFactory {
});
return pushDownPredicate(filter, aggregate, pushDownPredicates, filterPredicates);
}).toRule(RuleType.PUSH_DOWN_PREDICATE_THROUGH_AGGREGATION);
}).toRule(RuleType.PUSHDOWN_PREDICATE_THROUGH_AGGREGATION);
}
private Plan pushDownPredicate(LogicalFilter filter, LogicalAggregate aggregate,

View File

@ -20,13 +20,12 @@ package org.apache.doris.nereids.rules.rewrite.logical;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory;
import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.PlanUtils;
@ -35,13 +34,13 @@ import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import java.util.List;
import java.util.Objects;
import java.util.Set;
/**
* Push the predicate in the LogicalFilter to the join children.
*/
public class PushPredicatesThroughJoin extends OneRewriteRuleFactory {
public class PushdownFilterThroughJoin extends OneRewriteRuleFactory {
public static final PushdownFilterThroughJoin INSTANCE = new PushdownFilterThroughJoin();
private static final ImmutableList<JoinType> COULD_PUSH_THROUGH_LEFT = ImmutableList.of(
JoinType.INNER_JOIN,
@ -59,7 +58,7 @@ public class PushPredicatesThroughJoin extends OneRewriteRuleFactory {
JoinType.CROSS_JOIN
);
private static final ImmutableList<JoinType> COULD_PUSH_EQUAL_TO = ImmutableList.of(
private static final ImmutableList<JoinType> COULD_PUSH_INSIDE = ImmutableList.of(
JoinType.INNER_JOIN
);
@ -68,6 +67,9 @@ public class PushPredicatesThroughJoin extends OneRewriteRuleFactory {
* select a.k1, b.k1 from a join b on a.k1 = b.k1 and a.k2 > 2 and b.k2 > 5
* where a.k1 > 1 and b.k1 > 2 and a.k2 > b.k2
*
* TODO(jakevin): following graph is wrong, we should add a new rule to extract
* a.k2 > 2 and b.k2 > 5, and pushdown.
*
* Logical plan tree:
* project
* |
@ -93,29 +95,28 @@ public class PushPredicatesThroughJoin extends OneRewriteRuleFactory {
LogicalJoin<GroupPlan, GroupPlan> join = filter.child();
Expression filterPredicates = filter.getPredicates();
List<Expression> predicates = ExpressionUtils.extractConjunction(filter.getPredicates());
List<Expression> filterConditions = Lists.newArrayList();
List<Expression> filterPredicates = Lists.newArrayList();
List<Expression> joinConditions = Lists.newArrayList();
Set<Slot> leftInput = join.left().getOutputSet();
Set<Slot> rightInput = join.right().getOutputSet();
ExpressionUtils.extractConjunction(filterPredicates)
.forEach(predicate -> {
if (Objects.nonNull(getJoinCondition(predicate, leftInput, rightInput))
&& COULD_PUSH_EQUAL_TO.contains(join.getJoinType())) {
joinConditions.add(predicate);
} else {
filterConditions.add(predicate);
}
});
// TODO: predicate slotReference should be not nullable.
for (Expression predicate : predicates) {
if (convertJoinCondition(predicate, leftInput, rightInput, join.getJoinType())) {
joinConditions.add(predicate);
} else {
filterPredicates.add(predicate);
}
}
List<Expression> leftPredicates = Lists.newArrayList();
List<Expression> rightPredicates = Lists.newArrayList();
for (Expression p : filterConditions) {
Set<Slot> slots = p.getInputSlots();
List<Expression> remainingPredicates = Lists.newArrayList();
for (Expression p : filterPredicates) {
Set<Slot> slots = p.collect(SlotReference.class::isInstance);
if (slots.isEmpty()) {
leftPredicates.add(p);
rightPredicates.add(p);
@ -123,50 +124,47 @@ public class PushPredicatesThroughJoin extends OneRewriteRuleFactory {
}
if (leftInput.containsAll(slots) && COULD_PUSH_THROUGH_LEFT.contains(join.getJoinType())) {
leftPredicates.add(p);
}
if (rightInput.containsAll(slots) && COULD_PUSH_THROUGH_RIGHT.contains(join.getJoinType())) {
} else if (rightInput.containsAll(slots) && COULD_PUSH_THROUGH_RIGHT.contains(join.getJoinType())) {
rightPredicates.add(p);
} else {
remainingPredicates.add(p);
}
}
filterConditions.removeAll(leftPredicates);
filterConditions.removeAll(rightPredicates);
join.getOtherJoinCondition().map(joinConditions::add);
return PlanUtils.filterOrSelf(filterConditions,
pushDownPredicate(join, joinConditions, leftPredicates, rightPredicates));
}).toRule(RuleType.PUSH_DOWN_PREDICATE_THROUGH_JOIN);
return PlanUtils.filterOrSelf(remainingPredicates,
new LogicalJoin<>(join.getJoinType(),
join.getHashJoinConjuncts(),
ExpressionUtils.optionalAnd(joinConditions),
PlanUtils.filterOrSelf(leftPredicates, join.left()),
PlanUtils.filterOrSelf(rightPredicates, join.right())));
}).toRule(RuleType.PUSHDOWN_FILTER_THROUGH_JOIN);
}
private Plan pushDownPredicate(LogicalJoin<GroupPlan, GroupPlan> join,
List<Expression> joinConditions, List<Expression> leftPredicates, List<Expression> rightPredicates) {
// todo expr should optimize again using expr rewrite
Plan leftPlan = PlanUtils.filterOrSelf(leftPredicates, join.left());
Plan rightPlan = PlanUtils.filterOrSelf(rightPredicates, join.right());
return new LogicalJoin<>(join.getJoinType(), join.getHashJoinConjuncts(),
ExpressionUtils.optionalAnd(joinConditions), leftPlan, rightPlan);
}
private Expression getJoinCondition(Expression predicate, Set<Slot> leftOutputs, Set<Slot> rightOutputs) {
if (!(predicate instanceof ComparisonPredicate)) {
return null;
private boolean convertJoinCondition(Expression predicate, Set<Slot> leftOutputs, Set<Slot> rightOutputs,
JoinType joinType) {
if (!COULD_PUSH_INSIDE.contains(joinType)) {
return false;
}
if (!(predicate instanceof EqualTo)) {
return false;
}
ComparisonPredicate comparison = (ComparisonPredicate) predicate;
EqualTo equalTo = (EqualTo) predicate;
if (!(comparison instanceof EqualTo)) {
return null;
Set<Slot> leftSlots = equalTo.left().collect(SlotReference.class::isInstance);
Set<Slot> rightSlots = equalTo.right().collect(SlotReference.class::isInstance);
if (leftSlots.size() == 0 || rightSlots.size() == 0) {
return false;
}
Set<Slot> leftSlots = comparison.left().getInputSlots();
Set<Slot> rightSlots = comparison.right().getInputSlots();
if ((leftOutputs.containsAll(leftSlots) && rightOutputs.containsAll(rightSlots))
|| (leftOutputs.containsAll(rightSlots) && rightOutputs.containsAll(leftSlots))) {
return predicate;
return true;
}
return null;
return false;
}
}

View File

@ -37,7 +37,7 @@ import java.util.Set;
/**
* Push the other join conditions in LogicalJoin to children.
*/
public class PushDownJoinOtherCondition extends OneRewriteRuleFactory {
public class PushdownJoinOtherCondition extends OneRewriteRuleFactory {
private static final ImmutableList<JoinType> PUSH_DOWN_LEFT_VALID_TYPE = ImmutableList.of(
JoinType.INNER_JOIN,
JoinType.LEFT_SEMI_JOIN,
@ -90,7 +90,7 @@ public class PushDownJoinOtherCondition extends OneRewriteRuleFactory {
return new LogicalJoin<>(join.getJoinType(), join.getHashJoinConjuncts(),
ExpressionUtils.optionalAnd(otherConjuncts), left, right);
}).toRule(RuleType.PUSH_DOWN_JOIN_OTHER_CONDITION);
}).toRule(RuleType.PUSHDOWN_JOIN_OTHER_CONDITION);
}
private boolean allCoveredBy(Expression predicate, Set<Slot> inputSlotSet) {

View File

@ -37,10 +37,9 @@ public class SlotReference extends Slot {
// TODO: we should distinguish the name is alias or column name, and the column name should contains
// `cluster:db`.`table`.`column`
private final String name;
private final List<String> qualifier;
private final DataType dataType;
private final boolean nullable;
private final List<String> qualifier;
private final Column column;
public SlotReference(String name, DataType dataType) {

View File

@ -17,7 +17,7 @@
package org.apache.doris.nereids.postprocess;
import org.apache.doris.common.AnalysisException;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.datasets.ssb.SSBTestBase;
import org.apache.doris.nereids.datasets.ssb.SSBUtils;
import org.apache.doris.nereids.glue.translator.PhysicalPlanTranslator;
@ -28,6 +28,7 @@ import org.apache.doris.nereids.trees.plans.physical.PhysicalPlan;
import org.apache.doris.nereids.trees.plans.physical.RuntimeFilter;
import org.apache.doris.nereids.util.PlanChecker;
import com.google.common.collect.ImmutableList;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
@ -44,99 +45,110 @@ public class RuntimeFilterTest extends SSBTestBase {
}
@Test
public void testGenerateRuntimeFilter() throws AnalysisException {
public void testGenerateRuntimeFilter() {
String sql = "SELECT * FROM lineorder JOIN customer on c_custkey = lo_custkey";
List<RuntimeFilter> filters = getRuntimeFilters(sql).get();
Assertions.assertTrue(filters.size() == 1
&& checkRuntimeFilterExprs(filters, "c_custkey", "lo_custkey"));
Assertions.assertEquals(1, filters.size());
checkRuntimeFilterExprs(filters, ImmutableList.of(
Pair.of("c_custkey", "lo_custkey")));
}
@Test
public void testGenerateRuntimeFilterByIllegalSrcExpr() throws AnalysisException {
public void testGenerateRuntimeFilterByIllegalSrcExpr() {
String sql = "SELECT * FROM lineorder JOIN customer on c_custkey = c_custkey";
List<RuntimeFilter> filters = getRuntimeFilters(sql).get();
Assertions.assertEquals(0, filters.size());
}
@Test
public void testComplexExpressionToRuntimeFilter() throws AnalysisException {
public void testComplexExpressionToRuntimeFilter() {
String sql
= "SELECT * FROM supplier JOIN customer on c_name = s_name and s_city = c_city and s_nation = c_nation";
List<RuntimeFilter> filters = getRuntimeFilters(sql).get();
Assertions.assertTrue(filters.size() == 3
&& checkRuntimeFilterExprs(filters, "c_name", "s_name", "c_city", "s_city", "c_nation", "s_nation"));
Assertions.assertEquals(3, filters.size());
checkRuntimeFilterExprs(filters, ImmutableList.of(
Pair.of("c_name", "s_name"),
Pair.of("c_city", "s_city"),
Pair.of("c_nation", "s_nation")));
}
@Test
public void testNestedJoinGenerateRuntimeFilter() throws AnalysisException {
public void testNestedJoinGenerateRuntimeFilter() {
String sql = SSBUtils.Q4_1;
List<RuntimeFilter> filters = getRuntimeFilters(sql).get();
Assertions.assertTrue(filters.size() == 4
&& checkRuntimeFilterExprs(filters, "p_partkey", "lo_partkey", "s_suppkey", "lo_suppkey",
"c_custkey", "lo_custkey", "lo_orderdate", "d_datekey"));
Assertions.assertEquals(4, filters.size());
checkRuntimeFilterExprs(filters, ImmutableList.of(
Pair.of("p_partkey", "lo_partkey"), Pair.of("s_suppkey", "lo_suppkey"),
Pair.of("c_custkey", "lo_custkey"), Pair.of("lo_orderdate", "d_datekey")));
}
@Test
public void testSubTreeInUnsupportedJoinType() throws AnalysisException {
public void testSubTreeInUnsupportedJoinType() {
String sql = "select c_custkey"
+ " from (select lo_custkey from lineorder inner join dates on lo_orderdate = d_datekey) a"
+ " left outer join (select c_custkey from customer inner join supplier on c_custkey = s_suppkey) b"
+ " on b.c_custkey = a.lo_custkey";
List<RuntimeFilter> filters = getRuntimeFilters(sql).get();
Assertions.assertTrue(filters.size() == 2
&& checkRuntimeFilterExprs(filters, "d_datekey", "lo_orderdate", "s_suppkey", "c_custkey"));
Assertions.assertEquals(2, filters.size());
checkRuntimeFilterExprs(filters, ImmutableList.of(
Pair.of("d_datekey", "lo_orderdate"), Pair.of("s_suppkey", "c_custkey")));
}
@Test
public void testPushDownEncounterUnsupportedJoinType() throws AnalysisException {
public void testPushDownEncounterUnsupportedJoinType() {
String sql = "select c_custkey"
+ " from (select lo_custkey from lineorder left outer join dates on lo_orderdate = d_datekey) a"
+ " inner join (select c_custkey from customer inner join supplier on c_custkey = s_suppkey) b"
+ " on b.c_custkey = a.lo_custkey";
List<RuntimeFilter> filters = getRuntimeFilters(sql).get();
Assertions.assertTrue(filters.size() == 1
&& checkRuntimeFilterExprs(filters, "s_suppkey", "c_custkey"));
Assertions.assertEquals(1, filters.size());
checkRuntimeFilterExprs(filters, ImmutableList.of(
Pair.of("s_suppkey", "c_custkey")));
}
@Test
public void testPushDownThroughAggNode() throws AnalysisException {
public void testPushDownThroughAggNode() {
String sql = "select profit"
+ " from (select lo_custkey, sum(lo_revenue - lo_supplycost) as profit from lineorder inner join dates"
+ " on lo_orderdate = d_datekey group by lo_custkey) a"
+ " inner join (select c_custkey from customer inner join supplier on c_custkey = s_suppkey) b"
+ " on b.c_custkey = a.lo_custkey";
List<RuntimeFilter> filters = getRuntimeFilters(sql).get();
Assertions.assertTrue(filters.size() == 3
&& checkRuntimeFilterExprs(filters, "c_custkey", "lo_custkey", "d_datekey", "lo_orderdate",
"s_suppkey", "c_custkey"));
Assertions.assertEquals(3, filters.size());
checkRuntimeFilterExprs(filters, ImmutableList.of(
Pair.of("c_custkey", "lo_custkey"), Pair.of("d_datekey", "lo_orderdate"),
Pair.of("s_suppkey", "c_custkey")));
}
@Test
public void testDoNotPushDownThroughAggFunction() throws AnalysisException {
public void testDoNotPushDownThroughAggFunction() {
String sql = "select profit"
+ " from (select lo_custkey, sum(lo_revenue - lo_supplycost) as profit from lineorder inner join dates"
+ " on lo_orderdate = d_datekey group by lo_custkey) a"
+ " inner join (select sum(c_custkey) c_custkey from customer inner join supplier on c_custkey = s_suppkey group by s_suppkey) b"
+ " on b.c_custkey = a.lo_custkey";
List<RuntimeFilter> filters = getRuntimeFilters(sql).get();
Assertions.assertTrue(filters.size() == 2
&& checkRuntimeFilterExprs(filters, "d_datekey", "lo_orderdate", "s_suppkey", "c_custkey"));
Assertions.assertEquals(2, filters.size());
checkRuntimeFilterExprs(filters, ImmutableList.of(
Pair.of("d_datekey", "lo_orderdate"), Pair.of("s_suppkey", "c_custkey")));
}
@Test
public void testCrossJoin() throws AnalysisException {
public void testCrossJoin() {
String sql = "select c_custkey, lo_custkey from lineorder, customer where lo_custkey = c_custkey";
List<RuntimeFilter> filters = getRuntimeFilters(sql).get();
Assertions.assertTrue(filters.size() == 1
&& checkRuntimeFilterExprs(filters, "c_custkey", "lo_custkey"));
Assertions.assertEquals(1, filters.size());
checkRuntimeFilterExprs(filters, ImmutableList.of(
Pair.of("c_custkey", "lo_custkey")));
}
@Test
public void testSubQueryAlias() throws AnalysisException {
public void testSubQueryAlias() {
String sql = "select c_custkey, lo_custkey from lineorder l, customer c where c.c_custkey = l.lo_custkey";
List<RuntimeFilter> filters = getRuntimeFilters(sql).get();
Assertions.assertTrue(filters.size() == 1
&& checkRuntimeFilterExprs(filters, "c_custkey", "lo_custkey"));
Assertions.assertEquals(1, filters.size());
checkRuntimeFilterExprs(filters, ImmutableList.of(
Pair.of("c_custkey", "lo_custkey")));
}
@Test
@ -165,35 +177,39 @@ public class RuntimeFilterTest extends SSBTestBase {
+ " on t1.p_partkey = t2.lo_partkey\n"
+ " order by t1.lo_custkey, t1.p_partkey, t2.s_suppkey, t2.c_custkey, t2.lo_orderkey";
List<RuntimeFilter> filters = getRuntimeFilters(sql).get();
Assertions.assertTrue(filters.size() == 4
&& checkRuntimeFilterExprs(filters, "lo_partkey", "p_partkey", "lo_partkey", "p_partkey",
"c_region", "s_region", "lo_custkey", "c_custkey"));
Assertions.assertEquals(4, filters.size());
checkRuntimeFilterExprs(filters, ImmutableList.of(
Pair.of("lo_partkey", "p_partkey"), Pair.of("lo_partkey", "p_partkey"),
Pair.of("c_region", "s_region"), Pair.of("lo_custkey", "c_custkey")));
}
@Test
public void testPushDownThroughJoin() throws AnalysisException {
public void testPushDownThroughJoin() {
String sql = "select c_custkey from (select c_custkey from (select lo_custkey from lineorder inner join dates"
+ " on lo_orderdate = d_datekey) a"
+ " inner join (select c_custkey from customer inner join supplier on c_custkey = s_suppkey) b"
+ " on b.c_custkey = a.lo_custkey) c inner join (select lo_custkey from customer inner join lineorder"
+ " on c_custkey = lo_custkey) d on c.c_custkey = d.lo_custkey";
List<RuntimeFilter> filters = getRuntimeFilters(sql).get();
Assertions.assertTrue(filters.size() == 5
&& checkRuntimeFilterExprs(filters, "lo_custkey", "c_custkey", "c_custkey", "lo_custkey",
"d_datekey", "lo_orderdate", "s_suppkey", "c_custkey", "lo_custkey", "c_custkey"));
Assertions.assertEquals(5, filters.size());
checkRuntimeFilterExprs(filters, ImmutableList.of(
Pair.of("lo_custkey", "c_custkey"), Pair.of("c_custkey", "lo_custkey"),
Pair.of("d_datekey", "lo_orderdate"), Pair.of("s_suppkey", "c_custkey"),
Pair.of("lo_custkey", "c_custkey")));
}
@Test
public void testPushDownThroughUnsupportedJoinType() throws AnalysisException {
public void testPushDownThroughUnsupportedJoinType() {
String sql = "select c_custkey from (select c_custkey from (select lo_custkey from lineorder inner join dates"
+ " on lo_orderdate = d_datekey) a"
+ " inner join (select c_custkey from customer left outer join supplier on c_custkey = s_suppkey) b"
+ " on b.c_custkey = a.lo_custkey) c inner join (select lo_custkey from customer inner join lineorder"
+ " on c_custkey = lo_custkey) d on c.c_custkey = d.lo_custkey";
List<RuntimeFilter> filters = getRuntimeFilters(sql).get();
Assertions.assertTrue(filters.size() == 3
&& checkRuntimeFilterExprs(filters, "c_custkey", "lo_custkey", "d_datekey", "lo_orderdate",
"lo_custkey", "c_custkey"));
Assertions.assertEquals(3, filters.size());
checkRuntimeFilterExprs(filters, ImmutableList.of(
Pair.of("c_custkey", "lo_custkey"), Pair.of("d_datekey", "lo_orderdate"),
Pair.of("lo_custkey", "c_custkey")));
}
private Optional<List<RuntimeFilter>> getRuntimeFilters(String sql) {
@ -210,18 +226,11 @@ public class RuntimeFilterTest extends SSBTestBase {
return Optional.of(filters);
}
private boolean checkRuntimeFilterExprs(List<RuntimeFilter> filters, String... colNames) {
int idx = 0;
for (RuntimeFilter filter : filters) {
if (!checkRuntimeFilterExpr(filter, colNames[idx++], colNames[idx++])) {
return false;
}
private void checkRuntimeFilterExprs(List<RuntimeFilter> filters, List<Pair<String, String>> colNames) {
Assertions.assertEquals(filters.size(), colNames.size());
for (int i = 0; i < filters.size(); i++) {
Assertions.assertTrue(filters.get(i).getSrcExpr().toSql().equals(colNames.get(i).first)
&& filters.get(i).getTargetExpr().toSql().equals(colNames.get(i).second));
}
return true;
}
private boolean checkRuntimeFilterExpr(RuntimeFilter filter, String srcColName, String targetColName) {
return filter.getSrcExpr().toSql().equals(srcColName)
&& filter.getTargetExpr().toSql().equals(targetColName);
}
}

View File

@ -0,0 +1,54 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.rewrite.logical;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.GreaterThan;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.util.LogicalPlanBuilder;
import org.apache.doris.nereids.util.MemoTestUtils;
import org.apache.doris.nereids.util.PatternMatchSupported;
import org.apache.doris.nereids.util.PlanChecker;
import org.apache.doris.nereids.util.PlanConstructor;
import org.junit.jupiter.api.Test;
class PushFilterInsideJoinTest implements PatternMatchSupported {
private final LogicalOlapScan scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0);
private final LogicalOlapScan scan2 = PlanConstructor.newLogicalOlapScan(1, "t2", 0);
@Test
void testPushInside() {
Expression predicates = new GreaterThan(scan1.getOutput().get(1), scan2.getOutput().get(1));
LogicalPlan plan = new LogicalPlanBuilder(scan1)
.hashJoinEmptyOn(scan2, JoinType.CROSS_JOIN)
.filter(predicates)
.build();
PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
.applyTopDown(PushFilterInsideJoin.INSTANCE)
.printlnTree()
.matchesFromRoot(
logicalJoin().when(join -> join.getOtherJoinCondition().get().equals(predicates))
);
}
}

View File

@ -1,208 +0,0 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.rewrite.logical;
import org.apache.doris.nereids.memo.Group;
import org.apache.doris.nereids.memo.Memo;
import org.apache.doris.nereids.trees.expressions.Add;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.GreaterThan;
import org.apache.doris.nereids.trees.expressions.Subtract;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.PlanConstructor;
import org.apache.doris.nereids.util.PlanRewriter;
import org.apache.doris.qe.ConnectContext;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInstance;
import java.util.Optional;
/**
* plan rewrite ut.
*/
@TestInstance(TestInstance.Lifecycle.PER_CLASS)
public class PushPredicateThroughJoinTest {
private Plan rStudent;
private Plan rScore;
/**
* ut before.
*/
@BeforeAll
public final void beforeAll() {
rStudent = new LogicalOlapScan(PlanConstructor.getNextRelationId(), PlanConstructor.student, ImmutableList.of(""));
rScore = new LogicalOlapScan(PlanConstructor.getNextRelationId(), PlanConstructor.score, ImmutableList.of(""));
}
@Test
public void oneSide() {
oneSide(JoinType.CROSS_JOIN, false);
oneSide(JoinType.INNER_JOIN, false);
oneSide(JoinType.LEFT_OUTER_JOIN, false);
oneSide(JoinType.LEFT_SEMI_JOIN, false);
oneSide(JoinType.LEFT_ANTI_JOIN, false);
oneSide(JoinType.RIGHT_OUTER_JOIN, true);
oneSide(JoinType.RIGHT_SEMI_JOIN, true);
oneSide(JoinType.RIGHT_ANTI_JOIN, true);
}
private void oneSide(JoinType joinType, boolean testRight) {
Expression whereCondition1 = new GreaterThan(rStudent.getOutput().get(1), Literal.of(18));
Expression whereCondition2 = new GreaterThan(rStudent.getOutput().get(1), Literal.of(50));
Expression whereCondition = ExpressionUtils.and(whereCondition1, whereCondition2);
Plan left = rStudent;
Plan right = rScore;
if (testRight) {
left = rScore;
right = rStudent;
}
Plan join = new LogicalJoin<>(joinType, Lists.newArrayList(), Optional.empty(), left, right);
Plan filter = new LogicalFilter<>(whereCondition, join);
Plan root = new LogicalProject<>(Lists.newArrayList(), filter);
Memo memo = rewrite(root);
Group rootGroup = memo.getRoot();
Plan shouldJoin = rootGroup.getLogicalExpression().child(0).getLogicalExpression().getPlan();
Plan shouldFilter = rootGroup.getLogicalExpression().child(0).getLogicalExpression()
.child(0).getLogicalExpression().getPlan();
Plan shouldScan = rootGroup.getLogicalExpression().child(0).getLogicalExpression()
.child(1).getLogicalExpression().getPlan();
if (testRight) {
shouldFilter = rootGroup.getLogicalExpression().child(0).getLogicalExpression()
.child(1).getLogicalExpression().getPlan();
shouldScan = rootGroup.getLogicalExpression().child(0).getLogicalExpression()
.child(0).getLogicalExpression().getPlan();
}
Assertions.assertTrue(shouldJoin instanceof LogicalJoin);
Assertions.assertTrue(shouldFilter instanceof LogicalFilter);
Assertions.assertTrue(shouldScan instanceof LogicalOlapScan);
LogicalFilter<Plan> actualFilter = (LogicalFilter<Plan>) shouldFilter;
Assertions.assertEquals(whereCondition, actualFilter.getPredicates());
}
@Test
public void bothSideToBothSide() {
bothSideToBothSide(JoinType.INNER_JOIN);
}
private void bothSideToBothSide(JoinType joinType) {
Expression bothSideEqualTo = new EqualTo(new Add(rStudent.getOutput().get(0), Literal.of(1)),
new Subtract(rScore.getOutput().get(0), Literal.of(2)));
Expression leftSide = new GreaterThan(rStudent.getOutput().get(1), Literal.of(18));
Expression rightSide = new GreaterThan(rScore.getOutput().get(2), Literal.of(60));
Expression whereCondition = ExpressionUtils.and(bothSideEqualTo, leftSide, rightSide);
Plan join = new LogicalJoin<>(joinType, Lists.newArrayList(), Optional.empty(), rStudent, rScore);
Plan filter = new LogicalFilter<>(whereCondition, join);
Plan root = new LogicalProject<>(Lists.newArrayList(), filter);
Memo memo = rewrite(root);
Group rootGroup = memo.getRoot();
Plan shouldJoin = rootGroup.getLogicalExpression().child(0).getLogicalExpression().getPlan();
Plan leftFilter = rootGroup.getLogicalExpression().child(0).getLogicalExpression()
.child(0).getLogicalExpression().getPlan();
Plan rightFilter = rootGroup.getLogicalExpression().child(0).getLogicalExpression()
.child(1).getLogicalExpression().getPlan();
Assertions.assertTrue(shouldJoin instanceof LogicalJoin);
Assertions.assertTrue(leftFilter instanceof LogicalFilter);
Assertions.assertTrue(rightFilter instanceof LogicalFilter);
LogicalJoin<Plan, Plan> actualJoin = (LogicalJoin<Plan, Plan>) shouldJoin;
LogicalFilter<Plan> actualLeft = (LogicalFilter<Plan>) leftFilter;
LogicalFilter<Plan> actualRight = (LogicalFilter<Plan>) rightFilter;
Assertions.assertEquals(bothSideEqualTo, actualJoin.getOtherJoinCondition().get());
Assertions.assertEquals(leftSide, actualLeft.getPredicates());
Assertions.assertEquals(rightSide, actualRight.getPredicates());
}
@Test
public void bothSideToOneSide() {
bothSideToOneSide(JoinType.LEFT_OUTER_JOIN, false);
bothSideToOneSide(JoinType.LEFT_ANTI_JOIN, false);
bothSideToOneSide(JoinType.LEFT_SEMI_JOIN, false);
bothSideToOneSide(JoinType.RIGHT_OUTER_JOIN, true);
bothSideToOneSide(JoinType.RIGHT_ANTI_JOIN, true);
bothSideToOneSide(JoinType.RIGHT_SEMI_JOIN, true);
}
private void bothSideToOneSide(JoinType joinType, boolean testRight) {
Expression pushSide = new GreaterThan(rStudent.getOutput().get(1), Literal.of(18));
Expression reserveSide = new GreaterThan(rScore.getOutput().get(2), Literal.of(60));
Expression whereCondition = ExpressionUtils.and(pushSide, reserveSide);
Plan left = rStudent;
Plan right = rScore;
if (testRight) {
left = rScore;
right = rStudent;
}
Plan join = new LogicalJoin<>(joinType, Lists.newArrayList(), Optional.empty(), left, right);
Plan filter = new LogicalFilter<>(whereCondition, join);
Plan root = new LogicalProject<>(Lists.newArrayList(), filter);
Memo memo = rewrite(root);
Group rootGroup = memo.getRoot();
Plan shouldJoin = rootGroup.getLogicalExpression().child(0).getLogicalExpression()
.child(0).getLogicalExpression().getPlan();
Plan shouldFilter = rootGroup.getLogicalExpression().child(0).getLogicalExpression()
.child(0).getLogicalExpression().child(0).getLogicalExpression().getPlan();
Plan shouldScan = rootGroup.getLogicalExpression().child(0).getLogicalExpression()
.child(0).getLogicalExpression().child(1).getLogicalExpression().getPlan();
if (testRight) {
shouldFilter = rootGroup.getLogicalExpression().child(0).getLogicalExpression()
.child(0).getLogicalExpression().child(1).getLogicalExpression().getPlan();
shouldScan = rootGroup.getLogicalExpression().child(0).getLogicalExpression()
.child(0).getLogicalExpression().child(0).getLogicalExpression().getPlan();
}
Assertions.assertTrue(shouldJoin instanceof LogicalJoin);
Assertions.assertTrue(shouldFilter instanceof LogicalFilter);
Assertions.assertTrue(shouldScan instanceof LogicalOlapScan);
LogicalFilter<Plan> actualFilter = (LogicalFilter<Plan>) shouldFilter;
Assertions.assertEquals(pushSide, actualFilter.getPredicates());
}
private Memo rewrite(Plan plan) {
return PlanRewriter.topDownRewriteMemo(plan, new ConnectContext(), new PushPredicatesThroughJoin());
}
}

View File

@ -32,7 +32,7 @@ import org.junit.jupiter.api.Test;
import java.util.List;
public class PushDownExpressionsInHashConditionTest extends TestWithFeService implements PatternMatchSupported {
public class PushdownExpressionsInHashConditionTest extends TestWithFeService implements PatternMatchSupported {
private final List<String> testSql = ImmutableList.of(
"SELECT * FROM T1 JOIN T2 ON T1.ID + 1 = T2.ID + 2 AND T1.ID + 1 > 2",
"SELECT * FROM (SELECT * FROM T1) X JOIN (SELECT * FROM T2) Y ON X.ID + 1 = Y.ID + 2 AND X.ID + 1 > 2",
@ -98,7 +98,7 @@ public class PushDownExpressionsInHashConditionTest extends TestWithFeService im
PlanChecker.from(connectContext)
.analyze("SELECT * FROM T1 JOIN T2 ON T1.ID + 1 = T2.ID + 2 AND T1.ID + 1 > 2")
.applyTopDown(new FindHashConditionForJoin())
.applyTopDown(new PushDownExpressionsInHashCondition())
.applyTopDown(new PushdownExpressionsInHashCondition())
.matches(
logicalProject(
logicalJoin(
@ -119,7 +119,7 @@ public class PushDownExpressionsInHashConditionTest extends TestWithFeService im
.analyze(
"SELECT * FROM (SELECT * FROM T1) X JOIN (SELECT * FROM T2) Y ON X.ID + 1 = Y.ID + 2 AND X.ID + 1 > 2")
.applyTopDown(new FindHashConditionForJoin())
.applyTopDown(new PushDownExpressionsInHashCondition())
.applyTopDown(new PushdownExpressionsInHashCondition())
.matches(
logicalProject(
logicalJoin(
@ -144,7 +144,7 @@ public class PushDownExpressionsInHashConditionTest extends TestWithFeService im
.analyze(
"SELECT * FROM T1 JOIN (SELECT ID, SUM(SCORE) SCORE FROM T2 GROUP BY ID) T ON T1.ID + 1 = T.ID AND T.SCORE = T1.SCORE + 10")
.applyTopDown(new FindHashConditionForJoin())
.applyTopDown(new PushDownExpressionsInHashCondition())
.applyTopDown(new PushdownExpressionsInHashCondition())
.matches(
logicalProject(
logicalJoin(
@ -167,7 +167,7 @@ public class PushDownExpressionsInHashConditionTest extends TestWithFeService im
.analyze(
"SELECT * FROM T1 JOIN (SELECT ID, SUM(SCORE) SCORE FROM T2 GROUP BY ID ORDER BY ID) T ON T1.ID + 1 = T.ID AND T.SCORE = T1.SCORE + 10")
.applyTopDown(new FindHashConditionForJoin())
.applyTopDown(new PushDownExpressionsInHashCondition())
.applyTopDown(new PushdownExpressionsInHashCondition())
.matches(
logicalProject(
logicalJoin(

View File

@ -47,7 +47,7 @@ import org.junit.jupiter.api.Test;
import java.util.List;
public class PushDownPredicateThroughAggregationTest {
public class PushdownFilterThroughAggregationTest {
/**
* origin plan:
@ -187,6 +187,6 @@ public class PushDownPredicateThroughAggregationTest {
}
private Memo rewrite(Plan plan) {
return PlanRewriter.topDownRewriteMemo(plan, new ConnectContext(), new PushPredicateThroughAggregation());
return PlanRewriter.topDownRewriteMemo(plan, new ConnectContext(), new PushdownFilterThroughAggregation());
}
}

View File

@ -0,0 +1,218 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.rewrite.logical;
import org.apache.doris.nereids.trees.expressions.Add;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.GreaterThan;
import org.apache.doris.nereids.trees.expressions.Subtract;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.LogicalPlanBuilder;
import org.apache.doris.nereids.util.MemoTestUtils;
import org.apache.doris.nereids.util.PatternMatchSupported;
import org.apache.doris.nereids.util.PlanChecker;
import org.apache.doris.nereids.util.PlanConstructor;
import com.google.common.collect.ImmutableList;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInstance;
/**
* PushdownFilterThroughJoinTest UT.
*/
@TestInstance(TestInstance.Lifecycle.PER_CLASS)
public class PushdownFilterThroughJoinTest implements PatternMatchSupported {
private LogicalPlan rStudent;
private LogicalPlan rScore;
/**
* ut before.
*/
@BeforeAll
public final void beforeAll() {
rStudent = new LogicalOlapScan(PlanConstructor.getNextRelationId(), PlanConstructor.student,
ImmutableList.of(""));
rScore = new LogicalOlapScan(PlanConstructor.getNextRelationId(), PlanConstructor.score, ImmutableList.of(""));
}
public void testLeft(JoinType joinType) {
Expression whereCondition1 = new GreaterThan(rStudent.getOutput().get(1), Literal.of(18));
Expression whereCondition2 = new GreaterThan(rStudent.getOutput().get(1), Literal.of(50));
Expression whereCondition = ExpressionUtils.and(whereCondition1, whereCondition2);
LogicalPlan plan = new LogicalPlanBuilder(rStudent)
.hashJoinEmptyOn(rScore, joinType)
.filter(whereCondition)
.build();
PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
.applyTopDown(PushdownFilterThroughJoin.INSTANCE)
.matchesFromRoot(
logicalJoin(
logicalFilter(logicalOlapScan())
.when(filter -> filter.getPredicates().equals(whereCondition)),
logicalOlapScan()
)
);
}
public void testRight(JoinType joinType) {
Expression whereCondition1 = new GreaterThan(rStudent.getOutput().get(1), Literal.of(18));
Expression whereCondition2 = new GreaterThan(rStudent.getOutput().get(1), Literal.of(50));
Expression whereCondition = ExpressionUtils.and(whereCondition1, whereCondition2);
LogicalPlan plan = new LogicalPlanBuilder(rScore)
.hashJoinEmptyOn(rStudent, joinType)
.filter(whereCondition)
.build();
PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
.applyTopDown(PushdownFilterThroughJoin.INSTANCE)
.matchesFromRoot(
logicalJoin(
logicalOlapScan(),
logicalFilter(logicalOlapScan())
.when(filter -> filter.getPredicates().equals(whereCondition))
)
);
}
@Test
public void oneSide() {
testLeft(JoinType.CROSS_JOIN);
testLeft(JoinType.INNER_JOIN);
testLeft(JoinType.LEFT_OUTER_JOIN);
testLeft(JoinType.LEFT_SEMI_JOIN);
testLeft(JoinType.LEFT_ANTI_JOIN);
testRight(JoinType.RIGHT_OUTER_JOIN);
testRight(JoinType.RIGHT_SEMI_JOIN);
testRight(JoinType.RIGHT_ANTI_JOIN);
}
@Test
public void bothSideToBothSide() {
bothSideToBothSide(JoinType.INNER_JOIN);
bothSideToBothSide(JoinType.CROSS_JOIN);
}
private void bothSideToBothSide(JoinType joinType) {
Expression bothSideEqualTo = new EqualTo(new Add(rStudent.getOutput().get(0), Literal.of(1)),
new Subtract(rScore.getOutput().get(0), Literal.of(2)));
Expression leftSide = new GreaterThan(rStudent.getOutput().get(1), Literal.of(18));
Expression rightSide = new GreaterThan(rScore.getOutput().get(2), Literal.of(60));
Expression whereCondition = ExpressionUtils.and(bothSideEqualTo, leftSide, rightSide);
LogicalPlan plan = new LogicalPlanBuilder(rStudent)
.hashJoinEmptyOn(rScore, joinType)
.filter(whereCondition)
.build();
if (joinType.isInnerJoin()) {
PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
.applyTopDown(PushdownFilterThroughJoin.INSTANCE)
.printlnTree()
.matchesFromRoot(
logicalJoin(
logicalFilter(logicalOlapScan())
.when(filter -> filter.getPredicates().equals(leftSide)),
logicalFilter(logicalOlapScan())
.when(filter -> filter.getPredicates().equals(rightSide))
).when(join -> join.getOtherJoinCondition().get().equals(bothSideEqualTo))
);
}
if (joinType.isCrossJoin()) {
PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
.applyTopDown(PushdownFilterThroughJoin.INSTANCE)
.printlnTree()
.matchesFromRoot(
logicalFilter(
logicalJoin(
logicalFilter(logicalOlapScan())
.when(filter -> filter.getPredicates().equals(leftSide)),
logicalFilter(logicalOlapScan())
.when(filter -> filter.getPredicates().equals(rightSide))
)
).when(filter -> filter.getPredicates().equals(bothSideEqualTo))
);
}
}
@Test
public void bothSideToOneSide() {
bothSideToLeft(JoinType.LEFT_OUTER_JOIN);
bothSideToLeft(JoinType.LEFT_ANTI_JOIN);
bothSideToLeft(JoinType.LEFT_SEMI_JOIN);
bothSideToRight(JoinType.RIGHT_OUTER_JOIN);
bothSideToRight(JoinType.RIGHT_ANTI_JOIN);
bothSideToRight(JoinType.RIGHT_SEMI_JOIN);
}
private void bothSideToLeft(JoinType joinType) {
Expression pushSide = new GreaterThan(rStudent.getOutput().get(1), Literal.of(18));
Expression reserveSide = new GreaterThan(rScore.getOutput().get(2), Literal.of(60));
Expression whereCondition = ExpressionUtils.and(pushSide, reserveSide);
LogicalPlan plan = new LogicalPlanBuilder(rStudent)
.hashJoinEmptyOn(rScore, joinType)
.filter(whereCondition)
.build();
PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
.applyTopDown(PushdownFilterThroughJoin.INSTANCE)
.matchesFromRoot(
logicalFilter(
logicalJoin(
logicalFilter(logicalOlapScan())
.when(filter -> filter.getPredicates().equals(pushSide)),
logicalOlapScan()
)
)
);
}
private void bothSideToRight(JoinType joinType) {
Expression pushSide = new GreaterThan(rStudent.getOutput().get(1), Literal.of(18));
Expression reserveSide = new GreaterThan(rScore.getOutput().get(2), Literal.of(60));
Expression whereCondition = ExpressionUtils.and(pushSide, reserveSide);
LogicalPlan plan = new LogicalPlanBuilder(rScore)
.hashJoinEmptyOn(rStudent, joinType)
.filter(whereCondition)
.build();
PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
.applyTopDown(PushdownFilterThroughJoin.INSTANCE)
.matchesFromRoot(
logicalFilter(
logicalJoin(
logicalOlapScan(),
logicalFilter(logicalOlapScan()).when(
filter -> filter.getPredicates().equals(pushSide))
)
)
);
}
}

View File

@ -43,7 +43,7 @@ import org.junit.jupiter.api.TestInstance;
import java.util.Optional;
@TestInstance(TestInstance.Lifecycle.PER_CLASS)
public class PushDownJoinOtherConditionTest {
public class PushdownJoinOtherConditionTest {
private Plan rStudent;
private Plan rScore;
@ -191,6 +191,6 @@ public class PushDownJoinOtherConditionTest {
}
private Memo rewrite(Plan plan) {
return PlanRewriter.topDownRewriteMemo(plan, new ConnectContext(), new PushDownJoinOtherCondition());
return PlanRewriter.topDownRewriteMemo(plan, new ConnectContext(), new PushdownJoinOtherCondition());
}
}

View File

@ -79,6 +79,12 @@ public class LogicalPlanBuilder {
return from(join);
}
public LogicalPlanBuilder hashJoinEmptyOn(LogicalPlan right, JoinType joinType) {
LogicalJoin<LogicalPlan, LogicalPlan> join = new LogicalJoin<>(joinType, new ArrayList<>(),
Optional.empty(), this.plan, right);
return from(join);
}
public LogicalPlanBuilder limit(long limit, long offset) {
LogicalLimit<LogicalPlan> limitPlan = new LogicalLimit<>(limit, offset, this.plan);
return from(limitPlan);

View File

@ -107,6 +107,7 @@ public class PlanChecker {
/**
* apply a top down rewrite rule if you not care the ruleId
*
* @param patternMatcher the rule dsl, such as: logicalOlapScan().then(olapScan -> olapScan)
* @return this checker, for call chaining of follow-up check
*/
@ -129,6 +130,7 @@ public class PlanChecker {
/**
* apply a bottom up rewrite rule if you not care the ruleId
*
* @param patternMatcher the rule dsl, such as: logicalOlapScan().then(olapScan -> olapScan)
* @return this checker, for call chaining of follow-up check
*/
@ -163,7 +165,8 @@ public class PlanChecker {
PhysicalPlan current = null;
loop:
for (Rule rule : RuleSet.IMPLEMENTATION_RULES) {
GroupExpressionMatching matching = new GroupExpressionMatching(rule.getPattern(), group.getLogicalExpression());
GroupExpressionMatching matching = new GroupExpressionMatching(rule.getPattern(),
group.getLogicalExpression());
for (Plan plan : matching) {
Plan after = rule.transform(plan, cascadesContext).get(0);
if (after instanceof PhysicalPlan) {
@ -346,4 +349,10 @@ public class PlanChecker {
public Plan getPlan() {
return cascadesContext.getMemo().copyOut();
}
public PlanChecker printlnTree() {
System.out.println(cascadesContext.getMemo().copyOut().treeString());
return this;
}
}