294 lines
7.9 KiB
Go
294 lines
7.9 KiB
Go
// Copyright 2015 PingCAP, Inc.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
package expression
|
|
|
|
import (
|
|
"reflect"
|
|
"slices"
|
|
"sync"
|
|
"sync/atomic"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/pingcap/errors"
|
|
"github.com/pingcap/tidb/pkg/parser/ast"
|
|
"github.com/pingcap/tidb/pkg/parser/charset"
|
|
"github.com/pingcap/tidb/pkg/parser/mysql"
|
|
"github.com/pingcap/tidb/pkg/types"
|
|
"github.com/pingcap/tidb/pkg/util"
|
|
"github.com/pingcap/tidb/pkg/util/chunk"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
func evalBuiltinFuncConcurrent(f builtinFunc, ctx EvalContext, row chunk.Row) (d types.Datum, err error) {
|
|
var wg util.WaitGroupWrapper
|
|
concurrency := 10
|
|
var lock sync.Mutex
|
|
err = nil
|
|
for range concurrency {
|
|
wg.Run(func() {
|
|
di, erri := evalBuiltinFunc(f, ctx, chunk.Row{})
|
|
lock.Lock()
|
|
if err == nil {
|
|
d, err = di, erri
|
|
}
|
|
lock.Unlock()
|
|
})
|
|
}
|
|
wg.Wait()
|
|
return
|
|
}
|
|
|
|
func evalBuiltinFunc(f builtinFunc, ctx EvalContext, row chunk.Row) (d types.Datum, err error) {
|
|
ctx = wrapEvalAssert(ctx, f)
|
|
var (
|
|
res any
|
|
isNull bool
|
|
)
|
|
switch f.getRetTp().EvalType() {
|
|
case types.ETInt:
|
|
var intRes int64
|
|
intRes, isNull, err = f.evalInt(ctx, row)
|
|
if mysql.HasUnsignedFlag(f.getRetTp().GetFlag()) {
|
|
res = uint64(intRes)
|
|
} else {
|
|
res = intRes
|
|
}
|
|
case types.ETReal:
|
|
res, isNull, err = f.evalReal(ctx, row)
|
|
case types.ETDecimal:
|
|
res, isNull, err = f.evalDecimal(ctx, row)
|
|
case types.ETDatetime, types.ETTimestamp:
|
|
res, isNull, err = f.evalTime(ctx, row)
|
|
case types.ETDuration:
|
|
res, isNull, err = f.evalDuration(ctx, row)
|
|
case types.ETJson:
|
|
res, isNull, err = f.evalJSON(ctx, row)
|
|
case types.ETString:
|
|
res, isNull, err = f.evalString(ctx, row)
|
|
}
|
|
|
|
d.SetValue(res, f.getRetTp())
|
|
if isNull {
|
|
d.SetNull()
|
|
return d, err
|
|
}
|
|
return
|
|
}
|
|
|
|
// tblToDtbl is a utility function for test.
|
|
func tblToDtbl(i any) []map[string][]types.Datum {
|
|
l := reflect.ValueOf(i).Len()
|
|
tbl := make([]map[string][]types.Datum, l)
|
|
for j := range l {
|
|
v := reflect.ValueOf(i).Index(j).Interface()
|
|
val := reflect.ValueOf(v)
|
|
t := reflect.TypeOf(v)
|
|
item := make(map[string][]types.Datum, val.NumField())
|
|
for k := range val.NumField() {
|
|
tmp := val.Field(k).Interface()
|
|
item[t.Field(k).Name] = makeDatums(tmp)
|
|
}
|
|
tbl[j] = item
|
|
}
|
|
return tbl
|
|
}
|
|
|
|
func makeDatums(i any) []types.Datum {
|
|
if i != nil {
|
|
t := reflect.TypeOf(i)
|
|
val := reflect.ValueOf(i)
|
|
switch t.Kind() {
|
|
case reflect.Slice:
|
|
l := val.Len()
|
|
res := make([]types.Datum, l)
|
|
for j := range l {
|
|
res[j] = types.NewDatum(val.Index(j).Interface())
|
|
}
|
|
return res
|
|
}
|
|
}
|
|
return types.MakeDatums(i)
|
|
}
|
|
|
|
func TestIsNullFunc(t *testing.T) {
|
|
ctx := createContext(t)
|
|
fc := funcs[ast.IsNull]
|
|
f, err := fc.getFunction(ctx, datumsToConstants(types.MakeDatums(1)))
|
|
require.NoError(t, err)
|
|
v, err := evalBuiltinFunc(f, ctx, chunk.Row{})
|
|
require.NoError(t, err)
|
|
require.Equal(t, int64(0), v.GetInt64())
|
|
|
|
f, err = fc.getFunction(ctx, datumsToConstants(types.MakeDatums(nil)))
|
|
require.NoError(t, err)
|
|
v, err = evalBuiltinFunc(f, ctx, chunk.Row{})
|
|
require.NoError(t, err)
|
|
require.Equal(t, int64(1), v.GetInt64())
|
|
}
|
|
|
|
func TestLock(t *testing.T) {
|
|
ctx := createContext(t)
|
|
lock := funcs[ast.GetLock]
|
|
f, err := lock.getFunction(ctx, datumsToConstants(types.MakeDatums("mylock", 1)))
|
|
require.NoError(t, err)
|
|
v, err := evalBuiltinFunc(f, ctx, chunk.Row{})
|
|
require.NoError(t, err)
|
|
require.Equal(t, int64(1), v.GetInt64())
|
|
|
|
releaseLock := funcs[ast.ReleaseLock]
|
|
f, err = releaseLock.getFunction(ctx, datumsToConstants(types.MakeDatums("mylock")))
|
|
require.NoError(t, err)
|
|
v, err = evalBuiltinFunc(f, ctx, chunk.Row{})
|
|
require.NoError(t, err)
|
|
require.Equal(t, int64(1), v.GetInt64())
|
|
}
|
|
|
|
func TestDisplayName(t *testing.T) {
|
|
require.Equal(t, "=", GetDisplayName(ast.EQ))
|
|
require.Equal(t, "<=>", GetDisplayName(ast.NullEQ))
|
|
require.Equal(t, "IS TRUE", GetDisplayName(ast.IsTruthWithoutNull))
|
|
require.Equal(t, "abs", GetDisplayName("abs"))
|
|
require.Equal(t, "other_unknown_func", GetDisplayName("other_unknown_func"))
|
|
}
|
|
|
|
func TestBuiltinFuncCacheConcurrency(t *testing.T) {
|
|
cache := builtinFuncCache[int]{}
|
|
ctx := createContext(t)
|
|
|
|
var invoked atomic.Int64
|
|
construct := func() (int, error) {
|
|
invoked.Add(1)
|
|
time.Sleep(time.Millisecond)
|
|
return 100 + int(invoked.Load()), nil
|
|
}
|
|
|
|
var wg sync.WaitGroup
|
|
concurrency := 8
|
|
wg.Add(concurrency)
|
|
for range concurrency {
|
|
go func() {
|
|
defer wg.Done()
|
|
v, err := cache.getOrInitCache(ctx, construct)
|
|
// all goroutines should get the same value
|
|
require.NoError(t, err)
|
|
require.Equal(t, 101, v)
|
|
}()
|
|
}
|
|
|
|
wg.Wait()
|
|
// construct will only be called once even in concurrency
|
|
require.Equal(t, int64(1), invoked.Load())
|
|
}
|
|
|
|
func TestBuiltinFuncCache(t *testing.T) {
|
|
cache := builtinFuncCache[int]{}
|
|
ctx := createContext(t)
|
|
|
|
// ok should be false when no cache present
|
|
v, ok := cache.getCache(ctx.GetSessionVars().StmtCtx.CtxID())
|
|
require.Equal(t, 0, v)
|
|
require.False(t, ok)
|
|
|
|
// getCache should not init cache
|
|
v, ok = cache.getCache(ctx.GetSessionVars().StmtCtx.CtxID())
|
|
require.Equal(t, 0, v)
|
|
require.False(t, ok)
|
|
|
|
var invoked atomic.Int64
|
|
returnError := false
|
|
construct := func() (int, error) {
|
|
invoked.Add(1)
|
|
if returnError {
|
|
return 128, errors.New("mockError")
|
|
}
|
|
return 100 + int(invoked.Load()), nil
|
|
}
|
|
|
|
// the first getOrInitCache should init cache
|
|
v, err := cache.getOrInitCache(ctx, construct)
|
|
require.NoError(t, err)
|
|
require.Equal(t, 101, v)
|
|
require.Equal(t, int64(1), invoked.Load())
|
|
|
|
// get should return the cache
|
|
v, ok = cache.getCache(ctx.GetSessionVars().StmtCtx.CtxID())
|
|
require.Equal(t, 101, v)
|
|
require.True(t, ok)
|
|
|
|
// the second should use the cached one
|
|
v, err = cache.getOrInitCache(ctx, construct)
|
|
require.NoError(t, err)
|
|
require.Equal(t, 101, v)
|
|
require.Equal(t, int64(1), invoked.Load())
|
|
|
|
// if ctxID changed, should re-init cache
|
|
ctx = createContext(t)
|
|
v, err = cache.getOrInitCache(ctx, construct)
|
|
require.NoError(t, err)
|
|
require.Equal(t, 102, v)
|
|
require.Equal(t, int64(2), invoked.Load())
|
|
v, ok = cache.getCache(ctx.GetSessionVars().StmtCtx.CtxID())
|
|
require.Equal(t, 102, v)
|
|
require.True(t, ok)
|
|
|
|
// error should be returned
|
|
ctx = createContext(t)
|
|
returnError = true
|
|
v, err = cache.getOrInitCache(ctx, construct)
|
|
require.Equal(t, 0, v)
|
|
require.EqualError(t, err, "mockError")
|
|
|
|
// error should not be cached
|
|
returnError = false
|
|
v, err = cache.getOrInitCache(ctx, construct)
|
|
require.NoError(t, err)
|
|
require.Equal(t, 104, v)
|
|
}
|
|
|
|
// newFunctionForTest creates a new ScalarFunction using funcName and arguments,
|
|
// it is different from expression.NewFunction which needs an additional retType argument.
|
|
func newFunctionForTest(ctx BuildContext, funcName string, args ...Expression) (Expression, error) {
|
|
fc, ok := funcs[funcName]
|
|
if !ok {
|
|
return nil, ErrFunctionNotExists.GenWithStackByArgs("FUNCTION", funcName)
|
|
}
|
|
funcArgs := slices.Clone(args)
|
|
f, err := fc.getFunction(ctx, funcArgs)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &ScalarFunction{
|
|
FuncName: ast.NewCIStr(funcName),
|
|
RetType: f.getRetTp(),
|
|
Function: f,
|
|
}, nil
|
|
}
|
|
|
|
var (
|
|
// MySQL int8.
|
|
int8Con = &Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeLonglong).SetCharset(charset.CharsetBin).SetCollate(charset.CollationBin).BuildP()}
|
|
// MySQL varchar.
|
|
varcharCon = &Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeVarchar).SetCharset(charset.CharsetUTF8).SetCollate(charset.CollationUTF8).BuildP()}
|
|
)
|
|
|
|
func getInt8Con() Expression {
|
|
return int8Con.Clone()
|
|
}
|
|
|
|
func getVarcharCon() Expression {
|
|
return varcharCon.Clone()
|
|
}
|