[Enhancement](Nereids) add some expr rewrite rule and plan rewrite rule of rewrite its expression (#10667)

# first: Add two expr rewrite rule:
1. remove duplicate expr
a = 1 and a = 1 -> a = 1

2. extract common expr
(a or b) and (a or c) -> a or (b and c)

# second: Add some plan rewrite rule of rewriting expr of operator
1. NormalizeExpressionOfPlan contains normalize expr rewrite rule. Using these normalizerule rewrite LogicalFilter、LogicalAggravate,LogicalProject,LogicalJoin exprs
2. OptimizeExpressionOfPlan contains optimize expr rewrite rule. Using these optimize rule rewrite LogicalFilter、LogicalAggravate,LogicalProject,LogicalJoin exprs
This commit is contained in:
shee
2022-07-21 12:35:28 +08:00
committed by GitHub
parent 072479fa21
commit f8ad2613cf
11 changed files with 451 additions and 11 deletions

View File

@ -84,6 +84,11 @@ public class UnboundSlot extends Slot implements Unbound {
return Objects.hash(nameParts);
}
@Override
public int hashCode() {
return Objects.hash(nameParts.toArray());
}
@Override
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
return visitor.visitUnboundSlot(this, context);

View File

@ -46,6 +46,11 @@ public enum RuleType {
COLUMN_PRUNE_FILTER_CHILD(RuleTypeClass.REWRITE),
COLUMN_PRUNE_SORT_CHILD(RuleTypeClass.REWRITE),
COLUMN_PRUNE_JOIN_CHILD(RuleTypeClass.REWRITE),
// expression of plan rewrite
REWRITE_PROJECT_EXPRESSION(RuleTypeClass.REWRITE),
REWRITE_AGG_EXPRESSION(RuleTypeClass.REWRITE),
REWRITE_FILTER_EXPRESSION(RuleTypeClass.REWRITE),
REWRITE_JOIN_EXPRESSION(RuleTypeClass.REWRITE),
REORDER_JOIN(RuleTypeClass.REWRITE),

View File

@ -0,0 +1,121 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.expression.rewrite;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory;
import org.apache.doris.nereids.rules.rewrite.RewriteRuleFactory;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import com.google.common.collect.ImmutableList;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Collectors;
/**
* expression of plan rewrite rule.
*/
public class ExpressionOfPlanRewrite implements RewriteRuleFactory {
private final ExpressionRuleExecutor rewriter;
public ExpressionOfPlanRewrite(ExpressionRuleExecutor rewriter) {
this.rewriter = Objects.requireNonNull(rewriter, "rewriter is null");
}
@Override
public List<Rule> buildRules() {
return ImmutableList.of(
new ProjectExpressionRewrite().build(),
new AggExpressionRewrite().build(),
new FilterExpressionRewrite().build(),
new JoinExpressionRewrite().build());
}
private class ProjectExpressionRewrite extends OneRewriteRuleFactory {
@Override
public Rule build() {
return logicalProject().then(project -> {
List<NamedExpression> projects = project.getProjects();
List<NamedExpression> newProjects = projects.stream()
.map(expr -> (NamedExpression) rewriter.rewrite(expr)).collect(Collectors.toList());
if (projects.containsAll(newProjects)) {
return project;
}
return new LogicalProject<>(newProjects, project.child());
}).toRule(RuleType.REWRITE_PROJECT_EXPRESSION);
}
}
private class FilterExpressionRewrite extends OneRewriteRuleFactory {
@Override
public Rule build() {
return logicalFilter().then(filter -> {
Expression newExpr = rewriter.rewrite(filter.getPredicates());
if (newExpr.equals(filter.getPredicates())) {
return filter;
}
return new LogicalFilter<>(newExpr, filter.child());
}).toRule(RuleType.REWRITE_FILTER_EXPRESSION);
}
}
private class AggExpressionRewrite extends OneRewriteRuleFactory {
@Override
public Rule build() {
return logicalAggregate().then(agg -> {
List<Expression> groupByExprs = agg.getGroupByExpressionList();
List<Expression> newGroupByExprs = rewriter.rewrite(groupByExprs);
List<NamedExpression> outputExpressions = agg.getOutputExpressionList();
List<NamedExpression> newOutputExpressions = outputExpressions.stream()
.map(expr -> (NamedExpression) rewriter.rewrite(expr)).collect(Collectors.toList());
if (outputExpressions.containsAll(newOutputExpressions)) {
return agg;
}
return new LogicalAggregate<>(newGroupByExprs, newOutputExpressions, agg.isDisassembled(),
agg.getAggPhase(), agg.child());
}).toRule(RuleType.REWRITE_AGG_EXPRESSION);
}
}
private class JoinExpressionRewrite extends OneRewriteRuleFactory {
@Override
public Rule build() {
return logicalJoin().then(join -> {
Optional<Expression> condition = join.getCondition();
if (!condition.isPresent()) {
return join;
}
Expression newCondition = rewriter.rewrite(condition.get());
if (newCondition.equals(condition.get())) {
return join;
}
return new LogicalJoin<>(join.getJoinType(), Optional.of(newCondition), join.left(), join.right());
}).toRule(RuleType.REWRITE_JOIN_EXPRESSION);
}
}
}

View File

@ -26,6 +26,7 @@ import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import java.util.List;
import java.util.stream.Collectors;
/**
* Expression rewrite entry, which contains all rewrite rules.
@ -56,6 +57,10 @@ public class ExpressionRuleExecutor {
this.ctx = new ExpressionRewriteContext();
}
public List<Expression> rewrite(List<Expression> exprs) {
return exprs.stream().map(this::rewrite).collect(Collectors.toList());
}
/**
* Given an expression, returns a rewritten expression.
*/

View File

@ -0,0 +1,41 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.expression.rewrite;
import org.apache.doris.nereids.rules.expression.rewrite.rules.BetweenToCompoundRule;
import org.apache.doris.nereids.rules.expression.rewrite.rules.NormalizeBinaryPredicatesRule;
import com.google.common.collect.ImmutableList;
import java.util.List;
/**
* normalize expression of plan rule set.
*/
public class NormalizeExpressionOfPlan extends ExpressionOfPlanRewrite {
public static final List<ExpressionRewriteRule> NORMALIZE_REWRITE_RULES = ImmutableList.of(
NormalizeBinaryPredicatesRule.INSTANCE,
BetweenToCompoundRule.INSTANCE
);
private static final ExpressionRuleExecutor EXECUTOR = new ExpressionRuleExecutor(NORMALIZE_REWRITE_RULES);
public NormalizeExpressionOfPlan() {
super(EXECUTOR);
}
}

View File

@ -0,0 +1,41 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.expression.rewrite;
import org.apache.doris.nereids.rules.expression.rewrite.rules.DistinctPredicatesRule;
import org.apache.doris.nereids.rules.expression.rewrite.rules.ExtractCommonFactorRule;
import org.apache.doris.nereids.rules.expression.rewrite.rules.SimplifyNotExprRule;
import com.google.common.collect.ImmutableList;
import java.util.List;
/**
* optimize expression of plan rule set.
*/
public class OptimizeExpressionOfPlan extends ExpressionOfPlanRewrite {
public static final List<ExpressionRewriteRule> OPTIMIZE_REWRITE_RULES = ImmutableList.of(
SimplifyNotExprRule.INSTANCE,
ExtractCommonFactorRule.INSTANCE,
DistinctPredicatesRule.INSTANCE);
private static final ExpressionRuleExecutor EXECUTOR = new ExpressionRuleExecutor(OPTIMIZE_REWRITE_RULES);
public OptimizeExpressionOfPlan() {
super(EXECUTOR);
}
}

View File

@ -0,0 +1,52 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.expression.rewrite.rules;
import org.apache.doris.nereids.rules.expression.rewrite.AbstractExpressionRewriteRule;
import org.apache.doris.nereids.rules.expression.rewrite.ExpressionRewriteContext;
import org.apache.doris.nereids.trees.expressions.CompoundPredicate;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.util.ExpressionUtils;
import com.google.common.collect.Lists;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Set;
/**
* Remove redundant expr for 'CompoundPredicate'.
* for example:
* transform (a = 1) and (b > 2) and (a = 1) to (a = 1) and (b > 2)
* transform (a = 1) or (a = 1) to (a = 1)
*/
public class DistinctPredicatesRule extends AbstractExpressionRewriteRule {
public static final DistinctPredicatesRule INSTANCE = new DistinctPredicatesRule();
@Override
public Expression visitCompoundPredicate(CompoundPredicate expr, ExpressionRewriteContext context) {
List<Expression> extractExpressions = ExpressionUtils.extract(expr);
Set<Expression> distinctExpressions = new LinkedHashSet<>(extractExpressions);
if (distinctExpressions.size() != extractExpressions.size()) {
return ExpressionUtils.combine(expr.getType(), Lists.newArrayList(distinctExpressions));
}
return expr;
}
}

View File

@ -0,0 +1,76 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.expression.rewrite.rules;
import org.apache.doris.nereids.rules.expression.rewrite.AbstractExpressionRewriteRule;
import org.apache.doris.nereids.rules.expression.rewrite.ExpressionRewriteContext;
import org.apache.doris.nereids.trees.expressions.CompoundPredicate;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.util.ExpressionUtils;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import java.util.Collections;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
/**
* Extract common expr for `CompoundPredicate`.
* for example:
* transform (a or b) and (a or c) to a or (b and c)
* transform (a and b) or (a and c) to a and (b or c)
*/
public class ExtractCommonFactorRule extends AbstractExpressionRewriteRule {
public static final ExtractCommonFactorRule INSTANCE = new ExtractCommonFactorRule();
@Override
public Expression visitCompoundPredicate(CompoundPredicate expr, ExpressionRewriteContext context) {
Expression rewrittenChildren = ExpressionUtils.combine(expr.getType(), ExpressionUtils.extract(expr).stream()
.map(predicate -> rewrite(predicate, context)).collect(Collectors.toList()));
if (!(rewrittenChildren instanceof CompoundPredicate)) {
return rewrittenChildren;
}
CompoundPredicate compoundPredicate = (CompoundPredicate) rewrittenChildren;
List<List<Expression>> partitions = ExpressionUtils.extract(compoundPredicate).stream()
.map(predicate -> predicate instanceof CompoundPredicate ? ExpressionUtils.extract(
(CompoundPredicate) predicate) : Lists.newArrayList(predicate)).collect(Collectors.toList());
Set<Expression> commons = partitions.stream().map(predicates -> predicates.stream().collect(Collectors.toSet()))
.reduce(Sets::intersection).orElse(Collections.emptySet());
List<List<Expression>> uncorrelated = partitions.stream()
.map(predicates -> predicates.stream().filter(p -> !commons.contains(p)).collect(Collectors.toList()))
.collect(Collectors.toList());
Expression combineUncorrelated = ExpressionUtils.combine(compoundPredicate.getType(),
uncorrelated.stream().map(predicates -> ExpressionUtils.combine(compoundPredicate.flip(), predicates))
.collect(Collectors.toList()));
List<Expression> finalCompound = Lists.newArrayList(commons);
finalCompound.add(combineUncorrelated);
return ExpressionUtils.combine(compoundPredicate.flip(), finalCompound);
}
}

View File

@ -27,7 +27,7 @@ import java.util.Objects;
* Compound predicate expression.
* Such as &&,||,AND,OR.
*/
public class CompoundPredicate extends Expression implements BinaryExpression {
public abstract class CompoundPredicate extends Expression implements BinaryExpression {
/**
* Desc: Constructor for CompoundPredicate.
@ -58,7 +58,7 @@ public class CompoundPredicate extends Expression implements BinaryExpression {
@Override
public Expression withChildren(List<Expression> children) {
return new CompoundPredicate(getType(), children.get(0), children.get(1));
throw new RuntimeException("The withChildren() method is not implemented");
}
@Override

View File

@ -17,10 +17,12 @@
package org.apache.doris.nereids.util;
import org.apache.doris.nereids.trees.expressions.And;
import org.apache.doris.nereids.trees.expressions.BooleanLiteral;
import org.apache.doris.nereids.trees.expressions.CompoundPredicate;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.ExpressionType;
import org.apache.doris.nereids.trees.expressions.Or;
import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
@ -98,8 +100,8 @@ public class ExpressionUtils {
}
}
Optional<Expression> result =
distinctExpressions.stream().reduce((left, right) -> new CompoundPredicate(op, left, right));
Optional<Expression> result = distinctExpressions.stream()
.reduce(op == ExpressionType.AND ? And::new : Or::new);
return result.orElse(new BooleanLiteral(op == ExpressionType.AND));
}
}

View File

@ -18,10 +18,14 @@
package org.apache.doris.nereids.rules.expression.rewrite;
import org.apache.doris.nereids.parser.NereidsParser;
import org.apache.doris.nereids.rules.expression.rewrite.rules.BetweenToCompoundRule;
import org.apache.doris.nereids.rules.expression.rewrite.rules.DistinctPredicatesRule;
import org.apache.doris.nereids.rules.expression.rewrite.rules.ExtractCommonFactorRule;
import org.apache.doris.nereids.rules.expression.rewrite.rules.NormalizeBinaryPredicatesRule;
import org.apache.doris.nereids.rules.expression.rewrite.rules.SimplifyNotExprRule;
import org.apache.doris.nereids.trees.expressions.Expression;
import com.google.common.collect.ImmutableList;
import org.junit.Assert;
import org.junit.Test;
@ -36,13 +40,24 @@ public class ExpressionRewriteTest {
public void testNotRewrite() {
executor = new ExpressionRuleExecutor(SimplifyNotExprRule.INSTANCE);
assertRewrite("not x > y", "x <= y");
assertRewrite("not x < y", "x >= y");
assertRewrite("not x >= y", "x < y");
assertRewrite("not x <= y", "x > y");
assertRewrite("not x = y", "not x = y");
assertRewrite("not not x > y", "x > y");
assertRewrite("not not not x > y", "x <= y");
assertRewrite("not x", "not x");
assertRewrite("not not x", "x");
assertRewrite("not not not x", "not x");
assertRewrite("not (x > y)", "x <= y");
assertRewrite("not (x < y)", "x >= y");
assertRewrite("not (x >= y)", "x < y");
assertRewrite("not (x <= y)", "x > y");
assertRewrite("not (x = y)", "not (x = y)");
assertRewrite("not not (x > y)", "x > y");
assertRewrite("not not not (x > y)", "x <= y");
assertRewrite("not not not (x > (not not y))", "x <= y");
assertRewrite("not (x > (not not y))", "x <= y");
assertRewrite("not (a and b)", "(not a) or (not b)");
assertRewrite("not (a or b)", "(not a) and (not b)");
assertRewrite("not (a and b and (c or d))", "(not a) or (not b) or ((not c) and (not d))");
}
@Test
@ -56,6 +71,83 @@ public class ExpressionRewriteTest {
assertRewrite("2 = x", "x = 2");
}
@Test
public void testDistinctPredicatesRewrite() {
executor = new ExpressionRuleExecutor(DistinctPredicatesRule.INSTANCE);
assertRewrite("a = 1", "a = 1");
assertRewrite("a = 1 and a = 1", "a = 1");
assertRewrite("a = 1 and b > 2 and a = 1", "a = 1 and b > 2");
assertRewrite("a = 1 and a = 1 and b > 2 and a = 1 and a = 1", "a = 1 and b > 2");
assertRewrite("a = 1 or a = 1", "a = 1");
assertRewrite("a = 1 or a = 1 or b >= 1", "a = 1 or b >= 1");
}
@Test
public void testExtractCommonFactorRewrite() {
executor = new ExpressionRuleExecutor(ExtractCommonFactorRule.INSTANCE);
assertRewrite("a", "a");
assertRewrite("a = 1", "a = 1");
assertRewrite("a and b", "a and b");
assertRewrite("a = 1 and b > 2", "a = 1 and b > 2");
assertRewrite("(a and b) or (c and d)", "(a and b) or (c and d)");
assertRewrite("(a and b) and (c and d)", "((a and b) and c) and d");
assertRewrite("(a or b) and (a or c)", "a or (b and c)");
assertRewrite("(a and b) or (a and c)", "a and (b or c)");
assertRewrite("(a or b) and (a or c) and (a or d)", "a or (b and c and d)");
assertRewrite("(a and b) or (a and c) or (a and d)", "a and (b or c or d)");
assertRewrite("(a and b) or (a or c) or (a and d)", "((((a and b) or a) or c) or (a and d))");
assertRewrite("(a and b) or (a and c) or (a or d)", "(((a and b) or (a and c) or a) or d))");
assertRewrite("(a or b) or (a and c) or (a and d)", "(a or b) or (a and c) or (a and d)");
assertRewrite("(a or b) or (a and c) or (a or d)", "(((a or b) or (a and c)) or d)");
assertRewrite("(a or b) or (a or c) or (a and d)", "((a or b) or c) or (a and d)");
assertRewrite("(a or b) or (a or c) or (a or d)", "(((a or b) or c) or d)");
assertRewrite("(a and b) or (d and c) or (d and e)", "(a and b) or (d and c) or (d and e)");
assertRewrite("(a or b) and (d or c) and (d or e)", "(a or b) and (d or c) and (d or e)");
assertRewrite("(a and b) or ((d and c) and (d and e))", "(a and b) or (d and c and e)");
assertRewrite("(a or b) and ((d or c) or (d or e))", "(a or b) and (d or c or e)");
assertRewrite("(a and b) or (a and b and c)", "a and b");
assertRewrite("(a or b) and (a or b or c)", "a or b");
assertRewrite("a and true", "a");
assertRewrite("a or false", "a");
assertRewrite("a and false", "false");
assertRewrite("a or true", "true");
assertRewrite("a or false or false or false", "a");
assertRewrite("a and true and true and true", "a");
assertRewrite("(a and b) or a ", "a");
assertRewrite("(a or b) and a ", "a");
assertRewrite("(a and b) or (a and true)", "a");
assertRewrite("(a or b) and (a and true)", "a");
assertRewrite("(a or b) and (a or true)", "a or b");
}
@Test
public void testBetweenToCompoundRule() {
executor = new ExpressionRuleExecutor(ImmutableList.of(BetweenToCompoundRule.INSTANCE, SimplifyNotExprRule.INSTANCE));
assertRewrite(" a between c and d", "(a >= c) and (a <= d)");
assertRewrite(" a not between c and d)", "(a < c) or (a > d)");
}
private void assertRewrite(String expression, String expected) {
Expression needRewriteExpression = PARSER.parseExpression(expression);
Expression expectedExpression = PARSER.parseExpression(expected);