[fix](Nereids) let OrToIn rewritten result have stable order (#31731)

This commit is contained in:
morrySnow
2024-03-05 19:39:24 +08:00
committed by yiguolei
parent c43bc8349f
commit d94d2c65f6
4 changed files with 19 additions and 10 deletions

View File

@ -32,8 +32,7 @@ import com.google.common.collect.ImmutableList;
import com.google.common.collect.Maps;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
@ -60,7 +59,7 @@ public class OrToIn extends DefaultExpressionRewriter<ExpressionRewriteContext>
public static final OrToIn INSTANCE = new OrToIn();
private static final int REWRITE_OR_TO_IN_PREDICATE_THRESHOLD = 2;
public static final int REWRITE_OR_TO_IN_PREDICATE_THRESHOLD = 2;
@Override
public Expression rewrite(Expression expr, ExpressionRewriteContext ctx) {
@ -69,9 +68,11 @@ public class OrToIn extends DefaultExpressionRewriter<ExpressionRewriteContext>
@Override
public Expression visitOr(Or or, ExpressionRewriteContext ctx) {
Map<NamedExpression, Set<Literal>> slotNameToLiteral = new HashMap<>();
// NOTICE: use linked hash map to avoid unstable order or entry.
// unstable order entry lead to dead loop since return expression always un-equals to original one.
Map<NamedExpression, Set<Literal>> slotNameToLiteral = Maps.newLinkedHashMap();
Map<Expression, NamedExpression> disConjunctToSlot = Maps.newLinkedHashMap();
List<Expression> expressions = ExpressionUtils.extractDisjunction(or);
Map<Expression, NamedExpression> disConjunctToSlot = Maps.newHashMap();
for (Expression expression : expressions) {
if (expression instanceof EqualTo) {
handleEqualTo((EqualTo) expression, slotNameToLiteral, disConjunctToSlot);
@ -128,7 +129,7 @@ public class OrToIn extends DefaultExpressionRewriter<ExpressionRewriteContext>
public void addSlotToLiteral(NamedExpression namedExpression, Literal literal,
Map<NamedExpression, Set<Literal>> slotNameToLiteral) {
Set<Literal> literals = slotNameToLiteral.computeIfAbsent(namedExpression, k -> new HashSet<>());
Set<Literal> literals = slotNameToLiteral.computeIfAbsent(namedExpression, k -> new LinkedHashSet<>());
literals.add(literal);
}

View File

@ -417,7 +417,7 @@ public class SimplifyRange extends AbstractExpressionRewriteRule {
// They are same processes, so must change synchronously.
if (values.size() == 1) {
return new EqualTo(reference, values.iterator().next());
} else if (values.size() == 2) {
} else if (values.size() <= OrToIn.REWRITE_OR_TO_IN_PREDICATE_THRESHOLD) {
Iterator<Literal> iterator = values.iterator();
return new Or(new EqualTo(reference, iterator.next()), new EqualTo(reference, iterator.next()));
} else {

View File

@ -20,8 +20,6 @@ package org.apache.doris.nereids.trees.plans.algebra;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.util.ExpressionUtils;
import com.google.common.base.Suppliers;
import java.util.Set;
/**
@ -31,6 +29,6 @@ public interface Filter {
Set<Expression> getConjuncts();
default Expression getPredicate() {
return Suppliers.memoize(() -> ExpressionUtils.and(getConjuncts().toArray(new Expression[0]))).get();
return ExpressionUtils.and(getConjuncts());
}
}

View File

@ -133,4 +133,14 @@ class OrToInTest extends ExpressionRewriteTestHelper {
Assertions.assertEquals("((col = 1) OR ((col = 2) AND col IN ('4', 3, 5.0)))",
rewritten.toSql());
}
@Test
void testEnsureOrder() {
// ensure not rewrite to col2 in (1, 2) or cor 1 in (1, 2)
String expr = "col1 IN (1, 2) OR col2 IN (1, 2)";
Expression expression = PARSER.parseExpression(expr);
Expression rewritten = new OrToIn().rewrite(expression, new ExpressionRewriteContext(null));
Assertions.assertEquals("(col1 IN (1, 2) OR col2 IN (1, 2))",
rewritten.toSql());
}
}