Files
tidb/pkg/executor/internal/vecgroupchecker/vec_group_checker_test.go
2025-05-07 14:32:26 +00:00

281 lines
8.0 KiB
Go

// Copyright 2023 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package vecgroupchecker
import (
"fmt"
"math/rand"
"testing"
"github.com/pingcap/tidb/pkg/expression"
"github.com/pingcap/tidb/pkg/parser/mysql"
"github.com/pingcap/tidb/pkg/types"
"github.com/pingcap/tidb/pkg/util/chunk"
"github.com/pingcap/tidb/pkg/util/mock"
"github.com/stretchr/testify/require"
)
func TestVecGroupCheckerDATARACE(t *testing.T) {
ctx := mock.NewContext()
mTypes := []byte{mysql.TypeVarString, mysql.TypeNewDecimal, mysql.TypeJSON}
for _, mType := range mTypes {
exprs := make([]expression.Expression, 1)
exprs[0] = &expression.Column{
RetType: types.NewFieldTypeBuilder().SetType(mType).BuildP(),
Index: 0,
}
vgc := NewVecGroupChecker(ctx, ctx.GetSessionVars().EnableVectorizedExpression, exprs)
fts := []*types.FieldType{types.NewFieldType(mType)}
chk := chunk.New(fts, 1, 1)
vgc.allocateBuffer = func(evalType types.EvalType, capacity int) (*chunk.Column, error) {
return chk.Column(0), nil
}
vgc.releaseBuffer = func(column *chunk.Column) {}
switch mType {
case mysql.TypeVarString:
chk.Column(0).ReserveString(1)
chk.Column(0).AppendString("abc")
case mysql.TypeNewDecimal:
chk.Column(0).ResizeDecimal(1, false)
chk.Column(0).Decimals()[0] = *types.NewDecFromInt(123)
case mysql.TypeJSON:
chk.Column(0).ReserveJSON(1)
j := new(types.BinaryJSON)
require.NoError(t, j.UnmarshalJSON(fmt.Appendf(nil, `{"%v":%v}`, 123, 123)))
chk.Column(0).AppendJSON(*j)
}
_, err := vgc.SplitIntoGroups(chk)
require.NoError(t, err)
switch mType {
case mysql.TypeVarString:
require.Equal(t, "abc", vgc.firstRowDatums[0].GetString())
require.Equal(t, "abc", vgc.lastRowDatums[0].GetString())
chk.Column(0).ReserveString(1)
chk.Column(0).AppendString("edf")
require.Equal(t, "abc", vgc.firstRowDatums[0].GetString())
require.Equal(t, "abc", vgc.lastRowDatums[0].GetString())
case mysql.TypeNewDecimal:
require.Equal(t, "123", vgc.firstRowDatums[0].GetMysqlDecimal().String())
require.Equal(t, "123", vgc.lastRowDatums[0].GetMysqlDecimal().String())
chk.Column(0).ResizeDecimal(1, false)
chk.Column(0).Decimals()[0] = *types.NewDecFromInt(456)
require.Equal(t, "123", vgc.firstRowDatums[0].GetMysqlDecimal().String())
require.Equal(t, "123", vgc.lastRowDatums[0].GetMysqlDecimal().String())
case mysql.TypeJSON:
require.Equal(t, `{"123": 123}`, vgc.firstRowDatums[0].GetMysqlJSON().String())
require.Equal(t, `{"123": 123}`, vgc.lastRowDatums[0].GetMysqlJSON().String())
chk.Column(0).ReserveJSON(1)
j := new(types.BinaryJSON)
require.NoError(t, j.UnmarshalJSON(fmt.Appendf(nil, `{"%v":%v}`, 456, 456)))
chk.Column(0).AppendJSON(*j)
require.Equal(t, `{"123": 123}`, vgc.firstRowDatums[0].GetMysqlJSON().String())
require.Equal(t, `{"123": 123}`, vgc.lastRowDatums[0].GetMysqlJSON().String())
}
}
}
func genTestChunk4VecGroupChecker(chkRows []int, sameNum int) (expr []expression.Expression, inputs []*chunk.Chunk) {
chkNum := len(chkRows)
numRows := 0
inputs = make([]*chunk.Chunk, chkNum)
fts := make([]*types.FieldType, 1)
fts[0] = types.NewFieldType(mysql.TypeLonglong)
for i := range chkNum {
inputs[i] = chunk.New(fts, chkRows[i], chkRows[i])
numRows += chkRows[i]
}
var numGroups int
if numRows%sameNum == 0 {
numGroups = numRows / sameNum
} else {
numGroups = numRows/sameNum + 1
}
nullPos := rand.Intn(numGroups)
cnt := 0
val := rand.Int63()
for i := range chkNum {
col := inputs[i].Column(0)
col.ResizeInt64(chkRows[i], false)
i64s := col.Int64s()
for j := range chkRows[i] {
if cnt == sameNum {
val = rand.Int63()
cnt = 0
nullPos--
}
if nullPos == 0 {
col.SetNull(j, true)
} else {
i64s[j] = val
}
cnt++
}
}
expr = make([]expression.Expression, 1)
expr[0] = &expression.Column{
RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeLonglong).SetFlen(mysql.MaxIntWidth).BuildP(),
Index: 0,
}
return
}
func TestVecGroupChecker4GroupCount(t *testing.T) {
testCases := []struct {
chunkRows []int
expectedGroups int
expectedFlag []bool
sameNum int
}{
{
chunkRows: []int{1024, 1},
expectedGroups: 1025,
expectedFlag: []bool{false, false},
sameNum: 1,
},
{
chunkRows: []int{1024, 1},
expectedGroups: 1,
expectedFlag: []bool{false, true},
sameNum: 1025,
},
{
chunkRows: []int{1, 1},
expectedGroups: 1,
expectedFlag: []bool{false, true},
sameNum: 2,
},
{
chunkRows: []int{1, 1},
expectedGroups: 2,
expectedFlag: []bool{false, false},
sameNum: 1,
},
{
chunkRows: []int{2, 2},
expectedGroups: 2,
expectedFlag: []bool{false, false},
sameNum: 2,
},
{
chunkRows: []int{2, 2},
expectedGroups: 1,
expectedFlag: []bool{false, true},
sameNum: 4,
},
}
ctx := mock.NewContext()
for _, testCase := range testCases {
expr, inputChks := genTestChunk4VecGroupChecker(testCase.chunkRows, testCase.sameNum)
groupChecker := NewVecGroupChecker(ctx, ctx.GetSessionVars().EnableVectorizedExpression, expr)
groupNum := 0
for i, inputChk := range inputChks {
flag, err := groupChecker.SplitIntoGroups(inputChk)
require.NoError(t, err)
require.Equal(t, testCase.expectedFlag[i], flag)
if flag {
groupNum += groupChecker.GroupCount() - 1
} else {
groupNum += groupChecker.GroupCount()
}
}
require.Equal(t, testCase.expectedGroups, groupNum)
}
}
func TestVecGroupChecker(t *testing.T) {
tp := types.NewFieldTypeBuilder().SetType(mysql.TypeVarchar).BuildP()
col0 := &expression.Column{
RetType: tp,
Index: 0,
}
ctx := mock.NewContext()
groupChecker := NewVecGroupChecker(ctx, ctx.GetSessionVars().EnableVectorizedExpression, []expression.Expression{col0})
chk := chunk.New([]*types.FieldType{tp}, 6, 6)
chk.Reset()
chk.Column(0).AppendString("aaa")
chk.Column(0).AppendString("AAA")
chk.Column(0).AppendString("😜")
chk.Column(0).AppendString("😃")
chk.Column(0).AppendString("À")
chk.Column(0).AppendString("A")
tp.SetCollate("bin")
groupChecker.Reset()
_, err := groupChecker.SplitIntoGroups(chk)
require.NoError(t, err)
for i := range 6 {
b, e := groupChecker.GetNextGroup()
require.Equal(t, b, i)
require.Equal(t, e, i+1)
}
require.True(t, groupChecker.IsExhausted())
tp.SetCollate("utf8_general_ci")
groupChecker.Reset()
_, err = groupChecker.SplitIntoGroups(chk)
require.NoError(t, err)
for i := range 3 {
b, e := groupChecker.GetNextGroup()
require.Equal(t, b, i*2)
require.Equal(t, e, i*2+2)
}
require.True(t, groupChecker.IsExhausted())
tp.SetCollate("utf8_unicode_ci")
groupChecker.Reset()
_, err = groupChecker.SplitIntoGroups(chk)
require.NoError(t, err)
for i := range 3 {
b, e := groupChecker.GetNextGroup()
require.Equal(t, b, i*2)
require.Equal(t, e, i*2+2)
}
require.True(t, groupChecker.IsExhausted())
// test padding
tp.SetCollate("utf8_bin")
tp.SetFlen(6)
chk.Reset()
chk.Column(0).AppendString("a")
chk.Column(0).AppendString("a ")
chk.Column(0).AppendString("a ")
groupChecker.Reset()
_, err = groupChecker.SplitIntoGroups(chk)
require.NoError(t, err)
b, e := groupChecker.GetNextGroup()
require.Equal(t, b, 0)
require.Equal(t, e, 3)
require.True(t, groupChecker.IsExhausted())
}
func TestIssue53867(t *testing.T) {
checker := NewVecGroupChecker(nil, true, nil)
checker.groupOffset = make([]int, 20)
checker.nextGroupID = 10
checker.groupCount = 15
require.False(t, checker.IsExhausted())
checker.Reset()
require.True(t, checker.IsExhausted())
}