[feature](nereids): SimplifyCastRule (#11630)

Remove redundant cast like
```
cast(1 as int) -> 1
```
This commit is contained in:
jakevin
2022-08-15 12:41:36 +08:00
committed by GitHub
parent 74b0d0da88
commit 1e6b8cd1a9
4 changed files with 93 additions and 5 deletions

View File

@ -19,6 +19,7 @@ package org.apache.doris.nereids.rules.expression.rewrite;
import org.apache.doris.nereids.rules.expression.rewrite.rules.BetweenToCompoundRule;
import org.apache.doris.nereids.rules.expression.rewrite.rules.NormalizeBinaryPredicatesRule;
import org.apache.doris.nereids.rules.expression.rewrite.rules.SimplifyCastRule;
import org.apache.doris.nereids.rules.expression.rewrite.rules.SimplifyNotExprRule;
import com.google.common.collect.ImmutableList;
@ -33,7 +34,8 @@ public class ExpressionNormalization extends ExpressionRewrite {
public static final List<ExpressionRewriteRule> NORMALIZE_REWRITE_RULES = ImmutableList.of(
NormalizeBinaryPredicatesRule.INSTANCE,
BetweenToCompoundRule.INSTANCE,
SimplifyNotExprRule.INSTANCE
SimplifyNotExprRule.INSTANCE,
SimplifyCastRule.INSTANCE
);
private static final ExpressionRuleExecutor EXECUTOR = new ExpressionRuleExecutor(NORMALIZE_REWRITE_RULES);

View File

@ -0,0 +1,60 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.expression.rewrite.rules;
import org.apache.doris.nereids.rules.expression.rewrite.AbstractExpressionRewriteRule;
import org.apache.doris.nereids.rules.expression.rewrite.ExpressionRewriteContext;
import org.apache.doris.nereids.trees.expressions.Cast;
import org.apache.doris.nereids.trees.expressions.Expression;
/**
* Rewrite rule of simplify CAST expression.
* Remove redundant cast like
* - cast(1 as int) -> 1.
* Merge cast like
* - cast(cast(1 as bigint) as string) -> cast(1 as string).
*/
public class SimplifyCastRule extends AbstractExpressionRewriteRule {
public static SimplifyCastRule INSTANCE = new SimplifyCastRule();
@Override
public Expression visitCast(Cast origin, ExpressionRewriteContext context) {
return simplify(origin);
}
private Expression simplify(Cast cast) {
Expression source = cast.left();
// simplify inside
if (source instanceof Cast) {
source = simplify((Cast) source);
}
// remove redundant cast
// CAST(value as type), value is type
if (cast.getDataType().equals(source.getDataType())) {
return source;
}
if (source != cast.left()) {
return new Cast(source, cast.right());
}
return cast;
}
}

View File

@ -34,9 +34,18 @@ public class Cast extends Expression implements BinaryExpression {
super(child, new StringLiteral(type));
}
public Cast(Expression child, StringLiteral type) {
super(child, type);
}
@Override
public StringLiteral right() {
return (StringLiteral) BinaryExpression.super.right();
}
@Override
public DataType getDataType() {
StringLiteral type = (StringLiteral) right();
StringLiteral type = right();
return DataType.convertFromString(type.getValue());
}
@ -59,7 +68,7 @@ public class Cast extends Expression implements BinaryExpression {
@Override
public String toSql() throws UnboundException {
return "CAST(" + left().toSql() + " AS " + ((StringLiteral) right()).getValue() + ")";
return "CAST(" + left().toSql() + " AS " + right().getValue() + ")";
}
@Override

View File

@ -22,6 +22,7 @@ import org.apache.doris.nereids.rules.expression.rewrite.rules.BetweenToCompound
import org.apache.doris.nereids.rules.expression.rewrite.rules.DistinctPredicatesRule;
import org.apache.doris.nereids.rules.expression.rewrite.rules.ExtractCommonFactorRule;
import org.apache.doris.nereids.rules.expression.rewrite.rules.NormalizeBinaryPredicatesRule;
import org.apache.doris.nereids.rules.expression.rewrite.rules.SimplifyCastRule;
import org.apache.doris.nereids.rules.expression.rewrite.rules.SimplifyNotExprRule;
import org.apache.doris.nereids.trees.expressions.Expression;
@ -144,11 +145,27 @@ public class ExpressionRewriteTest {
public void testBetweenToCompoundRule() {
executor = new ExpressionRuleExecutor(ImmutableList.of(BetweenToCompoundRule.INSTANCE, SimplifyNotExprRule.INSTANCE));
assertRewrite(" a between c and d", "(a >= c) and (a <= d)");
assertRewrite(" a not between c and d)", "(a < c) or (a > d)");
assertRewrite("a between c and d", "(a >= c) and (a <= d)");
assertRewrite("a not between c and d)", "(a < c) or (a > d)");
}
@Test
public void testSimplifyCastRule() {
executor = new ExpressionRuleExecutor(ImmutableList.of(SimplifyCastRule.INSTANCE));
// deduplicate
assertRewrite("CAST(1 AS int)", "1");
assertRewrite("CAST(\"str\" AS string)", "\"str\"");
assertRewrite("CAST(CAST(1 AS int) AS int)", "1");
// deduplicate inside
assertRewrite("CAST(CAST(\"str\" AS string) AS double)", "CAST(\"str\" AS double)");
assertRewrite("CAST(CAST(1 AS int) AS double)", "CAST(1 AS double)");
}
private void assertRewrite(String expression, String expected) {
Expression needRewriteExpression = PARSER.parseExpression(expression);
Expression expectedExpression = PARSER.parseExpression(expected);