Files
tidb/pkg/executor/join_pkg_test.go
2024-06-14 08:35:44 +00:00

165 lines
4.8 KiB
Go

// Copyright 2020 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package executor
import (
"context"
"testing"
"github.com/pingcap/failpoint"
"github.com/pingcap/tidb/pkg/executor/internal/exec"
"github.com/pingcap/tidb/pkg/executor/internal/testutil"
"github.com/pingcap/tidb/pkg/expression"
"github.com/pingcap/tidb/pkg/parser/mysql"
"github.com/pingcap/tidb/pkg/types"
"github.com/stretchr/testify/require"
)
func TestHashJoinV2UnderApply(t *testing.T) {
colTypes := []*types.FieldType{
types.NewFieldType(mysql.TypeLonglong),
types.NewFieldType(mysql.TypeLonglong),
}
casTest := defaultHashJoinTestCase(colTypes, 0, false)
opt1 := testutil.MockDataSourceParameters{
Rows: casTest.rows,
Ctx: casTest.ctx,
GenDataFunc: func(row int, typ *types.FieldType) any {
switch typ.GetType() {
case mysql.TypeLong, mysql.TypeLonglong:
return int64(row)
case mysql.TypeDouble:
return float64(row)
default:
panic("not implement")
}
},
}
opt2 := opt1
opt1.DataSchema = expression.NewSchema(casTest.columns()...)
opt2.DataSchema = expression.NewSchema(casTest.columns()...)
dataSource1 := testutil.BuildMockDataSource(opt1)
dataSource2 := testutil.BuildMockDataSource(opt2)
dataSource1.PrepareChunks()
dataSource2.PrepareChunks()
executor := prepare4HashJoinV2(casTest, dataSource1, dataSource2)
ctx := context.Background()
for i := 0; i < 10; i++ {
// when in apply, the same executor will be open/closed multiple times
chk := exec.NewFirstChunk(executor)
err := executor.Open(ctx)
require.NoError(t, err)
rows := 0
for {
err = executor.Next(ctx, chk)
require.NoError(t, err)
if chk.NumRows() == 0 {
break
}
rows += chk.NumRows()
}
require.Equal(t, true, rows >= opt1.Rows)
err = executor.Close()
require.NoError(t, err)
dataSource1.PrepareChunks()
dataSource2.PrepareChunks()
}
}
func TestJoinExec(t *testing.T) {
require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/pkg/executor/join/testRowContainerSpill", "return(true)"))
defer func() {
require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/pkg/executor/join/testRowContainerSpill"))
}()
colTypes := []*types.FieldType{
types.NewFieldType(mysql.TypeLonglong),
types.NewFieldType(mysql.TypeDouble),
}
casTest := defaultHashJoinTestCase(colTypes, 0, false)
runTest := func() {
opt1 := testutil.MockDataSourceParameters{
Rows: casTest.rows,
Ctx: casTest.ctx,
GenDataFunc: func(row int, typ *types.FieldType) any {
switch typ.GetType() {
case mysql.TypeLong, mysql.TypeLonglong:
return int64(row)
case mysql.TypeDouble:
return float64(row)
default:
panic("not implement")
}
},
}
opt2 := opt1
opt1.DataSchema = expression.NewSchema(casTest.columns()...)
opt2.DataSchema = expression.NewSchema(casTest.columns()...)
dataSource1 := testutil.BuildMockDataSource(opt1)
dataSource2 := testutil.BuildMockDataSource(opt2)
dataSource1.PrepareChunks()
dataSource2.PrepareChunks()
executor := prepare4HashJoin(casTest, dataSource1, dataSource2)
result := exec.NewFirstChunk(executor)
{
ctx := context.Background()
chk := exec.NewFirstChunk(executor)
err := executor.Open(ctx)
require.NoError(t, err)
for {
err = executor.Next(ctx, chk)
require.NoError(t, err)
if chk.NumRows() == 0 {
break
}
result.Append(chk, 0, chk.NumRows())
}
require.Equal(t, casTest.disk, executor.RowContainer.AlreadySpilledSafeForTest())
err = executor.Close()
require.NoError(t, err)
}
require.Equal(t, 4, result.NumCols())
require.Equal(t, casTest.rows, result.NumRows())
visit := make(map[int64]bool, casTest.rows)
for i := 0; i < casTest.rows; i++ {
val := result.Column(0).Int64s()[i]
require.Equal(t, float64(val), result.Column(1).Float64s()[i])
require.Equal(t, val, result.Column(2).Int64s()[i])
require.Equal(t, float64(val), result.Column(3).Float64s()[i])
visit[val] = true
}
for i := 0; i < casTest.rows; i++ {
require.True(t, visit[int64(i)])
}
}
concurrency := []int{1, 4}
rows := []int{3, 1024, 4096}
disk := []bool{false, true}
for _, concurrency := range concurrency {
for _, rows := range rows {
for _, disk := range disk {
casTest.concurrency = concurrency
casTest.rows = rows
casTest.disk = disk
runTest()
}
}
}
}