Files
tidb/expression/builtin_other_test.go

316 lines
11 KiB
Go

// Copyright 2017 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package expression
import (
"math"
"testing"
"time"
"github.com/pingcap/tidb/parser/ast"
"github.com/pingcap/tidb/parser/mysql"
"github.com/pingcap/tidb/sessionctx/stmtctx"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/types/json"
"github.com/pingcap/tidb/util/chunk"
"github.com/pingcap/tidb/util/collate"
"github.com/pingcap/tidb/util/hack"
"github.com/stretchr/testify/require"
)
func TestBitCount(t *testing.T) {
ctx := createContext(t)
stmtCtx := ctx.GetSessionVars().StmtCtx
origin := stmtCtx.IgnoreTruncate
stmtCtx.IgnoreTruncate = true
defer func() {
stmtCtx.IgnoreTruncate = origin
}()
fc := funcs[ast.BitCount]
var bitCountCases = []struct {
origin interface{}
count interface{}
}{
{int64(8), int64(1)},
{int64(29), int64(4)},
{int64(0), int64(0)},
{int64(-1), int64(64)},
{int64(-11), int64(62)},
{int64(-1000), int64(56)},
{float64(1.1), int64(1)},
{float64(3.1), int64(2)},
{float64(-1.1), int64(64)},
{float64(-3.1), int64(63)},
{uint64(math.MaxUint64), int64(64)},
{"xxx", int64(0)},
{nil, nil},
}
for _, test := range bitCountCases {
in := types.NewDatum(test.origin)
f, err := fc.getFunction(ctx, datumsToConstants([]types.Datum{in}))
require.NoError(t, err)
require.NotNil(t, f)
count, err := evalBuiltinFunc(f, chunk.Row{})
require.NoError(t, err)
if count.IsNull() {
require.Nil(t, test.count)
continue
}
sc := new(stmtctx.StatementContext)
sc.IgnoreTruncate = true
res, err := count.ToInt64(sc)
require.NoError(t, err)
require.Equal(t, test.count, res)
}
}
func TestRowFunc(t *testing.T) {
ctx := createContext(t)
fc := funcs[ast.RowFunc]
_, err := fc.getFunction(ctx, datumsToConstants(types.MakeDatums([]interface{}{"1", 1.2, true, 120}...)))
require.NoError(t, err)
}
func TestSetVar(t *testing.T) {
ctx := createContext(t)
fc := funcs[ast.SetVar]
dec := types.NewDecFromInt(5)
timeDec := types.NewTime(types.FromGoTime(time.Now()), mysql.TypeTimestamp, 0)
testCases := []struct {
args []interface{}
res interface{}
}{
{[]interface{}{"a", "12"}, "12"},
{[]interface{}{"b", "34"}, "34"},
{[]interface{}{"c", nil}, nil},
{[]interface{}{"c", "ABC"}, "ABC"},
{[]interface{}{"c", "dEf"}, "dEf"},
{[]interface{}{"d", int64(3)}, int64(3)},
{[]interface{}{"e", float64(2.5)}, float64(2.5)},
{[]interface{}{"f", dec}, dec},
{[]interface{}{"g", timeDec}, timeDec},
}
for _, tc := range testCases {
fn, err := fc.getFunction(ctx, datumsToConstants(types.MakeDatums(tc.args...)))
require.NoError(t, err)
d, err := evalBuiltinFunc(fn, chunk.MutRowFromDatums(types.MakeDatums(tc.args...)).ToRow())
require.NoError(t, err)
require.Equal(t, tc.res, d.GetValue())
if tc.args[1] != nil {
key, ok := tc.args[0].(string)
require.Equal(t, true, ok)
sessionVar, ok := ctx.GetSessionVars().Users[key]
require.Equal(t, true, ok)
require.Equal(t, tc.res, sessionVar.GetValue())
}
}
}
func TestGetVar(t *testing.T) {
ctx := createContext(t)
dec := types.NewDecFromInt(5)
timeDec := types.NewTime(types.FromGoTime(time.Now()), mysql.TypeTimestamp, 0)
sessionVars := []struct {
key string
val interface{}
}{
{"a", "中"},
{"b", "文字符chuan"},
{"c", ""},
{"e", int64(3)},
{"f", float64(2.5)},
{"g", dec},
{"h", timeDec},
}
for _, kv := range sessionVars {
ctx.GetSessionVars().Users[kv.key] = types.NewDatum(kv.val)
var tp *types.FieldType
if _, ok := kv.val.(types.Time); ok {
tp = types.NewFieldType(mysql.TypeDatetime)
} else {
tp = types.NewFieldType(mysql.TypeVarString)
}
types.DefaultParamTypeForValue(kv.val, tp)
ctx.GetSessionVars().UserVarTypes[kv.key] = tp
}
testCases := []struct {
args []interface{}
res interface{}
}{
{[]interface{}{"a"}, "中"},
{[]interface{}{"b"}, "文字符chuan"},
{[]interface{}{"c"}, ""},
{[]interface{}{"d"}, nil},
{[]interface{}{"e"}, int64(3)},
{[]interface{}{"f"}, float64(2.5)},
{[]interface{}{"g"}, dec},
{[]interface{}{"h"}, timeDec.String()},
}
for _, tc := range testCases {
tp, ok := ctx.GetSessionVars().UserVarTypes[tc.args[0].(string)]
if !ok {
tp = types.NewFieldType(mysql.TypeVarString)
}
fn, err := BuildGetVarFunction(ctx, datumsToConstants(types.MakeDatums(tc.args...))[0], tp)
require.NoError(t, err)
d, err := fn.Eval(chunk.Row{})
require.NoError(t, err)
require.Equal(t, tc.res, d.GetValue())
}
}
func TestValues(t *testing.T) {
ctx := createContext(t)
fc := &valuesFunctionClass{baseFunctionClass{ast.Values, 0, 0}, 1, types.NewFieldType(mysql.TypeVarchar)}
_, err := fc.getFunction(ctx, datumsToConstants(types.MakeDatums("")))
require.Error(t, err)
require.Regexp(t, "Incorrect parameter count in the call to native function 'values'$", err.Error())
sig, err := fc.getFunction(ctx, datumsToConstants(types.MakeDatums()))
require.NoError(t, err)
ret, err := evalBuiltinFunc(sig, chunk.Row{})
require.NoError(t, err)
require.True(t, ret.IsNull())
ctx.GetSessionVars().CurrInsertValues = chunk.MutRowFromDatums(types.MakeDatums("1")).ToRow()
ret, err = evalBuiltinFunc(sig, chunk.Row{})
require.Error(t, err)
require.Regexp(t, "^Session current insert values len", err.Error())
currInsertValues := types.MakeDatums("1", "2")
ctx.GetSessionVars().CurrInsertValues = chunk.MutRowFromDatums(currInsertValues).ToRow()
ret, err = evalBuiltinFunc(sig, chunk.Row{})
require.NoError(t, err)
cmp, err := ret.Compare(nil, &currInsertValues[1], collate.GetBinaryCollator())
require.NoError(t, err)
require.Equal(t, 0, cmp)
}
func TestSetVarFromColumn(t *testing.T) {
ctx := createContext(t)
// Construct arguments.
argVarName := &Constant{
Value: types.NewStringDatum("a"),
RetType: &types.FieldType{Tp: mysql.TypeVarString, Flen: 20},
}
argCol := &Column{
RetType: &types.FieldType{Tp: mysql.TypeVarString, Flen: 20},
Index: 0,
}
// Construct SetVar function.
funcSetVar, err := NewFunction(
ctx,
ast.SetVar,
&types.FieldType{Tp: mysql.TypeVarString, Flen: 20},
[]Expression{argVarName, argCol}...,
)
require.NoError(t, err)
// Construct input and output Chunks.
inputChunk := chunk.NewChunkWithCapacity([]*types.FieldType{argCol.RetType}, 1)
inputChunk.AppendString(0, "a")
outputChunk := chunk.NewChunkWithCapacity([]*types.FieldType{argCol.RetType}, 1)
// Evaluate the SetVar function.
err = evalOneCell(ctx, funcSetVar, inputChunk.GetRow(0), outputChunk, 0)
require.NoError(t, err)
require.Equal(t, "a", outputChunk.GetRow(0).GetString(0))
// Change the content of the underlying Chunk.
inputChunk.Reset()
inputChunk.AppendString(0, "b")
// Check whether the user variable changed.
sessionVars := ctx.GetSessionVars()
sessionVars.UsersLock.RLock()
defer sessionVars.UsersLock.RUnlock()
sessionVar, ok := sessionVars.Users["a"]
require.Equal(t, true, ok)
require.Equal(t, "a", sessionVar.GetString())
}
func TestInFunc(t *testing.T) {
ctx := createContext(t)
fc := funcs[ast.In]
decimal1 := types.NewDecFromFloatForTest(123.121)
decimal2 := types.NewDecFromFloatForTest(123.122)
decimal3 := types.NewDecFromFloatForTest(123.123)
decimal4 := types.NewDecFromFloatForTest(123.124)
time1 := types.NewTime(types.FromGoTime(time.Date(2017, 1, 1, 1, 1, 1, 1, time.UTC)), mysql.TypeDatetime, 6)
time2 := types.NewTime(types.FromGoTime(time.Date(2017, 1, 2, 1, 1, 1, 1, time.UTC)), mysql.TypeDatetime, 6)
time3 := types.NewTime(types.FromGoTime(time.Date(2017, 1, 3, 1, 1, 1, 1, time.UTC)), mysql.TypeDatetime, 6)
time4 := types.NewTime(types.FromGoTime(time.Date(2017, 1, 4, 1, 1, 1, 1, time.UTC)), mysql.TypeDatetime, 6)
duration1 := types.Duration{Duration: 12*time.Hour + 1*time.Minute + 1*time.Second}
duration2 := types.Duration{Duration: 12*time.Hour + 1*time.Minute}
duration3 := types.Duration{Duration: 12*time.Hour + 1*time.Second}
duration4 := types.Duration{Duration: 12 * time.Hour}
json1 := json.CreateBinary("123")
json2 := json.CreateBinary("123.1")
json3 := json.CreateBinary("123.2")
json4 := json.CreateBinary("123.3")
testCases := []struct {
args []interface{}
res interface{}
}{
{[]interface{}{1, 1, 2, 3}, int64(1)},
{[]interface{}{1, 0, 2, 3}, int64(0)},
{[]interface{}{1, nil, 2, 3}, nil},
{[]interface{}{nil, nil, 2, 3}, nil},
{[]interface{}{uint64(0), 0, 2, 3}, int64(1)},
{[]interface{}{uint64(math.MaxUint64), uint64(math.MaxUint64), 2, 3}, int64(1)},
{[]interface{}{-1, uint64(math.MaxUint64), 2, 3}, int64(0)},
{[]interface{}{uint64(math.MaxUint64), -1, 2, 3}, int64(0)},
{[]interface{}{1, 0, 2, 3}, int64(0)},
{[]interface{}{1.1, 1.2, 1.3}, int64(0)},
{[]interface{}{1.1, 1.1, 1.2, 1.3}, int64(1)},
{[]interface{}{decimal1, decimal2, decimal3, decimal4}, int64(0)},
{[]interface{}{decimal1, decimal2, decimal3, decimal1}, int64(1)},
{[]interface{}{"1.1", "1.1", "1.2", "1.3"}, int64(1)},
{[]interface{}{"1.1", hack.Slice("1.1"), "1.2", "1.3"}, int64(1)},
{[]interface{}{hack.Slice("1.1"), "1.1", "1.2", "1.3"}, int64(1)},
{[]interface{}{time1, time2, time3, time1}, int64(1)},
{[]interface{}{time1, time2, time3, time4}, int64(0)},
{[]interface{}{duration1, duration2, duration3, duration4}, int64(0)},
{[]interface{}{duration1, duration2, duration1, duration4}, int64(1)},
{[]interface{}{json1, json2, json3, json4}, int64(0)},
{[]interface{}{json1, json1, json3, json4}, int64(1)},
}
for _, tc := range testCases {
fn, err := fc.getFunction(ctx, datumsToConstants(types.MakeDatums(tc.args...)))
require.NoError(t, err)
d, err := evalBuiltinFunc(fn, chunk.MutRowFromDatums(types.MakeDatums(tc.args...)).ToRow())
require.NoError(t, err)
require.Equalf(t, tc.res, d.GetValue(), "%v", types.MakeDatums(tc.args))
}
strD1 := types.NewCollationStringDatum("a", "utf8_general_ci")
strD2 := types.NewCollationStringDatum("Á", "utf8_general_ci")
fn, err := fc.getFunction(ctx, datumsToConstants([]types.Datum{strD1, strD2}))
require.NoError(t, err)
d, isNull, err := fn.evalInt(chunk.Row{})
require.False(t, isNull)
require.NoError(t, err)
require.Equalf(t, int64(1), d, "%v, %v", strD1, strD2)
chk1 := chunk.NewChunkWithCapacity(nil, 1)
chk1.SetNumVirtualRows(1)
chk2 := chunk.NewChunkWithCapacity([]*types.FieldType{types.NewFieldType(mysql.TypeTiny)}, 1)
err = fn.vecEvalInt(chk1, chk2.Column(0))
require.NoError(t, err)
require.Equal(t, int64(1), chk2.Column(0).GetInt64(0))
}