Files
tidb/util/testutil/testutil.go

423 lines
11 KiB
Go

// Copyright 2015 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.
// +build !codes
package testutil
import (
"bytes"
"encoding/json"
"flag"
"fmt"
"io"
"os"
"path/filepath"
"reflect"
"regexp"
"runtime"
"sort"
"strings"
"github.com/pingcap/check"
"github.com/pingcap/errors"
"github.com/pingcap/parser/mysql"
"github.com/pingcap/tidb/kv"
"github.com/pingcap/tidb/sessionctx/stmtctx"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/util/codec"
"github.com/pingcap/tidb/util/logutil"
"go.uber.org/zap"
)
// CompareUnorderedStringSlice compare two string slices.
// If a and b is exactly the same except the order, it returns true.
// In otherwise return false.
func CompareUnorderedStringSlice(a []string, b []string) bool {
if a == nil && b == nil {
return true
}
if a == nil || b == nil {
return false
}
if len(a) != len(b) {
return false
}
m := make(map[string]int, len(a))
for _, i := range a {
_, ok := m[i]
if !ok {
m[i] = 1
} else {
m[i]++
}
}
for _, i := range b {
_, ok := m[i]
if !ok {
return false
}
m[i]--
if m[i] == 0 {
delete(m, i)
}
}
return len(m) == 0
}
// datumEqualsChecker is a checker for DatumEquals.
type datumEqualsChecker struct {
*check.CheckerInfo
}
// DatumEquals checker verifies that the obtained value is equal to
// the expected value.
// For example:
// c.Assert(value, DatumEquals, NewDatum(42))
var DatumEquals check.Checker = &datumEqualsChecker{
&check.CheckerInfo{Name: "DatumEquals", Params: []string{"obtained", "expected"}},
}
func (checker *datumEqualsChecker) Check(params []interface{}, names []string) (result bool, error string) {
defer func() {
if v := recover(); v != nil {
result = false
error = fmt.Sprint(v)
logutil.BgLogger().Error("panic in datumEqualsChecker.Check",
zap.Reflect("r", v),
zap.Stack("stack trace"))
}
}()
paramFirst, ok := params[0].(types.Datum)
if !ok {
panic("the first param should be datum")
}
paramSecond, ok := params[1].(types.Datum)
if !ok {
panic("the second param should be datum")
}
sc := new(stmtctx.StatementContext)
res, err := paramFirst.CompareDatum(sc, &paramSecond)
if err != nil {
panic(err)
}
return res == 0, ""
}
// MustNewCommonHandle create a common handle with given values.
func MustNewCommonHandle(c *check.C, values ...interface{}) kv.Handle {
encoded, err := codec.EncodeKey(new(stmtctx.StatementContext), nil, types.MakeDatums(values...)...)
c.Assert(err, check.IsNil)
ch, err := kv.NewCommonHandle(encoded)
c.Assert(err, check.IsNil)
return ch
}
// CommonHandleSuite is used to adapt kv.CommonHandle to existing kv.IntHandle tests.
// Usage:
// type MyTestSuite struct {
// CommonHandleSuite
// }
// func (s *MyTestSuite) TestSomething(c *C) {
// // ...
// s.RerunWithCommonHandleEnabled(c, s.TestSomething)
// }
type CommonHandleSuite struct {
IsCommonHandle bool
}
// RerunWithCommonHandleEnabled runs a test function with IsCommonHandle enabled.
func (chs *CommonHandleSuite) RerunWithCommonHandleEnabled(c *check.C, f func(*check.C)) {
if !chs.IsCommonHandle {
chs.IsCommonHandle = true
f(c)
chs.IsCommonHandle = false
}
}
// NewHandle create a handle according to CommonHandleSuite.IsCommonHandle.
func (chs *CommonHandleSuite) NewHandle() *commonHandleSuiteNewHandleBuilder {
return &commonHandleSuiteNewHandleBuilder{isCommon: chs.IsCommonHandle}
}
type commonHandleSuiteNewHandleBuilder struct {
isCommon bool
intVal int64
commonVals []interface{}
}
func (c *commonHandleSuiteNewHandleBuilder) Int(v int64) *commonHandleSuiteNewHandleBuilder {
c.intVal = v
return c
}
func (c *commonHandleSuiteNewHandleBuilder) Common(vs ...interface{}) kv.Handle {
c.commonVals = vs
return c.Build()
}
func (c *commonHandleSuiteNewHandleBuilder) Build() kv.Handle {
if c.isCommon {
encoded, err := codec.EncodeKey(new(stmtctx.StatementContext), nil, types.MakeDatums(c.commonVals...)...)
if err != nil {
panic(err)
}
ch, err := kv.NewCommonHandle(encoded)
if err != nil {
panic(err)
}
return ch
}
return kv.IntHandle(c.intVal)
}
type handleEqualsChecker struct {
*check.CheckerInfo
}
// HandleEquals checker verifies that the obtained handle is equal to
// the expected handle.
// For example:
// c.Assert(value, HandleEquals, kv.IntHandle(42))
var HandleEquals = &handleEqualsChecker{
&check.CheckerInfo{Name: "HandleEquals", Params: []string{"obtained", "expected"}},
}
func (checker *handleEqualsChecker) Check(params []interface{}, names []string) (result bool, error string) {
if params[0] == nil && params[1] == nil {
return true, ""
}
param1, ok1 := params[0].(kv.Handle)
param2, ok2 := params[1].(kv.Handle)
if !ok1 || !ok2 {
return false, "Argument to " + checker.Name + " must be kv.Handle"
}
if param1.IsInt() != param2.IsInt() {
return false, "Two handle types arguments to" + checker.Name + " must be same"
}
return param1.String() == param2.String(), ""
}
// RowsWithSep is a convenient function to wrap args to a slice of []interface.
// The arg represents a row, split by sep.
func RowsWithSep(sep string, args ...string) [][]interface{} {
rows := make([][]interface{}, len(args))
for i, v := range args {
strs := strings.Split(v, sep)
row := make([]interface{}, len(strs))
for j, s := range strs {
row[j] = s
}
rows[i] = row
}
return rows
}
// record is a flag used for generate test result.
var record bool
func init() {
flag.BoolVar(&record, "record", false, "to generate test result")
}
type testCases struct {
Name string
Cases *json.RawMessage // For delayed parse.
decodedOut interface{} // For generate output.
}
// TestData stores all the data of a test suite.
type TestData struct {
input []testCases
output []testCases
filePathPrefix string
funcMap map[string]int
}
// LoadTestSuiteData loads test suite data from file.
func LoadTestSuiteData(dir, suiteName string) (res TestData, err error) {
res.filePathPrefix = filepath.Join(dir, suiteName)
res.input, err = loadTestSuiteCases(fmt.Sprintf("%s_in.json", res.filePathPrefix))
if err != nil {
return res, err
}
if record {
res.output = make([]testCases, len(res.input))
for i := range res.input {
res.output[i].Name = res.input[i].Name
}
} else {
res.output, err = loadTestSuiteCases(fmt.Sprintf("%s_out.json", res.filePathPrefix))
if err != nil {
return res, err
}
if len(res.input) != len(res.output) {
return res, errors.New(fmt.Sprintf("Number of test input cases %d does not match test output cases %d", len(res.input), len(res.output)))
}
}
res.funcMap = make(map[string]int, len(res.input))
for i, test := range res.input {
res.funcMap[test.Name] = i
if test.Name != res.output[i].Name {
return res, errors.New(fmt.Sprintf("Input name of the %d-case %s does not match output %s", i, test.Name, res.output[i].Name))
}
}
return res, nil
}
func loadTestSuiteCases(filePath string) (res []testCases, err error) {
jsonFile, err := os.Open(filePath)
if err != nil {
return res, err
}
defer func() {
if err1 := jsonFile.Close(); err == nil && err1 != nil {
err = err1
}
}()
byteValue, err := io.ReadAll(jsonFile)
if err != nil {
return res, err
}
// Remove comments, since they are not allowed in json.
re := regexp.MustCompile("(?s)//.*?\n")
err = json.Unmarshal(re.ReplaceAll(byteValue, nil), &res)
return res, err
}
// GetTestCasesByName gets the test cases for a test function by its name.
func (t *TestData) GetTestCasesByName(caseName string, c *check.C, in interface{}, out interface{}) {
casesIdx, ok := t.funcMap[caseName]
c.Assert(ok, check.IsTrue, check.Commentf("Must get test %s", caseName))
err := json.Unmarshal(*t.input[casesIdx].Cases, in)
c.Assert(err, check.IsNil)
if !record {
err = json.Unmarshal(*t.output[casesIdx].Cases, out)
c.Assert(err, check.IsNil)
} else {
// Init for generate output file.
inputLen := reflect.ValueOf(in).Elem().Len()
v := reflect.ValueOf(out).Elem()
if v.Kind() == reflect.Slice {
v.Set(reflect.MakeSlice(v.Type(), inputLen, inputLen))
}
}
t.output[casesIdx].decodedOut = out
}
// GetTestCases gets the test cases for a test function.
func (t *TestData) GetTestCases(c *check.C, in interface{}, out interface{}) {
// Extract caller's name.
pc, _, _, ok := runtime.Caller(1)
c.Assert(ok, check.IsTrue)
details := runtime.FuncForPC(pc)
funcNameIdx := strings.LastIndex(details.Name(), ".")
funcName := details.Name()[funcNameIdx+1:]
casesIdx, ok := t.funcMap[funcName]
c.Assert(ok, check.IsTrue, check.Commentf("Must get test %s", funcName))
err := json.Unmarshal(*t.input[casesIdx].Cases, in)
c.Assert(err, check.IsNil)
if !record {
err = json.Unmarshal(*t.output[casesIdx].Cases, out)
c.Assert(err, check.IsNil)
} else {
// Init for generate output file.
inputLen := reflect.ValueOf(in).Elem().Len()
v := reflect.ValueOf(out).Elem()
if v.Kind() == reflect.Slice {
v.Set(reflect.MakeSlice(v.Type(), inputLen, inputLen))
}
}
t.output[casesIdx].decodedOut = out
}
// OnRecord execute the function to update result.
func (t *TestData) OnRecord(updateFunc func()) {
if record {
updateFunc()
}
}
// ConvertRowsToStrings converts [][]interface{} to []string.
func (t *TestData) ConvertRowsToStrings(rows [][]interface{}) (rs []string) {
for _, row := range rows {
s := fmt.Sprintf("%v", row)
// Trim the leftmost `[` and rightmost `]`.
s = s[1 : len(s)-1]
rs = append(rs, s)
}
return rs
}
// ConvertSQLWarnToStrings converts []SQLWarn to []string.
func (t *TestData) ConvertSQLWarnToStrings(warns []stmtctx.SQLWarn) (rs []string) {
for _, warn := range warns {
rs = append(rs, fmt.Sprint(warn.Err.Error()))
}
return rs
}
// GenerateOutputIfNeeded generate the output file.
func (t *TestData) GenerateOutputIfNeeded() error {
if !record {
return nil
}
buf := new(bytes.Buffer)
enc := json.NewEncoder(buf)
enc.SetEscapeHTML(false)
enc.SetIndent("", " ")
for i, test := range t.output {
err := enc.Encode(test.decodedOut)
if err != nil {
return err
}
res := make([]byte, len(buf.Bytes()))
copy(res, buf.Bytes())
buf.Reset()
rm := json.RawMessage(res)
t.output[i].Cases = &rm
}
err := enc.Encode(t.output)
if err != nil {
return err
}
file, err := os.Create(fmt.Sprintf("%s_out.json", t.filePathPrefix))
if err != nil {
return err
}
defer func() {
if err1 := file.Close(); err == nil && err1 != nil {
err = err1
}
}()
_, err = file.Write(buf.Bytes())
return err
}
// MaskSortHandles sorts the handles by lowest (fieldTypeBits - 1 - shardBitsCount) bits.
func MaskSortHandles(handles []int64, shardBitsCount int, fieldType byte) []int64 {
typeBitsLength := mysql.DefaultLengthOfMysqlTypes[fieldType] * 8
const signBitCount = 1
shiftBitsCount := 64 - typeBitsLength + shardBitsCount + signBitCount
ordered := make([]int64, len(handles))
for i, h := range handles {
ordered[i] = h << shiftBitsCount >> shiftBitsCount
}
sort.Slice(ordered, func(i, j int) bool { return ordered[i] < ordered[j] })
return ordered
}