[fix](Nereids) fix or to in rule (#23940)
or expression context can't propagation cross or expression. for example: ``` select (a = 1 or a = 2 or a = 3) + (a = 4 or a = 5 or a = 6) = select a in [1, 2, 3] + a in [4,5,6] != select a in [1, 2, 3] + a in [1, 2, 3, 4, 5, 6] ```
This commit is contained in:
@ -19,13 +19,11 @@ package org.apache.doris.nereids.rules.expression.rules;
|
||||
|
||||
import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext;
|
||||
import org.apache.doris.nereids.rules.expression.ExpressionRewriteRule;
|
||||
import org.apache.doris.nereids.rules.expression.rules.OrToIn.OrToInContext;
|
||||
import org.apache.doris.nereids.trees.expressions.And;
|
||||
import org.apache.doris.nereids.trees.expressions.CompoundPredicate;
|
||||
import org.apache.doris.nereids.trees.expressions.EqualTo;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.InPredicate;
|
||||
import org.apache.doris.nereids.trees.expressions.NamedExpression;
|
||||
import org.apache.doris.nereids.trees.expressions.Or;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.Literal;
|
||||
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;
|
||||
import org.apache.doris.nereids.util.ExpressionUtils;
|
||||
@ -57,7 +55,7 @@ import java.util.Set;
|
||||
* adding any additional rule-specific fields to the default ExpressionRewriteContext. However, the entire expression
|
||||
* rewrite framework always passes an ExpressionRewriteContext of type context to all rules.
|
||||
*/
|
||||
public class OrToIn extends DefaultExpressionRewriter<OrToInContext> implements
|
||||
public class OrToIn extends DefaultExpressionRewriter<ExpressionRewriteContext> implements
|
||||
ExpressionRewriteRule<ExpressionRewriteContext> {
|
||||
|
||||
public static final OrToIn INSTANCE = new OrToIn();
|
||||
@ -66,25 +64,20 @@ public class OrToIn extends DefaultExpressionRewriter<OrToInContext> implements
|
||||
|
||||
@Override
|
||||
public Expression rewrite(Expression expr, ExpressionRewriteContext ctx) {
|
||||
return expr.accept(this, new OrToInContext());
|
||||
return expr.accept(this, null);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Expression visitCompoundPredicate(CompoundPredicate compoundPredicate, OrToInContext context) {
|
||||
if (compoundPredicate instanceof And) {
|
||||
return compoundPredicate.withChildren(compoundPredicate.child(0).accept(new OrToIn(),
|
||||
new OrToInContext()),
|
||||
compoundPredicate.child(1).accept(new OrToIn(),
|
||||
new OrToInContext()));
|
||||
}
|
||||
List<Expression> expressions = ExpressionUtils.extractDisjunction(compoundPredicate);
|
||||
public Expression visitOr(Or or, ExpressionRewriteContext ctx) {
|
||||
Map<NamedExpression, Set<Literal>> slotNameToLiteral = new HashMap<>();
|
||||
List<Expression> expressions = ExpressionUtils.extractDisjunction(or);
|
||||
for (Expression expression : expressions) {
|
||||
if (expression instanceof EqualTo) {
|
||||
addSlotToLiteralMap((EqualTo) expression, context);
|
||||
addSlotToLiteralMap((EqualTo) expression, slotNameToLiteral);
|
||||
}
|
||||
}
|
||||
List<Expression> rewrittenOr = new ArrayList<>();
|
||||
for (Map.Entry<NamedExpression, Set<Literal>> entry : context.slotNameToLiteral.entrySet()) {
|
||||
for (Map.Entry<NamedExpression, Set<Literal>> entry : slotNameToLiteral.entrySet()) {
|
||||
Set<Literal> literals = entry.getValue();
|
||||
if (literals.size() >= REWRITE_OR_TO_IN_PREDICATE_THRESHOLD) {
|
||||
InPredicate inPredicate = new InPredicate(entry.getKey(), ImmutableList.copyOf(entry.getValue()));
|
||||
@ -92,26 +85,26 @@ public class OrToIn extends DefaultExpressionRewriter<OrToInContext> implements
|
||||
}
|
||||
}
|
||||
for (Expression expression : expressions) {
|
||||
if (!ableToConvertToIn(expression, context)) {
|
||||
rewrittenOr.add(expression);
|
||||
if (!ableToConvertToIn(expression, slotNameToLiteral)) {
|
||||
rewrittenOr.add(expression.accept(this, null));
|
||||
}
|
||||
}
|
||||
|
||||
return ExpressionUtils.or(rewrittenOr);
|
||||
}
|
||||
|
||||
private void addSlotToLiteralMap(EqualTo equal, OrToInContext context) {
|
||||
private void addSlotToLiteralMap(EqualTo equal, Map<NamedExpression, Set<Literal>> slotNameToLiteral) {
|
||||
Expression left = equal.left();
|
||||
Expression right = equal.right();
|
||||
if (left instanceof NamedExpression && right instanceof Literal) {
|
||||
addSlotToLiteral((NamedExpression) left, (Literal) right, context);
|
||||
addSlotToLiteral((NamedExpression) left, (Literal) right, slotNameToLiteral);
|
||||
}
|
||||
if (right instanceof NamedExpression && left instanceof Literal) {
|
||||
addSlotToLiteral((NamedExpression) right, (Literal) left, context);
|
||||
addSlotToLiteral((NamedExpression) right, (Literal) left, slotNameToLiteral);
|
||||
}
|
||||
}
|
||||
|
||||
private boolean ableToConvertToIn(Expression expression, OrToInContext context) {
|
||||
private boolean ableToConvertToIn(Expression expression, Map<NamedExpression, Set<Literal>> slotNameToLiteral) {
|
||||
if (!(expression instanceof EqualTo)) {
|
||||
return false;
|
||||
}
|
||||
@ -126,24 +119,18 @@ public class OrToIn extends DefaultExpressionRewriter<OrToInContext> implements
|
||||
namedExpression = (NamedExpression) right;
|
||||
}
|
||||
return namedExpression != null
|
||||
&& findSizeOfLiteralThatEqualToSameSlotInOr(namedExpression, context)
|
||||
&& findSizeOfLiteralThatEqualToSameSlotInOr(namedExpression, slotNameToLiteral)
|
||||
>= REWRITE_OR_TO_IN_PREDICATE_THRESHOLD;
|
||||
}
|
||||
|
||||
public void addSlotToLiteral(NamedExpression namedExpression, Literal literal, OrToInContext context) {
|
||||
Set<Literal> literals = context.slotNameToLiteral.computeIfAbsent(namedExpression, k -> new HashSet<>());
|
||||
public void addSlotToLiteral(NamedExpression namedExpression, Literal literal,
|
||||
Map<NamedExpression, Set<Literal>> slotNameToLiteral) {
|
||||
Set<Literal> literals = slotNameToLiteral.computeIfAbsent(namedExpression, k -> new HashSet<>());
|
||||
literals.add(literal);
|
||||
}
|
||||
|
||||
public int findSizeOfLiteralThatEqualToSameSlotInOr(NamedExpression namedExpression, OrToInContext context) {
|
||||
return context.slotNameToLiteral.getOrDefault(namedExpression, Collections.emptySet()).size();
|
||||
}
|
||||
|
||||
/**
|
||||
* Context of OrToIn
|
||||
*/
|
||||
public static class OrToInContext {
|
||||
public final Map<NamedExpression, Set<Literal>> slotNameToLiteral = new HashMap<>();
|
||||
|
||||
public int findSizeOfLiteralThatEqualToSameSlotInOr(NamedExpression namedExpression,
|
||||
Map<NamedExpression, Set<Literal>> slotNameToLiteral) {
|
||||
return slotNameToLiteral.getOrDefault(namedExpression, Collections.emptySet()).size();
|
||||
}
|
||||
}
|
||||
|
||||
@ -33,10 +33,10 @@ import org.junit.jupiter.api.Test;
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
|
||||
public class OrToInTest extends ExpressionRewriteTestHelper {
|
||||
class OrToInTest extends ExpressionRewriteTestHelper {
|
||||
|
||||
@Test
|
||||
public void test1() {
|
||||
void test1() {
|
||||
String expr = "col1 = 1 or col1 = 2 or col1 = 3 and (col2 = 4)";
|
||||
Expression expression = PARSER.parseExpression(expr);
|
||||
Expression rewritten = new OrToIn().rewrite(expression, new ExpressionRewriteContext(null));
|
||||
@ -59,7 +59,7 @@ public class OrToInTest extends ExpressionRewriteTestHelper {
|
||||
}
|
||||
|
||||
@Test
|
||||
public void test2() {
|
||||
void test2() {
|
||||
String expr = "col1 = 1 and col1 = 3 and col2 = 3 or col2 = 4";
|
||||
Expression expression = PARSER.parseExpression(expr);
|
||||
Expression rewritten = new OrToIn().rewrite(expression, new ExpressionRewriteContext(null));
|
||||
@ -68,7 +68,7 @@ public class OrToInTest extends ExpressionRewriteTestHelper {
|
||||
}
|
||||
|
||||
@Test
|
||||
public void test3() {
|
||||
void test3() {
|
||||
String expr = "(col1 = 1 or col1 = 2) and (col2 = 3 or col2 = 4)";
|
||||
Expression expression = PARSER.parseExpression(expr);
|
||||
Expression rewritten = new OrToIn().rewrite(expression, new ExpressionRewriteContext(null));
|
||||
@ -90,4 +90,23 @@ public class OrToInTest extends ExpressionRewriteTestHelper {
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
void test4() {
|
||||
String expr = "case when col = 1 or col = 2 or col = 3 then 1"
|
||||
+ " when col = 4 or col = 5 or col = 6 then 1 else 0 end";
|
||||
Expression expression = PARSER.parseExpression(expr);
|
||||
Expression rewritten = new OrToIn().rewrite(expression, new ExpressionRewriteContext(null));
|
||||
Assertions.assertEquals("CASE WHEN col IN (1, 2, 3) THEN 1 WHEN col IN (4, 5, 6) THEN 1 ELSE 0 END",
|
||||
rewritten.toSql());
|
||||
}
|
||||
|
||||
@Test
|
||||
void test5() {
|
||||
String expr = "col = 1 or (col = 2 and (col = 3 or col = 4 or col = 5))";
|
||||
Expression expression = PARSER.parseExpression(expr);
|
||||
Expression rewritten = new OrToIn().rewrite(expression, new ExpressionRewriteContext(null));
|
||||
Assertions.assertEquals("((col = 1) OR ((col = 2) AND col IN (3, 4, 5)))",
|
||||
rewritten.toSql());
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user