[fix](Nereids)statistics calculator for Project and Aggregate lost some columns (#12196)

There are some bugs in Nereids' StatsCalculator.

1. Project: return child column stats directly, so its parents cannot find column stats from project's slot.
2. Aggregate: do not return column that is Alias, its parents cannot find some column stats from Aggregate's slot.
3. All: use SlotReference as key of column to stats map. So we need change SlotReference's equals and hashCode method to just using ExprId as we discussed.
This commit is contained in:
morrySnow
2022-08-31 20:47:22 +08:00
committed by GitHub
parent 57051d3591
commit a49bde8a71
4 changed files with 49 additions and 37 deletions

View File

@ -59,6 +59,9 @@ import org.apache.doris.statistics.Statistics;
import org.apache.doris.statistics.StatsDeriveResult;
import org.apache.doris.statistics.TableStats;
import com.google.common.collect.Maps;
import java.util.AbstractMap;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@ -250,46 +253,49 @@ public class StatsCalculator extends DefaultPlanVisitor<StatsDeriveResult, Void>
}
private StatsDeriveResult computeAggregate(Aggregate aggregate) {
List<Expression> groupByExprList = aggregate.getGroupByExpressions();
List<Expression> groupByExpressions = aggregate.getGroupByExpressions();
StatsDeriveResult childStats = groupExpression.getCopyOfChildStats(0);
Map<Slot, ColumnStats> childSlotColumnStatsMap = childStats.getSlotToColumnStats();
Map<Slot, ColumnStats> childSlotToColumnStats = childStats.getSlotToColumnStats();
long resultSetCount = 1;
for (Expression expression : groupByExprList) {
List<SlotReference> slotRefList = expression.collect(SlotReference.class::isInstance);
for (Expression groupByExpression : groupByExpressions) {
List<SlotReference> slotReferences = groupByExpression.collect(SlotReference.class::isInstance);
// TODO: Support more complex group expr.
// For example:
// select max(col1+col3) from t1 group by col1+col3;
if (slotRefList.size() != 1) {
if (slotReferences.size() != 1) {
continue;
}
SlotReference slotRef = slotRefList.get(0);
ColumnStats columnStats = childSlotColumnStatsMap.get(slotRef);
SlotReference slotReference = slotReferences.get(0);
ColumnStats columnStats = childSlotToColumnStats.get(slotReference);
resultSetCount *= columnStats.getNdv();
}
Map<Slot, ColumnStats> slotColumnStatsMap = new HashMap<>();
List<NamedExpression> namedExpressionList = aggregate.getOutputExpressions();
Map<Slot, ColumnStats> slotToColumnStats = Maps.newHashMap();
List<NamedExpression> outputExpressions = aggregate.getOutputExpressions();
// TODO: 1. Estimate the output unit size by the type of corresponding AggregateFunction
// 2. Handle alias, literal in the output expression list
for (NamedExpression namedExpression : namedExpressionList) {
if (namedExpression instanceof SlotReference) {
slotColumnStatsMap.put((SlotReference) namedExpression, new ColumnStats());
}
for (NamedExpression outputExpression : outputExpressions) {
slotToColumnStats.put(outputExpression.toSlot(), new ColumnStats());
}
StatsDeriveResult statsDeriveResult = new StatsDeriveResult(resultSetCount, slotColumnStatsMap);
StatsDeriveResult statsDeriveResult = new StatsDeriveResult(resultSetCount, slotToColumnStats);
// TODO: Update ColumnStats properly, add new mapping from output slot to ColumnStats
return statsDeriveResult;
}
// TODO: Update data size and min/max value.
// TODO: do real project on column stats
private StatsDeriveResult computeProject(Project project) {
List<NamedExpression> namedExpressionList = project.getProjects();
List<Slot> slotSet = namedExpressionList.stream().flatMap(namedExpression -> {
List<Slot> slotReferenceList = namedExpression.collect(SlotReference.class::isInstance);
return slotReferenceList.stream();
}).collect(Collectors.toList());
StatsDeriveResult stat = groupExpression.getCopyOfChildStats(0);
Map<Slot, ColumnStats> slotColumnStatsMap = stat.getSlotToColumnStats();
slotColumnStatsMap.entrySet().removeIf(entry -> !slotSet.contains(entry.getKey()));
return stat;
List<NamedExpression> projections = project.getProjects();
StatsDeriveResult statsDeriveResult = groupExpression.getCopyOfChildStats(0);
Map<Slot, ColumnStats> childColumnStats = statsDeriveResult.getSlotToColumnStats();
Map<Slot, ColumnStats> columnsStats = projections.stream().map(projection -> {
List<SlotReference> slotReferences = projection.collect(SlotReference.class::isInstance);
if (slotReferences.isEmpty()) {
return new AbstractMap.SimpleEntry<>(projection.toSlot(), ColumnStats.createDefaultColumnStats());
} else {
// TODO: just a trick here, need to do real project on column stats
return new AbstractMap.SimpleEntry<>(projection.toSlot(), childColumnStats.get(slotReferences.get(0)));
}
}).collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
statsDeriveResult.setSlotToColumnStats(columnsStats);
return statsDeriveResult;
}
}

View File

@ -116,16 +116,13 @@ public class SlotReference extends Slot {
return false;
}
SlotReference that = (SlotReference) o;
return nullable == that.nullable
&& dataType.equals(that.dataType)
&& exprId.equals(that.exprId)
&& name.equals(that.name)
&& qualifier.equals(that.qualifier);
return exprId.equals(that.exprId);
}
@Override
public int hashCode() {
return Objects.hash(exprId, name, qualifier, nullable);
return Objects.hash(exprId);
}
public Column getColumn() {

View File

@ -76,6 +76,15 @@ public class ColumnStats {
private LiteralExpr minValue;
private LiteralExpr maxValue;
public static ColumnStats createDefaultColumnStats() {
ColumnStats columnStats = new ColumnStats();
columnStats.setAvgSize(1);
columnStats.setMaxSize(1);
columnStats.setNdv(1);
columnStats.setNumNulls(0);
return columnStats;
}
public ColumnStats(ColumnStats other) {
this.ndv = other.ndv;
this.avgSize = other.avgSize;

View File

@ -65,7 +65,7 @@ public class PlanEqualsTest {
Assertions.assertEquals(expected, actual);
LogicalAggregate<Plan> unexpected = new LogicalAggregate<>(Lists.newArrayList(), ImmutableList.of(
new SlotReference(new ExprId(0), "b", BigIntType.INSTANCE, true, Lists.newArrayList())),
new SlotReference(new ExprId(1), "b", BigIntType.INSTANCE, true, Lists.newArrayList())),
child);
Assertions.assertNotEquals(unexpected, actual);
}
@ -95,8 +95,8 @@ public class PlanEqualsTest {
Assertions.assertEquals(expected, actual);
LogicalJoin<Plan, Plan> unexpected = new LogicalJoin<>(JoinType.INNER_JOIN, Lists.newArrayList(new EqualTo(
new SlotReference("a", BigIntType.INSTANCE, false, Lists.newArrayList()),
new SlotReference("b", BigIntType.INSTANCE, true, Lists.newArrayList()))),
new SlotReference(new ExprId(2), "a", BigIntType.INSTANCE, false, Lists.newArrayList()),
new SlotReference(new ExprId(3), "b", BigIntType.INSTANCE, true, Lists.newArrayList()))),
Optional.empty(), left, right);
Assertions.assertNotEquals(unexpected, actual);
}
@ -151,7 +151,7 @@ public class PlanEqualsTest {
LogicalSort<Plan> unexpected = new LogicalSort<>(
ImmutableList.of(new OrderKey(
new SlotReference(new ExprId(1), "a", BigIntType.INSTANCE, true, Lists.newArrayList()), true,
new SlotReference(new ExprId(2), "a", BigIntType.INSTANCE, true, Lists.newArrayList()), true,
true)),
child);
Assertions.assertNotEquals(unexpected, actual);
@ -211,8 +211,8 @@ public class PlanEqualsTest {
PhysicalHashJoin<Plan, Plan> unexpected = new PhysicalHashJoin<>(JoinType.INNER_JOIN,
Lists.newArrayList(new EqualTo(
new SlotReference("a", BigIntType.INSTANCE, false, Lists.newArrayList()),
new SlotReference("b", BigIntType.INSTANCE, true, Lists.newArrayList()))),
new SlotReference(new ExprId(2), "a", BigIntType.INSTANCE, false, Lists.newArrayList()),
new SlotReference(new ExprId(3), "b", BigIntType.INSTANCE, true, Lists.newArrayList()))),
Optional.empty(), logicalProperties, left, right);
Assertions.assertNotEquals(unexpected, actual);
}
@ -288,7 +288,7 @@ public class PlanEqualsTest {
PhysicalQuickSort<Plan> unexpected = new PhysicalQuickSort<>(
ImmutableList.of(new OrderKey(
new SlotReference(new ExprId(1), "a", BigIntType.INSTANCE, true, Lists.newArrayList()), true,
new SlotReference(new ExprId(2), "a", BigIntType.INSTANCE, true, Lists.newArrayList()), true,
true)),
logicalProperties,
child);