[fix](Nereids) let OrToIn rewritten result have stable order (#31731)
This commit is contained in:
@ -32,8 +32,7 @@ import com.google.common.collect.ImmutableList;
|
||||
import com.google.common.collect.Maps;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.LinkedHashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
@ -60,7 +59,7 @@ public class OrToIn extends DefaultExpressionRewriter<ExpressionRewriteContext>
|
||||
|
||||
public static final OrToIn INSTANCE = new OrToIn();
|
||||
|
||||
private static final int REWRITE_OR_TO_IN_PREDICATE_THRESHOLD = 2;
|
||||
public static final int REWRITE_OR_TO_IN_PREDICATE_THRESHOLD = 2;
|
||||
|
||||
@Override
|
||||
public Expression rewrite(Expression expr, ExpressionRewriteContext ctx) {
|
||||
@ -69,9 +68,11 @@ public class OrToIn extends DefaultExpressionRewriter<ExpressionRewriteContext>
|
||||
|
||||
@Override
|
||||
public Expression visitOr(Or or, ExpressionRewriteContext ctx) {
|
||||
Map<NamedExpression, Set<Literal>> slotNameToLiteral = new HashMap<>();
|
||||
// NOTICE: use linked hash map to avoid unstable order or entry.
|
||||
// unstable order entry lead to dead loop since return expression always un-equals to original one.
|
||||
Map<NamedExpression, Set<Literal>> slotNameToLiteral = Maps.newLinkedHashMap();
|
||||
Map<Expression, NamedExpression> disConjunctToSlot = Maps.newLinkedHashMap();
|
||||
List<Expression> expressions = ExpressionUtils.extractDisjunction(or);
|
||||
Map<Expression, NamedExpression> disConjunctToSlot = Maps.newHashMap();
|
||||
for (Expression expression : expressions) {
|
||||
if (expression instanceof EqualTo) {
|
||||
handleEqualTo((EqualTo) expression, slotNameToLiteral, disConjunctToSlot);
|
||||
@ -128,7 +129,7 @@ public class OrToIn extends DefaultExpressionRewriter<ExpressionRewriteContext>
|
||||
|
||||
public void addSlotToLiteral(NamedExpression namedExpression, Literal literal,
|
||||
Map<NamedExpression, Set<Literal>> slotNameToLiteral) {
|
||||
Set<Literal> literals = slotNameToLiteral.computeIfAbsent(namedExpression, k -> new HashSet<>());
|
||||
Set<Literal> literals = slotNameToLiteral.computeIfAbsent(namedExpression, k -> new LinkedHashSet<>());
|
||||
literals.add(literal);
|
||||
}
|
||||
|
||||
|
||||
@ -417,7 +417,7 @@ public class SimplifyRange extends AbstractExpressionRewriteRule {
|
||||
// They are same processes, so must change synchronously.
|
||||
if (values.size() == 1) {
|
||||
return new EqualTo(reference, values.iterator().next());
|
||||
} else if (values.size() == 2) {
|
||||
} else if (values.size() <= OrToIn.REWRITE_OR_TO_IN_PREDICATE_THRESHOLD) {
|
||||
Iterator<Literal> iterator = values.iterator();
|
||||
return new Or(new EqualTo(reference, iterator.next()), new EqualTo(reference, iterator.next()));
|
||||
} else {
|
||||
|
||||
@ -20,8 +20,6 @@ package org.apache.doris.nereids.trees.plans.algebra;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.util.ExpressionUtils;
|
||||
|
||||
import com.google.common.base.Suppliers;
|
||||
|
||||
import java.util.Set;
|
||||
|
||||
/**
|
||||
@ -31,6 +29,6 @@ public interface Filter {
|
||||
Set<Expression> getConjuncts();
|
||||
|
||||
default Expression getPredicate() {
|
||||
return Suppliers.memoize(() -> ExpressionUtils.and(getConjuncts().toArray(new Expression[0]))).get();
|
||||
return ExpressionUtils.and(getConjuncts());
|
||||
}
|
||||
}
|
||||
|
||||
@ -133,4 +133,14 @@ class OrToInTest extends ExpressionRewriteTestHelper {
|
||||
Assertions.assertEquals("((col = 1) OR ((col = 2) AND col IN ('4', 3, 5.0)))",
|
||||
rewritten.toSql());
|
||||
}
|
||||
|
||||
@Test
|
||||
void testEnsureOrder() {
|
||||
// ensure not rewrite to col2 in (1, 2) or cor 1 in (1, 2)
|
||||
String expr = "col1 IN (1, 2) OR col2 IN (1, 2)";
|
||||
Expression expression = PARSER.parseExpression(expr);
|
||||
Expression rewritten = new OrToIn().rewrite(expression, new ExpressionRewriteContext(null));
|
||||
Assertions.assertEquals("(col1 IN (1, 2) OR col2 IN (1, 2))",
|
||||
rewritten.toSql());
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user