252 lines
6.8 KiB
Go
252 lines
6.8 KiB
Go
// Copyright 2021 PingCAP, Inc.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
//go:build !codes
|
|
// +build !codes
|
|
|
|
package testdata
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/json"
|
|
"flag"
|
|
"fmt"
|
|
"io"
|
|
"os"
|
|
"path/filepath"
|
|
"reflect"
|
|
"regexp"
|
|
"runtime"
|
|
"strings"
|
|
"testing"
|
|
|
|
"github.com/pingcap/tidb/sessionctx/stmtctx"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
// record is a flag used for generate test result.
|
|
var record bool
|
|
|
|
func init() {
|
|
flag.BoolVar(&record, "record", false, "to generate test result")
|
|
}
|
|
|
|
type testCases struct {
|
|
Name string
|
|
Cases *json.RawMessage // For delayed parse.
|
|
decodedOut interface{} // For generate output.
|
|
}
|
|
|
|
// TestData stores all the data of a test suite.
|
|
type TestData struct {
|
|
input []testCases
|
|
output []testCases
|
|
filePathPrefix string
|
|
funcMap map[string]int
|
|
}
|
|
|
|
func loadTestSuiteData(dir, suiteName string) (res TestData, err error) {
|
|
res.filePathPrefix = filepath.Join(dir, suiteName)
|
|
res.input, err = loadTestSuiteCases(fmt.Sprintf("%s_in.json", res.filePathPrefix))
|
|
if err != nil {
|
|
return res, err
|
|
}
|
|
|
|
// Load all test cases result in order to keep the unrelated test results.
|
|
res.output, err = loadTestSuiteCases(fmt.Sprintf("%s_out.json", res.filePathPrefix))
|
|
if err != nil {
|
|
return res, err
|
|
}
|
|
if len(res.input) != len(res.output) {
|
|
return res, fmt.Errorf("Number of test input cases %d does not match test output cases %d", len(res.input), len(res.output))
|
|
}
|
|
res.funcMap = make(map[string]int, len(res.input))
|
|
for i, test := range res.input {
|
|
res.funcMap[test.Name] = i
|
|
if test.Name != res.output[i].Name {
|
|
return res, fmt.Errorf("Input name of the %d-case %s does not match output %s", i, test.Name, res.output[i].Name)
|
|
}
|
|
}
|
|
return res, nil
|
|
}
|
|
|
|
func loadTestSuiteCases(filePath string) (res []testCases, err error) {
|
|
//nolint: gosec
|
|
jsonFile, err := os.Open(filePath)
|
|
if err != nil {
|
|
return res, err
|
|
}
|
|
defer func() {
|
|
if err1 := jsonFile.Close(); err == nil && err1 != nil {
|
|
err = err1
|
|
}
|
|
}()
|
|
byteValue, err := io.ReadAll(jsonFile)
|
|
if err != nil {
|
|
return res, err
|
|
}
|
|
// Remove comments, since they are not allowed in json.
|
|
re := regexp.MustCompile("(?s)//.*?\n")
|
|
err = json.Unmarshal(re.ReplaceAll(byteValue, nil), &res)
|
|
return res, err
|
|
}
|
|
|
|
// OnRecord execute the function to update result.
|
|
func OnRecord(updateFunc func()) {
|
|
if record {
|
|
updateFunc()
|
|
}
|
|
}
|
|
|
|
// ConvertRowsToStrings converts [][]interface{} to []string.
|
|
func ConvertRowsToStrings(rows [][]interface{}) (rs []string) {
|
|
for _, row := range rows {
|
|
s := fmt.Sprintf("%v", row)
|
|
// Trim the leftmost `[` and rightmost `]`.
|
|
s = s[1 : len(s)-1]
|
|
rs = append(rs, s)
|
|
}
|
|
return rs
|
|
}
|
|
|
|
// ConvertSQLWarnToStrings converts []SQLWarn to []string.
|
|
func ConvertSQLWarnToStrings(warns []stmtctx.SQLWarn) (rs []string) {
|
|
for _, warn := range warns {
|
|
rs = append(rs, fmt.Sprint(warn.Err.Error()))
|
|
}
|
|
return rs
|
|
}
|
|
|
|
// LoadTestCases Loads the test cases for a test function.
|
|
func (td *TestData) LoadTestCases(t *testing.T, in interface{}, out interface{}) {
|
|
// Extract caller's name.
|
|
pc, _, _, ok := runtime.Caller(1)
|
|
require.True(t, ok)
|
|
details := runtime.FuncForPC(pc)
|
|
funcNameIdx := strings.LastIndex(details.Name(), ".")
|
|
funcName := details.Name()[funcNameIdx+1:]
|
|
|
|
casesIdx, ok := td.funcMap[funcName]
|
|
require.Truef(t, ok, "Must get test %s", funcName)
|
|
err := json.Unmarshal(*td.input[casesIdx].Cases, in)
|
|
require.NoError(t, err)
|
|
if !record {
|
|
err = json.Unmarshal(*td.output[casesIdx].Cases, out)
|
|
require.NoError(t, err)
|
|
} else {
|
|
// Init for generate output file.
|
|
inputLen := reflect.ValueOf(in).Elem().Len()
|
|
v := reflect.ValueOf(out).Elem()
|
|
if v.Kind() == reflect.Slice {
|
|
v.Set(reflect.MakeSlice(v.Type(), inputLen, inputLen))
|
|
}
|
|
}
|
|
td.output[casesIdx].decodedOut = out
|
|
}
|
|
|
|
// LoadTestCasesByName loads the test cases for a test function by its name.
|
|
func (td *TestData) LoadTestCasesByName(caseName string, t *testing.T, in interface{}, out interface{}) {
|
|
casesIdx, ok := td.funcMap[caseName]
|
|
require.Truef(t, ok, "Case name: %s", caseName)
|
|
require.NoError(t, json.Unmarshal(*td.input[casesIdx].Cases, in))
|
|
|
|
if Record() {
|
|
inputLen := reflect.ValueOf(in).Elem().Len()
|
|
v := reflect.ValueOf(out).Elem()
|
|
if v.Kind() == reflect.Slice {
|
|
v.Set(reflect.MakeSlice(v.Type(), inputLen, inputLen))
|
|
}
|
|
} else {
|
|
require.NoError(t, json.Unmarshal(*td.output[casesIdx].Cases, out))
|
|
}
|
|
|
|
td.output[casesIdx].decodedOut = out
|
|
}
|
|
|
|
func (td *TestData) generateOutputIfNeeded() error {
|
|
if !record {
|
|
return nil
|
|
}
|
|
|
|
buf := new(bytes.Buffer)
|
|
enc := json.NewEncoder(buf)
|
|
enc.SetEscapeHTML(false)
|
|
enc.SetIndent("", " ")
|
|
isRecord4ThisSuite := false
|
|
for i, test := range td.output {
|
|
if test.decodedOut != nil {
|
|
// Only update the results for the related test cases.
|
|
isRecord4ThisSuite = true
|
|
err := enc.Encode(test.decodedOut)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
res := make([]byte, len(buf.Bytes()))
|
|
copy(res, buf.Bytes())
|
|
buf.Reset()
|
|
rm := json.RawMessage(res)
|
|
td.output[i].Cases = &rm
|
|
}
|
|
}
|
|
// Skip the record for the unrelated test files.
|
|
if !isRecord4ThisSuite {
|
|
return nil
|
|
}
|
|
err := enc.Encode(td.output)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
file, err := os.Create(fmt.Sprintf("%s_out.json", td.filePathPrefix))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer func() {
|
|
if err1 := file.Close(); err == nil && err1 != nil {
|
|
err = err1
|
|
}
|
|
}()
|
|
_, err = file.Write(buf.Bytes())
|
|
return err
|
|
}
|
|
|
|
// BookKeeper does TestData suite bookkeeping.
|
|
type BookKeeper map[string]TestData
|
|
|
|
// LoadTestSuiteData loads test suite data from file and bookkeeping in the map.
|
|
func (m *BookKeeper) LoadTestSuiteData(dir, suiteName string) {
|
|
testData, err := loadTestSuiteData(dir, suiteName)
|
|
if err != nil {
|
|
_, _ = fmt.Fprintf(os.Stderr, "testdata: Errors on loading test data from file: %v\n", err)
|
|
os.Exit(1)
|
|
}
|
|
(*m)[suiteName] = testData
|
|
}
|
|
|
|
// GenerateOutputIfNeeded generate the output file from data bookkeeping in the map.
|
|
func (m *BookKeeper) GenerateOutputIfNeeded() {
|
|
for _, testData := range *m {
|
|
err := testData.generateOutputIfNeeded()
|
|
if err != nil {
|
|
_, _ = fmt.Fprintf(os.Stderr, "testdata: Errors on generating output: %v\n", err)
|
|
os.Exit(1)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Record is a temporary method for testutil to avoid "flag redefined: record" error,
|
|
// After we migrate all tests based on former testdata, we should remove testutil and this method.
|
|
func Record() bool {
|
|
return record
|
|
}
|