[refactor](nereids) Refine some code snippets (#10672)

Refine some code snippets:
1. Rename: ExpressionUtils::add -> ExpressionUtils::and
2. Reduce temporary objects when combing expressions.
This commit is contained in:
Adonis Ling
2022-07-11 16:31:38 +08:00
committed by GitHub
parent 51855633e4
commit deae728fc6
5 changed files with 31 additions and 68 deletions

View File

@ -84,7 +84,7 @@ public class PushPredicateThroughJoin extends OneRewriteRuleFactory {
List<Slot> leftInput = filter.child().left().getOutput();
List<Slot> rightInput = filter.child().right().getOutput();
ExpressionUtils.extractConjunct(ExpressionUtils.add(onPredicates, wherePredicates)).forEach(predicate -> {
ExpressionUtils.extractConjunct(ExpressionUtils.and(onPredicates, wherePredicates)).forEach(predicate -> {
if (Objects.nonNull(getJoinCondition(predicate, leftInput, rightInput))) {
eqConditions.add(predicate);
} else {
@ -112,7 +112,7 @@ public class PushPredicateThroughJoin extends OneRewriteRuleFactory {
otherConditions.removeAll(leftPredicates);
otherConditions.removeAll(rightPredicates);
otherConditions.addAll(eqConditions);
Expression joinConditions = ExpressionUtils.add(otherConditions);
Expression joinConditions = ExpressionUtils.and(otherConditions);
return pushDownPredicate(filter.child(), joinConditions, leftPredicates, rightPredicates);
}).toRule(RuleType.PUSH_DOWN_PREDICATE_THROUGH_JOIN);
@ -121,8 +121,8 @@ public class PushPredicateThroughJoin extends OneRewriteRuleFactory {
private Plan pushDownPredicate(LogicalBinaryPlan<LogicalJoin, GroupPlan, GroupPlan> joinPlan,
Expression joinConditions, List<Expression> leftPredicates, List<Expression> rightPredicates) {
Expression left = ExpressionUtils.add(leftPredicates);
Expression right = ExpressionUtils.add(rightPredicates);
Expression left = ExpressionUtils.and(leftPredicates);
Expression right = ExpressionUtils.and(rightPredicates);
//todo expr should optimize again using expr rewrite
ExpressionRuleExecutor exprRewriter = new ExpressionRuleExecutor();
Plan leftPlan = joinPlan.left();

View File

@ -75,12 +75,7 @@ public abstract class Expression extends AbstractTreeNode<Expression> {
* Whether the expression is a constant.
*/
public boolean isConstant() {
for (Expression child : children()) {
if (child.isConstant()) {
return true;
}
}
return false;
return children().stream().anyMatch(Expression::isConstant);
}
@Override

View File

@ -22,13 +22,14 @@ import org.apache.doris.nereids.trees.expressions.CompoundPredicate;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Literal;
import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import java.util.LinkedHashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
import java.util.Optional;
/**
* Expression rewrite helper class.
@ -43,7 +44,6 @@ public class ExpressionUtils {
return extract(NodeType.AND, expr);
}
public static List<Expression> extractDisjunct(Expression expr) {
return extract(NodeType.OR, expr);
}
@ -68,12 +68,11 @@ public class ExpressionUtils {
}
}
public static Expression add(List<Expression> expressions) {
public static Expression and(List<Expression> expressions) {
return combine(NodeType.AND, expressions);
}
public static Expression add(Expression... expressions) {
public static Expression and(Expression... expressions) {
return combine(NodeType.AND, Lists.newArrayList(expressions));
}
@ -89,51 +88,22 @@ public class ExpressionUtils {
* Use AND/OR to combine expressions together.
*/
public static Expression combine(NodeType op, List<Expression> expressions) {
Preconditions.checkArgument(op == NodeType.AND || op == NodeType.OR);
Objects.requireNonNull(expressions, "expressions is null");
if (expressions.size() == 0) {
if (op == NodeType.AND) {
return new Literal(true);
}
if (op == NodeType.OR) {
return new Literal(false);
Expression shortCircuit = (op == NodeType.AND ? Literal.FALSE_LITERAL : Literal.TRUE_LITERAL);
Expression skip = (op == NodeType.AND ? Literal.TRUE_LITERAL : Literal.FALSE_LITERAL);
LinkedHashSet<Expression> distinctExpressions = Sets.newLinkedHashSetWithExpectedSize(expressions.size());
for (Expression expression : expressions) {
if (expression.equals(shortCircuit)) {
return shortCircuit;
} else if (!expression.equals(skip)) {
distinctExpressions.add(expression);
}
}
if (expressions.size() == 1) {
return expressions.get(0);
}
List<Expression> distinctExpressions = Lists.newArrayList(new LinkedHashSet<>(expressions));
if (op == NodeType.AND) {
if (distinctExpressions.contains(Literal.FALSE_LITERAL)) {
return Literal.FALSE_LITERAL;
}
distinctExpressions = distinctExpressions.stream().filter(p -> !p.equals(Literal.TRUE_LITERAL))
.collect(Collectors.toList());
}
if (op == NodeType.OR) {
if (distinctExpressions.contains(Literal.TRUE_LITERAL)) {
return Literal.TRUE_LITERAL;
}
distinctExpressions = distinctExpressions.stream().filter(p -> !p.equals(Literal.FALSE_LITERAL))
.collect(Collectors.toList());
}
List<List<Expression>> partitions = Lists.partition(distinctExpressions, 2);
List<Expression> result = new LinkedList<>();
for (List<Expression> partition : partitions) {
if (partition.size() == 2) {
result.add(new CompoundPredicate(op, partition.get(0), partition.get(1)));
}
if (partition.size() == 1) {
result.add(partition.get(0));
}
}
return combine(op, result);
Optional<Expression> result =
distinctExpressions.stream().reduce((left, right) -> new CompoundPredicate<>(op, left, right));
return result.orElse(new Literal(op == NodeType.AND));
}
}

View File

@ -28,10 +28,8 @@ public class Utils {
* @return quoted string
*/
public static String quoteIfNeeded(String part) {
if (part.matches("[a-zA-Z0-9_]+") && !part.matches("\\d+")) {
return part;
} else {
return part.replace("`", "``");
}
// We quote strings except the ones which consist of digits only.
return part.matches("\\w*[\\w&&[^\\d]]+\\w*")
? part : part.replace("`", "``");
}
}

View File

@ -106,11 +106,11 @@ public class PushDownPredicateTest implements Plans {
Expression onCondition1 = new EqualTo<>(rStudent.getOutput().get(0), rScore.getOutput().get(0));
Expression onCondition2 = new GreaterThan<>(rStudent.getOutput().get(0), Literal.of(1));
Expression onCondition3 = new GreaterThan<>(rScore.getOutput().get(0), Literal.of(2));
Expression onCondition = ExpressionUtils.add(onCondition1, onCondition2, onCondition3);
Expression onCondition = ExpressionUtils.and(onCondition1, onCondition2, onCondition3);
Expression whereCondition1 = new GreaterThan<>(rStudent.getOutput().get(1), Literal.of(18));
Expression whereCondition2 = new GreaterThan<>(rScore.getOutput().get(2), Literal.of(60));
Expression whereCondition = ExpressionUtils.add(whereCondition1, whereCondition2);
Expression whereCondition = ExpressionUtils.and(whereCondition1, whereCondition2);
Plan join = plan(new LogicalJoin(JoinType.INNER_JOIN, Optional.of(onCondition)), rStudent, rScore);
@ -149,8 +149,8 @@ public class PushDownPredicateTest implements Plans {
LogicalFilter filter2 = (LogicalFilter) op3;
Assertions.assertEquals(join1.getCondition().get(), onCondition1);
Assertions.assertEquals(filter1.getPredicates(), ExpressionUtils.add(onCondition2, whereCondition1));
Assertions.assertEquals(filter2.getPredicates(), ExpressionUtils.add(onCondition3, whereCondition2));
Assertions.assertEquals(filter1.getPredicates(), ExpressionUtils.and(onCondition2, whereCondition1));
Assertions.assertEquals(filter2.getPredicates(), ExpressionUtils.and(onCondition3, whereCondition2));
}
@Test
@ -161,7 +161,7 @@ public class PushDownPredicateTest implements Plans {
new Subtract<>(rScore.getOutput().get(0), Literal.of(2)));
Expression whereCondition2 = new GreaterThan<>(rStudent.getOutput().get(1), Literal.of(18));
Expression whereCondition3 = new GreaterThan<>(rScore.getOutput().get(2), Literal.of(60));
Expression whereCondition = ExpressionUtils.add(whereCondition1, whereCondition2, whereCondition3);
Expression whereCondition = ExpressionUtils.and(whereCondition1, whereCondition2, whereCondition3);
Plan join = plan(new LogicalJoin(JoinType.INNER_JOIN, Optional.empty()), rStudent, rScore);
Plan filter = plan(new LogicalFilter(whereCondition), join);
@ -226,7 +226,7 @@ public class PushDownPredicateTest implements Plans {
// score.grade > 60
Expression whereCondition4 = new GreaterThan<>(rScore.getOutput().get(2), Literal.of(60));
Expression whereCondition = ExpressionUtils.add(whereCondition1, whereCondition2, whereCondition3, whereCondition4);
Expression whereCondition = ExpressionUtils.and(whereCondition1, whereCondition2, whereCondition3, whereCondition4);
Plan join = plan(new LogicalJoin(JoinType.INNER_JOIN, Optional.empty()), rStudent, rScore);
Plan join1 = plan(new LogicalJoin(JoinType.INNER_JOIN, Optional.empty()), join, rCourse);