From 64e9eab0dd05231b6d7f46fb441e281e10a69a64 Mon Sep 17 00:00:00 2001 From: minghong Date: Thu, 29 Jun 2023 16:37:05 +0800 Subject: [PATCH] [fix](nereids)update Agg stats estimation #21300 Agg stats estimation should use the biggest groupby key's NDV as base, and multiply expansion factor, which is calculated by other groupby key' ndv. Before, we use the smallest ndv as base --- .../doris/nereids/stats/StatsCalculator.java | 3 +- .../shape/query24.out | 23 +++++----- .../shape/query31.out | 46 +++++++++---------- 3 files changed, 37 insertions(+), 35 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/StatsCalculator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/StatsCalculator.java index acf24961df..385c0227ea 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/StatsCalculator.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/StatsCalculator.java @@ -127,6 +127,7 @@ import org.apache.logging.log4j.Logger; import java.util.AbstractMap.SimpleEntry; import java.util.Collections; +import java.util.Comparator; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -689,7 +690,7 @@ public class StatsCalculator extends DefaultPlanVisitor { if (groupByCount > 0) { List groupByNdvs = groupByColStats.values().stream() .map(colStats -> colStats.ndv) - .sorted().collect(Collectors.toList()); + .sorted(Comparator.reverseOrder()).collect(Collectors.toList()); rowCount = groupByNdvs.get(0); for (int groupByIndex = 1; groupByIndex < groupByCount; ++groupByIndex) { rowCount *= Math.max(1, groupByNdvs.get(groupByIndex) * Math.pow( diff --git a/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query24.out b/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query24.out index 70f424e95c..c887e96371 100644 --- a/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query24.out +++ b/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query24.out @@ -32,21 +32,22 @@ CteAnchor[cteId= ( CTEId#0=] ) ----------------PhysicalProject ------------------PhysicalOlapScan[store_returns] --PhysicalQuickSort -----PhysicalQuickSort -------PhysicalProject ---------NestedLoopJoin[INNER_JOIN](cast(paid as DOUBLE) > cast((0.05 * avg(netpaid)) as DOUBLE)) -----------PhysicalAssertNumRows -------------PhysicalProject ---------------hashAgg[GLOBAL] -----------------PhysicalDistribute -------------------hashAgg[LOCAL] ---------------------PhysicalProject -----------------------CteConsumer[cteId= ( CTEId#0=] ) -----------PhysicalDistribute +----PhysicalDistribute +------PhysicalQuickSort +--------PhysicalProject +----------NestedLoopJoin[INNER_JOIN](cast(paid as DOUBLE) > cast((0.05 * avg(netpaid)) as DOUBLE)) ------------hashAgg[GLOBAL] --------------PhysicalDistribute ----------------hashAgg[LOCAL] ------------------PhysicalProject --------------------filter((cast(i_color as VARCHAR(*)) = 'beige')) ----------------------CteConsumer[cteId= ( CTEId#0=] ) +------------PhysicalDistribute +--------------PhysicalAssertNumRows +----------------PhysicalProject +------------------hashAgg[GLOBAL] +--------------------PhysicalDistribute +----------------------hashAgg[LOCAL] +------------------------PhysicalProject +--------------------------CteConsumer[cteId= ( CTEId#0=] ) diff --git a/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query31.out b/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query31.out index e835f410d8..12a7b7db31 100644 --- a/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query31.out +++ b/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query31.out @@ -44,33 +44,33 @@ CteAnchor[cteId= ( CTEId#6=] ) --------PhysicalQuickSort ----------PhysicalProject ------------hashJoin[INNER_JOIN](ws1.ca_county = ws3.ca_county)(CASE WHEN (web_sales > 0.00) THEN (cast(web_sales as DECIMALV3(38, 8)) / web_sales) ELSE NULL END > CASE WHEN (store_sales > 0.00) THEN (cast(store_sales as DECIMALV3(38, 8)) / store_sales) ELSE NULL END) ---------------PhysicalProject -----------------filter((ws3.d_year = 2000)(ws3.d_qoy = 3)) -------------------CteConsumer[cteId= ( CTEId#7=] ) --------------PhysicalDistribute ----------------PhysicalProject -------------------hashJoin[INNER_JOIN](ss2.ca_county = ss3.ca_county) ---------------------PhysicalDistribute -----------------------PhysicalProject -------------------------filter((ss3.d_year = 2000)(ss3.d_qoy = 3)) ---------------------------CteConsumer[cteId= ( CTEId#6=] ) ---------------------hashJoin[INNER_JOIN](ss1.ca_county = ss2.ca_county)(CASE WHEN (web_sales > 0.00) THEN (cast(web_sales as DECIMALV3(38, 8)) / web_sales) ELSE NULL END > CASE WHEN (store_sales > 0.00) THEN (cast(store_sales as DECIMALV3(38, 8)) / store_sales) ELSE NULL END) -----------------------PhysicalDistribute -------------------------PhysicalProject ---------------------------filter((ss2.d_year = 2000)(ss2.d_qoy = 2)) -----------------------------CteConsumer[cteId= ( CTEId#6=] ) -----------------------hashJoin[INNER_JOIN](ss1.ca_county = ws1.ca_county) +------------------filter((ws3.d_year = 2000)(ws3.d_qoy = 3)) +--------------------CteConsumer[cteId= ( CTEId#7=] ) +--------------PhysicalProject +----------------hashJoin[INNER_JOIN](ss2.ca_county = ss3.ca_county) +------------------PhysicalDistribute +--------------------PhysicalProject +----------------------filter((ss3.d_year = 2000)(ss3.d_qoy = 3)) +------------------------CteConsumer[cteId= ( CTEId#6=] ) +------------------hashJoin[INNER_JOIN](ws1.ca_county = ws2.ca_county)(CASE WHEN (web_sales > 0.00) THEN (cast(web_sales as DECIMALV3(38, 8)) / web_sales) ELSE NULL END > CASE WHEN (store_sales > 0.00) THEN (cast(store_sales as DECIMALV3(38, 8)) / store_sales) ELSE NULL END) +--------------------hashJoin[INNER_JOIN](ss1.ca_county = ws1.ca_county) +----------------------hashJoin[INNER_JOIN](ss1.ca_county = ss2.ca_county) ------------------------PhysicalDistribute --------------------------PhysicalProject ----------------------------filter((ss1.d_year = 2000)(ss1.d_qoy = 1)) ------------------------------CteConsumer[cteId= ( CTEId#6=] ) -------------------------hashJoin[INNER_JOIN](ws1.ca_county = ws2.ca_county) ---------------------------PhysicalDistribute -----------------------------PhysicalProject -------------------------------filter((ws1.d_year = 2000)(ws1.d_qoy = 1)) ---------------------------------CteConsumer[cteId= ( CTEId#7=] ) ---------------------------PhysicalDistribute -----------------------------PhysicalProject -------------------------------filter((ws2.d_qoy = 2)(ws2.d_year = 2000)) ---------------------------------CteConsumer[cteId= ( CTEId#7=] ) +------------------------PhysicalDistribute +--------------------------PhysicalProject +----------------------------filter((ss2.d_year = 2000)(ss2.d_qoy = 2)) +------------------------------CteConsumer[cteId= ( CTEId#6=] ) +----------------------PhysicalDistribute +------------------------PhysicalProject +--------------------------filter((ws1.d_year = 2000)(ws1.d_qoy = 1)) +----------------------------CteConsumer[cteId= ( CTEId#7=] ) +--------------------PhysicalDistribute +----------------------PhysicalProject +------------------------filter((ws2.d_qoy = 2)(ws2.d_year = 2000)) +--------------------------CteConsumer[cteId= ( CTEId#7=] )