[feature](nereids): SimplifyCastRule (#11630)
Remove redundant cast like ``` cast(1 as int) -> 1 ```
This commit is contained in:
@ -19,6 +19,7 @@ package org.apache.doris.nereids.rules.expression.rewrite;
|
||||
|
||||
import org.apache.doris.nereids.rules.expression.rewrite.rules.BetweenToCompoundRule;
|
||||
import org.apache.doris.nereids.rules.expression.rewrite.rules.NormalizeBinaryPredicatesRule;
|
||||
import org.apache.doris.nereids.rules.expression.rewrite.rules.SimplifyCastRule;
|
||||
import org.apache.doris.nereids.rules.expression.rewrite.rules.SimplifyNotExprRule;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
@ -33,7 +34,8 @@ public class ExpressionNormalization extends ExpressionRewrite {
|
||||
public static final List<ExpressionRewriteRule> NORMALIZE_REWRITE_RULES = ImmutableList.of(
|
||||
NormalizeBinaryPredicatesRule.INSTANCE,
|
||||
BetweenToCompoundRule.INSTANCE,
|
||||
SimplifyNotExprRule.INSTANCE
|
||||
SimplifyNotExprRule.INSTANCE,
|
||||
SimplifyCastRule.INSTANCE
|
||||
);
|
||||
private static final ExpressionRuleExecutor EXECUTOR = new ExpressionRuleExecutor(NORMALIZE_REWRITE_RULES);
|
||||
|
||||
|
||||
@ -0,0 +1,60 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package org.apache.doris.nereids.rules.expression.rewrite.rules;
|
||||
|
||||
import org.apache.doris.nereids.rules.expression.rewrite.AbstractExpressionRewriteRule;
|
||||
import org.apache.doris.nereids.rules.expression.rewrite.ExpressionRewriteContext;
|
||||
import org.apache.doris.nereids.trees.expressions.Cast;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
|
||||
|
||||
/**
|
||||
* Rewrite rule of simplify CAST expression.
|
||||
* Remove redundant cast like
|
||||
* - cast(1 as int) -> 1.
|
||||
* Merge cast like
|
||||
* - cast(cast(1 as bigint) as string) -> cast(1 as string).
|
||||
*/
|
||||
public class SimplifyCastRule extends AbstractExpressionRewriteRule {
|
||||
|
||||
public static SimplifyCastRule INSTANCE = new SimplifyCastRule();
|
||||
|
||||
@Override
|
||||
public Expression visitCast(Cast origin, ExpressionRewriteContext context) {
|
||||
return simplify(origin);
|
||||
}
|
||||
|
||||
private Expression simplify(Cast cast) {
|
||||
Expression source = cast.left();
|
||||
// simplify inside
|
||||
if (source instanceof Cast) {
|
||||
source = simplify((Cast) source);
|
||||
}
|
||||
|
||||
// remove redundant cast
|
||||
// CAST(value as type), value is type
|
||||
if (cast.getDataType().equals(source.getDataType())) {
|
||||
return source;
|
||||
}
|
||||
|
||||
if (source != cast.left()) {
|
||||
return new Cast(source, cast.right());
|
||||
}
|
||||
return cast;
|
||||
}
|
||||
}
|
||||
@ -34,9 +34,18 @@ public class Cast extends Expression implements BinaryExpression {
|
||||
super(child, new StringLiteral(type));
|
||||
}
|
||||
|
||||
public Cast(Expression child, StringLiteral type) {
|
||||
super(child, type);
|
||||
}
|
||||
|
||||
@Override
|
||||
public StringLiteral right() {
|
||||
return (StringLiteral) BinaryExpression.super.right();
|
||||
}
|
||||
|
||||
@Override
|
||||
public DataType getDataType() {
|
||||
StringLiteral type = (StringLiteral) right();
|
||||
StringLiteral type = right();
|
||||
return DataType.convertFromString(type.getValue());
|
||||
}
|
||||
|
||||
@ -59,7 +68,7 @@ public class Cast extends Expression implements BinaryExpression {
|
||||
|
||||
@Override
|
||||
public String toSql() throws UnboundException {
|
||||
return "CAST(" + left().toSql() + " AS " + ((StringLiteral) right()).getValue() + ")";
|
||||
return "CAST(" + left().toSql() + " AS " + right().getValue() + ")";
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@ -22,6 +22,7 @@ import org.apache.doris.nereids.rules.expression.rewrite.rules.BetweenToCompound
|
||||
import org.apache.doris.nereids.rules.expression.rewrite.rules.DistinctPredicatesRule;
|
||||
import org.apache.doris.nereids.rules.expression.rewrite.rules.ExtractCommonFactorRule;
|
||||
import org.apache.doris.nereids.rules.expression.rewrite.rules.NormalizeBinaryPredicatesRule;
|
||||
import org.apache.doris.nereids.rules.expression.rewrite.rules.SimplifyCastRule;
|
||||
import org.apache.doris.nereids.rules.expression.rewrite.rules.SimplifyNotExprRule;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
|
||||
@ -144,11 +145,27 @@ public class ExpressionRewriteTest {
|
||||
public void testBetweenToCompoundRule() {
|
||||
executor = new ExpressionRuleExecutor(ImmutableList.of(BetweenToCompoundRule.INSTANCE, SimplifyNotExprRule.INSTANCE));
|
||||
|
||||
assertRewrite(" a between c and d", "(a >= c) and (a <= d)");
|
||||
assertRewrite(" a not between c and d)", "(a < c) or (a > d)");
|
||||
assertRewrite("a between c and d", "(a >= c) and (a <= d)");
|
||||
assertRewrite("a not between c and d)", "(a < c) or (a > d)");
|
||||
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSimplifyCastRule() {
|
||||
executor = new ExpressionRuleExecutor(ImmutableList.of(SimplifyCastRule.INSTANCE));
|
||||
|
||||
// deduplicate
|
||||
assertRewrite("CAST(1 AS int)", "1");
|
||||
assertRewrite("CAST(\"str\" AS string)", "\"str\"");
|
||||
assertRewrite("CAST(CAST(1 AS int) AS int)", "1");
|
||||
|
||||
// deduplicate inside
|
||||
assertRewrite("CAST(CAST(\"str\" AS string) AS double)", "CAST(\"str\" AS double)");
|
||||
assertRewrite("CAST(CAST(1 AS int) AS double)", "CAST(1 AS double)");
|
||||
|
||||
}
|
||||
|
||||
|
||||
private void assertRewrite(String expression, String expected) {
|
||||
Expression needRewriteExpression = PARSER.parseExpression(expression);
|
||||
Expression expectedExpression = PARSER.parseExpression(expected);
|
||||
|
||||
Reference in New Issue
Block a user