From a49bde8a718ec96f107ac7d44dda8a9017f7ebc4 Mon Sep 17 00:00:00 2001 From: morrySnow <101034200+morrySnow@users.noreply.github.com> Date: Wed, 31 Aug 2022 20:47:22 +0800 Subject: [PATCH] [fix](Nereids)statistics calculator for Project and Aggregate lost some columns (#12196) There are some bugs in Nereids' StatsCalculator. 1. Project: return child column stats directly, so its parents cannot find column stats from project's slot. 2. Aggregate: do not return column that is Alias, its parents cannot find some column stats from Aggregate's slot. 3. All: use SlotReference as key of column to stats map. So we need change SlotReference's equals and hashCode method to just using ExprId as we discussed. --- .../doris/nereids/stats/StatsCalculator.java | 54 ++++++++++--------- .../trees/expressions/SlotReference.java | 9 ++-- .../apache/doris/statistics/ColumnStats.java | 9 ++++ .../nereids/trees/plans/PlanEqualsTest.java | 14 ++--- 4 files changed, 49 insertions(+), 37 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/StatsCalculator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/StatsCalculator.java index 286dc12eff..50052a959b 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/StatsCalculator.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/StatsCalculator.java @@ -59,6 +59,9 @@ import org.apache.doris.statistics.Statistics; import org.apache.doris.statistics.StatsDeriveResult; import org.apache.doris.statistics.TableStats; +import com.google.common.collect.Maps; + +import java.util.AbstractMap; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -250,46 +253,49 @@ public class StatsCalculator extends DefaultPlanVisitor } private StatsDeriveResult computeAggregate(Aggregate aggregate) { - List groupByExprList = aggregate.getGroupByExpressions(); + List groupByExpressions = aggregate.getGroupByExpressions(); StatsDeriveResult childStats = groupExpression.getCopyOfChildStats(0); - Map childSlotColumnStatsMap = childStats.getSlotToColumnStats(); + Map childSlotToColumnStats = childStats.getSlotToColumnStats(); long resultSetCount = 1; - for (Expression expression : groupByExprList) { - List slotRefList = expression.collect(SlotReference.class::isInstance); + for (Expression groupByExpression : groupByExpressions) { + List slotReferences = groupByExpression.collect(SlotReference.class::isInstance); // TODO: Support more complex group expr. // For example: // select max(col1+col3) from t1 group by col1+col3; - if (slotRefList.size() != 1) { + if (slotReferences.size() != 1) { continue; } - SlotReference slotRef = slotRefList.get(0); - ColumnStats columnStats = childSlotColumnStatsMap.get(slotRef); + SlotReference slotReference = slotReferences.get(0); + ColumnStats columnStats = childSlotToColumnStats.get(slotReference); resultSetCount *= columnStats.getNdv(); } - Map slotColumnStatsMap = new HashMap<>(); - List namedExpressionList = aggregate.getOutputExpressions(); + Map slotToColumnStats = Maps.newHashMap(); + List outputExpressions = aggregate.getOutputExpressions(); // TODO: 1. Estimate the output unit size by the type of corresponding AggregateFunction // 2. Handle alias, literal in the output expression list - for (NamedExpression namedExpression : namedExpressionList) { - if (namedExpression instanceof SlotReference) { - slotColumnStatsMap.put((SlotReference) namedExpression, new ColumnStats()); - } + for (NamedExpression outputExpression : outputExpressions) { + slotToColumnStats.put(outputExpression.toSlot(), new ColumnStats()); } - StatsDeriveResult statsDeriveResult = new StatsDeriveResult(resultSetCount, slotColumnStatsMap); + StatsDeriveResult statsDeriveResult = new StatsDeriveResult(resultSetCount, slotToColumnStats); // TODO: Update ColumnStats properly, add new mapping from output slot to ColumnStats return statsDeriveResult; } - // TODO: Update data size and min/max value. + // TODO: do real project on column stats private StatsDeriveResult computeProject(Project project) { - List namedExpressionList = project.getProjects(); - List slotSet = namedExpressionList.stream().flatMap(namedExpression -> { - List slotReferenceList = namedExpression.collect(SlotReference.class::isInstance); - return slotReferenceList.stream(); - }).collect(Collectors.toList()); - StatsDeriveResult stat = groupExpression.getCopyOfChildStats(0); - Map slotColumnStatsMap = stat.getSlotToColumnStats(); - slotColumnStatsMap.entrySet().removeIf(entry -> !slotSet.contains(entry.getKey())); - return stat; + List projections = project.getProjects(); + StatsDeriveResult statsDeriveResult = groupExpression.getCopyOfChildStats(0); + Map childColumnStats = statsDeriveResult.getSlotToColumnStats(); + Map columnsStats = projections.stream().map(projection -> { + List slotReferences = projection.collect(SlotReference.class::isInstance); + if (slotReferences.isEmpty()) { + return new AbstractMap.SimpleEntry<>(projection.toSlot(), ColumnStats.createDefaultColumnStats()); + } else { + // TODO: just a trick here, need to do real project on column stats + return new AbstractMap.SimpleEntry<>(projection.toSlot(), childColumnStats.get(slotReferences.get(0))); + } + }).collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); + statsDeriveResult.setSlotToColumnStats(columnsStats); + return statsDeriveResult; } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/SlotReference.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/SlotReference.java index 24a601a2e5..cc77e75f7f 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/SlotReference.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/SlotReference.java @@ -116,16 +116,13 @@ public class SlotReference extends Slot { return false; } SlotReference that = (SlotReference) o; - return nullable == that.nullable - && dataType.equals(that.dataType) - && exprId.equals(that.exprId) - && name.equals(that.name) - && qualifier.equals(that.qualifier); + return exprId.equals(that.exprId); + } @Override public int hashCode() { - return Objects.hash(exprId, name, qualifier, nullable); + return Objects.hash(exprId); } public Column getColumn() { diff --git a/fe/fe-core/src/main/java/org/apache/doris/statistics/ColumnStats.java b/fe/fe-core/src/main/java/org/apache/doris/statistics/ColumnStats.java index 30fb619630..cd28ea263d 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/statistics/ColumnStats.java +++ b/fe/fe-core/src/main/java/org/apache/doris/statistics/ColumnStats.java @@ -76,6 +76,15 @@ public class ColumnStats { private LiteralExpr minValue; private LiteralExpr maxValue; + public static ColumnStats createDefaultColumnStats() { + ColumnStats columnStats = new ColumnStats(); + columnStats.setAvgSize(1); + columnStats.setMaxSize(1); + columnStats.setNdv(1); + columnStats.setNumNulls(0); + return columnStats; + } + public ColumnStats(ColumnStats other) { this.ndv = other.ndv; this.avgSize = other.avgSize; diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/plans/PlanEqualsTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/plans/PlanEqualsTest.java index a9e9c6c094..1fd9680a03 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/plans/PlanEqualsTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/plans/PlanEqualsTest.java @@ -65,7 +65,7 @@ public class PlanEqualsTest { Assertions.assertEquals(expected, actual); LogicalAggregate unexpected = new LogicalAggregate<>(Lists.newArrayList(), ImmutableList.of( - new SlotReference(new ExprId(0), "b", BigIntType.INSTANCE, true, Lists.newArrayList())), + new SlotReference(new ExprId(1), "b", BigIntType.INSTANCE, true, Lists.newArrayList())), child); Assertions.assertNotEquals(unexpected, actual); } @@ -95,8 +95,8 @@ public class PlanEqualsTest { Assertions.assertEquals(expected, actual); LogicalJoin unexpected = new LogicalJoin<>(JoinType.INNER_JOIN, Lists.newArrayList(new EqualTo( - new SlotReference("a", BigIntType.INSTANCE, false, Lists.newArrayList()), - new SlotReference("b", BigIntType.INSTANCE, true, Lists.newArrayList()))), + new SlotReference(new ExprId(2), "a", BigIntType.INSTANCE, false, Lists.newArrayList()), + new SlotReference(new ExprId(3), "b", BigIntType.INSTANCE, true, Lists.newArrayList()))), Optional.empty(), left, right); Assertions.assertNotEquals(unexpected, actual); } @@ -151,7 +151,7 @@ public class PlanEqualsTest { LogicalSort unexpected = new LogicalSort<>( ImmutableList.of(new OrderKey( - new SlotReference(new ExprId(1), "a", BigIntType.INSTANCE, true, Lists.newArrayList()), true, + new SlotReference(new ExprId(2), "a", BigIntType.INSTANCE, true, Lists.newArrayList()), true, true)), child); Assertions.assertNotEquals(unexpected, actual); @@ -211,8 +211,8 @@ public class PlanEqualsTest { PhysicalHashJoin unexpected = new PhysicalHashJoin<>(JoinType.INNER_JOIN, Lists.newArrayList(new EqualTo( - new SlotReference("a", BigIntType.INSTANCE, false, Lists.newArrayList()), - new SlotReference("b", BigIntType.INSTANCE, true, Lists.newArrayList()))), + new SlotReference(new ExprId(2), "a", BigIntType.INSTANCE, false, Lists.newArrayList()), + new SlotReference(new ExprId(3), "b", BigIntType.INSTANCE, true, Lists.newArrayList()))), Optional.empty(), logicalProperties, left, right); Assertions.assertNotEquals(unexpected, actual); } @@ -288,7 +288,7 @@ public class PlanEqualsTest { PhysicalQuickSort unexpected = new PhysicalQuickSort<>( ImmutableList.of(new OrderKey( - new SlotReference(new ExprId(1), "a", BigIntType.INSTANCE, true, Lists.newArrayList()), true, + new SlotReference(new ExprId(2), "a", BigIntType.INSTANCE, true, Lists.newArrayList()), true, true)), logicalProperties, child);