Files
tidb/pkg/executor/join/hash_table_v2_test.go

287 lines
8.5 KiB
Go

// Copyright 2024 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package join
import (
"math/rand"
"testing"
"unsafe"
"github.com/pingcap/tidb/pkg/executor/internal/testutil"
"github.com/pingcap/tidb/pkg/expression"
"github.com/pingcap/tidb/pkg/parser/mysql"
"github.com/pingcap/tidb/pkg/types"
"github.com/pingcap/tidb/pkg/util"
"github.com/pingcap/tidb/pkg/util/mock"
"github.com/stretchr/testify/require"
)
func createMockRowTable(maxRowsPerSeg int, segmentCount int, fixedSize bool) *rowTable {
ret := &rowTable{}
for range segmentCount {
// no empty segment is allowed
rows := maxRowsPerSeg
if !fixedSize {
rows = int(rand.Int31n(int32(maxRowsPerSeg)) + 1)
}
rowSeg := newRowTableSegment()
rowSeg.rawData = make([]byte, rows)
for j := range rows {
rowSeg.rowStartOffset = append(rowSeg.rowStartOffset, uint64(j))
rowSeg.validJoinKeyPos = append(rowSeg.validJoinKeyPos, j)
}
ret.segments = append(ret.segments, rowSeg)
}
return ret
}
func createRowTable(rows int) (*rowTable, uint8, error) {
tinyTp := types.NewFieldType(mysql.TypeTiny)
tinyTp.AddFlag(mysql.NotNullFlag)
buildKeyIndex := []int{0}
buildTypes := []*types.FieldType{tinyTp}
buildKeyTypes := []*types.FieldType{tinyTp}
probeKeyTypes := []*types.FieldType{tinyTp}
buildSchema := &expression.Schema{}
for _, tp := range buildTypes {
buildSchema.Append(&expression.Column{
RetType: tp,
})
}
meta := newTableMeta(buildKeyIndex, buildTypes, buildKeyTypes, probeKeyTypes, nil, []int{}, false)
hasNullableKey := false
for _, buildKeyType := range buildKeyTypes {
if !mysql.HasNotNullFlag(buildKeyType.GetFlag()) {
hasNullableKey = true
break
}
}
chk := testutil.GenRandomChunks(buildTypes, rows)
hashJoinCtx := &HashJoinCtxV2{
hashTableMeta: meta,
}
hashJoinCtx.Concurrency = 1
hashJoinCtx.SetupPartitionInfo()
hashJoinCtx.initHashTableContext()
hashJoinCtx.SessCtx = mock.NewContext()
builder := createRowTableBuilder(buildKeyIndex, buildKeyTypes, hashJoinCtx.partitionNumber, hasNullableKey, false, false, meta.nullMapLength)
err := builder.processOneChunk(chk, hashJoinCtx.SessCtx.GetSessionVars().StmtCtx.TypeCtx(), hashJoinCtx, 0)
if err != nil {
return nil, 0, err
}
taggedBits := uint8(maxTaggedBits)
for _, seg := range hashJoinCtx.hashTableContext.rowTables[0][0].segments {
taggedBits = min(taggedBits, seg.taggedBits)
}
return hashJoinCtx.hashTableContext.rowTables[0][0], taggedBits, nil
}
func TestHashTableSize(t *testing.T) {
rowTable := createMockRowTable(2, 5, true)
subTable := newSubTable(rowTable)
// min hash table size is 32
require.Equal(t, 32, len(subTable.hashTable))
rowTable = createMockRowTable(32, 1, true)
subTable = newSubTable(rowTable)
require.Equal(t, 64, len(subTable.hashTable))
rowTable = createMockRowTable(33, 1, true)
subTable = newSubTable(rowTable)
require.Equal(t, 64, len(subTable.hashTable))
rowTable = createMockRowTable(64, 1, true)
subTable = newSubTable(rowTable)
require.Equal(t, 128, len(subTable.hashTable))
rowTable = createMockRowTable(65, 1, true)
subTable = newSubTable(rowTable)
require.Equal(t, 128, len(subTable.hashTable))
}
func TestBuild(t *testing.T) {
rowTable, taggedBits, err := createRowTable(1000000)
require.NoError(t, err)
tagHelper := &tagPtrHelper{}
tagHelper.init(taggedBits)
subTable := newSubTable(rowTable)
// single thread build
subTable.build(0, len(rowTable.segments), tagHelper)
rowSet := make(map[unsafe.Pointer]struct{}, rowTable.rowCount())
for _, seg := range rowTable.segments {
for index := range seg.rowStartOffset {
loc := seg.getRowPointer(index)
_, ok := rowSet[loc]
require.False(t, ok)
rowSet[loc] = struct{}{}
}
}
rowCount := 0
for _, locHolder := range subTable.hashTable {
for locHolder != 0 {
rowCount++
loc := tagHelper.toUnsafePointer(locHolder)
_, ok := rowSet[loc]
require.True(t, ok)
delete(rowSet, loc)
// use 0 as hashvalue so getNextRowAddress won't exit early
locHolder = getNextRowAddress(loc, tagHelper, 0)
}
}
require.Equal(t, 0, len(rowSet))
require.Equal(t, rowTable.rowCount(), uint64(rowCount))
}
func TestConcurrentBuild(t *testing.T) {
rowTable, tagBits, err := createRowTable(3000000)
require.NoError(t, err)
subTable := newSubTable(rowTable)
segmentCount := len(rowTable.segments)
buildThreads := 3
tagHelper := &tagPtrHelper{}
tagHelper.init(tagBits)
wg := util.WaitGroupWrapper{}
for i := range buildThreads {
segmentStart := segmentCount / buildThreads * i
segmentEnd := segmentCount / buildThreads * (i + 1)
if i == buildThreads-1 {
segmentEnd = segmentCount
}
wg.Run(func() {
subTable.build(segmentStart, segmentEnd, tagHelper)
})
}
wg.Wait()
rowSet := make(map[unsafe.Pointer]struct{}, rowTable.rowCount())
for _, seg := range rowTable.segments {
for index := range seg.rowStartOffset {
loc := seg.getRowPointer(index)
_, ok := rowSet[loc]
require.False(t, ok)
rowSet[loc] = struct{}{}
}
}
for _, locHolder := range subTable.hashTable {
for locHolder != 0 {
loc := tagHelper.toUnsafePointer(locHolder)
_, ok := rowSet[loc]
require.True(t, ok)
delete(rowSet, loc)
locHolder = getNextRowAddress(loc, tagHelper, 0)
}
}
require.Equal(t, 0, len(rowSet))
}
func TestLookup(t *testing.T) {
rowTable, tagBits, err := createRowTable(200000)
require.NoError(t, err)
tagHelper := &tagPtrHelper{}
tagHelper.init(tagBits)
subTable := newSubTable(rowTable)
// single thread build
subTable.build(0, len(rowTable.segments), tagHelper)
for _, seg := range rowTable.segments {
for index := range seg.rowStartOffset {
hashValue := seg.hashValues[index]
candidate := subTable.lookup(hashValue, tagHelper)
loc := seg.getRowPointer(index)
found := false
for candidate != 0 {
candidatePtr := tagHelper.toUnsafePointer(candidate)
if candidatePtr == loc {
found = true
break
}
candidate = getNextRowAddress(candidatePtr, tagHelper, hashValue)
}
require.True(t, found)
}
}
}
func checkRowIter(t *testing.T, table *hashTableV2, scanConcurrency int) {
// first create a map containing all the row locations
totalRowCount := table.totalRowCount()
rowSet := make(map[unsafe.Pointer]struct{}, totalRowCount)
for _, rt := range table.tables {
for _, seg := range rt.rowData.segments {
for index := range seg.rowStartOffset {
loc := seg.getRowPointer(index)
_, ok := rowSet[loc]
require.False(t, ok)
rowSet[loc] = struct{}{}
}
}
}
// create row iters
rowIters := make([]*rowIter, 0, scanConcurrency)
rowPerScan := totalRowCount / uint64(scanConcurrency)
for i := uint64(0); i < uint64(scanConcurrency); i++ {
startIndex := rowPerScan * i
endIndex := rowPerScan * (i + 1)
if i == uint64(scanConcurrency-1) {
endIndex = totalRowCount
}
rowIters = append(rowIters, table.createRowIter(startIndex, endIndex))
}
locCount := uint64(0)
for _, it := range rowIters {
for !it.isEnd() {
locCount++
loc := it.getValue()
_, ok := rowSet[loc]
require.True(t, ok)
delete(rowSet, loc)
it.next()
}
}
require.Equal(t, table.totalRowCount(), locCount)
require.Equal(t, 0, len(rowSet))
}
func TestRowIter(t *testing.T) {
partitionNumbers := []int{1, 4, 8}
// normal case
for _, partitionNumber := range partitionNumbers {
// create row tables
rowTables := make([]*rowTable, 0, partitionNumber)
for range partitionNumber {
rt := createMockRowTable(1024, 16, false)
rowTables = append(rowTables, rt)
}
joinedHashTable := newJoinHashTableForTest(rowTables)
checkRowIter(t, joinedHashTable, partitionNumber)
}
// case with empty row table
for _, partitionNumber := range partitionNumbers {
for i := range partitionNumber {
// the i-th row table is an empty row table
rowTables := make([]*rowTable, 0, partitionNumber)
for j := range partitionNumber {
if i == j {
rt := createMockRowTable(0, 0, true)
rowTables = append(rowTables, rt)
} else {
rt := createMockRowTable(1024, 16, false)
rowTables = append(rowTables, rt)
}
}
joinedHashTable := newJoinHashTableForTest(rowTables)
checkRowIter(t, joinedHashTable, partitionNumber)
}
}
}