[feature](Nereids): enable convert CASE WHEN to IF (#24050)

enable rule to convert CASE WHEN to IF.
This commit is contained in:
jakevin
2023-09-08 16:58:33 +08:00
committed by GitHub
parent c0a41dc0f8
commit 161520feb4
20 changed files with 123 additions and 39 deletions

View File

@ -18,6 +18,7 @@
package org.apache.doris.nereids.rules.expression;
import org.apache.doris.nereids.rules.expression.rules.ArrayContainToArrayOverlap;
import org.apache.doris.nereids.rules.expression.rules.CaseWhenToIf;
import org.apache.doris.nereids.rules.expression.rules.DistinctPredicatesRule;
import org.apache.doris.nereids.rules.expression.rules.ExtractCommonFactorRule;
import org.apache.doris.nereids.rules.expression.rules.OrToIn;
@ -42,8 +43,8 @@ public class ExpressionOptimization extends ExpressionRewrite {
SimplifyDecimalV3Comparison.INSTANCE,
SimplifyRange.INSTANCE,
OrToIn.INSTANCE,
ArrayContainToArrayOverlap.INSTANCE
ArrayContainToArrayOverlap.INSTANCE,
CaseWhenToIf.INSTANCE
);
private static final ExpressionRuleExecutor EXECUTOR = new ExpressionRuleExecutor(OPTIMIZE_REWRITE_RULES);

View File

@ -59,6 +59,7 @@ import org.apache.doris.nereids.trees.expressions.functions.scalar.FromDays;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Hour;
import org.apache.doris.nereids.trees.expressions.functions.scalar.HoursDiff;
import org.apache.doris.nereids.trees.expressions.functions.scalar.HoursSub;
import org.apache.doris.nereids.trees.expressions.functions.scalar.If;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Least;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Minute;
import org.apache.doris.nereids.trees.expressions.functions.scalar.MinutesAdd;
@ -134,13 +135,25 @@ public class ExpressionEstimation extends ExpressionVisitor<ColumnStatistic, Sta
//TODO: case-when need to re-implemented
@Override
public ColumnStatistic visitCaseWhen(CaseWhen caseWhen, Statistics context) {
ColumnStatisticBuilder columnStat = new ColumnStatisticBuilder();
columnStat.setNdv(caseWhen.getWhenClauses().size() + 1);
columnStat.setMinValue(0);
columnStat.setMaxValue(Double.MAX_VALUE);
columnStat.setAvgSizeByte(8);
columnStat.setNumNulls(0);
return columnStat.build();
return new ColumnStatisticBuilder()
.setNdv(caseWhen.getWhenClauses().size() + 1)
.setMinValue(0)
.setMaxValue(Double.MAX_VALUE)
.setAvgSizeByte(8)
.setNumNulls(0)
.build();
}
@Override
public ColumnStatistic visitIf(If function, Statistics context) {
// TODO: copy from visitCaseWhen, polish them.
return new ColumnStatisticBuilder()
.setNdv(2)
.setMinValue(0)
.setMaxValue(Double.MAX_VALUE)
.setAvgSizeByte(8)
.setNumNulls(0)
.build();
}
@Override

View File

@ -42,7 +42,7 @@ PhysicalCteAnchor ( cteId=CTEId#0 )
------PhysicalDistribute
--------PhysicalTopN
----------PhysicalProject
------------hashJoin[INNER_JOIN](t_s_firstyear.customer_id = t_w_secyear.customer_id)(CASE WHEN (year_total > 0.00) THEN (cast(year_total as DOUBLE) / cast(year_total as DOUBLE)) ELSE 0 END > CASE WHEN (year_total > 0.00) THEN (cast(year_total as DOUBLE) / cast(year_total as DOUBLE)) ELSE 0 END)
------------hashJoin[INNER_JOIN](t_s_firstyear.customer_id = t_w_secyear.customer_id)(if((year_total > 0.00), (cast(year_total as DOUBLE) / cast(year_total as DOUBLE)), 0) > if((year_total > 0.00), (cast(year_total as DOUBLE) / cast(year_total as DOUBLE)), 0))
--------------hashJoin[INNER_JOIN](t_s_secyear.customer_id = t_s_firstyear.customer_id)
----------------hashJoin[INNER_JOIN](t_s_firstyear.customer_id = t_w_firstyear.customer_id)
------------------PhysicalDistribute

View File

@ -3,24 +3,25 @@
PhysicalResultSink
--PhysicalTopN
----PhysicalDistribute
------filter((CASE WHEN (inv_before > 0) THEN (cast(inv_after as DOUBLE) / cast(inv_before as DOUBLE)) ELSE NULL END >= cast((2.000000 / 3.0) as DOUBLE))(CASE WHEN (inv_before > 0) THEN (cast(inv_after as DOUBLE) / cast(inv_before as DOUBLE)) ELSE NULL END <= 1.5))
--------hashAgg[GLOBAL]
----------PhysicalDistribute
------------hashAgg[LOCAL]
--------------PhysicalProject
----------------hashJoin[INNER_JOIN](inventory.inv_warehouse_sk = warehouse.w_warehouse_sk)
------------------hashJoin[INNER_JOIN](inventory.inv_date_sk = date_dim.d_date_sk)
--------------------hashJoin[INNER_JOIN](item.i_item_sk = inventory.inv_item_sk)
----------------------PhysicalOlapScan[inventory]
------PhysicalTopN
--------filter((if((inv_before > 0), (cast(inv_after as DOUBLE) / cast(inv_before as DOUBLE)), NULL) <= 1.5)(if((inv_before > 0), (cast(inv_after as DOUBLE) / cast(inv_before as DOUBLE)), NULL) >= cast((2.000000 / 3.0) as DOUBLE)))
----------hashAgg[GLOBAL]
------------PhysicalDistribute
--------------hashAgg[LOCAL]
----------------PhysicalProject
------------------hashJoin[INNER_JOIN](inventory.inv_warehouse_sk = warehouse.w_warehouse_sk)
--------------------hashJoin[INNER_JOIN](inventory.inv_date_sk = date_dim.d_date_sk)
----------------------hashJoin[INNER_JOIN](item.i_item_sk = inventory.inv_item_sk)
------------------------PhysicalOlapScan[inventory]
------------------------PhysicalDistribute
--------------------------PhysicalProject
----------------------------filter((item.i_current_price <= 1.49)(item.i_current_price >= 0.99))
------------------------------PhysicalOlapScan[item]
----------------------PhysicalDistribute
------------------------PhysicalProject
--------------------------filter((item.i_current_price <= 1.49)(item.i_current_price >= 0.99))
----------------------------PhysicalOlapScan[item]
--------------------------filter((date_dim.d_date >= 2002-01-28)(date_dim.d_date <= 2002-03-29))
----------------------------PhysicalOlapScan[date_dim]
--------------------PhysicalDistribute
----------------------PhysicalProject
------------------------filter((date_dim.d_date >= 2002-01-28)(date_dim.d_date <= 2002-03-29))
--------------------------PhysicalOlapScan[date_dim]
------------------PhysicalDistribute
--------------------PhysicalProject
----------------------PhysicalOlapScan[warehouse]
------------------------PhysicalOlapScan[warehouse]

View File

@ -45,7 +45,7 @@ PhysicalCteAnchor ( cteId=CTEId#0 )
--------PhysicalDistribute
----------PhysicalQuickSort
------------PhysicalProject
--------------hashJoin[INNER_JOIN](ws1.ca_county = ws3.ca_county)(CASE WHEN (web_sales > 0.00) THEN (cast(web_sales as DOUBLE) / cast(web_sales as DOUBLE)) ELSE NULL END > CASE WHEN (store_sales > 0.00) THEN (cast(store_sales as DOUBLE) / cast(store_sales as DOUBLE)) ELSE NULL END)
--------------hashJoin[INNER_JOIN](ws1.ca_county = ws3.ca_county)(if((web_sales > 0.00), (cast(web_sales as DOUBLE) / cast(web_sales as DOUBLE)), NULL) > if((store_sales > 0.00), (cast(store_sales as DOUBLE) / cast(store_sales as DOUBLE)), NULL))
----------------PhysicalDistribute
------------------PhysicalProject
--------------------filter((ws3.d_year = 2000)(ws3.d_qoy = 3))
@ -56,7 +56,7 @@ PhysicalCteAnchor ( cteId=CTEId#0 )
----------------------PhysicalProject
------------------------filter((ss3.d_year = 2000)(ss3.d_qoy = 3))
--------------------------PhysicalCteConsumer ( cteId=CTEId#0 )
--------------------hashJoin[INNER_JOIN](ws1.ca_county = ws2.ca_county)(CASE WHEN (web_sales > 0.00) THEN (cast(web_sales as DOUBLE) / cast(web_sales as DOUBLE)) ELSE NULL END > CASE WHEN (store_sales > 0.00) THEN (cast(store_sales as DOUBLE) / cast(store_sales as DOUBLE)) ELSE NULL END)
--------------------hashJoin[INNER_JOIN](ws1.ca_county = ws2.ca_county)(if((web_sales > 0.00), (cast(web_sales as DOUBLE) / cast(web_sales as DOUBLE)), NULL) > if((store_sales > 0.00), (cast(store_sales as DOUBLE) / cast(store_sales as DOUBLE)), NULL))
----------------------hashJoin[INNER_JOIN](ss1.ca_county = ws1.ca_county)
------------------------hashJoin[INNER_JOIN](ss1.ca_county = ss2.ca_county)
--------------------------PhysicalDistribute

View File

@ -25,7 +25,7 @@ PhysicalResultSink
------------------------------------PhysicalOlapScan[date_dim]
----------------------------PhysicalDistribute
------------------------------PhysicalProject
--------------------------------filter(((cast(hd_buy_potential as VARCHAR(*)) = '1001-5000') OR (cast(hd_buy_potential as VARCHAR(*)) = '0-500'))(household_demographics.hd_vehicle_count > 0)(CASE WHEN (hd_vehicle_count > 0) THEN (cast(hd_dep_count as DOUBLE) / cast(hd_vehicle_count as DOUBLE)) ELSE NULL END > 1.2))
--------------------------------filter(((cast(hd_buy_potential as VARCHAR(*)) = '1001-5000') OR (cast(hd_buy_potential as VARCHAR(*)) = '0-500'))(household_demographics.hd_vehicle_count > 0)(if((hd_vehicle_count > 0), (cast(hd_dep_count as DOUBLE) / cast(hd_vehicle_count as DOUBLE)), NULL) > 1.2))
----------------------------------PhysicalOlapScan[household_demographics]
--------------------------PhysicalDistribute
----------------------------PhysicalProject

View File

@ -3,7 +3,7 @@
PhysicalCteAnchor ( cteId=CTEId#0 )
--PhysicalCteProducer ( cteId=CTEId#0 )
----PhysicalProject
------filter((CASE WHEN (mean = 0) THEN 0 ELSE (stdev / mean) END > 1))
------filter((if((mean = 0), 0, (stdev / mean)) > 1))
--------hashAgg[GLOBAL]
----------PhysicalDistribute
------------hashAgg[LOCAL]

View File

@ -62,11 +62,11 @@ PhysicalCteAnchor ( cteId=CTEId#0 )
------PhysicalDistribute
--------PhysicalTopN
----------PhysicalProject
------------hashJoin[INNER_JOIN](t_s_firstyear.customer_id = t_w_secyear.customer_id)(CASE WHEN (year_total > 0.000000) THEN (cast(year_total as DOUBLE) / cast(year_total as DOUBLE)) ELSE NULL END > CASE WHEN (year_total > 0.000000) THEN (cast(year_total as DOUBLE) / cast(year_total as DOUBLE)) ELSE NULL END)
------------hashJoin[INNER_JOIN](t_s_firstyear.customer_id = t_w_secyear.customer_id)(if((year_total > 0.000000), (cast(year_total as DOUBLE) / cast(year_total as DOUBLE)), NULL) > if((year_total > 0.000000), (cast(year_total as DOUBLE) / cast(year_total as DOUBLE)), NULL))
--------------PhysicalProject
----------------hashJoin[INNER_JOIN](t_s_firstyear.customer_id = t_w_firstyear.customer_id)
------------------PhysicalProject
--------------------hashJoin[INNER_JOIN](t_s_firstyear.customer_id = t_c_secyear.customer_id)(CASE WHEN (year_total > 0.000000) THEN (cast(year_total as DOUBLE) / cast(year_total as DOUBLE)) ELSE NULL END > CASE WHEN (year_total > 0.000000) THEN (cast(year_total as DOUBLE) / cast(year_total as DOUBLE)) ELSE NULL END)
--------------------hashJoin[INNER_JOIN](t_s_firstyear.customer_id = t_c_secyear.customer_id)(if((year_total > 0.000000), (cast(year_total as DOUBLE) / cast(year_total as DOUBLE)), NULL) > if((year_total > 0.000000), (cast(year_total as DOUBLE) / cast(year_total as DOUBLE)), NULL))
----------------------PhysicalProject
------------------------hashJoin[INNER_JOIN](t_s_secyear.customer_id = t_s_firstyear.customer_id)
--------------------------hashJoin[INNER_JOIN](t_s_firstyear.customer_id = t_c_firstyear.customer_id)

View File

@ -43,7 +43,7 @@ PhysicalCteAnchor ( cteId=CTEId#0 )
------------------------PhysicalCteConsumer ( cteId=CTEId#0 )
--------------------PhysicalDistribute
----------------------PhysicalProject
------------------------filter((CASE WHEN (avg_monthly_sales > 0.0000) THEN (abs((cast(sum_sales as DOUBLE) - cast(avg_monthly_sales as DOUBLE))) / cast(avg_monthly_sales as DOUBLE)) ELSE NULL END > 0.1)(v2.d_year = 2001)(v2.avg_monthly_sales > 0.0000))
------------------------filter((if((avg_monthly_sales > 0.0000), (abs((cast(sum_sales as DOUBLE) - cast(avg_monthly_sales as DOUBLE))) / cast(avg_monthly_sales as DOUBLE)), NULL) > 0.1)(v2.d_year = 2001)(v2.avg_monthly_sales > 0.0000))
--------------------------PhysicalCteConsumer ( cteId=CTEId#0 )
----------------PhysicalDistribute
------------------PhysicalProject

View File

@ -5,7 +5,7 @@ PhysicalResultSink
----PhysicalDistribute
------PhysicalTopN
--------PhysicalProject
----------filter((CASE WHEN (avg_quarterly_sales > 0.0000) THEN (abs((cast(sum_sales as DOUBLE) - cast(avg_quarterly_sales as DOUBLE))) / cast(avg_quarterly_sales as DOUBLE)) ELSE NULL END > 0.1))
----------filter((if((avg_quarterly_sales > 0.0000), (abs((cast(sum_sales as DOUBLE) - cast(avg_quarterly_sales as DOUBLE))) / cast(avg_quarterly_sales as DOUBLE)), NULL) > 0.1))
------------PhysicalWindow
--------------PhysicalQuickSort
----------------PhysicalDistribute

View File

@ -43,7 +43,7 @@ PhysicalCteAnchor ( cteId=CTEId#0 )
----------------------PhysicalCteConsumer ( cteId=CTEId#0 )
------------------PhysicalDistribute
--------------------PhysicalProject
----------------------filter((CASE WHEN (avg_monthly_sales > 0.0000) THEN (abs((cast(sum_sales as DOUBLE) - cast(avg_monthly_sales as DOUBLE))) / cast(avg_monthly_sales as DOUBLE)) ELSE NULL END > 0.1)(v2.d_year = 1999)(v2.avg_monthly_sales > 0.0000))
----------------------filter((v2.d_year = 1999)(if((avg_monthly_sales > 0.0000), (abs((cast(sum_sales as DOUBLE) - cast(avg_monthly_sales as DOUBLE))) / cast(avg_monthly_sales as DOUBLE)), NULL) > 0.1)(v2.avg_monthly_sales > 0.0000))
------------------------PhysicalCteConsumer ( cteId=CTEId#0 )
----------------PhysicalDistribute
------------------PhysicalProject

View File

@ -5,7 +5,7 @@ PhysicalResultSink
----PhysicalDistribute
------PhysicalTopN
--------PhysicalProject
----------filter((CASE WHEN (avg_monthly_sales > 0.0000) THEN (abs((cast(sum_sales as DOUBLE) - cast(avg_monthly_sales as DOUBLE))) / cast(avg_monthly_sales as DOUBLE)) ELSE NULL END > 0.1))
----------filter((if((avg_monthly_sales > 0.0000), (abs((cast(sum_sales as DOUBLE) - cast(avg_monthly_sales as DOUBLE))) / cast(avg_monthly_sales as DOUBLE)), NULL) > 0.1))
------------PhysicalWindow
--------------PhysicalQuickSort
----------------PhysicalDistribute

View File

@ -25,7 +25,7 @@ PhysicalResultSink
------------------------------------PhysicalOlapScan[date_dim]
----------------------------PhysicalDistribute
------------------------------PhysicalProject
--------------------------------filter(((cast(hd_buy_potential as VARCHAR(*)) = '501-1000') OR (cast(hd_buy_potential as VARCHAR(*)) = 'Unknown'))(household_demographics.hd_vehicle_count > 0)(CASE WHEN (hd_vehicle_count > 0) THEN (cast(hd_dep_count as DOUBLE) / cast(hd_vehicle_count as DOUBLE)) ELSE NULL END > 1))
--------------------------------filter(((cast(hd_buy_potential as VARCHAR(*)) = '501-1000') OR (cast(hd_buy_potential as VARCHAR(*)) = 'Unknown'))(household_demographics.hd_vehicle_count > 0)(if((hd_vehicle_count > 0), (cast(hd_dep_count as DOUBLE) / cast(hd_vehicle_count as DOUBLE)), NULL) > 1))
----------------------------------PhysicalOlapScan[household_demographics]
--------------------------PhysicalDistribute
----------------------------PhysicalProject

View File

@ -42,7 +42,7 @@ PhysicalCteAnchor ( cteId=CTEId#0 )
------PhysicalDistribute
--------PhysicalTopN
----------PhysicalProject
------------hashJoin[INNER_JOIN](t_s_firstyear.customer_id = t_w_firstyear.customer_id)(CASE WHEN (year_total > 0) THEN (year_total / year_total) ELSE NULL END > CASE WHEN (year_total > 0) THEN (year_total / year_total) ELSE NULL END)
------------hashJoin[INNER_JOIN](t_s_firstyear.customer_id = t_w_firstyear.customer_id)(if((year_total > 0), (year_total / year_total), NULL) > if((year_total > 0), (year_total / year_total), NULL))
--------------hashJoin[INNER_JOIN](t_s_firstyear.customer_id = t_w_secyear.customer_id)
----------------hashJoin[INNER_JOIN](t_s_secyear.customer_id = t_s_firstyear.customer_id)
------------------PhysicalDistribute

View File

@ -6,7 +6,7 @@ PhysicalResultSink
------PhysicalDistribute
--------PhysicalTopN
----------PhysicalProject
------------filter((CASE WHEN ( not (avg_monthly_sales = 0.0000)) THEN (abs((cast(sum_sales as DOUBLE) - cast(avg_monthly_sales as DOUBLE))) / cast(avg_monthly_sales as DOUBLE)) ELSE NULL END > 0.1))
------------filter((if(( not (avg_monthly_sales = 0.0000)), (abs((cast(sum_sales as DOUBLE) - cast(avg_monthly_sales as DOUBLE))) / cast(avg_monthly_sales as DOUBLE)), NULL) > 0.1))
--------------PhysicalWindow
----------------PhysicalQuickSort
------------------PhysicalDistribute

View File

@ -0,0 +1,69 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
suite("test_case_when_to_if") {
sql 'set enable_nereids_planner=true'
sql 'set enable_fallback_to_original_planner=false'
sql 'drop table if exists test_case_when_to_if;'
sql '''create table test_case_when_to_if (k1 int, k2 int) distributed by hash(k1) buckets 3 properties('replication_num' = '1');'''
// else is empty
sql '''
select k2,
sum(case when (k1=1) then 1 end) sum1
from test_case_when_to_if
group by k2;
'''
res = sql '''
explain rewritten plan select k2,
sum(case when (k1=1) then 1 end) sum1
from test_case_when_to_if
group by k2;
'''
assertTrue(res.toString().contains("if"))
// else is null
sql '''
select k2,
sum(case when (k1=1) then 1 else null end) sum1
from test_case_when_to_if
group by k2;
'''
res = sql '''
explain rewritten plan select k2,
sum(case when (k1=1) then 1 else null end) sum1
from test_case_when_to_if
group by k2;
'''
assertTrue(res.toString().contains("if"))
sql '''
select k2,
sum(case when (k1>0) then k1 else abs(k1) end) sum1
from test_case_when_to_if
group by k2;
'''
res = sql '''
explain rewritten plan select k2,
sum(case when (k1>0) then k1 else abs(k1) end) sum1
from test_case_when_to_if
group by k2;
'''
assertTrue(res.toString().contains("if"))
}