309 lines
6.7 KiB
Go
309 lines
6.7 KiB
Go
// Copyright 2025 PingCAP, Inc.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
package testkit
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"database/sql/driver"
|
|
"fmt"
|
|
"io"
|
|
"strconv"
|
|
"sync"
|
|
"sync/atomic"
|
|
|
|
"github.com/pingcap/errors"
|
|
"github.com/pingcap/tidb/pkg/expression"
|
|
"github.com/pingcap/tidb/pkg/parser/mysql"
|
|
"github.com/pingcap/tidb/pkg/session/sessionapi"
|
|
"github.com/pingcap/tidb/pkg/types"
|
|
"github.com/pingcap/tidb/pkg/util/chunk"
|
|
"github.com/pingcap/tidb/pkg/util/sqlexec"
|
|
)
|
|
|
|
var (
|
|
tkMapMu sync.RWMutex
|
|
tkMap = make(map[string]*TestKit)
|
|
tkIDSeq int64
|
|
)
|
|
|
|
func init() {
|
|
sql.Register("testkit", &testKitDriver{})
|
|
}
|
|
|
|
// CreateMockDB creates a *sql.DB that uses the TestKit's store to create sessions.
|
|
func CreateMockDB(tk *TestKit) *sql.DB {
|
|
id := strconv.FormatInt(atomic.AddInt64(&tkIDSeq, 1), 10)
|
|
tkMapMu.Lock()
|
|
tkMap[id] = tk
|
|
tkMapMu.Unlock()
|
|
|
|
db, err := sql.Open("testkit", id)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
return db
|
|
}
|
|
|
|
type testKitDriver struct{}
|
|
|
|
func (d *testKitDriver) Open(name string) (driver.Conn, error) {
|
|
tkMapMu.RLock()
|
|
tk, ok := tkMap[name]
|
|
tkMapMu.RUnlock()
|
|
if !ok {
|
|
return nil, fmt.Errorf("testkit not found for %s", name)
|
|
}
|
|
se := NewSession(tk.t, tk.store)
|
|
return &testKitConn{se: se}, nil
|
|
}
|
|
|
|
type testKitConn struct {
|
|
se sessionapi.Session
|
|
}
|
|
|
|
func (c *testKitConn) Prepare(query string) (driver.Stmt, error) {
|
|
return &testKitStmt{c: c, query: query}, nil
|
|
}
|
|
|
|
func (c *testKitConn) Close() error {
|
|
c.se.Close()
|
|
return nil
|
|
}
|
|
|
|
func (c *testKitConn) Begin() (driver.Tx, error) {
|
|
_, err := c.Exec("BEGIN", nil)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &testKitTxn{c: c}, nil
|
|
}
|
|
|
|
func (c *testKitConn) Exec(query string, args []driver.Value) (driver.Result, error) {
|
|
return nil, driver.ErrSkip
|
|
}
|
|
|
|
func (c *testKitConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
|
|
qArgs := make([]any, len(args))
|
|
for i, a := range args {
|
|
qArgs[i] = a.Value
|
|
}
|
|
|
|
rs, err := c.execute(ctx, query, qArgs)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if rs != nil {
|
|
if err := rs.Close(); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
return &testKitResult{
|
|
lastInsertID: int64(c.se.LastInsertID()),
|
|
rowsAffected: int64(c.se.AffectedRows()),
|
|
}, nil
|
|
}
|
|
|
|
func (c *testKitConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
|
|
qArgs := make([]any, len(args))
|
|
for i, a := range args {
|
|
qArgs[i] = a.Value
|
|
}
|
|
|
|
rs, err := c.execute(ctx, query, qArgs)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if rs == nil {
|
|
return nil, nil
|
|
}
|
|
return &testKitRows{rs: rs}, nil
|
|
}
|
|
|
|
func (c *testKitConn) execute(ctx context.Context, sql string, args []any) (sqlexec.RecordSet, error) {
|
|
// Set the command value to ComQuery, so that the process info can be updated correctly
|
|
c.se.SetCommandValue(mysql.ComQuery)
|
|
defer c.se.SetCommandValue(mysql.ComSleep)
|
|
|
|
if len(args) == 0 {
|
|
rss, err := c.se.Execute(ctx, sql)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if len(rss) == 0 {
|
|
return nil, nil
|
|
}
|
|
return rss[0], nil
|
|
}
|
|
|
|
stmtID, _, _, err := c.se.PrepareStmt(sql)
|
|
if err != nil {
|
|
return nil, errors.Trace(err)
|
|
}
|
|
params := expression.Args2Expressions4Test(args...)
|
|
rs, err := c.se.ExecutePreparedStmt(ctx, stmtID, params)
|
|
if err != nil {
|
|
return rs, errors.Trace(err)
|
|
}
|
|
err = c.se.DropPreparedStmt(stmtID)
|
|
if err != nil {
|
|
return rs, errors.Trace(err)
|
|
}
|
|
return rs, nil
|
|
}
|
|
|
|
type testKitStmt struct {
|
|
c *testKitConn
|
|
query string
|
|
}
|
|
|
|
func (s *testKitStmt) Close() error {
|
|
return nil
|
|
}
|
|
|
|
func (s *testKitStmt) NumInput() int {
|
|
return -1
|
|
}
|
|
|
|
func (s *testKitStmt) Exec(args []driver.Value) (driver.Result, error) {
|
|
return nil, driver.ErrSkip
|
|
}
|
|
|
|
func (s *testKitStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
|
|
return s.c.ExecContext(ctx, s.query, args)
|
|
}
|
|
|
|
func (s *testKitStmt) Query(args []driver.Value) (driver.Rows, error) {
|
|
return nil, driver.ErrSkip
|
|
}
|
|
|
|
func (s *testKitStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
|
|
return s.c.QueryContext(ctx, s.query, args)
|
|
}
|
|
|
|
type testKitResult struct {
|
|
lastInsertID int64
|
|
rowsAffected int64
|
|
}
|
|
|
|
func (r *testKitResult) LastInsertId() (int64, error) {
|
|
return r.lastInsertID, nil
|
|
}
|
|
|
|
func (r *testKitResult) RowsAffected() (int64, error) {
|
|
return r.rowsAffected, nil
|
|
}
|
|
|
|
type testKitRows struct {
|
|
rs sqlexec.RecordSet
|
|
chunk *chunk.Chunk
|
|
it *chunk.Iterator4Chunk
|
|
}
|
|
|
|
func (r *testKitRows) Columns() []string {
|
|
fields := r.rs.Fields()
|
|
cols := make([]string, len(fields))
|
|
for i, f := range fields {
|
|
cols[i] = f.Column.Name.O
|
|
}
|
|
return cols
|
|
}
|
|
|
|
func (r *testKitRows) Close() error {
|
|
return r.rs.Close()
|
|
}
|
|
|
|
func (r *testKitRows) Next(dest []driver.Value) error {
|
|
if r.chunk == nil {
|
|
r.chunk = r.rs.NewChunk(nil)
|
|
}
|
|
|
|
var row chunk.Row
|
|
if r.it == nil {
|
|
err := r.rs.Next(context.Background(), r.chunk)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if r.chunk.NumRows() == 0 {
|
|
return io.EOF
|
|
}
|
|
r.it = chunk.NewIterator4Chunk(r.chunk)
|
|
row = r.it.Begin()
|
|
} else {
|
|
row = r.it.Next()
|
|
if row.IsEmpty() {
|
|
err := r.rs.Next(context.Background(), r.chunk)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if r.chunk.NumRows() == 0 {
|
|
return io.EOF
|
|
}
|
|
r.it = chunk.NewIterator4Chunk(r.chunk)
|
|
row = r.it.Begin()
|
|
}
|
|
}
|
|
|
|
for i := range row.Len() {
|
|
d := row.GetDatum(i, &r.rs.Fields()[i].Column.FieldType)
|
|
// Handle NULL
|
|
if d.IsNull() {
|
|
dest[i] = nil
|
|
} else {
|
|
// Convert to appropriate type if needed, or just return string/bytes/int/float
|
|
// driver.Value allows int64, float64, bool, []byte, string, time.Time
|
|
// Datum.GetValue() returns interface{} which might be compatible.
|
|
v := d.GetValue()
|
|
switch x := v.(type) {
|
|
case []byte:
|
|
dest[i] = x
|
|
case string:
|
|
dest[i] = x
|
|
case int64:
|
|
dest[i] = x
|
|
case uint64:
|
|
dest[i] = x
|
|
case float64:
|
|
dest[i] = x
|
|
case float32:
|
|
dest[i] = x
|
|
case types.Time:
|
|
dest[i] = x.String()
|
|
case types.Duration:
|
|
dest[i] = x.String()
|
|
case *types.MyDecimal:
|
|
dest[i] = x.String()
|
|
default:
|
|
dest[i] = x
|
|
}
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
type testKitTxn struct {
|
|
c *testKitConn
|
|
}
|
|
|
|
func (t *testKitTxn) Commit() error {
|
|
_, err := t.c.Exec("COMMIT", nil)
|
|
return err
|
|
}
|
|
|
|
func (t *testKitTxn) Rollback() error {
|
|
_, err := t.c.Exec("ROLLBACK", nil)
|
|
return err
|
|
}
|