[enhancement](Nereids) refactor expression rewriter to pattern match (#32617)
this pr can improve the performance of the nereids planner, in plan stage. 1. refactor expression rewriter to pattern match, so the lots of expression rewrite rules can criss-crossed apply in a big bottom-up iteration, and rewrite until the expression became stable. now we can process more cases because original there has no loop, and sometimes only process the top expression, like `SimplifyArithmeticRule`. 2. replace `Collection.stream()` to `ImmutableXxx.Builder` to avoid useless method call 3. loop unrolling some codes, like `Expression.<init>`, `PlanTreeRewriteBottomUpJob.pushChildrenJobs` 4. use type/arity specified-code, like `OneRangePartitionEvaluator.toNereidsLiterals()`, `PartitionRangeExpander.tryExpandRange()`, `PartitionRangeExpander.enumerableCount()` 5. refactor `ExtractCommonFactorRule`, now we can extract more cases, and I fix the deed loop when use `ExtractCommonFactorRule` and `SimplifyRange` in one iterative, because `SimplifyRange` generate right deep tree, but `ExtractCommonFactorRule` generate left deep tree 6. refactor `FoldConstantRuleOnFE`, support visitor/pattern match mode, in ExpressionNormalization, pattern match can criss-crossed apply with other rules; in PartitionPruner, visitor can evaluate expression faster 7. lazy compute and cache some operation 8. use int field to compare date 9. use BitSet to find disableNereidsRules 10. two level loop usually faster then build Multimap when bind slot in Scope, so I revert the code 11. `PlanTreeRewriteBottomUpJob` don't need to clearStatePhase any more ### test case 100 threads parallel continuous send this sql which query an empty table, test in my mac machine(m2 chip, 8 core), enable sql cache ```sql select count(1),date_format(time_col,'%Y%m%d'),varchar_col1 from tbl where partition_date>'2024-02-15' and (varchar_col2 ='73130' or varchar_col3='73130') and time_col>'2024-03-04' and time_col<'2024-03-05' group by date_format(time_col,'%Y%m%d'),varchar_col1 order by date_format(time_col,'%Y%m%d') desc, varchar_col1 desc,count(1) asc limit 1000 ``` before this pr: 3100 peak QPS, about 2700 avg QPS after this pr: 4800 peak QPS, about 4400 avg QPS (cherry picked from commit 7338683fdbdf77711f2ce61e580c19f4ea100723)
This commit is contained in:
@ -56,7 +56,7 @@ public class HyperGraphTest {
|
||||
+ "LOGICAL_OLAP_SCAN0 -> LOGICAL_OLAP_SCAN4 [label=\"1.00\",arrowhead=none]\n"
|
||||
+ "}\n";
|
||||
|
||||
Assertions.assertEquals(dottyGraph, target);
|
||||
Assertions.assertEquals(target, dottyGraph);
|
||||
}
|
||||
|
||||
@Test
|
||||
@ -85,12 +85,12 @@ public class HyperGraphTest {
|
||||
+ " LOGICAL_OLAP_SCAN3 [label=\"LOGICAL_OLAP_SCAN3 \n"
|
||||
+ " rowCount=40.00\"];\n"
|
||||
+ "LOGICAL_OLAP_SCAN0 -> LOGICAL_OLAP_SCAN1 [label=\"1.00\",arrowhead=none]\n"
|
||||
+ "LOGICAL_OLAP_SCAN1 -> LOGICAL_OLAP_SCAN2 [label=\"1.00\",arrowhead=none]\n"
|
||||
+ "LOGICAL_OLAP_SCAN0 -> LOGICAL_OLAP_SCAN2 [label=\"1.00\",arrowhead=none]\n"
|
||||
+ "LOGICAL_OLAP_SCAN2 -> LOGICAL_OLAP_SCAN3 [label=\"1.00\",arrowhead=none]\n"
|
||||
+ "LOGICAL_OLAP_SCAN1 -> LOGICAL_OLAP_SCAN2 [label=\"1.00\",arrowhead=none]\n"
|
||||
+ "LOGICAL_OLAP_SCAN0 -> LOGICAL_OLAP_SCAN3 [label=\"1.00\",arrowhead=none]\n"
|
||||
+ "LOGICAL_OLAP_SCAN2 -> LOGICAL_OLAP_SCAN3 [label=\"1.00\",arrowhead=none]\n"
|
||||
+ "}\n";
|
||||
Assertions.assertEquals(dottyGraph, target);
|
||||
Assertions.assertEquals(target, dottyGraph);
|
||||
}
|
||||
|
||||
@Test
|
||||
@ -101,8 +101,8 @@ public class HyperGraphTest {
|
||||
for (int i = 0; i < 10; i++) {
|
||||
HyperGraphBuilder hyperGraphBuilder = new HyperGraphBuilder();
|
||||
HyperGraph hyperGraph = hyperGraphBuilder.randomBuildWith(tableNum, edgeNum);
|
||||
Assertions.assertEquals(hyperGraph.getNodes().size(), tableNum);
|
||||
Assertions.assertEquals(hyperGraph.getJoinEdges().size(), edgeNum);
|
||||
Assertions.assertEquals(tableNum, hyperGraph.getNodes().size());
|
||||
Assertions.assertEquals(edgeNum, hyperGraph.getJoinEdges().size());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -22,9 +22,11 @@ import org.apache.doris.nereids.rules.expression.rules.ExtractCommonFactorRule;
|
||||
import org.apache.doris.nereids.rules.expression.rules.InPredicateDedup;
|
||||
import org.apache.doris.nereids.rules.expression.rules.InPredicateToEqualToRule;
|
||||
import org.apache.doris.nereids.rules.expression.rules.NormalizeBinaryPredicatesRule;
|
||||
import org.apache.doris.nereids.rules.expression.rules.OrToIn;
|
||||
import org.apache.doris.nereids.rules.expression.rules.SimplifyCastRule;
|
||||
import org.apache.doris.nereids.rules.expression.rules.SimplifyDecimalV3Comparison;
|
||||
import org.apache.doris.nereids.rules.expression.rules.SimplifyNotExprRule;
|
||||
import org.apache.doris.nereids.rules.expression.rules.SimplifyRange;
|
||||
import org.apache.doris.nereids.trees.expressions.Cast;
|
||||
import org.apache.doris.nereids.trees.expressions.EqualTo;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
@ -54,7 +56,9 @@ class ExpressionRewriteTest extends ExpressionRewriteTestHelper {
|
||||
|
||||
@Test
|
||||
void testNotRewrite() {
|
||||
executor = new ExpressionRuleExecutor(ImmutableList.of(SimplifyNotExprRule.INSTANCE));
|
||||
executor = new ExpressionRuleExecutor(ImmutableList.of(
|
||||
ExpressionRewrite.bottomUp(SimplifyNotExprRule.INSTANCE)
|
||||
));
|
||||
|
||||
assertRewrite("not x", "not x");
|
||||
assertRewrite("not not x", "x");
|
||||
@ -79,7 +83,9 @@ class ExpressionRewriteTest extends ExpressionRewriteTestHelper {
|
||||
|
||||
@Test
|
||||
void testNormalizeExpressionRewrite() {
|
||||
executor = new ExpressionRuleExecutor(ImmutableList.of(NormalizeBinaryPredicatesRule.INSTANCE));
|
||||
executor = new ExpressionRuleExecutor(ImmutableList.of(
|
||||
ExpressionRewrite.bottomUp(NormalizeBinaryPredicatesRule.INSTANCE)
|
||||
));
|
||||
|
||||
assertRewrite("1 = 1", "1 = 1");
|
||||
assertRewrite("2 > x", "x < 2");
|
||||
@ -91,7 +97,9 @@ class ExpressionRewriteTest extends ExpressionRewriteTestHelper {
|
||||
|
||||
@Test
|
||||
void testDistinctPredicatesRewrite() {
|
||||
executor = new ExpressionRuleExecutor(ImmutableList.of(DistinctPredicatesRule.INSTANCE));
|
||||
executor = new ExpressionRuleExecutor(ImmutableList.of(
|
||||
bottomUp(DistinctPredicatesRule.INSTANCE)
|
||||
));
|
||||
|
||||
assertRewrite("a = 1", "a = 1");
|
||||
assertRewrite("a = 1 and a = 1", "a = 1");
|
||||
@ -103,7 +111,9 @@ class ExpressionRewriteTest extends ExpressionRewriteTestHelper {
|
||||
|
||||
@Test
|
||||
void testExtractCommonFactorRewrite() {
|
||||
executor = new ExpressionRuleExecutor(ImmutableList.of(ExtractCommonFactorRule.INSTANCE));
|
||||
executor = new ExpressionRuleExecutor(ImmutableList.of(
|
||||
bottomUp(ExtractCommonFactorRule.INSTANCE)
|
||||
));
|
||||
|
||||
assertRewrite("a", "a");
|
||||
|
||||
@ -112,22 +122,24 @@ class ExpressionRewriteTest extends ExpressionRewriteTestHelper {
|
||||
assertRewrite("a = 1 and b > 2", "a = 1 and b > 2");
|
||||
|
||||
assertRewrite("(a and b) or (c and d)", "(a and b) or (c and d)");
|
||||
assertRewrite("(a and b) and (c and d)", "((a and b) and c) and d");
|
||||
assertRewrite("(a and b) and (c and d)", "((a and b) and (c and d))");
|
||||
assertRewrite("(a and (b and c)) and (b or c)", "((b and c) and a)");
|
||||
|
||||
assertRewrite("(a or b) and (a or c)", "a or (b and c)");
|
||||
assertRewrite("(a and b) or (a and c)", "a and (b or c)");
|
||||
|
||||
assertRewrite("(a or b) and (a or c) and (a or d)", "a or (b and c and d)");
|
||||
assertRewrite("(a and b) or (a and c) or (a and d)", "a and (b or c or d)");
|
||||
assertRewrite("(a and b) or (a or c) or (a and d)", "((((a and b) or a) or c) or (a and d))");
|
||||
assertRewrite("(a and b) or (a and c) or (a or d)", "(((a and b) or (a and c) or a) or d))");
|
||||
assertRewrite("(a or b) or (a and c) or (a and d)", "(a or b) or (a and c) or (a and d)");
|
||||
assertRewrite("(a or b) or (a and c) or (a or d)", "(((a or b) or (a and c)) or d)");
|
||||
assertRewrite("(a or b) or (a or c) or (a and d)", "((a or b) or c) or (a and d)");
|
||||
assertRewrite("(a or b) and (a or d)", "a or (b and d)");
|
||||
assertRewrite("(a and b) or (a or c) or (a and d)", "a or c");
|
||||
assertRewrite("(a and b) or (a and c) or (a or d)", "(a or d)");
|
||||
assertRewrite("(a or b) or (a and c) or (a and d)", "(a or b)");
|
||||
assertRewrite("(a or b) or (a and c) or (a or d)", "((a or b) or d)");
|
||||
assertRewrite("(a or b) or (a or c) or (a and d)", "((a or b) or c)");
|
||||
assertRewrite("(a or b) or (a or c) or (a or d)", "(((a or b) or c) or d)");
|
||||
|
||||
assertRewrite("(a and b) or (d and c) or (d and e)", "(a and b) or (d and c) or (d and e)");
|
||||
assertRewrite("(a or b) and (d or c) and (d or e)", "(a or b) and (d or c) and (d or e)");
|
||||
assertRewrite("(a and b) or (d and c) or (d and e)", "((d and (c or e)) or (a and b))");
|
||||
assertRewrite("(a or b) and (d or c) and (d or e)", "((d or (c and e)) and (a or b))");
|
||||
|
||||
assertRewrite("(a and b) or ((d and c) and (d and e))", "(a and b) or (d and c and e)");
|
||||
assertRewrite("(a or b) and ((d or c) or (d or e))", "(a or b) and (d or c or e)");
|
||||
@ -152,11 +164,29 @@ class ExpressionRewriteTest extends ExpressionRewriteTestHelper {
|
||||
|
||||
assertRewrite("(a or b) and (a or true)", "a or b");
|
||||
|
||||
assertRewrite("a and (b or ((a and e) or (a and f))) and (b or d)", "(b or ((a and (e or f)) and d)) and a");
|
||||
|
||||
}
|
||||
|
||||
@Test
|
||||
void testTpcdsCase() {
|
||||
executor = new ExpressionRuleExecutor(ImmutableList.of(
|
||||
bottomUp(
|
||||
SimplifyRange.INSTANCE,
|
||||
OrToIn.INSTANCE,
|
||||
ExtractCommonFactorRule.INSTANCE
|
||||
)
|
||||
));
|
||||
assertRewrite(
|
||||
"(((((customer_address.ca_country = 'United States') AND ca_state IN ('DE', 'FL', 'TX')) OR ((customer_address.ca_country = 'United States') AND ca_state IN ('ID', 'IN', 'ND'))) OR ((customer_address.ca_country = 'United States') AND ca_state IN ('IL', 'MT', 'OH'))))",
|
||||
"((customer_address.ca_country = 'United States') AND ca_state IN ('DE', 'FL', 'TX', 'ID', 'IN', 'ND', 'IL', 'MT', 'OH'))");
|
||||
}
|
||||
|
||||
@Test
|
||||
void testInPredicateToEqualToRule() {
|
||||
executor = new ExpressionRuleExecutor(ImmutableList.of(InPredicateToEqualToRule.INSTANCE));
|
||||
executor = new ExpressionRuleExecutor(ImmutableList.of(
|
||||
bottomUp(InPredicateToEqualToRule.INSTANCE)
|
||||
));
|
||||
|
||||
assertRewrite("a in (1)", "a = 1");
|
||||
assertRewrite("a not in (1)", "not a = 1");
|
||||
@ -172,14 +202,18 @@ class ExpressionRewriteTest extends ExpressionRewriteTestHelper {
|
||||
|
||||
@Test
|
||||
void testInPredicateDedup() {
|
||||
executor = new ExpressionRuleExecutor(ImmutableList.of(InPredicateDedup.INSTANCE));
|
||||
executor = new ExpressionRuleExecutor(ImmutableList.of(
|
||||
bottomUp(InPredicateDedup.INSTANCE)
|
||||
));
|
||||
|
||||
assertRewrite("a in (1, 2, 1, 2)", "a in (1, 2)");
|
||||
}
|
||||
|
||||
@Test
|
||||
void testSimplifyCastRule() {
|
||||
executor = new ExpressionRuleExecutor(ImmutableList.of(SimplifyCastRule.INSTANCE));
|
||||
executor = new ExpressionRuleExecutor(ImmutableList.of(
|
||||
bottomUp(SimplifyCastRule.INSTANCE)
|
||||
));
|
||||
|
||||
// deduplicate
|
||||
assertRewrite("CAST(1 AS tinyint)", "1");
|
||||
@ -211,7 +245,9 @@ class ExpressionRewriteTest extends ExpressionRewriteTestHelper {
|
||||
|
||||
@Test
|
||||
void testSimplifyDecimalV3Comparison() {
|
||||
executor = new ExpressionRuleExecutor(ImmutableList.of(SimplifyDecimalV3Comparison.INSTANCE));
|
||||
executor = new ExpressionRuleExecutor(ImmutableList.of(
|
||||
bottomUp(SimplifyDecimalV3Comparison.INSTANCE)
|
||||
));
|
||||
|
||||
// do rewrite
|
||||
Expression left = new DecimalV3Literal(new BigDecimal("12345.67"));
|
||||
@ -226,4 +262,16 @@ class ExpressionRewriteTest extends ExpressionRewriteTestHelper {
|
||||
comparison = new EqualTo(new DecimalV3Literal(new BigDecimal("12345.67")), new DecimalV3Literal(new BigDecimal("76543.21")));
|
||||
assertRewrite(comparison, comparison);
|
||||
}
|
||||
|
||||
@Test
|
||||
void testDeadLoop() {
|
||||
executor = new ExpressionRuleExecutor(ImmutableList.of(
|
||||
bottomUp(
|
||||
SimplifyRange.INSTANCE,
|
||||
ExtractCommonFactorRule.INSTANCE
|
||||
)
|
||||
));
|
||||
|
||||
assertRewrite("a and (b > 0 and b < 10)", "a and (b > 0 and b < 10)");
|
||||
}
|
||||
}
|
||||
|
||||
@ -46,7 +46,7 @@ import org.junit.jupiter.api.Assertions;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
public abstract class ExpressionRewriteTestHelper {
|
||||
public abstract class ExpressionRewriteTestHelper extends ExpressionRewrite {
|
||||
protected static final NereidsParser PARSER = new NereidsParser();
|
||||
protected ExpressionRuleExecutor executor;
|
||||
|
||||
|
||||
@ -58,7 +58,9 @@ class FoldConstantTest extends ExpressionRewriteTestHelper {
|
||||
|
||||
@Test
|
||||
void testCaseWhenFold() {
|
||||
executor = new ExpressionRuleExecutor(ImmutableList.of(FoldConstantRuleOnFE.INSTANCE));
|
||||
executor = new ExpressionRuleExecutor(ImmutableList.of(
|
||||
bottomUp(FoldConstantRuleOnFE.VISITOR_INSTANCE)
|
||||
));
|
||||
// assertRewriteAfterTypeCoercion("case when 1 = 2 then 1 when '1' < 2 then 2 else 3 end", "2");
|
||||
// assertRewriteAfterTypeCoercion("case when 1 = 2 then 1 when '1' > 2 then 2 end", "null");
|
||||
assertRewriteAfterTypeCoercion("case when (1 + 5) / 2 > 2 then 4 when '1' < 2 then 2 else 3 end", "4");
|
||||
@ -75,7 +77,9 @@ class FoldConstantTest extends ExpressionRewriteTestHelper {
|
||||
|
||||
@Test
|
||||
void testInFold() {
|
||||
executor = new ExpressionRuleExecutor(ImmutableList.of(FoldConstantRuleOnFE.INSTANCE));
|
||||
executor = new ExpressionRuleExecutor(ImmutableList.of(
|
||||
bottomUp(FoldConstantRuleOnFE.VISITOR_INSTANCE)
|
||||
));
|
||||
assertRewriteAfterTypeCoercion("1 in (1,2,3,4)", "true");
|
||||
// Type Coercion trans all to string.
|
||||
assertRewriteAfterTypeCoercion("3 in ('1', 2 + 8 / 2, 3, 4)", "true");
|
||||
@ -88,7 +92,9 @@ class FoldConstantTest extends ExpressionRewriteTestHelper {
|
||||
|
||||
@Test
|
||||
void testLogicalFold() {
|
||||
executor = new ExpressionRuleExecutor(ImmutableList.of(FoldConstantRuleOnFE.INSTANCE));
|
||||
executor = new ExpressionRuleExecutor(ImmutableList.of(
|
||||
bottomUp(FoldConstantRuleOnFE.VISITOR_INSTANCE)
|
||||
));
|
||||
assertRewriteAfterTypeCoercion("10 + 1 > 1 and 1 > 2", "false");
|
||||
assertRewriteAfterTypeCoercion("10 + 1 > 1 and 1 < 2", "true");
|
||||
assertRewriteAfterTypeCoercion("null + 1 > 1 and 1 < 2", "null");
|
||||
@ -126,7 +132,9 @@ class FoldConstantTest extends ExpressionRewriteTestHelper {
|
||||
|
||||
@Test
|
||||
void testIsNullFold() {
|
||||
executor = new ExpressionRuleExecutor(ImmutableList.of(FoldConstantRuleOnFE.INSTANCE));
|
||||
executor = new ExpressionRuleExecutor(ImmutableList.of(
|
||||
bottomUp(FoldConstantRuleOnFE.VISITOR_INSTANCE)
|
||||
));
|
||||
assertRewriteAfterTypeCoercion("100 is null", "false");
|
||||
assertRewriteAfterTypeCoercion("null is null", "true");
|
||||
assertRewriteAfterTypeCoercion("null is not null", "false");
|
||||
@ -137,7 +145,9 @@ class FoldConstantTest extends ExpressionRewriteTestHelper {
|
||||
|
||||
@Test
|
||||
void testNotPredicateFold() {
|
||||
executor = new ExpressionRuleExecutor(ImmutableList.of(FoldConstantRuleOnFE.INSTANCE));
|
||||
executor = new ExpressionRuleExecutor(ImmutableList.of(
|
||||
bottomUp(FoldConstantRuleOnFE.VISITOR_INSTANCE)
|
||||
));
|
||||
assertRewriteAfterTypeCoercion("not 1 > 2", "true");
|
||||
assertRewriteAfterTypeCoercion("not null + 1 > 2", "null");
|
||||
assertRewriteAfterTypeCoercion("not (1 + 5) / 2 + (10 - 1) * 3 > 3 * 5 + 1", "false");
|
||||
@ -145,7 +155,9 @@ class FoldConstantTest extends ExpressionRewriteTestHelper {
|
||||
|
||||
@Test
|
||||
void testCastFold() {
|
||||
executor = new ExpressionRuleExecutor(ImmutableList.of(FoldConstantRuleOnFE.INSTANCE));
|
||||
executor = new ExpressionRuleExecutor(ImmutableList.of(
|
||||
bottomUp(FoldConstantRuleOnFE.VISITOR_INSTANCE)
|
||||
));
|
||||
|
||||
// cast '1' as tinyint
|
||||
Cast c = new Cast(Literal.of("1"), TinyIntType.INSTANCE);
|
||||
@ -156,7 +168,9 @@ class FoldConstantTest extends ExpressionRewriteTestHelper {
|
||||
|
||||
@Test
|
||||
void testCompareFold() {
|
||||
executor = new ExpressionRuleExecutor(ImmutableList.of(FoldConstantRuleOnFE.INSTANCE));
|
||||
executor = new ExpressionRuleExecutor(ImmutableList.of(
|
||||
bottomUp(FoldConstantRuleOnFE.VISITOR_INSTANCE)
|
||||
));
|
||||
assertRewriteAfterTypeCoercion("'1' = 2", "false");
|
||||
assertRewriteAfterTypeCoercion("1 = 2", "false");
|
||||
assertRewriteAfterTypeCoercion("1 != 2", "true");
|
||||
@ -173,7 +187,9 @@ class FoldConstantTest extends ExpressionRewriteTestHelper {
|
||||
|
||||
@Test
|
||||
void testArithmeticFold() {
|
||||
executor = new ExpressionRuleExecutor(ImmutableList.of(FoldConstantRuleOnFE.INSTANCE));
|
||||
executor = new ExpressionRuleExecutor(ImmutableList.of(
|
||||
bottomUp(FoldConstantRuleOnFE.VISITOR_INSTANCE)
|
||||
));
|
||||
assertRewrite("1 + 1", Literal.of((short) 2));
|
||||
assertRewrite("1 - 1", Literal.of((short) 0));
|
||||
assertRewrite("100 + 100", Literal.of((short) 200));
|
||||
@ -206,7 +222,9 @@ class FoldConstantTest extends ExpressionRewriteTestHelper {
|
||||
|
||||
@Test
|
||||
void testTimestampFold() {
|
||||
executor = new ExpressionRuleExecutor(ImmutableList.of(FoldConstantRuleOnFE.INSTANCE));
|
||||
executor = new ExpressionRuleExecutor(ImmutableList.of(
|
||||
bottomUp(FoldConstantRuleOnFE.VISITOR_INSTANCE)
|
||||
));
|
||||
String interval = "'1991-05-01' - interval 1 day";
|
||||
Expression e7 = process((TimestampArithmetic) PARSER.parseExpression(interval));
|
||||
Expression e8 = Config.enable_date_conversion
|
||||
|
||||
@ -48,7 +48,7 @@ public class PredicatesSplitterTest extends ExpressionRewriteTestHelper {
|
||||
"c = d or a = 10");
|
||||
assetEquals("a = b and c + d = e and a > 7 and 10 > d",
|
||||
"a = b",
|
||||
"10 > d and a > 7",
|
||||
"a > 7 and 10 > d",
|
||||
"c + d = e");
|
||||
assetEquals("a = b and c + d = e or a > 7 and 10 > d",
|
||||
"",
|
||||
|
||||
@ -29,9 +29,11 @@ class SimplifyArithmeticRuleTest extends ExpressionRewriteTestHelper {
|
||||
@Test
|
||||
void testSimplifyArithmetic() {
|
||||
executor = new ExpressionRuleExecutor(ImmutableList.of(
|
||||
SimplifyArithmeticRule.INSTANCE,
|
||||
bottomUp(SimplifyArithmeticRule.INSTANCE),
|
||||
FunctionBinder.INSTANCE,
|
||||
FoldConstantRule.INSTANCE
|
||||
bottomUp(
|
||||
FoldConstantRule.INSTANCE
|
||||
)
|
||||
));
|
||||
assertRewriteAfterTypeCoercion("IA", "IA");
|
||||
assertRewriteAfterTypeCoercion("IA + 1", "IA + 1");
|
||||
@ -55,7 +57,7 @@ class SimplifyArithmeticRuleTest extends ExpressionRewriteTestHelper {
|
||||
@Test
|
||||
void testSimplifyArithmeticRuleOnly() {
|
||||
executor = new ExpressionRuleExecutor(ImmutableList.of(
|
||||
SimplifyArithmeticRule.INSTANCE
|
||||
bottomUp(SimplifyArithmeticRule.INSTANCE)
|
||||
));
|
||||
|
||||
// add and subtract
|
||||
@ -67,39 +69,43 @@ class SimplifyArithmeticRuleTest extends ExpressionRewriteTestHelper {
|
||||
assertRewriteAfterTypeCoercion("IA - 2 - ((-IB - 1) - (3 + (IC + 4)))", "(((IA + IB) + IC) - ((((2 + 0) - 1) - 3) - 4))");
|
||||
|
||||
// multiply and divide
|
||||
assertRewriteAfterTypeCoercion("2 / IA / ((1 / IB) / (3 * IC))", "((((cast(2 as DOUBLE) / cast(1 as DOUBLE)) / cast(IA as DOUBLE)) * cast(IB as DOUBLE)) * cast((3 * IC) as DOUBLE))");
|
||||
assertRewriteAfterTypeCoercion("2 / IA / ((1 / IB) / (3 * IC))", "((((cast(2 as DOUBLE) / cast(1 as DOUBLE)) / cast(IA as DOUBLE)) * cast(IB as DOUBLE)) * cast((IC * 3) as DOUBLE))");
|
||||
assertRewriteAfterTypeCoercion("IA / 2 / ((IB * 1) / (3 / (IC / 4)))", "(((cast(IA as DOUBLE) / cast((IB * 1) as DOUBLE)) / cast(IC as DOUBLE)) / ((cast(2 as DOUBLE) / cast(3 as DOUBLE)) / cast(4 as DOUBLE)))");
|
||||
assertRewriteAfterTypeCoercion("IA / 2 / ((IB / 1) / (3 / (IC * 4)))", "(((cast(IA as DOUBLE) / cast(IB as DOUBLE)) / cast((IC * 4) as DOUBLE)) / ((cast(2 as DOUBLE) / cast(1 as DOUBLE)) / cast(3 as DOUBLE)))");
|
||||
assertRewriteAfterTypeCoercion("IA / 2 / ((IB / 1) / (3 * (IC * 4)))", "(((cast(IA as DOUBLE) / cast(IB as DOUBLE)) * cast((3 * (IC * 4)) as DOUBLE)) / (cast(2 as DOUBLE) / cast(1 as DOUBLE)))");
|
||||
assertRewriteAfterTypeCoercion("IA / 2 / ((IB / 1) / (3 * (IC * 4)))", "(((cast(IA as DOUBLE) / cast(IB as DOUBLE)) * cast((IC * (3 * 4)) as DOUBLE)) / (cast(2 as DOUBLE) / cast(1 as DOUBLE)))");
|
||||
|
||||
// hybrid
|
||||
// root is subtract
|
||||
assertRewriteAfterTypeCoercion("-2 - IA * ((1 - IB) - (3 / IC))", "(cast(-2 as DOUBLE) - (cast(IA as DOUBLE) * (cast((1 - IB) as DOUBLE) - (cast(3 as DOUBLE) / cast(IC as DOUBLE)))))");
|
||||
assertRewriteAfterTypeCoercion("-IA - 2 - ((IB * 1) - (3 * (IC / 4)))", "((cast(((0 - IA) - 2) as DOUBLE) - cast((IB * 1) as DOUBLE)) + (cast(3 as DOUBLE) * (cast(IC as DOUBLE) / cast(4 as DOUBLE))))");
|
||||
assertRewriteAfterTypeCoercion("-IA - 2 - ((IB * 1) - (3 * (IC / 4)))", "((cast(((0 - 2) - IA) as DOUBLE) - cast((IB * 1) as DOUBLE)) + (cast(3 as DOUBLE) * (cast(IC as DOUBLE) / cast(4 as DOUBLE))))");
|
||||
// root is add
|
||||
assertRewriteAfterTypeCoercion("-IA * 2 + ((IB - 1) / (3 - (IC + 4)))", "(cast(((0 - IA) * 2) as DOUBLE) + (cast((IB - 1) as DOUBLE) / cast((3 - (IC + 4)) as DOUBLE)))");
|
||||
assertRewriteAfterTypeCoercion("-IA * 2 + ((IB - 1) / (3 - (IC + 4)))", "(cast(((0 - IA) * 2) as DOUBLE) + (cast((IB - 1) as DOUBLE) / cast(((3 - 4) - IC) as DOUBLE)))");
|
||||
assertRewriteAfterTypeCoercion("-IA + 2 + ((IB - 1) - (3 * (IC + 4)))", "(((((0 + 2) - 1) - IA) + IB) - (3 * (IC + 4)))");
|
||||
// root is multiply
|
||||
assertRewriteAfterTypeCoercion("-IA / 2 * ((-IB - 1) - (3 + (IC + 4)))", "((cast((0 - IA) as DOUBLE) * cast((((0 - IB) - 1) - (3 + (IC + 4))) as DOUBLE)) / cast(2 as DOUBLE))");
|
||||
assertRewriteAfterTypeCoercion("-IA / 2 * ((-IB - 1) * (3 / (IC + 4)))", "(((cast((0 - IA) as DOUBLE) * cast(((0 - IB) - 1) as DOUBLE)) / cast((IC + 4) as DOUBLE)) / (cast(2 as DOUBLE) / cast(3 as DOUBLE)))");
|
||||
assertRewriteAfterTypeCoercion("-IA / 2 * ((-IB - 1) - (3 + (IC + 4)))", "((cast((0 - IA) as DOUBLE) * cast((((((0 - 1) - 3) - 4) - IB) - IC) as DOUBLE)) / cast(2 as DOUBLE))");
|
||||
assertRewriteAfterTypeCoercion("-IA / 2 * ((-IB - 1) * (3 / (IC + 4)))", "(((cast((0 - IA) as DOUBLE) * cast(((0 - 1) - IB) as DOUBLE)) / cast((IC + 4) as DOUBLE)) / (cast(2 as DOUBLE) / cast(3 as DOUBLE)))");
|
||||
// root is divide
|
||||
assertRewriteAfterTypeCoercion("(-IA / 2) / ((-IB - 1) - (3 + (IC + 4)))", "((cast((0 - IA) as DOUBLE) / cast((((0 - IB) - 1) - (3 + (IC + 4))) as DOUBLE)) / cast(2 as DOUBLE))");
|
||||
assertRewriteAfterTypeCoercion("(-IA / 2) / ((-IB - 1) / (3 + (IC * 4)))", "(((cast((0 - IA) as DOUBLE) / cast(((0 - IB) - 1) as DOUBLE)) * cast((3 + (IC * 4)) as DOUBLE)) / cast(2 as DOUBLE))");
|
||||
assertRewriteAfterTypeCoercion("(-IA / 2) / ((-IB - 1) - (3 + (IC + 4)))", "((cast((0 - IA) as DOUBLE) / cast((((((0 - 1) - 3) - 4) - IB) - IC) as DOUBLE)) / cast(2 as DOUBLE))");
|
||||
assertRewriteAfterTypeCoercion("(-IA / 2) / ((-IB - 1) / (3 + (IC * 4)))", "(((cast((0 - IA) as DOUBLE) / cast(((0 - 1) - IB) as DOUBLE)) * cast(((IC * 4) + 3) as DOUBLE)) / cast(2 as DOUBLE))");
|
||||
|
||||
// unsupported decimal
|
||||
assertRewriteAfterTypeCoercion("-2 - MA - ((1 - IB) - (3 + IC))", "((cast(-2 as DECIMALV3(38, 9)) - MA) - cast(((1 - IB) - (3 + IC)) as DECIMALV3(38, 9)))");
|
||||
assertRewriteAfterTypeCoercion("-IA / 2.0 * ((-IB - 1) - (3 + (IC + 4)))", "((cast((0 - IA) as DECIMALV3(25, 5)) / 2.0) * cast((((0 - IB) - 1) - (3 + (IC + 4))) as DECIMALV3(20, 0)))");
|
||||
assertRewriteAfterTypeCoercion("-2 - MA - ((1 - IB) - (3 + IC))", "((cast(-2 as DECIMALV3(38, 9)) - MA) - cast((((1 - 3) - IB) - IC) as DECIMALV3(38, 9)))");
|
||||
assertRewriteAfterTypeCoercion("-IA / 2.0 * ((-IB - 1) - (3 + (IC + 4)))", "((cast((0 - IA) as DECIMALV3(25, 5)) / 2.0) * cast((((((0 - 1) - 3) - 4) - IB) - IC) as DECIMALV3(20, 0)))");
|
||||
}
|
||||
|
||||
@Test
|
||||
void testSimplifyArithmeticComparison() {
|
||||
executor = new ExpressionRuleExecutor(ImmutableList.of(
|
||||
SimplifyArithmeticRule.INSTANCE,
|
||||
FoldConstantRule.INSTANCE,
|
||||
SimplifyArithmeticComparisonRule.INSTANCE,
|
||||
SimplifyArithmeticRule.INSTANCE,
|
||||
bottomUp(
|
||||
SimplifyArithmeticRule.INSTANCE,
|
||||
FoldConstantRule.INSTANCE,
|
||||
SimplifyArithmeticComparisonRule.INSTANCE,
|
||||
SimplifyArithmeticRule.INSTANCE
|
||||
),
|
||||
FunctionBinder.INSTANCE,
|
||||
FoldConstantRule.INSTANCE
|
||||
bottomUp(
|
||||
FoldConstantRule.INSTANCE
|
||||
)
|
||||
));
|
||||
assertRewriteAfterTypeCoercion("IA", "IA");
|
||||
assertRewriteAfterTypeCoercion("IA > IB", "IA > IB");
|
||||
@ -134,12 +140,16 @@ class SimplifyArithmeticRuleTest extends ExpressionRewriteTestHelper {
|
||||
@Test
|
||||
void testSimplifyDateTimeComparison() {
|
||||
executor = new ExpressionRuleExecutor(ImmutableList.of(
|
||||
SimplifyArithmeticRule.INSTANCE,
|
||||
FoldConstantRule.INSTANCE,
|
||||
SimplifyArithmeticComparisonRule.INSTANCE,
|
||||
SimplifyArithmeticRule.INSTANCE,
|
||||
bottomUp(
|
||||
SimplifyArithmeticRule.INSTANCE,
|
||||
FoldConstantRule.INSTANCE,
|
||||
SimplifyArithmeticComparisonRule.INSTANCE,
|
||||
SimplifyArithmeticRule.INSTANCE
|
||||
),
|
||||
FunctionBinder.INSTANCE,
|
||||
FoldConstantRule.INSTANCE
|
||||
bottomUp(
|
||||
FoldConstantRule.INSTANCE
|
||||
)
|
||||
));
|
||||
assertRewriteAfterTypeCoercion("years_add(IA, 1) > '2021-01-01 00:00:00'", "(cast(IA as DATETIMEV2(0)) > '2020-01-01 00:00:00')");
|
||||
assertRewriteAfterTypeCoercion("years_sub(IA, 1) > '2021-01-01 00:00:00'", "(cast(IA as DATETIMEV2(0)) > '2022-01-01 00:00:00')");
|
||||
|
||||
@ -34,8 +34,10 @@ public class SimplifyInPredicateTest extends ExpressionRewriteTestHelper {
|
||||
@Test
|
||||
public void test() {
|
||||
executor = new ExpressionRuleExecutor(ImmutableList.of(
|
||||
FoldConstantRule.INSTANCE,
|
||||
SimplifyInPredicate.INSTANCE
|
||||
bottomUp(
|
||||
FoldConstantRule.INSTANCE,
|
||||
SimplifyInPredicate.INSTANCE
|
||||
)
|
||||
));
|
||||
Map<String, Slot> mem = Maps.newHashMap();
|
||||
Expression rewrittenExpression = PARSER.parseExpression("cast(CA as DATETIME) in ('1992-01-31 00:00:00', '1992-02-01 00:00:00')");
|
||||
@ -48,7 +50,9 @@ public class SimplifyInPredicateTest extends ExpressionRewriteTestHelper {
|
||||
Expression expectedExpression = PARSER.parseExpression("CA in (cast('1992-01-31' as date), cast('1992-02-01' as date))");
|
||||
expectedExpression = replaceUnboundSlot(expectedExpression, mem);
|
||||
executor = new ExpressionRuleExecutor(ImmutableList.of(
|
||||
bottomUp(
|
||||
FoldConstantRule.INSTANCE
|
||||
)
|
||||
));
|
||||
expectedExpression = executor.rewrite(expectedExpression, context);
|
||||
Assertions.assertEquals(expectedExpression, rewrittenExpression);
|
||||
|
||||
@ -45,7 +45,7 @@ import org.junit.jupiter.api.Test;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
public class SimplifyRangeTest {
|
||||
public class SimplifyRangeTest extends ExpressionRewrite {
|
||||
|
||||
private static final NereidsParser PARSER = new NereidsParser();
|
||||
private ExpressionRuleExecutor executor;
|
||||
@ -59,7 +59,9 @@ public class SimplifyRangeTest {
|
||||
|
||||
@Test
|
||||
public void testSimplify() {
|
||||
executor = new ExpressionRuleExecutor(ImmutableList.of(SimplifyRange.INSTANCE));
|
||||
executor = new ExpressionRuleExecutor(ImmutableList.of(
|
||||
bottomUp(SimplifyRange.INSTANCE)
|
||||
));
|
||||
assertRewrite("TA", "TA");
|
||||
assertRewrite("TA > 3 or TA > null", "TA > 3");
|
||||
assertRewrite("TA > 3 or TA < null", "TA > 3");
|
||||
@ -100,7 +102,7 @@ public class SimplifyRangeTest {
|
||||
assertRewrite("((TA > 10 or TA > 5) and TB > 10) or (TB > 10 and (TB > 20 or TB < 10))", "(TA > 5 and TB > 10) or (TB > 10 and (TB > 20 or TB < 10))");
|
||||
assertRewrite("TA in (1,2,3) and TA > 10", "FALSE");
|
||||
assertRewrite("TA in (1,2,3) and TA >= 1", "TA in (1,2,3)");
|
||||
assertRewrite("TA in (1,2,3) and TA > 1", "((TA = 2) OR (TA = 3))");
|
||||
assertRewrite("TA in (1,2,3) and TA > 1", "TA IN (2, 3)");
|
||||
assertRewrite("TA in (1,2,3) or TA >= 1", "TA >= 1");
|
||||
assertRewrite("TA in (1)", "TA in (1)");
|
||||
assertRewrite("TA in (1,2,3) and TA < 10", "TA in (1,2,3)");
|
||||
@ -147,7 +149,7 @@ public class SimplifyRangeTest {
|
||||
assertRewrite("((TA + TC > 10 or TA + TC > 5) and TB > 10) or (TB > 10 and (TB > 20 or TB < 10))", "(TA + TC > 5 and TB > 10) or (TB > 10 and (TB > 20 or TB < 10))");
|
||||
assertRewrite("TA + TC in (1,2,3) and TA + TC > 10", "FALSE");
|
||||
assertRewrite("TA + TC in (1,2,3) and TA + TC >= 1", "TA + TC in (1,2,3)");
|
||||
assertRewrite("TA + TC in (1,2,3) and TA + TC > 1", "((TA + TC = 2) OR (TA + TC = 3))");
|
||||
assertRewrite("TA + TC in (1,2,3) and TA + TC > 1", "(TA + TC) IN (2, 3)");
|
||||
assertRewrite("TA + TC in (1,2,3) or TA + TC >= 1", "TA + TC >= 1");
|
||||
assertRewrite("TA + TC in (1)", "TA + TC in (1)");
|
||||
assertRewrite("TA + TC in (1,2,3) and TA + TC < 10", "TA + TC in (1,2,3)");
|
||||
@ -171,8 +173,10 @@ public class SimplifyRangeTest {
|
||||
|
||||
@Test
|
||||
public void testSimplifyDate() {
|
||||
executor = new ExpressionRuleExecutor(ImmutableList.of(SimplifyRange.INSTANCE));
|
||||
// assertRewrite("TA", "TA");
|
||||
executor = new ExpressionRuleExecutor(ImmutableList.of(
|
||||
bottomUp(SimplifyRange.INSTANCE)
|
||||
));
|
||||
assertRewrite("TA", "TA");
|
||||
assertRewrite(
|
||||
"(TA >= date '2024-01-01' and TA <= date '2024-01-03') or (TA > date '2024-01-05' and TA < date '2024-01-07')",
|
||||
"(TA >= date '2024-01-01' and TA <= date '2024-01-03') or (TA > date '2024-01-05' and TA < date '2024-01-07')");
|
||||
@ -213,7 +217,7 @@ public class SimplifyRangeTest {
|
||||
assertRewrite("TA in (date '2024-01-01',date '2024-01-02',date '2024-01-03') and TA >= date '2024-01-01'",
|
||||
"TA in (date '2024-01-01',date '2024-01-02',date '2024-01-03')");
|
||||
assertRewrite("TA in (date '2024-01-01',date '2024-01-02',date '2024-01-03') and TA > date '2024-01-01'",
|
||||
"((TA = date '2024-01-02') OR (TA = date '2024-01-03'))");
|
||||
"TA IN (date '2024-01-02', date '2024-01-03')");
|
||||
assertRewrite("TA in (date '2024-01-01',date '2024-01-02',date '2024-01-03') or TA >= date '2024-01-01'",
|
||||
"TA >= date '2024-01-01'");
|
||||
assertRewrite("TA in (date '2024-01-01')", "TA in (date '2024-01-01')");
|
||||
@ -237,8 +241,10 @@ public class SimplifyRangeTest {
|
||||
|
||||
@Test
|
||||
public void testSimplifyDateTime() {
|
||||
executor = new ExpressionRuleExecutor(ImmutableList.of(SimplifyRange.INSTANCE));
|
||||
// assertRewrite("TA", "TA");
|
||||
executor = new ExpressionRuleExecutor(ImmutableList.of(
|
||||
bottomUp(SimplifyRange.INSTANCE)
|
||||
));
|
||||
assertRewrite("TA", "TA");
|
||||
assertRewrite(
|
||||
"(TA >= timestamp '2024-01-01 00:00:00' and TA <= timestamp '2024-01-03 00:00:00') or (TA > timestamp '2024-01-05 00:00:00' and TA < timestamp '2024-01-07 00:00:00')",
|
||||
"(TA >= timestamp '2024-01-01 00:00:00' and TA <= timestamp '2024-01-03 00:00:00') or (TA > timestamp '2024-01-05 00:00:00' and TA < timestamp '2024-01-07 00:00:00')");
|
||||
@ -279,7 +285,7 @@ public class SimplifyRangeTest {
|
||||
assertRewrite("TA in (timestamp '2024-01-01 01:00:00',timestamp '2024-01-02 01:50:00',timestamp '2024-01-03 02:00:00') and TA >= timestamp '2024-01-01'",
|
||||
"TA in (timestamp '2024-01-01 01:00:00',timestamp '2024-01-02 01:50:00',timestamp '2024-01-03 02:00:00')");
|
||||
assertRewrite("TA in (timestamp '2024-01-01 02:00:00',timestamp '2024-01-02 02:00:00',timestamp '2024-01-03 02:00:00') and TA > timestamp '2024-01-01 02:10:00'",
|
||||
"((TA = timestamp '2024-01-02 02:00:00') OR (TA = timestamp '2024-01-03 02:00:00'))");
|
||||
"TA IN (timestamp '2024-01-02 02:00:00', timestamp '2024-01-03 02:00:00')");
|
||||
assertRewrite("TA in (timestamp '2024-01-01 02:00:00',timestamp '2024-01-02 02:00:00',timestamp '2024-01-03 02:00:00') or TA >= timestamp '2024-01-01 01:00:00'",
|
||||
"TA >= timestamp '2024-01-01 01:00:00'");
|
||||
assertRewrite("TA in (timestamp '2024-01-01 02:00:00')", "TA in (timestamp '2024-01-01 02:00:00')");
|
||||
|
||||
@ -35,7 +35,9 @@ class NullSafeEqualToEqualTest extends ExpressionRewriteTestHelper {
|
||||
// "A<=> Null" to "A is null"
|
||||
@Test
|
||||
void testNullSafeEqualToIsNull() {
|
||||
executor = new ExpressionRuleExecutor(ImmutableList.of(NullSafeEqualToEqual.INSTANCE));
|
||||
executor = new ExpressionRuleExecutor(ImmutableList.of(
|
||||
bottomUp(NullSafeEqualToEqual.INSTANCE)
|
||||
));
|
||||
SlotReference slot = new SlotReference("a", StringType.INSTANCE, true);
|
||||
assertRewrite(new NullSafeEqual(slot, NullLiteral.INSTANCE), new IsNull(slot));
|
||||
}
|
||||
@ -43,7 +45,9 @@ class NullSafeEqualToEqualTest extends ExpressionRewriteTestHelper {
|
||||
// "A<=> Null" to "False", when A is not nullable
|
||||
@Test
|
||||
void testNullSafeEqualToFalse() {
|
||||
executor = new ExpressionRuleExecutor(ImmutableList.of(NullSafeEqualToEqual.INSTANCE));
|
||||
executor = new ExpressionRuleExecutor(ImmutableList.of(
|
||||
bottomUp(NullSafeEqualToEqual.INSTANCE)
|
||||
));
|
||||
SlotReference slot = new SlotReference("a", StringType.INSTANCE, false);
|
||||
assertRewrite(new NullSafeEqual(slot, NullLiteral.INSTANCE), BooleanLiteral.FALSE);
|
||||
}
|
||||
@ -51,7 +55,9 @@ class NullSafeEqualToEqualTest extends ExpressionRewriteTestHelper {
|
||||
// "A(nullable)<=>B" not changed
|
||||
@Test
|
||||
void testNullSafeEqualNotChangedLeft() {
|
||||
executor = new ExpressionRuleExecutor(ImmutableList.of(NullSafeEqualToEqual.INSTANCE));
|
||||
executor = new ExpressionRuleExecutor(ImmutableList.of(
|
||||
bottomUp(NullSafeEqualToEqual.INSTANCE)
|
||||
));
|
||||
SlotReference a = new SlotReference("a", StringType.INSTANCE, true);
|
||||
SlotReference b = new SlotReference("b", StringType.INSTANCE, false);
|
||||
assertRewrite(new NullSafeEqual(a, b), new NullSafeEqual(a, b));
|
||||
@ -60,7 +66,9 @@ class NullSafeEqualToEqualTest extends ExpressionRewriteTestHelper {
|
||||
// "A<=>B(nullable)" not changed
|
||||
@Test
|
||||
void testNullSafeEqualNotChangedRight() {
|
||||
executor = new ExpressionRuleExecutor(ImmutableList.of(NullSafeEqualToEqual.INSTANCE));
|
||||
executor = new ExpressionRuleExecutor(ImmutableList.of(
|
||||
bottomUp(NullSafeEqualToEqual.INSTANCE)
|
||||
));
|
||||
SlotReference a = new SlotReference("a", StringType.INSTANCE, false);
|
||||
SlotReference b = new SlotReference("b", StringType.INSTANCE, true);
|
||||
assertRewrite(new NullSafeEqual(a, b), new NullSafeEqual(a, b));
|
||||
@ -69,7 +77,9 @@ class NullSafeEqualToEqualTest extends ExpressionRewriteTestHelper {
|
||||
// "A<=>B" changed
|
||||
@Test
|
||||
void testNullSafeEqualToEqual() {
|
||||
executor = new ExpressionRuleExecutor(ImmutableList.of(NullSafeEqualToEqual.INSTANCE));
|
||||
executor = new ExpressionRuleExecutor(ImmutableList.of(
|
||||
bottomUp(NullSafeEqualToEqual.INSTANCE)
|
||||
));
|
||||
SlotReference a = new SlotReference("a", StringType.INSTANCE, false);
|
||||
SlotReference b = new SlotReference("b", StringType.INSTANCE, false);
|
||||
assertRewrite(new NullSafeEqual(a, b), new EqualTo(a, b));
|
||||
|
||||
@ -17,6 +17,7 @@
|
||||
|
||||
package org.apache.doris.nereids.rules.expression.rules;
|
||||
|
||||
import org.apache.doris.nereids.rules.expression.ExpressionRewrite;
|
||||
import org.apache.doris.nereids.rules.expression.ExpressionRewriteTestHelper;
|
||||
import org.apache.doris.nereids.rules.expression.ExpressionRuleExecutor;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
@ -37,7 +38,9 @@ class SimplifyArithmeticComparisonRuleTest extends ExpressionRewriteTestHelper {
|
||||
public void testProcess() {
|
||||
Map<String, Slot> nameToSlot = new HashMap<>();
|
||||
nameToSlot.put("a", new SlotReference("a", IntegerType.INSTANCE));
|
||||
executor = new ExpressionRuleExecutor(ImmutableList.of(SimplifyArithmeticComparisonRule.INSTANCE));
|
||||
executor = new ExpressionRuleExecutor(ImmutableList.of(
|
||||
ExpressionRewrite.bottomUp(SimplifyArithmeticComparisonRule.INSTANCE)
|
||||
));
|
||||
assertRewriteAfterSimplify("a + 1 > 1", "a > cast((1 - 1) as INT)", nameToSlot);
|
||||
assertRewriteAfterSimplify("a - 1 > 1", "a > cast((1 + 1) as INT)", nameToSlot);
|
||||
assertRewriteAfterSimplify("a / -2 > 1", "cast((1 * -2) as INT) > a", nameToSlot);
|
||||
@ -82,7 +85,7 @@ class SimplifyArithmeticComparisonRuleTest extends ExpressionRewriteTestHelper {
|
||||
if (slotNameToSlot != null) {
|
||||
needRewriteExpression = replaceUnboundSlot(needRewriteExpression, slotNameToSlot);
|
||||
}
|
||||
Expression rewritten = SimplifyArithmeticComparisonRule.INSTANCE.rewrite(needRewriteExpression, context);
|
||||
Expression rewritten = executor.rewrite(needRewriteExpression, context);
|
||||
Expression expectedExpression = PARSER.parseExpression(expected);
|
||||
Assertions.assertEquals(expectedExpression.toSql(), rewritten.toSql());
|
||||
}
|
||||
|
||||
@ -17,6 +17,7 @@
|
||||
|
||||
package org.apache.doris.nereids.rules.expression.rules;
|
||||
|
||||
import org.apache.doris.nereids.rules.expression.ExpressionRewrite;
|
||||
import org.apache.doris.nereids.rules.expression.ExpressionRewriteTestHelper;
|
||||
import org.apache.doris.nereids.rules.expression.ExpressionRuleExecutor;
|
||||
import org.apache.doris.nereids.trees.expressions.Cast;
|
||||
@ -50,7 +51,9 @@ class SimplifyCastRuleTest extends ExpressionRewriteTestHelper {
|
||||
|
||||
@Test
|
||||
public void testSimplify() {
|
||||
executor = new ExpressionRuleExecutor(ImmutableList.of(SimplifyCastRule.INSTANCE));
|
||||
executor = new ExpressionRuleExecutor(ImmutableList.of(
|
||||
ExpressionRewrite.bottomUp(SimplifyCastRule.INSTANCE))
|
||||
);
|
||||
assertRewriteAfterSimplify("CAST('1' AS STRING)", "'1'", StringType.INSTANCE);
|
||||
assertRewriteAfterSimplify("CAST('1' AS VARCHAR)", "'1'",
|
||||
VarcharType.createVarcharType(-1));
|
||||
@ -186,7 +189,7 @@ class SimplifyCastRuleTest extends ExpressionRewriteTestHelper {
|
||||
|
||||
private void assertRewriteAfterSimplify(String expr, String expected, DataType expectedType) {
|
||||
Expression needRewriteExpression = PARSER.parseExpression(expr);
|
||||
Expression rewritten = SimplifyCastRule.INSTANCE.rewrite(needRewriteExpression, context);
|
||||
Expression rewritten = executor.rewrite(needRewriteExpression, context);
|
||||
Expression expectedExpression = PARSER.parseExpression(expected);
|
||||
Assertions.assertEquals(expectedExpression.toSql(), rewritten.toSql());
|
||||
Assertions.assertEquals(expectedType, rewritten.getDataType());
|
||||
|
||||
@ -41,8 +41,12 @@ import org.junit.jupiter.api.Test;
|
||||
class SimplifyComparisonPredicateTest extends ExpressionRewriteTestHelper {
|
||||
@Test
|
||||
void testSimplifyComparisonPredicateRule() {
|
||||
executor = new ExpressionRuleExecutor(
|
||||
ImmutableList.of(SimplifyCastRule.INSTANCE, SimplifyComparisonPredicate.INSTANCE));
|
||||
executor = new ExpressionRuleExecutor(ImmutableList.of(
|
||||
bottomUp(
|
||||
SimplifyCastRule.INSTANCE,
|
||||
SimplifyComparisonPredicate.INSTANCE
|
||||
)
|
||||
));
|
||||
|
||||
Expression dtv2 = new DateTimeV2Literal(1, 1, 1, 1, 1, 1, 0);
|
||||
Expression dt = new DateTimeLiteral(1, 1, 1, 1, 1, 1);
|
||||
@ -87,8 +91,12 @@ class SimplifyComparisonPredicateTest extends ExpressionRewriteTestHelper {
|
||||
|
||||
@Test
|
||||
void testDateTimeV2CmpDateTimeV2() {
|
||||
executor = new ExpressionRuleExecutor(
|
||||
ImmutableList.of(SimplifyCastRule.INSTANCE, SimplifyComparisonPredicate.INSTANCE));
|
||||
executor = new ExpressionRuleExecutor(ImmutableList.of(
|
||||
bottomUp(
|
||||
SimplifyCastRule.INSTANCE,
|
||||
SimplifyComparisonPredicate.INSTANCE
|
||||
)
|
||||
));
|
||||
|
||||
Expression dt = new DateTimeLiteral(1, 1, 1, 1, 1, 1);
|
||||
|
||||
@ -100,18 +108,22 @@ class SimplifyComparisonPredicateTest extends ExpressionRewriteTestHelper {
|
||||
// (cast(0001-01-01 01:01:01 as DATETIMEV2(0)) > 2021-01-01 00:00:00.001)
|
||||
Expression expression = new GreaterThan(left, right);
|
||||
Expression rewrittenExpression = executor.rewrite(typeCoercion(expression), context);
|
||||
Assertions.assertEquals(left.getDataType(), rewrittenExpression.child(0).getDataType());
|
||||
Assertions.assertEquals(dt.getDataType(), rewrittenExpression.child(0).getDataType());
|
||||
|
||||
// (cast(0001-01-01 01:01:01 as DATETIMEV2(0)) < 2021-01-01 00:00:00.001)
|
||||
expression = new GreaterThan(left, right);
|
||||
rewrittenExpression = executor.rewrite(typeCoercion(expression), context);
|
||||
Assertions.assertEquals(left.getDataType(), rewrittenExpression.child(0).getDataType());
|
||||
Assertions.assertEquals(dt.getDataType(), rewrittenExpression.child(0).getDataType());
|
||||
}
|
||||
|
||||
@Test
|
||||
void testRound() {
|
||||
executor = new ExpressionRuleExecutor(
|
||||
ImmutableList.of(SimplifyCastRule.INSTANCE, SimplifyComparisonPredicate.INSTANCE));
|
||||
executor = new ExpressionRuleExecutor(ImmutableList.of(
|
||||
bottomUp(
|
||||
SimplifyCastRule.INSTANCE,
|
||||
SimplifyComparisonPredicate.INSTANCE
|
||||
)
|
||||
));
|
||||
|
||||
Expression left = new Cast(new DateTimeLiteral("2021-01-02 00:00:00.00"), DateTimeV2Type.of(1));
|
||||
Expression right = new DateTimeV2Literal("2021-01-01 23:59:59.99");
|
||||
@ -120,13 +132,14 @@ class SimplifyComparisonPredicateTest extends ExpressionRewriteTestHelper {
|
||||
Expression rewrittenExpression = executor.rewrite(typeCoercion(expression), context);
|
||||
|
||||
// right should round to be 2021-01-02 00:00:00.00
|
||||
Assertions.assertEquals(new DateTimeV2Literal("2021-01-02 00:00:00"), rewrittenExpression.child(1));
|
||||
Assertions.assertEquals(new DateTimeLiteral("2021-01-02 00:00:00"), rewrittenExpression.child(1));
|
||||
}
|
||||
|
||||
@Test
|
||||
void testDoubleLiteral() {
|
||||
executor = new ExpressionRuleExecutor(
|
||||
ImmutableList.of(SimplifyComparisonPredicate.INSTANCE));
|
||||
executor = new ExpressionRuleExecutor(ImmutableList.of(
|
||||
bottomUp(SimplifyComparisonPredicate.INSTANCE)
|
||||
));
|
||||
|
||||
Expression leftChild = new BigIntLiteral(999);
|
||||
Expression left = new Cast(leftChild, DoubleType.INSTANCE);
|
||||
|
||||
@ -39,7 +39,9 @@ class SimplifyDecimalV3ComparisonTest extends ExpressionRewriteTestHelper {
|
||||
Config.enable_decimal_conversion = false;
|
||||
Map<String, Slot> nameToSlot = new HashMap<>();
|
||||
nameToSlot.put("col1", new SlotReference("col1", DecimalV3Type.createDecimalV3Type(15, 2)));
|
||||
executor = new ExpressionRuleExecutor(ImmutableList.of(SimplifyDecimalV3Comparison.INSTANCE));
|
||||
executor = new ExpressionRuleExecutor(ImmutableList.of(
|
||||
bottomUp(SimplifyDecimalV3Comparison.INSTANCE)
|
||||
));
|
||||
assertRewriteAfterSimplify("cast(col1 as decimalv3(27, 9)) > 0.6", "cast(col1 as decimalv3(27, 9)) > 0.6", nameToSlot);
|
||||
}
|
||||
|
||||
@ -48,7 +50,7 @@ class SimplifyDecimalV3ComparisonTest extends ExpressionRewriteTestHelper {
|
||||
if (slotNameToSlot != null) {
|
||||
needRewriteExpression = replaceUnboundSlot(needRewriteExpression, slotNameToSlot);
|
||||
}
|
||||
Expression rewritten = SimplifyDecimalV3Comparison.INSTANCE.rewrite(needRewriteExpression, context);
|
||||
Expression rewritten = executor.rewrite(needRewriteExpression, context);
|
||||
Expression expectedExpression = PARSER.parseExpression(expected);
|
||||
Assertions.assertEquals(expectedExpression.toSql(), rewritten.toSql());
|
||||
}
|
||||
|
||||
@ -32,7 +32,9 @@ import org.junit.jupiter.api.Test;
|
||||
class TopnToMaxTest extends ExpressionRewriteTestHelper {
|
||||
@Test
|
||||
void testSimplifyComparisonPredicateRule() {
|
||||
executor = new ExpressionRuleExecutor(ImmutableList.of(TopnToMax.INSTANCE));
|
||||
executor = new ExpressionRuleExecutor(ImmutableList.of(
|
||||
bottomUp(TopnToMax.INSTANCE)
|
||||
));
|
||||
|
||||
Slot slot = new SlotReference("a", StringType.INSTANCE);
|
||||
assertRewrite(new TopN(slot, Literal.of(1)), new Max(slot));
|
||||
|
||||
@ -127,6 +127,7 @@ class EliminateJoinByFkTest extends TestWithFeService implements MemoPatternMatc
|
||||
void testNullWithPredicate() throws Exception {
|
||||
String sql = "select pri.id1 from pri inner join foreign_null on pri.id1 = foreign_null.id3\n"
|
||||
+ "where pri.id1 = 1";
|
||||
|
||||
PlanChecker.from(connectContext)
|
||||
.analyze(sql)
|
||||
.rewrite()
|
||||
|
||||
@ -17,7 +17,6 @@
|
||||
|
||||
package org.apache.doris.nereids.rules.rewrite;
|
||||
|
||||
import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext;
|
||||
import org.apache.doris.nereids.rules.expression.ExpressionRewriteTestHelper;
|
||||
import org.apache.doris.nereids.rules.expression.rules.OrToIn;
|
||||
import org.apache.doris.nereids.trees.expressions.And;
|
||||
@ -39,7 +38,7 @@ class OrToInTest extends ExpressionRewriteTestHelper {
|
||||
void test1() {
|
||||
String expr = "col1 = 1 or col1 = 2 or col1 = 3 and (col2 = 4)";
|
||||
Expression expression = PARSER.parseExpression(expr);
|
||||
Expression rewritten = new OrToIn().rewrite(expression, new ExpressionRewriteContext(null));
|
||||
Expression rewritten = OrToIn.INSTANCE.rewriteTree(expression, context);
|
||||
Set<InPredicate> inPredicates = rewritten.collect(e -> e instanceof InPredicate);
|
||||
Assertions.assertEquals(1, inPredicates.size());
|
||||
InPredicate inPredicate = inPredicates.iterator().next();
|
||||
@ -62,7 +61,7 @@ class OrToInTest extends ExpressionRewriteTestHelper {
|
||||
void test2() {
|
||||
String expr = "col1 = 1 and col1 = 3 and col2 = 3 or col2 = 4";
|
||||
Expression expression = PARSER.parseExpression(expr);
|
||||
Expression rewritten = new OrToIn().rewrite(expression, new ExpressionRewriteContext(null));
|
||||
Expression rewritten = OrToIn.INSTANCE.rewriteTree(expression, context);
|
||||
Assertions.assertEquals("((((col1 = 1) AND (col1 = 3)) AND (col2 = 3)) OR (col2 = 4))",
|
||||
rewritten.toSql());
|
||||
}
|
||||
@ -71,7 +70,7 @@ class OrToInTest extends ExpressionRewriteTestHelper {
|
||||
void test3() {
|
||||
String expr = "(col1 = 1 or col1 = 2) and (col2 = 3 or col2 = 4)";
|
||||
Expression expression = PARSER.parseExpression(expr);
|
||||
Expression rewritten = new OrToIn().rewrite(expression, new ExpressionRewriteContext(null));
|
||||
Expression rewritten = OrToIn.INSTANCE.rewriteTree(expression, context);
|
||||
List<InPredicate> inPredicates = rewritten.collectToList(e -> e instanceof InPredicate);
|
||||
Assertions.assertEquals(2, inPredicates.size());
|
||||
InPredicate in1 = inPredicates.get(0);
|
||||
@ -95,7 +94,7 @@ class OrToInTest extends ExpressionRewriteTestHelper {
|
||||
String expr = "case when col = 1 or col = 2 or col = 3 then 1"
|
||||
+ " when col = 4 or col = 5 or col = 6 then 1 else 0 end";
|
||||
Expression expression = PARSER.parseExpression(expr);
|
||||
Expression rewritten = new OrToIn().rewrite(expression, new ExpressionRewriteContext(null));
|
||||
Expression rewritten = OrToIn.INSTANCE.rewriteTree(expression, context);
|
||||
Assertions.assertEquals("CASE WHEN col IN (1, 2, 3) THEN 1 WHEN col IN (4, 5, 6) THEN 1 ELSE 0 END",
|
||||
rewritten.toSql());
|
||||
}
|
||||
@ -104,7 +103,7 @@ class OrToInTest extends ExpressionRewriteTestHelper {
|
||||
void test5() {
|
||||
String expr = "col = 1 or (col = 2 and (col = 3 or col = 4 or col = 5))";
|
||||
Expression expression = PARSER.parseExpression(expr);
|
||||
Expression rewritten = new OrToIn().rewrite(expression, new ExpressionRewriteContext(null));
|
||||
Expression rewritten = OrToIn.INSTANCE.rewriteTree(expression, context);
|
||||
Assertions.assertEquals("((col = 1) OR ((col = 2) AND col IN (3, 4, 5)))",
|
||||
rewritten.toSql());
|
||||
}
|
||||
@ -113,7 +112,7 @@ class OrToInTest extends ExpressionRewriteTestHelper {
|
||||
void test6() {
|
||||
String expr = "col = 1 or col = 2 or col in (1, 2, 3)";
|
||||
Expression expression = PARSER.parseExpression(expr);
|
||||
Expression rewritten = new OrToIn().rewrite(expression, new ExpressionRewriteContext(null));
|
||||
Expression rewritten = OrToIn.INSTANCE.rewriteTree(expression, context);
|
||||
Assertions.assertEquals("col IN (1, 2, 3)", rewritten.toSql());
|
||||
}
|
||||
|
||||
@ -121,7 +120,7 @@ class OrToInTest extends ExpressionRewriteTestHelper {
|
||||
void test7() {
|
||||
String expr = "A = 1 or A = 2 or abs(A)=5 or A in (1, 2, 3) or B = 1 or B = 2 or B in (1, 2, 3) or B+1 in (4, 5, 7)";
|
||||
Expression expression = PARSER.parseExpression(expr);
|
||||
Expression rewritten = new OrToIn().rewrite(expression, new ExpressionRewriteContext(null));
|
||||
Expression rewritten = OrToIn.INSTANCE.rewriteTree(expression, context);
|
||||
Assertions.assertEquals("(((A IN (1, 2, 3) OR B IN (1, 2, 3)) OR (abs(A) = 5)) OR (B + 1) IN (4, 5, 7))", rewritten.toSql());
|
||||
}
|
||||
|
||||
@ -129,7 +128,7 @@ class OrToInTest extends ExpressionRewriteTestHelper {
|
||||
void test8() {
|
||||
String expr = "col = 1 or (col = 2 and (col = 3 or col = '4' or col = 5.0))";
|
||||
Expression expression = PARSER.parseExpression(expr);
|
||||
Expression rewritten = new OrToIn().rewrite(expression, new ExpressionRewriteContext(null));
|
||||
Expression rewritten = OrToIn.INSTANCE.rewriteTree(expression, context);
|
||||
Assertions.assertEquals("((col = 1) OR ((col = 2) AND col IN ('4', 3, 5.0)))",
|
||||
rewritten.toSql());
|
||||
}
|
||||
@ -139,7 +138,7 @@ class OrToInTest extends ExpressionRewriteTestHelper {
|
||||
// ensure not rewrite to col2 in (1, 2) or cor 1 in (1, 2)
|
||||
String expr = "col1 IN (1, 2) OR col2 IN (1, 2)";
|
||||
Expression expression = PARSER.parseExpression(expr);
|
||||
Expression rewritten = new OrToIn().rewrite(expression, new ExpressionRewriteContext(null));
|
||||
Expression rewritten = OrToIn.INSTANCE.rewriteTree(expression, context);
|
||||
Assertions.assertEquals("(col1 IN (1, 2) OR col2 IN (1, 2))",
|
||||
rewritten.toSql());
|
||||
}
|
||||
|
||||
@ -47,8 +47,8 @@ import org.junit.jupiter.api.Test;
|
||||
import java.util.Optional;
|
||||
|
||||
public class PushDownFilterThroughAggregationTest implements MemoPatternMatchSupported {
|
||||
private final LogicalOlapScan scan = new LogicalOlapScan(StatementScopeIdGenerator.newRelationId(), PlanConstructor.student,
|
||||
ImmutableList.of(""));
|
||||
private final LogicalOlapScan scan = new LogicalOlapScan(
|
||||
StatementScopeIdGenerator.newRelationId(), PlanConstructor.student, ImmutableList.of(""));
|
||||
|
||||
/*-
|
||||
* origin plan:
|
||||
|
||||
@ -49,6 +49,7 @@ import org.junit.jupiter.api.Test;
|
||||
import java.math.BigDecimal;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
|
||||
public class ComputeSignatureHelperTest {
|
||||
|
||||
@ -419,6 +420,16 @@ public class ComputeSignatureHelperTest {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public <T> Optional<T> getMutableState(String key) {
|
||||
return Optional.empty();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setMutableState(String key, Object value) {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public Expression withChildren(List<Expression> children) {
|
||||
return null;
|
||||
|
||||
@ -20,6 +20,7 @@ package org.apache.doris.nereids.util;
|
||||
import org.apache.doris.analysis.ExplainOptions;
|
||||
import org.apache.doris.nereids.CascadesContext;
|
||||
import org.apache.doris.nereids.NereidsPlanner;
|
||||
import org.apache.doris.nereids.PlanProcess;
|
||||
import org.apache.doris.nereids.StatementContext;
|
||||
import org.apache.doris.nereids.exceptions.AnalysisException;
|
||||
import org.apache.doris.nereids.glue.LogicalPlanAdapter;
|
||||
@ -161,6 +162,25 @@ public class PlanChecker {
|
||||
return this;
|
||||
}
|
||||
|
||||
public PlanChecker printPlanProcess(String sql) {
|
||||
List<PlanProcess> planProcesses = explainPlanProcess(sql);
|
||||
for (PlanProcess row : planProcesses) {
|
||||
System.out.println("RULE: " + row.ruleName + "\nBEFORE:\n"
|
||||
+ row.beforeShape + "\nafter:\n" + row.afterShape);
|
||||
}
|
||||
return this;
|
||||
}
|
||||
|
||||
public List<PlanProcess> explainPlanProcess(String sql) {
|
||||
NereidsParser parser = new NereidsParser();
|
||||
LogicalPlan command = parser.parseSingle(sql);
|
||||
NereidsPlanner planner = new NereidsPlanner(
|
||||
new StatementContext(connectContext, new OriginStatement(sql, 0)));
|
||||
planner.plan(command, PhysicalProperties.ANY, ExplainLevel.ALL_PLAN, true);
|
||||
this.cascadesContext = planner.getCascadesContext();
|
||||
return cascadesContext.getPlanProcesses();
|
||||
}
|
||||
|
||||
public PlanChecker applyTopDown(RuleFactory ruleFactory) {
|
||||
return applyTopDown(ruleFactory.buildRules());
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user