// Copyright 2020 PingCAP, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package executor import ( "sync" "testing" "time" "github.com/pingcap/errors" "github.com/pingcap/tidb/pkg/domain" "github.com/pingcap/tidb/pkg/executor/internal/exec" "github.com/pingcap/tidb/pkg/extension" "github.com/pingcap/tidb/pkg/parser/ast" "github.com/pingcap/tidb/pkg/parser/auth" "github.com/pingcap/tidb/pkg/parser/mysql" "github.com/pingcap/tidb/pkg/planner/core" "github.com/pingcap/tidb/pkg/types" "github.com/stretchr/testify/require" ) func TestBatchRetrieverHelper(t *testing.T) { rangeStarts := make([]int, 0) rangeEnds := make([]int, 0) collect := func(start, end int) error { rangeStarts = append(rangeStarts, start) rangeEnds = append(rangeEnds, end) return nil } r := &batchRetrieverHelper{} err := r.nextBatch(collect) require.NoError(t, err) require.Equal(t, rangeStarts, []int{}) require.Equal(t, rangeEnds, []int{}) r = &batchRetrieverHelper{ retrieved: true, batchSize: 3, totalRows: 10, } err = r.nextBatch(collect) require.NoError(t, err) require.Equal(t, rangeStarts, []int{}) require.Equal(t, rangeEnds, []int{}) r = &batchRetrieverHelper{ batchSize: 3, totalRows: 10, } err = r.nextBatch(func(start, end int) error { return errors.New("some error") }) require.Error(t, err) require.True(t, r.retrieved) r = &batchRetrieverHelper{ batchSize: 3, totalRows: 10, } for !r.retrieved { err = r.nextBatch(collect) require.NoError(t, err) } require.Equal(t, rangeStarts, []int{0, 3, 6, 9}) require.Equal(t, rangeEnds, []int{3, 6, 9, 10}) rangeStarts = rangeStarts[:0] rangeEnds = rangeEnds[:0] r = &batchRetrieverHelper{ batchSize: 3, totalRows: 9, } for !r.retrieved { err = r.nextBatch(collect) require.NoError(t, err) } require.Equal(t, rangeStarts, []int{0, 3, 6}) require.Equal(t, rangeEnds, []int{3, 6, 9}) rangeStarts = rangeStarts[:0] rangeEnds = rangeEnds[:0] r = &batchRetrieverHelper{ batchSize: 100, totalRows: 10, } for !r.retrieved { err = r.nextBatch(collect) require.NoError(t, err) } require.Equal(t, rangeStarts, []int{0}) require.Equal(t, rangeEnds, []int{10}) } func TestEqualDatumsAsBinary(t *testing.T) { tests := []struct { a []any b []any same bool }{ // Positive cases {[]any{1}, []any{1}, true}, {[]any{1, "aa"}, []any{1, "aa"}, true}, {[]any{1, "aa", 1}, []any{1, "aa", 1}, true}, // negative cases {[]any{1}, []any{2}, false}, {[]any{1, "a"}, []any{1, "aaaaaa"}, false}, {[]any{1, "aa", 3}, []any{1, "aa", 2}, false}, // Corner cases {[]any{}, []any{}, true}, {[]any{nil}, []any{nil}, true}, {[]any{}, []any{1}, false}, {[]any{1}, []any{1, 1}, false}, {[]any{nil}, []any{1}, false}, } ctx := core.MockContext() base := exec.NewBaseExecutor(ctx, nil, 0) defer func() { domain.GetDomain(ctx).StatsHandle().Close() }() e := &InsertValues{BaseExecutor: base} for _, tt := range tests { res, err := e.equalDatumsAsBinary(types.MakeDatums(tt.a...), types.MakeDatums(tt.b...)) require.NoError(t, err) require.Equal(t, tt.same, res) } } func TestEncodePasswordWithPlugin(t *testing.T) { hashString := "*3D56A309CD04FA2EEF181462E59011F075C89548" u := &ast.UserSpec{ User: &auth.UserIdentity{ Username: "test", }, AuthOpt: &ast.AuthOption{ ByAuthString: false, AuthString: "xxx", HashString: hashString, }, } p := &extension.AuthPlugin{ ValidateAuthString: func(s string) bool { return false }, GenerateAuthString: func(s string) (string, bool) { if s == "xxx" { return "xxxxxxx", true } return "", false }, } u.AuthOpt.ByAuthString = false _, ok := encodePasswordWithPlugin(*u, p, "") require.False(t, ok) u.AuthOpt.AuthString = "xxx" u.AuthOpt.ByAuthString = true pwd, ok := encodePasswordWithPlugin(*u, p, "") require.True(t, ok) require.Equal(t, "xxxxxxx", pwd) u.AuthOpt = nil pwd, ok = encodePasswordWithPlugin(*u, p, "") require.True(t, ok) require.Equal(t, "", pwd) } func TestWorkerPool(t *testing.T) { var ( list []int lock sync.Mutex ) push := func(i int) { lock.Lock() list = append(list, i) lock.Unlock() } clean := func() { lock.Lock() list = list[:0] lock.Unlock() } sleep := func(ms int) { time.Sleep(time.Duration(ms) * time.Millisecond) } t.Run("SingleWorker", func(t *testing.T) { clean() pool := &workerPool{ needSpawn: func(workers, tasks uint32) bool { return workers < 1 && tasks > 0 }, } wg := sync.WaitGroup{} wg.Add(1) pool.submit(func() { push(1) wg.Add(1) pool.submit(func() { push(3) sleep(10) push(4) wg.Done() }) sleep(1) push(2) wg.Done() }) wg.Wait() require.Equal(t, []int{1, 2, 3, 4}, list) }) t.Run("TwoWorkers", func(t *testing.T) { clean() pool := &workerPool{ needSpawn: func(workers, tasks uint32) bool { return workers < 2 && tasks > 0 }, } wg := sync.WaitGroup{} wg.Add(1) pool.submit(func() { push(1) wg.Add(1) pool.submit(func() { push(3) sleep(10) push(4) wg.Done() }) sleep(1) push(2) wg.Done() }) wg.Wait() require.Equal(t, []int{1, 3, 2, 4}, list) }) t.Run("TolerateOnePendingTask", func(t *testing.T) { clean() pool := &workerPool{ needSpawn: func(workers, tasks uint32) bool { return workers < 2 && tasks > 1 }, } wg := sync.WaitGroup{} wg.Add(1) pool.submit(func() { push(1) wg.Add(1) pool.submit(func() { push(3) sleep(10) push(4) wg.Done() }) sleep(1) push(2) wg.Done() }) wg.Wait() require.Equal(t, []int{1, 2, 3, 4}, list) }) } func TestEncodedPassword(t *testing.T) { hashString := "*3D56A309CD04FA2EEF181462E59011F075C89548" hashCachingString := "0123456789012345678901234567890123456789012345678901234567890123456789" u := ast.UserSpec{ User: &auth.UserIdentity{ Username: "test", }, AuthOpt: &ast.AuthOption{ ByAuthString: false, AuthString: "xxx", HashString: hashString, }, } pwd, ok := encodedPassword(&u, "") require.True(t, ok) require.Equal(t, u.AuthOpt.HashString, pwd) u.AuthOpt.HashString = "not-good-password-format" _, ok = encodedPassword(&u, "") require.False(t, ok) u.AuthOpt.ByAuthString = true // mysql_native_password pwd, ok = encodedPassword(&u, "") require.True(t, ok) require.Equal(t, hashString, pwd) // caching_sha2_password u.AuthOpt.HashString = hashCachingString pwd, ok = encodedPassword(&u, mysql.AuthCachingSha2Password) require.True(t, ok) require.Len(t, pwd, mysql.SHAPWDHashLen) u.AuthOpt.AuthString = "" pwd, ok = encodedPassword(&u, "") require.True(t, ok) require.Equal(t, "", pwd) }