Files
tidb/pkg/store/driver/client_test.go

111 lines
3.1 KiB
Go

// Copyright 2023 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package driver
import (
"context"
"testing"
"time"
"github.com/pingcap/errors"
"github.com/pingcap/kvproto/pkg/kvrpcpb"
"github.com/pingcap/tidb/pkg/util/tracing"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"github.com/tikv/client-go/v2/tikv"
"github.com/tikv/client-go/v2/tikvrpc"
)
type mockTiKVClient struct {
tikv.Client
mock.Mock
}
func (c *mockTiKVClient) SendRequest(ctx context.Context, addr string, req *tikvrpc.Request, timeout time.Duration) (*tikvrpc.Response, error) {
args := c.Called(ctx, addr, req, timeout)
var resp *tikvrpc.Response
if v := args.Get(0); v != nil {
resp = v.(*tikvrpc.Response)
}
return resp, args.Error(1)
}
func TestInjectTracingClient(t *testing.T) {
cases := []struct {
name string
trace *tracing.TraceInfo
existSourceStmt *kvrpcpb.SourceStmt
}{
{
name: "trace is nil",
trace: nil,
},
{
name: "trace not nil",
trace: &tracing.TraceInfo{
ConnectionID: 123,
SessionAlias: "alias123",
},
},
{
name: "only connection id in trace valid",
trace: &tracing.TraceInfo{
ConnectionID: 456,
},
},
{
name: "only session alias in trace valid and sourceStmt exists",
trace: &tracing.TraceInfo{
SessionAlias: "alias456",
},
existSourceStmt: &kvrpcpb.SourceStmt{},
},
}
cli := &mockTiKVClient{}
inject := injectTraceClient{Client: cli}
for _, c := range cases {
ctx := context.Background()
if c.trace != nil {
ctx = tracing.ContextWithTraceInfo(ctx, c.trace)
}
req := &tikvrpc.Request{}
expectedResp := &tikvrpc.Response{}
verifySendRequest := func(args mock.Arguments) {
inj := args.Get(2).(*tikvrpc.Request)
if c.trace == nil {
require.Nil(t, inj.Context.SourceStmt, c.name)
} else {
require.NotNil(t, inj.Context.SourceStmt, c.name)
require.Equal(t, c.trace.ConnectionID, inj.Context.SourceStmt.ConnectionId, c.name)
require.Equal(t, c.trace.SessionAlias, inj.Context.SourceStmt.SessionAlias, c.name)
}
}
cli.On("SendRequest", ctx, "addr1", req, time.Second).Return(expectedResp, nil).Once().Run(verifySendRequest)
resp, err := inject.SendRequest(ctx, "addr1", req, time.Second)
cli.AssertExpectations(t)
require.NoError(t, err)
require.Same(t, expectedResp, resp)
expectedErr := errors.New("mockErr")
cli.On("SendRequest", ctx, "addr2", req, time.Minute).Return(nil, expectedErr).Once().Run(verifySendRequest)
resp, err = inject.SendRequest(ctx, "addr2", req, time.Minute)
require.Same(t, expectedErr, err)
require.Nil(t, resp)
}
}