[feature](Nereids) add predicates push down on all join type (#12571)
* [feature](Nereids) add predicates push down on all join type
This commit is contained in:
@ -19,6 +19,7 @@ package org.apache.doris.nereids.jobs.batch;
|
||||
|
||||
import org.apache.doris.nereids.CascadesContext;
|
||||
import org.apache.doris.nereids.jobs.Job;
|
||||
import org.apache.doris.nereids.rules.RuleSet;
|
||||
import org.apache.doris.nereids.rules.expression.rewrite.ExpressionNormalization;
|
||||
import org.apache.doris.nereids.rules.mv.SelectRollup;
|
||||
import org.apache.doris.nereids.rules.rewrite.AggregateDisassemble;
|
||||
@ -27,14 +28,8 @@ import org.apache.doris.nereids.rules.rewrite.logical.EliminateFilter;
|
||||
import org.apache.doris.nereids.rules.rewrite.logical.EliminateLimit;
|
||||
import org.apache.doris.nereids.rules.rewrite.logical.FindHashConditionForJoin;
|
||||
import org.apache.doris.nereids.rules.rewrite.logical.LimitPushDown;
|
||||
import org.apache.doris.nereids.rules.rewrite.logical.MergeConsecutiveFilters;
|
||||
import org.apache.doris.nereids.rules.rewrite.logical.MergeConsecutiveLimits;
|
||||
import org.apache.doris.nereids.rules.rewrite.logical.MergeConsecutiveProjects;
|
||||
import org.apache.doris.nereids.rules.rewrite.logical.NormalizeAggregate;
|
||||
import org.apache.doris.nereids.rules.rewrite.logical.PruneOlapScanPartition;
|
||||
import org.apache.doris.nereids.rules.rewrite.logical.PushPredicateThroughJoin;
|
||||
import org.apache.doris.nereids.rules.rewrite.logical.PushdownFilterThroughProject;
|
||||
import org.apache.doris.nereids.rules.rewrite.logical.PushdownProjectThroughLimit;
|
||||
import org.apache.doris.nereids.rules.rewrite.logical.ReorderJoin;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
@ -64,15 +59,9 @@ public class RewriteJob extends BatchRulesJob {
|
||||
.add(topDownBatch(ImmutableList.of(new ExpressionNormalization())))
|
||||
.add(topDownBatch(ImmutableList.of(new NormalizeAggregate())))
|
||||
.add(topDownBatch(ImmutableList.of(new ReorderJoin())))
|
||||
.add(topDownBatch(ImmutableList.of(new FindHashConditionForJoin())))
|
||||
.add(topDownBatch(ImmutableList.of(new NormalizeAggregate())))
|
||||
.add(topDownBatch(ImmutableList.of(new ColumnPruning())))
|
||||
.add(topDownBatch(ImmutableList.of(new PushPredicateThroughJoin(),
|
||||
new PushdownProjectThroughLimit(),
|
||||
new PushdownFilterThroughProject(),
|
||||
new MergeConsecutiveProjects(),
|
||||
new MergeConsecutiveFilters(),
|
||||
new MergeConsecutiveLimits())))
|
||||
.add(topDownBatch(RuleSet.PUSH_DOWN_JOIN_CONDITION_RULES))
|
||||
.add(topDownBatch(ImmutableList.of(new FindHashConditionForJoin())))
|
||||
.add(topDownBatch(ImmutableList.of(new AggregateDisassemble())))
|
||||
.add(topDownBatch(ImmutableList.of(new LimitPushDown())))
|
||||
.add(topDownBatch(ImmutableList.of(new EliminateLimit())))
|
||||
|
||||
@ -33,9 +33,13 @@ import org.apache.doris.nereids.rules.implementation.LogicalOneRowRelationToPhys
|
||||
import org.apache.doris.nereids.rules.implementation.LogicalProjectToPhysicalProject;
|
||||
import org.apache.doris.nereids.rules.implementation.LogicalSortToPhysicalQuickSort;
|
||||
import org.apache.doris.nereids.rules.implementation.LogicalTopNToPhysicalTopN;
|
||||
import org.apache.doris.nereids.rules.rewrite.AggregateDisassemble;
|
||||
import org.apache.doris.nereids.rules.rewrite.logical.MergeConsecutiveFilters;
|
||||
import org.apache.doris.nereids.rules.rewrite.logical.MergeConsecutiveLimits;
|
||||
import org.apache.doris.nereids.rules.rewrite.logical.MergeConsecutiveProjects;
|
||||
import org.apache.doris.nereids.rules.rewrite.logical.PushDownJoinOtherCondition;
|
||||
import org.apache.doris.nereids.rules.rewrite.logical.PushPredicatesThroughJoin;
|
||||
import org.apache.doris.nereids.rules.rewrite.logical.PushdownFilterThroughProject;
|
||||
import org.apache.doris.nereids.rules.rewrite.logical.PushdownProjectThroughLimit;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import com.google.common.collect.ImmutableList.Builder;
|
||||
@ -55,9 +59,14 @@ public class RuleSet {
|
||||
.add(new MergeConsecutiveProjects())
|
||||
.build();
|
||||
|
||||
public static final List<Rule> REWRITE_RULES = planRuleFactories()
|
||||
.add(new AggregateDisassemble())
|
||||
.build();
|
||||
public static final List<RuleFactory> PUSH_DOWN_JOIN_CONDITION_RULES = ImmutableList.of(
|
||||
new PushDownJoinOtherCondition(),
|
||||
new PushPredicatesThroughJoin(),
|
||||
new PushdownProjectThroughLimit(),
|
||||
new PushdownFilterThroughProject(),
|
||||
new MergeConsecutiveProjects(),
|
||||
new MergeConsecutiveFilters(),
|
||||
new MergeConsecutiveLimits());
|
||||
|
||||
public static final List<Rule> IMPLEMENTATION_RULES = planRuleFactories()
|
||||
.add(new LogicalAggToPhysicalHashAgg())
|
||||
|
||||
@ -79,6 +79,7 @@ public enum RuleType {
|
||||
EXISTS_APPLY_TO_JOIN(RuleTypeClass.REWRITE),
|
||||
// predicate push down rules
|
||||
PUSH_DOWN_PREDICATE_THROUGH_JOIN(RuleTypeClass.REWRITE),
|
||||
PUSH_DOWN_JOIN_OTHER_CONDITION(RuleTypeClass.REWRITE),
|
||||
PUSH_DOWN_PREDICATE_THROUGH_AGGREGATION(RuleTypeClass.REWRITE),
|
||||
// column prune rules,
|
||||
COLUMN_PRUNE_AGGREGATION_CHILD(RuleTypeClass.REWRITE),
|
||||
|
||||
@ -0,0 +1,99 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package org.apache.doris.nereids.rules.rewrite.logical;
|
||||
|
||||
import org.apache.doris.nereids.rules.Rule;
|
||||
import org.apache.doris.nereids.rules.RuleType;
|
||||
import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.Slot;
|
||||
import org.apache.doris.nereids.trees.plans.JoinType;
|
||||
import org.apache.doris.nereids.trees.plans.Plan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
|
||||
import org.apache.doris.nereids.util.ExpressionUtils;
|
||||
import org.apache.doris.nereids.util.PlanUtils;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import com.google.common.collect.Lists;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
|
||||
/**
|
||||
* Push the other join conditions in LogicalJoin to children.
|
||||
*/
|
||||
public class PushDownJoinOtherCondition extends OneRewriteRuleFactory {
|
||||
private static final ImmutableList<JoinType> PUSH_DOWN_LEFT_VALID_TYPE = ImmutableList.of(
|
||||
JoinType.INNER_JOIN,
|
||||
JoinType.LEFT_SEMI_JOIN,
|
||||
JoinType.RIGHT_OUTER_JOIN,
|
||||
JoinType.RIGHT_ANTI_JOIN,
|
||||
JoinType.RIGHT_SEMI_JOIN,
|
||||
JoinType.CROSS_JOIN
|
||||
);
|
||||
|
||||
private static final ImmutableList<JoinType> PUSH_DOWN_RIGHT_VALID_TYPE = ImmutableList.of(
|
||||
JoinType.INNER_JOIN,
|
||||
JoinType.LEFT_OUTER_JOIN,
|
||||
JoinType.LEFT_ANTI_JOIN,
|
||||
JoinType.LEFT_SEMI_JOIN,
|
||||
JoinType.RIGHT_SEMI_JOIN,
|
||||
JoinType.CROSS_JOIN
|
||||
);
|
||||
|
||||
@Override
|
||||
public Rule build() {
|
||||
return logicalJoin().then(join -> {
|
||||
if (!join.getOtherJoinCondition().isPresent()) {
|
||||
return null;
|
||||
}
|
||||
List<Expression> otherConjuncts = ExpressionUtils.extractConjunction(join.getOtherJoinCondition().get());
|
||||
List<Expression> leftConjuncts = Lists.newArrayList();
|
||||
List<Expression> rightConjuncts = Lists.newArrayList();
|
||||
|
||||
for (Expression otherConjunct : otherConjuncts) {
|
||||
if (PUSH_DOWN_LEFT_VALID_TYPE.contains(join.getJoinType())
|
||||
&& allCoveredBy(otherConjunct, join.left().getOutputSet())) {
|
||||
leftConjuncts.add(otherConjunct);
|
||||
}
|
||||
if (PUSH_DOWN_RIGHT_VALID_TYPE.contains(join.getJoinType())
|
||||
&& allCoveredBy(otherConjunct, join.right().getOutputSet())) {
|
||||
rightConjuncts.add(otherConjunct);
|
||||
}
|
||||
}
|
||||
|
||||
if (leftConjuncts.isEmpty() && rightConjuncts.isEmpty()) {
|
||||
return null;
|
||||
}
|
||||
|
||||
otherConjuncts.removeAll(leftConjuncts);
|
||||
otherConjuncts.removeAll(rightConjuncts);
|
||||
|
||||
Plan left = PlanUtils.filterOrSelf(leftConjuncts, join.left());
|
||||
Plan right = PlanUtils.filterOrSelf(rightConjuncts, join.right());
|
||||
|
||||
return new LogicalJoin<>(join.getJoinType(), join.getHashJoinConjuncts(),
|
||||
ExpressionUtils.optionalAnd(otherConjuncts), left, right);
|
||||
|
||||
}).toRule(RuleType.PUSH_DOWN_JOIN_OTHER_CONDITION);
|
||||
}
|
||||
|
||||
private boolean allCoveredBy(Expression predicate, Set<Slot> inputSlotSet) {
|
||||
return inputSlotSet.containsAll(predicate.getInputSlots());
|
||||
}
|
||||
}
|
||||
@ -21,15 +21,17 @@ import org.apache.doris.nereids.rules.Rule;
|
||||
import org.apache.doris.nereids.rules.RuleType;
|
||||
import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory;
|
||||
import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
|
||||
import org.apache.doris.nereids.trees.expressions.EqualTo;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.Slot;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral;
|
||||
import org.apache.doris.nereids.trees.plans.GroupPlan;
|
||||
import org.apache.doris.nereids.trees.plans.JoinType;
|
||||
import org.apache.doris.nereids.trees.plans.Plan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
|
||||
import org.apache.doris.nereids.util.ExpressionUtils;
|
||||
import org.apache.doris.nereids.util.PlanUtils;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import com.google.common.collect.Lists;
|
||||
|
||||
import java.util.List;
|
||||
@ -37,17 +39,39 @@ import java.util.Objects;
|
||||
import java.util.Set;
|
||||
|
||||
/**
|
||||
* Push the predicate in the LogicalFilter or LogicalJoin to the join children.
|
||||
* todo: Now, only support eq on condition for inner join, support other case later
|
||||
* Push the predicate in the LogicalFilter to the join children.
|
||||
*/
|
||||
public class PushPredicateThroughJoin extends OneRewriteRuleFactory {
|
||||
public class PushPredicatesThroughJoin extends OneRewriteRuleFactory {
|
||||
|
||||
private static final ImmutableList<JoinType> COULD_PUSH_THROUGH_LEFT = ImmutableList.of(
|
||||
JoinType.INNER_JOIN,
|
||||
JoinType.LEFT_OUTER_JOIN,
|
||||
JoinType.LEFT_SEMI_JOIN,
|
||||
JoinType.LEFT_ANTI_JOIN,
|
||||
JoinType.CROSS_JOIN
|
||||
);
|
||||
|
||||
private static final ImmutableList<JoinType> COULD_PUSH_THROUGH_RIGHT = ImmutableList.of(
|
||||
JoinType.INNER_JOIN,
|
||||
JoinType.RIGHT_OUTER_JOIN,
|
||||
JoinType.RIGHT_SEMI_JOIN,
|
||||
JoinType.RIGHT_ANTI_JOIN,
|
||||
JoinType.CROSS_JOIN
|
||||
);
|
||||
|
||||
private static final ImmutableList<JoinType> COULD_PUSH_EQUAL_TO = ImmutableList.of(
|
||||
JoinType.INNER_JOIN
|
||||
);
|
||||
|
||||
/*
|
||||
* For example:
|
||||
* select a.k1,b.k1 from a join b on a.k1 = b.k1 and a.k2 > 2 and b.k2 > 5 where a.k1 > 1 and b.k1 > 2
|
||||
* select a.k1, b.k1 from a join b on a.k1 = b.k1 and a.k2 > 2 and b.k2 > 5
|
||||
* where a.k1 > 1 and b.k1 > 2 and a.k2 > b.k2
|
||||
*
|
||||
* Logical plan tree:
|
||||
* project
|
||||
* |
|
||||
* filter (a.k1 > 1 and b.k1 > 2)
|
||||
* filter (a.k1 > 1 and b.k1 > 2 and a.k2 > b.k2)
|
||||
* |
|
||||
* join (a.k1 = b.k1 and a.k2 > 2 and b.k2 > 5)
|
||||
* / \
|
||||
@ -55,69 +79,72 @@ public class PushPredicateThroughJoin extends OneRewriteRuleFactory {
|
||||
* transformed:
|
||||
* project
|
||||
* |
|
||||
* join (a.k1 = b.k1)
|
||||
* filter(a.k2 > b.k2)
|
||||
* |
|
||||
* join (otherConditions: a.k1 = b.k1)
|
||||
* / \
|
||||
* filter(a.k1 > 1 and a.k2 > 2 ) filter(b.k1 > 2 and b.k2 > 5)
|
||||
* filter(a.k1 > 1 and a.k2 > 2) filter(b.k1 > 2 and b.k2 > 5)
|
||||
* | |
|
||||
* scan scan
|
||||
*/
|
||||
@Override
|
||||
public Rule build() {
|
||||
return logicalFilter(innerLogicalJoin()).then(filter -> {
|
||||
return logicalFilter(logicalJoin()).then(filter -> {
|
||||
|
||||
LogicalJoin<GroupPlan, GroupPlan> join = filter.child();
|
||||
|
||||
Expression wherePredicates = filter.getPredicates();
|
||||
Expression onPredicates = join.getOtherJoinCondition().orElse(BooleanLiteral.TRUE);
|
||||
Expression filterPredicates = filter.getPredicates();
|
||||
|
||||
List<Expression> otherConditions = Lists.newArrayList();
|
||||
List<Expression> eqConditions = Lists.newArrayList();
|
||||
List<Expression> filterConditions = Lists.newArrayList();
|
||||
List<Expression> joinConditions = Lists.newArrayList();
|
||||
|
||||
Set<Slot> leftInput = join.left().getOutputSet();
|
||||
Set<Slot> rightInput = join.right().getOutputSet();
|
||||
|
||||
ExpressionUtils.extractConjunction(ExpressionUtils.and(onPredicates, wherePredicates))
|
||||
ExpressionUtils.extractConjunction(filterPredicates)
|
||||
.forEach(predicate -> {
|
||||
if (Objects.nonNull(getJoinCondition(predicate, leftInput, rightInput))) {
|
||||
eqConditions.add(predicate);
|
||||
if (Objects.nonNull(getJoinCondition(predicate, leftInput, rightInput))
|
||||
&& COULD_PUSH_EQUAL_TO.contains(join.getJoinType())) {
|
||||
joinConditions.add(predicate);
|
||||
} else {
|
||||
otherConditions.add(predicate);
|
||||
filterConditions.add(predicate);
|
||||
}
|
||||
});
|
||||
|
||||
List<Expression> leftPredicates = Lists.newArrayList();
|
||||
List<Expression> rightPredicates = Lists.newArrayList();
|
||||
|
||||
for (Expression p : otherConditions) {
|
||||
for (Expression p : filterConditions) {
|
||||
Set<Slot> slots = p.getInputSlots();
|
||||
if (slots.isEmpty()) {
|
||||
leftPredicates.add(p);
|
||||
rightPredicates.add(p);
|
||||
continue;
|
||||
}
|
||||
if (leftInput.containsAll(slots)) {
|
||||
if (leftInput.containsAll(slots) && COULD_PUSH_THROUGH_LEFT.contains(join.getJoinType())) {
|
||||
leftPredicates.add(p);
|
||||
}
|
||||
if (rightInput.containsAll(slots)) {
|
||||
if (rightInput.containsAll(slots) && COULD_PUSH_THROUGH_RIGHT.contains(join.getJoinType())) {
|
||||
rightPredicates.add(p);
|
||||
}
|
||||
}
|
||||
|
||||
otherConditions.removeAll(leftPredicates);
|
||||
otherConditions.removeAll(rightPredicates);
|
||||
otherConditions.addAll(eqConditions);
|
||||
filterConditions.removeAll(leftPredicates);
|
||||
filterConditions.removeAll(rightPredicates);
|
||||
join.getOtherJoinCondition().map(joinConditions::add);
|
||||
|
||||
return pushDownPredicate(join, otherConditions, leftPredicates, rightPredicates);
|
||||
return PlanUtils.filterOrSelf(filterConditions,
|
||||
pushDownPredicate(join, joinConditions, leftPredicates, rightPredicates));
|
||||
}).toRule(RuleType.PUSH_DOWN_PREDICATE_THROUGH_JOIN);
|
||||
}
|
||||
|
||||
private Plan pushDownPredicate(LogicalJoin<GroupPlan, GroupPlan> joinPlan,
|
||||
private Plan pushDownPredicate(LogicalJoin<GroupPlan, GroupPlan> join,
|
||||
List<Expression> joinConditions, List<Expression> leftPredicates, List<Expression> rightPredicates) {
|
||||
// todo expr should optimize again using expr rewrite
|
||||
Plan leftPlan = PlanUtils.filterOrSelf(leftPredicates, joinPlan.left());
|
||||
Plan rightPlan = PlanUtils.filterOrSelf(rightPredicates, joinPlan.right());
|
||||
Plan leftPlan = PlanUtils.filterOrSelf(leftPredicates, join.left());
|
||||
Plan rightPlan = PlanUtils.filterOrSelf(rightPredicates, join.right());
|
||||
|
||||
return new LogicalJoin<>(joinPlan.getJoinType(), joinPlan.getHashJoinConjuncts(),
|
||||
return new LogicalJoin<>(join.getJoinType(), join.getHashJoinConjuncts(),
|
||||
ExpressionUtils.optionalAnd(joinConditions), leftPlan, rightPlan);
|
||||
}
|
||||
|
||||
@ -128,13 +155,13 @@ public class PushPredicateThroughJoin extends OneRewriteRuleFactory {
|
||||
|
||||
ComparisonPredicate comparison = (ComparisonPredicate) predicate;
|
||||
|
||||
Set<Slot> leftSlots = comparison.left().getInputSlots();
|
||||
Set<Slot> rightSlots = comparison.right().getInputSlots();
|
||||
|
||||
if (!(leftSlots.size() >= 1 && rightSlots.size() >= 1)) {
|
||||
if (!(comparison instanceof EqualTo)) {
|
||||
return null;
|
||||
}
|
||||
|
||||
Set<Slot> leftSlots = comparison.left().getInputSlots();
|
||||
Set<Slot> rightSlots = comparison.right().getInputSlots();
|
||||
|
||||
if ((leftOutputs.containsAll(leftSlots) && rightOutputs.containsAll(rightSlots))
|
||||
|| (leftOutputs.containsAll(rightSlots) && rightOutputs.containsAll(leftSlots))) {
|
||||
return predicate;
|
||||
@ -57,8 +57,8 @@ import java.util.Optional;
|
||||
class FindHashConditionForJoinTest {
|
||||
@Test
|
||||
public void testFindHashCondition() {
|
||||
Plan student = new LogicalOlapScan(PlanConstructor.getNextId(), PlanConstructor.student, ImmutableList.of(""));
|
||||
Plan score = new LogicalOlapScan(PlanConstructor.getNextId(), PlanConstructor.score, ImmutableList.of(""));
|
||||
Plan student = new LogicalOlapScan(PlanConstructor.getNextRelationId(), PlanConstructor.student, ImmutableList.of(""));
|
||||
Plan score = new LogicalOlapScan(PlanConstructor.getNextRelationId(), PlanConstructor.score, ImmutableList.of(""));
|
||||
|
||||
Slot studentId = student.getOutput().get(0);
|
||||
Slot gender = student.getOutput().get(1);
|
||||
|
||||
@ -23,6 +23,7 @@ import org.apache.doris.nereids.trees.expressions.EqualTo;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.plans.JoinType;
|
||||
import org.apache.doris.nereids.trees.plans.Plan;
|
||||
import org.apache.doris.nereids.trees.plans.RelationId;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalLimit;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
|
||||
@ -48,8 +49,8 @@ import java.util.function.Function;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
class LimitPushDownTest extends TestWithFeService implements PatternMatchSupported {
|
||||
private Plan scanScore = new LogicalOlapScan(PlanConstructor.score);
|
||||
private Plan scanStudent = new LogicalOlapScan(PlanConstructor.student);
|
||||
private Plan scanScore = new LogicalOlapScan(new RelationId(0), PlanConstructor.score);
|
||||
private Plan scanStudent = new LogicalOlapScan(new RelationId(1), PlanConstructor.student);
|
||||
|
||||
@Override
|
||||
protected void runBeforeAll() throws Exception {
|
||||
@ -213,8 +214,8 @@ class LimitPushDownTest extends TestWithFeService implements PatternMatchSupport
|
||||
joinType,
|
||||
joinConditions,
|
||||
Optional.empty(),
|
||||
new LogicalOlapScan(PlanConstructor.score),
|
||||
new LogicalOlapScan(PlanConstructor.student)
|
||||
new LogicalOlapScan(new RelationId(0), PlanConstructor.score),
|
||||
new LogicalOlapScan(new RelationId(1), PlanConstructor.student)
|
||||
);
|
||||
|
||||
if (hasProject) {
|
||||
|
||||
@ -88,7 +88,7 @@ class PruneOlapScanPartitionTest {
|
||||
olapTable.getName();
|
||||
result = "tbl";
|
||||
}};
|
||||
LogicalOlapScan scan = new LogicalOlapScan(PlanConstructor.getNextId(), olapTable);
|
||||
LogicalOlapScan scan = new LogicalOlapScan(PlanConstructor.getNextRelationId(), olapTable);
|
||||
SlotReference slotRef = new SlotReference("col1", IntegerType.INSTANCE);
|
||||
Expression expression = new LessThan(slotRef, new IntegerLiteral(4));
|
||||
LogicalFilter<LogicalOlapScan> filter = new LogicalFilter<>(expression, scan);
|
||||
@ -104,7 +104,7 @@ class PruneOlapScanPartitionTest {
|
||||
Expression greaterThan6 = new GreaterThan(slotRef, new IntegerLiteral(6));
|
||||
Or lessThan0OrGreaterThan6 = new Or(lessThan0, greaterThan6);
|
||||
filter = new LogicalFilter<>(lessThan0OrGreaterThan6, scan);
|
||||
scan = new LogicalOlapScan(PlanConstructor.getNextId(), olapTable);
|
||||
scan = new LogicalOlapScan(PlanConstructor.getNextRelationId(), olapTable);
|
||||
cascadesContext = MemoTestUtils.createCascadesContext(filter);
|
||||
rules = Lists.newArrayList(new PruneOlapScanPartition().build());
|
||||
cascadesContext.topDownRewrite(rules);
|
||||
@ -118,7 +118,7 @@ class PruneOlapScanPartitionTest {
|
||||
Expression lessThanEqual5 =
|
||||
new LessThanEqual(slotRef, new IntegerLiteral(5));
|
||||
And greaterThanEqual0AndLessThanEqual5 = new And(greaterThanEqual0, lessThanEqual5);
|
||||
scan = new LogicalOlapScan(PlanConstructor.getNextId(), olapTable);
|
||||
scan = new LogicalOlapScan(PlanConstructor.getNextRelationId(), olapTable);
|
||||
filter = new LogicalFilter<>(greaterThanEqual0AndLessThanEqual5, scan);
|
||||
cascadesContext = MemoTestUtils.createCascadesContext(filter);
|
||||
rules = Lists.newArrayList(new PruneOlapScanPartition().build());
|
||||
@ -153,7 +153,7 @@ class PruneOlapScanPartitionTest {
|
||||
olapTable.getName();
|
||||
result = "tbl";
|
||||
}};
|
||||
LogicalOlapScan scan = new LogicalOlapScan(PlanConstructor.getNextId(), olapTable);
|
||||
LogicalOlapScan scan = new LogicalOlapScan(PlanConstructor.getNextRelationId(), olapTable);
|
||||
Expression left = new LessThan(new SlotReference("col1", IntegerType.INSTANCE), new IntegerLiteral(4));
|
||||
Expression right = new GreaterThan(new SlotReference("col2", IntegerType.INSTANCE), new IntegerLiteral(11));
|
||||
CompoundPredicate and = new And(left, right);
|
||||
|
||||
@ -0,0 +1,196 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package org.apache.doris.nereids.rules.rewrite.logical;
|
||||
|
||||
import org.apache.doris.nereids.memo.Group;
|
||||
import org.apache.doris.nereids.memo.Memo;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.GreaterThan;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.Literal;
|
||||
import org.apache.doris.nereids.trees.plans.JoinType;
|
||||
import org.apache.doris.nereids.trees.plans.Plan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
|
||||
import org.apache.doris.nereids.util.ExpressionUtils;
|
||||
import org.apache.doris.nereids.util.PlanConstructor;
|
||||
import org.apache.doris.nereids.util.PlanRewriter;
|
||||
import org.apache.doris.qe.ConnectContext;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import com.google.common.collect.Lists;
|
||||
import org.junit.jupiter.api.Assertions;
|
||||
import org.junit.jupiter.api.BeforeAll;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.TestInstance;
|
||||
|
||||
import java.util.Optional;
|
||||
|
||||
@TestInstance(TestInstance.Lifecycle.PER_CLASS)
|
||||
public class PushDownJoinOtherConditionTest {
|
||||
|
||||
private Plan rStudent;
|
||||
private Plan rScore;
|
||||
|
||||
/**
|
||||
* ut before.
|
||||
*/
|
||||
@BeforeAll
|
||||
public final void beforeAll() {
|
||||
rStudent = new LogicalOlapScan(PlanConstructor.getNextRelationId(), PlanConstructor.student, ImmutableList.of(""));
|
||||
rScore = new LogicalOlapScan(PlanConstructor.getNextRelationId(), PlanConstructor.score, ImmutableList.of(""));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void oneSide() {
|
||||
oneSide(JoinType.CROSS_JOIN, false);
|
||||
oneSide(JoinType.INNER_JOIN, false);
|
||||
oneSide(JoinType.LEFT_OUTER_JOIN, true);
|
||||
oneSide(JoinType.LEFT_SEMI_JOIN, true);
|
||||
oneSide(JoinType.LEFT_ANTI_JOIN, true);
|
||||
oneSide(JoinType.RIGHT_OUTER_JOIN, false);
|
||||
oneSide(JoinType.RIGHT_SEMI_JOIN, false);
|
||||
oneSide(JoinType.RIGHT_ANTI_JOIN, false);
|
||||
}
|
||||
|
||||
private void oneSide(JoinType joinType, boolean testRight) {
|
||||
|
||||
Expression pushSide1 = new GreaterThan(rStudent.getOutput().get(1), Literal.of(18));
|
||||
Expression pushSide2 = new GreaterThan(rStudent.getOutput().get(1), Literal.of(50));
|
||||
Expression condition = ExpressionUtils.and(pushSide1, pushSide2);
|
||||
|
||||
Plan left = rStudent;
|
||||
Plan right = rScore;
|
||||
if (testRight) {
|
||||
left = rScore;
|
||||
right = rStudent;
|
||||
}
|
||||
|
||||
Plan join = new LogicalJoin<>(joinType, Lists.newArrayList(), Optional.of(condition), left, right);
|
||||
Plan root = new LogicalProject<>(Lists.newArrayList(), join);
|
||||
|
||||
Memo memo = rewrite(root);
|
||||
Group rootGroup = memo.getRoot();
|
||||
|
||||
Plan shouldJoin = rootGroup.getLogicalExpression().child(0).getLogicalExpression().getPlan();
|
||||
Plan shouldFilter = rootGroup.getLogicalExpression().child(0).getLogicalExpression()
|
||||
.child(0).getLogicalExpression().getPlan();
|
||||
Plan shouldScan = rootGroup.getLogicalExpression().child(0).getLogicalExpression()
|
||||
.child(1).getLogicalExpression().getPlan();
|
||||
if (testRight) {
|
||||
shouldFilter = rootGroup.getLogicalExpression().child(0).getLogicalExpression()
|
||||
.child(1).getLogicalExpression().getPlan();
|
||||
shouldScan = rootGroup.getLogicalExpression().child(0).getLogicalExpression()
|
||||
.child(0).getLogicalExpression().getPlan();
|
||||
}
|
||||
|
||||
Assertions.assertTrue(shouldJoin instanceof LogicalJoin);
|
||||
Assertions.assertTrue(shouldFilter instanceof LogicalFilter);
|
||||
Assertions.assertTrue(shouldScan instanceof LogicalOlapScan);
|
||||
LogicalFilter<Plan> actualFilter = (LogicalFilter<Plan>) shouldFilter;
|
||||
|
||||
Assertions.assertEquals(condition, actualFilter.getPredicates());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void bothSideToBothSide() {
|
||||
bothSideToBothSide(JoinType.CROSS_JOIN);
|
||||
bothSideToBothSide(JoinType.INNER_JOIN);
|
||||
bothSideToBothSide(JoinType.LEFT_SEMI_JOIN);
|
||||
bothSideToBothSide(JoinType.RIGHT_SEMI_JOIN);
|
||||
}
|
||||
|
||||
private void bothSideToBothSide(JoinType joinType) {
|
||||
|
||||
Expression leftSide = new GreaterThan(rStudent.getOutput().get(1), Literal.of(18));
|
||||
Expression rightSide = new GreaterThan(rScore.getOutput().get(2), Literal.of(60));
|
||||
Expression condition = ExpressionUtils.and(leftSide, rightSide);
|
||||
|
||||
Plan join = new LogicalJoin<>(joinType, Lists.newArrayList(), Optional.of(condition), rStudent, rScore);
|
||||
Plan root = new LogicalProject<>(Lists.newArrayList(), join);
|
||||
|
||||
Memo memo = rewrite(root);
|
||||
Group rootGroup = memo.getRoot();
|
||||
|
||||
Plan shouldJoin = rootGroup.getLogicalExpression().child(0).getLogicalExpression().getPlan();
|
||||
Plan leftFilter = rootGroup.getLogicalExpression().child(0).getLogicalExpression()
|
||||
.child(0).getLogicalExpression().getPlan();
|
||||
Plan rightFilter = rootGroup.getLogicalExpression().child(0).getLogicalExpression()
|
||||
.child(1).getLogicalExpression().getPlan();
|
||||
|
||||
Assertions.assertTrue(shouldJoin instanceof LogicalJoin);
|
||||
Assertions.assertTrue(leftFilter instanceof LogicalFilter);
|
||||
Assertions.assertTrue(rightFilter instanceof LogicalFilter);
|
||||
LogicalFilter<Plan> actualLeft = (LogicalFilter<Plan>) leftFilter;
|
||||
LogicalFilter<Plan> actualRight = (LogicalFilter<Plan>) rightFilter;
|
||||
Assertions.assertEquals(leftSide, actualLeft.getPredicates());
|
||||
Assertions.assertEquals(rightSide, actualRight.getPredicates());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void bothSideToOneSide() {
|
||||
bothSideToOneSide(JoinType.LEFT_OUTER_JOIN, true);
|
||||
bothSideToOneSide(JoinType.LEFT_ANTI_JOIN, true);
|
||||
bothSideToOneSide(JoinType.RIGHT_OUTER_JOIN, false);
|
||||
bothSideToOneSide(JoinType.RIGHT_ANTI_JOIN, false);
|
||||
}
|
||||
|
||||
private void bothSideToOneSide(JoinType joinType, boolean testRight) {
|
||||
|
||||
Expression pushSide = new GreaterThan(rStudent.getOutput().get(1), Literal.of(18));
|
||||
Expression reserveSide = new GreaterThan(rScore.getOutput().get(2), Literal.of(60));
|
||||
Expression condition = ExpressionUtils.and(pushSide, reserveSide);
|
||||
|
||||
Plan left = rStudent;
|
||||
Plan right = rScore;
|
||||
if (testRight) {
|
||||
left = rScore;
|
||||
right = rStudent;
|
||||
}
|
||||
|
||||
Plan join = new LogicalJoin<>(joinType, Lists.newArrayList(), Optional.of(condition), left, right);
|
||||
Plan root = new LogicalProject<>(Lists.newArrayList(), join);
|
||||
|
||||
Memo memo = rewrite(root);
|
||||
Group rootGroup = memo.getRoot();
|
||||
|
||||
Plan shouldJoin = rootGroup.getLogicalExpression()
|
||||
.child(0).getLogicalExpression().getPlan();
|
||||
Plan shouldFilter = rootGroup.getLogicalExpression()
|
||||
.child(0).getLogicalExpression().child(0).getLogicalExpression().getPlan();
|
||||
Plan shouldScan = rootGroup.getLogicalExpression()
|
||||
.child(0).getLogicalExpression().child(1).getLogicalExpression().getPlan();
|
||||
if (testRight) {
|
||||
shouldFilter = rootGroup.getLogicalExpression()
|
||||
.child(0).getLogicalExpression().child(1).getLogicalExpression().getPlan();
|
||||
shouldScan = rootGroup.getLogicalExpression()
|
||||
.child(0).getLogicalExpression().child(0).getLogicalExpression().getPlan();
|
||||
}
|
||||
|
||||
Assertions.assertTrue(shouldJoin instanceof LogicalJoin);
|
||||
Assertions.assertTrue(shouldFilter instanceof LogicalFilter);
|
||||
Assertions.assertTrue(shouldScan instanceof LogicalOlapScan);
|
||||
LogicalFilter<Plan> actualFilter = (LogicalFilter<Plan>) shouldFilter;
|
||||
Assertions.assertEquals(pushSide, actualFilter.getPredicates());
|
||||
}
|
||||
|
||||
private Memo rewrite(Plan plan) {
|
||||
return PlanRewriter.topDownRewriteMemo(plan, new ConnectContext(), new PushDownJoinOtherCondition());
|
||||
}
|
||||
}
|
||||
@ -1,228 +0,0 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package org.apache.doris.nereids.rules.rewrite.logical;
|
||||
|
||||
import org.apache.doris.nereids.memo.Group;
|
||||
import org.apache.doris.nereids.memo.Memo;
|
||||
import org.apache.doris.nereids.rules.expression.rewrite.ExpressionNormalization;
|
||||
import org.apache.doris.nereids.trees.expressions.Add;
|
||||
import org.apache.doris.nereids.trees.expressions.And;
|
||||
import org.apache.doris.nereids.trees.expressions.Between;
|
||||
import org.apache.doris.nereids.trees.expressions.Cast;
|
||||
import org.apache.doris.nereids.trees.expressions.EqualTo;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.GreaterThan;
|
||||
import org.apache.doris.nereids.trees.expressions.GreaterThanEqual;
|
||||
import org.apache.doris.nereids.trees.expressions.LessThanEqual;
|
||||
import org.apache.doris.nereids.trees.expressions.Subtract;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.Literal;
|
||||
import org.apache.doris.nereids.trees.plans.JoinType;
|
||||
import org.apache.doris.nereids.trees.plans.Plan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
|
||||
import org.apache.doris.nereids.types.DoubleType;
|
||||
import org.apache.doris.nereids.types.StringType;
|
||||
import org.apache.doris.nereids.util.ExpressionUtils;
|
||||
import org.apache.doris.nereids.util.PlanConstructor;
|
||||
import org.apache.doris.nereids.util.PlanRewriter;
|
||||
import org.apache.doris.qe.ConnectContext;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import com.google.common.collect.Lists;
|
||||
import org.junit.jupiter.api.Assertions;
|
||||
import org.junit.jupiter.api.BeforeAll;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.TestInstance;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Optional;
|
||||
|
||||
/**
|
||||
* plan rewrite ut.
|
||||
*/
|
||||
@TestInstance(TestInstance.Lifecycle.PER_CLASS)
|
||||
public class PushDownPredicateTest {
|
||||
|
||||
private Plan rStudent;
|
||||
private Plan rScore;
|
||||
private Plan rCourse;
|
||||
|
||||
/**
|
||||
* ut before.
|
||||
*/
|
||||
@BeforeAll
|
||||
public final void beforeAll() {
|
||||
rStudent = new LogicalOlapScan(PlanConstructor.getNextId(), PlanConstructor.student, ImmutableList.of(""));
|
||||
|
||||
rScore = new LogicalOlapScan(PlanConstructor.getNextId(), PlanConstructor.score, ImmutableList.of(""));
|
||||
|
||||
rCourse = new LogicalOlapScan(PlanConstructor.getNextId(), PlanConstructor.course, ImmutableList.of(""));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void pushDownPredicateIntoScanTest1() {
|
||||
// select id,name,grade from student join score on student.id = score.sid and student.id > 1
|
||||
// and score.cid > 2 where student.age > 18 and score.grade > 60
|
||||
Expression onCondition1 = new EqualTo(rStudent.getOutput().get(0), rScore.getOutput().get(0));
|
||||
Expression onCondition2 = new GreaterThan(rStudent.getOutput().get(0), Literal.of(1));
|
||||
Expression onCondition3 = new GreaterThan(rScore.getOutput().get(0), Literal.of(2));
|
||||
Expression onCondition = ExpressionUtils.and(onCondition1, onCondition2, onCondition3);
|
||||
|
||||
Expression whereCondition1 = new GreaterThan(rStudent.getOutput().get(1), Literal.of(18));
|
||||
Expression whereCondition2 = new GreaterThan(rScore.getOutput().get(2), Literal.of(60));
|
||||
Expression whereCondition = ExpressionUtils.and(whereCondition1, whereCondition2);
|
||||
|
||||
Plan join = new LogicalJoin(JoinType.INNER_JOIN, new ArrayList<>(), Optional.of(onCondition), rStudent, rScore);
|
||||
Plan filter = new LogicalFilter(whereCondition, join);
|
||||
|
||||
Plan root = new LogicalProject(
|
||||
Lists.newArrayList(rStudent.getOutput().get(1), rCourse.getOutput().get(1), rScore.getOutput().get(2)),
|
||||
filter
|
||||
);
|
||||
|
||||
Memo memo = rewrite(root);
|
||||
|
||||
Group rootGroup = memo.getRoot();
|
||||
|
||||
Plan op1 = rootGroup.getLogicalExpression().child(0).getLogicalExpression().getPlan();
|
||||
Plan op2 = rootGroup.getLogicalExpression().child(0).getLogicalExpression().child(0).getLogicalExpression()
|
||||
.getPlan();
|
||||
Plan op3 = rootGroup.getLogicalExpression().child(0).getLogicalExpression().child(1).getLogicalExpression()
|
||||
.getPlan();
|
||||
|
||||
Assertions.assertTrue(op1 instanceof LogicalJoin);
|
||||
Assertions.assertTrue(op2 instanceof LogicalFilter);
|
||||
Assertions.assertTrue(op3 instanceof LogicalFilter);
|
||||
LogicalJoin join1 = (LogicalJoin) op1;
|
||||
LogicalFilter filter1 = (LogicalFilter) op2;
|
||||
LogicalFilter filter2 = (LogicalFilter) op3;
|
||||
|
||||
Assertions.assertEquals(onCondition1, join1.getOtherJoinCondition().get());
|
||||
Assertions.assertEquals(ExpressionUtils.and(onCondition2, whereCondition1), filter1.getPredicates());
|
||||
Assertions.assertEquals(ExpressionUtils.and(onCondition3,
|
||||
new GreaterThan(rScore.getOutput().get(2), new Cast(Literal.of(60), DoubleType.INSTANCE))),
|
||||
filter2.getPredicates());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void pushDownPredicateIntoScanTest3() {
|
||||
//select id,name,grade from student left join score on student.id + 1 = score.sid - 2
|
||||
//where student.age > 18 and score.grade > 60
|
||||
Expression whereCondition1 = new EqualTo(new Add(rStudent.getOutput().get(0), Literal.of(1)),
|
||||
new Subtract(rScore.getOutput().get(0), Literal.of(2)));
|
||||
Expression whereCondition2 = new GreaterThan(rStudent.getOutput().get(1), Literal.of(18));
|
||||
Expression whereCondition3 = new GreaterThan(rScore.getOutput().get(2), Literal.of(60));
|
||||
Expression whereCondition = ExpressionUtils.and(whereCondition1, whereCondition2, whereCondition3);
|
||||
|
||||
Plan join = new LogicalJoin(JoinType.INNER_JOIN, new ArrayList<>(), Optional.empty(), rStudent, rScore);
|
||||
Plan filter = new LogicalFilter(whereCondition, join);
|
||||
|
||||
Plan root = new LogicalProject(
|
||||
Lists.newArrayList(rStudent.getOutput().get(1), rCourse.getOutput().get(1), rScore.getOutput().get(2)),
|
||||
filter
|
||||
);
|
||||
|
||||
Memo memo = rewrite(root);
|
||||
Group rootGroup = memo.getRoot();
|
||||
|
||||
Plan op1 = rootGroup.getLogicalExpression().child(0).getLogicalExpression().getPlan();
|
||||
Plan op2 = rootGroup.getLogicalExpression().child(0).getLogicalExpression().child(0).getLogicalExpression()
|
||||
.getPlan();
|
||||
Plan op3 = rootGroup.getLogicalExpression().child(0).getLogicalExpression().child(1).getLogicalExpression()
|
||||
.getPlan();
|
||||
|
||||
Assertions.assertTrue(op1 instanceof LogicalJoin);
|
||||
Assertions.assertTrue(op2 instanceof LogicalFilter);
|
||||
Assertions.assertTrue(op3 instanceof LogicalFilter);
|
||||
LogicalJoin join1 = (LogicalJoin) op1;
|
||||
LogicalFilter filter1 = (LogicalFilter) op2;
|
||||
LogicalFilter filter2 = (LogicalFilter) op3;
|
||||
Assertions.assertEquals(whereCondition1, join1.getOtherJoinCondition().get());
|
||||
Assertions.assertEquals(whereCondition2, filter1.getPredicates());
|
||||
Assertions.assertEquals(
|
||||
new GreaterThan(rScore.getOutput().get(2), new Cast(Literal.of(60), DoubleType.INSTANCE)),
|
||||
filter2.getPredicates());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void pushDownPredicateIntoScanTest4() {
|
||||
/*
|
||||
select
|
||||
student.name,
|
||||
course.name,
|
||||
score.grade
|
||||
from student,score,course
|
||||
where on student.id = score.sid and student.age between 18 and 20 and score.grade > 60 and student.id = score.sid
|
||||
*/
|
||||
|
||||
// student.id = score.sid
|
||||
Expression whereCondition1 = new EqualTo(rStudent.getOutput().get(0), rScore.getOutput().get(0));
|
||||
// score.cid = course.cid
|
||||
Expression whereCondition2 = new EqualTo(rScore.getOutput().get(1), rCourse.getOutput().get(0));
|
||||
// student.age between 18 and 20
|
||||
Expression whereCondition3 = new Between(rStudent.getOutput().get(2), Literal.of(18), Literal.of(20));
|
||||
// student.age >= 18 and student.age <= 20
|
||||
Expression whereCondition3result = new And(
|
||||
new GreaterThanEqual(rStudent.getOutput().get(2), new Cast(Literal.of(18), StringType.INSTANCE)),
|
||||
new LessThanEqual(rStudent.getOutput().get(2), new Cast(Literal.of(20), StringType.INSTANCE)));
|
||||
|
||||
// score.grade > 60
|
||||
Expression whereCondition4 = new GreaterThan(rScore.getOutput().get(2), Literal.of(60));
|
||||
|
||||
Expression whereCondition = ExpressionUtils.and(whereCondition1, whereCondition2, whereCondition3,
|
||||
whereCondition4);
|
||||
|
||||
Plan join = new LogicalJoin(JoinType.INNER_JOIN, ImmutableList.of(), Optional.empty(), rStudent, rScore);
|
||||
Plan join1 = new LogicalJoin(JoinType.INNER_JOIN, ImmutableList.of(), Optional.empty(), join, rCourse);
|
||||
Plan filter = new LogicalFilter(whereCondition, join1);
|
||||
|
||||
Plan root = new LogicalProject(
|
||||
Lists.newArrayList(rStudent.getOutput().get(1), rCourse.getOutput().get(1), rScore.getOutput().get(2)),
|
||||
filter
|
||||
);
|
||||
|
||||
Memo memo = rewrite(root);
|
||||
Group rootGroup = memo.getRoot();
|
||||
Plan join2 = rootGroup.getLogicalExpression().child(0).getLogicalExpression().getPlan();
|
||||
Plan join3 = rootGroup.getLogicalExpression().child(0).getLogicalExpression().child(0).getLogicalExpression()
|
||||
.getPlan();
|
||||
Plan op1 = rootGroup.getLogicalExpression().child(0).getLogicalExpression().child(0).getLogicalExpression()
|
||||
.child(0).getLogicalExpression().getPlan();
|
||||
Plan op2 = rootGroup.getLogicalExpression().child(0).getLogicalExpression().child(0).getLogicalExpression()
|
||||
.child(1).getLogicalExpression().getPlan();
|
||||
|
||||
Assertions.assertTrue(join2 instanceof LogicalJoin);
|
||||
Assertions.assertTrue(join3 instanceof LogicalJoin);
|
||||
Assertions.assertTrue(op1 instanceof LogicalFilter);
|
||||
Assertions.assertTrue(op2 instanceof LogicalFilter);
|
||||
|
||||
Assertions.assertEquals(whereCondition2, ((LogicalJoin) join2).getOtherJoinCondition().get());
|
||||
Assertions.assertEquals(whereCondition1, ((LogicalJoin) join3).getOtherJoinCondition().get());
|
||||
Assertions.assertEquals(whereCondition3result.toSql(), ((LogicalFilter) op1).getPredicates().toSql());
|
||||
Assertions.assertEquals(
|
||||
new GreaterThan(rScore.getOutput().get(2), new Cast(Literal.of(60), DoubleType.INSTANCE)),
|
||||
((LogicalFilter) op2).getPredicates());
|
||||
}
|
||||
|
||||
private Memo rewrite(Plan plan) {
|
||||
Plan normalizedPlan = PlanRewriter.topDownRewrite(plan, new ConnectContext(), new ExpressionNormalization());
|
||||
return PlanRewriter.topDownRewriteMemo(normalizedPlan, new ConnectContext(), new PushPredicateThroughJoin());
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,208 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package org.apache.doris.nereids.rules.rewrite.logical;
|
||||
|
||||
import org.apache.doris.nereids.memo.Group;
|
||||
import org.apache.doris.nereids.memo.Memo;
|
||||
import org.apache.doris.nereids.trees.expressions.Add;
|
||||
import org.apache.doris.nereids.trees.expressions.EqualTo;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.GreaterThan;
|
||||
import org.apache.doris.nereids.trees.expressions.Subtract;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.Literal;
|
||||
import org.apache.doris.nereids.trees.plans.JoinType;
|
||||
import org.apache.doris.nereids.trees.plans.Plan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
|
||||
import org.apache.doris.nereids.util.ExpressionUtils;
|
||||
import org.apache.doris.nereids.util.PlanConstructor;
|
||||
import org.apache.doris.nereids.util.PlanRewriter;
|
||||
import org.apache.doris.qe.ConnectContext;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import com.google.common.collect.Lists;
|
||||
import org.junit.jupiter.api.Assertions;
|
||||
import org.junit.jupiter.api.BeforeAll;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.TestInstance;
|
||||
|
||||
import java.util.Optional;
|
||||
|
||||
/**
|
||||
* plan rewrite ut.
|
||||
*/
|
||||
@TestInstance(TestInstance.Lifecycle.PER_CLASS)
|
||||
public class PushPredicateThroughJoinTest {
|
||||
|
||||
private Plan rStudent;
|
||||
private Plan rScore;
|
||||
|
||||
/**
|
||||
* ut before.
|
||||
*/
|
||||
@BeforeAll
|
||||
public final void beforeAll() {
|
||||
rStudent = new LogicalOlapScan(PlanConstructor.getNextRelationId(), PlanConstructor.student, ImmutableList.of(""));
|
||||
rScore = new LogicalOlapScan(PlanConstructor.getNextRelationId(), PlanConstructor.score, ImmutableList.of(""));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void oneSide() {
|
||||
oneSide(JoinType.CROSS_JOIN, false);
|
||||
oneSide(JoinType.INNER_JOIN, false);
|
||||
oneSide(JoinType.LEFT_OUTER_JOIN, false);
|
||||
oneSide(JoinType.LEFT_SEMI_JOIN, false);
|
||||
oneSide(JoinType.LEFT_ANTI_JOIN, false);
|
||||
oneSide(JoinType.RIGHT_OUTER_JOIN, true);
|
||||
oneSide(JoinType.RIGHT_SEMI_JOIN, true);
|
||||
oneSide(JoinType.RIGHT_ANTI_JOIN, true);
|
||||
}
|
||||
|
||||
private void oneSide(JoinType joinType, boolean testRight) {
|
||||
|
||||
Expression whereCondition1 = new GreaterThan(rStudent.getOutput().get(1), Literal.of(18));
|
||||
Expression whereCondition2 = new GreaterThan(rStudent.getOutput().get(1), Literal.of(50));
|
||||
Expression whereCondition = ExpressionUtils.and(whereCondition1, whereCondition2);
|
||||
|
||||
Plan left = rStudent;
|
||||
Plan right = rScore;
|
||||
if (testRight) {
|
||||
left = rScore;
|
||||
right = rStudent;
|
||||
}
|
||||
|
||||
Plan join = new LogicalJoin<>(joinType, Lists.newArrayList(), Optional.empty(), left, right);
|
||||
Plan filter = new LogicalFilter<>(whereCondition, join);
|
||||
Plan root = new LogicalProject<>(Lists.newArrayList(), filter);
|
||||
|
||||
Memo memo = rewrite(root);
|
||||
Group rootGroup = memo.getRoot();
|
||||
|
||||
Plan shouldJoin = rootGroup.getLogicalExpression().child(0).getLogicalExpression().getPlan();
|
||||
Plan shouldFilter = rootGroup.getLogicalExpression().child(0).getLogicalExpression()
|
||||
.child(0).getLogicalExpression().getPlan();
|
||||
Plan shouldScan = rootGroup.getLogicalExpression().child(0).getLogicalExpression()
|
||||
.child(1).getLogicalExpression().getPlan();
|
||||
if (testRight) {
|
||||
shouldFilter = rootGroup.getLogicalExpression().child(0).getLogicalExpression()
|
||||
.child(1).getLogicalExpression().getPlan();
|
||||
shouldScan = rootGroup.getLogicalExpression().child(0).getLogicalExpression()
|
||||
.child(0).getLogicalExpression().getPlan();
|
||||
}
|
||||
|
||||
Assertions.assertTrue(shouldJoin instanceof LogicalJoin);
|
||||
Assertions.assertTrue(shouldFilter instanceof LogicalFilter);
|
||||
Assertions.assertTrue(shouldScan instanceof LogicalOlapScan);
|
||||
LogicalFilter<Plan> actualFilter = (LogicalFilter<Plan>) shouldFilter;
|
||||
|
||||
Assertions.assertEquals(whereCondition, actualFilter.getPredicates());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void bothSideToBothSide() {
|
||||
bothSideToBothSide(JoinType.INNER_JOIN);
|
||||
}
|
||||
|
||||
private void bothSideToBothSide(JoinType joinType) {
|
||||
|
||||
Expression bothSideEqualTo = new EqualTo(new Add(rStudent.getOutput().get(0), Literal.of(1)),
|
||||
new Subtract(rScore.getOutput().get(0), Literal.of(2)));
|
||||
Expression leftSide = new GreaterThan(rStudent.getOutput().get(1), Literal.of(18));
|
||||
Expression rightSide = new GreaterThan(rScore.getOutput().get(2), Literal.of(60));
|
||||
Expression whereCondition = ExpressionUtils.and(bothSideEqualTo, leftSide, rightSide);
|
||||
|
||||
Plan join = new LogicalJoin<>(joinType, Lists.newArrayList(), Optional.empty(), rStudent, rScore);
|
||||
Plan filter = new LogicalFilter<>(whereCondition, join);
|
||||
Plan root = new LogicalProject<>(Lists.newArrayList(), filter);
|
||||
|
||||
Memo memo = rewrite(root);
|
||||
Group rootGroup = memo.getRoot();
|
||||
|
||||
Plan shouldJoin = rootGroup.getLogicalExpression().child(0).getLogicalExpression().getPlan();
|
||||
Plan leftFilter = rootGroup.getLogicalExpression().child(0).getLogicalExpression()
|
||||
.child(0).getLogicalExpression().getPlan();
|
||||
Plan rightFilter = rootGroup.getLogicalExpression().child(0).getLogicalExpression()
|
||||
.child(1).getLogicalExpression().getPlan();
|
||||
|
||||
Assertions.assertTrue(shouldJoin instanceof LogicalJoin);
|
||||
Assertions.assertTrue(leftFilter instanceof LogicalFilter);
|
||||
Assertions.assertTrue(rightFilter instanceof LogicalFilter);
|
||||
LogicalJoin<Plan, Plan> actualJoin = (LogicalJoin<Plan, Plan>) shouldJoin;
|
||||
LogicalFilter<Plan> actualLeft = (LogicalFilter<Plan>) leftFilter;
|
||||
LogicalFilter<Plan> actualRight = (LogicalFilter<Plan>) rightFilter;
|
||||
Assertions.assertEquals(bothSideEqualTo, actualJoin.getOtherJoinCondition().get());
|
||||
Assertions.assertEquals(leftSide, actualLeft.getPredicates());
|
||||
Assertions.assertEquals(rightSide, actualRight.getPredicates());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void bothSideToOneSide() {
|
||||
bothSideToOneSide(JoinType.LEFT_OUTER_JOIN, false);
|
||||
bothSideToOneSide(JoinType.LEFT_ANTI_JOIN, false);
|
||||
bothSideToOneSide(JoinType.LEFT_SEMI_JOIN, false);
|
||||
bothSideToOneSide(JoinType.RIGHT_OUTER_JOIN, true);
|
||||
bothSideToOneSide(JoinType.RIGHT_ANTI_JOIN, true);
|
||||
bothSideToOneSide(JoinType.RIGHT_SEMI_JOIN, true);
|
||||
}
|
||||
|
||||
private void bothSideToOneSide(JoinType joinType, boolean testRight) {
|
||||
|
||||
Expression pushSide = new GreaterThan(rStudent.getOutput().get(1), Literal.of(18));
|
||||
Expression reserveSide = new GreaterThan(rScore.getOutput().get(2), Literal.of(60));
|
||||
Expression whereCondition = ExpressionUtils.and(pushSide, reserveSide);
|
||||
|
||||
Plan left = rStudent;
|
||||
Plan right = rScore;
|
||||
if (testRight) {
|
||||
left = rScore;
|
||||
right = rStudent;
|
||||
}
|
||||
|
||||
Plan join = new LogicalJoin<>(joinType, Lists.newArrayList(), Optional.empty(), left, right);
|
||||
Plan filter = new LogicalFilter<>(whereCondition, join);
|
||||
Plan root = new LogicalProject<>(Lists.newArrayList(), filter);
|
||||
|
||||
Memo memo = rewrite(root);
|
||||
Group rootGroup = memo.getRoot();
|
||||
|
||||
Plan shouldJoin = rootGroup.getLogicalExpression().child(0).getLogicalExpression()
|
||||
.child(0).getLogicalExpression().getPlan();
|
||||
Plan shouldFilter = rootGroup.getLogicalExpression().child(0).getLogicalExpression()
|
||||
.child(0).getLogicalExpression().child(0).getLogicalExpression().getPlan();
|
||||
Plan shouldScan = rootGroup.getLogicalExpression().child(0).getLogicalExpression()
|
||||
.child(0).getLogicalExpression().child(1).getLogicalExpression().getPlan();
|
||||
if (testRight) {
|
||||
shouldFilter = rootGroup.getLogicalExpression().child(0).getLogicalExpression()
|
||||
.child(0).getLogicalExpression().child(1).getLogicalExpression().getPlan();
|
||||
shouldScan = rootGroup.getLogicalExpression().child(0).getLogicalExpression()
|
||||
.child(0).getLogicalExpression().child(0).getLogicalExpression().getPlan();
|
||||
}
|
||||
|
||||
Assertions.assertTrue(shouldJoin instanceof LogicalJoin);
|
||||
Assertions.assertTrue(shouldFilter instanceof LogicalFilter);
|
||||
Assertions.assertTrue(shouldScan instanceof LogicalOlapScan);
|
||||
LogicalFilter<Plan> actualFilter = (LogicalFilter<Plan>) shouldFilter;
|
||||
Assertions.assertEquals(pushSide, actualFilter.getPredicates());
|
||||
}
|
||||
|
||||
private Memo rewrite(Plan plan) {
|
||||
return PlanRewriter.topDownRewriteMemo(plan, new ConnectContext(), new PushPredicatesThroughJoin());
|
||||
}
|
||||
}
|
||||
@ -37,7 +37,7 @@ public class PlanConstructor {
|
||||
public static OlapTable student;
|
||||
public static OlapTable score;
|
||||
public static OlapTable course;
|
||||
private static final IdGenerator<RelationId> GENERATOR = RelationId.createGenerator();
|
||||
private static final IdGenerator<RelationId> RELATION_ID_GENERATOR = RelationId.createGenerator();
|
||||
|
||||
static {
|
||||
student = new OlapTable(0L, "student",
|
||||
@ -102,14 +102,14 @@ public class PlanConstructor {
|
||||
// With OlapTable.
|
||||
// Warning: equals() of Table depends on tableId.
|
||||
public static LogicalOlapScan newLogicalOlapScan(long tableId, String tableName, int hashColumn) {
|
||||
return new LogicalOlapScan(GENERATOR.getNextId(), newOlapTable(tableId, tableName, hashColumn), ImmutableList.of("db"));
|
||||
return new LogicalOlapScan(RELATION_ID_GENERATOR.getNextId(), newOlapTable(tableId, tableName, hashColumn), ImmutableList.of("db"));
|
||||
}
|
||||
|
||||
public static LogicalOlapScan newLogicalOlapScanWithSameId(long tableId, String tableName, int hashColumn) {
|
||||
return new LogicalOlapScan(RelationId.createGenerator().getNextId(), newOlapTable(tableId, tableName, hashColumn), ImmutableList.of("db"));
|
||||
}
|
||||
|
||||
public static RelationId getNextId() {
|
||||
return GENERATOR.getNextId();
|
||||
public static RelationId getNextRelationId() {
|
||||
return RELATION_ID_GENERATOR.getNextId();
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user