diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionOptimization.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionOptimization.java index 53f360d5ff..3f30d10ebe 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionOptimization.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionOptimization.java @@ -17,6 +17,7 @@ package org.apache.doris.nereids.rules.expression; +import org.apache.doris.nereids.rules.expression.rules.ArrayContainToArrayOverlap; import org.apache.doris.nereids.rules.expression.rules.DistinctPredicatesRule; import org.apache.doris.nereids.rules.expression.rules.ExtractCommonFactorRule; import org.apache.doris.nereids.rules.expression.rules.OrToIn; @@ -40,7 +41,8 @@ public class ExpressionOptimization extends ExpressionRewrite { SimplifyInPredicate.INSTANCE, SimplifyDecimalV3Comparison.INSTANCE, SimplifyRange.INSTANCE, - OrToIn.INSTANCE + OrToIn.INSTANCE, + ArrayContainToArrayOverlap.INSTANCE ); private static final ExpressionRuleExecutor EXECUTOR = new ExpressionRuleExecutor(OPTIMIZE_REWRITE_RULES); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/ArrayContainToArrayOverlap.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/ArrayContainToArrayOverlap.java new file mode 100644 index 0000000000..7309ef111c --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/ArrayContainToArrayOverlap.java @@ -0,0 +1,99 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.expression.rules; + +import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext; +import org.apache.doris.nereids.rules.expression.ExpressionRewriteRule; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.Or; +import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayContains; +import org.apache.doris.nereids.trees.expressions.functions.scalar.ArraysOverlap; +import org.apache.doris.nereids.trees.expressions.literal.ArrayLiteral; +import org.apache.doris.nereids.trees.expressions.literal.Literal; +import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter; +import org.apache.doris.nereids.util.ExpressionUtils; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableList.Builder; + +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +/** + * array_contains ( c_array, '1' ) + * OR array_contains ( c_array, '2' ) + * =========================================> + * array_overlap(c_array, ['1', '2']) + */ +public class ArrayContainToArrayOverlap extends DefaultExpressionRewriter implements + ExpressionRewriteRule { + + public static final ArrayContainToArrayOverlap INSTANCE = new ArrayContainToArrayOverlap(); + + private static final int REWRITE_PREDICATE_THRESHOLD = 2; + + @Override + public Expression rewrite(Expression expr, ExpressionRewriteContext ctx) { + return expr.accept(this, ctx); + } + + @Override + public Expression visitOr(Or or, ExpressionRewriteContext ctx) { + List disjuncts = ExpressionUtils.extractDisjunction(or); + Map> containFuncAndOtherFunc = disjuncts.stream() + .collect(Collectors.partitioningBy(this::isValidArrayContains)); + Map> containLiteralSet = new HashMap<>(); + List contains = containFuncAndOtherFunc.get(true); + List others = containFuncAndOtherFunc.get(false); + + contains.forEach(containFunc -> + containLiteralSet.computeIfAbsent(containFunc.child(0), k -> new HashSet<>()) + .add((Literal) containFunc.child(1))); + + Builder newDisjunctsBuilder = new ImmutableList.Builder<>(); + containLiteralSet.forEach((left, literalSet) -> { + if (literalSet.size() > REWRITE_PREDICATE_THRESHOLD) { + newDisjunctsBuilder.add( + new ArraysOverlap(left, + new ArrayLiteral(ImmutableList.copyOf(literalSet)))); + } + }); + + contains.stream() + .filter(e -> !canCovertToArrayOverlap(e, containLiteralSet)) + .forEach(newDisjunctsBuilder::add); + others.stream() + .map(e -> e.accept(this, null)) + .forEach(newDisjunctsBuilder::add); + return ExpressionUtils.or(newDisjunctsBuilder.build()); + } + + private boolean isValidArrayContains(Expression expression) { + return expression instanceof ArrayContains && expression.child(1) instanceof Literal; + } + + private boolean canCovertToArrayOverlap(Expression expression, Map> containLiteralSet) { + return expression instanceof ArrayContains + && containLiteralSet.getOrDefault(expression.child(0), + new HashSet<>()).size() > REWRITE_PREDICATE_THRESHOLD; + } +} diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/ArrayContainsToArrayOverlapTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/ArrayContainsToArrayOverlapTest.java new file mode 100644 index 0000000000..dfee1a7cae --- /dev/null +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/ArrayContainsToArrayOverlapTest.java @@ -0,0 +1,99 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.rewrite; + +import org.apache.doris.nereids.rules.expression.ExpressionRewriteTestHelper; +import org.apache.doris.nereids.trees.expressions.And; +import org.apache.doris.nereids.trees.expressions.EqualTo; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.Or; +import org.apache.doris.nereids.trees.expressions.functions.scalar.ArraysOverlap; +import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.util.MemoTestUtils; +import org.apache.doris.nereids.util.PlanChecker; + +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +class ArrayContainsToArrayOverlapTest extends ExpressionRewriteTestHelper { + + @Test + void testOr() { + String sql = "select array_contains([1], 1) or array_contains([1], 2) or array_contains([1], 3);"; + Plan plan = PlanChecker.from(MemoTestUtils.createConnectContext()) + .analyze(sql) + .rewrite() + .getPlan(); + Expression expression = plan.child(0).getExpressions().get(0).child(0); + Assertions.assertTrue(expression instanceof ArraysOverlap); + Assertions.assertEquals("array(1)", expression.child(0).toSql()); + Assertions.assertEquals("array(1, 2, 3)", expression.child(1).toSql()); + } + + @Test + void testAnd() { + String sql = "select array_contains([1], 1) " + + "or array_contains([1], 2) " + + "or array_contains([1], 3)" + + "or array_contains([1], 4) and array_contains([1], 5);"; + Plan plan = PlanChecker.from(MemoTestUtils.createConnectContext()) + .analyze(sql) + .rewrite() + .getPlan(); + Expression expression = plan.child(0).getExpressions().get(0).child(0); + Assertions.assertTrue(expression instanceof Or); + Assertions.assertTrue(expression.child(0) instanceof ArraysOverlap); + Assertions.assertTrue(expression.child(1) instanceof And); + } + + @Test + void testAndOther() { + String sql = "select bin(0) == 1 " + + "or array_contains([1], 1) " + + "or array_contains([1], 2) " + + "or array_contains([1], 3) " + + "or array_contains([1], 4) and array_contains([1], 5);"; + Plan plan = PlanChecker.from(MemoTestUtils.createConnectContext()) + .analyze(sql) + .rewrite() + .getPlan(); + Expression expression = plan.child(0).getExpressions().get(0).child(0); + Assertions.assertTrue(expression instanceof Or); + Assertions.assertTrue(expression.child(0) instanceof Or); + Assertions.assertTrue(expression.child(0).child(0) instanceof ArraysOverlap); + Assertions.assertTrue(expression.child(0).child(1) instanceof EqualTo); + Assertions.assertTrue(expression.child(1) instanceof And); + } + + @Test + void testAndOverlap() { + String sql = "select array_contains([1], 0) " + + "or (array_contains([1], 1) " + + "and (array_contains([1], 2) " + + "or array_contains([1], 3) " + + "or array_contains([1], 4)));"; + Plan plan = PlanChecker.from(MemoTestUtils.createConnectContext()) + .analyze(sql) + .rewrite() + .getPlan(); + Expression expression = plan.child(0).getExpressions().get(0).child(0); + Assertions.assertEquals("(array_contains(array(1), 0) OR " + + "(array_contains(array(1), 1) AND arrays_overlap(array(1), array(2, 3, 4))))", + expression.toSql()); + } +}