[feature](Nereids) add predicates push down on all join type (#12571)

* [feature](Nereids) add predicates push down on all join type
This commit is contained in:
morrySnow
2022-09-15 15:18:42 +08:00
committed by GitHub
parent 5b6d48ed5b
commit 858e8234d7
12 changed files with 594 additions and 292 deletions

View File

@ -19,6 +19,7 @@ package org.apache.doris.nereids.jobs.batch;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.jobs.Job;
import org.apache.doris.nereids.rules.RuleSet;
import org.apache.doris.nereids.rules.expression.rewrite.ExpressionNormalization;
import org.apache.doris.nereids.rules.mv.SelectRollup;
import org.apache.doris.nereids.rules.rewrite.AggregateDisassemble;
@ -27,14 +28,8 @@ import org.apache.doris.nereids.rules.rewrite.logical.EliminateFilter;
import org.apache.doris.nereids.rules.rewrite.logical.EliminateLimit;
import org.apache.doris.nereids.rules.rewrite.logical.FindHashConditionForJoin;
import org.apache.doris.nereids.rules.rewrite.logical.LimitPushDown;
import org.apache.doris.nereids.rules.rewrite.logical.MergeConsecutiveFilters;
import org.apache.doris.nereids.rules.rewrite.logical.MergeConsecutiveLimits;
import org.apache.doris.nereids.rules.rewrite.logical.MergeConsecutiveProjects;
import org.apache.doris.nereids.rules.rewrite.logical.NormalizeAggregate;
import org.apache.doris.nereids.rules.rewrite.logical.PruneOlapScanPartition;
import org.apache.doris.nereids.rules.rewrite.logical.PushPredicateThroughJoin;
import org.apache.doris.nereids.rules.rewrite.logical.PushdownFilterThroughProject;
import org.apache.doris.nereids.rules.rewrite.logical.PushdownProjectThroughLimit;
import org.apache.doris.nereids.rules.rewrite.logical.ReorderJoin;
import com.google.common.collect.ImmutableList;
@ -64,15 +59,9 @@ public class RewriteJob extends BatchRulesJob {
.add(topDownBatch(ImmutableList.of(new ExpressionNormalization())))
.add(topDownBatch(ImmutableList.of(new NormalizeAggregate())))
.add(topDownBatch(ImmutableList.of(new ReorderJoin())))
.add(topDownBatch(ImmutableList.of(new FindHashConditionForJoin())))
.add(topDownBatch(ImmutableList.of(new NormalizeAggregate())))
.add(topDownBatch(ImmutableList.of(new ColumnPruning())))
.add(topDownBatch(ImmutableList.of(new PushPredicateThroughJoin(),
new PushdownProjectThroughLimit(),
new PushdownFilterThroughProject(),
new MergeConsecutiveProjects(),
new MergeConsecutiveFilters(),
new MergeConsecutiveLimits())))
.add(topDownBatch(RuleSet.PUSH_DOWN_JOIN_CONDITION_RULES))
.add(topDownBatch(ImmutableList.of(new FindHashConditionForJoin())))
.add(topDownBatch(ImmutableList.of(new AggregateDisassemble())))
.add(topDownBatch(ImmutableList.of(new LimitPushDown())))
.add(topDownBatch(ImmutableList.of(new EliminateLimit())))

View File

@ -33,9 +33,13 @@ import org.apache.doris.nereids.rules.implementation.LogicalOneRowRelationToPhys
import org.apache.doris.nereids.rules.implementation.LogicalProjectToPhysicalProject;
import org.apache.doris.nereids.rules.implementation.LogicalSortToPhysicalQuickSort;
import org.apache.doris.nereids.rules.implementation.LogicalTopNToPhysicalTopN;
import org.apache.doris.nereids.rules.rewrite.AggregateDisassemble;
import org.apache.doris.nereids.rules.rewrite.logical.MergeConsecutiveFilters;
import org.apache.doris.nereids.rules.rewrite.logical.MergeConsecutiveLimits;
import org.apache.doris.nereids.rules.rewrite.logical.MergeConsecutiveProjects;
import org.apache.doris.nereids.rules.rewrite.logical.PushDownJoinOtherCondition;
import org.apache.doris.nereids.rules.rewrite.logical.PushPredicatesThroughJoin;
import org.apache.doris.nereids.rules.rewrite.logical.PushdownFilterThroughProject;
import org.apache.doris.nereids.rules.rewrite.logical.PushdownProjectThroughLimit;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableList.Builder;
@ -55,9 +59,14 @@ public class RuleSet {
.add(new MergeConsecutiveProjects())
.build();
public static final List<Rule> REWRITE_RULES = planRuleFactories()
.add(new AggregateDisassemble())
.build();
public static final List<RuleFactory> PUSH_DOWN_JOIN_CONDITION_RULES = ImmutableList.of(
new PushDownJoinOtherCondition(),
new PushPredicatesThroughJoin(),
new PushdownProjectThroughLimit(),
new PushdownFilterThroughProject(),
new MergeConsecutiveProjects(),
new MergeConsecutiveFilters(),
new MergeConsecutiveLimits());
public static final List<Rule> IMPLEMENTATION_RULES = planRuleFactories()
.add(new LogicalAggToPhysicalHashAgg())

View File

@ -79,6 +79,7 @@ public enum RuleType {
EXISTS_APPLY_TO_JOIN(RuleTypeClass.REWRITE),
// predicate push down rules
PUSH_DOWN_PREDICATE_THROUGH_JOIN(RuleTypeClass.REWRITE),
PUSH_DOWN_JOIN_OTHER_CONDITION(RuleTypeClass.REWRITE),
PUSH_DOWN_PREDICATE_THROUGH_AGGREGATION(RuleTypeClass.REWRITE),
// column prune rules,
COLUMN_PRUNE_AGGREGATION_CHILD(RuleTypeClass.REWRITE),

View File

@ -0,0 +1,99 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.rewrite.logical;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.PlanUtils;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import java.util.List;
import java.util.Set;
/**
* Push the other join conditions in LogicalJoin to children.
*/
public class PushDownJoinOtherCondition extends OneRewriteRuleFactory {
private static final ImmutableList<JoinType> PUSH_DOWN_LEFT_VALID_TYPE = ImmutableList.of(
JoinType.INNER_JOIN,
JoinType.LEFT_SEMI_JOIN,
JoinType.RIGHT_OUTER_JOIN,
JoinType.RIGHT_ANTI_JOIN,
JoinType.RIGHT_SEMI_JOIN,
JoinType.CROSS_JOIN
);
private static final ImmutableList<JoinType> PUSH_DOWN_RIGHT_VALID_TYPE = ImmutableList.of(
JoinType.INNER_JOIN,
JoinType.LEFT_OUTER_JOIN,
JoinType.LEFT_ANTI_JOIN,
JoinType.LEFT_SEMI_JOIN,
JoinType.RIGHT_SEMI_JOIN,
JoinType.CROSS_JOIN
);
@Override
public Rule build() {
return logicalJoin().then(join -> {
if (!join.getOtherJoinCondition().isPresent()) {
return null;
}
List<Expression> otherConjuncts = ExpressionUtils.extractConjunction(join.getOtherJoinCondition().get());
List<Expression> leftConjuncts = Lists.newArrayList();
List<Expression> rightConjuncts = Lists.newArrayList();
for (Expression otherConjunct : otherConjuncts) {
if (PUSH_DOWN_LEFT_VALID_TYPE.contains(join.getJoinType())
&& allCoveredBy(otherConjunct, join.left().getOutputSet())) {
leftConjuncts.add(otherConjunct);
}
if (PUSH_DOWN_RIGHT_VALID_TYPE.contains(join.getJoinType())
&& allCoveredBy(otherConjunct, join.right().getOutputSet())) {
rightConjuncts.add(otherConjunct);
}
}
if (leftConjuncts.isEmpty() && rightConjuncts.isEmpty()) {
return null;
}
otherConjuncts.removeAll(leftConjuncts);
otherConjuncts.removeAll(rightConjuncts);
Plan left = PlanUtils.filterOrSelf(leftConjuncts, join.left());
Plan right = PlanUtils.filterOrSelf(rightConjuncts, join.right());
return new LogicalJoin<>(join.getJoinType(), join.getHashJoinConjuncts(),
ExpressionUtils.optionalAnd(otherConjuncts), left, right);
}).toRule(RuleType.PUSH_DOWN_JOIN_OTHER_CONDITION);
}
private boolean allCoveredBy(Expression predicate, Set<Slot> inputSlotSet) {
return inputSlotSet.containsAll(predicate.getInputSlots());
}
}

View File

@ -21,15 +21,17 @@ import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory;
import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.PlanUtils;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import java.util.List;
@ -37,17 +39,39 @@ import java.util.Objects;
import java.util.Set;
/**
* Push the predicate in the LogicalFilter or LogicalJoin to the join children.
* todo: Now, only support eq on condition for inner join, support other case later
* Push the predicate in the LogicalFilter to the join children.
*/
public class PushPredicateThroughJoin extends OneRewriteRuleFactory {
public class PushPredicatesThroughJoin extends OneRewriteRuleFactory {
private static final ImmutableList<JoinType> COULD_PUSH_THROUGH_LEFT = ImmutableList.of(
JoinType.INNER_JOIN,
JoinType.LEFT_OUTER_JOIN,
JoinType.LEFT_SEMI_JOIN,
JoinType.LEFT_ANTI_JOIN,
JoinType.CROSS_JOIN
);
private static final ImmutableList<JoinType> COULD_PUSH_THROUGH_RIGHT = ImmutableList.of(
JoinType.INNER_JOIN,
JoinType.RIGHT_OUTER_JOIN,
JoinType.RIGHT_SEMI_JOIN,
JoinType.RIGHT_ANTI_JOIN,
JoinType.CROSS_JOIN
);
private static final ImmutableList<JoinType> COULD_PUSH_EQUAL_TO = ImmutableList.of(
JoinType.INNER_JOIN
);
/*
* For example:
* select a.k1,b.k1 from a join b on a.k1 = b.k1 and a.k2 > 2 and b.k2 > 5 where a.k1 > 1 and b.k1 > 2
* select a.k1, b.k1 from a join b on a.k1 = b.k1 and a.k2 > 2 and b.k2 > 5
* where a.k1 > 1 and b.k1 > 2 and a.k2 > b.k2
*
* Logical plan tree:
* project
* |
* filter (a.k1 > 1 and b.k1 > 2)
* filter (a.k1 > 1 and b.k1 > 2 and a.k2 > b.k2)
* |
* join (a.k1 = b.k1 and a.k2 > 2 and b.k2 > 5)
* / \
@ -55,69 +79,72 @@ public class PushPredicateThroughJoin extends OneRewriteRuleFactory {
* transformed:
* project
* |
* join (a.k1 = b.k1)
* filter(a.k2 > b.k2)
* |
* join (otherConditions: a.k1 = b.k1)
* / \
* filter(a.k1 > 1 and a.k2 > 2 ) filter(b.k1 > 2 and b.k2 > 5)
* filter(a.k1 > 1 and a.k2 > 2) filter(b.k1 > 2 and b.k2 > 5)
* | |
* scan scan
*/
@Override
public Rule build() {
return logicalFilter(innerLogicalJoin()).then(filter -> {
return logicalFilter(logicalJoin()).then(filter -> {
LogicalJoin<GroupPlan, GroupPlan> join = filter.child();
Expression wherePredicates = filter.getPredicates();
Expression onPredicates = join.getOtherJoinCondition().orElse(BooleanLiteral.TRUE);
Expression filterPredicates = filter.getPredicates();
List<Expression> otherConditions = Lists.newArrayList();
List<Expression> eqConditions = Lists.newArrayList();
List<Expression> filterConditions = Lists.newArrayList();
List<Expression> joinConditions = Lists.newArrayList();
Set<Slot> leftInput = join.left().getOutputSet();
Set<Slot> rightInput = join.right().getOutputSet();
ExpressionUtils.extractConjunction(ExpressionUtils.and(onPredicates, wherePredicates))
ExpressionUtils.extractConjunction(filterPredicates)
.forEach(predicate -> {
if (Objects.nonNull(getJoinCondition(predicate, leftInput, rightInput))) {
eqConditions.add(predicate);
if (Objects.nonNull(getJoinCondition(predicate, leftInput, rightInput))
&& COULD_PUSH_EQUAL_TO.contains(join.getJoinType())) {
joinConditions.add(predicate);
} else {
otherConditions.add(predicate);
filterConditions.add(predicate);
}
});
List<Expression> leftPredicates = Lists.newArrayList();
List<Expression> rightPredicates = Lists.newArrayList();
for (Expression p : otherConditions) {
for (Expression p : filterConditions) {
Set<Slot> slots = p.getInputSlots();
if (slots.isEmpty()) {
leftPredicates.add(p);
rightPredicates.add(p);
continue;
}
if (leftInput.containsAll(slots)) {
if (leftInput.containsAll(slots) && COULD_PUSH_THROUGH_LEFT.contains(join.getJoinType())) {
leftPredicates.add(p);
}
if (rightInput.containsAll(slots)) {
if (rightInput.containsAll(slots) && COULD_PUSH_THROUGH_RIGHT.contains(join.getJoinType())) {
rightPredicates.add(p);
}
}
otherConditions.removeAll(leftPredicates);
otherConditions.removeAll(rightPredicates);
otherConditions.addAll(eqConditions);
filterConditions.removeAll(leftPredicates);
filterConditions.removeAll(rightPredicates);
join.getOtherJoinCondition().map(joinConditions::add);
return pushDownPredicate(join, otherConditions, leftPredicates, rightPredicates);
return PlanUtils.filterOrSelf(filterConditions,
pushDownPredicate(join, joinConditions, leftPredicates, rightPredicates));
}).toRule(RuleType.PUSH_DOWN_PREDICATE_THROUGH_JOIN);
}
private Plan pushDownPredicate(LogicalJoin<GroupPlan, GroupPlan> joinPlan,
private Plan pushDownPredicate(LogicalJoin<GroupPlan, GroupPlan> join,
List<Expression> joinConditions, List<Expression> leftPredicates, List<Expression> rightPredicates) {
// todo expr should optimize again using expr rewrite
Plan leftPlan = PlanUtils.filterOrSelf(leftPredicates, joinPlan.left());
Plan rightPlan = PlanUtils.filterOrSelf(rightPredicates, joinPlan.right());
Plan leftPlan = PlanUtils.filterOrSelf(leftPredicates, join.left());
Plan rightPlan = PlanUtils.filterOrSelf(rightPredicates, join.right());
return new LogicalJoin<>(joinPlan.getJoinType(), joinPlan.getHashJoinConjuncts(),
return new LogicalJoin<>(join.getJoinType(), join.getHashJoinConjuncts(),
ExpressionUtils.optionalAnd(joinConditions), leftPlan, rightPlan);
}
@ -128,13 +155,13 @@ public class PushPredicateThroughJoin extends OneRewriteRuleFactory {
ComparisonPredicate comparison = (ComparisonPredicate) predicate;
Set<Slot> leftSlots = comparison.left().getInputSlots();
Set<Slot> rightSlots = comparison.right().getInputSlots();
if (!(leftSlots.size() >= 1 && rightSlots.size() >= 1)) {
if (!(comparison instanceof EqualTo)) {
return null;
}
Set<Slot> leftSlots = comparison.left().getInputSlots();
Set<Slot> rightSlots = comparison.right().getInputSlots();
if ((leftOutputs.containsAll(leftSlots) && rightOutputs.containsAll(rightSlots))
|| (leftOutputs.containsAll(rightSlots) && rightOutputs.containsAll(leftSlots))) {
return predicate;

View File

@ -57,8 +57,8 @@ import java.util.Optional;
class FindHashConditionForJoinTest {
@Test
public void testFindHashCondition() {
Plan student = new LogicalOlapScan(PlanConstructor.getNextId(), PlanConstructor.student, ImmutableList.of(""));
Plan score = new LogicalOlapScan(PlanConstructor.getNextId(), PlanConstructor.score, ImmutableList.of(""));
Plan student = new LogicalOlapScan(PlanConstructor.getNextRelationId(), PlanConstructor.student, ImmutableList.of(""));
Plan score = new LogicalOlapScan(PlanConstructor.getNextRelationId(), PlanConstructor.score, ImmutableList.of(""));
Slot studentId = student.getOutput().get(0);
Slot gender = student.getOutput().get(1);

View File

@ -23,6 +23,7 @@ import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.RelationId;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalLimit;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
@ -48,8 +49,8 @@ import java.util.function.Function;
import java.util.stream.Collectors;
class LimitPushDownTest extends TestWithFeService implements PatternMatchSupported {
private Plan scanScore = new LogicalOlapScan(PlanConstructor.score);
private Plan scanStudent = new LogicalOlapScan(PlanConstructor.student);
private Plan scanScore = new LogicalOlapScan(new RelationId(0), PlanConstructor.score);
private Plan scanStudent = new LogicalOlapScan(new RelationId(1), PlanConstructor.student);
@Override
protected void runBeforeAll() throws Exception {
@ -213,8 +214,8 @@ class LimitPushDownTest extends TestWithFeService implements PatternMatchSupport
joinType,
joinConditions,
Optional.empty(),
new LogicalOlapScan(PlanConstructor.score),
new LogicalOlapScan(PlanConstructor.student)
new LogicalOlapScan(new RelationId(0), PlanConstructor.score),
new LogicalOlapScan(new RelationId(1), PlanConstructor.student)
);
if (hasProject) {

View File

@ -88,7 +88,7 @@ class PruneOlapScanPartitionTest {
olapTable.getName();
result = "tbl";
}};
LogicalOlapScan scan = new LogicalOlapScan(PlanConstructor.getNextId(), olapTable);
LogicalOlapScan scan = new LogicalOlapScan(PlanConstructor.getNextRelationId(), olapTable);
SlotReference slotRef = new SlotReference("col1", IntegerType.INSTANCE);
Expression expression = new LessThan(slotRef, new IntegerLiteral(4));
LogicalFilter<LogicalOlapScan> filter = new LogicalFilter<>(expression, scan);
@ -104,7 +104,7 @@ class PruneOlapScanPartitionTest {
Expression greaterThan6 = new GreaterThan(slotRef, new IntegerLiteral(6));
Or lessThan0OrGreaterThan6 = new Or(lessThan0, greaterThan6);
filter = new LogicalFilter<>(lessThan0OrGreaterThan6, scan);
scan = new LogicalOlapScan(PlanConstructor.getNextId(), olapTable);
scan = new LogicalOlapScan(PlanConstructor.getNextRelationId(), olapTable);
cascadesContext = MemoTestUtils.createCascadesContext(filter);
rules = Lists.newArrayList(new PruneOlapScanPartition().build());
cascadesContext.topDownRewrite(rules);
@ -118,7 +118,7 @@ class PruneOlapScanPartitionTest {
Expression lessThanEqual5 =
new LessThanEqual(slotRef, new IntegerLiteral(5));
And greaterThanEqual0AndLessThanEqual5 = new And(greaterThanEqual0, lessThanEqual5);
scan = new LogicalOlapScan(PlanConstructor.getNextId(), olapTable);
scan = new LogicalOlapScan(PlanConstructor.getNextRelationId(), olapTable);
filter = new LogicalFilter<>(greaterThanEqual0AndLessThanEqual5, scan);
cascadesContext = MemoTestUtils.createCascadesContext(filter);
rules = Lists.newArrayList(new PruneOlapScanPartition().build());
@ -153,7 +153,7 @@ class PruneOlapScanPartitionTest {
olapTable.getName();
result = "tbl";
}};
LogicalOlapScan scan = new LogicalOlapScan(PlanConstructor.getNextId(), olapTable);
LogicalOlapScan scan = new LogicalOlapScan(PlanConstructor.getNextRelationId(), olapTable);
Expression left = new LessThan(new SlotReference("col1", IntegerType.INSTANCE), new IntegerLiteral(4));
Expression right = new GreaterThan(new SlotReference("col2", IntegerType.INSTANCE), new IntegerLiteral(11));
CompoundPredicate and = new And(left, right);

View File

@ -0,0 +1,196 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.rewrite.logical;
import org.apache.doris.nereids.memo.Group;
import org.apache.doris.nereids.memo.Memo;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.GreaterThan;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.PlanConstructor;
import org.apache.doris.nereids.util.PlanRewriter;
import org.apache.doris.qe.ConnectContext;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInstance;
import java.util.Optional;
@TestInstance(TestInstance.Lifecycle.PER_CLASS)
public class PushDownJoinOtherConditionTest {
private Plan rStudent;
private Plan rScore;
/**
* ut before.
*/
@BeforeAll
public final void beforeAll() {
rStudent = new LogicalOlapScan(PlanConstructor.getNextRelationId(), PlanConstructor.student, ImmutableList.of(""));
rScore = new LogicalOlapScan(PlanConstructor.getNextRelationId(), PlanConstructor.score, ImmutableList.of(""));
}
@Test
public void oneSide() {
oneSide(JoinType.CROSS_JOIN, false);
oneSide(JoinType.INNER_JOIN, false);
oneSide(JoinType.LEFT_OUTER_JOIN, true);
oneSide(JoinType.LEFT_SEMI_JOIN, true);
oneSide(JoinType.LEFT_ANTI_JOIN, true);
oneSide(JoinType.RIGHT_OUTER_JOIN, false);
oneSide(JoinType.RIGHT_SEMI_JOIN, false);
oneSide(JoinType.RIGHT_ANTI_JOIN, false);
}
private void oneSide(JoinType joinType, boolean testRight) {
Expression pushSide1 = new GreaterThan(rStudent.getOutput().get(1), Literal.of(18));
Expression pushSide2 = new GreaterThan(rStudent.getOutput().get(1), Literal.of(50));
Expression condition = ExpressionUtils.and(pushSide1, pushSide2);
Plan left = rStudent;
Plan right = rScore;
if (testRight) {
left = rScore;
right = rStudent;
}
Plan join = new LogicalJoin<>(joinType, Lists.newArrayList(), Optional.of(condition), left, right);
Plan root = new LogicalProject<>(Lists.newArrayList(), join);
Memo memo = rewrite(root);
Group rootGroup = memo.getRoot();
Plan shouldJoin = rootGroup.getLogicalExpression().child(0).getLogicalExpression().getPlan();
Plan shouldFilter = rootGroup.getLogicalExpression().child(0).getLogicalExpression()
.child(0).getLogicalExpression().getPlan();
Plan shouldScan = rootGroup.getLogicalExpression().child(0).getLogicalExpression()
.child(1).getLogicalExpression().getPlan();
if (testRight) {
shouldFilter = rootGroup.getLogicalExpression().child(0).getLogicalExpression()
.child(1).getLogicalExpression().getPlan();
shouldScan = rootGroup.getLogicalExpression().child(0).getLogicalExpression()
.child(0).getLogicalExpression().getPlan();
}
Assertions.assertTrue(shouldJoin instanceof LogicalJoin);
Assertions.assertTrue(shouldFilter instanceof LogicalFilter);
Assertions.assertTrue(shouldScan instanceof LogicalOlapScan);
LogicalFilter<Plan> actualFilter = (LogicalFilter<Plan>) shouldFilter;
Assertions.assertEquals(condition, actualFilter.getPredicates());
}
@Test
public void bothSideToBothSide() {
bothSideToBothSide(JoinType.CROSS_JOIN);
bothSideToBothSide(JoinType.INNER_JOIN);
bothSideToBothSide(JoinType.LEFT_SEMI_JOIN);
bothSideToBothSide(JoinType.RIGHT_SEMI_JOIN);
}
private void bothSideToBothSide(JoinType joinType) {
Expression leftSide = new GreaterThan(rStudent.getOutput().get(1), Literal.of(18));
Expression rightSide = new GreaterThan(rScore.getOutput().get(2), Literal.of(60));
Expression condition = ExpressionUtils.and(leftSide, rightSide);
Plan join = new LogicalJoin<>(joinType, Lists.newArrayList(), Optional.of(condition), rStudent, rScore);
Plan root = new LogicalProject<>(Lists.newArrayList(), join);
Memo memo = rewrite(root);
Group rootGroup = memo.getRoot();
Plan shouldJoin = rootGroup.getLogicalExpression().child(0).getLogicalExpression().getPlan();
Plan leftFilter = rootGroup.getLogicalExpression().child(0).getLogicalExpression()
.child(0).getLogicalExpression().getPlan();
Plan rightFilter = rootGroup.getLogicalExpression().child(0).getLogicalExpression()
.child(1).getLogicalExpression().getPlan();
Assertions.assertTrue(shouldJoin instanceof LogicalJoin);
Assertions.assertTrue(leftFilter instanceof LogicalFilter);
Assertions.assertTrue(rightFilter instanceof LogicalFilter);
LogicalFilter<Plan> actualLeft = (LogicalFilter<Plan>) leftFilter;
LogicalFilter<Plan> actualRight = (LogicalFilter<Plan>) rightFilter;
Assertions.assertEquals(leftSide, actualLeft.getPredicates());
Assertions.assertEquals(rightSide, actualRight.getPredicates());
}
@Test
public void bothSideToOneSide() {
bothSideToOneSide(JoinType.LEFT_OUTER_JOIN, true);
bothSideToOneSide(JoinType.LEFT_ANTI_JOIN, true);
bothSideToOneSide(JoinType.RIGHT_OUTER_JOIN, false);
bothSideToOneSide(JoinType.RIGHT_ANTI_JOIN, false);
}
private void bothSideToOneSide(JoinType joinType, boolean testRight) {
Expression pushSide = new GreaterThan(rStudent.getOutput().get(1), Literal.of(18));
Expression reserveSide = new GreaterThan(rScore.getOutput().get(2), Literal.of(60));
Expression condition = ExpressionUtils.and(pushSide, reserveSide);
Plan left = rStudent;
Plan right = rScore;
if (testRight) {
left = rScore;
right = rStudent;
}
Plan join = new LogicalJoin<>(joinType, Lists.newArrayList(), Optional.of(condition), left, right);
Plan root = new LogicalProject<>(Lists.newArrayList(), join);
Memo memo = rewrite(root);
Group rootGroup = memo.getRoot();
Plan shouldJoin = rootGroup.getLogicalExpression()
.child(0).getLogicalExpression().getPlan();
Plan shouldFilter = rootGroup.getLogicalExpression()
.child(0).getLogicalExpression().child(0).getLogicalExpression().getPlan();
Plan shouldScan = rootGroup.getLogicalExpression()
.child(0).getLogicalExpression().child(1).getLogicalExpression().getPlan();
if (testRight) {
shouldFilter = rootGroup.getLogicalExpression()
.child(0).getLogicalExpression().child(1).getLogicalExpression().getPlan();
shouldScan = rootGroup.getLogicalExpression()
.child(0).getLogicalExpression().child(0).getLogicalExpression().getPlan();
}
Assertions.assertTrue(shouldJoin instanceof LogicalJoin);
Assertions.assertTrue(shouldFilter instanceof LogicalFilter);
Assertions.assertTrue(shouldScan instanceof LogicalOlapScan);
LogicalFilter<Plan> actualFilter = (LogicalFilter<Plan>) shouldFilter;
Assertions.assertEquals(pushSide, actualFilter.getPredicates());
}
private Memo rewrite(Plan plan) {
return PlanRewriter.topDownRewriteMemo(plan, new ConnectContext(), new PushDownJoinOtherCondition());
}
}

View File

@ -1,228 +0,0 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.rewrite.logical;
import org.apache.doris.nereids.memo.Group;
import org.apache.doris.nereids.memo.Memo;
import org.apache.doris.nereids.rules.expression.rewrite.ExpressionNormalization;
import org.apache.doris.nereids.trees.expressions.Add;
import org.apache.doris.nereids.trees.expressions.And;
import org.apache.doris.nereids.trees.expressions.Between;
import org.apache.doris.nereids.trees.expressions.Cast;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.GreaterThan;
import org.apache.doris.nereids.trees.expressions.GreaterThanEqual;
import org.apache.doris.nereids.trees.expressions.LessThanEqual;
import org.apache.doris.nereids.trees.expressions.Subtract;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.types.DoubleType;
import org.apache.doris.nereids.types.StringType;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.PlanConstructor;
import org.apache.doris.nereids.util.PlanRewriter;
import org.apache.doris.qe.ConnectContext;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInstance;
import java.util.ArrayList;
import java.util.Optional;
/**
* plan rewrite ut.
*/
@TestInstance(TestInstance.Lifecycle.PER_CLASS)
public class PushDownPredicateTest {
private Plan rStudent;
private Plan rScore;
private Plan rCourse;
/**
* ut before.
*/
@BeforeAll
public final void beforeAll() {
rStudent = new LogicalOlapScan(PlanConstructor.getNextId(), PlanConstructor.student, ImmutableList.of(""));
rScore = new LogicalOlapScan(PlanConstructor.getNextId(), PlanConstructor.score, ImmutableList.of(""));
rCourse = new LogicalOlapScan(PlanConstructor.getNextId(), PlanConstructor.course, ImmutableList.of(""));
}
@Test
public void pushDownPredicateIntoScanTest1() {
// select id,name,grade from student join score on student.id = score.sid and student.id > 1
// and score.cid > 2 where student.age > 18 and score.grade > 60
Expression onCondition1 = new EqualTo(rStudent.getOutput().get(0), rScore.getOutput().get(0));
Expression onCondition2 = new GreaterThan(rStudent.getOutput().get(0), Literal.of(1));
Expression onCondition3 = new GreaterThan(rScore.getOutput().get(0), Literal.of(2));
Expression onCondition = ExpressionUtils.and(onCondition1, onCondition2, onCondition3);
Expression whereCondition1 = new GreaterThan(rStudent.getOutput().get(1), Literal.of(18));
Expression whereCondition2 = new GreaterThan(rScore.getOutput().get(2), Literal.of(60));
Expression whereCondition = ExpressionUtils.and(whereCondition1, whereCondition2);
Plan join = new LogicalJoin(JoinType.INNER_JOIN, new ArrayList<>(), Optional.of(onCondition), rStudent, rScore);
Plan filter = new LogicalFilter(whereCondition, join);
Plan root = new LogicalProject(
Lists.newArrayList(rStudent.getOutput().get(1), rCourse.getOutput().get(1), rScore.getOutput().get(2)),
filter
);
Memo memo = rewrite(root);
Group rootGroup = memo.getRoot();
Plan op1 = rootGroup.getLogicalExpression().child(0).getLogicalExpression().getPlan();
Plan op2 = rootGroup.getLogicalExpression().child(0).getLogicalExpression().child(0).getLogicalExpression()
.getPlan();
Plan op3 = rootGroup.getLogicalExpression().child(0).getLogicalExpression().child(1).getLogicalExpression()
.getPlan();
Assertions.assertTrue(op1 instanceof LogicalJoin);
Assertions.assertTrue(op2 instanceof LogicalFilter);
Assertions.assertTrue(op3 instanceof LogicalFilter);
LogicalJoin join1 = (LogicalJoin) op1;
LogicalFilter filter1 = (LogicalFilter) op2;
LogicalFilter filter2 = (LogicalFilter) op3;
Assertions.assertEquals(onCondition1, join1.getOtherJoinCondition().get());
Assertions.assertEquals(ExpressionUtils.and(onCondition2, whereCondition1), filter1.getPredicates());
Assertions.assertEquals(ExpressionUtils.and(onCondition3,
new GreaterThan(rScore.getOutput().get(2), new Cast(Literal.of(60), DoubleType.INSTANCE))),
filter2.getPredicates());
}
@Test
public void pushDownPredicateIntoScanTest3() {
//select id,name,grade from student left join score on student.id + 1 = score.sid - 2
//where student.age > 18 and score.grade > 60
Expression whereCondition1 = new EqualTo(new Add(rStudent.getOutput().get(0), Literal.of(1)),
new Subtract(rScore.getOutput().get(0), Literal.of(2)));
Expression whereCondition2 = new GreaterThan(rStudent.getOutput().get(1), Literal.of(18));
Expression whereCondition3 = new GreaterThan(rScore.getOutput().get(2), Literal.of(60));
Expression whereCondition = ExpressionUtils.and(whereCondition1, whereCondition2, whereCondition3);
Plan join = new LogicalJoin(JoinType.INNER_JOIN, new ArrayList<>(), Optional.empty(), rStudent, rScore);
Plan filter = new LogicalFilter(whereCondition, join);
Plan root = new LogicalProject(
Lists.newArrayList(rStudent.getOutput().get(1), rCourse.getOutput().get(1), rScore.getOutput().get(2)),
filter
);
Memo memo = rewrite(root);
Group rootGroup = memo.getRoot();
Plan op1 = rootGroup.getLogicalExpression().child(0).getLogicalExpression().getPlan();
Plan op2 = rootGroup.getLogicalExpression().child(0).getLogicalExpression().child(0).getLogicalExpression()
.getPlan();
Plan op3 = rootGroup.getLogicalExpression().child(0).getLogicalExpression().child(1).getLogicalExpression()
.getPlan();
Assertions.assertTrue(op1 instanceof LogicalJoin);
Assertions.assertTrue(op2 instanceof LogicalFilter);
Assertions.assertTrue(op3 instanceof LogicalFilter);
LogicalJoin join1 = (LogicalJoin) op1;
LogicalFilter filter1 = (LogicalFilter) op2;
LogicalFilter filter2 = (LogicalFilter) op3;
Assertions.assertEquals(whereCondition1, join1.getOtherJoinCondition().get());
Assertions.assertEquals(whereCondition2, filter1.getPredicates());
Assertions.assertEquals(
new GreaterThan(rScore.getOutput().get(2), new Cast(Literal.of(60), DoubleType.INSTANCE)),
filter2.getPredicates());
}
@Test
public void pushDownPredicateIntoScanTest4() {
/*
select
student.name,
course.name,
score.grade
from student,score,course
where on student.id = score.sid and student.age between 18 and 20 and score.grade > 60 and student.id = score.sid
*/
// student.id = score.sid
Expression whereCondition1 = new EqualTo(rStudent.getOutput().get(0), rScore.getOutput().get(0));
// score.cid = course.cid
Expression whereCondition2 = new EqualTo(rScore.getOutput().get(1), rCourse.getOutput().get(0));
// student.age between 18 and 20
Expression whereCondition3 = new Between(rStudent.getOutput().get(2), Literal.of(18), Literal.of(20));
// student.age >= 18 and student.age <= 20
Expression whereCondition3result = new And(
new GreaterThanEqual(rStudent.getOutput().get(2), new Cast(Literal.of(18), StringType.INSTANCE)),
new LessThanEqual(rStudent.getOutput().get(2), new Cast(Literal.of(20), StringType.INSTANCE)));
// score.grade > 60
Expression whereCondition4 = new GreaterThan(rScore.getOutput().get(2), Literal.of(60));
Expression whereCondition = ExpressionUtils.and(whereCondition1, whereCondition2, whereCondition3,
whereCondition4);
Plan join = new LogicalJoin(JoinType.INNER_JOIN, ImmutableList.of(), Optional.empty(), rStudent, rScore);
Plan join1 = new LogicalJoin(JoinType.INNER_JOIN, ImmutableList.of(), Optional.empty(), join, rCourse);
Plan filter = new LogicalFilter(whereCondition, join1);
Plan root = new LogicalProject(
Lists.newArrayList(rStudent.getOutput().get(1), rCourse.getOutput().get(1), rScore.getOutput().get(2)),
filter
);
Memo memo = rewrite(root);
Group rootGroup = memo.getRoot();
Plan join2 = rootGroup.getLogicalExpression().child(0).getLogicalExpression().getPlan();
Plan join3 = rootGroup.getLogicalExpression().child(0).getLogicalExpression().child(0).getLogicalExpression()
.getPlan();
Plan op1 = rootGroup.getLogicalExpression().child(0).getLogicalExpression().child(0).getLogicalExpression()
.child(0).getLogicalExpression().getPlan();
Plan op2 = rootGroup.getLogicalExpression().child(0).getLogicalExpression().child(0).getLogicalExpression()
.child(1).getLogicalExpression().getPlan();
Assertions.assertTrue(join2 instanceof LogicalJoin);
Assertions.assertTrue(join3 instanceof LogicalJoin);
Assertions.assertTrue(op1 instanceof LogicalFilter);
Assertions.assertTrue(op2 instanceof LogicalFilter);
Assertions.assertEquals(whereCondition2, ((LogicalJoin) join2).getOtherJoinCondition().get());
Assertions.assertEquals(whereCondition1, ((LogicalJoin) join3).getOtherJoinCondition().get());
Assertions.assertEquals(whereCondition3result.toSql(), ((LogicalFilter) op1).getPredicates().toSql());
Assertions.assertEquals(
new GreaterThan(rScore.getOutput().get(2), new Cast(Literal.of(60), DoubleType.INSTANCE)),
((LogicalFilter) op2).getPredicates());
}
private Memo rewrite(Plan plan) {
Plan normalizedPlan = PlanRewriter.topDownRewrite(plan, new ConnectContext(), new ExpressionNormalization());
return PlanRewriter.topDownRewriteMemo(normalizedPlan, new ConnectContext(), new PushPredicateThroughJoin());
}
}

View File

@ -0,0 +1,208 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.rewrite.logical;
import org.apache.doris.nereids.memo.Group;
import org.apache.doris.nereids.memo.Memo;
import org.apache.doris.nereids.trees.expressions.Add;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.GreaterThan;
import org.apache.doris.nereids.trees.expressions.Subtract;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.PlanConstructor;
import org.apache.doris.nereids.util.PlanRewriter;
import org.apache.doris.qe.ConnectContext;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInstance;
import java.util.Optional;
/**
* plan rewrite ut.
*/
@TestInstance(TestInstance.Lifecycle.PER_CLASS)
public class PushPredicateThroughJoinTest {
private Plan rStudent;
private Plan rScore;
/**
* ut before.
*/
@BeforeAll
public final void beforeAll() {
rStudent = new LogicalOlapScan(PlanConstructor.getNextRelationId(), PlanConstructor.student, ImmutableList.of(""));
rScore = new LogicalOlapScan(PlanConstructor.getNextRelationId(), PlanConstructor.score, ImmutableList.of(""));
}
@Test
public void oneSide() {
oneSide(JoinType.CROSS_JOIN, false);
oneSide(JoinType.INNER_JOIN, false);
oneSide(JoinType.LEFT_OUTER_JOIN, false);
oneSide(JoinType.LEFT_SEMI_JOIN, false);
oneSide(JoinType.LEFT_ANTI_JOIN, false);
oneSide(JoinType.RIGHT_OUTER_JOIN, true);
oneSide(JoinType.RIGHT_SEMI_JOIN, true);
oneSide(JoinType.RIGHT_ANTI_JOIN, true);
}
private void oneSide(JoinType joinType, boolean testRight) {
Expression whereCondition1 = new GreaterThan(rStudent.getOutput().get(1), Literal.of(18));
Expression whereCondition2 = new GreaterThan(rStudent.getOutput().get(1), Literal.of(50));
Expression whereCondition = ExpressionUtils.and(whereCondition1, whereCondition2);
Plan left = rStudent;
Plan right = rScore;
if (testRight) {
left = rScore;
right = rStudent;
}
Plan join = new LogicalJoin<>(joinType, Lists.newArrayList(), Optional.empty(), left, right);
Plan filter = new LogicalFilter<>(whereCondition, join);
Plan root = new LogicalProject<>(Lists.newArrayList(), filter);
Memo memo = rewrite(root);
Group rootGroup = memo.getRoot();
Plan shouldJoin = rootGroup.getLogicalExpression().child(0).getLogicalExpression().getPlan();
Plan shouldFilter = rootGroup.getLogicalExpression().child(0).getLogicalExpression()
.child(0).getLogicalExpression().getPlan();
Plan shouldScan = rootGroup.getLogicalExpression().child(0).getLogicalExpression()
.child(1).getLogicalExpression().getPlan();
if (testRight) {
shouldFilter = rootGroup.getLogicalExpression().child(0).getLogicalExpression()
.child(1).getLogicalExpression().getPlan();
shouldScan = rootGroup.getLogicalExpression().child(0).getLogicalExpression()
.child(0).getLogicalExpression().getPlan();
}
Assertions.assertTrue(shouldJoin instanceof LogicalJoin);
Assertions.assertTrue(shouldFilter instanceof LogicalFilter);
Assertions.assertTrue(shouldScan instanceof LogicalOlapScan);
LogicalFilter<Plan> actualFilter = (LogicalFilter<Plan>) shouldFilter;
Assertions.assertEquals(whereCondition, actualFilter.getPredicates());
}
@Test
public void bothSideToBothSide() {
bothSideToBothSide(JoinType.INNER_JOIN);
}
private void bothSideToBothSide(JoinType joinType) {
Expression bothSideEqualTo = new EqualTo(new Add(rStudent.getOutput().get(0), Literal.of(1)),
new Subtract(rScore.getOutput().get(0), Literal.of(2)));
Expression leftSide = new GreaterThan(rStudent.getOutput().get(1), Literal.of(18));
Expression rightSide = new GreaterThan(rScore.getOutput().get(2), Literal.of(60));
Expression whereCondition = ExpressionUtils.and(bothSideEqualTo, leftSide, rightSide);
Plan join = new LogicalJoin<>(joinType, Lists.newArrayList(), Optional.empty(), rStudent, rScore);
Plan filter = new LogicalFilter<>(whereCondition, join);
Plan root = new LogicalProject<>(Lists.newArrayList(), filter);
Memo memo = rewrite(root);
Group rootGroup = memo.getRoot();
Plan shouldJoin = rootGroup.getLogicalExpression().child(0).getLogicalExpression().getPlan();
Plan leftFilter = rootGroup.getLogicalExpression().child(0).getLogicalExpression()
.child(0).getLogicalExpression().getPlan();
Plan rightFilter = rootGroup.getLogicalExpression().child(0).getLogicalExpression()
.child(1).getLogicalExpression().getPlan();
Assertions.assertTrue(shouldJoin instanceof LogicalJoin);
Assertions.assertTrue(leftFilter instanceof LogicalFilter);
Assertions.assertTrue(rightFilter instanceof LogicalFilter);
LogicalJoin<Plan, Plan> actualJoin = (LogicalJoin<Plan, Plan>) shouldJoin;
LogicalFilter<Plan> actualLeft = (LogicalFilter<Plan>) leftFilter;
LogicalFilter<Plan> actualRight = (LogicalFilter<Plan>) rightFilter;
Assertions.assertEquals(bothSideEqualTo, actualJoin.getOtherJoinCondition().get());
Assertions.assertEquals(leftSide, actualLeft.getPredicates());
Assertions.assertEquals(rightSide, actualRight.getPredicates());
}
@Test
public void bothSideToOneSide() {
bothSideToOneSide(JoinType.LEFT_OUTER_JOIN, false);
bothSideToOneSide(JoinType.LEFT_ANTI_JOIN, false);
bothSideToOneSide(JoinType.LEFT_SEMI_JOIN, false);
bothSideToOneSide(JoinType.RIGHT_OUTER_JOIN, true);
bothSideToOneSide(JoinType.RIGHT_ANTI_JOIN, true);
bothSideToOneSide(JoinType.RIGHT_SEMI_JOIN, true);
}
private void bothSideToOneSide(JoinType joinType, boolean testRight) {
Expression pushSide = new GreaterThan(rStudent.getOutput().get(1), Literal.of(18));
Expression reserveSide = new GreaterThan(rScore.getOutput().get(2), Literal.of(60));
Expression whereCondition = ExpressionUtils.and(pushSide, reserveSide);
Plan left = rStudent;
Plan right = rScore;
if (testRight) {
left = rScore;
right = rStudent;
}
Plan join = new LogicalJoin<>(joinType, Lists.newArrayList(), Optional.empty(), left, right);
Plan filter = new LogicalFilter<>(whereCondition, join);
Plan root = new LogicalProject<>(Lists.newArrayList(), filter);
Memo memo = rewrite(root);
Group rootGroup = memo.getRoot();
Plan shouldJoin = rootGroup.getLogicalExpression().child(0).getLogicalExpression()
.child(0).getLogicalExpression().getPlan();
Plan shouldFilter = rootGroup.getLogicalExpression().child(0).getLogicalExpression()
.child(0).getLogicalExpression().child(0).getLogicalExpression().getPlan();
Plan shouldScan = rootGroup.getLogicalExpression().child(0).getLogicalExpression()
.child(0).getLogicalExpression().child(1).getLogicalExpression().getPlan();
if (testRight) {
shouldFilter = rootGroup.getLogicalExpression().child(0).getLogicalExpression()
.child(0).getLogicalExpression().child(1).getLogicalExpression().getPlan();
shouldScan = rootGroup.getLogicalExpression().child(0).getLogicalExpression()
.child(0).getLogicalExpression().child(0).getLogicalExpression().getPlan();
}
Assertions.assertTrue(shouldJoin instanceof LogicalJoin);
Assertions.assertTrue(shouldFilter instanceof LogicalFilter);
Assertions.assertTrue(shouldScan instanceof LogicalOlapScan);
LogicalFilter<Plan> actualFilter = (LogicalFilter<Plan>) shouldFilter;
Assertions.assertEquals(pushSide, actualFilter.getPredicates());
}
private Memo rewrite(Plan plan) {
return PlanRewriter.topDownRewriteMemo(plan, new ConnectContext(), new PushPredicatesThroughJoin());
}
}

View File

@ -37,7 +37,7 @@ public class PlanConstructor {
public static OlapTable student;
public static OlapTable score;
public static OlapTable course;
private static final IdGenerator<RelationId> GENERATOR = RelationId.createGenerator();
private static final IdGenerator<RelationId> RELATION_ID_GENERATOR = RelationId.createGenerator();
static {
student = new OlapTable(0L, "student",
@ -102,14 +102,14 @@ public class PlanConstructor {
// With OlapTable.
// Warning: equals() of Table depends on tableId.
public static LogicalOlapScan newLogicalOlapScan(long tableId, String tableName, int hashColumn) {
return new LogicalOlapScan(GENERATOR.getNextId(), newOlapTable(tableId, tableName, hashColumn), ImmutableList.of("db"));
return new LogicalOlapScan(RELATION_ID_GENERATOR.getNextId(), newOlapTable(tableId, tableName, hashColumn), ImmutableList.of("db"));
}
public static LogicalOlapScan newLogicalOlapScanWithSameId(long tableId, String tableName, int hashColumn) {
return new LogicalOlapScan(RelationId.createGenerator().getNextId(), newOlapTable(tableId, tableName, hashColumn), ImmutableList.of("db"));
}
public static RelationId getNextId() {
return GENERATOR.getNextId();
public static RelationId getNextRelationId() {
return RELATION_ID_GENERATOR.getNextId();
}
}