[feature](Nereids) multi array contains to array overlap (#23864)
transform ``` array_contains ( c_array, '1' ) OR array_contains ( c_array, '2' ) ``` to ``` array_overlap(c_array, ['1', '2']) ```
This commit is contained in:
@ -17,6 +17,7 @@
|
||||
|
||||
package org.apache.doris.nereids.rules.expression;
|
||||
|
||||
import org.apache.doris.nereids.rules.expression.rules.ArrayContainToArrayOverlap;
|
||||
import org.apache.doris.nereids.rules.expression.rules.DistinctPredicatesRule;
|
||||
import org.apache.doris.nereids.rules.expression.rules.ExtractCommonFactorRule;
|
||||
import org.apache.doris.nereids.rules.expression.rules.OrToIn;
|
||||
@ -40,7 +41,8 @@ public class ExpressionOptimization extends ExpressionRewrite {
|
||||
SimplifyInPredicate.INSTANCE,
|
||||
SimplifyDecimalV3Comparison.INSTANCE,
|
||||
SimplifyRange.INSTANCE,
|
||||
OrToIn.INSTANCE
|
||||
OrToIn.INSTANCE,
|
||||
ArrayContainToArrayOverlap.INSTANCE
|
||||
|
||||
);
|
||||
private static final ExpressionRuleExecutor EXECUTOR = new ExpressionRuleExecutor(OPTIMIZE_REWRITE_RULES);
|
||||
|
||||
@ -0,0 +1,99 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package org.apache.doris.nereids.rules.expression.rules;
|
||||
|
||||
import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext;
|
||||
import org.apache.doris.nereids.rules.expression.ExpressionRewriteRule;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.Or;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayContains;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArraysOverlap;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.ArrayLiteral;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.Literal;
|
||||
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;
|
||||
import org.apache.doris.nereids.util.ExpressionUtils;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import com.google.common.collect.ImmutableList.Builder;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* array_contains ( c_array, '1' )
|
||||
* OR array_contains ( c_array, '2' )
|
||||
* =========================================>
|
||||
* array_overlap(c_array, ['1', '2'])
|
||||
*/
|
||||
public class ArrayContainToArrayOverlap extends DefaultExpressionRewriter<ExpressionRewriteContext> implements
|
||||
ExpressionRewriteRule<ExpressionRewriteContext> {
|
||||
|
||||
public static final ArrayContainToArrayOverlap INSTANCE = new ArrayContainToArrayOverlap();
|
||||
|
||||
private static final int REWRITE_PREDICATE_THRESHOLD = 2;
|
||||
|
||||
@Override
|
||||
public Expression rewrite(Expression expr, ExpressionRewriteContext ctx) {
|
||||
return expr.accept(this, ctx);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Expression visitOr(Or or, ExpressionRewriteContext ctx) {
|
||||
List<Expression> disjuncts = ExpressionUtils.extractDisjunction(or);
|
||||
Map<Boolean, List<Expression>> containFuncAndOtherFunc = disjuncts.stream()
|
||||
.collect(Collectors.partitioningBy(this::isValidArrayContains));
|
||||
Map<Expression, Set<Literal>> containLiteralSet = new HashMap<>();
|
||||
List<Expression> contains = containFuncAndOtherFunc.get(true);
|
||||
List<Expression> others = containFuncAndOtherFunc.get(false);
|
||||
|
||||
contains.forEach(containFunc ->
|
||||
containLiteralSet.computeIfAbsent(containFunc.child(0), k -> new HashSet<>())
|
||||
.add((Literal) containFunc.child(1)));
|
||||
|
||||
Builder<Expression> newDisjunctsBuilder = new ImmutableList.Builder<>();
|
||||
containLiteralSet.forEach((left, literalSet) -> {
|
||||
if (literalSet.size() > REWRITE_PREDICATE_THRESHOLD) {
|
||||
newDisjunctsBuilder.add(
|
||||
new ArraysOverlap(left,
|
||||
new ArrayLiteral(ImmutableList.copyOf(literalSet))));
|
||||
}
|
||||
});
|
||||
|
||||
contains.stream()
|
||||
.filter(e -> !canCovertToArrayOverlap(e, containLiteralSet))
|
||||
.forEach(newDisjunctsBuilder::add);
|
||||
others.stream()
|
||||
.map(e -> e.accept(this, null))
|
||||
.forEach(newDisjunctsBuilder::add);
|
||||
return ExpressionUtils.or(newDisjunctsBuilder.build());
|
||||
}
|
||||
|
||||
private boolean isValidArrayContains(Expression expression) {
|
||||
return expression instanceof ArrayContains && expression.child(1) instanceof Literal;
|
||||
}
|
||||
|
||||
private boolean canCovertToArrayOverlap(Expression expression, Map<Expression, Set<Literal>> containLiteralSet) {
|
||||
return expression instanceof ArrayContains
|
||||
&& containLiteralSet.getOrDefault(expression.child(0),
|
||||
new HashSet<>()).size() > REWRITE_PREDICATE_THRESHOLD;
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,99 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package org.apache.doris.nereids.rules.rewrite;
|
||||
|
||||
import org.apache.doris.nereids.rules.expression.ExpressionRewriteTestHelper;
|
||||
import org.apache.doris.nereids.trees.expressions.And;
|
||||
import org.apache.doris.nereids.trees.expressions.EqualTo;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.Or;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArraysOverlap;
|
||||
import org.apache.doris.nereids.trees.plans.Plan;
|
||||
import org.apache.doris.nereids.util.MemoTestUtils;
|
||||
import org.apache.doris.nereids.util.PlanChecker;
|
||||
|
||||
import org.junit.jupiter.api.Assertions;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
class ArrayContainsToArrayOverlapTest extends ExpressionRewriteTestHelper {
|
||||
|
||||
@Test
|
||||
void testOr() {
|
||||
String sql = "select array_contains([1], 1) or array_contains([1], 2) or array_contains([1], 3);";
|
||||
Plan plan = PlanChecker.from(MemoTestUtils.createConnectContext())
|
||||
.analyze(sql)
|
||||
.rewrite()
|
||||
.getPlan();
|
||||
Expression expression = plan.child(0).getExpressions().get(0).child(0);
|
||||
Assertions.assertTrue(expression instanceof ArraysOverlap);
|
||||
Assertions.assertEquals("array(1)", expression.child(0).toSql());
|
||||
Assertions.assertEquals("array(1, 2, 3)", expression.child(1).toSql());
|
||||
}
|
||||
|
||||
@Test
|
||||
void testAnd() {
|
||||
String sql = "select array_contains([1], 1) "
|
||||
+ "or array_contains([1], 2) "
|
||||
+ "or array_contains([1], 3)"
|
||||
+ "or array_contains([1], 4) and array_contains([1], 5);";
|
||||
Plan plan = PlanChecker.from(MemoTestUtils.createConnectContext())
|
||||
.analyze(sql)
|
||||
.rewrite()
|
||||
.getPlan();
|
||||
Expression expression = plan.child(0).getExpressions().get(0).child(0);
|
||||
Assertions.assertTrue(expression instanceof Or);
|
||||
Assertions.assertTrue(expression.child(0) instanceof ArraysOverlap);
|
||||
Assertions.assertTrue(expression.child(1) instanceof And);
|
||||
}
|
||||
|
||||
@Test
|
||||
void testAndOther() {
|
||||
String sql = "select bin(0) == 1 "
|
||||
+ "or array_contains([1], 1) "
|
||||
+ "or array_contains([1], 2) "
|
||||
+ "or array_contains([1], 3) "
|
||||
+ "or array_contains([1], 4) and array_contains([1], 5);";
|
||||
Plan plan = PlanChecker.from(MemoTestUtils.createConnectContext())
|
||||
.analyze(sql)
|
||||
.rewrite()
|
||||
.getPlan();
|
||||
Expression expression = plan.child(0).getExpressions().get(0).child(0);
|
||||
Assertions.assertTrue(expression instanceof Or);
|
||||
Assertions.assertTrue(expression.child(0) instanceof Or);
|
||||
Assertions.assertTrue(expression.child(0).child(0) instanceof ArraysOverlap);
|
||||
Assertions.assertTrue(expression.child(0).child(1) instanceof EqualTo);
|
||||
Assertions.assertTrue(expression.child(1) instanceof And);
|
||||
}
|
||||
|
||||
@Test
|
||||
void testAndOverlap() {
|
||||
String sql = "select array_contains([1], 0) "
|
||||
+ "or (array_contains([1], 1) "
|
||||
+ "and (array_contains([1], 2) "
|
||||
+ "or array_contains([1], 3) "
|
||||
+ "or array_contains([1], 4)));";
|
||||
Plan plan = PlanChecker.from(MemoTestUtils.createConnectContext())
|
||||
.analyze(sql)
|
||||
.rewrite()
|
||||
.getPlan();
|
||||
Expression expression = plan.child(0).getExpressions().get(0).child(0);
|
||||
Assertions.assertEquals("(array_contains(array(1), 0) OR "
|
||||
+ "(array_contains(array(1), 1) AND arrays_overlap(array(1), array(2, 3, 4))))",
|
||||
expression.toSql());
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user