Files
tidb/pkg/expression/builtin_other_test.go

374 lines
12 KiB
Go

// Copyright 2017 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package expression
import (
"math"
"testing"
"time"
"github.com/pingcap/tidb/pkg/expression/exprctx"
"github.com/pingcap/tidb/pkg/parser/ast"
"github.com/pingcap/tidb/pkg/parser/mysql"
"github.com/pingcap/tidb/pkg/types"
"github.com/pingcap/tidb/pkg/util/chunk"
"github.com/pingcap/tidb/pkg/util/collate"
"github.com/pingcap/tidb/pkg/util/hack"
"github.com/stretchr/testify/require"
)
func TestBitCount(t *testing.T) {
ctx := createContext(t)
stmtCtx := ctx.GetSessionVars().StmtCtx
oldTypeFlags := stmtCtx.TypeFlags()
defer func() {
stmtCtx.SetTypeFlags(oldTypeFlags)
}()
stmtCtx.SetTypeFlags(oldTypeFlags.WithIgnoreTruncateErr(true))
fc := funcs[ast.BitCount]
var bitCountCases = []struct {
origin any
count any
}{
{int64(8), int64(1)},
{int64(29), int64(4)},
{int64(0), int64(0)},
{int64(-1), int64(64)},
{int64(-11), int64(62)},
{int64(-1000), int64(56)},
{float64(1.1), int64(1)},
{float64(3.1), int64(2)},
{float64(-1.1), int64(64)},
{float64(-3.1), int64(63)},
{uint64(math.MaxUint64), int64(64)},
{"xxx", int64(0)},
{nil, nil},
}
for _, test := range bitCountCases {
in := types.NewDatum(test.origin)
f, err := fc.getFunction(ctx, datumsToConstants([]types.Datum{in}))
require.NoError(t, err)
require.NotNil(t, f)
count, err := evalBuiltinFunc(f, ctx, chunk.Row{})
require.NoError(t, err)
if count.IsNull() {
require.Nil(t, test.count)
continue
}
ctx := types.DefaultStmtNoWarningContext.WithFlags(types.DefaultStmtFlags.WithIgnoreTruncateErr(true))
res, err := count.ToInt64(ctx)
require.NoError(t, err)
require.Equal(t, test.count, res)
}
}
func TestRowFunc(t *testing.T) {
ctx := createContext(t)
fc := funcs[ast.RowFunc]
_, err := fc.getFunction(ctx, datumsToConstants(types.MakeDatums([]any{"1", 1.2, true, 120}...)))
require.NoError(t, err)
}
func TestSetVar(t *testing.T) {
ctx := createContext(t)
fc := funcs[ast.SetVar]
dec := types.NewDecFromInt(5)
timeDec := types.NewTime(types.FromGoTime(time.Now()), mysql.TypeTimestamp, 0)
testCases := []struct {
args []any
res any
}{
{[]any{"a", "12"}, "12"},
{[]any{"b", "34"}, "34"},
{[]any{"c", nil}, nil},
{[]any{"c", "ABC"}, "ABC"},
{[]any{"c", "dEf"}, "dEf"},
{[]any{"d", int64(3)}, int64(3)},
{[]any{"e", float64(2.5)}, float64(2.5)},
{[]any{"f", dec}, dec},
{[]any{"g", timeDec}, timeDec},
}
for _, tc := range testCases {
fn, err := fc.getFunction(ctx, datumsToConstants(types.MakeDatums(tc.args...)))
require.NoError(t, err)
d, err := evalBuiltinFunc(fn, ctx, chunk.MutRowFromDatums(types.MakeDatums(tc.args...)).ToRow())
require.NoError(t, err)
require.Equal(t, tc.res, d.GetValue())
if tc.args[1] != nil {
key, ok := tc.args[0].(string)
require.Equal(t, true, ok)
sessionVar, ok := ctx.GetSessionVars().GetUserVarVal(key)
require.Equal(t, true, ok)
require.Equal(t, tc.res, sessionVar.GetValue())
}
}
}
func TestGetVar(t *testing.T) {
ctx := createContext(t)
dec := types.NewDecFromInt(5)
timeDec := types.NewTime(types.FromGoTime(time.Now()), mysql.TypeTimestamp, 0)
sessionVars := []struct {
key string
val any
}{
{"a", "中"},
{"b", "文字符chuan"},
{"c", ""},
{"e", int64(3)},
{"f", float64(2.5)},
{"g", dec},
{"h", timeDec},
}
for _, kv := range sessionVars {
ctx.GetSessionVars().SetUserVarVal(kv.key, types.NewDatum(kv.val))
var tp *types.FieldType
if _, ok := kv.val.(types.Time); ok {
tp = types.NewFieldType(mysql.TypeDatetime)
} else {
tp = types.NewFieldType(mysql.TypeVarString)
}
types.InferParamTypeFromUnderlyingValue(kv.val, tp)
ctx.GetSessionVars().SetUserVarType(kv.key, tp)
}
testCases := []struct {
args []any
res any
}{
{[]any{"a"}, "中"},
{[]any{"b"}, "文字符chuan"},
{[]any{"c"}, ""},
{[]any{"d"}, nil},
{[]any{"e"}, int64(3)},
{[]any{"f"}, float64(2.5)},
{[]any{"g"}, dec},
{[]any{"h"}, timeDec.String()},
}
for _, tc := range testCases {
tp, ok := ctx.GetSessionVars().GetUserVarType(tc.args[0].(string))
if !ok {
tp = types.NewFieldType(mysql.TypeVarString)
}
fn, err := BuildGetVarFunction(ctx, datumsToConstants(types.MakeDatums(tc.args...))[0], tp)
require.NoError(t, err)
d, err := fn.Eval(ctx, chunk.Row{})
require.NoError(t, err)
require.Equal(t, tc.res, d.GetValue())
}
}
func TestTypeConversion(t *testing.T) {
ctx := createContext(t)
// Set value as int64
key := "a"
val := int64(3)
ctx.GetSessionVars().SetUserVarVal(key, types.NewDatum(val))
tp := types.NewFieldType(mysql.TypeLonglong)
ctx.GetSessionVars().SetUserVarType(key, tp)
args := []any{"a"}
// To Decimal.
tp = types.NewFieldType(mysql.TypeNewDecimal)
fn, err := BuildGetVarFunction(ctx, datumsToConstants(types.MakeDatums(args...))[0], tp)
require.NoError(t, err)
d, err := fn.Eval(ctx, chunk.Row{})
require.NoError(t, err)
des := types.NewDecFromInt(3)
require.Equal(t, des, d.GetValue())
// To Float.
tp = types.NewFieldType(mysql.TypeDouble)
fn, err = BuildGetVarFunction(ctx, datumsToConstants(types.MakeDatums(args...))[0], tp)
require.NoError(t, err)
d, err = fn.Eval(ctx, chunk.Row{})
require.NoError(t, err)
require.Equal(t, float64(3), d.GetValue())
}
func TestValues(t *testing.T) {
ctx := createContext(t)
fc := &valuesFunctionClass{baseFunctionClass{ast.Values, 0, 0}, 1, types.NewFieldType(mysql.TypeVarchar)}
_, err := fc.getFunction(ctx, datumsToConstants(types.MakeDatums("")))
require.Error(t, err)
require.Regexp(t, "Incorrect parameter count in the call to native function 'values'$", err.Error())
sig, err := fc.getFunction(ctx, datumsToConstants(types.MakeDatums()))
require.NoError(t, err)
ret, err := evalBuiltinFunc(sig, ctx, chunk.Row{})
require.NoError(t, err)
require.True(t, ret.IsNull())
ctx.GetSessionVars().CurrInsertValues = chunk.MutRowFromDatums(types.MakeDatums("1")).ToRow()
ret, err = evalBuiltinFunc(sig, ctx, chunk.Row{})
require.Error(t, err)
require.Regexp(t, "^Session current insert values len", err.Error())
currInsertValues := types.MakeDatums("1", "2")
ctx.GetSessionVars().CurrInsertValues = chunk.MutRowFromDatums(currInsertValues).ToRow()
ret, err = evalBuiltinFunc(sig, ctx, chunk.Row{})
require.NoError(t, err)
cmp, err := ret.Compare(types.DefaultStmtNoWarningContext, &currInsertValues[1], collate.GetBinaryCollator())
require.NoError(t, err)
require.Equal(t, 0, cmp)
}
func TestSetVarFromColumn(t *testing.T) {
ctx := createContext(t)
ft1 := types.FieldType{}
ft1.SetType(mysql.TypeVarString)
ft1.SetFlen(20)
ft2 := ft1.Clone()
ft3 := ft1.Clone()
// Construct arguments.
argVarName := &Constant{
Value: types.NewStringDatum("a"),
RetType: &ft1,
}
argCol := &Column{
RetType: ft2,
Index: 0,
}
// Construct SetVar function.
funcSetVar, err := NewFunction(
ctx,
ast.SetVar,
ft3,
[]Expression{argVarName, argCol}...,
)
require.NoError(t, err)
// Construct input and output Chunks.
inputChunk := chunk.NewChunkWithCapacity([]*types.FieldType{argCol.RetType}, 1)
inputChunk.AppendString(0, "a")
outputChunk := chunk.NewChunkWithCapacity([]*types.FieldType{argCol.RetType}, 1)
// Evaluate the SetVar function.
err = evalOneCell(ctx, funcSetVar, inputChunk.GetRow(0), outputChunk, 0)
require.NoError(t, err)
require.Equal(t, "a", outputChunk.GetRow(0).GetString(0))
// Change the content of the underlying Chunk.
inputChunk.Reset()
inputChunk.AppendString(0, "b")
// Check whether the user variable changed.
sessionVars := ctx.GetSessionVars()
sessionVar, ok := sessionVars.GetUserVarVal("a")
require.Equal(t, true, ok)
require.Equal(t, "a", sessionVar.GetString())
}
func TestInFunc(t *testing.T) {
ctx := createContext(t)
fc := funcs[ast.In]
decimal1 := types.NewDecFromFloatForTest(123.121)
decimal2 := types.NewDecFromFloatForTest(123.122)
decimal3 := types.NewDecFromFloatForTest(123.123)
decimal4 := types.NewDecFromFloatForTest(123.124)
time1 := types.NewTime(types.FromGoTime(time.Date(2017, 1, 1, 1, 1, 1, 1, time.UTC)), mysql.TypeDatetime, 6)
time2 := types.NewTime(types.FromGoTime(time.Date(2017, 1, 2, 1, 1, 1, 1, time.UTC)), mysql.TypeDatetime, 6)
time3 := types.NewTime(types.FromGoTime(time.Date(2017, 1, 3, 1, 1, 1, 1, time.UTC)), mysql.TypeDatetime, 6)
time4 := types.NewTime(types.FromGoTime(time.Date(2017, 1, 4, 1, 1, 1, 1, time.UTC)), mysql.TypeDatetime, 6)
duration1 := types.Duration{Duration: 12*time.Hour + 1*time.Minute + 1*time.Second}
duration2 := types.Duration{Duration: 12*time.Hour + 1*time.Minute}
duration3 := types.Duration{Duration: 12*time.Hour + 1*time.Second}
duration4 := types.Duration{Duration: 12 * time.Hour}
json1 := types.CreateBinaryJSON("123")
json2 := types.CreateBinaryJSON("123.1")
json3 := types.CreateBinaryJSON("123.2")
json4 := types.CreateBinaryJSON("123.3")
testCases := []struct {
args []any
res any
}{
{[]any{1, 1, 2, 3}, int64(1)},
{[]any{1, 0, 2, 3}, int64(0)},
{[]any{1, nil, 2, 3}, nil},
{[]any{nil, nil, 2, 3}, nil},
{[]any{uint64(0), 0, 2, 3}, int64(1)},
{[]any{uint64(math.MaxUint64), uint64(math.MaxUint64), 2, 3}, int64(1)},
{[]any{-1, uint64(math.MaxUint64), 2, 3}, int64(0)},
{[]any{uint64(math.MaxUint64), -1, 2, 3}, int64(0)},
{[]any{1, 0, 2, 3}, int64(0)},
{[]any{1.1, 1.2, 1.3}, int64(0)},
{[]any{1.1, 1.1, 1.2, 1.3}, int64(1)},
{[]any{decimal1, decimal2, decimal3, decimal4}, int64(0)},
{[]any{decimal1, decimal2, decimal3, decimal1}, int64(1)},
{[]any{"1.1", "1.1", "1.2", "1.3"}, int64(1)},
{[]any{"1.1", hack.Slice("1.1"), "1.2", "1.3"}, int64(1)},
{[]any{hack.Slice("1.1"), "1.1", "1.2", "1.3"}, int64(1)},
{[]any{time1, time2, time3, time1}, int64(1)},
{[]any{time1, time2, time3, time4}, int64(0)},
{[]any{duration1, duration2, duration3, duration4}, int64(0)},
{[]any{duration1, duration2, duration1, duration4}, int64(1)},
{[]any{json1, json2, json3, json4}, int64(0)},
{[]any{json1, json1, json3, json4}, int64(1)},
}
for _, tc := range testCases {
fn, err := fc.getFunction(ctx, datumsToConstants(types.MakeDatums(tc.args...)))
require.NoError(t, err)
d, err := evalBuiltinFunc(fn, ctx, chunk.MutRowFromDatums(types.MakeDatums(tc.args...)).ToRow())
require.NoError(t, err)
require.Equalf(t, tc.res, d.GetValue(), "%v", types.MakeDatums(tc.args))
}
strD1 := types.NewCollationStringDatum("a", "utf8_general_ci")
strD2 := types.NewCollationStringDatum("Á", "utf8_general_ci")
fn, err := fc.getFunction(ctx, datumsToConstants([]types.Datum{strD1, strD2}))
require.NoError(t, err)
d, err := evalBuiltinFunc(fn, ctx, chunk.Row{})
require.NoError(t, err)
require.False(t, d.IsNull())
require.Equal(t, types.KindInt64, d.Kind())
require.Equalf(t, int64(1), d.GetInt64(), "%v, %v", strD1, strD2)
chk1 := chunk.NewChunkWithCapacity(nil, 1)
chk1.SetNumVirtualRows(1)
chk2 := chunk.NewChunkWithCapacity([]*types.FieldType{types.NewFieldType(mysql.TypeTiny)}, 1)
require.True(t, fn.vectorized() && fn.isChildrenVectorized())
err = vecEvalType(ctx, fn, types.ETInt, chk1, chk2.Column(0))
require.NoError(t, err)
require.Equal(t, int64(1), chk2.Column(0).GetInt64(0))
}
func TestGetParam(t *testing.T) {
ctx := createContext(t)
params := []types.Datum{
types.NewIntDatum(123),
types.NewStringDatum("abc"),
}
ctx.GetSessionVars().PlanCacheParams.Append(params...)
fc := funcs[ast.GetParam]
for i := range params {
fn, err := fc.getFunction(ctx, datumsToConstants(types.MakeDatums(i)))
require.NoError(t, err)
d, err := evalBuiltinFunc(fn, ctx, chunk.Row{})
require.NoError(t, err)
str, err := params[i].ToString()
require.NoError(t, err)
require.Equal(t, d.Kind(), types.KindString)
require.Equal(t, str, d.GetString())
}
fn, err := fc.getFunction(ctx, datumsToConstants(types.MakeDatums(len(params)+1)))
require.NoError(t, err)
d, err := evalBuiltinFunc(fn, ctx, chunk.Row{})
require.Equal(t, exprctx.ErrParamIndexExceedParamCounts, err)
require.True(t, d.IsNull())
}