Files
tidb/pkg/table/column_test.go

576 lines
15 KiB
Go

// Copyright 2016 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package table
import (
"fmt"
"math"
"testing"
"time"
"github.com/pingcap/tidb/pkg/errctx"
"github.com/pingcap/tidb/pkg/expression"
"github.com/pingcap/tidb/pkg/meta/model"
"github.com/pingcap/tidb/pkg/parser/ast"
"github.com/pingcap/tidb/pkg/parser/charset"
"github.com/pingcap/tidb/pkg/parser/mysql"
"github.com/pingcap/tidb/pkg/sessionctx/stmtctx"
"github.com/pingcap/tidb/pkg/types"
"github.com/pingcap/tidb/pkg/util/collate"
"github.com/pingcap/tidb/pkg/util/mock"
"github.com/stretchr/testify/require"
)
func TestString(t *testing.T) {
col := ToColumn(&model.ColumnInfo{
FieldType: *types.NewFieldType(mysql.TypeTiny),
State: model.StatePublic,
})
col.SetFlen(2)
col.SetDecimal(1)
col.SetCharset(mysql.DefaultCharset)
col.SetCollate(mysql.DefaultCollationName)
col.AddFlag(mysql.ZerofillFlag | mysql.UnsignedFlag | mysql.BinaryFlag | mysql.AutoIncrementFlag | mysql.NotNullFlag)
require.Equal(t, "tinyint(2) unsigned zerofill", col.GetTypeDesc())
col.ToInfo()
tbInfo := &model.TableInfo{}
require.False(t, col.IsPKHandleColumn(tbInfo))
tbInfo.PKIsHandle = true
col.AddFlag(mysql.PriKeyFlag)
require.True(t, col.IsPKHandleColumn(tbInfo))
cs := col.String()
require.Greater(t, len(cs), 0)
col.SetType(mysql.TypeEnum)
col.SetFlag(0)
col.SetElems([]string{"a", "b"})
require.Equal(t, "enum('a','b')", col.GetTypeDesc())
col.SetElems([]string{"'a'", "b"})
require.Equal(t, "enum('''a''','b')", col.GetTypeDesc())
col.SetType(mysql.TypeFloat)
col.SetFlen(8)
col.SetDecimal(-1)
require.Equal(t, "float", col.GetTypeDesc())
col.SetDecimal(1)
require.Equal(t, "float(8,1)", col.GetTypeDesc())
col.SetType(mysql.TypeDatetime)
col.SetDecimal(6)
require.Equal(t, "datetime(6)", col.GetTypeDesc())
col.SetDecimal(0)
require.Equal(t, "datetime", col.GetTypeDesc())
col.SetDecimal(-1)
require.Equal(t, "datetime", col.GetTypeDesc())
}
func TestFind(t *testing.T) {
cols := []*Column{
newCol("a"),
newCol("b"),
newCol("c"),
}
c, s := FindCols(cols, []string{"a"}, true)
require.Equal(t, cols[:1], c)
require.Equal(t, "", s)
c1, s1 := FindCols(cols, []string{"d"}, true)
require.Nil(t, c1)
require.Equal(t, "d", s1)
cols[0].AddFlag(mysql.OnUpdateNowFlag)
c2 := FindOnUpdateCols(cols)
require.Equal(t, cols[:1], c2)
}
// checkNotNull checks if row has nil value set to a column with NotNull flag set.
func checkNotNull(cols []*Column, row []types.Datum) error {
for _, c := range cols {
if err := c.CheckNotNull(&row[c.Offset], 0); err != nil {
return err
}
}
return nil
}
func TestCheck(t *testing.T) {
col := newCol("a")
col.SetFlag(mysql.AutoIncrementFlag)
cols := []*Column{col, col}
err := CheckOnce(cols)
require.Error(t, err)
cols = cols[:1]
err = checkNotNull(cols, types.MakeDatums(nil))
require.NoError(t, err)
cols[0].AddFlag(mysql.NotNullFlag)
err = checkNotNull(cols, types.MakeDatums(nil))
require.Error(t, err)
err = CheckOnce([]*Column{})
require.NoError(t, err)
}
func TestHandleBadNull(t *testing.T) {
col := newCol("a")
sc := stmtctx.NewStmtCtx()
d := types.Datum{}
err := col.HandleBadNull(sc.ErrCtx(), &d, 0)
require.NoError(t, err)
cmp, err := d.Compare(sc.TypeCtx(), &types.Datum{}, collate.GetBinaryCollator())
require.NoError(t, err)
require.Equal(t, 0, cmp)
col.AddFlag(mysql.NotNullFlag)
err = col.HandleBadNull(sc.ErrCtx(), &types.Datum{}, 0)
require.Error(t, err)
var levels errctx.LevelMap
levels[errctx.ErrGroupBadNull] = errctx.LevelWarn
sc.SetErrLevels(levels)
err = col.HandleBadNull(sc.ErrCtx(), &types.Datum{}, 0)
require.NoError(t, err)
}
func TestDesc(t *testing.T) {
col := newCol("a")
col.SetFlag(mysql.AutoIncrementFlag | mysql.NotNullFlag | mysql.PriKeyFlag)
NewColDesc(col)
col.SetFlag(mysql.MultipleKeyFlag)
NewColDesc(col)
col.SetFlag(mysql.UniqueKeyFlag | mysql.OnUpdateNowFlag)
desc := NewColDesc(col)
require.Equal(t, "DEFAULT_GENERATED on update CURRENT_TIMESTAMP", desc.Extra)
col.SetFlag(0)
col.GeneratedExprString = "test"
col.GeneratedStored = true
desc = NewColDesc(col)
require.Equal(t, "STORED GENERATED", desc.Extra)
col.GeneratedStored = false
desc = NewColDesc(col)
require.Equal(t, "VIRTUAL GENERATED", desc.Extra)
ColDescFieldNames(false)
ColDescFieldNames(true)
}
func TestGetZeroValue(t *testing.T) {
tp1 := &types.FieldType{}
tp1.SetType(mysql.TypeLonglong)
tp1.SetFlag(mysql.UnsignedFlag)
tp2 := &types.FieldType{}
tp2.SetType(mysql.TypeString)
tp2.SetFlen(2)
tp2.SetCharset(charset.CharsetBin)
tp2.SetCollate(charset.CollationBin)
tp3 := &types.FieldType{}
tp3.SetType(mysql.TypeString)
tp3.SetFlen(2)
tp3.SetCharset(charset.CharsetUTF8MB4)
tp3.SetCollate(charset.CollationBin)
tests := []struct {
ft *types.FieldType
value types.Datum
}{
{
types.NewFieldType(mysql.TypeLong),
types.NewIntDatum(0),
},
{
tp1,
types.NewUintDatum(0),
},
{
types.NewFieldType(mysql.TypeFloat),
types.NewFloat32Datum(0),
},
{
types.NewFieldType(mysql.TypeDouble),
types.NewFloat64Datum(0),
},
{
types.NewFieldType(mysql.TypeNewDecimal),
types.NewDecimalDatum(types.NewDecFromInt(0)),
},
{
types.NewFieldType(mysql.TypeVarchar),
types.NewStringDatum(""),
},
{
types.NewFieldType(mysql.TypeBlob),
types.NewStringDatum(""),
},
{
types.NewFieldType(mysql.TypeDuration),
types.NewDurationDatum(types.ZeroDuration),
},
{
types.NewFieldType(mysql.TypeDatetime),
types.NewDatum(types.ZeroDatetime),
},
{
types.NewFieldType(mysql.TypeTimestamp),
types.NewDatum(types.ZeroTimestamp),
},
{
types.NewFieldType(mysql.TypeDate),
types.NewDatum(types.ZeroDate),
},
{
types.NewFieldType(mysql.TypeBit),
types.NewMysqlBitDatum(types.ZeroBinaryLiteral),
},
{
types.NewFieldType(mysql.TypeSet),
types.NewDatum(types.Set{}),
},
{
types.NewFieldType(mysql.TypeEnum),
types.NewDatum(types.Enum{}),
},
{
tp2,
types.NewDatum(make([]byte, 2)),
},
{
tp3,
types.NewDatum(""),
},
{
types.NewFieldType(mysql.TypeJSON),
types.NewDatum(types.CreateBinaryJSON(nil)),
},
}
sc := stmtctx.NewStmtCtx()
for _, tt := range tests {
t.Run(fmt.Sprintf("%+v", tt.ft), func(t *testing.T) {
colInfo := &model.ColumnInfo{FieldType: *tt.ft}
zv := GetZeroValue(colInfo)
require.Equal(t, tt.value.Kind(), zv.Kind())
cmp, err := zv.Compare(sc.TypeCtx(), &tt.value, collate.GetCollator(tt.ft.GetCollate()))
require.NoError(t, err)
require.Equal(t, 0, cmp)
})
}
}
func TestCastValue(t *testing.T) {
ctx := mock.NewContext()
colInfo := model.ColumnInfo{
FieldType: *types.NewFieldType(mysql.TypeLong),
State: model.StatePublic,
}
colInfo.SetCharset(mysql.UTF8Charset)
val, err := CastValue(ctx, types.Datum{}, &colInfo, false, false)
require.NoError(t, err)
require.Equal(t, int64(0), val.GetInt64())
val, err = CastValue(ctx, types.NewDatum("test"), &colInfo, false, false)
require.Error(t, err)
require.Equal(t, int64(0), val.GetInt64())
colInfoS := model.ColumnInfo{
FieldType: *types.NewFieldType(mysql.TypeString),
State: model.StatePublic,
}
val, err = CastValue(ctx, types.NewDatum("test"), &colInfoS, false, false)
require.NoError(t, err)
require.NotNil(t, val)
colInfoS.SetCharset(mysql.UTF8Charset)
_, err = CastValue(ctx, types.NewDatum([]byte{0xf0, 0x9f, 0x8c, 0x80}), &colInfoS, false, false)
require.Error(t, err)
colInfoS.SetCharset(mysql.UTF8Charset)
_, err = CastValue(ctx, types.NewDatum([]byte{0xf0, 0x9f, 0x8c, 0x80}), &colInfoS, false, true)
require.NoError(t, err)
colInfoS.SetCharset(mysql.UTF8MB4Charset)
_, err = CastValue(ctx, types.NewDatum([]byte{0xf0, 0x9f, 0x80}), &colInfoS, false, false)
require.Error(t, err)
colInfoS.SetCharset(mysql.UTF8MB4Charset)
_, err = CastValue(ctx, types.NewDatum([]byte{0xf0, 0x9f, 0x80}), &colInfoS, false, true)
require.NoError(t, err)
colInfoS.SetCharset(charset.CharsetASCII)
_, err = CastValue(ctx, types.NewDatum([]byte{0x32, 0xf0}), &colInfoS, false, false)
require.Error(t, err)
colInfoS.SetCharset(charset.CharsetASCII)
_, err = CastValue(ctx, types.NewDatum([]byte{0x32, 0xf0}), &colInfoS, false, true)
require.NoError(t, err)
colInfoS.SetCharset(charset.CharsetUTF8MB4)
colInfoS.SetCollate("utf8mb4_general_ci")
val, err = CastValue(ctx, types.NewBinaryLiteralDatum([]byte{0xE5, 0xA5, 0xBD}), &colInfoS, false, false)
require.NoError(t, err)
require.Equal(t, "utf8mb4_general_ci", val.Collation())
val, err = CastValue(ctx, types.NewBinaryLiteralDatum([]byte{0xE5, 0xA5, 0xBD, 0x81}), &colInfoS, false, false)
require.ErrorContains(t, err, "[table:1366]Incorrect string value '\\x81' for column ''")
require.Equal(t, "utf8mb4_general_ci", val.Collation())
val, err = CastValue(ctx, types.NewDatum([]byte{0xE5, 0xA5, 0xBD, 0x81}), &colInfoS, false, false)
require.ErrorContains(t, err, "[table:1366]Incorrect string value '\\x81' for column ''")
require.Equal(t, "utf8mb4_general_ci", val.Collation())
}
func TestGetDefaultValue(t *testing.T) {
var nilDt types.Datum
nilDt.SetNull()
ctx := mock.NewContext()
tz, err := time.LoadLocation("America/Los_Angeles")
require.NoError(t, err)
ctx.ResetSessionAndStmtTimeZone(tz)
zeroTimestamp := types.ZeroTimestamp
timestampValue := types.NewTime(types.FromDate(2019, 5, 6, 12, 48, 49, 0), mysql.TypeTimestamp, types.DefaultFsp)
timestampValueUTC := types.NewTime(timestampValue.CoreTime(), timestampValue.Type(), timestampValue.Fsp())
err = timestampValueUTC.ConvertTimeZone(tz, time.UTC)
require.NoError(t, err)
tp0 := types.FieldType{}
tp0.SetType(mysql.TypeLonglong)
tp1 := types.FieldType{}
tp1.SetType(mysql.TypeLonglong)
tp1.SetFlag(mysql.NotNullFlag)
tp2 := types.FieldType{}
tp2.SetType(mysql.TypeEnum)
tp2.SetFlag(mysql.NotNullFlag)
tp2.SetElems([]string{"abc", "def"})
tp2.SetCollate(mysql.DefaultCollationName)
tp3 := types.FieldType{}
tp3.SetType(mysql.TypeTimestamp)
tp3.SetFlag(mysql.TimestampFlag)
tp4 := types.FieldType{}
tp4.SetType(mysql.TypeLonglong)
tp4.SetFlag(mysql.NotNullFlag | mysql.AutoIncrementFlag)
tests := []struct {
colInfo *model.ColumnInfo
strict bool
val types.Datum
err error
}{
{
&model.ColumnInfo{
FieldType: tp1,
OriginDefaultValue: 1.0,
DefaultValue: 1.0,
},
false,
types.NewIntDatum(1),
nil,
},
{
&model.ColumnInfo{
FieldType: tp1,
},
false,
types.NewIntDatum(0),
nil,
},
{
&model.ColumnInfo{
FieldType: tp0,
},
false,
types.Datum{},
nil,
},
{
&model.ColumnInfo{
FieldType: tp2,
},
false,
types.NewMysqlEnumDatum(types.Enum{Name: "abc", Value: 1}),
nil,
},
{
&model.ColumnInfo{
FieldType: tp3,
OriginDefaultValue: "0000-00-00 00:00:00",
DefaultValue: "0000-00-00 00:00:00",
},
false,
types.NewDatum(zeroTimestamp),
nil,
},
{
&model.ColumnInfo{
FieldType: tp3,
OriginDefaultValue: timestampValueUTC.String(),
DefaultValue: timestampValueUTC.String(),
Version: model.ColumnInfoVersion2,
},
true,
types.NewDatum(timestampValue),
nil,
},
{
&model.ColumnInfo{
FieldType: tp3,
OriginDefaultValue: "not valid date",
DefaultValue: "not valid date",
},
true,
types.NewDatum(zeroTimestamp),
errGetDefaultFailed,
},
{
&model.ColumnInfo{
FieldType: tp1,
},
true,
types.NewDatum(zeroTimestamp),
ErrNoDefaultValue,
},
{
&model.ColumnInfo{
FieldType: tp4,
},
true,
types.NewIntDatum(0),
nil,
},
{
&model.ColumnInfo{
FieldType: tp1,
DefaultIsExpr: true,
DefaultValue: "1",
},
false,
nilDt,
nil,
},
}
exp := expression.EvalSimpleAst
expression.EvalSimpleAst = func(sctx expression.BuildContext, expr ast.ExprNode) (types.Datum, error) {
return types.NewIntDatum(1), nil
}
defer func() {
expression.EvalSimpleAst = exp
}()
defaultMode, err := mysql.GetSQLMode(mysql.DefaultSQLMode)
require.NoError(t, err)
require.True(t, defaultMode.HasStrictMode())
for _, tt := range tests {
sc := ctx.GetSessionVars().StmtCtx
if tt.strict {
ctx.GetSessionVars().SQLMode = defaultMode
} else {
ctx.GetSessionVars().SQLMode = mysql.DelSQLMode(defaultMode, mysql.ModeStrictAllTables|mysql.ModeStrictTransTables)
}
levels := sc.ErrLevels()
levels[errctx.ErrGroupNoDefault] = errctx.ResolveErrLevel(false, !tt.strict)
sc.SetErrLevels(levels)
val, err := GetColDefaultValue(ctx, tt.colInfo)
if err != nil {
require.Errorf(t, tt.err, "%v", err)
continue
}
if tt.colInfo.DefaultIsExpr {
require.Equal(t, types.NewIntDatum(1), val)
} else {
require.Equal(t, tt.val, val)
}
}
for _, tt := range tests {
sc := ctx.GetSessionVars().StmtCtx
if tt.strict {
ctx.GetSessionVars().SQLMode = defaultMode
} else {
ctx.GetSessionVars().SQLMode = mysql.DelSQLMode(defaultMode, mysql.ModeStrictAllTables|mysql.ModeStrictTransTables)
}
levels := sc.ErrLevels()
levels[errctx.ErrGroupNoDefault] = errctx.ResolveErrLevel(false, !tt.strict)
sc.SetErrLevels(levels)
val, err := GetColOriginDefaultValue(ctx, tt.colInfo)
if err != nil {
require.Errorf(t, tt.err, "%v", err)
continue
}
if !tt.colInfo.DefaultIsExpr {
require.Equal(t, tt.val, val)
}
}
}
func newCol(name string) *Column {
return ToColumn(&model.ColumnInfo{
Name: ast.NewCIStr(name),
State: model.StatePublic,
})
}
func TestCastValueStrict(t *testing.T) {
// signed -> unsigned, fail
input := types.NewIntDatum(-1)
ft := types.NewFieldType(mysql.TypeLonglong)
ft.AddFlag(mysql.UnsignedFlag)
v, err := CastColumnValueWithStrictMode(input, ft)
require.Error(t, err)
require.EqualValues(t, 0, v.GetUint64()) // we get truncated value
// signed -> signed, succeed
input = types.NewIntDatum(1)
ft = types.NewFieldType(mysql.TypeLonglong)
v, err = CastColumnValueWithStrictMode(input, ft)
require.NoError(t, err)
require.EqualValues(t, input.GetInt64(), v.GetInt64())
// bigint -> int, fail
input = types.NewIntDatum(1 << 40)
ft = types.NewFieldType(mysql.TypeLong)
v, err = CastColumnValueWithStrictMode(input, ft)
require.Error(t, err)
require.EqualValues(t, math.MaxInt32, v.GetInt64()) // we get truncated value
// int -> bigint, succeed
input = types.NewIntDatum(1 << 16)
ft = types.NewFieldType(mysql.TypeLonglong)
v, err = CastColumnValueWithStrictMode(input, ft)
require.NoError(t, err)
require.EqualValues(t, input.GetInt64(), v.GetInt64())
// char(4) -> char(2), fail
input = types.NewStringDatum("abcd")
ft = types.NewFieldType(mysql.TypeString)
ft.SetFlen(2)
v, err = CastColumnValueWithStrictMode(input, ft)
require.Error(t, err)
require.Equal(t, "ab", v.GetString()) // we get truncated value
// varchar(4) -> char(2), succeed
input = types.NewStringDatum("a ")
ft = types.NewFieldType(mysql.TypeString)
ft.SetFlen(2)
v, err = CastColumnValueWithStrictMode(input, ft)
require.NoError(t, err)
require.Equal(t, "a", v.GetString())
}