Files
tidb/pkg/testkit/testdata/testdata.go
2025-12-10 09:40:20 +00:00

323 lines
9.0 KiB
Go

// Copyright 2021 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//go:build !codes
package testdata
import (
"bytes"
"encoding/json"
"flag"
"fmt"
"io"
"os"
"path/filepath"
"reflect"
"regexp"
"runtime"
"strings"
"testing"
contextutil "github.com/pingcap/tidb/pkg/util/context"
"github.com/stretchr/testify/require"
)
// record is a flag used for generate test result.
var record bool
func init() {
flag.BoolVar(&record, "record", false, "to generate test result")
}
type testCases struct {
Name string
Cases *json.RawMessage // For delayed parse.
decodedOut any // For generate output.
}
// TestData stores all the data of a test suite.
type TestData struct {
input []testCases
output []testCases
casout []testCases
filePathPrefix string
funcMap map[string]int
}
func loadTestSuiteData(dir, suiteName string, cascades ...bool) (res TestData, err error) {
inCascades := len(cascades) > 0 && cascades[0]
res.filePathPrefix = filepath.Join(dir, suiteName)
res.input, err = loadTestSuiteCases(fmt.Sprintf("%s_in.json", res.filePathPrefix))
if err != nil {
return res, err
}
// Load all test cases result in order to keep the unrelated test results.
res.output, err = loadTestSuiteCases(fmt.Sprintf("%s_out.json", res.filePathPrefix))
if err != nil {
return res, err
}
if len(res.input) != len(res.output) {
return res, fmt.Errorf("Number of test input cases %d does not match test output cases %d", len(res.input), len(res.output))
}
if inCascades {
// Load all test cascades result in order to keep the unrelated test results.
res.casout, err = loadTestSuiteCases(fmt.Sprintf("%s_xut.json", res.filePathPrefix))
if err != nil {
return res, err
}
if len(res.input) != len(res.output) {
return res, fmt.Errorf("Number of test input cases %d does not match test casout cases %d", len(res.input), len(res.casout))
}
}
res.funcMap = make(map[string]int, len(res.input))
for i, test := range res.input {
res.funcMap[test.Name] = i
if test.Name != res.output[i].Name {
return res, fmt.Errorf("Input name of the %d-case %s does not match output %s", i, test.Name, res.output[i].Name)
}
if inCascades {
if test.Name != res.casout[i].Name {
return res, fmt.Errorf("Input name of the %d-case %s does not match casout %s", i, test.Name, res.casout[i].Name)
}
}
}
return res, nil
}
func loadTestSuiteCases(filePath string) (res []testCases, err error) {
//nolint: gosec
jsonFile, err := os.Open(filePath)
if err != nil {
return res, err
}
defer func() {
if err1 := jsonFile.Close(); err == nil && err1 != nil {
err = err1
}
}()
byteValue, err := io.ReadAll(jsonFile)
if err != nil {
return res, err
}
// Remove comments, since they are not allowed in json.
// Note that this logic is not always correct. For example, you can only add comments in a new line.
// Appending comments to the end of a line with other content can't be identified.
// flag:
// s: let . match \n
// m: let ^/$ match begin/end of a line in addition to begin/end of the entire text
// ^: begin of line
// \s*: zero or more whitespace characters
// .*?: zero or more any characters, but match as few as possible
re := regexp.MustCompile(`(?sm)^\s*//.*?\n`)
err = json.Unmarshal(re.ReplaceAll(byteValue, nil), &res)
return res, err
}
// OnRecord execute the function to update result.
func OnRecord(updateFunc func()) {
if record {
updateFunc()
}
}
// ConvertRowsToStrings converts [][]interface{} to []string.
func ConvertRowsToStrings(rows [][]any) (rs []string) {
for _, row := range rows {
s := fmt.Sprintf("%v", row)
// Trim the leftmost `[` and rightmost `]`.
s = s[1 : len(s)-1]
rs = append(rs, s)
}
return rs
}
// ConvertSQLWarnToStrings converts []SQLWarn to []string.
func ConvertSQLWarnToStrings(warns []contextutil.SQLWarn) (rs []string) {
for _, warn := range warns {
rs = append(rs, fmt.Sprint(warn.Err.Error()))
}
return rs
}
// LoadTestCases Loads the test cases for a test function.
// opts[0] should be the cascades "on" or "off"
// opts[1] should be the test portal caller name to locate the output file dir.
func (td *TestData) LoadTestCases(t *testing.T, in any, out any, opts ...string) {
withCascadesOpts := len(opts) > 0
inCascades := len(opts) > 0 && opts[0] == "on"
// Extract caller's name.
var funcName string
if !withCascadesOpts {
pc, _, _, ok := runtime.Caller(1)
require.True(t, ok)
details := runtime.FuncForPC(pc)
funcNameIdx := strings.LastIndex(details.Name(), ".")
funcName = details.Name()[funcNameIdx+1:]
} else {
// since cascades is called in macro wrapper, use the passed caller name
require.True(t, len(opts) > 1)
funcName = opts[1]
}
casesIdx, ok := td.funcMap[funcName]
require.Truef(t, ok, "Must get test %s", funcName)
err := json.Unmarshal(*td.input[casesIdx].Cases, in)
require.NoError(t, err)
if !record {
if !inCascades {
err = json.Unmarshal(*td.output[casesIdx].Cases, out)
require.NoError(t, err)
} else {
err = json.Unmarshal(*td.casout[casesIdx].Cases, out)
require.NoError(t, err)
}
} else {
// Init for generate output file.
inputLen := reflect.ValueOf(in).Elem().Len()
v := reflect.ValueOf(out).Elem()
if v.Kind() == reflect.Slice {
v.Set(reflect.MakeSlice(v.Type(), inputLen, inputLen))
}
}
if !inCascades {
td.output[casesIdx].decodedOut = out
} else {
td.casout[casesIdx].decodedOut = out
}
}
// LoadTestCasesByName loads the test cases for a test function by its name.
func (td *TestData) LoadTestCasesByName(caseName string, t *testing.T, in any, out any, opts ...string) {
inCascades := len(opts) > 0 && opts[0] == "on"
casesIdx, ok := td.funcMap[caseName]
require.Truef(t, ok, "Case name: %s", caseName)
require.NoError(t, json.Unmarshal(*td.input[casesIdx].Cases, in))
if Record() {
inputLen := reflect.ValueOf(in).Elem().Len()
v := reflect.ValueOf(out).Elem()
if v.Kind() == reflect.Slice {
v.Set(reflect.MakeSlice(v.Type(), inputLen, inputLen))
}
} else {
if !inCascades {
require.NoError(t, json.Unmarshal(*td.output[casesIdx].Cases, out))
} else {
require.NoError(t, json.Unmarshal(*td.casout[casesIdx].Cases, out))
}
}
if !inCascades {
td.output[casesIdx].decodedOut = out
} else {
td.casout[casesIdx].decodedOut = out
}
}
func (td *TestData) generateOutputIfNeeded() error {
if !record {
return nil
}
record := func(cascades bool) error {
buf := new(bytes.Buffer)
enc := json.NewEncoder(buf)
enc.SetEscapeHTML(false)
enc.SetIndent("", " ")
isRecord4ThisSuite := false
output := td.output
if cascades {
output = td.casout
}
for i, test := range output {
if test.decodedOut != nil {
// Only update the results for the related test cases.
isRecord4ThisSuite = true
err := enc.Encode(test.decodedOut)
if err != nil {
return err
}
res := make([]byte, len(buf.Bytes()))
copy(res, buf.Bytes())
buf.Reset()
rm := json.RawMessage(res)
output[i].Cases = &rm
}
}
// Skip the record for the unrelated test files.
if !isRecord4ThisSuite {
return nil
}
err := enc.Encode(output)
if err != nil {
return err
}
suffix := "_out.json"
if cascades {
suffix = "_xut.json"
}
file, err := os.Create(fmt.Sprintf("%s"+suffix, td.filePathPrefix))
if err != nil {
return err
}
defer func() {
if err1 := file.Close(); err == nil && err1 != nil {
err = err1
}
}()
_, err = file.Write(buf.Bytes())
return err
}
err := record(false)
if err != nil {
return err
}
return record(true)
}
// BookKeeper does TestData suite bookkeeping.
type BookKeeper map[string]TestData
// LoadTestSuiteData loads test suite data from file and bookkeeping in the map.
func (m *BookKeeper) LoadTestSuiteData(dir, suiteName string, cascades ...bool) {
testData, err := loadTestSuiteData(dir, suiteName, cascades...)
if err != nil {
_, _ = fmt.Fprintf(os.Stderr, "testdata: Errors on loading test data from file: %v\n", err)
os.Exit(1)
}
(*m)[suiteName] = testData
}
// GenerateOutputIfNeeded generate the output file from data bookkeeping in the map.
func (m *BookKeeper) GenerateOutputIfNeeded() {
for _, testData := range *m {
err := testData.generateOutputIfNeeded()
if err != nil {
_, _ = fmt.Fprintf(os.Stderr, "testdata: Errors on generating output: %v\n", err)
os.Exit(1)
}
}
}
// Record is a temporary method for testutil to avoid "flag redefined: record" error,
// After we migrate all tests based on former testdata, we should remove testutil and this method.
func Record() bool {
return record
}