diff --git a/planner/core/cache.go b/planner/core/cache.go index af5c574395..ac89e1d397 100644 --- a/planner/core/cache.go +++ b/planner/core/cache.go @@ -140,4 +140,5 @@ func NewPSTMTPlanCacheValue(plan Plan, names []*types.FieldName) *PSTMTPlanCache type CachedPrepareStmt struct { PreparedAst *ast.Prepared VisitInfos []visitInfo + ColumnInfos interface{} } diff --git a/server/driver_tidb.go b/server/driver_tidb.go index 42fff81390..93ad2f9182 100644 --- a/server/driver_tidb.go +++ b/server/driver_tidb.go @@ -26,6 +26,7 @@ import ( "github.com/pingcap/parser/mysql" "github.com/pingcap/parser/terror" "github.com/pingcap/tidb/kv" + "github.com/pingcap/tidb/planner/core" "github.com/pingcap/tidb/session" "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/types" @@ -80,7 +81,8 @@ func (ts *TiDBStatement) Execute(ctx context.Context, args []types.Datum) (rs Re return } rs = &tidbResultSet{ - recordSet: tidbRecordset, + recordSet: tidbRecordset, + preparedStmt: ts.ctx.GetSessionVars().PreparedStmts[ts.id].(*core.CachedPrepareStmt), } return } @@ -351,10 +353,11 @@ func (tc *TiDBContext) GetSessionVars() *variable.SessionVars { } type tidbResultSet struct { - recordSet sqlexec.RecordSet - columns []*ColumnInfo - rows []chunk.Row - closed int32 + recordSet sqlexec.RecordSet + columns []*ColumnInfo + rows []chunk.Row + closed int32 + preparedStmt *core.CachedPrepareStmt } func (trs *tidbResultSet) NewChunk() *chunk.Chunk { @@ -391,11 +394,26 @@ func (trs *tidbResultSet) OnFetchReturned() { } func (trs *tidbResultSet) Columns() []*ColumnInfo { + if trs.columns != nil { + return trs.columns + } + // for prepare statement, try to get cached columnInfo array + if trs.preparedStmt != nil { + ps := trs.preparedStmt + if colInfos, ok := ps.ColumnInfos.([]*ColumnInfo); ok { + trs.columns = colInfos + } + } if trs.columns == nil { fields := trs.recordSet.Fields() for _, v := range fields { trs.columns = append(trs.columns, convertColumnInfo(v)) } + if trs.preparedStmt != nil { + // if ColumnInfo struct has allocated object, + // here maybe we need deep copy ColumnInfo to do caching + trs.preparedStmt.ColumnInfos = trs.columns + } } return trs.columns }