Files
tidb/pkg/executor/join/hash_join_test_util.go

322 lines
9.5 KiB
Go

// Copyright 2024 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package join
import (
"context"
"math/rand"
"sort"
"strconv"
"testing"
"github.com/pingcap/tidb/pkg/executor/internal/exec"
"github.com/pingcap/tidb/pkg/executor/internal/testutil"
"github.com/pingcap/tidb/pkg/expression"
"github.com/pingcap/tidb/pkg/planner/core/base"
"github.com/pingcap/tidb/pkg/sessionctx"
"github.com/pingcap/tidb/pkg/types"
"github.com/pingcap/tidb/pkg/util/chunk"
"github.com/stretchr/testify/require"
)
const letterBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
type hashJoinInfo struct {
ctx sessionctx.Context
schema *expression.Schema
leftExec, rightExec exec.Executor
joinType base.JoinType
rightAsBuildSide bool
buildKeys []*expression.Column
probeKeys []*expression.Column
conditions expression.CNFExprs
lUsed []int
rUsed []int
otherCondition expression.CNFExprs
lUsedInOtherCondition []int
rUsedInOtherCondition []int
equalConditions []*expression.ScalarFunction
fileNamePrefixForTest string
}
func buildHashJoinV2Exec(info *hashJoinInfo) *HashJoinV2Exec {
concurrency := 3
e := &HashJoinV2Exec{
BaseExecutor: exec.NewBaseExecutor(info.ctx, info.schema, 0, info.leftExec, info.rightExec),
ProbeSideTupleFetcher: &ProbeSideTupleFetcherV2{},
ProbeWorkers: make([]*ProbeWorkerV2, concurrency),
BuildWorkers: make([]*BuildWorkerV2, concurrency),
HashJoinCtxV2: &HashJoinCtxV2{
OtherCondition: info.otherCondition,
partitionNumber: 4,
},
FileNamePrefixForTest: info.fileNamePrefixForTest,
}
e.HashJoinCtxV2.SessCtx = info.ctx
e.HashJoinCtxV2.JoinType = info.joinType
e.HashJoinCtxV2.Concurrency = uint(concurrency)
e.HashJoinCtxV2.BuildFilter = info.conditions
e.HashJoinCtxV2.SetupPartitionInfo()
e.ChunkAllocPool = e.AllocPool
e.RightAsBuildSide = info.rightAsBuildSide
lhsTypes, rhsTypes := exec.RetTypes(info.leftExec), exec.RetTypes(info.rightExec)
joinedTypes := make([]*types.FieldType, 0, len(lhsTypes)+len(rhsTypes))
joinedTypes = append(joinedTypes, lhsTypes...)
joinedTypes = append(joinedTypes, rhsTypes...)
var buildSideExec exec.Executor
if e.RightAsBuildSide {
buildSideExec = info.rightExec
e.ProbeSideTupleFetcher.ProbeSideExec = info.leftExec
} else {
buildSideExec = info.leftExec
e.ProbeSideTupleFetcher.ProbeSideExec = info.rightExec
}
probeKeyColIdx := make([]int, len(info.probeKeys))
buildKeyColIdx := make([]int, len(info.buildKeys))
for i := range info.buildKeys {
buildKeyColIdx[i] = info.buildKeys[i].Index
}
for i := range info.probeKeys {
probeKeyColIdx[i] = info.probeKeys[i].Index
}
e.LUsed = info.lUsed
e.RUsed = info.rUsed
e.LUsedInOtherCondition = info.lUsedInOtherCondition
e.RUsedInOtherCondition = info.rUsedInOtherCondition
var leftJoinKeys, rightJoinKeys []*expression.Column
if e.RightAsBuildSide {
rightJoinKeys = info.buildKeys
leftJoinKeys = info.probeKeys
} else {
rightJoinKeys = info.probeKeys
leftJoinKeys = info.buildKeys
}
leftExecTypes, rightExecTypes := exec.RetTypes(info.leftExec), exec.RetTypes(info.rightExec)
leftTypes, rightTypes := make([]*types.FieldType, 0, len(leftJoinKeys)), make([]*types.FieldType, 0, len(rightJoinKeys))
for i, col := range leftJoinKeys {
leftTypes = append(leftTypes, leftExecTypes[col.Index].Clone())
leftTypes[i].SetFlag(col.RetType.GetFlag())
}
for i, col := range rightJoinKeys {
rightTypes = append(rightTypes, rightExecTypes[col.Index].Clone())
rightTypes[i].SetFlag(col.RetType.GetFlag())
}
for i := range info.equalConditions {
chs, coll := info.equalConditions[i].CharsetAndCollation()
leftTypes[i].SetCharset(chs)
leftTypes[i].SetCollate(coll)
rightTypes[i].SetCharset(chs)
rightTypes[i].SetCollate(coll)
}
if e.RightAsBuildSide {
e.BuildKeyTypes, e.ProbeKeyTypes = rightTypes, leftTypes
} else {
e.BuildKeyTypes, e.ProbeKeyTypes = leftTypes, rightTypes
}
for i := range concurrency {
e.ProbeWorkers[i] = &ProbeWorkerV2{
HashJoinCtx: e.HashJoinCtxV2,
JoinProbe: NewJoinProbe(e.HashJoinCtxV2, uint(i), info.joinType, probeKeyColIdx, joinedTypes, e.ProbeKeyTypes, e.RightAsBuildSide),
}
e.ProbeWorkers[i].WorkerID = uint(i)
e.BuildWorkers[i] = NewJoinBuildWorkerV2(e.HashJoinCtxV2, uint(i), buildSideExec, buildKeyColIdx, exec.RetTypes(buildSideExec))
}
return e
}
func generateCMPFunc(fieldTypes []*types.FieldType) func(chunk.Row, chunk.Row) int {
cmpFuncs := make([]chunk.CompareFunc, 0, len(fieldTypes))
for _, colType := range fieldTypes {
cmpFuncs = append(cmpFuncs, chunk.GetCompareFunc(colType))
}
cmp := func(rowI, rowJ chunk.Row) int {
for i, cmpFunc := range cmpFuncs {
cmp := cmpFunc(rowI, i, rowJ, i)
if cmp != 0 {
return cmp
}
}
return 0
}
return cmp
}
func sortRows(chunks []*chunk.Chunk, fieldTypes []*types.FieldType) []chunk.Row {
cmp := generateCMPFunc(fieldTypes)
rowNum := 0
for _, chk := range chunks {
rowNum += chk.NumRows()
}
rows := make([]chunk.Row, 0, rowNum)
for _, chk := range chunks {
iter := chunk.NewIterator4Chunk(chk)
for row := iter.Begin(); row != iter.End(); row = iter.Next() {
rows = append(rows, row)
}
}
sort.Slice(rows, func(i, j int) bool {
return cmp(rows[i], rows[j]) < 0
})
return rows
}
func buildJoinKeyIntDatums(num int) []any {
datumSet := make(map[int64]bool, num)
datums := make([]any, 0, num)
for len(datums) < num {
val := rand.Int63n(100000)
if datumSet[val] {
continue
}
datumSet[val] = true
datums = append(datums, val)
}
return datums
}
func getRandString() string {
b := make([]byte, 10)
for i := range b {
b[i] = letterBytes[rand.Intn(len(letterBytes))]
}
return string(b)
}
func buildJoinKeyStringDatums(num int) []any {
datumSet := make(map[string]bool, num)
datums := make([]any, 0, num)
for len(datums) < num {
val := getRandString()
if datumSet[val] {
continue
}
datumSet[val] = true
datums = append(datums, val)
}
return datums
}
func buildLeftAndRightDataSource(ctx sessionctx.Context, leftCols []*expression.Column, rightCols []*expression.Column, hasSel bool) (_, _ *testutil.MockDataSource) {
leftSchema := expression.NewSchema(leftCols...)
rightSchema := expression.NewSchema(rightCols...)
joinKeyIntDatums := buildJoinKeyIntDatums(20000)
joinKeyStringDatums := buildJoinKeyStringDatums(2)
leftMockSrcParm := testutil.MockDataSourceParameters{DataSchema: leftSchema, Ctx: ctx, Rows: 50000, Ndvs: []int{0, -1, 0, -1, 0}, Datums: [][]any{nil, joinKeyIntDatums, nil, joinKeyStringDatums, nil}, HasSel: hasSel}
rightMockSrcParm := testutil.MockDataSourceParameters{DataSchema: rightSchema, Ctx: ctx, Rows: 50000, Ndvs: []int{-1, 0, -1, 0, 0}, Datums: [][]any{joinKeyIntDatums, nil, joinKeyStringDatums, nil, nil}, HasSel: hasSel}
return testutil.BuildMockDataSource(leftMockSrcParm), testutil.BuildMockDataSource(rightMockSrcParm)
}
func buildSchema(schemaTypes []*types.FieldType) *expression.Schema {
schema := &expression.Schema{}
for _, tp := range schemaTypes {
schema.Append(&expression.Column{
RetType: tp,
})
}
return schema
}
func executeHashJoinExec(t *testing.T, hashJoinExec *HashJoinV2Exec) []*chunk.Chunk {
tmpCtx := context.Background()
hashJoinExec.isMemoryClearedForTest = true
err := hashJoinExec.Open(tmpCtx)
require.NoError(t, err)
results := make([]*chunk.Chunk, 0)
chk := exec.NewFirstChunk(hashJoinExec)
for {
err = hashJoinExec.Next(tmpCtx, chk)
require.NoError(t, err)
if chk.NumRows() == 0 {
break
}
results = append(results, chk)
chk = exec.NewFirstChunk(hashJoinExec)
}
err = hashJoinExec.Close()
require.NoError(t, err)
return results
}
func executeHashJoinExecAndGetError(t *testing.T, hashJoinExec *HashJoinV2Exec) error {
tmpCtx := context.Background()
err := hashJoinExec.Open(tmpCtx)
require.NoError(t, err)
chk := exec.NewFirstChunk(hashJoinExec)
for {
err = hashJoinExec.Next(tmpCtx, chk)
if err != nil {
break
}
if chk.NumRows() == 0 {
break
}
chk.Reset()
}
require.NoError(t, hashJoinExec.Close())
return err
}
func executeHashJoinExecForRandomFailTest(t *testing.T, hashJoinExec *HashJoinV2Exec) {
tmpCtx := context.Background()
err := hashJoinExec.Open(tmpCtx)
require.NoError(t, err)
chk := exec.NewFirstChunk(hashJoinExec)
for {
err = hashJoinExec.Next(tmpCtx, chk)
if err != nil {
break
}
if chk.NumRows() == 0 {
break
}
chk.Reset()
}
_ = hashJoinExec.Close()
}
func getSortedResults(t *testing.T, hashJoinExec *HashJoinV2Exec, resultTypes []*types.FieldType) []chunk.Row {
results := executeHashJoinExec(t, hashJoinExec)
return sortRows(results, resultTypes)
}
func checkResults(t *testing.T, fieldTypes []*types.FieldType, actualResult []chunk.Row, expectedResult []chunk.Row) {
require.Equal(t, len(expectedResult), len(actualResult))
cmp := generateCMPFunc(fieldTypes)
for i := range actualResult {
x := cmp(actualResult[i], expectedResult[i])
require.Equal(t, 0, x, "result index = "+strconv.Itoa(i))
}
}