87 lines
2.1 KiB
Go
87 lines
2.1 KiB
Go
package testkit
|
|
|
|
import (
|
|
"fmt"
|
|
"github.com/pingcap/check"
|
|
"github.com/pingcap/tidb"
|
|
"github.com/pingcap/tidb/kv"
|
|
"github.com/pingcap/tidb/rset"
|
|
)
|
|
|
|
// TestKit is a utility to run sql test.
|
|
type TestKit struct {
|
|
c *check.C
|
|
store kv.Storage
|
|
Se tidb.Session
|
|
}
|
|
|
|
// Result is the result returned by MustQuery.
|
|
type Result struct {
|
|
rows [][]interface{}
|
|
comment check.CommentInterface
|
|
c *check.C
|
|
}
|
|
|
|
// NewTestKit returns a new *TestKit.
|
|
func NewTestKit(c *check.C, store kv.Storage) *TestKit {
|
|
return &TestKit{
|
|
c: c,
|
|
store: store,
|
|
}
|
|
}
|
|
|
|
// Exec executes a sql statement.
|
|
func (tk *TestKit) Exec(sql string, args ...interface{}) (rset.Recordset, error) {
|
|
var err error
|
|
if tk.Se == nil {
|
|
tk.Se, err = tidb.CreateSession(tk.store)
|
|
tk.c.Assert(err, check.IsNil)
|
|
}
|
|
if len(args) == 0 {
|
|
var rss []rset.Recordset
|
|
rss, err = tk.Se.Execute(sql)
|
|
if err == nil && len(rss) > 0 {
|
|
return rss[0], nil
|
|
}
|
|
return nil, err
|
|
}
|
|
stmtID, _, _, err := tk.Se.PrepareStmt(sql)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
rs, err := tk.Se.ExecutePreparedStmt(stmtID, args...)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
err = tk.Se.DropPreparedStmt(stmtID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return rs, nil
|
|
}
|
|
|
|
// MustExec executes a sql statement and asserts nil error.
|
|
func (tk *TestKit) MustExec(sql string, args ...interface{}) {
|
|
_, err := tk.Exec(sql, args...)
|
|
tk.c.Assert(err, check.IsNil, check.Commentf("sql:%s, %v", sql, args))
|
|
}
|
|
|
|
// MustQuery query the statements and returns result rows.
|
|
// If expected result is set it asserts the query result equals expected result.
|
|
func (tk *TestKit) MustQuery(sql string, args ...interface{}) *Result {
|
|
comment := check.Commentf("sql:%s, %v", sql, args)
|
|
rs, err := tk.Exec(sql, args...)
|
|
tk.c.Assert(err, check.IsNil, comment)
|
|
tk.c.Assert(rs, check.NotNil, comment)
|
|
rows, err := rs.Rows(-1, 0)
|
|
tk.c.Assert(err, check.IsNil, comment)
|
|
return &Result{rows: rows, c: tk.c, comment: comment}
|
|
}
|
|
|
|
// Check asserts the result equals the expected results.
|
|
func (res *Result) Check(expected [][]interface{}) {
|
|
got := fmt.Sprintf("%v", res.rows)
|
|
need := fmt.Sprintf("%v", expected)
|
|
res.c.Assert(got, check.Equals, need, res.comment)
|
|
}
|