Files
tidb/pkg/executor/join/inner_join_probe_test.go

701 lines
31 KiB
Go

// Copyright 2024 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package join
import (
"slices"
"sort"
"strconv"
"testing"
"github.com/pingcap/errors"
"github.com/pingcap/tidb/pkg/executor/internal/testutil"
"github.com/pingcap/tidb/pkg/expression"
"github.com/pingcap/tidb/pkg/parser/ast"
"github.com/pingcap/tidb/pkg/parser/mysql"
"github.com/pingcap/tidb/pkg/planner/core/base"
"github.com/pingcap/tidb/pkg/sessionctx"
"github.com/pingcap/tidb/pkg/types"
"github.com/pingcap/tidb/pkg/util/chunk"
"github.com/pingcap/tidb/pkg/util/codec"
"github.com/pingcap/tidb/pkg/util/mock"
"github.com/pingcap/tidb/pkg/util/sqlkiller"
"github.com/stretchr/testify/require"
)
func toNullableTypes(tps []*types.FieldType) []*types.FieldType {
ret := make([]*types.FieldType, 0, len(tps))
for _, tp := range tps {
nullableTp := tp.Clone()
nullableTp.DelFlag(mysql.NotNullFlag)
ret = append(ret, nullableTp)
}
return ret
}
func evalOtherCondition(sessCtx sessionctx.Context, leftRow chunk.Row, rightRow chunk.Row, shallowRow chunk.MutRow, otherCondition expression.CNFExprs) (bool, error) {
shallowRow.ShallowCopyPartialRow(0, leftRow)
shallowRow.ShallowCopyPartialRow(leftRow.Len(), rightRow)
valid, _, err := expression.EvalBool(sessCtx.GetExprCtx().GetEvalCtx(), otherCondition, shallowRow.ToRow())
return valid, err
}
func appendToResultChk(leftRow chunk.Row, rightRow chunk.Row, leftUsedColumns []int, rightUsedColumns []int, resultChunk *chunk.Chunk) {
lWide := 0
if leftRow.IsEmpty() {
for index := range leftUsedColumns {
resultChunk.Column(index).AppendNull()
}
resultChunk.SetNumVirtualRows(resultChunk.NumRows() + 1)
lWide = len(leftUsedColumns)
} else {
lWide = resultChunk.AppendRowByColIdxs(leftRow, leftUsedColumns)
}
if rightRow.IsEmpty() {
for index := range rightUsedColumns {
resultChunk.Column(index + lWide).AppendNull()
}
} else {
resultChunk.AppendPartialRowByColIdxs(lWide, rightRow, rightUsedColumns)
}
}
func containsNullKey(row chunk.Row, keyIndex []int) bool {
return slices.ContainsFunc(keyIndex, row.IsNull)
}
// generate inner join result using nested loop
func genInnerJoinResult(t *testing.T, sessCtx sessionctx.Context, leftChunks []*chunk.Chunk, rightChunks []*chunk.Chunk, leftKeyIndex []int, rightKeyIndex []int,
leftTypes []*types.FieldType, rightTypes []*types.FieldType, leftKeyTypes []*types.FieldType, rightKeyTypes []*types.FieldType, leftUsedColumns []int,
rightUsedColumns []int, otherConditions expression.CNFExprs, resultTypes []*types.FieldType) []*chunk.Chunk {
returnChks := make([]*chunk.Chunk, 0, 1)
resultChk := chunk.New(resultTypes, sessCtx.GetSessionVars().MaxChunkSize, sessCtx.GetSessionVars().MaxChunkSize)
shallowRowTypes := make([]*types.FieldType, 0, len(leftTypes)+len(rightTypes))
shallowRowTypes = append(shallowRowTypes, leftTypes...)
shallowRowTypes = append(shallowRowTypes, rightTypes...)
shallowRow := chunk.MutRowFromTypes(shallowRowTypes)
// for right out join, use left as build, for other join, always use right as build
for _, leftChunk := range leftChunks {
for leftIndex := range leftChunk.NumRows() {
leftRow := leftChunk.GetRow(leftIndex)
for _, rightChunk := range rightChunks {
for rightIndex := range rightChunk.NumRows() {
if resultChk.IsFull() {
returnChks = append(returnChks, resultChk)
resultChk = chunk.New(resultTypes, sessCtx.GetSessionVars().MaxChunkSize, sessCtx.GetSessionVars().MaxChunkSize)
}
rightRow := rightChunk.GetRow(rightIndex)
valid := !containsNullKey(leftRow, leftKeyIndex) && !containsNullKey(rightRow, rightKeyIndex)
if valid {
ok, err := codec.EqualChunkRow(sessCtx.GetSessionVars().StmtCtx.TypeCtx(), leftRow, leftKeyTypes, leftKeyIndex,
rightRow, rightKeyTypes, rightKeyIndex)
require.NoError(t, err)
valid = ok
}
if valid && otherConditions != nil {
// key is match, check other condition
ok, err := evalOtherCondition(sessCtx, leftRow, rightRow, shallowRow, otherConditions)
require.NoError(t, err)
valid = ok
}
if valid {
// construct result chunk
appendToResultChk(leftRow, rightRow, leftUsedColumns, rightUsedColumns, resultChk)
}
}
}
}
}
if resultChk.NumRows() > 0 {
returnChks = append(returnChks, resultChk)
}
return returnChks
}
func checkVirtualRows(t *testing.T, resultChunks []*chunk.Chunk) {
for _, chk := range resultChunks {
require.Equal(t, false, chk.IsInCompleteChunk())
numRows := chk.GetNumVirtualRows()
for i := range chk.NumCols() {
require.Equal(t, numRows, chk.Column(i).Rows())
}
}
}
func checkChunksEqual(t *testing.T, expectedChunks []*chunk.Chunk, resultChunks []*chunk.Chunk, schema []*types.FieldType) {
expectedNum := 0
resultNum := 0
for _, chk := range expectedChunks {
expectedNum += chk.NumRows()
}
for _, chk := range resultChunks {
resultNum += chk.NumRows()
}
require.Equal(t, expectedNum, resultNum)
if expectedNum == 0 || len(schema) == 0 {
return
}
cmpFuncs := make([]chunk.CompareFunc, 0, len(schema))
for _, colType := range schema {
cmpFuncs = append(cmpFuncs, chunk.GetCompareFunc(colType))
}
resultRows := make([]chunk.Row, 0, expectedNum)
expectedRows := make([]chunk.Row, 0, expectedNum)
for _, chk := range expectedChunks {
iter := chunk.NewIterator4Chunk(chk)
for row := iter.Begin(); row != iter.End(); row = iter.Next() {
expectedRows = append(expectedRows, row)
}
}
for _, chk := range resultChunks {
iter := chunk.NewIterator4Chunk(chk)
for row := iter.Begin(); row != iter.End(); row = iter.Next() {
resultRows = append(resultRows, row)
}
}
cmp := func(rowI, rowJ chunk.Row) int {
for i, cmpFunc := range cmpFuncs {
cmp := cmpFunc(rowI, i, rowJ, i)
if cmp != 0 {
return cmp
}
}
return 0
}
sort.Slice(expectedRows, func(i, j int) bool {
return cmp(expectedRows[i], expectedRows[j]) < 0
})
sort.Slice(resultRows, func(i, j int) bool {
return cmp(resultRows[i], resultRows[j]) < 0
})
for i := range expectedRows {
x := cmp(expectedRows[i], resultRows[i])
if x != 0 {
// used for debug
x = cmp(expectedRows[i], resultRows[i])
}
require.Equal(t, 0, x, "result index = "+strconv.Itoa(i))
}
}
// copy data from src to dst, the caller should ensure that src.NumCols() >= dst.NumCols()
func copySelectedRows(src *chunk.Chunk, dst *chunk.Chunk, selected []bool) (bool, error) {
if src.NumRows() == 0 {
return false, nil
}
if src.Sel() != nil || dst.Sel() != nil {
return false, errors.New("copy with sel")
}
if src.NumCols() == 0 {
numSelected := 0
for _, s := range selected {
if s {
numSelected++
}
}
dst.SetNumVirtualRows(dst.GetNumVirtualRows() + numSelected)
return numSelected > 0, nil
}
oldLen := dst.NumRows()
for j := range src.NumCols() {
if j >= dst.NumCols() {
break
}
srcCol := src.Column(j)
dstCol := dst.Column(j)
chunk.CopySelectedRows(dstCol, srcCol, selected)
}
numSelected := dst.NumRows() - oldLen
dst.SetNumVirtualRows(dst.GetNumVirtualRows() + numSelected)
return numSelected > 0, nil
}
func testJoinProbe(t *testing.T, withSel bool, leftKeyIndex []int, rightKeyIndex []int, leftKeyTypes []*types.FieldType, rightKeyTypes []*types.FieldType,
leftTypes []*types.FieldType, rightTypes []*types.FieldType, rightAsBuildSide bool, leftUsed []int, rightUsed []int,
leftUsedByOtherCondition []int, rightUsedByOtherCondition []int, leftFilter expression.CNFExprs, rightFilter expression.CNFExprs,
otherCondition expression.CNFExprs, partitionNumber int, joinType base.JoinType, inputRowNumber int) {
// leftUsed/rightUsed is nil, it means select all columns
if leftUsed == nil {
leftUsed = make([]int, 0)
for index := range leftTypes {
leftUsed = append(leftUsed, index)
}
}
if rightUsed == nil {
rightUsed = make([]int, 0)
for index := range rightTypes {
rightUsed = append(rightUsed, index)
}
}
buildKeyIndex, probeKeyIndex := leftKeyIndex, rightKeyIndex
buildKeyTypes, probeKeyTypes := leftKeyTypes, rightKeyTypes
buildTypes, probeTypes := leftTypes, rightTypes
buildUsed := leftUsed
buildUsedByOtherCondition := leftUsedByOtherCondition
buildFilter, probeFilter := leftFilter, rightFilter
needUsedFlag := false
if rightAsBuildSide {
probeKeyIndex, buildKeyIndex = leftKeyIndex, rightKeyIndex
probeKeyTypes, buildKeyTypes = leftKeyTypes, rightKeyTypes
probeTypes, buildTypes = leftTypes, rightTypes
buildUsed = rightUsed
buildUsedByOtherCondition = rightUsedByOtherCondition
buildFilter, probeFilter = rightFilter, leftFilter
if joinType == base.RightOuterJoin {
needUsedFlag = true
}
} else {
switch joinType {
case base.LeftOuterJoin, base.SemiJoin, base.AntiSemiJoin:
needUsedFlag = true
case base.LeftOuterSemiJoin, base.AntiLeftOuterSemiJoin:
require.NoError(t, errors.New("left semi/anti join does not support use left as build side"))
}
}
switch joinType {
case base.InnerJoin:
require.Equal(t, 0, len(leftFilter), "inner join does not support left filter")
require.Equal(t, 0, len(rightFilter), "inner join does not support right filter")
case base.LeftOuterJoin:
require.Equal(t, 0, len(rightFilter), "left outer join does not support right filter")
case base.RightOuterJoin:
require.Equal(t, 0, len(leftFilter), "right outer join does not support left filter")
case base.SemiJoin, base.AntiSemiJoin:
require.Equal(t, 0, len(leftFilter), "semi/anti join does not support left filter")
require.Equal(t, 0, len(rightFilter), "semi/anti join does not support right filter")
case base.LeftOuterSemiJoin, base.AntiLeftOuterSemiJoin:
require.Equal(t, 0, len(rightFilter), "left outer semi/anti join does not support right filter")
}
joinedTypes := make([]*types.FieldType, 0, len(leftTypes)+len(rightTypes))
joinedTypes = append(joinedTypes, leftTypes...)
joinedTypes = append(joinedTypes, rightTypes...)
resultTypes := make([]*types.FieldType, 0, len(leftUsed)+len(rightUsed))
for _, colIndex := range leftUsed {
resultTypes = append(resultTypes, leftTypes[colIndex].Clone())
if joinType == base.RightOuterJoin {
resultTypes[len(resultTypes)-1].DelFlag(mysql.NotNullFlag)
}
}
for _, colIndex := range rightUsed {
resultTypes = append(resultTypes, rightTypes[colIndex].Clone())
if joinType == base.LeftOuterJoin {
resultTypes[len(resultTypes)-1].DelFlag(mysql.NotNullFlag)
}
}
if joinType == base.LeftOuterSemiJoin || joinType == base.AntiLeftOuterSemiJoin {
resultTypes = append(resultTypes, types.NewFieldType(mysql.TypeTiny))
}
meta := newTableMeta(buildKeyIndex, buildTypes, buildKeyTypes, probeKeyTypes, buildUsedByOtherCondition, buildUsed, needUsedFlag)
hashJoinCtx := &HashJoinCtxV2{
hashTableMeta: meta,
BuildFilter: buildFilter,
ProbeFilter: probeFilter,
OtherCondition: otherCondition,
BuildKeyTypes: buildKeyTypes,
ProbeKeyTypes: probeKeyTypes,
RightAsBuildSide: rightAsBuildSide,
LUsed: leftUsed,
RUsed: rightUsed,
LUsedInOtherCondition: leftUsedByOtherCondition,
RUsedInOtherCondition: rightUsedByOtherCondition,
}
hashJoinCtx.SessCtx = mock.NewContext()
hashJoinCtx.JoinType = joinType
hashJoinCtx.Concurrency = uint(partitionNumber)
hashJoinCtx.SetupPartitionInfo()
// update the partition number
partitionNumber = int(hashJoinCtx.partitionNumber)
hashJoinCtx.spillHelper = newHashJoinSpillHelper(nil, partitionNumber, nil, "")
hashJoinCtx.initHashTableContext()
joinProbe := NewJoinProbe(hashJoinCtx, 0, joinType, probeKeyIndex, joinedTypes, probeKeyTypes, rightAsBuildSide)
buildSchema := &expression.Schema{}
for _, tp := range buildTypes {
buildSchema.Append(&expression.Column{
RetType: tp,
})
}
hasNullableKey := false
for _, buildKeyType := range buildKeyTypes {
if !mysql.HasNotNullFlag(buildKeyType.GetFlag()) {
hasNullableKey = true
break
}
}
builder := createRowTableBuilder(buildKeyIndex, buildKeyTypes, hashJoinCtx.partitionNumber, hasNullableKey, buildFilter != nil, joinProbe.NeedScanRowTable(), meta.nullMapLength)
chunkNumber := 3
buildChunks := make([]*chunk.Chunk, 0, chunkNumber)
probeChunks := make([]*chunk.Chunk, 0, chunkNumber)
selected := make([]bool, 0, inputRowNumber)
for i := range inputRowNumber {
if i%3 == 0 {
selected = append(selected, true)
} else {
selected = append(selected, false)
}
}
// check if build column can be inserted to probe column directly
for i := range min(len(buildTypes), len(probeTypes)) {
buildLength := chunk.GetFixedLen(buildTypes[i])
probeLength := chunk.GetFixedLen(probeTypes[i])
require.Equal(t, buildLength, probeLength, "build type and probe type is not compatible")
}
for i := range chunkNumber {
if len(buildTypes) >= len(probeTypes) {
buildChunks = append(buildChunks, testutil.GenRandomChunks(buildTypes, inputRowNumber))
probeChunk := testutil.GenRandomChunks(probeTypes, inputRowNumber*2/3)
// copy some build data to probe side, to make sure there is some matched rows
_, err := copySelectedRows(buildChunks[i], probeChunk, selected)
require.NoError(t, err)
probeChunks = append(probeChunks, probeChunk)
} else {
probeChunks = append(probeChunks, testutil.GenRandomChunks(probeTypes, inputRowNumber))
buildChunk := testutil.GenRandomChunks(buildTypes, inputRowNumber*2/3)
// copy some build data to probe side, to make sure there is some matched rows
_, err := copySelectedRows(probeChunks[i], buildChunk, selected)
require.NoError(t, err)
buildChunks = append(buildChunks, buildChunk)
}
}
if withSel {
sel := make([]int, 0, inputRowNumber)
for i := range inputRowNumber {
if i%9 == 0 {
continue
}
sel = append(sel, i)
}
for _, chk := range buildChunks {
chk.SetSel(sel)
}
for _, chk := range probeChunks {
chk.SetSel(sel)
}
}
leftChunks, rightChunks := probeChunks, buildChunks
if !rightAsBuildSide {
leftChunks, rightChunks = buildChunks, probeChunks
}
for i := range chunkNumber {
err := builder.processOneChunk(buildChunks[i], hashJoinCtx.SessCtx.GetSessionVars().StmtCtx.TypeCtx(), hashJoinCtx, 0)
require.NoError(t, err)
}
checkRowLocationAlignment(t, hashJoinCtx.hashTableContext.rowTables[0])
hashJoinCtx.hashTableContext.mergeRowTablesToHashTable(hashJoinCtx.partitionNumber, nil)
// build hash table
for i := range partitionNumber {
hashJoinCtx.hashTableContext.build(&buildTask{partitionIdx: i, segStartIdx: 0, segEndIdx: len(hashJoinCtx.hashTableContext.hashTable.tables[i].rowData.segments)})
}
// probe
resultChunks := make([]*chunk.Chunk, 0)
joinResult := &hashjoinWorkerResult{
chk: chunk.New(resultTypes, hashJoinCtx.SessCtx.GetSessionVars().MaxChunkSize, hashJoinCtx.SessCtx.GetSessionVars().MaxChunkSize),
}
for _, probeChunk := range probeChunks {
err := joinProbe.SetChunkForProbe(probeChunk)
require.NoError(t, err, "unexpected error during SetChunkForProbe")
for !joinProbe.IsCurrentChunkProbeDone() {
_, joinResult = joinProbe.Probe(joinResult, &sqlkiller.SQLKiller{})
require.NoError(t, joinResult.err, "unexpected error during join probe")
if joinResult.chk.IsFull() {
resultChunks = append(resultChunks, joinResult.chk)
joinResult.chk = chunk.New(resultTypes, hashJoinCtx.SessCtx.GetSessionVars().MaxChunkSize, hashJoinCtx.SessCtx.GetSessionVars().MaxChunkSize)
}
}
}
if joinProbe.NeedScanRowTable() {
joinProbes := make([]ProbeV2, 0, hashJoinCtx.Concurrency)
for i := uint(0); i < hashJoinCtx.Concurrency; i++ {
joinProbes = append(joinProbes, NewJoinProbe(hashJoinCtx, i, joinType, probeKeyIndex, joinedTypes, probeKeyTypes, rightAsBuildSide))
}
for _, prober := range joinProbes {
prober.InitForScanRowTable()
for !prober.IsScanRowTableDone() {
joinResult = prober.ScanRowTable(joinResult, &sqlkiller.SQLKiller{})
require.NoError(t, joinResult.err, "unexpected error during scan row table")
if joinResult.chk.IsFull() {
resultChunks = append(resultChunks, joinResult.chk)
joinResult.chk = chunk.New(resultTypes, hashJoinCtx.SessCtx.GetSessionVars().MaxChunkSize, hashJoinCtx.SessCtx.GetSessionVars().MaxChunkSize)
}
}
}
}
if joinResult.chk.NumRows() > 0 {
resultChunks = append(resultChunks, joinResult.chk)
}
checkVirtualRows(t, resultChunks)
switch joinType {
case base.InnerJoin:
expectedChunks := genInnerJoinResult(t, hashJoinCtx.SessCtx, leftChunks, rightChunks, leftKeyIndex, rightKeyIndex, leftTypes, rightTypes,
leftKeyTypes, rightKeyTypes, leftUsed, rightUsed, otherCondition, resultTypes)
checkChunksEqual(t, expectedChunks, resultChunks, resultTypes)
case base.LeftOuterJoin:
expectedChunks := genLeftOuterJoinResult(t, hashJoinCtx.SessCtx, leftFilter, leftChunks, rightChunks, leftKeyIndex, rightKeyIndex, leftTypes,
rightTypes, leftKeyTypes, rightKeyTypes, leftUsed, rightUsed, otherCondition, resultTypes)
checkChunksEqual(t, expectedChunks, resultChunks, resultTypes)
case base.RightOuterJoin:
expectedChunks := genRightOuterJoinResult(t, hashJoinCtx.SessCtx, rightFilter, leftChunks, rightChunks, leftKeyIndex, rightKeyIndex, leftTypes,
rightTypes, leftKeyTypes, rightKeyTypes, leftUsed, rightUsed, otherCondition, resultTypes)
checkChunksEqual(t, expectedChunks, resultChunks, resultTypes)
case base.LeftOuterSemiJoin:
expectedChunks := genLeftOuterSemiJoinResult(t, hashJoinCtx.SessCtx, leftFilter, leftChunks, rightChunks, leftKeyIndex, rightKeyIndex, leftTypes,
rightTypes, leftKeyTypes, rightKeyTypes, leftUsed, otherCondition, resultTypes)
checkChunksEqual(t, expectedChunks, resultChunks, resultTypes)
case base.SemiJoin:
expectedChunks := genSemiJoinResult(t, hashJoinCtx.SessCtx, leftFilter, leftChunks, rightChunks, leftKeyIndex, rightKeyIndex, leftTypes,
rightTypes, leftKeyTypes, rightKeyTypes, leftUsed, otherCondition, resultTypes)
checkChunksEqual(t, expectedChunks, resultChunks, resultTypes)
case base.AntiSemiJoin:
expectedChunks := genAntiSemiJoinResult(t, hashJoinCtx.SessCtx, leftChunks, rightChunks, leftKeyIndex, rightKeyIndex, leftTypes,
rightTypes, leftKeyTypes, rightKeyTypes, leftUsed, otherCondition, resultTypes)
checkChunksEqual(t, expectedChunks, resultChunks, resultTypes)
case base.AntiLeftOuterSemiJoin:
expectedChunks := genLeftOuterAntiSemiJoinResult(t, hashJoinCtx.SessCtx, leftFilter, leftChunks, rightChunks, leftKeyIndex, rightKeyIndex, leftTypes,
rightTypes, leftKeyTypes, rightKeyTypes, leftUsed, otherCondition, resultTypes)
checkChunksEqual(t, expectedChunks, resultChunks, resultTypes)
default:
require.NoError(t, errors.New("not supported join type"))
}
}
type testCase struct {
leftKeyIndex []int
rightKeyIndex []int
leftKeyTypes []*types.FieldType
rightKeyTypes []*types.FieldType
leftTypes []*types.FieldType
rightTypes []*types.FieldType
leftUsed []int
rightUsed []int
otherCondition expression.CNFExprs
leftUsedByOtherCondition []int
rightUsedByOtherCondition []int
}
func TestInnerJoinProbeBasic(t *testing.T) {
// todo test nullable type after builder support nullable type
tinyTp := types.NewFieldType(mysql.TypeTiny)
tinyTp.AddFlag(mysql.NotNullFlag)
intTp := types.NewFieldType(mysql.TypeLonglong)
intTp.AddFlag(mysql.NotNullFlag)
uintTp := types.NewFieldType(mysql.TypeLonglong)
uintTp.AddFlag(mysql.NotNullFlag)
uintTp.AddFlag(mysql.UnsignedFlag)
stringTp := types.NewFieldType(mysql.TypeVarString)
stringTp.AddFlag(mysql.NotNullFlag)
lTypes := []*types.FieldType{intTp, stringTp, uintTp, stringTp, tinyTp}
rTypes := []*types.FieldType{intTp, stringTp, uintTp, stringTp, tinyTp}
rTypes = append(rTypes, retTypes...)
rTypes1 := []*types.FieldType{uintTp, stringTp, intTp, stringTp, tinyTp}
rTypes1 = append(rTypes1, rTypes1...)
rightAsBuildSide := []bool{true, false}
partitionNumber := 4
testCases := []testCase{
// normal case
{[]int{0}, []int{0}, []*types.FieldType{intTp}, []*types.FieldType{intTp}, lTypes, rTypes, nil, nil, nil, nil, nil},
// rightUsed is empty
{[]int{0}, []int{0}, []*types.FieldType{intTp}, []*types.FieldType{intTp}, lTypes, rTypes, []int{0, 1, 2, 3}, []int{}, nil, nil, nil},
// leftUsed is empty
{[]int{0}, []int{0}, []*types.FieldType{intTp}, []*types.FieldType{intTp}, lTypes, rTypes, []int{}, []int{0, 1, 2, 3}, nil, nil, nil},
// both left/right Used are empty
{[]int{0}, []int{0}, []*types.FieldType{intTp}, []*types.FieldType{intTp}, lTypes, rTypes, []int{}, []int{}, nil, nil, nil},
// both left/right used is part of all columns
{[]int{0}, []int{0}, []*types.FieldType{intTp}, []*types.FieldType{intTp}, lTypes, rTypes, []int{0, 2}, []int{1, 3}, nil, nil, nil},
// int join uint
{[]int{0}, []int{0}, []*types.FieldType{intTp}, []*types.FieldType{uintTp}, lTypes, rTypes1, []int{0, 1, 2, 3}, []int{0, 1, 2, 3}, nil, nil, nil},
// multiple join keys
{[]int{0, 1}, []int{0, 1}, []*types.FieldType{intTp, stringTp}, []*types.FieldType{intTp, stringTp}, lTypes, rTypes, []int{0, 1, 2, 3}, []int{0, 1, 2, 3}, nil, nil, nil},
}
for _, tc := range testCases {
// inner join does not have left/right Filter
for _, rightAsBuild := range rightAsBuildSide {
testJoinProbe(t, false, tc.leftKeyIndex, tc.rightKeyIndex, tc.leftKeyTypes, tc.rightKeyTypes, tc.leftTypes, tc.rightTypes, rightAsBuild, tc.leftUsed,
tc.rightUsed, tc.leftUsedByOtherCondition, tc.rightUsedByOtherCondition, nil, nil, tc.otherCondition, partitionNumber, base.InnerJoin, 200)
testJoinProbe(t, false, tc.leftKeyIndex, tc.rightKeyIndex, toNullableTypes(tc.leftKeyTypes), toNullableTypes(tc.rightKeyTypes),
toNullableTypes(tc.leftTypes), toNullableTypes(tc.rightTypes), rightAsBuild, tc.leftUsed, tc.rightUsed, tc.leftUsedByOtherCondition, tc.rightUsedByOtherCondition, nil, nil, tc.otherCondition, partitionNumber, base.InnerJoin, 200)
}
}
}
func TestInnerJoinProbeAllJoinKeys(t *testing.T) {
tinyTp := types.NewFieldType(mysql.TypeTiny)
tinyTp.AddFlag(mysql.NotNullFlag)
intTp := types.NewFieldType(mysql.TypeLonglong)
intTp.AddFlag(mysql.NotNullFlag)
uintTp := types.NewFieldType(mysql.TypeLonglong)
uintTp.AddFlag(mysql.UnsignedFlag)
uintTp.AddFlag(mysql.NotNullFlag)
yearTp := types.NewFieldType(mysql.TypeYear)
yearTp.AddFlag(mysql.NotNullFlag)
durationTp := types.NewFieldType(mysql.TypeDuration)
durationTp.AddFlag(mysql.NotNullFlag)
enumTp := types.NewFieldType(mysql.TypeEnum)
enumTp.AddFlag(mysql.NotNullFlag)
enumWithIntFlag := types.NewFieldType(mysql.TypeEnum)
enumWithIntFlag.AddFlag(mysql.EnumSetAsIntFlag)
enumWithIntFlag.AddFlag(mysql.NotNullFlag)
setTp := types.NewFieldType(mysql.TypeSet)
setTp.AddFlag(mysql.NotNullFlag)
bitTp := types.NewFieldType(mysql.TypeBit)
bitTp.AddFlag(mysql.NotNullFlag)
jsonTp := types.NewFieldType(mysql.TypeJSON)
jsonTp.AddFlag(mysql.NotNullFlag)
floatTp := types.NewFieldType(mysql.TypeFloat)
floatTp.AddFlag(mysql.NotNullFlag)
doubleTp := types.NewFieldType(mysql.TypeDouble)
doubleTp.AddFlag(mysql.NotNullFlag)
stringTp := types.NewFieldType(mysql.TypeVarString)
stringTp.AddFlag(mysql.NotNullFlag)
datetimeTp := types.NewFieldType(mysql.TypeDatetime)
datetimeTp.AddFlag(mysql.NotNullFlag)
decimalTp := types.NewFieldType(mysql.TypeNewDecimal)
decimalTp.AddFlag(mysql.NotNullFlag)
timestampTp := types.NewFieldType(mysql.TypeTimestamp)
timestampTp.AddFlag(mysql.NotNullFlag)
dateTp := types.NewFieldType(mysql.TypeDate)
dateTp.AddFlag(mysql.NotNullFlag)
binaryStringTp := types.NewFieldType(mysql.TypeBlob)
binaryStringTp.AddFlag(mysql.NotNullFlag)
lTypes := []*types.FieldType{tinyTp, intTp, uintTp, yearTp, durationTp, enumTp, enumWithIntFlag, setTp, bitTp, jsonTp, floatTp, doubleTp, stringTp, datetimeTp, decimalTp, timestampTp, dateTp, binaryStringTp}
rTypes := lTypes
nullableLTypes := toNullableTypes(lTypes)
nullableRTypes := toNullableTypes(rTypes)
lUsed := []int{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17}
rUsed := lUsed
rightAsBuildSide := []bool{true, false}
partitionNumber := 4
// single key
for i := range lTypes {
for _, rightAsBuild := range rightAsBuildSide {
lKeyTypes := []*types.FieldType{lTypes[i]}
rKeyTypes := []*types.FieldType{rTypes[i]}
testJoinProbe(t, false, []int{i}, []int{i}, lKeyTypes, rKeyTypes, lTypes, rTypes, rightAsBuild, lUsed, rUsed, nil, nil, nil, nil, nil, partitionNumber, base.InnerJoin, 100)
testJoinProbe(t, false, []int{i}, []int{i}, toNullableTypes(lKeyTypes), toNullableTypes(rKeyTypes), nullableLTypes, nullableRTypes, rightAsBuild, lUsed, rUsed, nil, nil, nil, nil, nil, partitionNumber, base.InnerJoin, 100)
}
}
// composed key
// fixed size, inlined
for _, rightAsBuild := range rightAsBuildSide {
lKeyTypes := []*types.FieldType{intTp, uintTp}
rKeyTypes := []*types.FieldType{intTp, uintTp}
testJoinProbe(t, false, []int{1, 2}, []int{1, 2}, lKeyTypes, rKeyTypes, lTypes, rTypes, rightAsBuild, lUsed, rUsed, nil, nil, nil, nil, nil, partitionNumber, base.InnerJoin, 100)
testJoinProbe(t, false, []int{1, 2}, []int{1, 2}, toNullableTypes(lKeyTypes), toNullableTypes(rKeyTypes), nullableLTypes, nullableRTypes, rightAsBuild, lUsed, rUsed, nil, nil, nil, nil, nil, partitionNumber, base.InnerJoin, 100)
}
// variable size, inlined
for _, rightAsBuild := range rightAsBuildSide {
lKeyTypes := []*types.FieldType{intTp, binaryStringTp}
rKeyTypes := []*types.FieldType{intTp, binaryStringTp}
testJoinProbe(t, false, []int{1, 17}, []int{1, 17}, lKeyTypes, rKeyTypes, lTypes, rTypes, rightAsBuild, lUsed, rUsed, nil, nil, nil, nil, nil, partitionNumber, base.InnerJoin, 100)
testJoinProbe(t, false, []int{1, 17}, []int{1, 17}, toNullableTypes(lKeyTypes), toNullableTypes(rKeyTypes), nullableLTypes, nullableRTypes, rightAsBuild, lUsed, rUsed, nil, nil, nil, nil, nil, partitionNumber, base.InnerJoin, 100)
}
// fixed size, not inlined
for _, rightAsBuild := range rightAsBuildSide {
lKeyTypes := []*types.FieldType{intTp, datetimeTp}
rKeyTypes := []*types.FieldType{intTp, datetimeTp}
testJoinProbe(t, false, []int{1, 13}, []int{1, 13}, lKeyTypes, rKeyTypes, lTypes, rTypes, rightAsBuild, lUsed, rUsed, nil, nil, nil, nil, nil, partitionNumber, base.InnerJoin, 100)
testJoinProbe(t, false, []int{1, 13}, []int{1, 13}, toNullableTypes(lKeyTypes), toNullableTypes(rKeyTypes), nullableLTypes, nullableRTypes, rightAsBuild, lUsed, rUsed, nil, nil, nil, nil, nil, partitionNumber, base.InnerJoin, 100)
}
// variable size, not inlined
for _, rightAsBuild := range rightAsBuildSide {
lKeyTypes := []*types.FieldType{intTp, decimalTp}
rKeyTypes := []*types.FieldType{intTp, decimalTp}
testJoinProbe(t, false, []int{1, 14}, []int{1, 14}, lKeyTypes, rKeyTypes, lTypes, rTypes, rightAsBuild, lUsed, rUsed, nil, nil, nil, nil, nil, partitionNumber, base.InnerJoin, 100)
testJoinProbe(t, false, []int{1, 14}, []int{1, 14}, toNullableTypes(lKeyTypes), toNullableTypes(rKeyTypes), nullableLTypes, nullableRTypes, rightAsBuild, lUsed, rUsed, nil, nil, nil, nil, nil, partitionNumber, base.InnerJoin, 100)
}
}
func TestInnerJoinProbeOtherCondition(t *testing.T) {
intTp := types.NewFieldType(mysql.TypeLonglong)
intTp.AddFlag(mysql.NotNullFlag)
nullableIntTp := types.NewFieldType(mysql.TypeLonglong)
uintTp := types.NewFieldType(mysql.TypeLonglong)
uintTp.AddFlag(mysql.NotNullFlag)
uintTp.AddFlag(mysql.UnsignedFlag)
stringTp := types.NewFieldType(mysql.TypeVarString)
stringTp.AddFlag(mysql.NotNullFlag)
lTypes := []*types.FieldType{intTp, intTp, stringTp, uintTp, stringTp}
rTypes := []*types.FieldType{intTp, intTp, stringTp, uintTp, stringTp}
rTypes = append(rTypes, rTypes...)
tinyTp := types.NewFieldType(mysql.TypeTiny)
a := &expression.Column{Index: 1, RetType: nullableIntTp}
b := &expression.Column{Index: 8, RetType: nullableIntTp}
sf, err := expression.NewFunction(mock.NewContext(), ast.GT, tinyTp, a, b)
require.NoError(t, err, "error when create other condition")
otherCondition := make(expression.CNFExprs, 0)
otherCondition = append(otherCondition, sf)
rightAsBuildSide := []bool{true, false}
partitionNumber := 4
for _, rightAsBuild := range rightAsBuildSide {
testJoinProbe(t, false, []int{0}, []int{0}, []*types.FieldType{intTp}, []*types.FieldType{intTp}, lTypes, rTypes, rightAsBuild, []int{1, 2, 4}, []int{0}, []int{1}, []int{3}, nil, nil, otherCondition, partitionNumber, base.InnerJoin, 200)
testJoinProbe(t, false, []int{0}, []int{0}, []*types.FieldType{intTp}, []*types.FieldType{intTp}, lTypes, rTypes, rightAsBuild, []int{}, []int{}, []int{1}, []int{3}, nil, nil, otherCondition, partitionNumber, base.InnerJoin, 200)
testJoinProbe(t, false, []int{0}, []int{0}, []*types.FieldType{nullableIntTp}, []*types.FieldType{nullableIntTp}, toNullableTypes(lTypes), toNullableTypes(rTypes), rightAsBuild, []int{1, 2, 4}, []int{0}, []int{1}, []int{3}, nil, nil, otherCondition, partitionNumber, base.InnerJoin, 200)
testJoinProbe(t, false, []int{0}, []int{0}, []*types.FieldType{nullableIntTp}, []*types.FieldType{nullableIntTp}, toNullableTypes(lTypes), toNullableTypes(rTypes), rightAsBuild, nil, nil, []int{1}, []int{3}, nil, nil, otherCondition, partitionNumber, base.InnerJoin, 200)
}
}
func TestInnerJoinProbeWithSel(t *testing.T) {
intTp := types.NewFieldType(mysql.TypeLonglong)
intTp.AddFlag(mysql.NotNullFlag)
nullableIntTp := types.NewFieldType(mysql.TypeLonglong)
uintTp := types.NewFieldType(mysql.TypeLonglong)
uintTp.AddFlag(mysql.NotNullFlag)
uintTp.AddFlag(mysql.UnsignedFlag)
nullableUIntTp := types.NewFieldType(mysql.TypeLonglong)
nullableIntTp.AddFlag(mysql.UnsignedFlag)
stringTp := types.NewFieldType(mysql.TypeVarString)
stringTp.AddFlag(mysql.NotNullFlag)
lTypes := []*types.FieldType{intTp, intTp, stringTp, uintTp, stringTp}
rTypes := []*types.FieldType{intTp, intTp, stringTp, uintTp, stringTp}
rTypes = append(rTypes, rTypes...)
tinyTp := types.NewFieldType(mysql.TypeTiny)
a := &expression.Column{Index: 1, RetType: nullableIntTp}
b := &expression.Column{Index: 8, RetType: nullableUIntTp}
sf, err := expression.NewFunction(mock.NewContext(), ast.GT, tinyTp, a, b)
require.NoError(t, err, "error when create other condition")
otherCondition := make(expression.CNFExprs, 0)
otherCondition = append(otherCondition, sf)
otherConditions := []expression.CNFExprs{otherCondition, nil}
partitionNumber := 4
rightAsBuildSide := []bool{true, false}
for _, rightAsBuild := range rightAsBuildSide {
for _, oc := range otherConditions {
testJoinProbe(t, true, []int{0}, []int{0}, []*types.FieldType{intTp}, []*types.FieldType{intTp}, lTypes, rTypes, rightAsBuild, []int{1, 2, 4}, []int{0}, []int{1}, []int{3}, nil, nil, oc, partitionNumber, base.InnerJoin, 500)
testJoinProbe(t, true, []int{0}, []int{0}, []*types.FieldType{intTp}, []*types.FieldType{intTp}, lTypes, rTypes, rightAsBuild, []int{}, []int{}, []int{1}, []int{3}, nil, nil, oc, partitionNumber, base.InnerJoin, 500)
testJoinProbe(t, true, []int{0}, []int{0}, []*types.FieldType{nullableIntTp}, []*types.FieldType{nullableIntTp}, toNullableTypes(lTypes), toNullableTypes(rTypes), rightAsBuild, []int{1, 2, 4}, []int{0}, []int{1}, []int{3}, nil, nil, oc, partitionNumber, base.InnerJoin, 500)
testJoinProbe(t, true, []int{0}, []int{0}, []*types.FieldType{nullableIntTp}, []*types.FieldType{nullableIntTp}, toNullableTypes(lTypes), toNullableTypes(rTypes), rightAsBuild, nil, nil, []int{1}, []int{3}, nil, nil, oc, partitionNumber, base.InnerJoin, 500)
}
}
}