diff --git a/planner/core/stats.go b/planner/core/stats.go index 050e9e8a14..ad1c480fdd 100644 --- a/planner/core/stats.go +++ b/planner/core/stats.go @@ -16,7 +16,6 @@ package core import ( "math" - "github.com/pingcap/errors" "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/planner/property" log "github.com/sirupsen/logrus" @@ -27,6 +26,9 @@ func (p *basePhysicalPlan) StatsCount() float64 { } func (p *LogicalTableDual) deriveStats() (*property.StatsInfo, error) { + if p.stats != nil { + return p.stats, nil + } profile := &property.StatsInfo{ RowCount: float64(p.RowCount), Cardinality: make([]float64, p.Schema().Len()), @@ -39,6 +41,9 @@ func (p *LogicalTableDual) deriveStats() (*property.StatsInfo, error) { } func (p *baseLogicalPlan) deriveStats() (*property.StatsInfo, error) { + if p.stats != nil { + return p.stats, nil + } if len(p.children) > 1 { panic("LogicalPlans with more than one child should implement their own deriveStats().") } @@ -46,7 +51,7 @@ func (p *baseLogicalPlan) deriveStats() (*property.StatsInfo, error) { if len(p.children) == 1 { var err error p.stats, err = p.children[0].deriveStats() - return p.stats, errors.Trace(err) + return p.stats, err } profile := &property.StatsInfo{ @@ -86,6 +91,9 @@ func (ds *DataSource) getStatsByFilter(conds expression.CNFExprs) *property.Stat } func (ds *DataSource) deriveStats() (*property.StatsInfo, error) { + if ds.stats != nil { + return ds.stats, nil + } // PushDownNot here can convert query 'not (a != 1)' to 'a = 1'. for i, expr := range ds.pushedDownConds { ds.pushedDownConds[i] = expression.PushDownNot(nil, expr, false) @@ -95,7 +103,7 @@ func (ds *DataSource) deriveStats() (*property.StatsInfo, error) { if path.isTablePath { noIntervalRanges, err := ds.deriveTablePathStats(path) if err != nil { - return nil, errors.Trace(err) + return nil, err } // If we have point or empty range, just remove other possible paths. if noIntervalRanges || len(path.ranges) == 0 { @@ -107,7 +115,7 @@ func (ds *DataSource) deriveStats() (*property.StatsInfo, error) { } noIntervalRanges, err := ds.deriveIndexPathStats(path) if err != nil { - return nil, errors.Trace(err) + return nil, err } // If we have empty range, or point range on unique index, just remove other possible paths. if (noIntervalRanges && path.index.Unique) || len(path.ranges) == 0 { @@ -120,22 +128,28 @@ func (ds *DataSource) deriveStats() (*property.StatsInfo, error) { } func (p *LogicalSelection) deriveStats() (*property.StatsInfo, error) { + if p.stats != nil { + return p.stats, nil + } childProfile, err := p.children[0].deriveStats() if err != nil { - return nil, errors.Trace(err) + return nil, err } p.stats = childProfile.Scale(selectionFactor) return p.stats, nil } func (p *LogicalUnionAll) deriveStats() (*property.StatsInfo, error) { + if p.stats != nil { + return p.stats, nil + } p.stats = &property.StatsInfo{ Cardinality: make([]float64, p.Schema().Len()), } for _, child := range p.children { childProfile, err := child.deriveStats() if err != nil { - return nil, errors.Trace(err) + return nil, err } p.stats.RowCount += childProfile.RowCount for i := range p.stats.Cardinality { @@ -146,9 +160,12 @@ func (p *LogicalUnionAll) deriveStats() (*property.StatsInfo, error) { } func (p *LogicalLimit) deriveStats() (*property.StatsInfo, error) { + if p.stats != nil { + return p.stats, nil + } childProfile, err := p.children[0].deriveStats() if err != nil { - return nil, errors.Trace(err) + return nil, err } p.stats = &property.StatsInfo{ RowCount: math.Min(float64(p.Count), childProfile.RowCount), @@ -161,9 +178,12 @@ func (p *LogicalLimit) deriveStats() (*property.StatsInfo, error) { } func (lt *LogicalTopN) deriveStats() (*property.StatsInfo, error) { + if lt.stats != nil { + return lt.stats, nil + } childProfile, err := lt.children[0].deriveStats() if err != nil { - return nil, errors.Trace(err) + return nil, err } lt.stats = &property.StatsInfo{ RowCount: math.Min(float64(lt.Count), childProfile.RowCount), @@ -192,9 +212,12 @@ func getCardinality(cols []*expression.Column, schema *expression.Schema, profil } func (p *LogicalProjection) deriveStats() (*property.StatsInfo, error) { + if p.stats != nil { + return p.stats, nil + } childProfile, err := p.children[0].deriveStats() if err != nil { - return nil, errors.Trace(err) + return nil, err } p.stats = &property.StatsInfo{ RowCount: childProfile.RowCount, @@ -208,9 +231,12 @@ func (p *LogicalProjection) deriveStats() (*property.StatsInfo, error) { } func (la *LogicalAggregation) deriveStats() (*property.StatsInfo, error) { + if la.stats != nil { + return la.stats, nil + } childProfile, err := la.children[0].deriveStats() if err != nil { - return nil, errors.Trace(err) + return nil, err } gbyCols := make([]*expression.Column, 0, len(la.GroupByItems)) for _, gbyExpr := range la.GroupByItems { @@ -238,13 +264,16 @@ func (la *LogicalAggregation) deriveStats() (*property.StatsInfo, error) { // This is a quite simple strategy: We assume every bucket of relation which will participate join has the same number of rows, and apply cross join for // every matched bucket. func (p *LogicalJoin) deriveStats() (*property.StatsInfo, error) { + if p.stats != nil { + return p.stats, nil + } leftProfile, err := p.children[0].deriveStats() if err != nil { - return nil, errors.Trace(err) + return nil, err } rightProfile, err := p.children[1].deriveStats() if err != nil { - return nil, errors.Trace(err) + return nil, err } if p.JoinType == SemiJoin || p.JoinType == AntiSemiJoin { p.stats = &property.StatsInfo{ @@ -300,13 +329,16 @@ func (p *LogicalJoin) deriveStats() (*property.StatsInfo, error) { } func (la *LogicalApply) deriveStats() (*property.StatsInfo, error) { + if la.stats != nil { + return la.stats, nil + } leftProfile, err := la.children[0].deriveStats() if err != nil { - return nil, errors.Trace(err) + return nil, err } _, err = la.children[1].deriveStats() if err != nil { - return nil, errors.Trace(err) + return nil, err } la.stats = &property.StatsInfo{ RowCount: leftProfile.RowCount, @@ -336,9 +368,12 @@ func getSingletonStats(len int) *property.StatsInfo { } func (p *LogicalMaxOneRow) deriveStats() (*property.StatsInfo, error) { + if p.stats != nil { + return p.stats, nil + } _, err := p.children[0].deriveStats() if err != nil { - return nil, errors.Trace(err) + return nil, err } p.stats = getSingletonStats(p.Schema().Len()) return p.stats, nil