[feature](Nereids): fold Cast(s as date/datetime) on FE (#24353)

cast("20210101" as Date) -> DateLiteral(2021, 1, 1)
This commit is contained in:
jakevin
2023-09-14 22:08:26 +08:00
committed by GitHub
parent f61e6483bf
commit d4756d3118
8 changed files with 71 additions and 39 deletions

View File

@ -73,6 +73,7 @@ import org.apache.doris.nereids.trees.expressions.literal.StringLiteral;
import org.apache.doris.nereids.trees.expressions.literal.VarcharLiteral;
import org.apache.doris.nereids.types.BooleanType;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.coercion.DateLikeType;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.qe.ConnectContext;
import org.apache.doris.qe.GlobalVariable;
@ -343,12 +344,19 @@ public class FoldConstantRuleOnFE extends AbstractExpressionRewriteRule {
return checkedExpr.get();
}
Expression child = cast.child();
DataType dataType = cast.getDataType();
// todo: process other null case
if (child.isNullLiteral()) {
return new NullLiteral(cast.getDataType());
return new NullLiteral(dataType);
} else if (child instanceof StringLikeLiteral && dataType instanceof DateLikeType) {
try {
return ((DateLikeType) dataType).fromString(((StringLikeLiteral) child).getStringValue());
} catch (AnalysisException t) {
return new NullLiteral(dataType);
}
}
try {
Expression castResult = child.checkedCastTo(cast.getDataType());
Expression castResult = child.checkedCastTo(dataType);
if (!Objects.equals(castResult, cast) && !Objects.equals(castResult, child)) {
castResult = rewrite(castResult, context);
}

View File

@ -17,6 +17,16 @@
package org.apache.doris.nereids.types.coercion;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.trees.expressions.literal.DateLiteral;
import org.apache.doris.nereids.trees.expressions.literal.DateTimeLiteral;
import org.apache.doris.nereids.trees.expressions.literal.DateTimeV2Literal;
import org.apache.doris.nereids.trees.expressions.literal.DateV2Literal;
import org.apache.doris.nereids.types.DateTimeType;
import org.apache.doris.nereids.types.DateTimeV2Type;
import org.apache.doris.nereids.types.DateType;
import org.apache.doris.nereids.types.DateV2Type;
import java.time.temporal.ChronoUnit;
import java.util.Calendar;
@ -43,4 +53,21 @@ public abstract class DateLikeType extends PrimitiveType {
Calendar from = toCalendar(low);
return ChronoUnit.DAYS.between(from.toInstant(), to.toInstant());
}
/**
* parse string to date like literal.
*/
public DateLiteral fromString(String s) {
if (this instanceof DateType) {
return new DateLiteral(s);
} else if (this instanceof DateV2Type) {
return new DateV2Literal(s);
} else if (this instanceof DateTimeType) {
return new DateTimeLiteral(s);
} else if (this instanceof DateTimeV2Type) {
return new DateTimeV2Literal((DateTimeV2Type) this, s);
} else {
throw new AnalysisException("unknown date like type");
}
}
}

View File

@ -204,13 +204,6 @@ public class ColumnStatistic {
return rowCount * 0.9 < ndv && ndv < rowCount * 1.1;
}
public ColumnStatistic copy() {
return new ColumnStatisticBuilder().setCount(count).setNdv(ndv).setAvgSizeByte(avgSizeByte)
.setNumNulls(numNulls).setDataSize(dataSize).setMinValue(minValue)
.setMaxValue(maxValue).setMinExpr(minExpr).setMaxExpr(maxExpr)
.setIsUnknown(isUnKnown).build();
}
public ColumnStatistic updateByLimit(long limit, double rowCount) {
double ratio = 0;
if (rowCount != 0) {

View File

@ -144,13 +144,6 @@ public class StatsDeriveResult {
return statsDeriveResult;
}
public StatsDeriveResult merge(StatsDeriveResult other) {
for (Entry<Id, ColumnStatistic> entry : other.getSlotIdToColumnStats().entrySet()) {
this.slotIdToColumnStats.put(entry.getKey(), entry.getValue().copy());
}
return this;
}
public StatsDeriveResult copy() {
return new StatsDeriveResult(this);
}

View File

@ -53,10 +53,10 @@ import org.junit.jupiter.api.Test;
import java.util.Locale;
public class FoldConstantTest extends ExpressionRewriteTestHelper {
class FoldConstantTest extends ExpressionRewriteTestHelper {
@Test
public void testCaseWhenFold() {
void testCaseWhenFold() {
executor = new ExpressionRuleExecutor(ImmutableList.of(FoldConstantRuleOnFE.INSTANCE));
// assertRewriteAfterTypeCoercion("case when 1 = 2 then 1 when '1' < 2 then 2 else 3 end", "2");
// assertRewriteAfterTypeCoercion("case when 1 = 2 then 1 when '1' > 2 then 2 end", "null");
@ -73,7 +73,7 @@ public class FoldConstantTest extends ExpressionRewriteTestHelper {
}
@Test
public void testInFold() {
void testInFold() {
executor = new ExpressionRuleExecutor(ImmutableList.of(FoldConstantRuleOnFE.INSTANCE));
assertRewriteAfterTypeCoercion("1 in (1,2,3,4)", "true");
// Type Coercion trans all to string.
@ -86,7 +86,7 @@ public class FoldConstantTest extends ExpressionRewriteTestHelper {
}
@Test
public void testLogicalFold() {
void testLogicalFold() {
executor = new ExpressionRuleExecutor(ImmutableList.of(FoldConstantRuleOnFE.INSTANCE));
assertRewriteAfterTypeCoercion("10 + 1 > 1 and 1 > 2", "false");
assertRewriteAfterTypeCoercion("10 + 1 > 1 and 1 < 2", "true");
@ -124,7 +124,7 @@ public class FoldConstantTest extends ExpressionRewriteTestHelper {
}
@Test
public void testIsNullFold() {
void testIsNullFold() {
executor = new ExpressionRuleExecutor(ImmutableList.of(FoldConstantRuleOnFE.INSTANCE));
assertRewriteAfterTypeCoercion("100 is null", "false");
assertRewriteAfterTypeCoercion("null is null", "true");
@ -135,7 +135,7 @@ public class FoldConstantTest extends ExpressionRewriteTestHelper {
}
@Test
public void testNotPredicateFold() {
void testNotPredicateFold() {
executor = new ExpressionRuleExecutor(ImmutableList.of(FoldConstantRuleOnFE.INSTANCE));
assertRewriteAfterTypeCoercion("not 1 > 2", "true");
assertRewriteAfterTypeCoercion("not null + 1 > 2", "null");
@ -143,7 +143,7 @@ public class FoldConstantTest extends ExpressionRewriteTestHelper {
}
@Test
public void testCastFold() {
void testCastFold() {
executor = new ExpressionRuleExecutor(ImmutableList.of(FoldConstantRuleOnFE.INSTANCE));
// cast '1' as tinyint
@ -154,7 +154,7 @@ public class FoldConstantTest extends ExpressionRewriteTestHelper {
}
@Test
public void testCompareFold() {
void testCompareFold() {
executor = new ExpressionRuleExecutor(ImmutableList.of(FoldConstantRuleOnFE.INSTANCE));
assertRewriteAfterTypeCoercion("'1' = 2", "false");
assertRewriteAfterTypeCoercion("1 = 2", "false");
@ -171,7 +171,7 @@ public class FoldConstantTest extends ExpressionRewriteTestHelper {
}
@Test
public void testArithmeticFold() {
void testArithmeticFold() {
executor = new ExpressionRuleExecutor(ImmutableList.of(FoldConstantRuleOnFE.INSTANCE));
assertRewrite("1 + 1", Literal.of((short) 2));
assertRewrite("1 - 1", Literal.of((short) 0));
@ -204,7 +204,7 @@ public class FoldConstantTest extends ExpressionRewriteTestHelper {
}
@Test
public void testTimestampFold() {
void testTimestampFold() {
executor = new ExpressionRuleExecutor(ImmutableList.of(FoldConstantRuleOnFE.INSTANCE));
String interval = "'1991-05-01' - interval 1 day";
Expression e7 = process((TimestampArithmetic) PARSER.parseExpression(interval));
@ -290,7 +290,7 @@ public class FoldConstantTest extends ExpressionRewriteTestHelper {
assertRewrite(process, process);
}
public Expression process(TimestampArithmetic arithmetic) {
Expression process(TimestampArithmetic arithmetic) {
String funcOpName;
if (arithmetic.getFuncName() == null) {
funcOpName = String.format("%sS_%s", arithmetic.getTimeUnit(),
@ -302,7 +302,7 @@ public class FoldConstantTest extends ExpressionRewriteTestHelper {
}
@Test
public void testDateTypeDateTimeArithmeticFunctions() {
void testDateTypeDateTimeArithmeticFunctions() {
DateLiteral dateLiteral = new DateLiteral("1999-12-31");
IntegerLiteral integerLiteral = new IntegerLiteral(30);
VarcharLiteral format = new VarcharLiteral("%Y-%m-%d");
@ -339,7 +339,7 @@ public class FoldConstantTest extends ExpressionRewriteTestHelper {
}
@Test
public void testDateTimeTypeDateTimeArithmeticFunctions() {
void testDateTimeTypeDateTimeArithmeticFunctions() {
DateTimeLiteral dateLiteral = new DateTimeLiteral("1999-12-31 23:59:59");
IntegerLiteral integerLiteral = new IntegerLiteral(30);
VarcharLiteral format = new VarcharLiteral("%Y-%m-%d");
@ -394,7 +394,7 @@ public class FoldConstantTest extends ExpressionRewriteTestHelper {
}
@Test
public void testDateV2TypeDateTimeArithmeticFunctions() {
void testDateV2TypeDateTimeArithmeticFunctions() {
DateV2Literal dateLiteral = new DateV2Literal("1999-12-31");
IntegerLiteral integerLiteral = new IntegerLiteral(30);
VarcharLiteral format = new VarcharLiteral("%Y-%m-%d");
@ -431,7 +431,7 @@ public class FoldConstantTest extends ExpressionRewriteTestHelper {
}
@Test
public void testDateTimeV2TypeDateTimeArithmeticFunctions() {
void testDateTimeV2TypeDateTimeArithmeticFunctions() {
DateTimeV2Literal dateLiteral = new DateTimeV2Literal(DateTimeV2Type.SYSTEM_DEFAULT, "1999-12-31 23:59:59");
IntegerLiteral integerLiteral = new IntegerLiteral(30);
VarcharLiteral format = new VarcharLiteral("%Y-%m-%d");
@ -496,7 +496,7 @@ public class FoldConstantTest extends ExpressionRewriteTestHelper {
}
@Test
public void testDateDiff() {
void testDateDiff() {
DateTimeLiteral dateTimeLiteral = new DateTimeLiteral("2001-12-31 00:00:01");
DateV2Literal dateV2Literal = new DateV2Literal("2001-12-31");
DateTimeV2Literal dateTimeV2Literal = new DateTimeV2Literal("2001-12-31 00:00:01");
@ -518,7 +518,7 @@ public class FoldConstantTest extends ExpressionRewriteTestHelper {
}
@Test
public void testDateTrunc() {
void testDateTrunc() {
DateTimeLiteral dateTimeLiteral = new DateTimeLiteral("2001-12-31 01:01:01");
DateTimeV2Literal dateTimeV2Literal = new DateTimeV2Literal("2001-12-31 01:01:01");
@ -556,7 +556,7 @@ public class FoldConstantTest extends ExpressionRewriteTestHelper {
}
@Test
public void testDateConstructFunction() {
void testDateConstructFunction() {
String[] answer = {
"2001-07-19", "6411-08-17", "0000-01-01", "'1977-06-03 17:57:24'",
"'1977-06-03'", "1008909293", "1008864000"
@ -590,7 +590,7 @@ public class FoldConstantTest extends ExpressionRewriteTestHelper {
}
@Test
public void testFoldNestedExpression() {
void testFoldNestedExpression() {
assertRewriteExpression("makedate(year('2010-04-10'), dayofyear('2010-04-11'))", "2010-04-11");
assertRewriteExpression("null in ('d', null)", "NULL");
assertRewriteExpression("null not in ('d', null)", "NULL");
@ -604,7 +604,18 @@ public class FoldConstantTest extends ExpressionRewriteTestHelper {
}
@Test
public void testFoldTypeOfNullLiteral() {
void testFoldCastStringToDate() {
assertRewriteExpression("cast('2021-01-01' as date)", "2021-01-01");
assertRewriteExpression("cast('20210101' as date)", "2021-01-01");
assertRewriteExpression("cast('2021-01-01T00:00:00' as date)", "2021-01-01");
assertRewriteExpression("cast('2021-01-01' as datetime)", "2021-01-01 00:00:00");
assertRewriteExpression("cast('20210101' as datetime)", "2021-01-01 00:00:00");
assertRewriteExpression("cast('2021-01-01T00:00:00' as datetime)", "2021-01-01 00:00:00");
assertRewriteExpression("cast ('2022-12-02 22:23:24.999999' as datetimev2(3))", "2022-12-02 22:23:24.999");
}
@Test
void testFoldTypeOfNullLiteral() {
String actualExpression = "append_trailing_char_if_absent(cast(version() as varchar), cast(null as varchar))";
ExpressionRewriteContext context = new ExpressionRewriteContext(
MemoTestUtils.createCascadesContext(new UnboundRelation(new RelationId(1), ImmutableList.of("test_table"))));