[refactor](nereids) Refine some code snippets (#10672)
Refine some code snippets: 1. Rename: ExpressionUtils::add -> ExpressionUtils::and 2. Reduce temporary objects when combing expressions.
This commit is contained in:
@ -84,7 +84,7 @@ public class PushPredicateThroughJoin extends OneRewriteRuleFactory {
|
||||
List<Slot> leftInput = filter.child().left().getOutput();
|
||||
List<Slot> rightInput = filter.child().right().getOutput();
|
||||
|
||||
ExpressionUtils.extractConjunct(ExpressionUtils.add(onPredicates, wherePredicates)).forEach(predicate -> {
|
||||
ExpressionUtils.extractConjunct(ExpressionUtils.and(onPredicates, wherePredicates)).forEach(predicate -> {
|
||||
if (Objects.nonNull(getJoinCondition(predicate, leftInput, rightInput))) {
|
||||
eqConditions.add(predicate);
|
||||
} else {
|
||||
@ -112,7 +112,7 @@ public class PushPredicateThroughJoin extends OneRewriteRuleFactory {
|
||||
otherConditions.removeAll(leftPredicates);
|
||||
otherConditions.removeAll(rightPredicates);
|
||||
otherConditions.addAll(eqConditions);
|
||||
Expression joinConditions = ExpressionUtils.add(otherConditions);
|
||||
Expression joinConditions = ExpressionUtils.and(otherConditions);
|
||||
|
||||
return pushDownPredicate(filter.child(), joinConditions, leftPredicates, rightPredicates);
|
||||
}).toRule(RuleType.PUSH_DOWN_PREDICATE_THROUGH_JOIN);
|
||||
@ -121,8 +121,8 @@ public class PushPredicateThroughJoin extends OneRewriteRuleFactory {
|
||||
private Plan pushDownPredicate(LogicalBinaryPlan<LogicalJoin, GroupPlan, GroupPlan> joinPlan,
|
||||
Expression joinConditions, List<Expression> leftPredicates, List<Expression> rightPredicates) {
|
||||
|
||||
Expression left = ExpressionUtils.add(leftPredicates);
|
||||
Expression right = ExpressionUtils.add(rightPredicates);
|
||||
Expression left = ExpressionUtils.and(leftPredicates);
|
||||
Expression right = ExpressionUtils.and(rightPredicates);
|
||||
//todo expr should optimize again using expr rewrite
|
||||
ExpressionRuleExecutor exprRewriter = new ExpressionRuleExecutor();
|
||||
Plan leftPlan = joinPlan.left();
|
||||
|
||||
@ -75,12 +75,7 @@ public abstract class Expression extends AbstractTreeNode<Expression> {
|
||||
* Whether the expression is a constant.
|
||||
*/
|
||||
public boolean isConstant() {
|
||||
for (Expression child : children()) {
|
||||
if (child.isConstant()) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
return children().stream().anyMatch(Expression::isConstant);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@ -22,13 +22,14 @@ import org.apache.doris.nereids.trees.expressions.CompoundPredicate;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.Literal;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.google.common.collect.Sets;
|
||||
|
||||
import java.util.LinkedHashSet;
|
||||
import java.util.LinkedList;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.Optional;
|
||||
|
||||
/**
|
||||
* Expression rewrite helper class.
|
||||
@ -43,7 +44,6 @@ public class ExpressionUtils {
|
||||
return extract(NodeType.AND, expr);
|
||||
}
|
||||
|
||||
|
||||
public static List<Expression> extractDisjunct(Expression expr) {
|
||||
return extract(NodeType.OR, expr);
|
||||
}
|
||||
@ -68,12 +68,11 @@ public class ExpressionUtils {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
public static Expression add(List<Expression> expressions) {
|
||||
public static Expression and(List<Expression> expressions) {
|
||||
return combine(NodeType.AND, expressions);
|
||||
}
|
||||
|
||||
public static Expression add(Expression... expressions) {
|
||||
public static Expression and(Expression... expressions) {
|
||||
return combine(NodeType.AND, Lists.newArrayList(expressions));
|
||||
}
|
||||
|
||||
@ -89,51 +88,22 @@ public class ExpressionUtils {
|
||||
* Use AND/OR to combine expressions together.
|
||||
*/
|
||||
public static Expression combine(NodeType op, List<Expression> expressions) {
|
||||
|
||||
Preconditions.checkArgument(op == NodeType.AND || op == NodeType.OR);
|
||||
Objects.requireNonNull(expressions, "expressions is null");
|
||||
|
||||
if (expressions.size() == 0) {
|
||||
if (op == NodeType.AND) {
|
||||
return new Literal(true);
|
||||
}
|
||||
if (op == NodeType.OR) {
|
||||
return new Literal(false);
|
||||
Expression shortCircuit = (op == NodeType.AND ? Literal.FALSE_LITERAL : Literal.TRUE_LITERAL);
|
||||
Expression skip = (op == NodeType.AND ? Literal.TRUE_LITERAL : Literal.FALSE_LITERAL);
|
||||
LinkedHashSet<Expression> distinctExpressions = Sets.newLinkedHashSetWithExpectedSize(expressions.size());
|
||||
for (Expression expression : expressions) {
|
||||
if (expression.equals(shortCircuit)) {
|
||||
return shortCircuit;
|
||||
} else if (!expression.equals(skip)) {
|
||||
distinctExpressions.add(expression);
|
||||
}
|
||||
}
|
||||
|
||||
if (expressions.size() == 1) {
|
||||
return expressions.get(0);
|
||||
}
|
||||
|
||||
List<Expression> distinctExpressions = Lists.newArrayList(new LinkedHashSet<>(expressions));
|
||||
if (op == NodeType.AND) {
|
||||
if (distinctExpressions.contains(Literal.FALSE_LITERAL)) {
|
||||
return Literal.FALSE_LITERAL;
|
||||
}
|
||||
distinctExpressions = distinctExpressions.stream().filter(p -> !p.equals(Literal.TRUE_LITERAL))
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
if (op == NodeType.OR) {
|
||||
if (distinctExpressions.contains(Literal.TRUE_LITERAL)) {
|
||||
return Literal.TRUE_LITERAL;
|
||||
}
|
||||
distinctExpressions = distinctExpressions.stream().filter(p -> !p.equals(Literal.FALSE_LITERAL))
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
List<List<Expression>> partitions = Lists.partition(distinctExpressions, 2);
|
||||
List<Expression> result = new LinkedList<>();
|
||||
|
||||
for (List<Expression> partition : partitions) {
|
||||
if (partition.size() == 2) {
|
||||
result.add(new CompoundPredicate(op, partition.get(0), partition.get(1)));
|
||||
}
|
||||
if (partition.size() == 1) {
|
||||
result.add(partition.get(0));
|
||||
}
|
||||
}
|
||||
|
||||
return combine(op, result);
|
||||
Optional<Expression> result =
|
||||
distinctExpressions.stream().reduce((left, right) -> new CompoundPredicate<>(op, left, right));
|
||||
return result.orElse(new Literal(op == NodeType.AND));
|
||||
}
|
||||
}
|
||||
|
||||
@ -28,10 +28,8 @@ public class Utils {
|
||||
* @return quoted string
|
||||
*/
|
||||
public static String quoteIfNeeded(String part) {
|
||||
if (part.matches("[a-zA-Z0-9_]+") && !part.matches("\\d+")) {
|
||||
return part;
|
||||
} else {
|
||||
return part.replace("`", "``");
|
||||
}
|
||||
// We quote strings except the ones which consist of digits only.
|
||||
return part.matches("\\w*[\\w&&[^\\d]]+\\w*")
|
||||
? part : part.replace("`", "``");
|
||||
}
|
||||
}
|
||||
|
||||
@ -106,11 +106,11 @@ public class PushDownPredicateTest implements Plans {
|
||||
Expression onCondition1 = new EqualTo<>(rStudent.getOutput().get(0), rScore.getOutput().get(0));
|
||||
Expression onCondition2 = new GreaterThan<>(rStudent.getOutput().get(0), Literal.of(1));
|
||||
Expression onCondition3 = new GreaterThan<>(rScore.getOutput().get(0), Literal.of(2));
|
||||
Expression onCondition = ExpressionUtils.add(onCondition1, onCondition2, onCondition3);
|
||||
Expression onCondition = ExpressionUtils.and(onCondition1, onCondition2, onCondition3);
|
||||
|
||||
Expression whereCondition1 = new GreaterThan<>(rStudent.getOutput().get(1), Literal.of(18));
|
||||
Expression whereCondition2 = new GreaterThan<>(rScore.getOutput().get(2), Literal.of(60));
|
||||
Expression whereCondition = ExpressionUtils.add(whereCondition1, whereCondition2);
|
||||
Expression whereCondition = ExpressionUtils.and(whereCondition1, whereCondition2);
|
||||
|
||||
|
||||
Plan join = plan(new LogicalJoin(JoinType.INNER_JOIN, Optional.of(onCondition)), rStudent, rScore);
|
||||
@ -149,8 +149,8 @@ public class PushDownPredicateTest implements Plans {
|
||||
LogicalFilter filter2 = (LogicalFilter) op3;
|
||||
|
||||
Assertions.assertEquals(join1.getCondition().get(), onCondition1);
|
||||
Assertions.assertEquals(filter1.getPredicates(), ExpressionUtils.add(onCondition2, whereCondition1));
|
||||
Assertions.assertEquals(filter2.getPredicates(), ExpressionUtils.add(onCondition3, whereCondition2));
|
||||
Assertions.assertEquals(filter1.getPredicates(), ExpressionUtils.and(onCondition2, whereCondition1));
|
||||
Assertions.assertEquals(filter2.getPredicates(), ExpressionUtils.and(onCondition3, whereCondition2));
|
||||
}
|
||||
|
||||
@Test
|
||||
@ -161,7 +161,7 @@ public class PushDownPredicateTest implements Plans {
|
||||
new Subtract<>(rScore.getOutput().get(0), Literal.of(2)));
|
||||
Expression whereCondition2 = new GreaterThan<>(rStudent.getOutput().get(1), Literal.of(18));
|
||||
Expression whereCondition3 = new GreaterThan<>(rScore.getOutput().get(2), Literal.of(60));
|
||||
Expression whereCondition = ExpressionUtils.add(whereCondition1, whereCondition2, whereCondition3);
|
||||
Expression whereCondition = ExpressionUtils.and(whereCondition1, whereCondition2, whereCondition3);
|
||||
|
||||
Plan join = plan(new LogicalJoin(JoinType.INNER_JOIN, Optional.empty()), rStudent, rScore);
|
||||
Plan filter = plan(new LogicalFilter(whereCondition), join);
|
||||
@ -226,7 +226,7 @@ public class PushDownPredicateTest implements Plans {
|
||||
// score.grade > 60
|
||||
Expression whereCondition4 = new GreaterThan<>(rScore.getOutput().get(2), Literal.of(60));
|
||||
|
||||
Expression whereCondition = ExpressionUtils.add(whereCondition1, whereCondition2, whereCondition3, whereCondition4);
|
||||
Expression whereCondition = ExpressionUtils.and(whereCondition1, whereCondition2, whereCondition3, whereCondition4);
|
||||
|
||||
Plan join = plan(new LogicalJoin(JoinType.INNER_JOIN, Optional.empty()), rStudent, rScore);
|
||||
Plan join1 = plan(new LogicalJoin(JoinType.INNER_JOIN, Optional.empty()), join, rCourse);
|
||||
|
||||
Reference in New Issue
Block a user