[Enhancement](Nereids) push down predicate through join (#10462)

Add filter operator to join children according to the predicate of filter and join, in order to achieving  predicate push-down

Pattern: 
```
      filter
         |
       join
      /     \
child    child
```

Transform:
```
      filter
         |
       join
      /     \
filter     filter
 |            |  
child     child
```
This commit is contained in:
shee
2022-07-01 15:39:01 +08:00
committed by GitHub
parent f164d094e8
commit f998c0b044
17 changed files with 960 additions and 12 deletions

View File

@ -25,6 +25,7 @@ import org.apache.doris.nereids.rules.RuleSet;
import org.apache.doris.nereids.trees.TreeNode;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
/**
@ -57,7 +58,7 @@ public abstract class Job<NODE_TYPE extends TreeNode<NODE_TYPE>> {
public List<Rule<NODE_TYPE>> getValidRules(GroupExpression groupExpression,
List<Rule<NODE_TYPE>> candidateRules) {
return candidateRules.stream()
.filter(rule -> rule.getPattern().matchOperator(groupExpression.getOperator())
.filter(rule -> Objects.nonNull(rule) && rule.getPattern().matchOperator(groupExpression.getOperator())
&& groupExpression.notApplied(rule)).collect(Collectors.toList());
}

View File

@ -72,7 +72,9 @@ public class RewriteTopDownJob extends Job<Plan> {
Preconditions.checkArgument(afters.size() == 1);
Plan after = afters.get(0);
if (after != before) {
context.getOptimizerContext().getMemo().copyIn(after, group, rule.isRewrite());
GroupExpression expression = context.getOptimizerContext().getMemo()
.copyIn(after, group, rule.isRewrite());
expression.setApplied(rule);
pushTask(new RewriteTopDownJob(group, rules, context));
return;
}
@ -80,7 +82,7 @@ public class RewriteTopDownJob extends Job<Plan> {
logicalExpression.setApplied(rule);
}
for (Group childGroup : logicalExpression.children()) {
for (Group childGroup : group.getLogicalExpression().children()) {
pushTask(new RewriteTopDownJob(childGroup, rules, context));
}
}

View File

@ -98,6 +98,9 @@ public class Memo {
childrenNode.add(groupToTreeNode(child));
}
Plan result = logicalExpression.getOperator().toTreeNode(logicalExpression);
if (result.children().size() == 0) {
return result;
}
return result.withChildren(childrenNode);
}

View File

@ -39,6 +39,7 @@ import org.apache.doris.nereids.DorisParser.MultipartIdentifierContext;
import org.apache.doris.nereids.DorisParser.NamedExpressionContext;
import org.apache.doris.nereids.DorisParser.NamedExpressionSeqContext;
import org.apache.doris.nereids.DorisParser.NullLiteralContext;
import org.apache.doris.nereids.DorisParser.ParenthesizedExpressionContext;
import org.apache.doris.nereids.DorisParser.PredicateContext;
import org.apache.doris.nereids.DorisParser.PredicatedContext;
import org.apache.doris.nereids.DorisParser.QualifiedNameContext;
@ -402,7 +403,7 @@ public class LogicalPlanBuilder extends DorisParserBaseVisitor<Object> {
}
@Override
public Expression visitParenthesizedExpression(DorisParser.ParenthesizedExpressionContext ctx) {
public Expression visitParenthesizedExpression(ParenthesizedExpressionContext ctx) {
return getExpression(ctx.expression());
}

View File

@ -53,4 +53,9 @@ public class OrderKey {
public boolean isNullFirst() {
return nullFirst;
}
@Override
public String toString() {
return expr.sql();
}
}

View File

@ -39,6 +39,7 @@ public enum RuleType {
// rewrite rules
COLUMN_PRUNE_PROJECTION(RuleTypeClass.REWRITE),
PUSH_DOWN_PREDICATE_THROUGH_JOIN(RuleTypeClass.REWRITE),
// exploration rules
LOGICAL_JOIN_COMMUTATIVE(RuleTypeClass.EXPLORATION),

View File

@ -21,6 +21,7 @@ import org.apache.doris.nereids.rules.expression.rewrite.rules.NormalizeExpressi
import org.apache.doris.nereids.rules.expression.rewrite.rules.SimplifyNotExprRule;
import org.apache.doris.nereids.trees.expressions.Expression;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import java.util.List;
@ -30,7 +31,7 @@ import java.util.List;
*/
public class ExpressionRuleExecutor {
public static final List<ExpressionRewriteRule> REWRITE_RULES = Lists.newArrayList(
public static final List<ExpressionRewriteRule> REWRITE_RULES = ImmutableList.of(
new SimplifyNotExprRule(),
new NormalizeExpressionRule()
);
@ -38,6 +39,11 @@ public class ExpressionRuleExecutor {
private final ExpressionRewriteContext ctx;
private final List<ExpressionRewriteRule> rules;
public ExpressionRuleExecutor() {
this.rules = REWRITE_RULES;
this.ctx = new ExpressionRewriteContext();
}
public ExpressionRuleExecutor(List<ExpressionRewriteRule> rules) {
this.rules = rules;
this.ctx = new ExpressionRewriteContext();

View File

@ -0,0 +1,165 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.rewrite.logical;
import org.apache.doris.nereids.operators.plans.logical.LogicalFilter;
import org.apache.doris.nereids.operators.plans.logical.LogicalJoin;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.expression.rewrite.ExpressionRuleExecutor;
import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory;
import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Literal;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotExtractor;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalBinaryPlan;
import org.apache.doris.nereids.util.ExpressionUtils;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
/**
* Push the predicate in the LogicalFilter or LogicalJoin to the join children.
* For example:
* select a.k1,b.k1 from a join b on a.k1 = b.k1 and a.k2 > 2 and b.k2 > 5 where a.k1 > 1 and b.k1 > 2
* Logical plan tree:
* project
* |
* filter (a.k1 > 1 and b.k1 > 2)
* |
* join (a.k1 = b.k1 and a.k2 > 2 and b.k2 > 5)
* / \
* scan scan
* transformed:
* project
* |
* join (a.k1 = b.k1)
* / \
* filter(a.k1 > 1 and a.k2 > 2 ) filter(b.k1 > 2 and b.k2 > 5)
* | |
* scan scan
* todo: Now, only support eq on condition for inner join, support other case later
*/
public class PushPredicateThroughJoin extends OneRewriteRuleFactory {
@Override
public Rule<Plan> build() {
return logicalFilter(innerLogicalJoin()).then(filter -> {
LogicalJoin joinOp = filter.child().operator;
Expression wherePredicates = filter.operator.getPredicates();
Expression onPredicates = Literal.TRUE_LITERAL;
List<Expression> otherConditions = Lists.newArrayList();
List<Expression> eqConditions = Lists.newArrayList();
if (joinOp.getCondition().isPresent()) {
onPredicates = joinOp.getCondition().get();
}
List<Slot> leftInput = filter.child().left().getOutput();
List<Slot> rightInput = filter.child().right().getOutput();
ExpressionUtils.extractConjunct(ExpressionUtils.add(onPredicates, wherePredicates)).forEach(predicate -> {
if (Objects.nonNull(getJoinCondition(predicate, leftInput, rightInput))) {
eqConditions.add(predicate);
} else {
otherConditions.add(predicate);
}
});
List<Expression> leftPredicates = Lists.newArrayList();
List<Expression> rightPredicates = Lists.newArrayList();
for (Expression p : otherConditions) {
Set<Slot> slots = SlotExtractor.extractSlot(p);
if (slots.isEmpty()) {
leftPredicates.add(p);
rightPredicates.add(p);
}
if (leftInput.containsAll(slots)) {
leftPredicates.add(p);
}
if (rightInput.containsAll(slots)) {
rightPredicates.add(p);
}
}
otherConditions.removeAll(leftPredicates);
otherConditions.removeAll(rightPredicates);
otherConditions.addAll(eqConditions);
Expression joinConditions = ExpressionUtils.add(otherConditions);
return pushDownPredicate(filter.child(), joinConditions, leftPredicates, rightPredicates);
}).toRule(RuleType.PUSH_DOWN_PREDICATE_THROUGH_JOIN);
}
private Plan pushDownPredicate(LogicalBinaryPlan<LogicalJoin, GroupPlan, GroupPlan> joinPlan,
Expression joinConditions, List<Expression> leftPredicates, List<Expression> rightPredicates) {
Expression left = ExpressionUtils.add(leftPredicates);
Expression right = ExpressionUtils.add(rightPredicates);
//todo expr should optimize again using expr rewrite
ExpressionRuleExecutor exprRewriter = new ExpressionRuleExecutor();
Plan leftPlan = joinPlan.left();
Plan rightPlan = joinPlan.right();
if (!left.equals(Literal.TRUE_LITERAL)) {
leftPlan = plan(new LogicalFilter(exprRewriter.rewrite(left)), leftPlan);
}
if (!right.equals(Literal.TRUE_LITERAL)) {
rightPlan = plan(new LogicalFilter(exprRewriter.rewrite(right)), rightPlan);
}
return plan(new LogicalJoin(joinPlan.operator.getJoinType(), Optional.of(joinConditions)), leftPlan, rightPlan);
}
private Expression getJoinCondition(Expression predicate, List<Slot> leftOutputs, List<Slot> rightOutputs) {
if (!(predicate instanceof ComparisonPredicate)) {
return null;
}
ComparisonPredicate comparison = (ComparisonPredicate) predicate;
Set<Slot> leftSlots = SlotExtractor.extractSlot(comparison.left());
Set<Slot> rightSlots = SlotExtractor.extractSlot(comparison.right());
if (!(leftSlots.size() >= 1 && rightSlots.size() >= 1)) {
return null;
}
Set<Slot> left = Sets.newLinkedHashSet(leftOutputs);
Set<Slot> right = Sets.newLinkedHashSet(rightOutputs);
if ((left.containsAll(leftSlots) && right.containsAll(rightSlots)) || (left.containsAll(rightSlots)
&& right.containsAll(leftSlots))) {
return predicate;
}
return null;
}
}

View File

@ -48,4 +48,7 @@ public class Add<LEFT_CHILD_TYPE extends Expression, RIGHT_CHILD_TYPE extends Ex
}
public String toString() {
return sql();
}
}

View File

@ -19,6 +19,9 @@ package org.apache.doris.nereids.trees.expressions;
import org.apache.doris.nereids.trees.NodeType;
import java.util.List;
import java.util.Objects;
/**
* Compound predicate expression.
* Such as &&,||,AND,OR.
@ -53,4 +56,30 @@ public class CompoundPredicate<LEFT_CHILD_TYPE extends Expression, RIGHT_CHILD_T
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
return visitor.visitCompoundPredicate(this, context);
}
public NodeType flip() {
if (getType() == NodeType.AND) {
return NodeType.OR;
}
return NodeType.AND;
}
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
CompoundPredicate other = (CompoundPredicate) o;
return (type == other.getType()) && Objects.equals(left(), other.left())
&& Objects.equals(right(), other.right());
}
@Override
public Expression withChildren(List<Expression> children) {
return new CompoundPredicate<>(getType(), children.get(0), children.get(1));
}
}

View File

@ -0,0 +1,161 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.trees.expressions;
import org.apache.doris.nereids.trees.expressions.functions.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.BoundFunction;
/**
* Iterative traversal of an expression.
*/
public abstract class IterationVisitor<C> extends DefaultExpressionVisitor<Void, C> {
@Override
public Void visit(Expression expr, C context) {
return expr.accept(this, context);
}
@Override
public Void visitNot(Not expr, C context) {
visit(expr.child(), context);
return null;
}
@Override
public Void visitCompoundPredicate(CompoundPredicate expr, C context) {
visit(expr.left(), context);
visit(expr.right(), context);
return null;
}
@Override
public Void visitLiteral(Literal literal, C context) {
return null;
}
@Override
public Void visitArithmetic(Arithmetic arithmetic, C context) {
visit(arithmetic.child(0), context);
if (arithmetic.getArithmeticOperator().isBinary()) {
visit(arithmetic.child(1), context);
}
return null;
}
@Override
public Void visitBetween(Between betweenPredicate, C context) {
visit(betweenPredicate.getCompareExpr(), context);
visit(betweenPredicate.getLowerBound(), context);
visit(betweenPredicate.getUpperBound(), context);
return null;
}
@Override
public Void visitAlias(Alias alias, C context) {
return visitNamedExpression(alias, context);
}
@Override
public Void visitComparisonPredicate(ComparisonPredicate cp, C context) {
visit(cp.left(), context);
visit(cp.right(), context);
return null;
}
@Override
public Void visitEqualTo(EqualTo equalTo, C context) {
return visitComparisonPredicate(equalTo, context);
}
@Override
public Void visitGreaterThan(GreaterThan greaterThan, C context) {
return visitComparisonPredicate(greaterThan, context);
}
@Override
public Void visitGreaterThanEqual(GreaterThanEqual greaterThanEqual, C context) {
return visitComparisonPredicate(greaterThanEqual, context);
}
@Override
public Void visitLessThan(LessThan lessThan, C context) {
return visitComparisonPredicate(lessThan, context);
}
@Override
public Void visitLessThanEqual(LessThanEqual lessThanEqual, C context) {
return visitComparisonPredicate(lessThanEqual, context);
}
@Override
public Void visitNullSafeEqual(NullSafeEqual nullSafeEqual, C context) {
return visitComparisonPredicate(nullSafeEqual, context);
}
@Override
public Void visitSlot(Slot slot, C context) {
return null;
}
@Override
public Void visitNamedExpression(NamedExpression namedExpression, C context) {
for (Expression child : namedExpression.children()) {
visit(child, context);
}
return null;
}
@Override
public Void visitBoundFunction(BoundFunction boundFunction, C context) {
for (Expression argument : boundFunction.getArguments()) {
visit(argument, context);
}
return null;
}
@Override
public Void visitAggregateFunction(AggregateFunction aggregateFunction, C context) {
return visitBoundFunction(aggregateFunction, context);
}
@Override
public Void visitAdd(Add add, C context) {
return visitArithmetic(add, context);
}
@Override
public Void visitSubtract(Subtract subtract, C context) {
return visitArithmetic(subtract, context);
}
@Override
public Void visitMultiply(Multiply multiply, C context) {
return visitArithmetic(multiply, context);
}
@Override
public Void visitDivide(Divide divide, C context) {
return visitArithmetic(divide, context);
}
@Override
public Void visitMod(Mod mod, C context) {
return visitArithmetic(mod, context);
}
}

View File

@ -34,6 +34,8 @@ import java.util.Objects;
* TODO: Increase the implementation of sub expression. such as Integer.
*/
public class Literal extends Expression implements LeafExpression {
public static final Literal TRUE_LITERAL = new Literal(true);
public static final Literal FALSE_LITERAL = new Literal(false);
private final DataType dataType;
private final Object value;
@ -97,6 +99,10 @@ public class Literal extends Expression implements LeafExpression {
return value == null;
}
public static Literal of(Object value) {
return new Literal(value);
}
@Override
public boolean isConstant() {
return true;

View File

@ -0,0 +1,68 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.trees.expressions;
import com.clearspring.analytics.util.Lists;
import com.google.common.collect.Sets;
import java.util.Collection;
import java.util.List;
import java.util.Set;
/**
* Extracts the SlotReference contained in the expression.
*/
public class SlotExtractor extends IterationVisitor<List<Slot>> {
/**
* extract slot reference.
*/
public static Set<Slot> extractSlot(Collection<Expression> expressions) {
Set<Slot> slots = Sets.newLinkedHashSet();
for (Expression expression : expressions) {
slots.addAll(extractSlot(expression));
}
return slots;
}
/**
* extract slot reference.
*/
public static Set<Slot> extractSlot(Expression... expressions) {
Set<Slot> slots = Sets.newLinkedHashSet();
for (Expression expression : expressions) {
slots.addAll(extractSlot(expression));
}
return slots;
}
private static List<Slot> extractSlot(Expression expression) {
List<Slot> slots = Lists.newArrayList();
new SlotExtractor().visit(expression, slots);
return slots;
}
@Override
public Void visitSlotReference(SlotReference slotReference, List<Slot> context) {
context.add(slotReference);
return null;
}
}

View File

@ -0,0 +1,140 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.util;
import org.apache.doris.nereids.trees.NodeType;
import org.apache.doris.nereids.trees.expressions.CompoundPredicate;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Literal;
import com.google.common.collect.Lists;
import java.util.LinkedHashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
/**
* Expression rewrite helper class.
*/
public class ExpressionUtils {
public static boolean isConstant(Expression expr) {
return expr.isConstant();
}
public static List<Expression> extractConjunct(Expression expr) {
return extract(NodeType.AND, expr);
}
public static List<Expression> extractDisjunct(Expression expr) {
return extract(NodeType.OR, expr);
}
public static List<Expression> extract(CompoundPredicate expr) {
return extract(expr.getType(), expr);
}
private static List<Expression> extract(NodeType op, Expression expr) {
List<Expression> result = Lists.newArrayList();
extract(op, expr, result);
return result;
}
private static void extract(NodeType op, Expression expr, List<Expression> result) {
if (expr instanceof CompoundPredicate && expr.getType() == op) {
CompoundPredicate predicate = (CompoundPredicate) expr;
extract(op, predicate.left(), result);
extract(op, predicate.right(), result);
} else {
result.add(expr);
}
}
public static Expression add(List<Expression> expressions) {
return combine(NodeType.AND, expressions);
}
public static Expression add(Expression... expressions) {
return combine(NodeType.AND, Lists.newArrayList(expressions));
}
public static Expression or(Expression... expressions) {
return combine(NodeType.OR, Lists.newArrayList(expressions));
}
public static Expression or(List<Expression> expressions) {
return combine(NodeType.OR, expressions);
}
/**
* Use AND/OR to combine expressions together.
*/
public static Expression combine(NodeType op, List<Expression> expressions) {
Objects.requireNonNull(expressions, "expressions is null");
if (expressions.size() == 0) {
if (op == NodeType.AND) {
return new Literal(true);
}
if (op == NodeType.OR) {
return new Literal(false);
}
}
if (expressions.size() == 1) {
return expressions.get(0);
}
List<Expression> distinctExpressions = Lists.newArrayList(new LinkedHashSet<>(expressions));
if (op == NodeType.AND) {
if (distinctExpressions.contains(Literal.FALSE_LITERAL)) {
return Literal.FALSE_LITERAL;
}
distinctExpressions = distinctExpressions.stream().filter(p -> !p.equals(Literal.TRUE_LITERAL))
.collect(Collectors.toList());
}
if (op == NodeType.OR) {
if (distinctExpressions.contains(Literal.TRUE_LITERAL)) {
return Literal.TRUE_LITERAL;
}
distinctExpressions = distinctExpressions.stream().filter(p -> !p.equals(Literal.FALSE_LITERAL))
.collect(Collectors.toList());
}
List<List<Expression>> partitions = Lists.partition(distinctExpressions, 2);
List<Expression> result = new LinkedList<>();
for (List<Expression> partition : partitions) {
if (partition.size() == 2) {
result.add(new CompoundPredicate(op, partition.get(0), partition.get(1)));
}
if (partition.size() == 1) {
result.add(partition.get(0));
}
}
return combine(op, result);
}
}

View File

@ -0,0 +1,256 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.rewrite.logical;
import org.apache.doris.catalog.AggregateType;
import org.apache.doris.catalog.Column;
import org.apache.doris.catalog.Table;
import org.apache.doris.catalog.Type;
import org.apache.doris.nereids.OptimizerContext;
import org.apache.doris.nereids.PlannerContext;
import org.apache.doris.nereids.jobs.rewrite.RewriteTopDownJob;
import org.apache.doris.nereids.memo.Group;
import org.apache.doris.nereids.memo.Memo;
import org.apache.doris.nereids.operators.Operator;
import org.apache.doris.nereids.operators.plans.JoinType;
import org.apache.doris.nereids.operators.plans.logical.LogicalFilter;
import org.apache.doris.nereids.operators.plans.logical.LogicalJoin;
import org.apache.doris.nereids.operators.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.operators.plans.logical.LogicalProject;
import org.apache.doris.nereids.properties.PhysicalProperties;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.trees.expressions.Add;
import org.apache.doris.nereids.trees.expressions.Between;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.GreaterThan;
import org.apache.doris.nereids.trees.expressions.Literal;
import org.apache.doris.nereids.trees.expressions.Subtract;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.Plans;
import org.apache.doris.nereids.util.ExpressionUtils;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInstance;
import java.util.List;
import java.util.Optional;
/**
* plan rewrite ut.
*/
@TestInstance(TestInstance.Lifecycle.PER_CLASS)
public class PushDownPredicateTest implements Plans {
private Table student;
private Table score;
private Table course;
private Plan rStudent;
private Plan rScore;
private Plan rCourse;
/**
* ut before.
*/
@BeforeAll
public final void beforeAll() {
student = new Table(0L, "student", Table.TableType.OLAP,
ImmutableList.<Column>of(new Column("id", Type.INT, true, AggregateType.NONE, "0", ""),
new Column("name", Type.STRING, true, AggregateType.NONE, "", ""),
new Column("age", Type.INT, true, AggregateType.NONE, "", "")));
score = new Table(0L, "score", Table.TableType.OLAP,
ImmutableList.<Column>of(new Column("sid", Type.INT, true, AggregateType.NONE, "0", ""),
new Column("cid", Type.INT, true, AggregateType.NONE, "", ""),
new Column("grade", Type.DOUBLE, true, AggregateType.NONE, "", "")));
course = new Table(0L, "course", Table.TableType.OLAP,
ImmutableList.<Column>of(new Column("cid", Type.INT, true, AggregateType.NONE, "0", ""),
new Column("name", Type.STRING, true, AggregateType.NONE, "", ""),
new Column("teacher", Type.STRING, true, AggregateType.NONE, "", "")));
rStudent = plan(new LogicalOlapScan(student, ImmutableList.of("student")));
rScore = plan(new LogicalOlapScan(score, ImmutableList.of("score")));
rCourse = plan(new LogicalOlapScan(course, ImmutableList.of("course")));
}
@Test
public void pushDownPredicateIntoScanTest1() {
// select id,name,grade from student join score on student.id = score.sid and student.id > 1
// and score.cid > 2 where student.age > 18 and score.grade > 60
Expression onCondition1 = new EqualTo<>(rStudent.getOutput().get(0), rScore.getOutput().get(0));
Expression onCondition2 = new GreaterThan<>(rStudent.getOutput().get(0), Literal.of(1));
Expression onCondition3 = new GreaterThan<>(rScore.getOutput().get(0), Literal.of(2));
Expression onCondition = ExpressionUtils.add(onCondition1, onCondition2, onCondition3);
Expression whereCondition1 = new GreaterThan<>(rStudent.getOutput().get(1), Literal.of(18));
Expression whereCondition2 = new GreaterThan<>(rScore.getOutput().get(2), Literal.of(60));
Expression whereCondition = ExpressionUtils.add(whereCondition1, whereCondition2);
Plan join = plan(new LogicalJoin(JoinType.INNER_JOIN, Optional.of(onCondition)), rStudent, rScore);
Plan filter = plan(new LogicalFilter(whereCondition), join);
Plan root = plan(new LogicalProject(
Lists.newArrayList(rStudent.getOutput().get(1), rCourse.getOutput().get(1), rScore.getOutput().get(2))),
filter);
Memo memo = new Memo();
memo.initialize(root);
System.out.println(memo.copyOut().treeString());
OptimizerContext optimizerContext = new OptimizerContext(memo);
PlannerContext plannerContext = new PlannerContext(optimizerContext, null, new PhysicalProperties());
RewriteTopDownJob rewriteTopDownJob = new RewriteTopDownJob(memo.getRoot(),
ImmutableList.of(new PushPredicateThroughJoin().build()), plannerContext);
plannerContext.getOptimizerContext().pushJob(rewriteTopDownJob);
plannerContext.getOptimizerContext().getJobScheduler().executeJobPool(plannerContext);
Group rootGroup = memo.getRoot();
System.out.println(memo.copyOut().treeString());
System.out.println(11);
Operator op1 = rootGroup.getLogicalExpression().child(0).getLogicalExpression().getOperator();
Operator op2 = rootGroup.getLogicalExpression().child(0).getLogicalExpression().child(0).getLogicalExpression().getOperator();
Operator op3 = rootGroup.getLogicalExpression().child(0).getLogicalExpression().child(1).getLogicalExpression().getOperator();
Assertions.assertTrue(op1 instanceof LogicalJoin);
Assertions.assertTrue(op2 instanceof LogicalFilter);
Assertions.assertTrue(op3 instanceof LogicalFilter);
LogicalJoin join1 = (LogicalJoin) op1;
LogicalFilter filter1 = (LogicalFilter) op2;
LogicalFilter filter2 = (LogicalFilter) op3;
Assertions.assertEquals(join1.getCondition().get(), onCondition1);
Assertions.assertEquals(filter1.getPredicates(), ExpressionUtils.add(onCondition2, whereCondition1));
Assertions.assertEquals(filter2.getPredicates(), ExpressionUtils.add(onCondition3, whereCondition2));
}
@Test
public void pushDownPredicateIntoScanTest3() {
//select id,name,grade from student left join score on student.id + 1 = score.sid - 2
//where student.age > 18 and score.grade > 60
Expression whereCondition1 = new EqualTo<>(new Add<>(rStudent.getOutput().get(0), Literal.of(1)),
new Subtract<>(rScore.getOutput().get(0), Literal.of(2)));
Expression whereCondition2 = new GreaterThan<>(rStudent.getOutput().get(1), Literal.of(18));
Expression whereCondition3 = new GreaterThan<>(rScore.getOutput().get(2), Literal.of(60));
Expression whereCondition = ExpressionUtils.add(whereCondition1, whereCondition2, whereCondition3);
Plan join = plan(new LogicalJoin(JoinType.INNER_JOIN, Optional.empty()), rStudent, rScore);
Plan filter = plan(new LogicalFilter(whereCondition), join);
Plan root = plan(new LogicalProject(
Lists.newArrayList(rStudent.getOutput().get(1), rCourse.getOutput().get(1), rScore.getOutput().get(2))),
filter);
Memo memo = new Memo();
memo.initialize(root);
System.out.println(memo.copyOut().treeString());
OptimizerContext optimizerContext = new OptimizerContext(memo);
PlannerContext plannerContext = new PlannerContext(optimizerContext, null, new PhysicalProperties());
RewriteTopDownJob rewriteTopDownJob = new RewriteTopDownJob(memo.getRoot(),
ImmutableList.of(new PushPredicateThroughJoin().build()), plannerContext);
plannerContext.getOptimizerContext().pushJob(rewriteTopDownJob);
plannerContext.getOptimizerContext().getJobScheduler().executeJobPool(plannerContext);
Group rootGroup = memo.getRoot();
System.out.println(memo.copyOut().treeString());
Operator op1 = rootGroup.getLogicalExpression().child(0).getLogicalExpression().getOperator();
Operator op2 = rootGroup.getLogicalExpression().child(0).getLogicalExpression().child(0).getLogicalExpression().getOperator();
Operator op3 = rootGroup.getLogicalExpression().child(0).getLogicalExpression().child(1).getLogicalExpression().getOperator();
Assertions.assertTrue(op1 instanceof LogicalJoin);
Assertions.assertTrue(op2 instanceof LogicalFilter);
Assertions.assertTrue(op3 instanceof LogicalFilter);
LogicalJoin join1 = (LogicalJoin) op1;
LogicalFilter filter1 = (LogicalFilter) op2;
LogicalFilter filter2 = (LogicalFilter) op3;
Assertions.assertEquals(join1.getCondition().get(), whereCondition1);
Assertions.assertEquals(filter1.getPredicates(), whereCondition2);
Assertions.assertEquals(filter2.getPredicates(), whereCondition3);
}
@Test
public void pushDownPredicateIntoScanTest4() {
/*
select
student.name,
course.name,
score.grade
from student,score,course
where on student.id = score.sid and student.age between 18 and 20 and score.grade > 60 and student.id = score.sid
*/
// student.id = score.sid
Expression whereCondition1 = new EqualTo<>(rStudent.getOutput().get(0), rScore.getOutput().get(0));
// score.cid = course.cid
Expression whereCondition2 = new EqualTo<>(rScore.getOutput().get(1), rCourse.getOutput().get(0));
// student.age between 18 and 20
Expression whereCondition3 = new Between<>(rStudent.getOutput().get(2), Literal.of(18), Literal.of(20));
// score.grade > 60
Expression whereCondition4 = new GreaterThan<>(rScore.getOutput().get(2), Literal.of(60));
Expression whereCondition = ExpressionUtils.add(whereCondition1, whereCondition2, whereCondition3, whereCondition4);
Plan join = plan(new LogicalJoin(JoinType.INNER_JOIN, Optional.empty()), rStudent, rScore);
Plan join1 = plan(new LogicalJoin(JoinType.INNER_JOIN, Optional.empty()), join, rCourse);
Plan filter = plan(new LogicalFilter(whereCondition), join1);
Plan root = plan(new LogicalProject(
Lists.newArrayList(rStudent.getOutput().get(1), rCourse.getOutput().get(1), rScore.getOutput().get(2))),
filter);
Memo memo = new Memo();
memo.initialize(root);
System.out.println(memo.copyOut().treeString());
OptimizerContext optimizerContext = new OptimizerContext(memo);
PlannerContext plannerContext = new PlannerContext(optimizerContext, null, new PhysicalProperties());
List<Rule<Plan>> fakeRules = Lists.newArrayList(new PushPredicateThroughJoin().build());
RewriteTopDownJob rewriteTopDownJob = new RewriteTopDownJob(memo.getRoot(), fakeRules, plannerContext);
plannerContext.getOptimizerContext().pushJob(rewriteTopDownJob);
plannerContext.getOptimizerContext().getJobScheduler().executeJobPool(plannerContext);
Group rootGroup = memo.getRoot();
System.out.println(memo.copyOut().treeString());
Operator join2 = rootGroup.getLogicalExpression().child(0).getLogicalExpression().getOperator();
Operator join3 = rootGroup.getLogicalExpression().child(0).getLogicalExpression().child(0).getLogicalExpression().getOperator();
Operator op1 = rootGroup.getLogicalExpression().child(0).getLogicalExpression().child(0).getLogicalExpression().child(0).getLogicalExpression().getOperator();
Operator op2 = rootGroup.getLogicalExpression().child(0).getLogicalExpression().child(0).getLogicalExpression().child(1).getLogicalExpression().getOperator();
Assertions.assertTrue(join2 instanceof LogicalJoin);
Assertions.assertTrue(join3 instanceof LogicalJoin);
Assertions.assertTrue(op1 instanceof LogicalFilter);
Assertions.assertTrue(op2 instanceof LogicalFilter);
Assertions.assertEquals(((LogicalJoin) join2).getCondition().get(), whereCondition2);
Assertions.assertEquals(((LogicalJoin) join3).getCondition().get(), whereCondition1);
Assertions.assertEquals(((LogicalFilter) op1).getPredicates(), whereCondition3);
Assertions.assertEquals(((LogicalFilter) op2).getPredicates(), whereCondition4);
}
}

View File

@ -84,7 +84,7 @@ public class SSBUtils {
+ " s_nation,\n"
+ " d_year,\n"
+ " SUM(lo_revenue) AS REVENUE\n"
+ "FROM customer, lineorder, supplier, dates\n"
+ "FROM lineorder, customer, supplier, dates\n"
+ "WHERE\n"
+ " lo_custkey = c_custkey\n"
+ " AND lo_suppkey = s_suppkey\n"
@ -101,7 +101,7 @@ public class SSBUtils {
+ " s_city,\n"
+ " d_year,\n"
+ " SUM(lo_revenue) AS REVENUE\n"
+ "FROM customer, lineorder, supplier, dates\n"
+ "FROM lineorder, customer , supplier, dates\n"
+ "WHERE\n"
+ " lo_custkey = c_custkey\n"
+ " AND lo_suppkey = s_suppkey\n"
@ -118,7 +118,7 @@ public class SSBUtils {
+ " s_city,\n"
+ " d_year,\n"
+ " SUM(lo_revenue) AS REVENUE\n"
+ "FROM customer, lineorder, supplier, dates\n"
+ "FROM lineorder, customer, supplier, dates\n"
+ "WHERE\n"
+ " lo_custkey = c_custkey\n"
+ " AND lo_suppkey = s_suppkey\n"
@ -141,7 +141,7 @@ public class SSBUtils {
+ " s_city,\n"
+ " d_year,\n"
+ " SUM(lo_revenue) AS REVENUE\n"
+ "FROM customer, lineorder, supplier, dates\n"
+ "FROM lineorder, customer, supplier, dates\n"
+ "WHERE\n"
+ " lo_custkey = c_custkey\n"
+ " AND lo_suppkey = s_suppkey\n"
@ -162,7 +162,7 @@ public class SSBUtils {
+ " d_year,\n"
+ " c_nation,\n"
+ " SUM(lo_revenue - lo_supplycost) AS PROFIT\n"
+ "FROM dates, customer, supplier, part, lineorder\n"
+ "FROM lineorder, dates, customer, supplier, part\n"
+ "WHERE\n"
+ " lo_custkey = c_custkey\n"
+ " AND lo_suppkey = s_suppkey\n"
@ -182,7 +182,7 @@ public class SSBUtils {
+ " s_nation,\n"
+ " p_category,\n"
+ " SUM(lo_revenue - lo_supplycost) AS PROFIT\n"
+ "FROM dates, customer, supplier, part, lineorder\n"
+ "FROM lineorder, dates, customer, supplier, part\n"
+ "WHERE\n"
+ " lo_custkey = c_custkey\n"
+ " AND lo_suppkey = s_suppkey\n"
@ -206,7 +206,7 @@ public class SSBUtils {
+ " s_city,\n"
+ " p_brand,\n"
+ " SUM(lo_revenue - lo_supplycost) AS PROFIT\n"
+ "FROM dates, customer, supplier, part, lineorder\n"
+ "FROM lineorder, dates, customer, supplier, part\n"
+ "WHERE\n"
+ " lo_custkey = c_custkey\n"
+ " AND lo_suppkey = s_suppkey\n"

View File

@ -0,0 +1,101 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.util;
import org.apache.doris.nereids.parser.NereidsParser;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import java.util.List;
/**
* ExpressionUtils ut.
*/
public class ExpressionUtilsTest {
private static final NereidsParser PARSER = new NereidsParser();
@Test
public void extractConjunctsTest() {
List<Expression> expressions;
Expression expr;
expr = PARSER.createExpression("a");
expressions = ExpressionUtils.extractConjunct(expr);
Assertions.assertEquals(expressions.size(), 1);
Assertions.assertEquals(expressions.get(0), expr);
expr = PARSER.createExpression("a and b and c");
Expression a = PARSER.createExpression("a");
Expression b = PARSER.createExpression("b");
Expression c = PARSER.createExpression("c");
expressions = ExpressionUtils.extractConjunct(expr);
Assertions.assertEquals(expressions.size(), 3);
Assertions.assertEquals(expressions.get(0), a);
Assertions.assertEquals(expressions.get(1), b);
Assertions.assertEquals(expressions.get(2), c);
expr = PARSER.createExpression("(a or b) and c and (e or f)");
expressions = ExpressionUtils.extractConjunct(expr);
Expression aOrb = PARSER.createExpression("a or b");
Expression eOrf = PARSER.createExpression("e or f");
Assertions.assertEquals(expressions.size(), 3);
Assertions.assertEquals(expressions.get(0), aOrb);
Assertions.assertEquals(expressions.get(1), c);
Assertions.assertEquals(expressions.get(2), eOrf);
}
@Test
public void extractDisjunctsTest() {
List<Expression> expressions;
Expression expr;
expr = PARSER.createExpression("a");
expressions = ExpressionUtils.extractDisjunct(expr);
Assertions.assertEquals(expressions.size(), 1);
Assertions.assertEquals(expressions.get(0), expr);
expr = PARSER.createExpression("a or b or c");
Expression a = PARSER.createExpression("a");
Expression b = PARSER.createExpression("b");
Expression c = PARSER.createExpression("c");
expressions = ExpressionUtils.extractDisjunct(expr);
Assertions.assertEquals(expressions.size(), 3);
Assertions.assertEquals(expressions.get(0), a);
Assertions.assertEquals(expressions.get(1), b);
Assertions.assertEquals(expressions.get(2), c);
expr = PARSER.createExpression("(a and b) or c or (e and f)");
expressions = ExpressionUtils.extractDisjunct(expr);
Expression aAndb = PARSER.createExpression("a and b");
Expression eAndf = PARSER.createExpression("e and f");
Assertions.assertEquals(expressions.size(), 3);
Assertions.assertEquals(expressions.get(0), aAndb);
Assertions.assertEquals(expressions.get(1), c);
Assertions.assertEquals(expressions.get(2), eAndf);
}
}